isotree 0.2.2 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (151) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -1
  3. data/LICENSE.txt +2 -2
  4. data/README.md +32 -14
  5. data/ext/isotree/ext.cpp +144 -31
  6. data/ext/isotree/extconf.rb +7 -7
  7. data/lib/isotree/isolation_forest.rb +110 -30
  8. data/lib/isotree/version.rb +1 -1
  9. data/vendor/isotree/LICENSE +1 -1
  10. data/vendor/isotree/README.md +165 -27
  11. data/vendor/isotree/include/isotree.hpp +2111 -0
  12. data/vendor/isotree/include/isotree_oop.hpp +394 -0
  13. data/vendor/isotree/inst/COPYRIGHTS +62 -0
  14. data/vendor/isotree/src/RcppExports.cpp +525 -52
  15. data/vendor/isotree/src/Rwrapper.cpp +1931 -268
  16. data/vendor/isotree/src/c_interface.cpp +953 -0
  17. data/vendor/isotree/src/crit.hpp +4232 -0
  18. data/vendor/isotree/src/dist.hpp +1886 -0
  19. data/vendor/isotree/src/exp_depth_table.hpp +134 -0
  20. data/vendor/isotree/src/extended.hpp +1444 -0
  21. data/vendor/isotree/src/external_facing_generic.hpp +399 -0
  22. data/vendor/isotree/src/fit_model.hpp +2401 -0
  23. data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
  24. data/vendor/isotree/src/helpers_iforest.hpp +813 -0
  25. data/vendor/isotree/src/{impute.cpp → impute.hpp} +353 -122
  26. data/vendor/isotree/src/indexer.cpp +515 -0
  27. data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
  28. data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
  29. data/vendor/isotree/src/isoforest.hpp +1659 -0
  30. data/vendor/isotree/src/isotree.hpp +1804 -392
  31. data/vendor/isotree/src/isotree_exportable.hpp +99 -0
  32. data/vendor/isotree/src/merge_models.cpp +159 -16
  33. data/vendor/isotree/src/mult.hpp +1321 -0
  34. data/vendor/isotree/src/oop_interface.cpp +842 -0
  35. data/vendor/isotree/src/oop_interface.hpp +278 -0
  36. data/vendor/isotree/src/other_helpers.hpp +219 -0
  37. data/vendor/isotree/src/predict.hpp +1932 -0
  38. data/vendor/isotree/src/python_helpers.hpp +134 -0
  39. data/vendor/isotree/src/ref_indexer.hpp +154 -0
  40. data/vendor/isotree/src/robinmap/LICENSE +21 -0
  41. data/vendor/isotree/src/robinmap/README.md +483 -0
  42. data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
  43. data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1620 -0
  44. data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
  45. data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
  46. data/vendor/isotree/src/serialize.cpp +4300 -139
  47. data/vendor/isotree/src/sql.cpp +141 -59
  48. data/vendor/isotree/src/subset_models.cpp +174 -0
  49. data/vendor/isotree/src/utils.hpp +3808 -0
  50. data/vendor/isotree/src/xoshiro.hpp +467 -0
  51. data/vendor/isotree/src/ziggurat.hpp +405 -0
  52. metadata +38 -104
  53. data/vendor/cereal/LICENSE +0 -24
  54. data/vendor/cereal/README.md +0 -85
  55. data/vendor/cereal/include/cereal/access.hpp +0 -351
  56. data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
  57. data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
  58. data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
  59. data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
  60. data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
  61. data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
  62. data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
  63. data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
  64. data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
  65. data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
  66. data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
  67. data/vendor/cereal/include/cereal/details/util.hpp +0 -84
  68. data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
  69. data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
  70. data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
  71. data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
  72. data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
  73. data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
  74. data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
  75. data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
  76. data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
  77. data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
  78. data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
  79. data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
  80. data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
  81. data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
  82. data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
  83. data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
  84. data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
  85. data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
  86. data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
  87. data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
  88. data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
  89. data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
  90. data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
  91. data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
  92. data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
  93. data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
  94. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
  95. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
  96. data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
  97. data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
  98. data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
  99. data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
  100. data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
  101. data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
  102. data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
  103. data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
  104. data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
  105. data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
  106. data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
  107. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
  108. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
  109. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
  110. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
  111. data/vendor/cereal/include/cereal/macros.hpp +0 -154
  112. data/vendor/cereal/include/cereal/specialize.hpp +0 -139
  113. data/vendor/cereal/include/cereal/types/array.hpp +0 -79
  114. data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
  115. data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
  116. data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
  117. data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
  118. data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
  119. data/vendor/cereal/include/cereal/types/common.hpp +0 -129
  120. data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
  121. data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
  122. data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
  123. data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
  124. data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
  125. data/vendor/cereal/include/cereal/types/list.hpp +0 -62
  126. data/vendor/cereal/include/cereal/types/map.hpp +0 -36
  127. data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
  128. data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
  129. data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
  130. data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
  131. data/vendor/cereal/include/cereal/types/set.hpp +0 -103
  132. data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
  133. data/vendor/cereal/include/cereal/types/string.hpp +0 -61
  134. data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
  135. data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
  136. data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
  137. data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
  138. data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
  139. data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
  140. data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
  141. data/vendor/cereal/include/cereal/version.hpp +0 -52
  142. data/vendor/isotree/src/Makevars +0 -4
  143. data/vendor/isotree/src/crit.cpp +0 -912
  144. data/vendor/isotree/src/dist.cpp +0 -749
  145. data/vendor/isotree/src/extended.cpp +0 -790
  146. data/vendor/isotree/src/fit_model.cpp +0 -1090
  147. data/vendor/isotree/src/helpers_iforest.cpp +0 -324
  148. data/vendor/isotree/src/isoforest.cpp +0 -771
  149. data/vendor/isotree/src/mult.cpp +0 -607
  150. data/vendor/isotree/src/predict.cpp +0 -853
  151. data/vendor/isotree/src/utils.cpp +0 -1566
@@ -1,912 +0,0 @@
1
- /* Isolation forests and variations thereof, with adjustments for incorporation
2
- * of categorical variables and missing values.
3
- * Writen for C++11 standard and aimed at being used in R and Python.
4
- *
5
- * This library is based on the following works:
6
- * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
- * "Isolation forest."
8
- * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
- * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
- * "Isolation-based anomaly detection."
11
- * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
- * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
- * "Extended Isolation Forest."
14
- * arXiv preprint arXiv:1811.02141 (2018).
15
- * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
- * "On detecting clustered anomalies using SCiForest."
17
- * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
- * [5] https://sourceforge.net/projects/iforest/
19
- * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
- * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
- * [8] Cortes, David. "Distance approximation using Isolation Forests." arXiv preprint arXiv:1910.12362 (2019).
22
- * [9] Cortes, David. "Imputing missing values with unsupervised random trees." arXiv preprint arXiv:1911.06646 (2019).
23
- *
24
- * BSD 2-Clause License
25
- * Copyright (c) 2020, David Cortes
26
- * All rights reserved.
27
- * Redistribution and use in source and binary forms, with or without
28
- * modification, are permitted provided that the following conditions are met:
29
- * * Redistributions of source code must retain the above copyright notice, this
30
- * list of conditions and the following disclaimer.
31
- * * Redistributions in binary form must reproduce the above copyright notice,
32
- * this list of conditions and the following disclaimer in the documentation
33
- * and/or other materials provided with the distribution.
34
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
38
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44
- */
45
- #include "isotree.hpp"
46
-
47
- #define pw1(x) ((x))
48
- #define pw2(x) ((x) * (x))
49
- #define pw3(x) ((x) * (x) * (x))
50
- #define pw4(x) ((x) * (x) * (x) * (x))
51
-
52
- double calc_kurtosis(size_t ix_arr[], size_t st, size_t end, double x[], MissingAction missing_action)
53
- {
54
- long double m = 0;
55
- long double M2 = 0, M3 = 0, M4 = 0;
56
- long double delta, delta_s, delta_div;
57
- long double diff, n;
58
-
59
- if (missing_action == Fail)
60
- {
61
- for (size_t row = st; row <= end; row++)
62
- {
63
- n = (long double)(row - st + 1);
64
-
65
- delta = x[ix_arr[row]] - m;
66
- delta_div = delta / n;
67
- delta_s = delta_div * delta_div;
68
- diff = delta * (delta_div * (long double)(row - st));
69
-
70
- m += delta_div;
71
- M4 += diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
72
- M3 += diff * delta_div * (n - 2) - 3 * delta_div * M2;
73
- M2 += diff;
74
- }
75
-
76
- return ( M4 / M2 ) * ( (long double)(end - st + 1) / M2 );
77
- }
78
-
79
- else
80
- {
81
- size_t cnt = 0;
82
- for (size_t row = st; row <= end; row++)
83
- {
84
- if (!is_na_or_inf(x[ix_arr[row]]))
85
- {
86
- cnt++;
87
- n = (long double) cnt;
88
-
89
- delta = x[ix_arr[row]] - m;
90
- delta_div = delta / n;
91
- delta_s = delta_div * delta_div;
92
- diff = delta * (delta_div * (long double)(cnt - 1));
93
-
94
- m += delta_div;
95
- M4 += diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
96
- M3 += diff * delta_div * (n - 2) - 3 * delta_div * M2;
97
- M2 += diff;
98
- }
99
- }
100
-
101
- return ( M4 / M2 ) * ( (long double)cnt / M2 );
102
- }
103
- }
104
-
105
-
106
- double calc_kurtosis(size_t ix_arr[], size_t st, size_t end, size_t col_num,
107
- double Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
108
- MissingAction missing_action)
109
- {
110
- /* ix_arr must be already sorted beforehand */
111
- if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
112
- return 0;
113
-
114
- long double s1 = 0;
115
- long double s2 = 0;
116
- long double s3 = 0;
117
- long double s4 = 0;
118
- size_t cnt = end - st + 1;
119
-
120
- if (cnt <= 1) return 0;
121
-
122
- size_t st_col = Xc_indptr[col_num];
123
- size_t end_col = Xc_indptr[col_num + 1] - 1;
124
- size_t curr_pos = st_col;
125
- size_t ind_end_col = Xc_ind[end_col];
126
- size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
127
-
128
- if (missing_action != Fail)
129
- {
130
- for (size_t *row = ptr_st;
131
- row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
132
- )
133
- {
134
- if (Xc_ind[curr_pos] == *row)
135
- {
136
- if (is_na_or_inf(Xc[curr_pos]))
137
- {
138
- cnt--;
139
- }
140
-
141
- else
142
- {
143
- s1 += pw1(Xc[curr_pos]);
144
- s2 += pw2(Xc[curr_pos]);
145
- s3 += pw3(Xc[curr_pos]);
146
- s4 += pw4(Xc[curr_pos]);
147
- }
148
-
149
- if (row == ix_arr + end || curr_pos == end_col) break;
150
- curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
151
- }
152
-
153
- else
154
- {
155
- if (Xc_ind[curr_pos] > *row)
156
- row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
157
- else
158
- curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
159
- }
160
- }
161
- }
162
-
163
- else
164
- {
165
- for (size_t *row = ptr_st;
166
- row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
167
- )
168
- {
169
- if (Xc_ind[curr_pos] == *row)
170
- {
171
- s1 += pw1(Xc[curr_pos]);
172
- s2 += pw2(Xc[curr_pos]);
173
- s3 += pw3(Xc[curr_pos]);
174
- s4 += pw4(Xc[curr_pos]);
175
-
176
- if (row == ix_arr + end || curr_pos == end_col) break;
177
- curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
178
- }
179
-
180
- else
181
- {
182
- if (Xc_ind[curr_pos] > *row)
183
- row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
184
- else
185
- curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
186
- }
187
- }
188
- }
189
-
190
- if (cnt <= 1 || s2 == 0 || s2 == pw2(s1)) return 0;
191
- long double cnt_l = (long double) cnt;
192
- long double sn = s1 / cnt_l;
193
- long double v = s2 / cnt_l - pw2(sn);
194
- if (v <= 0) return 0;
195
- return (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt_l * pw4(sn)) / (cnt_l * pw2(v));
196
- }
197
-
198
-
199
- double calc_kurtosis(size_t ix_arr[], size_t st, size_t end, int x[], int ncat, size_t buffer_cnt[], double buffer_prob[],
200
- MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
201
- {
202
- /* This calculation proceeds as follows:
203
- - If splitting by subsets, it will assign a random weight ~Unif(0,1) to
204
- each category, and approximate kurtosis by sampling from such distribution
205
- with the same probabilities as given by the current counts.
206
- - If splitting by isolating one category, will binarize at each categorical level,
207
- assume the values are zero or one, and output the average assuming each categorical
208
- level has equal probability of being picked.
209
- (Note that both are misleading heuristics, but might be better than random)
210
- */
211
- size_t cnt = end - st + 1;
212
- std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
213
- double sum_kurt = 0;
214
-
215
- if (missing_action == Fail)
216
- {
217
- for (size_t row = st; row <= end; row++)
218
- buffer_cnt[x[ix_arr[row]]]++;
219
- }
220
-
221
- else
222
- {
223
- for (size_t row = st; row <= end; row++)
224
- {
225
- if (x[ix_arr[row]] >= 0)
226
- buffer_cnt[x[ix_arr[row]]]++;
227
- else
228
- buffer_cnt[ncat]++;
229
- }
230
- }
231
-
232
- cnt -= buffer_cnt[ncat];
233
- if (cnt <= 1) return 0;
234
- long double cnt_l = (long double) cnt;
235
- for (int cat = 0; cat < ncat; cat++)
236
- buffer_prob[cat] = buffer_cnt[cat] / cnt_l;
237
-
238
- switch(cat_split_type)
239
- {
240
- case SubSet:
241
- {
242
- long double temp_v;
243
- long double s1, s2, s3, s4;
244
- long double coef;
245
- std::uniform_real_distribution<double> runif(0, 1);
246
- size_t ntry = 50;
247
- for (size_t iternum = 0; iternum < 50; iternum++)
248
- {
249
- s1 = 0; s2 = 0; s3 = 0; s4 = 0;
250
- for (int cat = 0; cat < ncat; cat++)
251
- {
252
- coef = runif(rnd_generator);
253
- s1 += buffer_prob[cat] * pw1(coef);
254
- s2 += buffer_prob[cat] * pw2(coef);
255
- s3 += buffer_prob[cat] * pw3(coef);
256
- s4 += buffer_prob[cat] * pw4(coef);
257
- }
258
- temp_v = s2 - pw2(s1);
259
- if (temp_v <= 0)
260
- ntry--;
261
- else
262
- sum_kurt += (s4 - 4 * s3 * pw1(s1) + 6 * s2 * pw2(s1) - 4 * s1 * pw3(s1) + pw4(s1)) / pw2(temp_v);
263
- }
264
- if (!ntry)
265
- return 0;
266
- else
267
- return sum_kurt / (long double)ntry;
268
- }
269
-
270
- case SingleCateg:
271
- {
272
- double p;
273
- int ncat_present = ncat;
274
- for (int cat = 0; cat < ncat; cat++)
275
- {
276
- p = buffer_prob[cat];
277
- if (p == 0)
278
- ncat_present--;
279
- else
280
- sum_kurt += (p - 4 * p * pw1(p) + 6 * p * pw2(p) - 4 * p * pw3(p) + pw4(p)) / pw2(p - pw2(p));
281
- }
282
- if (ncat_present <= 1)
283
- return 0;
284
- else
285
- return sum_kurt / (double) ncat_present;
286
- }
287
- }
288
-
289
- return -1; /* this will never be reached, but CRAN complains otherwise */
290
- }
291
-
292
-
293
- double expected_sd_cat(double p[], size_t n, size_t pos[])
294
- {
295
- if (n <= 1) return 0;
296
-
297
- long double cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[1]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[1]]) / 3.0 + p[pos[1]] / 3.0;
298
- for (size_t cat1 = 2; cat1 < n; cat1++)
299
- {
300
- cum_var += p[pos[cat1]] / 3.0 - square(p[pos[cat1]]) / 3.0;
301
- for (size_t cat2 = 0; cat2 < cat1; cat2++)
302
- cum_var -= p[pos[cat1]] * p[pos[cat2]] / 2.0;
303
- }
304
- return sqrt(fmax(cum_var, 1e-8));
305
- }
306
-
307
- double expected_sd_cat(size_t counts[], double p[], size_t n, size_t pos[])
308
- {
309
- if (n <= 1) return 0;
310
-
311
- size_t tot = std::accumulate(pos, pos + n, (size_t)0, [&counts](size_t tot, const size_t ix){return tot + counts[ix];});
312
- long double cnt_div = (long double) tot;
313
- for (size_t cat = 0; cat < n; cat++)
314
- p[pos[cat]] = (long double)counts[pos[cat]] / cnt_div;
315
-
316
- return expected_sd_cat(p, n, pos);
317
- }
318
-
319
- double expected_sd_cat_single(size_t counts[], double p[], size_t n, size_t pos[], size_t cat_exclude, size_t cnt)
320
- {
321
- if (cat_exclude == 0)
322
- return expected_sd_cat(counts, p, n-1, pos + 1);
323
-
324
- else if (cat_exclude == (n-1))
325
- return expected_sd_cat(counts, p, n-1, pos);
326
-
327
- size_t ix_exclude = pos[cat_exclude];
328
-
329
- long double cnt_div = (long double) (cnt - counts[ix_exclude]);
330
- for (size_t cat = 0; cat < n; cat++)
331
- p[pos[cat]] = (long double)counts[pos[cat]] / cnt_div;
332
-
333
- double cum_var;
334
- if (cat_exclude != 1)
335
- cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[1]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[1]]) / 3.0 + p[pos[1]] / 3.0;
336
- else
337
- cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[2]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[2]]) / 3.0 + p[pos[2]] / 3.0;
338
- for (size_t cat1 = (cat_exclude == 1)? 3 : 2; cat1 < n; cat1++)
339
- {
340
- if (pos[cat1] == ix_exclude) continue;
341
- cum_var += p[pos[cat1]] / 3.0 - square(p[pos[cat1]]) / 3.0;
342
- for (size_t cat2 = 0; cat2 < cat1; cat2++)
343
- {
344
- if (pos[cat2] == ix_exclude) continue;
345
- cum_var -= p[pos[cat1]] * p[pos[cat2]] / 2.0;
346
- }
347
-
348
- }
349
- return sqrt(fmax(cum_var, 1e-8));
350
- }
351
-
352
- double numeric_gain(size_t cnt_left, size_t cnt_right,
353
- long double sum_left, long double sum_right,
354
- long double sum_sq_left, long double sum_sq_right,
355
- double sd_full, long double cnt)
356
- {
357
- long double residual =
358
- (long double) cnt_left * calc_sd_raw_l(cnt_left, sum_left, sum_sq_left) +
359
- (long double) cnt_right * calc_sd_raw_l(cnt_right, sum_right, sum_sq_right);
360
- return 1 - residual / (cnt * sd_full);
361
- }
362
-
363
- double numeric_gain_no_div(size_t cnt_left, size_t cnt_right,
364
- long double sum_left, long double sum_right,
365
- long double sum_sq_left, long double sum_sq_right,
366
- double sd_full, long double cnt)
367
- {
368
- long double residual =
369
- (long double) cnt_left * calc_sd_raw_l(cnt_left, sum_left, sum_sq_left) +
370
- (long double) cnt_right * calc_sd_raw_l(cnt_right, sum_right, sum_sq_right);
371
- return sd_full - residual / cnt;
372
- }
373
-
374
- double categ_gain(size_t cnt_left, size_t cnt_right,
375
- long double s_left, long double s_right,
376
- long double base_info, long double cnt)
377
- {
378
- return (
379
- base_info -
380
- (((cnt_left <= 1)? 0 : ((long double)cnt_left * logl((long double)cnt_left))) - s_left) -
381
- (((cnt_right <= 1)? 0 : ((long double)cnt_right * logl((long double)cnt_right))) - s_right)
382
- ) / cnt;
383
- }
384
-
385
-
386
- #define avg_between(a, b) (((a) + (b)) / 2)
387
- #define sd_gain(sd, sd_left, sd_right) (1.0 - ((sd_left) + (sd_right)) / (2.0 * (sd)))
388
-
389
- /* for split-criterion in hyperplanes (see below for version aimed at single-variable splits) */
390
- double eval_guided_crit(double *restrict x, size_t n, GainCriterion criterion, double min_gain,
391
- double &split_point, double &xmin, double &xmax)
392
- {
393
- /* Note: the input 'x' is supposed to be a linear combination of standardized variables, so
394
- all numbers are assumed to be small and in the same scale */
395
-
396
- /* here it's assumed the 'x' vector matches exactly with 'ix_arr' + 'st' */
397
- if (n == 2)
398
- {
399
- split_point = avg_between(x[0], x[1]);
400
- return 0;
401
- }
402
-
403
- /* sort in ascending order */
404
- std::sort(x, x + n);
405
- if (x[0] == x[n-1]) return -HUGE_VAL;
406
- xmin = x[0]; xmax = x[n-1];
407
-
408
- /* compute sum - sum_sq - sd in one pass */
409
- long double sum = 0;
410
- long double sum_sq = 0;
411
- double sd_full;
412
- for (size_t row = 0; row < n; row++)
413
- {
414
- sum += x[row];
415
- sum_sq += square(x[row]);
416
- }
417
- sd_full = calc_sd_raw(n, sum, sum_sq);
418
-
419
- /* try splits by moving observations one at a time from right to left */
420
- long double sum_left = 0;
421
- long double sum_sq_left = 0;
422
- long double sum_right = sum;
423
- long double sum_sq_right = sum_sq;
424
- double this_gain = -HUGE_VAL;
425
- double best_gain = -HUGE_VAL;
426
-
427
- switch(criterion)
428
- {
429
- case Averaged:
430
- {
431
- for (size_t row = 0; row < n-1; row++)
432
- {
433
- sum_left += x[row];
434
- sum_sq_left += square(x[row]);
435
- sum_right -= x[row];
436
- sum_sq_right -= square(x[row]);
437
-
438
- if (x[row] == x[row + 1]) continue;
439
-
440
- this_gain = sd_gain(sd_full,
441
- calc_sd_raw(row + 1, sum_left, sum_sq_left),
442
- calc_sd_raw(n - row - 1, sum_right, sum_sq_right)
443
- );
444
- if (this_gain > min_gain && this_gain > best_gain)
445
- {
446
- best_gain = this_gain;
447
- split_point = avg_between(x[row], x[row + 1]);
448
- }
449
- }
450
- break;
451
- }
452
-
453
- case Pooled:
454
- {
455
- long double cnt = (long double) n;
456
- for (size_t row = 0; row < n-1; row++)
457
- {
458
- sum_left += x[row];
459
- sum_sq_left += square(x[row]);
460
- sum_right -= x[row];
461
- sum_sq_right -= square(x[row]);
462
-
463
- if (x[row] == x[row + 1]) continue;
464
-
465
- this_gain = numeric_gain(row + 1, n - row - 1,
466
- sum_left, sum_right,
467
- sum_sq_left, sum_sq_right,
468
- sd_full, cnt
469
- );
470
-
471
- if (this_gain > min_gain && this_gain > best_gain)
472
- {
473
- best_gain = this_gain;
474
- split_point = avg_between(x[row], x[row + 1]);
475
- }
476
- }
477
- break;
478
- }
479
- }
480
-
481
- if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
482
- return 0;
483
- else
484
- return best_gain;
485
- }
486
-
487
- /* for split-criterion in single-variable splits */
488
- #define std_val(x, m, sd) ( ((x) - (m)) / (sd) )
489
- double eval_guided_crit(size_t *restrict ix_arr, size_t st, size_t end, double *restrict x,
490
- size_t &split_ix, double &split_point, double &xmin, double &xmax,
491
- GainCriterion criterion, double min_gain, MissingAction missing_action)
492
- {
493
- /* move NAs to the front if there's any, exclude them from calculations */
494
- if (missing_action != Fail)
495
- st = move_NAs_to_front(ix_arr, st, end, x);
496
-
497
- if (st >= end) return -HUGE_VAL;
498
- else if (st == (end-1))
499
- {
500
- split_point = avg_between(x[ix_arr[st]], x[ix_arr[end]]);
501
- split_ix = st;
502
- return 0;
503
- }
504
-
505
- /* sort in ascending order */
506
- std::sort(ix_arr + st, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
507
- if (x[ix_arr[st]] == x[ix_arr[end]]) return -HUGE_VAL;
508
- xmin = x[ix_arr[st]]; xmax = x[ix_arr[end]];
509
-
510
- /* Note: these variables are not standardized beforehand, so a single-pass gain
511
- calculation for both branches would suffer from numerical instability and perhaps give
512
- negative standard deviations if the sample size is large or the values have different
513
- orders of magnitude */
514
-
515
- /* first get mean and sd */
516
- double x_mean, x_sd;
517
- calc_mean_and_sd(ix_arr, st, end, x,
518
- Fail, x_sd, x_mean);
519
-
520
- /* compute sum - sum_sq - sd in one pass, on the standardized values */
521
- double zval;
522
- long double sum = 0;
523
- long double sum_sq = 0;
524
- double sd_full;
525
- for (size_t row = st; row <= end; row++)
526
- {
527
- zval = std_val(x[ix_arr[row]], x_mean, x_sd);
528
- sum += zval;
529
- sum_sq += square(zval);
530
- }
531
- sd_full = calc_sd_raw(end - st + 1, sum, sum_sq);
532
-
533
- /* try splits by moving observations one at a time from right to left */
534
- long double sum_left = 0;
535
- long double sum_sq_left = 0;
536
- long double sum_right = sum;
537
- long double sum_sq_right = sum_sq;
538
- double this_gain = -HUGE_VAL;
539
- double best_gain = -HUGE_VAL;
540
-
541
- switch(criterion)
542
- {
543
- case Averaged:
544
- {
545
- for (size_t row = st; row < end; row++)
546
- {
547
- zval = std_val(x[ix_arr[row]], x_mean, x_sd);
548
- sum_left += zval;
549
- sum_sq_left += square(zval);
550
- sum_right -= zval;
551
- sum_sq_right -= square(zval);
552
-
553
- if (x[ix_arr[row]] == x[ix_arr[row + 1]]) continue;
554
-
555
- this_gain = sd_gain(sd_full,
556
- calc_sd_raw(row - st + 1, sum_left, sum_sq_left),
557
- calc_sd_raw(end - row, sum_right, sum_sq_right)
558
- );
559
- if (this_gain > min_gain && this_gain > best_gain)
560
- {
561
- best_gain = this_gain;
562
- split_point = avg_between(x[ix_arr[row]], x[ix_arr[row + 1]]);
563
- split_ix = row;
564
- }
565
- }
566
- break;
567
- }
568
-
569
- case Pooled:
570
- {
571
- long double cnt = (long double)(end - st + 1);
572
- for (size_t row = st; row < end; row++)
573
- {
574
- zval = std_val(x[ix_arr[row]], x_mean, x_sd);
575
- sum_left += zval;
576
- sum_sq_left += square(zval);
577
- sum_right -= zval;
578
- sum_sq_right -= square(zval);
579
-
580
- if (x[ix_arr[row]] == x[ix_arr[row + 1]]) continue;
581
-
582
- this_gain = numeric_gain_no_div(row - st + 1, end - row,
583
- sum_left, sum_right,
584
- sum_sq_left, sum_sq_right,
585
- sd_full, cnt
586
- );
587
-
588
- if (this_gain > min_gain && this_gain > best_gain)
589
- {
590
- best_gain = this_gain;
591
- split_point = avg_between(x[ix_arr[row]], x[ix_arr[row + 1]]);
592
- split_ix = row;
593
- }
594
- }
595
- break;
596
- }
597
- }
598
-
599
- if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
600
- return 0;
601
- else
602
- return best_gain;
603
- }
604
-
605
- double eval_guided_crit(size_t ix_arr[], size_t st, size_t end,
606
- size_t col_num, double Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
607
- double buffer_arr[], size_t buffer_pos[],
608
- double &split_point, double &xmin, double &xmax,
609
- GainCriterion criterion, double min_gain, MissingAction missing_action)
610
- {
611
- todense(ix_arr, st, end,
612
- col_num, Xc, Xc_ind, Xc_indptr,
613
- buffer_arr);
614
- std::iota(buffer_pos, buffer_pos + (end - st + 1), (size_t)0);
615
- size_t temp;
616
- return eval_guided_crit(buffer_pos, 0, end - st, buffer_arr, temp, split_point,
617
- xmin, xmax, criterion, min_gain, missing_action);
618
- }
619
-
620
- /* How this works:
621
- - For Averaged criterion, will take the expected standard deviation that would be gotten with the category counts
622
- if each category got assigned a real number at random ~ Unif(0,1) and the data were thus converted to
623
- numerical. In such case, the best split (highest sd gain) is always putting the second-highest count in one
624
- branch, so there is no point in doing a full search over other permutations. In order to get more reasonable
625
- splits, when using the option to split by subsets of categories, it will sort the counts and evaluate only
626
- splits in which the categories are grouped in sorted order - in such cases it tends to pick either the
627
- smallest or the largest category to assign to one branch, but sometimes picks groups too.
628
- - For Pooled criterion, will take shannon entropy, which tends to make a more even split. In the case of splitting
629
- by a single category, it always puts the largest category in a separate branch. In the case of subsets,
630
- it can either evaluate possible splits over all permutations (not feasible if there are too many categories),
631
- or look up for splits in sorted order just like for Averaged criterion.
632
- Splitting by averaged Gini gain (like with Averaged) also selects always the second-largest category to put in one branch,
633
- while splitting by weighted Gini (like with Pooled) usually selects the largest category to put in one branch. The
634
- Gini gain is not easily comparable to that of numerical columns, so it's not offered as an option here.
635
- */
636
- /* https://math.stackexchange.com/questions/3343384/expected-variance-and-kurtosis-from-pmf-in-which-possible-discrete-values-are-dr */
637
- double eval_guided_crit(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int ncat,
638
- size_t *restrict buffer_cnt, size_t *restrict buffer_pos, double *restrict buffer_prob,
639
- int &chosen_cat, char *restrict split_categ, char *restrict buffer_split,
640
- GainCriterion criterion, double min_gain, bool all_perm, MissingAction missing_action, CategSplit cat_split_type)
641
- {
642
- /* move NAs to the front if there's any, exclude them from calculations */
643
- if (missing_action != Fail)
644
- st = move_NAs_to_front(ix_arr, st, end, x);
645
-
646
- if (st >= end) return -HUGE_VAL;
647
-
648
- /* count categories */
649
- memset(buffer_cnt, 0, sizeof(size_t) * ncat);
650
- for (size_t row = st; row <= end; row++)
651
- buffer_cnt[x[ix_arr[row]]]++;
652
-
653
- double this_gain = -HUGE_VAL;
654
- double best_gain = -HUGE_VAL;
655
- std::iota(buffer_pos, buffer_pos + ncat, (size_t)0);
656
- size_t st_pos = 0;
657
-
658
- switch(cat_split_type)
659
- {
660
- case SingleCateg:
661
- {
662
- size_t cnt = end - st + 1;
663
- size_t ncat_present = 0;
664
-
665
- switch(criterion)
666
- {
667
- case Averaged:
668
- {
669
- /* move zero-counts to the beginning */
670
- size_t temp;
671
- for (int cat = 0; cat < ncat; cat++)
672
- {
673
- if (buffer_cnt[cat])
674
- {
675
- ncat_present++;
676
- buffer_prob[cat] = (long double) buffer_cnt[cat] / (long double) cnt;
677
- }
678
-
679
- else
680
- {
681
- temp = buffer_pos[st_pos];
682
- buffer_pos[st_pos] = buffer_pos[cat];
683
- buffer_pos[cat] = temp;
684
- st_pos++;
685
- }
686
- }
687
-
688
- if (ncat_present <= 1) return -HUGE_VAL;
689
-
690
- double sd_full = expected_sd_cat(buffer_prob, ncat_present, buffer_pos + st_pos);
691
-
692
- /* try isolating each category one at a time */
693
- for (size_t pos = st_pos; (int)pos < ncat; pos++)
694
- {
695
- this_gain = sd_gain(sd_full,
696
- 0.0,
697
- expected_sd_cat_single(buffer_cnt, buffer_prob, ncat_present, buffer_pos + st_pos, pos - st_pos, cnt)
698
- );
699
- if (this_gain > min_gain && this_gain > best_gain)
700
- {
701
- best_gain = this_gain;
702
- chosen_cat = buffer_pos[pos];
703
- }
704
- }
705
- break;
706
- }
707
-
708
- case Pooled:
709
- {
710
- /* here it will always pick the largest one */
711
- size_t ncat_present = 0;
712
- size_t cnt_max = 0;
713
- for (int cat = 0; cat < ncat; cat++)
714
- {
715
- if (buffer_cnt[cat])
716
- {
717
- ncat_present++;
718
- if (cnt_max < buffer_cnt[cat])
719
- {
720
- cnt_max = buffer_cnt[cat];
721
- chosen_cat = cat;
722
- }
723
- }
724
- }
725
-
726
- if (ncat_present <= 1) return -HUGE_VAL;
727
-
728
- long double cnt_left = (long double)((end - st + 1) - cnt_max);
729
- this_gain = (
730
- (long double)cnt * logl((long double)cnt)
731
- - cnt_left * logl(cnt_left)
732
- - (long double)cnt_max * logl((long double)cnt_max)
733
- ) / cnt;
734
- best_gain = (this_gain > min_gain)? this_gain : best_gain;
735
- break;
736
- }
737
- }
738
- break;
739
- }
740
-
741
- case SubSet:
742
- {
743
- /* sort by counts */
744
- std::sort(buffer_pos, buffer_pos + ncat, [&buffer_cnt](const size_t a, const size_t b){return buffer_cnt[a] < buffer_cnt[b];});
745
-
746
- /* set split as: (1):left (0):right (-1):not_present */
747
- memset(buffer_split, 0, ncat * sizeof(char));
748
-
749
- long double cnt = (long double)(end - st + 1);
750
-
751
- switch(criterion)
752
- {
753
- case Averaged:
754
- {
755
- /* determine first non-zero and convert to probabilities */
756
- double sd_full;
757
- for (int cat = 0; cat < ncat; cat++)
758
- {
759
- if (buffer_cnt[buffer_pos[cat]])
760
- {
761
- buffer_prob[buffer_pos[cat]] = (long double)buffer_cnt[buffer_pos[cat]] / cnt;
762
- }
763
-
764
- else
765
- {
766
- buffer_split[buffer_pos[cat]] = -1;
767
- st_pos++;
768
- }
769
- }
770
-
771
- if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
772
-
773
- /* calculate full SD assuming they take values randomly ~Unif(0, 1) */
774
- size_t ncat_present = (size_t)ncat - st_pos;
775
- sd_full = expected_sd_cat(buffer_prob, ncat_present, buffer_pos + st_pos);
776
- if (ncat_present >= log2ceil(SIZE_MAX)) all_perm = false;
777
-
778
- /* move categories one at a time */
779
- for (size_t pos = st_pos; pos < ((size_t)ncat - st_pos - 1); pos++)
780
- {
781
- buffer_split[buffer_pos[pos]] = 1;
782
- this_gain = sd_gain(sd_full,
783
- expected_sd_cat(buffer_cnt, buffer_prob, pos - st_pos + 1, buffer_pos + st_pos),
784
- expected_sd_cat(buffer_cnt, buffer_prob, (size_t)ncat - pos - 1, buffer_pos + pos + 1)
785
- );
786
- if (this_gain > min_gain && this_gain > best_gain)
787
- {
788
- best_gain = this_gain;
789
- memcpy(split_categ, buffer_split, ncat * sizeof(char));
790
- }
791
- }
792
-
793
- break;
794
- }
795
-
796
- case Pooled:
797
- {
798
- long double s = 0;
799
-
800
- /* determine first non-zero and get base info */
801
- for (int cat = 0; cat < ncat; cat++)
802
- {
803
- if (buffer_cnt[buffer_pos[cat]])
804
- {
805
- s += (buffer_cnt[buffer_pos[cat]] <= 1)?
806
- 0 : ((long double) buffer_cnt[buffer_pos[cat]] * logl((long double)buffer_cnt[buffer_pos[cat]]));
807
- }
808
-
809
- else
810
- {
811
- buffer_split[buffer_pos[cat]] = -1;
812
- st_pos++;
813
- }
814
- }
815
-
816
- if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
817
-
818
- /* calculate base info */
819
- long double base_info = cnt * logl(cnt) - s;
820
-
821
- if (all_perm)
822
- {
823
- size_t cnt_left, cnt_right;
824
- double s_left, s_right;
825
- size_t ncat_present = (size_t)ncat - st_pos;
826
- size_t ncomb = pow2(ncat_present) - 1;
827
- size_t best_combin;
828
-
829
- for (size_t combin = 1; combin < ncomb; combin++)
830
- {
831
- cnt_left = 0; cnt_right = 0;
832
- s_left = 0; s_right = 0;
833
- for (size_t pos = st_pos; (int)pos < ncat; pos++)
834
- {
835
- if (extract_bit(combin, pos))
836
- {
837
- cnt_left += buffer_cnt[buffer_pos[pos]];
838
- s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
839
- 0 : ((long double) buffer_cnt[buffer_pos[pos]]
840
- * logl((long double) buffer_cnt[buffer_pos[pos]]));
841
- }
842
-
843
- else
844
- {
845
- cnt_right += buffer_cnt[buffer_pos[pos]];
846
- s_right += (buffer_cnt[buffer_pos[pos]] <= 1)?
847
- 0 : ((long double) buffer_cnt[buffer_pos[pos]]
848
- * logl((long double) buffer_cnt[buffer_pos[pos]]));
849
- }
850
- }
851
-
852
- this_gain = categ_gain(cnt_left, cnt_right,
853
- s_left, s_right,
854
- base_info, cnt);
855
-
856
- if (this_gain > min_gain && this_gain > best_gain)
857
- {
858
- best_gain = this_gain;
859
- best_combin = combin;
860
- }
861
-
862
- }
863
-
864
- if (best_gain > min_gain)
865
- for (size_t pos = 0; pos < ncat_present; pos++)
866
- split_categ[buffer_pos[st_pos + pos]] = extract_bit(best_combin, pos);
867
-
868
- }
869
-
870
- else
871
- {
872
- /* try moving the categories one at a time */
873
- size_t cnt_left = 0;
874
- size_t cnt_right = end - st + 1;
875
- double s_left = 0;
876
- double s_right = s;
877
-
878
- for (size_t pos = st_pos; pos < (ncat - st_pos - 1); pos++)
879
- {
880
- buffer_split[buffer_pos[pos]] = 1;
881
- s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
882
- 0 : ((long double)buffer_cnt[buffer_pos[pos]] * logl((long double)buffer_cnt[buffer_pos[pos]]));
883
- s_right -= (buffer_cnt[buffer_pos[pos]] <= 1)?
884
- 0 : ((long double)buffer_cnt[buffer_pos[pos]] * logl((long double)buffer_cnt[buffer_pos[pos]]));
885
- cnt_left += buffer_cnt[buffer_pos[pos]];
886
- cnt_right -= buffer_cnt[buffer_pos[pos]];
887
-
888
- this_gain = categ_gain(cnt_left, cnt_right,
889
- s_left, s_right,
890
- base_info, cnt);
891
-
892
- if (this_gain > min_gain && this_gain > best_gain)
893
- {
894
- best_gain = this_gain;
895
- memcpy(split_categ, buffer_split, ncat * sizeof(char));
896
- }
897
- }
898
- }
899
-
900
- break;
901
- }
902
- }
903
- }
904
- }
905
-
906
- if (st == (end-1)) return 0;
907
-
908
- if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
909
- return 0;
910
- else
911
- return best_gain;
912
- }