isotree 0.2.2 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -1
  3. data/LICENSE.txt +2 -2
  4. data/README.md +32 -14
  5. data/ext/isotree/ext.cpp +144 -31
  6. data/ext/isotree/extconf.rb +7 -7
  7. data/lib/isotree/isolation_forest.rb +110 -30
  8. data/lib/isotree/version.rb +1 -1
  9. data/vendor/isotree/LICENSE +1 -1
  10. data/vendor/isotree/README.md +165 -27
  11. data/vendor/isotree/include/isotree.hpp +2111 -0
  12. data/vendor/isotree/include/isotree_oop.hpp +394 -0
  13. data/vendor/isotree/inst/COPYRIGHTS +62 -0
  14. data/vendor/isotree/src/RcppExports.cpp +525 -52
  15. data/vendor/isotree/src/Rwrapper.cpp +1931 -268
  16. data/vendor/isotree/src/c_interface.cpp +953 -0
  17. data/vendor/isotree/src/crit.hpp +4232 -0
  18. data/vendor/isotree/src/dist.hpp +1886 -0
  19. data/vendor/isotree/src/exp_depth_table.hpp +134 -0
  20. data/vendor/isotree/src/extended.hpp +1444 -0
  21. data/vendor/isotree/src/external_facing_generic.hpp +399 -0
  22. data/vendor/isotree/src/fit_model.hpp +2401 -0
  23. data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
  24. data/vendor/isotree/src/helpers_iforest.hpp +813 -0
  25. data/vendor/isotree/src/{impute.cpp → impute.hpp} +353 -122
  26. data/vendor/isotree/src/indexer.cpp +515 -0
  27. data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
  28. data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
  29. data/vendor/isotree/src/isoforest.hpp +1659 -0
  30. data/vendor/isotree/src/isotree.hpp +1804 -392
  31. data/vendor/isotree/src/isotree_exportable.hpp +99 -0
  32. data/vendor/isotree/src/merge_models.cpp +159 -16
  33. data/vendor/isotree/src/mult.hpp +1321 -0
  34. data/vendor/isotree/src/oop_interface.cpp +842 -0
  35. data/vendor/isotree/src/oop_interface.hpp +278 -0
  36. data/vendor/isotree/src/other_helpers.hpp +219 -0
  37. data/vendor/isotree/src/predict.hpp +1932 -0
  38. data/vendor/isotree/src/python_helpers.hpp +134 -0
  39. data/vendor/isotree/src/ref_indexer.hpp +154 -0
  40. data/vendor/isotree/src/robinmap/LICENSE +21 -0
  41. data/vendor/isotree/src/robinmap/README.md +483 -0
  42. data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
  43. data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1620 -0
  44. data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
  45. data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
  46. data/vendor/isotree/src/serialize.cpp +4300 -139
  47. data/vendor/isotree/src/sql.cpp +141 -59
  48. data/vendor/isotree/src/subset_models.cpp +174 -0
  49. data/vendor/isotree/src/utils.hpp +3808 -0
  50. data/vendor/isotree/src/xoshiro.hpp +467 -0
  51. data/vendor/isotree/src/ziggurat.hpp +405 -0
  52. metadata +38 -104
  53. data/vendor/cereal/LICENSE +0 -24
  54. data/vendor/cereal/README.md +0 -85
  55. data/vendor/cereal/include/cereal/access.hpp +0 -351
  56. data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
  57. data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
  58. data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
  59. data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
  60. data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
  61. data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
  62. data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
  63. data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
  64. data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
  65. data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
  66. data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
  67. data/vendor/cereal/include/cereal/details/util.hpp +0 -84
  68. data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
  69. data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
  70. data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
  71. data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
  72. data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
  73. data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
  74. data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
  75. data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
  76. data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
  77. data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
  78. data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
  79. data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
  80. data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
  81. data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
  82. data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
  83. data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
  84. data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
  85. data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
  86. data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
  87. data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
  88. data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
  89. data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
  90. data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
  91. data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
  92. data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
  93. data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
  94. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
  95. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
  96. data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
  97. data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
  98. data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
  99. data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
  100. data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
  101. data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
  102. data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
  103. data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
  104. data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
  105. data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
  106. data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
  107. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
  108. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
  109. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
  110. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
  111. data/vendor/cereal/include/cereal/macros.hpp +0 -154
  112. data/vendor/cereal/include/cereal/specialize.hpp +0 -139
  113. data/vendor/cereal/include/cereal/types/array.hpp +0 -79
  114. data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
  115. data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
  116. data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
  117. data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
  118. data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
  119. data/vendor/cereal/include/cereal/types/common.hpp +0 -129
  120. data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
  121. data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
  122. data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
  123. data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
  124. data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
  125. data/vendor/cereal/include/cereal/types/list.hpp +0 -62
  126. data/vendor/cereal/include/cereal/types/map.hpp +0 -36
  127. data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
  128. data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
  129. data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
  130. data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
  131. data/vendor/cereal/include/cereal/types/set.hpp +0 -103
  132. data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
  133. data/vendor/cereal/include/cereal/types/string.hpp +0 -61
  134. data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
  135. data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
  136. data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
  137. data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
  138. data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
  139. data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
  140. data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
  141. data/vendor/cereal/include/cereal/version.hpp +0 -52
  142. data/vendor/isotree/src/Makevars +0 -4
  143. data/vendor/isotree/src/crit.cpp +0 -912
  144. data/vendor/isotree/src/dist.cpp +0 -749
  145. data/vendor/isotree/src/extended.cpp +0 -790
  146. data/vendor/isotree/src/fit_model.cpp +0 -1090
  147. data/vendor/isotree/src/helpers_iforest.cpp +0 -324
  148. data/vendor/isotree/src/isoforest.cpp +0 -771
  149. data/vendor/isotree/src/mult.cpp +0 -607
  150. data/vendor/isotree/src/predict.cpp +0 -853
  151. data/vendor/isotree/src/utils.cpp +0 -1566
@@ -0,0 +1,4232 @@
1
+ /* Isolation forests and variations thereof, with adjustments for incorporation
2
+ * of categorical variables and missing values.
3
+ * Writen for C++11 standard and aimed at being used in R and Python.
4
+ *
5
+ * This library is based on the following works:
6
+ * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
+ * "Isolation forest."
8
+ * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
+ * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
+ * "Isolation-based anomaly detection."
11
+ * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
+ * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
+ * "Extended Isolation Forest."
14
+ * arXiv preprint arXiv:1811.02141 (2018).
15
+ * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
+ * "On detecting clustered anomalies using SCiForest."
17
+ * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
+ * [5] https://sourceforge.net/projects/iforest/
19
+ * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
+ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
+ * [8] Cortes, David.
22
+ * "Distance approximation using Isolation Forests."
23
+ * arXiv preprint arXiv:1910.12362 (2019).
24
+ * [9] Cortes, David.
25
+ * "Imputing missing values with unsupervised random trees."
26
+ * arXiv preprint arXiv:1911.06646 (2019).
27
+ * [10] https://math.stackexchange.com/questions/3333220/expected-average-depth-in-random-binary-tree-constructed-top-to-bottom
28
+ * [11] Cortes, David.
29
+ * "Revisiting randomized choices in isolation forests."
30
+ * arXiv preprint arXiv:2110.13402 (2021).
31
+ * [12] Guha, Sudipto, et al.
32
+ * "Robust random cut forest based anomaly detection on streams."
33
+ * International conference on machine learning. PMLR, 2016.
34
+ * [13] Cortes, David.
35
+ * "Isolation forests: looking beyond tree depth."
36
+ * arXiv preprint arXiv:2111.11639 (2021).
37
+ * [14] Ting, Kai Ming, Yue Zhu, and Zhi-Hua Zhou.
38
+ * "Isolation kernel and its effect on SVM"
39
+ * Proceedings of the 24th ACM SIGKDD
40
+ * International Conference on Knowledge Discovery & Data Mining. 2018.
41
+ *
42
+ * BSD 2-Clause License
43
+ * Copyright (c) 2019-2022, David Cortes
44
+ * All rights reserved.
45
+ * Redistribution and use in source and binary forms, with or without
46
+ * modification, are permitted provided that the following conditions are met:
47
+ * * Redistributions of source code must retain the above copyright notice, this
48
+ * list of conditions and the following disclaimer.
49
+ * * Redistributions in binary form must reproduce the above copyright notice,
50
+ * this list of conditions and the following disclaimer in the documentation
51
+ * and/or other materials provided with the distribution.
52
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
53
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
54
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
55
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
56
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
57
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
58
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
59
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
60
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
61
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
62
+ */
63
+ #include "isotree.hpp"
64
+
65
+ /* TODO: should the kurtosis calculation impute values when using ndim=1 + missing_action=Impute?
66
+ It should be the theoretically correct approach, but will cause the kurtosis to increase
67
+ significantly if there is a large number of missing values, which would lead to prefer
68
+ splitting on columns with mostly missing values. */
69
+
70
+ /* TODO: this kurtosis caps the minimum values to zero, but the minimum achievable value is 1,
71
+ see how are imprecise results used outside of the function in the different kind of calculations
72
+ that use kurtosis and maybe change the logic. */
73
+
74
+ #define pw1(x) ((x))
75
+ #define pw2(x) ((x) * (x))
76
+ #define pw3(x) ((x) * (x) * (x))
77
+ #define pw4(x) ((x) * (x) * (x) * (x))
78
+
79
+ /* https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics */
80
+ template <class real_t, class ldouble_safe>
81
+ double calc_kurtosis(size_t ix_arr[], size_t st, size_t end, real_t x[], MissingAction missing_action)
82
+ {
83
+ ldouble_safe m = 0;
84
+ ldouble_safe M2 = 0, M3 = 0, M4 = 0;
85
+ ldouble_safe delta, delta_s, delta_div;
86
+ ldouble_safe diff, n;
87
+ ldouble_safe out;
88
+
89
+ if (missing_action == Fail)
90
+ {
91
+ for (size_t row = st; row <= end; row++)
92
+ {
93
+ n = (ldouble_safe)(row - st + 1);
94
+
95
+ delta = x[ix_arr[row]] - m;
96
+ delta_div = delta / n;
97
+ delta_s = delta_div * delta_div;
98
+ diff = delta * (delta_div * (ldouble_safe)(row - st));
99
+
100
+ m += delta_div;
101
+ M4 += diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
102
+ M3 += diff * delta_div * (n - 2) - 3 * delta_div * M2;
103
+ M2 += diff;
104
+ }
105
+
106
+ if (unlikely(!is_na_or_inf(M2) && M2 <= 0))
107
+ {
108
+ if (!check_more_than_two_unique_values(ix_arr, st, end, x, missing_action))
109
+ return -HUGE_VAL;
110
+ }
111
+
112
+ out = ( M4 / M2 ) * ( (ldouble_safe)(end - st + 1) / M2 );
113
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
114
+ }
115
+
116
+ else
117
+ {
118
+ size_t cnt = 0;
119
+ for (size_t row = st; row <= end; row++)
120
+ {
121
+ if (likely(!is_na_or_inf(x[ix_arr[row]])))
122
+ {
123
+ cnt++;
124
+ n = (ldouble_safe) cnt;
125
+
126
+ delta = x[ix_arr[row]] - m;
127
+ delta_div = delta / n;
128
+ delta_s = delta_div * delta_div;
129
+ diff = delta * (delta_div * (ldouble_safe)(cnt - 1));
130
+
131
+ m += delta_div;
132
+ M4 += diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
133
+ M3 += diff * delta_div * (n - 2) - 3 * delta_div * M2;
134
+ M2 += diff;
135
+ }
136
+ }
137
+
138
+ if (unlikely(cnt == 0)) return -HUGE_VAL;
139
+ if (unlikely(!is_na_or_inf(M2) && M2 <= 0))
140
+ {
141
+ if (!check_more_than_two_unique_values(ix_arr, st, end, x, missing_action))
142
+ return -HUGE_VAL;
143
+ }
144
+
145
+ out = ( M4 / M2 ) * ( (ldouble_safe)cnt / M2 );
146
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
147
+ }
148
+ }
149
+
150
+ template <class real_t, class ldouble_safe>
151
+ double calc_kurtosis(real_t x[], size_t n, MissingAction missing_action)
152
+ {
153
+ ldouble_safe m = 0;
154
+ ldouble_safe M2 = 0, M3 = 0, M4 = 0;
155
+ ldouble_safe delta, delta_s, delta_div;
156
+ ldouble_safe diff, n_;
157
+ ldouble_safe out;
158
+
159
+ if (missing_action == Fail)
160
+ {
161
+ for (size_t row = 0; row < n; row++)
162
+ {
163
+ n_ = (ldouble_safe)(row + 1);
164
+
165
+ delta = x[row] - m;
166
+ delta_div = delta / n_;
167
+ delta_s = delta_div * delta_div;
168
+ diff = delta * (delta_div * (ldouble_safe)row);
169
+
170
+ m += delta_div;
171
+ M4 += diff * delta_s * (n_ * n_ - 3 * n_ + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
172
+ M3 += diff * delta_div * (n_ - 2) - 3 * delta_div * M2;
173
+ M2 += diff;
174
+ }
175
+
176
+ out = ( M4 / M2 ) * ( (ldouble_safe)n / M2 );
177
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
178
+ }
179
+
180
+ else
181
+ {
182
+ size_t cnt = 0;
183
+ for (size_t row = 0; row < n; row++)
184
+ {
185
+ if (likely(!is_na_or_inf(x[row])))
186
+ {
187
+ cnt++;
188
+ n_ = (ldouble_safe) cnt;
189
+
190
+ delta = x[row] - m;
191
+ delta_div = delta / n_;
192
+ delta_s = delta_div * delta_div;
193
+ diff = delta * (delta_div * (ldouble_safe)(cnt - 1));
194
+
195
+ m += delta_div;
196
+ M4 += diff * delta_s * (n_ * n_ - 3 * n_ + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
197
+ M3 += diff * delta_div * (n_ - 2) - 3 * delta_div * M2;
198
+ M2 += diff;
199
+ }
200
+ }
201
+
202
+ if (unlikely(cnt == 0)) return -HUGE_VAL;
203
+
204
+ out = ( M4 / M2 ) * ( (ldouble_safe)cnt / M2 );
205
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
206
+ }
207
+ }
208
+
209
+ /* TODO: is this algorithm correct? */
210
+ template <class real_t, class mapping, class ldouble_safe>
211
+ double calc_kurtosis_weighted(size_t ix_arr[], size_t st, size_t end, real_t x[],
212
+ MissingAction missing_action, mapping &restrict w)
213
+ {
214
+ ldouble_safe m = 0;
215
+ ldouble_safe M2 = 0, M3 = 0, M4 = 0;
216
+ ldouble_safe delta, delta_s, delta_div;
217
+ ldouble_safe diff;
218
+ ldouble_safe n = 0;
219
+ ldouble_safe out;
220
+ ldouble_safe n_prev = 0.;
221
+ ldouble_safe w_this;
222
+
223
+ for (size_t row = st; row <= end; row++)
224
+ {
225
+ if (likely(!is_na_or_inf(x[ix_arr[row]])))
226
+ {
227
+ w_this = w[ix_arr[row]];
228
+ n += w_this;
229
+
230
+ delta = x[ix_arr[row]] - m;
231
+ delta_div = delta / n;
232
+ delta_s = delta_div * delta_div;
233
+ diff = delta * (delta_div * n_prev);
234
+ n_prev = n;
235
+
236
+ m += w_this * (delta_div);
237
+ M4 += w_this * (diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3);
238
+ M3 += w_this * (diff * delta_div * (n - 2) - 3 * delta_div * M2);
239
+ M2 += w_this * (diff);
240
+ }
241
+ }
242
+
243
+ if (unlikely(n <= 0)) return -HUGE_VAL;
244
+ if (unlikely(!is_na_or_inf(M2) && M2 <= std::numeric_limits<double>::epsilon()))
245
+ {
246
+ if (!check_more_than_two_unique_values(ix_arr, st, end, x, missing_action))
247
+ return -HUGE_VAL;
248
+ }
249
+
250
+ out = ( M4 / M2 ) * ( n / M2 );
251
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
252
+ }
253
+
254
+ template <class real_t, class ldouble_safe>
255
+ double calc_kurtosis_weighted(real_t *restrict x, size_t n_, MissingAction missing_action, real_t *restrict w)
256
+ {
257
+ ldouble_safe m = 0;
258
+ ldouble_safe M2 = 0, M3 = 0, M4 = 0;
259
+ ldouble_safe delta, delta_s, delta_div;
260
+ ldouble_safe diff;
261
+ ldouble_safe n = 0;
262
+ ldouble_safe out;
263
+ ldouble_safe n_prev = 0.;
264
+ ldouble_safe w_this;
265
+
266
+ for (size_t row = 0; row < n_; row++)
267
+ {
268
+ if (likely(!is_na_or_inf(x[row])))
269
+ {
270
+ w_this = w[row];
271
+ n += w_this;
272
+
273
+ delta = x[row] - m;
274
+ delta_div = delta / n;
275
+ delta_s = delta_div * delta_div;
276
+ diff = delta * (delta_div * n_prev);
277
+ n_prev = n;
278
+
279
+ m += w_this * (delta_div);
280
+ M4 += w_this * (diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3);
281
+ M3 += w_this * (diff * delta_div * (n - 2) - 3 * delta_div * M2);
282
+ M2 += w_this * (diff);
283
+ }
284
+ }
285
+
286
+ if (unlikely(n <= 0)) return -HUGE_VAL;
287
+
288
+ out = ( M4 / M2 ) * ( n / M2 );
289
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
290
+ }
291
+
292
+
293
+ /* TODO: make these compensated sums */
294
+ /* TODO: can this use the same algorithm as above but with a correction at the end,
295
+ like it was done for the variance? */
296
+ template <class real_t, class sparse_ix, class ldouble_safe>
297
+ double calc_kurtosis(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
298
+ real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
299
+ MissingAction missing_action)
300
+ {
301
+ /* ix_arr must be already sorted beforehand */
302
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
303
+ return -HUGE_VAL;
304
+
305
+ ldouble_safe s1 = 0;
306
+ ldouble_safe s2 = 0;
307
+ ldouble_safe s3 = 0;
308
+ ldouble_safe s4 = 0;
309
+ ldouble_safe x_sq;
310
+ size_t cnt = end - st + 1;
311
+
312
+ if (unlikely(cnt <= 1)) return -HUGE_VAL;
313
+
314
+ size_t st_col = Xc_indptr[col_num];
315
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
316
+ size_t curr_pos = st_col;
317
+ size_t ind_end_col = Xc_ind[end_col];
318
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
319
+
320
+ ldouble_safe xval;
321
+
322
+ if (missing_action != Fail)
323
+ {
324
+ for (size_t *row = ptr_st;
325
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
326
+ )
327
+ {
328
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
329
+ {
330
+ xval = Xc[curr_pos];
331
+ if (unlikely(is_na_or_inf(xval)))
332
+ {
333
+ cnt--;
334
+ }
335
+
336
+ else
337
+ {
338
+ /* TODO: is it safe to use FMA here? some calculations rely on assuming that
339
+ some of these 's' are larger than the others. Would this procedure be guaranteed
340
+ to preserve such differences if done with a mixture of sums and FMAs? */
341
+ x_sq = square(xval);
342
+ s1 += xval;
343
+ s2 = std::fma(xval, xval, s2);
344
+ s3 = std::fma(x_sq, xval, s3);
345
+ s4 = std::fma(x_sq, x_sq, s4);
346
+ // s1 += pw1(xval);
347
+ // s2 += pw2(xval);
348
+ // s3 += pw3(xval);
349
+ // s4 += pw4(xval);
350
+ }
351
+
352
+ if (row == ix_arr + end || curr_pos == end_col) break;
353
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
354
+ }
355
+
356
+ else
357
+ {
358
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
359
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
360
+ else
361
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
362
+ }
363
+ }
364
+
365
+ if (unlikely(cnt <= (end - st + 1) - (Xc_indptr[col_num+1] - Xc_indptr[col_num]))) return -HUGE_VAL;
366
+ }
367
+
368
+ else
369
+ {
370
+ for (size_t *row = ptr_st;
371
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
372
+ )
373
+ {
374
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
375
+ {
376
+ xval = Xc[curr_pos];
377
+ x_sq = square(xval);
378
+ s1 += xval;
379
+ s2 = std::fma(xval, xval, s2);
380
+ s3 = std::fma(x_sq, xval, s3);
381
+ s4 = std::fma(x_sq, x_sq, s4);
382
+ // s1 += pw1(xval);
383
+ // s2 += pw2(xval);
384
+ // s3 += pw3(xval);
385
+ // s4 += pw4(xval);
386
+
387
+ if (row == ix_arr + end || curr_pos == end_col) break;
388
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
389
+ }
390
+
391
+ else
392
+ {
393
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
394
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
395
+ else
396
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
397
+ }
398
+ }
399
+ }
400
+
401
+ if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
402
+ ldouble_safe cnt_l = (ldouble_safe) cnt;
403
+ ldouble_safe sn = s1 / cnt_l;
404
+ ldouble_safe v = s2 / cnt_l - pw2(sn);
405
+ if (unlikely(std::isnan(v))) return -HUGE_VAL;
406
+ if (
407
+ v <= std::numeric_limits<double>::epsilon() &&
408
+ !check_more_than_two_unique_values(ix_arr, st, end, col_num,
409
+ Xc_indptr, Xc_ind, Xc,
410
+ missing_action)
411
+ )
412
+ return -HUGE_VAL;
413
+ if (unlikely(v <= 0)) return 0.;
414
+ ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt_l * pw4(sn)) / (cnt_l * pw2(v));
415
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
416
+ }
417
+
418
+ template <class real_t, class sparse_ix, class ldouble_safe>
419
+ double calc_kurtosis(size_t col_num, size_t nrows,
420
+ real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
421
+ MissingAction missing_action)
422
+ {
423
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
424
+ return -HUGE_VAL;
425
+
426
+ ldouble_safe s1 = 0;
427
+ ldouble_safe s2 = 0;
428
+ ldouble_safe s3 = 0;
429
+ ldouble_safe s4 = 0;
430
+ ldouble_safe x_sq;
431
+ size_t cnt = nrows;
432
+
433
+ if (unlikely(cnt <= 1)) return -HUGE_VAL;
434
+
435
+ ldouble_safe xval;
436
+
437
+ if (missing_action != Fail)
438
+ {
439
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num+1]; ix++)
440
+ {
441
+ xval = Xc[ix];
442
+ if (unlikely(is_na_or_inf(xval)))
443
+ {
444
+ cnt--;
445
+ }
446
+
447
+ else
448
+ {
449
+ x_sq = square(xval);
450
+ s1 += xval;
451
+ s2 = std::fma(xval, xval, s2);
452
+ s3 = std::fma(x_sq, xval, s3);
453
+ s4 = std::fma(x_sq, x_sq, s4);
454
+ // s1 += pw1(xval);
455
+ // s2 += pw2(xval);
456
+ // s3 += pw3(xval);
457
+ // s4 += pw4(xval);
458
+ }
459
+ }
460
+
461
+ if (cnt <= (nrows) - (Xc_indptr[col_num+1] - Xc_indptr[col_num])) return -HUGE_VAL;
462
+ }
463
+
464
+ else
465
+ {
466
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num+1]; ix++)
467
+ {
468
+ xval = Xc[ix];
469
+ x_sq = square(xval);
470
+ s1 += xval;
471
+ s2 = std::fma(xval, xval, s2);
472
+ s3 = std::fma(x_sq, xval, s3);
473
+ s4 = std::fma(x_sq, x_sq, s4);
474
+ // s1 += pw1(xval);
475
+ // s2 += pw2(xval);
476
+ // s3 += pw3(xval);
477
+ // s4 += pw4(xval);
478
+ }
479
+ }
480
+
481
+ if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
482
+ ldouble_safe cnt_l = (ldouble_safe) cnt;
483
+ ldouble_safe sn = s1 / cnt_l;
484
+ ldouble_safe v = s2 / cnt_l - pw2(sn);
485
+ if (unlikely(std::isnan(v))) return -HUGE_VAL;
486
+ if (
487
+ v <= std::numeric_limits<double>::epsilon() &&
488
+ !check_more_than_two_unique_values(nrows, col_num,
489
+ Xc_indptr, Xc_ind, Xc,
490
+ missing_action)
491
+ )
492
+ return -HUGE_VAL;
493
+ if (unlikely(v <= 0)) return 0.;
494
+ ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt_l * pw4(sn)) / (cnt_l * pw2(v));
495
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
496
+ }
497
+
498
+
499
+ template <class real_t, class sparse_ix, class mapping, class ldouble_safe>
500
+ double calc_kurtosis_weighted(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
501
+ real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
502
+ MissingAction missing_action, mapping &restrict w)
503
+ {
504
+ /* ix_arr must be already sorted beforehand */
505
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
506
+ return -HUGE_VAL;
507
+
508
+ ldouble_safe s1 = 0;
509
+ ldouble_safe s2 = 0;
510
+ ldouble_safe s3 = 0;
511
+ ldouble_safe s4 = 0;
512
+ ldouble_safe x_sq;
513
+ ldouble_safe w_this;
514
+ ldouble_safe cnt = 0;
515
+ for (size_t row = st; row <= end; row++)
516
+ cnt += w[ix_arr[row]];
517
+
518
+ if (unlikely(cnt <= 0)) return -HUGE_VAL;
519
+
520
+ size_t st_col = Xc_indptr[col_num];
521
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
522
+ size_t curr_pos = st_col;
523
+ size_t ind_end_col = Xc_ind[end_col];
524
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
525
+
526
+ ldouble_safe xval;
527
+
528
+ if (missing_action != Fail)
529
+ {
530
+ for (size_t *row = ptr_st;
531
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
532
+ )
533
+ {
534
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
535
+ {
536
+ w_this = w[*row];
537
+ xval = Xc[curr_pos];
538
+
539
+ if (unlikely(is_na_or_inf(xval)))
540
+ {
541
+ cnt -= w_this;
542
+ }
543
+
544
+ else
545
+ {
546
+ x_sq = xval * xval;
547
+ s1 = std::fma(w_this, xval, s1);
548
+ s2 = std::fma(w_this, x_sq, s2);
549
+ s3 = std::fma(w_this, x_sq*xval, s3);
550
+ s4 = std::fma(w_this, x_sq*x_sq, s4);
551
+ // s1 += w_this * pw1(xval);
552
+ // s2 += w_this * pw2(xval);
553
+ // s3 += w_this * pw3(xval);
554
+ // s4 += w_this * pw4(xval);
555
+ }
556
+
557
+ if (row == ix_arr + end || curr_pos == end_col) break;
558
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
559
+ }
560
+
561
+ else
562
+ {
563
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
564
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
565
+ else
566
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
567
+ }
568
+ }
569
+
570
+ if (unlikely(cnt <= 0)) return -HUGE_VAL;
571
+ }
572
+
573
+ else
574
+ {
575
+ for (size_t *row = ptr_st;
576
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
577
+ )
578
+ {
579
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
580
+ {
581
+ w_this = w[*row];
582
+ xval = Xc[curr_pos];
583
+
584
+ x_sq = xval * xval;
585
+ s1 = std::fma(w_this, xval, s1);
586
+ s2 = std::fma(w_this, x_sq, s2);
587
+ s3 = std::fma(w_this, x_sq*xval, s3);
588
+ s4 = std::fma(w_this, x_sq*x_sq, s4);
589
+ // s1 += w_this * pw1(xval);
590
+ // s2 += w_this * pw2(xval);
591
+ // s3 += w_this * pw3(xval);
592
+ // s4 += w_this * pw4(xval);
593
+
594
+ if (row == ix_arr + end || curr_pos == end_col) break;
595
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
596
+ }
597
+
598
+ else
599
+ {
600
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
601
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
602
+ else
603
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
604
+ }
605
+ }
606
+ }
607
+
608
+ if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
609
+ ldouble_safe sn = s1 / cnt;
610
+ ldouble_safe v = s2 / cnt - pw2(sn);
611
+ if (unlikely(std::isnan(v))) return -HUGE_VAL;
612
+ if (
613
+ v <= std::numeric_limits<double>::epsilon() &&
614
+ !check_more_than_two_unique_values(ix_arr, st, end, col_num,
615
+ Xc_indptr, Xc_ind, Xc,
616
+ missing_action)
617
+ )
618
+ return -HUGE_VAL;
619
+ if (v <= 0) return 0.;
620
+ ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt * pw4(sn)) / (cnt * pw2(v));
621
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
622
+ }
623
+
624
+ template <class real_t, class sparse_ix, class ldouble_safe>
625
+ double calc_kurtosis_weighted(size_t col_num, size_t nrows,
626
+ real_t *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
627
+ MissingAction missing_action, real_t *restrict w)
628
+ {
629
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
630
+ return -HUGE_VAL;
631
+
632
+ ldouble_safe s1 = 0;
633
+ ldouble_safe s2 = 0;
634
+ ldouble_safe s3 = 0;
635
+ ldouble_safe s4 = 0;
636
+ ldouble_safe x_sq;
637
+ ldouble_safe w_this;
638
+ ldouble_safe cnt = nrows - (Xc_indptr[col_num + 1] - Xc_indptr[col_num]);
639
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num + 1]; ix++)
640
+ cnt += w[Xc_ind[ix]];
641
+
642
+ if (unlikely(cnt <= 0)) return -HUGE_VAL;
643
+
644
+ ldouble_safe xval;
645
+
646
+ if (missing_action != Fail)
647
+ {
648
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num + 1]; ix++)
649
+ {
650
+ w_this = w[Xc_ind[ix]];
651
+ xval = Xc[ix];
652
+
653
+ if (unlikely(is_na_or_inf(xval)))
654
+ {
655
+ cnt -= w_this;
656
+ }
657
+
658
+ else
659
+ {
660
+ x_sq = xval * xval;
661
+ s1 = std::fma(w_this, xval, s1);
662
+ s2 = std::fma(w_this, x_sq, s2);
663
+ s3 = std::fma(w_this, x_sq*xval, s3);
664
+ s4 = std::fma(w_this, x_sq*x_sq, s4);
665
+ // s1 += w_this * pw1(xval);
666
+ // s2 += w_this * pw2(xval);
667
+ // s3 += w_this * pw3(xval);
668
+ // s4 += w_this * pw4(xval);
669
+ }
670
+ }
671
+
672
+ if (cnt <= 0) return -HUGE_VAL;
673
+ }
674
+
675
+ else
676
+ {
677
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num + 1]; ix++)
678
+ {
679
+ w_this = w[Xc_ind[ix]];
680
+ xval = Xc[ix];
681
+
682
+ x_sq = xval * xval;
683
+ s1 = std::fma(w_this, xval, s1);
684
+ s2 = std::fma(w_this, x_sq, s2);
685
+ s3 = std::fma(w_this, x_sq*xval, s3);
686
+ s4 = std::fma(w_this, x_sq*x_sq, s4);
687
+ // s1 += w_this * pw1(xval);
688
+ // s2 += w_this * pw2(xval);
689
+ // s3 += w_this * pw3(xval);
690
+ // s4 += w_this * pw4(xval);
691
+ }
692
+ }
693
+
694
+ if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
695
+ ldouble_safe sn = s1 / cnt;
696
+ ldouble_safe v = s2 / cnt - pw2(sn);
697
+ if (unlikely(std::isnan(v))) return -HUGE_VAL;
698
+ if (
699
+ v <= std::numeric_limits<double>::epsilon() &&
700
+ !check_more_than_two_unique_values(nrows, col_num,
701
+ Xc_indptr, Xc_ind, Xc,
702
+ missing_action)
703
+ )
704
+ return -HUGE_VAL;
705
+ if (unlikely(v <= 0)) return -HUGE_VAL;
706
+ ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt * pw4(sn)) / (cnt * pw2(v));
707
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
708
+ }
709
+
710
+
711
+ template <class ldouble_safe>
712
+ double calc_kurtosis_internal(size_t cnt, int x[], int ncat, size_t buffer_cnt[], double buffer_prob[],
713
+ MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
714
+ {
715
+ /* This calculation proceeds as follows:
716
+ - If splitting by subsets, it will assign a random weight ~Unif(0,1) to
717
+ each category, and approximate kurtosis by sampling from such distribution
718
+ with the same probabilities as given by the current counts.
719
+ - If splitting by isolating one category, will binarize at each categorical level,
720
+ assume the values are zero or one, and output the average assuming each categorical
721
+ level has equal probability of being picked.
722
+ (Note that both are misleading heuristics, but might be better than random)
723
+ */
724
+ double sum_kurt = 0;
725
+
726
+ cnt -= buffer_cnt[ncat];
727
+ if (cnt <= 1) return -HUGE_VAL;
728
+ ldouble_safe cnt_l = (ldouble_safe) cnt;
729
+ for (int cat = 0; cat < ncat; cat++)
730
+ buffer_prob[cat] = buffer_cnt[cat] / cnt_l;
731
+
732
+ switch (cat_split_type)
733
+ {
734
+ case SubSet:
735
+ {
736
+ ldouble_safe temp_v;
737
+ ldouble_safe s1, s2, s3, s4;
738
+ ldouble_safe coef;
739
+ ldouble_safe coef2;
740
+ ldouble_safe w_this;
741
+ UniformUnitInterval runif(0, 1);
742
+ size_t ntry = 50;
743
+ for (size_t iternum = 0; iternum < 50; iternum++)
744
+ {
745
+ s1 = 0; s2 = 0; s3 = 0; s4 = 0;
746
+ for (int cat = 0; cat < ncat; cat++)
747
+ {
748
+ coef = runif(rnd_generator);
749
+ coef2 = coef * coef;
750
+ w_this = buffer_prob[cat];
751
+ s1 = std::fma(w_this, coef, s1);
752
+ s2 = std::fma(w_this, coef2, s2);
753
+ s3 = std::fma(w_this, coef2*coef, s3);
754
+ s4 = std::fma(w_this, coef2*coef2, s4);
755
+ // s1 += buffer_prob[cat] * pw1(coef);
756
+ // s2 += buffer_prob[cat] * pw2(coef);
757
+ // s3 += buffer_prob[cat] * pw3(coef);
758
+ // s4 += buffer_prob[cat] * pw4(coef);
759
+ }
760
+ temp_v = s2 - pw2(s1);
761
+ if (temp_v <= 0)
762
+ ntry--;
763
+ else
764
+ sum_kurt += (s4 - 4 * s3 * pw1(s1) + 6 * s2 * pw2(s1) - 4 * s1 * pw3(s1) + pw4(s1)) / pw2(temp_v);
765
+ }
766
+ if (unlikely(!ntry))
767
+ return -HUGE_VAL;
768
+ else if (unlikely(is_na_or_inf(sum_kurt)))
769
+ return -HUGE_VAL;
770
+ else
771
+ return std::fmax(sum_kurt, 0.) / (double)ntry;
772
+ }
773
+
774
+ case SingleCateg:
775
+ {
776
+ double p;
777
+ int ncat_present = ncat;
778
+ for (int cat = 0; cat < ncat; cat++)
779
+ {
780
+ p = buffer_prob[cat];
781
+ if (p == 0)
782
+ ncat_present--;
783
+ else
784
+ sum_kurt += (p - 4 * p * pw1(p) + 6 * p * pw2(p) - 4 * p * pw3(p) + pw4(p)) / pw2(p - pw2(p));
785
+ }
786
+ if (ncat_present <= 1)
787
+ return -HUGE_VAL;
788
+ else if (unlikely(is_na_or_inf(sum_kurt)))
789
+ return -HUGE_VAL;
790
+ else
791
+ return std::fmax(sum_kurt, 0.) / (double)ncat_present;
792
+ }
793
+ }
794
+
795
+ return -1; /* this will never be reached, but CRAN complains otherwise */
796
+ }
797
+
798
+ template <class ldouble_safe>
799
+ double calc_kurtosis(size_t *restrict ix_arr, size_t st, size_t end, int x[], int ncat, size_t *restrict buffer_cnt, double buffer_prob[],
800
+ MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
801
+ {
802
+ /* This calculation proceeds as follows:
803
+ - If splitting by subsets, it will assign a random weight ~Unif(0,1) to
804
+ each category, and approximate kurtosis by sampling from such distribution
805
+ with the same probabilities as given by the current counts.
806
+ - If splitting by isolating one category, will binarize at each categorical level,
807
+ assume the values are zero or one, and output the average assuming each categorical
808
+ level has equal probability of being picked.
809
+ (Note that both are misleading heuristics, but might be better than random)
810
+ */
811
+ size_t cnt = end - st + 1;
812
+ std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
813
+
814
+ if (missing_action == Fail)
815
+ {
816
+ for (size_t row = st; row <= end; row++)
817
+ buffer_cnt[x[ix_arr[row]]]++;
818
+ }
819
+
820
+ else
821
+ {
822
+ for (size_t row = st; row <= end; row++)
823
+ {
824
+ if (likely(x[ix_arr[row]] >= 0))
825
+ buffer_cnt[x[ix_arr[row]]]++;
826
+ else
827
+ buffer_cnt[ncat]++;
828
+ }
829
+ }
830
+
831
+ return calc_kurtosis_internal<ldouble_safe>(
832
+ cnt, x, ncat, buffer_cnt, buffer_prob,
833
+ missing_action, cat_split_type, rnd_generator);
834
+ }
835
+
836
+ template <class ldouble_safe>
837
+ double calc_kurtosis(size_t nrows, int x[], int ncat, size_t buffer_cnt[], double buffer_prob[],
838
+ MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
839
+ {
840
+ size_t cnt = nrows;
841
+ std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
842
+
843
+ if (missing_action == Fail)
844
+ {
845
+ for (size_t row = 0; row < nrows; row++)
846
+ buffer_cnt[x[row]]++;
847
+ }
848
+
849
+ else
850
+ {
851
+ for (size_t row = 0; row < nrows; row++)
852
+ {
853
+ if (likely(x[row] >= 0))
854
+ buffer_cnt[x[row]]++;
855
+ else
856
+ buffer_cnt[ncat]++;
857
+ }
858
+ }
859
+
860
+ return calc_kurtosis_internal<ldouble_safe>(
861
+ cnt, x, ncat, buffer_cnt, buffer_prob,
862
+ missing_action, cat_split_type, rnd_generator);
863
+ }
864
+
865
+
866
+ /* TODO: this one should get a buffer preallocated from outside */
867
+ template <class mapping, class ldouble_safe>
868
+ double calc_kurtosis_weighted_internal(std::vector<ldouble_safe> &buffer_cnt, int x[], int ncat,
869
+ double buffer_prob[], MissingAction missing_action, CategSplit cat_split_type,
870
+ RNG_engine &rnd_generator, mapping &restrict w)
871
+ {
872
+ double sum_kurt = 0;
873
+
874
+ ldouble_safe cnt = std::accumulate(buffer_cnt.begin(), buffer_cnt.end(), (ldouble_safe)0);
875
+
876
+ cnt -= buffer_cnt[ncat];
877
+ if (unlikely(cnt <= 1)) return -HUGE_VAL;
878
+ for (int cat = 0; cat < ncat; cat++)
879
+ buffer_prob[cat] = buffer_cnt[cat] / cnt;
880
+
881
+ switch (cat_split_type)
882
+ {
883
+ case SubSet:
884
+ {
885
+ ldouble_safe temp_v;
886
+ ldouble_safe s1, s2, s3, s4;
887
+ ldouble_safe coef, coef2;
888
+ ldouble_safe w_this;
889
+ UniformUnitInterval runif(0, 1);
890
+ size_t ntry = 50;
891
+ for (size_t iternum = 0; iternum < 50; iternum++)
892
+ {
893
+ s1 = 0; s2 = 0; s3 = 0; s4 = 0;
894
+ for (int cat = 0; cat < ncat; cat++)
895
+ {
896
+ coef = runif(rnd_generator);
897
+ coef2 = coef * coef;
898
+ w_this = buffer_prob[cat];
899
+ s1 = std::fma(w_this, coef, s1);
900
+ s2 = std::fma(w_this, coef2, s2);
901
+ s3 = std::fma(w_this, coef2*coef, s3);
902
+ s4 = std::fma(w_this, coef2*coef2, s4);
903
+ // s1 += buffer_prob[cat] * pw1(coef);
904
+ // s2 += buffer_prob[cat] * pw2(coef);
905
+ // s3 += buffer_prob[cat] * pw3(coef);
906
+ // s4 += buffer_prob[cat] * pw4(coef);
907
+ }
908
+ temp_v = s2 - pw2(s1);
909
+ if (unlikely(temp_v <= 0))
910
+ ntry--;
911
+ else
912
+ sum_kurt += (s4 - 4 * s3 * pw1(s1) + 6 * s2 * pw2(s1) - 4 * s1 * pw3(s1) + pw4(s1)) / pw2(temp_v);
913
+ }
914
+ if (unlikely(!ntry))
915
+ return -HUGE_VAL;
916
+ else if (unlikely(is_na_or_inf(sum_kurt)))
917
+ return -HUGE_VAL;
918
+ else
919
+ return std::fmax(sum_kurt, 0.) / (double)ntry;
920
+ }
921
+
922
+ case SingleCateg:
923
+ {
924
+ double p;
925
+ int ncat_present = ncat;
926
+ for (int cat = 0; cat < ncat; cat++)
927
+ {
928
+ p = buffer_prob[cat];
929
+ if (p == 0)
930
+ ncat_present--;
931
+ else
932
+ sum_kurt += (p - 4 * p * pw1(p) + 6 * p * pw2(p) - 4 * p * pw3(p) + pw4(p)) / pw2(p - pw2(p));
933
+ }
934
+ if (ncat_present <= 1)
935
+ return -HUGE_VAL;
936
+ else if (unlikely(is_na_or_inf(sum_kurt)))
937
+ return -HUGE_VAL;
938
+ else
939
+ return std::fmax(sum_kurt, 0.) / (double)ncat_present;
940
+ }
941
+ }
942
+
943
+ return -1; /* this will never be reached, but CRAN complains otherwise */
944
+ }
945
+
946
+ template <class mapping, class ldouble_safe>
947
+ double calc_kurtosis_weighted(size_t ix_arr[], size_t st, size_t end, int x[], int ncat, double buffer_prob[],
948
+ MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator,
949
+ mapping &restrict w)
950
+ {
951
+ std::vector<ldouble_safe> buffer_cnt(ncat+1, 0.);
952
+ ldouble_safe w_this;
953
+
954
+ for (size_t row = st; row <= end; row++)
955
+ {
956
+ w_this = w[ix_arr[row]];
957
+ if (likely(x[ix_arr[row]] >= 0))
958
+ buffer_cnt[x[ix_arr[row]]] += w_this;
959
+ else
960
+ buffer_cnt[ncat] += w_this;
961
+ }
962
+
963
+ return calc_kurtosis_weighted_internal<mapping, ldouble_safe>(
964
+ buffer_cnt, x, ncat,
965
+ buffer_prob, missing_action, cat_split_type,
966
+ rnd_generator, w);
967
+ }
968
+
969
+ template <class real_t, class ldouble_safe>
970
+ double calc_kurtosis_weighted(size_t nrows, int x[], int ncat, double *restrict buffer_prob,
971
+ MissingAction missing_action, CategSplit cat_split_type,
972
+ RNG_engine &rnd_generator, real_t *restrict w)
973
+ {
974
+ std::vector<ldouble_safe> buffer_cnt(ncat+1, 0.);
975
+ ldouble_safe w_this;
976
+
977
+ for (size_t row = 0; row < nrows; row++)
978
+ {
979
+ w_this = w[row];
980
+ if (likely(x[row] >= 0))
981
+ buffer_cnt[x[row]] += w_this;
982
+ else
983
+ buffer_cnt[ncat] += w_this;
984
+ }
985
+
986
+ return calc_kurtosis_weighted_internal<real_t *restrict, ldouble_safe>(
987
+ buffer_cnt, x, ncat,
988
+ buffer_prob, missing_action, cat_split_type,
989
+ rnd_generator, w);
990
+ }
991
+
992
+ template <class int_t, class ldouble_safe>
993
+ double expected_sd_cat(double p[], size_t n, int_t pos[])
994
+ {
995
+ if (n <= 1) return 0;
996
+
997
+ ldouble_safe cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[1]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[1]]) / 3.0 + p[pos[1]] / 3.0;
998
+ for (size_t cat1 = 2; cat1 < n; cat1++)
999
+ {
1000
+ cum_var += p[pos[cat1]] / 3.0 - square(p[pos[cat1]]) / 3.0;
1001
+ for (size_t cat2 = 0; cat2 < cat1; cat2++)
1002
+ cum_var -= p[pos[cat1]] * p[pos[cat2]] / 2.0;
1003
+ }
1004
+ return std::sqrt(std::fmax(cum_var, (ldouble_safe)0));
1005
+ }
1006
+
1007
+ template <class number, class int_t, class ldouble_safe>
1008
+ double expected_sd_cat(number *restrict counts, double *restrict p, size_t n, int_t *restrict pos)
1009
+ {
1010
+ if (n <= 1) return 0;
1011
+
1012
+ number tot = std::accumulate(pos, pos + n, (number)0, [&counts](number tot, const size_t ix){return tot + counts[ix];});
1013
+ ldouble_safe cnt_div = (ldouble_safe) tot;
1014
+ for (size_t cat = 0; cat < n; cat++)
1015
+ p[pos[cat]] = (ldouble_safe)counts[pos[cat]] / cnt_div;
1016
+
1017
+ return expected_sd_cat<int_t, ldouble_safe>(p, n, pos);
1018
+ }
1019
+
1020
+ template <class number, class int_t, class ldouble_safe>
1021
+ double expected_sd_cat_single(number *restrict counts, double *restrict p, size_t n, int_t *restrict pos, size_t cat_exclude, number cnt)
1022
+ {
1023
+ if (cat_exclude == 0)
1024
+ return expected_sd_cat<number, int_t, ldouble_safe>(counts, p, n-1, pos + 1);
1025
+
1026
+ else if (cat_exclude == (n-1))
1027
+ return expected_sd_cat<number, int_t, ldouble_safe>(counts, p, n-1, pos);
1028
+
1029
+ size_t ix_exclude = pos[cat_exclude];
1030
+
1031
+ ldouble_safe cnt_div = (ldouble_safe) (cnt - counts[ix_exclude]);
1032
+ for (size_t cat = 0; cat < n; cat++)
1033
+ p[pos[cat]] = (ldouble_safe)counts[pos[cat]] / cnt_div;
1034
+
1035
+ ldouble_safe cum_var;
1036
+ if (cat_exclude != 1)
1037
+ cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[1]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[1]]) / 3.0 + p[pos[1]] / 3.0;
1038
+ else
1039
+ cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[2]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[2]]) / 3.0 + p[pos[2]] / 3.0;
1040
+ for (size_t cat1 = (cat_exclude == 1)? 3 : 2; cat1 < n; cat1++)
1041
+ {
1042
+ if (pos[cat1] == ix_exclude) continue;
1043
+ cum_var += p[pos[cat1]] / 3.0 - square(p[pos[cat1]]) / 3.0;
1044
+ for (size_t cat2 = 0; cat2 < cat1; cat2++)
1045
+ {
1046
+ if (pos[cat2] == ix_exclude) continue;
1047
+ cum_var -= p[pos[cat1]] * p[pos[cat2]] / 2.0;
1048
+ }
1049
+
1050
+ }
1051
+ return std::sqrt(std::fmax(cum_var, (ldouble_safe)0));
1052
+ }
1053
+
1054
+ template <class number, class int_t, class ldouble_safe>
1055
+ double expected_sd_cat_internal(int ncat, number *restrict buffer_cnt, ldouble_safe cnt_l,
1056
+ int_t *restrict buffer_pos, double *restrict buffer_prob)
1057
+ {
1058
+ /* move zero-valued to the beginning */
1059
+ std::iota(buffer_pos, buffer_pos + ncat, (int_t)0);
1060
+ int_t st_pos = 0;
1061
+ int ncat_present = 0;
1062
+ int_t temp;
1063
+ for (int cat = 0; cat < ncat; cat++)
1064
+ {
1065
+ if (buffer_cnt[cat])
1066
+ {
1067
+ ncat_present++;
1068
+ buffer_prob[cat] = (ldouble_safe) buffer_cnt[cat] / cnt_l;
1069
+ }
1070
+
1071
+ else
1072
+ {
1073
+ temp = buffer_pos[st_pos];
1074
+ buffer_pos[st_pos] = buffer_pos[cat];
1075
+ buffer_pos[cat] = temp;
1076
+ st_pos++;
1077
+ }
1078
+ }
1079
+
1080
+ if (ncat_present <= 1) return 0;
1081
+ return expected_sd_cat<int_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
1082
+ }
1083
+
1084
+
1085
+ template <class int_t, class ldouble_safe>
1086
+ double expected_sd_cat(size_t *restrict ix_arr, size_t st, size_t end, int x[], int ncat,
1087
+ MissingAction missing_action,
1088
+ size_t *restrict buffer_cnt, int_t *restrict buffer_pos, double buffer_prob[])
1089
+ {
1090
+ /* generate counts */
1091
+ std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
1092
+ size_t cnt = end - st + 1;
1093
+
1094
+ if (missing_action != Fail)
1095
+ {
1096
+ int xval;
1097
+ for (size_t row = st; row <= end; row++)
1098
+ {
1099
+ xval = x[ix_arr[row]];
1100
+ if (unlikely(xval < 0))
1101
+ buffer_cnt[ncat]++;
1102
+ else
1103
+ buffer_cnt[xval]++;
1104
+ }
1105
+ cnt -= buffer_cnt[ncat];
1106
+ if (cnt == 0) return 0;
1107
+ }
1108
+
1109
+ else
1110
+ {
1111
+ for (size_t row = st; row <= end; row++)
1112
+ {
1113
+ if (likely(x[ix_arr[row]] >= 0)) buffer_cnt[x[ix_arr[row]]]++;
1114
+ }
1115
+ }
1116
+
1117
+ return expected_sd_cat_internal<size_t, int_t, ldouble_safe>(ncat, buffer_cnt, cnt, buffer_pos, buffer_prob);
1118
+ }
1119
+
1120
+ template <class mapping, class int_t, class ldouble_safe>
1121
+ double expected_sd_cat_weighted(size_t *restrict ix_arr, size_t st, size_t end, int x[], int ncat,
1122
+ MissingAction missing_action, mapping &restrict w,
1123
+ double *restrict buffer_cnt, int_t *restrict buffer_pos, double *restrict buffer_prob)
1124
+ {
1125
+ /* generate counts */
1126
+ std::fill(buffer_cnt, buffer_cnt + ncat + 1, 0.);
1127
+ ldouble_safe cnt = 0;
1128
+
1129
+ if (missing_action != Fail)
1130
+ {
1131
+ int xval;
1132
+ double w_this;
1133
+ for (size_t row = st; row <= end; row++)
1134
+ {
1135
+ xval = x[ix_arr[row]];
1136
+ w_this = w[ix_arr[row]];
1137
+
1138
+ if (unlikely(xval < 0)) {
1139
+ buffer_cnt[ncat] += w_this;
1140
+ }
1141
+ else {
1142
+ buffer_cnt[xval] += w_this;
1143
+ cnt += w_this;
1144
+ }
1145
+ }
1146
+ if (cnt == 0) return 0;
1147
+ }
1148
+
1149
+ else
1150
+ {
1151
+ for (size_t row = st; row <= end; row++)
1152
+ {
1153
+ if (likely(x[ix_arr[row]] >= 0))
1154
+ {
1155
+ buffer_cnt[x[ix_arr[row]]] += w[ix_arr[row]];
1156
+ }
1157
+ }
1158
+ for (int cat = 0; cat < ncat; cat++)
1159
+ cnt += buffer_cnt[cat];
1160
+ if (unlikely(cnt == 0)) return 0;
1161
+ }
1162
+
1163
+ return expected_sd_cat_internal<double, int_t, ldouble_safe>(ncat, buffer_cnt, cnt, buffer_pos, buffer_prob);
1164
+ }
1165
+
1166
+ /* Note: this isn't exactly comparable to the pooled gain from numeric variables,
1167
+ but among all the possible options, this is what happens to end up in the most
1168
+ similar scale when considering standardized gain. */
1169
+ template <class number, class ldouble_safe>
1170
+ double categ_gain(number cnt_left, number cnt_right,
1171
+ ldouble_safe s_left, ldouble_safe s_right,
1172
+ ldouble_safe base_info, ldouble_safe cnt)
1173
+ {
1174
+ return (
1175
+ base_info -
1176
+ (((cnt_left <= 1)? 0 : ((ldouble_safe)cnt_left * std::log((ldouble_safe)cnt_left))) - s_left) -
1177
+ (((cnt_right <= 1)? 0 : ((ldouble_safe)cnt_right * std::log((ldouble_safe)cnt_right))) - s_right)
1178
+ ) / cnt;
1179
+ }
1180
+
1181
+
1182
+ /* A couple notes about gain calculation:
1183
+
1184
+ Here one wants to find the best split point, maximizing either:
1185
+ (1/sigma) * (sigma - (1/n)*(n_left*sigma_left + n_right*sigma_right))
1186
+ or:
1187
+ (1/sigma) * (sigma - (1/2)*(sigma_left + sigma_right))
1188
+
1189
+ All the algorithms here use the sorted-indices approach, which is
1190
+ an exact method (note that there's still room for optimization by adding the
1191
+ unsorted approach for small sample sizes and for sparse matrices).
1192
+
1193
+ A naive approach would move observations one at a time from right
1194
+ to left using this formula:
1195
+ sigma = (ssq - s^2/n) / n
1196
+ ssq = sum(x^2)
1197
+ s = sum(x)
1198
+ But such approach has poor numerical precision, and this library is
1199
+ aimed precisely at cases in which there are outliers in the data.
1200
+ It's possible to improve the numerical precision by standardizing the
1201
+ data beforehand, but this library uses instead a more exact two-pass
1202
+ sigma calculation observation-by-observation (from left to right and
1203
+ from right to left, keeping the calculations of the first pass in an
1204
+ array and calculating gain in the second pass), but there's
1205
+ other methods too.
1206
+
1207
+ If one is aiming at maximizing the pooled gain, it's possible to
1208
+ simplify either the gain or the increase in gain without involving
1209
+ 'ssq'. Assuming one already has 'ssq' and 's' calculated for the left and
1210
+ right partitions, and one wants to move one ovservation from right to left,
1211
+ the following hold:
1212
+ s_right = s - s_left
1213
+ ssq_right = ssq - ssq_left
1214
+ n_right = n - n_left
1215
+ If one then moves observation x, these are updated as follows:
1216
+ s_left_new = s_left + x
1217
+ s_right_new = s - s_left - x
1218
+ ssq_left_new = ssq_left + x^2
1219
+ ssq_right_new = ssq - ssq_left - x^2
1220
+ n_left_new = n_left + 1
1221
+ n_right_new = n - n_left - 1
1222
+ Gain is then:
1223
+ (1/sigma) * (sigma - (1/n)*({ssq_left_new - s_left_new^2/n_left_new} + {ssq_right_new - s_right_new^2/n_right_new}))
1224
+ Which simplifies to:
1225
+ 1 - (1/(sigma*n))(ssq - ( (s_left + x)^2/(n_left+1) + (s - (s_left + x))^2/(n - (n_left+1)) ))
1226
+ Since 'sigma', n', and 'ssq' are constant, they can be ignored when determining the
1227
+ maximum gain - that is, one is interest in finding the point that maximizes:
1228
+ (s_left+x)^2/(n_left+1) + (s-(s_left+x))^2/(n-(n_left+1))
1229
+ And this calculation will be robust-enough when dealing with numbers that were
1230
+ already standardized beforehand, as the extended model does at each step.
1231
+ Note however that, when fitting this model, one is usually interested in evaluating
1232
+ the actual gain, standardized by the standard deviation, as it will try different
1233
+ linear combinations which will give different standard deviations, so this simpler
1234
+ formula cannot be applied unless only one linear combination is probed.
1235
+
1236
+ One can also look at:
1237
+ diff_gain = (1/sigma) * (gain_new - gain)
1238
+ Which can be simplified to something that doesn't include sums of squares:
1239
+ (1/(sigma*n))*( -s_left^2/n_left - (s-s_left)^2/(n-n_left) + (s_left+x)^2/(n_left+1) + (s-(s_left+x))^2/(n-(n_left+1)) )
1240
+ And this calculation would in theory allow getting the actual standardized gain.
1241
+ In practice however, this calculation can have poor numerical precision when the
1242
+ sample size is large, so the functions here do not even attempt at calculating it,
1243
+ and this is the reason why the two-pass approach is preferred.
1244
+
1245
+ The averaged SD formula unfortunately doesn't reduce to something that would involve
1246
+ only sums.
1247
+ */
1248
+
1249
+ /* TODO: maybe it's not a good idea to use the two-pass approach with un-standardized
1250
+ variables at large sample sizes (ndim=1), considering that they come in sorted order.
1251
+ Maybe it should instead use sums of centered squares: sigma = sqrt((x-mean(x))^2/n)
1252
+ The sums of centered squares method is also likely to be more precise. */
1253
+
1254
+ template <class real_t>
1255
+ double midpoint(real_t x, real_t y)
1256
+ {
1257
+ real_t m = x + (y-x)/(real_t)2;
1258
+ if (likely((double)m < (double)y))
1259
+ return m;
1260
+ else {
1261
+ m = std::nextafter(m, y);
1262
+ if (m > x && m < y)
1263
+ return m;
1264
+ else
1265
+ return x;
1266
+ }
1267
+ }
1268
+
1269
+ template <class real_t>
1270
+ double midpoint_with_reorder(real_t x, real_t y)
1271
+ {
1272
+ if (x < y)
1273
+ return midpoint(x, y);
1274
+ else
1275
+ return midpoint(y, x);
1276
+ }
1277
+
1278
+ #define sd_gain(sd, sd_left, sd_right) (1. - ((sd_left) + (sd_right)) / (2. * (sd)))
1279
+ #define pooled_gain(sd, cnt, sd_left, sd_right, cnt_left, cnt_right) \
1280
+ (1. - (1./(sd))*( ( ((real_t)(cnt_left))/(cnt) )*(sd_left) + ( ((real_t)(cnt_right)/(cnt)) )*(sd_right) ))
1281
+
1282
+
1283
+ /* TODO: make this a compensated sum */
1284
+ template <class real_t, class real_t_>
1285
+ double find_split_rel_gain_t(real_t_ *restrict x, size_t n, double &restrict split_point)
1286
+ {
1287
+ real_t this_gain;
1288
+ real_t best_gain = -HUGE_VAL;
1289
+ real_t x1 = 0, x2 = 0;
1290
+ real_t sum_left = 0, sum_right = 0, sum_tot = 0;
1291
+ for (size_t row = 0; row < n; row++)
1292
+ sum_tot += x[row];
1293
+ for (size_t row = 0; row < n-1; row++)
1294
+ {
1295
+ sum_left += x[row];
1296
+ if (x[row] == x[row+1])
1297
+ continue;
1298
+
1299
+ sum_right = sum_tot - sum_left;
1300
+ this_gain = sum_left * (sum_left / (real_t)(row+1))
1301
+ + sum_right * (sum_right / (real_t)(n-row-1));
1302
+ if (this_gain > best_gain)
1303
+ {
1304
+ best_gain = this_gain;
1305
+ x1 = x[row]; x2 = x[row+1];
1306
+ }
1307
+ }
1308
+
1309
+ if (best_gain <= -HUGE_VAL)
1310
+ return best_gain;
1311
+ split_point = midpoint(x1, x2);
1312
+ return std::fmax((double)best_gain, std::numeric_limits<double>::epsilon());
1313
+ }
1314
+
1315
+ template <class real_t_, class ldouble_safe>
1316
+ double find_split_rel_gain(real_t_ *restrict x, size_t n, double &restrict split_point)
1317
+ {
1318
+ if (n < THRESHOLD_LONG_DOUBLE)
1319
+ return find_split_rel_gain_t<double, real_t_>(x, n, split_point);
1320
+ else
1321
+ return find_split_rel_gain_t<ldouble_safe, real_t_>(x, n, split_point);
1322
+ }
1323
+
1324
+ /* Note: there is no 'weighted' version of 'find_split_rel_gain' with unindexed 'x', because calling it would
1325
+ imply having to argsort the 'x' values in order to sort the weights, which is less efficient. */
1326
+
1327
+ template <class real_t, class real_t_>
1328
+ double find_split_rel_gain_t(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &split_point, size_t &restrict split_ix)
1329
+ {
1330
+ real_t this_gain;
1331
+ real_t best_gain = -HUGE_VAL;
1332
+ split_ix = 0; /* <- avoid out-of-bounds at the end */
1333
+ real_t sum_left = 0, sum_right = 0, sum_tot = 0;
1334
+ for (size_t row = st; row <= end; row++)
1335
+ sum_tot += x[ix_arr[row]] - xmean;
1336
+ for (size_t row = st; row < end; row++)
1337
+ {
1338
+ sum_left += x[ix_arr[row]] - xmean;
1339
+ if (x[ix_arr[row]] == x[ix_arr[row+1]])
1340
+ continue;
1341
+
1342
+ sum_right = sum_tot - sum_left;
1343
+ this_gain = sum_left * (sum_left / (real_t)(row - st + 1))
1344
+ + sum_right * (sum_right / (real_t)(end - row));
1345
+ if (this_gain > best_gain)
1346
+ {
1347
+ best_gain = this_gain;
1348
+ split_ix = row;
1349
+ }
1350
+ }
1351
+
1352
+ if (best_gain <= -HUGE_VAL)
1353
+ return best_gain;
1354
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1355
+ return std::fmax((double)best_gain, std::numeric_limits<double>::epsilon());
1356
+ }
1357
+
1358
+ template <class real_t_, class ldouble_safe>
1359
+ double find_split_rel_gain(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &restrict split_point, size_t &restrict split_ix)
1360
+ {
1361
+ if ((end-st+1) < THRESHOLD_LONG_DOUBLE)
1362
+ return find_split_rel_gain_t<double, real_t_>(x, xmean, ix_arr, st, end, split_point, split_ix);
1363
+ else
1364
+ return find_split_rel_gain_t<ldouble_safe, real_t_>(x, xmean, ix_arr, st, end, split_point, split_ix);
1365
+ }
1366
+
1367
+ template <class real_t, class real_t_, class mapping>
1368
+ double find_split_rel_gain_weighted_t(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &split_point, size_t &restrict split_ix, mapping &restrict w)
1369
+ {
1370
+ real_t this_gain;
1371
+ real_t best_gain = -HUGE_VAL;
1372
+ split_ix = 0; /* <- avoid out-of-bounds at the end */
1373
+ real_t sum_left = 0, sum_right = 0, sum_tot = 0, sumw = 0, sumw_tot = 0;
1374
+
1375
+ for (size_t row = st; row <= end; row++)
1376
+ sumw_tot += w[ix_arr[row]];
1377
+ for (size_t row = st; row <= end; row++)
1378
+ sum_tot += x[ix_arr[row]] - xmean;
1379
+ for (size_t row = st; row < end; row++)
1380
+ {
1381
+ sumw += w[ix_arr[row]];
1382
+ sum_left += x[ix_arr[row]] - xmean;
1383
+ if (x[ix_arr[row]] == x[ix_arr[row+1]])
1384
+ continue;
1385
+
1386
+ sum_right = sum_tot - sum_left;
1387
+ this_gain = sum_left * (sum_left / sumw)
1388
+ + sum_right * (sum_right / (sumw_tot - sumw));
1389
+ if (this_gain > best_gain)
1390
+ {
1391
+ best_gain = this_gain;
1392
+ split_ix = row;
1393
+ }
1394
+ }
1395
+
1396
+ if (best_gain <= -HUGE_VAL)
1397
+ return best_gain;
1398
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1399
+ return std::fmax((double)best_gain, std::numeric_limits<double>::epsilon());
1400
+ }
1401
+
1402
+ template <class real_t_, class mapping, class ldouble_safe>
1403
+ double find_split_rel_gain_weighted(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
1404
+ {
1405
+ if ((end-st+1) < THRESHOLD_LONG_DOUBLE)
1406
+ return find_split_rel_gain_weighted_t<double, real_t_, mapping>(x, xmean, ix_arr, st, end, split_point, split_ix, w);
1407
+ else
1408
+ return find_split_rel_gain_weighted_t<ldouble_safe, real_t_, mapping>(x, xmean, ix_arr, st, end, split_point, split_ix, w);
1409
+ }
1410
+
1411
+
1412
+ template <class real_t, class real_t_>
1413
+ real_t calc_sd_right_to_left(real_t_ *restrict x, size_t n, double *restrict sd_arr)
1414
+ {
1415
+ real_t running_mean = 0;
1416
+ real_t running_ssq = 0;
1417
+ real_t mean_prev = x[n-1];
1418
+ for (size_t row = 0; row < n-1; row++)
1419
+ {
1420
+ running_mean += (x[n-row-1] - running_mean) / (real_t)(row+1);
1421
+ running_ssq += (x[n-row-1] - running_mean) * (x[n-row-1] - mean_prev);
1422
+ mean_prev = running_mean;
1423
+ sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / (real_t)(row+1));
1424
+ }
1425
+ running_mean += (x[0] - running_mean) / (real_t)n;
1426
+ running_ssq += (x[0] - running_mean) * (x[0] - mean_prev);
1427
+ return std::sqrt(running_ssq / (real_t)n);
1428
+ }
1429
+
1430
+ template <class real_t_, class ldouble_safe>
1431
+ ldouble_safe calc_sd_right_to_left_weighted(real_t_ *restrict x, size_t n, double *restrict sd_arr,
1432
+ double *restrict w, ldouble_safe &cumw, size_t *restrict sorted_ix)
1433
+ {
1434
+ ldouble_safe running_mean = 0;
1435
+ ldouble_safe running_ssq = 0;
1436
+ ldouble_safe mean_prev = x[sorted_ix[n-1]];
1437
+ ldouble_safe cnt = 0;
1438
+ double w_this;
1439
+ for (size_t row = 0; row < n-1; row++)
1440
+ {
1441
+ w_this = w[sorted_ix[n-row-1]];
1442
+ cnt += w_this;
1443
+ running_mean += w_this * (x[sorted_ix[n-row-1]] - running_mean) / cnt;
1444
+ running_ssq += w_this * ((x[sorted_ix[n-row-1]] - running_mean) * (x[sorted_ix[n-row-1]] - mean_prev));
1445
+ mean_prev = running_mean;
1446
+ sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / cnt);
1447
+ }
1448
+ w_this = w[sorted_ix[0]];
1449
+ cnt += w_this;
1450
+ running_mean += (x[sorted_ix[0]] - running_mean) / cnt;
1451
+ running_ssq += w_this * ((x[sorted_ix[0]] - running_mean) * (x[sorted_ix[0]] - mean_prev));
1452
+ cumw = cnt;
1453
+ return std::sqrt(running_ssq / cnt);
1454
+ }
1455
+
1456
+ template <class real_t, class real_t_>
1457
+ real_t calc_sd_right_to_left(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr)
1458
+ {
1459
+ real_t running_mean = 0;
1460
+ real_t running_ssq = 0;
1461
+ real_t mean_prev = x[ix_arr[end]] - xmean;
1462
+ size_t n = end - st + 1;
1463
+ for (size_t row = 0; row < n-1; row++)
1464
+ {
1465
+ running_mean += ((x[ix_arr[end-row]] - xmean) - running_mean) / (real_t)(row+1);
1466
+ running_ssq += ((x[ix_arr[end-row]] - xmean) - running_mean) * ((x[ix_arr[end-row]] - xmean) - mean_prev);
1467
+ mean_prev = running_mean;
1468
+ sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / (real_t)(row+1));
1469
+ }
1470
+ running_mean += ((x[ix_arr[st]] - xmean) - running_mean) / (real_t)n;
1471
+ running_ssq += ((x[ix_arr[st]] - xmean) - running_mean) * ((x[ix_arr[st]] - xmean) - mean_prev);
1472
+ return std::sqrt(running_ssq / (real_t)n);
1473
+ }
1474
+
1475
+ template <class real_t_, class mapping, class ldouble_safe>
1476
+ ldouble_safe calc_sd_right_to_left_weighted(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end,
1477
+ double *restrict sd_arr, mapping &restrict w, ldouble_safe &cumw)
1478
+ {
1479
+ ldouble_safe running_mean = 0;
1480
+ ldouble_safe running_ssq = 0;
1481
+ real_t_ mean_prev = x[ix_arr[end]] - xmean;
1482
+ size_t n = end - st + 1;
1483
+ ldouble_safe cnt = 0;
1484
+ double w_this;
1485
+ for (size_t row = 0; row < n-1; row++)
1486
+ {
1487
+ w_this = w[ix_arr[end-row]];
1488
+ cnt += w_this;
1489
+ running_mean += w_this * ((x[ix_arr[end-row]] - xmean) - running_mean) / cnt;
1490
+ running_ssq += w_this * (((x[ix_arr[end-row]] - xmean) - running_mean) * ((x[ix_arr[end-row]] - xmean) - mean_prev));
1491
+ mean_prev = running_mean;
1492
+ sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / cnt);
1493
+ }
1494
+ w_this = w[ix_arr[st]];
1495
+ cnt += w_this;
1496
+ running_mean += ((x[ix_arr[st]] - xmean) - running_mean) / cnt;
1497
+ running_ssq += w_this * (((x[ix_arr[st]] - xmean) - running_mean) * ((x[ix_arr[st]] - xmean) - mean_prev));
1498
+ cumw = cnt;
1499
+ return std::sqrt(running_ssq / cnt);
1500
+ }
1501
+
1502
+ template <class real_t, class real_t_>
1503
+ double find_split_std_gain_t(real_t_ *restrict x, size_t n, double *restrict sd_arr,
1504
+ GainCriterion criterion, double min_gain, double &restrict split_point)
1505
+ {
1506
+ real_t full_sd = calc_sd_right_to_left<real_t>(x, n, sd_arr);
1507
+ real_t running_mean = 0;
1508
+ real_t running_ssq = 0;
1509
+ real_t mean_prev = x[0];
1510
+ real_t best_gain = -HUGE_VAL;
1511
+ real_t this_sd, this_gain;
1512
+ real_t n_ = (real_t)n;
1513
+ size_t best_ix = 0;
1514
+ for (size_t row = 0; row < n-1; row++)
1515
+ {
1516
+ running_mean += (x[row] - running_mean) / (real_t)(row+1);
1517
+ running_ssq += (x[row] - running_mean) * (x[row] - mean_prev);
1518
+ mean_prev = running_mean;
1519
+ if (x[row] == x[row+1])
1520
+ continue;
1521
+
1522
+ this_sd = (row == 0)? 0. : std::sqrt(running_ssq / (real_t)(row+1));
1523
+ this_gain = (criterion == Pooled)?
1524
+ pooled_gain(full_sd, n_, this_sd, sd_arr[row+1], row+1, n-row-1)
1525
+ :
1526
+ sd_gain(full_sd, this_sd, sd_arr[row+1]);
1527
+ if (this_gain > best_gain && this_gain > min_gain)
1528
+ {
1529
+ best_gain = this_gain;
1530
+ best_ix = row;
1531
+ }
1532
+ }
1533
+
1534
+ if (best_gain > -HUGE_VAL)
1535
+ split_point = midpoint(x[best_ix], x[best_ix+1]);
1536
+
1537
+ return best_gain;
1538
+ }
1539
+
1540
+ template <class real_t_, class ldouble_safe>
1541
+ double find_split_std_gain(real_t_ *restrict x, size_t n, double *restrict sd_arr,
1542
+ GainCriterion criterion, double min_gain, double &restrict split_point)
1543
+ {
1544
+ if (n < THRESHOLD_LONG_DOUBLE)
1545
+ return find_split_std_gain_t<double, real_t_>(x, n, sd_arr, criterion, min_gain, split_point);
1546
+ else
1547
+ return find_split_std_gain_t<ldouble_safe, real_t_>(x, n, sd_arr, criterion, min_gain, split_point);
1548
+ }
1549
+
1550
+ template <class real_t, class ldouble_safe>
1551
+ double find_split_std_gain_weighted(real_t *restrict x, size_t n, double *restrict sd_arr,
1552
+ GainCriterion criterion, double min_gain, double &restrict split_point,
1553
+ double *restrict w, size_t *restrict sorted_ix)
1554
+ {
1555
+ ldouble_safe cumw;
1556
+ double full_sd = calc_sd_right_to_left_weighted(x, n, sd_arr, w, cumw, sorted_ix);
1557
+ ldouble_safe running_mean = 0;
1558
+ ldouble_safe running_ssq = 0;
1559
+ ldouble_safe mean_prev = x[sorted_ix[0]];
1560
+ double best_gain = -HUGE_VAL;
1561
+ double this_sd, this_gain;
1562
+ double w_this;
1563
+ ldouble_safe currw = 0;
1564
+ size_t best_ix = 0;
1565
+
1566
+ for (size_t row = 0; row < n-1; row++)
1567
+ {
1568
+ w_this = w[sorted_ix[row]];
1569
+ currw += w_this;
1570
+ running_mean += w_this * (x[sorted_ix[row]] - running_mean) / currw;
1571
+ running_ssq += w_this * ((x[sorted_ix[row]] - running_mean) * (x[sorted_ix[row]] - mean_prev));
1572
+ mean_prev = running_mean;
1573
+ if (x[sorted_ix[row]] == x[sorted_ix[row+1]])
1574
+ continue;
1575
+
1576
+ this_sd = (row == 0)? 0. : std::sqrt(running_ssq / currw);
1577
+ this_gain = (criterion == Pooled)?
1578
+ pooled_gain(full_sd, cumw, this_sd, sd_arr[row+1], currw, cumw-currw)
1579
+ :
1580
+ sd_gain(full_sd, this_sd, sd_arr[row+1]);
1581
+ if (this_gain > best_gain && this_gain > min_gain)
1582
+ {
1583
+ best_gain = this_gain;
1584
+ best_ix = row;
1585
+ }
1586
+ }
1587
+
1588
+ if (best_gain > -HUGE_VAL)
1589
+ split_point = midpoint(x[sorted_ix[best_ix]], x[sorted_ix[best_ix+1]]);
1590
+
1591
+ return best_gain;
1592
+ }
1593
+
1594
+ template <class real_t, class real_t_>
1595
+ double find_split_std_gain_t(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr,
1596
+ GainCriterion criterion, double min_gain, double &restrict split_point, size_t &restrict split_ix)
1597
+ {
1598
+ real_t full_sd = calc_sd_right_to_left<real_t>(x, xmean, ix_arr, st, end, sd_arr);
1599
+ real_t running_mean = 0;
1600
+ real_t running_ssq = 0;
1601
+ real_t mean_prev = x[ix_arr[st]] - xmean;
1602
+ real_t best_gain = -HUGE_VAL;
1603
+ real_t n = (real_t)(end - st + 1);
1604
+ real_t this_sd, this_gain;
1605
+ split_ix = st;
1606
+ for (size_t row = st; row < end; row++)
1607
+ {
1608
+ running_mean += ((x[ix_arr[row]] - xmean) - running_mean) / (real_t)(row-st+1);
1609
+ running_ssq += ((x[ix_arr[row]] - xmean) - running_mean) * ((x[ix_arr[row]] - xmean) - mean_prev);
1610
+ mean_prev = running_mean;
1611
+ if (x[ix_arr[row]] == x[ix_arr[row+1]])
1612
+ continue;
1613
+
1614
+ this_sd = (row == st)? 0. : std::sqrt(running_ssq / (real_t)(row-st+1));
1615
+ this_gain = (criterion == Pooled)?
1616
+ pooled_gain(full_sd, n, this_sd, sd_arr[row-st+1], row-st+1, end-row)
1617
+ :
1618
+ sd_gain(full_sd, this_sd, sd_arr[row-st+1]);
1619
+ if (this_gain > best_gain && this_gain > min_gain)
1620
+ {
1621
+ best_gain = this_gain;
1622
+ split_ix = row;
1623
+ }
1624
+ }
1625
+
1626
+ if (best_gain > -HUGE_VAL)
1627
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1628
+
1629
+ return best_gain;
1630
+ }
1631
+
1632
+ template <class real_t_, class ldouble_safe>
1633
+ double find_split_std_gain(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr,
1634
+ GainCriterion criterion, double min_gain, double &restrict split_point, size_t &restrict split_ix)
1635
+ {
1636
+ if ((end-st+1) < THRESHOLD_LONG_DOUBLE)
1637
+ return find_split_std_gain_t<double, real_t_>(x, xmean, ix_arr, st, end, sd_arr, criterion, min_gain, split_point, split_ix);
1638
+ else
1639
+ return find_split_std_gain_t<ldouble_safe, real_t_>(x, xmean, ix_arr, st, end, sd_arr, criterion, min_gain, split_point, split_ix);
1640
+ }
1641
+
1642
+ template <class real_t, class mapping, class ldouble_safe>
1643
+ double find_split_std_gain_weighted(real_t *restrict x, real_t xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr,
1644
+ GainCriterion criterion, double min_gain, double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
1645
+ {
1646
+ ldouble_safe cumw;
1647
+ double full_sd = calc_sd_right_to_left_weighted(x, xmean, ix_arr, st, end, sd_arr, w, cumw);
1648
+ ldouble_safe running_mean = 0;
1649
+ ldouble_safe running_ssq = 0;
1650
+ ldouble_safe mean_prev = x[ix_arr[st]] - xmean;
1651
+ double best_gain = -HUGE_VAL;
1652
+ ldouble_safe currw = 0;
1653
+ double this_sd, this_gain;
1654
+ double w_this;
1655
+ split_ix = st;
1656
+
1657
+ for (size_t row = st; row < end; row++)
1658
+ {
1659
+ w_this = w[ix_arr[row]];
1660
+ currw += w_this;
1661
+ running_mean += w_this * ((x[ix_arr[row]] - xmean) - running_mean) / currw;
1662
+ running_ssq += w_this * (((x[ix_arr[row]] - xmean) - running_mean) * ((x[ix_arr[row]] - xmean) - mean_prev));
1663
+ mean_prev = running_mean;
1664
+ if (x[ix_arr[row]] == x[ix_arr[row+1]])
1665
+ continue;
1666
+
1667
+ this_sd = (row == st)? 0. : std::sqrt(running_ssq / currw);
1668
+ this_gain = (criterion == Pooled)?
1669
+ pooled_gain(full_sd, cumw, this_sd, sd_arr[row-st+1], currw, cumw-currw)
1670
+ :
1671
+ sd_gain(full_sd, this_sd, sd_arr[row-st+1]);
1672
+ if (this_gain > best_gain && this_gain > min_gain)
1673
+ {
1674
+ best_gain = this_gain;
1675
+ split_ix = row;
1676
+ }
1677
+ }
1678
+
1679
+ if (best_gain > -HUGE_VAL)
1680
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1681
+
1682
+ return best_gain;
1683
+ }
1684
+
1685
+ #ifndef _FOR_R
1686
+ #if defined(__clang__)
1687
+ #pragma clang diagnostic push
1688
+ #pragma clang diagnostic ignored "-Wunknown-attributes"
1689
+ #endif
1690
+ #endif
1691
+
1692
+ #ifndef _FOR_R
1693
+ [[gnu::optimize("Ofast")]]
1694
+ #endif
1695
+ static inline void xpy1(double *restrict x, double *restrict y, size_t n)
1696
+ {
1697
+ for (size_t ix = 0; ix < n; ix++) y[ix] += x[ix];
1698
+ }
1699
+
1700
+ #ifndef _FOR_R
1701
+ [[gnu::optimize("Ofast")]]
1702
+ #endif
1703
+ static inline void axpy1(const double a, double *restrict x, double *restrict y, size_t n)
1704
+ {
1705
+ for (size_t ix = 0; ix < n; ix++) y[ix] = std::fma(a, x[ix], y[ix]);
1706
+ }
1707
+
1708
+ #ifndef _FOR_R
1709
+ [[gnu::optimize("Ofast")]]
1710
+ #endif
1711
+ static inline void xpy1(double *restrict xval, size_t ind[], size_t nnz, double *restrict y)
1712
+ {
1713
+ for (size_t ix = 0; ix < nnz; ix++) y[ind[ix]] += xval[ix];
1714
+ }
1715
+
1716
+ #ifndef _FOR_R
1717
+ [[gnu::optimize("Ofast")]]
1718
+ #endif
1719
+ static inline void axpy1(const double a, double *restrict xval, size_t ind[], size_t nnz, double *restrict y)
1720
+ {
1721
+ for (size_t ix = 0; ix < nnz; ix++) y[ind[ix]] = std::fma(a, xval[ix], y[ind[ix]]);
1722
+ }
1723
+
1724
+ #ifndef _FOR_R
1725
+ #if defined(__clang__)
1726
+ #pragma clang diagnostic pop
1727
+ #endif
1728
+ #endif
1729
+
1730
+ template <class real_t, class ldouble_safe>
1731
+ double find_split_full_gain(real_t *restrict x, size_t st, size_t end, size_t *restrict ix_arr,
1732
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
1733
+ double *restrict X_row_major, size_t ncols,
1734
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
1735
+ double *restrict buffer_sum_left, double *restrict buffer_sum_tot,
1736
+ size_t &restrict split_ix, double &restrict split_point,
1737
+ bool x_uses_ix_arr)
1738
+ {
1739
+ if (end <= st) return -HUGE_VAL;
1740
+ if (cols_use != NULL && ncols_use && (double)ncols_use / (double)ncols < 0.1)
1741
+ force_cols_use = true;
1742
+
1743
+ memset(buffer_sum_tot, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
1744
+ if (Xr_indptr == NULL)
1745
+ {
1746
+ if (force_cols_use)
1747
+ {
1748
+ double *restrict ptr_row;
1749
+ for (size_t row = st; row <= end; row++)
1750
+ {
1751
+ ptr_row = X_row_major + ix_arr[row]*ncols;
1752
+ for (size_t col = 0; col < ncols_use; col++)
1753
+ buffer_sum_tot[col] += ptr_row[cols_use[col]];
1754
+ }
1755
+ }
1756
+
1757
+ else
1758
+ {
1759
+ for (size_t row = st; row <= end; row++)
1760
+ xpy1(X_row_major + ix_arr[row]*ncols, buffer_sum_tot, ncols);
1761
+ }
1762
+ }
1763
+
1764
+ else
1765
+ {
1766
+ if (force_cols_use)
1767
+ {
1768
+ size_t *curr_begin;
1769
+ size_t *row_end;
1770
+ size_t *curr_col;
1771
+ double *restrict Xr_this;
1772
+ size_t *cols_end = cols_use + ncols_use;
1773
+ for (size_t row = st; row <= end; row++)
1774
+ {
1775
+ curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
1776
+ row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
1777
+ if (curr_begin == row_end) continue;
1778
+ curr_col = cols_use;
1779
+ Xr_this = Xr + Xr_indptr[ix_arr[row]];
1780
+
1781
+ while (curr_col < cols_end && curr_begin < row_end)
1782
+ {
1783
+ if (*curr_begin == *curr_col)
1784
+ {
1785
+ buffer_sum_tot[std::distance(cols_use, curr_col)] += Xr_this[std::distance(curr_begin, row_end)];
1786
+ curr_col++;
1787
+ curr_begin++;
1788
+ }
1789
+
1790
+ else
1791
+ {
1792
+ if (*curr_begin > *curr_col)
1793
+ curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
1794
+ else
1795
+ curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
1796
+ }
1797
+ }
1798
+ }
1799
+ }
1800
+
1801
+ else
1802
+ {
1803
+ size_t ptr_this;
1804
+ for (size_t row = st; row <= end; row++)
1805
+ {
1806
+ ptr_this = Xr_indptr[ix_arr[row]];
1807
+ xpy1(Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_tot);
1808
+ }
1809
+ }
1810
+ }
1811
+
1812
+ double best_gain = -HUGE_VAL;
1813
+ double this_gain;
1814
+ double sl, sr;
1815
+ double dl, dr;
1816
+ double vleft, vright;
1817
+ memset(buffer_sum_left, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
1818
+ if (Xr_indptr == NULL)
1819
+ {
1820
+ if (!force_cols_use)
1821
+ {
1822
+ for (size_t row = st; row < end; row++)
1823
+ {
1824
+ xpy1(X_row_major + ix_arr[row]*ncols, buffer_sum_left, ncols);
1825
+ if (x_uses_ix_arr) {
1826
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
1827
+ }
1828
+ else {
1829
+ if (unlikely(x[row] == x[row+1])) continue;
1830
+ }
1831
+
1832
+ vleft = 0;
1833
+ vright = 0;
1834
+ dl = (double)(row-st+1);
1835
+ dr = (double)(end-row);
1836
+ for (size_t col = 0; col < ncols; col++)
1837
+ {
1838
+ sl = buffer_sum_left[col];
1839
+ vleft += sl * (sl / dl);
1840
+ sr = buffer_sum_tot[col] - sl;
1841
+ vright += sr * (sr / dr);
1842
+ }
1843
+
1844
+ this_gain = vleft + vright;
1845
+ if (this_gain > best_gain)
1846
+ {
1847
+ best_gain = this_gain;
1848
+ split_ix = row;
1849
+ }
1850
+ }
1851
+ }
1852
+
1853
+ else
1854
+ {
1855
+ double *restrict ptr_row;
1856
+ for (size_t row = st; row < end; row++)
1857
+ {
1858
+ ptr_row = X_row_major + ix_arr[row]*ncols;
1859
+ for (size_t col = 0; col < ncols_use; col++)
1860
+ buffer_sum_left[col] += ptr_row[cols_use[col]];
1861
+ if (x_uses_ix_arr) {
1862
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
1863
+ }
1864
+ else {
1865
+ if (unlikely(x[row] == x[row+1])) continue;
1866
+ }
1867
+
1868
+ vleft = 0;
1869
+ vright = 0;
1870
+ dl = (double)(row-st+1);
1871
+ dr = (double)(end-row);
1872
+ for (size_t col = 0; col < ncols_use; col++)
1873
+ {
1874
+ sl = buffer_sum_left[col];
1875
+ vleft += sl * (sl / dl);
1876
+ sr = buffer_sum_tot[col] - sl;
1877
+ vright += sr * (sr / dr);
1878
+ }
1879
+
1880
+ this_gain = vleft + vright;
1881
+ if (this_gain > best_gain)
1882
+ {
1883
+ best_gain = this_gain;
1884
+ split_ix = row;
1885
+ }
1886
+ }
1887
+ }
1888
+ }
1889
+
1890
+ else
1891
+ {
1892
+ if (!force_cols_use)
1893
+ {
1894
+ size_t ptr_this;
1895
+ for (size_t row = st; row < end; row++)
1896
+ {
1897
+ ptr_this = Xr_indptr[ix_arr[row]];
1898
+ xpy1(Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_left);
1899
+ if (x_uses_ix_arr) {
1900
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
1901
+ }
1902
+ else {
1903
+ if (unlikely(x[row] == x[row+1])) continue;
1904
+ }
1905
+
1906
+ vleft = 0;
1907
+ vright = 0;
1908
+ dl = (double)(row-st+1);
1909
+ dr = (double)(end-row);
1910
+ for (size_t col = 0; col < ncols; col++)
1911
+ {
1912
+ sl = buffer_sum_left[col];
1913
+ vleft += sl * (sl / dl);
1914
+ sr = buffer_sum_tot[col] - sl;
1915
+ vright += sr * (sr / dr);
1916
+ }
1917
+
1918
+ this_gain = vleft + vright;
1919
+ if (this_gain > best_gain)
1920
+ {
1921
+ best_gain = this_gain;
1922
+ split_ix = row;
1923
+ }
1924
+ }
1925
+ }
1926
+
1927
+ else
1928
+ {
1929
+ size_t *curr_begin;
1930
+ size_t *row_end;
1931
+ size_t *curr_col;
1932
+ double *restrict Xr_this;
1933
+ size_t *cols_end = cols_use + ncols_use;
1934
+ for (size_t row = st; row < end; row++)
1935
+ {
1936
+ curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
1937
+ row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
1938
+ if (curr_begin == row_end) goto skip_sum;
1939
+ curr_col = cols_use;
1940
+ Xr_this = Xr + Xr_indptr[ix_arr[row]];
1941
+ while (curr_col < cols_end && curr_begin < row_end)
1942
+ {
1943
+ if (*curr_begin == *curr_col)
1944
+ {
1945
+ buffer_sum_left[std::distance(cols_use, curr_col)] += Xr_this[std::distance(curr_begin, row_end)];
1946
+ curr_col++;
1947
+ curr_begin++;
1948
+ }
1949
+
1950
+ else
1951
+ {
1952
+ if (*curr_begin > *curr_col)
1953
+ curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
1954
+ else
1955
+ curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
1956
+ }
1957
+ }
1958
+
1959
+ skip_sum:
1960
+ if (x_uses_ix_arr) {
1961
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
1962
+ }
1963
+ else {
1964
+ if (unlikely(x[row] == x[row+1])) continue;
1965
+ }
1966
+
1967
+ vleft = 0;
1968
+ vright = 0;
1969
+ dl = (double)(row-st+1);
1970
+ dr = (double)(end-row);
1971
+ for (size_t col = 0; col < ncols_use; col++)
1972
+ {
1973
+ sl = buffer_sum_left[col];
1974
+ vleft += sl * (sl / dl);
1975
+ sr = buffer_sum_tot[col] - sl;
1976
+ vright += sr * (sr / dr);
1977
+ }
1978
+
1979
+ this_gain = vleft + vright;
1980
+ if (this_gain > best_gain)
1981
+ {
1982
+ best_gain = this_gain;
1983
+ split_ix = row;
1984
+ }
1985
+ }
1986
+ }
1987
+ }
1988
+
1989
+ if (best_gain <= -HUGE_VAL) return best_gain;
1990
+
1991
+ if (x_uses_ix_arr)
1992
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1993
+ else
1994
+ split_point = midpoint(x[split_ix], x[split_ix+1]);
1995
+ return best_gain / (ldouble_safe)(end - st + 1);
1996
+ }
1997
+
1998
+ template <class real_t, class mapping, class ldouble_safe>
1999
+ double find_split_full_gain_weighted(real_t *restrict x, size_t st, size_t end, size_t *restrict ix_arr,
2000
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
2001
+ double *restrict X_row_major, size_t ncols,
2002
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
2003
+ double *restrict buffer_sum_left, double *restrict buffer_sum_tot,
2004
+ size_t &restrict split_ix, double &restrict split_point,
2005
+ bool x_uses_ix_arr,
2006
+ mapping &restrict w)
2007
+ {
2008
+ if (end <= st) return -HUGE_VAL;
2009
+ if (cols_use != NULL && ncols_use && (double)ncols_use / (double)ncols < 0.1)
2010
+ force_cols_use = true;
2011
+
2012
+ double wtot = 0;
2013
+ if (x_uses_ix_arr)
2014
+ {
2015
+ for (size_t row = st; row <= end; row++)
2016
+ wtot += w[ix_arr[row]];
2017
+ }
2018
+
2019
+ else
2020
+ {
2021
+ for (size_t row = st; row <= end; row++)
2022
+ wtot += w[row];
2023
+ }
2024
+
2025
+ memset(buffer_sum_tot, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
2026
+ if (Xr_indptr == NULL)
2027
+ {
2028
+ if (!force_cols_use)
2029
+ {
2030
+ if (x_uses_ix_arr)
2031
+ {
2032
+ for (size_t row = st; row <= end; row++)
2033
+ axpy1(w[ix_arr[row]], X_row_major + ix_arr[row]*ncols, buffer_sum_tot, ncols);
2034
+ }
2035
+
2036
+ else
2037
+ {
2038
+ for (size_t row = st; row <= end; row++)
2039
+ axpy1(w[row], X_row_major + ix_arr[row]*ncols, buffer_sum_tot, ncols);
2040
+ }
2041
+ }
2042
+
2043
+ else
2044
+ {
2045
+ double *restrict ptr_row;
2046
+ double w_row;
2047
+
2048
+ if (x_uses_ix_arr)
2049
+ {
2050
+ for (size_t row = st; row <= end; row++)
2051
+ {
2052
+ ptr_row = X_row_major + ix_arr[row]*ncols;
2053
+ w_row = w[ix_arr[row]];
2054
+ for (size_t col = 0; col < ncols_use; col++)
2055
+ buffer_sum_tot[col] = std::fma(w_row, ptr_row[cols_use[col]], buffer_sum_tot[col]);
2056
+ }
2057
+ }
2058
+
2059
+ else
2060
+ {
2061
+ for (size_t row = st; row <= end; row++)
2062
+ {
2063
+ ptr_row = X_row_major + ix_arr[row]*ncols;
2064
+ w_row = w[row];
2065
+ for (size_t col = 0; col < ncols_use; col++)
2066
+ buffer_sum_tot[col] = std::fma(w_row, ptr_row[cols_use[col]], buffer_sum_tot[col]);
2067
+ }
2068
+ }
2069
+ }
2070
+ }
2071
+
2072
+ else
2073
+ {
2074
+ if (!force_cols_use)
2075
+ {
2076
+ size_t ptr_this;
2077
+ if (x_uses_ix_arr)
2078
+ {
2079
+ for (size_t row = st; row <= end; row++)
2080
+ {
2081
+ ptr_this = Xr_indptr[ix_arr[row]];
2082
+ axpy1(w[ix_arr[row]], Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_tot);
2083
+ }
2084
+ }
2085
+
2086
+ else
2087
+ {
2088
+ for (size_t row = st; row <= end; row++)
2089
+ {
2090
+ ptr_this = Xr_indptr[ix_arr[row]];
2091
+ axpy1(w[row], Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_tot);
2092
+ }
2093
+ }
2094
+ }
2095
+
2096
+ else
2097
+ {
2098
+ size_t *curr_begin;
2099
+ size_t *row_end;
2100
+ size_t *curr_col;
2101
+ double *restrict Xr_this;
2102
+ size_t *cols_end = cols_use + ncols_use;
2103
+ double w_row;
2104
+ for (size_t row = st; row <= end; row++)
2105
+ {
2106
+ curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
2107
+ row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
2108
+ if (curr_begin == row_end) continue;
2109
+ curr_col = cols_use;
2110
+ Xr_this = Xr + Xr_indptr[ix_arr[row]];
2111
+ w_row = w[x_uses_ix_arr? ix_arr[row] : row];
2112
+ size_t dtemp;
2113
+
2114
+ while (curr_col < cols_end && curr_begin < row_end)
2115
+ {
2116
+ if (*curr_begin == *curr_col)
2117
+ {
2118
+ dtemp = std::distance(cols_use, curr_col);
2119
+ buffer_sum_tot[dtemp]
2120
+ =
2121
+ std::fma(w_row, Xr_this[std::distance(curr_begin, row_end)], buffer_sum_tot[dtemp]);
2122
+ curr_col++;
2123
+ curr_begin++;
2124
+ }
2125
+
2126
+ else
2127
+ {
2128
+ if (*curr_begin > *curr_col)
2129
+ curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
2130
+ else
2131
+ curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
2132
+ }
2133
+ }
2134
+ }
2135
+ }
2136
+ }
2137
+
2138
+ double best_gain = -HUGE_VAL;
2139
+ double this_gain;
2140
+ double sl, sr;
2141
+ double vleft, vright;
2142
+ double wleft = 0;
2143
+ double w_row;
2144
+ double wright;
2145
+ memset(buffer_sum_left, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
2146
+ if (Xr_indptr == NULL)
2147
+ {
2148
+ if (!force_cols_use)
2149
+ {
2150
+ for (size_t row = st; row < end; row++)
2151
+ {
2152
+ w_row = w[x_uses_ix_arr? ix_arr[row] : row];
2153
+ wleft += w_row;
2154
+ axpy1(w_row, X_row_major + ix_arr[row]*ncols, buffer_sum_left, ncols);
2155
+ if (x_uses_ix_arr) {
2156
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
2157
+ }
2158
+ else {
2159
+ if (unlikely(x[row] == x[row+1])) continue;
2160
+ }
2161
+
2162
+ vleft = 0;
2163
+ vright = 0;
2164
+ wright = wtot - wleft;
2165
+ for (size_t col = 0; col < ncols; col++)
2166
+ {
2167
+ sl = buffer_sum_left[col];
2168
+ vleft += sl * (sl / wleft);
2169
+ sr = buffer_sum_tot[col] - sl;
2170
+ vright += sr * (sr / wright);
2171
+ }
2172
+
2173
+ this_gain = vleft + vright;
2174
+ if (this_gain > best_gain)
2175
+ {
2176
+ best_gain = this_gain;
2177
+ split_ix = row;
2178
+ }
2179
+ }
2180
+ }
2181
+
2182
+ else
2183
+ {
2184
+ double *restrict ptr_row;
2185
+ double w_row;
2186
+ for (size_t row = st; row < end; row++)
2187
+ {
2188
+ w_row = w[x_uses_ix_arr? ix_arr[row] : row];
2189
+ wleft += w_row;
2190
+
2191
+ ptr_row = X_row_major + ix_arr[row]*ncols;
2192
+ for (size_t col = 0; col < ncols_use; col++)
2193
+ buffer_sum_left[col] = std::fma(w_row, ptr_row[cols_use[col]], buffer_sum_left[col]);
2194
+ if (x_uses_ix_arr) {
2195
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
2196
+ }
2197
+ else {
2198
+ if (unlikely(x[row] == x[row+1])) continue;
2199
+ }
2200
+
2201
+ vleft = 0;
2202
+ vright = 0;
2203
+ wright = wtot - wleft;
2204
+ for (size_t col = 0; col < ncols_use; col++)
2205
+ {
2206
+ sl = buffer_sum_left[col];
2207
+ vleft += sl * (sl / wleft);
2208
+ sr = buffer_sum_tot[col] - sl;
2209
+ vright += sr * (sr / wright);
2210
+ }
2211
+
2212
+ this_gain = vleft + vright;
2213
+ if (this_gain > best_gain)
2214
+ {
2215
+ best_gain = this_gain;
2216
+ split_ix = row;
2217
+ }
2218
+ }
2219
+ }
2220
+ }
2221
+
2222
+ else
2223
+ {
2224
+ if (!force_cols_use)
2225
+ {
2226
+ size_t ptr_this;
2227
+ double w_row;
2228
+ for (size_t row = st; row < end; row++)
2229
+ {
2230
+ w_row= w[x_uses_ix_arr? ix_arr[row] : row];
2231
+ wleft += w_row;
2232
+ ptr_this = Xr_indptr[ix_arr[row]];
2233
+ axpy1(w_row, Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_left);
2234
+ if (x_uses_ix_arr) {
2235
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
2236
+ }
2237
+ else {
2238
+ if (unlikely(x[row] == x[row+1])) continue;
2239
+ }
2240
+
2241
+ vleft = 0;
2242
+ vright = 0;
2243
+ wright = wtot - wleft;
2244
+ for (size_t col = 0; col < ncols; col++)
2245
+ {
2246
+ sl = buffer_sum_left[col];
2247
+ vleft += sl * (sl / wleft);
2248
+ sr = buffer_sum_tot[col] - sl;
2249
+ vright += sr * (sr / wright);
2250
+ }
2251
+
2252
+ this_gain = vleft + vright;
2253
+ if (this_gain > best_gain)
2254
+ {
2255
+ best_gain = this_gain;
2256
+ split_ix = row;
2257
+ }
2258
+ }
2259
+ }
2260
+
2261
+ else
2262
+ {
2263
+ size_t *curr_begin;
2264
+ size_t *row_end;
2265
+ size_t *curr_col;
2266
+ double *restrict Xr_this;
2267
+ size_t *cols_end = cols_use + ncols_use;
2268
+ double w_row;
2269
+ size_t dtemp;
2270
+ for (size_t row = st; row < end; row++)
2271
+ {
2272
+ w_row = w[x_uses_ix_arr? ix_arr[row] : row];
2273
+ wleft += w_row;
2274
+
2275
+ curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
2276
+ row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
2277
+ if (curr_begin == row_end) goto skip_sum;
2278
+ curr_col = cols_use;
2279
+ Xr_this = Xr + Xr_indptr[ix_arr[row]];
2280
+ while (curr_col < cols_end && curr_begin < row_end)
2281
+ {
2282
+ if (*curr_begin == *curr_col)
2283
+ {
2284
+ dtemp = std::distance(cols_use, curr_col);
2285
+ buffer_sum_left[dtemp]
2286
+ =
2287
+ std::fma(w_row, Xr_this[std::distance(curr_begin, row_end)], buffer_sum_left[dtemp]);
2288
+ curr_col++;
2289
+ curr_begin++;
2290
+ }
2291
+
2292
+ else
2293
+ {
2294
+ if (*curr_begin > *curr_col)
2295
+ curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
2296
+ else
2297
+ curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
2298
+ }
2299
+ }
2300
+
2301
+ skip_sum:
2302
+ if (x_uses_ix_arr) {
2303
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
2304
+ }
2305
+ else {
2306
+ if (unlikely(x[row] == x[row+1])) continue;
2307
+ }
2308
+
2309
+ vleft = 0;
2310
+ vright = 0;
2311
+ wright = wtot - wleft;
2312
+ for (size_t col = 0; col < ncols_use; col++)
2313
+ {
2314
+ sl = buffer_sum_left[col];
2315
+ vleft += sl * (sl / wleft);
2316
+ sr = buffer_sum_tot[col] - sl;
2317
+ vright += sr * (sr / wright);
2318
+ }
2319
+
2320
+ this_gain = vleft + vright;
2321
+ if (this_gain > best_gain)
2322
+ {
2323
+ best_gain = this_gain;
2324
+ split_ix = row;
2325
+ }
2326
+ }
2327
+ }
2328
+ }
2329
+
2330
+ if (best_gain <= -HUGE_VAL) return best_gain;
2331
+
2332
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
2333
+ return best_gain / wtot;
2334
+ }
2335
+
2336
+ template <class real_t_, class real_t>
2337
+ double find_split_dens_shortform_t(real_t *restrict x, size_t n, double &restrict split_point)
2338
+ {
2339
+ double best_gain = -HUGE_VAL;
2340
+ size_t n_minus_one = n - 1;
2341
+ real_t_ xmin = x[0];
2342
+ real_t_ xmax = x[n-1];
2343
+ real_t_ xleft, xright;
2344
+ real_t_ xmid;
2345
+ double this_gain;
2346
+ size_t split_ix = 0;
2347
+
2348
+ for (size_t ix = 0; ix < n_minus_one; ix++)
2349
+ {
2350
+ if (x[ix] == x[ix+1]) continue;
2351
+ xmid = (real_t_)x[ix] + ((real_t_)x[ix+1] - (real_t_)x[ix]) / (real_t_)2;
2352
+ xleft = xmid - xmin;
2353
+ xright = xmax - xmid;
2354
+ if (unlikely(!xleft || !xright)) continue;
2355
+ this_gain = (real_t_)square(ix+1) / xleft + (real_t_)square(n_minus_one - ix) / xright;
2356
+ if (this_gain > best_gain)
2357
+ {
2358
+ best_gain = this_gain;
2359
+ split_ix = ix;
2360
+ }
2361
+ }
2362
+
2363
+ if (best_gain <= -HUGE_VAL) return best_gain;
2364
+
2365
+ real_t_ xtot = (real_t_)xmax - (real_t_)xmin;
2366
+ real_t_ nleft = (real_t_)(split_ix+1);
2367
+ real_t_ nright = (real_t_)(n_minus_one - split_ix);
2368
+ split_point = midpoint(x[split_ix], x[split_ix+1]);
2369
+ real_t_ rpct_left = split_point / xtot;
2370
+ rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
2371
+ real_t_ rpct_right = (real_t_)1 - rpct_left;
2372
+ rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
2373
+
2374
+ real_t_ nl_sq = nleft / (real_t_)n; nl_sq = square(nl_sq);
2375
+ real_t_ nr_sq = nright / (real_t_)n; nl_sq = square(nr_sq);
2376
+
2377
+ return nl_sq / rpct_left + nr_sq / rpct_right;
2378
+ }
2379
+
2380
+ template <class real_t, class ldouble_safe>
2381
+ double find_split_dens_shortform(real_t *restrict x, size_t n, double &restrict split_point)
2382
+ {
2383
+ if (n < INT32_MAX)
2384
+ return find_split_dens_shortform_t<double, real_t>(x, n, split_point);
2385
+ else
2386
+ return find_split_dens_shortform_t<ldouble_safe, real_t>(x, n, split_point);
2387
+ }
2388
+
2389
+ template <class real_t_, class real_t, class mapping>
2390
+ double find_split_dens_shortform_weighted_t(real_t *restrict x, size_t n, double &restrict split_point, mapping &restrict w, size_t *restrict buffer_indices)
2391
+ {
2392
+ double best_gain = -HUGE_VAL;
2393
+ size_t n_minus_one = n - 1;
2394
+ real_t_ xmin = x[buffer_indices[0]];
2395
+ real_t_ xmax = x[buffer_indices[n-1]];
2396
+ real_t_ xleft, xright;
2397
+ real_t_ xmid;
2398
+ double this_gain;
2399
+
2400
+ real_t_ wtot = 0;
2401
+ for (size_t ix = 0; ix < n; ix++)
2402
+ wtot += w[buffer_indices[ix]];
2403
+ real_t_ w_left = 0;
2404
+ real_t_ w_right;
2405
+ real_t_ best_w = 0;
2406
+ size_t split_ix = 0;
2407
+
2408
+ for (size_t ix = 0; ix < n_minus_one; ix++)
2409
+ {
2410
+ w_left += w[buffer_indices[ix]];
2411
+ if (x[buffer_indices[ix]] == x[buffer_indices[ix+1]]) continue;
2412
+ xmid = (real_t_)x[buffer_indices[ix]] + ((real_t_)x[buffer_indices[ix+1]] - (real_t_)x[buffer_indices[ix]]) / (real_t_)2;
2413
+ xleft = xmid - xmin;
2414
+ xright = xmax - xmid;
2415
+ if (unlikely(!xleft || !xright)) continue;
2416
+
2417
+ w_right = wtot - w_left;
2418
+ this_gain = square(w_left) / xleft + square(w_right) / xright;
2419
+ if (this_gain > best_gain)
2420
+ {
2421
+ best_gain = this_gain;
2422
+ best_w = w_left;
2423
+ split_ix = xmid;
2424
+ }
2425
+ }
2426
+
2427
+ if (best_gain <= -HUGE_VAL) return best_gain;
2428
+
2429
+ real_t_ xtot = xmax - xmin;
2430
+ w_left = best_w;
2431
+ w_right = wtot - w_left;
2432
+ w_left = std::fmax(w_left, std::numeric_limits<double>::min());
2433
+ w_right = std::fmax(w_right, std::numeric_limits<double>::min());
2434
+ split_point = midpoint(x[buffer_indices[split_ix]], x[buffer_indices[split_ix+1]]);
2435
+ real_t_ rpct_left = split_point / xtot;
2436
+ rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
2437
+ real_t_ rpct_right = (real_t_)1 - rpct_left;
2438
+ rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
2439
+
2440
+ real_t_ wl_sq = w_left / wtot; wl_sq = square(wl_sq);
2441
+ real_t_ wr_sq = w_right / wtot; wl_sq = square(wr_sq);
2442
+
2443
+ return wl_sq / rpct_left + wr_sq / rpct_right;
2444
+ }
2445
+
2446
+ template <class real_t, class mapping, class ldouble_safe>
2447
+ double find_split_dens_shortform_weighted(real_t *restrict x, size_t n, double &restrict split_point, mapping &restrict w, size_t *restrict buffer_indices)
2448
+ {
2449
+ if (n < INT32_MAX)
2450
+ return find_split_dens_shortform_weighted_t<double, real_t, mapping>(x, n, split_point, w, buffer_indices);
2451
+ else
2452
+ return find_split_dens_shortform_weighted_t<ldouble_safe, real_t, mapping>(x, n, split_point, w, buffer_indices);
2453
+ }
2454
+
2455
+ template <class real_t>
2456
+ double find_split_dens_shortform(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2457
+ double &restrict split_point, size_t &restrict split_ix)
2458
+ {
2459
+ double best_gain = -HUGE_VAL;
2460
+ real_t xmin = x[ix_arr[st]];
2461
+ real_t xmax = x[ix_arr[end]];
2462
+ real_t xleft, xright;
2463
+ real_t xmid;
2464
+ double this_gain;
2465
+
2466
+ for (size_t row = st; row < end; row++)
2467
+ {
2468
+ if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
2469
+ xmid = x[ix_arr[row]] + (x[ix_arr[row+1]] - x[ix_arr[row]]) / (real_t)2;
2470
+ xleft = xmid - xmin;
2471
+ xright = xmax - xmid;
2472
+ if (unlikely(!xleft || !xright)) continue;
2473
+ this_gain = square(row-st+1) / xleft + square(end-row) / xright;
2474
+ if (this_gain > best_gain)
2475
+ {
2476
+ best_gain = this_gain;
2477
+ split_ix = row;
2478
+ }
2479
+ }
2480
+
2481
+ if (best_gain <= -HUGE_VAL) return best_gain;
2482
+
2483
+ double xtot = (double)xmax - (double)xmin;
2484
+ double nleft = (double)(split_ix-st+1);
2485
+ double nright = (double)(end - split_ix);
2486
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
2487
+ double rpct_left = split_point / xtot;
2488
+ rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
2489
+ double rpct_right = 1. - rpct_left;
2490
+ rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
2491
+ double ntot = (double)(end - st + 1);
2492
+
2493
+ double nl_sq = nleft / ntot; nl_sq = square(nl_sq);
2494
+ double nr_sq = nright / ntot; nl_sq = square(nr_sq);
2495
+
2496
+ return nl_sq / rpct_left + nr_sq / rpct_right;
2497
+ }
2498
+
2499
+ template <class real_t, class mapping>
2500
+ double find_split_dens_shortform_weighted(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2501
+ double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
2502
+ {
2503
+ double best_gain = -HUGE_VAL;
2504
+ real_t xmin = x[ix_arr[st]];
2505
+ real_t xmax = x[ix_arr[end]];
2506
+ real_t xleft, xright;
2507
+ real_t xmid;
2508
+ double this_gain;
2509
+
2510
+ double wtot = 0;
2511
+ for (size_t row = st; row <= end; row++)
2512
+ wtot += w[ix_arr[row]];
2513
+ double w_left = 0;
2514
+ double w_right;
2515
+ double best_w = 0;
2516
+
2517
+ for (size_t row = st; row < end; row++)
2518
+ {
2519
+ w_left += w[ix_arr[row]];
2520
+ if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
2521
+ xmid = x[ix_arr[row]] + (x[ix_arr[row+1]] - x[ix_arr[row]]) / (real_t)2;
2522
+ xleft = xmid - xmin;
2523
+ xright = xmax - xmid;
2524
+ if (unlikely(!xleft || !xright)) continue;
2525
+
2526
+ w_right = wtot - w_left;
2527
+ this_gain = square(w_left) / xleft + square(w_right) / xright;
2528
+ if (this_gain > best_gain)
2529
+ {
2530
+ best_gain = this_gain;
2531
+ best_w = w_left;
2532
+ split_ix = row;
2533
+ }
2534
+ }
2535
+
2536
+ if (best_gain <= -HUGE_VAL) return best_gain;
2537
+
2538
+ double xtot = (double)xmax - (double)xmin;
2539
+ w_left = best_w;
2540
+ w_right = wtot - w_left;
2541
+ w_left = std::fmax(w_left, std::numeric_limits<double>::min());
2542
+ w_right = std::fmax(w_right, std::numeric_limits<double>::min());
2543
+ split_point = midpoint(x[split_ix], x[split_ix+1]);
2544
+ double rpct_left = split_point / xtot;
2545
+ rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
2546
+ double rpct_right = 1. - rpct_left;
2547
+ rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
2548
+
2549
+ double wl_sq = w_left / wtot; wl_sq = square(wl_sq);
2550
+ double wr_sq = w_right / wtot; wl_sq = square(wr_sq);
2551
+
2552
+ return wl_sq / rpct_left + wr_sq / rpct_right;
2553
+ }
2554
+
2555
+ /* This is a slower but more numerically-robust form */
2556
+ template <class real_t, class ldouble_safe>
2557
+ double find_split_dens_longform(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2558
+ double &restrict split_point, size_t &restrict split_ix)
2559
+ {
2560
+ double best_gain = -HUGE_VAL;
2561
+ real_t xmin = x[ix_arr[st]];
2562
+ real_t xmax = x[ix_arr[end]];
2563
+ real_t xleft, xright;
2564
+ real_t xmid;
2565
+ ldouble_safe pct_left, pct_right;
2566
+ ldouble_safe rpct_left, rpct_right;
2567
+ ldouble_safe n_tot = end - st + 1;
2568
+ ldouble_safe xtot = (ldouble_safe)xmax - (ldouble_safe)xmin;
2569
+ ldouble_safe cnt_left;
2570
+ double this_gain;
2571
+
2572
+ for (size_t row = st; row < end; row++)
2573
+ {
2574
+ if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
2575
+ xmid = midpoint(x[ix_arr[row]], x[ix_arr[row+1]]);
2576
+ xleft = xmid - xmin;
2577
+ xright = xmax - xmid;
2578
+ if (unlikely(!xleft || !xright)) continue;
2579
+
2580
+ cnt_left = (ldouble_safe)(row-st+1);
2581
+
2582
+ xleft = std::fmax(xleft, (real_t)std::numeric_limits<real_t>::min());
2583
+ xright = std::fmax(xright, (real_t)std::numeric_limits<real_t>::min());
2584
+ pct_left = cnt_left / n_tot;
2585
+ pct_right = (ldouble_safe)1 - pct_left;
2586
+ rpct_left = (ldouble_safe)xleft / xtot;
2587
+ rpct_right = (ldouble_safe)xright / xtot;
2588
+
2589
+ this_gain = square(pct_left) / rpct_left + square(pct_right) / rpct_right;
2590
+ if (unlikely(is_na_or_inf(this_gain))) continue;
2591
+ if (this_gain > best_gain)
2592
+ {
2593
+ best_gain = this_gain;
2594
+ split_point = xmid;
2595
+ split_ix = row;
2596
+ }
2597
+ }
2598
+
2599
+ return best_gain;
2600
+ }
2601
+
2602
+ template <class real_t, class mapping, class ldouble_safe>
2603
+ double find_split_dens_longform_weighted(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2604
+ double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
2605
+ {
2606
+ double best_gain = -HUGE_VAL;
2607
+ real_t xmin = x[ix_arr[st]];
2608
+ real_t xmax = x[ix_arr[end]];
2609
+ real_t xleft, xright;
2610
+ real_t xmid;
2611
+ ldouble_safe pct_left, pct_right;
2612
+ ldouble_safe rpct_left, rpct_right;
2613
+ ldouble_safe xtot = (ldouble_safe)xmax - (ldouble_safe)xmin;
2614
+ double this_gain;
2615
+
2616
+ ldouble_safe wtot = 0;
2617
+ for (size_t row = st; row <= end; row++)
2618
+ wtot += w[ix_arr[row]];
2619
+ ldouble_safe w_left = 0;
2620
+
2621
+ for (size_t row = st; row < end; row++)
2622
+ {
2623
+ w_left += w[ix_arr[row]];
2624
+ if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
2625
+ xmid = midpoint(x[ix_arr[row]], x[ix_arr[row+1]]);
2626
+ xleft = xmid - xmin;
2627
+ xright = xmax - xmid;
2628
+ if (unlikely(!xleft || !xright)) continue;
2629
+
2630
+ xleft = std::fmax(xleft, (real_t)std::numeric_limits<real_t>::min());
2631
+ xright = std::fmax(xright, (real_t)std::numeric_limits<real_t>::min());
2632
+ pct_left = w_left / wtot;
2633
+ pct_right = (ldouble_safe)1 - pct_left;
2634
+ rpct_left = (ldouble_safe)xleft / xtot;
2635
+ rpct_right = (ldouble_safe)xright / xtot;
2636
+
2637
+ this_gain = square(pct_left) / rpct_left + square(pct_right) / rpct_right;
2638
+ if (unlikely(is_na_or_inf(this_gain))) continue;
2639
+ if (this_gain > best_gain)
2640
+ {
2641
+ best_gain = this_gain;
2642
+ split_point = xmid;
2643
+ split_ix = row;
2644
+ }
2645
+ }
2646
+
2647
+ return best_gain;
2648
+ }
2649
+
2650
+ template <class real_t, class ldouble_safe>
2651
+ double find_split_dens(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2652
+ double &restrict split_point, size_t &restrict split_ix)
2653
+ {
2654
+ if (end - st + 1 < INT32_MAX && x[ix_arr[end]] - x[ix_arr[st]] >= 1)
2655
+ return find_split_dens_shortform<real_t>(x, ix_arr, st, end, split_point, split_ix);
2656
+ else
2657
+ return find_split_dens_longform<real_t, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix);
2658
+ }
2659
+
2660
+ template <class real_t, class mapping, class ldouble_safe>
2661
+ double find_split_dens_weighted(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2662
+ double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
2663
+ {
2664
+ if (end - st + 1 < INT32_MAX && x[ix_arr[end]] - x[ix_arr[st]] >= 1)
2665
+ return find_split_dens_shortform_weighted<real_t, mapping>(x, ix_arr, st, end, split_point, split_ix, w);
2666
+ else
2667
+ return find_split_dens_longform_weighted<real_t, mapping, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix, w);
2668
+ }
2669
+
2670
+ template <class int_t, class ldouble_safe>
2671
+ double find_split_dens_longform(int *restrict x, int ncat, size_t *restrict ix_arr, size_t st, size_t end,
2672
+ CategSplit cat_split_type, MissingAction missing_action,
2673
+ int &restrict chosen_cat, signed char *restrict split_categ, int *restrict saved_cat_mode,
2674
+ size_t *restrict buffer_cnt, int_t *restrict buffer_indices)
2675
+ {
2676
+ if (st >= end || ncat <= 1) return -HUGE_VAL;
2677
+ size_t n_nas = 0;
2678
+ int xval;
2679
+
2680
+ /* count categories */
2681
+ memset(buffer_cnt, 0, sizeof(size_t) * ncat);
2682
+ if (missing_action == Fail)
2683
+ {
2684
+ for (size_t row = st; row <= end; row++)
2685
+ if (likely(x[ix_arr[row]] >= 0))
2686
+ buffer_cnt[x[ix_arr[row]]]++;
2687
+ }
2688
+
2689
+ else if (missing_action == Impute)
2690
+ {
2691
+ for (size_t row = st; row <= end; row++)
2692
+ {
2693
+ xval = x[ix_arr[row]];
2694
+ if (unlikely(xval < 0))
2695
+ n_nas++;
2696
+ else
2697
+ buffer_cnt[xval]++;
2698
+ }
2699
+
2700
+ if (unlikely(n_nas >= end-st)) return -HUGE_VAL;
2701
+
2702
+ if (n_nas)
2703
+ {
2704
+ auto idxmax = std::max_element(buffer_cnt, buffer_cnt + ncat);
2705
+ *idxmax += n_nas;
2706
+ *saved_cat_mode = (int)std::distance(buffer_cnt, idxmax);
2707
+ }
2708
+ }
2709
+
2710
+ else
2711
+ {
2712
+ for (size_t row = st; row <= end; row++)
2713
+ {
2714
+ xval = x[ix_arr[row]];
2715
+ if (likely(xval >= 0)) buffer_cnt[xval]++;
2716
+ }
2717
+ }
2718
+
2719
+ std::iota(buffer_indices, buffer_indices + ncat, (int_t)0);
2720
+ std::sort(buffer_indices, buffer_indices + ncat,
2721
+ [&buffer_cnt](const int_t a, const int_t b)
2722
+ {return buffer_cnt[a] < buffer_cnt[b];});
2723
+
2724
+ int curr = 0;
2725
+ if (split_categ != NULL)
2726
+ {
2727
+ while (buffer_cnt[buffer_indices[curr]] == 0)
2728
+ {
2729
+ split_categ[buffer_indices[curr]] = -1;
2730
+ curr++;
2731
+ }
2732
+ }
2733
+
2734
+ else
2735
+ {
2736
+ while (buffer_cnt[buffer_indices[curr]] == 0) curr++;
2737
+ }
2738
+
2739
+ int ncat_present = ncat - curr;
2740
+ if (ncat_present <= 1) return -HUGE_VAL;
2741
+ if (ncat_present == 2)
2742
+ {
2743
+ switch (cat_split_type)
2744
+ {
2745
+ case SingleCateg:
2746
+ {
2747
+ chosen_cat = buffer_indices[curr];
2748
+ break;
2749
+ }
2750
+
2751
+ case SubSet:
2752
+ {
2753
+ split_categ[buffer_indices[curr]] = 1;
2754
+ split_categ[buffer_indices[curr+1]] = 0;
2755
+ break;
2756
+ }
2757
+ }
2758
+
2759
+ ldouble_safe pct_left
2760
+ =
2761
+ (ldouble_safe)buffer_cnt[buffer_indices[curr]]
2762
+ /
2763
+ (ldouble_safe)(
2764
+ buffer_cnt[buffer_indices[curr]]
2765
+ +
2766
+ buffer_cnt[buffer_indices[curr+1]]
2767
+ );
2768
+
2769
+ return ((ldouble_safe)buffer_cnt[buffer_indices[curr]] * (2. * pct_left)
2770
+ +
2771
+ (ldouble_safe)buffer_cnt[buffer_indices[curr+1]] * (2. - 2.*pct_left))
2772
+ /
2773
+ (ldouble_safe)(buffer_cnt[buffer_indices[curr]] + buffer_cnt[buffer_indices[curr+1]]);
2774
+ }
2775
+
2776
+ size_t ntot;
2777
+ if (missing_action == Impute)
2778
+ ntot = end - st + 1;
2779
+ else
2780
+ ntot = std::accumulate(buffer_cnt, buffer_cnt + ncat, (size_t)0);
2781
+ if (unlikely(ntot <= 1)) unexpected_error();
2782
+ ldouble_safe ntot_ = (ldouble_safe)ntot;
2783
+
2784
+ switch (cat_split_type)
2785
+ {
2786
+ case SingleCateg:
2787
+ {
2788
+ double pct_one_cat = 1. / (double)ncat_present;
2789
+ double pct_left_smallest = (ldouble_safe)buffer_cnt[buffer_indices[curr]] / ntot_;
2790
+ double gain_smallest
2791
+ =
2792
+ (ldouble_safe)buffer_cnt[buffer_indices[curr]] * (pct_left_smallest / pct_one_cat)
2793
+ +
2794
+ (ldouble_safe)(ntot - buffer_cnt[buffer_indices[curr]]) * ((1. - pct_left_smallest) / (1. - pct_one_cat))
2795
+ ;
2796
+
2797
+ double pct_left_biggest = (ldouble_safe)buffer_cnt[buffer_indices[ncat-1]] / ntot_;
2798
+ double gain_biggest
2799
+ =
2800
+ (ldouble_safe)buffer_cnt[buffer_indices[ncat-1]] * (pct_left_biggest / pct_one_cat)
2801
+ +
2802
+ (ldouble_safe)(ntot - buffer_cnt[buffer_indices[ncat-1]]) * ((1. - pct_left_biggest) / (1. - pct_one_cat))
2803
+ ;
2804
+
2805
+ if (gain_smallest >= gain_biggest)
2806
+ {
2807
+ chosen_cat = buffer_indices[curr];
2808
+ return gain_smallest / ntot_;
2809
+ }
2810
+
2811
+ else
2812
+ {
2813
+ chosen_cat = buffer_indices[ncat-1];
2814
+ return gain_biggest / ntot_;
2815
+ }
2816
+ break;
2817
+ }
2818
+
2819
+ case SubSet:
2820
+ {
2821
+ size_t cnt_left = 0;
2822
+ size_t cnt_right;
2823
+ int st_cat = curr - 1;
2824
+ double this_gain;
2825
+ double best_gain = -HUGE_VAL;
2826
+ int best_cat = 0;
2827
+ ldouble_safe pct_left;
2828
+ double pct_cat_left;
2829
+ double ncat_present_ = (double)ncat_present;
2830
+ for (; curr < ncat; curr++)
2831
+ {
2832
+ cnt_left += buffer_cnt[buffer_indices[curr]];
2833
+ cnt_right = ntot - cnt_left;
2834
+ pct_left = (ldouble_safe)cnt_left / ntot_;
2835
+ pct_cat_left = (double)(curr - st_cat) / ncat_present_;
2836
+ this_gain
2837
+ =
2838
+ (ldouble_safe)cnt_left * (pct_left / pct_cat_left)
2839
+ +
2840
+ (ldouble_safe)cnt_right * (((ldouble_safe)1 - pct_left) / (1. - pct_cat_left))
2841
+ ;
2842
+ if (this_gain > best_gain)
2843
+ {
2844
+ best_gain = this_gain;
2845
+ best_cat = curr;
2846
+ }
2847
+ }
2848
+
2849
+ if (best_gain <= -HUGE_VAL) return best_gain;
2850
+ st_cat++;
2851
+ for (; st_cat <= best_cat; st_cat++)
2852
+ split_categ[buffer_indices[st_cat]] = 1;
2853
+ for (; st_cat < ncat; st_cat++)
2854
+ split_categ[buffer_indices[st_cat]] = 0;
2855
+ return best_gain / ntot_;
2856
+ break;
2857
+ }
2858
+ }
2859
+
2860
+ /* This will not be reached, but CRAN might complain otherwise */
2861
+ return -HUGE_VAL;
2862
+ }
2863
+
2864
+ template <class mapping, class int_t, class ldouble_safe>
2865
+ double find_split_dens_longform_weighted(int *restrict x, int ncat, size_t *restrict ix_arr, size_t st, size_t end,
2866
+ CategSplit cat_split_type, MissingAction missing_action,
2867
+ int &restrict chosen_cat, signed char *restrict split_categ, int *restrict saved_cat_mode,
2868
+ int_t *restrict buffer_indices, mapping &restrict w)
2869
+ {
2870
+ if (st >= end || ncat <= 1) return -HUGE_VAL;
2871
+ ldouble_safe w_missing = 0;
2872
+ int xval;
2873
+ size_t ix_;
2874
+
2875
+ /* count categories */
2876
+ /* TODO: allocate this buffer externally */
2877
+ std::vector<ldouble_safe> buffer_cnt(ncat, (ldouble_safe)0);
2878
+ if (missing_action == Fail)
2879
+ {
2880
+ for (size_t row = st; row <= end; row++)
2881
+ {
2882
+ ix_ = ix_arr[row];
2883
+ if (unlikely(x[ix_]) < 0) continue;
2884
+ buffer_cnt[x[ix_]] += w[ix_];
2885
+ }
2886
+ }
2887
+
2888
+ else if (missing_action == Impute)
2889
+ {
2890
+ for (size_t row = st; row <= end; row++)
2891
+ {
2892
+ ix_ = ix_arr[row];
2893
+ xval = x[ix_];
2894
+ if (unlikely(xval < 0))
2895
+ w_missing += w[ix_];
2896
+ else
2897
+ buffer_cnt[xval] += w[ix_];
2898
+ }
2899
+
2900
+ if (w_missing)
2901
+ {
2902
+ auto idxmax = std::max_element(buffer_cnt.begin(), buffer_cnt.end());
2903
+ *idxmax += w_missing;
2904
+ *saved_cat_mode = (int)std::distance(buffer_cnt.begin(), idxmax);
2905
+ }
2906
+ }
2907
+
2908
+ else
2909
+ {
2910
+ for (size_t row = st; row <= end; row++)
2911
+ {
2912
+ ix_ = ix_arr[row];
2913
+ xval = x[ix_];
2914
+ if (likely(xval >= 0)) buffer_cnt[xval] += w[ix_];
2915
+ }
2916
+ }
2917
+
2918
+ std::iota(buffer_indices, buffer_indices + ncat, (int_t)0);
2919
+ std::sort(buffer_indices, buffer_indices + ncat,
2920
+ [&buffer_cnt](const int_t a, const int_t b)
2921
+ {return buffer_cnt[a] < buffer_cnt[b];});
2922
+
2923
+ int curr = 0;
2924
+ if (split_categ != NULL)
2925
+ {
2926
+ while (buffer_cnt[buffer_indices[curr]] == 0)
2927
+ {
2928
+ split_categ[buffer_indices[curr]] = -1;
2929
+ curr++;
2930
+ }
2931
+ }
2932
+
2933
+ else
2934
+ {
2935
+ while (buffer_cnt[buffer_indices[curr]] == 0) curr++;
2936
+ }
2937
+
2938
+ int ncat_present = ncat - curr;
2939
+ if (ncat_present <= 1) return -HUGE_VAL;
2940
+ if (ncat_present == 2)
2941
+ {
2942
+ switch (cat_split_type)
2943
+ {
2944
+ case SingleCateg:
2945
+ {
2946
+ chosen_cat = buffer_indices[curr];
2947
+ break;
2948
+ }
2949
+
2950
+ case SubSet:
2951
+ {
2952
+ split_categ[buffer_indices[curr]] = 1;
2953
+ split_categ[buffer_indices[curr+1]] = 0;
2954
+ break;
2955
+ }
2956
+ }
2957
+
2958
+ ldouble_safe pct_left
2959
+ =
2960
+ buffer_cnt[buffer_indices[curr]]
2961
+ /
2962
+ (
2963
+ buffer_cnt[buffer_indices[curr]]
2964
+ +
2965
+ buffer_cnt[buffer_indices[curr+1]]
2966
+ );
2967
+
2968
+ return (buffer_cnt[buffer_indices[curr]] * (pct_left * 2.)
2969
+ +
2970
+ buffer_cnt[buffer_indices[curr+1]] * (2. - 2.*pct_left))
2971
+ /
2972
+ (buffer_cnt[buffer_indices[curr]] + buffer_cnt[buffer_indices[curr+1]]);
2973
+ }
2974
+
2975
+ ldouble_safe ntot = std::accumulate(buffer_cnt.begin(), buffer_cnt.end(), (ldouble_safe)0);
2976
+ if (unlikely(ntot <= 0)) unexpected_error();
2977
+
2978
+ switch (cat_split_type)
2979
+ {
2980
+ case SingleCateg:
2981
+ {
2982
+ double pct_one_cat = 1. / (double)ncat_present;
2983
+ double pct_left_smallest = buffer_cnt[buffer_indices[curr]] / ntot;
2984
+ double gain_smallest
2985
+ =
2986
+ buffer_cnt[buffer_indices[curr]] * (pct_left_smallest / pct_one_cat)
2987
+ +
2988
+ (ntot - buffer_cnt[buffer_indices[curr]]) * ((1. - pct_left_smallest) / (1. - pct_one_cat))
2989
+ ;
2990
+
2991
+ double pct_left_biggest = buffer_cnt[buffer_indices[ncat-1]] / ntot;
2992
+ double gain_biggest
2993
+ =
2994
+ buffer_cnt[buffer_indices[ncat-1]] * (pct_left_biggest / pct_one_cat)
2995
+ +
2996
+ (ntot - buffer_cnt[buffer_indices[ncat-1]]) * ((1. - pct_left_biggest) / (1. - pct_one_cat))
2997
+ ;
2998
+
2999
+ if (gain_smallest >= gain_biggest)
3000
+ {
3001
+ chosen_cat = buffer_indices[curr];
3002
+ return gain_smallest / ntot;
3003
+ }
3004
+
3005
+ else
3006
+ {
3007
+ chosen_cat = buffer_indices[ncat-1];
3008
+ return gain_biggest / ntot;
3009
+ }
3010
+ break;
3011
+ }
3012
+
3013
+ case SubSet:
3014
+ {
3015
+ ldouble_safe cnt_left = 0;
3016
+ ldouble_safe cnt_right;
3017
+ int st_cat = curr - 1;
3018
+ double this_gain;
3019
+ double best_gain = -HUGE_VAL;
3020
+ int best_cat = 0;
3021
+ ldouble_safe pct_left;
3022
+ double pct_cat_left;
3023
+ double ncat_present_ = (double)ncat_present;
3024
+ for (; curr < ncat; curr++)
3025
+ {
3026
+ cnt_left += buffer_cnt[buffer_indices[curr]];
3027
+ cnt_right = ntot - cnt_left;
3028
+ pct_left = cnt_left / ntot;
3029
+ pct_cat_left = (double)(curr - st_cat) / ncat_present_;
3030
+ this_gain
3031
+ =
3032
+ (ldouble_safe)cnt_left * (pct_left / pct_cat_left)
3033
+ +
3034
+ (ldouble_safe)cnt_right * (((ldouble_safe)1 - pct_left) / (1. - pct_cat_left))
3035
+ ;
3036
+ if (this_gain > best_gain)
3037
+ {
3038
+ best_gain = this_gain;
3039
+ best_cat = curr;
3040
+ }
3041
+ }
3042
+
3043
+ if (best_gain <= -HUGE_VAL) return best_gain;
3044
+ st_cat++;
3045
+ for (; st_cat <= best_cat; st_cat++)
3046
+ split_categ[buffer_indices[st_cat]] = 1;
3047
+ for (; st_cat < ncat; st_cat++)
3048
+ split_categ[buffer_indices[st_cat]] = 0;
3049
+ return best_gain / ntot;
3050
+ break;
3051
+ }
3052
+ }
3053
+
3054
+ /* This will not be reached, but CRAN might complain otherwise */
3055
+ return -HUGE_VAL;
3056
+ }
3057
+
3058
+ /* for split-criterion in hyperplanes (see below for version aimed at single-variable splits) */
3059
+ template <class ldouble_safe>
3060
+ double eval_guided_crit(double *restrict x, size_t n, GainCriterion criterion,
3061
+ double min_gain, bool as_relative_gain, double *restrict buffer_sd,
3062
+ double &restrict split_point, double &restrict xmin, double &restrict xmax,
3063
+ size_t *restrict ix_arr_plus_st,
3064
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3065
+ double *restrict X_row_major, size_t ncols,
3066
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
3067
+ {
3068
+ /* Note: the input 'x' is supposed to be a linear combination of standardized variables, so
3069
+ all numbers are assumed to be small and in the same scale */
3070
+ double gain = 0;
3071
+ if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
3072
+
3073
+ /* here it's assumed the 'x' vector matches exactly with 'ix_arr' + 'st' */
3074
+ if (unlikely(n == 2))
3075
+ {
3076
+ if (x[0] == x[1]) return -HUGE_VAL;
3077
+ split_point = midpoint_with_reorder(x[0], x[1]);
3078
+ gain = 1.;
3079
+ if (gain > min_gain)
3080
+ return gain;
3081
+ else
3082
+ return 0.;
3083
+ }
3084
+
3085
+ if (criterion == FullGain)
3086
+ {
3087
+ /* TODO: these buffers should be allocated externally */
3088
+ std::vector<size_t> argsorted(n);
3089
+ std::iota(argsorted.begin(), argsorted.end(), (size_t)0);
3090
+ std::sort(argsorted.begin(), argsorted.end(),
3091
+ [&x](const size_t a, const size_t b){return x[a] < x[b];});
3092
+ if (x[argsorted[0]] == x[argsorted[n-1]]) return -HUGE_VAL;
3093
+ std::vector<double> temp_buffer(n + mult2(ncols));
3094
+ for (size_t ix = 0; ix < n; ix++) temp_buffer[ix] = x[argsorted[ix]];
3095
+ for (size_t ix = 0; ix < n; ix++)
3096
+ argsorted[ix] = ix_arr_plus_st[argsorted[ix]];
3097
+ size_t ignored;
3098
+ return find_split_full_gain<double, ldouble_safe>(
3099
+ temp_buffer.data(), (size_t)0, n-1, argsorted.data(),
3100
+ cols_use, ncols_use, force_cols_use,
3101
+ X_row_major, ncols,
3102
+ Xr, Xr_ind, Xr_indptr,
3103
+ temp_buffer.data() + n, temp_buffer.data() + n + ncols,
3104
+ ignored, split_point,
3105
+ false);
3106
+ }
3107
+
3108
+ /* sort in ascending order */
3109
+ std::sort(x, x + n);
3110
+ xmin = x[0]; xmax = x[n-1];
3111
+ if (x[0] == x[n-1]) return -HUGE_VAL;
3112
+
3113
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3114
+ gain = find_split_rel_gain<double, ldouble_safe>(x, n, split_point);
3115
+ else if (criterion == Pooled || criterion == Averaged)
3116
+ gain = find_split_std_gain<double, ldouble_safe>(x, n, buffer_sd, criterion, min_gain, split_point);
3117
+ else if (criterion == DensityCrit)
3118
+ gain = find_split_dens_shortform<double, ldouble_safe>(x, n, split_point);
3119
+ /* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
3120
+ return std::fmax(0., gain);
3121
+ }
3122
+
3123
+ template <class ldouble_safe>
3124
+ double eval_guided_crit_weighted(double *restrict x, size_t n, GainCriterion criterion,
3125
+ double min_gain, bool as_relative_gain, double *restrict buffer_sd,
3126
+ double &restrict split_point, double &restrict xmin, double &restrict xmax,
3127
+ double *restrict w, size_t *restrict buffer_indices,
3128
+ size_t *restrict ix_arr_plus_st,
3129
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3130
+ double *restrict X_row_major, size_t ncols,
3131
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
3132
+ {
3133
+ /* Note: the input 'x' is supposed to be a linear combination of standardized variables, so
3134
+ all numbers are assumed to be small and in the same scale */
3135
+ double gain = 0;
3136
+ if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
3137
+
3138
+ /* here it's assumed the 'x' vector matches exactly with 'ix_arr' + 'st' */
3139
+ if (unlikely(n == 2))
3140
+ {
3141
+ if (x[0] == x[1]) return -HUGE_VAL;
3142
+ split_point = midpoint_with_reorder(x[0], x[1]);
3143
+ gain = 1.;
3144
+ if (gain > min_gain)
3145
+ return gain;
3146
+ else
3147
+ return 0.;
3148
+ }
3149
+
3150
+ /* sort in ascending order */
3151
+ std::iota(buffer_indices, buffer_indices + n, (size_t)0);
3152
+ std::sort(buffer_indices, buffer_indices + n,
3153
+ [&x](const size_t a, const size_t b){return x[a] < x[b];});
3154
+ xmin = x[buffer_indices[0]]; xmax = x[buffer_indices[n-1]];
3155
+ if (xmin == xmax) return -HUGE_VAL;
3156
+
3157
+ if (criterion == Pooled || criterion == Averaged)
3158
+ gain = find_split_std_gain_weighted<double, ldouble_safe>(x, n, buffer_sd, criterion, min_gain, split_point, w, buffer_indices);
3159
+ else if (criterion == DensityCrit)
3160
+ gain = find_split_dens_shortform_weighted<double, double *restrict, ldouble_safe>(x, n, split_point, w, buffer_indices);
3161
+ else if (criterion == FullGain)
3162
+ {
3163
+ std::vector<size_t> argsorted(n);
3164
+ std::iota(argsorted.begin(), argsorted.end(), (size_t)0);
3165
+ std::sort(argsorted.begin(), argsorted.end(),
3166
+ [&x](const size_t a, const size_t b){return x[a] < x[b];});
3167
+ if (x[argsorted[0]] == x[argsorted[n-1]]) return -HUGE_VAL;
3168
+ std::vector<double> temp_buffer(n + mult2(ncols));
3169
+ for (size_t ix = 0; ix < n; ix++) temp_buffer[ix] = x[argsorted[ix]];
3170
+ for (size_t ix = 0; ix < n; ix++)
3171
+ argsorted[ix] = ix_arr_plus_st[argsorted[ix]];
3172
+ size_t ignored;
3173
+ gain = find_split_full_gain_weighted<double, double *restrict, ldouble_safe>(
3174
+ temp_buffer.data(), (size_t)0, n-1, argsorted.data(),
3175
+ cols_use, ncols_use, force_cols_use,
3176
+ X_row_major, ncols,
3177
+ Xr, Xr_ind, Xr_indptr,
3178
+ temp_buffer.data() + n, temp_buffer.data() + n + ncols,
3179
+ ignored, split_point,
3180
+ false,
3181
+ w);
3182
+ }
3183
+ /* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
3184
+ return std::fmax(0., gain);
3185
+ }
3186
+
3187
+ /* for split-criterion in single-variable splits */
3188
+ template <class real_t_, class ldouble_safe>
3189
+ double eval_guided_crit(size_t *restrict ix_arr, size_t st, size_t end, real_t_ *restrict x,
3190
+ double *restrict buffer_sd, bool as_relative_gain,
3191
+ double *restrict buffer_imputed_x, double *restrict saved_xmedian,
3192
+ size_t &split_ix, double &restrict split_point, double &restrict xmin, double &restrict xmax,
3193
+ GainCriterion criterion, double min_gain, MissingAction missing_action,
3194
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3195
+ double *restrict X_row_major, size_t ncols,
3196
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
3197
+ {
3198
+ size_t st_orig = st;
3199
+ double gain = 0;
3200
+ if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
3201
+
3202
+ /* move NAs to the front if there's any, exclude them from calculations */
3203
+ if (missing_action != Fail)
3204
+ st = move_NAs_to_front(ix_arr, st, end, x);
3205
+
3206
+ if (unlikely(st >= end)) return -HUGE_VAL;
3207
+ else if (unlikely(st == (end-1)))
3208
+ {
3209
+ if (x[ix_arr[st]] == x[ix_arr[end]])
3210
+ return -HUGE_VAL;
3211
+ split_point = midpoint_with_reorder(x[ix_arr[st]], x[ix_arr[end]]);
3212
+ split_ix = st;
3213
+ gain = 1.;
3214
+ if (gain > min_gain)
3215
+ return gain;
3216
+ else
3217
+ return 0.;
3218
+ }
3219
+
3220
+ /* sort in ascending order */
3221
+ std::sort(ix_arr + st, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
3222
+ if (x[ix_arr[st]] == x[ix_arr[end]]) return -HUGE_VAL;
3223
+ xmin = x[ix_arr[st]]; xmax = x[ix_arr[end]];
3224
+
3225
+ /* unlike the previous case for the extended model, the data here has not been centered,
3226
+ which could make the standard deviations have poor precision. It's nevertheless not
3227
+ necessary for this mean to have good precision, since it's only meant for centering,
3228
+ so it can be calculated inexactly with simd instructions. */
3229
+ real_t_ xmean = 0;
3230
+ if (criterion == Pooled || criterion == Averaged)
3231
+ {
3232
+ for (size_t ix = st; ix <= end; ix++)
3233
+ xmean += x[ix_arr[ix]];
3234
+ xmean /= (real_t_)(end - st + 1);
3235
+ }
3236
+
3237
+ if (missing_action == Impute && st > st_orig)
3238
+ {
3239
+ missing_action = Fail;
3240
+ fill_NAs_with_median(ix_arr, st_orig, st, end, x, buffer_imputed_x, saved_xmedian);
3241
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3242
+ gain = find_split_rel_gain<double, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, split_point, split_ix);
3243
+ else if (criterion == Pooled || criterion == Averaged)
3244
+ gain = find_split_std_gain<double, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, buffer_sd, criterion, min_gain, split_point, split_ix);
3245
+ else if (criterion == DensityCrit)
3246
+ gain = find_split_dens<double, ldouble_safe>(buffer_imputed_x, ix_arr, st_orig, end, split_point, split_ix);
3247
+ else if (criterion == FullGain)
3248
+ {
3249
+ /* TODO: this buffer should be allocated from outside */
3250
+ std::vector<double> temp_buffer(mult2(ncols));
3251
+ gain = find_split_full_gain<double, ldouble_safe>(
3252
+ buffer_imputed_x, st_orig, end, ix_arr,
3253
+ cols_use, ncols_use, force_cols_use,
3254
+ X_row_major, ncols,
3255
+ Xr, Xr_ind, Xr_indptr,
3256
+ temp_buffer.data(), temp_buffer.data() + ncols,
3257
+ split_ix, split_point, true);
3258
+ }
3259
+
3260
+ /* Note: in theory, it should be possible to use a faster version assuming a contiguous array for 'x',
3261
+ but such an approach might give inexact split points. Better to avoid such inexactness at the
3262
+ expense of more computations. */
3263
+ }
3264
+
3265
+ else
3266
+ {
3267
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3268
+ gain = find_split_rel_gain<real_t_, ldouble_safe>(x, xmean, ix_arr, st, end, split_point, split_ix);
3269
+ else if (criterion == Pooled || criterion == Averaged)
3270
+ gain = find_split_std_gain<real_t_, ldouble_safe>(x, xmean, ix_arr, st, end, buffer_sd, criterion, min_gain, split_point, split_ix);
3271
+ else if (criterion == DensityCrit)
3272
+ gain = find_split_dens<real_t_, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix);
3273
+ else if (criterion == FullGain)
3274
+ {
3275
+ /* TODO: this buffer should be allocated from outside */
3276
+ std::vector<double> temp_buffer(mult2(ncols));
3277
+ gain = find_split_full_gain<real_t_, ldouble_safe>(
3278
+ x, st, end, ix_arr,
3279
+ cols_use, ncols_use, force_cols_use,
3280
+ X_row_major, ncols,
3281
+ Xr, Xr_ind, Xr_indptr,
3282
+ temp_buffer.data(), temp_buffer.data() + ncols,
3283
+ split_ix, split_point, true);
3284
+ }
3285
+ }
3286
+
3287
+ /* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
3288
+ return std::fmax(0., gain);
3289
+ }
3290
+
3291
+ template <class real_t_, class mapping, class ldouble_safe>
3292
+ double eval_guided_crit_weighted(size_t *restrict ix_arr, size_t st, size_t end, real_t_ *restrict x,
3293
+ double *restrict buffer_sd, bool as_relative_gain,
3294
+ double *restrict buffer_imputed_x, double *restrict saved_xmedian,
3295
+ size_t &split_ix, double &restrict split_point, double &restrict xmin, double &restrict xmax,
3296
+ GainCriterion criterion, double min_gain, MissingAction missing_action,
3297
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3298
+ double *restrict X_row_major, size_t ncols,
3299
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
3300
+ mapping &restrict w)
3301
+ {
3302
+ size_t st_orig = st;
3303
+ double gain = 0;
3304
+ if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
3305
+
3306
+ /* move NAs to the front if there's any, exclude them from calculations */
3307
+ if (missing_action != Fail)
3308
+ st = move_NAs_to_front(ix_arr, st, end, x);
3309
+
3310
+ if (unlikely(st >= end)) return -HUGE_VAL;
3311
+ else if (unlikely(st == (end-1)))
3312
+ {
3313
+ if (x[ix_arr[st]] == x[ix_arr[end]])
3314
+ return -HUGE_VAL;
3315
+ split_point = midpoint_with_reorder(x[ix_arr[st]], x[ix_arr[end]]);
3316
+ split_ix = st;
3317
+ gain = 1.;
3318
+ if (gain > min_gain)
3319
+ return gain;
3320
+ else
3321
+ return 0.;
3322
+ }
3323
+
3324
+ /* sort in ascending order */
3325
+ std::sort(ix_arr + st, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
3326
+ if (x[ix_arr[st]] == x[ix_arr[end]]) return -HUGE_VAL;
3327
+ xmin = x[ix_arr[st]]; xmax = x[ix_arr[end]];
3328
+
3329
+ /* unlike the previous case for the extended model, the data here has not been centered,
3330
+ which could make the standard deviations have poor precision. It's nevertheless not
3331
+ necessary for this mean to have good precision, since it's only meant for centering,
3332
+ so it can be calculated inexactly with simd instructions. */
3333
+ real_t_ xmean = 0;
3334
+ real_t_ cnt = 0;
3335
+ if (criterion == Pooled || criterion == Averaged)
3336
+ {
3337
+ for (size_t ix = st; ix <= end; ix++)
3338
+ {
3339
+ xmean += x[ix_arr[ix]];
3340
+ cnt += w[ix_arr[ix]];
3341
+ }
3342
+ xmean /= cnt;
3343
+ }
3344
+
3345
+ if (missing_action == Impute && st > st_orig)
3346
+ {
3347
+ missing_action = Fail;
3348
+ fill_NAs_with_median(ix_arr, st_orig, st, end, x, buffer_imputed_x, saved_xmedian);
3349
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3350
+ gain = find_split_rel_gain_weighted<double, mapping, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, split_point, split_ix, w);
3351
+ else if (criterion == Pooled || criterion == Averaged)
3352
+ gain = find_split_std_gain_weighted<double, mapping, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, buffer_sd, criterion, min_gain, split_point, split_ix, w);
3353
+ else if (criterion == DensityCrit)
3354
+ gain = find_split_dens_weighted<double, mapping, ldouble_safe>(buffer_imputed_x, ix_arr, st_orig, end, split_point, split_ix, w);
3355
+ else if (criterion == FullGain)
3356
+ {
3357
+ std::vector<double> temp_buffer(mult2(ncols));
3358
+ gain = find_split_full_gain_weighted<double, mapping, ldouble_safe>(
3359
+ buffer_imputed_x, st_orig, end, ix_arr,
3360
+ cols_use, ncols_use, force_cols_use,
3361
+ X_row_major, ncols,
3362
+ Xr, Xr_ind, Xr_indptr,
3363
+ temp_buffer.data(), temp_buffer.data() + ncols,
3364
+ split_ix, split_point, true,
3365
+ w);
3366
+ }
3367
+ }
3368
+
3369
+ else
3370
+ {
3371
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3372
+ gain = find_split_rel_gain_weighted<real_t_, mapping, ldouble_safe>(x, xmean, ix_arr, st, end, split_point, split_ix, w);
3373
+ else if (criterion == Pooled || criterion == Averaged)
3374
+ gain = find_split_std_gain_weighted<real_t_, mapping, ldouble_safe>(x, xmean, ix_arr, st, end, buffer_sd, criterion, min_gain, split_point, split_ix, w);
3375
+ else if (criterion == DensityCrit)
3376
+ gain = find_split_dens_weighted<real_t_, mapping, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix, w);
3377
+ else if (criterion == FullGain)
3378
+ {
3379
+ std::vector<double> temp_buffer(mult2(ncols));
3380
+ gain = find_split_full_gain_weighted<real_t_, mapping, ldouble_safe>(
3381
+ x, st, end, ix_arr,
3382
+ cols_use, ncols_use, force_cols_use,
3383
+ X_row_major, ncols,
3384
+ Xr, Xr_ind, Xr_indptr,
3385
+ temp_buffer.data(), temp_buffer.data() + ncols,
3386
+ split_ix, split_point, true,
3387
+ w);
3388
+ }
3389
+ }
3390
+
3391
+ /* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
3392
+ return std::fmax(0., gain);
3393
+ }
3394
+
3395
+ /* TODO: here it should only need to look at the non-zero entries. It can then use the
3396
+ same algorithm as above, but putting an extra check to see where do the zeros fit in
3397
+ the sorted order of the non-zero entries while calculating gains and SDs, and then
3398
+ call the 'divide_subset' function after-the-fact to reach the same end result.
3399
+ It should be much faster than this if the non-zero entries are few. */
3400
+ template <class real_t_, class sparse_ix, class ldouble_safe>
3401
+ double eval_guided_crit(size_t ix_arr[], size_t st, size_t end,
3402
+ size_t col_num, real_t_ Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
3403
+ double buffer_arr[], size_t buffer_pos[], bool as_relative_gain,
3404
+ double *restrict saved_xmedian,
3405
+ double &split_point, double &xmin, double &xmax,
3406
+ GainCriterion criterion, double min_gain, MissingAction missing_action,
3407
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3408
+ double *restrict X_row_major, size_t ncols,
3409
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
3410
+ {
3411
+ size_t ignored;
3412
+
3413
+
3414
+ todense(ix_arr, st, end,
3415
+ col_num, Xc, Xc_ind, Xc_indptr,
3416
+ buffer_arr);
3417
+ size_t tot = end - st + 1;
3418
+ std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
3419
+
3420
+ if (missing_action == Impute)
3421
+ {
3422
+ missing_action = Fail;
3423
+ for (size_t ix = 0; ix < tot; ix++)
3424
+ {
3425
+ if (unlikely(is_na_or_inf(buffer_arr[ix])))
3426
+ {
3427
+ goto fill_missing;
3428
+ }
3429
+ }
3430
+ goto no_nas;
3431
+
3432
+ fill_missing:
3433
+ {
3434
+ size_t idx_half = div2(tot);
3435
+ std::nth_element(buffer_pos, buffer_pos + idx_half, buffer_pos + tot,
3436
+ [&buffer_arr](const size_t a, const size_t b){return buffer_arr[a] < buffer_arr[b];});
3437
+ *saved_xmedian = buffer_arr[buffer_pos[idx_half]];
3438
+
3439
+ if ((tot % 2) == 0)
3440
+ {
3441
+ double xlow = *std::max_element(buffer_pos, buffer_pos + idx_half);
3442
+ *saved_xmedian = xlow + ((*saved_xmedian)-xlow)/2.;
3443
+ }
3444
+
3445
+ for (size_t ix = 0; ix < tot; ix++)
3446
+ buffer_arr[ix] = is_na_or_inf(buffer_arr[ix])? (*saved_xmedian) : buffer_arr[ix];
3447
+ std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
3448
+ }
3449
+ }
3450
+
3451
+ no_nas:
3452
+ return eval_guided_crit<double, ldouble_safe>(
3453
+ buffer_pos, 0, end - st, buffer_arr, buffer_arr + tot,
3454
+ as_relative_gain, saved_xmedian, (double*)NULL, ignored, split_point,
3455
+ xmin, xmax, criterion, min_gain, missing_action,
3456
+ cols_use, ncols_use, force_cols_use,
3457
+ X_row_major, ncols,
3458
+ Xr, Xr_ind, Xr_indptr);
3459
+ }
3460
+
3461
+ template <class real_t_, class sparse_ix, class mapping, class ldouble_safe>
3462
+ double eval_guided_crit_weighted(size_t ix_arr[], size_t st, size_t end,
3463
+ size_t col_num, real_t_ Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
3464
+ double buffer_arr[], size_t buffer_pos[], bool as_relative_gain,
3465
+ double *restrict saved_xmedian,
3466
+ double &restrict split_point, double &restrict xmin, double &restrict xmax,
3467
+ GainCriterion criterion, double min_gain, MissingAction missing_action,
3468
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3469
+ double *restrict X_row_major, size_t ncols,
3470
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
3471
+ mapping &restrict w)
3472
+ {
3473
+ size_t ignored;
3474
+
3475
+
3476
+ todense(ix_arr, st, end,
3477
+ col_num, Xc, Xc_ind, Xc_indptr,
3478
+ buffer_arr);
3479
+ size_t tot = end - st + 1;
3480
+ std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
3481
+
3482
+
3483
+ if (missing_action == Impute)
3484
+ {
3485
+ missing_action = Fail;
3486
+ for (size_t ix = 0; ix < tot; ix++)
3487
+ {
3488
+ if (unlikely(is_na_or_inf(buffer_arr[ix])))
3489
+ {
3490
+ goto fill_missing;
3491
+ }
3492
+ }
3493
+ goto no_nas;
3494
+
3495
+ fill_missing:
3496
+ {
3497
+ size_t idx_half = div2(tot);
3498
+ std::nth_element(buffer_pos, buffer_pos + idx_half, buffer_pos + tot,
3499
+ [&buffer_arr](const size_t a, const size_t b){return buffer_arr[a] < buffer_arr[b];});
3500
+ *saved_xmedian = buffer_arr[buffer_pos[idx_half]];
3501
+
3502
+ if ((tot % 2) == 0)
3503
+ {
3504
+ double xlow = *std::max_element(buffer_pos, buffer_pos + idx_half);
3505
+ *saved_xmedian = xlow + ((*saved_xmedian)-xlow)/2.;
3506
+ }
3507
+
3508
+ for (size_t ix = 0; ix < tot; ix++)
3509
+ buffer_arr[ix] = is_na_or_inf(buffer_arr[ix])? (*saved_xmedian) : buffer_arr[ix];
3510
+ std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
3511
+ }
3512
+ }
3513
+
3514
+
3515
+ no_nas:
3516
+ /* TODO: allocate this buffer externally */
3517
+ std::vector<double> buffer_w(tot);
3518
+ for (size_t row = st; row <= end; row++)
3519
+ buffer_w[row-st] = w[ix_arr[row]];
3520
+ /* TODO: in this case, as the weights match with the order of the indices, could use a faster version
3521
+ with a weighted rel_gain function instead (not yet implemented). */
3522
+ return eval_guided_crit_weighted<double, std::vector<double>, ldouble_safe>(
3523
+ buffer_pos, 0, end - st, buffer_arr, buffer_arr + tot,
3524
+ as_relative_gain, saved_xmedian, (double*)NULL, ignored, split_point,
3525
+ xmin, xmax, criterion, min_gain, missing_action,
3526
+ cols_use, ncols_use, force_cols_use,
3527
+ X_row_major, ncols,
3528
+ Xr, Xr_ind, Xr_indptr,
3529
+ buffer_w);
3530
+ }
3531
+
3532
+ /* How this works:
3533
+ - For Averaged criterion, will take the expected standard deviation that would be gotten with the category counts
3534
+ if each category got assigned a real number at random ~ Unif(0,1) and the data were thus converted to
3535
+ numerical. In such case, the best split (highest sd gain) is always putting the second-highest count in one
3536
+ branch, so there is no point in doing a full search over other permutations. In order to get more reasonable
3537
+ splits, when using the option to split by subsets of categories, it will sort the counts and evaluate only
3538
+ splits in which the categories are grouped in sorted order - in such cases it tends to pick either the
3539
+ smallest or the largest category to assign to one branch, but sometimes picks groups too.
3540
+ - For Pooled criterion, will take shannon entropy, which tends to make a more even split. In the case of splitting
3541
+ by a single category, it always puts the largest category in a separate branch. In the case of subsets,
3542
+ it can either evaluate possible splits over all permutations (not feasible if there are too many categories),
3543
+ or look up for splits in sorted order just like for Averaged criterion.
3544
+ Splitting by averaged Gini gain (like with Averaged) also selects always the second-largest category to put in one branch,
3545
+ while splitting by weighted Gini (like with Pooled) usually selects the largest category to put in one branch. The
3546
+ Gini gain is not easily comparable to that of numerical columns, so it's not offered as an option here.
3547
+ */
3548
+ /* https://math.stackexchange.com/questions/3343384/expected-variance-and-kurtosis-from-pmf-in-which-possible-discrete-values-are-dr */
3549
+ /* TODO: 'buffer_pos' doesn't need to be 'size_t', 'int' would suffice */
3550
+ template <class ldouble_safe>
3551
+ double eval_guided_crit(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int ncat,
3552
+ int *restrict saved_cat_mode,
3553
+ size_t *restrict buffer_cnt, size_t *restrict buffer_pos, double *restrict buffer_prob,
3554
+ int &restrict chosen_cat, signed char *restrict split_categ, signed char *restrict buffer_split,
3555
+ GainCriterion criterion, double min_gain, bool all_perm,
3556
+ MissingAction missing_action, CategSplit cat_split_type)
3557
+ {
3558
+ if (criterion == DensityCrit)
3559
+ return find_split_dens_longform<size_t, ldouble_safe>(
3560
+ x, ncat, ix_arr, st, end,
3561
+ cat_split_type, missing_action,
3562
+ chosen_cat, split_categ, saved_cat_mode,
3563
+ buffer_cnt, buffer_pos);
3564
+ if (st >= end) return -HUGE_VAL;
3565
+ size_t n_nas = 0;
3566
+ int xval;
3567
+
3568
+ /* count categories */
3569
+ memset(buffer_cnt, 0, sizeof(size_t) * ncat);
3570
+ if (missing_action == Fail)
3571
+ {
3572
+ for (size_t row = st; row <= end; row++)
3573
+ if (likely(x[ix_arr[row]] >= 0))
3574
+ buffer_cnt[x[ix_arr[row]]]++;
3575
+ }
3576
+
3577
+ else if (missing_action == Impute)
3578
+ {
3579
+ for (size_t row = st; row <= end; row++)
3580
+ {
3581
+ xval = x[ix_arr[row]];
3582
+ if (unlikely(xval < 0))
3583
+ n_nas++;
3584
+ else
3585
+ buffer_cnt[xval]++;
3586
+ }
3587
+
3588
+ if (unlikely(n_nas >= end-st)) return -HUGE_VAL;
3589
+
3590
+ if (n_nas)
3591
+ {
3592
+ auto idxmax = std::max_element(buffer_cnt, buffer_cnt + ncat);
3593
+ *idxmax += n_nas;
3594
+ *saved_cat_mode = (int)std::distance(buffer_cnt, idxmax);
3595
+ }
3596
+ }
3597
+
3598
+ else
3599
+ {
3600
+ for (size_t row = st; row <= end; row++)
3601
+ {
3602
+ xval = x[ix_arr[row]];
3603
+ if (likely(xval >= 0)) buffer_cnt[xval]++;
3604
+ }
3605
+ }
3606
+
3607
+ double this_gain = -HUGE_VAL;
3608
+ double best_gain = -HUGE_VAL;
3609
+ std::iota(buffer_pos, buffer_pos + ncat, (size_t)0);
3610
+ size_t st_pos = 0;
3611
+
3612
+ switch (cat_split_type)
3613
+ {
3614
+ case SingleCateg:
3615
+ {
3616
+ size_t cnt = end - st + 1;
3617
+ ldouble_safe cnt_l = (ldouble_safe) cnt;
3618
+ size_t ncat_present = 0;
3619
+
3620
+ switch(criterion)
3621
+ {
3622
+ case Averaged:
3623
+ {
3624
+ /* move zero-counts to the beginning */
3625
+ size_t temp;
3626
+ for (int cat = 0; cat < ncat; cat++)
3627
+ {
3628
+ if (buffer_cnt[cat])
3629
+ {
3630
+ ncat_present++;
3631
+ buffer_prob[cat] = (ldouble_safe) buffer_cnt[cat] / cnt_l;
3632
+ }
3633
+
3634
+ else
3635
+ {
3636
+ temp = buffer_pos[st_pos];
3637
+ buffer_pos[st_pos] = buffer_pos[cat];
3638
+ buffer_pos[cat] = temp;
3639
+ st_pos++;
3640
+ }
3641
+ }
3642
+
3643
+ if (ncat_present <= 1) return -HUGE_VAL;
3644
+
3645
+ double sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
3646
+
3647
+ /* try isolating each category one at a time */
3648
+ for (size_t pos = st_pos; (int)pos < ncat; pos++)
3649
+ {
3650
+ this_gain = sd_gain(sd_full,
3651
+ 0.0,
3652
+ (expected_sd_cat_single<size_t, size_t, ldouble_safe>(buffer_cnt, buffer_prob, ncat_present, buffer_pos + st_pos, pos - st_pos, cnt)));
3653
+ if (this_gain > min_gain && this_gain > best_gain)
3654
+ {
3655
+ best_gain = this_gain;
3656
+ chosen_cat = buffer_pos[pos];
3657
+ }
3658
+ }
3659
+ break;
3660
+ }
3661
+
3662
+ case Pooled:
3663
+ {
3664
+ /* here it will always pick the largest one */
3665
+ size_t ncat_present = 0;
3666
+ size_t cnt_max = 0;
3667
+ for (int cat = 0; cat < ncat; cat++)
3668
+ {
3669
+ if (buffer_cnt[cat])
3670
+ {
3671
+ ncat_present++;
3672
+ if (cnt_max < buffer_cnt[cat])
3673
+ {
3674
+ cnt_max = buffer_cnt[cat];
3675
+ chosen_cat = cat;
3676
+ }
3677
+ }
3678
+ }
3679
+
3680
+ if (ncat_present <= 1) return -HUGE_VAL;
3681
+
3682
+ ldouble_safe cnt_left = (ldouble_safe)((end - st + 1) - cnt_max);
3683
+ this_gain = (
3684
+ (ldouble_safe)cnt * std::log((ldouble_safe)cnt)
3685
+ - cnt_left * std::log(cnt_left)
3686
+ - (ldouble_safe)cnt_max * std::log((ldouble_safe)cnt_max)
3687
+ ) / cnt;
3688
+ best_gain = (this_gain > min_gain)? this_gain : best_gain;
3689
+ break;
3690
+ }
3691
+
3692
+ default:
3693
+ {
3694
+ unexpected_error();
3695
+ break;
3696
+ }
3697
+ }
3698
+ break;
3699
+ }
3700
+
3701
+ case SubSet:
3702
+ {
3703
+ /* sort by counts */
3704
+ std::sort(buffer_pos, buffer_pos + ncat, [&buffer_cnt](const size_t a, const size_t b){return buffer_cnt[a] < buffer_cnt[b];});
3705
+
3706
+ /* set split as: (1):left (0):right (-1):not_present */
3707
+ memset(buffer_split, 0, ncat * sizeof(signed char));
3708
+
3709
+ ldouble_safe cnt = (ldouble_safe)(end - st + 1);
3710
+
3711
+ switch(criterion)
3712
+ {
3713
+ case Averaged:
3714
+ {
3715
+ /* determine first non-zero and convert to probabilities */
3716
+ double sd_full;
3717
+ for (int cat = 0; cat < ncat; cat++)
3718
+ {
3719
+ if (buffer_cnt[buffer_pos[cat]])
3720
+ {
3721
+ buffer_prob[buffer_pos[cat]] = (ldouble_safe)buffer_cnt[buffer_pos[cat]] / cnt;
3722
+ }
3723
+
3724
+ else
3725
+ {
3726
+ buffer_split[buffer_pos[cat]] = -1;
3727
+ st_pos++;
3728
+ }
3729
+ }
3730
+
3731
+ if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
3732
+
3733
+ /* calculate full SD assuming they take values randomly ~Unif(0, 1) */
3734
+ size_t ncat_present = (size_t)ncat - st_pos;
3735
+ sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
3736
+ if (ncat_present >= log2ceil(SIZE_MAX)) all_perm = false;
3737
+
3738
+ /* move categories one at a time */
3739
+ for (size_t pos = st_pos; pos < ((size_t)ncat - st_pos - 1); pos++)
3740
+ {
3741
+ buffer_split[buffer_pos[pos]] = 1;
3742
+ this_gain = sd_gain(sd_full,
3743
+ (expected_sd_cat<size_t, size_t, ldouble_safe>(buffer_cnt, buffer_prob, pos - st_pos + 1, buffer_pos + st_pos)),
3744
+ (expected_sd_cat<size_t, size_t, ldouble_safe>(buffer_cnt, buffer_prob, (size_t)ncat - pos - 1, buffer_pos + pos + 1))
3745
+ );
3746
+ if (this_gain > min_gain && this_gain > best_gain)
3747
+ {
3748
+ best_gain = this_gain;
3749
+ memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
3750
+ }
3751
+ }
3752
+
3753
+ break;
3754
+ }
3755
+
3756
+ case Pooled:
3757
+ {
3758
+ ldouble_safe s = 0;
3759
+
3760
+ /* determine first non-zero and get base info */
3761
+ for (int cat = 0; cat < ncat; cat++)
3762
+ {
3763
+ if (buffer_cnt[buffer_pos[cat]])
3764
+ {
3765
+ s += (buffer_cnt[buffer_pos[cat]] <= 1)?
3766
+ 0 : ((ldouble_safe) buffer_cnt[buffer_pos[cat]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[cat]]));
3767
+ }
3768
+
3769
+ else
3770
+ {
3771
+ buffer_split[buffer_pos[cat]] = -1;
3772
+ st_pos++;
3773
+ }
3774
+ }
3775
+
3776
+ if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
3777
+
3778
+ /* calculate base info */
3779
+ ldouble_safe base_info = cnt * std::log(cnt) - s;
3780
+
3781
+ if (all_perm)
3782
+ {
3783
+ size_t cnt_left, cnt_right;
3784
+ double s_left, s_right;
3785
+ size_t ncat_present = (size_t)ncat - st_pos;
3786
+ size_t ncomb = pow2(ncat_present) - 1;
3787
+ size_t best_combin;
3788
+
3789
+ for (size_t combin = 1; combin < ncomb; combin++)
3790
+ {
3791
+ cnt_left = 0; cnt_right = 0;
3792
+ s_left = 0; s_right = 0;
3793
+ for (size_t pos = st_pos; (int)pos < ncat; pos++)
3794
+ {
3795
+ if (extract_bit(combin, pos))
3796
+ {
3797
+ cnt_left += buffer_cnt[buffer_pos[pos]];
3798
+ s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
3799
+ 0 : ((ldouble_safe) buffer_cnt[buffer_pos[pos]]
3800
+ * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
3801
+ }
3802
+
3803
+ else
3804
+ {
3805
+ cnt_right += buffer_cnt[buffer_pos[pos]];
3806
+ s_right += (buffer_cnt[buffer_pos[pos]] <= 1)?
3807
+ 0 : ((ldouble_safe) buffer_cnt[buffer_pos[pos]]
3808
+ * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
3809
+ }
3810
+ }
3811
+
3812
+ this_gain = categ_gain<size_t, ldouble_safe>(
3813
+ cnt_left, cnt_right,
3814
+ s_left, s_right,
3815
+ base_info, cnt);
3816
+
3817
+ if (this_gain > min_gain && this_gain > best_gain)
3818
+ {
3819
+ best_gain = this_gain;
3820
+ best_combin = combin;
3821
+ }
3822
+
3823
+ }
3824
+
3825
+ if (best_gain > min_gain)
3826
+ for (size_t pos = 0; pos < ncat_present; pos++)
3827
+ split_categ[buffer_pos[st_pos + pos]] = extract_bit(best_combin, pos);
3828
+
3829
+ }
3830
+
3831
+ else
3832
+ {
3833
+ /* try moving the categories one at a time */
3834
+ size_t cnt_left = 0;
3835
+ size_t cnt_right = end - st + 1;
3836
+ double s_left = 0;
3837
+ double s_right = s;
3838
+
3839
+ for (size_t pos = st_pos; pos < (ncat - st_pos - 1); pos++)
3840
+ {
3841
+ buffer_split[buffer_pos[pos]] = 1;
3842
+ s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
3843
+ 0 : ((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[pos]]));
3844
+ s_right -= (buffer_cnt[buffer_pos[pos]] <= 1)?
3845
+ 0 : ((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[pos]]));
3846
+ cnt_left += buffer_cnt[buffer_pos[pos]];
3847
+ cnt_right -= buffer_cnt[buffer_pos[pos]];
3848
+
3849
+ this_gain = categ_gain<size_t, ldouble_safe>(
3850
+ cnt_left, cnt_right,
3851
+ s_left, s_right,
3852
+ base_info, cnt);
3853
+
3854
+ if (this_gain > min_gain && this_gain > best_gain)
3855
+ {
3856
+ best_gain = this_gain;
3857
+ memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
3858
+ }
3859
+ }
3860
+ }
3861
+
3862
+ break;
3863
+ }
3864
+
3865
+ default:
3866
+ {
3867
+ unexpected_error();
3868
+ break;
3869
+ }
3870
+ }
3871
+ }
3872
+ }
3873
+
3874
+ if (st == (end-1)) return 0;
3875
+
3876
+ if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
3877
+ return 0;
3878
+ else
3879
+ return best_gain;
3880
+ }
3881
+
3882
+
3883
+ template <class mapping, class ldouble_safe>
3884
+ double eval_guided_crit_weighted(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int ncat,
3885
+ int *restrict saved_cat_mode,
3886
+ size_t *restrict buffer_pos, double *restrict buffer_prob,
3887
+ int &restrict chosen_cat, signed char *restrict split_categ, signed char *restrict buffer_split,
3888
+ GainCriterion criterion, double min_gain, bool all_perm,
3889
+ MissingAction missing_action, CategSplit cat_split_type,
3890
+ mapping &restrict w)
3891
+ {
3892
+ if (criterion == DensityCrit)
3893
+ return find_split_dens_longform_weighted<mapping, size_t, ldouble_safe>(
3894
+ x, ncat, ix_arr, st, end,
3895
+ cat_split_type, missing_action,
3896
+ chosen_cat, split_categ, saved_cat_mode,
3897
+ buffer_pos, w);
3898
+ if (st >= end) return -HUGE_VAL;
3899
+ ldouble_safe w_missing = 0;
3900
+ int xval;
3901
+ size_t ix_;
3902
+
3903
+ /* count categories */
3904
+ /* TODO: allocate this buffer externally */
3905
+ std::vector<ldouble_safe> buffer_cnt(ncat, (ldouble_safe)0);
3906
+ if (missing_action == Fail)
3907
+ {
3908
+ for (size_t row = st; row <= end; row++)
3909
+ {
3910
+ ix_ = ix_arr[row];
3911
+ if (unlikely(x[ix_]) < 0) continue;
3912
+ buffer_cnt[x[ix_]] += w[ix_];
3913
+ }
3914
+ }
3915
+
3916
+ else if (missing_action == Impute)
3917
+ {
3918
+ for (size_t row = st; row <= end; row++)
3919
+ {
3920
+ ix_ = ix_arr[row];
3921
+ xval = x[ix_];
3922
+ if (unlikely(xval < 0))
3923
+ w_missing += w[ix_];
3924
+ else
3925
+ buffer_cnt[xval] += w[ix_];
3926
+ }
3927
+
3928
+ if (w_missing)
3929
+ {
3930
+ auto idxmax = std::max_element(buffer_cnt.begin(), buffer_cnt.end());
3931
+ *idxmax += w_missing;
3932
+ *saved_cat_mode = (int)std::distance(buffer_cnt.begin(), idxmax);
3933
+ }
3934
+ }
3935
+
3936
+ else
3937
+ {
3938
+ for (size_t row = st; row <= end; row++)
3939
+ {
3940
+ ix_ = ix_arr[row];
3941
+ xval = x[ix_];
3942
+ if (likely(xval >= 0)) buffer_cnt[xval] += w[ix_];
3943
+ }
3944
+ }
3945
+
3946
+ ldouble_safe cnt = std::accumulate(buffer_cnt.begin(), buffer_cnt.end(), (ldouble_safe)0);
3947
+
3948
+ double this_gain = -HUGE_VAL;
3949
+ double best_gain = -HUGE_VAL;
3950
+ std::iota(buffer_pos, buffer_pos + ncat, (size_t)0);
3951
+ size_t st_pos = 0;
3952
+
3953
+ switch(cat_split_type)
3954
+ {
3955
+ case SingleCateg:
3956
+ {
3957
+ size_t ncat_present = 0;
3958
+
3959
+ switch(criterion)
3960
+ {
3961
+ case Averaged:
3962
+ {
3963
+ /* move zero-counts to the beginning */
3964
+ size_t temp;
3965
+ for (int cat = 0; cat < ncat; cat++)
3966
+ {
3967
+ if (buffer_cnt[cat])
3968
+ {
3969
+ ncat_present++;
3970
+ buffer_prob[cat] = buffer_cnt[cat] / cnt;
3971
+ }
3972
+
3973
+ else
3974
+ {
3975
+ temp = buffer_pos[st_pos];
3976
+ buffer_pos[st_pos] = buffer_pos[cat];
3977
+ buffer_pos[cat] = temp;
3978
+ st_pos++;
3979
+ }
3980
+ }
3981
+
3982
+ if (ncat_present <= 1) return -HUGE_VAL;
3983
+
3984
+ double sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
3985
+
3986
+ /* try isolating each category one at a time */
3987
+ for (size_t pos = st_pos; (int)pos < ncat; pos++)
3988
+ {
3989
+ this_gain = sd_gain(sd_full,
3990
+ 0.0,
3991
+ (expected_sd_cat_single<ldouble_safe, size_t, ldouble_safe>(buffer_cnt.data(), buffer_prob, ncat_present, buffer_pos + st_pos, pos - st_pos, cnt))
3992
+ );
3993
+ if (this_gain > min_gain && this_gain > best_gain)
3994
+ {
3995
+ best_gain = this_gain;
3996
+ chosen_cat = buffer_pos[pos];
3997
+ }
3998
+ }
3999
+ break;
4000
+ }
4001
+
4002
+ case Pooled:
4003
+ {
4004
+ /* here it will always pick the largest one */
4005
+ size_t ncat_present = 0;
4006
+ ldouble_safe cnt_max = 0;
4007
+ for (int cat = 0; cat < ncat; cat++)
4008
+ {
4009
+ if (buffer_cnt[cat])
4010
+ {
4011
+ ncat_present++;
4012
+ if (cnt_max < buffer_cnt[cat])
4013
+ {
4014
+ cnt_max = buffer_cnt[cat];
4015
+ chosen_cat = cat;
4016
+ }
4017
+ }
4018
+ }
4019
+
4020
+ if (ncat_present <= 1) return -HUGE_VAL;
4021
+
4022
+ ldouble_safe cnt_left = (ldouble_safe)(cnt - cnt_max);
4023
+
4024
+ /* TODO: think of a better way of dealing with numbers between zero and one */
4025
+ this_gain = (
4026
+ std::fmax((ldouble_safe)1, cnt) * std::log(std::fmax((ldouble_safe)1, cnt))
4027
+ - std::fmax((ldouble_safe)1, cnt_left) * std::log(std::fmax((ldouble_safe)1, cnt_left))
4028
+ - std::fmax((ldouble_safe)1, cnt_max) * std::log(std::fmax((ldouble_safe)1, cnt_max))
4029
+ ) / std::fmax((ldouble_safe)1, cnt);
4030
+ best_gain = (this_gain > min_gain)? this_gain : best_gain;
4031
+ break;
4032
+ }
4033
+
4034
+ default:
4035
+ {
4036
+ unexpected_error();
4037
+ break;
4038
+ }
4039
+ }
4040
+ break;
4041
+ }
4042
+
4043
+ case SubSet:
4044
+ {
4045
+ /* sort by counts */
4046
+ std::sort(buffer_pos, buffer_pos + ncat, [&buffer_cnt](const size_t a, const size_t b){return buffer_cnt[a] < buffer_cnt[b];});
4047
+
4048
+ /* set split as: (1):left (0):right (-1):not_present */
4049
+ memset(buffer_split, 0, ncat * sizeof(signed char));
4050
+
4051
+
4052
+ switch(criterion)
4053
+ {
4054
+ case Averaged:
4055
+ {
4056
+ /* determine first non-zero and convert to probabilities */
4057
+ double sd_full;
4058
+ for (int cat = 0; cat < ncat; cat++)
4059
+ {
4060
+ if (buffer_cnt[buffer_pos[cat]])
4061
+ {
4062
+ buffer_prob[buffer_pos[cat]] = (ldouble_safe)buffer_cnt[buffer_pos[cat]] / cnt;
4063
+ }
4064
+
4065
+ else
4066
+ {
4067
+ buffer_split[buffer_pos[cat]] = -1;
4068
+ st_pos++;
4069
+ }
4070
+ }
4071
+
4072
+ if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
4073
+
4074
+ /* calculate full SD assuming they take values randomly ~Unif(0, 1) */
4075
+ size_t ncat_present = (size_t)ncat - st_pos;
4076
+ sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
4077
+ if (ncat_present >= log2ceil(SIZE_MAX)) all_perm = false;
4078
+
4079
+ /* move categories one at a time */
4080
+ for (size_t pos = st_pos; pos < ((size_t)ncat - st_pos - 1); pos++)
4081
+ {
4082
+ buffer_split[buffer_pos[pos]] = 1;
4083
+ /* TODO: is this correct? */
4084
+ this_gain = sd_gain(sd_full,
4085
+ (expected_sd_cat<ldouble_safe, size_t, ldouble_safe>(buffer_cnt.data(), buffer_prob, pos - st_pos + 1, buffer_pos + st_pos)),
4086
+ (expected_sd_cat<ldouble_safe, size_t, ldouble_safe>(buffer_cnt.data(), buffer_prob, (size_t)ncat - pos - 1, buffer_pos + pos + 1))
4087
+ );
4088
+ if (this_gain > min_gain && this_gain > best_gain)
4089
+ {
4090
+ best_gain = this_gain;
4091
+ memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
4092
+ }
4093
+ }
4094
+
4095
+ break;
4096
+ }
4097
+
4098
+ case Pooled:
4099
+ {
4100
+ ldouble_safe s = 0;
4101
+
4102
+ /* determine first non-zero and get base info */
4103
+ for (int cat = 0; cat < ncat; cat++)
4104
+ {
4105
+ if (buffer_cnt[buffer_pos[cat]])
4106
+ {
4107
+ s += (buffer_cnt[buffer_pos[cat]] <= 1)?
4108
+ (ldouble_safe)0
4109
+ :
4110
+ ((ldouble_safe) buffer_cnt[buffer_pos[cat]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[cat]]));
4111
+ }
4112
+
4113
+ else
4114
+ {
4115
+ buffer_split[buffer_pos[cat]] = -1;
4116
+ st_pos++;
4117
+ }
4118
+ }
4119
+
4120
+ if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
4121
+
4122
+ /* calculate base info */
4123
+ ldouble_safe base_info = std::fmax((ldouble_safe)1, cnt) * std::log(std::fmax((ldouble_safe)1, cnt)) - s;
4124
+
4125
+ if (all_perm)
4126
+ {
4127
+ size_t cnt_left, cnt_right;
4128
+ double s_left, s_right;
4129
+ size_t ncat_present = (size_t)ncat - st_pos;
4130
+ size_t ncomb = pow2(ncat_present) - 1;
4131
+ size_t best_combin;
4132
+
4133
+ for (size_t combin = 1; combin < ncomb; combin++)
4134
+ {
4135
+ cnt_left = 0; cnt_right = 0;
4136
+ s_left = 0; s_right = 0;
4137
+ for (size_t pos = st_pos; (int)pos < ncat; pos++)
4138
+ {
4139
+ if (extract_bit(combin, pos))
4140
+ {
4141
+ cnt_left += buffer_cnt[buffer_pos[pos]];
4142
+ s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
4143
+ (ldouble_safe)0
4144
+ :
4145
+ ((ldouble_safe) buffer_cnt[buffer_pos[pos]]
4146
+ * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
4147
+ }
4148
+
4149
+ else
4150
+ {
4151
+ cnt_right += buffer_cnt[buffer_pos[pos]];
4152
+ s_right += (buffer_cnt[buffer_pos[pos]] <= 1)?
4153
+ (ldouble_safe)0
4154
+ :
4155
+ ((ldouble_safe) buffer_cnt[buffer_pos[pos]]
4156
+ * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
4157
+ }
4158
+ }
4159
+
4160
+ this_gain = categ_gain<size_t, ldouble_safe>(
4161
+ cnt_left, cnt_right,
4162
+ s_left, s_right,
4163
+ base_info, cnt);
4164
+
4165
+ if (this_gain > min_gain && this_gain > best_gain)
4166
+ {
4167
+ best_gain = this_gain;
4168
+ best_combin = combin;
4169
+ }
4170
+
4171
+ }
4172
+
4173
+ if (best_gain > min_gain)
4174
+ for (size_t pos = 0; pos < ncat_present; pos++)
4175
+ split_categ[buffer_pos[st_pos + pos]] = extract_bit(best_combin, pos);
4176
+
4177
+ }
4178
+
4179
+ else
4180
+ {
4181
+ /* try moving the categories one at a time */
4182
+ size_t cnt_left = 0;
4183
+ size_t cnt_right = end - st + 1;
4184
+ double s_left = 0;
4185
+ double s_right = s;
4186
+
4187
+ for (size_t pos = st_pos; pos < (ncat - st_pos - 1); pos++)
4188
+ {
4189
+ buffer_split[buffer_pos[pos]] = 1;
4190
+ s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
4191
+ (ldouble_safe)0
4192
+ :
4193
+ ((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
4194
+ s_right -= (buffer_cnt[buffer_pos[pos]] <= 1)?
4195
+ (ldouble_safe)0
4196
+ :
4197
+ ((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
4198
+ cnt_left += buffer_cnt[buffer_pos[pos]];
4199
+ cnt_right -= buffer_cnt[buffer_pos[pos]];
4200
+
4201
+ this_gain = categ_gain<size_t, ldouble_safe>(
4202
+ cnt_left, cnt_right,
4203
+ s_left, s_right,
4204
+ base_info, cnt);
4205
+
4206
+ if (this_gain > min_gain && this_gain > best_gain)
4207
+ {
4208
+ best_gain = this_gain;
4209
+ memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
4210
+ }
4211
+ }
4212
+ }
4213
+
4214
+ break;
4215
+ }
4216
+
4217
+ default:
4218
+ {
4219
+ unexpected_error();
4220
+ break;
4221
+ }
4222
+ }
4223
+ }
4224
+ }
4225
+
4226
+ if (st == (end-1)) return 0;
4227
+
4228
+ if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
4229
+ return 0;
4230
+ else
4231
+ return best_gain;
4232
+ }