isotree 0.2.2 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -1
- data/LICENSE.txt +2 -2
- data/README.md +32 -14
- data/ext/isotree/ext.cpp +144 -31
- data/ext/isotree/extconf.rb +7 -7
- data/lib/isotree/isolation_forest.rb +110 -30
- data/lib/isotree/version.rb +1 -1
- data/vendor/isotree/LICENSE +1 -1
- data/vendor/isotree/README.md +165 -27
- data/vendor/isotree/include/isotree.hpp +2111 -0
- data/vendor/isotree/include/isotree_oop.hpp +394 -0
- data/vendor/isotree/inst/COPYRIGHTS +62 -0
- data/vendor/isotree/src/RcppExports.cpp +525 -52
- data/vendor/isotree/src/Rwrapper.cpp +1931 -268
- data/vendor/isotree/src/c_interface.cpp +953 -0
- data/vendor/isotree/src/crit.hpp +4232 -0
- data/vendor/isotree/src/dist.hpp +1886 -0
- data/vendor/isotree/src/exp_depth_table.hpp +134 -0
- data/vendor/isotree/src/extended.hpp +1444 -0
- data/vendor/isotree/src/external_facing_generic.hpp +399 -0
- data/vendor/isotree/src/fit_model.hpp +2401 -0
- data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
- data/vendor/isotree/src/helpers_iforest.hpp +813 -0
- data/vendor/isotree/src/{impute.cpp → impute.hpp} +353 -122
- data/vendor/isotree/src/indexer.cpp +515 -0
- data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
- data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
- data/vendor/isotree/src/isoforest.hpp +1659 -0
- data/vendor/isotree/src/isotree.hpp +1804 -392
- data/vendor/isotree/src/isotree_exportable.hpp +99 -0
- data/vendor/isotree/src/merge_models.cpp +159 -16
- data/vendor/isotree/src/mult.hpp +1321 -0
- data/vendor/isotree/src/oop_interface.cpp +842 -0
- data/vendor/isotree/src/oop_interface.hpp +278 -0
- data/vendor/isotree/src/other_helpers.hpp +219 -0
- data/vendor/isotree/src/predict.hpp +1932 -0
- data/vendor/isotree/src/python_helpers.hpp +134 -0
- data/vendor/isotree/src/ref_indexer.hpp +154 -0
- data/vendor/isotree/src/robinmap/LICENSE +21 -0
- data/vendor/isotree/src/robinmap/README.md +483 -0
- data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
- data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1620 -0
- data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
- data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
- data/vendor/isotree/src/serialize.cpp +4300 -139
- data/vendor/isotree/src/sql.cpp +141 -59
- data/vendor/isotree/src/subset_models.cpp +174 -0
- data/vendor/isotree/src/utils.hpp +3808 -0
- data/vendor/isotree/src/xoshiro.hpp +467 -0
- data/vendor/isotree/src/ziggurat.hpp +405 -0
- metadata +38 -104
- data/vendor/cereal/LICENSE +0 -24
- data/vendor/cereal/README.md +0 -85
- data/vendor/cereal/include/cereal/access.hpp +0 -351
- data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
- data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
- data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
- data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
- data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
- data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
- data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
- data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
- data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
- data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
- data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
- data/vendor/cereal/include/cereal/details/util.hpp +0 -84
- data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
- data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
- data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
- data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
- data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
- data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
- data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
- data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
- data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
- data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
- data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
- data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
- data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
- data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
- data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
- data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
- data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
- data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
- data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
- data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
- data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
- data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
- data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
- data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
- data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
- data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
- data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
- data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
- data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
- data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
- data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
- data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
- data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
- data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
- data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
- data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
- data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
- data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
- data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
- data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
- data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
- data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
- data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
- data/vendor/cereal/include/cereal/macros.hpp +0 -154
- data/vendor/cereal/include/cereal/specialize.hpp +0 -139
- data/vendor/cereal/include/cereal/types/array.hpp +0 -79
- data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
- data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
- data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
- data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
- data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
- data/vendor/cereal/include/cereal/types/common.hpp +0 -129
- data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
- data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
- data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
- data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
- data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
- data/vendor/cereal/include/cereal/types/list.hpp +0 -62
- data/vendor/cereal/include/cereal/types/map.hpp +0 -36
- data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
- data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
- data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
- data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
- data/vendor/cereal/include/cereal/types/set.hpp +0 -103
- data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
- data/vendor/cereal/include/cereal/types/string.hpp +0 -61
- data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
- data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
- data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
- data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
- data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
- data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
- data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
- data/vendor/cereal/include/cereal/version.hpp +0 -52
- data/vendor/isotree/src/Makevars +0 -4
- data/vendor/isotree/src/crit.cpp +0 -912
- data/vendor/isotree/src/dist.cpp +0 -749
- data/vendor/isotree/src/extended.cpp +0 -790
- data/vendor/isotree/src/fit_model.cpp +0 -1090
- data/vendor/isotree/src/helpers_iforest.cpp +0 -324
- data/vendor/isotree/src/isoforest.cpp +0 -771
- data/vendor/isotree/src/mult.cpp +0 -607
- data/vendor/isotree/src/predict.cpp +0 -853
- data/vendor/isotree/src/utils.cpp +0 -1566
|
@@ -0,0 +1,4232 @@
|
|
|
1
|
+
/* Isolation forests and variations thereof, with adjustments for incorporation
|
|
2
|
+
* of categorical variables and missing values.
|
|
3
|
+
* Writen for C++11 standard and aimed at being used in R and Python.
|
|
4
|
+
*
|
|
5
|
+
* This library is based on the following works:
|
|
6
|
+
* [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
|
|
7
|
+
* "Isolation forest."
|
|
8
|
+
* 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
|
|
9
|
+
* [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
|
|
10
|
+
* "Isolation-based anomaly detection."
|
|
11
|
+
* ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
|
|
12
|
+
* [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
|
|
13
|
+
* "Extended Isolation Forest."
|
|
14
|
+
* arXiv preprint arXiv:1811.02141 (2018).
|
|
15
|
+
* [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
|
|
16
|
+
* "On detecting clustered anomalies using SCiForest."
|
|
17
|
+
* Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
|
|
18
|
+
* [5] https://sourceforge.net/projects/iforest/
|
|
19
|
+
* [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
|
|
20
|
+
* [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
|
|
21
|
+
* [8] Cortes, David.
|
|
22
|
+
* "Distance approximation using Isolation Forests."
|
|
23
|
+
* arXiv preprint arXiv:1910.12362 (2019).
|
|
24
|
+
* [9] Cortes, David.
|
|
25
|
+
* "Imputing missing values with unsupervised random trees."
|
|
26
|
+
* arXiv preprint arXiv:1911.06646 (2019).
|
|
27
|
+
* [10] https://math.stackexchange.com/questions/3333220/expected-average-depth-in-random-binary-tree-constructed-top-to-bottom
|
|
28
|
+
* [11] Cortes, David.
|
|
29
|
+
* "Revisiting randomized choices in isolation forests."
|
|
30
|
+
* arXiv preprint arXiv:2110.13402 (2021).
|
|
31
|
+
* [12] Guha, Sudipto, et al.
|
|
32
|
+
* "Robust random cut forest based anomaly detection on streams."
|
|
33
|
+
* International conference on machine learning. PMLR, 2016.
|
|
34
|
+
* [13] Cortes, David.
|
|
35
|
+
* "Isolation forests: looking beyond tree depth."
|
|
36
|
+
* arXiv preprint arXiv:2111.11639 (2021).
|
|
37
|
+
* [14] Ting, Kai Ming, Yue Zhu, and Zhi-Hua Zhou.
|
|
38
|
+
* "Isolation kernel and its effect on SVM"
|
|
39
|
+
* Proceedings of the 24th ACM SIGKDD
|
|
40
|
+
* International Conference on Knowledge Discovery & Data Mining. 2018.
|
|
41
|
+
*
|
|
42
|
+
* BSD 2-Clause License
|
|
43
|
+
* Copyright (c) 2019-2022, David Cortes
|
|
44
|
+
* All rights reserved.
|
|
45
|
+
* Redistribution and use in source and binary forms, with or without
|
|
46
|
+
* modification, are permitted provided that the following conditions are met:
|
|
47
|
+
* * Redistributions of source code must retain the above copyright notice, this
|
|
48
|
+
* list of conditions and the following disclaimer.
|
|
49
|
+
* * Redistributions in binary form must reproduce the above copyright notice,
|
|
50
|
+
* this list of conditions and the following disclaimer in the documentation
|
|
51
|
+
* and/or other materials provided with the distribution.
|
|
52
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
53
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
54
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
55
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
56
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
57
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
58
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
59
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
60
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
61
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
62
|
+
*/
|
|
63
|
+
#include "isotree.hpp"
|
|
64
|
+
|
|
65
|
+
/* TODO: should the kurtosis calculation impute values when using ndim=1 + missing_action=Impute?
|
|
66
|
+
It should be the theoretically correct approach, but will cause the kurtosis to increase
|
|
67
|
+
significantly if there is a large number of missing values, which would lead to prefer
|
|
68
|
+
splitting on columns with mostly missing values. */
|
|
69
|
+
|
|
70
|
+
/* TODO: this kurtosis caps the minimum values to zero, but the minimum achievable value is 1,
|
|
71
|
+
see how are imprecise results used outside of the function in the different kind of calculations
|
|
72
|
+
that use kurtosis and maybe change the logic. */
|
|
73
|
+
|
|
74
|
+
#define pw1(x) ((x))
|
|
75
|
+
#define pw2(x) ((x) * (x))
|
|
76
|
+
#define pw3(x) ((x) * (x) * (x))
|
|
77
|
+
#define pw4(x) ((x) * (x) * (x) * (x))
|
|
78
|
+
|
|
79
|
+
/* https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics */
|
|
80
|
+
template <class real_t, class ldouble_safe>
|
|
81
|
+
double calc_kurtosis(size_t ix_arr[], size_t st, size_t end, real_t x[], MissingAction missing_action)
|
|
82
|
+
{
|
|
83
|
+
ldouble_safe m = 0;
|
|
84
|
+
ldouble_safe M2 = 0, M3 = 0, M4 = 0;
|
|
85
|
+
ldouble_safe delta, delta_s, delta_div;
|
|
86
|
+
ldouble_safe diff, n;
|
|
87
|
+
ldouble_safe out;
|
|
88
|
+
|
|
89
|
+
if (missing_action == Fail)
|
|
90
|
+
{
|
|
91
|
+
for (size_t row = st; row <= end; row++)
|
|
92
|
+
{
|
|
93
|
+
n = (ldouble_safe)(row - st + 1);
|
|
94
|
+
|
|
95
|
+
delta = x[ix_arr[row]] - m;
|
|
96
|
+
delta_div = delta / n;
|
|
97
|
+
delta_s = delta_div * delta_div;
|
|
98
|
+
diff = delta * (delta_div * (ldouble_safe)(row - st));
|
|
99
|
+
|
|
100
|
+
m += delta_div;
|
|
101
|
+
M4 += diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
|
|
102
|
+
M3 += diff * delta_div * (n - 2) - 3 * delta_div * M2;
|
|
103
|
+
M2 += diff;
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
if (unlikely(!is_na_or_inf(M2) && M2 <= 0))
|
|
107
|
+
{
|
|
108
|
+
if (!check_more_than_two_unique_values(ix_arr, st, end, x, missing_action))
|
|
109
|
+
return -HUGE_VAL;
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
out = ( M4 / M2 ) * ( (ldouble_safe)(end - st + 1) / M2 );
|
|
113
|
+
return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
else
|
|
117
|
+
{
|
|
118
|
+
size_t cnt = 0;
|
|
119
|
+
for (size_t row = st; row <= end; row++)
|
|
120
|
+
{
|
|
121
|
+
if (likely(!is_na_or_inf(x[ix_arr[row]])))
|
|
122
|
+
{
|
|
123
|
+
cnt++;
|
|
124
|
+
n = (ldouble_safe) cnt;
|
|
125
|
+
|
|
126
|
+
delta = x[ix_arr[row]] - m;
|
|
127
|
+
delta_div = delta / n;
|
|
128
|
+
delta_s = delta_div * delta_div;
|
|
129
|
+
diff = delta * (delta_div * (ldouble_safe)(cnt - 1));
|
|
130
|
+
|
|
131
|
+
m += delta_div;
|
|
132
|
+
M4 += diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
|
|
133
|
+
M3 += diff * delta_div * (n - 2) - 3 * delta_div * M2;
|
|
134
|
+
M2 += diff;
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
if (unlikely(cnt == 0)) return -HUGE_VAL;
|
|
139
|
+
if (unlikely(!is_na_or_inf(M2) && M2 <= 0))
|
|
140
|
+
{
|
|
141
|
+
if (!check_more_than_two_unique_values(ix_arr, st, end, x, missing_action))
|
|
142
|
+
return -HUGE_VAL;
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
out = ( M4 / M2 ) * ( (ldouble_safe)cnt / M2 );
|
|
146
|
+
return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
template <class real_t, class ldouble_safe>
|
|
151
|
+
double calc_kurtosis(real_t x[], size_t n, MissingAction missing_action)
|
|
152
|
+
{
|
|
153
|
+
ldouble_safe m = 0;
|
|
154
|
+
ldouble_safe M2 = 0, M3 = 0, M4 = 0;
|
|
155
|
+
ldouble_safe delta, delta_s, delta_div;
|
|
156
|
+
ldouble_safe diff, n_;
|
|
157
|
+
ldouble_safe out;
|
|
158
|
+
|
|
159
|
+
if (missing_action == Fail)
|
|
160
|
+
{
|
|
161
|
+
for (size_t row = 0; row < n; row++)
|
|
162
|
+
{
|
|
163
|
+
n_ = (ldouble_safe)(row + 1);
|
|
164
|
+
|
|
165
|
+
delta = x[row] - m;
|
|
166
|
+
delta_div = delta / n_;
|
|
167
|
+
delta_s = delta_div * delta_div;
|
|
168
|
+
diff = delta * (delta_div * (ldouble_safe)row);
|
|
169
|
+
|
|
170
|
+
m += delta_div;
|
|
171
|
+
M4 += diff * delta_s * (n_ * n_ - 3 * n_ + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
|
|
172
|
+
M3 += diff * delta_div * (n_ - 2) - 3 * delta_div * M2;
|
|
173
|
+
M2 += diff;
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
out = ( M4 / M2 ) * ( (ldouble_safe)n / M2 );
|
|
177
|
+
return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
else
|
|
181
|
+
{
|
|
182
|
+
size_t cnt = 0;
|
|
183
|
+
for (size_t row = 0; row < n; row++)
|
|
184
|
+
{
|
|
185
|
+
if (likely(!is_na_or_inf(x[row])))
|
|
186
|
+
{
|
|
187
|
+
cnt++;
|
|
188
|
+
n_ = (ldouble_safe) cnt;
|
|
189
|
+
|
|
190
|
+
delta = x[row] - m;
|
|
191
|
+
delta_div = delta / n_;
|
|
192
|
+
delta_s = delta_div * delta_div;
|
|
193
|
+
diff = delta * (delta_div * (ldouble_safe)(cnt - 1));
|
|
194
|
+
|
|
195
|
+
m += delta_div;
|
|
196
|
+
M4 += diff * delta_s * (n_ * n_ - 3 * n_ + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
|
|
197
|
+
M3 += diff * delta_div * (n_ - 2) - 3 * delta_div * M2;
|
|
198
|
+
M2 += diff;
|
|
199
|
+
}
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
if (unlikely(cnt == 0)) return -HUGE_VAL;
|
|
203
|
+
|
|
204
|
+
out = ( M4 / M2 ) * ( (ldouble_safe)cnt / M2 );
|
|
205
|
+
return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
/* TODO: is this algorithm correct? */
|
|
210
|
+
template <class real_t, class mapping, class ldouble_safe>
|
|
211
|
+
double calc_kurtosis_weighted(size_t ix_arr[], size_t st, size_t end, real_t x[],
|
|
212
|
+
MissingAction missing_action, mapping &restrict w)
|
|
213
|
+
{
|
|
214
|
+
ldouble_safe m = 0;
|
|
215
|
+
ldouble_safe M2 = 0, M3 = 0, M4 = 0;
|
|
216
|
+
ldouble_safe delta, delta_s, delta_div;
|
|
217
|
+
ldouble_safe diff;
|
|
218
|
+
ldouble_safe n = 0;
|
|
219
|
+
ldouble_safe out;
|
|
220
|
+
ldouble_safe n_prev = 0.;
|
|
221
|
+
ldouble_safe w_this;
|
|
222
|
+
|
|
223
|
+
for (size_t row = st; row <= end; row++)
|
|
224
|
+
{
|
|
225
|
+
if (likely(!is_na_or_inf(x[ix_arr[row]])))
|
|
226
|
+
{
|
|
227
|
+
w_this = w[ix_arr[row]];
|
|
228
|
+
n += w_this;
|
|
229
|
+
|
|
230
|
+
delta = x[ix_arr[row]] - m;
|
|
231
|
+
delta_div = delta / n;
|
|
232
|
+
delta_s = delta_div * delta_div;
|
|
233
|
+
diff = delta * (delta_div * n_prev);
|
|
234
|
+
n_prev = n;
|
|
235
|
+
|
|
236
|
+
m += w_this * (delta_div);
|
|
237
|
+
M4 += w_this * (diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3);
|
|
238
|
+
M3 += w_this * (diff * delta_div * (n - 2) - 3 * delta_div * M2);
|
|
239
|
+
M2 += w_this * (diff);
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
if (unlikely(n <= 0)) return -HUGE_VAL;
|
|
244
|
+
if (unlikely(!is_na_or_inf(M2) && M2 <= std::numeric_limits<double>::epsilon()))
|
|
245
|
+
{
|
|
246
|
+
if (!check_more_than_two_unique_values(ix_arr, st, end, x, missing_action))
|
|
247
|
+
return -HUGE_VAL;
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
out = ( M4 / M2 ) * ( n / M2 );
|
|
251
|
+
return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
template <class real_t, class ldouble_safe>
|
|
255
|
+
double calc_kurtosis_weighted(real_t *restrict x, size_t n_, MissingAction missing_action, real_t *restrict w)
|
|
256
|
+
{
|
|
257
|
+
ldouble_safe m = 0;
|
|
258
|
+
ldouble_safe M2 = 0, M3 = 0, M4 = 0;
|
|
259
|
+
ldouble_safe delta, delta_s, delta_div;
|
|
260
|
+
ldouble_safe diff;
|
|
261
|
+
ldouble_safe n = 0;
|
|
262
|
+
ldouble_safe out;
|
|
263
|
+
ldouble_safe n_prev = 0.;
|
|
264
|
+
ldouble_safe w_this;
|
|
265
|
+
|
|
266
|
+
for (size_t row = 0; row < n_; row++)
|
|
267
|
+
{
|
|
268
|
+
if (likely(!is_na_or_inf(x[row])))
|
|
269
|
+
{
|
|
270
|
+
w_this = w[row];
|
|
271
|
+
n += w_this;
|
|
272
|
+
|
|
273
|
+
delta = x[row] - m;
|
|
274
|
+
delta_div = delta / n;
|
|
275
|
+
delta_s = delta_div * delta_div;
|
|
276
|
+
diff = delta * (delta_div * n_prev);
|
|
277
|
+
n_prev = n;
|
|
278
|
+
|
|
279
|
+
m += w_this * (delta_div);
|
|
280
|
+
M4 += w_this * (diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3);
|
|
281
|
+
M3 += w_this * (diff * delta_div * (n - 2) - 3 * delta_div * M2);
|
|
282
|
+
M2 += w_this * (diff);
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
if (unlikely(n <= 0)) return -HUGE_VAL;
|
|
287
|
+
|
|
288
|
+
out = ( M4 / M2 ) * ( n / M2 );
|
|
289
|
+
return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
/* TODO: make these compensated sums */
|
|
294
|
+
/* TODO: can this use the same algorithm as above but with a correction at the end,
|
|
295
|
+
like it was done for the variance? */
|
|
296
|
+
template <class real_t, class sparse_ix, class ldouble_safe>
|
|
297
|
+
double calc_kurtosis(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
|
|
298
|
+
real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
|
|
299
|
+
MissingAction missing_action)
|
|
300
|
+
{
|
|
301
|
+
/* ix_arr must be already sorted beforehand */
|
|
302
|
+
if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
|
|
303
|
+
return -HUGE_VAL;
|
|
304
|
+
|
|
305
|
+
ldouble_safe s1 = 0;
|
|
306
|
+
ldouble_safe s2 = 0;
|
|
307
|
+
ldouble_safe s3 = 0;
|
|
308
|
+
ldouble_safe s4 = 0;
|
|
309
|
+
ldouble_safe x_sq;
|
|
310
|
+
size_t cnt = end - st + 1;
|
|
311
|
+
|
|
312
|
+
if (unlikely(cnt <= 1)) return -HUGE_VAL;
|
|
313
|
+
|
|
314
|
+
size_t st_col = Xc_indptr[col_num];
|
|
315
|
+
size_t end_col = Xc_indptr[col_num + 1] - 1;
|
|
316
|
+
size_t curr_pos = st_col;
|
|
317
|
+
size_t ind_end_col = Xc_ind[end_col];
|
|
318
|
+
size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
|
|
319
|
+
|
|
320
|
+
ldouble_safe xval;
|
|
321
|
+
|
|
322
|
+
if (missing_action != Fail)
|
|
323
|
+
{
|
|
324
|
+
for (size_t *row = ptr_st;
|
|
325
|
+
row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
|
|
326
|
+
)
|
|
327
|
+
{
|
|
328
|
+
if (Xc_ind[curr_pos] == (sparse_ix)(*row))
|
|
329
|
+
{
|
|
330
|
+
xval = Xc[curr_pos];
|
|
331
|
+
if (unlikely(is_na_or_inf(xval)))
|
|
332
|
+
{
|
|
333
|
+
cnt--;
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
else
|
|
337
|
+
{
|
|
338
|
+
/* TODO: is it safe to use FMA here? some calculations rely on assuming that
|
|
339
|
+
some of these 's' are larger than the others. Would this procedure be guaranteed
|
|
340
|
+
to preserve such differences if done with a mixture of sums and FMAs? */
|
|
341
|
+
x_sq = square(xval);
|
|
342
|
+
s1 += xval;
|
|
343
|
+
s2 = std::fma(xval, xval, s2);
|
|
344
|
+
s3 = std::fma(x_sq, xval, s3);
|
|
345
|
+
s4 = std::fma(x_sq, x_sq, s4);
|
|
346
|
+
// s1 += pw1(xval);
|
|
347
|
+
// s2 += pw2(xval);
|
|
348
|
+
// s3 += pw3(xval);
|
|
349
|
+
// s4 += pw4(xval);
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
if (row == ix_arr + end || curr_pos == end_col) break;
|
|
353
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
else
|
|
357
|
+
{
|
|
358
|
+
if (Xc_ind[curr_pos] > (sparse_ix)(*row))
|
|
359
|
+
row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
|
|
360
|
+
else
|
|
361
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
|
|
362
|
+
}
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
if (unlikely(cnt <= (end - st + 1) - (Xc_indptr[col_num+1] - Xc_indptr[col_num]))) return -HUGE_VAL;
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
else
|
|
369
|
+
{
|
|
370
|
+
for (size_t *row = ptr_st;
|
|
371
|
+
row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
|
|
372
|
+
)
|
|
373
|
+
{
|
|
374
|
+
if (Xc_ind[curr_pos] == (sparse_ix)(*row))
|
|
375
|
+
{
|
|
376
|
+
xval = Xc[curr_pos];
|
|
377
|
+
x_sq = square(xval);
|
|
378
|
+
s1 += xval;
|
|
379
|
+
s2 = std::fma(xval, xval, s2);
|
|
380
|
+
s3 = std::fma(x_sq, xval, s3);
|
|
381
|
+
s4 = std::fma(x_sq, x_sq, s4);
|
|
382
|
+
// s1 += pw1(xval);
|
|
383
|
+
// s2 += pw2(xval);
|
|
384
|
+
// s3 += pw3(xval);
|
|
385
|
+
// s4 += pw4(xval);
|
|
386
|
+
|
|
387
|
+
if (row == ix_arr + end || curr_pos == end_col) break;
|
|
388
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
else
|
|
392
|
+
{
|
|
393
|
+
if (Xc_ind[curr_pos] > (sparse_ix)(*row))
|
|
394
|
+
row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
|
|
395
|
+
else
|
|
396
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
|
|
397
|
+
}
|
|
398
|
+
}
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
|
|
402
|
+
ldouble_safe cnt_l = (ldouble_safe) cnt;
|
|
403
|
+
ldouble_safe sn = s1 / cnt_l;
|
|
404
|
+
ldouble_safe v = s2 / cnt_l - pw2(sn);
|
|
405
|
+
if (unlikely(std::isnan(v))) return -HUGE_VAL;
|
|
406
|
+
if (
|
|
407
|
+
v <= std::numeric_limits<double>::epsilon() &&
|
|
408
|
+
!check_more_than_two_unique_values(ix_arr, st, end, col_num,
|
|
409
|
+
Xc_indptr, Xc_ind, Xc,
|
|
410
|
+
missing_action)
|
|
411
|
+
)
|
|
412
|
+
return -HUGE_VAL;
|
|
413
|
+
if (unlikely(v <= 0)) return 0.;
|
|
414
|
+
ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt_l * pw4(sn)) / (cnt_l * pw2(v));
|
|
415
|
+
return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
template <class real_t, class sparse_ix, class ldouble_safe>
|
|
419
|
+
double calc_kurtosis(size_t col_num, size_t nrows,
|
|
420
|
+
real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
|
|
421
|
+
MissingAction missing_action)
|
|
422
|
+
{
|
|
423
|
+
if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
|
|
424
|
+
return -HUGE_VAL;
|
|
425
|
+
|
|
426
|
+
ldouble_safe s1 = 0;
|
|
427
|
+
ldouble_safe s2 = 0;
|
|
428
|
+
ldouble_safe s3 = 0;
|
|
429
|
+
ldouble_safe s4 = 0;
|
|
430
|
+
ldouble_safe x_sq;
|
|
431
|
+
size_t cnt = nrows;
|
|
432
|
+
|
|
433
|
+
if (unlikely(cnt <= 1)) return -HUGE_VAL;
|
|
434
|
+
|
|
435
|
+
ldouble_safe xval;
|
|
436
|
+
|
|
437
|
+
if (missing_action != Fail)
|
|
438
|
+
{
|
|
439
|
+
for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num+1]; ix++)
|
|
440
|
+
{
|
|
441
|
+
xval = Xc[ix];
|
|
442
|
+
if (unlikely(is_na_or_inf(xval)))
|
|
443
|
+
{
|
|
444
|
+
cnt--;
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
else
|
|
448
|
+
{
|
|
449
|
+
x_sq = square(xval);
|
|
450
|
+
s1 += xval;
|
|
451
|
+
s2 = std::fma(xval, xval, s2);
|
|
452
|
+
s3 = std::fma(x_sq, xval, s3);
|
|
453
|
+
s4 = std::fma(x_sq, x_sq, s4);
|
|
454
|
+
// s1 += pw1(xval);
|
|
455
|
+
// s2 += pw2(xval);
|
|
456
|
+
// s3 += pw3(xval);
|
|
457
|
+
// s4 += pw4(xval);
|
|
458
|
+
}
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
if (cnt <= (nrows) - (Xc_indptr[col_num+1] - Xc_indptr[col_num])) return -HUGE_VAL;
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
else
|
|
465
|
+
{
|
|
466
|
+
for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num+1]; ix++)
|
|
467
|
+
{
|
|
468
|
+
xval = Xc[ix];
|
|
469
|
+
x_sq = square(xval);
|
|
470
|
+
s1 += xval;
|
|
471
|
+
s2 = std::fma(xval, xval, s2);
|
|
472
|
+
s3 = std::fma(x_sq, xval, s3);
|
|
473
|
+
s4 = std::fma(x_sq, x_sq, s4);
|
|
474
|
+
// s1 += pw1(xval);
|
|
475
|
+
// s2 += pw2(xval);
|
|
476
|
+
// s3 += pw3(xval);
|
|
477
|
+
// s4 += pw4(xval);
|
|
478
|
+
}
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
|
|
482
|
+
ldouble_safe cnt_l = (ldouble_safe) cnt;
|
|
483
|
+
ldouble_safe sn = s1 / cnt_l;
|
|
484
|
+
ldouble_safe v = s2 / cnt_l - pw2(sn);
|
|
485
|
+
if (unlikely(std::isnan(v))) return -HUGE_VAL;
|
|
486
|
+
if (
|
|
487
|
+
v <= std::numeric_limits<double>::epsilon() &&
|
|
488
|
+
!check_more_than_two_unique_values(nrows, col_num,
|
|
489
|
+
Xc_indptr, Xc_ind, Xc,
|
|
490
|
+
missing_action)
|
|
491
|
+
)
|
|
492
|
+
return -HUGE_VAL;
|
|
493
|
+
if (unlikely(v <= 0)) return 0.;
|
|
494
|
+
ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt_l * pw4(sn)) / (cnt_l * pw2(v));
|
|
495
|
+
return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
template <class real_t, class sparse_ix, class mapping, class ldouble_safe>
|
|
500
|
+
double calc_kurtosis_weighted(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
|
|
501
|
+
real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
|
|
502
|
+
MissingAction missing_action, mapping &restrict w)
|
|
503
|
+
{
|
|
504
|
+
/* ix_arr must be already sorted beforehand */
|
|
505
|
+
if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
|
|
506
|
+
return -HUGE_VAL;
|
|
507
|
+
|
|
508
|
+
ldouble_safe s1 = 0;
|
|
509
|
+
ldouble_safe s2 = 0;
|
|
510
|
+
ldouble_safe s3 = 0;
|
|
511
|
+
ldouble_safe s4 = 0;
|
|
512
|
+
ldouble_safe x_sq;
|
|
513
|
+
ldouble_safe w_this;
|
|
514
|
+
ldouble_safe cnt = 0;
|
|
515
|
+
for (size_t row = st; row <= end; row++)
|
|
516
|
+
cnt += w[ix_arr[row]];
|
|
517
|
+
|
|
518
|
+
if (unlikely(cnt <= 0)) return -HUGE_VAL;
|
|
519
|
+
|
|
520
|
+
size_t st_col = Xc_indptr[col_num];
|
|
521
|
+
size_t end_col = Xc_indptr[col_num + 1] - 1;
|
|
522
|
+
size_t curr_pos = st_col;
|
|
523
|
+
size_t ind_end_col = Xc_ind[end_col];
|
|
524
|
+
size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
|
|
525
|
+
|
|
526
|
+
ldouble_safe xval;
|
|
527
|
+
|
|
528
|
+
if (missing_action != Fail)
|
|
529
|
+
{
|
|
530
|
+
for (size_t *row = ptr_st;
|
|
531
|
+
row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
|
|
532
|
+
)
|
|
533
|
+
{
|
|
534
|
+
if (Xc_ind[curr_pos] == (sparse_ix)(*row))
|
|
535
|
+
{
|
|
536
|
+
w_this = w[*row];
|
|
537
|
+
xval = Xc[curr_pos];
|
|
538
|
+
|
|
539
|
+
if (unlikely(is_na_or_inf(xval)))
|
|
540
|
+
{
|
|
541
|
+
cnt -= w_this;
|
|
542
|
+
}
|
|
543
|
+
|
|
544
|
+
else
|
|
545
|
+
{
|
|
546
|
+
x_sq = xval * xval;
|
|
547
|
+
s1 = std::fma(w_this, xval, s1);
|
|
548
|
+
s2 = std::fma(w_this, x_sq, s2);
|
|
549
|
+
s3 = std::fma(w_this, x_sq*xval, s3);
|
|
550
|
+
s4 = std::fma(w_this, x_sq*x_sq, s4);
|
|
551
|
+
// s1 += w_this * pw1(xval);
|
|
552
|
+
// s2 += w_this * pw2(xval);
|
|
553
|
+
// s3 += w_this * pw3(xval);
|
|
554
|
+
// s4 += w_this * pw4(xval);
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
if (row == ix_arr + end || curr_pos == end_col) break;
|
|
558
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
else
|
|
562
|
+
{
|
|
563
|
+
if (Xc_ind[curr_pos] > (sparse_ix)(*row))
|
|
564
|
+
row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
|
|
565
|
+
else
|
|
566
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
|
|
567
|
+
}
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
if (unlikely(cnt <= 0)) return -HUGE_VAL;
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
else
|
|
574
|
+
{
|
|
575
|
+
for (size_t *row = ptr_st;
|
|
576
|
+
row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
|
|
577
|
+
)
|
|
578
|
+
{
|
|
579
|
+
if (Xc_ind[curr_pos] == (sparse_ix)(*row))
|
|
580
|
+
{
|
|
581
|
+
w_this = w[*row];
|
|
582
|
+
xval = Xc[curr_pos];
|
|
583
|
+
|
|
584
|
+
x_sq = xval * xval;
|
|
585
|
+
s1 = std::fma(w_this, xval, s1);
|
|
586
|
+
s2 = std::fma(w_this, x_sq, s2);
|
|
587
|
+
s3 = std::fma(w_this, x_sq*xval, s3);
|
|
588
|
+
s4 = std::fma(w_this, x_sq*x_sq, s4);
|
|
589
|
+
// s1 += w_this * pw1(xval);
|
|
590
|
+
// s2 += w_this * pw2(xval);
|
|
591
|
+
// s3 += w_this * pw3(xval);
|
|
592
|
+
// s4 += w_this * pw4(xval);
|
|
593
|
+
|
|
594
|
+
if (row == ix_arr + end || curr_pos == end_col) break;
|
|
595
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
|
|
596
|
+
}
|
|
597
|
+
|
|
598
|
+
else
|
|
599
|
+
{
|
|
600
|
+
if (Xc_ind[curr_pos] > (sparse_ix)(*row))
|
|
601
|
+
row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
|
|
602
|
+
else
|
|
603
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
|
|
604
|
+
}
|
|
605
|
+
}
|
|
606
|
+
}
|
|
607
|
+
|
|
608
|
+
if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
|
|
609
|
+
ldouble_safe sn = s1 / cnt;
|
|
610
|
+
ldouble_safe v = s2 / cnt - pw2(sn);
|
|
611
|
+
if (unlikely(std::isnan(v))) return -HUGE_VAL;
|
|
612
|
+
if (
|
|
613
|
+
v <= std::numeric_limits<double>::epsilon() &&
|
|
614
|
+
!check_more_than_two_unique_values(ix_arr, st, end, col_num,
|
|
615
|
+
Xc_indptr, Xc_ind, Xc,
|
|
616
|
+
missing_action)
|
|
617
|
+
)
|
|
618
|
+
return -HUGE_VAL;
|
|
619
|
+
if (v <= 0) return 0.;
|
|
620
|
+
ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt * pw4(sn)) / (cnt * pw2(v));
|
|
621
|
+
return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
|
|
622
|
+
}
|
|
623
|
+
|
|
624
|
+
template <class real_t, class sparse_ix, class ldouble_safe>
|
|
625
|
+
double calc_kurtosis_weighted(size_t col_num, size_t nrows,
|
|
626
|
+
real_t *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
|
|
627
|
+
MissingAction missing_action, real_t *restrict w)
|
|
628
|
+
{
|
|
629
|
+
if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
|
|
630
|
+
return -HUGE_VAL;
|
|
631
|
+
|
|
632
|
+
ldouble_safe s1 = 0;
|
|
633
|
+
ldouble_safe s2 = 0;
|
|
634
|
+
ldouble_safe s3 = 0;
|
|
635
|
+
ldouble_safe s4 = 0;
|
|
636
|
+
ldouble_safe x_sq;
|
|
637
|
+
ldouble_safe w_this;
|
|
638
|
+
ldouble_safe cnt = nrows - (Xc_indptr[col_num + 1] - Xc_indptr[col_num]);
|
|
639
|
+
for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num + 1]; ix++)
|
|
640
|
+
cnt += w[Xc_ind[ix]];
|
|
641
|
+
|
|
642
|
+
if (unlikely(cnt <= 0)) return -HUGE_VAL;
|
|
643
|
+
|
|
644
|
+
ldouble_safe xval;
|
|
645
|
+
|
|
646
|
+
if (missing_action != Fail)
|
|
647
|
+
{
|
|
648
|
+
for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num + 1]; ix++)
|
|
649
|
+
{
|
|
650
|
+
w_this = w[Xc_ind[ix]];
|
|
651
|
+
xval = Xc[ix];
|
|
652
|
+
|
|
653
|
+
if (unlikely(is_na_or_inf(xval)))
|
|
654
|
+
{
|
|
655
|
+
cnt -= w_this;
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
else
|
|
659
|
+
{
|
|
660
|
+
x_sq = xval * xval;
|
|
661
|
+
s1 = std::fma(w_this, xval, s1);
|
|
662
|
+
s2 = std::fma(w_this, x_sq, s2);
|
|
663
|
+
s3 = std::fma(w_this, x_sq*xval, s3);
|
|
664
|
+
s4 = std::fma(w_this, x_sq*x_sq, s4);
|
|
665
|
+
// s1 += w_this * pw1(xval);
|
|
666
|
+
// s2 += w_this * pw2(xval);
|
|
667
|
+
// s3 += w_this * pw3(xval);
|
|
668
|
+
// s4 += w_this * pw4(xval);
|
|
669
|
+
}
|
|
670
|
+
}
|
|
671
|
+
|
|
672
|
+
if (cnt <= 0) return -HUGE_VAL;
|
|
673
|
+
}
|
|
674
|
+
|
|
675
|
+
else
|
|
676
|
+
{
|
|
677
|
+
for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num + 1]; ix++)
|
|
678
|
+
{
|
|
679
|
+
w_this = w[Xc_ind[ix]];
|
|
680
|
+
xval = Xc[ix];
|
|
681
|
+
|
|
682
|
+
x_sq = xval * xval;
|
|
683
|
+
s1 = std::fma(w_this, xval, s1);
|
|
684
|
+
s2 = std::fma(w_this, x_sq, s2);
|
|
685
|
+
s3 = std::fma(w_this, x_sq*xval, s3);
|
|
686
|
+
s4 = std::fma(w_this, x_sq*x_sq, s4);
|
|
687
|
+
// s1 += w_this * pw1(xval);
|
|
688
|
+
// s2 += w_this * pw2(xval);
|
|
689
|
+
// s3 += w_this * pw3(xval);
|
|
690
|
+
// s4 += w_this * pw4(xval);
|
|
691
|
+
}
|
|
692
|
+
}
|
|
693
|
+
|
|
694
|
+
if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
|
|
695
|
+
ldouble_safe sn = s1 / cnt;
|
|
696
|
+
ldouble_safe v = s2 / cnt - pw2(sn);
|
|
697
|
+
if (unlikely(std::isnan(v))) return -HUGE_VAL;
|
|
698
|
+
if (
|
|
699
|
+
v <= std::numeric_limits<double>::epsilon() &&
|
|
700
|
+
!check_more_than_two_unique_values(nrows, col_num,
|
|
701
|
+
Xc_indptr, Xc_ind, Xc,
|
|
702
|
+
missing_action)
|
|
703
|
+
)
|
|
704
|
+
return -HUGE_VAL;
|
|
705
|
+
if (unlikely(v <= 0)) return -HUGE_VAL;
|
|
706
|
+
ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt * pw4(sn)) / (cnt * pw2(v));
|
|
707
|
+
return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
|
|
708
|
+
}
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
template <class ldouble_safe>
|
|
712
|
+
double calc_kurtosis_internal(size_t cnt, int x[], int ncat, size_t buffer_cnt[], double buffer_prob[],
|
|
713
|
+
MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
|
|
714
|
+
{
|
|
715
|
+
/* This calculation proceeds as follows:
|
|
716
|
+
- If splitting by subsets, it will assign a random weight ~Unif(0,1) to
|
|
717
|
+
each category, and approximate kurtosis by sampling from such distribution
|
|
718
|
+
with the same probabilities as given by the current counts.
|
|
719
|
+
- If splitting by isolating one category, will binarize at each categorical level,
|
|
720
|
+
assume the values are zero or one, and output the average assuming each categorical
|
|
721
|
+
level has equal probability of being picked.
|
|
722
|
+
(Note that both are misleading heuristics, but might be better than random)
|
|
723
|
+
*/
|
|
724
|
+
double sum_kurt = 0;
|
|
725
|
+
|
|
726
|
+
cnt -= buffer_cnt[ncat];
|
|
727
|
+
if (cnt <= 1) return -HUGE_VAL;
|
|
728
|
+
ldouble_safe cnt_l = (ldouble_safe) cnt;
|
|
729
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
730
|
+
buffer_prob[cat] = buffer_cnt[cat] / cnt_l;
|
|
731
|
+
|
|
732
|
+
switch (cat_split_type)
|
|
733
|
+
{
|
|
734
|
+
case SubSet:
|
|
735
|
+
{
|
|
736
|
+
ldouble_safe temp_v;
|
|
737
|
+
ldouble_safe s1, s2, s3, s4;
|
|
738
|
+
ldouble_safe coef;
|
|
739
|
+
ldouble_safe coef2;
|
|
740
|
+
ldouble_safe w_this;
|
|
741
|
+
UniformUnitInterval runif(0, 1);
|
|
742
|
+
size_t ntry = 50;
|
|
743
|
+
for (size_t iternum = 0; iternum < 50; iternum++)
|
|
744
|
+
{
|
|
745
|
+
s1 = 0; s2 = 0; s3 = 0; s4 = 0;
|
|
746
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
747
|
+
{
|
|
748
|
+
coef = runif(rnd_generator);
|
|
749
|
+
coef2 = coef * coef;
|
|
750
|
+
w_this = buffer_prob[cat];
|
|
751
|
+
s1 = std::fma(w_this, coef, s1);
|
|
752
|
+
s2 = std::fma(w_this, coef2, s2);
|
|
753
|
+
s3 = std::fma(w_this, coef2*coef, s3);
|
|
754
|
+
s4 = std::fma(w_this, coef2*coef2, s4);
|
|
755
|
+
// s1 += buffer_prob[cat] * pw1(coef);
|
|
756
|
+
// s2 += buffer_prob[cat] * pw2(coef);
|
|
757
|
+
// s3 += buffer_prob[cat] * pw3(coef);
|
|
758
|
+
// s4 += buffer_prob[cat] * pw4(coef);
|
|
759
|
+
}
|
|
760
|
+
temp_v = s2 - pw2(s1);
|
|
761
|
+
if (temp_v <= 0)
|
|
762
|
+
ntry--;
|
|
763
|
+
else
|
|
764
|
+
sum_kurt += (s4 - 4 * s3 * pw1(s1) + 6 * s2 * pw2(s1) - 4 * s1 * pw3(s1) + pw4(s1)) / pw2(temp_v);
|
|
765
|
+
}
|
|
766
|
+
if (unlikely(!ntry))
|
|
767
|
+
return -HUGE_VAL;
|
|
768
|
+
else if (unlikely(is_na_or_inf(sum_kurt)))
|
|
769
|
+
return -HUGE_VAL;
|
|
770
|
+
else
|
|
771
|
+
return std::fmax(sum_kurt, 0.) / (double)ntry;
|
|
772
|
+
}
|
|
773
|
+
|
|
774
|
+
case SingleCateg:
|
|
775
|
+
{
|
|
776
|
+
double p;
|
|
777
|
+
int ncat_present = ncat;
|
|
778
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
779
|
+
{
|
|
780
|
+
p = buffer_prob[cat];
|
|
781
|
+
if (p == 0)
|
|
782
|
+
ncat_present--;
|
|
783
|
+
else
|
|
784
|
+
sum_kurt += (p - 4 * p * pw1(p) + 6 * p * pw2(p) - 4 * p * pw3(p) + pw4(p)) / pw2(p - pw2(p));
|
|
785
|
+
}
|
|
786
|
+
if (ncat_present <= 1)
|
|
787
|
+
return -HUGE_VAL;
|
|
788
|
+
else if (unlikely(is_na_or_inf(sum_kurt)))
|
|
789
|
+
return -HUGE_VAL;
|
|
790
|
+
else
|
|
791
|
+
return std::fmax(sum_kurt, 0.) / (double)ncat_present;
|
|
792
|
+
}
|
|
793
|
+
}
|
|
794
|
+
|
|
795
|
+
return -1; /* this will never be reached, but CRAN complains otherwise */
|
|
796
|
+
}
|
|
797
|
+
|
|
798
|
+
template <class ldouble_safe>
|
|
799
|
+
double calc_kurtosis(size_t *restrict ix_arr, size_t st, size_t end, int x[], int ncat, size_t *restrict buffer_cnt, double buffer_prob[],
|
|
800
|
+
MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
|
|
801
|
+
{
|
|
802
|
+
/* This calculation proceeds as follows:
|
|
803
|
+
- If splitting by subsets, it will assign a random weight ~Unif(0,1) to
|
|
804
|
+
each category, and approximate kurtosis by sampling from such distribution
|
|
805
|
+
with the same probabilities as given by the current counts.
|
|
806
|
+
- If splitting by isolating one category, will binarize at each categorical level,
|
|
807
|
+
assume the values are zero or one, and output the average assuming each categorical
|
|
808
|
+
level has equal probability of being picked.
|
|
809
|
+
(Note that both are misleading heuristics, but might be better than random)
|
|
810
|
+
*/
|
|
811
|
+
size_t cnt = end - st + 1;
|
|
812
|
+
std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
|
|
813
|
+
|
|
814
|
+
if (missing_action == Fail)
|
|
815
|
+
{
|
|
816
|
+
for (size_t row = st; row <= end; row++)
|
|
817
|
+
buffer_cnt[x[ix_arr[row]]]++;
|
|
818
|
+
}
|
|
819
|
+
|
|
820
|
+
else
|
|
821
|
+
{
|
|
822
|
+
for (size_t row = st; row <= end; row++)
|
|
823
|
+
{
|
|
824
|
+
if (likely(x[ix_arr[row]] >= 0))
|
|
825
|
+
buffer_cnt[x[ix_arr[row]]]++;
|
|
826
|
+
else
|
|
827
|
+
buffer_cnt[ncat]++;
|
|
828
|
+
}
|
|
829
|
+
}
|
|
830
|
+
|
|
831
|
+
return calc_kurtosis_internal<ldouble_safe>(
|
|
832
|
+
cnt, x, ncat, buffer_cnt, buffer_prob,
|
|
833
|
+
missing_action, cat_split_type, rnd_generator);
|
|
834
|
+
}
|
|
835
|
+
|
|
836
|
+
template <class ldouble_safe>
|
|
837
|
+
double calc_kurtosis(size_t nrows, int x[], int ncat, size_t buffer_cnt[], double buffer_prob[],
|
|
838
|
+
MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
|
|
839
|
+
{
|
|
840
|
+
size_t cnt = nrows;
|
|
841
|
+
std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
|
|
842
|
+
|
|
843
|
+
if (missing_action == Fail)
|
|
844
|
+
{
|
|
845
|
+
for (size_t row = 0; row < nrows; row++)
|
|
846
|
+
buffer_cnt[x[row]]++;
|
|
847
|
+
}
|
|
848
|
+
|
|
849
|
+
else
|
|
850
|
+
{
|
|
851
|
+
for (size_t row = 0; row < nrows; row++)
|
|
852
|
+
{
|
|
853
|
+
if (likely(x[row] >= 0))
|
|
854
|
+
buffer_cnt[x[row]]++;
|
|
855
|
+
else
|
|
856
|
+
buffer_cnt[ncat]++;
|
|
857
|
+
}
|
|
858
|
+
}
|
|
859
|
+
|
|
860
|
+
return calc_kurtosis_internal<ldouble_safe>(
|
|
861
|
+
cnt, x, ncat, buffer_cnt, buffer_prob,
|
|
862
|
+
missing_action, cat_split_type, rnd_generator);
|
|
863
|
+
}
|
|
864
|
+
|
|
865
|
+
|
|
866
|
+
/* TODO: this one should get a buffer preallocated from outside */
|
|
867
|
+
template <class mapping, class ldouble_safe>
|
|
868
|
+
double calc_kurtosis_weighted_internal(std::vector<ldouble_safe> &buffer_cnt, int x[], int ncat,
|
|
869
|
+
double buffer_prob[], MissingAction missing_action, CategSplit cat_split_type,
|
|
870
|
+
RNG_engine &rnd_generator, mapping &restrict w)
|
|
871
|
+
{
|
|
872
|
+
double sum_kurt = 0;
|
|
873
|
+
|
|
874
|
+
ldouble_safe cnt = std::accumulate(buffer_cnt.begin(), buffer_cnt.end(), (ldouble_safe)0);
|
|
875
|
+
|
|
876
|
+
cnt -= buffer_cnt[ncat];
|
|
877
|
+
if (unlikely(cnt <= 1)) return -HUGE_VAL;
|
|
878
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
879
|
+
buffer_prob[cat] = buffer_cnt[cat] / cnt;
|
|
880
|
+
|
|
881
|
+
switch (cat_split_type)
|
|
882
|
+
{
|
|
883
|
+
case SubSet:
|
|
884
|
+
{
|
|
885
|
+
ldouble_safe temp_v;
|
|
886
|
+
ldouble_safe s1, s2, s3, s4;
|
|
887
|
+
ldouble_safe coef, coef2;
|
|
888
|
+
ldouble_safe w_this;
|
|
889
|
+
UniformUnitInterval runif(0, 1);
|
|
890
|
+
size_t ntry = 50;
|
|
891
|
+
for (size_t iternum = 0; iternum < 50; iternum++)
|
|
892
|
+
{
|
|
893
|
+
s1 = 0; s2 = 0; s3 = 0; s4 = 0;
|
|
894
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
895
|
+
{
|
|
896
|
+
coef = runif(rnd_generator);
|
|
897
|
+
coef2 = coef * coef;
|
|
898
|
+
w_this = buffer_prob[cat];
|
|
899
|
+
s1 = std::fma(w_this, coef, s1);
|
|
900
|
+
s2 = std::fma(w_this, coef2, s2);
|
|
901
|
+
s3 = std::fma(w_this, coef2*coef, s3);
|
|
902
|
+
s4 = std::fma(w_this, coef2*coef2, s4);
|
|
903
|
+
// s1 += buffer_prob[cat] * pw1(coef);
|
|
904
|
+
// s2 += buffer_prob[cat] * pw2(coef);
|
|
905
|
+
// s3 += buffer_prob[cat] * pw3(coef);
|
|
906
|
+
// s4 += buffer_prob[cat] * pw4(coef);
|
|
907
|
+
}
|
|
908
|
+
temp_v = s2 - pw2(s1);
|
|
909
|
+
if (unlikely(temp_v <= 0))
|
|
910
|
+
ntry--;
|
|
911
|
+
else
|
|
912
|
+
sum_kurt += (s4 - 4 * s3 * pw1(s1) + 6 * s2 * pw2(s1) - 4 * s1 * pw3(s1) + pw4(s1)) / pw2(temp_v);
|
|
913
|
+
}
|
|
914
|
+
if (unlikely(!ntry))
|
|
915
|
+
return -HUGE_VAL;
|
|
916
|
+
else if (unlikely(is_na_or_inf(sum_kurt)))
|
|
917
|
+
return -HUGE_VAL;
|
|
918
|
+
else
|
|
919
|
+
return std::fmax(sum_kurt, 0.) / (double)ntry;
|
|
920
|
+
}
|
|
921
|
+
|
|
922
|
+
case SingleCateg:
|
|
923
|
+
{
|
|
924
|
+
double p;
|
|
925
|
+
int ncat_present = ncat;
|
|
926
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
927
|
+
{
|
|
928
|
+
p = buffer_prob[cat];
|
|
929
|
+
if (p == 0)
|
|
930
|
+
ncat_present--;
|
|
931
|
+
else
|
|
932
|
+
sum_kurt += (p - 4 * p * pw1(p) + 6 * p * pw2(p) - 4 * p * pw3(p) + pw4(p)) / pw2(p - pw2(p));
|
|
933
|
+
}
|
|
934
|
+
if (ncat_present <= 1)
|
|
935
|
+
return -HUGE_VAL;
|
|
936
|
+
else if (unlikely(is_na_or_inf(sum_kurt)))
|
|
937
|
+
return -HUGE_VAL;
|
|
938
|
+
else
|
|
939
|
+
return std::fmax(sum_kurt, 0.) / (double)ncat_present;
|
|
940
|
+
}
|
|
941
|
+
}
|
|
942
|
+
|
|
943
|
+
return -1; /* this will never be reached, but CRAN complains otherwise */
|
|
944
|
+
}
|
|
945
|
+
|
|
946
|
+
template <class mapping, class ldouble_safe>
|
|
947
|
+
double calc_kurtosis_weighted(size_t ix_arr[], size_t st, size_t end, int x[], int ncat, double buffer_prob[],
|
|
948
|
+
MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator,
|
|
949
|
+
mapping &restrict w)
|
|
950
|
+
{
|
|
951
|
+
std::vector<ldouble_safe> buffer_cnt(ncat+1, 0.);
|
|
952
|
+
ldouble_safe w_this;
|
|
953
|
+
|
|
954
|
+
for (size_t row = st; row <= end; row++)
|
|
955
|
+
{
|
|
956
|
+
w_this = w[ix_arr[row]];
|
|
957
|
+
if (likely(x[ix_arr[row]] >= 0))
|
|
958
|
+
buffer_cnt[x[ix_arr[row]]] += w_this;
|
|
959
|
+
else
|
|
960
|
+
buffer_cnt[ncat] += w_this;
|
|
961
|
+
}
|
|
962
|
+
|
|
963
|
+
return calc_kurtosis_weighted_internal<mapping, ldouble_safe>(
|
|
964
|
+
buffer_cnt, x, ncat,
|
|
965
|
+
buffer_prob, missing_action, cat_split_type,
|
|
966
|
+
rnd_generator, w);
|
|
967
|
+
}
|
|
968
|
+
|
|
969
|
+
template <class real_t, class ldouble_safe>
|
|
970
|
+
double calc_kurtosis_weighted(size_t nrows, int x[], int ncat, double *restrict buffer_prob,
|
|
971
|
+
MissingAction missing_action, CategSplit cat_split_type,
|
|
972
|
+
RNG_engine &rnd_generator, real_t *restrict w)
|
|
973
|
+
{
|
|
974
|
+
std::vector<ldouble_safe> buffer_cnt(ncat+1, 0.);
|
|
975
|
+
ldouble_safe w_this;
|
|
976
|
+
|
|
977
|
+
for (size_t row = 0; row < nrows; row++)
|
|
978
|
+
{
|
|
979
|
+
w_this = w[row];
|
|
980
|
+
if (likely(x[row] >= 0))
|
|
981
|
+
buffer_cnt[x[row]] += w_this;
|
|
982
|
+
else
|
|
983
|
+
buffer_cnt[ncat] += w_this;
|
|
984
|
+
}
|
|
985
|
+
|
|
986
|
+
return calc_kurtosis_weighted_internal<real_t *restrict, ldouble_safe>(
|
|
987
|
+
buffer_cnt, x, ncat,
|
|
988
|
+
buffer_prob, missing_action, cat_split_type,
|
|
989
|
+
rnd_generator, w);
|
|
990
|
+
}
|
|
991
|
+
|
|
992
|
+
template <class int_t, class ldouble_safe>
|
|
993
|
+
double expected_sd_cat(double p[], size_t n, int_t pos[])
|
|
994
|
+
{
|
|
995
|
+
if (n <= 1) return 0;
|
|
996
|
+
|
|
997
|
+
ldouble_safe cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[1]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[1]]) / 3.0 + p[pos[1]] / 3.0;
|
|
998
|
+
for (size_t cat1 = 2; cat1 < n; cat1++)
|
|
999
|
+
{
|
|
1000
|
+
cum_var += p[pos[cat1]] / 3.0 - square(p[pos[cat1]]) / 3.0;
|
|
1001
|
+
for (size_t cat2 = 0; cat2 < cat1; cat2++)
|
|
1002
|
+
cum_var -= p[pos[cat1]] * p[pos[cat2]] / 2.0;
|
|
1003
|
+
}
|
|
1004
|
+
return std::sqrt(std::fmax(cum_var, (ldouble_safe)0));
|
|
1005
|
+
}
|
|
1006
|
+
|
|
1007
|
+
template <class number, class int_t, class ldouble_safe>
|
|
1008
|
+
double expected_sd_cat(number *restrict counts, double *restrict p, size_t n, int_t *restrict pos)
|
|
1009
|
+
{
|
|
1010
|
+
if (n <= 1) return 0;
|
|
1011
|
+
|
|
1012
|
+
number tot = std::accumulate(pos, pos + n, (number)0, [&counts](number tot, const size_t ix){return tot + counts[ix];});
|
|
1013
|
+
ldouble_safe cnt_div = (ldouble_safe) tot;
|
|
1014
|
+
for (size_t cat = 0; cat < n; cat++)
|
|
1015
|
+
p[pos[cat]] = (ldouble_safe)counts[pos[cat]] / cnt_div;
|
|
1016
|
+
|
|
1017
|
+
return expected_sd_cat<int_t, ldouble_safe>(p, n, pos);
|
|
1018
|
+
}
|
|
1019
|
+
|
|
1020
|
+
template <class number, class int_t, class ldouble_safe>
|
|
1021
|
+
double expected_sd_cat_single(number *restrict counts, double *restrict p, size_t n, int_t *restrict pos, size_t cat_exclude, number cnt)
|
|
1022
|
+
{
|
|
1023
|
+
if (cat_exclude == 0)
|
|
1024
|
+
return expected_sd_cat<number, int_t, ldouble_safe>(counts, p, n-1, pos + 1);
|
|
1025
|
+
|
|
1026
|
+
else if (cat_exclude == (n-1))
|
|
1027
|
+
return expected_sd_cat<number, int_t, ldouble_safe>(counts, p, n-1, pos);
|
|
1028
|
+
|
|
1029
|
+
size_t ix_exclude = pos[cat_exclude];
|
|
1030
|
+
|
|
1031
|
+
ldouble_safe cnt_div = (ldouble_safe) (cnt - counts[ix_exclude]);
|
|
1032
|
+
for (size_t cat = 0; cat < n; cat++)
|
|
1033
|
+
p[pos[cat]] = (ldouble_safe)counts[pos[cat]] / cnt_div;
|
|
1034
|
+
|
|
1035
|
+
ldouble_safe cum_var;
|
|
1036
|
+
if (cat_exclude != 1)
|
|
1037
|
+
cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[1]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[1]]) / 3.0 + p[pos[1]] / 3.0;
|
|
1038
|
+
else
|
|
1039
|
+
cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[2]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[2]]) / 3.0 + p[pos[2]] / 3.0;
|
|
1040
|
+
for (size_t cat1 = (cat_exclude == 1)? 3 : 2; cat1 < n; cat1++)
|
|
1041
|
+
{
|
|
1042
|
+
if (pos[cat1] == ix_exclude) continue;
|
|
1043
|
+
cum_var += p[pos[cat1]] / 3.0 - square(p[pos[cat1]]) / 3.0;
|
|
1044
|
+
for (size_t cat2 = 0; cat2 < cat1; cat2++)
|
|
1045
|
+
{
|
|
1046
|
+
if (pos[cat2] == ix_exclude) continue;
|
|
1047
|
+
cum_var -= p[pos[cat1]] * p[pos[cat2]] / 2.0;
|
|
1048
|
+
}
|
|
1049
|
+
|
|
1050
|
+
}
|
|
1051
|
+
return std::sqrt(std::fmax(cum_var, (ldouble_safe)0));
|
|
1052
|
+
}
|
|
1053
|
+
|
|
1054
|
+
template <class number, class int_t, class ldouble_safe>
|
|
1055
|
+
double expected_sd_cat_internal(int ncat, number *restrict buffer_cnt, ldouble_safe cnt_l,
|
|
1056
|
+
int_t *restrict buffer_pos, double *restrict buffer_prob)
|
|
1057
|
+
{
|
|
1058
|
+
/* move zero-valued to the beginning */
|
|
1059
|
+
std::iota(buffer_pos, buffer_pos + ncat, (int_t)0);
|
|
1060
|
+
int_t st_pos = 0;
|
|
1061
|
+
int ncat_present = 0;
|
|
1062
|
+
int_t temp;
|
|
1063
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
1064
|
+
{
|
|
1065
|
+
if (buffer_cnt[cat])
|
|
1066
|
+
{
|
|
1067
|
+
ncat_present++;
|
|
1068
|
+
buffer_prob[cat] = (ldouble_safe) buffer_cnt[cat] / cnt_l;
|
|
1069
|
+
}
|
|
1070
|
+
|
|
1071
|
+
else
|
|
1072
|
+
{
|
|
1073
|
+
temp = buffer_pos[st_pos];
|
|
1074
|
+
buffer_pos[st_pos] = buffer_pos[cat];
|
|
1075
|
+
buffer_pos[cat] = temp;
|
|
1076
|
+
st_pos++;
|
|
1077
|
+
}
|
|
1078
|
+
}
|
|
1079
|
+
|
|
1080
|
+
if (ncat_present <= 1) return 0;
|
|
1081
|
+
return expected_sd_cat<int_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
|
|
1082
|
+
}
|
|
1083
|
+
|
|
1084
|
+
|
|
1085
|
+
template <class int_t, class ldouble_safe>
|
|
1086
|
+
double expected_sd_cat(size_t *restrict ix_arr, size_t st, size_t end, int x[], int ncat,
|
|
1087
|
+
MissingAction missing_action,
|
|
1088
|
+
size_t *restrict buffer_cnt, int_t *restrict buffer_pos, double buffer_prob[])
|
|
1089
|
+
{
|
|
1090
|
+
/* generate counts */
|
|
1091
|
+
std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
|
|
1092
|
+
size_t cnt = end - st + 1;
|
|
1093
|
+
|
|
1094
|
+
if (missing_action != Fail)
|
|
1095
|
+
{
|
|
1096
|
+
int xval;
|
|
1097
|
+
for (size_t row = st; row <= end; row++)
|
|
1098
|
+
{
|
|
1099
|
+
xval = x[ix_arr[row]];
|
|
1100
|
+
if (unlikely(xval < 0))
|
|
1101
|
+
buffer_cnt[ncat]++;
|
|
1102
|
+
else
|
|
1103
|
+
buffer_cnt[xval]++;
|
|
1104
|
+
}
|
|
1105
|
+
cnt -= buffer_cnt[ncat];
|
|
1106
|
+
if (cnt == 0) return 0;
|
|
1107
|
+
}
|
|
1108
|
+
|
|
1109
|
+
else
|
|
1110
|
+
{
|
|
1111
|
+
for (size_t row = st; row <= end; row++)
|
|
1112
|
+
{
|
|
1113
|
+
if (likely(x[ix_arr[row]] >= 0)) buffer_cnt[x[ix_arr[row]]]++;
|
|
1114
|
+
}
|
|
1115
|
+
}
|
|
1116
|
+
|
|
1117
|
+
return expected_sd_cat_internal<size_t, int_t, ldouble_safe>(ncat, buffer_cnt, cnt, buffer_pos, buffer_prob);
|
|
1118
|
+
}
|
|
1119
|
+
|
|
1120
|
+
template <class mapping, class int_t, class ldouble_safe>
|
|
1121
|
+
double expected_sd_cat_weighted(size_t *restrict ix_arr, size_t st, size_t end, int x[], int ncat,
|
|
1122
|
+
MissingAction missing_action, mapping &restrict w,
|
|
1123
|
+
double *restrict buffer_cnt, int_t *restrict buffer_pos, double *restrict buffer_prob)
|
|
1124
|
+
{
|
|
1125
|
+
/* generate counts */
|
|
1126
|
+
std::fill(buffer_cnt, buffer_cnt + ncat + 1, 0.);
|
|
1127
|
+
ldouble_safe cnt = 0;
|
|
1128
|
+
|
|
1129
|
+
if (missing_action != Fail)
|
|
1130
|
+
{
|
|
1131
|
+
int xval;
|
|
1132
|
+
double w_this;
|
|
1133
|
+
for (size_t row = st; row <= end; row++)
|
|
1134
|
+
{
|
|
1135
|
+
xval = x[ix_arr[row]];
|
|
1136
|
+
w_this = w[ix_arr[row]];
|
|
1137
|
+
|
|
1138
|
+
if (unlikely(xval < 0)) {
|
|
1139
|
+
buffer_cnt[ncat] += w_this;
|
|
1140
|
+
}
|
|
1141
|
+
else {
|
|
1142
|
+
buffer_cnt[xval] += w_this;
|
|
1143
|
+
cnt += w_this;
|
|
1144
|
+
}
|
|
1145
|
+
}
|
|
1146
|
+
if (cnt == 0) return 0;
|
|
1147
|
+
}
|
|
1148
|
+
|
|
1149
|
+
else
|
|
1150
|
+
{
|
|
1151
|
+
for (size_t row = st; row <= end; row++)
|
|
1152
|
+
{
|
|
1153
|
+
if (likely(x[ix_arr[row]] >= 0))
|
|
1154
|
+
{
|
|
1155
|
+
buffer_cnt[x[ix_arr[row]]] += w[ix_arr[row]];
|
|
1156
|
+
}
|
|
1157
|
+
}
|
|
1158
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
1159
|
+
cnt += buffer_cnt[cat];
|
|
1160
|
+
if (unlikely(cnt == 0)) return 0;
|
|
1161
|
+
}
|
|
1162
|
+
|
|
1163
|
+
return expected_sd_cat_internal<double, int_t, ldouble_safe>(ncat, buffer_cnt, cnt, buffer_pos, buffer_prob);
|
|
1164
|
+
}
|
|
1165
|
+
|
|
1166
|
+
/* Note: this isn't exactly comparable to the pooled gain from numeric variables,
|
|
1167
|
+
but among all the possible options, this is what happens to end up in the most
|
|
1168
|
+
similar scale when considering standardized gain. */
|
|
1169
|
+
template <class number, class ldouble_safe>
|
|
1170
|
+
double categ_gain(number cnt_left, number cnt_right,
|
|
1171
|
+
ldouble_safe s_left, ldouble_safe s_right,
|
|
1172
|
+
ldouble_safe base_info, ldouble_safe cnt)
|
|
1173
|
+
{
|
|
1174
|
+
return (
|
|
1175
|
+
base_info -
|
|
1176
|
+
(((cnt_left <= 1)? 0 : ((ldouble_safe)cnt_left * std::log((ldouble_safe)cnt_left))) - s_left) -
|
|
1177
|
+
(((cnt_right <= 1)? 0 : ((ldouble_safe)cnt_right * std::log((ldouble_safe)cnt_right))) - s_right)
|
|
1178
|
+
) / cnt;
|
|
1179
|
+
}
|
|
1180
|
+
|
|
1181
|
+
|
|
1182
|
+
/* A couple notes about gain calculation:
|
|
1183
|
+
|
|
1184
|
+
Here one wants to find the best split point, maximizing either:
|
|
1185
|
+
(1/sigma) * (sigma - (1/n)*(n_left*sigma_left + n_right*sigma_right))
|
|
1186
|
+
or:
|
|
1187
|
+
(1/sigma) * (sigma - (1/2)*(sigma_left + sigma_right))
|
|
1188
|
+
|
|
1189
|
+
All the algorithms here use the sorted-indices approach, which is
|
|
1190
|
+
an exact method (note that there's still room for optimization by adding the
|
|
1191
|
+
unsorted approach for small sample sizes and for sparse matrices).
|
|
1192
|
+
|
|
1193
|
+
A naive approach would move observations one at a time from right
|
|
1194
|
+
to left using this formula:
|
|
1195
|
+
sigma = (ssq - s^2/n) / n
|
|
1196
|
+
ssq = sum(x^2)
|
|
1197
|
+
s = sum(x)
|
|
1198
|
+
But such approach has poor numerical precision, and this library is
|
|
1199
|
+
aimed precisely at cases in which there are outliers in the data.
|
|
1200
|
+
It's possible to improve the numerical precision by standardizing the
|
|
1201
|
+
data beforehand, but this library uses instead a more exact two-pass
|
|
1202
|
+
sigma calculation observation-by-observation (from left to right and
|
|
1203
|
+
from right to left, keeping the calculations of the first pass in an
|
|
1204
|
+
array and calculating gain in the second pass), but there's
|
|
1205
|
+
other methods too.
|
|
1206
|
+
|
|
1207
|
+
If one is aiming at maximizing the pooled gain, it's possible to
|
|
1208
|
+
simplify either the gain or the increase in gain without involving
|
|
1209
|
+
'ssq'. Assuming one already has 'ssq' and 's' calculated for the left and
|
|
1210
|
+
right partitions, and one wants to move one ovservation from right to left,
|
|
1211
|
+
the following hold:
|
|
1212
|
+
s_right = s - s_left
|
|
1213
|
+
ssq_right = ssq - ssq_left
|
|
1214
|
+
n_right = n - n_left
|
|
1215
|
+
If one then moves observation x, these are updated as follows:
|
|
1216
|
+
s_left_new = s_left + x
|
|
1217
|
+
s_right_new = s - s_left - x
|
|
1218
|
+
ssq_left_new = ssq_left + x^2
|
|
1219
|
+
ssq_right_new = ssq - ssq_left - x^2
|
|
1220
|
+
n_left_new = n_left + 1
|
|
1221
|
+
n_right_new = n - n_left - 1
|
|
1222
|
+
Gain is then:
|
|
1223
|
+
(1/sigma) * (sigma - (1/n)*({ssq_left_new - s_left_new^2/n_left_new} + {ssq_right_new - s_right_new^2/n_right_new}))
|
|
1224
|
+
Which simplifies to:
|
|
1225
|
+
1 - (1/(sigma*n))(ssq - ( (s_left + x)^2/(n_left+1) + (s - (s_left + x))^2/(n - (n_left+1)) ))
|
|
1226
|
+
Since 'sigma', n', and 'ssq' are constant, they can be ignored when determining the
|
|
1227
|
+
maximum gain - that is, one is interest in finding the point that maximizes:
|
|
1228
|
+
(s_left+x)^2/(n_left+1) + (s-(s_left+x))^2/(n-(n_left+1))
|
|
1229
|
+
And this calculation will be robust-enough when dealing with numbers that were
|
|
1230
|
+
already standardized beforehand, as the extended model does at each step.
|
|
1231
|
+
Note however that, when fitting this model, one is usually interested in evaluating
|
|
1232
|
+
the actual gain, standardized by the standard deviation, as it will try different
|
|
1233
|
+
linear combinations which will give different standard deviations, so this simpler
|
|
1234
|
+
formula cannot be applied unless only one linear combination is probed.
|
|
1235
|
+
|
|
1236
|
+
One can also look at:
|
|
1237
|
+
diff_gain = (1/sigma) * (gain_new - gain)
|
|
1238
|
+
Which can be simplified to something that doesn't include sums of squares:
|
|
1239
|
+
(1/(sigma*n))*( -s_left^2/n_left - (s-s_left)^2/(n-n_left) + (s_left+x)^2/(n_left+1) + (s-(s_left+x))^2/(n-(n_left+1)) )
|
|
1240
|
+
And this calculation would in theory allow getting the actual standardized gain.
|
|
1241
|
+
In practice however, this calculation can have poor numerical precision when the
|
|
1242
|
+
sample size is large, so the functions here do not even attempt at calculating it,
|
|
1243
|
+
and this is the reason why the two-pass approach is preferred.
|
|
1244
|
+
|
|
1245
|
+
The averaged SD formula unfortunately doesn't reduce to something that would involve
|
|
1246
|
+
only sums.
|
|
1247
|
+
*/
|
|
1248
|
+
|
|
1249
|
+
/* TODO: maybe it's not a good idea to use the two-pass approach with un-standardized
|
|
1250
|
+
variables at large sample sizes (ndim=1), considering that they come in sorted order.
|
|
1251
|
+
Maybe it should instead use sums of centered squares: sigma = sqrt((x-mean(x))^2/n)
|
|
1252
|
+
The sums of centered squares method is also likely to be more precise. */
|
|
1253
|
+
|
|
1254
|
+
template <class real_t>
|
|
1255
|
+
double midpoint(real_t x, real_t y)
|
|
1256
|
+
{
|
|
1257
|
+
real_t m = x + (y-x)/(real_t)2;
|
|
1258
|
+
if (likely((double)m < (double)y))
|
|
1259
|
+
return m;
|
|
1260
|
+
else {
|
|
1261
|
+
m = std::nextafter(m, y);
|
|
1262
|
+
if (m > x && m < y)
|
|
1263
|
+
return m;
|
|
1264
|
+
else
|
|
1265
|
+
return x;
|
|
1266
|
+
}
|
|
1267
|
+
}
|
|
1268
|
+
|
|
1269
|
+
template <class real_t>
|
|
1270
|
+
double midpoint_with_reorder(real_t x, real_t y)
|
|
1271
|
+
{
|
|
1272
|
+
if (x < y)
|
|
1273
|
+
return midpoint(x, y);
|
|
1274
|
+
else
|
|
1275
|
+
return midpoint(y, x);
|
|
1276
|
+
}
|
|
1277
|
+
|
|
1278
|
+
#define sd_gain(sd, sd_left, sd_right) (1. - ((sd_left) + (sd_right)) / (2. * (sd)))
|
|
1279
|
+
#define pooled_gain(sd, cnt, sd_left, sd_right, cnt_left, cnt_right) \
|
|
1280
|
+
(1. - (1./(sd))*( ( ((real_t)(cnt_left))/(cnt) )*(sd_left) + ( ((real_t)(cnt_right)/(cnt)) )*(sd_right) ))
|
|
1281
|
+
|
|
1282
|
+
|
|
1283
|
+
/* TODO: make this a compensated sum */
|
|
1284
|
+
template <class real_t, class real_t_>
|
|
1285
|
+
double find_split_rel_gain_t(real_t_ *restrict x, size_t n, double &restrict split_point)
|
|
1286
|
+
{
|
|
1287
|
+
real_t this_gain;
|
|
1288
|
+
real_t best_gain = -HUGE_VAL;
|
|
1289
|
+
real_t x1 = 0, x2 = 0;
|
|
1290
|
+
real_t sum_left = 0, sum_right = 0, sum_tot = 0;
|
|
1291
|
+
for (size_t row = 0; row < n; row++)
|
|
1292
|
+
sum_tot += x[row];
|
|
1293
|
+
for (size_t row = 0; row < n-1; row++)
|
|
1294
|
+
{
|
|
1295
|
+
sum_left += x[row];
|
|
1296
|
+
if (x[row] == x[row+1])
|
|
1297
|
+
continue;
|
|
1298
|
+
|
|
1299
|
+
sum_right = sum_tot - sum_left;
|
|
1300
|
+
this_gain = sum_left * (sum_left / (real_t)(row+1))
|
|
1301
|
+
+ sum_right * (sum_right / (real_t)(n-row-1));
|
|
1302
|
+
if (this_gain > best_gain)
|
|
1303
|
+
{
|
|
1304
|
+
best_gain = this_gain;
|
|
1305
|
+
x1 = x[row]; x2 = x[row+1];
|
|
1306
|
+
}
|
|
1307
|
+
}
|
|
1308
|
+
|
|
1309
|
+
if (best_gain <= -HUGE_VAL)
|
|
1310
|
+
return best_gain;
|
|
1311
|
+
split_point = midpoint(x1, x2);
|
|
1312
|
+
return std::fmax((double)best_gain, std::numeric_limits<double>::epsilon());
|
|
1313
|
+
}
|
|
1314
|
+
|
|
1315
|
+
template <class real_t_, class ldouble_safe>
|
|
1316
|
+
double find_split_rel_gain(real_t_ *restrict x, size_t n, double &restrict split_point)
|
|
1317
|
+
{
|
|
1318
|
+
if (n < THRESHOLD_LONG_DOUBLE)
|
|
1319
|
+
return find_split_rel_gain_t<double, real_t_>(x, n, split_point);
|
|
1320
|
+
else
|
|
1321
|
+
return find_split_rel_gain_t<ldouble_safe, real_t_>(x, n, split_point);
|
|
1322
|
+
}
|
|
1323
|
+
|
|
1324
|
+
/* Note: there is no 'weighted' version of 'find_split_rel_gain' with unindexed 'x', because calling it would
|
|
1325
|
+
imply having to argsort the 'x' values in order to sort the weights, which is less efficient. */
|
|
1326
|
+
|
|
1327
|
+
template <class real_t, class real_t_>
|
|
1328
|
+
double find_split_rel_gain_t(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &split_point, size_t &restrict split_ix)
|
|
1329
|
+
{
|
|
1330
|
+
real_t this_gain;
|
|
1331
|
+
real_t best_gain = -HUGE_VAL;
|
|
1332
|
+
split_ix = 0; /* <- avoid out-of-bounds at the end */
|
|
1333
|
+
real_t sum_left = 0, sum_right = 0, sum_tot = 0;
|
|
1334
|
+
for (size_t row = st; row <= end; row++)
|
|
1335
|
+
sum_tot += x[ix_arr[row]] - xmean;
|
|
1336
|
+
for (size_t row = st; row < end; row++)
|
|
1337
|
+
{
|
|
1338
|
+
sum_left += x[ix_arr[row]] - xmean;
|
|
1339
|
+
if (x[ix_arr[row]] == x[ix_arr[row+1]])
|
|
1340
|
+
continue;
|
|
1341
|
+
|
|
1342
|
+
sum_right = sum_tot - sum_left;
|
|
1343
|
+
this_gain = sum_left * (sum_left / (real_t)(row - st + 1))
|
|
1344
|
+
+ sum_right * (sum_right / (real_t)(end - row));
|
|
1345
|
+
if (this_gain > best_gain)
|
|
1346
|
+
{
|
|
1347
|
+
best_gain = this_gain;
|
|
1348
|
+
split_ix = row;
|
|
1349
|
+
}
|
|
1350
|
+
}
|
|
1351
|
+
|
|
1352
|
+
if (best_gain <= -HUGE_VAL)
|
|
1353
|
+
return best_gain;
|
|
1354
|
+
split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
|
|
1355
|
+
return std::fmax((double)best_gain, std::numeric_limits<double>::epsilon());
|
|
1356
|
+
}
|
|
1357
|
+
|
|
1358
|
+
template <class real_t_, class ldouble_safe>
|
|
1359
|
+
double find_split_rel_gain(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &restrict split_point, size_t &restrict split_ix)
|
|
1360
|
+
{
|
|
1361
|
+
if ((end-st+1) < THRESHOLD_LONG_DOUBLE)
|
|
1362
|
+
return find_split_rel_gain_t<double, real_t_>(x, xmean, ix_arr, st, end, split_point, split_ix);
|
|
1363
|
+
else
|
|
1364
|
+
return find_split_rel_gain_t<ldouble_safe, real_t_>(x, xmean, ix_arr, st, end, split_point, split_ix);
|
|
1365
|
+
}
|
|
1366
|
+
|
|
1367
|
+
template <class real_t, class real_t_, class mapping>
|
|
1368
|
+
double find_split_rel_gain_weighted_t(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &split_point, size_t &restrict split_ix, mapping &restrict w)
|
|
1369
|
+
{
|
|
1370
|
+
real_t this_gain;
|
|
1371
|
+
real_t best_gain = -HUGE_VAL;
|
|
1372
|
+
split_ix = 0; /* <- avoid out-of-bounds at the end */
|
|
1373
|
+
real_t sum_left = 0, sum_right = 0, sum_tot = 0, sumw = 0, sumw_tot = 0;
|
|
1374
|
+
|
|
1375
|
+
for (size_t row = st; row <= end; row++)
|
|
1376
|
+
sumw_tot += w[ix_arr[row]];
|
|
1377
|
+
for (size_t row = st; row <= end; row++)
|
|
1378
|
+
sum_tot += x[ix_arr[row]] - xmean;
|
|
1379
|
+
for (size_t row = st; row < end; row++)
|
|
1380
|
+
{
|
|
1381
|
+
sumw += w[ix_arr[row]];
|
|
1382
|
+
sum_left += x[ix_arr[row]] - xmean;
|
|
1383
|
+
if (x[ix_arr[row]] == x[ix_arr[row+1]])
|
|
1384
|
+
continue;
|
|
1385
|
+
|
|
1386
|
+
sum_right = sum_tot - sum_left;
|
|
1387
|
+
this_gain = sum_left * (sum_left / sumw)
|
|
1388
|
+
+ sum_right * (sum_right / (sumw_tot - sumw));
|
|
1389
|
+
if (this_gain > best_gain)
|
|
1390
|
+
{
|
|
1391
|
+
best_gain = this_gain;
|
|
1392
|
+
split_ix = row;
|
|
1393
|
+
}
|
|
1394
|
+
}
|
|
1395
|
+
|
|
1396
|
+
if (best_gain <= -HUGE_VAL)
|
|
1397
|
+
return best_gain;
|
|
1398
|
+
split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
|
|
1399
|
+
return std::fmax((double)best_gain, std::numeric_limits<double>::epsilon());
|
|
1400
|
+
}
|
|
1401
|
+
|
|
1402
|
+
template <class real_t_, class mapping, class ldouble_safe>
|
|
1403
|
+
double find_split_rel_gain_weighted(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
|
|
1404
|
+
{
|
|
1405
|
+
if ((end-st+1) < THRESHOLD_LONG_DOUBLE)
|
|
1406
|
+
return find_split_rel_gain_weighted_t<double, real_t_, mapping>(x, xmean, ix_arr, st, end, split_point, split_ix, w);
|
|
1407
|
+
else
|
|
1408
|
+
return find_split_rel_gain_weighted_t<ldouble_safe, real_t_, mapping>(x, xmean, ix_arr, st, end, split_point, split_ix, w);
|
|
1409
|
+
}
|
|
1410
|
+
|
|
1411
|
+
|
|
1412
|
+
template <class real_t, class real_t_>
|
|
1413
|
+
real_t calc_sd_right_to_left(real_t_ *restrict x, size_t n, double *restrict sd_arr)
|
|
1414
|
+
{
|
|
1415
|
+
real_t running_mean = 0;
|
|
1416
|
+
real_t running_ssq = 0;
|
|
1417
|
+
real_t mean_prev = x[n-1];
|
|
1418
|
+
for (size_t row = 0; row < n-1; row++)
|
|
1419
|
+
{
|
|
1420
|
+
running_mean += (x[n-row-1] - running_mean) / (real_t)(row+1);
|
|
1421
|
+
running_ssq += (x[n-row-1] - running_mean) * (x[n-row-1] - mean_prev);
|
|
1422
|
+
mean_prev = running_mean;
|
|
1423
|
+
sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / (real_t)(row+1));
|
|
1424
|
+
}
|
|
1425
|
+
running_mean += (x[0] - running_mean) / (real_t)n;
|
|
1426
|
+
running_ssq += (x[0] - running_mean) * (x[0] - mean_prev);
|
|
1427
|
+
return std::sqrt(running_ssq / (real_t)n);
|
|
1428
|
+
}
|
|
1429
|
+
|
|
1430
|
+
template <class real_t_, class ldouble_safe>
|
|
1431
|
+
ldouble_safe calc_sd_right_to_left_weighted(real_t_ *restrict x, size_t n, double *restrict sd_arr,
|
|
1432
|
+
double *restrict w, ldouble_safe &cumw, size_t *restrict sorted_ix)
|
|
1433
|
+
{
|
|
1434
|
+
ldouble_safe running_mean = 0;
|
|
1435
|
+
ldouble_safe running_ssq = 0;
|
|
1436
|
+
ldouble_safe mean_prev = x[sorted_ix[n-1]];
|
|
1437
|
+
ldouble_safe cnt = 0;
|
|
1438
|
+
double w_this;
|
|
1439
|
+
for (size_t row = 0; row < n-1; row++)
|
|
1440
|
+
{
|
|
1441
|
+
w_this = w[sorted_ix[n-row-1]];
|
|
1442
|
+
cnt += w_this;
|
|
1443
|
+
running_mean += w_this * (x[sorted_ix[n-row-1]] - running_mean) / cnt;
|
|
1444
|
+
running_ssq += w_this * ((x[sorted_ix[n-row-1]] - running_mean) * (x[sorted_ix[n-row-1]] - mean_prev));
|
|
1445
|
+
mean_prev = running_mean;
|
|
1446
|
+
sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / cnt);
|
|
1447
|
+
}
|
|
1448
|
+
w_this = w[sorted_ix[0]];
|
|
1449
|
+
cnt += w_this;
|
|
1450
|
+
running_mean += (x[sorted_ix[0]] - running_mean) / cnt;
|
|
1451
|
+
running_ssq += w_this * ((x[sorted_ix[0]] - running_mean) * (x[sorted_ix[0]] - mean_prev));
|
|
1452
|
+
cumw = cnt;
|
|
1453
|
+
return std::sqrt(running_ssq / cnt);
|
|
1454
|
+
}
|
|
1455
|
+
|
|
1456
|
+
template <class real_t, class real_t_>
|
|
1457
|
+
real_t calc_sd_right_to_left(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr)
|
|
1458
|
+
{
|
|
1459
|
+
real_t running_mean = 0;
|
|
1460
|
+
real_t running_ssq = 0;
|
|
1461
|
+
real_t mean_prev = x[ix_arr[end]] - xmean;
|
|
1462
|
+
size_t n = end - st + 1;
|
|
1463
|
+
for (size_t row = 0; row < n-1; row++)
|
|
1464
|
+
{
|
|
1465
|
+
running_mean += ((x[ix_arr[end-row]] - xmean) - running_mean) / (real_t)(row+1);
|
|
1466
|
+
running_ssq += ((x[ix_arr[end-row]] - xmean) - running_mean) * ((x[ix_arr[end-row]] - xmean) - mean_prev);
|
|
1467
|
+
mean_prev = running_mean;
|
|
1468
|
+
sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / (real_t)(row+1));
|
|
1469
|
+
}
|
|
1470
|
+
running_mean += ((x[ix_arr[st]] - xmean) - running_mean) / (real_t)n;
|
|
1471
|
+
running_ssq += ((x[ix_arr[st]] - xmean) - running_mean) * ((x[ix_arr[st]] - xmean) - mean_prev);
|
|
1472
|
+
return std::sqrt(running_ssq / (real_t)n);
|
|
1473
|
+
}
|
|
1474
|
+
|
|
1475
|
+
template <class real_t_, class mapping, class ldouble_safe>
|
|
1476
|
+
ldouble_safe calc_sd_right_to_left_weighted(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end,
|
|
1477
|
+
double *restrict sd_arr, mapping &restrict w, ldouble_safe &cumw)
|
|
1478
|
+
{
|
|
1479
|
+
ldouble_safe running_mean = 0;
|
|
1480
|
+
ldouble_safe running_ssq = 0;
|
|
1481
|
+
real_t_ mean_prev = x[ix_arr[end]] - xmean;
|
|
1482
|
+
size_t n = end - st + 1;
|
|
1483
|
+
ldouble_safe cnt = 0;
|
|
1484
|
+
double w_this;
|
|
1485
|
+
for (size_t row = 0; row < n-1; row++)
|
|
1486
|
+
{
|
|
1487
|
+
w_this = w[ix_arr[end-row]];
|
|
1488
|
+
cnt += w_this;
|
|
1489
|
+
running_mean += w_this * ((x[ix_arr[end-row]] - xmean) - running_mean) / cnt;
|
|
1490
|
+
running_ssq += w_this * (((x[ix_arr[end-row]] - xmean) - running_mean) * ((x[ix_arr[end-row]] - xmean) - mean_prev));
|
|
1491
|
+
mean_prev = running_mean;
|
|
1492
|
+
sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / cnt);
|
|
1493
|
+
}
|
|
1494
|
+
w_this = w[ix_arr[st]];
|
|
1495
|
+
cnt += w_this;
|
|
1496
|
+
running_mean += ((x[ix_arr[st]] - xmean) - running_mean) / cnt;
|
|
1497
|
+
running_ssq += w_this * (((x[ix_arr[st]] - xmean) - running_mean) * ((x[ix_arr[st]] - xmean) - mean_prev));
|
|
1498
|
+
cumw = cnt;
|
|
1499
|
+
return std::sqrt(running_ssq / cnt);
|
|
1500
|
+
}
|
|
1501
|
+
|
|
1502
|
+
template <class real_t, class real_t_>
|
|
1503
|
+
double find_split_std_gain_t(real_t_ *restrict x, size_t n, double *restrict sd_arr,
|
|
1504
|
+
GainCriterion criterion, double min_gain, double &restrict split_point)
|
|
1505
|
+
{
|
|
1506
|
+
real_t full_sd = calc_sd_right_to_left<real_t>(x, n, sd_arr);
|
|
1507
|
+
real_t running_mean = 0;
|
|
1508
|
+
real_t running_ssq = 0;
|
|
1509
|
+
real_t mean_prev = x[0];
|
|
1510
|
+
real_t best_gain = -HUGE_VAL;
|
|
1511
|
+
real_t this_sd, this_gain;
|
|
1512
|
+
real_t n_ = (real_t)n;
|
|
1513
|
+
size_t best_ix = 0;
|
|
1514
|
+
for (size_t row = 0; row < n-1; row++)
|
|
1515
|
+
{
|
|
1516
|
+
running_mean += (x[row] - running_mean) / (real_t)(row+1);
|
|
1517
|
+
running_ssq += (x[row] - running_mean) * (x[row] - mean_prev);
|
|
1518
|
+
mean_prev = running_mean;
|
|
1519
|
+
if (x[row] == x[row+1])
|
|
1520
|
+
continue;
|
|
1521
|
+
|
|
1522
|
+
this_sd = (row == 0)? 0. : std::sqrt(running_ssq / (real_t)(row+1));
|
|
1523
|
+
this_gain = (criterion == Pooled)?
|
|
1524
|
+
pooled_gain(full_sd, n_, this_sd, sd_arr[row+1], row+1, n-row-1)
|
|
1525
|
+
:
|
|
1526
|
+
sd_gain(full_sd, this_sd, sd_arr[row+1]);
|
|
1527
|
+
if (this_gain > best_gain && this_gain > min_gain)
|
|
1528
|
+
{
|
|
1529
|
+
best_gain = this_gain;
|
|
1530
|
+
best_ix = row;
|
|
1531
|
+
}
|
|
1532
|
+
}
|
|
1533
|
+
|
|
1534
|
+
if (best_gain > -HUGE_VAL)
|
|
1535
|
+
split_point = midpoint(x[best_ix], x[best_ix+1]);
|
|
1536
|
+
|
|
1537
|
+
return best_gain;
|
|
1538
|
+
}
|
|
1539
|
+
|
|
1540
|
+
template <class real_t_, class ldouble_safe>
|
|
1541
|
+
double find_split_std_gain(real_t_ *restrict x, size_t n, double *restrict sd_arr,
|
|
1542
|
+
GainCriterion criterion, double min_gain, double &restrict split_point)
|
|
1543
|
+
{
|
|
1544
|
+
if (n < THRESHOLD_LONG_DOUBLE)
|
|
1545
|
+
return find_split_std_gain_t<double, real_t_>(x, n, sd_arr, criterion, min_gain, split_point);
|
|
1546
|
+
else
|
|
1547
|
+
return find_split_std_gain_t<ldouble_safe, real_t_>(x, n, sd_arr, criterion, min_gain, split_point);
|
|
1548
|
+
}
|
|
1549
|
+
|
|
1550
|
+
template <class real_t, class ldouble_safe>
|
|
1551
|
+
double find_split_std_gain_weighted(real_t *restrict x, size_t n, double *restrict sd_arr,
|
|
1552
|
+
GainCriterion criterion, double min_gain, double &restrict split_point,
|
|
1553
|
+
double *restrict w, size_t *restrict sorted_ix)
|
|
1554
|
+
{
|
|
1555
|
+
ldouble_safe cumw;
|
|
1556
|
+
double full_sd = calc_sd_right_to_left_weighted(x, n, sd_arr, w, cumw, sorted_ix);
|
|
1557
|
+
ldouble_safe running_mean = 0;
|
|
1558
|
+
ldouble_safe running_ssq = 0;
|
|
1559
|
+
ldouble_safe mean_prev = x[sorted_ix[0]];
|
|
1560
|
+
double best_gain = -HUGE_VAL;
|
|
1561
|
+
double this_sd, this_gain;
|
|
1562
|
+
double w_this;
|
|
1563
|
+
ldouble_safe currw = 0;
|
|
1564
|
+
size_t best_ix = 0;
|
|
1565
|
+
|
|
1566
|
+
for (size_t row = 0; row < n-1; row++)
|
|
1567
|
+
{
|
|
1568
|
+
w_this = w[sorted_ix[row]];
|
|
1569
|
+
currw += w_this;
|
|
1570
|
+
running_mean += w_this * (x[sorted_ix[row]] - running_mean) / currw;
|
|
1571
|
+
running_ssq += w_this * ((x[sorted_ix[row]] - running_mean) * (x[sorted_ix[row]] - mean_prev));
|
|
1572
|
+
mean_prev = running_mean;
|
|
1573
|
+
if (x[sorted_ix[row]] == x[sorted_ix[row+1]])
|
|
1574
|
+
continue;
|
|
1575
|
+
|
|
1576
|
+
this_sd = (row == 0)? 0. : std::sqrt(running_ssq / currw);
|
|
1577
|
+
this_gain = (criterion == Pooled)?
|
|
1578
|
+
pooled_gain(full_sd, cumw, this_sd, sd_arr[row+1], currw, cumw-currw)
|
|
1579
|
+
:
|
|
1580
|
+
sd_gain(full_sd, this_sd, sd_arr[row+1]);
|
|
1581
|
+
if (this_gain > best_gain && this_gain > min_gain)
|
|
1582
|
+
{
|
|
1583
|
+
best_gain = this_gain;
|
|
1584
|
+
best_ix = row;
|
|
1585
|
+
}
|
|
1586
|
+
}
|
|
1587
|
+
|
|
1588
|
+
if (best_gain > -HUGE_VAL)
|
|
1589
|
+
split_point = midpoint(x[sorted_ix[best_ix]], x[sorted_ix[best_ix+1]]);
|
|
1590
|
+
|
|
1591
|
+
return best_gain;
|
|
1592
|
+
}
|
|
1593
|
+
|
|
1594
|
+
template <class real_t, class real_t_>
|
|
1595
|
+
double find_split_std_gain_t(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr,
|
|
1596
|
+
GainCriterion criterion, double min_gain, double &restrict split_point, size_t &restrict split_ix)
|
|
1597
|
+
{
|
|
1598
|
+
real_t full_sd = calc_sd_right_to_left<real_t>(x, xmean, ix_arr, st, end, sd_arr);
|
|
1599
|
+
real_t running_mean = 0;
|
|
1600
|
+
real_t running_ssq = 0;
|
|
1601
|
+
real_t mean_prev = x[ix_arr[st]] - xmean;
|
|
1602
|
+
real_t best_gain = -HUGE_VAL;
|
|
1603
|
+
real_t n = (real_t)(end - st + 1);
|
|
1604
|
+
real_t this_sd, this_gain;
|
|
1605
|
+
split_ix = st;
|
|
1606
|
+
for (size_t row = st; row < end; row++)
|
|
1607
|
+
{
|
|
1608
|
+
running_mean += ((x[ix_arr[row]] - xmean) - running_mean) / (real_t)(row-st+1);
|
|
1609
|
+
running_ssq += ((x[ix_arr[row]] - xmean) - running_mean) * ((x[ix_arr[row]] - xmean) - mean_prev);
|
|
1610
|
+
mean_prev = running_mean;
|
|
1611
|
+
if (x[ix_arr[row]] == x[ix_arr[row+1]])
|
|
1612
|
+
continue;
|
|
1613
|
+
|
|
1614
|
+
this_sd = (row == st)? 0. : std::sqrt(running_ssq / (real_t)(row-st+1));
|
|
1615
|
+
this_gain = (criterion == Pooled)?
|
|
1616
|
+
pooled_gain(full_sd, n, this_sd, sd_arr[row-st+1], row-st+1, end-row)
|
|
1617
|
+
:
|
|
1618
|
+
sd_gain(full_sd, this_sd, sd_arr[row-st+1]);
|
|
1619
|
+
if (this_gain > best_gain && this_gain > min_gain)
|
|
1620
|
+
{
|
|
1621
|
+
best_gain = this_gain;
|
|
1622
|
+
split_ix = row;
|
|
1623
|
+
}
|
|
1624
|
+
}
|
|
1625
|
+
|
|
1626
|
+
if (best_gain > -HUGE_VAL)
|
|
1627
|
+
split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
|
|
1628
|
+
|
|
1629
|
+
return best_gain;
|
|
1630
|
+
}
|
|
1631
|
+
|
|
1632
|
+
template <class real_t_, class ldouble_safe>
|
|
1633
|
+
double find_split_std_gain(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr,
|
|
1634
|
+
GainCriterion criterion, double min_gain, double &restrict split_point, size_t &restrict split_ix)
|
|
1635
|
+
{
|
|
1636
|
+
if ((end-st+1) < THRESHOLD_LONG_DOUBLE)
|
|
1637
|
+
return find_split_std_gain_t<double, real_t_>(x, xmean, ix_arr, st, end, sd_arr, criterion, min_gain, split_point, split_ix);
|
|
1638
|
+
else
|
|
1639
|
+
return find_split_std_gain_t<ldouble_safe, real_t_>(x, xmean, ix_arr, st, end, sd_arr, criterion, min_gain, split_point, split_ix);
|
|
1640
|
+
}
|
|
1641
|
+
|
|
1642
|
+
template <class real_t, class mapping, class ldouble_safe>
|
|
1643
|
+
double find_split_std_gain_weighted(real_t *restrict x, real_t xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr,
|
|
1644
|
+
GainCriterion criterion, double min_gain, double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
|
|
1645
|
+
{
|
|
1646
|
+
ldouble_safe cumw;
|
|
1647
|
+
double full_sd = calc_sd_right_to_left_weighted(x, xmean, ix_arr, st, end, sd_arr, w, cumw);
|
|
1648
|
+
ldouble_safe running_mean = 0;
|
|
1649
|
+
ldouble_safe running_ssq = 0;
|
|
1650
|
+
ldouble_safe mean_prev = x[ix_arr[st]] - xmean;
|
|
1651
|
+
double best_gain = -HUGE_VAL;
|
|
1652
|
+
ldouble_safe currw = 0;
|
|
1653
|
+
double this_sd, this_gain;
|
|
1654
|
+
double w_this;
|
|
1655
|
+
split_ix = st;
|
|
1656
|
+
|
|
1657
|
+
for (size_t row = st; row < end; row++)
|
|
1658
|
+
{
|
|
1659
|
+
w_this = w[ix_arr[row]];
|
|
1660
|
+
currw += w_this;
|
|
1661
|
+
running_mean += w_this * ((x[ix_arr[row]] - xmean) - running_mean) / currw;
|
|
1662
|
+
running_ssq += w_this * (((x[ix_arr[row]] - xmean) - running_mean) * ((x[ix_arr[row]] - xmean) - mean_prev));
|
|
1663
|
+
mean_prev = running_mean;
|
|
1664
|
+
if (x[ix_arr[row]] == x[ix_arr[row+1]])
|
|
1665
|
+
continue;
|
|
1666
|
+
|
|
1667
|
+
this_sd = (row == st)? 0. : std::sqrt(running_ssq / currw);
|
|
1668
|
+
this_gain = (criterion == Pooled)?
|
|
1669
|
+
pooled_gain(full_sd, cumw, this_sd, sd_arr[row-st+1], currw, cumw-currw)
|
|
1670
|
+
:
|
|
1671
|
+
sd_gain(full_sd, this_sd, sd_arr[row-st+1]);
|
|
1672
|
+
if (this_gain > best_gain && this_gain > min_gain)
|
|
1673
|
+
{
|
|
1674
|
+
best_gain = this_gain;
|
|
1675
|
+
split_ix = row;
|
|
1676
|
+
}
|
|
1677
|
+
}
|
|
1678
|
+
|
|
1679
|
+
if (best_gain > -HUGE_VAL)
|
|
1680
|
+
split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
|
|
1681
|
+
|
|
1682
|
+
return best_gain;
|
|
1683
|
+
}
|
|
1684
|
+
|
|
1685
|
+
#ifndef _FOR_R
|
|
1686
|
+
#if defined(__clang__)
|
|
1687
|
+
#pragma clang diagnostic push
|
|
1688
|
+
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
|
1689
|
+
#endif
|
|
1690
|
+
#endif
|
|
1691
|
+
|
|
1692
|
+
#ifndef _FOR_R
|
|
1693
|
+
[[gnu::optimize("Ofast")]]
|
|
1694
|
+
#endif
|
|
1695
|
+
static inline void xpy1(double *restrict x, double *restrict y, size_t n)
|
|
1696
|
+
{
|
|
1697
|
+
for (size_t ix = 0; ix < n; ix++) y[ix] += x[ix];
|
|
1698
|
+
}
|
|
1699
|
+
|
|
1700
|
+
#ifndef _FOR_R
|
|
1701
|
+
[[gnu::optimize("Ofast")]]
|
|
1702
|
+
#endif
|
|
1703
|
+
static inline void axpy1(const double a, double *restrict x, double *restrict y, size_t n)
|
|
1704
|
+
{
|
|
1705
|
+
for (size_t ix = 0; ix < n; ix++) y[ix] = std::fma(a, x[ix], y[ix]);
|
|
1706
|
+
}
|
|
1707
|
+
|
|
1708
|
+
#ifndef _FOR_R
|
|
1709
|
+
[[gnu::optimize("Ofast")]]
|
|
1710
|
+
#endif
|
|
1711
|
+
static inline void xpy1(double *restrict xval, size_t ind[], size_t nnz, double *restrict y)
|
|
1712
|
+
{
|
|
1713
|
+
for (size_t ix = 0; ix < nnz; ix++) y[ind[ix]] += xval[ix];
|
|
1714
|
+
}
|
|
1715
|
+
|
|
1716
|
+
#ifndef _FOR_R
|
|
1717
|
+
[[gnu::optimize("Ofast")]]
|
|
1718
|
+
#endif
|
|
1719
|
+
static inline void axpy1(const double a, double *restrict xval, size_t ind[], size_t nnz, double *restrict y)
|
|
1720
|
+
{
|
|
1721
|
+
for (size_t ix = 0; ix < nnz; ix++) y[ind[ix]] = std::fma(a, xval[ix], y[ind[ix]]);
|
|
1722
|
+
}
|
|
1723
|
+
|
|
1724
|
+
#ifndef _FOR_R
|
|
1725
|
+
#if defined(__clang__)
|
|
1726
|
+
#pragma clang diagnostic pop
|
|
1727
|
+
#endif
|
|
1728
|
+
#endif
|
|
1729
|
+
|
|
1730
|
+
template <class real_t, class ldouble_safe>
|
|
1731
|
+
double find_split_full_gain(real_t *restrict x, size_t st, size_t end, size_t *restrict ix_arr,
|
|
1732
|
+
size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
|
|
1733
|
+
double *restrict X_row_major, size_t ncols,
|
|
1734
|
+
double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
|
|
1735
|
+
double *restrict buffer_sum_left, double *restrict buffer_sum_tot,
|
|
1736
|
+
size_t &restrict split_ix, double &restrict split_point,
|
|
1737
|
+
bool x_uses_ix_arr)
|
|
1738
|
+
{
|
|
1739
|
+
if (end <= st) return -HUGE_VAL;
|
|
1740
|
+
if (cols_use != NULL && ncols_use && (double)ncols_use / (double)ncols < 0.1)
|
|
1741
|
+
force_cols_use = true;
|
|
1742
|
+
|
|
1743
|
+
memset(buffer_sum_tot, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
|
|
1744
|
+
if (Xr_indptr == NULL)
|
|
1745
|
+
{
|
|
1746
|
+
if (force_cols_use)
|
|
1747
|
+
{
|
|
1748
|
+
double *restrict ptr_row;
|
|
1749
|
+
for (size_t row = st; row <= end; row++)
|
|
1750
|
+
{
|
|
1751
|
+
ptr_row = X_row_major + ix_arr[row]*ncols;
|
|
1752
|
+
for (size_t col = 0; col < ncols_use; col++)
|
|
1753
|
+
buffer_sum_tot[col] += ptr_row[cols_use[col]];
|
|
1754
|
+
}
|
|
1755
|
+
}
|
|
1756
|
+
|
|
1757
|
+
else
|
|
1758
|
+
{
|
|
1759
|
+
for (size_t row = st; row <= end; row++)
|
|
1760
|
+
xpy1(X_row_major + ix_arr[row]*ncols, buffer_sum_tot, ncols);
|
|
1761
|
+
}
|
|
1762
|
+
}
|
|
1763
|
+
|
|
1764
|
+
else
|
|
1765
|
+
{
|
|
1766
|
+
if (force_cols_use)
|
|
1767
|
+
{
|
|
1768
|
+
size_t *curr_begin;
|
|
1769
|
+
size_t *row_end;
|
|
1770
|
+
size_t *curr_col;
|
|
1771
|
+
double *restrict Xr_this;
|
|
1772
|
+
size_t *cols_end = cols_use + ncols_use;
|
|
1773
|
+
for (size_t row = st; row <= end; row++)
|
|
1774
|
+
{
|
|
1775
|
+
curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
|
|
1776
|
+
row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
|
|
1777
|
+
if (curr_begin == row_end) continue;
|
|
1778
|
+
curr_col = cols_use;
|
|
1779
|
+
Xr_this = Xr + Xr_indptr[ix_arr[row]];
|
|
1780
|
+
|
|
1781
|
+
while (curr_col < cols_end && curr_begin < row_end)
|
|
1782
|
+
{
|
|
1783
|
+
if (*curr_begin == *curr_col)
|
|
1784
|
+
{
|
|
1785
|
+
buffer_sum_tot[std::distance(cols_use, curr_col)] += Xr_this[std::distance(curr_begin, row_end)];
|
|
1786
|
+
curr_col++;
|
|
1787
|
+
curr_begin++;
|
|
1788
|
+
}
|
|
1789
|
+
|
|
1790
|
+
else
|
|
1791
|
+
{
|
|
1792
|
+
if (*curr_begin > *curr_col)
|
|
1793
|
+
curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
|
|
1794
|
+
else
|
|
1795
|
+
curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
|
|
1796
|
+
}
|
|
1797
|
+
}
|
|
1798
|
+
}
|
|
1799
|
+
}
|
|
1800
|
+
|
|
1801
|
+
else
|
|
1802
|
+
{
|
|
1803
|
+
size_t ptr_this;
|
|
1804
|
+
for (size_t row = st; row <= end; row++)
|
|
1805
|
+
{
|
|
1806
|
+
ptr_this = Xr_indptr[ix_arr[row]];
|
|
1807
|
+
xpy1(Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_tot);
|
|
1808
|
+
}
|
|
1809
|
+
}
|
|
1810
|
+
}
|
|
1811
|
+
|
|
1812
|
+
double best_gain = -HUGE_VAL;
|
|
1813
|
+
double this_gain;
|
|
1814
|
+
double sl, sr;
|
|
1815
|
+
double dl, dr;
|
|
1816
|
+
double vleft, vright;
|
|
1817
|
+
memset(buffer_sum_left, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
|
|
1818
|
+
if (Xr_indptr == NULL)
|
|
1819
|
+
{
|
|
1820
|
+
if (!force_cols_use)
|
|
1821
|
+
{
|
|
1822
|
+
for (size_t row = st; row < end; row++)
|
|
1823
|
+
{
|
|
1824
|
+
xpy1(X_row_major + ix_arr[row]*ncols, buffer_sum_left, ncols);
|
|
1825
|
+
if (x_uses_ix_arr) {
|
|
1826
|
+
if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
|
|
1827
|
+
}
|
|
1828
|
+
else {
|
|
1829
|
+
if (unlikely(x[row] == x[row+1])) continue;
|
|
1830
|
+
}
|
|
1831
|
+
|
|
1832
|
+
vleft = 0;
|
|
1833
|
+
vright = 0;
|
|
1834
|
+
dl = (double)(row-st+1);
|
|
1835
|
+
dr = (double)(end-row);
|
|
1836
|
+
for (size_t col = 0; col < ncols; col++)
|
|
1837
|
+
{
|
|
1838
|
+
sl = buffer_sum_left[col];
|
|
1839
|
+
vleft += sl * (sl / dl);
|
|
1840
|
+
sr = buffer_sum_tot[col] - sl;
|
|
1841
|
+
vright += sr * (sr / dr);
|
|
1842
|
+
}
|
|
1843
|
+
|
|
1844
|
+
this_gain = vleft + vright;
|
|
1845
|
+
if (this_gain > best_gain)
|
|
1846
|
+
{
|
|
1847
|
+
best_gain = this_gain;
|
|
1848
|
+
split_ix = row;
|
|
1849
|
+
}
|
|
1850
|
+
}
|
|
1851
|
+
}
|
|
1852
|
+
|
|
1853
|
+
else
|
|
1854
|
+
{
|
|
1855
|
+
double *restrict ptr_row;
|
|
1856
|
+
for (size_t row = st; row < end; row++)
|
|
1857
|
+
{
|
|
1858
|
+
ptr_row = X_row_major + ix_arr[row]*ncols;
|
|
1859
|
+
for (size_t col = 0; col < ncols_use; col++)
|
|
1860
|
+
buffer_sum_left[col] += ptr_row[cols_use[col]];
|
|
1861
|
+
if (x_uses_ix_arr) {
|
|
1862
|
+
if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
|
|
1863
|
+
}
|
|
1864
|
+
else {
|
|
1865
|
+
if (unlikely(x[row] == x[row+1])) continue;
|
|
1866
|
+
}
|
|
1867
|
+
|
|
1868
|
+
vleft = 0;
|
|
1869
|
+
vright = 0;
|
|
1870
|
+
dl = (double)(row-st+1);
|
|
1871
|
+
dr = (double)(end-row);
|
|
1872
|
+
for (size_t col = 0; col < ncols_use; col++)
|
|
1873
|
+
{
|
|
1874
|
+
sl = buffer_sum_left[col];
|
|
1875
|
+
vleft += sl * (sl / dl);
|
|
1876
|
+
sr = buffer_sum_tot[col] - sl;
|
|
1877
|
+
vright += sr * (sr / dr);
|
|
1878
|
+
}
|
|
1879
|
+
|
|
1880
|
+
this_gain = vleft + vright;
|
|
1881
|
+
if (this_gain > best_gain)
|
|
1882
|
+
{
|
|
1883
|
+
best_gain = this_gain;
|
|
1884
|
+
split_ix = row;
|
|
1885
|
+
}
|
|
1886
|
+
}
|
|
1887
|
+
}
|
|
1888
|
+
}
|
|
1889
|
+
|
|
1890
|
+
else
|
|
1891
|
+
{
|
|
1892
|
+
if (!force_cols_use)
|
|
1893
|
+
{
|
|
1894
|
+
size_t ptr_this;
|
|
1895
|
+
for (size_t row = st; row < end; row++)
|
|
1896
|
+
{
|
|
1897
|
+
ptr_this = Xr_indptr[ix_arr[row]];
|
|
1898
|
+
xpy1(Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_left);
|
|
1899
|
+
if (x_uses_ix_arr) {
|
|
1900
|
+
if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
|
|
1901
|
+
}
|
|
1902
|
+
else {
|
|
1903
|
+
if (unlikely(x[row] == x[row+1])) continue;
|
|
1904
|
+
}
|
|
1905
|
+
|
|
1906
|
+
vleft = 0;
|
|
1907
|
+
vright = 0;
|
|
1908
|
+
dl = (double)(row-st+1);
|
|
1909
|
+
dr = (double)(end-row);
|
|
1910
|
+
for (size_t col = 0; col < ncols; col++)
|
|
1911
|
+
{
|
|
1912
|
+
sl = buffer_sum_left[col];
|
|
1913
|
+
vleft += sl * (sl / dl);
|
|
1914
|
+
sr = buffer_sum_tot[col] - sl;
|
|
1915
|
+
vright += sr * (sr / dr);
|
|
1916
|
+
}
|
|
1917
|
+
|
|
1918
|
+
this_gain = vleft + vright;
|
|
1919
|
+
if (this_gain > best_gain)
|
|
1920
|
+
{
|
|
1921
|
+
best_gain = this_gain;
|
|
1922
|
+
split_ix = row;
|
|
1923
|
+
}
|
|
1924
|
+
}
|
|
1925
|
+
}
|
|
1926
|
+
|
|
1927
|
+
else
|
|
1928
|
+
{
|
|
1929
|
+
size_t *curr_begin;
|
|
1930
|
+
size_t *row_end;
|
|
1931
|
+
size_t *curr_col;
|
|
1932
|
+
double *restrict Xr_this;
|
|
1933
|
+
size_t *cols_end = cols_use + ncols_use;
|
|
1934
|
+
for (size_t row = st; row < end; row++)
|
|
1935
|
+
{
|
|
1936
|
+
curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
|
|
1937
|
+
row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
|
|
1938
|
+
if (curr_begin == row_end) goto skip_sum;
|
|
1939
|
+
curr_col = cols_use;
|
|
1940
|
+
Xr_this = Xr + Xr_indptr[ix_arr[row]];
|
|
1941
|
+
while (curr_col < cols_end && curr_begin < row_end)
|
|
1942
|
+
{
|
|
1943
|
+
if (*curr_begin == *curr_col)
|
|
1944
|
+
{
|
|
1945
|
+
buffer_sum_left[std::distance(cols_use, curr_col)] += Xr_this[std::distance(curr_begin, row_end)];
|
|
1946
|
+
curr_col++;
|
|
1947
|
+
curr_begin++;
|
|
1948
|
+
}
|
|
1949
|
+
|
|
1950
|
+
else
|
|
1951
|
+
{
|
|
1952
|
+
if (*curr_begin > *curr_col)
|
|
1953
|
+
curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
|
|
1954
|
+
else
|
|
1955
|
+
curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
|
|
1956
|
+
}
|
|
1957
|
+
}
|
|
1958
|
+
|
|
1959
|
+
skip_sum:
|
|
1960
|
+
if (x_uses_ix_arr) {
|
|
1961
|
+
if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
|
|
1962
|
+
}
|
|
1963
|
+
else {
|
|
1964
|
+
if (unlikely(x[row] == x[row+1])) continue;
|
|
1965
|
+
}
|
|
1966
|
+
|
|
1967
|
+
vleft = 0;
|
|
1968
|
+
vright = 0;
|
|
1969
|
+
dl = (double)(row-st+1);
|
|
1970
|
+
dr = (double)(end-row);
|
|
1971
|
+
for (size_t col = 0; col < ncols_use; col++)
|
|
1972
|
+
{
|
|
1973
|
+
sl = buffer_sum_left[col];
|
|
1974
|
+
vleft += sl * (sl / dl);
|
|
1975
|
+
sr = buffer_sum_tot[col] - sl;
|
|
1976
|
+
vright += sr * (sr / dr);
|
|
1977
|
+
}
|
|
1978
|
+
|
|
1979
|
+
this_gain = vleft + vright;
|
|
1980
|
+
if (this_gain > best_gain)
|
|
1981
|
+
{
|
|
1982
|
+
best_gain = this_gain;
|
|
1983
|
+
split_ix = row;
|
|
1984
|
+
}
|
|
1985
|
+
}
|
|
1986
|
+
}
|
|
1987
|
+
}
|
|
1988
|
+
|
|
1989
|
+
if (best_gain <= -HUGE_VAL) return best_gain;
|
|
1990
|
+
|
|
1991
|
+
if (x_uses_ix_arr)
|
|
1992
|
+
split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
|
|
1993
|
+
else
|
|
1994
|
+
split_point = midpoint(x[split_ix], x[split_ix+1]);
|
|
1995
|
+
return best_gain / (ldouble_safe)(end - st + 1);
|
|
1996
|
+
}
|
|
1997
|
+
|
|
1998
|
+
template <class real_t, class mapping, class ldouble_safe>
|
|
1999
|
+
double find_split_full_gain_weighted(real_t *restrict x, size_t st, size_t end, size_t *restrict ix_arr,
|
|
2000
|
+
size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
|
|
2001
|
+
double *restrict X_row_major, size_t ncols,
|
|
2002
|
+
double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
|
|
2003
|
+
double *restrict buffer_sum_left, double *restrict buffer_sum_tot,
|
|
2004
|
+
size_t &restrict split_ix, double &restrict split_point,
|
|
2005
|
+
bool x_uses_ix_arr,
|
|
2006
|
+
mapping &restrict w)
|
|
2007
|
+
{
|
|
2008
|
+
if (end <= st) return -HUGE_VAL;
|
|
2009
|
+
if (cols_use != NULL && ncols_use && (double)ncols_use / (double)ncols < 0.1)
|
|
2010
|
+
force_cols_use = true;
|
|
2011
|
+
|
|
2012
|
+
double wtot = 0;
|
|
2013
|
+
if (x_uses_ix_arr)
|
|
2014
|
+
{
|
|
2015
|
+
for (size_t row = st; row <= end; row++)
|
|
2016
|
+
wtot += w[ix_arr[row]];
|
|
2017
|
+
}
|
|
2018
|
+
|
|
2019
|
+
else
|
|
2020
|
+
{
|
|
2021
|
+
for (size_t row = st; row <= end; row++)
|
|
2022
|
+
wtot += w[row];
|
|
2023
|
+
}
|
|
2024
|
+
|
|
2025
|
+
memset(buffer_sum_tot, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
|
|
2026
|
+
if (Xr_indptr == NULL)
|
|
2027
|
+
{
|
|
2028
|
+
if (!force_cols_use)
|
|
2029
|
+
{
|
|
2030
|
+
if (x_uses_ix_arr)
|
|
2031
|
+
{
|
|
2032
|
+
for (size_t row = st; row <= end; row++)
|
|
2033
|
+
axpy1(w[ix_arr[row]], X_row_major + ix_arr[row]*ncols, buffer_sum_tot, ncols);
|
|
2034
|
+
}
|
|
2035
|
+
|
|
2036
|
+
else
|
|
2037
|
+
{
|
|
2038
|
+
for (size_t row = st; row <= end; row++)
|
|
2039
|
+
axpy1(w[row], X_row_major + ix_arr[row]*ncols, buffer_sum_tot, ncols);
|
|
2040
|
+
}
|
|
2041
|
+
}
|
|
2042
|
+
|
|
2043
|
+
else
|
|
2044
|
+
{
|
|
2045
|
+
double *restrict ptr_row;
|
|
2046
|
+
double w_row;
|
|
2047
|
+
|
|
2048
|
+
if (x_uses_ix_arr)
|
|
2049
|
+
{
|
|
2050
|
+
for (size_t row = st; row <= end; row++)
|
|
2051
|
+
{
|
|
2052
|
+
ptr_row = X_row_major + ix_arr[row]*ncols;
|
|
2053
|
+
w_row = w[ix_arr[row]];
|
|
2054
|
+
for (size_t col = 0; col < ncols_use; col++)
|
|
2055
|
+
buffer_sum_tot[col] = std::fma(w_row, ptr_row[cols_use[col]], buffer_sum_tot[col]);
|
|
2056
|
+
}
|
|
2057
|
+
}
|
|
2058
|
+
|
|
2059
|
+
else
|
|
2060
|
+
{
|
|
2061
|
+
for (size_t row = st; row <= end; row++)
|
|
2062
|
+
{
|
|
2063
|
+
ptr_row = X_row_major + ix_arr[row]*ncols;
|
|
2064
|
+
w_row = w[row];
|
|
2065
|
+
for (size_t col = 0; col < ncols_use; col++)
|
|
2066
|
+
buffer_sum_tot[col] = std::fma(w_row, ptr_row[cols_use[col]], buffer_sum_tot[col]);
|
|
2067
|
+
}
|
|
2068
|
+
}
|
|
2069
|
+
}
|
|
2070
|
+
}
|
|
2071
|
+
|
|
2072
|
+
else
|
|
2073
|
+
{
|
|
2074
|
+
if (!force_cols_use)
|
|
2075
|
+
{
|
|
2076
|
+
size_t ptr_this;
|
|
2077
|
+
if (x_uses_ix_arr)
|
|
2078
|
+
{
|
|
2079
|
+
for (size_t row = st; row <= end; row++)
|
|
2080
|
+
{
|
|
2081
|
+
ptr_this = Xr_indptr[ix_arr[row]];
|
|
2082
|
+
axpy1(w[ix_arr[row]], Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_tot);
|
|
2083
|
+
}
|
|
2084
|
+
}
|
|
2085
|
+
|
|
2086
|
+
else
|
|
2087
|
+
{
|
|
2088
|
+
for (size_t row = st; row <= end; row++)
|
|
2089
|
+
{
|
|
2090
|
+
ptr_this = Xr_indptr[ix_arr[row]];
|
|
2091
|
+
axpy1(w[row], Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_tot);
|
|
2092
|
+
}
|
|
2093
|
+
}
|
|
2094
|
+
}
|
|
2095
|
+
|
|
2096
|
+
else
|
|
2097
|
+
{
|
|
2098
|
+
size_t *curr_begin;
|
|
2099
|
+
size_t *row_end;
|
|
2100
|
+
size_t *curr_col;
|
|
2101
|
+
double *restrict Xr_this;
|
|
2102
|
+
size_t *cols_end = cols_use + ncols_use;
|
|
2103
|
+
double w_row;
|
|
2104
|
+
for (size_t row = st; row <= end; row++)
|
|
2105
|
+
{
|
|
2106
|
+
curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
|
|
2107
|
+
row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
|
|
2108
|
+
if (curr_begin == row_end) continue;
|
|
2109
|
+
curr_col = cols_use;
|
|
2110
|
+
Xr_this = Xr + Xr_indptr[ix_arr[row]];
|
|
2111
|
+
w_row = w[x_uses_ix_arr? ix_arr[row] : row];
|
|
2112
|
+
size_t dtemp;
|
|
2113
|
+
|
|
2114
|
+
while (curr_col < cols_end && curr_begin < row_end)
|
|
2115
|
+
{
|
|
2116
|
+
if (*curr_begin == *curr_col)
|
|
2117
|
+
{
|
|
2118
|
+
dtemp = std::distance(cols_use, curr_col);
|
|
2119
|
+
buffer_sum_tot[dtemp]
|
|
2120
|
+
=
|
|
2121
|
+
std::fma(w_row, Xr_this[std::distance(curr_begin, row_end)], buffer_sum_tot[dtemp]);
|
|
2122
|
+
curr_col++;
|
|
2123
|
+
curr_begin++;
|
|
2124
|
+
}
|
|
2125
|
+
|
|
2126
|
+
else
|
|
2127
|
+
{
|
|
2128
|
+
if (*curr_begin > *curr_col)
|
|
2129
|
+
curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
|
|
2130
|
+
else
|
|
2131
|
+
curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
|
|
2132
|
+
}
|
|
2133
|
+
}
|
|
2134
|
+
}
|
|
2135
|
+
}
|
|
2136
|
+
}
|
|
2137
|
+
|
|
2138
|
+
double best_gain = -HUGE_VAL;
|
|
2139
|
+
double this_gain;
|
|
2140
|
+
double sl, sr;
|
|
2141
|
+
double vleft, vright;
|
|
2142
|
+
double wleft = 0;
|
|
2143
|
+
double w_row;
|
|
2144
|
+
double wright;
|
|
2145
|
+
memset(buffer_sum_left, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
|
|
2146
|
+
if (Xr_indptr == NULL)
|
|
2147
|
+
{
|
|
2148
|
+
if (!force_cols_use)
|
|
2149
|
+
{
|
|
2150
|
+
for (size_t row = st; row < end; row++)
|
|
2151
|
+
{
|
|
2152
|
+
w_row = w[x_uses_ix_arr? ix_arr[row] : row];
|
|
2153
|
+
wleft += w_row;
|
|
2154
|
+
axpy1(w_row, X_row_major + ix_arr[row]*ncols, buffer_sum_left, ncols);
|
|
2155
|
+
if (x_uses_ix_arr) {
|
|
2156
|
+
if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
|
|
2157
|
+
}
|
|
2158
|
+
else {
|
|
2159
|
+
if (unlikely(x[row] == x[row+1])) continue;
|
|
2160
|
+
}
|
|
2161
|
+
|
|
2162
|
+
vleft = 0;
|
|
2163
|
+
vright = 0;
|
|
2164
|
+
wright = wtot - wleft;
|
|
2165
|
+
for (size_t col = 0; col < ncols; col++)
|
|
2166
|
+
{
|
|
2167
|
+
sl = buffer_sum_left[col];
|
|
2168
|
+
vleft += sl * (sl / wleft);
|
|
2169
|
+
sr = buffer_sum_tot[col] - sl;
|
|
2170
|
+
vright += sr * (sr / wright);
|
|
2171
|
+
}
|
|
2172
|
+
|
|
2173
|
+
this_gain = vleft + vright;
|
|
2174
|
+
if (this_gain > best_gain)
|
|
2175
|
+
{
|
|
2176
|
+
best_gain = this_gain;
|
|
2177
|
+
split_ix = row;
|
|
2178
|
+
}
|
|
2179
|
+
}
|
|
2180
|
+
}
|
|
2181
|
+
|
|
2182
|
+
else
|
|
2183
|
+
{
|
|
2184
|
+
double *restrict ptr_row;
|
|
2185
|
+
double w_row;
|
|
2186
|
+
for (size_t row = st; row < end; row++)
|
|
2187
|
+
{
|
|
2188
|
+
w_row = w[x_uses_ix_arr? ix_arr[row] : row];
|
|
2189
|
+
wleft += w_row;
|
|
2190
|
+
|
|
2191
|
+
ptr_row = X_row_major + ix_arr[row]*ncols;
|
|
2192
|
+
for (size_t col = 0; col < ncols_use; col++)
|
|
2193
|
+
buffer_sum_left[col] = std::fma(w_row, ptr_row[cols_use[col]], buffer_sum_left[col]);
|
|
2194
|
+
if (x_uses_ix_arr) {
|
|
2195
|
+
if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
|
|
2196
|
+
}
|
|
2197
|
+
else {
|
|
2198
|
+
if (unlikely(x[row] == x[row+1])) continue;
|
|
2199
|
+
}
|
|
2200
|
+
|
|
2201
|
+
vleft = 0;
|
|
2202
|
+
vright = 0;
|
|
2203
|
+
wright = wtot - wleft;
|
|
2204
|
+
for (size_t col = 0; col < ncols_use; col++)
|
|
2205
|
+
{
|
|
2206
|
+
sl = buffer_sum_left[col];
|
|
2207
|
+
vleft += sl * (sl / wleft);
|
|
2208
|
+
sr = buffer_sum_tot[col] - sl;
|
|
2209
|
+
vright += sr * (sr / wright);
|
|
2210
|
+
}
|
|
2211
|
+
|
|
2212
|
+
this_gain = vleft + vright;
|
|
2213
|
+
if (this_gain > best_gain)
|
|
2214
|
+
{
|
|
2215
|
+
best_gain = this_gain;
|
|
2216
|
+
split_ix = row;
|
|
2217
|
+
}
|
|
2218
|
+
}
|
|
2219
|
+
}
|
|
2220
|
+
}
|
|
2221
|
+
|
|
2222
|
+
else
|
|
2223
|
+
{
|
|
2224
|
+
if (!force_cols_use)
|
|
2225
|
+
{
|
|
2226
|
+
size_t ptr_this;
|
|
2227
|
+
double w_row;
|
|
2228
|
+
for (size_t row = st; row < end; row++)
|
|
2229
|
+
{
|
|
2230
|
+
w_row= w[x_uses_ix_arr? ix_arr[row] : row];
|
|
2231
|
+
wleft += w_row;
|
|
2232
|
+
ptr_this = Xr_indptr[ix_arr[row]];
|
|
2233
|
+
axpy1(w_row, Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_left);
|
|
2234
|
+
if (x_uses_ix_arr) {
|
|
2235
|
+
if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
|
|
2236
|
+
}
|
|
2237
|
+
else {
|
|
2238
|
+
if (unlikely(x[row] == x[row+1])) continue;
|
|
2239
|
+
}
|
|
2240
|
+
|
|
2241
|
+
vleft = 0;
|
|
2242
|
+
vright = 0;
|
|
2243
|
+
wright = wtot - wleft;
|
|
2244
|
+
for (size_t col = 0; col < ncols; col++)
|
|
2245
|
+
{
|
|
2246
|
+
sl = buffer_sum_left[col];
|
|
2247
|
+
vleft += sl * (sl / wleft);
|
|
2248
|
+
sr = buffer_sum_tot[col] - sl;
|
|
2249
|
+
vright += sr * (sr / wright);
|
|
2250
|
+
}
|
|
2251
|
+
|
|
2252
|
+
this_gain = vleft + vright;
|
|
2253
|
+
if (this_gain > best_gain)
|
|
2254
|
+
{
|
|
2255
|
+
best_gain = this_gain;
|
|
2256
|
+
split_ix = row;
|
|
2257
|
+
}
|
|
2258
|
+
}
|
|
2259
|
+
}
|
|
2260
|
+
|
|
2261
|
+
else
|
|
2262
|
+
{
|
|
2263
|
+
size_t *curr_begin;
|
|
2264
|
+
size_t *row_end;
|
|
2265
|
+
size_t *curr_col;
|
|
2266
|
+
double *restrict Xr_this;
|
|
2267
|
+
size_t *cols_end = cols_use + ncols_use;
|
|
2268
|
+
double w_row;
|
|
2269
|
+
size_t dtemp;
|
|
2270
|
+
for (size_t row = st; row < end; row++)
|
|
2271
|
+
{
|
|
2272
|
+
w_row = w[x_uses_ix_arr? ix_arr[row] : row];
|
|
2273
|
+
wleft += w_row;
|
|
2274
|
+
|
|
2275
|
+
curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
|
|
2276
|
+
row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
|
|
2277
|
+
if (curr_begin == row_end) goto skip_sum;
|
|
2278
|
+
curr_col = cols_use;
|
|
2279
|
+
Xr_this = Xr + Xr_indptr[ix_arr[row]];
|
|
2280
|
+
while (curr_col < cols_end && curr_begin < row_end)
|
|
2281
|
+
{
|
|
2282
|
+
if (*curr_begin == *curr_col)
|
|
2283
|
+
{
|
|
2284
|
+
dtemp = std::distance(cols_use, curr_col);
|
|
2285
|
+
buffer_sum_left[dtemp]
|
|
2286
|
+
=
|
|
2287
|
+
std::fma(w_row, Xr_this[std::distance(curr_begin, row_end)], buffer_sum_left[dtemp]);
|
|
2288
|
+
curr_col++;
|
|
2289
|
+
curr_begin++;
|
|
2290
|
+
}
|
|
2291
|
+
|
|
2292
|
+
else
|
|
2293
|
+
{
|
|
2294
|
+
if (*curr_begin > *curr_col)
|
|
2295
|
+
curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
|
|
2296
|
+
else
|
|
2297
|
+
curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
|
|
2298
|
+
}
|
|
2299
|
+
}
|
|
2300
|
+
|
|
2301
|
+
skip_sum:
|
|
2302
|
+
if (x_uses_ix_arr) {
|
|
2303
|
+
if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
|
|
2304
|
+
}
|
|
2305
|
+
else {
|
|
2306
|
+
if (unlikely(x[row] == x[row+1])) continue;
|
|
2307
|
+
}
|
|
2308
|
+
|
|
2309
|
+
vleft = 0;
|
|
2310
|
+
vright = 0;
|
|
2311
|
+
wright = wtot - wleft;
|
|
2312
|
+
for (size_t col = 0; col < ncols_use; col++)
|
|
2313
|
+
{
|
|
2314
|
+
sl = buffer_sum_left[col];
|
|
2315
|
+
vleft += sl * (sl / wleft);
|
|
2316
|
+
sr = buffer_sum_tot[col] - sl;
|
|
2317
|
+
vright += sr * (sr / wright);
|
|
2318
|
+
}
|
|
2319
|
+
|
|
2320
|
+
this_gain = vleft + vright;
|
|
2321
|
+
if (this_gain > best_gain)
|
|
2322
|
+
{
|
|
2323
|
+
best_gain = this_gain;
|
|
2324
|
+
split_ix = row;
|
|
2325
|
+
}
|
|
2326
|
+
}
|
|
2327
|
+
}
|
|
2328
|
+
}
|
|
2329
|
+
|
|
2330
|
+
if (best_gain <= -HUGE_VAL) return best_gain;
|
|
2331
|
+
|
|
2332
|
+
split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
|
|
2333
|
+
return best_gain / wtot;
|
|
2334
|
+
}
|
|
2335
|
+
|
|
2336
|
+
template <class real_t_, class real_t>
|
|
2337
|
+
double find_split_dens_shortform_t(real_t *restrict x, size_t n, double &restrict split_point)
|
|
2338
|
+
{
|
|
2339
|
+
double best_gain = -HUGE_VAL;
|
|
2340
|
+
size_t n_minus_one = n - 1;
|
|
2341
|
+
real_t_ xmin = x[0];
|
|
2342
|
+
real_t_ xmax = x[n-1];
|
|
2343
|
+
real_t_ xleft, xright;
|
|
2344
|
+
real_t_ xmid;
|
|
2345
|
+
double this_gain;
|
|
2346
|
+
size_t split_ix = 0;
|
|
2347
|
+
|
|
2348
|
+
for (size_t ix = 0; ix < n_minus_one; ix++)
|
|
2349
|
+
{
|
|
2350
|
+
if (x[ix] == x[ix+1]) continue;
|
|
2351
|
+
xmid = (real_t_)x[ix] + ((real_t_)x[ix+1] - (real_t_)x[ix]) / (real_t_)2;
|
|
2352
|
+
xleft = xmid - xmin;
|
|
2353
|
+
xright = xmax - xmid;
|
|
2354
|
+
if (unlikely(!xleft || !xright)) continue;
|
|
2355
|
+
this_gain = (real_t_)square(ix+1) / xleft + (real_t_)square(n_minus_one - ix) / xright;
|
|
2356
|
+
if (this_gain > best_gain)
|
|
2357
|
+
{
|
|
2358
|
+
best_gain = this_gain;
|
|
2359
|
+
split_ix = ix;
|
|
2360
|
+
}
|
|
2361
|
+
}
|
|
2362
|
+
|
|
2363
|
+
if (best_gain <= -HUGE_VAL) return best_gain;
|
|
2364
|
+
|
|
2365
|
+
real_t_ xtot = (real_t_)xmax - (real_t_)xmin;
|
|
2366
|
+
real_t_ nleft = (real_t_)(split_ix+1);
|
|
2367
|
+
real_t_ nright = (real_t_)(n_minus_one - split_ix);
|
|
2368
|
+
split_point = midpoint(x[split_ix], x[split_ix+1]);
|
|
2369
|
+
real_t_ rpct_left = split_point / xtot;
|
|
2370
|
+
rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
|
|
2371
|
+
real_t_ rpct_right = (real_t_)1 - rpct_left;
|
|
2372
|
+
rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
|
|
2373
|
+
|
|
2374
|
+
real_t_ nl_sq = nleft / (real_t_)n; nl_sq = square(nl_sq);
|
|
2375
|
+
real_t_ nr_sq = nright / (real_t_)n; nl_sq = square(nr_sq);
|
|
2376
|
+
|
|
2377
|
+
return nl_sq / rpct_left + nr_sq / rpct_right;
|
|
2378
|
+
}
|
|
2379
|
+
|
|
2380
|
+
template <class real_t, class ldouble_safe>
|
|
2381
|
+
double find_split_dens_shortform(real_t *restrict x, size_t n, double &restrict split_point)
|
|
2382
|
+
{
|
|
2383
|
+
if (n < INT32_MAX)
|
|
2384
|
+
return find_split_dens_shortform_t<double, real_t>(x, n, split_point);
|
|
2385
|
+
else
|
|
2386
|
+
return find_split_dens_shortform_t<ldouble_safe, real_t>(x, n, split_point);
|
|
2387
|
+
}
|
|
2388
|
+
|
|
2389
|
+
template <class real_t_, class real_t, class mapping>
|
|
2390
|
+
double find_split_dens_shortform_weighted_t(real_t *restrict x, size_t n, double &restrict split_point, mapping &restrict w, size_t *restrict buffer_indices)
|
|
2391
|
+
{
|
|
2392
|
+
double best_gain = -HUGE_VAL;
|
|
2393
|
+
size_t n_minus_one = n - 1;
|
|
2394
|
+
real_t_ xmin = x[buffer_indices[0]];
|
|
2395
|
+
real_t_ xmax = x[buffer_indices[n-1]];
|
|
2396
|
+
real_t_ xleft, xright;
|
|
2397
|
+
real_t_ xmid;
|
|
2398
|
+
double this_gain;
|
|
2399
|
+
|
|
2400
|
+
real_t_ wtot = 0;
|
|
2401
|
+
for (size_t ix = 0; ix < n; ix++)
|
|
2402
|
+
wtot += w[buffer_indices[ix]];
|
|
2403
|
+
real_t_ w_left = 0;
|
|
2404
|
+
real_t_ w_right;
|
|
2405
|
+
real_t_ best_w = 0;
|
|
2406
|
+
size_t split_ix = 0;
|
|
2407
|
+
|
|
2408
|
+
for (size_t ix = 0; ix < n_minus_one; ix++)
|
|
2409
|
+
{
|
|
2410
|
+
w_left += w[buffer_indices[ix]];
|
|
2411
|
+
if (x[buffer_indices[ix]] == x[buffer_indices[ix+1]]) continue;
|
|
2412
|
+
xmid = (real_t_)x[buffer_indices[ix]] + ((real_t_)x[buffer_indices[ix+1]] - (real_t_)x[buffer_indices[ix]]) / (real_t_)2;
|
|
2413
|
+
xleft = xmid - xmin;
|
|
2414
|
+
xright = xmax - xmid;
|
|
2415
|
+
if (unlikely(!xleft || !xright)) continue;
|
|
2416
|
+
|
|
2417
|
+
w_right = wtot - w_left;
|
|
2418
|
+
this_gain = square(w_left) / xleft + square(w_right) / xright;
|
|
2419
|
+
if (this_gain > best_gain)
|
|
2420
|
+
{
|
|
2421
|
+
best_gain = this_gain;
|
|
2422
|
+
best_w = w_left;
|
|
2423
|
+
split_ix = xmid;
|
|
2424
|
+
}
|
|
2425
|
+
}
|
|
2426
|
+
|
|
2427
|
+
if (best_gain <= -HUGE_VAL) return best_gain;
|
|
2428
|
+
|
|
2429
|
+
real_t_ xtot = xmax - xmin;
|
|
2430
|
+
w_left = best_w;
|
|
2431
|
+
w_right = wtot - w_left;
|
|
2432
|
+
w_left = std::fmax(w_left, std::numeric_limits<double>::min());
|
|
2433
|
+
w_right = std::fmax(w_right, std::numeric_limits<double>::min());
|
|
2434
|
+
split_point = midpoint(x[buffer_indices[split_ix]], x[buffer_indices[split_ix+1]]);
|
|
2435
|
+
real_t_ rpct_left = split_point / xtot;
|
|
2436
|
+
rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
|
|
2437
|
+
real_t_ rpct_right = (real_t_)1 - rpct_left;
|
|
2438
|
+
rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
|
|
2439
|
+
|
|
2440
|
+
real_t_ wl_sq = w_left / wtot; wl_sq = square(wl_sq);
|
|
2441
|
+
real_t_ wr_sq = w_right / wtot; wl_sq = square(wr_sq);
|
|
2442
|
+
|
|
2443
|
+
return wl_sq / rpct_left + wr_sq / rpct_right;
|
|
2444
|
+
}
|
|
2445
|
+
|
|
2446
|
+
template <class real_t, class mapping, class ldouble_safe>
|
|
2447
|
+
double find_split_dens_shortform_weighted(real_t *restrict x, size_t n, double &restrict split_point, mapping &restrict w, size_t *restrict buffer_indices)
|
|
2448
|
+
{
|
|
2449
|
+
if (n < INT32_MAX)
|
|
2450
|
+
return find_split_dens_shortform_weighted_t<double, real_t, mapping>(x, n, split_point, w, buffer_indices);
|
|
2451
|
+
else
|
|
2452
|
+
return find_split_dens_shortform_weighted_t<ldouble_safe, real_t, mapping>(x, n, split_point, w, buffer_indices);
|
|
2453
|
+
}
|
|
2454
|
+
|
|
2455
|
+
template <class real_t>
|
|
2456
|
+
double find_split_dens_shortform(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
|
|
2457
|
+
double &restrict split_point, size_t &restrict split_ix)
|
|
2458
|
+
{
|
|
2459
|
+
double best_gain = -HUGE_VAL;
|
|
2460
|
+
real_t xmin = x[ix_arr[st]];
|
|
2461
|
+
real_t xmax = x[ix_arr[end]];
|
|
2462
|
+
real_t xleft, xright;
|
|
2463
|
+
real_t xmid;
|
|
2464
|
+
double this_gain;
|
|
2465
|
+
|
|
2466
|
+
for (size_t row = st; row < end; row++)
|
|
2467
|
+
{
|
|
2468
|
+
if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
|
|
2469
|
+
xmid = x[ix_arr[row]] + (x[ix_arr[row+1]] - x[ix_arr[row]]) / (real_t)2;
|
|
2470
|
+
xleft = xmid - xmin;
|
|
2471
|
+
xright = xmax - xmid;
|
|
2472
|
+
if (unlikely(!xleft || !xright)) continue;
|
|
2473
|
+
this_gain = square(row-st+1) / xleft + square(end-row) / xright;
|
|
2474
|
+
if (this_gain > best_gain)
|
|
2475
|
+
{
|
|
2476
|
+
best_gain = this_gain;
|
|
2477
|
+
split_ix = row;
|
|
2478
|
+
}
|
|
2479
|
+
}
|
|
2480
|
+
|
|
2481
|
+
if (best_gain <= -HUGE_VAL) return best_gain;
|
|
2482
|
+
|
|
2483
|
+
double xtot = (double)xmax - (double)xmin;
|
|
2484
|
+
double nleft = (double)(split_ix-st+1);
|
|
2485
|
+
double nright = (double)(end - split_ix);
|
|
2486
|
+
split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
|
|
2487
|
+
double rpct_left = split_point / xtot;
|
|
2488
|
+
rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
|
|
2489
|
+
double rpct_right = 1. - rpct_left;
|
|
2490
|
+
rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
|
|
2491
|
+
double ntot = (double)(end - st + 1);
|
|
2492
|
+
|
|
2493
|
+
double nl_sq = nleft / ntot; nl_sq = square(nl_sq);
|
|
2494
|
+
double nr_sq = nright / ntot; nl_sq = square(nr_sq);
|
|
2495
|
+
|
|
2496
|
+
return nl_sq / rpct_left + nr_sq / rpct_right;
|
|
2497
|
+
}
|
|
2498
|
+
|
|
2499
|
+
template <class real_t, class mapping>
|
|
2500
|
+
double find_split_dens_shortform_weighted(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
|
|
2501
|
+
double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
|
|
2502
|
+
{
|
|
2503
|
+
double best_gain = -HUGE_VAL;
|
|
2504
|
+
real_t xmin = x[ix_arr[st]];
|
|
2505
|
+
real_t xmax = x[ix_arr[end]];
|
|
2506
|
+
real_t xleft, xright;
|
|
2507
|
+
real_t xmid;
|
|
2508
|
+
double this_gain;
|
|
2509
|
+
|
|
2510
|
+
double wtot = 0;
|
|
2511
|
+
for (size_t row = st; row <= end; row++)
|
|
2512
|
+
wtot += w[ix_arr[row]];
|
|
2513
|
+
double w_left = 0;
|
|
2514
|
+
double w_right;
|
|
2515
|
+
double best_w = 0;
|
|
2516
|
+
|
|
2517
|
+
for (size_t row = st; row < end; row++)
|
|
2518
|
+
{
|
|
2519
|
+
w_left += w[ix_arr[row]];
|
|
2520
|
+
if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
|
|
2521
|
+
xmid = x[ix_arr[row]] + (x[ix_arr[row+1]] - x[ix_arr[row]]) / (real_t)2;
|
|
2522
|
+
xleft = xmid - xmin;
|
|
2523
|
+
xright = xmax - xmid;
|
|
2524
|
+
if (unlikely(!xleft || !xright)) continue;
|
|
2525
|
+
|
|
2526
|
+
w_right = wtot - w_left;
|
|
2527
|
+
this_gain = square(w_left) / xleft + square(w_right) / xright;
|
|
2528
|
+
if (this_gain > best_gain)
|
|
2529
|
+
{
|
|
2530
|
+
best_gain = this_gain;
|
|
2531
|
+
best_w = w_left;
|
|
2532
|
+
split_ix = row;
|
|
2533
|
+
}
|
|
2534
|
+
}
|
|
2535
|
+
|
|
2536
|
+
if (best_gain <= -HUGE_VAL) return best_gain;
|
|
2537
|
+
|
|
2538
|
+
double xtot = (double)xmax - (double)xmin;
|
|
2539
|
+
w_left = best_w;
|
|
2540
|
+
w_right = wtot - w_left;
|
|
2541
|
+
w_left = std::fmax(w_left, std::numeric_limits<double>::min());
|
|
2542
|
+
w_right = std::fmax(w_right, std::numeric_limits<double>::min());
|
|
2543
|
+
split_point = midpoint(x[split_ix], x[split_ix+1]);
|
|
2544
|
+
double rpct_left = split_point / xtot;
|
|
2545
|
+
rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
|
|
2546
|
+
double rpct_right = 1. - rpct_left;
|
|
2547
|
+
rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
|
|
2548
|
+
|
|
2549
|
+
double wl_sq = w_left / wtot; wl_sq = square(wl_sq);
|
|
2550
|
+
double wr_sq = w_right / wtot; wl_sq = square(wr_sq);
|
|
2551
|
+
|
|
2552
|
+
return wl_sq / rpct_left + wr_sq / rpct_right;
|
|
2553
|
+
}
|
|
2554
|
+
|
|
2555
|
+
/* This is a slower but more numerically-robust form */
|
|
2556
|
+
template <class real_t, class ldouble_safe>
|
|
2557
|
+
double find_split_dens_longform(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
|
|
2558
|
+
double &restrict split_point, size_t &restrict split_ix)
|
|
2559
|
+
{
|
|
2560
|
+
double best_gain = -HUGE_VAL;
|
|
2561
|
+
real_t xmin = x[ix_arr[st]];
|
|
2562
|
+
real_t xmax = x[ix_arr[end]];
|
|
2563
|
+
real_t xleft, xright;
|
|
2564
|
+
real_t xmid;
|
|
2565
|
+
ldouble_safe pct_left, pct_right;
|
|
2566
|
+
ldouble_safe rpct_left, rpct_right;
|
|
2567
|
+
ldouble_safe n_tot = end - st + 1;
|
|
2568
|
+
ldouble_safe xtot = (ldouble_safe)xmax - (ldouble_safe)xmin;
|
|
2569
|
+
ldouble_safe cnt_left;
|
|
2570
|
+
double this_gain;
|
|
2571
|
+
|
|
2572
|
+
for (size_t row = st; row < end; row++)
|
|
2573
|
+
{
|
|
2574
|
+
if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
|
|
2575
|
+
xmid = midpoint(x[ix_arr[row]], x[ix_arr[row+1]]);
|
|
2576
|
+
xleft = xmid - xmin;
|
|
2577
|
+
xright = xmax - xmid;
|
|
2578
|
+
if (unlikely(!xleft || !xright)) continue;
|
|
2579
|
+
|
|
2580
|
+
cnt_left = (ldouble_safe)(row-st+1);
|
|
2581
|
+
|
|
2582
|
+
xleft = std::fmax(xleft, (real_t)std::numeric_limits<real_t>::min());
|
|
2583
|
+
xright = std::fmax(xright, (real_t)std::numeric_limits<real_t>::min());
|
|
2584
|
+
pct_left = cnt_left / n_tot;
|
|
2585
|
+
pct_right = (ldouble_safe)1 - pct_left;
|
|
2586
|
+
rpct_left = (ldouble_safe)xleft / xtot;
|
|
2587
|
+
rpct_right = (ldouble_safe)xright / xtot;
|
|
2588
|
+
|
|
2589
|
+
this_gain = square(pct_left) / rpct_left + square(pct_right) / rpct_right;
|
|
2590
|
+
if (unlikely(is_na_or_inf(this_gain))) continue;
|
|
2591
|
+
if (this_gain > best_gain)
|
|
2592
|
+
{
|
|
2593
|
+
best_gain = this_gain;
|
|
2594
|
+
split_point = xmid;
|
|
2595
|
+
split_ix = row;
|
|
2596
|
+
}
|
|
2597
|
+
}
|
|
2598
|
+
|
|
2599
|
+
return best_gain;
|
|
2600
|
+
}
|
|
2601
|
+
|
|
2602
|
+
template <class real_t, class mapping, class ldouble_safe>
|
|
2603
|
+
double find_split_dens_longform_weighted(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
|
|
2604
|
+
double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
|
|
2605
|
+
{
|
|
2606
|
+
double best_gain = -HUGE_VAL;
|
|
2607
|
+
real_t xmin = x[ix_arr[st]];
|
|
2608
|
+
real_t xmax = x[ix_arr[end]];
|
|
2609
|
+
real_t xleft, xright;
|
|
2610
|
+
real_t xmid;
|
|
2611
|
+
ldouble_safe pct_left, pct_right;
|
|
2612
|
+
ldouble_safe rpct_left, rpct_right;
|
|
2613
|
+
ldouble_safe xtot = (ldouble_safe)xmax - (ldouble_safe)xmin;
|
|
2614
|
+
double this_gain;
|
|
2615
|
+
|
|
2616
|
+
ldouble_safe wtot = 0;
|
|
2617
|
+
for (size_t row = st; row <= end; row++)
|
|
2618
|
+
wtot += w[ix_arr[row]];
|
|
2619
|
+
ldouble_safe w_left = 0;
|
|
2620
|
+
|
|
2621
|
+
for (size_t row = st; row < end; row++)
|
|
2622
|
+
{
|
|
2623
|
+
w_left += w[ix_arr[row]];
|
|
2624
|
+
if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
|
|
2625
|
+
xmid = midpoint(x[ix_arr[row]], x[ix_arr[row+1]]);
|
|
2626
|
+
xleft = xmid - xmin;
|
|
2627
|
+
xright = xmax - xmid;
|
|
2628
|
+
if (unlikely(!xleft || !xright)) continue;
|
|
2629
|
+
|
|
2630
|
+
xleft = std::fmax(xleft, (real_t)std::numeric_limits<real_t>::min());
|
|
2631
|
+
xright = std::fmax(xright, (real_t)std::numeric_limits<real_t>::min());
|
|
2632
|
+
pct_left = w_left / wtot;
|
|
2633
|
+
pct_right = (ldouble_safe)1 - pct_left;
|
|
2634
|
+
rpct_left = (ldouble_safe)xleft / xtot;
|
|
2635
|
+
rpct_right = (ldouble_safe)xright / xtot;
|
|
2636
|
+
|
|
2637
|
+
this_gain = square(pct_left) / rpct_left + square(pct_right) / rpct_right;
|
|
2638
|
+
if (unlikely(is_na_or_inf(this_gain))) continue;
|
|
2639
|
+
if (this_gain > best_gain)
|
|
2640
|
+
{
|
|
2641
|
+
best_gain = this_gain;
|
|
2642
|
+
split_point = xmid;
|
|
2643
|
+
split_ix = row;
|
|
2644
|
+
}
|
|
2645
|
+
}
|
|
2646
|
+
|
|
2647
|
+
return best_gain;
|
|
2648
|
+
}
|
|
2649
|
+
|
|
2650
|
+
template <class real_t, class ldouble_safe>
|
|
2651
|
+
double find_split_dens(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
|
|
2652
|
+
double &restrict split_point, size_t &restrict split_ix)
|
|
2653
|
+
{
|
|
2654
|
+
if (end - st + 1 < INT32_MAX && x[ix_arr[end]] - x[ix_arr[st]] >= 1)
|
|
2655
|
+
return find_split_dens_shortform<real_t>(x, ix_arr, st, end, split_point, split_ix);
|
|
2656
|
+
else
|
|
2657
|
+
return find_split_dens_longform<real_t, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix);
|
|
2658
|
+
}
|
|
2659
|
+
|
|
2660
|
+
template <class real_t, class mapping, class ldouble_safe>
|
|
2661
|
+
double find_split_dens_weighted(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
|
|
2662
|
+
double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
|
|
2663
|
+
{
|
|
2664
|
+
if (end - st + 1 < INT32_MAX && x[ix_arr[end]] - x[ix_arr[st]] >= 1)
|
|
2665
|
+
return find_split_dens_shortform_weighted<real_t, mapping>(x, ix_arr, st, end, split_point, split_ix, w);
|
|
2666
|
+
else
|
|
2667
|
+
return find_split_dens_longform_weighted<real_t, mapping, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix, w);
|
|
2668
|
+
}
|
|
2669
|
+
|
|
2670
|
+
template <class int_t, class ldouble_safe>
|
|
2671
|
+
double find_split_dens_longform(int *restrict x, int ncat, size_t *restrict ix_arr, size_t st, size_t end,
|
|
2672
|
+
CategSplit cat_split_type, MissingAction missing_action,
|
|
2673
|
+
int &restrict chosen_cat, signed char *restrict split_categ, int *restrict saved_cat_mode,
|
|
2674
|
+
size_t *restrict buffer_cnt, int_t *restrict buffer_indices)
|
|
2675
|
+
{
|
|
2676
|
+
if (st >= end || ncat <= 1) return -HUGE_VAL;
|
|
2677
|
+
size_t n_nas = 0;
|
|
2678
|
+
int xval;
|
|
2679
|
+
|
|
2680
|
+
/* count categories */
|
|
2681
|
+
memset(buffer_cnt, 0, sizeof(size_t) * ncat);
|
|
2682
|
+
if (missing_action == Fail)
|
|
2683
|
+
{
|
|
2684
|
+
for (size_t row = st; row <= end; row++)
|
|
2685
|
+
if (likely(x[ix_arr[row]] >= 0))
|
|
2686
|
+
buffer_cnt[x[ix_arr[row]]]++;
|
|
2687
|
+
}
|
|
2688
|
+
|
|
2689
|
+
else if (missing_action == Impute)
|
|
2690
|
+
{
|
|
2691
|
+
for (size_t row = st; row <= end; row++)
|
|
2692
|
+
{
|
|
2693
|
+
xval = x[ix_arr[row]];
|
|
2694
|
+
if (unlikely(xval < 0))
|
|
2695
|
+
n_nas++;
|
|
2696
|
+
else
|
|
2697
|
+
buffer_cnt[xval]++;
|
|
2698
|
+
}
|
|
2699
|
+
|
|
2700
|
+
if (unlikely(n_nas >= end-st)) return -HUGE_VAL;
|
|
2701
|
+
|
|
2702
|
+
if (n_nas)
|
|
2703
|
+
{
|
|
2704
|
+
auto idxmax = std::max_element(buffer_cnt, buffer_cnt + ncat);
|
|
2705
|
+
*idxmax += n_nas;
|
|
2706
|
+
*saved_cat_mode = (int)std::distance(buffer_cnt, idxmax);
|
|
2707
|
+
}
|
|
2708
|
+
}
|
|
2709
|
+
|
|
2710
|
+
else
|
|
2711
|
+
{
|
|
2712
|
+
for (size_t row = st; row <= end; row++)
|
|
2713
|
+
{
|
|
2714
|
+
xval = x[ix_arr[row]];
|
|
2715
|
+
if (likely(xval >= 0)) buffer_cnt[xval]++;
|
|
2716
|
+
}
|
|
2717
|
+
}
|
|
2718
|
+
|
|
2719
|
+
std::iota(buffer_indices, buffer_indices + ncat, (int_t)0);
|
|
2720
|
+
std::sort(buffer_indices, buffer_indices + ncat,
|
|
2721
|
+
[&buffer_cnt](const int_t a, const int_t b)
|
|
2722
|
+
{return buffer_cnt[a] < buffer_cnt[b];});
|
|
2723
|
+
|
|
2724
|
+
int curr = 0;
|
|
2725
|
+
if (split_categ != NULL)
|
|
2726
|
+
{
|
|
2727
|
+
while (buffer_cnt[buffer_indices[curr]] == 0)
|
|
2728
|
+
{
|
|
2729
|
+
split_categ[buffer_indices[curr]] = -1;
|
|
2730
|
+
curr++;
|
|
2731
|
+
}
|
|
2732
|
+
}
|
|
2733
|
+
|
|
2734
|
+
else
|
|
2735
|
+
{
|
|
2736
|
+
while (buffer_cnt[buffer_indices[curr]] == 0) curr++;
|
|
2737
|
+
}
|
|
2738
|
+
|
|
2739
|
+
int ncat_present = ncat - curr;
|
|
2740
|
+
if (ncat_present <= 1) return -HUGE_VAL;
|
|
2741
|
+
if (ncat_present == 2)
|
|
2742
|
+
{
|
|
2743
|
+
switch (cat_split_type)
|
|
2744
|
+
{
|
|
2745
|
+
case SingleCateg:
|
|
2746
|
+
{
|
|
2747
|
+
chosen_cat = buffer_indices[curr];
|
|
2748
|
+
break;
|
|
2749
|
+
}
|
|
2750
|
+
|
|
2751
|
+
case SubSet:
|
|
2752
|
+
{
|
|
2753
|
+
split_categ[buffer_indices[curr]] = 1;
|
|
2754
|
+
split_categ[buffer_indices[curr+1]] = 0;
|
|
2755
|
+
break;
|
|
2756
|
+
}
|
|
2757
|
+
}
|
|
2758
|
+
|
|
2759
|
+
ldouble_safe pct_left
|
|
2760
|
+
=
|
|
2761
|
+
(ldouble_safe)buffer_cnt[buffer_indices[curr]]
|
|
2762
|
+
/
|
|
2763
|
+
(ldouble_safe)(
|
|
2764
|
+
buffer_cnt[buffer_indices[curr]]
|
|
2765
|
+
+
|
|
2766
|
+
buffer_cnt[buffer_indices[curr+1]]
|
|
2767
|
+
);
|
|
2768
|
+
|
|
2769
|
+
return ((ldouble_safe)buffer_cnt[buffer_indices[curr]] * (2. * pct_left)
|
|
2770
|
+
+
|
|
2771
|
+
(ldouble_safe)buffer_cnt[buffer_indices[curr+1]] * (2. - 2.*pct_left))
|
|
2772
|
+
/
|
|
2773
|
+
(ldouble_safe)(buffer_cnt[buffer_indices[curr]] + buffer_cnt[buffer_indices[curr+1]]);
|
|
2774
|
+
}
|
|
2775
|
+
|
|
2776
|
+
size_t ntot;
|
|
2777
|
+
if (missing_action == Impute)
|
|
2778
|
+
ntot = end - st + 1;
|
|
2779
|
+
else
|
|
2780
|
+
ntot = std::accumulate(buffer_cnt, buffer_cnt + ncat, (size_t)0);
|
|
2781
|
+
if (unlikely(ntot <= 1)) unexpected_error();
|
|
2782
|
+
ldouble_safe ntot_ = (ldouble_safe)ntot;
|
|
2783
|
+
|
|
2784
|
+
switch (cat_split_type)
|
|
2785
|
+
{
|
|
2786
|
+
case SingleCateg:
|
|
2787
|
+
{
|
|
2788
|
+
double pct_one_cat = 1. / (double)ncat_present;
|
|
2789
|
+
double pct_left_smallest = (ldouble_safe)buffer_cnt[buffer_indices[curr]] / ntot_;
|
|
2790
|
+
double gain_smallest
|
|
2791
|
+
=
|
|
2792
|
+
(ldouble_safe)buffer_cnt[buffer_indices[curr]] * (pct_left_smallest / pct_one_cat)
|
|
2793
|
+
+
|
|
2794
|
+
(ldouble_safe)(ntot - buffer_cnt[buffer_indices[curr]]) * ((1. - pct_left_smallest) / (1. - pct_one_cat))
|
|
2795
|
+
;
|
|
2796
|
+
|
|
2797
|
+
double pct_left_biggest = (ldouble_safe)buffer_cnt[buffer_indices[ncat-1]] / ntot_;
|
|
2798
|
+
double gain_biggest
|
|
2799
|
+
=
|
|
2800
|
+
(ldouble_safe)buffer_cnt[buffer_indices[ncat-1]] * (pct_left_biggest / pct_one_cat)
|
|
2801
|
+
+
|
|
2802
|
+
(ldouble_safe)(ntot - buffer_cnt[buffer_indices[ncat-1]]) * ((1. - pct_left_biggest) / (1. - pct_one_cat))
|
|
2803
|
+
;
|
|
2804
|
+
|
|
2805
|
+
if (gain_smallest >= gain_biggest)
|
|
2806
|
+
{
|
|
2807
|
+
chosen_cat = buffer_indices[curr];
|
|
2808
|
+
return gain_smallest / ntot_;
|
|
2809
|
+
}
|
|
2810
|
+
|
|
2811
|
+
else
|
|
2812
|
+
{
|
|
2813
|
+
chosen_cat = buffer_indices[ncat-1];
|
|
2814
|
+
return gain_biggest / ntot_;
|
|
2815
|
+
}
|
|
2816
|
+
break;
|
|
2817
|
+
}
|
|
2818
|
+
|
|
2819
|
+
case SubSet:
|
|
2820
|
+
{
|
|
2821
|
+
size_t cnt_left = 0;
|
|
2822
|
+
size_t cnt_right;
|
|
2823
|
+
int st_cat = curr - 1;
|
|
2824
|
+
double this_gain;
|
|
2825
|
+
double best_gain = -HUGE_VAL;
|
|
2826
|
+
int best_cat = 0;
|
|
2827
|
+
ldouble_safe pct_left;
|
|
2828
|
+
double pct_cat_left;
|
|
2829
|
+
double ncat_present_ = (double)ncat_present;
|
|
2830
|
+
for (; curr < ncat; curr++)
|
|
2831
|
+
{
|
|
2832
|
+
cnt_left += buffer_cnt[buffer_indices[curr]];
|
|
2833
|
+
cnt_right = ntot - cnt_left;
|
|
2834
|
+
pct_left = (ldouble_safe)cnt_left / ntot_;
|
|
2835
|
+
pct_cat_left = (double)(curr - st_cat) / ncat_present_;
|
|
2836
|
+
this_gain
|
|
2837
|
+
=
|
|
2838
|
+
(ldouble_safe)cnt_left * (pct_left / pct_cat_left)
|
|
2839
|
+
+
|
|
2840
|
+
(ldouble_safe)cnt_right * (((ldouble_safe)1 - pct_left) / (1. - pct_cat_left))
|
|
2841
|
+
;
|
|
2842
|
+
if (this_gain > best_gain)
|
|
2843
|
+
{
|
|
2844
|
+
best_gain = this_gain;
|
|
2845
|
+
best_cat = curr;
|
|
2846
|
+
}
|
|
2847
|
+
}
|
|
2848
|
+
|
|
2849
|
+
if (best_gain <= -HUGE_VAL) return best_gain;
|
|
2850
|
+
st_cat++;
|
|
2851
|
+
for (; st_cat <= best_cat; st_cat++)
|
|
2852
|
+
split_categ[buffer_indices[st_cat]] = 1;
|
|
2853
|
+
for (; st_cat < ncat; st_cat++)
|
|
2854
|
+
split_categ[buffer_indices[st_cat]] = 0;
|
|
2855
|
+
return best_gain / ntot_;
|
|
2856
|
+
break;
|
|
2857
|
+
}
|
|
2858
|
+
}
|
|
2859
|
+
|
|
2860
|
+
/* This will not be reached, but CRAN might complain otherwise */
|
|
2861
|
+
return -HUGE_VAL;
|
|
2862
|
+
}
|
|
2863
|
+
|
|
2864
|
+
template <class mapping, class int_t, class ldouble_safe>
|
|
2865
|
+
double find_split_dens_longform_weighted(int *restrict x, int ncat, size_t *restrict ix_arr, size_t st, size_t end,
|
|
2866
|
+
CategSplit cat_split_type, MissingAction missing_action,
|
|
2867
|
+
int &restrict chosen_cat, signed char *restrict split_categ, int *restrict saved_cat_mode,
|
|
2868
|
+
int_t *restrict buffer_indices, mapping &restrict w)
|
|
2869
|
+
{
|
|
2870
|
+
if (st >= end || ncat <= 1) return -HUGE_VAL;
|
|
2871
|
+
ldouble_safe w_missing = 0;
|
|
2872
|
+
int xval;
|
|
2873
|
+
size_t ix_;
|
|
2874
|
+
|
|
2875
|
+
/* count categories */
|
|
2876
|
+
/* TODO: allocate this buffer externally */
|
|
2877
|
+
std::vector<ldouble_safe> buffer_cnt(ncat, (ldouble_safe)0);
|
|
2878
|
+
if (missing_action == Fail)
|
|
2879
|
+
{
|
|
2880
|
+
for (size_t row = st; row <= end; row++)
|
|
2881
|
+
{
|
|
2882
|
+
ix_ = ix_arr[row];
|
|
2883
|
+
if (unlikely(x[ix_]) < 0) continue;
|
|
2884
|
+
buffer_cnt[x[ix_]] += w[ix_];
|
|
2885
|
+
}
|
|
2886
|
+
}
|
|
2887
|
+
|
|
2888
|
+
else if (missing_action == Impute)
|
|
2889
|
+
{
|
|
2890
|
+
for (size_t row = st; row <= end; row++)
|
|
2891
|
+
{
|
|
2892
|
+
ix_ = ix_arr[row];
|
|
2893
|
+
xval = x[ix_];
|
|
2894
|
+
if (unlikely(xval < 0))
|
|
2895
|
+
w_missing += w[ix_];
|
|
2896
|
+
else
|
|
2897
|
+
buffer_cnt[xval] += w[ix_];
|
|
2898
|
+
}
|
|
2899
|
+
|
|
2900
|
+
if (w_missing)
|
|
2901
|
+
{
|
|
2902
|
+
auto idxmax = std::max_element(buffer_cnt.begin(), buffer_cnt.end());
|
|
2903
|
+
*idxmax += w_missing;
|
|
2904
|
+
*saved_cat_mode = (int)std::distance(buffer_cnt.begin(), idxmax);
|
|
2905
|
+
}
|
|
2906
|
+
}
|
|
2907
|
+
|
|
2908
|
+
else
|
|
2909
|
+
{
|
|
2910
|
+
for (size_t row = st; row <= end; row++)
|
|
2911
|
+
{
|
|
2912
|
+
ix_ = ix_arr[row];
|
|
2913
|
+
xval = x[ix_];
|
|
2914
|
+
if (likely(xval >= 0)) buffer_cnt[xval] += w[ix_];
|
|
2915
|
+
}
|
|
2916
|
+
}
|
|
2917
|
+
|
|
2918
|
+
std::iota(buffer_indices, buffer_indices + ncat, (int_t)0);
|
|
2919
|
+
std::sort(buffer_indices, buffer_indices + ncat,
|
|
2920
|
+
[&buffer_cnt](const int_t a, const int_t b)
|
|
2921
|
+
{return buffer_cnt[a] < buffer_cnt[b];});
|
|
2922
|
+
|
|
2923
|
+
int curr = 0;
|
|
2924
|
+
if (split_categ != NULL)
|
|
2925
|
+
{
|
|
2926
|
+
while (buffer_cnt[buffer_indices[curr]] == 0)
|
|
2927
|
+
{
|
|
2928
|
+
split_categ[buffer_indices[curr]] = -1;
|
|
2929
|
+
curr++;
|
|
2930
|
+
}
|
|
2931
|
+
}
|
|
2932
|
+
|
|
2933
|
+
else
|
|
2934
|
+
{
|
|
2935
|
+
while (buffer_cnt[buffer_indices[curr]] == 0) curr++;
|
|
2936
|
+
}
|
|
2937
|
+
|
|
2938
|
+
int ncat_present = ncat - curr;
|
|
2939
|
+
if (ncat_present <= 1) return -HUGE_VAL;
|
|
2940
|
+
if (ncat_present == 2)
|
|
2941
|
+
{
|
|
2942
|
+
switch (cat_split_type)
|
|
2943
|
+
{
|
|
2944
|
+
case SingleCateg:
|
|
2945
|
+
{
|
|
2946
|
+
chosen_cat = buffer_indices[curr];
|
|
2947
|
+
break;
|
|
2948
|
+
}
|
|
2949
|
+
|
|
2950
|
+
case SubSet:
|
|
2951
|
+
{
|
|
2952
|
+
split_categ[buffer_indices[curr]] = 1;
|
|
2953
|
+
split_categ[buffer_indices[curr+1]] = 0;
|
|
2954
|
+
break;
|
|
2955
|
+
}
|
|
2956
|
+
}
|
|
2957
|
+
|
|
2958
|
+
ldouble_safe pct_left
|
|
2959
|
+
=
|
|
2960
|
+
buffer_cnt[buffer_indices[curr]]
|
|
2961
|
+
/
|
|
2962
|
+
(
|
|
2963
|
+
buffer_cnt[buffer_indices[curr]]
|
|
2964
|
+
+
|
|
2965
|
+
buffer_cnt[buffer_indices[curr+1]]
|
|
2966
|
+
);
|
|
2967
|
+
|
|
2968
|
+
return (buffer_cnt[buffer_indices[curr]] * (pct_left * 2.)
|
|
2969
|
+
+
|
|
2970
|
+
buffer_cnt[buffer_indices[curr+1]] * (2. - 2.*pct_left))
|
|
2971
|
+
/
|
|
2972
|
+
(buffer_cnt[buffer_indices[curr]] + buffer_cnt[buffer_indices[curr+1]]);
|
|
2973
|
+
}
|
|
2974
|
+
|
|
2975
|
+
ldouble_safe ntot = std::accumulate(buffer_cnt.begin(), buffer_cnt.end(), (ldouble_safe)0);
|
|
2976
|
+
if (unlikely(ntot <= 0)) unexpected_error();
|
|
2977
|
+
|
|
2978
|
+
switch (cat_split_type)
|
|
2979
|
+
{
|
|
2980
|
+
case SingleCateg:
|
|
2981
|
+
{
|
|
2982
|
+
double pct_one_cat = 1. / (double)ncat_present;
|
|
2983
|
+
double pct_left_smallest = buffer_cnt[buffer_indices[curr]] / ntot;
|
|
2984
|
+
double gain_smallest
|
|
2985
|
+
=
|
|
2986
|
+
buffer_cnt[buffer_indices[curr]] * (pct_left_smallest / pct_one_cat)
|
|
2987
|
+
+
|
|
2988
|
+
(ntot - buffer_cnt[buffer_indices[curr]]) * ((1. - pct_left_smallest) / (1. - pct_one_cat))
|
|
2989
|
+
;
|
|
2990
|
+
|
|
2991
|
+
double pct_left_biggest = buffer_cnt[buffer_indices[ncat-1]] / ntot;
|
|
2992
|
+
double gain_biggest
|
|
2993
|
+
=
|
|
2994
|
+
buffer_cnt[buffer_indices[ncat-1]] * (pct_left_biggest / pct_one_cat)
|
|
2995
|
+
+
|
|
2996
|
+
(ntot - buffer_cnt[buffer_indices[ncat-1]]) * ((1. - pct_left_biggest) / (1. - pct_one_cat))
|
|
2997
|
+
;
|
|
2998
|
+
|
|
2999
|
+
if (gain_smallest >= gain_biggest)
|
|
3000
|
+
{
|
|
3001
|
+
chosen_cat = buffer_indices[curr];
|
|
3002
|
+
return gain_smallest / ntot;
|
|
3003
|
+
}
|
|
3004
|
+
|
|
3005
|
+
else
|
|
3006
|
+
{
|
|
3007
|
+
chosen_cat = buffer_indices[ncat-1];
|
|
3008
|
+
return gain_biggest / ntot;
|
|
3009
|
+
}
|
|
3010
|
+
break;
|
|
3011
|
+
}
|
|
3012
|
+
|
|
3013
|
+
case SubSet:
|
|
3014
|
+
{
|
|
3015
|
+
ldouble_safe cnt_left = 0;
|
|
3016
|
+
ldouble_safe cnt_right;
|
|
3017
|
+
int st_cat = curr - 1;
|
|
3018
|
+
double this_gain;
|
|
3019
|
+
double best_gain = -HUGE_VAL;
|
|
3020
|
+
int best_cat = 0;
|
|
3021
|
+
ldouble_safe pct_left;
|
|
3022
|
+
double pct_cat_left;
|
|
3023
|
+
double ncat_present_ = (double)ncat_present;
|
|
3024
|
+
for (; curr < ncat; curr++)
|
|
3025
|
+
{
|
|
3026
|
+
cnt_left += buffer_cnt[buffer_indices[curr]];
|
|
3027
|
+
cnt_right = ntot - cnt_left;
|
|
3028
|
+
pct_left = cnt_left / ntot;
|
|
3029
|
+
pct_cat_left = (double)(curr - st_cat) / ncat_present_;
|
|
3030
|
+
this_gain
|
|
3031
|
+
=
|
|
3032
|
+
(ldouble_safe)cnt_left * (pct_left / pct_cat_left)
|
|
3033
|
+
+
|
|
3034
|
+
(ldouble_safe)cnt_right * (((ldouble_safe)1 - pct_left) / (1. - pct_cat_left))
|
|
3035
|
+
;
|
|
3036
|
+
if (this_gain > best_gain)
|
|
3037
|
+
{
|
|
3038
|
+
best_gain = this_gain;
|
|
3039
|
+
best_cat = curr;
|
|
3040
|
+
}
|
|
3041
|
+
}
|
|
3042
|
+
|
|
3043
|
+
if (best_gain <= -HUGE_VAL) return best_gain;
|
|
3044
|
+
st_cat++;
|
|
3045
|
+
for (; st_cat <= best_cat; st_cat++)
|
|
3046
|
+
split_categ[buffer_indices[st_cat]] = 1;
|
|
3047
|
+
for (; st_cat < ncat; st_cat++)
|
|
3048
|
+
split_categ[buffer_indices[st_cat]] = 0;
|
|
3049
|
+
return best_gain / ntot;
|
|
3050
|
+
break;
|
|
3051
|
+
}
|
|
3052
|
+
}
|
|
3053
|
+
|
|
3054
|
+
/* This will not be reached, but CRAN might complain otherwise */
|
|
3055
|
+
return -HUGE_VAL;
|
|
3056
|
+
}
|
|
3057
|
+
|
|
3058
|
+
/* for split-criterion in hyperplanes (see below for version aimed at single-variable splits) */
|
|
3059
|
+
template <class ldouble_safe>
|
|
3060
|
+
double eval_guided_crit(double *restrict x, size_t n, GainCriterion criterion,
|
|
3061
|
+
double min_gain, bool as_relative_gain, double *restrict buffer_sd,
|
|
3062
|
+
double &restrict split_point, double &restrict xmin, double &restrict xmax,
|
|
3063
|
+
size_t *restrict ix_arr_plus_st,
|
|
3064
|
+
size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
|
|
3065
|
+
double *restrict X_row_major, size_t ncols,
|
|
3066
|
+
double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
|
|
3067
|
+
{
|
|
3068
|
+
/* Note: the input 'x' is supposed to be a linear combination of standardized variables, so
|
|
3069
|
+
all numbers are assumed to be small and in the same scale */
|
|
3070
|
+
double gain = 0;
|
|
3071
|
+
if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
|
|
3072
|
+
|
|
3073
|
+
/* here it's assumed the 'x' vector matches exactly with 'ix_arr' + 'st' */
|
|
3074
|
+
if (unlikely(n == 2))
|
|
3075
|
+
{
|
|
3076
|
+
if (x[0] == x[1]) return -HUGE_VAL;
|
|
3077
|
+
split_point = midpoint_with_reorder(x[0], x[1]);
|
|
3078
|
+
gain = 1.;
|
|
3079
|
+
if (gain > min_gain)
|
|
3080
|
+
return gain;
|
|
3081
|
+
else
|
|
3082
|
+
return 0.;
|
|
3083
|
+
}
|
|
3084
|
+
|
|
3085
|
+
if (criterion == FullGain)
|
|
3086
|
+
{
|
|
3087
|
+
/* TODO: these buffers should be allocated externally */
|
|
3088
|
+
std::vector<size_t> argsorted(n);
|
|
3089
|
+
std::iota(argsorted.begin(), argsorted.end(), (size_t)0);
|
|
3090
|
+
std::sort(argsorted.begin(), argsorted.end(),
|
|
3091
|
+
[&x](const size_t a, const size_t b){return x[a] < x[b];});
|
|
3092
|
+
if (x[argsorted[0]] == x[argsorted[n-1]]) return -HUGE_VAL;
|
|
3093
|
+
std::vector<double> temp_buffer(n + mult2(ncols));
|
|
3094
|
+
for (size_t ix = 0; ix < n; ix++) temp_buffer[ix] = x[argsorted[ix]];
|
|
3095
|
+
for (size_t ix = 0; ix < n; ix++)
|
|
3096
|
+
argsorted[ix] = ix_arr_plus_st[argsorted[ix]];
|
|
3097
|
+
size_t ignored;
|
|
3098
|
+
return find_split_full_gain<double, ldouble_safe>(
|
|
3099
|
+
temp_buffer.data(), (size_t)0, n-1, argsorted.data(),
|
|
3100
|
+
cols_use, ncols_use, force_cols_use,
|
|
3101
|
+
X_row_major, ncols,
|
|
3102
|
+
Xr, Xr_ind, Xr_indptr,
|
|
3103
|
+
temp_buffer.data() + n, temp_buffer.data() + n + ncols,
|
|
3104
|
+
ignored, split_point,
|
|
3105
|
+
false);
|
|
3106
|
+
}
|
|
3107
|
+
|
|
3108
|
+
/* sort in ascending order */
|
|
3109
|
+
std::sort(x, x + n);
|
|
3110
|
+
xmin = x[0]; xmax = x[n-1];
|
|
3111
|
+
if (x[0] == x[n-1]) return -HUGE_VAL;
|
|
3112
|
+
|
|
3113
|
+
if (criterion == Pooled && as_relative_gain && min_gain <= 0)
|
|
3114
|
+
gain = find_split_rel_gain<double, ldouble_safe>(x, n, split_point);
|
|
3115
|
+
else if (criterion == Pooled || criterion == Averaged)
|
|
3116
|
+
gain = find_split_std_gain<double, ldouble_safe>(x, n, buffer_sd, criterion, min_gain, split_point);
|
|
3117
|
+
else if (criterion == DensityCrit)
|
|
3118
|
+
gain = find_split_dens_shortform<double, ldouble_safe>(x, n, split_point);
|
|
3119
|
+
/* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
|
|
3120
|
+
return std::fmax(0., gain);
|
|
3121
|
+
}
|
|
3122
|
+
|
|
3123
|
+
template <class ldouble_safe>
|
|
3124
|
+
double eval_guided_crit_weighted(double *restrict x, size_t n, GainCriterion criterion,
|
|
3125
|
+
double min_gain, bool as_relative_gain, double *restrict buffer_sd,
|
|
3126
|
+
double &restrict split_point, double &restrict xmin, double &restrict xmax,
|
|
3127
|
+
double *restrict w, size_t *restrict buffer_indices,
|
|
3128
|
+
size_t *restrict ix_arr_plus_st,
|
|
3129
|
+
size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
|
|
3130
|
+
double *restrict X_row_major, size_t ncols,
|
|
3131
|
+
double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
|
|
3132
|
+
{
|
|
3133
|
+
/* Note: the input 'x' is supposed to be a linear combination of standardized variables, so
|
|
3134
|
+
all numbers are assumed to be small and in the same scale */
|
|
3135
|
+
double gain = 0;
|
|
3136
|
+
if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
|
|
3137
|
+
|
|
3138
|
+
/* here it's assumed the 'x' vector matches exactly with 'ix_arr' + 'st' */
|
|
3139
|
+
if (unlikely(n == 2))
|
|
3140
|
+
{
|
|
3141
|
+
if (x[0] == x[1]) return -HUGE_VAL;
|
|
3142
|
+
split_point = midpoint_with_reorder(x[0], x[1]);
|
|
3143
|
+
gain = 1.;
|
|
3144
|
+
if (gain > min_gain)
|
|
3145
|
+
return gain;
|
|
3146
|
+
else
|
|
3147
|
+
return 0.;
|
|
3148
|
+
}
|
|
3149
|
+
|
|
3150
|
+
/* sort in ascending order */
|
|
3151
|
+
std::iota(buffer_indices, buffer_indices + n, (size_t)0);
|
|
3152
|
+
std::sort(buffer_indices, buffer_indices + n,
|
|
3153
|
+
[&x](const size_t a, const size_t b){return x[a] < x[b];});
|
|
3154
|
+
xmin = x[buffer_indices[0]]; xmax = x[buffer_indices[n-1]];
|
|
3155
|
+
if (xmin == xmax) return -HUGE_VAL;
|
|
3156
|
+
|
|
3157
|
+
if (criterion == Pooled || criterion == Averaged)
|
|
3158
|
+
gain = find_split_std_gain_weighted<double, ldouble_safe>(x, n, buffer_sd, criterion, min_gain, split_point, w, buffer_indices);
|
|
3159
|
+
else if (criterion == DensityCrit)
|
|
3160
|
+
gain = find_split_dens_shortform_weighted<double, double *restrict, ldouble_safe>(x, n, split_point, w, buffer_indices);
|
|
3161
|
+
else if (criterion == FullGain)
|
|
3162
|
+
{
|
|
3163
|
+
std::vector<size_t> argsorted(n);
|
|
3164
|
+
std::iota(argsorted.begin(), argsorted.end(), (size_t)0);
|
|
3165
|
+
std::sort(argsorted.begin(), argsorted.end(),
|
|
3166
|
+
[&x](const size_t a, const size_t b){return x[a] < x[b];});
|
|
3167
|
+
if (x[argsorted[0]] == x[argsorted[n-1]]) return -HUGE_VAL;
|
|
3168
|
+
std::vector<double> temp_buffer(n + mult2(ncols));
|
|
3169
|
+
for (size_t ix = 0; ix < n; ix++) temp_buffer[ix] = x[argsorted[ix]];
|
|
3170
|
+
for (size_t ix = 0; ix < n; ix++)
|
|
3171
|
+
argsorted[ix] = ix_arr_plus_st[argsorted[ix]];
|
|
3172
|
+
size_t ignored;
|
|
3173
|
+
gain = find_split_full_gain_weighted<double, double *restrict, ldouble_safe>(
|
|
3174
|
+
temp_buffer.data(), (size_t)0, n-1, argsorted.data(),
|
|
3175
|
+
cols_use, ncols_use, force_cols_use,
|
|
3176
|
+
X_row_major, ncols,
|
|
3177
|
+
Xr, Xr_ind, Xr_indptr,
|
|
3178
|
+
temp_buffer.data() + n, temp_buffer.data() + n + ncols,
|
|
3179
|
+
ignored, split_point,
|
|
3180
|
+
false,
|
|
3181
|
+
w);
|
|
3182
|
+
}
|
|
3183
|
+
/* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
|
|
3184
|
+
return std::fmax(0., gain);
|
|
3185
|
+
}
|
|
3186
|
+
|
|
3187
|
+
/* for split-criterion in single-variable splits */
|
|
3188
|
+
template <class real_t_, class ldouble_safe>
|
|
3189
|
+
double eval_guided_crit(size_t *restrict ix_arr, size_t st, size_t end, real_t_ *restrict x,
|
|
3190
|
+
double *restrict buffer_sd, bool as_relative_gain,
|
|
3191
|
+
double *restrict buffer_imputed_x, double *restrict saved_xmedian,
|
|
3192
|
+
size_t &split_ix, double &restrict split_point, double &restrict xmin, double &restrict xmax,
|
|
3193
|
+
GainCriterion criterion, double min_gain, MissingAction missing_action,
|
|
3194
|
+
size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
|
|
3195
|
+
double *restrict X_row_major, size_t ncols,
|
|
3196
|
+
double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
|
|
3197
|
+
{
|
|
3198
|
+
size_t st_orig = st;
|
|
3199
|
+
double gain = 0;
|
|
3200
|
+
if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
|
|
3201
|
+
|
|
3202
|
+
/* move NAs to the front if there's any, exclude them from calculations */
|
|
3203
|
+
if (missing_action != Fail)
|
|
3204
|
+
st = move_NAs_to_front(ix_arr, st, end, x);
|
|
3205
|
+
|
|
3206
|
+
if (unlikely(st >= end)) return -HUGE_VAL;
|
|
3207
|
+
else if (unlikely(st == (end-1)))
|
|
3208
|
+
{
|
|
3209
|
+
if (x[ix_arr[st]] == x[ix_arr[end]])
|
|
3210
|
+
return -HUGE_VAL;
|
|
3211
|
+
split_point = midpoint_with_reorder(x[ix_arr[st]], x[ix_arr[end]]);
|
|
3212
|
+
split_ix = st;
|
|
3213
|
+
gain = 1.;
|
|
3214
|
+
if (gain > min_gain)
|
|
3215
|
+
return gain;
|
|
3216
|
+
else
|
|
3217
|
+
return 0.;
|
|
3218
|
+
}
|
|
3219
|
+
|
|
3220
|
+
/* sort in ascending order */
|
|
3221
|
+
std::sort(ix_arr + st, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
|
|
3222
|
+
if (x[ix_arr[st]] == x[ix_arr[end]]) return -HUGE_VAL;
|
|
3223
|
+
xmin = x[ix_arr[st]]; xmax = x[ix_arr[end]];
|
|
3224
|
+
|
|
3225
|
+
/* unlike the previous case for the extended model, the data here has not been centered,
|
|
3226
|
+
which could make the standard deviations have poor precision. It's nevertheless not
|
|
3227
|
+
necessary for this mean to have good precision, since it's only meant for centering,
|
|
3228
|
+
so it can be calculated inexactly with simd instructions. */
|
|
3229
|
+
real_t_ xmean = 0;
|
|
3230
|
+
if (criterion == Pooled || criterion == Averaged)
|
|
3231
|
+
{
|
|
3232
|
+
for (size_t ix = st; ix <= end; ix++)
|
|
3233
|
+
xmean += x[ix_arr[ix]];
|
|
3234
|
+
xmean /= (real_t_)(end - st + 1);
|
|
3235
|
+
}
|
|
3236
|
+
|
|
3237
|
+
if (missing_action == Impute && st > st_orig)
|
|
3238
|
+
{
|
|
3239
|
+
missing_action = Fail;
|
|
3240
|
+
fill_NAs_with_median(ix_arr, st_orig, st, end, x, buffer_imputed_x, saved_xmedian);
|
|
3241
|
+
if (criterion == Pooled && as_relative_gain && min_gain <= 0)
|
|
3242
|
+
gain = find_split_rel_gain<double, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, split_point, split_ix);
|
|
3243
|
+
else if (criterion == Pooled || criterion == Averaged)
|
|
3244
|
+
gain = find_split_std_gain<double, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, buffer_sd, criterion, min_gain, split_point, split_ix);
|
|
3245
|
+
else if (criterion == DensityCrit)
|
|
3246
|
+
gain = find_split_dens<double, ldouble_safe>(buffer_imputed_x, ix_arr, st_orig, end, split_point, split_ix);
|
|
3247
|
+
else if (criterion == FullGain)
|
|
3248
|
+
{
|
|
3249
|
+
/* TODO: this buffer should be allocated from outside */
|
|
3250
|
+
std::vector<double> temp_buffer(mult2(ncols));
|
|
3251
|
+
gain = find_split_full_gain<double, ldouble_safe>(
|
|
3252
|
+
buffer_imputed_x, st_orig, end, ix_arr,
|
|
3253
|
+
cols_use, ncols_use, force_cols_use,
|
|
3254
|
+
X_row_major, ncols,
|
|
3255
|
+
Xr, Xr_ind, Xr_indptr,
|
|
3256
|
+
temp_buffer.data(), temp_buffer.data() + ncols,
|
|
3257
|
+
split_ix, split_point, true);
|
|
3258
|
+
}
|
|
3259
|
+
|
|
3260
|
+
/* Note: in theory, it should be possible to use a faster version assuming a contiguous array for 'x',
|
|
3261
|
+
but such an approach might give inexact split points. Better to avoid such inexactness at the
|
|
3262
|
+
expense of more computations. */
|
|
3263
|
+
}
|
|
3264
|
+
|
|
3265
|
+
else
|
|
3266
|
+
{
|
|
3267
|
+
if (criterion == Pooled && as_relative_gain && min_gain <= 0)
|
|
3268
|
+
gain = find_split_rel_gain<real_t_, ldouble_safe>(x, xmean, ix_arr, st, end, split_point, split_ix);
|
|
3269
|
+
else if (criterion == Pooled || criterion == Averaged)
|
|
3270
|
+
gain = find_split_std_gain<real_t_, ldouble_safe>(x, xmean, ix_arr, st, end, buffer_sd, criterion, min_gain, split_point, split_ix);
|
|
3271
|
+
else if (criterion == DensityCrit)
|
|
3272
|
+
gain = find_split_dens<real_t_, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix);
|
|
3273
|
+
else if (criterion == FullGain)
|
|
3274
|
+
{
|
|
3275
|
+
/* TODO: this buffer should be allocated from outside */
|
|
3276
|
+
std::vector<double> temp_buffer(mult2(ncols));
|
|
3277
|
+
gain = find_split_full_gain<real_t_, ldouble_safe>(
|
|
3278
|
+
x, st, end, ix_arr,
|
|
3279
|
+
cols_use, ncols_use, force_cols_use,
|
|
3280
|
+
X_row_major, ncols,
|
|
3281
|
+
Xr, Xr_ind, Xr_indptr,
|
|
3282
|
+
temp_buffer.data(), temp_buffer.data() + ncols,
|
|
3283
|
+
split_ix, split_point, true);
|
|
3284
|
+
}
|
|
3285
|
+
}
|
|
3286
|
+
|
|
3287
|
+
/* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
|
|
3288
|
+
return std::fmax(0., gain);
|
|
3289
|
+
}
|
|
3290
|
+
|
|
3291
|
+
template <class real_t_, class mapping, class ldouble_safe>
|
|
3292
|
+
double eval_guided_crit_weighted(size_t *restrict ix_arr, size_t st, size_t end, real_t_ *restrict x,
|
|
3293
|
+
double *restrict buffer_sd, bool as_relative_gain,
|
|
3294
|
+
double *restrict buffer_imputed_x, double *restrict saved_xmedian,
|
|
3295
|
+
size_t &split_ix, double &restrict split_point, double &restrict xmin, double &restrict xmax,
|
|
3296
|
+
GainCriterion criterion, double min_gain, MissingAction missing_action,
|
|
3297
|
+
size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
|
|
3298
|
+
double *restrict X_row_major, size_t ncols,
|
|
3299
|
+
double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
|
|
3300
|
+
mapping &restrict w)
|
|
3301
|
+
{
|
|
3302
|
+
size_t st_orig = st;
|
|
3303
|
+
double gain = 0;
|
|
3304
|
+
if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
|
|
3305
|
+
|
|
3306
|
+
/* move NAs to the front if there's any, exclude them from calculations */
|
|
3307
|
+
if (missing_action != Fail)
|
|
3308
|
+
st = move_NAs_to_front(ix_arr, st, end, x);
|
|
3309
|
+
|
|
3310
|
+
if (unlikely(st >= end)) return -HUGE_VAL;
|
|
3311
|
+
else if (unlikely(st == (end-1)))
|
|
3312
|
+
{
|
|
3313
|
+
if (x[ix_arr[st]] == x[ix_arr[end]])
|
|
3314
|
+
return -HUGE_VAL;
|
|
3315
|
+
split_point = midpoint_with_reorder(x[ix_arr[st]], x[ix_arr[end]]);
|
|
3316
|
+
split_ix = st;
|
|
3317
|
+
gain = 1.;
|
|
3318
|
+
if (gain > min_gain)
|
|
3319
|
+
return gain;
|
|
3320
|
+
else
|
|
3321
|
+
return 0.;
|
|
3322
|
+
}
|
|
3323
|
+
|
|
3324
|
+
/* sort in ascending order */
|
|
3325
|
+
std::sort(ix_arr + st, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
|
|
3326
|
+
if (x[ix_arr[st]] == x[ix_arr[end]]) return -HUGE_VAL;
|
|
3327
|
+
xmin = x[ix_arr[st]]; xmax = x[ix_arr[end]];
|
|
3328
|
+
|
|
3329
|
+
/* unlike the previous case for the extended model, the data here has not been centered,
|
|
3330
|
+
which could make the standard deviations have poor precision. It's nevertheless not
|
|
3331
|
+
necessary for this mean to have good precision, since it's only meant for centering,
|
|
3332
|
+
so it can be calculated inexactly with simd instructions. */
|
|
3333
|
+
real_t_ xmean = 0;
|
|
3334
|
+
real_t_ cnt = 0;
|
|
3335
|
+
if (criterion == Pooled || criterion == Averaged)
|
|
3336
|
+
{
|
|
3337
|
+
for (size_t ix = st; ix <= end; ix++)
|
|
3338
|
+
{
|
|
3339
|
+
xmean += x[ix_arr[ix]];
|
|
3340
|
+
cnt += w[ix_arr[ix]];
|
|
3341
|
+
}
|
|
3342
|
+
xmean /= cnt;
|
|
3343
|
+
}
|
|
3344
|
+
|
|
3345
|
+
if (missing_action == Impute && st > st_orig)
|
|
3346
|
+
{
|
|
3347
|
+
missing_action = Fail;
|
|
3348
|
+
fill_NAs_with_median(ix_arr, st_orig, st, end, x, buffer_imputed_x, saved_xmedian);
|
|
3349
|
+
if (criterion == Pooled && as_relative_gain && min_gain <= 0)
|
|
3350
|
+
gain = find_split_rel_gain_weighted<double, mapping, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, split_point, split_ix, w);
|
|
3351
|
+
else if (criterion == Pooled || criterion == Averaged)
|
|
3352
|
+
gain = find_split_std_gain_weighted<double, mapping, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, buffer_sd, criterion, min_gain, split_point, split_ix, w);
|
|
3353
|
+
else if (criterion == DensityCrit)
|
|
3354
|
+
gain = find_split_dens_weighted<double, mapping, ldouble_safe>(buffer_imputed_x, ix_arr, st_orig, end, split_point, split_ix, w);
|
|
3355
|
+
else if (criterion == FullGain)
|
|
3356
|
+
{
|
|
3357
|
+
std::vector<double> temp_buffer(mult2(ncols));
|
|
3358
|
+
gain = find_split_full_gain_weighted<double, mapping, ldouble_safe>(
|
|
3359
|
+
buffer_imputed_x, st_orig, end, ix_arr,
|
|
3360
|
+
cols_use, ncols_use, force_cols_use,
|
|
3361
|
+
X_row_major, ncols,
|
|
3362
|
+
Xr, Xr_ind, Xr_indptr,
|
|
3363
|
+
temp_buffer.data(), temp_buffer.data() + ncols,
|
|
3364
|
+
split_ix, split_point, true,
|
|
3365
|
+
w);
|
|
3366
|
+
}
|
|
3367
|
+
}
|
|
3368
|
+
|
|
3369
|
+
else
|
|
3370
|
+
{
|
|
3371
|
+
if (criterion == Pooled && as_relative_gain && min_gain <= 0)
|
|
3372
|
+
gain = find_split_rel_gain_weighted<real_t_, mapping, ldouble_safe>(x, xmean, ix_arr, st, end, split_point, split_ix, w);
|
|
3373
|
+
else if (criterion == Pooled || criterion == Averaged)
|
|
3374
|
+
gain = find_split_std_gain_weighted<real_t_, mapping, ldouble_safe>(x, xmean, ix_arr, st, end, buffer_sd, criterion, min_gain, split_point, split_ix, w);
|
|
3375
|
+
else if (criterion == DensityCrit)
|
|
3376
|
+
gain = find_split_dens_weighted<real_t_, mapping, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix, w);
|
|
3377
|
+
else if (criterion == FullGain)
|
|
3378
|
+
{
|
|
3379
|
+
std::vector<double> temp_buffer(mult2(ncols));
|
|
3380
|
+
gain = find_split_full_gain_weighted<real_t_, mapping, ldouble_safe>(
|
|
3381
|
+
x, st, end, ix_arr,
|
|
3382
|
+
cols_use, ncols_use, force_cols_use,
|
|
3383
|
+
X_row_major, ncols,
|
|
3384
|
+
Xr, Xr_ind, Xr_indptr,
|
|
3385
|
+
temp_buffer.data(), temp_buffer.data() + ncols,
|
|
3386
|
+
split_ix, split_point, true,
|
|
3387
|
+
w);
|
|
3388
|
+
}
|
|
3389
|
+
}
|
|
3390
|
+
|
|
3391
|
+
/* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
|
|
3392
|
+
return std::fmax(0., gain);
|
|
3393
|
+
}
|
|
3394
|
+
|
|
3395
|
+
/* TODO: here it should only need to look at the non-zero entries. It can then use the
|
|
3396
|
+
same algorithm as above, but putting an extra check to see where do the zeros fit in
|
|
3397
|
+
the sorted order of the non-zero entries while calculating gains and SDs, and then
|
|
3398
|
+
call the 'divide_subset' function after-the-fact to reach the same end result.
|
|
3399
|
+
It should be much faster than this if the non-zero entries are few. */
|
|
3400
|
+
template <class real_t_, class sparse_ix, class ldouble_safe>
|
|
3401
|
+
double eval_guided_crit(size_t ix_arr[], size_t st, size_t end,
|
|
3402
|
+
size_t col_num, real_t_ Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
|
|
3403
|
+
double buffer_arr[], size_t buffer_pos[], bool as_relative_gain,
|
|
3404
|
+
double *restrict saved_xmedian,
|
|
3405
|
+
double &split_point, double &xmin, double &xmax,
|
|
3406
|
+
GainCriterion criterion, double min_gain, MissingAction missing_action,
|
|
3407
|
+
size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
|
|
3408
|
+
double *restrict X_row_major, size_t ncols,
|
|
3409
|
+
double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
|
|
3410
|
+
{
|
|
3411
|
+
size_t ignored;
|
|
3412
|
+
|
|
3413
|
+
|
|
3414
|
+
todense(ix_arr, st, end,
|
|
3415
|
+
col_num, Xc, Xc_ind, Xc_indptr,
|
|
3416
|
+
buffer_arr);
|
|
3417
|
+
size_t tot = end - st + 1;
|
|
3418
|
+
std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
|
|
3419
|
+
|
|
3420
|
+
if (missing_action == Impute)
|
|
3421
|
+
{
|
|
3422
|
+
missing_action = Fail;
|
|
3423
|
+
for (size_t ix = 0; ix < tot; ix++)
|
|
3424
|
+
{
|
|
3425
|
+
if (unlikely(is_na_or_inf(buffer_arr[ix])))
|
|
3426
|
+
{
|
|
3427
|
+
goto fill_missing;
|
|
3428
|
+
}
|
|
3429
|
+
}
|
|
3430
|
+
goto no_nas;
|
|
3431
|
+
|
|
3432
|
+
fill_missing:
|
|
3433
|
+
{
|
|
3434
|
+
size_t idx_half = div2(tot);
|
|
3435
|
+
std::nth_element(buffer_pos, buffer_pos + idx_half, buffer_pos + tot,
|
|
3436
|
+
[&buffer_arr](const size_t a, const size_t b){return buffer_arr[a] < buffer_arr[b];});
|
|
3437
|
+
*saved_xmedian = buffer_arr[buffer_pos[idx_half]];
|
|
3438
|
+
|
|
3439
|
+
if ((tot % 2) == 0)
|
|
3440
|
+
{
|
|
3441
|
+
double xlow = *std::max_element(buffer_pos, buffer_pos + idx_half);
|
|
3442
|
+
*saved_xmedian = xlow + ((*saved_xmedian)-xlow)/2.;
|
|
3443
|
+
}
|
|
3444
|
+
|
|
3445
|
+
for (size_t ix = 0; ix < tot; ix++)
|
|
3446
|
+
buffer_arr[ix] = is_na_or_inf(buffer_arr[ix])? (*saved_xmedian) : buffer_arr[ix];
|
|
3447
|
+
std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
|
|
3448
|
+
}
|
|
3449
|
+
}
|
|
3450
|
+
|
|
3451
|
+
no_nas:
|
|
3452
|
+
return eval_guided_crit<double, ldouble_safe>(
|
|
3453
|
+
buffer_pos, 0, end - st, buffer_arr, buffer_arr + tot,
|
|
3454
|
+
as_relative_gain, saved_xmedian, (double*)NULL, ignored, split_point,
|
|
3455
|
+
xmin, xmax, criterion, min_gain, missing_action,
|
|
3456
|
+
cols_use, ncols_use, force_cols_use,
|
|
3457
|
+
X_row_major, ncols,
|
|
3458
|
+
Xr, Xr_ind, Xr_indptr);
|
|
3459
|
+
}
|
|
3460
|
+
|
|
3461
|
+
template <class real_t_, class sparse_ix, class mapping, class ldouble_safe>
|
|
3462
|
+
double eval_guided_crit_weighted(size_t ix_arr[], size_t st, size_t end,
|
|
3463
|
+
size_t col_num, real_t_ Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
|
|
3464
|
+
double buffer_arr[], size_t buffer_pos[], bool as_relative_gain,
|
|
3465
|
+
double *restrict saved_xmedian,
|
|
3466
|
+
double &restrict split_point, double &restrict xmin, double &restrict xmax,
|
|
3467
|
+
GainCriterion criterion, double min_gain, MissingAction missing_action,
|
|
3468
|
+
size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
|
|
3469
|
+
double *restrict X_row_major, size_t ncols,
|
|
3470
|
+
double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
|
|
3471
|
+
mapping &restrict w)
|
|
3472
|
+
{
|
|
3473
|
+
size_t ignored;
|
|
3474
|
+
|
|
3475
|
+
|
|
3476
|
+
todense(ix_arr, st, end,
|
|
3477
|
+
col_num, Xc, Xc_ind, Xc_indptr,
|
|
3478
|
+
buffer_arr);
|
|
3479
|
+
size_t tot = end - st + 1;
|
|
3480
|
+
std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
|
|
3481
|
+
|
|
3482
|
+
|
|
3483
|
+
if (missing_action == Impute)
|
|
3484
|
+
{
|
|
3485
|
+
missing_action = Fail;
|
|
3486
|
+
for (size_t ix = 0; ix < tot; ix++)
|
|
3487
|
+
{
|
|
3488
|
+
if (unlikely(is_na_or_inf(buffer_arr[ix])))
|
|
3489
|
+
{
|
|
3490
|
+
goto fill_missing;
|
|
3491
|
+
}
|
|
3492
|
+
}
|
|
3493
|
+
goto no_nas;
|
|
3494
|
+
|
|
3495
|
+
fill_missing:
|
|
3496
|
+
{
|
|
3497
|
+
size_t idx_half = div2(tot);
|
|
3498
|
+
std::nth_element(buffer_pos, buffer_pos + idx_half, buffer_pos + tot,
|
|
3499
|
+
[&buffer_arr](const size_t a, const size_t b){return buffer_arr[a] < buffer_arr[b];});
|
|
3500
|
+
*saved_xmedian = buffer_arr[buffer_pos[idx_half]];
|
|
3501
|
+
|
|
3502
|
+
if ((tot % 2) == 0)
|
|
3503
|
+
{
|
|
3504
|
+
double xlow = *std::max_element(buffer_pos, buffer_pos + idx_half);
|
|
3505
|
+
*saved_xmedian = xlow + ((*saved_xmedian)-xlow)/2.;
|
|
3506
|
+
}
|
|
3507
|
+
|
|
3508
|
+
for (size_t ix = 0; ix < tot; ix++)
|
|
3509
|
+
buffer_arr[ix] = is_na_or_inf(buffer_arr[ix])? (*saved_xmedian) : buffer_arr[ix];
|
|
3510
|
+
std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
|
|
3511
|
+
}
|
|
3512
|
+
}
|
|
3513
|
+
|
|
3514
|
+
|
|
3515
|
+
no_nas:
|
|
3516
|
+
/* TODO: allocate this buffer externally */
|
|
3517
|
+
std::vector<double> buffer_w(tot);
|
|
3518
|
+
for (size_t row = st; row <= end; row++)
|
|
3519
|
+
buffer_w[row-st] = w[ix_arr[row]];
|
|
3520
|
+
/* TODO: in this case, as the weights match with the order of the indices, could use a faster version
|
|
3521
|
+
with a weighted rel_gain function instead (not yet implemented). */
|
|
3522
|
+
return eval_guided_crit_weighted<double, std::vector<double>, ldouble_safe>(
|
|
3523
|
+
buffer_pos, 0, end - st, buffer_arr, buffer_arr + tot,
|
|
3524
|
+
as_relative_gain, saved_xmedian, (double*)NULL, ignored, split_point,
|
|
3525
|
+
xmin, xmax, criterion, min_gain, missing_action,
|
|
3526
|
+
cols_use, ncols_use, force_cols_use,
|
|
3527
|
+
X_row_major, ncols,
|
|
3528
|
+
Xr, Xr_ind, Xr_indptr,
|
|
3529
|
+
buffer_w);
|
|
3530
|
+
}
|
|
3531
|
+
|
|
3532
|
+
/* How this works:
|
|
3533
|
+
- For Averaged criterion, will take the expected standard deviation that would be gotten with the category counts
|
|
3534
|
+
if each category got assigned a real number at random ~ Unif(0,1) and the data were thus converted to
|
|
3535
|
+
numerical. In such case, the best split (highest sd gain) is always putting the second-highest count in one
|
|
3536
|
+
branch, so there is no point in doing a full search over other permutations. In order to get more reasonable
|
|
3537
|
+
splits, when using the option to split by subsets of categories, it will sort the counts and evaluate only
|
|
3538
|
+
splits in which the categories are grouped in sorted order - in such cases it tends to pick either the
|
|
3539
|
+
smallest or the largest category to assign to one branch, but sometimes picks groups too.
|
|
3540
|
+
- For Pooled criterion, will take shannon entropy, which tends to make a more even split. In the case of splitting
|
|
3541
|
+
by a single category, it always puts the largest category in a separate branch. In the case of subsets,
|
|
3542
|
+
it can either evaluate possible splits over all permutations (not feasible if there are too many categories),
|
|
3543
|
+
or look up for splits in sorted order just like for Averaged criterion.
|
|
3544
|
+
Splitting by averaged Gini gain (like with Averaged) also selects always the second-largest category to put in one branch,
|
|
3545
|
+
while splitting by weighted Gini (like with Pooled) usually selects the largest category to put in one branch. The
|
|
3546
|
+
Gini gain is not easily comparable to that of numerical columns, so it's not offered as an option here.
|
|
3547
|
+
*/
|
|
3548
|
+
/* https://math.stackexchange.com/questions/3343384/expected-variance-and-kurtosis-from-pmf-in-which-possible-discrete-values-are-dr */
|
|
3549
|
+
/* TODO: 'buffer_pos' doesn't need to be 'size_t', 'int' would suffice */
|
|
3550
|
+
template <class ldouble_safe>
|
|
3551
|
+
double eval_guided_crit(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int ncat,
|
|
3552
|
+
int *restrict saved_cat_mode,
|
|
3553
|
+
size_t *restrict buffer_cnt, size_t *restrict buffer_pos, double *restrict buffer_prob,
|
|
3554
|
+
int &restrict chosen_cat, signed char *restrict split_categ, signed char *restrict buffer_split,
|
|
3555
|
+
GainCriterion criterion, double min_gain, bool all_perm,
|
|
3556
|
+
MissingAction missing_action, CategSplit cat_split_type)
|
|
3557
|
+
{
|
|
3558
|
+
if (criterion == DensityCrit)
|
|
3559
|
+
return find_split_dens_longform<size_t, ldouble_safe>(
|
|
3560
|
+
x, ncat, ix_arr, st, end,
|
|
3561
|
+
cat_split_type, missing_action,
|
|
3562
|
+
chosen_cat, split_categ, saved_cat_mode,
|
|
3563
|
+
buffer_cnt, buffer_pos);
|
|
3564
|
+
if (st >= end) return -HUGE_VAL;
|
|
3565
|
+
size_t n_nas = 0;
|
|
3566
|
+
int xval;
|
|
3567
|
+
|
|
3568
|
+
/* count categories */
|
|
3569
|
+
memset(buffer_cnt, 0, sizeof(size_t) * ncat);
|
|
3570
|
+
if (missing_action == Fail)
|
|
3571
|
+
{
|
|
3572
|
+
for (size_t row = st; row <= end; row++)
|
|
3573
|
+
if (likely(x[ix_arr[row]] >= 0))
|
|
3574
|
+
buffer_cnt[x[ix_arr[row]]]++;
|
|
3575
|
+
}
|
|
3576
|
+
|
|
3577
|
+
else if (missing_action == Impute)
|
|
3578
|
+
{
|
|
3579
|
+
for (size_t row = st; row <= end; row++)
|
|
3580
|
+
{
|
|
3581
|
+
xval = x[ix_arr[row]];
|
|
3582
|
+
if (unlikely(xval < 0))
|
|
3583
|
+
n_nas++;
|
|
3584
|
+
else
|
|
3585
|
+
buffer_cnt[xval]++;
|
|
3586
|
+
}
|
|
3587
|
+
|
|
3588
|
+
if (unlikely(n_nas >= end-st)) return -HUGE_VAL;
|
|
3589
|
+
|
|
3590
|
+
if (n_nas)
|
|
3591
|
+
{
|
|
3592
|
+
auto idxmax = std::max_element(buffer_cnt, buffer_cnt + ncat);
|
|
3593
|
+
*idxmax += n_nas;
|
|
3594
|
+
*saved_cat_mode = (int)std::distance(buffer_cnt, idxmax);
|
|
3595
|
+
}
|
|
3596
|
+
}
|
|
3597
|
+
|
|
3598
|
+
else
|
|
3599
|
+
{
|
|
3600
|
+
for (size_t row = st; row <= end; row++)
|
|
3601
|
+
{
|
|
3602
|
+
xval = x[ix_arr[row]];
|
|
3603
|
+
if (likely(xval >= 0)) buffer_cnt[xval]++;
|
|
3604
|
+
}
|
|
3605
|
+
}
|
|
3606
|
+
|
|
3607
|
+
double this_gain = -HUGE_VAL;
|
|
3608
|
+
double best_gain = -HUGE_VAL;
|
|
3609
|
+
std::iota(buffer_pos, buffer_pos + ncat, (size_t)0);
|
|
3610
|
+
size_t st_pos = 0;
|
|
3611
|
+
|
|
3612
|
+
switch (cat_split_type)
|
|
3613
|
+
{
|
|
3614
|
+
case SingleCateg:
|
|
3615
|
+
{
|
|
3616
|
+
size_t cnt = end - st + 1;
|
|
3617
|
+
ldouble_safe cnt_l = (ldouble_safe) cnt;
|
|
3618
|
+
size_t ncat_present = 0;
|
|
3619
|
+
|
|
3620
|
+
switch(criterion)
|
|
3621
|
+
{
|
|
3622
|
+
case Averaged:
|
|
3623
|
+
{
|
|
3624
|
+
/* move zero-counts to the beginning */
|
|
3625
|
+
size_t temp;
|
|
3626
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
3627
|
+
{
|
|
3628
|
+
if (buffer_cnt[cat])
|
|
3629
|
+
{
|
|
3630
|
+
ncat_present++;
|
|
3631
|
+
buffer_prob[cat] = (ldouble_safe) buffer_cnt[cat] / cnt_l;
|
|
3632
|
+
}
|
|
3633
|
+
|
|
3634
|
+
else
|
|
3635
|
+
{
|
|
3636
|
+
temp = buffer_pos[st_pos];
|
|
3637
|
+
buffer_pos[st_pos] = buffer_pos[cat];
|
|
3638
|
+
buffer_pos[cat] = temp;
|
|
3639
|
+
st_pos++;
|
|
3640
|
+
}
|
|
3641
|
+
}
|
|
3642
|
+
|
|
3643
|
+
if (ncat_present <= 1) return -HUGE_VAL;
|
|
3644
|
+
|
|
3645
|
+
double sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
|
|
3646
|
+
|
|
3647
|
+
/* try isolating each category one at a time */
|
|
3648
|
+
for (size_t pos = st_pos; (int)pos < ncat; pos++)
|
|
3649
|
+
{
|
|
3650
|
+
this_gain = sd_gain(sd_full,
|
|
3651
|
+
0.0,
|
|
3652
|
+
(expected_sd_cat_single<size_t, size_t, ldouble_safe>(buffer_cnt, buffer_prob, ncat_present, buffer_pos + st_pos, pos - st_pos, cnt)));
|
|
3653
|
+
if (this_gain > min_gain && this_gain > best_gain)
|
|
3654
|
+
{
|
|
3655
|
+
best_gain = this_gain;
|
|
3656
|
+
chosen_cat = buffer_pos[pos];
|
|
3657
|
+
}
|
|
3658
|
+
}
|
|
3659
|
+
break;
|
|
3660
|
+
}
|
|
3661
|
+
|
|
3662
|
+
case Pooled:
|
|
3663
|
+
{
|
|
3664
|
+
/* here it will always pick the largest one */
|
|
3665
|
+
size_t ncat_present = 0;
|
|
3666
|
+
size_t cnt_max = 0;
|
|
3667
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
3668
|
+
{
|
|
3669
|
+
if (buffer_cnt[cat])
|
|
3670
|
+
{
|
|
3671
|
+
ncat_present++;
|
|
3672
|
+
if (cnt_max < buffer_cnt[cat])
|
|
3673
|
+
{
|
|
3674
|
+
cnt_max = buffer_cnt[cat];
|
|
3675
|
+
chosen_cat = cat;
|
|
3676
|
+
}
|
|
3677
|
+
}
|
|
3678
|
+
}
|
|
3679
|
+
|
|
3680
|
+
if (ncat_present <= 1) return -HUGE_VAL;
|
|
3681
|
+
|
|
3682
|
+
ldouble_safe cnt_left = (ldouble_safe)((end - st + 1) - cnt_max);
|
|
3683
|
+
this_gain = (
|
|
3684
|
+
(ldouble_safe)cnt * std::log((ldouble_safe)cnt)
|
|
3685
|
+
- cnt_left * std::log(cnt_left)
|
|
3686
|
+
- (ldouble_safe)cnt_max * std::log((ldouble_safe)cnt_max)
|
|
3687
|
+
) / cnt;
|
|
3688
|
+
best_gain = (this_gain > min_gain)? this_gain : best_gain;
|
|
3689
|
+
break;
|
|
3690
|
+
}
|
|
3691
|
+
|
|
3692
|
+
default:
|
|
3693
|
+
{
|
|
3694
|
+
unexpected_error();
|
|
3695
|
+
break;
|
|
3696
|
+
}
|
|
3697
|
+
}
|
|
3698
|
+
break;
|
|
3699
|
+
}
|
|
3700
|
+
|
|
3701
|
+
case SubSet:
|
|
3702
|
+
{
|
|
3703
|
+
/* sort by counts */
|
|
3704
|
+
std::sort(buffer_pos, buffer_pos + ncat, [&buffer_cnt](const size_t a, const size_t b){return buffer_cnt[a] < buffer_cnt[b];});
|
|
3705
|
+
|
|
3706
|
+
/* set split as: (1):left (0):right (-1):not_present */
|
|
3707
|
+
memset(buffer_split, 0, ncat * sizeof(signed char));
|
|
3708
|
+
|
|
3709
|
+
ldouble_safe cnt = (ldouble_safe)(end - st + 1);
|
|
3710
|
+
|
|
3711
|
+
switch(criterion)
|
|
3712
|
+
{
|
|
3713
|
+
case Averaged:
|
|
3714
|
+
{
|
|
3715
|
+
/* determine first non-zero and convert to probabilities */
|
|
3716
|
+
double sd_full;
|
|
3717
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
3718
|
+
{
|
|
3719
|
+
if (buffer_cnt[buffer_pos[cat]])
|
|
3720
|
+
{
|
|
3721
|
+
buffer_prob[buffer_pos[cat]] = (ldouble_safe)buffer_cnt[buffer_pos[cat]] / cnt;
|
|
3722
|
+
}
|
|
3723
|
+
|
|
3724
|
+
else
|
|
3725
|
+
{
|
|
3726
|
+
buffer_split[buffer_pos[cat]] = -1;
|
|
3727
|
+
st_pos++;
|
|
3728
|
+
}
|
|
3729
|
+
}
|
|
3730
|
+
|
|
3731
|
+
if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
|
|
3732
|
+
|
|
3733
|
+
/* calculate full SD assuming they take values randomly ~Unif(0, 1) */
|
|
3734
|
+
size_t ncat_present = (size_t)ncat - st_pos;
|
|
3735
|
+
sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
|
|
3736
|
+
if (ncat_present >= log2ceil(SIZE_MAX)) all_perm = false;
|
|
3737
|
+
|
|
3738
|
+
/* move categories one at a time */
|
|
3739
|
+
for (size_t pos = st_pos; pos < ((size_t)ncat - st_pos - 1); pos++)
|
|
3740
|
+
{
|
|
3741
|
+
buffer_split[buffer_pos[pos]] = 1;
|
|
3742
|
+
this_gain = sd_gain(sd_full,
|
|
3743
|
+
(expected_sd_cat<size_t, size_t, ldouble_safe>(buffer_cnt, buffer_prob, pos - st_pos + 1, buffer_pos + st_pos)),
|
|
3744
|
+
(expected_sd_cat<size_t, size_t, ldouble_safe>(buffer_cnt, buffer_prob, (size_t)ncat - pos - 1, buffer_pos + pos + 1))
|
|
3745
|
+
);
|
|
3746
|
+
if (this_gain > min_gain && this_gain > best_gain)
|
|
3747
|
+
{
|
|
3748
|
+
best_gain = this_gain;
|
|
3749
|
+
memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
|
|
3750
|
+
}
|
|
3751
|
+
}
|
|
3752
|
+
|
|
3753
|
+
break;
|
|
3754
|
+
}
|
|
3755
|
+
|
|
3756
|
+
case Pooled:
|
|
3757
|
+
{
|
|
3758
|
+
ldouble_safe s = 0;
|
|
3759
|
+
|
|
3760
|
+
/* determine first non-zero and get base info */
|
|
3761
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
3762
|
+
{
|
|
3763
|
+
if (buffer_cnt[buffer_pos[cat]])
|
|
3764
|
+
{
|
|
3765
|
+
s += (buffer_cnt[buffer_pos[cat]] <= 1)?
|
|
3766
|
+
0 : ((ldouble_safe) buffer_cnt[buffer_pos[cat]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[cat]]));
|
|
3767
|
+
}
|
|
3768
|
+
|
|
3769
|
+
else
|
|
3770
|
+
{
|
|
3771
|
+
buffer_split[buffer_pos[cat]] = -1;
|
|
3772
|
+
st_pos++;
|
|
3773
|
+
}
|
|
3774
|
+
}
|
|
3775
|
+
|
|
3776
|
+
if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
|
|
3777
|
+
|
|
3778
|
+
/* calculate base info */
|
|
3779
|
+
ldouble_safe base_info = cnt * std::log(cnt) - s;
|
|
3780
|
+
|
|
3781
|
+
if (all_perm)
|
|
3782
|
+
{
|
|
3783
|
+
size_t cnt_left, cnt_right;
|
|
3784
|
+
double s_left, s_right;
|
|
3785
|
+
size_t ncat_present = (size_t)ncat - st_pos;
|
|
3786
|
+
size_t ncomb = pow2(ncat_present) - 1;
|
|
3787
|
+
size_t best_combin;
|
|
3788
|
+
|
|
3789
|
+
for (size_t combin = 1; combin < ncomb; combin++)
|
|
3790
|
+
{
|
|
3791
|
+
cnt_left = 0; cnt_right = 0;
|
|
3792
|
+
s_left = 0; s_right = 0;
|
|
3793
|
+
for (size_t pos = st_pos; (int)pos < ncat; pos++)
|
|
3794
|
+
{
|
|
3795
|
+
if (extract_bit(combin, pos))
|
|
3796
|
+
{
|
|
3797
|
+
cnt_left += buffer_cnt[buffer_pos[pos]];
|
|
3798
|
+
s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
|
|
3799
|
+
0 : ((ldouble_safe) buffer_cnt[buffer_pos[pos]]
|
|
3800
|
+
* std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
|
|
3801
|
+
}
|
|
3802
|
+
|
|
3803
|
+
else
|
|
3804
|
+
{
|
|
3805
|
+
cnt_right += buffer_cnt[buffer_pos[pos]];
|
|
3806
|
+
s_right += (buffer_cnt[buffer_pos[pos]] <= 1)?
|
|
3807
|
+
0 : ((ldouble_safe) buffer_cnt[buffer_pos[pos]]
|
|
3808
|
+
* std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
|
|
3809
|
+
}
|
|
3810
|
+
}
|
|
3811
|
+
|
|
3812
|
+
this_gain = categ_gain<size_t, ldouble_safe>(
|
|
3813
|
+
cnt_left, cnt_right,
|
|
3814
|
+
s_left, s_right,
|
|
3815
|
+
base_info, cnt);
|
|
3816
|
+
|
|
3817
|
+
if (this_gain > min_gain && this_gain > best_gain)
|
|
3818
|
+
{
|
|
3819
|
+
best_gain = this_gain;
|
|
3820
|
+
best_combin = combin;
|
|
3821
|
+
}
|
|
3822
|
+
|
|
3823
|
+
}
|
|
3824
|
+
|
|
3825
|
+
if (best_gain > min_gain)
|
|
3826
|
+
for (size_t pos = 0; pos < ncat_present; pos++)
|
|
3827
|
+
split_categ[buffer_pos[st_pos + pos]] = extract_bit(best_combin, pos);
|
|
3828
|
+
|
|
3829
|
+
}
|
|
3830
|
+
|
|
3831
|
+
else
|
|
3832
|
+
{
|
|
3833
|
+
/* try moving the categories one at a time */
|
|
3834
|
+
size_t cnt_left = 0;
|
|
3835
|
+
size_t cnt_right = end - st + 1;
|
|
3836
|
+
double s_left = 0;
|
|
3837
|
+
double s_right = s;
|
|
3838
|
+
|
|
3839
|
+
for (size_t pos = st_pos; pos < (ncat - st_pos - 1); pos++)
|
|
3840
|
+
{
|
|
3841
|
+
buffer_split[buffer_pos[pos]] = 1;
|
|
3842
|
+
s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
|
|
3843
|
+
0 : ((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[pos]]));
|
|
3844
|
+
s_right -= (buffer_cnt[buffer_pos[pos]] <= 1)?
|
|
3845
|
+
0 : ((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[pos]]));
|
|
3846
|
+
cnt_left += buffer_cnt[buffer_pos[pos]];
|
|
3847
|
+
cnt_right -= buffer_cnt[buffer_pos[pos]];
|
|
3848
|
+
|
|
3849
|
+
this_gain = categ_gain<size_t, ldouble_safe>(
|
|
3850
|
+
cnt_left, cnt_right,
|
|
3851
|
+
s_left, s_right,
|
|
3852
|
+
base_info, cnt);
|
|
3853
|
+
|
|
3854
|
+
if (this_gain > min_gain && this_gain > best_gain)
|
|
3855
|
+
{
|
|
3856
|
+
best_gain = this_gain;
|
|
3857
|
+
memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
|
|
3858
|
+
}
|
|
3859
|
+
}
|
|
3860
|
+
}
|
|
3861
|
+
|
|
3862
|
+
break;
|
|
3863
|
+
}
|
|
3864
|
+
|
|
3865
|
+
default:
|
|
3866
|
+
{
|
|
3867
|
+
unexpected_error();
|
|
3868
|
+
break;
|
|
3869
|
+
}
|
|
3870
|
+
}
|
|
3871
|
+
}
|
|
3872
|
+
}
|
|
3873
|
+
|
|
3874
|
+
if (st == (end-1)) return 0;
|
|
3875
|
+
|
|
3876
|
+
if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
|
|
3877
|
+
return 0;
|
|
3878
|
+
else
|
|
3879
|
+
return best_gain;
|
|
3880
|
+
}
|
|
3881
|
+
|
|
3882
|
+
|
|
3883
|
+
template <class mapping, class ldouble_safe>
|
|
3884
|
+
double eval_guided_crit_weighted(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int ncat,
|
|
3885
|
+
int *restrict saved_cat_mode,
|
|
3886
|
+
size_t *restrict buffer_pos, double *restrict buffer_prob,
|
|
3887
|
+
int &restrict chosen_cat, signed char *restrict split_categ, signed char *restrict buffer_split,
|
|
3888
|
+
GainCriterion criterion, double min_gain, bool all_perm,
|
|
3889
|
+
MissingAction missing_action, CategSplit cat_split_type,
|
|
3890
|
+
mapping &restrict w)
|
|
3891
|
+
{
|
|
3892
|
+
if (criterion == DensityCrit)
|
|
3893
|
+
return find_split_dens_longform_weighted<mapping, size_t, ldouble_safe>(
|
|
3894
|
+
x, ncat, ix_arr, st, end,
|
|
3895
|
+
cat_split_type, missing_action,
|
|
3896
|
+
chosen_cat, split_categ, saved_cat_mode,
|
|
3897
|
+
buffer_pos, w);
|
|
3898
|
+
if (st >= end) return -HUGE_VAL;
|
|
3899
|
+
ldouble_safe w_missing = 0;
|
|
3900
|
+
int xval;
|
|
3901
|
+
size_t ix_;
|
|
3902
|
+
|
|
3903
|
+
/* count categories */
|
|
3904
|
+
/* TODO: allocate this buffer externally */
|
|
3905
|
+
std::vector<ldouble_safe> buffer_cnt(ncat, (ldouble_safe)0);
|
|
3906
|
+
if (missing_action == Fail)
|
|
3907
|
+
{
|
|
3908
|
+
for (size_t row = st; row <= end; row++)
|
|
3909
|
+
{
|
|
3910
|
+
ix_ = ix_arr[row];
|
|
3911
|
+
if (unlikely(x[ix_]) < 0) continue;
|
|
3912
|
+
buffer_cnt[x[ix_]] += w[ix_];
|
|
3913
|
+
}
|
|
3914
|
+
}
|
|
3915
|
+
|
|
3916
|
+
else if (missing_action == Impute)
|
|
3917
|
+
{
|
|
3918
|
+
for (size_t row = st; row <= end; row++)
|
|
3919
|
+
{
|
|
3920
|
+
ix_ = ix_arr[row];
|
|
3921
|
+
xval = x[ix_];
|
|
3922
|
+
if (unlikely(xval < 0))
|
|
3923
|
+
w_missing += w[ix_];
|
|
3924
|
+
else
|
|
3925
|
+
buffer_cnt[xval] += w[ix_];
|
|
3926
|
+
}
|
|
3927
|
+
|
|
3928
|
+
if (w_missing)
|
|
3929
|
+
{
|
|
3930
|
+
auto idxmax = std::max_element(buffer_cnt.begin(), buffer_cnt.end());
|
|
3931
|
+
*idxmax += w_missing;
|
|
3932
|
+
*saved_cat_mode = (int)std::distance(buffer_cnt.begin(), idxmax);
|
|
3933
|
+
}
|
|
3934
|
+
}
|
|
3935
|
+
|
|
3936
|
+
else
|
|
3937
|
+
{
|
|
3938
|
+
for (size_t row = st; row <= end; row++)
|
|
3939
|
+
{
|
|
3940
|
+
ix_ = ix_arr[row];
|
|
3941
|
+
xval = x[ix_];
|
|
3942
|
+
if (likely(xval >= 0)) buffer_cnt[xval] += w[ix_];
|
|
3943
|
+
}
|
|
3944
|
+
}
|
|
3945
|
+
|
|
3946
|
+
ldouble_safe cnt = std::accumulate(buffer_cnt.begin(), buffer_cnt.end(), (ldouble_safe)0);
|
|
3947
|
+
|
|
3948
|
+
double this_gain = -HUGE_VAL;
|
|
3949
|
+
double best_gain = -HUGE_VAL;
|
|
3950
|
+
std::iota(buffer_pos, buffer_pos + ncat, (size_t)0);
|
|
3951
|
+
size_t st_pos = 0;
|
|
3952
|
+
|
|
3953
|
+
switch(cat_split_type)
|
|
3954
|
+
{
|
|
3955
|
+
case SingleCateg:
|
|
3956
|
+
{
|
|
3957
|
+
size_t ncat_present = 0;
|
|
3958
|
+
|
|
3959
|
+
switch(criterion)
|
|
3960
|
+
{
|
|
3961
|
+
case Averaged:
|
|
3962
|
+
{
|
|
3963
|
+
/* move zero-counts to the beginning */
|
|
3964
|
+
size_t temp;
|
|
3965
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
3966
|
+
{
|
|
3967
|
+
if (buffer_cnt[cat])
|
|
3968
|
+
{
|
|
3969
|
+
ncat_present++;
|
|
3970
|
+
buffer_prob[cat] = buffer_cnt[cat] / cnt;
|
|
3971
|
+
}
|
|
3972
|
+
|
|
3973
|
+
else
|
|
3974
|
+
{
|
|
3975
|
+
temp = buffer_pos[st_pos];
|
|
3976
|
+
buffer_pos[st_pos] = buffer_pos[cat];
|
|
3977
|
+
buffer_pos[cat] = temp;
|
|
3978
|
+
st_pos++;
|
|
3979
|
+
}
|
|
3980
|
+
}
|
|
3981
|
+
|
|
3982
|
+
if (ncat_present <= 1) return -HUGE_VAL;
|
|
3983
|
+
|
|
3984
|
+
double sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
|
|
3985
|
+
|
|
3986
|
+
/* try isolating each category one at a time */
|
|
3987
|
+
for (size_t pos = st_pos; (int)pos < ncat; pos++)
|
|
3988
|
+
{
|
|
3989
|
+
this_gain = sd_gain(sd_full,
|
|
3990
|
+
0.0,
|
|
3991
|
+
(expected_sd_cat_single<ldouble_safe, size_t, ldouble_safe>(buffer_cnt.data(), buffer_prob, ncat_present, buffer_pos + st_pos, pos - st_pos, cnt))
|
|
3992
|
+
);
|
|
3993
|
+
if (this_gain > min_gain && this_gain > best_gain)
|
|
3994
|
+
{
|
|
3995
|
+
best_gain = this_gain;
|
|
3996
|
+
chosen_cat = buffer_pos[pos];
|
|
3997
|
+
}
|
|
3998
|
+
}
|
|
3999
|
+
break;
|
|
4000
|
+
}
|
|
4001
|
+
|
|
4002
|
+
case Pooled:
|
|
4003
|
+
{
|
|
4004
|
+
/* here it will always pick the largest one */
|
|
4005
|
+
size_t ncat_present = 0;
|
|
4006
|
+
ldouble_safe cnt_max = 0;
|
|
4007
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
4008
|
+
{
|
|
4009
|
+
if (buffer_cnt[cat])
|
|
4010
|
+
{
|
|
4011
|
+
ncat_present++;
|
|
4012
|
+
if (cnt_max < buffer_cnt[cat])
|
|
4013
|
+
{
|
|
4014
|
+
cnt_max = buffer_cnt[cat];
|
|
4015
|
+
chosen_cat = cat;
|
|
4016
|
+
}
|
|
4017
|
+
}
|
|
4018
|
+
}
|
|
4019
|
+
|
|
4020
|
+
if (ncat_present <= 1) return -HUGE_VAL;
|
|
4021
|
+
|
|
4022
|
+
ldouble_safe cnt_left = (ldouble_safe)(cnt - cnt_max);
|
|
4023
|
+
|
|
4024
|
+
/* TODO: think of a better way of dealing with numbers between zero and one */
|
|
4025
|
+
this_gain = (
|
|
4026
|
+
std::fmax((ldouble_safe)1, cnt) * std::log(std::fmax((ldouble_safe)1, cnt))
|
|
4027
|
+
- std::fmax((ldouble_safe)1, cnt_left) * std::log(std::fmax((ldouble_safe)1, cnt_left))
|
|
4028
|
+
- std::fmax((ldouble_safe)1, cnt_max) * std::log(std::fmax((ldouble_safe)1, cnt_max))
|
|
4029
|
+
) / std::fmax((ldouble_safe)1, cnt);
|
|
4030
|
+
best_gain = (this_gain > min_gain)? this_gain : best_gain;
|
|
4031
|
+
break;
|
|
4032
|
+
}
|
|
4033
|
+
|
|
4034
|
+
default:
|
|
4035
|
+
{
|
|
4036
|
+
unexpected_error();
|
|
4037
|
+
break;
|
|
4038
|
+
}
|
|
4039
|
+
}
|
|
4040
|
+
break;
|
|
4041
|
+
}
|
|
4042
|
+
|
|
4043
|
+
case SubSet:
|
|
4044
|
+
{
|
|
4045
|
+
/* sort by counts */
|
|
4046
|
+
std::sort(buffer_pos, buffer_pos + ncat, [&buffer_cnt](const size_t a, const size_t b){return buffer_cnt[a] < buffer_cnt[b];});
|
|
4047
|
+
|
|
4048
|
+
/* set split as: (1):left (0):right (-1):not_present */
|
|
4049
|
+
memset(buffer_split, 0, ncat * sizeof(signed char));
|
|
4050
|
+
|
|
4051
|
+
|
|
4052
|
+
switch(criterion)
|
|
4053
|
+
{
|
|
4054
|
+
case Averaged:
|
|
4055
|
+
{
|
|
4056
|
+
/* determine first non-zero and convert to probabilities */
|
|
4057
|
+
double sd_full;
|
|
4058
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
4059
|
+
{
|
|
4060
|
+
if (buffer_cnt[buffer_pos[cat]])
|
|
4061
|
+
{
|
|
4062
|
+
buffer_prob[buffer_pos[cat]] = (ldouble_safe)buffer_cnt[buffer_pos[cat]] / cnt;
|
|
4063
|
+
}
|
|
4064
|
+
|
|
4065
|
+
else
|
|
4066
|
+
{
|
|
4067
|
+
buffer_split[buffer_pos[cat]] = -1;
|
|
4068
|
+
st_pos++;
|
|
4069
|
+
}
|
|
4070
|
+
}
|
|
4071
|
+
|
|
4072
|
+
if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
|
|
4073
|
+
|
|
4074
|
+
/* calculate full SD assuming they take values randomly ~Unif(0, 1) */
|
|
4075
|
+
size_t ncat_present = (size_t)ncat - st_pos;
|
|
4076
|
+
sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
|
|
4077
|
+
if (ncat_present >= log2ceil(SIZE_MAX)) all_perm = false;
|
|
4078
|
+
|
|
4079
|
+
/* move categories one at a time */
|
|
4080
|
+
for (size_t pos = st_pos; pos < ((size_t)ncat - st_pos - 1); pos++)
|
|
4081
|
+
{
|
|
4082
|
+
buffer_split[buffer_pos[pos]] = 1;
|
|
4083
|
+
/* TODO: is this correct? */
|
|
4084
|
+
this_gain = sd_gain(sd_full,
|
|
4085
|
+
(expected_sd_cat<ldouble_safe, size_t, ldouble_safe>(buffer_cnt.data(), buffer_prob, pos - st_pos + 1, buffer_pos + st_pos)),
|
|
4086
|
+
(expected_sd_cat<ldouble_safe, size_t, ldouble_safe>(buffer_cnt.data(), buffer_prob, (size_t)ncat - pos - 1, buffer_pos + pos + 1))
|
|
4087
|
+
);
|
|
4088
|
+
if (this_gain > min_gain && this_gain > best_gain)
|
|
4089
|
+
{
|
|
4090
|
+
best_gain = this_gain;
|
|
4091
|
+
memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
|
|
4092
|
+
}
|
|
4093
|
+
}
|
|
4094
|
+
|
|
4095
|
+
break;
|
|
4096
|
+
}
|
|
4097
|
+
|
|
4098
|
+
case Pooled:
|
|
4099
|
+
{
|
|
4100
|
+
ldouble_safe s = 0;
|
|
4101
|
+
|
|
4102
|
+
/* determine first non-zero and get base info */
|
|
4103
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
4104
|
+
{
|
|
4105
|
+
if (buffer_cnt[buffer_pos[cat]])
|
|
4106
|
+
{
|
|
4107
|
+
s += (buffer_cnt[buffer_pos[cat]] <= 1)?
|
|
4108
|
+
(ldouble_safe)0
|
|
4109
|
+
:
|
|
4110
|
+
((ldouble_safe) buffer_cnt[buffer_pos[cat]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[cat]]));
|
|
4111
|
+
}
|
|
4112
|
+
|
|
4113
|
+
else
|
|
4114
|
+
{
|
|
4115
|
+
buffer_split[buffer_pos[cat]] = -1;
|
|
4116
|
+
st_pos++;
|
|
4117
|
+
}
|
|
4118
|
+
}
|
|
4119
|
+
|
|
4120
|
+
if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
|
|
4121
|
+
|
|
4122
|
+
/* calculate base info */
|
|
4123
|
+
ldouble_safe base_info = std::fmax((ldouble_safe)1, cnt) * std::log(std::fmax((ldouble_safe)1, cnt)) - s;
|
|
4124
|
+
|
|
4125
|
+
if (all_perm)
|
|
4126
|
+
{
|
|
4127
|
+
size_t cnt_left, cnt_right;
|
|
4128
|
+
double s_left, s_right;
|
|
4129
|
+
size_t ncat_present = (size_t)ncat - st_pos;
|
|
4130
|
+
size_t ncomb = pow2(ncat_present) - 1;
|
|
4131
|
+
size_t best_combin;
|
|
4132
|
+
|
|
4133
|
+
for (size_t combin = 1; combin < ncomb; combin++)
|
|
4134
|
+
{
|
|
4135
|
+
cnt_left = 0; cnt_right = 0;
|
|
4136
|
+
s_left = 0; s_right = 0;
|
|
4137
|
+
for (size_t pos = st_pos; (int)pos < ncat; pos++)
|
|
4138
|
+
{
|
|
4139
|
+
if (extract_bit(combin, pos))
|
|
4140
|
+
{
|
|
4141
|
+
cnt_left += buffer_cnt[buffer_pos[pos]];
|
|
4142
|
+
s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
|
|
4143
|
+
(ldouble_safe)0
|
|
4144
|
+
:
|
|
4145
|
+
((ldouble_safe) buffer_cnt[buffer_pos[pos]]
|
|
4146
|
+
* std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
|
|
4147
|
+
}
|
|
4148
|
+
|
|
4149
|
+
else
|
|
4150
|
+
{
|
|
4151
|
+
cnt_right += buffer_cnt[buffer_pos[pos]];
|
|
4152
|
+
s_right += (buffer_cnt[buffer_pos[pos]] <= 1)?
|
|
4153
|
+
(ldouble_safe)0
|
|
4154
|
+
:
|
|
4155
|
+
((ldouble_safe) buffer_cnt[buffer_pos[pos]]
|
|
4156
|
+
* std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
|
|
4157
|
+
}
|
|
4158
|
+
}
|
|
4159
|
+
|
|
4160
|
+
this_gain = categ_gain<size_t, ldouble_safe>(
|
|
4161
|
+
cnt_left, cnt_right,
|
|
4162
|
+
s_left, s_right,
|
|
4163
|
+
base_info, cnt);
|
|
4164
|
+
|
|
4165
|
+
if (this_gain > min_gain && this_gain > best_gain)
|
|
4166
|
+
{
|
|
4167
|
+
best_gain = this_gain;
|
|
4168
|
+
best_combin = combin;
|
|
4169
|
+
}
|
|
4170
|
+
|
|
4171
|
+
}
|
|
4172
|
+
|
|
4173
|
+
if (best_gain > min_gain)
|
|
4174
|
+
for (size_t pos = 0; pos < ncat_present; pos++)
|
|
4175
|
+
split_categ[buffer_pos[st_pos + pos]] = extract_bit(best_combin, pos);
|
|
4176
|
+
|
|
4177
|
+
}
|
|
4178
|
+
|
|
4179
|
+
else
|
|
4180
|
+
{
|
|
4181
|
+
/* try moving the categories one at a time */
|
|
4182
|
+
size_t cnt_left = 0;
|
|
4183
|
+
size_t cnt_right = end - st + 1;
|
|
4184
|
+
double s_left = 0;
|
|
4185
|
+
double s_right = s;
|
|
4186
|
+
|
|
4187
|
+
for (size_t pos = st_pos; pos < (ncat - st_pos - 1); pos++)
|
|
4188
|
+
{
|
|
4189
|
+
buffer_split[buffer_pos[pos]] = 1;
|
|
4190
|
+
s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
|
|
4191
|
+
(ldouble_safe)0
|
|
4192
|
+
:
|
|
4193
|
+
((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
|
|
4194
|
+
s_right -= (buffer_cnt[buffer_pos[pos]] <= 1)?
|
|
4195
|
+
(ldouble_safe)0
|
|
4196
|
+
:
|
|
4197
|
+
((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
|
|
4198
|
+
cnt_left += buffer_cnt[buffer_pos[pos]];
|
|
4199
|
+
cnt_right -= buffer_cnt[buffer_pos[pos]];
|
|
4200
|
+
|
|
4201
|
+
this_gain = categ_gain<size_t, ldouble_safe>(
|
|
4202
|
+
cnt_left, cnt_right,
|
|
4203
|
+
s_left, s_right,
|
|
4204
|
+
base_info, cnt);
|
|
4205
|
+
|
|
4206
|
+
if (this_gain > min_gain && this_gain > best_gain)
|
|
4207
|
+
{
|
|
4208
|
+
best_gain = this_gain;
|
|
4209
|
+
memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
|
|
4210
|
+
}
|
|
4211
|
+
}
|
|
4212
|
+
}
|
|
4213
|
+
|
|
4214
|
+
break;
|
|
4215
|
+
}
|
|
4216
|
+
|
|
4217
|
+
default:
|
|
4218
|
+
{
|
|
4219
|
+
unexpected_error();
|
|
4220
|
+
break;
|
|
4221
|
+
}
|
|
4222
|
+
}
|
|
4223
|
+
}
|
|
4224
|
+
}
|
|
4225
|
+
|
|
4226
|
+
if (st == (end-1)) return 0;
|
|
4227
|
+
|
|
4228
|
+
if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
|
|
4229
|
+
return 0;
|
|
4230
|
+
else
|
|
4231
|
+
return best_gain;
|
|
4232
|
+
}
|