isotree 0.2.2 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -1
- data/LICENSE.txt +2 -2
- data/README.md +32 -14
- data/ext/isotree/ext.cpp +144 -31
- data/ext/isotree/extconf.rb +7 -7
- data/lib/isotree/isolation_forest.rb +110 -30
- data/lib/isotree/version.rb +1 -1
- data/vendor/isotree/LICENSE +1 -1
- data/vendor/isotree/README.md +165 -27
- data/vendor/isotree/include/isotree.hpp +2111 -0
- data/vendor/isotree/include/isotree_oop.hpp +394 -0
- data/vendor/isotree/inst/COPYRIGHTS +62 -0
- data/vendor/isotree/src/RcppExports.cpp +525 -52
- data/vendor/isotree/src/Rwrapper.cpp +1931 -268
- data/vendor/isotree/src/c_interface.cpp +953 -0
- data/vendor/isotree/src/crit.hpp +4232 -0
- data/vendor/isotree/src/dist.hpp +1886 -0
- data/vendor/isotree/src/exp_depth_table.hpp +134 -0
- data/vendor/isotree/src/extended.hpp +1444 -0
- data/vendor/isotree/src/external_facing_generic.hpp +399 -0
- data/vendor/isotree/src/fit_model.hpp +2401 -0
- data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
- data/vendor/isotree/src/helpers_iforest.hpp +813 -0
- data/vendor/isotree/src/{impute.cpp → impute.hpp} +353 -122
- data/vendor/isotree/src/indexer.cpp +515 -0
- data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
- data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
- data/vendor/isotree/src/isoforest.hpp +1659 -0
- data/vendor/isotree/src/isotree.hpp +1804 -392
- data/vendor/isotree/src/isotree_exportable.hpp +99 -0
- data/vendor/isotree/src/merge_models.cpp +159 -16
- data/vendor/isotree/src/mult.hpp +1321 -0
- data/vendor/isotree/src/oop_interface.cpp +842 -0
- data/vendor/isotree/src/oop_interface.hpp +278 -0
- data/vendor/isotree/src/other_helpers.hpp +219 -0
- data/vendor/isotree/src/predict.hpp +1932 -0
- data/vendor/isotree/src/python_helpers.hpp +134 -0
- data/vendor/isotree/src/ref_indexer.hpp +154 -0
- data/vendor/isotree/src/robinmap/LICENSE +21 -0
- data/vendor/isotree/src/robinmap/README.md +483 -0
- data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
- data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1620 -0
- data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
- data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
- data/vendor/isotree/src/serialize.cpp +4300 -139
- data/vendor/isotree/src/sql.cpp +141 -59
- data/vendor/isotree/src/subset_models.cpp +174 -0
- data/vendor/isotree/src/utils.hpp +3808 -0
- data/vendor/isotree/src/xoshiro.hpp +467 -0
- data/vendor/isotree/src/ziggurat.hpp +405 -0
- metadata +38 -104
- data/vendor/cereal/LICENSE +0 -24
- data/vendor/cereal/README.md +0 -85
- data/vendor/cereal/include/cereal/access.hpp +0 -351
- data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
- data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
- data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
- data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
- data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
- data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
- data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
- data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
- data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
- data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
- data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
- data/vendor/cereal/include/cereal/details/util.hpp +0 -84
- data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
- data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
- data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
- data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
- data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
- data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
- data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
- data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
- data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
- data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
- data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
- data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
- data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
- data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
- data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
- data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
- data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
- data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
- data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
- data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
- data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
- data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
- data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
- data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
- data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
- data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
- data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
- data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
- data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
- data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
- data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
- data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
- data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
- data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
- data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
- data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
- data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
- data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
- data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
- data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
- data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
- data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
- data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
- data/vendor/cereal/include/cereal/macros.hpp +0 -154
- data/vendor/cereal/include/cereal/specialize.hpp +0 -139
- data/vendor/cereal/include/cereal/types/array.hpp +0 -79
- data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
- data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
- data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
- data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
- data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
- data/vendor/cereal/include/cereal/types/common.hpp +0 -129
- data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
- data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
- data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
- data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
- data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
- data/vendor/cereal/include/cereal/types/list.hpp +0 -62
- data/vendor/cereal/include/cereal/types/map.hpp +0 -36
- data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
- data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
- data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
- data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
- data/vendor/cereal/include/cereal/types/set.hpp +0 -103
- data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
- data/vendor/cereal/include/cereal/types/string.hpp +0 -61
- data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
- data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
- data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
- data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
- data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
- data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
- data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
- data/vendor/cereal/include/cereal/version.hpp +0 -52
- data/vendor/isotree/src/Makevars +0 -4
- data/vendor/isotree/src/crit.cpp +0 -912
- data/vendor/isotree/src/dist.cpp +0 -749
- data/vendor/isotree/src/extended.cpp +0 -790
- data/vendor/isotree/src/fit_model.cpp +0 -1090
- data/vendor/isotree/src/helpers_iforest.cpp +0 -324
- data/vendor/isotree/src/isoforest.cpp +0 -771
- data/vendor/isotree/src/mult.cpp +0 -607
- data/vendor/isotree/src/predict.cpp +0 -853
- data/vendor/isotree/src/utils.cpp +0 -1566
|
@@ -0,0 +1,1321 @@
|
|
|
1
|
+
/* Isolation forests and variations thereof, with adjustments for incorporation
|
|
2
|
+
* of categorical variables and missing values.
|
|
3
|
+
* Writen for C++11 standard and aimed at being used in R and Python.
|
|
4
|
+
*
|
|
5
|
+
* This library is based on the following works:
|
|
6
|
+
* [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
|
|
7
|
+
* "Isolation forest."
|
|
8
|
+
* 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
|
|
9
|
+
* [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
|
|
10
|
+
* "Isolation-based anomaly detection."
|
|
11
|
+
* ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
|
|
12
|
+
* [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
|
|
13
|
+
* "Extended Isolation Forest."
|
|
14
|
+
* arXiv preprint arXiv:1811.02141 (2018).
|
|
15
|
+
* [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
|
|
16
|
+
* "On detecting clustered anomalies using SCiForest."
|
|
17
|
+
* Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
|
|
18
|
+
* [5] https://sourceforge.net/projects/iforest/
|
|
19
|
+
* [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
|
|
20
|
+
* [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
|
|
21
|
+
* [8] Cortes, David.
|
|
22
|
+
* "Distance approximation using Isolation Forests."
|
|
23
|
+
* arXiv preprint arXiv:1910.12362 (2019).
|
|
24
|
+
* [9] Cortes, David.
|
|
25
|
+
* "Imputing missing values with unsupervised random trees."
|
|
26
|
+
* arXiv preprint arXiv:1911.06646 (2019).
|
|
27
|
+
* [10] https://math.stackexchange.com/questions/3333220/expected-average-depth-in-random-binary-tree-constructed-top-to-bottom
|
|
28
|
+
* [11] Cortes, David.
|
|
29
|
+
* "Revisiting randomized choices in isolation forests."
|
|
30
|
+
* arXiv preprint arXiv:2110.13402 (2021).
|
|
31
|
+
* [12] Guha, Sudipto, et al.
|
|
32
|
+
* "Robust random cut forest based anomaly detection on streams."
|
|
33
|
+
* International conference on machine learning. PMLR, 2016.
|
|
34
|
+
* [13] Cortes, David.
|
|
35
|
+
* "Isolation forests: looking beyond tree depth."
|
|
36
|
+
* arXiv preprint arXiv:2111.11639 (2021).
|
|
37
|
+
* [14] Ting, Kai Ming, Yue Zhu, and Zhi-Hua Zhou.
|
|
38
|
+
* "Isolation kernel and its effect on SVM"
|
|
39
|
+
* Proceedings of the 24th ACM SIGKDD
|
|
40
|
+
* International Conference on Knowledge Discovery & Data Mining. 2018.
|
|
41
|
+
*
|
|
42
|
+
* BSD 2-Clause License
|
|
43
|
+
* Copyright (c) 2019-2022, David Cortes
|
|
44
|
+
* All rights reserved.
|
|
45
|
+
* Redistribution and use in source and binary forms, with or without
|
|
46
|
+
* modification, are permitted provided that the following conditions are met:
|
|
47
|
+
* * Redistributions of source code must retain the above copyright notice, this
|
|
48
|
+
* list of conditions and the following disclaimer.
|
|
49
|
+
* * Redistributions in binary form must reproduce the above copyright notice,
|
|
50
|
+
* this list of conditions and the following disclaimer in the documentation
|
|
51
|
+
* and/or other materials provided with the distribution.
|
|
52
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
53
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
54
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
55
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
56
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
57
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
58
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
59
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
60
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
61
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
62
|
+
*/
|
|
63
|
+
#include "isotree.hpp"
|
|
64
|
+
|
|
65
|
+
/* FIXME / TODO: here the calculations of medians do not take weights into account */
|
|
66
|
+
|
|
67
|
+
#define SD_MIN 1e-10
|
|
68
|
+
/* https://www.johndcook.com/blog/standard_deviation/ */
|
|
69
|
+
|
|
70
|
+
/* for regular numerical */
|
|
71
|
+
template <class real_t, class real_t_>
|
|
72
|
+
void calc_mean_and_sd_t(size_t ix_arr[], size_t st, size_t end, real_t_ *restrict x,
|
|
73
|
+
MissingAction missing_action, double &restrict x_sd, double &restrict x_mean)
|
|
74
|
+
{
|
|
75
|
+
real_t m = 0;
|
|
76
|
+
real_t s = 0;
|
|
77
|
+
real_t m_prev = x[ix_arr[st]];
|
|
78
|
+
real_t xval;
|
|
79
|
+
|
|
80
|
+
if (missing_action == Fail)
|
|
81
|
+
{
|
|
82
|
+
m_prev = x[ix_arr[st]];
|
|
83
|
+
for (size_t row = st; row <= end; row++)
|
|
84
|
+
{
|
|
85
|
+
xval = x[ix_arr[row]];
|
|
86
|
+
m += (xval - m) / (real_t)(row - st + 1);
|
|
87
|
+
s = std::fma(xval - m, xval - m_prev, s);
|
|
88
|
+
m_prev = m;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
x_mean = m;
|
|
92
|
+
x_sd = std::sqrt(s / (real_t)(end - st + 1));
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
else
|
|
96
|
+
{
|
|
97
|
+
size_t cnt = 0;
|
|
98
|
+
while (is_na_or_inf(m_prev) && st <= end)
|
|
99
|
+
{
|
|
100
|
+
m_prev = x[ix_arr[++st]];
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
for (size_t row = st; row <= end; row++)
|
|
104
|
+
{
|
|
105
|
+
xval = x[ix_arr[row]];
|
|
106
|
+
if (likely(!is_na_or_inf(xval)))
|
|
107
|
+
{
|
|
108
|
+
cnt++;
|
|
109
|
+
m += (xval - m) / (real_t)cnt;
|
|
110
|
+
s = std::fma(xval - m, xval - m_prev, s);
|
|
111
|
+
m_prev = m;
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
x_mean = m;
|
|
116
|
+
x_sd = std::sqrt(s / (real_t)cnt);
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
template <class real_t_>
|
|
121
|
+
double calc_mean_only(size_t ix_arr[], size_t st, size_t end, real_t_ *restrict x)
|
|
122
|
+
{
|
|
123
|
+
size_t cnt = 0;
|
|
124
|
+
double m = 0;
|
|
125
|
+
real_t_ xval;
|
|
126
|
+
for (size_t row = st; row <= end; row++)
|
|
127
|
+
{
|
|
128
|
+
xval = x[ix_arr[row]];
|
|
129
|
+
if (likely(!is_na_or_inf(xval)))
|
|
130
|
+
{
|
|
131
|
+
cnt++;
|
|
132
|
+
m += (xval - m) / (double)cnt;
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
return m;
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
template <class real_t_, class ldouble_safe>
|
|
140
|
+
void calc_mean_and_sd(size_t ix_arr[], size_t st, size_t end, real_t_ *restrict x,
|
|
141
|
+
MissingAction missing_action, double &restrict x_sd, double &restrict x_mean)
|
|
142
|
+
{
|
|
143
|
+
if (end - st + 1 < THRESHOLD_LONG_DOUBLE)
|
|
144
|
+
calc_mean_and_sd_t<double, real_t_>(ix_arr, st, end, x, missing_action, x_sd, x_mean);
|
|
145
|
+
else
|
|
146
|
+
calc_mean_and_sd_t<ldouble_safe, real_t_>(ix_arr, st, end, x, missing_action, x_sd, x_mean);
|
|
147
|
+
x_sd = std::fmax(x_sd, SD_MIN);
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
template <class real_t_, class mapping, class ldouble_safe>
|
|
151
|
+
void calc_mean_and_sd_weighted(size_t ix_arr[], size_t st, size_t end, real_t_ *restrict x, mapping &restrict w,
|
|
152
|
+
MissingAction missing_action, double &restrict x_sd, double &restrict x_mean)
|
|
153
|
+
{
|
|
154
|
+
ldouble_safe cnt = 0;
|
|
155
|
+
ldouble_safe w_this;
|
|
156
|
+
ldouble_safe m = 0;
|
|
157
|
+
ldouble_safe s = 0;
|
|
158
|
+
ldouble_safe m_prev = x[ix_arr[st]];
|
|
159
|
+
ldouble_safe xval;
|
|
160
|
+
while (is_na_or_inf(m_prev) && st <= end)
|
|
161
|
+
{
|
|
162
|
+
m_prev = x[ix_arr[++st]];
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
for (size_t row = st; row <= end; row++)
|
|
166
|
+
{
|
|
167
|
+
xval = x[ix_arr[row]];
|
|
168
|
+
if (likely(!is_na_or_inf(xval)))
|
|
169
|
+
{
|
|
170
|
+
w_this = w[ix_arr[row]];
|
|
171
|
+
cnt += w_this;
|
|
172
|
+
m = std::fma(w_this, (xval - m) / cnt, m);
|
|
173
|
+
s = std::fma(w_this, (xval - m) * (xval - m_prev), s);
|
|
174
|
+
m_prev = m;
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
x_mean = m;
|
|
179
|
+
x_sd = std::sqrt((ldouble_safe)s / (ldouble_safe)cnt);
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
template <class real_t_, class mapping>
|
|
183
|
+
double calc_mean_only_weighted(size_t ix_arr[], size_t st, size_t end, real_t_ *restrict x, mapping &restrict w)
|
|
184
|
+
{
|
|
185
|
+
double cnt = 0;
|
|
186
|
+
double w_this;
|
|
187
|
+
double m = 0;
|
|
188
|
+
for (size_t row = st; row <= end; row++)
|
|
189
|
+
{
|
|
190
|
+
if (likely(!is_na_or_inf(x[ix_arr[row]])))
|
|
191
|
+
{
|
|
192
|
+
w_this = w[ix_arr[row]];
|
|
193
|
+
cnt += w_this;
|
|
194
|
+
m = std::fma(w_this, (x[ix_arr[row]] - m) / cnt, m);
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
return m;
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
/* for sparse numerical */
|
|
202
|
+
template <class real_t_, class sparse_ix, class real_t>
|
|
203
|
+
void calc_mean_and_sd_(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
|
|
204
|
+
real_t_ *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
|
|
205
|
+
double &restrict x_sd, double &restrict x_mean)
|
|
206
|
+
{
|
|
207
|
+
/* ix_arr must be already sorted beforehand */
|
|
208
|
+
if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
|
|
209
|
+
{
|
|
210
|
+
x_sd = 0;
|
|
211
|
+
x_mean = 0;
|
|
212
|
+
return;
|
|
213
|
+
}
|
|
214
|
+
size_t st_col = Xc_indptr[col_num];
|
|
215
|
+
size_t end_col = Xc_indptr[col_num + 1] - 1;
|
|
216
|
+
size_t curr_pos = st_col;
|
|
217
|
+
size_t ind_end_col = (size_t) Xc_ind[end_col];
|
|
218
|
+
size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, (size_t)Xc_ind[st_col]);
|
|
219
|
+
|
|
220
|
+
size_t cnt = end - st + 1;
|
|
221
|
+
size_t added = 0;
|
|
222
|
+
real_t m = 0;
|
|
223
|
+
real_t s = 0;
|
|
224
|
+
real_t m_prev = 0;
|
|
225
|
+
|
|
226
|
+
for (size_t *row = ptr_st;
|
|
227
|
+
row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
|
|
228
|
+
)
|
|
229
|
+
{
|
|
230
|
+
if (Xc_ind[curr_pos] == (sparse_ix)(*row))
|
|
231
|
+
{
|
|
232
|
+
if (unlikely(is_na_or_inf(Xc[curr_pos])))
|
|
233
|
+
{
|
|
234
|
+
cnt--;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
else
|
|
238
|
+
{
|
|
239
|
+
if (added == 0) m_prev = Xc[curr_pos];
|
|
240
|
+
m += (Xc[curr_pos] - m) / (real_t)(++added);
|
|
241
|
+
s = std::fma(Xc[curr_pos] - m, Xc[curr_pos] - m_prev, s);
|
|
242
|
+
m_prev = m;
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
if (row == ix_arr + end || curr_pos == end_col) break;
|
|
246
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
else
|
|
250
|
+
{
|
|
251
|
+
if (Xc_ind[curr_pos] > (sparse_ix)(*row))
|
|
252
|
+
row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
|
|
253
|
+
else
|
|
254
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
if (added == 0)
|
|
259
|
+
{
|
|
260
|
+
x_mean = 0;
|
|
261
|
+
x_sd = 0;
|
|
262
|
+
return;
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
/* Note: up to this point:
|
|
266
|
+
m = sum(x)/nnz
|
|
267
|
+
s = sum(x^2) - (1/nnz)*(sum(x)^2)
|
|
268
|
+
Here the standard deviation is given by:
|
|
269
|
+
sigma = (1/n)*(sum(x^2) - (1/n)*(sum(x)^2))
|
|
270
|
+
The difference can be put to a closed form. */
|
|
271
|
+
if (cnt > added)
|
|
272
|
+
{
|
|
273
|
+
s += square(m) * ((real_t)added * ((real_t)1 - (real_t)added/(real_t)cnt));
|
|
274
|
+
m *= (real_t)added / (real_t)cnt;
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
x_mean = m;
|
|
278
|
+
x_sd = std::sqrt(s / (real_t)cnt);
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
template <class real_t_, class sparse_ix, class ldouble_safe>
|
|
282
|
+
void calc_mean_and_sd(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
|
|
283
|
+
real_t_ *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
|
|
284
|
+
double &restrict x_sd, double &restrict x_mean)
|
|
285
|
+
{
|
|
286
|
+
if (end - st + 1 < THRESHOLD_LONG_DOUBLE)
|
|
287
|
+
calc_mean_and_sd_<real_t_, sparse_ix, double>(ix_arr, st, end, col_num, Xc, Xc_ind, Xc_indptr, x_sd, x_mean);
|
|
288
|
+
else
|
|
289
|
+
calc_mean_and_sd_<real_t_, sparse_ix, ldouble_safe>(ix_arr, st, end, col_num, Xc, Xc_ind, Xc_indptr, x_sd, x_mean);
|
|
290
|
+
x_sd = std::fmax(SD_MIN, x_sd);
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
template <class real_t_, class sparse_ix, class ldouble_safe>
|
|
294
|
+
double calc_mean_only(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
|
|
295
|
+
real_t_ *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr)
|
|
296
|
+
{
|
|
297
|
+
/* ix_arr must be already sorted beforehand */
|
|
298
|
+
if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
|
|
299
|
+
return 0;
|
|
300
|
+
size_t st_col = Xc_indptr[col_num];
|
|
301
|
+
size_t end_col = Xc_indptr[col_num + 1] - 1;
|
|
302
|
+
size_t curr_pos = st_col;
|
|
303
|
+
size_t ind_end_col = (size_t) Xc_ind[end_col];
|
|
304
|
+
size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, (size_t)Xc_ind[st_col]);
|
|
305
|
+
|
|
306
|
+
size_t cnt = end - st + 1;
|
|
307
|
+
size_t added = 0;
|
|
308
|
+
double m = 0;
|
|
309
|
+
|
|
310
|
+
for (size_t *row = ptr_st;
|
|
311
|
+
row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
|
|
312
|
+
)
|
|
313
|
+
{
|
|
314
|
+
if (Xc_ind[curr_pos] == (sparse_ix)(*row))
|
|
315
|
+
{
|
|
316
|
+
if (unlikely(is_na_or_inf(Xc[curr_pos])))
|
|
317
|
+
cnt--;
|
|
318
|
+
else
|
|
319
|
+
m += (Xc[curr_pos] - m) / (double)(++added);
|
|
320
|
+
|
|
321
|
+
if (row == ix_arr + end || curr_pos == end_col) break;
|
|
322
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
else
|
|
326
|
+
{
|
|
327
|
+
if (Xc_ind[curr_pos] > (sparse_ix)(*row))
|
|
328
|
+
row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
|
|
329
|
+
else
|
|
330
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
if (added == 0)
|
|
335
|
+
return 0;
|
|
336
|
+
|
|
337
|
+
if (cnt > added)
|
|
338
|
+
m *= ((ldouble_safe)added / (ldouble_safe)cnt);
|
|
339
|
+
|
|
340
|
+
return m;
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
template <class real_t_, class sparse_ix, class mapping, class ldouble_safe>
|
|
344
|
+
void calc_mean_and_sd_weighted(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
|
|
345
|
+
real_t_ *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
|
|
346
|
+
double &restrict x_sd, double &restrict x_mean, mapping &restrict w)
|
|
347
|
+
{
|
|
348
|
+
/* ix_arr must be already sorted beforehand */
|
|
349
|
+
if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
|
|
350
|
+
{
|
|
351
|
+
x_sd = 0;
|
|
352
|
+
x_mean = 0;
|
|
353
|
+
return;
|
|
354
|
+
}
|
|
355
|
+
size_t st_col = Xc_indptr[col_num];
|
|
356
|
+
size_t end_col = Xc_indptr[col_num + 1] - 1;
|
|
357
|
+
size_t curr_pos = st_col;
|
|
358
|
+
size_t ind_end_col = (size_t) Xc_ind[end_col];
|
|
359
|
+
size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, (size_t)Xc_ind[st_col]);
|
|
360
|
+
|
|
361
|
+
ldouble_safe cnt = 0.;
|
|
362
|
+
for (size_t row = st; row <= end; row++)
|
|
363
|
+
cnt += w[ix_arr[row]];
|
|
364
|
+
ldouble_safe added = 0;
|
|
365
|
+
ldouble_safe m = 0;
|
|
366
|
+
ldouble_safe s = 0;
|
|
367
|
+
ldouble_safe m_prev = 0;
|
|
368
|
+
ldouble_safe w_this;
|
|
369
|
+
|
|
370
|
+
for (size_t *row = ptr_st;
|
|
371
|
+
row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
|
|
372
|
+
)
|
|
373
|
+
{
|
|
374
|
+
if (Xc_ind[curr_pos] == (sparse_ix)(*row))
|
|
375
|
+
{
|
|
376
|
+
if (unlikely(is_na_or_inf(Xc[curr_pos])))
|
|
377
|
+
{
|
|
378
|
+
cnt -= w[*row];
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
else
|
|
382
|
+
{
|
|
383
|
+
w_this = w[*row];
|
|
384
|
+
if (added == 0) m_prev = Xc[curr_pos];
|
|
385
|
+
added += w_this;
|
|
386
|
+
m = std::fma(w_this, (Xc[curr_pos] - m) / added, m);
|
|
387
|
+
s = std::fma(w_this, (Xc[curr_pos] - m) * (Xc[curr_pos] - m_prev), s);
|
|
388
|
+
m_prev = m;
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
if (row == ix_arr + end || curr_pos == end_col) break;
|
|
392
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
else
|
|
396
|
+
{
|
|
397
|
+
if (Xc_ind[curr_pos] > (sparse_ix)(*row))
|
|
398
|
+
row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
|
|
399
|
+
else
|
|
400
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
|
|
401
|
+
}
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
if (added == 0)
|
|
405
|
+
{
|
|
406
|
+
x_mean = 0;
|
|
407
|
+
x_sd = 0;
|
|
408
|
+
return;
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
/* Note: up to this point:
|
|
412
|
+
m = sum(x)/nnz
|
|
413
|
+
s = sum(x^2) - (1/nnz)*(sum(x)^2)
|
|
414
|
+
Here the standard deviation is given by:
|
|
415
|
+
sigma = (1/n)*(sum(x^2) - (1/n)*(sum(x)^2))
|
|
416
|
+
The difference can be put to a closed form. */
|
|
417
|
+
if (cnt > added)
|
|
418
|
+
{
|
|
419
|
+
s += square(m) * (added * ((ldouble_safe)1 - (ldouble_safe)added/(ldouble_safe)cnt));
|
|
420
|
+
m *= added / cnt;
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
x_mean = m;
|
|
424
|
+
x_sd = std::sqrt(s / (ldouble_safe)cnt);
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
template <class real_t_, class sparse_ix, class mapping, class ldouble_safe>
|
|
428
|
+
double calc_mean_only_weighted(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
|
|
429
|
+
real_t_ *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
|
|
430
|
+
mapping &restrict w)
|
|
431
|
+
{
|
|
432
|
+
/* ix_arr must be already sorted beforehand */
|
|
433
|
+
if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
|
|
434
|
+
return 0;
|
|
435
|
+
size_t st_col = Xc_indptr[col_num];
|
|
436
|
+
size_t end_col = Xc_indptr[col_num + 1] - 1;
|
|
437
|
+
size_t curr_pos = st_col;
|
|
438
|
+
size_t ind_end_col = (size_t) Xc_ind[end_col];
|
|
439
|
+
size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, (size_t)Xc_ind[st_col]);
|
|
440
|
+
|
|
441
|
+
ldouble_safe cnt = 0.;
|
|
442
|
+
for (size_t row = st; row <= end; row++)
|
|
443
|
+
cnt += w[ix_arr[row]];
|
|
444
|
+
ldouble_safe added = 0;
|
|
445
|
+
ldouble_safe m = 0;
|
|
446
|
+
ldouble_safe w_this;
|
|
447
|
+
|
|
448
|
+
for (size_t *row = ptr_st;
|
|
449
|
+
row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
|
|
450
|
+
)
|
|
451
|
+
{
|
|
452
|
+
if (Xc_ind[curr_pos] == (sparse_ix)(*row))
|
|
453
|
+
{
|
|
454
|
+
if (unlikely(is_na_or_inf(Xc[curr_pos]))) {
|
|
455
|
+
cnt -= w[*row];
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
else {
|
|
459
|
+
w_this = w[*row];
|
|
460
|
+
added += w_this;
|
|
461
|
+
m += w_this * (Xc[curr_pos] - m) / added;
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
if (row == ix_arr + end || curr_pos == end_col) break;
|
|
465
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
else
|
|
469
|
+
{
|
|
470
|
+
if (Xc_ind[curr_pos] > (sparse_ix)(*row))
|
|
471
|
+
row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
|
|
472
|
+
else
|
|
473
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
|
|
474
|
+
}
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
if (added == 0)
|
|
478
|
+
return 0;
|
|
479
|
+
|
|
480
|
+
if (cnt > added)
|
|
481
|
+
m *= (ldouble_safe)added / (ldouble_safe)cnt;
|
|
482
|
+
|
|
483
|
+
return m;
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
/* Note about these functions: they write into an array that does not need to match to 'ix_arr',
|
|
487
|
+
and instead, the index that is stored in ix_arr[n] will have the value in res[n] */
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
/* for regular numerical */
|
|
491
|
+
template <class real_t_>
|
|
492
|
+
void add_linear_comb(size_t ix_arr[], size_t st, size_t end, double *restrict res,
|
|
493
|
+
real_t_ *restrict x, double &coef, double x_sd, double x_mean, double &restrict fill_val,
|
|
494
|
+
MissingAction missing_action, double *restrict buffer_arr,
|
|
495
|
+
size_t *restrict buffer_NAs, bool first_run)
|
|
496
|
+
{
|
|
497
|
+
/* TODO: here don't need the buffer for NAs */
|
|
498
|
+
|
|
499
|
+
if (first_run)
|
|
500
|
+
coef /= x_sd;
|
|
501
|
+
|
|
502
|
+
size_t cnt = 0;
|
|
503
|
+
size_t cnt_NA = 0;
|
|
504
|
+
double *restrict res_write = res - st;
|
|
505
|
+
|
|
506
|
+
if (missing_action == Fail)
|
|
507
|
+
{
|
|
508
|
+
for (size_t row = st; row <= end; row++)
|
|
509
|
+
res_write[row] = std::fma(x[ix_arr[row]] - x_mean, coef, res_write[row]);
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
else
|
|
513
|
+
{
|
|
514
|
+
if (first_run)
|
|
515
|
+
{
|
|
516
|
+
for (size_t row = st; row <= end; row++)
|
|
517
|
+
{
|
|
518
|
+
if (likely(!is_na_or_inf(x[ix_arr[row]])))
|
|
519
|
+
{
|
|
520
|
+
res_write[row] = std::fma(x[ix_arr[row]] - x_mean, coef, res_write[row]);
|
|
521
|
+
buffer_arr[cnt++] = x[ix_arr[row]];
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
else
|
|
525
|
+
{
|
|
526
|
+
buffer_NAs[cnt_NA++] = row;
|
|
527
|
+
}
|
|
528
|
+
|
|
529
|
+
}
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
else
|
|
533
|
+
{
|
|
534
|
+
for (size_t row = st; row <= end; row++)
|
|
535
|
+
{
|
|
536
|
+
res_write[row] += (is_na_or_inf(x[ix_arr[row]]))? fill_val : ( (x[ix_arr[row]]-x_mean) * coef );
|
|
537
|
+
}
|
|
538
|
+
return;
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
size_t mid_ceil = cnt / 2;
|
|
542
|
+
std::partial_sort(buffer_arr, buffer_arr + mid_ceil + 1, buffer_arr + cnt);
|
|
543
|
+
|
|
544
|
+
if ((cnt % 2) == 0)
|
|
545
|
+
fill_val = buffer_arr[mid_ceil-1] + (buffer_arr[mid_ceil] - buffer_arr[mid_ceil-1]) / 2.0;
|
|
546
|
+
else
|
|
547
|
+
fill_val = buffer_arr[mid_ceil];
|
|
548
|
+
|
|
549
|
+
fill_val = (fill_val - x_mean) * coef;
|
|
550
|
+
if (cnt_NA && fill_val)
|
|
551
|
+
{
|
|
552
|
+
for (size_t row = 0; row < cnt_NA; row++)
|
|
553
|
+
res_write[buffer_NAs[row]] += fill_val;
|
|
554
|
+
}
|
|
555
|
+
|
|
556
|
+
}
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
/* for regular numerical */
|
|
560
|
+
template <class real_t_, class mapping, class ldouble_safe>
|
|
561
|
+
void add_linear_comb_weighted(size_t ix_arr[], size_t st, size_t end, double *restrict res,
|
|
562
|
+
real_t_ *restrict x, double &coef, double x_sd, double x_mean, double &restrict fill_val,
|
|
563
|
+
MissingAction missing_action, double *restrict buffer_arr,
|
|
564
|
+
size_t *restrict buffer_NAs, bool first_run, mapping &restrict w)
|
|
565
|
+
{
|
|
566
|
+
/* TODO: here don't need the buffer for NAs */
|
|
567
|
+
|
|
568
|
+
if (first_run)
|
|
569
|
+
coef /= x_sd;
|
|
570
|
+
|
|
571
|
+
size_t cnt = 0;
|
|
572
|
+
size_t cnt_NA = 0;
|
|
573
|
+
double *restrict res_write = res - st;
|
|
574
|
+
ldouble_safe cumw = 0;
|
|
575
|
+
double w_this;
|
|
576
|
+
/* TODO: these buffers should be allocated externally */
|
|
577
|
+
std::vector<double> obs_weight;
|
|
578
|
+
|
|
579
|
+
if (first_run && missing_action != Fail)
|
|
580
|
+
{
|
|
581
|
+
obs_weight.resize(end - st + 1, 0.);
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
if (missing_action == Fail)
|
|
585
|
+
{
|
|
586
|
+
for (size_t row = st; row <= end; row++)
|
|
587
|
+
res_write[row] = std::fma(x[ix_arr[row]] - x_mean, coef, res_write[row]);
|
|
588
|
+
}
|
|
589
|
+
|
|
590
|
+
else
|
|
591
|
+
{
|
|
592
|
+
if (first_run)
|
|
593
|
+
{
|
|
594
|
+
for (size_t row = st; row <= end; row++)
|
|
595
|
+
{
|
|
596
|
+
if (likely(!is_na_or_inf(x[ix_arr[row]])))
|
|
597
|
+
{
|
|
598
|
+
w_this = w[ix_arr[row]];
|
|
599
|
+
res_write[row] = std::fma(x[ix_arr[row]] - x_mean, coef, res_write[row]);
|
|
600
|
+
obs_weight[cnt] = w_this;
|
|
601
|
+
buffer_arr[cnt++] = x[ix_arr[row]];
|
|
602
|
+
cumw += w_this;
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
else
|
|
606
|
+
{
|
|
607
|
+
buffer_NAs[cnt_NA++] = row;
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
}
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
else
|
|
614
|
+
{
|
|
615
|
+
for (size_t row = st; row <= end; row++)
|
|
616
|
+
{
|
|
617
|
+
res_write[row] += (is_na_or_inf(x[ix_arr[row]]))? fill_val : ( (x[ix_arr[row]]-x_mean) * coef );
|
|
618
|
+
}
|
|
619
|
+
return;
|
|
620
|
+
}
|
|
621
|
+
|
|
622
|
+
|
|
623
|
+
ldouble_safe mid_point = cumw / (ldouble_safe)2;
|
|
624
|
+
std::vector<size_t> sorted_ix(cnt);
|
|
625
|
+
std::iota(sorted_ix.begin(), sorted_ix.end(), (size_t)0);
|
|
626
|
+
std::sort(sorted_ix.begin(), sorted_ix.end(),
|
|
627
|
+
[&buffer_arr](const size_t a, const size_t b){return buffer_arr[a] < buffer_arr[b];});
|
|
628
|
+
ldouble_safe currw = 0;
|
|
629
|
+
fill_val = buffer_arr[sorted_ix.back()]; /* <- will overwrite later */
|
|
630
|
+
/* TODO: is this median calculation correct? should it do a weighted interpolation? */
|
|
631
|
+
for (size_t ix = 0; ix < cnt; ix++)
|
|
632
|
+
{
|
|
633
|
+
currw += obs_weight[sorted_ix[ix]];
|
|
634
|
+
if (currw >= mid_point)
|
|
635
|
+
{
|
|
636
|
+
if (currw == mid_point && ix < cnt-1)
|
|
637
|
+
fill_val = buffer_arr[sorted_ix[ix]] + (buffer_arr[sorted_ix[ix+1]] - buffer_arr[sorted_ix[ix]]) / 2.0;
|
|
638
|
+
else
|
|
639
|
+
fill_val = buffer_arr[sorted_ix[ix]];
|
|
640
|
+
break;
|
|
641
|
+
}
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
fill_val = (fill_val - x_mean) * coef;
|
|
645
|
+
if (cnt_NA && fill_val)
|
|
646
|
+
{
|
|
647
|
+
for (size_t row = 0; row < cnt_NA; row++)
|
|
648
|
+
res_write[buffer_NAs[row]] += fill_val;
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
}
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
/* for sparse numerical */
|
|
655
|
+
template <class real_t_, class sparse_ix>
|
|
656
|
+
void add_linear_comb(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num, double *restrict res,
|
|
657
|
+
real_t_ *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
|
|
658
|
+
double &restrict coef, double x_sd, double x_mean, double &restrict fill_val, MissingAction missing_action,
|
|
659
|
+
double *restrict buffer_arr, size_t *restrict buffer_NAs, bool first_run)
|
|
660
|
+
{
|
|
661
|
+
/* ix_arr must be already sorted beforehand */
|
|
662
|
+
|
|
663
|
+
/* if it's all zeros, no need to do anything, but this is not supposed
|
|
664
|
+
to happen while fitting because the range is determined before calling this */
|
|
665
|
+
if (
|
|
666
|
+
Xc_indptr[col_num] == Xc_indptr[col_num + 1] ||
|
|
667
|
+
Xc_ind[Xc_indptr[col_num]] > (sparse_ix)ix_arr[end] ||
|
|
668
|
+
Xc_ind[Xc_indptr[col_num + 1] - 1] < (sparse_ix)ix_arr[st]
|
|
669
|
+
)
|
|
670
|
+
{
|
|
671
|
+
if (first_run)
|
|
672
|
+
{
|
|
673
|
+
coef /= x_sd;
|
|
674
|
+
if (missing_action != Fail)
|
|
675
|
+
fill_val = 0;
|
|
676
|
+
}
|
|
677
|
+
|
|
678
|
+
double *restrict res_write = res - st;
|
|
679
|
+
double offset = x_mean * coef;
|
|
680
|
+
if (offset)
|
|
681
|
+
{
|
|
682
|
+
for (size_t row = st; row <= end; row++)
|
|
683
|
+
res_write[row] -= offset;
|
|
684
|
+
}
|
|
685
|
+
|
|
686
|
+
return;
|
|
687
|
+
}
|
|
688
|
+
|
|
689
|
+
size_t st_col = Xc_indptr[col_num];
|
|
690
|
+
size_t end_col = Xc_indptr[col_num + 1] - 1;
|
|
691
|
+
size_t curr_pos = st_col;
|
|
692
|
+
size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, (size_t)Xc_ind[st_col]);
|
|
693
|
+
|
|
694
|
+
size_t cnt_non_NA = 0; /* when NAs need to be imputed */
|
|
695
|
+
size_t cnt_NA = 0; /* when NAs need to be imputed */
|
|
696
|
+
size_t n_sample = end - st + 1;
|
|
697
|
+
size_t *ix_arr_plus_st = ix_arr + st;
|
|
698
|
+
|
|
699
|
+
if (first_run)
|
|
700
|
+
coef /= x_sd;
|
|
701
|
+
|
|
702
|
+
double *restrict res_write = res - st;
|
|
703
|
+
double offset = x_mean * coef;
|
|
704
|
+
if (offset)
|
|
705
|
+
{
|
|
706
|
+
for (size_t row = st; row <= end; row++)
|
|
707
|
+
res_write[row] -= offset;
|
|
708
|
+
}
|
|
709
|
+
|
|
710
|
+
size_t ind_end_col = Xc_ind[end_col];
|
|
711
|
+
size_t nmatches = 0;
|
|
712
|
+
|
|
713
|
+
if (missing_action != Fail)
|
|
714
|
+
{
|
|
715
|
+
if (first_run)
|
|
716
|
+
{
|
|
717
|
+
for (size_t *row = ptr_st;
|
|
718
|
+
row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
|
|
719
|
+
)
|
|
720
|
+
{
|
|
721
|
+
if (Xc_ind[curr_pos] == (sparse_ix)(*row))
|
|
722
|
+
{
|
|
723
|
+
if (unlikely(is_na_or_inf(Xc[curr_pos])))
|
|
724
|
+
{
|
|
725
|
+
buffer_NAs[cnt_NA++] = row - ix_arr_plus_st;
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
else
|
|
729
|
+
{
|
|
730
|
+
buffer_arr[cnt_non_NA++] = Xc[curr_pos];
|
|
731
|
+
res[row - ix_arr_plus_st] = std::fma(Xc[curr_pos], coef, res[row - ix_arr_plus_st]);
|
|
732
|
+
}
|
|
733
|
+
|
|
734
|
+
nmatches++;
|
|
735
|
+
if (row == ix_arr + end || curr_pos == end_col) break;
|
|
736
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
else
|
|
740
|
+
{
|
|
741
|
+
if (Xc_ind[curr_pos] > (sparse_ix)(*row))
|
|
742
|
+
row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
|
|
743
|
+
else
|
|
744
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
|
|
745
|
+
}
|
|
746
|
+
}
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
else
|
|
750
|
+
{
|
|
751
|
+
/* when impute value for missing has already been determined */
|
|
752
|
+
for (size_t *row = ptr_st;
|
|
753
|
+
row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
|
|
754
|
+
)
|
|
755
|
+
{
|
|
756
|
+
if (Xc_ind[curr_pos] == (sparse_ix)(*row))
|
|
757
|
+
{
|
|
758
|
+
res[row - ix_arr_plus_st] += is_na_or_inf(Xc[curr_pos])?
|
|
759
|
+
(fill_val + offset) : (Xc[curr_pos] * coef);
|
|
760
|
+
if (row == ix_arr + end) break;
|
|
761
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
|
|
762
|
+
}
|
|
763
|
+
|
|
764
|
+
else
|
|
765
|
+
{
|
|
766
|
+
if (Xc_ind[curr_pos] > (sparse_ix)(*row))
|
|
767
|
+
row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
|
|
768
|
+
else
|
|
769
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
|
|
770
|
+
}
|
|
771
|
+
}
|
|
772
|
+
|
|
773
|
+
return;
|
|
774
|
+
}
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
/* Determine imputation value */
|
|
778
|
+
std::sort(buffer_arr, buffer_arr + cnt_non_NA);
|
|
779
|
+
size_t mid_ceil = (n_sample - cnt_NA) / 2;
|
|
780
|
+
size_t nzeros = (end - st + 1) - nmatches;
|
|
781
|
+
if (nzeros > mid_ceil && buffer_arr[0] > 0)
|
|
782
|
+
{
|
|
783
|
+
fill_val = 0;
|
|
784
|
+
return;
|
|
785
|
+
}
|
|
786
|
+
|
|
787
|
+
else
|
|
788
|
+
{
|
|
789
|
+
size_t n_neg = (buffer_arr[0] > 0)?
|
|
790
|
+
0 : ((buffer_arr[cnt_non_NA - 1] < 0)?
|
|
791
|
+
cnt_non_NA : std::lower_bound(buffer_arr, buffer_arr + cnt_non_NA, (double)0) - buffer_arr);
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
if (n_neg < (mid_ceil-1) && n_neg + nzeros > mid_ceil)
|
|
795
|
+
{
|
|
796
|
+
fill_val = 0;
|
|
797
|
+
return;
|
|
798
|
+
}
|
|
799
|
+
|
|
800
|
+
else
|
|
801
|
+
{
|
|
802
|
+
/* if the sample size is odd, take the middle, otherwise take a simple average */
|
|
803
|
+
if (((n_sample - cnt_NA) % 2) != 0)
|
|
804
|
+
{
|
|
805
|
+
if (mid_ceil < n_neg)
|
|
806
|
+
fill_val = buffer_arr[mid_ceil];
|
|
807
|
+
else if (mid_ceil < n_neg + nzeros)
|
|
808
|
+
fill_val = 0;
|
|
809
|
+
else
|
|
810
|
+
fill_val = buffer_arr[mid_ceil - nzeros];
|
|
811
|
+
}
|
|
812
|
+
|
|
813
|
+
else
|
|
814
|
+
{
|
|
815
|
+
if (mid_ceil < n_neg)
|
|
816
|
+
{
|
|
817
|
+
fill_val = (buffer_arr[mid_ceil - 1] + buffer_arr[mid_ceil]) / 2;
|
|
818
|
+
}
|
|
819
|
+
|
|
820
|
+
else if (mid_ceil < n_neg + nzeros)
|
|
821
|
+
{
|
|
822
|
+
if (mid_ceil == n_neg)
|
|
823
|
+
fill_val = buffer_arr[mid_ceil - 1] / 2;
|
|
824
|
+
else
|
|
825
|
+
fill_val = 0;
|
|
826
|
+
}
|
|
827
|
+
|
|
828
|
+
else
|
|
829
|
+
{
|
|
830
|
+
if (mid_ceil == n_neg + nzeros && nzeros > 0)
|
|
831
|
+
fill_val = buffer_arr[n_neg] / 2;
|
|
832
|
+
else
|
|
833
|
+
fill_val = (buffer_arr[mid_ceil - nzeros - 1] + buffer_arr[mid_ceil - nzeros]) / 2; /* WRONG!!!! */
|
|
834
|
+
}
|
|
835
|
+
}
|
|
836
|
+
|
|
837
|
+
/* fill missing if any */
|
|
838
|
+
fill_val *= coef;
|
|
839
|
+
if (cnt_NA && fill_val)
|
|
840
|
+
for (size_t ix = 0; ix < cnt_NA; ix++)
|
|
841
|
+
res[buffer_NAs[ix]] += fill_val;
|
|
842
|
+
|
|
843
|
+
/* next time, it will need to have the offset added */
|
|
844
|
+
fill_val -= offset;
|
|
845
|
+
}
|
|
846
|
+
}
|
|
847
|
+
}
|
|
848
|
+
|
|
849
|
+
else /* no NAs */
|
|
850
|
+
{
|
|
851
|
+
for (size_t *row = ptr_st;
|
|
852
|
+
row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
|
|
853
|
+
)
|
|
854
|
+
{
|
|
855
|
+
if (Xc_ind[curr_pos] == (sparse_ix)(*row))
|
|
856
|
+
{
|
|
857
|
+
res[row - ix_arr_plus_st] += Xc[curr_pos] * coef;
|
|
858
|
+
if (row == ix_arr + end || curr_pos == end_col) break;
|
|
859
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
|
|
860
|
+
}
|
|
861
|
+
|
|
862
|
+
else
|
|
863
|
+
{
|
|
864
|
+
if (Xc_ind[curr_pos] > (sparse_ix)(*row))
|
|
865
|
+
row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
|
|
866
|
+
else
|
|
867
|
+
curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
|
|
868
|
+
}
|
|
869
|
+
}
|
|
870
|
+
}
|
|
871
|
+
}
|
|
872
|
+
|
|
873
|
+
template <class real_t_, class sparse_ix, class mapping, class ldouble_safe>
|
|
874
|
+
void add_linear_comb_weighted(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num, double *restrict res,
|
|
875
|
+
real_t_ *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
|
|
876
|
+
double &restrict coef, double x_sd, double x_mean, double &restrict fill_val, MissingAction missing_action,
|
|
877
|
+
double *restrict buffer_arr, size_t *restrict buffer_NAs, bool first_run, mapping &restrict w)
|
|
878
|
+
{
|
|
879
|
+
/* TODO: there's likely a better way of doing this directly with sparse inputs.
|
|
880
|
+
Think about some way of doing it efficiently. */
|
|
881
|
+
if (first_run && missing_action != Fail)
|
|
882
|
+
{
|
|
883
|
+
std::vector<double> denseX(end-st+1, 0.);
|
|
884
|
+
todense(ix_arr, st, end,
|
|
885
|
+
col_num, Xc, Xc_ind, Xc_indptr,
|
|
886
|
+
denseX.data());
|
|
887
|
+
std::vector<double> obs_weight(end-st+1);
|
|
888
|
+
for (size_t row = st; row <= end; row++)
|
|
889
|
+
obs_weight[row - st] = w[ix_arr[row]];
|
|
890
|
+
|
|
891
|
+
size_t end_new = end - st + 1;
|
|
892
|
+
for (size_t ix = 0; ix < end-st+1; ix++)
|
|
893
|
+
{
|
|
894
|
+
if (unlikely(is_na_or_inf(denseX[ix])))
|
|
895
|
+
{
|
|
896
|
+
std::swap(denseX[ix], denseX[--end_new]);
|
|
897
|
+
std::swap(obs_weight[ix], obs_weight[end_new]);
|
|
898
|
+
}
|
|
899
|
+
}
|
|
900
|
+
|
|
901
|
+
ldouble_safe cumw = std::accumulate(obs_weight.begin(), obs_weight.begin() + end_new, (ldouble_safe)0);
|
|
902
|
+
ldouble_safe mid_point = cumw / (ldouble_safe)2;
|
|
903
|
+
std::vector<size_t> sorted_ix(end_new);
|
|
904
|
+
std::iota(sorted_ix.begin(), sorted_ix.end(), (size_t)0);
|
|
905
|
+
std::sort(sorted_ix.begin(), sorted_ix.end(),
|
|
906
|
+
[&denseX](const size_t a, const size_t b){return denseX[a] < denseX[b];});
|
|
907
|
+
ldouble_safe currw = 0;
|
|
908
|
+
fill_val = denseX[sorted_ix.back()]; /* <- will overwrite later */
|
|
909
|
+
/* TODO: is this median calculation correct? should it do a weighted interpolation? */
|
|
910
|
+
for (size_t ix = 0; ix < end_new; ix++)
|
|
911
|
+
{
|
|
912
|
+
currw += obs_weight[sorted_ix[ix]];
|
|
913
|
+
if (currw >= mid_point)
|
|
914
|
+
{
|
|
915
|
+
if (currw == mid_point && ix < end_new-1)
|
|
916
|
+
fill_val = denseX[sorted_ix[ix]] + (denseX[sorted_ix[ix+1]] - denseX[sorted_ix[ix]]) / 2.0;
|
|
917
|
+
else
|
|
918
|
+
fill_val = denseX[sorted_ix[ix]];
|
|
919
|
+
break;
|
|
920
|
+
}
|
|
921
|
+
}
|
|
922
|
+
|
|
923
|
+
fill_val = (fill_val - x_mean) * (coef / x_sd);
|
|
924
|
+
denseX.clear();
|
|
925
|
+
obs_weight.clear();
|
|
926
|
+
sorted_ix.clear();
|
|
927
|
+
|
|
928
|
+
add_linear_comb(ix_arr, st, end, col_num, res,
|
|
929
|
+
Xc, Xc_ind, Xc_indptr,
|
|
930
|
+
coef, x_sd, x_mean, fill_val, missing_action,
|
|
931
|
+
buffer_arr, buffer_NAs, false);
|
|
932
|
+
}
|
|
933
|
+
|
|
934
|
+
else
|
|
935
|
+
{
|
|
936
|
+
add_linear_comb(ix_arr, st, end, col_num, res,
|
|
937
|
+
Xc, Xc_ind, Xc_indptr,
|
|
938
|
+
coef, x_sd, x_mean, fill_val, missing_action,
|
|
939
|
+
buffer_arr, buffer_NAs, first_run);
|
|
940
|
+
}
|
|
941
|
+
}
|
|
942
|
+
|
|
943
|
+
/* for categoricals */
|
|
944
|
+
template <class ldouble_safe>
|
|
945
|
+
void add_linear_comb(size_t *restrict ix_arr, size_t st, size_t end, double *restrict res,
|
|
946
|
+
int x[], int ncat, double *restrict cat_coef, double single_cat_coef, int chosen_cat,
|
|
947
|
+
double &restrict fill_val, double &restrict fill_new, size_t *restrict buffer_cnt, size_t *restrict buffer_pos,
|
|
948
|
+
NewCategAction new_cat_action, MissingAction missing_action, CategSplit cat_split_type, bool first_run)
|
|
949
|
+
{
|
|
950
|
+
double *restrict res_write = res - st;
|
|
951
|
+
switch(cat_split_type)
|
|
952
|
+
{
|
|
953
|
+
case SingleCateg:
|
|
954
|
+
{
|
|
955
|
+
/* in this case there's no need to make-up an impute value for new categories, only for NAs */
|
|
956
|
+
switch(missing_action)
|
|
957
|
+
{
|
|
958
|
+
case Fail:
|
|
959
|
+
{
|
|
960
|
+
for (size_t row = st; row <= end; row++)
|
|
961
|
+
res_write[row] += (x[ix_arr[row]] == chosen_cat)? single_cat_coef : 0;
|
|
962
|
+
return;
|
|
963
|
+
}
|
|
964
|
+
|
|
965
|
+
case Impute:
|
|
966
|
+
{
|
|
967
|
+
size_t cnt_NA = 0;
|
|
968
|
+
size_t cnt_this = 0;
|
|
969
|
+
size_t cnt = end - st + 1;
|
|
970
|
+
if (first_run)
|
|
971
|
+
{
|
|
972
|
+
for (size_t row = st; row <= end; row++)
|
|
973
|
+
{
|
|
974
|
+
if (unlikely(x[ix_arr[row]] < 0))
|
|
975
|
+
{
|
|
976
|
+
cnt_NA++;
|
|
977
|
+
}
|
|
978
|
+
|
|
979
|
+
else if (x[ix_arr[row]] == chosen_cat)
|
|
980
|
+
{
|
|
981
|
+
cnt_this++;
|
|
982
|
+
res_write[row] += single_cat_coef;
|
|
983
|
+
}
|
|
984
|
+
}
|
|
985
|
+
}
|
|
986
|
+
|
|
987
|
+
else
|
|
988
|
+
{
|
|
989
|
+
for (size_t row = st; row <= end; row++)
|
|
990
|
+
res_write[row] += (x[ix_arr[row]] < 0)? fill_val : ((x[ix_arr[row]] == chosen_cat)? single_cat_coef : 0);
|
|
991
|
+
return;
|
|
992
|
+
}
|
|
993
|
+
|
|
994
|
+
fill_val = (cnt_this > (cnt - cnt_NA - cnt_this))? single_cat_coef : 0;
|
|
995
|
+
if (cnt_NA && fill_val)
|
|
996
|
+
{
|
|
997
|
+
for (size_t row = st; row <= end; row++)
|
|
998
|
+
if (x[ix_arr[row]] < 0)
|
|
999
|
+
res_write[row] += fill_val;
|
|
1000
|
+
}
|
|
1001
|
+
return;
|
|
1002
|
+
}
|
|
1003
|
+
|
|
1004
|
+
default:
|
|
1005
|
+
{
|
|
1006
|
+
unexpected_error();
|
|
1007
|
+
break;
|
|
1008
|
+
}
|
|
1009
|
+
}
|
|
1010
|
+
}
|
|
1011
|
+
|
|
1012
|
+
case SubSet:
|
|
1013
|
+
{
|
|
1014
|
+
/* in this case, since the splits are by more than 1 variable, it's not possible to
|
|
1015
|
+
divide missing/new categoricals by assigning weights, so they have to be imputed
|
|
1016
|
+
in both cases, unless using random weights for the new ones, in which case they won't
|
|
1017
|
+
need to be imputed for new, but sill need it for NA */
|
|
1018
|
+
|
|
1019
|
+
if (new_cat_action == Random && missing_action == Fail)
|
|
1020
|
+
{
|
|
1021
|
+
for (size_t row = st; row <= end; row++)
|
|
1022
|
+
res_write[row] += cat_coef[x[ix_arr[row]]];
|
|
1023
|
+
return;
|
|
1024
|
+
}
|
|
1025
|
+
|
|
1026
|
+
if (!first_run)
|
|
1027
|
+
{
|
|
1028
|
+
if (missing_action == Fail)
|
|
1029
|
+
{
|
|
1030
|
+
for (size_t row = st; row <= end; row++)
|
|
1031
|
+
res_write[row] += (x[ix_arr[row]] >= ncat)? fill_new : cat_coef[x[ix_arr[row]]];
|
|
1032
|
+
}
|
|
1033
|
+
|
|
1034
|
+
else
|
|
1035
|
+
{
|
|
1036
|
+
for (size_t row = st; row <= end; row++)
|
|
1037
|
+
res_write[row] += (x[ix_arr[row]] < 0)? fill_val : ((x[ix_arr[row]] >= ncat)? fill_new : cat_coef[x[ix_arr[row]]]);
|
|
1038
|
+
}
|
|
1039
|
+
return;
|
|
1040
|
+
}
|
|
1041
|
+
|
|
1042
|
+
std::fill(buffer_cnt, buffer_cnt + ncat + 1, 0);
|
|
1043
|
+
switch(missing_action)
|
|
1044
|
+
{
|
|
1045
|
+
case Fail:
|
|
1046
|
+
{
|
|
1047
|
+
for (size_t row = st; row <= end; row++)
|
|
1048
|
+
{
|
|
1049
|
+
buffer_cnt[x[ix_arr[row]]]++;
|
|
1050
|
+
res_write[row] += cat_coef[x[ix_arr[row]]];
|
|
1051
|
+
}
|
|
1052
|
+
break;
|
|
1053
|
+
}
|
|
1054
|
+
|
|
1055
|
+
default:
|
|
1056
|
+
{
|
|
1057
|
+
for (size_t row = st; row <= end; row++)
|
|
1058
|
+
{
|
|
1059
|
+
if (x[ix_arr[row]] >= 0)
|
|
1060
|
+
{
|
|
1061
|
+
buffer_cnt[x[ix_arr[row]]]++;
|
|
1062
|
+
res_write[row] += cat_coef[x[ix_arr[row]]];
|
|
1063
|
+
}
|
|
1064
|
+
|
|
1065
|
+
else
|
|
1066
|
+
{
|
|
1067
|
+
buffer_cnt[ncat]++;
|
|
1068
|
+
}
|
|
1069
|
+
|
|
1070
|
+
}
|
|
1071
|
+
break;
|
|
1072
|
+
}
|
|
1073
|
+
}
|
|
1074
|
+
|
|
1075
|
+
switch(new_cat_action)
|
|
1076
|
+
{
|
|
1077
|
+
case Smallest:
|
|
1078
|
+
{
|
|
1079
|
+
size_t smallest = SIZE_MAX;
|
|
1080
|
+
int cat_smallest = 0;
|
|
1081
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
1082
|
+
{
|
|
1083
|
+
if (buffer_cnt[cat] > 0 && buffer_cnt[cat] < smallest)
|
|
1084
|
+
{
|
|
1085
|
+
smallest = buffer_cnt[cat];
|
|
1086
|
+
cat_smallest = cat;
|
|
1087
|
+
}
|
|
1088
|
+
}
|
|
1089
|
+
fill_new = cat_coef[cat_smallest];
|
|
1090
|
+
if (missing_action == Fail) break;
|
|
1091
|
+
}
|
|
1092
|
+
|
|
1093
|
+
default:
|
|
1094
|
+
{
|
|
1095
|
+
/* Determine imputation value as the category in sorted order that gives 50% + 1 */
|
|
1096
|
+
ldouble_safe cnt_l = (ldouble_safe)((end - st + 1) - buffer_cnt[ncat]);
|
|
1097
|
+
std::iota(buffer_pos, buffer_pos + ncat, (size_t)0);
|
|
1098
|
+
std::sort(buffer_pos, buffer_pos + ncat, [&cat_coef](const size_t a, const size_t b){return cat_coef[a] < cat_coef[b];});
|
|
1099
|
+
|
|
1100
|
+
double cumprob = 0;
|
|
1101
|
+
int cat;
|
|
1102
|
+
for (cat = 0; cat < ncat; cat++)
|
|
1103
|
+
{
|
|
1104
|
+
cumprob += (ldouble_safe)buffer_cnt[buffer_pos[cat]] / cnt_l;
|
|
1105
|
+
if (cumprob >= .5) break;
|
|
1106
|
+
}
|
|
1107
|
+
// cat = std::min(cat, ncat); /* in case it picks the last one */
|
|
1108
|
+
fill_val = cat_coef[buffer_pos[cat]];
|
|
1109
|
+
if (new_cat_action != Smallest)
|
|
1110
|
+
fill_new = fill_val;
|
|
1111
|
+
|
|
1112
|
+
if (buffer_cnt[ncat] > 0 && fill_val) /* NAs */
|
|
1113
|
+
for (size_t row = st; row <= end; row++)
|
|
1114
|
+
if (unlikely(x[ix_arr[row]] < 0))
|
|
1115
|
+
res_write[row] += fill_val;
|
|
1116
|
+
}
|
|
1117
|
+
}
|
|
1118
|
+
|
|
1119
|
+
/* now fill unseen categories */
|
|
1120
|
+
if (new_cat_action != Random)
|
|
1121
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
1122
|
+
if (!buffer_cnt[cat])
|
|
1123
|
+
cat_coef[cat] = fill_new;
|
|
1124
|
+
|
|
1125
|
+
}
|
|
1126
|
+
}
|
|
1127
|
+
}
|
|
1128
|
+
|
|
1129
|
+
template <class mapping, class ldouble_safe>
|
|
1130
|
+
void add_linear_comb_weighted(size_t *restrict ix_arr, size_t st, size_t end, double *restrict res,
|
|
1131
|
+
int x[], int ncat, double *restrict cat_coef, double single_cat_coef, int chosen_cat,
|
|
1132
|
+
double &restrict fill_val, double &restrict fill_new, size_t *restrict buffer_pos,
|
|
1133
|
+
NewCategAction new_cat_action, MissingAction missing_action, CategSplit cat_split_type,
|
|
1134
|
+
bool first_run, mapping &restrict w)
|
|
1135
|
+
{
|
|
1136
|
+
double *restrict res_write = res - st;
|
|
1137
|
+
/* TODO: this buffer should be allocated externally */
|
|
1138
|
+
|
|
1139
|
+
switch(cat_split_type)
|
|
1140
|
+
{
|
|
1141
|
+
case SingleCateg:
|
|
1142
|
+
{
|
|
1143
|
+
/* in this case there's no need to make-up an impute value for new categories, only for NAs */
|
|
1144
|
+
switch(missing_action)
|
|
1145
|
+
{
|
|
1146
|
+
case Fail:
|
|
1147
|
+
{
|
|
1148
|
+
for (size_t row = st; row <= end; row++)
|
|
1149
|
+
res_write[row] += (x[ix_arr[row]] == chosen_cat)? single_cat_coef : 0;
|
|
1150
|
+
return;
|
|
1151
|
+
}
|
|
1152
|
+
|
|
1153
|
+
case Impute:
|
|
1154
|
+
{
|
|
1155
|
+
bool has_NA = false;
|
|
1156
|
+
ldouble_safe cnt_this = 0;
|
|
1157
|
+
ldouble_safe cnt_other = 0;
|
|
1158
|
+
if (first_run)
|
|
1159
|
+
{
|
|
1160
|
+
for (size_t row = st; row <= end; row++)
|
|
1161
|
+
{
|
|
1162
|
+
if (unlikely(x[ix_arr[row]] < 0))
|
|
1163
|
+
{
|
|
1164
|
+
has_NA = true;
|
|
1165
|
+
}
|
|
1166
|
+
|
|
1167
|
+
else if (x[ix_arr[row]] == chosen_cat)
|
|
1168
|
+
{
|
|
1169
|
+
cnt_this += w[ix_arr[row]];
|
|
1170
|
+
res_write[row] += single_cat_coef;
|
|
1171
|
+
}
|
|
1172
|
+
|
|
1173
|
+
else
|
|
1174
|
+
{
|
|
1175
|
+
cnt_other += w[ix_arr[row]];
|
|
1176
|
+
}
|
|
1177
|
+
}
|
|
1178
|
+
}
|
|
1179
|
+
|
|
1180
|
+
else
|
|
1181
|
+
{
|
|
1182
|
+
for (size_t row = st; row <= end; row++)
|
|
1183
|
+
res_write[row] += (x[ix_arr[row]] < 0)? fill_val : ((x[ix_arr[row]] == chosen_cat)? single_cat_coef : 0);
|
|
1184
|
+
return;
|
|
1185
|
+
}
|
|
1186
|
+
|
|
1187
|
+
fill_val = (cnt_this > cnt_other)? single_cat_coef : 0;
|
|
1188
|
+
if (has_NA && fill_val)
|
|
1189
|
+
{
|
|
1190
|
+
for (size_t row = st; row <= end; row++)
|
|
1191
|
+
if (unlikely(x[ix_arr[row]] < 0))
|
|
1192
|
+
res_write[row] += fill_val;
|
|
1193
|
+
}
|
|
1194
|
+
return;
|
|
1195
|
+
}
|
|
1196
|
+
|
|
1197
|
+
default:
|
|
1198
|
+
{
|
|
1199
|
+
unexpected_error();
|
|
1200
|
+
break;
|
|
1201
|
+
}
|
|
1202
|
+
}
|
|
1203
|
+
}
|
|
1204
|
+
|
|
1205
|
+
case SubSet:
|
|
1206
|
+
{
|
|
1207
|
+
/* in this case, since the splits are by more than 1 variable, it's not possible to
|
|
1208
|
+
divide missing/new categoricals by assigning weights, so they have to be imputed
|
|
1209
|
+
in both cases, unless using random weights for the new ones, in which case they won't
|
|
1210
|
+
need to be imputed for new, but sill need it for NA */
|
|
1211
|
+
|
|
1212
|
+
if (new_cat_action == Random && missing_action == Fail)
|
|
1213
|
+
{
|
|
1214
|
+
for (size_t row = st; row <= end; row++)
|
|
1215
|
+
res_write[row] += cat_coef[x[ix_arr[row]]];
|
|
1216
|
+
return;
|
|
1217
|
+
}
|
|
1218
|
+
|
|
1219
|
+
if (!first_run)
|
|
1220
|
+
{
|
|
1221
|
+
if (missing_action == Fail)
|
|
1222
|
+
{
|
|
1223
|
+
for (size_t row = st; row <= end; row++)
|
|
1224
|
+
res_write[row] += (x[ix_arr[row]] >= ncat)? fill_new : cat_coef[x[ix_arr[row]]];
|
|
1225
|
+
}
|
|
1226
|
+
|
|
1227
|
+
else
|
|
1228
|
+
{
|
|
1229
|
+
for (size_t row = st; row <= end; row++)
|
|
1230
|
+
res_write[row] += (x[ix_arr[row]] < 0)? fill_val : ((x[ix_arr[row]] >= ncat)? fill_new : cat_coef[x[ix_arr[row]]]);
|
|
1231
|
+
}
|
|
1232
|
+
return;
|
|
1233
|
+
}
|
|
1234
|
+
|
|
1235
|
+
/* TODO: this buffer should be allocated externally */
|
|
1236
|
+
std::vector<ldouble_safe> buffer_cnt(ncat+1, 0.);
|
|
1237
|
+
switch(missing_action)
|
|
1238
|
+
{
|
|
1239
|
+
case Fail:
|
|
1240
|
+
{
|
|
1241
|
+
for (size_t row = st; row <= end; row++)
|
|
1242
|
+
{
|
|
1243
|
+
buffer_cnt[x[ix_arr[row]]] += w[ix_arr[row]];
|
|
1244
|
+
res_write[row] += cat_coef[x[ix_arr[row]]];
|
|
1245
|
+
}
|
|
1246
|
+
break;
|
|
1247
|
+
}
|
|
1248
|
+
|
|
1249
|
+
default:
|
|
1250
|
+
{
|
|
1251
|
+
for (size_t row = st; row <= end; row++)
|
|
1252
|
+
{
|
|
1253
|
+
if (likely(x[ix_arr[row]] >= 0))
|
|
1254
|
+
{
|
|
1255
|
+
buffer_cnt[x[ix_arr[row]]] += w[ix_arr[row]];
|
|
1256
|
+
res_write[row] += cat_coef[x[ix_arr[row]]];
|
|
1257
|
+
}
|
|
1258
|
+
|
|
1259
|
+
else
|
|
1260
|
+
{
|
|
1261
|
+
buffer_cnt[ncat] += w[ix_arr[row]];
|
|
1262
|
+
}
|
|
1263
|
+
|
|
1264
|
+
}
|
|
1265
|
+
break;
|
|
1266
|
+
}
|
|
1267
|
+
}
|
|
1268
|
+
|
|
1269
|
+
switch(new_cat_action)
|
|
1270
|
+
{
|
|
1271
|
+
case Smallest:
|
|
1272
|
+
{
|
|
1273
|
+
ldouble_safe smallest = std::numeric_limits<ldouble_safe>::infinity();
|
|
1274
|
+
int cat_smallest = 0;
|
|
1275
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
1276
|
+
{
|
|
1277
|
+
if (buffer_cnt[cat] > 0 && buffer_cnt[cat] < smallest)
|
|
1278
|
+
{
|
|
1279
|
+
smallest = buffer_cnt[cat];
|
|
1280
|
+
cat_smallest = cat;
|
|
1281
|
+
}
|
|
1282
|
+
}
|
|
1283
|
+
fill_new = cat_coef[cat_smallest];
|
|
1284
|
+
if (missing_action == Fail) break;
|
|
1285
|
+
}
|
|
1286
|
+
|
|
1287
|
+
default:
|
|
1288
|
+
{
|
|
1289
|
+
/* Determine imputation value as the category in sorted order that gives 50% + 1 */
|
|
1290
|
+
ldouble_safe cnt_l = std::accumulate(buffer_cnt.begin(), buffer_cnt.begin() + ncat, (ldouble_safe)0);
|
|
1291
|
+
std::iota(buffer_pos, buffer_pos + ncat, (size_t)0);
|
|
1292
|
+
std::sort(buffer_pos, buffer_pos + ncat, [&cat_coef](const size_t a, const size_t b){return cat_coef[a] < cat_coef[b];});
|
|
1293
|
+
|
|
1294
|
+
double cumprob = 0;
|
|
1295
|
+
int cat;
|
|
1296
|
+
for (cat = 0; cat < ncat; cat++)
|
|
1297
|
+
{
|
|
1298
|
+
cumprob += buffer_cnt[buffer_pos[cat]] / cnt_l;
|
|
1299
|
+
if (cumprob >= .5) break;
|
|
1300
|
+
}
|
|
1301
|
+
// cat = std::min(cat, ncat); /* in case it picks the last one */
|
|
1302
|
+
fill_val = cat_coef[buffer_pos[cat]];
|
|
1303
|
+
if (new_cat_action != Smallest)
|
|
1304
|
+
fill_new = fill_val;
|
|
1305
|
+
|
|
1306
|
+
if (buffer_cnt[ncat] > 0 && fill_val) /* NAs */
|
|
1307
|
+
for (size_t row = st; row <= end; row++)
|
|
1308
|
+
if (unlikely(x[ix_arr[row]] < 0))
|
|
1309
|
+
res_write[row] += fill_val;
|
|
1310
|
+
}
|
|
1311
|
+
}
|
|
1312
|
+
|
|
1313
|
+
/* now fill unseen categories */
|
|
1314
|
+
if (new_cat_action != Random)
|
|
1315
|
+
for (int cat = 0; cat < ncat; cat++)
|
|
1316
|
+
if (!buffer_cnt[cat])
|
|
1317
|
+
cat_coef[cat] = fill_new;
|
|
1318
|
+
|
|
1319
|
+
}
|
|
1320
|
+
}
|
|
1321
|
+
}
|