isotree 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,912 @@
1
+ /* Isolation forests and variations thereof, with adjustments for incorporation
2
+ * of categorical variables and missing values.
3
+ * Writen for C++11 standard and aimed at being used in R and Python.
4
+ *
5
+ * This library is based on the following works:
6
+ * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
+ * "Isolation forest."
8
+ * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
+ * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
+ * "Isolation-based anomaly detection."
11
+ * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
+ * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
+ * "Extended Isolation Forest."
14
+ * arXiv preprint arXiv:1811.02141 (2018).
15
+ * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
+ * "On detecting clustered anomalies using SCiForest."
17
+ * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
+ * [5] https://sourceforge.net/projects/iforest/
19
+ * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
+ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
+ * [8] Cortes, David. "Distance approximation using Isolation Forests." arXiv preprint arXiv:1910.12362 (2019).
22
+ * [9] Cortes, David. "Imputing missing values with unsupervised random trees." arXiv preprint arXiv:1911.06646 (2019).
23
+ *
24
+ * BSD 2-Clause License
25
+ * Copyright (c) 2019, David Cortes
26
+ * All rights reserved.
27
+ * Redistribution and use in source and binary forms, with or without
28
+ * modification, are permitted provided that the following conditions are met:
29
+ * * Redistributions of source code must retain the above copyright notice, this
30
+ * list of conditions and the following disclaimer.
31
+ * * Redistributions in binary form must reproduce the above copyright notice,
32
+ * this list of conditions and the following disclaimer in the documentation
33
+ * and/or other materials provided with the distribution.
34
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
38
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44
+ */
45
+ #include "isotree.hpp"
46
+
47
+ #define pw1(x) ((x))
48
+ #define pw2(x) ((x) * (x))
49
+ #define pw3(x) ((x) * (x) * (x))
50
+ #define pw4(x) ((x) * (x) * (x) * (x))
51
+
52
+ double calc_kurtosis(size_t ix_arr[], size_t st, size_t end, double x[], MissingAction missing_action)
53
+ {
54
+ long double m = 0;
55
+ long double M2 = 0, M3 = 0, M4 = 0;
56
+ long double delta, delta_s, delta_div;
57
+ long double diff, n;
58
+
59
+ if (missing_action == Fail)
60
+ {
61
+ for (size_t row = st; row <= end; row++)
62
+ {
63
+ n = (long double)(row - st + 1);
64
+
65
+ delta = x[ix_arr[row]] - m;
66
+ delta_div = delta / n;
67
+ delta_s = delta_div * delta_div;
68
+ diff = delta * (delta_div * (long double)(row - st));
69
+
70
+ m += delta_div;
71
+ M4 += diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
72
+ M3 += diff * delta_div * (n - 2) - 3 * delta_div * M2;
73
+ M2 += diff;
74
+ }
75
+
76
+ return ( M4 / M2 ) * ( (long double)(end - st + 1) / M2 );
77
+ }
78
+
79
+ else
80
+ {
81
+ size_t cnt = 0;
82
+ for (size_t row = st; row <= end; row++)
83
+ {
84
+ if (!is_na_or_inf(x[ix_arr[row]]))
85
+ {
86
+ cnt++;
87
+ n = (long double) cnt;
88
+
89
+ delta = x[ix_arr[row]] - m;
90
+ delta_div = delta / n;
91
+ delta_s = delta_div * delta_div;
92
+ diff = delta * (delta_div * (long double)(cnt - 1));
93
+
94
+ m += delta_div;
95
+ M4 += diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
96
+ M3 += diff * delta_div * (n - 2) - 3 * delta_div * M2;
97
+ M2 += diff;
98
+ }
99
+ }
100
+
101
+ return ( M4 / M2 ) * ( (long double)cnt / M2 );
102
+ }
103
+ }
104
+
105
+
106
+ double calc_kurtosis(size_t ix_arr[], size_t st, size_t end, size_t col_num,
107
+ double Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
108
+ MissingAction missing_action)
109
+ {
110
+ /* ix_arr must be already sorted beforehand */
111
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
112
+ return 0;
113
+
114
+ long double s1 = 0;
115
+ long double s2 = 0;
116
+ long double s3 = 0;
117
+ long double s4 = 0;
118
+ size_t cnt = end - st + 1;
119
+
120
+ if (cnt <= 1) return 0;
121
+
122
+ size_t st_col = Xc_indptr[col_num];
123
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
124
+ size_t curr_pos = st_col;
125
+ size_t ind_end_col = Xc_ind[end_col];
126
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
127
+
128
+ if (missing_action != Fail)
129
+ {
130
+ for (size_t *row = ptr_st;
131
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
132
+ )
133
+ {
134
+ if (Xc_ind[curr_pos] == *row)
135
+ {
136
+ if (is_na_or_inf(Xc[curr_pos]))
137
+ {
138
+ cnt--;
139
+ }
140
+
141
+ else
142
+ {
143
+ s1 += pw1(Xc[curr_pos]);
144
+ s2 += pw2(Xc[curr_pos]);
145
+ s3 += pw3(Xc[curr_pos]);
146
+ s4 += pw4(Xc[curr_pos]);
147
+ }
148
+
149
+ if (row == ix_arr + end || curr_pos == end_col) break;
150
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
151
+ }
152
+
153
+ else
154
+ {
155
+ if (Xc_ind[curr_pos] > *row)
156
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
157
+ else
158
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
159
+ }
160
+ }
161
+ }
162
+
163
+ else
164
+ {
165
+ for (size_t *row = ptr_st;
166
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
167
+ )
168
+ {
169
+ if (Xc_ind[curr_pos] == *row)
170
+ {
171
+ s1 += pw1(Xc[curr_pos]);
172
+ s2 += pw2(Xc[curr_pos]);
173
+ s3 += pw3(Xc[curr_pos]);
174
+ s4 += pw4(Xc[curr_pos]);
175
+
176
+ if (row == ix_arr + end || curr_pos == end_col) break;
177
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
178
+ }
179
+
180
+ else
181
+ {
182
+ if (Xc_ind[curr_pos] > *row)
183
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
184
+ else
185
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
186
+ }
187
+ }
188
+ }
189
+
190
+ if (cnt <= 1 || s2 == 0 || s2 == pw2(s1)) return 0;
191
+ long double cnt_l = (long double) cnt;
192
+ long double sn = s1 / cnt_l;
193
+ long double v = s2 / cnt_l - pw2(sn);
194
+ if (v <= 0) return 0;
195
+ return (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt_l * pw4(sn)) / (cnt_l * pw2(v));
196
+ }
197
+
198
+
199
+ double calc_kurtosis(size_t ix_arr[], size_t st, size_t end, int x[], int ncat, size_t buffer_cnt[], double buffer_prob[],
200
+ MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
201
+ {
202
+ /* This calculation proceeds as follows:
203
+ - If splitting by subsets, it will assign a random weight ~Unif(0,1) to
204
+ each category, and approximate kurtosis by sampling from such distribution
205
+ with the same probabilities as given by the current counts.
206
+ - If splitting by isolating one category, will binarize at each categorical level,
207
+ assume the values are zero or one, and output the average assuming each categorical
208
+ level has equal probability of being picked.
209
+ (Note that both are misleading heuristics, but might be better than random)
210
+ */
211
+ size_t cnt = end - st + 1;
212
+ std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
213
+ double sum_kurt = 0;
214
+
215
+ if (missing_action == Fail)
216
+ {
217
+ for (size_t row = st; row <= end; row++)
218
+ buffer_cnt[x[ix_arr[row]]]++;
219
+ }
220
+
221
+ else
222
+ {
223
+ for (size_t row = st; row <= end; row++)
224
+ {
225
+ if (x[ix_arr[row]] >= 0)
226
+ buffer_cnt[x[ix_arr[row]]]++;
227
+ else
228
+ buffer_cnt[ncat]++;
229
+ }
230
+ }
231
+
232
+ cnt -= buffer_cnt[ncat];
233
+ if (cnt <= 1) return 0;
234
+ long double cnt_l = (long double) cnt;
235
+ for (int cat = 0; cat < ncat; cat++)
236
+ buffer_prob[cat] = buffer_cnt[cat] / cnt_l;
237
+
238
+ switch(cat_split_type)
239
+ {
240
+ case SubSet:
241
+ {
242
+ long double temp_v;
243
+ long double s1, s2, s3, s4;
244
+ long double coef;
245
+ std::uniform_real_distribution<double> runif(0, 1);
246
+ size_t ntry = 50;
247
+ for (size_t iternum = 0; iternum < 50; iternum++)
248
+ {
249
+ s1 = 0; s2 = 0; s3 = 0; s4 = 0;
250
+ for (int cat = 0; cat < ncat; cat++)
251
+ {
252
+ coef = runif(rnd_generator);
253
+ s1 += buffer_prob[cat] * pw1(coef);
254
+ s2 += buffer_prob[cat] * pw2(coef);
255
+ s3 += buffer_prob[cat] * pw3(coef);
256
+ s4 += buffer_prob[cat] * pw4(coef);
257
+ }
258
+ temp_v = s2 - pw2(s1);
259
+ if (temp_v <= 0)
260
+ ntry--;
261
+ else
262
+ sum_kurt += (s4 - 4 * s3 * pw1(s1) + 6 * s2 * pw2(s1) - 4 * s1 * pw3(s1) + pw4(s1)) / pw2(temp_v);
263
+ }
264
+ if (!ntry)
265
+ return 0;
266
+ else
267
+ return sum_kurt / (long double)ntry;
268
+ }
269
+
270
+ case SingleCateg:
271
+ {
272
+ double p;
273
+ int ncat_present = ncat;
274
+ for (int cat = 0; cat < ncat; cat++)
275
+ {
276
+ p = buffer_prob[cat];
277
+ if (p == 0)
278
+ ncat_present--;
279
+ else
280
+ sum_kurt += (p - 4 * p * pw1(p) + 6 * p * pw2(p) - 4 * p * pw3(p) + pw4(p)) / pw2(p - pw2(p));
281
+ }
282
+ if (ncat_present <= 1)
283
+ return 0;
284
+ else
285
+ return sum_kurt / (double) ncat_present;
286
+ }
287
+ }
288
+
289
+ return -1; /* this will never be reached, but CRAN complains otherwise */
290
+ }
291
+
292
+
293
+ double expected_sd_cat(double p[], size_t n, size_t pos[])
294
+ {
295
+ if (n <= 1) return 0;
296
+
297
+ long double cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[1]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[1]]) / 3.0 + p[pos[1]] / 3.0;
298
+ for (size_t cat1 = 2; cat1 < n; cat1++)
299
+ {
300
+ cum_var += p[pos[cat1]] / 3.0 - square(p[pos[cat1]]) / 3.0;
301
+ for (size_t cat2 = 0; cat2 < cat1; cat2++)
302
+ cum_var -= p[pos[cat1]] * p[pos[cat2]] / 2.0;
303
+ }
304
+ return sqrt(fmax(cum_var, 1e-8));
305
+ }
306
+
307
+ double expected_sd_cat(size_t counts[], double p[], size_t n, size_t pos[])
308
+ {
309
+ if (n <= 1) return 0;
310
+
311
+ size_t tot = std::accumulate(pos, pos + n, (size_t)0, [&counts](size_t tot, const size_t ix){return tot + counts[ix];});
312
+ long double cnt_div = (long double) tot;
313
+ for (size_t cat = 0; cat < n; cat++)
314
+ p[pos[cat]] = (long double)counts[pos[cat]] / cnt_div;
315
+
316
+ return expected_sd_cat(p, n, pos);
317
+ }
318
+
319
+ double expected_sd_cat_single(size_t counts[], double p[], size_t n, size_t pos[], size_t cat_exclude, size_t cnt)
320
+ {
321
+ if (cat_exclude == 0)
322
+ return expected_sd_cat(counts, p, n-1, pos + 1);
323
+
324
+ else if (cat_exclude == (n-1))
325
+ return expected_sd_cat(counts, p, n-1, pos);
326
+
327
+ size_t ix_exclude = pos[cat_exclude];
328
+
329
+ long double cnt_div = (long double) (cnt - counts[ix_exclude]);
330
+ for (size_t cat = 0; cat < n; cat++)
331
+ p[pos[cat]] = (long double)counts[pos[cat]] / cnt_div;
332
+
333
+ double cum_var;
334
+ if (cat_exclude != 1)
335
+ cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[1]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[1]]) / 3.0 + p[pos[1]] / 3.0;
336
+ else
337
+ cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[2]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[2]]) / 3.0 + p[pos[2]] / 3.0;
338
+ for (size_t cat1 = (cat_exclude == 1)? 3 : 2; cat1 < n; cat1++)
339
+ {
340
+ if (pos[cat1] == ix_exclude) continue;
341
+ cum_var += p[pos[cat1]] / 3.0 - square(p[pos[cat1]]) / 3.0;
342
+ for (size_t cat2 = 0; cat2 < cat1; cat2++)
343
+ {
344
+ if (pos[cat2] == ix_exclude) continue;
345
+ cum_var -= p[pos[cat1]] * p[pos[cat2]] / 2.0;
346
+ }
347
+
348
+ }
349
+ return sqrt(fmax(cum_var, 1e-8));
350
+ }
351
+
352
+ double numeric_gain(size_t cnt_left, size_t cnt_right,
353
+ long double sum_left, long double sum_right,
354
+ long double sum_sq_left, long double sum_sq_right,
355
+ double sd_full, long double cnt)
356
+ {
357
+ long double residual =
358
+ (long double) cnt_left * calc_sd_raw_l(cnt_left, sum_left, sum_sq_left) +
359
+ (long double) cnt_right * calc_sd_raw_l(cnt_right, sum_right, sum_sq_right);
360
+ return 1 - residual / (cnt * sd_full);
361
+ }
362
+
363
+ double numeric_gain_no_div(size_t cnt_left, size_t cnt_right,
364
+ long double sum_left, long double sum_right,
365
+ long double sum_sq_left, long double sum_sq_right,
366
+ double sd_full, long double cnt)
367
+ {
368
+ long double residual =
369
+ (long double) cnt_left * calc_sd_raw_l(cnt_left, sum_left, sum_sq_left) +
370
+ (long double) cnt_right * calc_sd_raw_l(cnt_right, sum_right, sum_sq_right);
371
+ return sd_full - residual / cnt;
372
+ }
373
+
374
+ double categ_gain(size_t cnt_left, size_t cnt_right,
375
+ long double s_left, long double s_right,
376
+ long double base_info, long double cnt)
377
+ {
378
+ return (
379
+ base_info -
380
+ (((cnt_left <= 1)? 0 : ((long double)cnt_left * logl((long double)cnt_left))) - s_left) -
381
+ (((cnt_right <= 1)? 0 : ((long double)cnt_right * logl((long double)cnt_right))) - s_right)
382
+ ) / cnt;
383
+ }
384
+
385
+
386
+ #define avg_between(a, b) (((a) + (b)) / 2)
387
+ #define sd_gain(sd, sd_left, sd_right) (1.0 - ((sd_left) + (sd_right)) / (2.0 * (sd)))
388
+
389
+ /* for split-criterion in hyperplanes (see below for version aimed at single-variable splits) */
390
+ double eval_guided_crit(double *restrict x, size_t n, GainCriterion criterion, double min_gain,
391
+ double &split_point, double &xmin, double &xmax)
392
+ {
393
+ /* Note: the input 'x' is supposed to be a linear combination of standardized variables, so
394
+ all numbers are assumed to be small and in the same scale */
395
+
396
+ /* here it's assumed the 'x' vector matches exactly with 'ix_arr' + 'st' */
397
+ if (n == 2)
398
+ {
399
+ split_point = avg_between(x[0], x[1]);
400
+ return 0;
401
+ }
402
+
403
+ /* sort in ascending order */
404
+ std::sort(x, x + n);
405
+ if (x[0] == x[n-1]) return -HUGE_VAL;
406
+ xmin = x[0]; xmax = x[n-1];
407
+
408
+ /* compute sum - sum_sq - sd in one pass */
409
+ long double sum = 0;
410
+ long double sum_sq = 0;
411
+ double sd_full;
412
+ for (size_t row = 0; row < n; row++)
413
+ {
414
+ sum += x[row];
415
+ sum_sq += square(x[row]);
416
+ }
417
+ sd_full = calc_sd_raw(n, sum, sum_sq);
418
+
419
+ /* try splits by moving observations one at a time from right to left */
420
+ long double sum_left = 0;
421
+ long double sum_sq_left = 0;
422
+ long double sum_right = sum;
423
+ long double sum_sq_right = sum_sq;
424
+ double this_gain = -HUGE_VAL;
425
+ double best_gain = -HUGE_VAL;
426
+
427
+ switch(criterion)
428
+ {
429
+ case Averaged:
430
+ {
431
+ for (size_t row = 0; row < n-1; row++)
432
+ {
433
+ sum_left += x[row];
434
+ sum_sq_left += square(x[row]);
435
+ sum_right -= x[row];
436
+ sum_sq_right -= square(x[row]);
437
+
438
+ if (x[row] == x[row + 1]) continue;
439
+
440
+ this_gain = sd_gain(sd_full,
441
+ calc_sd_raw(row + 1, sum_left, sum_sq_left),
442
+ calc_sd_raw(n - row - 1, sum_right, sum_sq_right)
443
+ );
444
+ if (this_gain > min_gain && this_gain > best_gain)
445
+ {
446
+ best_gain = this_gain;
447
+ split_point = avg_between(x[row], x[row + 1]);
448
+ }
449
+ }
450
+ break;
451
+ }
452
+
453
+ case Pooled:
454
+ {
455
+ long double cnt = (long double) n;
456
+ for (size_t row = 0; row < n-1; row++)
457
+ {
458
+ sum_left += x[row];
459
+ sum_sq_left += square(x[row]);
460
+ sum_right -= x[row];
461
+ sum_sq_right -= square(x[row]);
462
+
463
+ if (x[row] == x[row + 1]) continue;
464
+
465
+ this_gain = numeric_gain(row + 1, n - row - 1,
466
+ sum_left, sum_right,
467
+ sum_sq_left, sum_sq_right,
468
+ sd_full, cnt
469
+ );
470
+
471
+ if (this_gain > min_gain && this_gain > best_gain)
472
+ {
473
+ best_gain = this_gain;
474
+ split_point = avg_between(x[row], x[row + 1]);
475
+ }
476
+ }
477
+ break;
478
+ }
479
+ }
480
+
481
+ if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
482
+ return 0;
483
+ else
484
+ return best_gain;
485
+ }
486
+
487
+ /* for split-criterion in single-variable splits */
488
+ #define std_val(x, m, sd) ( ((x) - (m)) / (sd) )
489
+ double eval_guided_crit(size_t *restrict ix_arr, size_t st, size_t end, double *restrict x,
490
+ size_t &split_ix, double &split_point, double &xmin, double &xmax,
491
+ GainCriterion criterion, double min_gain, MissingAction missing_action)
492
+ {
493
+ /* move NAs to the front if there's any, exclude them from calculations */
494
+ if (missing_action != Fail)
495
+ st = move_NAs_to_front(ix_arr, st, end, x);
496
+
497
+ if (st >= end) return -HUGE_VAL;
498
+ else if (st == (end-1))
499
+ {
500
+ split_point = avg_between(x[ix_arr[st]], x[ix_arr[end]]);
501
+ split_ix = st;
502
+ return 0;
503
+ }
504
+
505
+ /* sort in ascending order */
506
+ std::sort(ix_arr + st, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
507
+ if (x[ix_arr[st]] == x[ix_arr[end]]) return -HUGE_VAL;
508
+ xmin = x[ix_arr[st]]; xmax = x[ix_arr[end]];
509
+
510
+ /* Note: these variables are not standardized beforehand, so a single-pass gain
511
+ calculation for both branches would suffer from numerical instability and perhaps give
512
+ negative standard deviations if the sample size is large or the values have different
513
+ orders of magnitude */
514
+
515
+ /* first get mean and sd */
516
+ double x_mean, x_sd;
517
+ calc_mean_and_sd(ix_arr, st, end, x,
518
+ Fail, x_sd, x_mean);
519
+
520
+ /* compute sum - sum_sq - sd in one pass, on the standardized values */
521
+ double zval;
522
+ long double sum = 0;
523
+ long double sum_sq = 0;
524
+ double sd_full;
525
+ for (size_t row = st; row <= end; row++)
526
+ {
527
+ zval = std_val(x[ix_arr[row]], x_mean, x_sd);
528
+ sum += zval;
529
+ sum_sq += square(zval);
530
+ }
531
+ sd_full = calc_sd_raw(end - st + 1, sum, sum_sq);
532
+
533
+ /* try splits by moving observations one at a time from right to left */
534
+ long double sum_left = 0;
535
+ long double sum_sq_left = 0;
536
+ long double sum_right = sum;
537
+ long double sum_sq_right = sum_sq;
538
+ double this_gain = -HUGE_VAL;
539
+ double best_gain = -HUGE_VAL;
540
+
541
+ switch(criterion)
542
+ {
543
+ case Averaged:
544
+ {
545
+ for (size_t row = st; row < end; row++)
546
+ {
547
+ zval = std_val(x[ix_arr[row]], x_mean, x_sd);
548
+ sum_left += zval;
549
+ sum_sq_left += square(zval);
550
+ sum_right -= zval;
551
+ sum_sq_right -= square(zval);
552
+
553
+ if (x[ix_arr[row]] == x[ix_arr[row + 1]]) continue;
554
+
555
+ this_gain = sd_gain(sd_full,
556
+ calc_sd_raw(row - st + 1, sum_left, sum_sq_left),
557
+ calc_sd_raw(end - row, sum_right, sum_sq_right)
558
+ );
559
+ if (this_gain > min_gain && this_gain > best_gain)
560
+ {
561
+ best_gain = this_gain;
562
+ split_point = avg_between(x[ix_arr[row]], x[ix_arr[row + 1]]);
563
+ split_ix = row;
564
+ }
565
+ }
566
+ break;
567
+ }
568
+
569
+ case Pooled:
570
+ {
571
+ long double cnt = (long double)(end - st + 1);
572
+ for (size_t row = st; row < end; row++)
573
+ {
574
+ zval = std_val(x[ix_arr[row]], x_mean, x_sd);
575
+ sum_left += zval;
576
+ sum_sq_left += square(zval);
577
+ sum_right -= zval;
578
+ sum_sq_right -= square(zval);
579
+
580
+ if (x[ix_arr[row]] == x[ix_arr[row + 1]]) continue;
581
+
582
+ this_gain = numeric_gain_no_div(row - st + 1, end - row,
583
+ sum_left, sum_right,
584
+ sum_sq_left, sum_sq_right,
585
+ sd_full, cnt
586
+ );
587
+
588
+ if (this_gain > min_gain && this_gain > best_gain)
589
+ {
590
+ best_gain = this_gain;
591
+ split_point = avg_between(x[ix_arr[row]], x[ix_arr[row + 1]]);
592
+ split_ix = row;
593
+ }
594
+ }
595
+ break;
596
+ }
597
+ }
598
+
599
+ if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
600
+ return 0;
601
+ else
602
+ return best_gain;
603
+ }
604
+
605
+ double eval_guided_crit(size_t ix_arr[], size_t st, size_t end,
606
+ size_t col_num, double Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
607
+ double buffer_arr[], size_t buffer_pos[],
608
+ double &split_point, double &xmin, double &xmax,
609
+ GainCriterion criterion, double min_gain, MissingAction missing_action)
610
+ {
611
+ todense(ix_arr, st, end,
612
+ col_num, Xc, Xc_ind, Xc_indptr,
613
+ buffer_arr);
614
+ std::iota(buffer_pos, buffer_pos + (end - st + 1), (size_t)0);
615
+ size_t temp;
616
+ return eval_guided_crit(buffer_pos, 0, end - st, buffer_arr, temp, split_point,
617
+ xmin, xmax, criterion, min_gain, missing_action);
618
+ }
619
+
620
+ /* How this works:
621
+ - For Averaged criterion, will take the expected standard deviation that would be gotten with the category counts
622
+ if each category got assigned a real number at random ~ Unif(0,1) and the data were thus converted to
623
+ numerical. In such case, the best split (highest sd gain) is always putting the second-highest count in one
624
+ branch, so there is no point in doing a full search over other permutations. In order to get more reasonable
625
+ splits, when using the option to split by subsets of categories, it will sort the counts and evaluate only
626
+ splits in which the categories are grouped in sorted order - in such cases it tends to pick either the
627
+ smallest or the largest category to assign to one branch, but sometimes picks groups too.
628
+ - For Pooled criterion, will take shannon entropy, which tends to make a more even split. In the case of splitting
629
+ by a single category, it always puts the largest category in a separate branch. In the case of subsets,
630
+ it can either evaluate possible splits over all permutations (not feasible if there are too many categories),
631
+ or look up for splits in sorted order just like for Averaged criterion.
632
+ Splitting by averaged Gini gain (like with Averaged) also selects always the second-largest category to put in one branch,
633
+ while splitting by weighted Gini (like with Pooled) usually selects the largest category to put in one branch. The
634
+ Gini gain is not easily comparable to that of numerical columns, so it's not offered as an option here.
635
+ */
636
+ /* https://math.stackexchange.com/questions/3343384/expected-variance-and-kurtosis-from-pmf-in-which-possible-discrete-values-are-dr */
637
+ double eval_guided_crit(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int ncat,
638
+ size_t *restrict buffer_cnt, size_t *restrict buffer_pos, double *restrict buffer_prob,
639
+ int &chosen_cat, char *restrict split_categ, char *restrict buffer_split,
640
+ GainCriterion criterion, double min_gain, bool all_perm, MissingAction missing_action, CategSplit cat_split_type)
641
+ {
642
+ /* move NAs to the front if there's any, exclude them from calculations */
643
+ if (missing_action != Fail)
644
+ st = move_NAs_to_front(ix_arr, st, end, x);
645
+
646
+ if (st >= end) return -HUGE_VAL;
647
+
648
+ /* count categories */
649
+ memset(buffer_cnt, 0, sizeof(size_t) * ncat);
650
+ for (size_t row = st; row <= end; row++)
651
+ buffer_cnt[x[ix_arr[row]]]++;
652
+
653
+ double this_gain = -HUGE_VAL;
654
+ double best_gain = -HUGE_VAL;
655
+ std::iota(buffer_pos, buffer_pos + ncat, (size_t)0);
656
+ size_t st_pos = 0;
657
+
658
+ switch(cat_split_type)
659
+ {
660
+ case SingleCateg:
661
+ {
662
+ size_t cnt = end - st + 1;
663
+ size_t ncat_present = 0;
664
+
665
+ switch(criterion)
666
+ {
667
+ case Averaged:
668
+ {
669
+ /* move zero-counts to the beginning */
670
+ size_t temp;
671
+ for (int cat = 0; cat < ncat; cat++)
672
+ {
673
+ if (buffer_cnt[cat])
674
+ {
675
+ ncat_present++;
676
+ buffer_prob[cat] = (long double) buffer_cnt[cat] / (long double) cnt;
677
+ }
678
+
679
+ else
680
+ {
681
+ temp = buffer_pos[st_pos];
682
+ buffer_pos[st_pos] = buffer_pos[cat];
683
+ buffer_pos[cat] = temp;
684
+ st_pos++;
685
+ }
686
+ }
687
+
688
+ if (ncat_present <= 1) return -HUGE_VAL;
689
+
690
+ double sd_full = expected_sd_cat(buffer_prob, ncat_present, buffer_pos + st_pos);
691
+
692
+ /* try isolating each category one at a time */
693
+ for (size_t pos = st_pos; (int)pos < ncat; pos++)
694
+ {
695
+ this_gain = sd_gain(sd_full,
696
+ 0.0,
697
+ expected_sd_cat_single(buffer_cnt, buffer_prob, ncat_present, buffer_pos + st_pos, pos - st_pos, cnt)
698
+ );
699
+ if (this_gain > min_gain && this_gain > best_gain)
700
+ {
701
+ best_gain = this_gain;
702
+ chosen_cat = buffer_pos[pos];
703
+ }
704
+ }
705
+ break;
706
+ }
707
+
708
+ case Pooled:
709
+ {
710
+ /* here it will always pick the largest one */
711
+ size_t ncat_present = 0;
712
+ size_t cnt_max = 0;
713
+ for (int cat = 0; cat < ncat; cat++)
714
+ {
715
+ if (buffer_cnt[cat])
716
+ {
717
+ ncat_present++;
718
+ if (cnt_max < buffer_cnt[cat])
719
+ {
720
+ cnt_max = buffer_cnt[cat];
721
+ chosen_cat = cat;
722
+ }
723
+ }
724
+ }
725
+
726
+ if (ncat_present <= 1) return -HUGE_VAL;
727
+
728
+ long double cnt_left = (long double)((end - st + 1) - cnt_max);
729
+ this_gain = (
730
+ (long double)cnt * logl((long double)cnt)
731
+ - cnt_left * logl(cnt_left)
732
+ - (long double)cnt_max * logl((long double)cnt_max)
733
+ ) / cnt;
734
+ best_gain = (this_gain > min_gain)? this_gain : best_gain;
735
+ break;
736
+ }
737
+ }
738
+ break;
739
+ }
740
+
741
+ case SubSet:
742
+ {
743
+ /* sort by counts */
744
+ std::sort(buffer_pos, buffer_pos + ncat, [&buffer_cnt](const size_t a, const size_t b){return buffer_cnt[a] < buffer_cnt[b];});
745
+
746
+ /* set split as: (1):left (0):right (-1):not_present */
747
+ memset(buffer_split, 0, ncat * sizeof(char));
748
+
749
+ long double cnt = (long double)(end - st + 1);
750
+
751
+ switch(criterion)
752
+ {
753
+ case Averaged:
754
+ {
755
+ /* determine first non-zero and convert to probabilities */
756
+ double sd_full;
757
+ for (int cat = 0; cat < ncat; cat++)
758
+ {
759
+ if (buffer_cnt[buffer_pos[cat]])
760
+ {
761
+ buffer_prob[buffer_pos[cat]] = (long double)buffer_cnt[buffer_pos[cat]] / cnt;
762
+ }
763
+
764
+ else
765
+ {
766
+ buffer_split[buffer_pos[cat]] = -1;
767
+ st_pos++;
768
+ }
769
+ }
770
+
771
+ if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
772
+
773
+ /* calculate full SD assuming they take values randomly ~Unif(0, 1) */
774
+ size_t ncat_present = (size_t)ncat - st_pos;
775
+ sd_full = expected_sd_cat(buffer_prob, ncat_present, buffer_pos + st_pos);
776
+ if (ncat_present >= log2ceil(SIZE_MAX)) all_perm = false;
777
+
778
+ /* move categories one at a time */
779
+ for (size_t pos = st_pos; pos < ((size_t)ncat - st_pos - 1); pos++)
780
+ {
781
+ buffer_split[buffer_pos[pos]] = 1;
782
+ this_gain = sd_gain(sd_full,
783
+ expected_sd_cat(buffer_cnt, buffer_prob, pos - st_pos + 1, buffer_pos + st_pos),
784
+ expected_sd_cat(buffer_cnt, buffer_prob, (size_t)ncat - pos - 1, buffer_pos + pos + 1)
785
+ );
786
+ if (this_gain > min_gain && this_gain > best_gain)
787
+ {
788
+ best_gain = this_gain;
789
+ memcpy(split_categ, buffer_split, ncat * sizeof(char));
790
+ }
791
+ }
792
+
793
+ break;
794
+ }
795
+
796
+ case Pooled:
797
+ {
798
+ long double s = 0;
799
+
800
+ /* determine first non-zero and get base info */
801
+ for (int cat = 0; cat < ncat; cat++)
802
+ {
803
+ if (buffer_cnt[buffer_pos[cat]])
804
+ {
805
+ s += (buffer_cnt[buffer_pos[cat]] <= 1)?
806
+ 0 : ((long double) buffer_cnt[buffer_pos[cat]] * logl((long double)buffer_cnt[buffer_pos[cat]]));
807
+ }
808
+
809
+ else
810
+ {
811
+ buffer_split[buffer_pos[cat]] = -1;
812
+ st_pos++;
813
+ }
814
+ }
815
+
816
+ if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
817
+
818
+ /* calculate base info */
819
+ long double base_info = cnt * logl(cnt) - s;
820
+
821
+ if (all_perm)
822
+ {
823
+ size_t cnt_left, cnt_right;
824
+ double s_left, s_right;
825
+ size_t ncat_present = (size_t)ncat - st_pos;
826
+ size_t ncomb = pow2(ncat_present) - 1;
827
+ size_t best_combin;
828
+
829
+ for (size_t combin = 1; combin < ncomb; combin++)
830
+ {
831
+ cnt_left = 0; cnt_right = 0;
832
+ s_left = 0; s_right = 0;
833
+ for (size_t pos = st_pos; (int)pos < ncat; pos++)
834
+ {
835
+ if (extract_bit(combin, pos))
836
+ {
837
+ cnt_left += buffer_cnt[buffer_pos[pos]];
838
+ s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
839
+ 0 : ((long double) buffer_cnt[buffer_pos[pos]]
840
+ * logl((long double) buffer_cnt[buffer_pos[pos]]));
841
+ }
842
+
843
+ else
844
+ {
845
+ cnt_right += buffer_cnt[buffer_pos[pos]];
846
+ s_right += (buffer_cnt[buffer_pos[pos]] <= 1)?
847
+ 0 : ((long double) buffer_cnt[buffer_pos[pos]]
848
+ * logl((long double) buffer_cnt[buffer_pos[pos]]));
849
+ }
850
+ }
851
+
852
+ this_gain = categ_gain(cnt_left, cnt_right,
853
+ s_left, s_right,
854
+ base_info, cnt);
855
+
856
+ if (this_gain > min_gain && this_gain > best_gain)
857
+ {
858
+ best_gain = this_gain;
859
+ best_combin = combin;
860
+ }
861
+
862
+ }
863
+
864
+ if (best_gain > min_gain)
865
+ for (size_t pos = 0; pos < ncat_present; pos++)
866
+ split_categ[buffer_pos[st_pos + pos]] = extract_bit(best_combin, pos);
867
+
868
+ }
869
+
870
+ else
871
+ {
872
+ /* try moving the categories one at a time */
873
+ size_t cnt_left = 0;
874
+ size_t cnt_right = end - st + 1;
875
+ double s_left = 0;
876
+ double s_right = s;
877
+
878
+ for (size_t pos = st_pos; pos < (ncat - st_pos - 1); pos++)
879
+ {
880
+ buffer_split[buffer_pos[pos]] = 1;
881
+ s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
882
+ 0 : ((long double)buffer_cnt[buffer_pos[pos]] * logl((long double)buffer_cnt[buffer_pos[pos]]));
883
+ s_right -= (buffer_cnt[buffer_pos[pos]] <= 1)?
884
+ 0 : ((long double)buffer_cnt[buffer_pos[pos]] * logl((long double)buffer_cnt[buffer_pos[pos]]));
885
+ cnt_left += buffer_cnt[buffer_pos[pos]];
886
+ cnt_right -= buffer_cnt[buffer_pos[pos]];
887
+
888
+ this_gain = categ_gain(cnt_left, cnt_right,
889
+ s_left, s_right,
890
+ base_info, cnt);
891
+
892
+ if (this_gain > min_gain && this_gain > best_gain)
893
+ {
894
+ best_gain = this_gain;
895
+ memcpy(split_categ, buffer_split, ncat * sizeof(char));
896
+ }
897
+ }
898
+ }
899
+
900
+ break;
901
+ }
902
+ }
903
+ }
904
+ }
905
+
906
+ if (st == (end-1)) return 0;
907
+
908
+ if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
909
+ return 0;
910
+ else
911
+ return best_gain;
912
+ }