xgb 0.8.0 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
 - data/CHANGELOG.md +8 -0
 - data/NOTICE.txt +1 -1
 - data/README.md +1 -1
 - data/lib/xgboost/booster.rb +176 -65
 - data/lib/xgboost/callback_container.rb +145 -0
 - data/lib/xgboost/cv_pack.rb +26 -0
 - data/lib/xgboost/dmatrix.rb +190 -78
 - data/lib/xgboost/early_stopping.rb +132 -0
 - data/lib/xgboost/evaluation_monitor.rb +44 -0
 - data/lib/xgboost/ffi.rb +11 -2
 - data/lib/xgboost/packed_booster.rb +51 -0
 - data/lib/xgboost/training_callback.rb +23 -0
 - data/lib/xgboost/utils.rb +19 -4
 - data/lib/xgboost/version.rb +1 -1
 - data/lib/xgboost.rb +107 -112
 - data/vendor/aarch64-linux/libxgboost.so +0 -0
 - data/vendor/arm64-darwin/libxgboost.dylib +0 -0
 - data/vendor/x64-mingw/xgboost.dll +0 -0
 - data/vendor/x86_64-darwin/libxgboost.dylib +0 -0
 - data/vendor/x86_64-linux/libxgboost.so +0 -0
 - data/vendor/x86_64-linux-musl/libxgboost.so +0 -0
 - metadata +10 -10
 - data/vendor/aarch64-linux/LICENSE-rabit.txt +0 -28
 - data/vendor/arm64-darwin/LICENSE-rabit.txt +0 -28
 - data/vendor/x64-mingw/LICENSE-rabit.txt +0 -28
 - data/vendor/x86_64-darwin/LICENSE-rabit.txt +0 -28
 - data/vendor/x86_64-linux/LICENSE-rabit.txt +0 -28
 - data/vendor/x86_64-linux-musl/LICENSE-rabit.txt +0 -28
 
    
        checksums.yaml
    CHANGED
    
    | 
         @@ -1,7 +1,7 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            ---
         
     | 
| 
       2 
2 
     | 
    
         
             
            SHA256:
         
     | 
| 
       3 
     | 
    
         
            -
              metadata.gz:  
     | 
| 
       4 
     | 
    
         
            -
              data.tar.gz:  
     | 
| 
      
 3 
     | 
    
         
            +
              metadata.gz: ba232bc8a9c27bc9cfb2b7f87afc581acf9a5464df9fccf1b206d912704771fe
         
     | 
| 
      
 4 
     | 
    
         
            +
              data.tar.gz: 608ca80f01e81d32a31b796364de14096c89556543fa2892efd86bbf86fac96a
         
     | 
| 
       5 
5 
     | 
    
         
             
            SHA512:
         
     | 
| 
       6 
     | 
    
         
            -
              metadata.gz:  
     | 
| 
       7 
     | 
    
         
            -
              data.tar.gz:  
     | 
| 
      
 6 
     | 
    
         
            +
              metadata.gz: af38ca270cf3d12a8a7757ff5f32a41f5d68e43f9ad08c585df8efae835d014ddcac51201019aa339143cc081c5c69ef5914d8de06cb4326ccb2864b5b04eba2
         
     | 
| 
      
 7 
     | 
    
         
            +
              data.tar.gz: 787bddc1c867d648ffe13c783aede98c62917feca536d1f8fe27e7992e5f2945126ad15b1a4e3ff48cc041723d299cf3cbacdd5557083983575eae22a9583cd8
         
     | 
    
        data/CHANGELOG.md
    CHANGED
    
    | 
         @@ -1,3 +1,11 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            ## 0.9.0 (2024-10-17)
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            - Updated XGBoost to 2.1.1
         
     | 
| 
      
 4 
     | 
    
         
            +
            - Added support for callbacks
         
     | 
| 
      
 5 
     | 
    
         
            +
            - Added `num_features` and `save_config` methods to `Booster`
         
     | 
| 
      
 6 
     | 
    
         
            +
            - Added `num_nonmissing` and `data_split_mode` methods to `DMatrix`
         
     | 
| 
      
 7 
     | 
    
         
            +
            - Dropped support for Ruby < 3.1
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
       1 
9 
     | 
    
         
             
            ## 0.8.0 (2023-09-13)
         
     | 
| 
       2 
10 
     | 
    
         | 
| 
       3 
11 
     | 
    
         
             
            - Updated XGBoost to 2.0.0
         
     | 
    
        data/NOTICE.txt
    CHANGED
    
    
    
        data/README.md
    CHANGED
    
    | 
         @@ -2,7 +2,7 @@ 
     | 
|
| 
       2 
2 
     | 
    
         | 
| 
       3 
3 
     | 
    
         
             
            [XGBoost](https://github.com/dmlc/xgboost) - high performance gradient boosting - for Ruby
         
     | 
| 
       4 
4 
     | 
    
         | 
| 
       5 
     | 
    
         
            -
            [](https://github.com/ankane/xgboost-ruby/actions)
         
     | 
| 
       6 
6 
     | 
    
         | 
| 
       7 
7 
     | 
    
         
             
            ## Installation
         
     | 
| 
       8 
8 
     | 
    
         | 
    
        data/lib/xgboost/booster.rb
    CHANGED
    
    | 
         @@ -1,77 +1,155 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            module XGBoost
         
     | 
| 
       2 
2 
     | 
    
         
             
              class Booster
         
     | 
| 
       3 
     | 
    
         
            -
                 
     | 
| 
      
 3 
     | 
    
         
            +
                include Utils
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
                def initialize(params: nil, cache: nil, model_file: nil)
         
     | 
| 
      
 6 
     | 
    
         
            +
                  cache ||= []
         
     | 
| 
      
 7 
     | 
    
         
            +
                  cache.each do |d|
         
     | 
| 
      
 8 
     | 
    
         
            +
                    if !d.is_a?(DMatrix)
         
     | 
| 
      
 9 
     | 
    
         
            +
                      raise TypeError, "invalid cache item: #{d.class.name}"
         
     | 
| 
      
 10 
     | 
    
         
            +
                    end
         
     | 
| 
      
 11 
     | 
    
         
            +
                  end
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
                  dmats = array_of_pointers(cache.map { |d| d.handle })
         
     | 
| 
      
 14 
     | 
    
         
            +
                  out = ::FFI::MemoryPointer.new(:pointer)
         
     | 
| 
      
 15 
     | 
    
         
            +
                  check_call FFI.XGBoosterCreate(dmats, cache.length, out)
         
     | 
| 
      
 16 
     | 
    
         
            +
                  @handle = ::FFI::AutoPointer.new(out.read_pointer, FFI.method(:XGBoosterFree))
         
     | 
| 
       4 
17 
     | 
    
         | 
| 
       5 
     | 
    
         
            -
             
     | 
| 
       6 
     | 
    
         
            -
             
     | 
| 
       7 
     | 
    
         
            -
                   
     | 
| 
       8 
     | 
    
         
            -
                  ObjectSpace.define_finalizer(@handle, self.class.finalize(handle_pointer.to_i))
         
     | 
| 
      
 18 
     | 
    
         
            +
                  cache.each do |d|
         
     | 
| 
      
 19 
     | 
    
         
            +
                    assign_dmatrix_features(d)
         
     | 
| 
      
 20 
     | 
    
         
            +
                  end
         
     | 
| 
       9 
21 
     | 
    
         | 
| 
       10 
22 
     | 
    
         
             
                  if model_file
         
     | 
| 
       11 
     | 
    
         
            -
                     
     | 
| 
      
 23 
     | 
    
         
            +
                    check_call FFI.XGBoosterLoadModel(handle, model_file)
         
     | 
| 
       12 
24 
     | 
    
         
             
                  end
         
     | 
| 
       13 
25 
     | 
    
         | 
| 
       14 
     | 
    
         
            -
                  self.best_iteration = 0
         
     | 
| 
       15 
26 
     | 
    
         
             
                  set_param(params)
         
     | 
| 
       16 
27 
     | 
    
         
             
                end
         
     | 
| 
       17 
28 
     | 
    
         | 
| 
       18 
     | 
    
         
            -
                def  
     | 
| 
       19 
     | 
    
         
            -
                   
     | 
| 
       20 
     | 
    
         
            -
             
     | 
| 
      
 29 
     | 
    
         
            +
                def [](key_name)
         
     | 
| 
      
 30 
     | 
    
         
            +
                  if key_name.is_a?(String)
         
     | 
| 
      
 31 
     | 
    
         
            +
                    return attr(key_name)
         
     | 
| 
      
 32 
     | 
    
         
            +
                  end
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
                  # TODO slice
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
                  raise TypeError, "expected string"
         
     | 
| 
       21 
37 
     | 
    
         
             
                end
         
     | 
| 
       22 
38 
     | 
    
         | 
| 
       23 
     | 
    
         
            -
                def  
     | 
| 
       24 
     | 
    
         
            -
                   
     | 
| 
      
 39 
     | 
    
         
            +
                def []=(key_name, raw_value)
         
     | 
| 
      
 40 
     | 
    
         
            +
                  set_attr(**{key_name => raw_value})
         
     | 
| 
       25 
41 
     | 
    
         
             
                end
         
     | 
| 
       26 
42 
     | 
    
         | 
| 
       27 
     | 
    
         
            -
                def  
     | 
| 
       28 
     | 
    
         
            -
                   
     | 
| 
       29 
     | 
    
         
            -
                   
     | 
| 
      
 43 
     | 
    
         
            +
                def save_config
         
     | 
| 
      
 44 
     | 
    
         
            +
                  length = ::FFI::MemoryPointer.new(:uint64)
         
     | 
| 
      
 45 
     | 
    
         
            +
                  json_string = ::FFI::MemoryPointer.new(:pointer)
         
     | 
| 
      
 46 
     | 
    
         
            +
                  check_call FFI.XGBoosterSaveJsonConfig(handle, length, json_string)
         
     | 
| 
      
 47 
     | 
    
         
            +
                  json_string.read_pointer.read_string(length.read_uint64).force_encoding(Encoding::UTF_8)
         
     | 
| 
      
 48 
     | 
    
         
            +
                end
         
     | 
| 
       30 
49 
     | 
    
         | 
| 
       31 
     | 
    
         
            -
             
     | 
| 
      
 50 
     | 
    
         
            +
                def attr(key)
         
     | 
| 
      
 51 
     | 
    
         
            +
                  ret = ::FFI::MemoryPointer.new(:pointer)
         
     | 
| 
      
 52 
     | 
    
         
            +
                  success = ::FFI::MemoryPointer.new(:int)
         
     | 
| 
      
 53 
     | 
    
         
            +
                  check_call FFI.XGBoosterGetAttr(handle, key.to_s, ret, success)
         
     | 
| 
      
 54 
     | 
    
         
            +
                  success.read_int != 0 ? ret.read_pointer.read_string : nil
         
     | 
| 
      
 55 
     | 
    
         
            +
                end
         
     | 
| 
       32 
56 
     | 
    
         | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
      
 57 
     | 
    
         
            +
                def attributes
         
     | 
| 
      
 58 
     | 
    
         
            +
                  length = ::FFI::MemoryPointer.new(:uint64)
         
     | 
| 
      
 59 
     | 
    
         
            +
                  sarr = ::FFI::MemoryPointer.new(:pointer)
         
     | 
| 
      
 60 
     | 
    
         
            +
                  check_call FFI.XGBoosterGetAttrNames(handle, length, sarr)
         
     | 
| 
      
 61 
     | 
    
         
            +
                  attr_names = from_cstr_to_rbstr(sarr, length)
         
     | 
| 
      
 62 
     | 
    
         
            +
                  attr_names.to_h { |n| [n, attr(n)] }
         
     | 
| 
      
 63 
     | 
    
         
            +
                end
         
     | 
| 
       34 
64 
     | 
    
         | 
| 
       35 
     | 
    
         
            -
             
     | 
| 
      
 65 
     | 
    
         
            +
                def set_attr(**kwargs)
         
     | 
| 
      
 66 
     | 
    
         
            +
                  kwargs.each do |key, value|
         
     | 
| 
      
 67 
     | 
    
         
            +
                    check_call FFI.XGBoosterSetAttr(handle, key.to_s, value&.to_s)
         
     | 
| 
      
 68 
     | 
    
         
            +
                  end
         
     | 
| 
      
 69 
     | 
    
         
            +
                end
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
                def feature_types
         
     | 
| 
      
 72 
     | 
    
         
            +
                  get_feature_info("feature_type")
         
     | 
| 
      
 73 
     | 
    
         
            +
                end
         
     | 
| 
      
 74 
     | 
    
         
            +
             
     | 
| 
      
 75 
     | 
    
         
            +
                def feature_types=(features)
         
     | 
| 
      
 76 
     | 
    
         
            +
                  set_feature_info(features, "feature_type")
         
     | 
| 
      
 77 
     | 
    
         
            +
                end
         
     | 
| 
      
 78 
     | 
    
         
            +
             
     | 
| 
      
 79 
     | 
    
         
            +
                def feature_names
         
     | 
| 
      
 80 
     | 
    
         
            +
                  get_feature_info("feature_name")
         
     | 
| 
      
 81 
     | 
    
         
            +
                end
         
     | 
| 
      
 82 
     | 
    
         
            +
             
     | 
| 
      
 83 
     | 
    
         
            +
                def feature_names=(features)
         
     | 
| 
      
 84 
     | 
    
         
            +
                  set_feature_info(features, "feature_name")
         
     | 
| 
       36 
85 
     | 
    
         
             
                end
         
     | 
| 
       37 
86 
     | 
    
         | 
| 
       38 
87 
     | 
    
         
             
                def set_param(params, value = nil)
         
     | 
| 
       39 
88 
     | 
    
         
             
                  if params.is_a?(Enumerable)
         
     | 
| 
       40 
89 
     | 
    
         
             
                    params.each do |k, v|
         
     | 
| 
       41 
     | 
    
         
            -
                       
     | 
| 
      
 90 
     | 
    
         
            +
                      check_call FFI.XGBoosterSetParam(handle, k.to_s, v.to_s)
         
     | 
| 
       42 
91 
     | 
    
         
             
                    end
         
     | 
| 
       43 
92 
     | 
    
         
             
                  else
         
     | 
| 
       44 
     | 
    
         
            -
                     
     | 
| 
      
 93 
     | 
    
         
            +
                    check_call FFI.XGBoosterSetParam(handle, params.to_s, value.to_s)
         
     | 
| 
       45 
94 
     | 
    
         
             
                  end
         
     | 
| 
       46 
95 
     | 
    
         
             
                end
         
     | 
| 
       47 
96 
     | 
    
         | 
| 
      
 97 
     | 
    
         
            +
                def update(dtrain, iteration)
         
     | 
| 
      
 98 
     | 
    
         
            +
                  check_call FFI.XGBoosterUpdateOneIter(handle, iteration, dtrain.handle)
         
     | 
| 
      
 99 
     | 
    
         
            +
                end
         
     | 
| 
      
 100 
     | 
    
         
            +
             
     | 
| 
      
 101 
     | 
    
         
            +
                def eval_set(evals, iteration)
         
     | 
| 
      
 102 
     | 
    
         
            +
                  dmats = array_of_pointers(evals.map { |v| v[0].handle })
         
     | 
| 
      
 103 
     | 
    
         
            +
                  evnames = array_of_pointers(evals.map { |v| string_pointer(v[1]) })
         
     | 
| 
      
 104 
     | 
    
         
            +
             
     | 
| 
      
 105 
     | 
    
         
            +
                  out_result = ::FFI::MemoryPointer.new(:pointer)
         
     | 
| 
      
 106 
     | 
    
         
            +
             
     | 
| 
      
 107 
     | 
    
         
            +
                  check_call FFI.XGBoosterEvalOneIter(handle, iteration, dmats, evnames, evals.size, out_result)
         
     | 
| 
      
 108 
     | 
    
         
            +
             
     | 
| 
      
 109 
     | 
    
         
            +
                  out_result.read_pointer.read_string
         
     | 
| 
      
 110 
     | 
    
         
            +
                end
         
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
       48 
112 
     | 
    
         
             
                def predict(data, ntree_limit: nil)
         
     | 
| 
       49 
113 
     | 
    
         
             
                  ntree_limit ||= 0
         
     | 
| 
       50 
114 
     | 
    
         
             
                  out_len = ::FFI::MemoryPointer.new(:uint64)
         
     | 
| 
       51 
115 
     | 
    
         
             
                  out_result = ::FFI::MemoryPointer.new(:pointer)
         
     | 
| 
       52 
     | 
    
         
            -
                   
     | 
| 
       53 
     | 
    
         
            -
                  out = out_result.read_pointer.read_array_of_float(read_uint64 
     | 
| 
      
 116 
     | 
    
         
            +
                  check_call FFI.XGBoosterPredict(handle, data.handle, 0, ntree_limit, 0, out_len, out_result)
         
     | 
| 
      
 117 
     | 
    
         
            +
                  out = out_result.read_pointer.read_array_of_float(out_len.read_uint64)
         
     | 
| 
       54 
118 
     | 
    
         
             
                  num_class = out.size / data.num_row
         
     | 
| 
       55 
119 
     | 
    
         
             
                  out = out.each_slice(num_class).to_a if num_class > 1
         
     | 
| 
       56 
120 
     | 
    
         
             
                  out
         
     | 
| 
       57 
121 
     | 
    
         
             
                end
         
     | 
| 
       58 
122 
     | 
    
         | 
| 
       59 
123 
     | 
    
         
             
                def save_model(fname)
         
     | 
| 
       60 
     | 
    
         
            -
                   
     | 
| 
      
 124 
     | 
    
         
            +
                  check_call FFI.XGBoosterSaveModel(handle, fname)
         
     | 
| 
       61 
125 
     | 
    
         
             
                end
         
     | 
| 
       62 
126 
     | 
    
         | 
| 
       63 
     | 
    
         
            -
                 
     | 
| 
       64 
     | 
    
         
            -
             
     | 
| 
       65 
     | 
    
         
            -
             
     | 
| 
       66 
     | 
    
         
            -
                  out_result = ::FFI::MemoryPointer.new(:pointer)
         
     | 
| 
      
 127 
     | 
    
         
            +
                def best_iteration
         
     | 
| 
      
 128 
     | 
    
         
            +
                  attr(:best_iteration)&.to_i
         
     | 
| 
      
 129 
     | 
    
         
            +
                end
         
     | 
| 
       67 
130 
     | 
    
         | 
| 
       68 
     | 
    
         
            -
             
     | 
| 
       69 
     | 
    
         
            -
                   
     | 
| 
       70 
     | 
    
         
            -
             
     | 
| 
      
 131 
     | 
    
         
            +
                def best_iteration=(iteration)
         
     | 
| 
      
 132 
     | 
    
         
            +
                  set_attr(best_iteration: iteration)
         
     | 
| 
      
 133 
     | 
    
         
            +
                end
         
     | 
| 
      
 134 
     | 
    
         
            +
             
     | 
| 
      
 135 
     | 
    
         
            +
                def best_score
         
     | 
| 
      
 136 
     | 
    
         
            +
                  attr(:best_score)&.to_f
         
     | 
| 
      
 137 
     | 
    
         
            +
                end
         
     | 
| 
       71 
138 
     | 
    
         | 
| 
       72 
     | 
    
         
            -
             
     | 
| 
      
 139 
     | 
    
         
            +
                def best_score=(score)
         
     | 
| 
      
 140 
     | 
    
         
            +
                  set_attr(best_score: score)
         
     | 
| 
      
 141 
     | 
    
         
            +
                end
         
     | 
| 
      
 142 
     | 
    
         
            +
             
     | 
| 
      
 143 
     | 
    
         
            +
                def num_boosted_rounds
         
     | 
| 
      
 144 
     | 
    
         
            +
                  rounds = ::FFI::MemoryPointer.new(:int)
         
     | 
| 
      
 145 
     | 
    
         
            +
                  check_call FFI.XGBoosterBoostedRounds(handle, rounds)
         
     | 
| 
      
 146 
     | 
    
         
            +
                  rounds.read_int
         
     | 
| 
      
 147 
     | 
    
         
            +
                end
         
     | 
| 
       73 
148 
     | 
    
         | 
| 
       74 
     | 
    
         
            -
             
     | 
| 
      
 149 
     | 
    
         
            +
                def num_features
         
     | 
| 
      
 150 
     | 
    
         
            +
                  features = ::FFI::MemoryPointer.new(:uint64)
         
     | 
| 
      
 151 
     | 
    
         
            +
                  check_call FFI.XGBoosterGetNumFeature(handle, features)
         
     | 
| 
      
 152 
     | 
    
         
            +
                  features.read_uint64
         
     | 
| 
       75 
153 
     | 
    
         
             
                end
         
     | 
| 
       76 
154 
     | 
    
         | 
| 
       77 
155 
     | 
    
         
             
                def dump_model(fout, fmap: "", with_stats: false, dump_format: "text")
         
     | 
| 
         @@ -93,6 +171,20 @@ module XGBoost 
     | 
|
| 
       93 
171 
     | 
    
         
             
                  end
         
     | 
| 
       94 
172 
     | 
    
         
             
                end
         
     | 
| 
       95 
173 
     | 
    
         | 
| 
      
 174 
     | 
    
         
            +
                # returns an array of strings
         
     | 
| 
      
 175 
     | 
    
         
            +
                def dump(fmap: "", with_stats: false, dump_format: "text")
         
     | 
| 
      
 176 
     | 
    
         
            +
                  out_len = ::FFI::MemoryPointer.new(:uint64)
         
     | 
| 
      
 177 
     | 
    
         
            +
                  out_result = ::FFI::MemoryPointer.new(:pointer)
         
     | 
| 
      
 178 
     | 
    
         
            +
             
     | 
| 
      
 179 
     | 
    
         
            +
                  names = feature_names || []
         
     | 
| 
      
 180 
     | 
    
         
            +
                  fnames = array_of_pointers(names.map { |fname| string_pointer(fname) })
         
     | 
| 
      
 181 
     | 
    
         
            +
                  ftypes = array_of_pointers(feature_types || Array.new(names.size, string_pointer("float")))
         
     | 
| 
      
 182 
     | 
    
         
            +
             
     | 
| 
      
 183 
     | 
    
         
            +
                  check_call FFI.XGBoosterDumpModelExWithFeatures(handle, names.size, fnames, ftypes, with_stats ? 1 : 0, dump_format, out_len, out_result)
         
     | 
| 
      
 184 
     | 
    
         
            +
             
     | 
| 
      
 185 
     | 
    
         
            +
                  out_result.read_pointer.get_array_of_string(0, out_len.read_uint64)
         
     | 
| 
      
 186 
     | 
    
         
            +
                end
         
     | 
| 
      
 187 
     | 
    
         
            +
             
     | 
| 
       96 
188 
     | 
    
         
             
                def fscore(fmap: "")
         
     | 
| 
       97 
189 
     | 
    
         
             
                  # always weight
         
     | 
| 
       98 
190 
     | 
    
         
             
                  score(fmap: fmap, importance_type: "weight")
         
     | 
| 
         @@ -157,48 +249,67 @@ module XGBoost 
     | 
|
| 
       157 
249 
     | 
    
         
             
                  end
         
     | 
| 
       158 
250 
     | 
    
         
             
                end
         
     | 
| 
       159 
251 
     | 
    
         | 
| 
       160 
     | 
    
         
            -
                 
     | 
| 
       161 
     | 
    
         
            -
                  key = string_pointer(key_name)
         
     | 
| 
       162 
     | 
    
         
            -
                  success = ::FFI::MemoryPointer.new(:int)
         
     | 
| 
       163 
     | 
    
         
            -
                  out_result = ::FFI::MemoryPointer.new(:pointer)
         
     | 
| 
       164 
     | 
    
         
            -
             
     | 
| 
       165 
     | 
    
         
            -
                  check_result FFI.XGBoosterGetAttr(handle_pointer, key, out_result, success)
         
     | 
| 
       166 
     | 
    
         
            -
             
     | 
| 
       167 
     | 
    
         
            -
                  success.read_int == 1 ? out_result.read_pointer.read_string : nil
         
     | 
| 
       168 
     | 
    
         
            -
                end
         
     | 
| 
       169 
     | 
    
         
            -
             
     | 
| 
       170 
     | 
    
         
            -
                def []=(key_name, raw_value)
         
     | 
| 
       171 
     | 
    
         
            -
                  key = string_pointer(key_name)
         
     | 
| 
       172 
     | 
    
         
            -
                  value = raw_value.nil? ? nil : string_pointer(raw_value)
         
     | 
| 
      
 252 
     | 
    
         
            +
                private
         
     | 
| 
       173 
253 
     | 
    
         | 
| 
       174 
     | 
    
         
            -
             
     | 
| 
      
 254 
     | 
    
         
            +
                def handle
         
     | 
| 
      
 255 
     | 
    
         
            +
                  @handle
         
     | 
| 
       175 
256 
     | 
    
         
             
                end
         
     | 
| 
       176 
257 
     | 
    
         | 
| 
       177 
     | 
    
         
            -
                def  
     | 
| 
       178 
     | 
    
         
            -
                   
     | 
| 
       179 
     | 
    
         
            -
             
     | 
| 
       180 
     | 
    
         
            -
                   
     | 
| 
       181 
     | 
    
         
            -
             
     | 
| 
       182 
     | 
    
         
            -
                  len = read_uint64(out_len)
         
     | 
| 
       183 
     | 
    
         
            -
                  key_names = len.zero? ? [] : out_result.read_pointer.get_array_of_string(0, len)
         
     | 
| 
       184 
     | 
    
         
            -
             
     | 
| 
       185 
     | 
    
         
            -
                  key_names.map { |key_name| [key_name, self[key_name]] }.to_h
         
     | 
| 
       186 
     | 
    
         
            -
                end
         
     | 
| 
      
 258 
     | 
    
         
            +
                def assign_dmatrix_features(data)
         
     | 
| 
      
 259 
     | 
    
         
            +
                  if data.num_row == 0
         
     | 
| 
      
 260 
     | 
    
         
            +
                    return
         
     | 
| 
      
 261 
     | 
    
         
            +
                  end
         
     | 
| 
       187 
262 
     | 
    
         | 
| 
       188 
     | 
    
         
            -
             
     | 
| 
      
 263 
     | 
    
         
            +
                  fn = data.feature_names
         
     | 
| 
      
 264 
     | 
    
         
            +
                  ft = data.feature_types
         
     | 
| 
       189 
265 
     | 
    
         | 
| 
       190 
     | 
    
         
            -
             
     | 
| 
       191 
     | 
    
         
            -
             
     | 
| 
      
 266 
     | 
    
         
            +
                  if feature_names.nil?
         
     | 
| 
      
 267 
     | 
    
         
            +
                    self.feature_names = fn
         
     | 
| 
      
 268 
     | 
    
         
            +
                  end
         
     | 
| 
      
 269 
     | 
    
         
            +
                  if feature_types.nil?
         
     | 
| 
      
 270 
     | 
    
         
            +
                    self.feature_types = ft
         
     | 
| 
      
 271 
     | 
    
         
            +
                  end
         
     | 
| 
       192 
272 
     | 
    
         
             
                end
         
     | 
| 
       193 
273 
     | 
    
         | 
| 
       194 
     | 
    
         
            -
                def  
     | 
| 
       195 
     | 
    
         
            -
                  ::FFI::MemoryPointer.new(: 
     | 
| 
      
 274 
     | 
    
         
            +
                def get_feature_info(field)
         
     | 
| 
      
 275 
     | 
    
         
            +
                  length = ::FFI::MemoryPointer.new(:uint64)
         
     | 
| 
      
 276 
     | 
    
         
            +
                  sarr = ::FFI::MemoryPointer.new(:pointer)
         
     | 
| 
      
 277 
     | 
    
         
            +
                  if @handle.nil?
         
     | 
| 
      
 278 
     | 
    
         
            +
                    return nil
         
     | 
| 
      
 279 
     | 
    
         
            +
                  end
         
     | 
| 
      
 280 
     | 
    
         
            +
                  check_call(
         
     | 
| 
      
 281 
     | 
    
         
            +
                    FFI.XGBoosterGetStrFeatureInfo(
         
     | 
| 
      
 282 
     | 
    
         
            +
                      handle,
         
     | 
| 
      
 283 
     | 
    
         
            +
                      field,
         
     | 
| 
      
 284 
     | 
    
         
            +
                      length,
         
     | 
| 
      
 285 
     | 
    
         
            +
                      sarr
         
     | 
| 
      
 286 
     | 
    
         
            +
                    )
         
     | 
| 
      
 287 
     | 
    
         
            +
                  )
         
     | 
| 
      
 288 
     | 
    
         
            +
                  feature_info = from_cstr_to_rbstr(sarr, length)
         
     | 
| 
      
 289 
     | 
    
         
            +
                  !feature_info.empty? ? feature_info : nil
         
     | 
| 
       196 
290 
     | 
    
         
             
                end
         
     | 
| 
       197 
291 
     | 
    
         | 
| 
       198 
     | 
    
         
            -
                def  
     | 
| 
       199 
     | 
    
         
            -
                   
     | 
| 
      
 292 
     | 
    
         
            +
                def set_feature_info(features, field)
         
     | 
| 
      
 293 
     | 
    
         
            +
                  if !features.nil?
         
     | 
| 
      
 294 
     | 
    
         
            +
                    if !features.is_a?(Array)
         
     | 
| 
      
 295 
     | 
    
         
            +
                      raise TypeError, "features must be an array"
         
     | 
| 
      
 296 
     | 
    
         
            +
                    end
         
     | 
| 
      
 297 
     | 
    
         
            +
                    c_feature_info = array_of_pointers(features.map { |f| string_pointer(f) })
         
     | 
| 
      
 298 
     | 
    
         
            +
                    check_call(
         
     | 
| 
      
 299 
     | 
    
         
            +
                      FFI.XGBoosterSetStrFeatureInfo(
         
     | 
| 
      
 300 
     | 
    
         
            +
                        handle,
         
     | 
| 
      
 301 
     | 
    
         
            +
                        field,
         
     | 
| 
      
 302 
     | 
    
         
            +
                        c_feature_info,
         
     | 
| 
      
 303 
     | 
    
         
            +
                        features.length
         
     | 
| 
      
 304 
     | 
    
         
            +
                      )
         
     | 
| 
      
 305 
     | 
    
         
            +
                    )
         
     | 
| 
      
 306 
     | 
    
         
            +
                  else
         
     | 
| 
      
 307 
     | 
    
         
            +
                    check_call(
         
     | 
| 
      
 308 
     | 
    
         
            +
                      FFI.XGBoosterSetStrFeatureInfo(
         
     | 
| 
      
 309 
     | 
    
         
            +
                        handle, field, nil, 0
         
     | 
| 
      
 310 
     | 
    
         
            +
                      )
         
     | 
| 
      
 311 
     | 
    
         
            +
                    )
         
     | 
| 
      
 312 
     | 
    
         
            +
                  end
         
     | 
| 
       200 
313 
     | 
    
         
             
                end
         
     | 
| 
       201 
     | 
    
         
            -
             
     | 
| 
       202 
     | 
    
         
            -
                include Utils
         
     | 
| 
       203 
314 
     | 
    
         
             
              end
         
     | 
| 
       204 
315 
     | 
    
         
             
            end
         
     | 
| 
         @@ -0,0 +1,145 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            module XGBoost
         
     | 
| 
      
 2 
     | 
    
         
            +
              class CallbackContainer
         
     | 
| 
      
 3 
     | 
    
         
            +
                attr_reader :aggregated_cv, :history
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
                def initialize(callbacks, is_cv: false)
         
     | 
| 
      
 6 
     | 
    
         
            +
                  @callbacks = callbacks
         
     | 
| 
      
 7 
     | 
    
         
            +
                  callbacks.each do |callback|
         
     | 
| 
      
 8 
     | 
    
         
            +
                    unless callback.is_a?(TrainingCallback)
         
     | 
| 
      
 9 
     | 
    
         
            +
                      raise TypeError, "callback must be an instance of XGBoost::TrainingCallback"
         
     | 
| 
      
 10 
     | 
    
         
            +
                    end
         
     | 
| 
      
 11 
     | 
    
         
            +
                  end
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
                  @history = {}
         
     | 
| 
      
 14 
     | 
    
         
            +
                  @is_cv = is_cv
         
     | 
| 
      
 15 
     | 
    
         
            +
                end
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
                def before_training(model)
         
     | 
| 
      
 18 
     | 
    
         
            +
                  @callbacks.each do |callback|
         
     | 
| 
      
 19 
     | 
    
         
            +
                    model = callback.before_training(model)
         
     | 
| 
      
 20 
     | 
    
         
            +
                    if @is_cv
         
     | 
| 
      
 21 
     | 
    
         
            +
                      unless model.is_a?(PackedBooster)
         
     | 
| 
      
 22 
     | 
    
         
            +
                        raise TypeError, "before_training should return the model"
         
     | 
| 
      
 23 
     | 
    
         
            +
                      end
         
     | 
| 
      
 24 
     | 
    
         
            +
                    else
         
     | 
| 
      
 25 
     | 
    
         
            +
                      unless model.is_a?(Booster)
         
     | 
| 
      
 26 
     | 
    
         
            +
                        raise TypeError, "before_training should return the model"
         
     | 
| 
      
 27 
     | 
    
         
            +
                      end
         
     | 
| 
      
 28 
     | 
    
         
            +
                    end
         
     | 
| 
      
 29 
     | 
    
         
            +
                  end
         
     | 
| 
      
 30 
     | 
    
         
            +
                  model
         
     | 
| 
      
 31 
     | 
    
         
            +
                end
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
                def after_training(model)
         
     | 
| 
      
 34 
     | 
    
         
            +
                  @callbacks.each do |callback|
         
     | 
| 
      
 35 
     | 
    
         
            +
                    model = callback.after_training(model)
         
     | 
| 
      
 36 
     | 
    
         
            +
                    if @is_cv
         
     | 
| 
      
 37 
     | 
    
         
            +
                      unless model.is_a?(PackedBooster)
         
     | 
| 
      
 38 
     | 
    
         
            +
                        raise TypeError, "after_training should return the model"
         
     | 
| 
      
 39 
     | 
    
         
            +
                      end
         
     | 
| 
      
 40 
     | 
    
         
            +
                    else
         
     | 
| 
      
 41 
     | 
    
         
            +
                      unless model.is_a?(Booster)
         
     | 
| 
      
 42 
     | 
    
         
            +
                        raise TypeError, "after_training should return the model"
         
     | 
| 
      
 43 
     | 
    
         
            +
                      end
         
     | 
| 
      
 44 
     | 
    
         
            +
                    end
         
     | 
| 
      
 45 
     | 
    
         
            +
                  end
         
     | 
| 
      
 46 
     | 
    
         
            +
                  model
         
     | 
| 
      
 47 
     | 
    
         
            +
                end
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
                def before_iteration(model, epoch, dtrain, evals)
         
     | 
| 
      
 50 
     | 
    
         
            +
                  @callbacks.any? do |callback|
         
     | 
| 
      
 51 
     | 
    
         
            +
                    callback.before_iteration(model, epoch, @history)
         
     | 
| 
      
 52 
     | 
    
         
            +
                  end
         
     | 
| 
      
 53 
     | 
    
         
            +
                end
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
                def after_iteration(model, epoch, dtrain, evals)
         
     | 
| 
      
 56 
     | 
    
         
            +
                  if @is_cv
         
     | 
| 
      
 57 
     | 
    
         
            +
                    scores = model.eval_set(epoch)
         
     | 
| 
      
 58 
     | 
    
         
            +
                    scores = aggcv(scores)
         
     | 
| 
      
 59 
     | 
    
         
            +
                    @aggregated_cv = scores
         
     | 
| 
      
 60 
     | 
    
         
            +
                    update_history(scores, epoch)
         
     | 
| 
      
 61 
     | 
    
         
            +
                  else
         
     | 
| 
      
 62 
     | 
    
         
            +
                    evals ||= []
         
     | 
| 
      
 63 
     | 
    
         
            +
                    evals.each do |_, name|
         
     | 
| 
      
 64 
     | 
    
         
            +
                      if name.include?("-")
         
     | 
| 
      
 65 
     | 
    
         
            +
                        raise ArgumentError, "Dataset name should not contain `-`"
         
     | 
| 
      
 66 
     | 
    
         
            +
                      end
         
     | 
| 
      
 67 
     | 
    
         
            +
                    end
         
     | 
| 
      
 68 
     | 
    
         
            +
                    score = model.eval_set(evals, epoch)
         
     | 
| 
      
 69 
     | 
    
         
            +
                    metric_score = parse_eval_str(score)
         
     | 
| 
      
 70 
     | 
    
         
            +
                    update_history(metric_score, epoch)
         
     | 
| 
      
 71 
     | 
    
         
            +
                  end
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
                  @callbacks.any? do |callback|
         
     | 
| 
      
 74 
     | 
    
         
            +
                    callback.after_iteration(model, epoch, @history)
         
     | 
| 
      
 75 
     | 
    
         
            +
                  end
         
     | 
| 
      
 76 
     | 
    
         
            +
                end
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                private
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
                def update_history(score, epoch)
         
     | 
| 
      
 81 
     | 
    
         
            +
                  score.each do |d|
         
     | 
| 
      
 82 
     | 
    
         
            +
                    name = d[0]
         
     | 
| 
      
 83 
     | 
    
         
            +
                    s = d[1]
         
     | 
| 
      
 84 
     | 
    
         
            +
                    if @is_cv
         
     | 
| 
      
 85 
     | 
    
         
            +
                      std = d[2]
         
     | 
| 
      
 86 
     | 
    
         
            +
                      x = [s, std]
         
     | 
| 
      
 87 
     | 
    
         
            +
                    else
         
     | 
| 
      
 88 
     | 
    
         
            +
                      x = s
         
     | 
| 
      
 89 
     | 
    
         
            +
                    end
         
     | 
| 
      
 90 
     | 
    
         
            +
                    splited_names = name.split("-")
         
     | 
| 
      
 91 
     | 
    
         
            +
                    data_name = splited_names[0]
         
     | 
| 
      
 92 
     | 
    
         
            +
                    metric_name = splited_names[1..].join("-")
         
     | 
| 
      
 93 
     | 
    
         
            +
                    @history[data_name] ||= {}
         
     | 
| 
      
 94 
     | 
    
         
            +
                    data_history = @history[data_name]
         
     | 
| 
      
 95 
     | 
    
         
            +
                    data_history[metric_name] ||= []
         
     | 
| 
      
 96 
     | 
    
         
            +
                    metric_history = data_history[metric_name]
         
     | 
| 
      
 97 
     | 
    
         
            +
                    metric_history << x
         
     | 
| 
      
 98 
     | 
    
         
            +
                  end
         
     | 
| 
      
 99 
     | 
    
         
            +
                end
         
     | 
| 
      
 100 
     | 
    
         
            +
             
     | 
| 
      
 101 
     | 
    
         
            +
                # TODO move
         
     | 
| 
      
 102 
     | 
    
         
            +
                def parse_eval_str(result)
         
     | 
| 
      
 103 
     | 
    
         
            +
                  splited = result.split[1..]
         
     | 
| 
      
 104 
     | 
    
         
            +
                  # split up `test-error:0.1234`
         
     | 
| 
      
 105 
     | 
    
         
            +
                  metric_score_str = splited.map { |s| s.split(":") }
         
     | 
| 
      
 106 
     | 
    
         
            +
                  # convert to float
         
     | 
| 
      
 107 
     | 
    
         
            +
                  metric_score = metric_score_str.map { |n, s| [n, s.to_f] }
         
     | 
| 
      
 108 
     | 
    
         
            +
                  metric_score
         
     | 
| 
      
 109 
     | 
    
         
            +
                end
         
     | 
| 
      
 110 
     | 
    
         
            +
             
     | 
| 
      
 111 
     | 
    
         
            +
                def aggcv(rlist)
         
     | 
| 
      
 112 
     | 
    
         
            +
                  cvmap = {}
         
     | 
| 
      
 113 
     | 
    
         
            +
                  idx = rlist[0].split[0]
         
     | 
| 
      
 114 
     | 
    
         
            +
                  rlist.each do |line|
         
     | 
| 
      
 115 
     | 
    
         
            +
                    arr = line.split
         
     | 
| 
      
 116 
     | 
    
         
            +
                    arr[1..].each_with_index do |it, metric_idx|
         
     | 
| 
      
 117 
     | 
    
         
            +
                      k, v = it.split(":")
         
     | 
| 
      
 118 
     | 
    
         
            +
                      (cvmap[[metric_idx, k]] ||= []) << v.to_f
         
     | 
| 
      
 119 
     | 
    
         
            +
                    end
         
     | 
| 
      
 120 
     | 
    
         
            +
                  end
         
     | 
| 
      
 121 
     | 
    
         
            +
                  msg = idx
         
     | 
| 
      
 122 
     | 
    
         
            +
                  results = []
         
     | 
| 
      
 123 
     | 
    
         
            +
                  cvmap.sort { |x| x[0][0] }.each do |(_, name), s|
         
     | 
| 
      
 124 
     | 
    
         
            +
                    mean = mean(s)
         
     | 
| 
      
 125 
     | 
    
         
            +
                    std = stdev(s)
         
     | 
| 
      
 126 
     | 
    
         
            +
                    results << [name, mean, std]
         
     | 
| 
      
 127 
     | 
    
         
            +
                  end
         
     | 
| 
      
 128 
     | 
    
         
            +
                  results
         
     | 
| 
      
 129 
     | 
    
         
            +
                end
         
     | 
| 
      
 130 
     | 
    
         
            +
             
     | 
| 
      
 131 
     | 
    
         
            +
                def mean(arr)
         
     | 
| 
      
 132 
     | 
    
         
            +
                  arr.sum / arr.size.to_f
         
     | 
| 
      
 133 
     | 
    
         
            +
                end
         
     | 
| 
      
 134 
     | 
    
         
            +
             
     | 
| 
      
 135 
     | 
    
         
            +
                # don't subtract one from arr.size
         
     | 
| 
      
 136 
     | 
    
         
            +
                def stdev(arr)
         
     | 
| 
      
 137 
     | 
    
         
            +
                  m = mean(arr)
         
     | 
| 
      
 138 
     | 
    
         
            +
                  sum = 0
         
     | 
| 
      
 139 
     | 
    
         
            +
                  arr.each do |v|
         
     | 
| 
      
 140 
     | 
    
         
            +
                    sum += (v - m) ** 2
         
     | 
| 
      
 141 
     | 
    
         
            +
                  end
         
     | 
| 
      
 142 
     | 
    
         
            +
                  Math.sqrt(sum / arr.size)
         
     | 
| 
      
 143 
     | 
    
         
            +
                end
         
     | 
| 
      
 144 
     | 
    
         
            +
              end
         
     | 
| 
      
 145 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,26 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            require "forwardable"
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            module XGBoost
         
     | 
| 
      
 4 
     | 
    
         
            +
              class CVPack
         
     | 
| 
      
 5 
     | 
    
         
            +
                extend Forwardable
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
                def_delegators :@bst, :num_boosted_rounds, :best_iteration=, :best_score=
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
                attr_reader :bst
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
                def initialize(dtrain, dtest, param)
         
     | 
| 
      
 12 
     | 
    
         
            +
                  @dtrain = dtrain
         
     | 
| 
      
 13 
     | 
    
         
            +
                  @dtest = dtest
         
     | 
| 
      
 14 
     | 
    
         
            +
                  @watchlist = [[dtrain, "train"], [dtest, "test"]]
         
     | 
| 
      
 15 
     | 
    
         
            +
                  @bst = Booster.new(params: param, cache: [dtrain, dtest])
         
     | 
| 
      
 16 
     | 
    
         
            +
                end
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
                def update(iteration)
         
     | 
| 
      
 19 
     | 
    
         
            +
                  @bst.update(@dtrain, iteration)
         
     | 
| 
      
 20 
     | 
    
         
            +
                end
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
                def eval_set(iteration)
         
     | 
| 
      
 23 
     | 
    
         
            +
                  @bst.eval_set(@watchlist, iteration)
         
     | 
| 
      
 24 
     | 
    
         
            +
                end
         
     | 
| 
      
 25 
     | 
    
         
            +
              end
         
     | 
| 
      
 26 
     | 
    
         
            +
            end
         
     |