xgb 0.8.0 → 0.10.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/NOTICE.txt +2 -2
- data/README.md +3 -2
- data/lib/xgboost/booster.rb +181 -65
- data/lib/xgboost/callback_container.rb +145 -0
- data/lib/xgboost/classifier.rb +1 -1
- data/lib/xgboost/cv_pack.rb +26 -0
- data/lib/xgboost/dmatrix.rb +190 -78
- data/lib/xgboost/early_stopping.rb +132 -0
- data/lib/xgboost/evaluation_monitor.rb +44 -0
- data/lib/xgboost/ffi.rb +12 -2
- data/lib/xgboost/model.rb +2 -1
- data/lib/xgboost/packed_booster.rb +51 -0
- data/lib/xgboost/regressor.rb +1 -1
- data/lib/xgboost/training_callback.rb +23 -0
- data/lib/xgboost/utils.rb +19 -4
- data/lib/xgboost/version.rb +1 -1
- data/lib/xgboost.rb +107 -112
- data/vendor/aarch64-linux/libxgboost.so +0 -0
- data/vendor/arm64-darwin/libxgboost.dylib +0 -0
- data/vendor/x64-mingw/xgboost.dll +0 -0
- data/vendor/x86_64-darwin/libxgboost.dylib +0 -0
- data/vendor/x86_64-linux/libxgboost.so +0 -0
- data/vendor/x86_64-linux-musl/libxgboost.so +0 -0
- metadata +10 -14
- data/vendor/aarch64-linux/LICENSE-rabit.txt +0 -28
- data/vendor/arm64-darwin/LICENSE-rabit.txt +0 -28
- data/vendor/x64-mingw/LICENSE-rabit.txt +0 -28
- data/vendor/x86_64-darwin/LICENSE-rabit.txt +0 -28
- data/vendor/x86_64-linux/LICENSE-rabit.txt +0 -28
- data/vendor/x86_64-linux-musl/LICENSE-rabit.txt +0 -28
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: c717478c585e099431318c391c7c054b58c844f5ddb19dd5d23f921acbea5b26
|
4
|
+
data.tar.gz: dd221a13627c57135ef2e6e87f267889a76692f1417014e5728b5c279b44a2d5
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 4430473a5999cde2447f0782be837f1fb06db44eb3e823621ae94db96b356b40e04a772fa211819ed5915324a10b54ccb0b595d7368aaa22d31f6be4eb13fd0d
|
7
|
+
data.tar.gz: eccc8f38705f82f2aa963c672df5515aa345926af91f9f94beab5fd01406a8d5ad84807d3609dec321952d28d8b2da575b880220ae8375d4296a2c34bf6b0b97
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,15 @@
|
|
1
|
+
## 0.10.0 (2025-03-15)
|
2
|
+
|
3
|
+
- Updated XGBoost to 3.0.0
|
4
|
+
|
5
|
+
## 0.9.0 (2024-10-17)
|
6
|
+
|
7
|
+
- Updated XGBoost to 2.1.1
|
8
|
+
- Added support for callbacks
|
9
|
+
- Added `num_features` and `save_config` methods to `Booster`
|
10
|
+
- Added `num_nonmissing` and `data_split_mode` methods to `DMatrix`
|
11
|
+
- Dropped support for Ruby < 3.1
|
12
|
+
|
1
13
|
## 0.8.0 (2023-09-13)
|
2
14
|
|
3
15
|
- Updated XGBoost to 2.0.0
|
data/NOTICE.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1
|
-
Copyright XGBoost contributors
|
2
|
-
Copyright 2019-
|
1
|
+
Copyright 2014-2025 XGBoost contributors
|
2
|
+
Copyright 2019-2025 Andrew Kane
|
3
3
|
|
4
4
|
Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
you may not use this file except in compliance with the License.
|
data/README.md
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
[XGBoost](https://github.com/dmlc/xgboost) - high performance gradient boosting - for Ruby
|
4
4
|
|
5
|
-
[](https://github.com/ankane/xgboost-ruby/actions)
|
6
6
|
|
7
7
|
## Installation
|
8
8
|
|
@@ -126,7 +126,8 @@ model.feature_importances
|
|
126
126
|
Early stopping
|
127
127
|
|
128
128
|
```ruby
|
129
|
-
model.
|
129
|
+
model = XGBoost::Regressor.new(early_stopping_rounds: 5)
|
130
|
+
model.fit(x, y, eval_set: [[x_test, y_test]])
|
130
131
|
```
|
131
132
|
|
132
133
|
## Data
|
data/lib/xgboost/booster.rb
CHANGED
@@ -1,77 +1,160 @@
|
|
1
1
|
module XGBoost
|
2
2
|
class Booster
|
3
|
-
|
3
|
+
include Utils
|
4
|
+
|
5
|
+
def initialize(params: nil, cache: nil, model_file: nil)
|
6
|
+
cache ||= []
|
7
|
+
cache.each do |d|
|
8
|
+
if !d.is_a?(DMatrix)
|
9
|
+
raise TypeError, "invalid cache item: #{d.class.name}"
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
dmats = array_of_pointers(cache.map { |d| d.handle })
|
14
|
+
out = ::FFI::MemoryPointer.new(:pointer)
|
15
|
+
check_call FFI.XGBoosterCreate(dmats, cache.length, out)
|
16
|
+
@handle = ::FFI::AutoPointer.new(out.read_pointer, FFI.method(:XGBoosterFree))
|
4
17
|
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
ObjectSpace.define_finalizer(@handle, self.class.finalize(handle_pointer.to_i))
|
18
|
+
cache.each do |d|
|
19
|
+
assign_dmatrix_features(d)
|
20
|
+
end
|
9
21
|
|
10
22
|
if model_file
|
11
|
-
|
23
|
+
check_call FFI.XGBoosterLoadModel(handle, model_file)
|
12
24
|
end
|
13
25
|
|
14
|
-
self.best_iteration = 0
|
15
26
|
set_param(params)
|
16
27
|
end
|
17
28
|
|
18
|
-
def
|
19
|
-
|
20
|
-
|
29
|
+
def [](key_name)
|
30
|
+
if key_name.is_a?(String)
|
31
|
+
return attr(key_name)
|
32
|
+
end
|
33
|
+
|
34
|
+
# TODO slice
|
35
|
+
|
36
|
+
raise TypeError, "expected string"
|
21
37
|
end
|
22
38
|
|
23
|
-
def
|
24
|
-
|
39
|
+
def []=(key_name, raw_value)
|
40
|
+
set_attr(**{key_name => raw_value})
|
25
41
|
end
|
26
42
|
|
27
|
-
def
|
28
|
-
|
29
|
-
|
43
|
+
def save_config
|
44
|
+
length = ::FFI::MemoryPointer.new(:uint64)
|
45
|
+
json_string = ::FFI::MemoryPointer.new(:pointer)
|
46
|
+
check_call FFI.XGBoosterSaveJsonConfig(handle, length, json_string)
|
47
|
+
json_string.read_pointer.read_string(length.read_uint64).force_encoding(Encoding::UTF_8)
|
48
|
+
end
|
30
49
|
|
31
|
-
|
50
|
+
def reset
|
51
|
+
check_call FFI.XGBoosterReset(handle)
|
52
|
+
self
|
53
|
+
end
|
54
|
+
|
55
|
+
def attr(key)
|
56
|
+
ret = ::FFI::MemoryPointer.new(:pointer)
|
57
|
+
success = ::FFI::MemoryPointer.new(:int)
|
58
|
+
check_call FFI.XGBoosterGetAttr(handle, key.to_s, ret, success)
|
59
|
+
success.read_int != 0 ? ret.read_pointer.read_string : nil
|
60
|
+
end
|
32
61
|
|
33
|
-
|
62
|
+
def attributes
|
63
|
+
length = ::FFI::MemoryPointer.new(:uint64)
|
64
|
+
sarr = ::FFI::MemoryPointer.new(:pointer)
|
65
|
+
check_call FFI.XGBoosterGetAttrNames(handle, length, sarr)
|
66
|
+
attr_names = from_cstr_to_rbstr(sarr, length)
|
67
|
+
attr_names.to_h { |n| [n, attr(n)] }
|
68
|
+
end
|
34
69
|
|
35
|
-
|
70
|
+
def set_attr(**kwargs)
|
71
|
+
kwargs.each do |key, value|
|
72
|
+
check_call FFI.XGBoosterSetAttr(handle, key.to_s, value&.to_s)
|
73
|
+
end
|
74
|
+
end
|
75
|
+
|
76
|
+
def feature_types
|
77
|
+
get_feature_info("feature_type")
|
78
|
+
end
|
79
|
+
|
80
|
+
def feature_types=(features)
|
81
|
+
set_feature_info(features, "feature_type")
|
82
|
+
end
|
83
|
+
|
84
|
+
def feature_names
|
85
|
+
get_feature_info("feature_name")
|
86
|
+
end
|
87
|
+
|
88
|
+
def feature_names=(features)
|
89
|
+
set_feature_info(features, "feature_name")
|
36
90
|
end
|
37
91
|
|
38
92
|
def set_param(params, value = nil)
|
39
93
|
if params.is_a?(Enumerable)
|
40
94
|
params.each do |k, v|
|
41
|
-
|
95
|
+
check_call FFI.XGBoosterSetParam(handle, k.to_s, v.to_s)
|
42
96
|
end
|
43
97
|
else
|
44
|
-
|
98
|
+
check_call FFI.XGBoosterSetParam(handle, params.to_s, value.to_s)
|
45
99
|
end
|
46
100
|
end
|
47
101
|
|
102
|
+
def update(dtrain, iteration)
|
103
|
+
check_call FFI.XGBoosterUpdateOneIter(handle, iteration, dtrain.handle)
|
104
|
+
end
|
105
|
+
|
106
|
+
def eval_set(evals, iteration)
|
107
|
+
dmats = array_of_pointers(evals.map { |v| v[0].handle })
|
108
|
+
evnames = array_of_pointers(evals.map { |v| string_pointer(v[1]) })
|
109
|
+
|
110
|
+
out_result = ::FFI::MemoryPointer.new(:pointer)
|
111
|
+
|
112
|
+
check_call FFI.XGBoosterEvalOneIter(handle, iteration, dmats, evnames, evals.size, out_result)
|
113
|
+
|
114
|
+
out_result.read_pointer.read_string
|
115
|
+
end
|
116
|
+
|
48
117
|
def predict(data, ntree_limit: nil)
|
49
118
|
ntree_limit ||= 0
|
50
119
|
out_len = ::FFI::MemoryPointer.new(:uint64)
|
51
120
|
out_result = ::FFI::MemoryPointer.new(:pointer)
|
52
|
-
|
53
|
-
out = out_result.read_pointer.read_array_of_float(read_uint64
|
121
|
+
check_call FFI.XGBoosterPredict(handle, data.handle, 0, ntree_limit, 0, out_len, out_result)
|
122
|
+
out = out_result.read_pointer.read_array_of_float(out_len.read_uint64)
|
54
123
|
num_class = out.size / data.num_row
|
55
124
|
out = out.each_slice(num_class).to_a if num_class > 1
|
56
125
|
out
|
57
126
|
end
|
58
127
|
|
59
128
|
def save_model(fname)
|
60
|
-
|
129
|
+
check_call FFI.XGBoosterSaveModel(handle, fname)
|
61
130
|
end
|
62
131
|
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
out_result = ::FFI::MemoryPointer.new(:pointer)
|
132
|
+
def best_iteration
|
133
|
+
attr(:best_iteration)&.to_i
|
134
|
+
end
|
67
135
|
|
68
|
-
|
69
|
-
|
70
|
-
|
136
|
+
def best_iteration=(iteration)
|
137
|
+
set_attr(best_iteration: iteration)
|
138
|
+
end
|
139
|
+
|
140
|
+
def best_score
|
141
|
+
attr(:best_score)&.to_f
|
142
|
+
end
|
143
|
+
|
144
|
+
def best_score=(score)
|
145
|
+
set_attr(best_score: score)
|
146
|
+
end
|
71
147
|
|
72
|
-
|
148
|
+
def num_boosted_rounds
|
149
|
+
rounds = ::FFI::MemoryPointer.new(:int)
|
150
|
+
check_call FFI.XGBoosterBoostedRounds(handle, rounds)
|
151
|
+
rounds.read_int
|
152
|
+
end
|
73
153
|
|
74
|
-
|
154
|
+
def num_features
|
155
|
+
features = ::FFI::MemoryPointer.new(:uint64)
|
156
|
+
check_call FFI.XGBoosterGetNumFeature(handle, features)
|
157
|
+
features.read_uint64
|
75
158
|
end
|
76
159
|
|
77
160
|
def dump_model(fout, fmap: "", with_stats: false, dump_format: "text")
|
@@ -93,6 +176,20 @@ module XGBoost
|
|
93
176
|
end
|
94
177
|
end
|
95
178
|
|
179
|
+
# returns an array of strings
|
180
|
+
def dump(fmap: "", with_stats: false, dump_format: "text")
|
181
|
+
out_len = ::FFI::MemoryPointer.new(:uint64)
|
182
|
+
out_result = ::FFI::MemoryPointer.new(:pointer)
|
183
|
+
|
184
|
+
names = feature_names || []
|
185
|
+
fnames = array_of_pointers(names.map { |fname| string_pointer(fname) })
|
186
|
+
ftypes = array_of_pointers(feature_types || Array.new(names.size, string_pointer("float")))
|
187
|
+
|
188
|
+
check_call FFI.XGBoosterDumpModelExWithFeatures(handle, names.size, fnames, ftypes, with_stats ? 1 : 0, dump_format, out_len, out_result)
|
189
|
+
|
190
|
+
out_result.read_pointer.get_array_of_string(0, out_len.read_uint64)
|
191
|
+
end
|
192
|
+
|
96
193
|
def fscore(fmap: "")
|
97
194
|
# always weight
|
98
195
|
score(fmap: fmap, importance_type: "weight")
|
@@ -157,48 +254,67 @@ module XGBoost
|
|
157
254
|
end
|
158
255
|
end
|
159
256
|
|
160
|
-
|
161
|
-
key = string_pointer(key_name)
|
162
|
-
success = ::FFI::MemoryPointer.new(:int)
|
163
|
-
out_result = ::FFI::MemoryPointer.new(:pointer)
|
164
|
-
|
165
|
-
check_result FFI.XGBoosterGetAttr(handle_pointer, key, out_result, success)
|
166
|
-
|
167
|
-
success.read_int == 1 ? out_result.read_pointer.read_string : nil
|
168
|
-
end
|
169
|
-
|
170
|
-
def []=(key_name, raw_value)
|
171
|
-
key = string_pointer(key_name)
|
172
|
-
value = raw_value.nil? ? nil : string_pointer(raw_value)
|
257
|
+
private
|
173
258
|
|
174
|
-
|
259
|
+
def handle
|
260
|
+
@handle
|
175
261
|
end
|
176
262
|
|
177
|
-
def
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
len = read_uint64(out_len)
|
183
|
-
key_names = len.zero? ? [] : out_result.read_pointer.get_array_of_string(0, len)
|
184
|
-
|
185
|
-
key_names.map { |key_name| [key_name, self[key_name]] }.to_h
|
186
|
-
end
|
263
|
+
def assign_dmatrix_features(data)
|
264
|
+
if data.num_row == 0
|
265
|
+
return
|
266
|
+
end
|
187
267
|
|
188
|
-
|
268
|
+
fn = data.feature_names
|
269
|
+
ft = data.feature_types
|
189
270
|
|
190
|
-
|
191
|
-
|
271
|
+
if feature_names.nil?
|
272
|
+
self.feature_names = fn
|
273
|
+
end
|
274
|
+
if feature_types.nil?
|
275
|
+
self.feature_types = ft
|
276
|
+
end
|
192
277
|
end
|
193
278
|
|
194
|
-
def
|
195
|
-
::FFI::MemoryPointer.new(:
|
279
|
+
def get_feature_info(field)
|
280
|
+
length = ::FFI::MemoryPointer.new(:uint64)
|
281
|
+
sarr = ::FFI::MemoryPointer.new(:pointer)
|
282
|
+
if @handle.nil?
|
283
|
+
return nil
|
284
|
+
end
|
285
|
+
check_call(
|
286
|
+
FFI.XGBoosterGetStrFeatureInfo(
|
287
|
+
handle,
|
288
|
+
field,
|
289
|
+
length,
|
290
|
+
sarr
|
291
|
+
)
|
292
|
+
)
|
293
|
+
feature_info = from_cstr_to_rbstr(sarr, length)
|
294
|
+
!feature_info.empty? ? feature_info : nil
|
196
295
|
end
|
197
296
|
|
198
|
-
def
|
199
|
-
|
297
|
+
def set_feature_info(features, field)
|
298
|
+
if !features.nil?
|
299
|
+
if !features.is_a?(Array)
|
300
|
+
raise TypeError, "features must be an array"
|
301
|
+
end
|
302
|
+
c_feature_info = array_of_pointers(features.map { |f| string_pointer(f) })
|
303
|
+
check_call(
|
304
|
+
FFI.XGBoosterSetStrFeatureInfo(
|
305
|
+
handle,
|
306
|
+
field,
|
307
|
+
c_feature_info,
|
308
|
+
features.length
|
309
|
+
)
|
310
|
+
)
|
311
|
+
else
|
312
|
+
check_call(
|
313
|
+
FFI.XGBoosterSetStrFeatureInfo(
|
314
|
+
handle, field, nil, 0
|
315
|
+
)
|
316
|
+
)
|
317
|
+
end
|
200
318
|
end
|
201
|
-
|
202
|
-
include Utils
|
203
319
|
end
|
204
320
|
end
|
@@ -0,0 +1,145 @@
|
|
1
|
+
module XGBoost
|
2
|
+
class CallbackContainer
|
3
|
+
attr_reader :aggregated_cv, :history
|
4
|
+
|
5
|
+
def initialize(callbacks, is_cv: false)
|
6
|
+
@callbacks = callbacks
|
7
|
+
callbacks.each do |callback|
|
8
|
+
unless callback.is_a?(TrainingCallback)
|
9
|
+
raise TypeError, "callback must be an instance of XGBoost::TrainingCallback"
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
@history = {}
|
14
|
+
@is_cv = is_cv
|
15
|
+
end
|
16
|
+
|
17
|
+
def before_training(model)
|
18
|
+
@callbacks.each do |callback|
|
19
|
+
model = callback.before_training(model)
|
20
|
+
if @is_cv
|
21
|
+
unless model.is_a?(PackedBooster)
|
22
|
+
raise TypeError, "before_training should return the model"
|
23
|
+
end
|
24
|
+
else
|
25
|
+
unless model.is_a?(Booster)
|
26
|
+
raise TypeError, "before_training should return the model"
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
model
|
31
|
+
end
|
32
|
+
|
33
|
+
def after_training(model)
|
34
|
+
@callbacks.each do |callback|
|
35
|
+
model = callback.after_training(model)
|
36
|
+
if @is_cv
|
37
|
+
unless model.is_a?(PackedBooster)
|
38
|
+
raise TypeError, "after_training should return the model"
|
39
|
+
end
|
40
|
+
else
|
41
|
+
unless model.is_a?(Booster)
|
42
|
+
raise TypeError, "after_training should return the model"
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
model
|
47
|
+
end
|
48
|
+
|
49
|
+
def before_iteration(model, epoch, dtrain, evals)
|
50
|
+
@callbacks.any? do |callback|
|
51
|
+
callback.before_iteration(model, epoch, @history)
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
55
|
+
def after_iteration(model, epoch, dtrain, evals)
|
56
|
+
if @is_cv
|
57
|
+
scores = model.eval_set(epoch)
|
58
|
+
scores = aggcv(scores)
|
59
|
+
@aggregated_cv = scores
|
60
|
+
update_history(scores, epoch)
|
61
|
+
else
|
62
|
+
evals ||= []
|
63
|
+
evals.each do |_, name|
|
64
|
+
if name.include?("-")
|
65
|
+
raise ArgumentError, "Dataset name should not contain `-`"
|
66
|
+
end
|
67
|
+
end
|
68
|
+
score = model.eval_set(evals, epoch)
|
69
|
+
metric_score = parse_eval_str(score)
|
70
|
+
update_history(metric_score, epoch)
|
71
|
+
end
|
72
|
+
|
73
|
+
@callbacks.any? do |callback|
|
74
|
+
callback.after_iteration(model, epoch, @history)
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
private
|
79
|
+
|
80
|
+
def update_history(score, epoch)
|
81
|
+
score.each do |d|
|
82
|
+
name = d[0]
|
83
|
+
s = d[1]
|
84
|
+
if @is_cv
|
85
|
+
std = d[2]
|
86
|
+
x = [s, std]
|
87
|
+
else
|
88
|
+
x = s
|
89
|
+
end
|
90
|
+
splited_names = name.split("-")
|
91
|
+
data_name = splited_names[0]
|
92
|
+
metric_name = splited_names[1..].join("-")
|
93
|
+
@history[data_name] ||= {}
|
94
|
+
data_history = @history[data_name]
|
95
|
+
data_history[metric_name] ||= []
|
96
|
+
metric_history = data_history[metric_name]
|
97
|
+
metric_history << x
|
98
|
+
end
|
99
|
+
end
|
100
|
+
|
101
|
+
# TODO move
|
102
|
+
def parse_eval_str(result)
|
103
|
+
splited = result.split[1..]
|
104
|
+
# split up `test-error:0.1234`
|
105
|
+
metric_score_str = splited.map { |s| s.split(":") }
|
106
|
+
# convert to float
|
107
|
+
metric_score = metric_score_str.map { |n, s| [n, s.to_f] }
|
108
|
+
metric_score
|
109
|
+
end
|
110
|
+
|
111
|
+
def aggcv(rlist)
|
112
|
+
cvmap = {}
|
113
|
+
idx = rlist[0].split[0]
|
114
|
+
rlist.each do |line|
|
115
|
+
arr = line.split
|
116
|
+
arr[1..].each_with_index do |it, metric_idx|
|
117
|
+
k, v = it.split(":")
|
118
|
+
(cvmap[[metric_idx, k]] ||= []) << v.to_f
|
119
|
+
end
|
120
|
+
end
|
121
|
+
msg = idx
|
122
|
+
results = []
|
123
|
+
cvmap.sort { |x| x[0][0] }.each do |(_, name), s|
|
124
|
+
mean = mean(s)
|
125
|
+
std = stdev(s)
|
126
|
+
results << [name, mean, std]
|
127
|
+
end
|
128
|
+
results
|
129
|
+
end
|
130
|
+
|
131
|
+
def mean(arr)
|
132
|
+
arr.sum / arr.size.to_f
|
133
|
+
end
|
134
|
+
|
135
|
+
# don't subtract one from arr.size
|
136
|
+
def stdev(arr)
|
137
|
+
m = mean(arr)
|
138
|
+
sum = 0
|
139
|
+
arr.each do |v|
|
140
|
+
sum += (v - m) ** 2
|
141
|
+
end
|
142
|
+
Math.sqrt(sum / arr.size)
|
143
|
+
end
|
144
|
+
end
|
145
|
+
end
|
data/lib/xgboost/classifier.rb
CHANGED
@@ -18,7 +18,7 @@ module XGBoost
|
|
18
18
|
|
19
19
|
@booster = XGBoost.train(params, dtrain,
|
20
20
|
num_boost_round: @n_estimators,
|
21
|
-
early_stopping_rounds: early_stopping_rounds,
|
21
|
+
early_stopping_rounds: early_stopping_rounds || @early_stopping_rounds,
|
22
22
|
verbose_eval: verbose,
|
23
23
|
evals: evals
|
24
24
|
)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
require "forwardable"
|
2
|
+
|
3
|
+
module XGBoost
|
4
|
+
class CVPack
|
5
|
+
extend Forwardable
|
6
|
+
|
7
|
+
def_delegators :@bst, :num_boosted_rounds, :best_iteration=, :best_score=
|
8
|
+
|
9
|
+
attr_reader :bst
|
10
|
+
|
11
|
+
def initialize(dtrain, dtest, param)
|
12
|
+
@dtrain = dtrain
|
13
|
+
@dtest = dtest
|
14
|
+
@watchlist = [[dtrain, "train"], [dtest, "test"]]
|
15
|
+
@bst = Booster.new(params: param, cache: [dtrain, dtest])
|
16
|
+
end
|
17
|
+
|
18
|
+
def update(iteration)
|
19
|
+
@bst.update(@dtrain, iteration)
|
20
|
+
end
|
21
|
+
|
22
|
+
def eval_set(iteration)
|
23
|
+
@bst.eval_set(@watchlist, iteration)
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|