isotree 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -33,6 +33,7 @@
33
33
  #include <iterator>
34
34
  #include <limits>
35
35
  #include <memory>
36
+ #include <new>
36
37
  #include <stdexcept>
37
38
  #include <tuple>
38
39
  #include <type_traits>
@@ -195,6 +196,7 @@ class bucket_entry : public bucket_entry_hash<StoreHash> {
195
196
  value_type(other.value());
196
197
  m_dist_from_ideal_bucket = other.m_dist_from_ideal_bucket;
197
198
  }
199
+ tsl_rh_assert(empty() == other.empty());
198
200
  }
199
201
 
200
202
  /**
@@ -212,6 +214,7 @@ class bucket_entry : public bucket_entry_hash<StoreHash> {
212
214
  value_type(std::move(other.value()));
213
215
  m_dist_from_ideal_bucket = other.m_dist_from_ideal_bucket;
214
216
  }
217
+ tsl_rh_assert(empty() == other.empty());
215
218
  }
216
219
 
217
220
  bucket_entry& operator=(const bucket_entry& other) noexcept(
@@ -249,12 +252,22 @@ class bucket_entry : public bucket_entry_hash<StoreHash> {
249
252
 
250
253
  value_type& value() noexcept {
251
254
  tsl_rh_assert(!empty());
255
+ #if defined(__cplusplus) && __cplusplus >= 201703L
256
+ return *std::launder(
257
+ reinterpret_cast<value_type*>(std::addressof(m_value)));
258
+ #else
252
259
  return *reinterpret_cast<value_type*>(std::addressof(m_value));
260
+ #endif
253
261
  }
254
262
 
255
263
  const value_type& value() const noexcept {
256
264
  tsl_rh_assert(!empty());
265
+ #if defined(__cplusplus) && __cplusplus >= 201703L
266
+ return *std::launder(
267
+ reinterpret_cast<const value_type*>(std::addressof(m_value)));
268
+ #else
257
269
  return *reinterpret_cast<const value_type*>(std::addressof(m_value));
270
+ #endif
258
271
  }
259
272
 
260
273
  distance_type dist_from_ideal_bucket() const noexcept {
@@ -283,6 +296,7 @@ class bucket_entry : public bucket_entry_hash<StoreHash> {
283
296
  void swap_with_value_in_bucket(distance_type& dist_from_ideal_bucket,
284
297
  truncated_hash_type& hash, value_type& value) {
285
298
  tsl_rh_assert(!empty());
299
+ tsl_rh_assert(dist_from_ideal_bucket > m_dist_from_ideal_bucket);
286
300
 
287
301
  using std::swap;
288
302
  swap(value, this->value());
@@ -310,19 +324,16 @@ class bucket_entry : public bucket_entry_hash<StoreHash> {
310
324
 
311
325
  public:
312
326
  static const distance_type EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET = -1;
313
- static const distance_type DIST_FROM_IDEAL_BUCKET_LIMIT = 4096;
327
+ static const distance_type DIST_FROM_IDEAL_BUCKET_LIMIT = 8192;
314
328
  static_assert(DIST_FROM_IDEAL_BUCKET_LIMIT <=
315
329
  std::numeric_limits<distance_type>::max() - 1,
316
330
  "DIST_FROM_IDEAL_BUCKET_LIMIT must be <= "
317
331
  "std::numeric_limits<distance_type>::max() - 1.");
318
332
 
319
333
  private:
320
- using storage = typename std::aligned_storage<sizeof(value_type),
321
- alignof(value_type)>::type;
322
-
323
334
  distance_type m_dist_from_ideal_bucket;
324
335
  bool m_last_bucket;
325
- storage m_value;
336
+ alignas(value_type) unsigned char m_value[sizeof(value_type)];
326
337
  };
327
338
 
328
339
  /**
@@ -659,7 +670,7 @@ class robin_hash : private Hash, private KeyEqual, private GrowthPolicy {
659
670
 
660
671
  robin_hash& operator=(robin_hash&& other) {
661
672
  other.swap(*this);
662
- other.clear();
673
+ other.clear_and_shrink();
663
674
 
664
675
  return *this;
665
676
  }
@@ -1068,6 +1079,7 @@ class robin_hash : private Hash, private KeyEqual, private GrowthPolicy {
1068
1079
  m_max_load_factor = clamp(ml, float(MINIMUM_MAX_LOAD_FACTOR),
1069
1080
  float(MAXIMUM_MAX_LOAD_FACTOR));
1070
1081
  m_load_threshold = size_type(float(bucket_count()) * m_max_load_factor);
1082
+ tsl_rh_assert(bucket_count() == 0 || m_load_threshold < bucket_count());
1071
1083
  }
1072
1084
 
1073
1085
  void rehash(size_type count_) {
@@ -1219,7 +1231,7 @@ class robin_hash : private Hash, private KeyEqual, private GrowthPolicy {
1219
1231
  dist_from_ideal_bucket++;
1220
1232
  }
1221
1233
 
1222
- if (rehash_on_extreme_load()) {
1234
+ while (rehash_on_extreme_load(dist_from_ideal_bucket)) {
1223
1235
  ibucket = bucket_for_hash(hash);
1224
1236
  dist_from_ideal_bucket = 0;
1225
1237
 
@@ -1271,6 +1283,8 @@ class robin_hash : private Hash, private KeyEqual, private GrowthPolicy {
1271
1283
  void insert_value_impl(std::size_t ibucket,
1272
1284
  distance_type dist_from_ideal_bucket,
1273
1285
  truncated_hash_type hash, value_type& value) {
1286
+ tsl_rh_assert(dist_from_ideal_bucket >
1287
+ m_buckets[ibucket].dist_from_ideal_bucket());
1274
1288
  m_buckets[ibucket].swap_with_value_in_bucket(dist_from_ideal_bucket, hash,
1275
1289
  value);
1276
1290
  ibucket = next_bucket(ibucket);
@@ -1279,7 +1293,7 @@ class robin_hash : private Hash, private KeyEqual, private GrowthPolicy {
1279
1293
  while (!m_buckets[ibucket].empty()) {
1280
1294
  if (dist_from_ideal_bucket >
1281
1295
  m_buckets[ibucket].dist_from_ideal_bucket()) {
1282
- if (dist_from_ideal_bucket >=
1296
+ if (dist_from_ideal_bucket >
1283
1297
  bucket_entry::DIST_FROM_IDEAL_BUCKET_LIMIT) {
1284
1298
  /**
1285
1299
  * The number of probes is really high, rehash the map on the next
@@ -1304,6 +1318,7 @@ class robin_hash : private Hash, private KeyEqual, private GrowthPolicy {
1304
1318
  robin_hash new_table(count_, static_cast<Hash&>(*this),
1305
1319
  static_cast<KeyEqual&>(*this), get_allocator(),
1306
1320
  m_min_load_factor, m_max_load_factor);
1321
+ tsl_rh_assert(size() <= new_table.m_load_threshold);
1307
1322
 
1308
1323
  const bool use_stored_hash =
1309
1324
  USE_STORED_HASH_ON_REHASH(new_table.bucket_count());
@@ -1364,8 +1379,11 @@ class robin_hash : private Hash, private KeyEqual, private GrowthPolicy {
1364
1379
  *
1365
1380
  * Return true if the table has been rehashed.
1366
1381
  */
1367
- bool rehash_on_extreme_load() {
1368
- if (m_grow_on_next_insert || size() >= m_load_threshold) {
1382
+ bool rehash_on_extreme_load(distance_type curr_dist_from_ideal_bucket) {
1383
+ if (m_grow_on_next_insert ||
1384
+ curr_dist_from_ideal_bucket >
1385
+ bucket_entry::DIST_FROM_IDEAL_BUCKET_LIMIT ||
1386
+ size() >= m_load_threshold) {
1369
1387
  rehash_impl(GrowthPolicy::next_bucket_count());
1370
1388
  m_grow_on_next_insert = false;
1371
1389
 
@@ -1571,6 +1589,7 @@ class robin_hash : private Hash, private KeyEqual, private GrowthPolicy {
1571
1589
  */
1572
1590
  bucket_entry* static_empty_bucket_ptr() noexcept {
1573
1591
  static bucket_entry empty_bucket(true);
1592
+ tsl_rh_assert(empty_bucket.empty());
1574
1593
  return &empty_bucket;
1575
1594
  }
1576
1595
 
@@ -1847,13 +1847,13 @@ void check_setup_info
1847
1847
  }
1848
1848
 
1849
1849
  if (setup_info[4] == (uint8_t)IsAbnormalDouble)
1850
- fprintf(stderr, "Warning: input model uses non-standard numeric type, might read correctly.\n");
1850
+ print_errmsg("Warning: input model uses non-standard numeric type, might read correctly.\n");
1851
1851
 
1852
1852
  switch(setup_info[6])
1853
1853
  {
1854
- case 16: {saved_int_t = Is16Bit; break;}
1855
- case 32: {saved_int_t = Is32Bit; break;}
1856
- case 64: {saved_int_t = Is64Bit; break;}
1854
+ case 2: {saved_int_t = Is16Bit; break;}
1855
+ case 4: {saved_int_t = Is32Bit; break;}
1856
+ case 8: {saved_int_t = Is64Bit; break;}
1857
1857
  default: {saved_int_t = IsOther; break;}
1858
1858
  }
1859
1859
  if ((uint8_t)sizeof(int) != setup_info[6]) {
@@ -3844,7 +3844,7 @@ void serialize_combined
3844
3844
  {
3845
3845
  if (memcmp(curr_setup.get(), serialized_model, get_size_setup_info()))
3846
3846
  {
3847
- fprintf(stderr, "Warning: 'model' was serialized in a different setup, will need to convert.\n");
3847
+ print_errmsg("Warning: 'model' was serialized in a different setup, will need to convert.\n");
3848
3848
  IsoForest model;
3849
3849
  deserialization_pipeline(model, serialized_model);
3850
3850
  new_model = std::unique_ptr<char[]>(new char[get_size_model(model)]);
@@ -3862,7 +3862,7 @@ void serialize_combined
3862
3862
  {
3863
3863
  if (memcmp(curr_setup.get(), serialized_model_ext, get_size_setup_info()))
3864
3864
  {
3865
- fprintf(stderr, "Warning: 'model_ext' was serialized in a different setup, will need to convert.\n");
3865
+ print_errmsg("Warning: 'model_ext' was serialized in a different setup, will need to convert.\n");
3866
3866
  ExtIsoForest model;
3867
3867
  deserialization_pipeline(model, serialized_model_ext);
3868
3868
  new_model = std::unique_ptr<char[]>(new char[get_size_model(model)]);
@@ -3884,7 +3884,7 @@ void serialize_combined
3884
3884
  {
3885
3885
  if (memcmp(curr_setup.get(), serialized_imputer, get_size_setup_info()))
3886
3886
  {
3887
- fprintf(stderr, "Warning: 'imputer' was serialized in a different setup, will need to convert.\n");
3887
+ print_errmsg("Warning: 'imputer' was serialized in a different setup, will need to convert.\n");
3888
3888
  Imputer model;
3889
3889
  deserialization_pipeline(model, serialized_imputer);
3890
3890
  new_model = std::unique_ptr<char[]>(new char[get_size_model(model)]);
@@ -3907,7 +3907,7 @@ void serialize_combined
3907
3907
  {
3908
3908
  if (memcmp(curr_setup.get(), serialized_indexer, get_size_setup_info()))
3909
3909
  {
3910
- fprintf(stderr, "Warning: 'indexer' was serialized in a different setup, will need to convert.\n");
3910
+ print_errmsg("Warning: 'indexer' was serialized in a different setup, will need to convert.\n");
3911
3911
  TreesIndexer model;
3912
3912
  deserialization_pipeline(model, serialized_indexer);
3913
3913
  new_model = std::unique_ptr<char[]>(new char[get_size_model(model)]);
@@ -4211,7 +4211,9 @@ void deserialize_combined
4211
4211
  {
4212
4212
  deserialize_model(*model, in, has_same_endianness, has_same_int_size, has_same_size_t_size, saved_int_t, saved_size_t, lacks_range_penalty, lacks_scoring_metric);
4213
4213
  check_interrupt_switch(ss);
4214
- read_bytes<char>((void*)optional_metadata, size_metadata, in);
4214
+ if (optional_metadata) {
4215
+ read_bytes<char>((void*)optional_metadata, size_metadata, in);
4216
+ }
4215
4217
  break;
4216
4218
  }
4217
4219
  case HasSingleVarModelPlusIndexerPlusMetadataNext:
@@ -4220,14 +4222,18 @@ void deserialize_combined
4220
4222
  check_interrupt_switch(ss);
4221
4223
  deserialize_model(*indexer, in, has_same_endianness, has_same_int_size, has_same_size_t_size, saved_int_t, saved_size_t, lacks_range_penalty, lacks_scoring_metric);
4222
4224
  check_interrupt_switch(ss);
4223
- read_bytes<char>((void*)optional_metadata, size_metadata, in);
4225
+ if (optional_metadata) {
4226
+ read_bytes<char>((void*)optional_metadata, size_metadata, in);
4227
+ }
4224
4228
  break;
4225
4229
  }
4226
4230
  case HasExtModelPlusMetadataNext:
4227
4231
  {
4228
4232
  deserialize_model(*model_ext, in, has_same_endianness, has_same_int_size, has_same_size_t_size, saved_int_t, saved_size_t, lacks_range_penalty, lacks_scoring_metric);
4229
4233
  check_interrupt_switch(ss);
4230
- read_bytes<char>((void*)optional_metadata, size_metadata, in);
4234
+ if (optional_metadata) {
4235
+ read_bytes<char>((void*)optional_metadata, size_metadata, in);
4236
+ }
4231
4237
  break;
4232
4238
  }
4233
4239
  case HasExtModelPlusIndexerPlusMetadataNext:
@@ -4236,7 +4242,9 @@ void deserialize_combined
4236
4242
  check_interrupt_switch(ss);
4237
4243
  deserialize_model(*indexer, in, has_same_endianness, has_same_int_size, has_same_size_t_size, saved_int_t, saved_size_t, lacks_range_penalty, lacks_scoring_metric);
4238
4244
  check_interrupt_switch(ss);
4239
- read_bytes<char>((void*)optional_metadata, size_metadata, in);
4245
+ if (optional_metadata) {
4246
+ read_bytes<char>((void*)optional_metadata, size_metadata, in);
4247
+ }
4240
4248
  break;
4241
4249
  }
4242
4250
  case HasSingleVarModelPlusImputerPlusMetadataNext:
@@ -4245,7 +4253,9 @@ void deserialize_combined
4245
4253
  check_interrupt_switch(ss);
4246
4254
  deserialize_model(*imputer, in, has_same_endianness, has_same_int_size, has_same_size_t_size, saved_int_t, saved_size_t, lacks_range_penalty, lacks_scoring_metric);
4247
4255
  check_interrupt_switch(ss);
4248
- read_bytes<char>((void*)optional_metadata, size_metadata, in);
4256
+ if (optional_metadata) {
4257
+ read_bytes<char>((void*)optional_metadata, size_metadata, in);
4258
+ }
4249
4259
  break;
4250
4260
  }
4251
4261
  case HasSingleVarModelPlusImputerPlusIndexerPlusMetadataNext:
@@ -4256,7 +4266,9 @@ void deserialize_combined
4256
4266
  check_interrupt_switch(ss);
4257
4267
  deserialize_model(*indexer, in, has_same_endianness, has_same_int_size, has_same_size_t_size, saved_int_t, saved_size_t, lacks_range_penalty, lacks_scoring_metric);
4258
4268
  check_interrupt_switch(ss);
4259
- read_bytes<char>((void*)optional_metadata, size_metadata, in);
4269
+ if (optional_metadata) {
4270
+ read_bytes<char>((void*)optional_metadata, size_metadata, in);
4271
+ }
4260
4272
  break;
4261
4273
  }
4262
4274
  case HasExtModelPlusImputerPlusMetadataNext:
@@ -4265,7 +4277,9 @@ void deserialize_combined
4265
4277
  check_interrupt_switch(ss);
4266
4278
  deserialize_model(*imputer, in, has_same_endianness, has_same_int_size, has_same_size_t_size, saved_int_t, saved_size_t, lacks_range_penalty, lacks_scoring_metric);
4267
4279
  check_interrupt_switch(ss);
4268
- read_bytes<char>((void*)optional_metadata, size_metadata, in);
4280
+ if (optional_metadata) {
4281
+ read_bytes<char>((void*)optional_metadata, size_metadata, in);
4282
+ }
4269
4283
  break;
4270
4284
  }
4271
4285
  case HasExtModelPlusImputerPlusIndexerPlusMetadataNext:
@@ -4276,7 +4290,9 @@ void deserialize_combined
4276
4290
  check_interrupt_switch(ss);
4277
4291
  deserialize_model(*indexer, in, has_same_endianness, has_same_int_size, has_same_size_t_size, saved_int_t, saved_size_t, lacks_range_penalty, lacks_scoring_metric);
4278
4292
  check_interrupt_switch(ss);
4279
- read_bytes<char>((void*)optional_metadata, size_metadata, in);
4293
+ if (optional_metadata) {
4294
+ read_bytes<char>((void*)optional_metadata, size_metadata, in);
4295
+ }
4280
4296
  break;
4281
4297
  }
4282
4298
 
@@ -593,14 +593,14 @@ void extract_cond_ext_isotree(ExtIsoForest &model, IsoHPlane &hplane,
593
593
  case SubSet:
594
594
  {
595
595
  hplane_conds += std::string("CASE ") + categ_colnames[hplane.col_num[ix]];
596
- for (size_t categ = 0; categ < hplane.cat_coef[hplane.col_num[ix]].size(); categ++)
596
+ for (size_t categ = 0; categ < hplane.cat_coef[n_visited_categ].size(); categ++)
597
597
  {
598
598
  hplane_conds
599
599
  +=
600
600
  std::string(" WHEN '")
601
601
  + categ_levels[hplane.col_num[ix]][categ]
602
602
  + std::string("' THEN ")
603
- + std::to_string( hplane.cat_coef[hplane.col_num[ix]][categ]);
603
+ + std::to_string( hplane.cat_coef[n_visited_categ][categ]);
604
604
  }
605
605
  if (model.new_cat_action == Smallest)
606
606
  hplane_conds += std::string(" ELSE ") + std::to_string(hplane.fill_new[n_visited_categ]);
@@ -130,34 +130,7 @@
130
130
 
131
131
  /* adapted from cephes */
132
132
  #define EULERS_GAMMA 0.577215664901532860606512
133
- double digamma(double x)
134
- {
135
- double y, z, z2;
136
-
137
- /* check for positive integer up to 128 */
138
- if (unlikely((x <= 64) && (x == std::floor(x)))) {
139
- return harmonic_recursive(1.0, (double)x) - EULERS_GAMMA;
140
- }
141
-
142
- if (likely(x < 1.0e17 ))
143
- {
144
- z = 1.0/(x * x);
145
- z2 = square(z);
146
- y = z * ( 8.33333333333333333333E-2
147
- -8.33333333333333333333E-3*z
148
- +3.96825396825396825397E-3*z2
149
- -4.16666666666666666667E-3*z2*z
150
- +7.57575757575757575758E-3*square(z2)
151
- -2.10927960927960927961E-2*square(z2)*z
152
- +8.33333333333333333333E-2*square(z2)*z2);
153
- }
154
- else {
155
- y = 0.0;
156
- }
157
-
158
- y = ((-0.5/x) - y) + std::log(x);
159
- return y;
160
- }
133
+ #include "digamma.hpp"
161
134
 
162
135
  /* http://fredrik-j.blogspot.com/2009/02/how-not-to-compute-harmonic-numbers.html
163
136
  https://en.wikipedia.org/wiki/Harmonic_number
@@ -434,7 +407,7 @@ void build_btree_sampler(std::vector<double> &btree_weights, real_t *restrict sa
434
407
 
435
408
  if (std::isnan(btree_weights[0]) || btree_weights[0] <= 0)
436
409
  {
437
- fprintf(stderr, "Numeric precision error with sample weights, will not use them.\n");
410
+ print_errmsg("Numeric precision error with sample weights, will not use them.\n");
438
411
  log2_n = 0;
439
412
  btree_weights.clear();
440
413
  btree_weights.shrink_to_fit();
@@ -672,16 +645,21 @@ void weighted_shuffle(size_t *restrict outp, size_t n, real_t *restrict weights,
672
645
  }
673
646
  }
674
647
 
648
+ /* Goualard, Frédéric. "Drawing random floating-point numbers from an interval."
649
+ ACM Transactions on Modeling and Computer Simulation (TOMACS) 32.3 (2022): 1-24. */
650
+ [[gnu::flatten]]
675
651
  double sample_random_uniform(double xmin, double xmax, RNG_engine &rng) noexcept
676
652
  {
677
- double out;
678
- std::uniform_real_distribution<double> runif(xmin, xmax);
679
- for (int attempt = 0; attempt < 100; attempt++)
680
- {
681
- out = runif(rng);
682
- if (likely(out < xmax)) return out;
653
+ const double random_unit = UniformUnitInterval(0, 1)(rng);
654
+ const double half_min = 0.5 * xmin;
655
+ const double half_max = 0.5 * xmax;
656
+ double out = 2. * (half_min + random_unit * (half_max - half_min));
657
+ if (unlikely(out >= xmax)) {
658
+ if (unlikely(xmax == xmin)) return xmin;
659
+ out = std::nextafter(xmax, xmin);
683
660
  }
684
- return xmin;
661
+ out = std::fmax(out, xmin);
662
+ return out;
685
663
  }
686
664
 
687
665
  template <class ldouble_safe>
@@ -3721,7 +3699,7 @@ void check_interrupt_switch(SignalSwitcher &ss)
3721
3699
  if (interrupt_switch)
3722
3700
  {
3723
3701
  ss.restore_handle();
3724
- fprintf(stderr, "Error: procedure was interrupted\n");
3702
+ print_errmsg("Error: procedure was interrupted\n");
3725
3703
  raise(SIGINT);
3726
3704
  #ifdef _FOR_R
3727
3705
  Rcpp::checkUserInterrupt();
@@ -285,13 +285,9 @@ static inline bool get_is_little_endian() noexcept
285
285
  static const bool is_little_endian = get_is_little_endian();
286
286
 
287
287
  /* ~Uniform([0,1))
288
- Be aware that the compilers headers may:
289
- - Produce a non-uniform distribution as they divide
290
- by the maximum value of the generator (not all numbers
291
- between zero and one are representable).
292
- - Draw from a closed interval [0,1] (infinitesimal chance
293
- that something will go wrong, but better not take it).
294
- (For example, GCC4 had bugs like those)
288
+ Be aware that the compilers headers may produce a non-uniform
289
+ distribution as they divide by the maximum value of the generator
290
+ (not all numbers between zero and one are representable).
295
291
  Hence this replacement. It is not too much slower
296
292
  than what the compiler's header use. */
297
293
  class UniformUnitInterval
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: isotree
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.0
4
+ version: 0.3.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-06-14 00:00:00.000000000 Z
11
+ date: 2023-12-20 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -49,6 +49,7 @@ files:
49
49
  - vendor/isotree/src/Rwrapper.cpp
50
50
  - vendor/isotree/src/c_interface.cpp
51
51
  - vendor/isotree/src/crit.hpp
52
+ - vendor/isotree/src/digamma.hpp
52
53
  - vendor/isotree/src/dist.hpp
53
54
  - vendor/isotree/src/exp_depth_table.hpp
54
55
  - vendor/isotree/src/extended.hpp
@@ -102,7 +103,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
102
103
  - !ruby/object:Gem::Version
103
104
  version: '0'
104
105
  requirements: []
105
- rubygems_version: 3.3.7
106
+ rubygems_version: 3.4.10
106
107
  signing_key:
107
108
  specification_version: 4
108
109
  summary: Outlier/anomaly detection for Ruby using Isolation Forest