outliertree 0.1.2 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/NOTICE.txt +1 -1
- data/README.md +11 -10
- data/ext/outliertree/ext.cpp +104 -105
- data/ext/outliertree/extconf.rb +1 -1
- data/lib/outliertree/result.rb +3 -3
- data/lib/outliertree/version.rb +1 -1
- data/vendor/outliertree/README.md +77 -40
- data/vendor/outliertree/src/Makevars.in +4 -0
- data/vendor/outliertree/src/Makevars.win +4 -0
- data/vendor/outliertree/src/RcppExports.cpp +20 -9
- data/vendor/outliertree/src/Rwrapper.cpp +256 -57
- data/vendor/outliertree/src/cat_outlier.cpp +6 -6
- data/vendor/outliertree/src/clusters.cpp +114 -9
- data/vendor/outliertree/src/fit_model.cpp +505 -308
- data/vendor/outliertree/src/misc.cpp +165 -4
- data/vendor/outliertree/src/outlier_tree.hpp +159 -51
- data/vendor/outliertree/src/outliertree-win.def +3 -0
- data/vendor/outliertree/src/predict.cpp +33 -0
- data/vendor/outliertree/src/split.cpp +124 -20
- metadata +10 -8
- data/vendor/outliertree/src/Makevars +0 -3
| @@ -36,12 +36,15 @@ | |
| 36 36 |  | 
| 37 37 | 
             
            /* TODO: don't divide the gains by tot at every calculation as it makes it slower */
         | 
| 38 38 |  | 
| 39 | 
            -
            /* TODO: sorting here is the slowest thing, so it could be improved by using radix sort for categorical/ordinal  | 
| 39 | 
            +
            /* TODO: sorting here is the slowest thing, so it could be improved by using radix sort for categorical/ordinal */
         | 
| 40 40 |  | 
| 41 41 | 
             
            /* TODO: columns that split by numeric should output the sum/sum_sq to pass it to the cluster functions, instead of recalculating them later */
         | 
| 42 42 |  | 
| 43 | 
            +
            /* TODO: the calculations of standard deviations when splitting a numeric column by a categorical column are
         | 
| 44 | 
            +
               highly imprecise and might throw negative variances. Should switch to a more robust procedure. */
         | 
| 43 45 |  | 
| 44 | 
            -
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            void subset_to_onehot(size_t ix_arr[], size_t n_true, size_t n_tot, signed char onehot[])
         | 
| 45 48 | 
             
            {
         | 
| 46 49 | 
             
                memset(onehot, 0, sizeof(bool) * n_tot);
         | 
| 47 50 | 
             
                for (size_t i = 0; i <= n_true; i++) onehot[ix_arr[i]] = 1;
         | 
| @@ -62,7 +65,7 @@ size_t move_zero_count_to_front(size_t *restrict cat_sorted, size_t *restrict ca | |
| 62 65 | 
             
                return st_cat;
         | 
| 63 66 | 
             
            }
         | 
| 64 67 |  | 
| 65 | 
            -
            void flag_zero_counts(char split_subset[], size_t buffer_cat_cnt[], size_t ncat_x)
         | 
| 68 | 
            +
            void flag_zero_counts(signed char split_subset[], size_t buffer_cat_cnt[], size_t ncat_x)
         | 
| 66 69 | 
             
            {
         | 
| 67 70 | 
             
                for (size_t cat = 0; cat < ncat_x; cat++)
         | 
| 68 71 | 
             
                    if (buffer_cat_cnt[cat] == 0) split_subset[cat] = -1;
         | 
| @@ -71,20 +74,20 @@ void flag_zero_counts(char split_subset[], size_t buffer_cat_cnt[], size_t ncat_ | |
| 71 74 | 
             
            long double calc_sd(size_t cnt, long double sum, long double sum_sq)
         | 
| 72 75 | 
             
            {
         | 
| 73 76 | 
             
                if (cnt < 3) return 0;
         | 
| 74 | 
            -
                return  | 
| 77 | 
            +
                return std::sqrt( (sum_sq - (square(sum) / (long double) cnt) + SD_REG) / (long double) (cnt - 1) );
         | 
| 75 78 | 
             
            }
         | 
| 76 79 |  | 
| 77 80 | 
             
            long double calc_sd(NumericBranch &branch)
         | 
| 78 81 | 
             
            {
         | 
| 79 82 | 
             
                if (branch.cnt < 3) return 0;
         | 
| 80 | 
            -
                return  | 
| 83 | 
            +
                return std::sqrt((branch.sum_sq - (square(branch.sum) / (long double) branch.cnt) + SD_REG) / (long double) (branch.cnt - 1));
         | 
| 81 84 | 
             
            }
         | 
| 82 85 |  | 
| 83 86 | 
             
            long double calc_sd(size_t ix_arr[], double *restrict x, size_t st, size_t end, double *restrict mean)
         | 
| 84 87 | 
             
            {
         | 
| 85 88 | 
             
                long double running_mean = 0;
         | 
| 86 | 
            -
                long double mean_prev    = 0;
         | 
| 87 89 | 
             
                long double running_ssq  = 0;
         | 
| 90 | 
            +
                long double mean_prev    = x[ix_arr[st]];
         | 
| 88 91 | 
             
                double xval;
         | 
| 89 92 | 
             
                for (size_t row = st; row <= end; row++) {
         | 
| 90 93 | 
             
                    xval = x[ix_arr[row]];
         | 
| @@ -93,7 +96,7 @@ long double calc_sd(size_t ix_arr[], double *restrict x, size_t st, size_t end, | |
| 93 96 | 
             
                    mean_prev     = running_mean;
         | 
| 94 97 | 
             
                }
         | 
| 95 98 | 
             
                *mean = (double) running_mean;
         | 
| 96 | 
            -
                return  | 
| 99 | 
            +
                return std::sqrt(running_ssq / (long double)(end - st));
         | 
| 97 100 |  | 
| 98 101 | 
             
            }
         | 
| 99 102 |  | 
| @@ -242,11 +245,13 @@ long double categ_gain_from_split(size_t *restrict ix_arr, int *restrict x, size | |
| 242 245 | 
             
            *    - split_left (out)
         | 
| 243 246 | 
             
            *        Index at which the data is split between the two branches (includes last from left branch).
         | 
| 244 247 | 
             
            *    - split_NA (out)
         | 
| 245 | 
            -
            *        Index at which the NA data is separated from the other branches
         | 
| 248 | 
            +
            *        Index at which the NA data is separated from the other branches.
         | 
| 249 | 
            +
            *    - has_zero_variance (out)
         | 
| 250 | 
            +
            *        Whether the 'x' column has zero variance (contains only one unique value).
         | 
| 246 251 | 
             
            */
         | 
| 247 252 | 
             
            void split_numericx_numericy(size_t *restrict ix_arr, size_t st, size_t end, double *restrict x, double *restrict y,
         | 
| 248 253 | 
             
                                         long double sd_y, bool has_na, size_t min_size, bool take_mid, long double *restrict buffer_sd,
         | 
| 249 | 
            -
                                         long double *restrict gain, double *restrict split_point, size_t *restrict split_left, size_t *restrict split_NA)
         | 
| 254 | 
            +
                                         long double *restrict gain, double *restrict split_point, size_t *restrict split_left, size_t *restrict split_NA, bool *restrict has_zero_variance)
         | 
| 250 255 | 
             
            {
         | 
| 251 256 |  | 
| 252 257 | 
             
                *gain = -HUGE_VAL;
         | 
| @@ -255,11 +260,12 @@ void split_numericx_numericy(size_t *restrict ix_arr, size_t st, size_t end, dou | |
| 255 260 | 
             
                long double this_gain;
         | 
| 256 261 | 
             
                long double cnt_dbl = (long double)(end - st + 1);
         | 
| 257 262 | 
             
                long double running_mean = 0;
         | 
| 258 | 
            -
                long double mean_prev    = 0;
         | 
| 259 263 | 
             
                long double running_ssq  = 0;
         | 
| 264 | 
            +
                long double mean_prev    = 0;
         | 
| 260 265 | 
             
                double xval;
         | 
| 261 266 | 
             
                long double info_left;
         | 
| 262 267 | 
             
                long double info_NA = 0;
         | 
| 268 | 
            +
                *has_zero_variance = false;
         | 
| 263 269 |  | 
| 264 270 | 
             
                /* check that there are enough observations for a split */
         | 
| 265 271 | 
             
                if ((end - st + 1) < (2 * min_size)) return;
         | 
| @@ -281,8 +287,13 @@ void split_numericx_numericy(size_t *restrict ix_arr, size_t st, size_t end, dou | |
| 281 287 |  | 
| 282 288 | 
             
                /* sort the remaining non-NA values in ascending order */
         | 
| 283 289 | 
             
                std::sort(ix_arr + st_non_na, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
         | 
| 290 | 
            +
                if (x[ix_arr[st_non_na]] == x[ix_arr[end]]) {
         | 
| 291 | 
            +
                    *has_zero_variance = true;
         | 
| 292 | 
            +
                    return;
         | 
| 293 | 
            +
                }
         | 
| 284 294 |  | 
| 285 295 | 
             
                /* calculate SD*N backwards first, then forwards */
         | 
| 296 | 
            +
                mean_prev = y[ix_arr[end]];
         | 
| 286 297 | 
             
                for (size_t i = end; i >= st_non_na; i--) {
         | 
| 287 298 | 
             
                    xval = y[ix_arr[i]];
         | 
| 288 299 | 
             
                    running_mean += (xval - running_mean) / (long double)(end - i + 1);
         | 
| @@ -297,7 +308,7 @@ void split_numericx_numericy(size_t *restrict ix_arr, size_t st, size_t end, dou | |
| 297 308 | 
             
                /* look for the best split point, by moving one observation at a time to the left branch*/
         | 
| 298 309 | 
             
                running_mean = 0;
         | 
| 299 310 | 
             
                running_ssq  = 0;
         | 
| 300 | 
            -
                mean_prev    =  | 
| 311 | 
            +
                mean_prev    = y[ix_arr[st_non_na]];
         | 
| 301 312 | 
             
                for (size_t i = st_non_na; i <= (end - min_size); i++) {
         | 
| 302 313 | 
             
                    xval = y[ix_arr[i]];
         | 
| 303 314 | 
             
                    running_mean += (xval - running_mean) / (long double)(i - st_non_na + 1);
         | 
| @@ -366,12 +377,16 @@ void split_numericx_numericy(size_t *restrict ix_arr, size_t st, size_t end, dou | |
| 366 377 | 
             
            *        Array that will indicate which categories go into the left branch in the chosen split.
         | 
| 367 378 | 
             
            *        (value of 1 means it's on the left branch, 0 in the right branch, -1 not applicable)
         | 
| 368 379 | 
             
            *    - split_point (out)
         | 
| 369 | 
            -
            *        Split level for ordinal X variables (left branch is <= this)
         | 
| 380 | 
            +
            *        Split level for ordinal X variables (left branch is <= this).
         | 
| 381 | 
            +
            *    - has_zero_variance (out)
         | 
| 382 | 
            +
            *        Whether the 'x' column has zero variance (contains only one unique value).
         | 
| 383 | 
            +
            *    - binary_split
         | 
| 384 | 
            +
            *        Whether the produced split is binary (single category at each branch).
         | 
| 370 385 | 
             
            */
         | 
| 371 386 | 
             
            void split_categx_numericy(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, double *restrict y, long double sd_y, double ymean,
         | 
| 372 387 | 
             
                                       bool x_is_ordinal, size_t ncat_x, size_t *restrict buffer_cat_cnt, long double *restrict buffer_cat_sum,
         | 
| 373 388 | 
             
                                       long double *restrict buffer_cat_sum_sq, size_t *restrict buffer_cat_sorted,
         | 
| 374 | 
            -
                                       bool has_na, size_t min_size, long double *gain, char *restrict split_subset, int *restrict split_point)
         | 
| 389 | 
            +
                                       bool has_na, size_t min_size, long double *gain, signed char *restrict split_subset, int *restrict split_point, bool *restrict has_zero_variance, bool *restrict binary_split)
         | 
| 375 390 | 
             
            {
         | 
| 376 391 |  | 
| 377 392 | 
             
                /* output parameters and variables to use */
         | 
| @@ -380,9 +395,11 @@ void split_categx_numericy(size_t *restrict ix_arr, size_t st, size_t end, int * | |
| 380 395 | 
             
                NumericSplit split_info;
         | 
| 381 396 | 
             
                size_t st_cat = 0;
         | 
| 382 397 | 
             
                double sd_y_d = (double) sd_y;
         | 
| 398 | 
            +
                *has_zero_variance = false;
         | 
| 399 | 
            +
                *binary_split = false;
         | 
| 383 400 |  | 
| 384 401 | 
             
                /* reset the buffers */
         | 
| 385 | 
            -
                memset(split_subset,      0, sizeof(char)   *  ncat_x);
         | 
| 402 | 
            +
                memset(split_subset,      0, sizeof(signed char)   *  ncat_x);
         | 
| 386 403 | 
             
                memset(buffer_cat_cnt,    0, sizeof(size_t) * (ncat_x + 1));
         | 
| 387 404 | 
             
                memset(buffer_cat_sum,    0, sizeof(long double) * (ncat_x + 1));
         | 
| 388 405 | 
             
                memset(buffer_cat_sum_sq, 0, sizeof(long double) * (ncat_x + 1));
         | 
| @@ -415,6 +432,16 @@ void split_categx_numericy(size_t *restrict ix_arr, size_t st, size_t end, int * | |
| 415 432 |  | 
| 416 433 | 
             
                }
         | 
| 417 434 |  | 
| 435 | 
            +
                int n_unique_cat = 0;
         | 
| 436 | 
            +
                for (size_t cat = 0; cat < ncat_x; cat++) {
         | 
| 437 | 
            +
                    n_unique_cat += buffer_cat_sum_sq[cat] > 0;
         | 
| 438 | 
            +
                    if (n_unique_cat >= 2) break;
         | 
| 439 | 
            +
                }
         | 
| 440 | 
            +
                if (n_unique_cat <= 1) {
         | 
| 441 | 
            +
                    *has_zero_variance = true;
         | 
| 442 | 
            +
                    return;
         | 
| 443 | 
            +
                }
         | 
| 444 | 
            +
             | 
| 418 445 | 
             
                /* set NAs to their own branch */
         | 
| 419 446 | 
             
                if (buffer_cat_cnt[ncat_x] > 0) {
         | 
| 420 447 | 
             
                    split_info.NA_branch = {buffer_cat_cnt[ncat_x], buffer_cat_sum[ncat_x], buffer_cat_sum_sq[ncat_x]};
         | 
| @@ -430,6 +457,8 @@ void split_categx_numericy(size_t *restrict ix_arr, size_t st, size_t end, int * | |
| 430 457 | 
             
                    split_info.right_branch = {buffer_cat_cnt[1], buffer_cat_sum[1], buffer_cat_sum_sq[1]};
         | 
| 431 458 | 
             
                    *gain = numeric_gain(split_info, 1.0) * sd_y;
         | 
| 432 459 | 
             
                    split_subset[0] = 1;
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                    *binary_split = true;
         | 
| 433 462 | 
             
                }
         | 
| 434 463 |  | 
| 435 464 | 
             
                /* subset and ordinal splits */
         | 
| @@ -443,7 +472,7 @@ void split_categx_numericy(size_t *restrict ix_arr, size_t st, size_t end, int * | |
| 443 472 | 
             
                    }
         | 
| 444 473 |  | 
| 445 474 | 
             
                    /* if it's an ordinal variable, must respect the order */
         | 
| 446 | 
            -
                     | 
| 475 | 
            +
                    std::iota(buffer_cat_sorted, buffer_cat_sorted + ncat_x, (size_t)0);
         | 
| 447 476 |  | 
| 448 477 | 
             
                    if (!x_is_ordinal) {
         | 
| 449 478 | 
             
                        /* otherwise, sort the categories according to their mean of y */
         | 
| @@ -458,6 +487,10 @@ void split_categx_numericy(size_t *restrict ix_arr, size_t st, size_t end, int * | |
| 458 487 | 
             
                                      return (buffer_cat_sum[a] / (long double) buffer_cat_cnt[a]) >
         | 
| 459 488 | 
             
                                             (buffer_cat_sum[b] / (long double) buffer_cat_cnt[b]);
         | 
| 460 489 | 
             
                                  });
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                        if (ncat_x - st_cat == 2) {
         | 
| 492 | 
            +
                            *binary_split = true;
         | 
| 493 | 
            +
                        }
         | 
| 461 494 | 
             
                    }
         | 
| 462 495 |  | 
| 463 496 | 
             
                    /* try moving each category to the left branch in the given order */
         | 
| @@ -530,11 +563,13 @@ void split_categx_numericy(size_t *restrict ix_arr, size_t st, size_t end, int * | |
| 530 563 | 
             
            *        Index at which the data is split between the two branches (includes last from left branch).
         | 
| 531 564 | 
             
            *    - split_NA (out)
         | 
| 532 565 | 
             
            *        Index at which the NA data is separated from the other branches
         | 
| 566 | 
            +
            *    - has_zero_variance (out)
         | 
| 567 | 
            +
            *        Whether the 'x' column has zero variance (contains only one unique value).
         | 
| 533 568 | 
             
            */
         | 
| 534 569 | 
             
            void split_numericx_categy(size_t *restrict ix_arr, size_t st, size_t end, double *restrict x, int *restrict y,
         | 
| 535 570 | 
             
                                       size_t ncat_y, long double base_info, size_t *restrict buffer_cat_cnt,
         | 
| 536 571 | 
             
                                       bool has_na, size_t min_size, bool take_mid, long double *restrict gain, double *restrict split_point,
         | 
| 537 | 
            -
                                       size_t *restrict split_left, size_t *restrict split_NA)
         | 
| 572 | 
            +
                                       size_t *restrict split_left, size_t *restrict split_NA, bool *restrict has_zero_variance)
         | 
| 538 573 | 
             
            {
         | 
| 539 574 | 
             
                *gain = -HUGE_VAL;
         | 
| 540 575 | 
             
                *split_point = -HUGE_VAL;
         | 
| @@ -543,6 +578,7 @@ void split_numericx_categy(size_t *restrict ix_arr, size_t st, size_t end, doubl | |
| 543 578 | 
             
                CategSplit split_info;
         | 
| 544 579 | 
             
                split_info.ncat = ncat_y;
         | 
| 545 580 | 
             
                split_info.tot = end - st + 1;
         | 
| 581 | 
            +
                *has_zero_variance = false;
         | 
| 546 582 |  | 
| 547 583 | 
             
                /* check that there are enough observations for a split */
         | 
| 548 584 | 
             
                if ((end - st + 1) < (2 * min_size)) return;
         | 
| @@ -571,6 +607,10 @@ void split_numericx_categy(size_t *restrict ix_arr, size_t st, size_t end, doubl | |
| 571 607 |  | 
| 572 608 | 
             
                /* sort the remaining non-NA values in ascending order */
         | 
| 573 609 | 
             
                std::sort(ix_arr + st_non_na, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
         | 
| 610 | 
            +
                if (x[ix_arr[st_non_na]] == x[ix_arr[end]]) {
         | 
| 611 | 
            +
                    *has_zero_variance = true;
         | 
| 612 | 
            +
                    return;
         | 
| 613 | 
            +
                }
         | 
| 574 614 |  | 
| 575 615 | 
             
                /* put all observations on the right branch */
         | 
| 576 616 | 
             
                for (size_t i = st_non_na; i <= end; i++) split_info.right_branch[ y[ix_arr[i]] ]++;
         | 
| @@ -638,11 +678,16 @@ void split_numericx_categy(size_t *restrict ix_arr, size_t st, size_t end, doubl | |
| 638 678 | 
             
            *        Gain calculated on the best split found. If no split is possible, will return -Inf.
         | 
| 639 679 | 
             
            *    - split_point (out)
         | 
| 640 680 | 
             
            *        Threshold for splitting on values of 'x'. If no split is posible, will return -1.
         | 
| 681 | 
            +
            *    - has_zero_variance (out)
         | 
| 682 | 
            +
            *        Whether the 'x' column has zero variance (contains only one unique value).
         | 
| 683 | 
            +
            *    - binary_split
         | 
| 684 | 
            +
            *        Whether the produced split is binary (single category at each branch).
         | 
| 641 685 | 
             
            */
         | 
| 642 686 | 
             
            void split_ordx_categy(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int *restrict y,
         | 
| 643 687 | 
             
                                   size_t ncat_y, size_t ncat_x, long double base_info,
         | 
| 644 688 | 
             
                                   size_t *restrict buffer_cat_cnt, size_t *restrict buffer_crosstab, size_t *restrict buffer_ord_cnt,
         | 
| 645 | 
            -
                                   bool has_na, size_t min_size, long double *gain, int *split_point | 
| 689 | 
            +
                                   bool has_na, size_t min_size, long double *gain, int *split_point,
         | 
| 690 | 
            +
                                   bool *restrict has_zero_variance, bool *restrict binary_split)
         | 
| 646 691 | 
             
            {
         | 
| 647 692 | 
             
                *gain = -HUGE_VAL;
         | 
| 648 693 | 
             
                *split_point = -1;
         | 
| @@ -651,6 +696,8 @@ void split_ordx_categy(size_t *restrict ix_arr, size_t st, size_t end, int *rest | |
| 651 696 | 
             
                CategSplit split_info;
         | 
| 652 697 | 
             
                split_info.ncat = ncat_y;
         | 
| 653 698 | 
             
                split_info.tot = end - st + 1;
         | 
| 699 | 
            +
                *has_zero_variance = false;
         | 
| 700 | 
            +
                *binary_split = false;
         | 
| 654 701 |  | 
| 655 702 | 
             
                /* check that there are enough observations for a split */
         | 
| 656 703 | 
             
                if ((end - st + 1) < (2 * min_size)) return;
         | 
| @@ -687,6 +734,19 @@ void split_ordx_categy(size_t *restrict ix_arr, size_t st, size_t end, int *rest | |
| 687 734 | 
             
                split_info.size_right = end - st_non_na + 1;
         | 
| 688 735 | 
             
                split_info.size_left  = 0;
         | 
| 689 736 |  | 
| 737 | 
            +
                int n_unique_cat = 0;
         | 
| 738 | 
            +
                for (size_t cat = 0; cat < ncat_x; cat++) {
         | 
| 739 | 
            +
                    n_unique_cat += buffer_ord_cnt[cat] > 0;
         | 
| 740 | 
            +
                    if (n_unique_cat >= 3) break;
         | 
| 741 | 
            +
                }
         | 
| 742 | 
            +
                if (n_unique_cat <= 1) {
         | 
| 743 | 
            +
                    *has_zero_variance = true;
         | 
| 744 | 
            +
                    return;
         | 
| 745 | 
            +
                }
         | 
| 746 | 
            +
                if (n_unique_cat == 2) {
         | 
| 747 | 
            +
                    *binary_split = true;
         | 
| 748 | 
            +
                }
         | 
| 749 | 
            +
             | 
| 690 750 | 
             
                /* look for the best split point, by moving one observation at a time to the left branch*/
         | 
| 691 751 | 
             
                for (size_t ord_cat = 0; ord_cat < (ncat_x - 1); ord_cat++) {
         | 
| 692 752 |  | 
| @@ -749,11 +809,16 @@ void split_ordx_categy(size_t *restrict ix_arr, size_t st, size_t end, int *rest | |
| 749 809 | 
             
            *    - split_subset[ncat_x] (out)
         | 
| 750 810 | 
             
            *        Array that will indicate which categories go into the left branch in the chosen split.
         | 
| 751 811 | 
             
            *        (value of 1 means it's on the left branch, 0 in the right branch, -1 not applicable)
         | 
| 812 | 
            +
            *    - has_zero_variance (out)
         | 
| 813 | 
            +
            *        Whether the 'x' column has zero variance (contains only one unique value).
         | 
| 814 | 
            +
            *    - binary_split
         | 
| 815 | 
            +
            *        Whether the produced split is binary (single category at each branch).
         | 
| 752 816 | 
             
            */
         | 
| 753 817 | 
             
            void split_categx_biny(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int *restrict y,
         | 
| 754 818 | 
             
                                   size_t ncat_x, long double base_info,
         | 
| 755 819 | 
             
                                   size_t *restrict buffer_cat_cnt, size_t *restrict buffer_crosstab, size_t *restrict buffer_cat_sorted,
         | 
| 756 | 
            -
                                   bool has_na, size_t min_size, long double *gain, char *restrict split_subset | 
| 820 | 
            +
                                   bool has_na, size_t min_size, long double *gain, signed char *restrict split_subset,
         | 
| 821 | 
            +
                                   bool *restrict has_zero_variance, bool *restrict binary_split)
         | 
| 757 822 | 
             
            {
         | 
| 758 823 | 
             
                *gain = -HUGE_VAL;
         | 
| 759 824 | 
             
                size_t st_non_na;
         | 
| @@ -763,6 +828,8 @@ void split_categx_biny(size_t *restrict ix_arr, size_t st, size_t end, int *rest | |
| 763 828 | 
             
                size_t st_cat;
         | 
| 764 829 | 
             
                split_info.ncat = 2;
         | 
| 765 830 | 
             
                split_info.tot = end - st + 1;
         | 
| 831 | 
            +
                *has_zero_variance = false;
         | 
| 832 | 
            +
                *binary_split = false;
         | 
| 766 833 |  | 
| 767 834 | 
             
                /* check that there are enough observations for a split */
         | 
| 768 835 | 
             
                if ((end - st + 1) < (2 * min_size)) return;
         | 
| @@ -798,8 +865,18 @@ void split_categx_biny(size_t *restrict ix_arr, size_t st, size_t end, int *rest | |
| 798 865 | 
             
                split_info.size_right = end - st_non_na + 1;
         | 
| 799 866 | 
             
                split_info.size_left  = 0;
         | 
| 800 867 |  | 
| 868 | 
            +
                int n_unique_cat = 0;
         | 
| 869 | 
            +
                for (size_t cat = 0; cat < ncat_x; cat++) {
         | 
| 870 | 
            +
                    n_unique_cat += buffer_cat_cnt[cat] > 0;
         | 
| 871 | 
            +
                    if (n_unique_cat >= 2) break;
         | 
| 872 | 
            +
                }
         | 
| 873 | 
            +
                if (n_unique_cat <= 1) {
         | 
| 874 | 
            +
                    *has_zero_variance = true;
         | 
| 875 | 
            +
                    return;
         | 
| 876 | 
            +
                }
         | 
| 877 | 
            +
             | 
| 801 878 | 
             
                /* sort the categories according to their mean of y */
         | 
| 802 | 
            -
                 | 
| 879 | 
            +
                std::iota(buffer_cat_sorted, buffer_cat_sorted + ncat_x, (size_t)0);
         | 
| 803 880 | 
             
                st_cat = move_zero_count_to_front(buffer_cat_sorted, buffer_cat_cnt, ncat_x);
         | 
| 804 881 | 
             
                std::sort(buffer_cat_sorted + st_cat, buffer_cat_sorted + ncat_x,
         | 
| 805 882 | 
             
                          [&buffer_crosstab, &buffer_cat_cnt](const size_t a, const size_t b)
         | 
| @@ -807,6 +884,9 @@ void split_categx_biny(size_t *restrict ix_arr, size_t st, size_t end, int *rest | |
| 807 884 | 
             
                              return ((long double) buffer_crosstab[2 * a] / (long double) buffer_cat_cnt[a]) >
         | 
| 808 885 | 
             
                                     ((long double) buffer_crosstab[2 * b] / (long double) buffer_cat_cnt[b]);
         | 
| 809 886 | 
             
                          });
         | 
| 887 | 
            +
                if (ncat_x - st_cat == 2) {
         | 
| 888 | 
            +
                    *binary_split = true;
         | 
| 889 | 
            +
                }
         | 
| 810 890 |  | 
| 811 891 | 
             
                /* look for the best split subset, by moving one category at a time to the left branch*/
         | 
| 812 892 | 
             
                for (size_t cat = st_cat; cat < (ncat_x - 1); cat++) {
         | 
| @@ -954,11 +1034,16 @@ void split_categx_categy_separate(size_t *restrict ix_arr, size_t st, size_t end | |
| 954 1034 | 
             
            *    - split_subset[ncat_x] (out)
         | 
| 955 1035 | 
             
            *        Array that will indicate which categories go into the left branch in the chosen split.
         | 
| 956 1036 | 
             
            *        (value of 1 means it's on the left branch, 0 in the right branch, -1 not applicable)
         | 
| 1037 | 
            +
            *    - has_zero_variance (out)
         | 
| 1038 | 
            +
            *        Whether the 'x' column has zero variance (contains only one unique value).
         | 
| 1039 | 
            +
            *    - binary_split
         | 
| 1040 | 
            +
            *        Whether the produced split is binary (single category at each branch).
         | 
| 957 1041 | 
             
            */
         | 
| 958 1042 | 
             
            void split_categx_categy_subset(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int *restrict y,
         | 
| 959 1043 | 
             
                                            size_t ncat_x, size_t ncat_y, long double base_info,
         | 
| 960 1044 | 
             
                                            size_t *restrict buffer_cat_cnt, size_t *restrict buffer_crosstab, size_t *restrict buffer_split,
         | 
| 961 | 
            -
                                            bool has_na, size_t min_size, long double *gain, char *restrict split_subset | 
| 1045 | 
            +
                                            bool has_na, size_t min_size, long double *gain, signed char *restrict split_subset,
         | 
| 1046 | 
            +
                                            bool *restrict has_zero_variance, bool *restrict binary_split)
         | 
| 962 1047 | 
             
            {
         | 
| 963 1048 | 
             
                *gain = -HUGE_VAL;
         | 
| 964 1049 | 
             
                long double this_gain;
         | 
| @@ -967,6 +1052,8 @@ void split_categx_categy_subset(size_t *restrict ix_arr, size_t st, size_t end, | |
| 967 1052 | 
             
                split_info.tot = end - st + 1;
         | 
| 968 1053 | 
             
                split_info.ncat = ncat_y;
         | 
| 969 1054 | 
             
                size_t st_non_na;
         | 
| 1055 | 
            +
                *has_zero_variance = false;
         | 
| 1056 | 
            +
                *binary_split = false;
         | 
| 970 1057 |  | 
| 971 1058 | 
             
                /* will divide into 3 branches: NA, within subset, outside subset */
         | 
| 972 1059 | 
             
                memset(buffer_split, 0, 3 * ncat_y * sizeof(size_t));
         | 
| @@ -993,6 +1080,19 @@ void split_categx_categy_subset(size_t *restrict ix_arr, size_t st, size_t end, | |
| 993 1080 | 
             
                    }
         | 
| 994 1081 | 
             
                }
         | 
| 995 1082 |  | 
| 1083 | 
            +
                int n_unique_cat = 0;
         | 
| 1084 | 
            +
                for (size_t cat = 0; cat < ncat_x; cat++) {
         | 
| 1085 | 
            +
                    n_unique_cat += buffer_cat_cnt[cat] > 0;
         | 
| 1086 | 
            +
                    if (n_unique_cat >= 3) break;
         | 
| 1087 | 
            +
                }
         | 
| 1088 | 
            +
                if (n_unique_cat <= 1) {
         | 
| 1089 | 
            +
                    *has_zero_variance = true;
         | 
| 1090 | 
            +
                    return;
         | 
| 1091 | 
            +
                }
         | 
| 1092 | 
            +
                if (n_unique_cat == 2) {
         | 
| 1093 | 
            +
                    *binary_split = true;
         | 
| 1094 | 
            +
                }
         | 
| 1095 | 
            +
             | 
| 996 1096 | 
             
                /* put all categories on the right branch */
         | 
| 997 1097 | 
             
                memset(split_info.left_branch,   0, ncat_y * sizeof(size_t));
         | 
| 998 1098 | 
             
                memset(split_info.right_branch,  0, ncat_y * sizeof(size_t));
         | 
| @@ -1012,6 +1112,10 @@ void split_categx_categy_subset(size_t *restrict ix_arr, size_t st, size_t end, | |
| 1012 1112 | 
             
                size_t last_bit;
         | 
| 1013 1113 | 
             
                size_t ncomb = pow2(ncat_x) - 1;
         | 
| 1014 1114 |  | 
| 1115 | 
            +
                /* TODO: this is highly inefficient:
         | 
| 1116 | 
            +
                   - categories with zero count can be discarded beforehand.
         | 
| 1117 | 
            +
                   - could use C++ next_permutation instead. */
         | 
| 1118 | 
            +
             | 
| 1015 1119 | 
             
                /* iteration is done by putting a category in the left branch if the bit at its
         | 
| 1016 1120 | 
             
                   position in the binary representation of the combination number is a 1 */
         | 
| 1017 1121 | 
             
                /* TODO: this would be faster with a depth-first search routine */
         | 
    
        metadata
    CHANGED
    
    | @@ -1,14 +1,14 @@ | |
| 1 1 | 
             
            --- !ruby/object:Gem::Specification
         | 
| 2 2 | 
             
            name: outliertree
         | 
| 3 3 | 
             
            version: !ruby/object:Gem::Version
         | 
| 4 | 
            -
              version: 0. | 
| 4 | 
            +
              version: 0.3.0
         | 
| 5 5 | 
             
            platform: ruby
         | 
| 6 6 | 
             
            authors:
         | 
| 7 7 | 
             
            - Andrew Kane
         | 
| 8 8 | 
             
            autorequire:
         | 
| 9 9 | 
             
            bindir: bin
         | 
| 10 10 | 
             
            cert_chain: []
         | 
| 11 | 
            -
            date:  | 
| 11 | 
            +
            date: 2022-06-13 00:00:00.000000000 Z
         | 
| 12 12 | 
             
            dependencies:
         | 
| 13 13 | 
             
            - !ruby/object:Gem::Dependency
         | 
| 14 14 | 
             
              name: rice
         | 
| @@ -16,14 +16,14 @@ dependencies: | |
| 16 16 | 
             
                requirements:
         | 
| 17 17 | 
             
                - - ">="
         | 
| 18 18 | 
             
                  - !ruby/object:Gem::Version
         | 
| 19 | 
            -
                    version:  | 
| 19 | 
            +
                    version: 4.0.2
         | 
| 20 20 | 
             
              type: :runtime
         | 
| 21 21 | 
             
              prerelease: false
         | 
| 22 22 | 
             
              version_requirements: !ruby/object:Gem::Requirement
         | 
| 23 23 | 
             
                requirements:
         | 
| 24 24 | 
             
                - - ">="
         | 
| 25 25 | 
             
                  - !ruby/object:Gem::Version
         | 
| 26 | 
            -
                    version:  | 
| 26 | 
            +
                    version: 4.0.2
         | 
| 27 27 | 
             
            description:
         | 
| 28 28 | 
             
            email: andrew@ankane.org
         | 
| 29 29 | 
             
            executables: []
         | 
| @@ -44,7 +44,8 @@ files: | |
| 44 44 | 
             
            - lib/outliertree/version.rb
         | 
| 45 45 | 
             
            - vendor/outliertree/LICENSE
         | 
| 46 46 | 
             
            - vendor/outliertree/README.md
         | 
| 47 | 
            -
            - vendor/outliertree/src/Makevars
         | 
| 47 | 
            +
            - vendor/outliertree/src/Makevars.in
         | 
| 48 | 
            +
            - vendor/outliertree/src/Makevars.win
         | 
| 48 49 | 
             
            - vendor/outliertree/src/RcppExports.cpp
         | 
| 49 50 | 
             
            - vendor/outliertree/src/Rwrapper.cpp
         | 
| 50 51 | 
             
            - vendor/outliertree/src/cat_outlier.cpp
         | 
| @@ -52,9 +53,10 @@ files: | |
| 52 53 | 
             
            - vendor/outliertree/src/fit_model.cpp
         | 
| 53 54 | 
             
            - vendor/outliertree/src/misc.cpp
         | 
| 54 55 | 
             
            - vendor/outliertree/src/outlier_tree.hpp
         | 
| 56 | 
            +
            - vendor/outliertree/src/outliertree-win.def
         | 
| 55 57 | 
             
            - vendor/outliertree/src/predict.cpp
         | 
| 56 58 | 
             
            - vendor/outliertree/src/split.cpp
         | 
| 57 | 
            -
            homepage: https://github.com/ankane/outliertree
         | 
| 59 | 
            +
            homepage: https://github.com/ankane/outliertree-ruby
         | 
| 58 60 | 
             
            licenses:
         | 
| 59 61 | 
             
            - GPL-3.0-or-later
         | 
| 60 62 | 
             
            metadata: {}
         | 
| @@ -66,14 +68,14 @@ required_ruby_version: !ruby/object:Gem::Requirement | |
| 66 68 | 
             
              requirements:
         | 
| 67 69 | 
             
              - - ">="
         | 
| 68 70 | 
             
                - !ruby/object:Gem::Version
         | 
| 69 | 
            -
                  version: '2. | 
| 71 | 
            +
                  version: '2.7'
         | 
| 70 72 | 
             
            required_rubygems_version: !ruby/object:Gem::Requirement
         | 
| 71 73 | 
             
              requirements:
         | 
| 72 74 | 
             
              - - ">="
         | 
| 73 75 | 
             
                - !ruby/object:Gem::Version
         | 
| 74 76 | 
             
                  version: '0'
         | 
| 75 77 | 
             
            requirements: []
         | 
| 76 | 
            -
            rubygems_version: 3. | 
| 78 | 
            +
            rubygems_version: 3.3.7
         | 
| 77 79 | 
             
            signing_key:
         | 
| 78 80 | 
             
            specification_version: 4
         | 
| 79 81 | 
             
            summary: Explainable outlier/anomaly detection for Ruby
         |