isotree 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,771 @@
1
+ /* Isolation forests and variations thereof, with adjustments for incorporation
2
+ * of categorical variables and missing values.
3
+ * Writen for C++11 standard and aimed at being used in R and Python.
4
+ *
5
+ * This library is based on the following works:
6
+ * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
+ * "Isolation forest."
8
+ * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
+ * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
+ * "Isolation-based anomaly detection."
11
+ * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
+ * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
+ * "Extended Isolation Forest."
14
+ * arXiv preprint arXiv:1811.02141 (2018).
15
+ * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
+ * "On detecting clustered anomalies using SCiForest."
17
+ * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
+ * [5] https://sourceforge.net/projects/iforest/
19
+ * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
+ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
+ * [8] Cortes, David. "Distance approximation using Isolation Forests." arXiv preprint arXiv:1910.12362 (2019).
22
+ * [9] Cortes, David. "Imputing missing values with unsupervised random trees." arXiv preprint arXiv:1911.06646 (2019).
23
+ *
24
+ * BSD 2-Clause License
25
+ * Copyright (c) 2019, David Cortes
26
+ * All rights reserved.
27
+ * Redistribution and use in source and binary forms, with or without
28
+ * modification, are permitted provided that the following conditions are met:
29
+ * * Redistributions of source code must retain the above copyright notice, this
30
+ * list of conditions and the following disclaimer.
31
+ * * Redistributions in binary form must reproduce the above copyright notice,
32
+ * this list of conditions and the following disclaimer in the documentation
33
+ * and/or other materials provided with the distribution.
34
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
38
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44
+ */
45
+ #include "isotree.hpp"
46
+
47
+ void split_itree_recursive(std::vector<IsoTree> &trees,
48
+ WorkerMemory &workspace,
49
+ InputData &input_data,
50
+ ModelParams &model_params,
51
+ std::vector<ImputeNode> *impute_nodes,
52
+ size_t curr_depth)
53
+ {
54
+ long double sum_weight = -HUGE_VAL;
55
+
56
+ /* calculate imputation statistics if desired */
57
+ if (impute_nodes != NULL)
58
+ {
59
+ if (input_data.Xc != NULL)
60
+ std::sort(workspace.ix_arr.begin() + workspace.st,
61
+ workspace.ix_arr.begin() + workspace.end + 1);
62
+ build_impute_node(impute_nodes->back(), workspace,
63
+ input_data, model_params,
64
+ *impute_nodes, curr_depth,
65
+ model_params.min_imp_obs);
66
+ }
67
+
68
+ /* check for potential isolated leafs */
69
+ if (workspace.end == workspace.st || curr_depth >= model_params.max_depth)
70
+ goto terminal_statistics;
71
+
72
+ /* with 2 observations and no weights, there's only 1 potential or assumed split */
73
+ if ((workspace.end - workspace.st) == 1 && !workspace.weights_arr.size() && !workspace.weights_map.size())
74
+ goto terminal_statistics;
75
+
76
+ /* when using weights, the split should stop when the sum of weights is <= 2 */
77
+ sum_weight = calculate_sum_weights(workspace.ix_arr, workspace.st, workspace.end, curr_depth,
78
+ workspace.weights_arr, workspace.weights_map);
79
+
80
+ if (curr_depth > 0 && (workspace.weights_arr.size() || workspace.weights_map.size()) && sum_weight < 2.5)
81
+ goto terminal_statistics;
82
+
83
+ /* for sparse matrices, need to sort the indices */
84
+ if (input_data.Xc != NULL && impute_nodes == NULL)
85
+ std::sort(workspace.ix_arr.begin() + workspace.st, workspace.ix_arr.begin() + workspace.end + 1);
86
+
87
+ /* pick column to split according to criteria */
88
+ workspace.prob_split_type = workspace.rbin(workspace.rnd_generator);
89
+
90
+ /* case1: guided, pick column and point with best gain */
91
+ if (
92
+ workspace.prob_split_type
93
+ < (
94
+ model_params.prob_pick_by_gain_avg +
95
+ model_params.prob_pick_by_gain_pl
96
+ )
97
+ )
98
+ {
99
+ workspace.determine_split = false;
100
+
101
+ /* case 1.1: column is decided by averaged gain */
102
+ if (workspace.prob_split_type < model_params.prob_pick_by_gain_avg)
103
+ workspace.criterion = Averaged;
104
+
105
+ /* case 1.2: column is decided by pooled gain */
106
+ else
107
+ workspace.criterion = Pooled;
108
+
109
+ /* evaluate gain for all columns */
110
+ trees.back().score = -HUGE_VAL; /* this is used to track the best gain */
111
+ if (input_data.Xc == NULL)
112
+ {
113
+ for (size_t col = 0; col < input_data.ncols_numeric; col++)
114
+ {
115
+ workspace.this_gain = eval_guided_crit(workspace.ix_arr.data(), workspace.st, workspace.end,
116
+ input_data.numeric_data + col * input_data.nrows,
117
+ workspace.split_ix, workspace.this_split_point,
118
+ workspace.xmin, workspace.xmax,
119
+ workspace.criterion, model_params.min_gain,
120
+ model_params.missing_action);
121
+ if (workspace.this_gain <= -HUGE_VAL)
122
+ {
123
+ workspace.cols_possible[col] = false;
124
+ }
125
+
126
+ else if (workspace.this_gain > trees.back().score)
127
+ {
128
+ trees.back().score = workspace.this_gain;
129
+ trees.back().col_num = col;
130
+ trees.back().num_split = workspace.this_split_point;
131
+ if (model_params.penalize_range)
132
+ {
133
+ trees.back().range_low = workspace.xmin - workspace.xmax + trees.back().num_split;
134
+ trees.back().range_high = workspace.xmax - workspace.xmin + trees.back().num_split;
135
+ }
136
+ }
137
+ }
138
+
139
+ }
140
+
141
+ else
142
+ {
143
+ for (size_t col = 0; col < input_data.ncols_numeric; col++)
144
+ {
145
+ workspace.this_gain = eval_guided_crit(workspace.ix_arr.data(), workspace.st, workspace.end,
146
+ col, input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
147
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(),
148
+ workspace.this_split_point, workspace.xmin, workspace.xmax,
149
+ workspace.criterion, model_params.min_gain, model_params.missing_action);
150
+ if (workspace.this_gain <= -HUGE_VAL)
151
+ {
152
+ workspace.cols_possible[col] = false;
153
+ }
154
+
155
+ else if (workspace.this_gain > trees.back().score)
156
+ {
157
+ trees.back().score = workspace.this_gain;
158
+ trees.back().col_num = col;
159
+ trees.back().num_split = workspace.this_split_point;
160
+ if (model_params.penalize_range)
161
+ {
162
+ trees.back().range_low = workspace.xmin - workspace.xmax + trees.back().num_split;
163
+ trees.back().range_high = workspace.xmax - workspace.xmin + trees.back().num_split;
164
+ }
165
+ }
166
+ }
167
+ }
168
+
169
+ for (size_t col = 0; col < input_data.ncols_categ; col++)
170
+ {
171
+ workspace.this_gain = eval_guided_crit(workspace.ix_arr.data(), workspace.st, workspace.end,
172
+ input_data.categ_data + col * input_data.nrows, input_data.ncat[col],
173
+ workspace.buffer_szt.data(), workspace.buffer_szt.data() + input_data.max_categ,
174
+ workspace.buffer_dbl.data(), workspace.this_categ, workspace.this_split_categ.data(),
175
+ workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
176
+ model_params.all_perm, model_params.missing_action, model_params.cat_split_type);
177
+ if (workspace.this_gain <= -HUGE_VAL)
178
+ {
179
+ workspace.cols_possible[col + input_data.ncols_numeric] = false;
180
+ }
181
+
182
+ else if (workspace.this_gain > trees.back().score)
183
+ {
184
+ trees.back().score = workspace.this_gain;
185
+ trees.back().col_num = col + input_data.ncols_numeric;
186
+ switch(model_params.cat_split_type)
187
+ {
188
+ case SingleCateg:
189
+ {
190
+ trees.back().chosen_cat = workspace.this_categ;
191
+ break;
192
+ }
193
+
194
+ case SubSet:
195
+ {
196
+ trees.back().cat_split.assign(workspace.this_split_categ.begin(),
197
+ workspace.this_split_categ.begin() + input_data.ncat[col]);
198
+ break;
199
+ }
200
+ }
201
+ }
202
+ }
203
+
204
+
205
+ if (trees.back().score <= 0.)
206
+ goto terminal_statistics;
207
+ else
208
+ trees.back().score = 0.;
209
+
210
+ if (trees.back().col_num < input_data.ncols_numeric)
211
+ {
212
+ trees.back().col_type = Numeric;
213
+ }
214
+
215
+ else
216
+ {
217
+ trees.back().col_type = Categorical;
218
+ trees.back().col_num -= input_data.ncols_numeric;
219
+ }
220
+ }
221
+
222
+ /* case2: column is chosen at random */
223
+ else
224
+ {
225
+ workspace.determine_split = true;
226
+
227
+ /* case 2.1: split point is chosen according to gain (averaged) */
228
+ if (
229
+ workspace.prob_split_type
230
+ < (
231
+ model_params.prob_pick_by_gain_avg +
232
+ model_params.prob_pick_by_gain_pl +
233
+ model_params.prob_split_by_gain_avg
234
+ )
235
+ )
236
+ workspace.criterion = Averaged;
237
+
238
+ /* case 2.2: split point is chosen according to gain (pooled) */
239
+ else if (
240
+ workspace.prob_split_type
241
+ < (
242
+ model_params.prob_pick_by_gain_avg +
243
+ model_params.prob_pick_by_gain_pl +
244
+ model_params.prob_split_by_gain_avg +
245
+ model_params.prob_split_by_gain_pl
246
+ )
247
+ )
248
+ workspace.criterion = Pooled;
249
+
250
+ /* case 2.3: split point is chosen randomly (like in the original paper) */
251
+ else
252
+ workspace.criterion = NoCrit;
253
+
254
+
255
+ /* pick column at random */
256
+ decide_column(input_data.ncols_numeric, input_data.ncols_categ,
257
+ trees.back().col_num, trees.back().col_type,
258
+ workspace.rnd_generator, workspace.runif,
259
+ workspace.col_sampler);
260
+
261
+ /* get the range of possible splits */
262
+ get_split_range(workspace, input_data, model_params, trees.back());
263
+
264
+ /* if it's not possible to split, will have to try more */
265
+ if (workspace.unsplittable)
266
+ {
267
+ /* keep track of which columns are tried */
268
+ add_unsplittable_col(workspace, trees.back(), input_data);
269
+
270
+ /* try more random columns for {(1/2) * ncols} times */
271
+ workspace.ncols_tried = 1;
272
+ do
273
+ {
274
+ decide_column(input_data.ncols_numeric, input_data.ncols_categ,
275
+ trees.back().col_num, trees.back().col_type,
276
+ workspace.rnd_generator, workspace.runif,
277
+ workspace.col_sampler);
278
+ if (!check_is_not_unsplittable_col(workspace, trees.back(), input_data))
279
+ {
280
+ get_split_range(workspace, input_data, model_params, trees.back());
281
+ if (!workspace.unsplittable)
282
+ break;
283
+ else
284
+ add_unsplittable_col(workspace, trees.back(), input_data);
285
+ }
286
+ workspace.ncols_tried++;
287
+ }
288
+ while (workspace.ncols_tried < input_data.ncols_tot / 2);
289
+
290
+ /* if that didn't work, then check all the columns that are still splittable */
291
+ if (workspace.unsplittable)
292
+ {
293
+ workspace.ncols_tried = 0; /* note: this is used here as a counter for the number of still splittable columns */
294
+ if (input_data.Xc == NULL)
295
+ {
296
+ for (size_t col = 0; col < input_data.ncols_numeric; col++)
297
+ {
298
+ if (!workspace.cols_possible[col]) continue;
299
+ get_range(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * col,
300
+ workspace.st, workspace.end, model_params.missing_action,
301
+ workspace.xmin, workspace.xmax, workspace.unsplittable);
302
+ workspace.cols_possible[col] = !workspace.unsplittable;
303
+ workspace.ncols_tried += !workspace.unsplittable;
304
+ }
305
+ }
306
+
307
+ else
308
+ {
309
+ for (size_t col = 0; col < input_data.ncols_numeric; col++)
310
+ {
311
+ if (!workspace.cols_possible[col]) continue;
312
+ get_range(workspace.ix_arr.data(), workspace.st, workspace.end, col,
313
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
314
+ model_params.missing_action, workspace.xmin, workspace.xmax, workspace.unsplittable);
315
+ workspace.cols_possible[col] = !workspace.unsplittable;
316
+ workspace.ncols_tried += !workspace.unsplittable;
317
+ }
318
+ }
319
+
320
+ for (size_t col = 0; col < input_data.ncols_categ; col++)
321
+ {
322
+ if (!workspace.cols_possible[col + input_data.ncols_numeric]) continue;
323
+ get_categs(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * col,
324
+ workspace.st, workspace.end, input_data.ncat[col],
325
+ model_params.missing_action, workspace.categs.data(), workspace.npresent, workspace.unsplittable);
326
+ workspace.cols_possible[col + input_data.ncols_numeric] = !workspace.unsplittable;
327
+ workspace.ncols_tried += !workspace.unsplittable;
328
+ }
329
+
330
+
331
+ /* if no further splits are possible, end the procedure here */
332
+ workspace.npresent = workspace.ncols_tried;
333
+ if (!workspace.npresent) goto terminal_statistics;
334
+
335
+ /* otherwise, pick a column at random from the possible ones */
336
+ if (!workspace.col_sampler.max())
337
+ {
338
+ /* no weights by columns */
339
+ trees.back().col_num = std::uniform_int_distribution<size_t>
340
+ (0, workspace.npresent - 1)
341
+ (workspace.rnd_generator);
342
+ workspace.ncols_tried = 0;
343
+ for (size_t col = 0; col < input_data.ncols_tot; col++)
344
+ {
345
+
346
+ if (workspace.cols_possible[col])
347
+ {
348
+ if (workspace.ncols_tried == trees.back().col_num)
349
+ {
350
+ if (col < input_data.ncols_numeric)
351
+ {
352
+ trees.back().col_num = col;
353
+ trees.back().col_type = Numeric;
354
+ }
355
+
356
+ else
357
+ {
358
+ trees.back().col_num = col - input_data.ncols_numeric;
359
+ trees.back().col_type = Categorical;
360
+ }
361
+ break;
362
+ }
363
+ workspace.ncols_tried++;
364
+ }
365
+
366
+ }
367
+ }
368
+
369
+ else
370
+ {
371
+ /* weights by columns */
372
+ std::vector<double> col_weights = workspace.col_sampler.probabilities();
373
+ update_col_sampler(workspace, input_data);
374
+
375
+ decide_column(input_data.ncols_numeric, input_data.ncols_categ,
376
+ trees.back().col_num, trees.back().col_type,
377
+ workspace.rnd_generator, workspace.runif,
378
+ workspace.col_sampler);
379
+ }
380
+
381
+ }
382
+
383
+ /* finally, check the range if needed, and later decide on the split point */
384
+ if (workspace.criterion == NoCrit)
385
+ get_split_range(workspace, input_data, model_params, trees.back());
386
+
387
+ }
388
+
389
+ }
390
+
391
+
392
+ /* for numeric, choose a random point, or pick the best point as determined earlier */
393
+ if (trees.back().col_type == Numeric)
394
+ {
395
+ if (workspace.determine_split)
396
+ {
397
+ switch(workspace.criterion)
398
+ {
399
+ case NoCrit:
400
+ {
401
+ trees.back().num_split = std::uniform_real_distribution<double>
402
+ (workspace.xmin, workspace.xmax)
403
+ (workspace.rnd_generator);
404
+ break;
405
+ }
406
+
407
+ default:
408
+ {
409
+ if (input_data.Xc == NULL)
410
+ {
411
+ eval_guided_crit(workspace.ix_arr.data(), workspace.st, workspace.end,
412
+ input_data.numeric_data + trees.back().col_num * input_data.nrows,
413
+ workspace.split_ix, trees.back().num_split,
414
+ workspace.xmin, workspace.xmax,
415
+ workspace.criterion, model_params.min_gain,
416
+ model_params.missing_action);
417
+ if (model_params.missing_action == Fail) /* data is already split */
418
+ {
419
+ workspace.split_ix++;
420
+ goto follow_branches;
421
+ }
422
+ }
423
+
424
+ else
425
+ {
426
+ eval_guided_crit(workspace.ix_arr.data(), workspace.st, workspace.end,
427
+ trees.back().col_num, input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
428
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(),
429
+ trees.back().num_split, workspace.xmin, workspace.xmax,
430
+ workspace.criterion, model_params.min_gain,
431
+ model_params.missing_action);
432
+ }
433
+ break;
434
+ }
435
+ }
436
+
437
+ if (model_params.penalize_range)
438
+ {
439
+ trees.back().range_low = workspace.xmin - workspace.xmax + trees.back().num_split;
440
+ trees.back().range_high = workspace.xmax - workspace.xmin + trees.back().num_split;
441
+ }
442
+ }
443
+
444
+ if (input_data.Xc == NULL)
445
+ divide_subset_split(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * trees.back().col_num,
446
+ workspace.st, workspace.end, trees.back().num_split, model_params.missing_action,
447
+ workspace.st_NA, workspace.end_NA, workspace.split_ix);
448
+ else
449
+ divide_subset_split(workspace.ix_arr.data(), workspace.st, workspace.end, trees.back().col_num,
450
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr, trees.back().num_split,
451
+ model_params.missing_action, workspace.st_NA, workspace.end_NA, workspace.split_ix);
452
+ }
453
+
454
+ /* for categorical, there are different ways of splitting */
455
+ else
456
+ {
457
+ /* if the columns is binary, there's only one possible split */
458
+ if (input_data.ncat[trees.back().col_num] <= 2)
459
+ {
460
+ trees.back().chosen_cat = 0;
461
+ divide_subset_split(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * trees.back().col_num,
462
+ workspace.st, workspace.end, (int)0, model_params.missing_action,
463
+ workspace.st_NA, workspace.end_NA, workspace.split_ix);
464
+ trees.back().cat_split.clear();
465
+ trees.back().cat_split.shrink_to_fit();
466
+ }
467
+
468
+ /* otherwise, split according to desired type (single/subset) */
469
+ /* TODO: refactor this */
470
+ else
471
+ {
472
+
473
+ switch(model_params.cat_split_type)
474
+ {
475
+
476
+ case SingleCateg:
477
+ {
478
+
479
+ if (workspace.determine_split)
480
+ {
481
+ switch(workspace.criterion)
482
+ {
483
+ case NoCrit:
484
+ {
485
+ trees.back().chosen_cat = choose_cat_from_present(workspace, input_data, trees.back().col_num);
486
+ break;
487
+ }
488
+
489
+ default:
490
+ {
491
+ eval_guided_crit(workspace.ix_arr.data(), workspace.st, workspace.end,
492
+ input_data.categ_data + trees.back().col_num * input_data.nrows, input_data.ncat[trees.back().col_num],
493
+ workspace.buffer_szt.data(), workspace.buffer_szt.data() + input_data.max_categ,
494
+ workspace.buffer_dbl.data(), trees.back().chosen_cat, workspace.this_split_categ.data(),
495
+ workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
496
+ model_params.all_perm, model_params.missing_action, model_params.cat_split_type);
497
+ break;
498
+ }
499
+ }
500
+ }
501
+
502
+
503
+ divide_subset_split(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * trees.back().col_num,
504
+ workspace.st, workspace.end, trees.back().chosen_cat, model_params.missing_action,
505
+ workspace.st_NA, workspace.end_NA, workspace.split_ix);
506
+ break;
507
+ }
508
+
509
+
510
+ case SubSet:
511
+ {
512
+
513
+ if (workspace.determine_split)
514
+ {
515
+ switch(workspace.criterion)
516
+ {
517
+ case NoCrit:
518
+ {
519
+ workspace.unsplittable = true;
520
+ while(workspace.unsplittable)
521
+ {
522
+ workspace.npresent = 0;
523
+ workspace.ncols_tried = 0;
524
+ for (int cat = 0; cat < input_data.ncat[trees.back().col_num]; cat++)
525
+ {
526
+ if (workspace.categs[cat] >= 0)
527
+ {
528
+ workspace.categs[cat] = workspace.rbin(workspace.rnd_generator) < 0.5;
529
+ workspace.npresent += workspace.categs[cat];
530
+ workspace.ncols_tried += !workspace.categs[cat];
531
+ }
532
+ workspace.unsplittable = !(workspace.npresent && workspace.ncols_tried);
533
+ }
534
+ }
535
+
536
+ trees.back().cat_split.assign(workspace.categs.begin(), workspace.categs.begin() + input_data.ncat[trees.back().col_num]);
537
+ break; /* NoCrit */
538
+ }
539
+
540
+ default:
541
+ {
542
+ trees.back().cat_split.resize(input_data.ncat[trees.back().col_num]);
543
+ eval_guided_crit(workspace.ix_arr.data(), workspace.st, workspace.end,
544
+ input_data.categ_data + trees.back().col_num * input_data.nrows, input_data.ncat[trees.back().col_num],
545
+ workspace.buffer_szt.data(), workspace.buffer_szt.data() + input_data.max_categ,
546
+ workspace.buffer_dbl.data(), trees.back().chosen_cat, trees.back().cat_split.data(),
547
+ workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
548
+ model_params.all_perm, model_params.missing_action, model_params.cat_split_type);
549
+ break;
550
+ }
551
+ }
552
+ }
553
+
554
+ if (model_params.new_cat_action == Random)
555
+ for (int cat = 0; cat < input_data.ncat[trees.back().col_num]; cat++)
556
+ if (trees.back().cat_split[cat] < 0)
557
+ trees.back().cat_split[cat] = workspace.rbin(workspace.rnd_generator) < 0.5;
558
+
559
+ divide_subset_split(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * trees.back().col_num,
560
+ workspace.st, workspace.end, trees.back().cat_split.data(), model_params.missing_action,
561
+ workspace.st_NA, workspace.end_NA, workspace.split_ix);
562
+ }
563
+
564
+ }
565
+
566
+ }
567
+
568
+ }
569
+
570
+
571
+ /* if it hasn't reached the limit, continue splitting from here */
572
+ follow_branches:
573
+ {
574
+ /* add another round of separation depth for distance */
575
+ if (model_params.calc_dist && curr_depth > 0)
576
+ add_separation_step(workspace, input_data, (double)(-1));
577
+
578
+ size_t tree_from = trees.size() - 1;
579
+ size_t ix2, ix3;
580
+ std::unique_ptr<std::vector<bool>> cols_possible_ptr;
581
+ std::unique_ptr<std::discrete_distribution<size_t>> col_sampler_ptr;
582
+ trees.back().score = -1;
583
+
584
+ /* compute statistics for NAs and remember recursion indices/weights */
585
+ std::unique_ptr<RecursionState> recursion_state;
586
+ if (model_params.missing_action != Fail)
587
+ {
588
+ recursion_state = std::unique_ptr<RecursionState>(new RecursionState);
589
+ backup_recursion_state(workspace, *recursion_state);
590
+
591
+ trees.back().pct_tree_left = (long double)(workspace.st_NA - workspace.st)
592
+ /
593
+ (long double)(workspace.end - workspace.st + 1 - (workspace.end_NA - workspace.st_NA));
594
+
595
+ switch(model_params.missing_action)
596
+ {
597
+ case Impute:
598
+ {
599
+ if (trees.back().pct_tree_left >= .5)
600
+ workspace.end = workspace.end_NA - 1;
601
+ else
602
+ workspace.end = workspace.st_NA - 1;
603
+ break;
604
+ }
605
+
606
+
607
+ case Divide:
608
+ {
609
+ if (workspace.weights_map.size())
610
+ for (size_t row = workspace.st_NA; row < workspace.end_NA; row++)
611
+ workspace.weights_map[workspace.ix_arr[row]] *= trees.back().pct_tree_left;
612
+ else
613
+ for (size_t row = workspace.st_NA; row < workspace.end_NA; row++)
614
+ workspace.weights_arr[workspace.ix_arr[row]] *= trees.back().pct_tree_left;
615
+ workspace.end = workspace.end_NA - 1;
616
+ break;
617
+ }
618
+ }
619
+ }
620
+
621
+ else
622
+ {
623
+ trees.back().pct_tree_left = (long double) (workspace.split_ix - workspace.st)
624
+ /
625
+ (long double) (workspace.end - workspace.st + 1);
626
+
627
+ ix2 = workspace.split_ix;
628
+ ix3 = workspace.end;
629
+ cols_possible_ptr = std::unique_ptr<std::vector<bool>>(new std::vector<bool>);
630
+ *cols_possible_ptr = workspace.cols_possible;
631
+ if (workspace.col_sampler.max())
632
+ {
633
+ col_sampler_ptr = std::unique_ptr<std::discrete_distribution<size_t>>(new std::discrete_distribution<size_t>);
634
+ *col_sampler_ptr = workspace.col_sampler;
635
+ }
636
+ workspace.end = workspace.split_ix - 1;
637
+ }
638
+
639
+ /* Branch where to assign new categories can be pre-determined in this case */
640
+ if (
641
+ trees.back().col_type == Categorical &&
642
+ model_params.cat_split_type == SubSet &&
643
+ input_data.ncat[trees.back().col_num] > 2 &&
644
+ model_params.new_cat_action == Smallest
645
+ )
646
+ {
647
+ bool new_to_left = trees.back().pct_tree_left < 0.5;
648
+ for (int cat = 0; cat < input_data.ncat[trees.back().col_num]; cat++)
649
+ if (trees.back().cat_split[cat] < 0)
650
+ trees.back().cat_split[cat] = new_to_left;
651
+ }
652
+
653
+ /* left branch */
654
+ trees.back().tree_left = trees.size();
655
+ trees.emplace_back();
656
+ if (impute_nodes != NULL) impute_nodes->emplace_back(tree_from);
657
+ split_itree_recursive(trees,
658
+ workspace,
659
+ input_data,
660
+ model_params,
661
+ impute_nodes,
662
+ curr_depth + 1);
663
+
664
+
665
+ /* right branch */
666
+ if (model_params.missing_action != Fail)
667
+ {
668
+ restore_recursion_state(workspace, *recursion_state);
669
+
670
+ switch(model_params.missing_action)
671
+ {
672
+ case Impute:
673
+ {
674
+ if (trees[tree_from].pct_tree_left >= .5)
675
+ workspace.st = workspace.end_NA;
676
+ else
677
+ workspace.st = workspace.st_NA;
678
+ break;
679
+ }
680
+
681
+ case Divide:
682
+ {
683
+ if (workspace.weights_map.size())
684
+ for (size_t row = workspace.st_NA; row < workspace.end_NA; row++)
685
+ workspace.weights_map[workspace.ix_arr[row]] *= (1 - trees[tree_from].pct_tree_left);
686
+ else
687
+ for (size_t row = workspace.st_NA; row < workspace.end_NA; row++)
688
+ workspace.weights_arr[workspace.ix_arr[row]] *= (1 - trees[tree_from].pct_tree_left);
689
+ workspace.st = workspace.st_NA;
690
+ break;
691
+ }
692
+ }
693
+ }
694
+
695
+ else
696
+ {
697
+ workspace.st = ix2;
698
+ workspace.end = ix3;
699
+ workspace.cols_possible = std::move(*cols_possible_ptr);
700
+ if (col_sampler_ptr)
701
+ workspace.col_sampler = std::move(*col_sampler_ptr);
702
+ }
703
+
704
+ trees[tree_from].tree_right = trees.size();
705
+ trees.emplace_back();
706
+ if (impute_nodes != NULL) impute_nodes->emplace_back(tree_from);
707
+ split_itree_recursive(trees,
708
+ workspace,
709
+ input_data,
710
+ model_params,
711
+ impute_nodes,
712
+ curr_depth + 1);
713
+ }
714
+ return;
715
+
716
+ /* if it reached the limit, calculate terminal statistics */
717
+ terminal_statistics:
718
+ {
719
+ if (!workspace.weights_arr.size() && !workspace.weights_map.size())
720
+ {
721
+ trees.back().score = (double)(curr_depth + expected_avg_depth(workspace.end - workspace.st + 1));
722
+ }
723
+
724
+ else
725
+ {
726
+ if (sum_weight == -HUGE_VAL)
727
+ sum_weight = calculate_sum_weights(workspace.ix_arr, workspace.st, workspace.end, curr_depth,
728
+ workspace.weights_arr, workspace.weights_map);
729
+ trees.back().score = (double)(curr_depth + expected_avg_depth(sum_weight));
730
+ }
731
+
732
+ trees.back().cat_split.clear();
733
+ trees.back().cat_split.shrink_to_fit();
734
+
735
+ trees.back().remainder = workspace.weights_arr.size()?
736
+ sum_weight : (workspace.weights_map.size()?
737
+ sum_weight : ((double)(workspace.end - workspace.st + 1))
738
+ );
739
+
740
+ /* for distance, assume also the elements keep being split */
741
+ if (model_params.calc_dist)
742
+ add_remainder_separation_steps(workspace, input_data, sum_weight);
743
+
744
+ /* add this depth right away if requested */
745
+ if (workspace.row_depths.size())
746
+ {
747
+ if (!workspace.weights_arr.size() && !workspace.weights_map.size())
748
+ {
749
+ for (size_t row = workspace.st; row <= workspace.end; row++)
750
+ workspace.row_depths[workspace.ix_arr[row]] += trees.back().score;
751
+ }
752
+
753
+ else if (workspace.weights_arr.size())
754
+ {
755
+ for (size_t row = workspace.st; row <= workspace.end; row++)
756
+ workspace.row_depths[workspace.ix_arr[row]] += workspace.weights_arr[workspace.ix_arr[row]] * trees.back().score;
757
+ }
758
+
759
+ else
760
+ {
761
+ for (size_t row = workspace.st; row <= workspace.end; row++)
762
+ workspace.row_depths[workspace.ix_arr[row]] += workspace.weights_map[workspace.ix_arr[row]] * trees.back().score;
763
+ }
764
+ }
765
+
766
+ /* add imputations from node if requested */
767
+ if (model_params.impute_at_fit)
768
+ add_from_impute_node(impute_nodes->back(), workspace, input_data);
769
+ }
770
+
771
+ }