xgb 0.8.0 → 0.10.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/NOTICE.txt +2 -2
- data/README.md +3 -2
- data/lib/xgboost/booster.rb +181 -65
- data/lib/xgboost/callback_container.rb +145 -0
- data/lib/xgboost/classifier.rb +1 -1
- data/lib/xgboost/cv_pack.rb +26 -0
- data/lib/xgboost/dmatrix.rb +190 -78
- data/lib/xgboost/early_stopping.rb +132 -0
- data/lib/xgboost/evaluation_monitor.rb +44 -0
- data/lib/xgboost/ffi.rb +12 -2
- data/lib/xgboost/model.rb +2 -1
- data/lib/xgboost/packed_booster.rb +51 -0
- data/lib/xgboost/regressor.rb +1 -1
- data/lib/xgboost/training_callback.rb +23 -0
- data/lib/xgboost/utils.rb +19 -4
- data/lib/xgboost/version.rb +1 -1
- data/lib/xgboost.rb +107 -112
- data/vendor/aarch64-linux/libxgboost.so +0 -0
- data/vendor/arm64-darwin/libxgboost.dylib +0 -0
- data/vendor/x64-mingw/xgboost.dll +0 -0
- data/vendor/x86_64-darwin/libxgboost.dylib +0 -0
- data/vendor/x86_64-linux/libxgboost.so +0 -0
- data/vendor/x86_64-linux-musl/libxgboost.so +0 -0
- metadata +10 -14
- data/vendor/aarch64-linux/LICENSE-rabit.txt +0 -28
- data/vendor/arm64-darwin/LICENSE-rabit.txt +0 -28
- data/vendor/x64-mingw/LICENSE-rabit.txt +0 -28
- data/vendor/x86_64-darwin/LICENSE-rabit.txt +0 -28
- data/vendor/x86_64-linux/LICENSE-rabit.txt +0 -28
- data/vendor/x86_64-linux-musl/LICENSE-rabit.txt +0 -28
    
        checksums.yaml
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 1 | 
             
            ---
         | 
| 2 2 | 
             
            SHA256:
         | 
| 3 | 
            -
              metadata.gz:  | 
| 4 | 
            -
              data.tar.gz:  | 
| 3 | 
            +
              metadata.gz: c717478c585e099431318c391c7c054b58c844f5ddb19dd5d23f921acbea5b26
         | 
| 4 | 
            +
              data.tar.gz: dd221a13627c57135ef2e6e87f267889a76692f1417014e5728b5c279b44a2d5
         | 
| 5 5 | 
             
            SHA512:
         | 
| 6 | 
            -
              metadata.gz:  | 
| 7 | 
            -
              data.tar.gz:  | 
| 6 | 
            +
              metadata.gz: 4430473a5999cde2447f0782be837f1fb06db44eb3e823621ae94db96b356b40e04a772fa211819ed5915324a10b54ccb0b595d7368aaa22d31f6be4eb13fd0d
         | 
| 7 | 
            +
              data.tar.gz: eccc8f38705f82f2aa963c672df5515aa345926af91f9f94beab5fd01406a8d5ad84807d3609dec321952d28d8b2da575b880220ae8375d4296a2c34bf6b0b97
         | 
    
        data/CHANGELOG.md
    CHANGED
    
    | @@ -1,3 +1,15 @@ | |
| 1 | 
            +
            ## 0.10.0 (2025-03-15)
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            - Updated XGBoost to 3.0.0
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            ## 0.9.0 (2024-10-17)
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            - Updated XGBoost to 2.1.1
         | 
| 8 | 
            +
            - Added support for callbacks
         | 
| 9 | 
            +
            - Added `num_features` and `save_config` methods to `Booster`
         | 
| 10 | 
            +
            - Added `num_nonmissing` and `data_split_mode` methods to `DMatrix`
         | 
| 11 | 
            +
            - Dropped support for Ruby < 3.1
         | 
| 12 | 
            +
             | 
| 1 13 | 
             
            ## 0.8.0 (2023-09-13)
         | 
| 2 14 |  | 
| 3 15 | 
             
            - Updated XGBoost to 2.0.0
         | 
    
        data/NOTICE.txt
    CHANGED
    
    | @@ -1,5 +1,5 @@ | |
| 1 | 
            -
            Copyright XGBoost contributors
         | 
| 2 | 
            -
            Copyright 2019- | 
| 1 | 
            +
            Copyright 2014-2025 XGBoost contributors
         | 
| 2 | 
            +
            Copyright 2019-2025 Andrew Kane
         | 
| 3 3 |  | 
| 4 4 | 
             
            Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 5 | 
             
            you may not use this file except in compliance with the License.
         | 
    
        data/README.md
    CHANGED
    
    | @@ -2,7 +2,7 @@ | |
| 2 2 |  | 
| 3 3 | 
             
            [XGBoost](https://github.com/dmlc/xgboost) - high performance gradient boosting - for Ruby
         | 
| 4 4 |  | 
| 5 | 
            -
            [](https://github.com/ankane/xgboost-ruby/actions)
         | 
| 6 6 |  | 
| 7 7 | 
             
            ## Installation
         | 
| 8 8 |  | 
| @@ -126,7 +126,8 @@ model.feature_importances | |
| 126 126 | 
             
            Early stopping
         | 
| 127 127 |  | 
| 128 128 | 
             
            ```ruby
         | 
| 129 | 
            -
            model. | 
| 129 | 
            +
            model = XGBoost::Regressor.new(early_stopping_rounds: 5)
         | 
| 130 | 
            +
            model.fit(x, y, eval_set: [[x_test, y_test]])
         | 
| 130 131 | 
             
            ```
         | 
| 131 132 |  | 
| 132 133 | 
             
            ## Data
         | 
    
        data/lib/xgboost/booster.rb
    CHANGED
    
    | @@ -1,77 +1,160 @@ | |
| 1 1 | 
             
            module XGBoost
         | 
| 2 2 | 
             
              class Booster
         | 
| 3 | 
            -
                 | 
| 3 | 
            +
                include Utils
         | 
| 4 | 
            +
             | 
| 5 | 
            +
                def initialize(params: nil, cache: nil, model_file: nil)
         | 
| 6 | 
            +
                  cache ||= []
         | 
| 7 | 
            +
                  cache.each do |d|
         | 
| 8 | 
            +
                    if !d.is_a?(DMatrix)
         | 
| 9 | 
            +
                      raise TypeError, "invalid cache item: #{d.class.name}"
         | 
| 10 | 
            +
                    end
         | 
| 11 | 
            +
                  end
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                  dmats = array_of_pointers(cache.map { |d| d.handle })
         | 
| 14 | 
            +
                  out = ::FFI::MemoryPointer.new(:pointer)
         | 
| 15 | 
            +
                  check_call FFI.XGBoosterCreate(dmats, cache.length, out)
         | 
| 16 | 
            +
                  @handle = ::FFI::AutoPointer.new(out.read_pointer, FFI.method(:XGBoosterFree))
         | 
| 4 17 |  | 
| 5 | 
            -
             | 
| 6 | 
            -
             | 
| 7 | 
            -
                   | 
| 8 | 
            -
                  ObjectSpace.define_finalizer(@handle, self.class.finalize(handle_pointer.to_i))
         | 
| 18 | 
            +
                  cache.each do |d|
         | 
| 19 | 
            +
                    assign_dmatrix_features(d)
         | 
| 20 | 
            +
                  end
         | 
| 9 21 |  | 
| 10 22 | 
             
                  if model_file
         | 
| 11 | 
            -
                     | 
| 23 | 
            +
                    check_call FFI.XGBoosterLoadModel(handle, model_file)
         | 
| 12 24 | 
             
                  end
         | 
| 13 25 |  | 
| 14 | 
            -
                  self.best_iteration = 0
         | 
| 15 26 | 
             
                  set_param(params)
         | 
| 16 27 | 
             
                end
         | 
| 17 28 |  | 
| 18 | 
            -
                def  | 
| 19 | 
            -
                   | 
| 20 | 
            -
             | 
| 29 | 
            +
                def [](key_name)
         | 
| 30 | 
            +
                  if key_name.is_a?(String)
         | 
| 31 | 
            +
                    return attr(key_name)
         | 
| 32 | 
            +
                  end
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                  # TODO slice
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                  raise TypeError, "expected string"
         | 
| 21 37 | 
             
                end
         | 
| 22 38 |  | 
| 23 | 
            -
                def  | 
| 24 | 
            -
                   | 
| 39 | 
            +
                def []=(key_name, raw_value)
         | 
| 40 | 
            +
                  set_attr(**{key_name => raw_value})
         | 
| 25 41 | 
             
                end
         | 
| 26 42 |  | 
| 27 | 
            -
                def  | 
| 28 | 
            -
                   | 
| 29 | 
            -
                   | 
| 43 | 
            +
                def save_config
         | 
| 44 | 
            +
                  length = ::FFI::MemoryPointer.new(:uint64)
         | 
| 45 | 
            +
                  json_string = ::FFI::MemoryPointer.new(:pointer)
         | 
| 46 | 
            +
                  check_call FFI.XGBoosterSaveJsonConfig(handle, length, json_string)
         | 
| 47 | 
            +
                  json_string.read_pointer.read_string(length.read_uint64).force_encoding(Encoding::UTF_8)
         | 
| 48 | 
            +
                end
         | 
| 30 49 |  | 
| 31 | 
            -
             | 
| 50 | 
            +
                def reset
         | 
| 51 | 
            +
                  check_call FFI.XGBoosterReset(handle)
         | 
| 52 | 
            +
                  self
         | 
| 53 | 
            +
                end
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def attr(key)
         | 
| 56 | 
            +
                  ret = ::FFI::MemoryPointer.new(:pointer)
         | 
| 57 | 
            +
                  success = ::FFI::MemoryPointer.new(:int)
         | 
| 58 | 
            +
                  check_call FFI.XGBoosterGetAttr(handle, key.to_s, ret, success)
         | 
| 59 | 
            +
                  success.read_int != 0 ? ret.read_pointer.read_string : nil
         | 
| 60 | 
            +
                end
         | 
| 32 61 |  | 
| 33 | 
            -
             | 
| 62 | 
            +
                def attributes
         | 
| 63 | 
            +
                  length = ::FFI::MemoryPointer.new(:uint64)
         | 
| 64 | 
            +
                  sarr = ::FFI::MemoryPointer.new(:pointer)
         | 
| 65 | 
            +
                  check_call FFI.XGBoosterGetAttrNames(handle, length, sarr)
         | 
| 66 | 
            +
                  attr_names = from_cstr_to_rbstr(sarr, length)
         | 
| 67 | 
            +
                  attr_names.to_h { |n| [n, attr(n)] }
         | 
| 68 | 
            +
                end
         | 
| 34 69 |  | 
| 35 | 
            -
             | 
| 70 | 
            +
                def set_attr(**kwargs)
         | 
| 71 | 
            +
                  kwargs.each do |key, value|
         | 
| 72 | 
            +
                    check_call FFI.XGBoosterSetAttr(handle, key.to_s, value&.to_s)
         | 
| 73 | 
            +
                  end
         | 
| 74 | 
            +
                end
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def feature_types
         | 
| 77 | 
            +
                  get_feature_info("feature_type")
         | 
| 78 | 
            +
                end
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                def feature_types=(features)
         | 
| 81 | 
            +
                  set_feature_info(features, "feature_type")
         | 
| 82 | 
            +
                end
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def feature_names
         | 
| 85 | 
            +
                  get_feature_info("feature_name")
         | 
| 86 | 
            +
                end
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def feature_names=(features)
         | 
| 89 | 
            +
                  set_feature_info(features, "feature_name")
         | 
| 36 90 | 
             
                end
         | 
| 37 91 |  | 
| 38 92 | 
             
                def set_param(params, value = nil)
         | 
| 39 93 | 
             
                  if params.is_a?(Enumerable)
         | 
| 40 94 | 
             
                    params.each do |k, v|
         | 
| 41 | 
            -
                       | 
| 95 | 
            +
                      check_call FFI.XGBoosterSetParam(handle, k.to_s, v.to_s)
         | 
| 42 96 | 
             
                    end
         | 
| 43 97 | 
             
                  else
         | 
| 44 | 
            -
                     | 
| 98 | 
            +
                    check_call FFI.XGBoosterSetParam(handle, params.to_s, value.to_s)
         | 
| 45 99 | 
             
                  end
         | 
| 46 100 | 
             
                end
         | 
| 47 101 |  | 
| 102 | 
            +
                def update(dtrain, iteration)
         | 
| 103 | 
            +
                  check_call FFI.XGBoosterUpdateOneIter(handle, iteration, dtrain.handle)
         | 
| 104 | 
            +
                end
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def eval_set(evals, iteration)
         | 
| 107 | 
            +
                  dmats = array_of_pointers(evals.map { |v| v[0].handle })
         | 
| 108 | 
            +
                  evnames = array_of_pointers(evals.map { |v| string_pointer(v[1]) })
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                  out_result = ::FFI::MemoryPointer.new(:pointer)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                  check_call FFI.XGBoosterEvalOneIter(handle, iteration, dmats, evnames, evals.size, out_result)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                  out_result.read_pointer.read_string
         | 
| 115 | 
            +
                end
         | 
| 116 | 
            +
             | 
| 48 117 | 
             
                def predict(data, ntree_limit: nil)
         | 
| 49 118 | 
             
                  ntree_limit ||= 0
         | 
| 50 119 | 
             
                  out_len = ::FFI::MemoryPointer.new(:uint64)
         | 
| 51 120 | 
             
                  out_result = ::FFI::MemoryPointer.new(:pointer)
         | 
| 52 | 
            -
                   | 
| 53 | 
            -
                  out = out_result.read_pointer.read_array_of_float(read_uint64 | 
| 121 | 
            +
                  check_call FFI.XGBoosterPredict(handle, data.handle, 0, ntree_limit, 0, out_len, out_result)
         | 
| 122 | 
            +
                  out = out_result.read_pointer.read_array_of_float(out_len.read_uint64)
         | 
| 54 123 | 
             
                  num_class = out.size / data.num_row
         | 
| 55 124 | 
             
                  out = out.each_slice(num_class).to_a if num_class > 1
         | 
| 56 125 | 
             
                  out
         | 
| 57 126 | 
             
                end
         | 
| 58 127 |  | 
| 59 128 | 
             
                def save_model(fname)
         | 
| 60 | 
            -
                   | 
| 129 | 
            +
                  check_call FFI.XGBoosterSaveModel(handle, fname)
         | 
| 61 130 | 
             
                end
         | 
| 62 131 |  | 
| 63 | 
            -
                 | 
| 64 | 
            -
             | 
| 65 | 
            -
             | 
| 66 | 
            -
                  out_result = ::FFI::MemoryPointer.new(:pointer)
         | 
| 132 | 
            +
                def best_iteration
         | 
| 133 | 
            +
                  attr(:best_iteration)&.to_i
         | 
| 134 | 
            +
                end
         | 
| 67 135 |  | 
| 68 | 
            -
             | 
| 69 | 
            -
                   | 
| 70 | 
            -
             | 
| 136 | 
            +
                def best_iteration=(iteration)
         | 
| 137 | 
            +
                  set_attr(best_iteration: iteration)
         | 
| 138 | 
            +
                end
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                def best_score
         | 
| 141 | 
            +
                  attr(:best_score)&.to_f
         | 
| 142 | 
            +
                end
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def best_score=(score)
         | 
| 145 | 
            +
                  set_attr(best_score: score)
         | 
| 146 | 
            +
                end
         | 
| 71 147 |  | 
| 72 | 
            -
             | 
| 148 | 
            +
                def num_boosted_rounds
         | 
| 149 | 
            +
                  rounds = ::FFI::MemoryPointer.new(:int)
         | 
| 150 | 
            +
                  check_call FFI.XGBoosterBoostedRounds(handle, rounds)
         | 
| 151 | 
            +
                  rounds.read_int
         | 
| 152 | 
            +
                end
         | 
| 73 153 |  | 
| 74 | 
            -
             | 
| 154 | 
            +
                def num_features
         | 
| 155 | 
            +
                  features = ::FFI::MemoryPointer.new(:uint64)
         | 
| 156 | 
            +
                  check_call FFI.XGBoosterGetNumFeature(handle, features)
         | 
| 157 | 
            +
                  features.read_uint64
         | 
| 75 158 | 
             
                end
         | 
| 76 159 |  | 
| 77 160 | 
             
                def dump_model(fout, fmap: "", with_stats: false, dump_format: "text")
         | 
| @@ -93,6 +176,20 @@ module XGBoost | |
| 93 176 | 
             
                  end
         | 
| 94 177 | 
             
                end
         | 
| 95 178 |  | 
| 179 | 
            +
                # returns an array of strings
         | 
| 180 | 
            +
                def dump(fmap: "", with_stats: false, dump_format: "text")
         | 
| 181 | 
            +
                  out_len = ::FFI::MemoryPointer.new(:uint64)
         | 
| 182 | 
            +
                  out_result = ::FFI::MemoryPointer.new(:pointer)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                  names = feature_names || []
         | 
| 185 | 
            +
                  fnames = array_of_pointers(names.map { |fname| string_pointer(fname) })
         | 
| 186 | 
            +
                  ftypes = array_of_pointers(feature_types || Array.new(names.size, string_pointer("float")))
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                  check_call FFI.XGBoosterDumpModelExWithFeatures(handle, names.size, fnames, ftypes, with_stats ? 1 : 0, dump_format, out_len, out_result)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                  out_result.read_pointer.get_array_of_string(0, out_len.read_uint64)
         | 
| 191 | 
            +
                end
         | 
| 192 | 
            +
             | 
| 96 193 | 
             
                def fscore(fmap: "")
         | 
| 97 194 | 
             
                  # always weight
         | 
| 98 195 | 
             
                  score(fmap: fmap, importance_type: "weight")
         | 
| @@ -157,48 +254,67 @@ module XGBoost | |
| 157 254 | 
             
                  end
         | 
| 158 255 | 
             
                end
         | 
| 159 256 |  | 
| 160 | 
            -
                 | 
| 161 | 
            -
                  key = string_pointer(key_name)
         | 
| 162 | 
            -
                  success = ::FFI::MemoryPointer.new(:int)
         | 
| 163 | 
            -
                  out_result = ::FFI::MemoryPointer.new(:pointer)
         | 
| 164 | 
            -
             | 
| 165 | 
            -
                  check_result FFI.XGBoosterGetAttr(handle_pointer, key, out_result, success)
         | 
| 166 | 
            -
             | 
| 167 | 
            -
                  success.read_int == 1 ? out_result.read_pointer.read_string : nil
         | 
| 168 | 
            -
                end
         | 
| 169 | 
            -
             | 
| 170 | 
            -
                def []=(key_name, raw_value)
         | 
| 171 | 
            -
                  key = string_pointer(key_name)
         | 
| 172 | 
            -
                  value = raw_value.nil? ? nil : string_pointer(raw_value)
         | 
| 257 | 
            +
                private
         | 
| 173 258 |  | 
| 174 | 
            -
             | 
| 259 | 
            +
                def handle
         | 
| 260 | 
            +
                  @handle
         | 
| 175 261 | 
             
                end
         | 
| 176 262 |  | 
| 177 | 
            -
                def  | 
| 178 | 
            -
                   | 
| 179 | 
            -
             | 
| 180 | 
            -
                   | 
| 181 | 
            -
             | 
| 182 | 
            -
                  len = read_uint64(out_len)
         | 
| 183 | 
            -
                  key_names = len.zero? ? [] : out_result.read_pointer.get_array_of_string(0, len)
         | 
| 184 | 
            -
             | 
| 185 | 
            -
                  key_names.map { |key_name| [key_name, self[key_name]] }.to_h
         | 
| 186 | 
            -
                end
         | 
| 263 | 
            +
                def assign_dmatrix_features(data)
         | 
| 264 | 
            +
                  if data.num_row == 0
         | 
| 265 | 
            +
                    return
         | 
| 266 | 
            +
                  end
         | 
| 187 267 |  | 
| 188 | 
            -
             | 
| 268 | 
            +
                  fn = data.feature_names
         | 
| 269 | 
            +
                  ft = data.feature_types
         | 
| 189 270 |  | 
| 190 | 
            -
             | 
| 191 | 
            -
             | 
| 271 | 
            +
                  if feature_names.nil?
         | 
| 272 | 
            +
                    self.feature_names = fn
         | 
| 273 | 
            +
                  end
         | 
| 274 | 
            +
                  if feature_types.nil?
         | 
| 275 | 
            +
                    self.feature_types = ft
         | 
| 276 | 
            +
                  end
         | 
| 192 277 | 
             
                end
         | 
| 193 278 |  | 
| 194 | 
            -
                def  | 
| 195 | 
            -
                  ::FFI::MemoryPointer.new(: | 
| 279 | 
            +
                def get_feature_info(field)
         | 
| 280 | 
            +
                  length = ::FFI::MemoryPointer.new(:uint64)
         | 
| 281 | 
            +
                  sarr = ::FFI::MemoryPointer.new(:pointer)
         | 
| 282 | 
            +
                  if @handle.nil?
         | 
| 283 | 
            +
                    return nil
         | 
| 284 | 
            +
                  end
         | 
| 285 | 
            +
                  check_call(
         | 
| 286 | 
            +
                    FFI.XGBoosterGetStrFeatureInfo(
         | 
| 287 | 
            +
                      handle,
         | 
| 288 | 
            +
                      field,
         | 
| 289 | 
            +
                      length,
         | 
| 290 | 
            +
                      sarr
         | 
| 291 | 
            +
                    )
         | 
| 292 | 
            +
                  )
         | 
| 293 | 
            +
                  feature_info = from_cstr_to_rbstr(sarr, length)
         | 
| 294 | 
            +
                  !feature_info.empty? ? feature_info : nil
         | 
| 196 295 | 
             
                end
         | 
| 197 296 |  | 
| 198 | 
            -
                def  | 
| 199 | 
            -
                   | 
| 297 | 
            +
                def set_feature_info(features, field)
         | 
| 298 | 
            +
                  if !features.nil?
         | 
| 299 | 
            +
                    if !features.is_a?(Array)
         | 
| 300 | 
            +
                      raise TypeError, "features must be an array"
         | 
| 301 | 
            +
                    end
         | 
| 302 | 
            +
                    c_feature_info = array_of_pointers(features.map { |f| string_pointer(f) })
         | 
| 303 | 
            +
                    check_call(
         | 
| 304 | 
            +
                      FFI.XGBoosterSetStrFeatureInfo(
         | 
| 305 | 
            +
                        handle,
         | 
| 306 | 
            +
                        field,
         | 
| 307 | 
            +
                        c_feature_info,
         | 
| 308 | 
            +
                        features.length
         | 
| 309 | 
            +
                      )
         | 
| 310 | 
            +
                    )
         | 
| 311 | 
            +
                  else
         | 
| 312 | 
            +
                    check_call(
         | 
| 313 | 
            +
                      FFI.XGBoosterSetStrFeatureInfo(
         | 
| 314 | 
            +
                        handle, field, nil, 0
         | 
| 315 | 
            +
                      )
         | 
| 316 | 
            +
                    )
         | 
| 317 | 
            +
                  end
         | 
| 200 318 | 
             
                end
         | 
| 201 | 
            -
             | 
| 202 | 
            -
                include Utils
         | 
| 203 319 | 
             
              end
         | 
| 204 320 | 
             
            end
         | 
| @@ -0,0 +1,145 @@ | |
| 1 | 
            +
            module XGBoost
         | 
| 2 | 
            +
              class CallbackContainer
         | 
| 3 | 
            +
                attr_reader :aggregated_cv, :history
         | 
| 4 | 
            +
             | 
| 5 | 
            +
                def initialize(callbacks, is_cv: false)
         | 
| 6 | 
            +
                  @callbacks = callbacks
         | 
| 7 | 
            +
                  callbacks.each do |callback|
         | 
| 8 | 
            +
                    unless callback.is_a?(TrainingCallback)
         | 
| 9 | 
            +
                      raise TypeError, "callback must be an instance of XGBoost::TrainingCallback"
         | 
| 10 | 
            +
                    end
         | 
| 11 | 
            +
                  end
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                  @history = {}
         | 
| 14 | 
            +
                  @is_cv = is_cv
         | 
| 15 | 
            +
                end
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def before_training(model)
         | 
| 18 | 
            +
                  @callbacks.each do |callback|
         | 
| 19 | 
            +
                    model = callback.before_training(model)
         | 
| 20 | 
            +
                    if @is_cv
         | 
| 21 | 
            +
                      unless model.is_a?(PackedBooster)
         | 
| 22 | 
            +
                        raise TypeError, "before_training should return the model"
         | 
| 23 | 
            +
                      end
         | 
| 24 | 
            +
                    else
         | 
| 25 | 
            +
                      unless model.is_a?(Booster)
         | 
| 26 | 
            +
                        raise TypeError, "before_training should return the model"
         | 
| 27 | 
            +
                      end
         | 
| 28 | 
            +
                    end
         | 
| 29 | 
            +
                  end
         | 
| 30 | 
            +
                  model
         | 
| 31 | 
            +
                end
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def after_training(model)
         | 
| 34 | 
            +
                  @callbacks.each do |callback|
         | 
| 35 | 
            +
                    model = callback.after_training(model)
         | 
| 36 | 
            +
                    if @is_cv
         | 
| 37 | 
            +
                      unless model.is_a?(PackedBooster)
         | 
| 38 | 
            +
                        raise TypeError, "after_training should return the model"
         | 
| 39 | 
            +
                      end
         | 
| 40 | 
            +
                    else
         | 
| 41 | 
            +
                      unless model.is_a?(Booster)
         | 
| 42 | 
            +
                        raise TypeError, "after_training should return the model"
         | 
| 43 | 
            +
                      end
         | 
| 44 | 
            +
                    end
         | 
| 45 | 
            +
                  end
         | 
| 46 | 
            +
                  model
         | 
| 47 | 
            +
                end
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def before_iteration(model, epoch, dtrain, evals)
         | 
| 50 | 
            +
                  @callbacks.any? do |callback|
         | 
| 51 | 
            +
                    callback.before_iteration(model, epoch, @history)
         | 
| 52 | 
            +
                  end
         | 
| 53 | 
            +
                end
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def after_iteration(model, epoch, dtrain, evals)
         | 
| 56 | 
            +
                  if @is_cv
         | 
| 57 | 
            +
                    scores = model.eval_set(epoch)
         | 
| 58 | 
            +
                    scores = aggcv(scores)
         | 
| 59 | 
            +
                    @aggregated_cv = scores
         | 
| 60 | 
            +
                    update_history(scores, epoch)
         | 
| 61 | 
            +
                  else
         | 
| 62 | 
            +
                    evals ||= []
         | 
| 63 | 
            +
                    evals.each do |_, name|
         | 
| 64 | 
            +
                      if name.include?("-")
         | 
| 65 | 
            +
                        raise ArgumentError, "Dataset name should not contain `-`"
         | 
| 66 | 
            +
                      end
         | 
| 67 | 
            +
                    end
         | 
| 68 | 
            +
                    score = model.eval_set(evals, epoch)
         | 
| 69 | 
            +
                    metric_score = parse_eval_str(score)
         | 
| 70 | 
            +
                    update_history(metric_score, epoch)
         | 
| 71 | 
            +
                  end
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                  @callbacks.any? do |callback|
         | 
| 74 | 
            +
                    callback.after_iteration(model, epoch, @history)
         | 
| 75 | 
            +
                  end
         | 
| 76 | 
            +
                end
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                private
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                def update_history(score, epoch)
         | 
| 81 | 
            +
                  score.each do |d|
         | 
| 82 | 
            +
                    name = d[0]
         | 
| 83 | 
            +
                    s = d[1]
         | 
| 84 | 
            +
                    if @is_cv
         | 
| 85 | 
            +
                      std = d[2]
         | 
| 86 | 
            +
                      x = [s, std]
         | 
| 87 | 
            +
                    else
         | 
| 88 | 
            +
                      x = s
         | 
| 89 | 
            +
                    end
         | 
| 90 | 
            +
                    splited_names = name.split("-")
         | 
| 91 | 
            +
                    data_name = splited_names[0]
         | 
| 92 | 
            +
                    metric_name = splited_names[1..].join("-")
         | 
| 93 | 
            +
                    @history[data_name] ||= {}
         | 
| 94 | 
            +
                    data_history = @history[data_name]
         | 
| 95 | 
            +
                    data_history[metric_name] ||= []
         | 
| 96 | 
            +
                    metric_history = data_history[metric_name]
         | 
| 97 | 
            +
                    metric_history << x
         | 
| 98 | 
            +
                  end
         | 
| 99 | 
            +
                end
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                # TODO move
         | 
| 102 | 
            +
                def parse_eval_str(result)
         | 
| 103 | 
            +
                  splited = result.split[1..]
         | 
| 104 | 
            +
                  # split up `test-error:0.1234`
         | 
| 105 | 
            +
                  metric_score_str = splited.map { |s| s.split(":") }
         | 
| 106 | 
            +
                  # convert to float
         | 
| 107 | 
            +
                  metric_score = metric_score_str.map { |n, s| [n, s.to_f] }
         | 
| 108 | 
            +
                  metric_score
         | 
| 109 | 
            +
                end
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def aggcv(rlist)
         | 
| 112 | 
            +
                  cvmap = {}
         | 
| 113 | 
            +
                  idx = rlist[0].split[0]
         | 
| 114 | 
            +
                  rlist.each do |line|
         | 
| 115 | 
            +
                    arr = line.split
         | 
| 116 | 
            +
                    arr[1..].each_with_index do |it, metric_idx|
         | 
| 117 | 
            +
                      k, v = it.split(":")
         | 
| 118 | 
            +
                      (cvmap[[metric_idx, k]] ||= []) << v.to_f
         | 
| 119 | 
            +
                    end
         | 
| 120 | 
            +
                  end
         | 
| 121 | 
            +
                  msg = idx
         | 
| 122 | 
            +
                  results = []
         | 
| 123 | 
            +
                  cvmap.sort { |x| x[0][0] }.each do |(_, name), s|
         | 
| 124 | 
            +
                    mean = mean(s)
         | 
| 125 | 
            +
                    std = stdev(s)
         | 
| 126 | 
            +
                    results << [name, mean, std]
         | 
| 127 | 
            +
                  end
         | 
| 128 | 
            +
                  results
         | 
| 129 | 
            +
                end
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def mean(arr)
         | 
| 132 | 
            +
                  arr.sum / arr.size.to_f
         | 
| 133 | 
            +
                end
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                # don't subtract one from arr.size
         | 
| 136 | 
            +
                def stdev(arr)
         | 
| 137 | 
            +
                  m = mean(arr)
         | 
| 138 | 
            +
                  sum = 0
         | 
| 139 | 
            +
                  arr.each do |v|
         | 
| 140 | 
            +
                    sum += (v - m) ** 2
         | 
| 141 | 
            +
                  end
         | 
| 142 | 
            +
                  Math.sqrt(sum / arr.size)
         | 
| 143 | 
            +
                end
         | 
| 144 | 
            +
              end
         | 
| 145 | 
            +
            end
         | 
    
        data/lib/xgboost/classifier.rb
    CHANGED
    
    | @@ -18,7 +18,7 @@ module XGBoost | |
| 18 18 |  | 
| 19 19 | 
             
                  @booster = XGBoost.train(params, dtrain,
         | 
| 20 20 | 
             
                    num_boost_round: @n_estimators,
         | 
| 21 | 
            -
                    early_stopping_rounds: early_stopping_rounds,
         | 
| 21 | 
            +
                    early_stopping_rounds: early_stopping_rounds || @early_stopping_rounds,
         | 
| 22 22 | 
             
                    verbose_eval: verbose,
         | 
| 23 23 | 
             
                    evals: evals
         | 
| 24 24 | 
             
                  )
         | 
| @@ -0,0 +1,26 @@ | |
| 1 | 
            +
            require "forwardable"
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            module XGBoost
         | 
| 4 | 
            +
              class CVPack
         | 
| 5 | 
            +
                extend Forwardable
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                def_delegators :@bst, :num_boosted_rounds, :best_iteration=, :best_score=
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                attr_reader :bst
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                def initialize(dtrain, dtest, param)
         | 
| 12 | 
            +
                  @dtrain = dtrain
         | 
| 13 | 
            +
                  @dtest = dtest
         | 
| 14 | 
            +
                  @watchlist = [[dtrain, "train"], [dtest, "test"]]
         | 
| 15 | 
            +
                  @bst = Booster.new(params: param, cache: [dtrain, dtest])
         | 
| 16 | 
            +
                end
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def update(iteration)
         | 
| 19 | 
            +
                  @bst.update(@dtrain, iteration)
         | 
| 20 | 
            +
                end
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def eval_set(iteration)
         | 
| 23 | 
            +
                  @bst.eval_set(@watchlist, iteration)
         | 
| 24 | 
            +
                end
         | 
| 25 | 
            +
              end
         | 
| 26 | 
            +
            end
         |