isotree 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,790 @@
1
+ /* Isolation forests and variations thereof, with adjustments for incorporation
2
+ * of categorical variables and missing values.
3
+ * Writen for C++11 standard and aimed at being used in R and Python.
4
+ *
5
+ * This library is based on the following works:
6
+ * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
+ * "Isolation forest."
8
+ * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
+ * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
+ * "Isolation-based anomaly detection."
11
+ * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
+ * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
+ * "Extended Isolation Forest."
14
+ * arXiv preprint arXiv:1811.02141 (2018).
15
+ * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
+ * "On detecting clustered anomalies using SCiForest."
17
+ * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
+ * [5] https://sourceforge.net/projects/iforest/
19
+ * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
+ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
+ * [8] Cortes, David. "Distance approximation using Isolation Forests." arXiv preprint arXiv:1910.12362 (2019).
22
+ * [9] Cortes, David. "Imputing missing values with unsupervised random trees." arXiv preprint arXiv:1911.06646 (2019).
23
+ *
24
+ * BSD 2-Clause License
25
+ * Copyright (c) 2019, David Cortes
26
+ * All rights reserved.
27
+ * Redistribution and use in source and binary forms, with or without
28
+ * modification, are permitted provided that the following conditions are met:
29
+ * * Redistributions of source code must retain the above copyright notice, this
30
+ * list of conditions and the following disclaimer.
31
+ * * Redistributions in binary form must reproduce the above copyright notice,
32
+ * this list of conditions and the following disclaimer in the documentation
33
+ * and/or other materials provided with the distribution.
34
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
38
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44
+ */
45
+ #include "isotree.hpp"
46
+
47
+ void split_hplane_recursive(std::vector<IsoHPlane> &hplanes,
48
+ WorkerMemory &workspace,
49
+ InputData &input_data,
50
+ ModelParams &model_params,
51
+ std::vector<ImputeNode> *impute_nodes,
52
+ size_t curr_depth)
53
+ {
54
+ long double sum_weight = -HUGE_VAL;
55
+ size_t hplane_from = hplanes.size() - 1;
56
+ std::unique_ptr<RecursionState> recursion_state;
57
+ std::vector<bool> col_is_taken;
58
+ std::unordered_set<size_t> col_is_taken_s;
59
+
60
+ /* calculate imputation statistics if desired */
61
+ if (impute_nodes != NULL)
62
+ {
63
+ if (input_data.Xc != NULL)
64
+ std::sort(workspace.ix_arr.begin() + workspace.st,
65
+ workspace.ix_arr.begin() + workspace.end + 1);
66
+ build_impute_node(impute_nodes->back(), workspace,
67
+ input_data, model_params,
68
+ *impute_nodes, curr_depth,
69
+ model_params.min_imp_obs);
70
+ }
71
+
72
+ /* check for potential isolated leafs */
73
+ if (workspace.end == workspace.st || curr_depth >= model_params.max_depth)
74
+ goto terminal_statistics;
75
+
76
+ /* with 2 observations and no weights, there's only 1 potential or assumed split */
77
+ if ((workspace.end - workspace.st) == 1 && !workspace.weights_arr.size() && !workspace.weights_map.size())
78
+ goto terminal_statistics;
79
+
80
+ /* when using weights, the split should stop when the sum of weights is <= 2 */
81
+ sum_weight = calculate_sum_weights(workspace.ix_arr, workspace.st, workspace.end, curr_depth,
82
+ workspace.weights_arr, workspace.weights_map);
83
+
84
+ if (curr_depth > 0 && (workspace.weights_arr.size() || workspace.weights_map.size()) && sum_weight < 2.5)
85
+ goto terminal_statistics;
86
+
87
+ /* for sparse matrices, need to sort the indices */
88
+ if (input_data.Xc != NULL && impute_nodes == NULL)
89
+ std::sort(workspace.ix_arr.begin() + workspace.st, workspace.ix_arr.begin() + workspace.end + 1);
90
+
91
+ /* pick column to split according to criteria */
92
+ workspace.prob_split_type = workspace.rbin(workspace.rnd_generator);
93
+
94
+ if (
95
+ workspace.prob_split_type
96
+ < (
97
+ model_params.prob_pick_by_gain_avg +
98
+ model_params.prob_pick_by_gain_pl
99
+ )
100
+ )
101
+ {
102
+ workspace.ntry = model_params.ntry;
103
+ hplanes.back().score = -HUGE_VAL; /* this keeps track of the gain */
104
+ if (workspace.prob_split_type < model_params.prob_pick_by_gain_avg)
105
+ workspace.criterion = Averaged;
106
+ else
107
+ workspace.criterion = Pooled;
108
+ }
109
+
110
+ else
111
+ {
112
+ workspace.criterion = NoCrit;
113
+ workspace.ntry = 1;
114
+ }
115
+
116
+ workspace.ntaken_best = 0;
117
+
118
+ for (size_t attempt = 0; attempt < workspace.ntry; attempt++)
119
+ {
120
+ if (input_data.ncols_tot < 1e3)
121
+ {
122
+ if (!col_is_taken.size())
123
+ col_is_taken.resize(input_data.ncols_tot, false);
124
+ else
125
+ col_is_taken.assign(input_data.ncols_tot, false);
126
+ }
127
+ else
128
+ col_is_taken_s.clear();
129
+ workspace.ntaken = 0;
130
+ workspace.ncols_tried = 0;
131
+ std::fill(workspace.comb_val.begin(),
132
+ workspace.comb_val.begin() + (workspace.end - workspace.st + 1),
133
+ (double)0);
134
+
135
+ workspace.tried_all = false;
136
+ if (model_params.ndim < input_data.ncols_tot / 2 || workspace.col_sampler.max())
137
+ {
138
+ while(workspace.ncols_tried < std::max(input_data.ncols_tot / 2, model_params.ndim))
139
+ {
140
+ workspace.ncols_tried++;
141
+ decide_column(input_data.ncols_numeric, input_data.ncols_categ,
142
+ workspace.col_chosen, workspace.col_type,
143
+ workspace.rnd_generator, workspace.runif,
144
+ workspace.col_sampler);
145
+
146
+ if (
147
+ (workspace.col_type == Numeric && !workspace.cols_possible[workspace.col_chosen])
148
+ ||
149
+ (workspace.col_type == Categorical && !workspace.cols_possible[workspace.col_chosen + input_data.ncols_numeric])
150
+ ||
151
+ is_col_taken(col_is_taken, col_is_taken_s, input_data, workspace.col_chosen, workspace.col_type)
152
+ )
153
+ continue;
154
+
155
+
156
+ get_split_range(workspace, input_data, model_params);
157
+ if (workspace.unsplittable)
158
+ {
159
+ add_unsplittable_col(workspace, input_data);
160
+ }
161
+
162
+ else
163
+ {
164
+ add_chosen_column(workspace, input_data, model_params, col_is_taken, col_is_taken_s);
165
+ if (++workspace.ntaken >= model_params.ndim)
166
+ break;
167
+ }
168
+
169
+ }
170
+
171
+ if (workspace.ntaken < model_params.ndim)
172
+ {
173
+ update_col_sampler(workspace, input_data);
174
+ goto probe_all;
175
+ }
176
+ }
177
+
178
+ else /* probe all columns */
179
+ {
180
+ probe_all:
181
+ workspace.tried_all = true;
182
+ std::iota(workspace.cols_shuffled.begin(), workspace.cols_shuffled.end(), (size_t)0);
183
+ if (model_params.ndim < input_data.ncols_tot)
184
+ {
185
+
186
+ if (!workspace.col_sampler.max())
187
+ {
188
+ std::shuffle(workspace.cols_shuffled.begin(),
189
+ workspace.cols_shuffled.end(),
190
+ workspace.rnd_generator);
191
+ }
192
+
193
+ else
194
+ {
195
+ if (!model_params.weigh_by_kurt)
196
+ {
197
+ weighted_shuffle(workspace.cols_shuffled.data(), input_data.ncols_tot, input_data.col_weights,
198
+ workspace.buffer_dbl.data(), workspace.rnd_generator);
199
+ }
200
+
201
+ else
202
+ {
203
+ std::vector<double> col_weights = workspace.col_sampler.probabilities();
204
+ /* sampler will fail if passed weights of zero, so need to discard those first and then remap */
205
+ std::iota(workspace.buffer_szt.begin(), workspace.buffer_szt.begin() + input_data.ncols_tot, (size_t)0);
206
+ long st = input_data.ncols_tot - 1;
207
+ for (long col = st; col >= 0; col--)
208
+ {
209
+ if (col_weights[col] <= 0)
210
+ {
211
+ std::swap(col_weights[st], col_weights[col]);
212
+ std::swap(workspace.buffer_szt[st], workspace.buffer_szt[col]);
213
+ st--;
214
+ }
215
+ }
216
+
217
+ if ((size_t)st == input_data.ncols_tot - 1)
218
+ {
219
+ weighted_shuffle(workspace.cols_shuffled.data(), input_data.ncols_tot, col_weights.data(),
220
+ workspace.buffer_dbl.data(), workspace.rnd_generator);
221
+ }
222
+
223
+ else if (st < 0)
224
+ {
225
+ goto terminal_statistics;
226
+ }
227
+
228
+ else if (st == 0)
229
+ {
230
+ std::copy(workspace.buffer_szt.begin(),
231
+ workspace.buffer_szt.begin() + input_data.ncols_tot,
232
+ workspace.cols_shuffled.begin());
233
+ }
234
+
235
+ else
236
+ {
237
+ weighted_shuffle(workspace.buffer_szt.data(), (size_t) ++st, col_weights.data(),
238
+ workspace.buffer_dbl.data(), workspace.rnd_generator);
239
+ std::copy(workspace.buffer_szt.begin(),
240
+ workspace.buffer_szt.begin() + input_data.ncols_tot,
241
+ workspace.cols_shuffled.begin());
242
+ }
243
+ }
244
+ }
245
+ }
246
+
247
+ for (size_t col : workspace.cols_shuffled)
248
+ {
249
+ if (
250
+ !workspace.cols_possible[col]
251
+ ||
252
+ (workspace.ntaken
253
+ &&
254
+ is_col_taken(col_is_taken, col_is_taken_s, input_data,
255
+ (col < input_data.ncols_numeric)? col : col - input_data.ncols_numeric,
256
+ (col < input_data.ncols_numeric)? Numeric : Categorical)
257
+ )
258
+ )
259
+ continue;
260
+
261
+ if (col < input_data.ncols_numeric)
262
+ {
263
+ workspace.col_chosen = col;
264
+ workspace.col_type = Numeric;
265
+ }
266
+
267
+ else
268
+ {
269
+ workspace.col_chosen = col - input_data.ncols_numeric;
270
+ workspace.col_type = Categorical;
271
+ }
272
+
273
+ get_split_range(workspace, input_data, model_params);
274
+ if (workspace.unsplittable)
275
+ {
276
+ add_unsplittable_col(workspace, input_data);
277
+ }
278
+
279
+ else
280
+ {
281
+ add_chosen_column(workspace, input_data, model_params, col_is_taken, col_is_taken_s);
282
+ if (++workspace.ntaken >= model_params.ndim)
283
+ break;
284
+ }
285
+ }
286
+
287
+ if (model_params.weigh_by_kurt)
288
+ update_col_sampler(workspace, input_data);
289
+ }
290
+
291
+ /* evaluate gain if necessary */
292
+ if (workspace.criterion != NoCrit)
293
+ workspace.this_gain = eval_guided_crit(workspace.comb_val.data(), workspace.end - workspace.st + 1,
294
+ workspace.criterion, model_params.min_gain, workspace.this_split_point,
295
+ workspace.xmin, workspace.xmax);
296
+
297
+ /* pass to the output object */
298
+ if (workspace.ntry == 1 || workspace.this_gain > hplanes.back().score)
299
+ {
300
+ /* these should be shrunk later according to what ends up used */
301
+ hplanes.back().score = workspace.this_gain;
302
+ workspace.ntaken_best = workspace.ntaken;
303
+ if (workspace.criterion != NoCrit)
304
+ {
305
+ hplanes.back().split_point = workspace.this_split_point;
306
+ if (model_params.penalize_range)
307
+ {
308
+ hplanes.back().range_low = workspace.xmin - workspace.xmax + hplanes.back().split_point;
309
+ hplanes.back().range_high = workspace.xmax - workspace.xmin + hplanes.back().split_point;
310
+ }
311
+ }
312
+ hplanes.back().col_num.assign(workspace.col_take.begin(), workspace.col_take.begin() + workspace.ntaken);
313
+ hplanes.back().col_type.assign(workspace.col_take_type.begin(), workspace.col_take_type.begin() + workspace.ntaken);
314
+ if (input_data.ncols_numeric)
315
+ {
316
+ hplanes.back().coef.assign(workspace.ext_coef.begin(), workspace.ext_coef.begin() + workspace.ntaken);
317
+ hplanes.back().mean.assign(workspace.ext_mean.begin(), workspace.ext_mean.begin() + workspace.ntaken);
318
+ }
319
+
320
+ if (model_params.missing_action != Fail)
321
+ hplanes.back().fill_val.assign(workspace.ext_fill_val.begin(), workspace.ext_fill_val.begin() + workspace.ntaken);
322
+
323
+ if (input_data.ncols_categ)
324
+ {
325
+ hplanes.back().fill_new.assign(workspace.ext_fill_new.begin(), workspace.ext_fill_new.begin() + workspace.ntaken);
326
+ switch(model_params.cat_split_type)
327
+ {
328
+ case SingleCateg:
329
+ {
330
+ hplanes.back().chosen_cat.assign(workspace.chosen_cat.begin(),
331
+ workspace.chosen_cat.begin() + workspace.ntaken);
332
+ break;
333
+ }
334
+
335
+ case SubSet:
336
+ {
337
+ if (hplanes.back().cat_coef.size() < workspace.ntaken)
338
+ hplanes.back().cat_coef.assign(workspace.ext_cat_coef.begin(),
339
+ workspace.ext_cat_coef.begin() + workspace.ntaken);
340
+ else
341
+ for (size_t col = 0; col < workspace.ntaken_best; col++)
342
+ std::copy(workspace.ext_cat_coef[col].begin(),
343
+ workspace.ext_cat_coef[col].end(),
344
+ hplanes.back().cat_coef[col].begin());
345
+ break;
346
+ }
347
+ }
348
+ }
349
+ }
350
+
351
+ }
352
+
353
+ /* if there isn't a single splittable column, end here */
354
+ if (!workspace.ntaken_best && !workspace.ntaken && workspace.tried_all)
355
+ goto terminal_statistics;
356
+
357
+ /* if the best split is not good enough, don't split any further */
358
+ if (workspace.criterion != NoCrit && hplanes.back().score <= 0)
359
+ goto terminal_statistics;
360
+
361
+ /* now need to reproduce the same split from before */
362
+ if (workspace.criterion != NoCrit && workspace.ntry > 1)
363
+ {
364
+ std::fill(workspace.comb_val.begin(),
365
+ workspace.comb_val.begin() + (workspace.end - workspace.st + 1),
366
+ (double)0);
367
+ for (size_t col = 0; col < workspace.ntaken_best; col++)
368
+ {
369
+ switch(hplanes.back().col_type[col])
370
+ {
371
+ case Numeric:
372
+ {
373
+ if (input_data.Xc == NULL)
374
+ {
375
+ add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
376
+ input_data.numeric_data + hplanes.back().col_num[col] * input_data.nrows,
377
+ hplanes.back().coef[col], (double)0, hplanes.back().mean[col],
378
+ hplanes.back().fill_val.size()? hplanes.back().fill_val[col] : workspace.this_split_point, /* second case is not used */
379
+ model_params.missing_action, NULL, NULL, false);
380
+ }
381
+
382
+ else
383
+ {
384
+ add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end,
385
+ hplanes.back().col_num[col], workspace.comb_val.data(),
386
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
387
+ hplanes.back().coef[col], (double)0, hplanes.back().mean[col],
388
+ hplanes.back().fill_val.size()? hplanes.back().fill_val[col] : workspace.this_split_point, /* second case is not used */
389
+ model_params.missing_action, NULL, NULL, false);
390
+ }
391
+
392
+ break;
393
+ }
394
+
395
+ case Categorical:
396
+ {
397
+ add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
398
+ input_data.categ_data + hplanes.back().col_num[col] * input_data.nrows,
399
+ input_data.ncat[hplanes.back().col_num[col]],
400
+ (model_params.cat_split_type == SubSet)? hplanes.back().cat_coef[col].data() : NULL,
401
+ (model_params.cat_split_type == SingleCateg)? hplanes.back().fill_new[col] : (double)0,
402
+ (model_params.cat_split_type == SingleCateg)? hplanes.back().chosen_cat[col] : 0,
403
+ (hplanes.back().fill_val.size())? hplanes.back().fill_val[col] : workspace.this_split_point, /* second case is not used */
404
+ (model_params.cat_split_type == SubSet)? hplanes.back().fill_new[col] : workspace.this_split_point, /* second case is not used */
405
+ NULL, NULL, model_params.new_cat_action, model_params.missing_action,
406
+ model_params.cat_split_type, false);
407
+ break;
408
+ }
409
+ }
410
+ }
411
+ }
412
+
413
+ /* get the range */
414
+ if (workspace.criterion == NoCrit)
415
+ {
416
+ workspace.xmin = HUGE_VAL;
417
+ workspace.xmax = -HUGE_VAL;
418
+ for (size_t row = 0; row < (workspace.end - workspace.st + 1); row++)
419
+ {
420
+ workspace.xmin = (workspace.xmin > workspace.comb_val[row])? workspace.comb_val[row] : workspace.xmin;
421
+ workspace.xmax = (workspace.xmax < workspace.comb_val[row])? workspace.comb_val[row] : workspace.xmax;
422
+ }
423
+ if (workspace.xmin == workspace.xmax)
424
+ goto terminal_statistics; /* in theory, could try again too, this could just be an unlucky case */
425
+
426
+ hplanes.back().split_point =
427
+ std::uniform_real_distribution<double>(workspace.xmin, workspace.xmax)
428
+ (workspace.rnd_generator);
429
+
430
+ /* determine acceptable range */
431
+ if (model_params.penalize_range)
432
+ {
433
+ hplanes.back().range_low = workspace.xmin - workspace.xmax + hplanes.back().split_point;
434
+ hplanes.back().range_high = workspace.xmax - workspace.xmin + hplanes.back().split_point;
435
+ }
436
+ }
437
+
438
+ /* divide */
439
+ workspace.split_ix = divide_subset_split(workspace.ix_arr.data(), workspace.comb_val.data(),
440
+ workspace.st, workspace.end, hplanes.back().split_point);
441
+
442
+ /* set as non-terminal */
443
+ hplanes.back().score = -1;
444
+
445
+ /* add another round of separation depth for distance */
446
+ if (model_params.calc_dist && curr_depth > 0)
447
+ add_separation_step(workspace, input_data, (double)(-1));
448
+
449
+ /* simplify vectors according to what ends up used */
450
+ if (input_data.ncols_categ || workspace.ntaken_best < model_params.ndim)
451
+ simplify_hplane(hplanes.back(), workspace, input_data, model_params);
452
+
453
+ shrink_to_fit_hplane(hplanes.back(), false);
454
+
455
+ /* now split */
456
+
457
+ /* back-up where it was */
458
+ recursion_state = std::unique_ptr<RecursionState>(new RecursionState);
459
+ backup_recursion_state(workspace, *recursion_state);
460
+
461
+ /* follow left branch */
462
+ hplanes[hplane_from].hplane_left = hplanes.size();
463
+ hplanes.emplace_back();
464
+ if (impute_nodes != NULL) impute_nodes->emplace_back(hplane_from);
465
+ workspace.end = workspace.split_ix - 1;
466
+ split_hplane_recursive(hplanes,
467
+ workspace,
468
+ input_data,
469
+ model_params,
470
+ impute_nodes,
471
+ curr_depth + 1);
472
+
473
+
474
+ /* follow right branch */
475
+ hplanes[hplane_from].hplane_right = hplanes.size();
476
+ restore_recursion_state(workspace, *recursion_state);
477
+ hplanes.emplace_back();
478
+ if (impute_nodes != NULL) impute_nodes->emplace_back(hplane_from);
479
+ workspace.st = workspace.split_ix;
480
+ split_hplane_recursive(hplanes,
481
+ workspace,
482
+ input_data,
483
+ model_params,
484
+ impute_nodes,
485
+ curr_depth + 1);
486
+
487
+ return;
488
+
489
+ terminal_statistics:
490
+ {
491
+ if (!workspace.weights_arr.size() && !workspace.weights_map.size())
492
+ {
493
+ hplanes.back().score = (double)(curr_depth + expected_avg_depth(workspace.end - workspace.st + 1));
494
+ }
495
+
496
+ else
497
+ {
498
+ if (sum_weight == -HUGE_VAL)
499
+ sum_weight = calculate_sum_weights(workspace.ix_arr, workspace.st, workspace.end, curr_depth,
500
+ workspace.weights_arr, workspace.weights_map);
501
+ hplanes.back().score = (double)(curr_depth + expected_avg_depth(sum_weight));
502
+ }
503
+
504
+ /* don't leave any vector initialized */
505
+ shrink_to_fit_hplane(hplanes.back(), true);
506
+
507
+ hplanes.back().remainder = workspace.weights_arr.size()?
508
+ sum_weight : (workspace.weights_map.size()?
509
+ sum_weight : ((double)(workspace.end - workspace.st + 1))
510
+ );
511
+
512
+ /* for distance, assume also the elements keep being split */
513
+ if (model_params.calc_dist)
514
+ add_remainder_separation_steps(workspace, input_data, sum_weight);
515
+
516
+ /* add this depth right away if requested */
517
+ if (workspace.row_depths.size())
518
+ for (size_t row = workspace.st; row <= workspace.end; row++)
519
+ workspace.row_depths[workspace.ix_arr[row]] += hplanes.back().score;
520
+
521
+ /* add imputations from node if requested */
522
+ if (model_params.impute_at_fit)
523
+ add_from_impute_node(impute_nodes->back(), workspace, input_data);
524
+ }
525
+ }
526
+
527
+
528
+ void add_chosen_column(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params,
529
+ std::vector<bool> &col_is_taken, std::unordered_set<size_t> &col_is_taken_s)
530
+ {
531
+ set_col_as_taken(col_is_taken, col_is_taken_s, input_data, workspace.col_chosen, workspace.col_type);
532
+ workspace.col_take[workspace.ntaken] = workspace.col_chosen;
533
+ workspace.col_take_type[workspace.ntaken] = workspace.col_type;
534
+
535
+ switch(workspace.col_type)
536
+ {
537
+ case Numeric:
538
+ {
539
+ switch(model_params.coef_type)
540
+ {
541
+ case Uniform:
542
+ {
543
+ workspace.ext_coef[workspace.ntaken] = workspace.coef_unif(workspace.rnd_generator);
544
+ break;
545
+ }
546
+
547
+ case Normal:
548
+ {
549
+ workspace.ext_coef[workspace.ntaken] = workspace.coef_norm(workspace.rnd_generator);
550
+ break;
551
+ }
552
+ }
553
+
554
+ if (input_data.Xc == NULL)
555
+ {
556
+ calc_mean_and_sd(workspace.ix_arr.data(), workspace.st, workspace.end,
557
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
558
+ model_params.missing_action, workspace.ext_sd, workspace.ext_mean[workspace.ntaken]);
559
+ add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
560
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
561
+ workspace.ext_coef[workspace.ntaken], workspace.ext_sd, workspace.ext_mean[workspace.ntaken],
562
+ workspace.ext_fill_val[workspace.ntaken], model_params.missing_action,
563
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(), true);
564
+ }
565
+
566
+ else
567
+ {
568
+ calc_mean_and_sd(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.col_chosen,
569
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
570
+ workspace.ext_sd, workspace.ext_mean[workspace.ntaken]);
571
+ add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end,
572
+ workspace.col_chosen, workspace.comb_val.data(),
573
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
574
+ workspace.ext_coef[workspace.ntaken], workspace.ext_sd, workspace.ext_mean[workspace.ntaken],
575
+ workspace.ext_fill_val[workspace.ntaken], model_params.missing_action,
576
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(), true);
577
+ }
578
+ break;
579
+ }
580
+
581
+ case Categorical:
582
+ {
583
+ switch(model_params.cat_split_type)
584
+ {
585
+ case SingleCateg:
586
+ {
587
+ workspace.chosen_cat[workspace.ntaken] = choose_cat_from_present(workspace, input_data, workspace.col_chosen);
588
+ workspace.ext_fill_new[workspace.ntaken] = workspace.coef_norm(workspace.rnd_generator);
589
+ add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
590
+ input_data.categ_data + workspace.col_chosen * input_data.nrows,
591
+ input_data.ncat[workspace.col_chosen],
592
+ NULL, workspace.ext_fill_new[workspace.ntaken],
593
+ workspace.chosen_cat[workspace.ntaken],
594
+ workspace.ext_fill_val[workspace.ntaken], workspace.ext_fill_new[workspace.ntaken],
595
+ NULL, NULL, model_params.new_cat_action, model_params.missing_action, SingleCateg, true);
596
+
597
+ break;
598
+ }
599
+
600
+ case SubSet:
601
+ {
602
+ for (int cat = 0; cat < input_data.ncat[workspace.col_chosen]; cat++)
603
+ workspace.ext_cat_coef[workspace.ntaken][cat] = workspace.coef_norm(workspace.rnd_generator);
604
+
605
+ if (model_params.coef_by_prop)
606
+ {
607
+ int ncat = input_data.ncat[workspace.col_chosen];
608
+ size_t *restrict counts = workspace.buffer_szt.data();
609
+ size_t *restrict sorted_ix = workspace.buffer_szt.data() + ncat;
610
+ /* calculate counts and sort by them */
611
+ std::fill(counts, counts + ncat, (size_t)0);
612
+ for (size_t ix = workspace.st; ix <= workspace.end; ix++)
613
+ if (input_data.categ_data[workspace.col_chosen * input_data.nrows + ix] >= 0)
614
+ counts[input_data.categ_data[workspace.col_chosen * input_data.nrows + ix]]++;
615
+ std::iota(sorted_ix, sorted_ix + ncat, (size_t)0);
616
+ std::sort(sorted_ix, sorted_ix + ncat,
617
+ [&counts](const size_t a, const size_t b){return counts[a] < counts[b];});
618
+ /* now re-order the coefficients accordingly */
619
+ std::sort(workspace.ext_cat_coef[workspace.ntaken].begin(),
620
+ workspace.ext_cat_coef[workspace.ntaken].begin() + ncat);
621
+ std::copy(workspace.ext_cat_coef[workspace.ntaken].begin(),
622
+ workspace.ext_cat_coef[workspace.ntaken].begin() + ncat,
623
+ workspace.buffer_dbl.begin());
624
+ for (size_t ix = 0; ix < ncat; ix++)
625
+ workspace.ext_cat_coef[workspace.ntaken][ix] = workspace.buffer_dbl[sorted_ix[ix]];
626
+ }
627
+
628
+ add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
629
+ input_data.categ_data + workspace.col_chosen * input_data.nrows,
630
+ input_data.ncat[workspace.col_chosen],
631
+ workspace.ext_cat_coef[workspace.ntaken].data(), (double)0, (int)0,
632
+ workspace.ext_fill_val[workspace.ntaken], workspace.ext_fill_new[workspace.ntaken],
633
+ workspace.buffer_szt.data(), workspace.buffer_szt.data() + input_data.max_categ + 1,
634
+ model_params.new_cat_action, model_params.missing_action, SubSet, true);
635
+ break;
636
+ }
637
+ }
638
+ break;
639
+ }
640
+ }
641
+
642
+ double xmin = HUGE_VAL, xmax = -HUGE_VAL;
643
+ for (size_t row = workspace.st; row <= workspace.end; row++)
644
+ {
645
+ xmin = fmin(xmin, workspace.comb_val[row - workspace.st]);
646
+ xmax = fmax(xmax, workspace.comb_val[row - workspace.st]);
647
+ }
648
+ }
649
+
650
+ void shrink_to_fit_hplane(IsoHPlane &hplane, bool clear_vectors)
651
+ {
652
+ if (clear_vectors)
653
+ {
654
+ hplane.col_num.clear();
655
+ hplane.col_type.clear();
656
+ hplane.coef.clear();
657
+ hplane.mean.clear();
658
+ hplane.cat_coef.clear();
659
+ hplane.chosen_cat.clear();
660
+ hplane.fill_val.clear();
661
+ hplane.fill_new.clear();
662
+ }
663
+
664
+ hplane.col_num.shrink_to_fit();
665
+ hplane.col_type.shrink_to_fit();
666
+ hplane.coef.shrink_to_fit();
667
+ hplane.mean.shrink_to_fit();
668
+ hplane.cat_coef.shrink_to_fit();
669
+ hplane.chosen_cat.shrink_to_fit();
670
+ hplane.fill_val.shrink_to_fit();
671
+ hplane.fill_new.shrink_to_fit();
672
+ }
673
+
674
+ void simplify_hplane(IsoHPlane &hplane, WorkerMemory &workspace, InputData &input_data, ModelParams &model_params)
675
+ {
676
+ if (workspace.ntaken_best < model_params.ndim)
677
+ {
678
+ hplane.col_num.resize(workspace.ntaken_best);
679
+ hplane.col_type.resize(workspace.ntaken_best);
680
+ if (model_params.missing_action != Fail)
681
+ hplane.fill_val.resize(workspace.ntaken_best);
682
+ }
683
+
684
+ size_t ncols_numeric = 0;
685
+ size_t ncols_categ = 0;
686
+
687
+ if (input_data.ncols_categ)
688
+ {
689
+ for (size_t col = 0; col < workspace.ntaken_best; col++)
690
+ {
691
+ switch(hplane.col_type[col])
692
+ {
693
+ case Numeric:
694
+ {
695
+ workspace.ext_coef[ncols_numeric] = hplane.coef[col];
696
+ workspace.ext_mean[ncols_numeric] = hplane.mean[col];
697
+ ncols_numeric++;
698
+ break;
699
+ }
700
+
701
+ case Categorical:
702
+ {
703
+ workspace.ext_fill_new[ncols_categ] = hplane.fill_new[col];
704
+ switch(model_params.cat_split_type)
705
+ {
706
+ case SingleCateg:
707
+ {
708
+ workspace.chosen_cat[ncols_categ] = hplane.chosen_cat[col];
709
+ break;
710
+ }
711
+
712
+ case SubSet:
713
+ {
714
+ std::copy(hplane.cat_coef[col].begin(),
715
+ hplane.cat_coef[col].begin() + input_data.ncat[hplane.col_num[col]],
716
+ workspace.ext_cat_coef[ncols_categ].begin());
717
+ break;
718
+ }
719
+ }
720
+ ncols_categ++;
721
+ break;
722
+ }
723
+ }
724
+ }
725
+ }
726
+
727
+ else
728
+ {
729
+ ncols_numeric = workspace.ntaken_best;
730
+ }
731
+
732
+
733
+ hplane.coef.resize(ncols_numeric);
734
+ hplane.mean.resize(ncols_numeric);
735
+ if (input_data.ncols_numeric)
736
+ {
737
+ std::copy(workspace.ext_coef.begin(), workspace.ext_coef.begin() + ncols_numeric, hplane.coef.begin());
738
+ std::copy(workspace.ext_mean.begin(), workspace.ext_mean.begin() + ncols_numeric, hplane.mean.begin());
739
+ }
740
+
741
+ /* If there are no categorical columns, all of them will be numerical and there is no need to reorder */
742
+ if (ncols_categ)
743
+ {
744
+ hplane.fill_new.resize(ncols_categ);
745
+ std::copy(workspace.ext_fill_new.begin(),
746
+ workspace.ext_fill_new.begin() + ncols_categ,
747
+ hplane.fill_new.begin());
748
+
749
+ hplane.cat_coef.resize(ncols_categ);
750
+ switch(model_params.cat_split_type)
751
+ {
752
+ case SingleCateg:
753
+ {
754
+ hplane.chosen_cat.resize(ncols_categ);
755
+ std::copy(workspace.chosen_cat.begin(),
756
+ workspace.chosen_cat.begin() + ncols_categ,
757
+ hplane.chosen_cat.begin());
758
+ hplane.cat_coef.clear();
759
+ break;
760
+ }
761
+
762
+ case SubSet:
763
+ {
764
+ hplane.chosen_cat.clear();
765
+ ncols_categ = 0;
766
+ for (size_t col = 0; col < workspace.ntaken_best; col++)
767
+ {
768
+ if (hplane.col_type[col] == Categorical)
769
+ {
770
+ hplane.cat_coef[ncols_categ].resize(input_data.ncat[hplane.col_num[col]]);
771
+ std::copy(workspace.ext_cat_coef[ncols_categ].begin(),
772
+ workspace.ext_cat_coef[ncols_categ].begin()
773
+ + input_data.ncat[hplane.col_num[col]],
774
+ hplane.cat_coef[ncols_categ].begin());
775
+ hplane.cat_coef[ncols_categ].shrink_to_fit();
776
+ ncols_categ++;
777
+ }
778
+ }
779
+ break;
780
+ }
781
+ }
782
+ }
783
+
784
+ else
785
+ {
786
+ hplane.cat_coef.clear();
787
+ hplane.chosen_cat.clear();
788
+ hplane.fill_new.clear();
789
+ }
790
+ }