outliertree 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1098 @@
1
+ /********************************************************************************************************************
2
+ * Explainable outlier detection
3
+ *
4
+ * Tries to detect outliers by generating decision trees that attempt to predict the values of each column based on
5
+ * each other column, testing in each branch of every tried split (if it meets some minimum criteria) whether there
6
+ * are observations that seem too distant from the others in a 1-D distribution for the column that the split tries
7
+ * to "predict" (will not generate a score for each observation).
8
+ * Splits are based on gain, while outlierness is based on confidence intervals.
9
+ * Similar in spirit to the GritBot software developed by RuleQuest research. Reference article is:
10
+ * Cortes, David. "Explainable outlier detection through decision tree conditioning."
11
+ * arXiv preprint arXiv:2001.00636 (2020).
12
+ *
13
+ *
14
+ * Copyright 2020 David Cortes.
15
+ *
16
+ * Written for C++11 standard and OpenMP 2.0 or later. Code is meant to be wrapped into scripting languages
17
+ * such as R or Python.
18
+ *
19
+ * This file is part of OutlierTree.
20
+ *
21
+ * OutlierTree is free software: you can redistribute it and/or modify
22
+ * it under the terms of the GNU General Public License as published by
23
+ * the Free Software Foundation, either version 3 of the License, or
24
+ * (at your option) any later version.
25
+ *
26
+ * OutlierTree is distributed in the hope that it will be useful,
27
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
28
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
29
+ * GNU General Public License for more details.
30
+ *
31
+ * You should have received a copy of the GNU General Public License
32
+ * along with OutlierTree. If not, see <https://www.gnu.org/licenses/>.
33
+ ********************************************************************************************************************/
34
+ #include "outlier_tree.hpp"
35
+
36
+
37
+ /* TODO: don't divide the gains by tot at every calculation as it makes it slower */
38
+
39
+ /* TODO: sorting here is the slowest thing, so it could be improved by using radix sort for categorical/ordinal and timsort for numerical */
40
+
41
+ /* TODO: columns that split by numeric should output the sum/sum_sq to pass it to the cluster functions, instead of recalculating them later */
42
+
43
+
44
+ void subset_to_onehot(size_t ix_arr[], size_t n_true, size_t n_tot, char onehot[])
45
+ {
46
+ memset(onehot, 0, sizeof(bool) * n_tot);
47
+ for (size_t i = 0; i <= n_true; i++) onehot[ix_arr[i]] = 1;
48
+ }
49
+
50
+ size_t move_zero_count_to_front(size_t *restrict cat_sorted, size_t *restrict cat_cnt, size_t ncat_x)
51
+ {
52
+ size_t temp_ix;
53
+ size_t st_cat = 0;
54
+ for (size_t cat = 0; cat < ncat_x; cat++) {
55
+ if (cat_cnt[cat] == 0) {
56
+ temp_ix = cat_sorted[st_cat];
57
+ cat_sorted[st_cat] = cat;
58
+ cat_sorted[cat] = temp_ix;
59
+ st_cat++;
60
+ }
61
+ }
62
+ return st_cat;
63
+ }
64
+
65
+ void flag_zero_counts(char split_subset[], size_t buffer_cat_cnt[], size_t ncat_x)
66
+ {
67
+ for (size_t cat = 0; cat < ncat_x; cat++)
68
+ if (buffer_cat_cnt[cat] == 0) split_subset[cat] = -1;
69
+ }
70
+
71
+ long double calc_sd(size_t cnt, long double sum, long double sum_sq)
72
+ {
73
+ if (cnt < 3) return 0;
74
+ return sqrtl( (sum_sq - (square(sum) / (long double) cnt) + SD_REG) / (long double) (cnt - 1) );
75
+ }
76
+
77
+ long double calc_sd(NumericBranch &branch)
78
+ {
79
+ if (branch.cnt < 3) return 0;
80
+ return sqrtl((branch.sum_sq - (square(branch.sum) / (long double) branch.cnt) + SD_REG) / (long double) (branch.cnt - 1));
81
+ }
82
+
83
+ long double calc_sd(size_t ix_arr[], double *restrict x, size_t st, size_t end, double *restrict mean)
84
+ {
85
+ long double running_mean = 0;
86
+ long double mean_prev = 0;
87
+ long double running_ssq = 0;
88
+ double xval;
89
+ for (size_t row = st; row <= end; row++) {
90
+ xval = x[ix_arr[row]];
91
+ running_mean += (xval - running_mean) / (long double)(row - st + 1);
92
+ running_ssq += (xval - running_mean) * (xval - mean_prev);
93
+ mean_prev = running_mean;
94
+ }
95
+ *mean = (double) running_mean;
96
+ return sqrtl(running_ssq / (long double)(end - st));
97
+
98
+ }
99
+
100
+ long double numeric_gain(NumericSplit &split_info, long double tot_sd)
101
+ {
102
+ long double tot = (long double)(split_info.NA_branch.cnt + split_info.left_branch.cnt + split_info.right_branch.cnt);
103
+ long double residual =
104
+ ((long double) split_info.NA_branch.cnt) * calc_sd(split_info.NA_branch) +
105
+ ((long double) split_info.left_branch.cnt) * calc_sd(split_info.left_branch) +
106
+ ((long double) split_info.right_branch.cnt) * calc_sd(split_info.right_branch);
107
+
108
+ return tot_sd - (residual / tot);
109
+ }
110
+
111
+ long double numeric_gain(long double tot_sd, long double info_left, long double info_right, long double info_NA, long double cnt)
112
+ {
113
+ return tot_sd - (info_left + info_right + info_NA) / cnt;
114
+ }
115
+
116
+ long double total_info(size_t categ_counts[], size_t ncat)
117
+ {
118
+ long double s = 0;
119
+ size_t tot = 0;
120
+ for (size_t cat = 0; cat < ncat; cat++) {
121
+ if (categ_counts[cat] > 0) {
122
+ s += (long double)categ_counts[cat] * logl((long double)categ_counts[cat]);
123
+ tot += categ_counts[cat];
124
+ }
125
+ }
126
+ if (tot == 0) return 0;
127
+ return (long double)tot * logl((long double)tot) - s;
128
+ }
129
+
130
+ long double total_info(size_t categ_counts[], size_t ncat, size_t tot)
131
+ {
132
+ if (tot == 0) return 0;
133
+ long double s = 0;
134
+ for (size_t cat = 0; cat < ncat; cat++) {
135
+ if (categ_counts[cat] > 1) {
136
+ s += (long double)categ_counts[cat] * logl((long double)categ_counts[cat]);
137
+ }
138
+ }
139
+ return (long double) tot * logl((long double) tot) - s;
140
+ /* tot = sum(categ_counts[]) */
141
+ }
142
+
143
+ long double total_info(size_t *restrict ix_arr, int *restrict x, size_t st, size_t end, size_t ncat, size_t *restrict buffer_cat_cnt)
144
+ {
145
+ long double info = (long double)(end - st + 1) * logl((long double)(end - st + 1));
146
+ memset(buffer_cat_cnt, 0, ncat * sizeof(size_t));
147
+ for (size_t row = st; row <= end; row++) {
148
+ buffer_cat_cnt[ x[ix_arr[row]] ]++;
149
+ }
150
+ for (size_t cat = 0; cat < ncat; cat++) {
151
+ if (buffer_cat_cnt[cat] > 1) {
152
+ info -= (long double)buffer_cat_cnt[cat] * logl((long double)buffer_cat_cnt[cat]);
153
+ }
154
+ }
155
+ return info;
156
+ }
157
+
158
+ long double categ_gain(CategSplit split_info, long double base_info)
159
+ {
160
+ return (
161
+ base_info -
162
+ total_info(split_info.NA_branch, split_info.ncat, split_info.size_NA) -
163
+ total_info(split_info.left_branch, split_info.ncat, split_info.size_left) -
164
+ total_info(split_info.right_branch, split_info.ncat, split_info.size_right)
165
+ ) / (long double) split_info.tot ;
166
+ }
167
+
168
+ long double categ_gain(size_t *restrict categ_counts, size_t ncat, size_t *restrict ncat_col, size_t maxcat, long double base_info, size_t tot)
169
+ {
170
+ long double info = 0;
171
+ for (size_t cat = 0; cat < ncat; cat++) {
172
+ if (categ_counts[cat] > 0) {
173
+ info += total_info(categ_counts + cat * maxcat, ncat_col[cat]);
174
+ }
175
+ }
176
+
177
+ /* last entry in the array corresponds to NA values */
178
+ if (categ_counts[ncat] > 0) {
179
+ info += total_info(categ_counts + ncat * maxcat, ncat_col[ncat]);
180
+ }
181
+
182
+ return (base_info - info) / (long double) tot;
183
+ }
184
+
185
+ long double categ_gain_from_split(size_t *restrict ix_arr, int *restrict x, size_t st, size_t st_non_na, size_t split_ix, size_t end,
186
+ size_t ncat, size_t *restrict buffer_cat_cnt, long double base_info)
187
+ {
188
+ long double gain = base_info;
189
+ memset(buffer_cat_cnt, 0, ncat * sizeof(size_t));
190
+ if (st_non_na > st) {
191
+ for (size_t row = st; row < st_non_na; row++) {
192
+ buffer_cat_cnt[ x[ix_arr[row]] ]++;
193
+ }
194
+ gain -= total_info(buffer_cat_cnt, ncat, st_non_na - st);
195
+ memset(buffer_cat_cnt, 0, ncat * sizeof(size_t));
196
+ }
197
+
198
+ for (size_t row = st_non_na; row < split_ix; row++) {
199
+ buffer_cat_cnt[ x[ix_arr[row]] ]++;
200
+ }
201
+ gain -= total_info(buffer_cat_cnt, ncat, split_ix - st_non_na);
202
+ memset(buffer_cat_cnt, 0, ncat * sizeof(size_t));
203
+
204
+ for (size_t row = split_ix; row <= end; row++) {
205
+ buffer_cat_cnt[ x[ix_arr[row]] ]++;
206
+ }
207
+ gain -= total_info(buffer_cat_cnt, ncat, end - split_ix + 1);
208
+
209
+ return gain / (long double)(end - st + 1);
210
+ }
211
+
212
+ /* Calculate gain from splitting a numeric column by another numeric column
213
+ *
214
+ * Function splits into buckets (NA, <= threshold, > threshold)
215
+ *
216
+ * Parameters
217
+ * - ix_arr[n] (in)
218
+ * Array containing the indices at which 'x' and 'y' can be accessed, considering only the
219
+ * elements between st and end (i.e. ix_arr[st:end], inclusive of both ends)
220
+ * (Note: will be modified in-place)
221
+ * - st (in)
222
+ * See above.
223
+ * - end (in)
224
+ * See above.
225
+ * - x[n] (in)
226
+ * Numeric column from which a split predicting 'y' will be calculated.
227
+ * - y[n] (in)
228
+ * Numeric column whose distribution wants to be split by 'x'.
229
+ * Must not contain missing values.
230
+ * - sd_y (in)
231
+ * Standard deviation of 'y' between the indices considered here.
232
+ * - has_na (in)
233
+ * Whether 'x' can have missing values or not.
234
+ * - min_size (in)
235
+ * Minimum number of elements that can be in a split.
236
+ * - buffer_sd[n] (in)
237
+ * Buffer where to write temporary sd/information at each split point in a first pass.
238
+ * - gain (out)
239
+ * Gain calculated on the best split found. If no split is possible, will return -Inf.
240
+ * - split_point (out)
241
+ * Threshold for splitting on values of 'x'. If no split is posible, will return -Inf.
242
+ * - split_left (out)
243
+ * Index at which the data is split between the two branches (includes last from left branch).
244
+ * - split_NA (out)
245
+ * Index at which the NA data is separated from the other branches
246
+ */
247
+ void split_numericx_numericy(size_t *restrict ix_arr, size_t st, size_t end, double *restrict x, double *restrict y,
248
+ long double sd_y, bool has_na, size_t min_size, bool take_mid, long double *restrict buffer_sd,
249
+ long double *restrict gain, double *restrict split_point, size_t *restrict split_left, size_t *restrict split_NA)
250
+ {
251
+
252
+ *gain = -HUGE_VAL;
253
+ *split_point = -HUGE_VAL;
254
+ size_t st_non_na;
255
+ long double this_gain;
256
+ long double cnt_dbl = (long double)(end - st + 1);
257
+ long double running_mean = 0;
258
+ long double mean_prev = 0;
259
+ long double running_ssq = 0;
260
+ double xval;
261
+ long double info_left;
262
+ long double info_NA = 0;
263
+
264
+ /* check that there are enough observations for a split */
265
+ if ((end - st + 1) < (2 * min_size)) return;
266
+
267
+ /* move all NAs of X to the front */
268
+ if (has_na) {
269
+ st_non_na = move_NAs_to_front(ix_arr, x, st, end, false);
270
+ } else { st_non_na = st; }
271
+ *split_NA = st_non_na;
272
+
273
+ /* assign NAs to their own branch */
274
+ if (st_non_na > st) {
275
+
276
+ /* first check that it's still possible to split */
277
+ if ((end - st_non_na + 1) < (2 * min_size)) return;
278
+
279
+ info_NA = (long double)(st_non_na - st) * calc_sd(ix_arr, y, st, st_non_na-1, &xval); /* last arg is not used */
280
+ }
281
+
282
+ /* sort the remaining non-NA values in ascending order */
283
+ std::sort(ix_arr + st_non_na, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
284
+
285
+ /* calculate SD*N backwards first, then forwards */
286
+ for (size_t i = end; i >= st_non_na; i--) {
287
+ xval = y[ix_arr[i]];
288
+ running_mean += (xval - running_mean) / (long double)(end - i + 1);
289
+ running_ssq += (xval - running_mean) * (xval - mean_prev);
290
+ mean_prev = running_mean;
291
+ buffer_sd[i] = (long double)(end - i + 1) * sqrtl(running_ssq / (long double)(end - i));
292
+ /* could also avoid div by n-1, would be faster */
293
+
294
+ if (i == st_non_na) break; /* be aware unsigned integer overflow */
295
+ }
296
+
297
+ /* look for the best split point, by moving one observation at a time to the left branch*/
298
+ running_mean = 0;
299
+ running_ssq = 0;
300
+ mean_prev = 0;
301
+ for (size_t i = st_non_na; i <= (end - min_size); i++) {
302
+ xval = y[ix_arr[i]];
303
+ running_mean += (xval - running_mean) / (long double)(i - st_non_na + 1);
304
+ running_ssq += (xval - running_mean) * (xval - mean_prev);
305
+ mean_prev = running_mean;
306
+
307
+ /* check that split meets minimum criteria (size on right branch is controlled in loop condition) */
308
+ if ((i - st_non_na + 1) < min_size) continue;
309
+
310
+ /* check that value is not repeated next -- note that condition in loop prevents out-of-bounds access */
311
+ if (x[ix_arr[i]] == x[ix_arr[i + 1]]) continue;
312
+
313
+ /* evaluate gain at this split point */
314
+ info_left = (long double)(i - st_non_na + 1) * sqrtl(running_ssq / (long double)(i - st_non_na));
315
+ this_gain = numeric_gain(sd_y, info_left, buffer_sd[i + 1], info_NA, cnt_dbl);
316
+ if (this_gain > *gain) {
317
+ *gain = this_gain;
318
+ *split_point = take_mid? (avg_between(x[ix_arr[i]], x[ix_arr[i + 1]])) : (x[ix_arr[i]]);
319
+ *split_left = i;
320
+ }
321
+ }
322
+ }
323
+
324
+ /* Calculate gain from splitting a numeric column by a categorical column
325
+ *
326
+ * Function splits into two subsets + NAs on their own branch
327
+ *
328
+ * Parameters
329
+ * - ix_arr[n] (in)
330
+ * Array containing the indices at which 'x' and 'y' can be accessed, considering only the
331
+ * elements between st and end (i.e. ix_arr[st:end], inclusive of both ends)
332
+ * (Note: will be modified in-place)
333
+ * - st (in)
334
+ * See above.
335
+ * - end (in)
336
+ * See above.
337
+ * - x[n] (in)
338
+ * Categorical column from which a split predicting 'y' will be calculated.
339
+ * Missing values should be encoded as negative integers.
340
+ * - y[n] (in)
341
+ * Numeric column whose distribution wants to be split by 'x'.
342
+ * Must not contain missing values.
343
+ * - sd_y (in)
344
+ * Standard deviation of 'y' between the indices considered here.
345
+ * - x_is_ordinal (in)
346
+ * Whether the 'x' column has ordered categories, in which case the split will be a
347
+ * <= that respects this order.
348
+ * - ncat_x (in)
349
+ * Number of categories in 'x' (excluding NA).
350
+ * - buffer_cat_cnt[ncat_x + 1] (temp)
351
+ * Array where temporary data for each category will be written into.
352
+ * Must have one additional entry anove the number of categories to account for NAs.
353
+ * - buffer_cat_sum[ncat_x + 1] (temp)
354
+ * See above.
355
+ * - buffer_cat_sum_sq[ncat_x + 1] (temp)
356
+ * See above.
357
+ * - buffer_cat_sorted[ncat_x] (temp)
358
+ * See above. This one doesn't need an extra entry.
359
+ * - has_na (in)
360
+ * Whether 'x' can have missing values or not.
361
+ * - min_size (in)
362
+ * Minimum number of elements that can be in a split.
363
+ * - gain (out)
364
+ * Gain calculated on the best split found. If no split is possible, will return -Inf.
365
+ * - split_subset[ncat_x] (out)
366
+ * Array that will indicate which categories go into the left branch in the chosen split.
367
+ * (value of 1 means it's on the left branch, 0 in the right branch, -1 not applicable)
368
+ * - split_point (out)
369
+ * Split level for ordinal X variables (left branch is <= this)
370
+ */
371
+ void split_categx_numericy(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, double *restrict y, long double sd_y, double ymean,
372
+ bool x_is_ordinal, size_t ncat_x, size_t *restrict buffer_cat_cnt, long double *restrict buffer_cat_sum,
373
+ long double *restrict buffer_cat_sum_sq, size_t *restrict buffer_cat_sorted,
374
+ bool has_na, size_t min_size, long double *gain, char *restrict split_subset, int *restrict split_point)
375
+ {
376
+
377
+ /* output parameters and variables to use */
378
+ *gain = -HUGE_VAL;
379
+ long double this_gain;
380
+ NumericSplit split_info;
381
+ size_t st_cat = 0;
382
+ double sd_y_d = (double) sd_y;
383
+
384
+ /* reset the buffers */
385
+ memset(split_subset, 0, sizeof(char) * ncat_x);
386
+ memset(buffer_cat_cnt, 0, sizeof(size_t) * (ncat_x + 1));
387
+ memset(buffer_cat_sum, 0, sizeof(long double) * (ncat_x + 1));
388
+ memset(buffer_cat_sum_sq, 0, sizeof(long double) * (ncat_x + 1));
389
+
390
+ /* calculate summary info for each category */
391
+ if (has_na) {
392
+
393
+ for (size_t i = st; i <= end; i++) {
394
+
395
+ /* NAs are encoded as negative integers, and go at the last slot */
396
+ if (x[ix_arr[i]] < 0) {
397
+ buffer_cat_cnt[ncat_x]++;
398
+ buffer_cat_sum[ncat_x] += z_score(y[ix_arr[i]], ymean, sd_y_d);
399
+ buffer_cat_sum_sq[ncat_x] += square(z_score(y[ix_arr[i]], ymean, sd_y_d));
400
+ } else {
401
+ buffer_cat_cnt[ x[ix_arr[i]] ]++;
402
+ buffer_cat_sum[ x[ix_arr[i]] ] += z_score(y[ix_arr[i]], ymean, sd_y_d);
403
+ buffer_cat_sum_sq[ x[ix_arr[i]] ] += square(z_score(y[ix_arr[i]], ymean, sd_y_d));
404
+ }
405
+ }
406
+
407
+ } else {
408
+
409
+ buffer_cat_cnt[ncat_x] = 0;
410
+ for (size_t i = st; i <= end; i++) {
411
+ buffer_cat_cnt[ x[ix_arr[i]] ]++;
412
+ buffer_cat_sum[ x[ix_arr[i]] ] += z_score(y[ix_arr[i]], ymean, sd_y_d);
413
+ buffer_cat_sum_sq[ x[ix_arr[i]] ] += square(z_score(y[ix_arr[i]], ymean, sd_y_d));
414
+ }
415
+
416
+ }
417
+
418
+ /* set NAs to their own branch */
419
+ if (buffer_cat_cnt[ncat_x] > 0) {
420
+ split_info.NA_branch = {buffer_cat_cnt[ncat_x], buffer_cat_sum[ncat_x], buffer_cat_sum_sq[ncat_x]};
421
+ }
422
+
423
+ /* easy case: binary split (only one possible split point) */
424
+ if (ncat_x == 2) {
425
+
426
+ /* must still meet minimum size requirements */
427
+ if (buffer_cat_cnt[0] < min_size || buffer_cat_cnt[1] < min_size) return;
428
+
429
+ split_info.left_branch = {buffer_cat_cnt[0], buffer_cat_sum[0], buffer_cat_sum_sq[0]};
430
+ split_info.right_branch = {buffer_cat_cnt[1], buffer_cat_sum[1], buffer_cat_sum_sq[1]};
431
+ *gain = numeric_gain(split_info, 1.0) * sd_y;
432
+ split_subset[0] = 1;
433
+ }
434
+
435
+ /* subset and ordinal splits */
436
+ else {
437
+
438
+ /* put all the categories on the right branch */
439
+ for (size_t cat = 0; cat < ncat_x; cat++) {
440
+ split_info.right_branch.cnt += buffer_cat_cnt[cat];
441
+ split_info.right_branch.sum += buffer_cat_sum[cat];
442
+ split_info.right_branch.sum_sq += buffer_cat_sum_sq[cat];
443
+ }
444
+
445
+ /* if it's an ordinal variable, must respect the order */
446
+ for (size_t cat = 0; cat < ncat_x; cat++) buffer_cat_sorted[cat] = cat;
447
+
448
+ if (!x_is_ordinal) {
449
+ /* otherwise, sort the categories according to their mean of y */
450
+
451
+ /* first remove zero-counts */
452
+ st_cat = move_zero_count_to_front(buffer_cat_sorted, buffer_cat_cnt, ncat_x);
453
+
454
+ /* then sort */
455
+ std::sort(buffer_cat_sorted + st_cat, buffer_cat_sorted + ncat_x,
456
+ [&buffer_cat_sum, &buffer_cat_cnt](const size_t a, const size_t b)
457
+ {
458
+ return (buffer_cat_sum[a] / (long double) buffer_cat_cnt[a]) >
459
+ (buffer_cat_sum[b] / (long double) buffer_cat_cnt[b]);
460
+ });
461
+ }
462
+
463
+ /* try moving each category to the left branch in the given order */
464
+ for (size_t cat = st_cat; cat < ncat_x; cat++) {
465
+ split_info.right_branch.cnt -= buffer_cat_cnt[ buffer_cat_sorted[cat] ];
466
+ split_info.right_branch.sum -= buffer_cat_sum[ buffer_cat_sorted[cat] ];
467
+ split_info.right_branch.sum_sq -= buffer_cat_sum_sq[ buffer_cat_sorted[cat] ];
468
+
469
+ split_info.left_branch.cnt += buffer_cat_cnt[ buffer_cat_sorted[cat] ];
470
+ split_info.left_branch.sum += buffer_cat_sum[ buffer_cat_sorted[cat] ];
471
+ split_info.left_branch.sum_sq += buffer_cat_sum_sq[ buffer_cat_sorted[cat] ];
472
+
473
+ /* see if it meets minimum split sizes */
474
+ if (split_info.left_branch.cnt < min_size || split_info.right_branch.cnt < min_size) continue;
475
+
476
+ /* calculate the gain */
477
+ this_gain = numeric_gain(split_info, 1.0);
478
+ if (this_gain > *gain) {
479
+ *gain = this_gain * sd_y;
480
+ if (!x_is_ordinal)
481
+ subset_to_onehot(buffer_cat_sorted, cat, ncat_x, split_subset);
482
+ else
483
+ *split_point = (int) cat;
484
+ }
485
+ }
486
+
487
+ /* if it's categorical, set the non-present categories to -1 */
488
+ if (!is_na_or_inf(*gain) && !x_is_ordinal) flag_zero_counts(split_subset, buffer_cat_cnt, ncat_x);
489
+
490
+ }
491
+
492
+ }
493
+
494
+
495
+
496
+ /* Calculate gain from splitting a categorical column by a numeric column
497
+ *
498
+ * Function splits into two subsets + NAs on their own branch
499
+ *
500
+ * Parameters
501
+ * - ix_arr[n] (in)
502
+ * Array containing the indices at which 'x' and 'y' can be accessed, considering only the
503
+ * elements between st and end (i.e. ix_arr[st:end], inclusive of both ends)
504
+ * (Note: will be modified in-place)
505
+ * - st (in)
506
+ * See above.
507
+ * - end (in)
508
+ * See above.
509
+ * - x[n] (in)
510
+ * Numerical column from which a split predicting 'y' will be calculated.
511
+ * - y[n] (in)
512
+ * Categorical column whose distributions are to be split by 'x'.
513
+ * Must not contain missing values (which are encoded as negative integers).
514
+ * - ncat_y (in)
515
+ * Number of categories in 'y' (excluding NAs, which are encoded as negative integers).
516
+ * - base_info (in)
517
+ * Base information for the 'y' counts before splitting.
518
+ * (:= N*log(N) - sum_i..m N_i*log(N_i))
519
+ * - buffer_cat_cnt[ncat_y * 3] (temp)
520
+ * Array where temporary data for each category will be written into.
521
+ * - has_na (in)
522
+ * Whether 'x' can have missing values or not.
523
+ * - min_size (in)
524
+ * Minimum number of elements that can be in a split.
525
+ * - gain (out)
526
+ * Gain calculated on the best split found. If no split is possible, will return -Inf.
527
+ * - split_point (out)
528
+ * Threshold for splitting on values of 'x'. If no split is posible, will return -Inf.
529
+ * - split_left (out)
530
+ * Index at which the data is split between the two branches (includes last from left branch).
531
+ * - split_NA (out)
532
+ * Index at which the NA data is separated from the other branches
533
+ */
534
+ void split_numericx_categy(size_t *restrict ix_arr, size_t st, size_t end, double *restrict x, int *restrict y,
535
+ size_t ncat_y, long double base_info, size_t *restrict buffer_cat_cnt,
536
+ bool has_na, size_t min_size, bool take_mid, long double *restrict gain, double *restrict split_point,
537
+ size_t *restrict split_left, size_t *restrict split_NA)
538
+ {
539
+ *gain = -HUGE_VAL;
540
+ *split_point = -HUGE_VAL;
541
+ size_t st_non_na;
542
+ long double this_gain;
543
+ CategSplit split_info;
544
+ split_info.ncat = ncat_y;
545
+ split_info.tot = end - st + 1;
546
+
547
+ /* check that there are enough observations for a split */
548
+ if ((end - st + 1) < (2 * min_size)) return;
549
+
550
+ /* will divide into 3 branches: NA, <= p, > p */
551
+ memset(buffer_cat_cnt, 0, 3 * ncat_y * sizeof(size_t));
552
+ split_info.NA_branch = buffer_cat_cnt;
553
+ split_info.left_branch = buffer_cat_cnt + ncat_y;
554
+ split_info.right_branch = buffer_cat_cnt + 2 * ncat_y;
555
+
556
+ /* move all NAs of X to the front */
557
+ if (has_na) {
558
+ st_non_na = move_NAs_to_front(ix_arr, x, st, end, false);
559
+ } else { st_non_na = st; }
560
+ *split_NA = st_non_na;
561
+
562
+ /* assign NAs to their own branch */
563
+ split_info.size_NA = st_non_na - st;
564
+ if (st_non_na > st) {
565
+
566
+ /* first check that it's still possible to split */
567
+ if ((end - st_non_na + 1) < (2 * min_size)) return;
568
+
569
+ for (size_t i = st; i < st_non_na; i++) split_info.NA_branch[ y[ix_arr[i]] ]++;
570
+ }
571
+
572
+ /* sort the remaining non-NA values in ascending order */
573
+ std::sort(ix_arr + st_non_na, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
574
+
575
+ /* put all observations on the right branch */
576
+ for (size_t i = st_non_na; i <= end; i++) split_info.right_branch[ y[ix_arr[i]] ]++;
577
+
578
+ /* look for the best split point, by moving one observation at a time to the left branch*/
579
+ for (size_t i = st_non_na; i <= (end - min_size); i++) {
580
+ split_info.right_branch[ y[ix_arr[i]] ]--;
581
+ split_info.left_branch [ y[ix_arr[i]] ]++;
582
+ split_info.size_left = i - st_non_na + 1;
583
+ split_info.size_right = end - i;
584
+
585
+ /* check that split meets minimum criteria (size on right branch is controlled in loop condition) */
586
+ if (split_info.size_left < min_size) continue;
587
+
588
+ /* check that value is not repeated next -- note that condition in loop prevents out-of-bounds access */
589
+ if (x[ix_arr[i]] == x[ix_arr[i + 1]]) continue;
590
+
591
+ /* evaluate gain at this split point */
592
+ this_gain = categ_gain(split_info, base_info);
593
+ if (this_gain > *gain) {
594
+ *gain = this_gain;
595
+ *split_point = take_mid? (avg_between(x[ix_arr[i]], x[ix_arr[i + 1]])) : (x[ix_arr[i]]);
596
+ *split_left = i;
597
+ }
598
+ }
599
+ }
600
+
601
+ /* Calculate gain from splitting a categorical column by an ordinal column
602
+ *
603
+ * Function splits into two subsets + NAs on their own branch
604
+ *
605
+ * Parameters
606
+ * - ix_arr[n] (in)
607
+ * Array containing the indices at which 'x' and 'y' can be accessed, considering only the
608
+ * elements between st and end (i.e. ix_arr[st:end], inclusive of both ends)
609
+ * (Note: will be modified in-place)
610
+ * - st (in)
611
+ * See above.
612
+ * - end (in)
613
+ * See above.
614
+ * - x[n] (in)
615
+ * Ordinal column from which a split predicting 'y' will be calculated.
616
+ * Missing values must be encoded as negative integers.
617
+ * - y[n] (in)
618
+ * Categorical column whose distributions are to be split by 'x'.
619
+ * Must not contain missing values (which are encoded as negative integers).
620
+ * - ncat_y (in)
621
+ * Number of categories in 'y' (excluding NAs, which are encoded as negative integers).
622
+ * - ncat_x (in)
623
+ * Number of categories in 'x' (excluding NAs, which are encoded as negative integers).
624
+ * - base_info (in)
625
+ * Base information for the 'y' counts before splitting.
626
+ * (:= N*log(N) - sum_i..m N_i*log(N_i))
627
+ * - buffer_cat_cnt[ncat_y * 3] (temp)
628
+ * Array where temporary data for each category will be written into.
629
+ * - buffer_crosstab[ncat_x * ncat_y] (temp)
630
+ * See above.
631
+ * - buffer_ord_cnt[ncat_x] (temp)
632
+ * See above.
633
+ * - has_na (in)
634
+ * Whether 'x' can have missing values or not.
635
+ * - min_size (in)
636
+ * Minimum number of elements that can be in a split.
637
+ * - gain (out)
638
+ * Gain calculated on the best split found. If no split is possible, will return -Inf.
639
+ * - split_point (out)
640
+ * Threshold for splitting on values of 'x'. If no split is posible, will return -1.
641
+ */
642
+ void split_ordx_categy(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int *restrict y,
643
+ size_t ncat_y, size_t ncat_x, long double base_info,
644
+ size_t *restrict buffer_cat_cnt, size_t *restrict buffer_crosstab, size_t *restrict buffer_ord_cnt,
645
+ bool has_na, size_t min_size, long double *gain, int *split_point)
646
+ {
647
+ *gain = -HUGE_VAL;
648
+ *split_point = -1;
649
+ size_t st_non_na;
650
+ long double this_gain;
651
+ CategSplit split_info;
652
+ split_info.ncat = ncat_y;
653
+ split_info.tot = end - st + 1;
654
+
655
+ /* check that there are enough observations for a split */
656
+ if ((end - st + 1) < (2 * min_size)) return;
657
+
658
+ /* will divide into 3 branches: NA, <= p, > p */
659
+ memset(buffer_cat_cnt, 0, 3 * ncat_y * sizeof(size_t));
660
+ split_info.NA_branch = buffer_cat_cnt;
661
+ split_info.left_branch = buffer_cat_cnt + ncat_y;
662
+ split_info.right_branch = buffer_cat_cnt + 2 * ncat_y;
663
+
664
+ /* move all NAs of X to the front */
665
+ if (has_na) {
666
+ st_non_na = move_NAs_to_front(ix_arr, x, st, end);
667
+ } else { st_non_na = st; }
668
+
669
+ /* assign NAs to their own branch */
670
+ split_info.size_NA = st_non_na - st;
671
+ if (st_non_na > st) {
672
+
673
+ /* first check that it's still possible to split */
674
+ if ((end - st_non_na + 1) < (2 * min_size)) return;
675
+
676
+ for (size_t i = st; i < st_non_na; i++) split_info.NA_branch[ y[ix_arr[i]] ]++;
677
+ }
678
+
679
+ /* calculate cross-table on the non-missing cases, and put all observations in the right branch */
680
+ memset(buffer_crosstab, 0, ncat_y * ncat_x * sizeof(size_t));
681
+ memset(buffer_ord_cnt, 0, ncat_x * sizeof(size_t));
682
+ for (size_t i = st_non_na; i <= end; i++) {
683
+ buffer_crosstab[ y[ix_arr[i]] + ncat_y * x[ix_arr[i]] ]++;
684
+ buffer_ord_cnt [ x[ix_arr[i]] ]++;
685
+ split_info.right_branch[ y[ix_arr[i]] ]++;
686
+ }
687
+ split_info.size_right = end - st_non_na + 1;
688
+ split_info.size_left = 0;
689
+
690
+ /* look for the best split point, by moving one observation at a time to the left branch*/
691
+ for (size_t ord_cat = 0; ord_cat < (ncat_x - 1); ord_cat++) {
692
+
693
+ for (size_t moved_cat = 0; moved_cat < ncat_y; moved_cat++) {
694
+ split_info.right_branch[ moved_cat ] -= buffer_crosstab[ moved_cat + ncat_y * ord_cat ];
695
+ split_info.left_branch [ moved_cat ] += buffer_crosstab[ moved_cat + ncat_y * ord_cat ];
696
+ }
697
+ split_info.size_right -= buffer_ord_cnt[ord_cat];
698
+ split_info.size_left += buffer_ord_cnt[ord_cat];
699
+
700
+ /* check that split meets minimum criteria */
701
+ if (split_info.size_left < min_size || split_info.size_right < min_size) continue;
702
+
703
+ /* evaluate gain at this split point */
704
+ this_gain = categ_gain(split_info, base_info);
705
+ if (this_gain > *gain) {
706
+ *gain = this_gain;
707
+ *split_point = ord_cat;
708
+ }
709
+ }
710
+ }
711
+
712
+
713
+ /* Calculate gain from splitting a binary column by a categorical column
714
+ *
715
+ * Function splits into two subsets + NAs on their own branch
716
+ *
717
+ * Parameters
718
+ * - ix_arr[n] (in)
719
+ * Array containing the indices at which 'x' and 'y' can be accessed, considering only the
720
+ * elements between st and end (i.e. ix_arr[st:end], inclusive of both ends)
721
+ * (Note: will be modified in-place)
722
+ * - st (in)
723
+ * See above.
724
+ * - end (in)
725
+ * See above.
726
+ * - x[n] (in)
727
+ * Categorical column from which a split predicting 'y' will be calculated.
728
+ * Missing values must be encoded as negative integers.
729
+ * - y[n] (in)
730
+ * Binary column whose distributions are to be split by 'x'.
731
+ * Must not contain missing values (which are encoded as negative integers).
732
+ * - ncat_x (in)
733
+ * Number of categories in 'x' (excluding NAs, which are encoded as negative integers).
734
+ * - base_info (in)
735
+ * Base information for the 'y' counts before splitting.
736
+ * (:= N*log(N) - sum_i..m N_i*log(N_i))
737
+ * - buffer_cat_cnt[ncat_x] (temp)
738
+ * Array where temporary data for each category will be written into.
739
+ * - buffer_crosstab[2 * ncat_x] (temp)
740
+ * See above.
741
+ * - buffer_cat_sorted[ncat_x] (temp)
742
+ * See above.
743
+ * - has_na (in)
744
+ * Whether 'x' can have missing values or not.
745
+ * - min_size (in)
746
+ * Minimum number of elements that can be in a split.
747
+ * - gain (out)
748
+ * Gain calculated on the best split found. If no split is possible, will return -Inf.
749
+ * - split_subset[ncat_x] (out)
750
+ * Array that will indicate which categories go into the left branch in the chosen split.
751
+ * (value of 1 means it's on the left branch, 0 in the right branch, -1 not applicable)
752
+ */
753
+ void split_categx_biny(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int *restrict y,
754
+ size_t ncat_x, long double base_info,
755
+ size_t *restrict buffer_cat_cnt, size_t *restrict buffer_crosstab, size_t *restrict buffer_cat_sorted,
756
+ bool has_na, size_t min_size, long double *gain, char *restrict split_subset)
757
+ {
758
+ *gain = -HUGE_VAL;
759
+ size_t st_non_na;
760
+ long double this_gain;
761
+ size_t buffer_fixed_size[6] = {0};
762
+ CategSplit split_info;
763
+ size_t st_cat;
764
+ split_info.ncat = 2;
765
+ split_info.tot = end - st + 1;
766
+
767
+ /* check that there are enough observations for a split */
768
+ if ((end - st + 1) < (2 * min_size)) return;
769
+
770
+ /* will divide into 3 branches: NA, <= p, > p */
771
+ split_info.NA_branch = buffer_fixed_size;
772
+ split_info.left_branch = buffer_fixed_size + 2;
773
+ split_info.right_branch = buffer_fixed_size + 2 * 2;
774
+
775
+ /* move all NAs of X to the front */
776
+ if (has_na) {
777
+ st_non_na = move_NAs_to_front(ix_arr, x, st, end);
778
+ } else { st_non_na = st; }
779
+
780
+ /* assign NAs to their own branch */
781
+ split_info.size_NA = st_non_na - st;
782
+ if (st_non_na > st) {
783
+
784
+ /* first check that it's still possible to split */
785
+ if ((end - st_non_na + 1) < (2 * min_size)) return;
786
+
787
+ for (size_t i = st; i < st_non_na; i++) split_info.NA_branch[ y[ix_arr[i]] ]++;
788
+ }
789
+
790
+ /* calculate cross-table on the non-missing cases, and put all observations in the right branch */
791
+ memset(buffer_crosstab, 0, 2 * ncat_x * sizeof(size_t));
792
+ memset(buffer_cat_cnt, 0, ncat_x * sizeof(size_t));
793
+ for (size_t i = st_non_na; i <= end; i++) {
794
+ buffer_crosstab[ y[ix_arr[i]] + 2 * x[ix_arr[i]] ]++;
795
+ buffer_cat_cnt [ x[ix_arr[i]] ]++;
796
+ split_info.right_branch[ y[ix_arr[i]] ]++;
797
+ }
798
+ split_info.size_right = end - st_non_na + 1;
799
+ split_info.size_left = 0;
800
+
801
+ /* sort the categories according to their mean of y */
802
+ for (size_t cat = 0; cat < ncat_x; cat++) buffer_cat_sorted[cat] = cat;
803
+ st_cat = move_zero_count_to_front(buffer_cat_sorted, buffer_cat_cnt, ncat_x);
804
+ std::sort(buffer_cat_sorted + st_cat, buffer_cat_sorted + ncat_x,
805
+ [&buffer_crosstab, &buffer_cat_cnt](const size_t a, const size_t b)
806
+ {
807
+ return ((long double) buffer_crosstab[2 * a] / (long double) buffer_cat_cnt[a]) >
808
+ ((long double) buffer_crosstab[2 * b] / (long double) buffer_cat_cnt[b]);
809
+ });
810
+
811
+ /* look for the best split subset, by moving one category at a time to the left branch*/
812
+ for (size_t cat = st_cat; cat < (ncat_x - 1); cat++) {
813
+
814
+ split_info.right_branch[0] -= buffer_crosstab[2 * buffer_cat_sorted[cat]];
815
+ split_info.right_branch[1] -= buffer_crosstab[2 * buffer_cat_sorted[cat] + 1];
816
+ split_info.left_branch [0] += buffer_crosstab[2 * buffer_cat_sorted[cat]];
817
+ split_info.left_branch [1] += buffer_crosstab[2 * buffer_cat_sorted[cat] + 1];
818
+ split_info.size_right -= buffer_cat_cnt [buffer_cat_sorted[cat]];
819
+ split_info.size_left += buffer_cat_cnt [buffer_cat_sorted[cat]];
820
+
821
+ /* check that split meets minimum criteria */
822
+ if (split_info.size_left < min_size || split_info.size_right < min_size) continue;
823
+
824
+ /* evaluate gain at this split point */
825
+ this_gain = categ_gain(split_info, base_info);
826
+ if (this_gain > *gain) {
827
+ *gain = this_gain;
828
+ subset_to_onehot(buffer_cat_sorted, cat, ncat_x, split_subset);
829
+ }
830
+ }
831
+ if (!is_na_or_inf(*gain)) flag_zero_counts(split_subset, buffer_cat_cnt, ncat_x);
832
+ }
833
+
834
+
835
+ /* Calculate gain from splitting a categorical columns by another categorical column
836
+ *
837
+ * Function splits into one branch per category of 'x'
838
+ *
839
+ * Parameters
840
+ * - ix_arr[n] (in)
841
+ * Array containing the indices at which 'x' and 'y' can be accessed, considering only the
842
+ * elements between st and end (i.e. ix_arr[st:end], inclusive of both ends)
843
+ * (Note: will be modified in-place)
844
+ * - st (in)
845
+ * See above.
846
+ * - end (in)
847
+ * See above.
848
+ * - x[n] (in)
849
+ * Categorical column from which a split predicting 'y' will be calculated.
850
+ * Missing values must be encoded as negative integers.
851
+ * - y[n] (in)
852
+ * Categorical column whose distributions are to be split by 'x'.
853
+ * Must not contain missing values (which are encoded as negative integers).
854
+ * - ncat_x (in)
855
+ * Number of categories in 'x' (excluding NAs, which are encoded as negative integers).
856
+ * - ncat_y (in)
857
+ * Number of categories in 'y'.
858
+ * - base_info (in)
859
+ * Base information for the 'y' counts before splitting.
860
+ * (:= N*log(N) - sum_i..m N_i*log(N_i))
861
+ * - buffer_cat_cnt[ncat_x + 1] (temp)
862
+ * Array where temporary data for each category will be written into.
863
+ * - buffer_crosstab[(ncat_x + 1) * ncat_y] (temp)
864
+ * See above.
865
+ * - has_na (in)
866
+ * Whether 'x' can have missing values or not.
867
+ * - gain (out)
868
+ * Gain calculated on the split. If no split is possible, will return -Inf.
869
+ */
870
+ void split_categx_categy_separate(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int *restrict y,
871
+ size_t ncat_x, size_t ncat_y, long double base_info,
872
+ size_t *restrict buffer_cat_cnt, size_t *restrict buffer_crosstab,
873
+ bool has_na, size_t min_size, long double *gain)
874
+ {
875
+ long double this_gain = 0;
876
+ size_t st_non_na;
877
+
878
+ /* move all NAs of X to the front */
879
+ if (has_na) {
880
+ st_non_na = move_NAs_to_front(ix_arr, x, st, end);
881
+ } else { st_non_na = st; }
882
+
883
+ /* calculate cross-table on the non-missing cases */
884
+ memset(buffer_crosstab, 0, ncat_y * (ncat_x + 1) * sizeof(size_t));
885
+ memset(buffer_cat_cnt, 0, (ncat_x + 1) * sizeof(size_t));
886
+ for (size_t i = st_non_na; i <= end; i++) {
887
+ buffer_crosstab[ y[ix_arr[i]] + ncat_y * x[ix_arr[i]] ]++;
888
+ buffer_cat_cnt [ x[ix_arr[i]] ]++;
889
+ }
890
+
891
+ /* if no category meets the minimum split size, end here */
892
+ if (*std::max_element(buffer_cat_cnt, buffer_cat_cnt + (ncat_x + 1)) < min_size) {
893
+ *gain = -HUGE_VAL;
894
+ return;
895
+ }
896
+
897
+ /* calculate gain for splitting at each category */
898
+ for (size_t cat = 0; cat < ncat_x; cat++) {
899
+ this_gain += total_info(buffer_crosstab + cat * ncat_y, ncat_y, buffer_cat_cnt[cat]);
900
+ }
901
+
902
+ /* add the split on missing x */
903
+ if (st_non_na > st) {
904
+ for (size_t i = st; i < st_non_na; i++) {
905
+ buffer_crosstab[ y[ix_arr[i]] + ncat_y * ncat_x ]++;
906
+ buffer_cat_cnt [ ncat_x ]++;
907
+ }
908
+ this_gain += total_info(buffer_crosstab + ncat_x * ncat_y, ncat_y, buffer_cat_cnt[ncat_x]);
909
+ }
910
+
911
+ /* return calculated gain */
912
+ *gain = (base_info - this_gain) / (long double) (end - st + 1);
913
+ }
914
+
915
+
916
+ /* Calculate gain from splitting a categorical column by another categorical column
917
+ *
918
+ * Function splits into two subsets + NAs on their own branch
919
+ *
920
+ * Parameters
921
+ * - ix_arr[n] (in)
922
+ * Array containing the indices at which 'x' and 'y' can be accessed, considering only the
923
+ * elements between st and end (i.e. ix_arr[st:end], inclusive of both ends)
924
+ * (Note: will be modified in-place)
925
+ * - st (in)
926
+ * See above.
927
+ * - end (in)
928
+ * See above.
929
+ * - x[n] (in)
930
+ * Categorical column from which a split predicting 'y' will be calculated.
931
+ * Missing values must be encoded as negative integers.
932
+ * - y[n] (in)
933
+ * Categorical column whose distributions are to be split by 'x'.
934
+ * Must not contain missing values (which are encoded as negative integers).
935
+ * - ncat_x (in)
936
+ * Number of categories in 'x' (excluding NAs, which are encoded as negative integers).
937
+ * - ncat_y (in)
938
+ * Number of categories in 'x'.
939
+ * - base_info (in)
940
+ * Base information for the 'y' counts before splitting.
941
+ * (:= N*log(N) - sum_i..m N_i*log(N_i))
942
+ * - buffer_cat_cnt[ncat_x] (temp)
943
+ * Array where temporary data for each category will be written into.
944
+ * - buffer_crosstab[ncat_x * ncat_y] (temp)
945
+ * See above.
946
+ * - buffer_split[3 * ncat_y] (temp)
947
+ * See above.
948
+ * - has_na (in)
949
+ * Whether 'x' can have missing values or not.
950
+ * - min_size (in)
951
+ * Minimum number of elements that can be in a split.
952
+ * - gain (out)
953
+ * Gain calculated on the best split found. If no split is possible, will return -Inf.
954
+ * - split_subset[ncat_x] (out)
955
+ * Array that will indicate which categories go into the left branch in the chosen split.
956
+ * (value of 1 means it's on the left branch, 0 in the right branch, -1 not applicable)
957
+ */
958
+ void split_categx_categy_subset(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int *restrict y,
959
+ size_t ncat_x, size_t ncat_y, long double base_info,
960
+ size_t *restrict buffer_cat_cnt, size_t *restrict buffer_crosstab, size_t *restrict buffer_split,
961
+ bool has_na, size_t min_size, long double *gain, char *restrict split_subset)
962
+ {
963
+ *gain = -HUGE_VAL;
964
+ long double this_gain;
965
+ size_t best_subset;
966
+ CategSplit split_info;
967
+ split_info.tot = end - st + 1;
968
+ split_info.ncat = ncat_y;
969
+ size_t st_non_na;
970
+
971
+ /* will divide into 3 branches: NA, within subset, outside subset */
972
+ memset(buffer_split, 0, 3 * ncat_y * sizeof(size_t));
973
+ split_info.NA_branch = buffer_split;
974
+ split_info.left_branch = buffer_split + ncat_y;
975
+ split_info.right_branch = buffer_split + 2 * ncat_y;
976
+
977
+ /* move all NAs of X to the front */
978
+ if (has_na) {
979
+ st_non_na = move_NAs_to_front(ix_arr, x, st, end);
980
+ } else { st_non_na = st; }
981
+ split_info.size_NA = st_non_na - st;
982
+
983
+ /* calculate cross-table */
984
+ memset(buffer_crosstab, 0, ncat_y * ncat_x * sizeof(size_t));
985
+ memset(buffer_cat_cnt, 0, ncat_x * sizeof(size_t));
986
+ for (size_t i = st_non_na; i <= end; i++) {
987
+ buffer_crosstab[ y[ix_arr[i]] + ncat_y * x[ix_arr[i]] ]++;
988
+ buffer_cat_cnt [ x[ix_arr[i]] ]++;
989
+ }
990
+ if (st_non_na > st) {
991
+ for (size_t i = st; i < st_non_na; i++) {
992
+ split_info.NA_branch[ y[ix_arr[i]] ]++;
993
+ }
994
+ }
995
+
996
+ /* put all categories on the right branch */
997
+ memset(split_info.left_branch, 0, ncat_y * sizeof(size_t));
998
+ memset(split_info.right_branch, 0, ncat_y * sizeof(size_t));
999
+ split_info.size_left = 0;
1000
+ split_info.size_right = 0;
1001
+ for (size_t catx = 0; catx < ncat_x; catx++) {
1002
+ for (size_t caty = 0; caty < ncat_y; caty++) {
1003
+ split_info.right_branch[caty] += buffer_crosstab[caty + catx * ncat_y];
1004
+ }
1005
+ split_info.size_right += buffer_cat_cnt[catx];
1006
+ }
1007
+
1008
+ /* TODO: don't loop over categories with zero-counts everywhere */
1009
+
1010
+ /* do a brute-force search over all possible subset splits (there's [2^ncat_x - 2] of them) */
1011
+ size_t curr_exponent = 0;
1012
+ size_t last_bit;
1013
+ size_t ncomb = pow2(ncat_x) - 1;
1014
+
1015
+ /* iteration is done by putting a category in the left branch if the bit at its
1016
+ position in the binary representation of the combination number is a 1 */
1017
+ /* TODO: this would be faster with a depth-first search routine */
1018
+ for (size_t combin = 1; combin < ncomb; combin++) {
1019
+
1020
+ /* at each iteration, move the bits that differ from one branch to the other */
1021
+ /* note however than when there are few categories, it's actually faster to recalculate
1022
+ the counts based on the bitset -- this code however still follows this more "smart" way
1023
+ of moving cateogries when needed, which makes it slightly more scalable */
1024
+
1025
+ /* at any given number, the bits can only vary up a certain bit from an increase by one,
1026
+ which can be obtained from calculating the maximum power of two that is smaller than
1027
+ the combination number */
1028
+ if (combin == pow2(curr_exponent)) {
1029
+ curr_exponent++;
1030
+ last_bit = (size_t) curr_exponent - 1;
1031
+
1032
+ /* when this happens, this specific bit will change from a zero to a one,
1033
+ while the ones before will change from ones to zeros */
1034
+ for (size_t caty = 0; caty < ncat_y; caty++) {
1035
+ split_info.right_branch[caty] -= buffer_crosstab[caty + last_bit * ncat_y];
1036
+ split_info.left_branch [caty] += buffer_crosstab[caty + last_bit * ncat_y];
1037
+ }
1038
+ split_info.size_left += buffer_cat_cnt[last_bit];
1039
+ split_info.size_right -= buffer_cat_cnt[last_bit];
1040
+
1041
+ for (size_t catx = 0; catx < last_bit; catx++) {
1042
+ for (size_t caty = 0; caty < ncat_y; caty++) {
1043
+ split_info.left_branch [caty] -= buffer_crosstab[caty + catx * ncat_y];
1044
+ split_info.right_branch[caty] += buffer_crosstab[caty + catx * ncat_y];
1045
+ }
1046
+ split_info.size_left -= buffer_cat_cnt[catx];
1047
+ split_info.size_right += buffer_cat_cnt[catx];
1048
+ }
1049
+
1050
+ } else {
1051
+
1052
+ /* in the regular case, just inspect the bits that come before the exponent in the current
1053
+ power of two that is less than the combination number, and see if a category needs to be moved */
1054
+ for (size_t catx = 0; catx < last_bit; catx++) {
1055
+ if (extract_bit(combin, catx) != extract_bit(combin - 1, catx)) {
1056
+
1057
+ if (extract_bit(combin - 1, catx)) {
1058
+ for (size_t caty = 0; caty < ncat_y; caty++) {
1059
+ split_info.left_branch [caty] -= buffer_crosstab[caty + catx * ncat_y];
1060
+ split_info.right_branch[caty] += buffer_crosstab[caty + catx * ncat_y];
1061
+ }
1062
+ split_info.size_left -= buffer_cat_cnt[catx];
1063
+ split_info.size_right += buffer_cat_cnt[catx];
1064
+ } else {
1065
+ for (size_t caty = 0; caty < ncat_y; caty++) {
1066
+ split_info.left_branch [caty] += buffer_crosstab[caty + catx * ncat_y];
1067
+ split_info.right_branch[caty] -= buffer_crosstab[caty + catx * ncat_y];
1068
+ }
1069
+ split_info.size_left += buffer_cat_cnt[catx];
1070
+ split_info.size_right -= buffer_cat_cnt[catx];
1071
+ }
1072
+
1073
+ }
1074
+ }
1075
+
1076
+ }
1077
+
1078
+ /* check that split meets minimum criteria */
1079
+ if (split_info.size_left < min_size || split_info.size_right < min_size) continue;
1080
+
1081
+ /* now evaluate the subset */
1082
+ this_gain = categ_gain(split_info, base_info);
1083
+ if (this_gain > *gain) {
1084
+ *gain = this_gain;
1085
+ best_subset = combin;
1086
+ }
1087
+
1088
+ }
1089
+
1090
+ /* now convert the best subset into a proper array */
1091
+ if (*gain > -HUGE_VAL) {
1092
+ for (size_t catx = 0; catx < ncat_x; catx++) {
1093
+ split_subset[catx] = extract_bit(best_subset, catx);
1094
+ }
1095
+ flag_zero_counts(split_subset, buffer_cat_cnt, ncat_x);
1096
+ }
1097
+
1098
+ }