isotree 0.2.2 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -1
  3. data/LICENSE.txt +2 -2
  4. data/README.md +32 -14
  5. data/ext/isotree/ext.cpp +144 -31
  6. data/ext/isotree/extconf.rb +7 -7
  7. data/lib/isotree/isolation_forest.rb +110 -30
  8. data/lib/isotree/version.rb +1 -1
  9. data/vendor/isotree/LICENSE +1 -1
  10. data/vendor/isotree/README.md +165 -27
  11. data/vendor/isotree/include/isotree.hpp +2111 -0
  12. data/vendor/isotree/include/isotree_oop.hpp +394 -0
  13. data/vendor/isotree/inst/COPYRIGHTS +62 -0
  14. data/vendor/isotree/src/RcppExports.cpp +525 -52
  15. data/vendor/isotree/src/Rwrapper.cpp +1931 -268
  16. data/vendor/isotree/src/c_interface.cpp +953 -0
  17. data/vendor/isotree/src/crit.hpp +4232 -0
  18. data/vendor/isotree/src/dist.hpp +1886 -0
  19. data/vendor/isotree/src/exp_depth_table.hpp +134 -0
  20. data/vendor/isotree/src/extended.hpp +1444 -0
  21. data/vendor/isotree/src/external_facing_generic.hpp +399 -0
  22. data/vendor/isotree/src/fit_model.hpp +2401 -0
  23. data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
  24. data/vendor/isotree/src/helpers_iforest.hpp +813 -0
  25. data/vendor/isotree/src/{impute.cpp → impute.hpp} +353 -122
  26. data/vendor/isotree/src/indexer.cpp +515 -0
  27. data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
  28. data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
  29. data/vendor/isotree/src/isoforest.hpp +1659 -0
  30. data/vendor/isotree/src/isotree.hpp +1804 -392
  31. data/vendor/isotree/src/isotree_exportable.hpp +99 -0
  32. data/vendor/isotree/src/merge_models.cpp +159 -16
  33. data/vendor/isotree/src/mult.hpp +1321 -0
  34. data/vendor/isotree/src/oop_interface.cpp +842 -0
  35. data/vendor/isotree/src/oop_interface.hpp +278 -0
  36. data/vendor/isotree/src/other_helpers.hpp +219 -0
  37. data/vendor/isotree/src/predict.hpp +1932 -0
  38. data/vendor/isotree/src/python_helpers.hpp +134 -0
  39. data/vendor/isotree/src/ref_indexer.hpp +154 -0
  40. data/vendor/isotree/src/robinmap/LICENSE +21 -0
  41. data/vendor/isotree/src/robinmap/README.md +483 -0
  42. data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
  43. data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1620 -0
  44. data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
  45. data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
  46. data/vendor/isotree/src/serialize.cpp +4300 -139
  47. data/vendor/isotree/src/sql.cpp +141 -59
  48. data/vendor/isotree/src/subset_models.cpp +174 -0
  49. data/vendor/isotree/src/utils.hpp +3808 -0
  50. data/vendor/isotree/src/xoshiro.hpp +467 -0
  51. data/vendor/isotree/src/ziggurat.hpp +405 -0
  52. metadata +38 -104
  53. data/vendor/cereal/LICENSE +0 -24
  54. data/vendor/cereal/README.md +0 -85
  55. data/vendor/cereal/include/cereal/access.hpp +0 -351
  56. data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
  57. data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
  58. data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
  59. data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
  60. data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
  61. data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
  62. data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
  63. data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
  64. data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
  65. data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
  66. data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
  67. data/vendor/cereal/include/cereal/details/util.hpp +0 -84
  68. data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
  69. data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
  70. data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
  71. data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
  72. data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
  73. data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
  74. data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
  75. data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
  76. data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
  77. data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
  78. data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
  79. data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
  80. data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
  81. data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
  82. data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
  83. data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
  84. data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
  85. data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
  86. data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
  87. data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
  88. data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
  89. data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
  90. data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
  91. data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
  92. data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
  93. data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
  94. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
  95. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
  96. data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
  97. data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
  98. data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
  99. data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
  100. data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
  101. data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
  102. data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
  103. data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
  104. data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
  105. data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
  106. data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
  107. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
  108. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
  109. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
  110. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
  111. data/vendor/cereal/include/cereal/macros.hpp +0 -154
  112. data/vendor/cereal/include/cereal/specialize.hpp +0 -139
  113. data/vendor/cereal/include/cereal/types/array.hpp +0 -79
  114. data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
  115. data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
  116. data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
  117. data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
  118. data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
  119. data/vendor/cereal/include/cereal/types/common.hpp +0 -129
  120. data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
  121. data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
  122. data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
  123. data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
  124. data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
  125. data/vendor/cereal/include/cereal/types/list.hpp +0 -62
  126. data/vendor/cereal/include/cereal/types/map.hpp +0 -36
  127. data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
  128. data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
  129. data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
  130. data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
  131. data/vendor/cereal/include/cereal/types/set.hpp +0 -103
  132. data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
  133. data/vendor/cereal/include/cereal/types/string.hpp +0 -61
  134. data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
  135. data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
  136. data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
  137. data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
  138. data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
  139. data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
  140. data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
  141. data/vendor/cereal/include/cereal/version.hpp +0 -52
  142. data/vendor/isotree/src/Makevars +0 -4
  143. data/vendor/isotree/src/crit.cpp +0 -912
  144. data/vendor/isotree/src/dist.cpp +0 -749
  145. data/vendor/isotree/src/extended.cpp +0 -790
  146. data/vendor/isotree/src/fit_model.cpp +0 -1090
  147. data/vendor/isotree/src/helpers_iforest.cpp +0 -324
  148. data/vendor/isotree/src/isoforest.cpp +0 -771
  149. data/vendor/isotree/src/mult.cpp +0 -607
  150. data/vendor/isotree/src/predict.cpp +0 -853
  151. data/vendor/isotree/src/utils.cpp +0 -1566
@@ -1,912 +0,0 @@
1
- /* Isolation forests and variations thereof, with adjustments for incorporation
2
- * of categorical variables and missing values.
3
- * Writen for C++11 standard and aimed at being used in R and Python.
4
- *
5
- * This library is based on the following works:
6
- * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
- * "Isolation forest."
8
- * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
- * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
- * "Isolation-based anomaly detection."
11
- * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
- * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
- * "Extended Isolation Forest."
14
- * arXiv preprint arXiv:1811.02141 (2018).
15
- * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
- * "On detecting clustered anomalies using SCiForest."
17
- * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
- * [5] https://sourceforge.net/projects/iforest/
19
- * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
- * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
- * [8] Cortes, David. "Distance approximation using Isolation Forests." arXiv preprint arXiv:1910.12362 (2019).
22
- * [9] Cortes, David. "Imputing missing values with unsupervised random trees." arXiv preprint arXiv:1911.06646 (2019).
23
- *
24
- * BSD 2-Clause License
25
- * Copyright (c) 2020, David Cortes
26
- * All rights reserved.
27
- * Redistribution and use in source and binary forms, with or without
28
- * modification, are permitted provided that the following conditions are met:
29
- * * Redistributions of source code must retain the above copyright notice, this
30
- * list of conditions and the following disclaimer.
31
- * * Redistributions in binary form must reproduce the above copyright notice,
32
- * this list of conditions and the following disclaimer in the documentation
33
- * and/or other materials provided with the distribution.
34
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
38
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44
- */
45
- #include "isotree.hpp"
46
-
47
- #define pw1(x) ((x))
48
- #define pw2(x) ((x) * (x))
49
- #define pw3(x) ((x) * (x) * (x))
50
- #define pw4(x) ((x) * (x) * (x) * (x))
51
-
52
- double calc_kurtosis(size_t ix_arr[], size_t st, size_t end, double x[], MissingAction missing_action)
53
- {
54
- long double m = 0;
55
- long double M2 = 0, M3 = 0, M4 = 0;
56
- long double delta, delta_s, delta_div;
57
- long double diff, n;
58
-
59
- if (missing_action == Fail)
60
- {
61
- for (size_t row = st; row <= end; row++)
62
- {
63
- n = (long double)(row - st + 1);
64
-
65
- delta = x[ix_arr[row]] - m;
66
- delta_div = delta / n;
67
- delta_s = delta_div * delta_div;
68
- diff = delta * (delta_div * (long double)(row - st));
69
-
70
- m += delta_div;
71
- M4 += diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
72
- M3 += diff * delta_div * (n - 2) - 3 * delta_div * M2;
73
- M2 += diff;
74
- }
75
-
76
- return ( M4 / M2 ) * ( (long double)(end - st + 1) / M2 );
77
- }
78
-
79
- else
80
- {
81
- size_t cnt = 0;
82
- for (size_t row = st; row <= end; row++)
83
- {
84
- if (!is_na_or_inf(x[ix_arr[row]]))
85
- {
86
- cnt++;
87
- n = (long double) cnt;
88
-
89
- delta = x[ix_arr[row]] - m;
90
- delta_div = delta / n;
91
- delta_s = delta_div * delta_div;
92
- diff = delta * (delta_div * (long double)(cnt - 1));
93
-
94
- m += delta_div;
95
- M4 += diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
96
- M3 += diff * delta_div * (n - 2) - 3 * delta_div * M2;
97
- M2 += diff;
98
- }
99
- }
100
-
101
- return ( M4 / M2 ) * ( (long double)cnt / M2 );
102
- }
103
- }
104
-
105
-
106
- double calc_kurtosis(size_t ix_arr[], size_t st, size_t end, size_t col_num,
107
- double Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
108
- MissingAction missing_action)
109
- {
110
- /* ix_arr must be already sorted beforehand */
111
- if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
112
- return 0;
113
-
114
- long double s1 = 0;
115
- long double s2 = 0;
116
- long double s3 = 0;
117
- long double s4 = 0;
118
- size_t cnt = end - st + 1;
119
-
120
- if (cnt <= 1) return 0;
121
-
122
- size_t st_col = Xc_indptr[col_num];
123
- size_t end_col = Xc_indptr[col_num + 1] - 1;
124
- size_t curr_pos = st_col;
125
- size_t ind_end_col = Xc_ind[end_col];
126
- size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
127
-
128
- if (missing_action != Fail)
129
- {
130
- for (size_t *row = ptr_st;
131
- row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
132
- )
133
- {
134
- if (Xc_ind[curr_pos] == *row)
135
- {
136
- if (is_na_or_inf(Xc[curr_pos]))
137
- {
138
- cnt--;
139
- }
140
-
141
- else
142
- {
143
- s1 += pw1(Xc[curr_pos]);
144
- s2 += pw2(Xc[curr_pos]);
145
- s3 += pw3(Xc[curr_pos]);
146
- s4 += pw4(Xc[curr_pos]);
147
- }
148
-
149
- if (row == ix_arr + end || curr_pos == end_col) break;
150
- curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
151
- }
152
-
153
- else
154
- {
155
- if (Xc_ind[curr_pos] > *row)
156
- row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
157
- else
158
- curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
159
- }
160
- }
161
- }
162
-
163
- else
164
- {
165
- for (size_t *row = ptr_st;
166
- row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
167
- )
168
- {
169
- if (Xc_ind[curr_pos] == *row)
170
- {
171
- s1 += pw1(Xc[curr_pos]);
172
- s2 += pw2(Xc[curr_pos]);
173
- s3 += pw3(Xc[curr_pos]);
174
- s4 += pw4(Xc[curr_pos]);
175
-
176
- if (row == ix_arr + end || curr_pos == end_col) break;
177
- curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
178
- }
179
-
180
- else
181
- {
182
- if (Xc_ind[curr_pos] > *row)
183
- row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
184
- else
185
- curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
186
- }
187
- }
188
- }
189
-
190
- if (cnt <= 1 || s2 == 0 || s2 == pw2(s1)) return 0;
191
- long double cnt_l = (long double) cnt;
192
- long double sn = s1 / cnt_l;
193
- long double v = s2 / cnt_l - pw2(sn);
194
- if (v <= 0) return 0;
195
- return (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt_l * pw4(sn)) / (cnt_l * pw2(v));
196
- }
197
-
198
-
199
- double calc_kurtosis(size_t ix_arr[], size_t st, size_t end, int x[], int ncat, size_t buffer_cnt[], double buffer_prob[],
200
- MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
201
- {
202
- /* This calculation proceeds as follows:
203
- - If splitting by subsets, it will assign a random weight ~Unif(0,1) to
204
- each category, and approximate kurtosis by sampling from such distribution
205
- with the same probabilities as given by the current counts.
206
- - If splitting by isolating one category, will binarize at each categorical level,
207
- assume the values are zero or one, and output the average assuming each categorical
208
- level has equal probability of being picked.
209
- (Note that both are misleading heuristics, but might be better than random)
210
- */
211
- size_t cnt = end - st + 1;
212
- std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
213
- double sum_kurt = 0;
214
-
215
- if (missing_action == Fail)
216
- {
217
- for (size_t row = st; row <= end; row++)
218
- buffer_cnt[x[ix_arr[row]]]++;
219
- }
220
-
221
- else
222
- {
223
- for (size_t row = st; row <= end; row++)
224
- {
225
- if (x[ix_arr[row]] >= 0)
226
- buffer_cnt[x[ix_arr[row]]]++;
227
- else
228
- buffer_cnt[ncat]++;
229
- }
230
- }
231
-
232
- cnt -= buffer_cnt[ncat];
233
- if (cnt <= 1) return 0;
234
- long double cnt_l = (long double) cnt;
235
- for (int cat = 0; cat < ncat; cat++)
236
- buffer_prob[cat] = buffer_cnt[cat] / cnt_l;
237
-
238
- switch(cat_split_type)
239
- {
240
- case SubSet:
241
- {
242
- long double temp_v;
243
- long double s1, s2, s3, s4;
244
- long double coef;
245
- std::uniform_real_distribution<double> runif(0, 1);
246
- size_t ntry = 50;
247
- for (size_t iternum = 0; iternum < 50; iternum++)
248
- {
249
- s1 = 0; s2 = 0; s3 = 0; s4 = 0;
250
- for (int cat = 0; cat < ncat; cat++)
251
- {
252
- coef = runif(rnd_generator);
253
- s1 += buffer_prob[cat] * pw1(coef);
254
- s2 += buffer_prob[cat] * pw2(coef);
255
- s3 += buffer_prob[cat] * pw3(coef);
256
- s4 += buffer_prob[cat] * pw4(coef);
257
- }
258
- temp_v = s2 - pw2(s1);
259
- if (temp_v <= 0)
260
- ntry--;
261
- else
262
- sum_kurt += (s4 - 4 * s3 * pw1(s1) + 6 * s2 * pw2(s1) - 4 * s1 * pw3(s1) + pw4(s1)) / pw2(temp_v);
263
- }
264
- if (!ntry)
265
- return 0;
266
- else
267
- return sum_kurt / (long double)ntry;
268
- }
269
-
270
- case SingleCateg:
271
- {
272
- double p;
273
- int ncat_present = ncat;
274
- for (int cat = 0; cat < ncat; cat++)
275
- {
276
- p = buffer_prob[cat];
277
- if (p == 0)
278
- ncat_present--;
279
- else
280
- sum_kurt += (p - 4 * p * pw1(p) + 6 * p * pw2(p) - 4 * p * pw3(p) + pw4(p)) / pw2(p - pw2(p));
281
- }
282
- if (ncat_present <= 1)
283
- return 0;
284
- else
285
- return sum_kurt / (double) ncat_present;
286
- }
287
- }
288
-
289
- return -1; /* this will never be reached, but CRAN complains otherwise */
290
- }
291
-
292
-
293
- double expected_sd_cat(double p[], size_t n, size_t pos[])
294
- {
295
- if (n <= 1) return 0;
296
-
297
- long double cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[1]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[1]]) / 3.0 + p[pos[1]] / 3.0;
298
- for (size_t cat1 = 2; cat1 < n; cat1++)
299
- {
300
- cum_var += p[pos[cat1]] / 3.0 - square(p[pos[cat1]]) / 3.0;
301
- for (size_t cat2 = 0; cat2 < cat1; cat2++)
302
- cum_var -= p[pos[cat1]] * p[pos[cat2]] / 2.0;
303
- }
304
- return sqrt(fmax(cum_var, 1e-8));
305
- }
306
-
307
- double expected_sd_cat(size_t counts[], double p[], size_t n, size_t pos[])
308
- {
309
- if (n <= 1) return 0;
310
-
311
- size_t tot = std::accumulate(pos, pos + n, (size_t)0, [&counts](size_t tot, const size_t ix){return tot + counts[ix];});
312
- long double cnt_div = (long double) tot;
313
- for (size_t cat = 0; cat < n; cat++)
314
- p[pos[cat]] = (long double)counts[pos[cat]] / cnt_div;
315
-
316
- return expected_sd_cat(p, n, pos);
317
- }
318
-
319
- double expected_sd_cat_single(size_t counts[], double p[], size_t n, size_t pos[], size_t cat_exclude, size_t cnt)
320
- {
321
- if (cat_exclude == 0)
322
- return expected_sd_cat(counts, p, n-1, pos + 1);
323
-
324
- else if (cat_exclude == (n-1))
325
- return expected_sd_cat(counts, p, n-1, pos);
326
-
327
- size_t ix_exclude = pos[cat_exclude];
328
-
329
- long double cnt_div = (long double) (cnt - counts[ix_exclude]);
330
- for (size_t cat = 0; cat < n; cat++)
331
- p[pos[cat]] = (long double)counts[pos[cat]] / cnt_div;
332
-
333
- double cum_var;
334
- if (cat_exclude != 1)
335
- cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[1]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[1]]) / 3.0 + p[pos[1]] / 3.0;
336
- else
337
- cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[2]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[2]]) / 3.0 + p[pos[2]] / 3.0;
338
- for (size_t cat1 = (cat_exclude == 1)? 3 : 2; cat1 < n; cat1++)
339
- {
340
- if (pos[cat1] == ix_exclude) continue;
341
- cum_var += p[pos[cat1]] / 3.0 - square(p[pos[cat1]]) / 3.0;
342
- for (size_t cat2 = 0; cat2 < cat1; cat2++)
343
- {
344
- if (pos[cat2] == ix_exclude) continue;
345
- cum_var -= p[pos[cat1]] * p[pos[cat2]] / 2.0;
346
- }
347
-
348
- }
349
- return sqrt(fmax(cum_var, 1e-8));
350
- }
351
-
352
- double numeric_gain(size_t cnt_left, size_t cnt_right,
353
- long double sum_left, long double sum_right,
354
- long double sum_sq_left, long double sum_sq_right,
355
- double sd_full, long double cnt)
356
- {
357
- long double residual =
358
- (long double) cnt_left * calc_sd_raw_l(cnt_left, sum_left, sum_sq_left) +
359
- (long double) cnt_right * calc_sd_raw_l(cnt_right, sum_right, sum_sq_right);
360
- return 1 - residual / (cnt * sd_full);
361
- }
362
-
363
- double numeric_gain_no_div(size_t cnt_left, size_t cnt_right,
364
- long double sum_left, long double sum_right,
365
- long double sum_sq_left, long double sum_sq_right,
366
- double sd_full, long double cnt)
367
- {
368
- long double residual =
369
- (long double) cnt_left * calc_sd_raw_l(cnt_left, sum_left, sum_sq_left) +
370
- (long double) cnt_right * calc_sd_raw_l(cnt_right, sum_right, sum_sq_right);
371
- return sd_full - residual / cnt;
372
- }
373
-
374
- double categ_gain(size_t cnt_left, size_t cnt_right,
375
- long double s_left, long double s_right,
376
- long double base_info, long double cnt)
377
- {
378
- return (
379
- base_info -
380
- (((cnt_left <= 1)? 0 : ((long double)cnt_left * logl((long double)cnt_left))) - s_left) -
381
- (((cnt_right <= 1)? 0 : ((long double)cnt_right * logl((long double)cnt_right))) - s_right)
382
- ) / cnt;
383
- }
384
-
385
-
386
- #define avg_between(a, b) (((a) + (b)) / 2)
387
- #define sd_gain(sd, sd_left, sd_right) (1.0 - ((sd_left) + (sd_right)) / (2.0 * (sd)))
388
-
389
- /* for split-criterion in hyperplanes (see below for version aimed at single-variable splits) */
390
- double eval_guided_crit(double *restrict x, size_t n, GainCriterion criterion, double min_gain,
391
- double &split_point, double &xmin, double &xmax)
392
- {
393
- /* Note: the input 'x' is supposed to be a linear combination of standardized variables, so
394
- all numbers are assumed to be small and in the same scale */
395
-
396
- /* here it's assumed the 'x' vector matches exactly with 'ix_arr' + 'st' */
397
- if (n == 2)
398
- {
399
- split_point = avg_between(x[0], x[1]);
400
- return 0;
401
- }
402
-
403
- /* sort in ascending order */
404
- std::sort(x, x + n);
405
- if (x[0] == x[n-1]) return -HUGE_VAL;
406
- xmin = x[0]; xmax = x[n-1];
407
-
408
- /* compute sum - sum_sq - sd in one pass */
409
- long double sum = 0;
410
- long double sum_sq = 0;
411
- double sd_full;
412
- for (size_t row = 0; row < n; row++)
413
- {
414
- sum += x[row];
415
- sum_sq += square(x[row]);
416
- }
417
- sd_full = calc_sd_raw(n, sum, sum_sq);
418
-
419
- /* try splits by moving observations one at a time from right to left */
420
- long double sum_left = 0;
421
- long double sum_sq_left = 0;
422
- long double sum_right = sum;
423
- long double sum_sq_right = sum_sq;
424
- double this_gain = -HUGE_VAL;
425
- double best_gain = -HUGE_VAL;
426
-
427
- switch(criterion)
428
- {
429
- case Averaged:
430
- {
431
- for (size_t row = 0; row < n-1; row++)
432
- {
433
- sum_left += x[row];
434
- sum_sq_left += square(x[row]);
435
- sum_right -= x[row];
436
- sum_sq_right -= square(x[row]);
437
-
438
- if (x[row] == x[row + 1]) continue;
439
-
440
- this_gain = sd_gain(sd_full,
441
- calc_sd_raw(row + 1, sum_left, sum_sq_left),
442
- calc_sd_raw(n - row - 1, sum_right, sum_sq_right)
443
- );
444
- if (this_gain > min_gain && this_gain > best_gain)
445
- {
446
- best_gain = this_gain;
447
- split_point = avg_between(x[row], x[row + 1]);
448
- }
449
- }
450
- break;
451
- }
452
-
453
- case Pooled:
454
- {
455
- long double cnt = (long double) n;
456
- for (size_t row = 0; row < n-1; row++)
457
- {
458
- sum_left += x[row];
459
- sum_sq_left += square(x[row]);
460
- sum_right -= x[row];
461
- sum_sq_right -= square(x[row]);
462
-
463
- if (x[row] == x[row + 1]) continue;
464
-
465
- this_gain = numeric_gain(row + 1, n - row - 1,
466
- sum_left, sum_right,
467
- sum_sq_left, sum_sq_right,
468
- sd_full, cnt
469
- );
470
-
471
- if (this_gain > min_gain && this_gain > best_gain)
472
- {
473
- best_gain = this_gain;
474
- split_point = avg_between(x[row], x[row + 1]);
475
- }
476
- }
477
- break;
478
- }
479
- }
480
-
481
- if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
482
- return 0;
483
- else
484
- return best_gain;
485
- }
486
-
487
- /* for split-criterion in single-variable splits */
488
- #define std_val(x, m, sd) ( ((x) - (m)) / (sd) )
489
- double eval_guided_crit(size_t *restrict ix_arr, size_t st, size_t end, double *restrict x,
490
- size_t &split_ix, double &split_point, double &xmin, double &xmax,
491
- GainCriterion criterion, double min_gain, MissingAction missing_action)
492
- {
493
- /* move NAs to the front if there's any, exclude them from calculations */
494
- if (missing_action != Fail)
495
- st = move_NAs_to_front(ix_arr, st, end, x);
496
-
497
- if (st >= end) return -HUGE_VAL;
498
- else if (st == (end-1))
499
- {
500
- split_point = avg_between(x[ix_arr[st]], x[ix_arr[end]]);
501
- split_ix = st;
502
- return 0;
503
- }
504
-
505
- /* sort in ascending order */
506
- std::sort(ix_arr + st, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
507
- if (x[ix_arr[st]] == x[ix_arr[end]]) return -HUGE_VAL;
508
- xmin = x[ix_arr[st]]; xmax = x[ix_arr[end]];
509
-
510
- /* Note: these variables are not standardized beforehand, so a single-pass gain
511
- calculation for both branches would suffer from numerical instability and perhaps give
512
- negative standard deviations if the sample size is large or the values have different
513
- orders of magnitude */
514
-
515
- /* first get mean and sd */
516
- double x_mean, x_sd;
517
- calc_mean_and_sd(ix_arr, st, end, x,
518
- Fail, x_sd, x_mean);
519
-
520
- /* compute sum - sum_sq - sd in one pass, on the standardized values */
521
- double zval;
522
- long double sum = 0;
523
- long double sum_sq = 0;
524
- double sd_full;
525
- for (size_t row = st; row <= end; row++)
526
- {
527
- zval = std_val(x[ix_arr[row]], x_mean, x_sd);
528
- sum += zval;
529
- sum_sq += square(zval);
530
- }
531
- sd_full = calc_sd_raw(end - st + 1, sum, sum_sq);
532
-
533
- /* try splits by moving observations one at a time from right to left */
534
- long double sum_left = 0;
535
- long double sum_sq_left = 0;
536
- long double sum_right = sum;
537
- long double sum_sq_right = sum_sq;
538
- double this_gain = -HUGE_VAL;
539
- double best_gain = -HUGE_VAL;
540
-
541
- switch(criterion)
542
- {
543
- case Averaged:
544
- {
545
- for (size_t row = st; row < end; row++)
546
- {
547
- zval = std_val(x[ix_arr[row]], x_mean, x_sd);
548
- sum_left += zval;
549
- sum_sq_left += square(zval);
550
- sum_right -= zval;
551
- sum_sq_right -= square(zval);
552
-
553
- if (x[ix_arr[row]] == x[ix_arr[row + 1]]) continue;
554
-
555
- this_gain = sd_gain(sd_full,
556
- calc_sd_raw(row - st + 1, sum_left, sum_sq_left),
557
- calc_sd_raw(end - row, sum_right, sum_sq_right)
558
- );
559
- if (this_gain > min_gain && this_gain > best_gain)
560
- {
561
- best_gain = this_gain;
562
- split_point = avg_between(x[ix_arr[row]], x[ix_arr[row + 1]]);
563
- split_ix = row;
564
- }
565
- }
566
- break;
567
- }
568
-
569
- case Pooled:
570
- {
571
- long double cnt = (long double)(end - st + 1);
572
- for (size_t row = st; row < end; row++)
573
- {
574
- zval = std_val(x[ix_arr[row]], x_mean, x_sd);
575
- sum_left += zval;
576
- sum_sq_left += square(zval);
577
- sum_right -= zval;
578
- sum_sq_right -= square(zval);
579
-
580
- if (x[ix_arr[row]] == x[ix_arr[row + 1]]) continue;
581
-
582
- this_gain = numeric_gain_no_div(row - st + 1, end - row,
583
- sum_left, sum_right,
584
- sum_sq_left, sum_sq_right,
585
- sd_full, cnt
586
- );
587
-
588
- if (this_gain > min_gain && this_gain > best_gain)
589
- {
590
- best_gain = this_gain;
591
- split_point = avg_between(x[ix_arr[row]], x[ix_arr[row + 1]]);
592
- split_ix = row;
593
- }
594
- }
595
- break;
596
- }
597
- }
598
-
599
- if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
600
- return 0;
601
- else
602
- return best_gain;
603
- }
604
-
605
- double eval_guided_crit(size_t ix_arr[], size_t st, size_t end,
606
- size_t col_num, double Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
607
- double buffer_arr[], size_t buffer_pos[],
608
- double &split_point, double &xmin, double &xmax,
609
- GainCriterion criterion, double min_gain, MissingAction missing_action)
610
- {
611
- todense(ix_arr, st, end,
612
- col_num, Xc, Xc_ind, Xc_indptr,
613
- buffer_arr);
614
- std::iota(buffer_pos, buffer_pos + (end - st + 1), (size_t)0);
615
- size_t temp;
616
- return eval_guided_crit(buffer_pos, 0, end - st, buffer_arr, temp, split_point,
617
- xmin, xmax, criterion, min_gain, missing_action);
618
- }
619
-
620
- /* How this works:
621
- - For Averaged criterion, will take the expected standard deviation that would be gotten with the category counts
622
- if each category got assigned a real number at random ~ Unif(0,1) and the data were thus converted to
623
- numerical. In such case, the best split (highest sd gain) is always putting the second-highest count in one
624
- branch, so there is no point in doing a full search over other permutations. In order to get more reasonable
625
- splits, when using the option to split by subsets of categories, it will sort the counts and evaluate only
626
- splits in which the categories are grouped in sorted order - in such cases it tends to pick either the
627
- smallest or the largest category to assign to one branch, but sometimes picks groups too.
628
- - For Pooled criterion, will take shannon entropy, which tends to make a more even split. In the case of splitting
629
- by a single category, it always puts the largest category in a separate branch. In the case of subsets,
630
- it can either evaluate possible splits over all permutations (not feasible if there are too many categories),
631
- or look up for splits in sorted order just like for Averaged criterion.
632
- Splitting by averaged Gini gain (like with Averaged) also selects always the second-largest category to put in one branch,
633
- while splitting by weighted Gini (like with Pooled) usually selects the largest category to put in one branch. The
634
- Gini gain is not easily comparable to that of numerical columns, so it's not offered as an option here.
635
- */
636
- /* https://math.stackexchange.com/questions/3343384/expected-variance-and-kurtosis-from-pmf-in-which-possible-discrete-values-are-dr */
637
- double eval_guided_crit(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int ncat,
638
- size_t *restrict buffer_cnt, size_t *restrict buffer_pos, double *restrict buffer_prob,
639
- int &chosen_cat, char *restrict split_categ, char *restrict buffer_split,
640
- GainCriterion criterion, double min_gain, bool all_perm, MissingAction missing_action, CategSplit cat_split_type)
641
- {
642
- /* move NAs to the front if there's any, exclude them from calculations */
643
- if (missing_action != Fail)
644
- st = move_NAs_to_front(ix_arr, st, end, x);
645
-
646
- if (st >= end) return -HUGE_VAL;
647
-
648
- /* count categories */
649
- memset(buffer_cnt, 0, sizeof(size_t) * ncat);
650
- for (size_t row = st; row <= end; row++)
651
- buffer_cnt[x[ix_arr[row]]]++;
652
-
653
- double this_gain = -HUGE_VAL;
654
- double best_gain = -HUGE_VAL;
655
- std::iota(buffer_pos, buffer_pos + ncat, (size_t)0);
656
- size_t st_pos = 0;
657
-
658
- switch(cat_split_type)
659
- {
660
- case SingleCateg:
661
- {
662
- size_t cnt = end - st + 1;
663
- size_t ncat_present = 0;
664
-
665
- switch(criterion)
666
- {
667
- case Averaged:
668
- {
669
- /* move zero-counts to the beginning */
670
- size_t temp;
671
- for (int cat = 0; cat < ncat; cat++)
672
- {
673
- if (buffer_cnt[cat])
674
- {
675
- ncat_present++;
676
- buffer_prob[cat] = (long double) buffer_cnt[cat] / (long double) cnt;
677
- }
678
-
679
- else
680
- {
681
- temp = buffer_pos[st_pos];
682
- buffer_pos[st_pos] = buffer_pos[cat];
683
- buffer_pos[cat] = temp;
684
- st_pos++;
685
- }
686
- }
687
-
688
- if (ncat_present <= 1) return -HUGE_VAL;
689
-
690
- double sd_full = expected_sd_cat(buffer_prob, ncat_present, buffer_pos + st_pos);
691
-
692
- /* try isolating each category one at a time */
693
- for (size_t pos = st_pos; (int)pos < ncat; pos++)
694
- {
695
- this_gain = sd_gain(sd_full,
696
- 0.0,
697
- expected_sd_cat_single(buffer_cnt, buffer_prob, ncat_present, buffer_pos + st_pos, pos - st_pos, cnt)
698
- );
699
- if (this_gain > min_gain && this_gain > best_gain)
700
- {
701
- best_gain = this_gain;
702
- chosen_cat = buffer_pos[pos];
703
- }
704
- }
705
- break;
706
- }
707
-
708
- case Pooled:
709
- {
710
- /* here it will always pick the largest one */
711
- size_t ncat_present = 0;
712
- size_t cnt_max = 0;
713
- for (int cat = 0; cat < ncat; cat++)
714
- {
715
- if (buffer_cnt[cat])
716
- {
717
- ncat_present++;
718
- if (cnt_max < buffer_cnt[cat])
719
- {
720
- cnt_max = buffer_cnt[cat];
721
- chosen_cat = cat;
722
- }
723
- }
724
- }
725
-
726
- if (ncat_present <= 1) return -HUGE_VAL;
727
-
728
- long double cnt_left = (long double)((end - st + 1) - cnt_max);
729
- this_gain = (
730
- (long double)cnt * logl((long double)cnt)
731
- - cnt_left * logl(cnt_left)
732
- - (long double)cnt_max * logl((long double)cnt_max)
733
- ) / cnt;
734
- best_gain = (this_gain > min_gain)? this_gain : best_gain;
735
- break;
736
- }
737
- }
738
- break;
739
- }
740
-
741
- case SubSet:
742
- {
743
- /* sort by counts */
744
- std::sort(buffer_pos, buffer_pos + ncat, [&buffer_cnt](const size_t a, const size_t b){return buffer_cnt[a] < buffer_cnt[b];});
745
-
746
- /* set split as: (1):left (0):right (-1):not_present */
747
- memset(buffer_split, 0, ncat * sizeof(char));
748
-
749
- long double cnt = (long double)(end - st + 1);
750
-
751
- switch(criterion)
752
- {
753
- case Averaged:
754
- {
755
- /* determine first non-zero and convert to probabilities */
756
- double sd_full;
757
- for (int cat = 0; cat < ncat; cat++)
758
- {
759
- if (buffer_cnt[buffer_pos[cat]])
760
- {
761
- buffer_prob[buffer_pos[cat]] = (long double)buffer_cnt[buffer_pos[cat]] / cnt;
762
- }
763
-
764
- else
765
- {
766
- buffer_split[buffer_pos[cat]] = -1;
767
- st_pos++;
768
- }
769
- }
770
-
771
- if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
772
-
773
- /* calculate full SD assuming they take values randomly ~Unif(0, 1) */
774
- size_t ncat_present = (size_t)ncat - st_pos;
775
- sd_full = expected_sd_cat(buffer_prob, ncat_present, buffer_pos + st_pos);
776
- if (ncat_present >= log2ceil(SIZE_MAX)) all_perm = false;
777
-
778
- /* move categories one at a time */
779
- for (size_t pos = st_pos; pos < ((size_t)ncat - st_pos - 1); pos++)
780
- {
781
- buffer_split[buffer_pos[pos]] = 1;
782
- this_gain = sd_gain(sd_full,
783
- expected_sd_cat(buffer_cnt, buffer_prob, pos - st_pos + 1, buffer_pos + st_pos),
784
- expected_sd_cat(buffer_cnt, buffer_prob, (size_t)ncat - pos - 1, buffer_pos + pos + 1)
785
- );
786
- if (this_gain > min_gain && this_gain > best_gain)
787
- {
788
- best_gain = this_gain;
789
- memcpy(split_categ, buffer_split, ncat * sizeof(char));
790
- }
791
- }
792
-
793
- break;
794
- }
795
-
796
- case Pooled:
797
- {
798
- long double s = 0;
799
-
800
- /* determine first non-zero and get base info */
801
- for (int cat = 0; cat < ncat; cat++)
802
- {
803
- if (buffer_cnt[buffer_pos[cat]])
804
- {
805
- s += (buffer_cnt[buffer_pos[cat]] <= 1)?
806
- 0 : ((long double) buffer_cnt[buffer_pos[cat]] * logl((long double)buffer_cnt[buffer_pos[cat]]));
807
- }
808
-
809
- else
810
- {
811
- buffer_split[buffer_pos[cat]] = -1;
812
- st_pos++;
813
- }
814
- }
815
-
816
- if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
817
-
818
- /* calculate base info */
819
- long double base_info = cnt * logl(cnt) - s;
820
-
821
- if (all_perm)
822
- {
823
- size_t cnt_left, cnt_right;
824
- double s_left, s_right;
825
- size_t ncat_present = (size_t)ncat - st_pos;
826
- size_t ncomb = pow2(ncat_present) - 1;
827
- size_t best_combin;
828
-
829
- for (size_t combin = 1; combin < ncomb; combin++)
830
- {
831
- cnt_left = 0; cnt_right = 0;
832
- s_left = 0; s_right = 0;
833
- for (size_t pos = st_pos; (int)pos < ncat; pos++)
834
- {
835
- if (extract_bit(combin, pos))
836
- {
837
- cnt_left += buffer_cnt[buffer_pos[pos]];
838
- s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
839
- 0 : ((long double) buffer_cnt[buffer_pos[pos]]
840
- * logl((long double) buffer_cnt[buffer_pos[pos]]));
841
- }
842
-
843
- else
844
- {
845
- cnt_right += buffer_cnt[buffer_pos[pos]];
846
- s_right += (buffer_cnt[buffer_pos[pos]] <= 1)?
847
- 0 : ((long double) buffer_cnt[buffer_pos[pos]]
848
- * logl((long double) buffer_cnt[buffer_pos[pos]]));
849
- }
850
- }
851
-
852
- this_gain = categ_gain(cnt_left, cnt_right,
853
- s_left, s_right,
854
- base_info, cnt);
855
-
856
- if (this_gain > min_gain && this_gain > best_gain)
857
- {
858
- best_gain = this_gain;
859
- best_combin = combin;
860
- }
861
-
862
- }
863
-
864
- if (best_gain > min_gain)
865
- for (size_t pos = 0; pos < ncat_present; pos++)
866
- split_categ[buffer_pos[st_pos + pos]] = extract_bit(best_combin, pos);
867
-
868
- }
869
-
870
- else
871
- {
872
- /* try moving the categories one at a time */
873
- size_t cnt_left = 0;
874
- size_t cnt_right = end - st + 1;
875
- double s_left = 0;
876
- double s_right = s;
877
-
878
- for (size_t pos = st_pos; pos < (ncat - st_pos - 1); pos++)
879
- {
880
- buffer_split[buffer_pos[pos]] = 1;
881
- s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
882
- 0 : ((long double)buffer_cnt[buffer_pos[pos]] * logl((long double)buffer_cnt[buffer_pos[pos]]));
883
- s_right -= (buffer_cnt[buffer_pos[pos]] <= 1)?
884
- 0 : ((long double)buffer_cnt[buffer_pos[pos]] * logl((long double)buffer_cnt[buffer_pos[pos]]));
885
- cnt_left += buffer_cnt[buffer_pos[pos]];
886
- cnt_right -= buffer_cnt[buffer_pos[pos]];
887
-
888
- this_gain = categ_gain(cnt_left, cnt_right,
889
- s_left, s_right,
890
- base_info, cnt);
891
-
892
- if (this_gain > min_gain && this_gain > best_gain)
893
- {
894
- best_gain = this_gain;
895
- memcpy(split_categ, buffer_split, ncat * sizeof(char));
896
- }
897
- }
898
- }
899
-
900
- break;
901
- }
902
- }
903
- }
904
- }
905
-
906
- if (st == (end-1)) return 0;
907
-
908
- if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
909
- return 0;
910
- else
911
- return best_gain;
912
- }