outliertree 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,1098 @@
1
+ /********************************************************************************************************************
2
+ * Explainable outlier detection
3
+ *
4
+ * Tries to detect outliers by generating decision trees that attempt to predict the values of each column based on
5
+ * each other column, testing in each branch of every tried split (if it meets some minimum criteria) whether there
6
+ * are observations that seem too distant from the others in a 1-D distribution for the column that the split tries
7
+ * to "predict" (will not generate a score for each observation).
8
+ * Splits are based on gain, while outlierness is based on confidence intervals.
9
+ * Similar in spirit to the GritBot software developed by RuleQuest research. Reference article is:
10
+ * Cortes, David. "Explainable outlier detection through decision tree conditioning."
11
+ * arXiv preprint arXiv:2001.00636 (2020).
12
+ *
13
+ *
14
+ * Copyright 2020 David Cortes.
15
+ *
16
+ * Written for C++11 standard and OpenMP 2.0 or later. Code is meant to be wrapped into scripting languages
17
+ * such as R or Python.
18
+ *
19
+ * This file is part of OutlierTree.
20
+ *
21
+ * OutlierTree is free software: you can redistribute it and/or modify
22
+ * it under the terms of the GNU General Public License as published by
23
+ * the Free Software Foundation, either version 3 of the License, or
24
+ * (at your option) any later version.
25
+ *
26
+ * OutlierTree is distributed in the hope that it will be useful,
27
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
28
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
29
+ * GNU General Public License for more details.
30
+ *
31
+ * You should have received a copy of the GNU General Public License
32
+ * along with OutlierTree. If not, see <https://www.gnu.org/licenses/>.
33
+ ********************************************************************************************************************/
34
+ #include "outlier_tree.hpp"
35
+
36
+
37
+ /* TODO: don't divide the gains by tot at every calculation as it makes it slower */
38
+
39
+ /* TODO: sorting here is the slowest thing, so it could be improved by using radix sort for categorical/ordinal and timsort for numerical */
40
+
41
+ /* TODO: columns that split by numeric should output the sum/sum_sq to pass it to the cluster functions, instead of recalculating them later */
42
+
43
+
44
+ void subset_to_onehot(size_t ix_arr[], size_t n_true, size_t n_tot, char onehot[])
45
+ {
46
+ memset(onehot, 0, sizeof(bool) * n_tot);
47
+ for (size_t i = 0; i <= n_true; i++) onehot[ix_arr[i]] = 1;
48
+ }
49
+
50
+ size_t move_zero_count_to_front(size_t *restrict cat_sorted, size_t *restrict cat_cnt, size_t ncat_x)
51
+ {
52
+ size_t temp_ix;
53
+ size_t st_cat = 0;
54
+ for (size_t cat = 0; cat < ncat_x; cat++) {
55
+ if (cat_cnt[cat] == 0) {
56
+ temp_ix = cat_sorted[st_cat];
57
+ cat_sorted[st_cat] = cat;
58
+ cat_sorted[cat] = temp_ix;
59
+ st_cat++;
60
+ }
61
+ }
62
+ return st_cat;
63
+ }
64
+
65
+ void flag_zero_counts(char split_subset[], size_t buffer_cat_cnt[], size_t ncat_x)
66
+ {
67
+ for (size_t cat = 0; cat < ncat_x; cat++)
68
+ if (buffer_cat_cnt[cat] == 0) split_subset[cat] = -1;
69
+ }
70
+
71
+ long double calc_sd(size_t cnt, long double sum, long double sum_sq)
72
+ {
73
+ if (cnt < 3) return 0;
74
+ return sqrtl( (sum_sq - (square(sum) / (long double) cnt) + SD_REG) / (long double) (cnt - 1) );
75
+ }
76
+
77
+ long double calc_sd(NumericBranch &branch)
78
+ {
79
+ if (branch.cnt < 3) return 0;
80
+ return sqrtl((branch.sum_sq - (square(branch.sum) / (long double) branch.cnt) + SD_REG) / (long double) (branch.cnt - 1));
81
+ }
82
+
83
+ long double calc_sd(size_t ix_arr[], double *restrict x, size_t st, size_t end, double *restrict mean)
84
+ {
85
+ long double running_mean = 0;
86
+ long double mean_prev = 0;
87
+ long double running_ssq = 0;
88
+ double xval;
89
+ for (size_t row = st; row <= end; row++) {
90
+ xval = x[ix_arr[row]];
91
+ running_mean += (xval - running_mean) / (long double)(row - st + 1);
92
+ running_ssq += (xval - running_mean) * (xval - mean_prev);
93
+ mean_prev = running_mean;
94
+ }
95
+ *mean = (double) running_mean;
96
+ return sqrtl(running_ssq / (long double)(end - st));
97
+
98
+ }
99
+
100
+ long double numeric_gain(NumericSplit &split_info, long double tot_sd)
101
+ {
102
+ long double tot = (long double)(split_info.NA_branch.cnt + split_info.left_branch.cnt + split_info.right_branch.cnt);
103
+ long double residual =
104
+ ((long double) split_info.NA_branch.cnt) * calc_sd(split_info.NA_branch) +
105
+ ((long double) split_info.left_branch.cnt) * calc_sd(split_info.left_branch) +
106
+ ((long double) split_info.right_branch.cnt) * calc_sd(split_info.right_branch);
107
+
108
+ return tot_sd - (residual / tot);
109
+ }
110
+
111
+ long double numeric_gain(long double tot_sd, long double info_left, long double info_right, long double info_NA, long double cnt)
112
+ {
113
+ return tot_sd - (info_left + info_right + info_NA) / cnt;
114
+ }
115
+
116
+ long double total_info(size_t categ_counts[], size_t ncat)
117
+ {
118
+ long double s = 0;
119
+ size_t tot = 0;
120
+ for (size_t cat = 0; cat < ncat; cat++) {
121
+ if (categ_counts[cat] > 0) {
122
+ s += (long double)categ_counts[cat] * logl((long double)categ_counts[cat]);
123
+ tot += categ_counts[cat];
124
+ }
125
+ }
126
+ if (tot == 0) return 0;
127
+ return (long double)tot * logl((long double)tot) - s;
128
+ }
129
+
130
+ long double total_info(size_t categ_counts[], size_t ncat, size_t tot)
131
+ {
132
+ if (tot == 0) return 0;
133
+ long double s = 0;
134
+ for (size_t cat = 0; cat < ncat; cat++) {
135
+ if (categ_counts[cat] > 1) {
136
+ s += (long double)categ_counts[cat] * logl((long double)categ_counts[cat]);
137
+ }
138
+ }
139
+ return (long double) tot * logl((long double) tot) - s;
140
+ /* tot = sum(categ_counts[]) */
141
+ }
142
+
143
+ long double total_info(size_t *restrict ix_arr, int *restrict x, size_t st, size_t end, size_t ncat, size_t *restrict buffer_cat_cnt)
144
+ {
145
+ long double info = (long double)(end - st + 1) * logl((long double)(end - st + 1));
146
+ memset(buffer_cat_cnt, 0, ncat * sizeof(size_t));
147
+ for (size_t row = st; row <= end; row++) {
148
+ buffer_cat_cnt[ x[ix_arr[row]] ]++;
149
+ }
150
+ for (size_t cat = 0; cat < ncat; cat++) {
151
+ if (buffer_cat_cnt[cat] > 1) {
152
+ info -= (long double)buffer_cat_cnt[cat] * logl((long double)buffer_cat_cnt[cat]);
153
+ }
154
+ }
155
+ return info;
156
+ }
157
+
158
+ long double categ_gain(CategSplit split_info, long double base_info)
159
+ {
160
+ return (
161
+ base_info -
162
+ total_info(split_info.NA_branch, split_info.ncat, split_info.size_NA) -
163
+ total_info(split_info.left_branch, split_info.ncat, split_info.size_left) -
164
+ total_info(split_info.right_branch, split_info.ncat, split_info.size_right)
165
+ ) / (long double) split_info.tot ;
166
+ }
167
+
168
+ long double categ_gain(size_t *restrict categ_counts, size_t ncat, size_t *restrict ncat_col, size_t maxcat, long double base_info, size_t tot)
169
+ {
170
+ long double info = 0;
171
+ for (size_t cat = 0; cat < ncat; cat++) {
172
+ if (categ_counts[cat] > 0) {
173
+ info += total_info(categ_counts + cat * maxcat, ncat_col[cat]);
174
+ }
175
+ }
176
+
177
+ /* last entry in the array corresponds to NA values */
178
+ if (categ_counts[ncat] > 0) {
179
+ info += total_info(categ_counts + ncat * maxcat, ncat_col[ncat]);
180
+ }
181
+
182
+ return (base_info - info) / (long double) tot;
183
+ }
184
+
185
+ long double categ_gain_from_split(size_t *restrict ix_arr, int *restrict x, size_t st, size_t st_non_na, size_t split_ix, size_t end,
186
+ size_t ncat, size_t *restrict buffer_cat_cnt, long double base_info)
187
+ {
188
+ long double gain = base_info;
189
+ memset(buffer_cat_cnt, 0, ncat * sizeof(size_t));
190
+ if (st_non_na > st) {
191
+ for (size_t row = st; row < st_non_na; row++) {
192
+ buffer_cat_cnt[ x[ix_arr[row]] ]++;
193
+ }
194
+ gain -= total_info(buffer_cat_cnt, ncat, st_non_na - st);
195
+ memset(buffer_cat_cnt, 0, ncat * sizeof(size_t));
196
+ }
197
+
198
+ for (size_t row = st_non_na; row < split_ix; row++) {
199
+ buffer_cat_cnt[ x[ix_arr[row]] ]++;
200
+ }
201
+ gain -= total_info(buffer_cat_cnt, ncat, split_ix - st_non_na);
202
+ memset(buffer_cat_cnt, 0, ncat * sizeof(size_t));
203
+
204
+ for (size_t row = split_ix; row <= end; row++) {
205
+ buffer_cat_cnt[ x[ix_arr[row]] ]++;
206
+ }
207
+ gain -= total_info(buffer_cat_cnt, ncat, end - split_ix + 1);
208
+
209
+ return gain / (long double)(end - st + 1);
210
+ }
211
+
212
+ /* Calculate gain from splitting a numeric column by another numeric column
213
+ *
214
+ * Function splits into buckets (NA, <= threshold, > threshold)
215
+ *
216
+ * Parameters
217
+ * - ix_arr[n] (in)
218
+ * Array containing the indices at which 'x' and 'y' can be accessed, considering only the
219
+ * elements between st and end (i.e. ix_arr[st:end], inclusive of both ends)
220
+ * (Note: will be modified in-place)
221
+ * - st (in)
222
+ * See above.
223
+ * - end (in)
224
+ * See above.
225
+ * - x[n] (in)
226
+ * Numeric column from which a split predicting 'y' will be calculated.
227
+ * - y[n] (in)
228
+ * Numeric column whose distribution wants to be split by 'x'.
229
+ * Must not contain missing values.
230
+ * - sd_y (in)
231
+ * Standard deviation of 'y' between the indices considered here.
232
+ * - has_na (in)
233
+ * Whether 'x' can have missing values or not.
234
+ * - min_size (in)
235
+ * Minimum number of elements that can be in a split.
236
+ * - buffer_sd[n] (in)
237
+ * Buffer where to write temporary sd/information at each split point in a first pass.
238
+ * - gain (out)
239
+ * Gain calculated on the best split found. If no split is possible, will return -Inf.
240
+ * - split_point (out)
241
+ * Threshold for splitting on values of 'x'. If no split is posible, will return -Inf.
242
+ * - split_left (out)
243
+ * Index at which the data is split between the two branches (includes last from left branch).
244
+ * - split_NA (out)
245
+ * Index at which the NA data is separated from the other branches
246
+ */
247
+ void split_numericx_numericy(size_t *restrict ix_arr, size_t st, size_t end, double *restrict x, double *restrict y,
248
+ long double sd_y, bool has_na, size_t min_size, bool take_mid, long double *restrict buffer_sd,
249
+ long double *restrict gain, double *restrict split_point, size_t *restrict split_left, size_t *restrict split_NA)
250
+ {
251
+
252
+ *gain = -HUGE_VAL;
253
+ *split_point = -HUGE_VAL;
254
+ size_t st_non_na;
255
+ long double this_gain;
256
+ long double cnt_dbl = (long double)(end - st + 1);
257
+ long double running_mean = 0;
258
+ long double mean_prev = 0;
259
+ long double running_ssq = 0;
260
+ double xval;
261
+ long double info_left;
262
+ long double info_NA = 0;
263
+
264
+ /* check that there are enough observations for a split */
265
+ if ((end - st + 1) < (2 * min_size)) return;
266
+
267
+ /* move all NAs of X to the front */
268
+ if (has_na) {
269
+ st_non_na = move_NAs_to_front(ix_arr, x, st, end, false);
270
+ } else { st_non_na = st; }
271
+ *split_NA = st_non_na;
272
+
273
+ /* assign NAs to their own branch */
274
+ if (st_non_na > st) {
275
+
276
+ /* first check that it's still possible to split */
277
+ if ((end - st_non_na + 1) < (2 * min_size)) return;
278
+
279
+ info_NA = (long double)(st_non_na - st) * calc_sd(ix_arr, y, st, st_non_na-1, &xval); /* last arg is not used */
280
+ }
281
+
282
+ /* sort the remaining non-NA values in ascending order */
283
+ std::sort(ix_arr + st_non_na, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
284
+
285
+ /* calculate SD*N backwards first, then forwards */
286
+ for (size_t i = end; i >= st_non_na; i--) {
287
+ xval = y[ix_arr[i]];
288
+ running_mean += (xval - running_mean) / (long double)(end - i + 1);
289
+ running_ssq += (xval - running_mean) * (xval - mean_prev);
290
+ mean_prev = running_mean;
291
+ buffer_sd[i] = (long double)(end - i + 1) * sqrtl(running_ssq / (long double)(end - i));
292
+ /* could also avoid div by n-1, would be faster */
293
+
294
+ if (i == st_non_na) break; /* be aware unsigned integer overflow */
295
+ }
296
+
297
+ /* look for the best split point, by moving one observation at a time to the left branch*/
298
+ running_mean = 0;
299
+ running_ssq = 0;
300
+ mean_prev = 0;
301
+ for (size_t i = st_non_na; i <= (end - min_size); i++) {
302
+ xval = y[ix_arr[i]];
303
+ running_mean += (xval - running_mean) / (long double)(i - st_non_na + 1);
304
+ running_ssq += (xval - running_mean) * (xval - mean_prev);
305
+ mean_prev = running_mean;
306
+
307
+ /* check that split meets minimum criteria (size on right branch is controlled in loop condition) */
308
+ if ((i - st_non_na + 1) < min_size) continue;
309
+
310
+ /* check that value is not repeated next -- note that condition in loop prevents out-of-bounds access */
311
+ if (x[ix_arr[i]] == x[ix_arr[i + 1]]) continue;
312
+
313
+ /* evaluate gain at this split point */
314
+ info_left = (long double)(i - st_non_na + 1) * sqrtl(running_ssq / (long double)(i - st_non_na));
315
+ this_gain = numeric_gain(sd_y, info_left, buffer_sd[i + 1], info_NA, cnt_dbl);
316
+ if (this_gain > *gain) {
317
+ *gain = this_gain;
318
+ *split_point = take_mid? (avg_between(x[ix_arr[i]], x[ix_arr[i + 1]])) : (x[ix_arr[i]]);
319
+ *split_left = i;
320
+ }
321
+ }
322
+ }
323
+
324
+ /* Calculate gain from splitting a numeric column by a categorical column
325
+ *
326
+ * Function splits into two subsets + NAs on their own branch
327
+ *
328
+ * Parameters
329
+ * - ix_arr[n] (in)
330
+ * Array containing the indices at which 'x' and 'y' can be accessed, considering only the
331
+ * elements between st and end (i.e. ix_arr[st:end], inclusive of both ends)
332
+ * (Note: will be modified in-place)
333
+ * - st (in)
334
+ * See above.
335
+ * - end (in)
336
+ * See above.
337
+ * - x[n] (in)
338
+ * Categorical column from which a split predicting 'y' will be calculated.
339
+ * Missing values should be encoded as negative integers.
340
+ * - y[n] (in)
341
+ * Numeric column whose distribution wants to be split by 'x'.
342
+ * Must not contain missing values.
343
+ * - sd_y (in)
344
+ * Standard deviation of 'y' between the indices considered here.
345
+ * - x_is_ordinal (in)
346
+ * Whether the 'x' column has ordered categories, in which case the split will be a
347
+ * <= that respects this order.
348
+ * - ncat_x (in)
349
+ * Number of categories in 'x' (excluding NA).
350
+ * - buffer_cat_cnt[ncat_x + 1] (temp)
351
+ * Array where temporary data for each category will be written into.
352
+ * Must have one additional entry anove the number of categories to account for NAs.
353
+ * - buffer_cat_sum[ncat_x + 1] (temp)
354
+ * See above.
355
+ * - buffer_cat_sum_sq[ncat_x + 1] (temp)
356
+ * See above.
357
+ * - buffer_cat_sorted[ncat_x] (temp)
358
+ * See above. This one doesn't need an extra entry.
359
+ * - has_na (in)
360
+ * Whether 'x' can have missing values or not.
361
+ * - min_size (in)
362
+ * Minimum number of elements that can be in a split.
363
+ * - gain (out)
364
+ * Gain calculated on the best split found. If no split is possible, will return -Inf.
365
+ * - split_subset[ncat_x] (out)
366
+ * Array that will indicate which categories go into the left branch in the chosen split.
367
+ * (value of 1 means it's on the left branch, 0 in the right branch, -1 not applicable)
368
+ * - split_point (out)
369
+ * Split level for ordinal X variables (left branch is <= this)
370
+ */
371
+ void split_categx_numericy(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, double *restrict y, long double sd_y, double ymean,
372
+ bool x_is_ordinal, size_t ncat_x, size_t *restrict buffer_cat_cnt, long double *restrict buffer_cat_sum,
373
+ long double *restrict buffer_cat_sum_sq, size_t *restrict buffer_cat_sorted,
374
+ bool has_na, size_t min_size, long double *gain, char *restrict split_subset, int *restrict split_point)
375
+ {
376
+
377
+ /* output parameters and variables to use */
378
+ *gain = -HUGE_VAL;
379
+ long double this_gain;
380
+ NumericSplit split_info;
381
+ size_t st_cat = 0;
382
+ double sd_y_d = (double) sd_y;
383
+
384
+ /* reset the buffers */
385
+ memset(split_subset, 0, sizeof(char) * ncat_x);
386
+ memset(buffer_cat_cnt, 0, sizeof(size_t) * (ncat_x + 1));
387
+ memset(buffer_cat_sum, 0, sizeof(long double) * (ncat_x + 1));
388
+ memset(buffer_cat_sum_sq, 0, sizeof(long double) * (ncat_x + 1));
389
+
390
+ /* calculate summary info for each category */
391
+ if (has_na) {
392
+
393
+ for (size_t i = st; i <= end; i++) {
394
+
395
+ /* NAs are encoded as negative integers, and go at the last slot */
396
+ if (x[ix_arr[i]] < 0) {
397
+ buffer_cat_cnt[ncat_x]++;
398
+ buffer_cat_sum[ncat_x] += z_score(y[ix_arr[i]], ymean, sd_y_d);
399
+ buffer_cat_sum_sq[ncat_x] += square(z_score(y[ix_arr[i]], ymean, sd_y_d));
400
+ } else {
401
+ buffer_cat_cnt[ x[ix_arr[i]] ]++;
402
+ buffer_cat_sum[ x[ix_arr[i]] ] += z_score(y[ix_arr[i]], ymean, sd_y_d);
403
+ buffer_cat_sum_sq[ x[ix_arr[i]] ] += square(z_score(y[ix_arr[i]], ymean, sd_y_d));
404
+ }
405
+ }
406
+
407
+ } else {
408
+
409
+ buffer_cat_cnt[ncat_x] = 0;
410
+ for (size_t i = st; i <= end; i++) {
411
+ buffer_cat_cnt[ x[ix_arr[i]] ]++;
412
+ buffer_cat_sum[ x[ix_arr[i]] ] += z_score(y[ix_arr[i]], ymean, sd_y_d);
413
+ buffer_cat_sum_sq[ x[ix_arr[i]] ] += square(z_score(y[ix_arr[i]], ymean, sd_y_d));
414
+ }
415
+
416
+ }
417
+
418
+ /* set NAs to their own branch */
419
+ if (buffer_cat_cnt[ncat_x] > 0) {
420
+ split_info.NA_branch = {buffer_cat_cnt[ncat_x], buffer_cat_sum[ncat_x], buffer_cat_sum_sq[ncat_x]};
421
+ }
422
+
423
+ /* easy case: binary split (only one possible split point) */
424
+ if (ncat_x == 2) {
425
+
426
+ /* must still meet minimum size requirements */
427
+ if (buffer_cat_cnt[0] < min_size || buffer_cat_cnt[1] < min_size) return;
428
+
429
+ split_info.left_branch = {buffer_cat_cnt[0], buffer_cat_sum[0], buffer_cat_sum_sq[0]};
430
+ split_info.right_branch = {buffer_cat_cnt[1], buffer_cat_sum[1], buffer_cat_sum_sq[1]};
431
+ *gain = numeric_gain(split_info, 1.0) * sd_y;
432
+ split_subset[0] = 1;
433
+ }
434
+
435
+ /* subset and ordinal splits */
436
+ else {
437
+
438
+ /* put all the categories on the right branch */
439
+ for (size_t cat = 0; cat < ncat_x; cat++) {
440
+ split_info.right_branch.cnt += buffer_cat_cnt[cat];
441
+ split_info.right_branch.sum += buffer_cat_sum[cat];
442
+ split_info.right_branch.sum_sq += buffer_cat_sum_sq[cat];
443
+ }
444
+
445
+ /* if it's an ordinal variable, must respect the order */
446
+ for (size_t cat = 0; cat < ncat_x; cat++) buffer_cat_sorted[cat] = cat;
447
+
448
+ if (!x_is_ordinal) {
449
+ /* otherwise, sort the categories according to their mean of y */
450
+
451
+ /* first remove zero-counts */
452
+ st_cat = move_zero_count_to_front(buffer_cat_sorted, buffer_cat_cnt, ncat_x);
453
+
454
+ /* then sort */
455
+ std::sort(buffer_cat_sorted + st_cat, buffer_cat_sorted + ncat_x,
456
+ [&buffer_cat_sum, &buffer_cat_cnt](const size_t a, const size_t b)
457
+ {
458
+ return (buffer_cat_sum[a] / (long double) buffer_cat_cnt[a]) >
459
+ (buffer_cat_sum[b] / (long double) buffer_cat_cnt[b]);
460
+ });
461
+ }
462
+
463
+ /* try moving each category to the left branch in the given order */
464
+ for (size_t cat = st_cat; cat < ncat_x; cat++) {
465
+ split_info.right_branch.cnt -= buffer_cat_cnt[ buffer_cat_sorted[cat] ];
466
+ split_info.right_branch.sum -= buffer_cat_sum[ buffer_cat_sorted[cat] ];
467
+ split_info.right_branch.sum_sq -= buffer_cat_sum_sq[ buffer_cat_sorted[cat] ];
468
+
469
+ split_info.left_branch.cnt += buffer_cat_cnt[ buffer_cat_sorted[cat] ];
470
+ split_info.left_branch.sum += buffer_cat_sum[ buffer_cat_sorted[cat] ];
471
+ split_info.left_branch.sum_sq += buffer_cat_sum_sq[ buffer_cat_sorted[cat] ];
472
+
473
+ /* see if it meets minimum split sizes */
474
+ if (split_info.left_branch.cnt < min_size || split_info.right_branch.cnt < min_size) continue;
475
+
476
+ /* calculate the gain */
477
+ this_gain = numeric_gain(split_info, 1.0);
478
+ if (this_gain > *gain) {
479
+ *gain = this_gain * sd_y;
480
+ if (!x_is_ordinal)
481
+ subset_to_onehot(buffer_cat_sorted, cat, ncat_x, split_subset);
482
+ else
483
+ *split_point = (int) cat;
484
+ }
485
+ }
486
+
487
+ /* if it's categorical, set the non-present categories to -1 */
488
+ if (!is_na_or_inf(*gain) && !x_is_ordinal) flag_zero_counts(split_subset, buffer_cat_cnt, ncat_x);
489
+
490
+ }
491
+
492
+ }
493
+
494
+
495
+
496
+ /* Calculate gain from splitting a categorical column by a numeric column
497
+ *
498
+ * Function splits into two subsets + NAs on their own branch
499
+ *
500
+ * Parameters
501
+ * - ix_arr[n] (in)
502
+ * Array containing the indices at which 'x' and 'y' can be accessed, considering only the
503
+ * elements between st and end (i.e. ix_arr[st:end], inclusive of both ends)
504
+ * (Note: will be modified in-place)
505
+ * - st (in)
506
+ * See above.
507
+ * - end (in)
508
+ * See above.
509
+ * - x[n] (in)
510
+ * Numerical column from which a split predicting 'y' will be calculated.
511
+ * - y[n] (in)
512
+ * Categorical column whose distributions are to be split by 'x'.
513
+ * Must not contain missing values (which are encoded as negative integers).
514
+ * - ncat_y (in)
515
+ * Number of categories in 'y' (excluding NAs, which are encoded as negative integers).
516
+ * - base_info (in)
517
+ * Base information for the 'y' counts before splitting.
518
+ * (:= N*log(N) - sum_i..m N_i*log(N_i))
519
+ * - buffer_cat_cnt[ncat_y * 3] (temp)
520
+ * Array where temporary data for each category will be written into.
521
+ * - has_na (in)
522
+ * Whether 'x' can have missing values or not.
523
+ * - min_size (in)
524
+ * Minimum number of elements that can be in a split.
525
+ * - gain (out)
526
+ * Gain calculated on the best split found. If no split is possible, will return -Inf.
527
+ * - split_point (out)
528
+ * Threshold for splitting on values of 'x'. If no split is posible, will return -Inf.
529
+ * - split_left (out)
530
+ * Index at which the data is split between the two branches (includes last from left branch).
531
+ * - split_NA (out)
532
+ * Index at which the NA data is separated from the other branches
533
+ */
534
+ void split_numericx_categy(size_t *restrict ix_arr, size_t st, size_t end, double *restrict x, int *restrict y,
535
+ size_t ncat_y, long double base_info, size_t *restrict buffer_cat_cnt,
536
+ bool has_na, size_t min_size, bool take_mid, long double *restrict gain, double *restrict split_point,
537
+ size_t *restrict split_left, size_t *restrict split_NA)
538
+ {
539
+ *gain = -HUGE_VAL;
540
+ *split_point = -HUGE_VAL;
541
+ size_t st_non_na;
542
+ long double this_gain;
543
+ CategSplit split_info;
544
+ split_info.ncat = ncat_y;
545
+ split_info.tot = end - st + 1;
546
+
547
+ /* check that there are enough observations for a split */
548
+ if ((end - st + 1) < (2 * min_size)) return;
549
+
550
+ /* will divide into 3 branches: NA, <= p, > p */
551
+ memset(buffer_cat_cnt, 0, 3 * ncat_y * sizeof(size_t));
552
+ split_info.NA_branch = buffer_cat_cnt;
553
+ split_info.left_branch = buffer_cat_cnt + ncat_y;
554
+ split_info.right_branch = buffer_cat_cnt + 2 * ncat_y;
555
+
556
+ /* move all NAs of X to the front */
557
+ if (has_na) {
558
+ st_non_na = move_NAs_to_front(ix_arr, x, st, end, false);
559
+ } else { st_non_na = st; }
560
+ *split_NA = st_non_na;
561
+
562
+ /* assign NAs to their own branch */
563
+ split_info.size_NA = st_non_na - st;
564
+ if (st_non_na > st) {
565
+
566
+ /* first check that it's still possible to split */
567
+ if ((end - st_non_na + 1) < (2 * min_size)) return;
568
+
569
+ for (size_t i = st; i < st_non_na; i++) split_info.NA_branch[ y[ix_arr[i]] ]++;
570
+ }
571
+
572
+ /* sort the remaining non-NA values in ascending order */
573
+ std::sort(ix_arr + st_non_na, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
574
+
575
+ /* put all observations on the right branch */
576
+ for (size_t i = st_non_na; i <= end; i++) split_info.right_branch[ y[ix_arr[i]] ]++;
577
+
578
+ /* look for the best split point, by moving one observation at a time to the left branch*/
579
+ for (size_t i = st_non_na; i <= (end - min_size); i++) {
580
+ split_info.right_branch[ y[ix_arr[i]] ]--;
581
+ split_info.left_branch [ y[ix_arr[i]] ]++;
582
+ split_info.size_left = i - st_non_na + 1;
583
+ split_info.size_right = end - i;
584
+
585
+ /* check that split meets minimum criteria (size on right branch is controlled in loop condition) */
586
+ if (split_info.size_left < min_size) continue;
587
+
588
+ /* check that value is not repeated next -- note that condition in loop prevents out-of-bounds access */
589
+ if (x[ix_arr[i]] == x[ix_arr[i + 1]]) continue;
590
+
591
+ /* evaluate gain at this split point */
592
+ this_gain = categ_gain(split_info, base_info);
593
+ if (this_gain > *gain) {
594
+ *gain = this_gain;
595
+ *split_point = take_mid? (avg_between(x[ix_arr[i]], x[ix_arr[i + 1]])) : (x[ix_arr[i]]);
596
+ *split_left = i;
597
+ }
598
+ }
599
+ }
600
+
601
+ /* Calculate gain from splitting a categorical column by an ordinal column
602
+ *
603
+ * Function splits into two subsets + NAs on their own branch
604
+ *
605
+ * Parameters
606
+ * - ix_arr[n] (in)
607
+ * Array containing the indices at which 'x' and 'y' can be accessed, considering only the
608
+ * elements between st and end (i.e. ix_arr[st:end], inclusive of both ends)
609
+ * (Note: will be modified in-place)
610
+ * - st (in)
611
+ * See above.
612
+ * - end (in)
613
+ * See above.
614
+ * - x[n] (in)
615
+ * Ordinal column from which a split predicting 'y' will be calculated.
616
+ * Missing values must be encoded as negative integers.
617
+ * - y[n] (in)
618
+ * Categorical column whose distributions are to be split by 'x'.
619
+ * Must not contain missing values (which are encoded as negative integers).
620
+ * - ncat_y (in)
621
+ * Number of categories in 'y' (excluding NAs, which are encoded as negative integers).
622
+ * - ncat_x (in)
623
+ * Number of categories in 'x' (excluding NAs, which are encoded as negative integers).
624
+ * - base_info (in)
625
+ * Base information for the 'y' counts before splitting.
626
+ * (:= N*log(N) - sum_i..m N_i*log(N_i))
627
+ * - buffer_cat_cnt[ncat_y * 3] (temp)
628
+ * Array where temporary data for each category will be written into.
629
+ * - buffer_crosstab[ncat_x * ncat_y] (temp)
630
+ * See above.
631
+ * - buffer_ord_cnt[ncat_x] (temp)
632
+ * See above.
633
+ * - has_na (in)
634
+ * Whether 'x' can have missing values or not.
635
+ * - min_size (in)
636
+ * Minimum number of elements that can be in a split.
637
+ * - gain (out)
638
+ * Gain calculated on the best split found. If no split is possible, will return -Inf.
639
+ * - split_point (out)
640
+ * Threshold for splitting on values of 'x'. If no split is posible, will return -1.
641
+ */
642
+ void split_ordx_categy(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int *restrict y,
643
+ size_t ncat_y, size_t ncat_x, long double base_info,
644
+ size_t *restrict buffer_cat_cnt, size_t *restrict buffer_crosstab, size_t *restrict buffer_ord_cnt,
645
+ bool has_na, size_t min_size, long double *gain, int *split_point)
646
+ {
647
+ *gain = -HUGE_VAL;
648
+ *split_point = -1;
649
+ size_t st_non_na;
650
+ long double this_gain;
651
+ CategSplit split_info;
652
+ split_info.ncat = ncat_y;
653
+ split_info.tot = end - st + 1;
654
+
655
+ /* check that there are enough observations for a split */
656
+ if ((end - st + 1) < (2 * min_size)) return;
657
+
658
+ /* will divide into 3 branches: NA, <= p, > p */
659
+ memset(buffer_cat_cnt, 0, 3 * ncat_y * sizeof(size_t));
660
+ split_info.NA_branch = buffer_cat_cnt;
661
+ split_info.left_branch = buffer_cat_cnt + ncat_y;
662
+ split_info.right_branch = buffer_cat_cnt + 2 * ncat_y;
663
+
664
+ /* move all NAs of X to the front */
665
+ if (has_na) {
666
+ st_non_na = move_NAs_to_front(ix_arr, x, st, end);
667
+ } else { st_non_na = st; }
668
+
669
+ /* assign NAs to their own branch */
670
+ split_info.size_NA = st_non_na - st;
671
+ if (st_non_na > st) {
672
+
673
+ /* first check that it's still possible to split */
674
+ if ((end - st_non_na + 1) < (2 * min_size)) return;
675
+
676
+ for (size_t i = st; i < st_non_na; i++) split_info.NA_branch[ y[ix_arr[i]] ]++;
677
+ }
678
+
679
+ /* calculate cross-table on the non-missing cases, and put all observations in the right branch */
680
+ memset(buffer_crosstab, 0, ncat_y * ncat_x * sizeof(size_t));
681
+ memset(buffer_ord_cnt, 0, ncat_x * sizeof(size_t));
682
+ for (size_t i = st_non_na; i <= end; i++) {
683
+ buffer_crosstab[ y[ix_arr[i]] + ncat_y * x[ix_arr[i]] ]++;
684
+ buffer_ord_cnt [ x[ix_arr[i]] ]++;
685
+ split_info.right_branch[ y[ix_arr[i]] ]++;
686
+ }
687
+ split_info.size_right = end - st_non_na + 1;
688
+ split_info.size_left = 0;
689
+
690
+ /* look for the best split point, by moving one observation at a time to the left branch*/
691
+ for (size_t ord_cat = 0; ord_cat < (ncat_x - 1); ord_cat++) {
692
+
693
+ for (size_t moved_cat = 0; moved_cat < ncat_y; moved_cat++) {
694
+ split_info.right_branch[ moved_cat ] -= buffer_crosstab[ moved_cat + ncat_y * ord_cat ];
695
+ split_info.left_branch [ moved_cat ] += buffer_crosstab[ moved_cat + ncat_y * ord_cat ];
696
+ }
697
+ split_info.size_right -= buffer_ord_cnt[ord_cat];
698
+ split_info.size_left += buffer_ord_cnt[ord_cat];
699
+
700
+ /* check that split meets minimum criteria */
701
+ if (split_info.size_left < min_size || split_info.size_right < min_size) continue;
702
+
703
+ /* evaluate gain at this split point */
704
+ this_gain = categ_gain(split_info, base_info);
705
+ if (this_gain > *gain) {
706
+ *gain = this_gain;
707
+ *split_point = ord_cat;
708
+ }
709
+ }
710
+ }
711
+
712
+
713
+ /* Calculate gain from splitting a binary column by a categorical column
714
+ *
715
+ * Function splits into two subsets + NAs on their own branch
716
+ *
717
+ * Parameters
718
+ * - ix_arr[n] (in)
719
+ * Array containing the indices at which 'x' and 'y' can be accessed, considering only the
720
+ * elements between st and end (i.e. ix_arr[st:end], inclusive of both ends)
721
+ * (Note: will be modified in-place)
722
+ * - st (in)
723
+ * See above.
724
+ * - end (in)
725
+ * See above.
726
+ * - x[n] (in)
727
+ * Categorical column from which a split predicting 'y' will be calculated.
728
+ * Missing values must be encoded as negative integers.
729
+ * - y[n] (in)
730
+ * Binary column whose distributions are to be split by 'x'.
731
+ * Must not contain missing values (which are encoded as negative integers).
732
+ * - ncat_x (in)
733
+ * Number of categories in 'x' (excluding NAs, which are encoded as negative integers).
734
+ * - base_info (in)
735
+ * Base information for the 'y' counts before splitting.
736
+ * (:= N*log(N) - sum_i..m N_i*log(N_i))
737
+ * - buffer_cat_cnt[ncat_x] (temp)
738
+ * Array where temporary data for each category will be written into.
739
+ * - buffer_crosstab[2 * ncat_x] (temp)
740
+ * See above.
741
+ * - buffer_cat_sorted[ncat_x] (temp)
742
+ * See above.
743
+ * - has_na (in)
744
+ * Whether 'x' can have missing values or not.
745
+ * - min_size (in)
746
+ * Minimum number of elements that can be in a split.
747
+ * - gain (out)
748
+ * Gain calculated on the best split found. If no split is possible, will return -Inf.
749
+ * - split_subset[ncat_x] (out)
750
+ * Array that will indicate which categories go into the left branch in the chosen split.
751
+ * (value of 1 means it's on the left branch, 0 in the right branch, -1 not applicable)
752
+ */
753
+ void split_categx_biny(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int *restrict y,
754
+ size_t ncat_x, long double base_info,
755
+ size_t *restrict buffer_cat_cnt, size_t *restrict buffer_crosstab, size_t *restrict buffer_cat_sorted,
756
+ bool has_na, size_t min_size, long double *gain, char *restrict split_subset)
757
+ {
758
+ *gain = -HUGE_VAL;
759
+ size_t st_non_na;
760
+ long double this_gain;
761
+ size_t buffer_fixed_size[6] = {0};
762
+ CategSplit split_info;
763
+ size_t st_cat;
764
+ split_info.ncat = 2;
765
+ split_info.tot = end - st + 1;
766
+
767
+ /* check that there are enough observations for a split */
768
+ if ((end - st + 1) < (2 * min_size)) return;
769
+
770
+ /* will divide into 3 branches: NA, <= p, > p */
771
+ split_info.NA_branch = buffer_fixed_size;
772
+ split_info.left_branch = buffer_fixed_size + 2;
773
+ split_info.right_branch = buffer_fixed_size + 2 * 2;
774
+
775
+ /* move all NAs of X to the front */
776
+ if (has_na) {
777
+ st_non_na = move_NAs_to_front(ix_arr, x, st, end);
778
+ } else { st_non_na = st; }
779
+
780
+ /* assign NAs to their own branch */
781
+ split_info.size_NA = st_non_na - st;
782
+ if (st_non_na > st) {
783
+
784
+ /* first check that it's still possible to split */
785
+ if ((end - st_non_na + 1) < (2 * min_size)) return;
786
+
787
+ for (size_t i = st; i < st_non_na; i++) split_info.NA_branch[ y[ix_arr[i]] ]++;
788
+ }
789
+
790
+ /* calculate cross-table on the non-missing cases, and put all observations in the right branch */
791
+ memset(buffer_crosstab, 0, 2 * ncat_x * sizeof(size_t));
792
+ memset(buffer_cat_cnt, 0, ncat_x * sizeof(size_t));
793
+ for (size_t i = st_non_na; i <= end; i++) {
794
+ buffer_crosstab[ y[ix_arr[i]] + 2 * x[ix_arr[i]] ]++;
795
+ buffer_cat_cnt [ x[ix_arr[i]] ]++;
796
+ split_info.right_branch[ y[ix_arr[i]] ]++;
797
+ }
798
+ split_info.size_right = end - st_non_na + 1;
799
+ split_info.size_left = 0;
800
+
801
+ /* sort the categories according to their mean of y */
802
+ for (size_t cat = 0; cat < ncat_x; cat++) buffer_cat_sorted[cat] = cat;
803
+ st_cat = move_zero_count_to_front(buffer_cat_sorted, buffer_cat_cnt, ncat_x);
804
+ std::sort(buffer_cat_sorted + st_cat, buffer_cat_sorted + ncat_x,
805
+ [&buffer_crosstab, &buffer_cat_cnt](const size_t a, const size_t b)
806
+ {
807
+ return ((long double) buffer_crosstab[2 * a] / (long double) buffer_cat_cnt[a]) >
808
+ ((long double) buffer_crosstab[2 * b] / (long double) buffer_cat_cnt[b]);
809
+ });
810
+
811
+ /* look for the best split subset, by moving one category at a time to the left branch*/
812
+ for (size_t cat = st_cat; cat < (ncat_x - 1); cat++) {
813
+
814
+ split_info.right_branch[0] -= buffer_crosstab[2 * buffer_cat_sorted[cat]];
815
+ split_info.right_branch[1] -= buffer_crosstab[2 * buffer_cat_sorted[cat] + 1];
816
+ split_info.left_branch [0] += buffer_crosstab[2 * buffer_cat_sorted[cat]];
817
+ split_info.left_branch [1] += buffer_crosstab[2 * buffer_cat_sorted[cat] + 1];
818
+ split_info.size_right -= buffer_cat_cnt [buffer_cat_sorted[cat]];
819
+ split_info.size_left += buffer_cat_cnt [buffer_cat_sorted[cat]];
820
+
821
+ /* check that split meets minimum criteria */
822
+ if (split_info.size_left < min_size || split_info.size_right < min_size) continue;
823
+
824
+ /* evaluate gain at this split point */
825
+ this_gain = categ_gain(split_info, base_info);
826
+ if (this_gain > *gain) {
827
+ *gain = this_gain;
828
+ subset_to_onehot(buffer_cat_sorted, cat, ncat_x, split_subset);
829
+ }
830
+ }
831
+ if (!is_na_or_inf(*gain)) flag_zero_counts(split_subset, buffer_cat_cnt, ncat_x);
832
+ }
833
+
834
+
835
+ /* Calculate gain from splitting a categorical columns by another categorical column
836
+ *
837
+ * Function splits into one branch per category of 'x'
838
+ *
839
+ * Parameters
840
+ * - ix_arr[n] (in)
841
+ * Array containing the indices at which 'x' and 'y' can be accessed, considering only the
842
+ * elements between st and end (i.e. ix_arr[st:end], inclusive of both ends)
843
+ * (Note: will be modified in-place)
844
+ * - st (in)
845
+ * See above.
846
+ * - end (in)
847
+ * See above.
848
+ * - x[n] (in)
849
+ * Categorical column from which a split predicting 'y' will be calculated.
850
+ * Missing values must be encoded as negative integers.
851
+ * - y[n] (in)
852
+ * Categorical column whose distributions are to be split by 'x'.
853
+ * Must not contain missing values (which are encoded as negative integers).
854
+ * - ncat_x (in)
855
+ * Number of categories in 'x' (excluding NAs, which are encoded as negative integers).
856
+ * - ncat_y (in)
857
+ * Number of categories in 'y'.
858
+ * - base_info (in)
859
+ * Base information for the 'y' counts before splitting.
860
+ * (:= N*log(N) - sum_i..m N_i*log(N_i))
861
+ * - buffer_cat_cnt[ncat_x + 1] (temp)
862
+ * Array where temporary data for each category will be written into.
863
+ * - buffer_crosstab[(ncat_x + 1) * ncat_y] (temp)
864
+ * See above.
865
+ * - has_na (in)
866
+ * Whether 'x' can have missing values or not.
867
+ * - gain (out)
868
+ * Gain calculated on the split. If no split is possible, will return -Inf.
869
+ */
870
+ void split_categx_categy_separate(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int *restrict y,
871
+ size_t ncat_x, size_t ncat_y, long double base_info,
872
+ size_t *restrict buffer_cat_cnt, size_t *restrict buffer_crosstab,
873
+ bool has_na, size_t min_size, long double *gain)
874
+ {
875
+ long double this_gain = 0;
876
+ size_t st_non_na;
877
+
878
+ /* move all NAs of X to the front */
879
+ if (has_na) {
880
+ st_non_na = move_NAs_to_front(ix_arr, x, st, end);
881
+ } else { st_non_na = st; }
882
+
883
+ /* calculate cross-table on the non-missing cases */
884
+ memset(buffer_crosstab, 0, ncat_y * (ncat_x + 1) * sizeof(size_t));
885
+ memset(buffer_cat_cnt, 0, (ncat_x + 1) * sizeof(size_t));
886
+ for (size_t i = st_non_na; i <= end; i++) {
887
+ buffer_crosstab[ y[ix_arr[i]] + ncat_y * x[ix_arr[i]] ]++;
888
+ buffer_cat_cnt [ x[ix_arr[i]] ]++;
889
+ }
890
+
891
+ /* if no category meets the minimum split size, end here */
892
+ if (*std::max_element(buffer_cat_cnt, buffer_cat_cnt + (ncat_x + 1)) < min_size) {
893
+ *gain = -HUGE_VAL;
894
+ return;
895
+ }
896
+
897
+ /* calculate gain for splitting at each category */
898
+ for (size_t cat = 0; cat < ncat_x; cat++) {
899
+ this_gain += total_info(buffer_crosstab + cat * ncat_y, ncat_y, buffer_cat_cnt[cat]);
900
+ }
901
+
902
+ /* add the split on missing x */
903
+ if (st_non_na > st) {
904
+ for (size_t i = st; i < st_non_na; i++) {
905
+ buffer_crosstab[ y[ix_arr[i]] + ncat_y * ncat_x ]++;
906
+ buffer_cat_cnt [ ncat_x ]++;
907
+ }
908
+ this_gain += total_info(buffer_crosstab + ncat_x * ncat_y, ncat_y, buffer_cat_cnt[ncat_x]);
909
+ }
910
+
911
+ /* return calculated gain */
912
+ *gain = (base_info - this_gain) / (long double) (end - st + 1);
913
+ }
914
+
915
+
916
+ /* Calculate gain from splitting a categorical column by another categorical column
917
+ *
918
+ * Function splits into two subsets + NAs on their own branch
919
+ *
920
+ * Parameters
921
+ * - ix_arr[n] (in)
922
+ * Array containing the indices at which 'x' and 'y' can be accessed, considering only the
923
+ * elements between st and end (i.e. ix_arr[st:end], inclusive of both ends)
924
+ * (Note: will be modified in-place)
925
+ * - st (in)
926
+ * See above.
927
+ * - end (in)
928
+ * See above.
929
+ * - x[n] (in)
930
+ * Categorical column from which a split predicting 'y' will be calculated.
931
+ * Missing values must be encoded as negative integers.
932
+ * - y[n] (in)
933
+ * Categorical column whose distributions are to be split by 'x'.
934
+ * Must not contain missing values (which are encoded as negative integers).
935
+ * - ncat_x (in)
936
+ * Number of categories in 'x' (excluding NAs, which are encoded as negative integers).
937
+ * - ncat_y (in)
938
+ * Number of categories in 'x'.
939
+ * - base_info (in)
940
+ * Base information for the 'y' counts before splitting.
941
+ * (:= N*log(N) - sum_i..m N_i*log(N_i))
942
+ * - buffer_cat_cnt[ncat_x] (temp)
943
+ * Array where temporary data for each category will be written into.
944
+ * - buffer_crosstab[ncat_x * ncat_y] (temp)
945
+ * See above.
946
+ * - buffer_split[3 * ncat_y] (temp)
947
+ * See above.
948
+ * - has_na (in)
949
+ * Whether 'x' can have missing values or not.
950
+ * - min_size (in)
951
+ * Minimum number of elements that can be in a split.
952
+ * - gain (out)
953
+ * Gain calculated on the best split found. If no split is possible, will return -Inf.
954
+ * - split_subset[ncat_x] (out)
955
+ * Array that will indicate which categories go into the left branch in the chosen split.
956
+ * (value of 1 means it's on the left branch, 0 in the right branch, -1 not applicable)
957
+ */
958
+ void split_categx_categy_subset(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int *restrict y,
959
+ size_t ncat_x, size_t ncat_y, long double base_info,
960
+ size_t *restrict buffer_cat_cnt, size_t *restrict buffer_crosstab, size_t *restrict buffer_split,
961
+ bool has_na, size_t min_size, long double *gain, char *restrict split_subset)
962
+ {
963
+ *gain = -HUGE_VAL;
964
+ long double this_gain;
965
+ size_t best_subset;
966
+ CategSplit split_info;
967
+ split_info.tot = end - st + 1;
968
+ split_info.ncat = ncat_y;
969
+ size_t st_non_na;
970
+
971
+ /* will divide into 3 branches: NA, within subset, outside subset */
972
+ memset(buffer_split, 0, 3 * ncat_y * sizeof(size_t));
973
+ split_info.NA_branch = buffer_split;
974
+ split_info.left_branch = buffer_split + ncat_y;
975
+ split_info.right_branch = buffer_split + 2 * ncat_y;
976
+
977
+ /* move all NAs of X to the front */
978
+ if (has_na) {
979
+ st_non_na = move_NAs_to_front(ix_arr, x, st, end);
980
+ } else { st_non_na = st; }
981
+ split_info.size_NA = st_non_na - st;
982
+
983
+ /* calculate cross-table */
984
+ memset(buffer_crosstab, 0, ncat_y * ncat_x * sizeof(size_t));
985
+ memset(buffer_cat_cnt, 0, ncat_x * sizeof(size_t));
986
+ for (size_t i = st_non_na; i <= end; i++) {
987
+ buffer_crosstab[ y[ix_arr[i]] + ncat_y * x[ix_arr[i]] ]++;
988
+ buffer_cat_cnt [ x[ix_arr[i]] ]++;
989
+ }
990
+ if (st_non_na > st) {
991
+ for (size_t i = st; i < st_non_na; i++) {
992
+ split_info.NA_branch[ y[ix_arr[i]] ]++;
993
+ }
994
+ }
995
+
996
+ /* put all categories on the right branch */
997
+ memset(split_info.left_branch, 0, ncat_y * sizeof(size_t));
998
+ memset(split_info.right_branch, 0, ncat_y * sizeof(size_t));
999
+ split_info.size_left = 0;
1000
+ split_info.size_right = 0;
1001
+ for (size_t catx = 0; catx < ncat_x; catx++) {
1002
+ for (size_t caty = 0; caty < ncat_y; caty++) {
1003
+ split_info.right_branch[caty] += buffer_crosstab[caty + catx * ncat_y];
1004
+ }
1005
+ split_info.size_right += buffer_cat_cnt[catx];
1006
+ }
1007
+
1008
+ /* TODO: don't loop over categories with zero-counts everywhere */
1009
+
1010
+ /* do a brute-force search over all possible subset splits (there's [2^ncat_x - 2] of them) */
1011
+ size_t curr_exponent = 0;
1012
+ size_t last_bit;
1013
+ size_t ncomb = pow2(ncat_x) - 1;
1014
+
1015
+ /* iteration is done by putting a category in the left branch if the bit at its
1016
+ position in the binary representation of the combination number is a 1 */
1017
+ /* TODO: this would be faster with a depth-first search routine */
1018
+ for (size_t combin = 1; combin < ncomb; combin++) {
1019
+
1020
+ /* at each iteration, move the bits that differ from one branch to the other */
1021
+ /* note however than when there are few categories, it's actually faster to recalculate
1022
+ the counts based on the bitset -- this code however still follows this more "smart" way
1023
+ of moving cateogries when needed, which makes it slightly more scalable */
1024
+
1025
+ /* at any given number, the bits can only vary up a certain bit from an increase by one,
1026
+ which can be obtained from calculating the maximum power of two that is smaller than
1027
+ the combination number */
1028
+ if (combin == pow2(curr_exponent)) {
1029
+ curr_exponent++;
1030
+ last_bit = (size_t) curr_exponent - 1;
1031
+
1032
+ /* when this happens, this specific bit will change from a zero to a one,
1033
+ while the ones before will change from ones to zeros */
1034
+ for (size_t caty = 0; caty < ncat_y; caty++) {
1035
+ split_info.right_branch[caty] -= buffer_crosstab[caty + last_bit * ncat_y];
1036
+ split_info.left_branch [caty] += buffer_crosstab[caty + last_bit * ncat_y];
1037
+ }
1038
+ split_info.size_left += buffer_cat_cnt[last_bit];
1039
+ split_info.size_right -= buffer_cat_cnt[last_bit];
1040
+
1041
+ for (size_t catx = 0; catx < last_bit; catx++) {
1042
+ for (size_t caty = 0; caty < ncat_y; caty++) {
1043
+ split_info.left_branch [caty] -= buffer_crosstab[caty + catx * ncat_y];
1044
+ split_info.right_branch[caty] += buffer_crosstab[caty + catx * ncat_y];
1045
+ }
1046
+ split_info.size_left -= buffer_cat_cnt[catx];
1047
+ split_info.size_right += buffer_cat_cnt[catx];
1048
+ }
1049
+
1050
+ } else {
1051
+
1052
+ /* in the regular case, just inspect the bits that come before the exponent in the current
1053
+ power of two that is less than the combination number, and see if a category needs to be moved */
1054
+ for (size_t catx = 0; catx < last_bit; catx++) {
1055
+ if (extract_bit(combin, catx) != extract_bit(combin - 1, catx)) {
1056
+
1057
+ if (extract_bit(combin - 1, catx)) {
1058
+ for (size_t caty = 0; caty < ncat_y; caty++) {
1059
+ split_info.left_branch [caty] -= buffer_crosstab[caty + catx * ncat_y];
1060
+ split_info.right_branch[caty] += buffer_crosstab[caty + catx * ncat_y];
1061
+ }
1062
+ split_info.size_left -= buffer_cat_cnt[catx];
1063
+ split_info.size_right += buffer_cat_cnt[catx];
1064
+ } else {
1065
+ for (size_t caty = 0; caty < ncat_y; caty++) {
1066
+ split_info.left_branch [caty] += buffer_crosstab[caty + catx * ncat_y];
1067
+ split_info.right_branch[caty] -= buffer_crosstab[caty + catx * ncat_y];
1068
+ }
1069
+ split_info.size_left += buffer_cat_cnt[catx];
1070
+ split_info.size_right -= buffer_cat_cnt[catx];
1071
+ }
1072
+
1073
+ }
1074
+ }
1075
+
1076
+ }
1077
+
1078
+ /* check that split meets minimum criteria */
1079
+ if (split_info.size_left < min_size || split_info.size_right < min_size) continue;
1080
+
1081
+ /* now evaluate the subset */
1082
+ this_gain = categ_gain(split_info, base_info);
1083
+ if (this_gain > *gain) {
1084
+ *gain = this_gain;
1085
+ best_subset = combin;
1086
+ }
1087
+
1088
+ }
1089
+
1090
+ /* now convert the best subset into a proper array */
1091
+ if (*gain > -HUGE_VAL) {
1092
+ for (size_t catx = 0; catx < ncat_x; catx++) {
1093
+ split_subset[catx] = extract_bit(best_subset, catx);
1094
+ }
1095
+ flag_zero_counts(split_subset, buffer_cat_cnt, ncat_x);
1096
+ }
1097
+
1098
+ }