lightgbm 0.1.1 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +22 -18
- data/lib/lightgbm.rb +167 -7
- data/lib/lightgbm/booster.rb +124 -52
- data/lib/lightgbm/dataset.rb +24 -3
- data/lib/lightgbm/ffi.rb +7 -0
- data/lib/lightgbm/utils.rb +9 -1
- data/lib/lightgbm/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 723130d41ea9196bcbd7bcffeb865c40c65985f26eca018d49bd176d33c43142
|
4
|
+
data.tar.gz: d92b41899ff72da2ef4e5782bf4d2840caee1554107d9fd5d02bd6728829585a
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 6960dbf1e2a884705e8a2752952392483c0ab1e74e970382b45df747ccbedbc0d978d3910e2b935190c98a9c45315841b53e27128aa5944e3a7834808e05582a
|
7
|
+
data.tar.gz: c6e793933dc794fa62099580ad35f29d7e5e3ae24a07df4335dca3d68571bc1d5b360a7cf0859d75d245a77d470a0dafba0cb5d86edda1974f3bc532b0f5c11a
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -16,6 +16,8 @@ Add this line to your application’s Gemfile:
|
|
16
16
|
gem 'lightgbm'
|
17
17
|
```
|
18
18
|
|
19
|
+
## Getting Started
|
20
|
+
|
19
21
|
Train a model
|
20
22
|
|
21
23
|
```ruby
|
@@ -30,13 +32,13 @@ Predict
|
|
30
32
|
booster.predict(x_test)
|
31
33
|
```
|
32
34
|
|
33
|
-
Save the model
|
35
|
+
Save the model to a file
|
34
36
|
|
35
37
|
```ruby
|
36
38
|
booster.save_model("model.txt")
|
37
39
|
```
|
38
40
|
|
39
|
-
Load
|
41
|
+
Load the model from a file
|
40
42
|
|
41
43
|
```ruby
|
42
44
|
booster = LightGBM::Booster.new(model_file: "model.txt")
|
@@ -48,30 +50,32 @@ Get feature importance
|
|
48
50
|
booster.feature_importance
|
49
51
|
```
|
50
52
|
|
51
|
-
##
|
52
|
-
|
53
|
-
### Booster
|
53
|
+
## Early Stopping
|
54
54
|
|
55
55
|
```ruby
|
56
|
-
|
57
|
-
booster.to_json
|
58
|
-
booster.model_to_string
|
59
|
-
booster.current_iteration
|
56
|
+
LightGBM.train(params, train_set, valid_set: [train_set, test_set], early_stopping_rounds: 5)
|
60
57
|
```
|
61
58
|
|
62
|
-
|
59
|
+
## CV
|
63
60
|
|
64
61
|
```ruby
|
65
|
-
|
66
|
-
dataset.num_data
|
67
|
-
dataset.num_feature
|
68
|
-
|
69
|
-
# note: only works with unquoted CSVs
|
70
|
-
dataset = LightGBM::Dataset.new("data.csv", params: {headers: true, label: "name:label"})
|
71
|
-
dataset.save_binary("train.bin")
|
72
|
-
dataset.dump_text("train.txt")
|
62
|
+
LightGBM.cv(params, train_set, nfold: 5, verbose_eval: true)
|
73
63
|
```
|
74
64
|
|
65
|
+
## Reference
|
66
|
+
|
67
|
+
This library follows the [Data Structure and Training APIs](https://lightgbm.readthedocs.io/en/latest/Python-API.html) for the Python library. A few differences are:
|
68
|
+
|
69
|
+
- The default verbosity is `-1`
|
70
|
+
- With the `cv` method, `stratified` is set to `false`
|
71
|
+
|
72
|
+
Some methods and options are also missing at the moment. PRs welcome!
|
73
|
+
|
74
|
+
## Helpful Resources
|
75
|
+
|
76
|
+
- [Parameters](https://lightgbm.readthedocs.io/en/latest/Parameters.html)
|
77
|
+
- [Parameter Tuning](https://lightgbm.readthedocs.io/en/latest/Parameters-Tuning.html)
|
78
|
+
|
75
79
|
## Credits
|
76
80
|
|
77
81
|
Thanks to the [xgboost](https://github.com/PairOnAir/xgboost-ruby) gem for serving as an initial reference, and Selva Prabhakaran for the [test datasets](https://github.com/selva86/datasets).
|
data/lib/lightgbm.rb
CHANGED
@@ -11,14 +11,174 @@ require "lightgbm/version"
|
|
11
11
|
module LightGBM
|
12
12
|
class Error < StandardError; end
|
13
13
|
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
14
|
+
class << self
|
15
|
+
def train(params, train_set,num_boost_round: 100, valid_sets: [], valid_names: [], early_stopping_rounds: nil, verbose_eval: true)
|
16
|
+
booster = Booster.new(params: params, train_set: train_set)
|
17
|
+
|
18
|
+
valid_contain_train = false
|
19
|
+
valid_sets.zip(valid_names).each_with_index do |(data, name), i|
|
20
|
+
if data == train_set
|
21
|
+
booster.train_data_name = name || "training"
|
22
|
+
valid_contain_train = true
|
23
|
+
else
|
24
|
+
booster.add_valid(data, name || "valid_#{i}")
|
25
|
+
end
|
26
|
+
end
|
27
|
+
|
28
|
+
booster.best_iteration = 0
|
29
|
+
|
30
|
+
if early_stopping_rounds
|
31
|
+
best_score = []
|
32
|
+
best_iter = []
|
33
|
+
best_message = []
|
34
|
+
|
35
|
+
puts "Training until validation scores don't improve for #{early_stopping_rounds.to_i} rounds." if verbose_eval
|
36
|
+
end
|
37
|
+
|
38
|
+
num_boost_round.times do |iteration|
|
39
|
+
booster.update
|
40
|
+
|
41
|
+
if valid_sets.any?
|
42
|
+
# print results
|
43
|
+
messages = []
|
44
|
+
|
45
|
+
if valid_contain_train
|
46
|
+
# not sure why reversed in output
|
47
|
+
booster.eval_train.reverse.each do |res|
|
48
|
+
messages << "%s's %s: %g" % [res[0], res[1], res[2]]
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
eval_valid = booster.eval_valid
|
53
|
+
# not sure why reversed in output
|
54
|
+
eval_valid.reverse.each do |res|
|
55
|
+
messages << "%s's %s: %g" % [res[0], res[1], res[2]]
|
56
|
+
end
|
57
|
+
|
58
|
+
message = "[#{iteration + 1}]\t#{messages.join("\t")}"
|
59
|
+
|
60
|
+
puts message if verbose_eval
|
61
|
+
|
62
|
+
if early_stopping_rounds
|
63
|
+
stop_early = false
|
64
|
+
eval_valid.each_with_index do |(_, _, score, higher_better), i|
|
65
|
+
op = higher_better ? :> : :<
|
66
|
+
if best_score[i].nil? || score.send(op, best_score[i])
|
67
|
+
best_score[i] = score
|
68
|
+
best_iter[i] = iteration
|
69
|
+
best_message[i] = message
|
70
|
+
elsif iteration - best_iter[i] >= early_stopping_rounds
|
71
|
+
booster.best_iteration = best_iter[i] + 1
|
72
|
+
puts "Early stopping, best iteration is:\n#{best_message[i]}" if verbose_eval
|
73
|
+
stop_early = true
|
74
|
+
break
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
break if stop_early
|
79
|
+
|
80
|
+
if iteration == num_boost_round - 1
|
81
|
+
booster.best_iteration = best_iter[0] + 1
|
82
|
+
puts "Did not meet early stopping. Best iteration is: #{best_message[0]}" if verbose_eval
|
83
|
+
end
|
84
|
+
end
|
85
|
+
end
|
86
|
+
end
|
87
|
+
|
88
|
+
booster
|
89
|
+
end
|
90
|
+
|
91
|
+
def cv(params, train_set, num_boost_round: 100, nfold: 5, seed: 0, shuffle: true, early_stopping_rounds: nil, verbose_eval: nil, show_stdv: true)
|
92
|
+
rand_idx = (0...train_set.num_data).to_a
|
93
|
+
rand_idx.shuffle!(random: Random.new(seed)) if shuffle
|
94
|
+
|
95
|
+
kstep = rand_idx.size / nfold
|
96
|
+
test_id = rand_idx.each_slice(kstep).to_a[0...nfold]
|
97
|
+
train_id = []
|
98
|
+
nfold.times do |i|
|
99
|
+
idx = test_id.dup
|
100
|
+
idx.delete_at(i)
|
101
|
+
train_id << idx.flatten
|
102
|
+
end
|
103
|
+
|
104
|
+
boosters = []
|
105
|
+
folds = train_id.zip(test_id)
|
106
|
+
folds.each do |(train_idx, test_idx)|
|
107
|
+
fold_train_set = train_set.subset(train_idx)
|
108
|
+
fold_valid_set = train_set.subset(test_idx)
|
109
|
+
booster = Booster.new(params: params, train_set: fold_train_set)
|
110
|
+
booster.add_valid(fold_valid_set, "valid")
|
111
|
+
boosters << booster
|
112
|
+
end
|
113
|
+
|
114
|
+
eval_hist = {}
|
115
|
+
|
116
|
+
if early_stopping_rounds
|
117
|
+
best_score = {}
|
118
|
+
best_iter = {}
|
119
|
+
end
|
120
|
+
|
121
|
+
num_boost_round.times do |iteration|
|
122
|
+
boosters.each(&:update)
|
123
|
+
|
124
|
+
scores = {}
|
125
|
+
boosters.map(&:eval_valid).map(&:reverse).flatten(1).each do |r|
|
126
|
+
(scores[r[1]] ||= []) << r[2]
|
127
|
+
end
|
128
|
+
|
129
|
+
message_parts = ["[#{iteration + 1}]"]
|
130
|
+
|
131
|
+
means = {}
|
132
|
+
scores.each do |eval_name, vals|
|
133
|
+
mean = mean(vals)
|
134
|
+
stdev = stdev(vals)
|
135
|
+
|
136
|
+
(eval_hist["#{eval_name}-mean"] ||= []) << mean
|
137
|
+
(eval_hist["#{eval_name}-stdv"] ||= []) << stdev
|
138
|
+
|
139
|
+
means[eval_name] = mean
|
140
|
+
|
141
|
+
if show_stdv
|
142
|
+
message_parts << "cv_agg's %s: %g + %g" % [eval_name, mean, stdev]
|
143
|
+
else
|
144
|
+
message_parts << "cv_agg's %s: %g" % [eval_name, mean]
|
145
|
+
end
|
146
|
+
end
|
147
|
+
|
148
|
+
puts message_parts.join("\t") if verbose_eval
|
149
|
+
|
150
|
+
if early_stopping_rounds
|
151
|
+
stop_early = false
|
152
|
+
means.each do |k, score|
|
153
|
+
if best_score[k].nil? || score < best_score[k]
|
154
|
+
best_score[k] = score
|
155
|
+
best_iter[k] = iteration
|
156
|
+
elsif iteration - best_iter[k] >= early_stopping_rounds
|
157
|
+
stop_early = true
|
158
|
+
break
|
159
|
+
end
|
160
|
+
end
|
161
|
+
break if stop_early
|
162
|
+
end
|
163
|
+
end
|
164
|
+
|
165
|
+
eval_hist
|
18
166
|
end
|
19
|
-
|
20
|
-
|
167
|
+
|
168
|
+
private
|
169
|
+
|
170
|
+
def mean(arr)
|
171
|
+
arr.sum / arr.size.to_f
|
172
|
+
end
|
173
|
+
|
174
|
+
# don't subtract one from arr.size
|
175
|
+
def stdev(arr)
|
176
|
+
m = mean(arr)
|
177
|
+
sum = 0
|
178
|
+
arr.each do |v|
|
179
|
+
sum += (v - m) ** 2
|
180
|
+
end
|
181
|
+
Math.sqrt(sum / arr.size)
|
21
182
|
end
|
22
|
-
booster
|
23
183
|
end
|
24
184
|
end
|
data/lib/lightgbm/booster.rb
CHANGED
@@ -1,57 +1,65 @@
|
|
1
1
|
module LightGBM
|
2
2
|
class Booster
|
3
|
+
attr_accessor :best_iteration, :train_data_name
|
4
|
+
|
3
5
|
def initialize(params: nil, train_set: nil, model_file: nil, model_str: nil)
|
4
6
|
@handle = ::FFI::MemoryPointer.new(:pointer)
|
5
7
|
if model_str
|
6
|
-
|
7
|
-
check_result FFI.LGBM_BoosterLoadModelFromString(model_str, out_num_iterations, @handle)
|
8
|
+
model_from_string(model_str)
|
8
9
|
elsif model_file
|
9
10
|
out_num_iterations = ::FFI::MemoryPointer.new(:int)
|
10
11
|
check_result FFI.LGBM_BoosterCreateFromModelfile(model_file, out_num_iterations, @handle)
|
11
12
|
else
|
13
|
+
params ||= {}
|
14
|
+
set_verbosity(params)
|
12
15
|
check_result FFI.LGBM_BoosterCreate(train_set.handle_pointer, params_str(params), @handle)
|
13
16
|
end
|
14
17
|
# causes "Stack consistency error"
|
15
18
|
# ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer))
|
19
|
+
|
20
|
+
self.best_iteration = -1
|
21
|
+
|
22
|
+
# TODO get names when loaded from file
|
23
|
+
@name_valid_sets = []
|
16
24
|
end
|
17
25
|
|
18
26
|
def self.finalize(pointer)
|
19
27
|
-> { FFI.LGBM_BoosterFree(pointer) }
|
20
28
|
end
|
21
29
|
|
22
|
-
# TODO handle name
|
23
30
|
def add_valid(data, name)
|
24
31
|
check_result FFI.LGBM_BoosterAddValidData(handle_pointer, data.handle_pointer)
|
32
|
+
@name_valid_sets << name
|
25
33
|
self # consistent with Python API
|
26
34
|
end
|
27
35
|
|
28
|
-
def
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
data = ::FFI::MemoryPointer.new(:float, input.count * input.first.count)
|
35
|
-
data.put_array_of_float(0, input.flatten)
|
36
|
+
def current_iteration
|
37
|
+
out = ::FFI::MemoryPointer.new(:int)
|
38
|
+
check_result FFI::LGBM_BoosterGetCurrentIteration(handle_pointer, out)
|
39
|
+
out.read_int
|
40
|
+
end
|
36
41
|
|
42
|
+
def dump_model(num_iteration: nil, start_iteration: 0)
|
43
|
+
num_iteration ||= best_iteration
|
44
|
+
buffer_len = 1 << 20
|
37
45
|
out_len = ::FFI::MemoryPointer.new(:int64)
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
46
|
+
out_str = ::FFI::MemoryPointer.new(:string, buffer_len)
|
47
|
+
check_result FFI.LGBM_BoosterDumpModel(handle_pointer, start_iteration, num_iteration, buffer_len, out_len, out_str)
|
48
|
+
actual_len = out_len.read_int64
|
49
|
+
if actual_len > buffer_len
|
50
|
+
out_str = ::FFI::MemoryPointer.new(:string, actual_len)
|
51
|
+
check_result FFI.LGBM_BoosterDumpModel(handle_pointer, start_iteration, num_iteration, actual_len, out_len, out_str)
|
52
|
+
end
|
53
|
+
out_str.read_string
|
44
54
|
end
|
55
|
+
alias_method :to_json, :dump_model
|
45
56
|
|
46
|
-
def
|
47
|
-
|
48
|
-
self # consistent with Python API
|
57
|
+
def eval_valid
|
58
|
+
@name_valid_sets.each_with_index.map { |n, i| inner_eval(n, i + 1) }.flatten(1)
|
49
59
|
end
|
50
60
|
|
51
|
-
def
|
52
|
-
|
53
|
-
check_result FFI.LGBM_BoosterUpdateOneIter(handle_pointer, finished)
|
54
|
-
finished.read_int == 1
|
61
|
+
def eval_train
|
62
|
+
inner_eval(train_data_name, 0)
|
55
63
|
end
|
56
64
|
|
57
65
|
def feature_importance(iteration: nil, importance_type: "split")
|
@@ -66,27 +74,16 @@ module LightGBM
|
|
66
74
|
-1
|
67
75
|
end
|
68
76
|
|
69
|
-
|
70
|
-
out_result = ::FFI::MemoryPointer.new(:double,
|
77
|
+
num_feature = self.num_feature
|
78
|
+
out_result = ::FFI::MemoryPointer.new(:double, num_feature)
|
71
79
|
check_result FFI.LGBM_BoosterFeatureImportance(handle_pointer, iteration, importance_type, out_result)
|
72
|
-
out_result.read_array_of_double(
|
73
|
-
end
|
74
|
-
|
75
|
-
def num_features
|
76
|
-
out = ::FFI::MemoryPointer.new(:int)
|
77
|
-
check_result FFI.LGBM_BoosterGetNumFeature(handle_pointer, out)
|
78
|
-
out.read_int
|
79
|
-
end
|
80
|
-
|
81
|
-
def current_iteration
|
82
|
-
out = ::FFI::MemoryPointer.new(:int)
|
83
|
-
check_result FFI::LGBM_BoosterGetCurrentIteration(handle_pointer, out)
|
84
|
-
out.read_int
|
80
|
+
out_result.read_array_of_double(num_feature)
|
85
81
|
end
|
86
82
|
|
87
|
-
|
88
|
-
|
89
|
-
|
83
|
+
def model_from_string(model_str)
|
84
|
+
out_num_iterations = ::FFI::MemoryPointer.new(:int)
|
85
|
+
check_result FFI.LGBM_BoosterLoadModelFromString(model_str, out_num_iterations, @handle)
|
86
|
+
self
|
90
87
|
end
|
91
88
|
|
92
89
|
def model_to_string(num_iteration: nil, start_iteration: 0)
|
@@ -103,18 +100,57 @@ module LightGBM
|
|
103
100
|
out_str.read_string
|
104
101
|
end
|
105
102
|
|
106
|
-
def
|
103
|
+
def num_feature
|
104
|
+
out = ::FFI::MemoryPointer.new(:int)
|
105
|
+
check_result FFI.LGBM_BoosterGetNumFeature(handle_pointer, out)
|
106
|
+
out.read_int
|
107
|
+
end
|
108
|
+
alias_method :num_features, :num_feature # legacy typo
|
109
|
+
|
110
|
+
def num_model_per_iteration
|
111
|
+
out = ::FFI::MemoryPointer.new(:int)
|
112
|
+
check_result FFI::LGBM_BoosterNumModelPerIteration(handle_pointer, out)
|
113
|
+
out.read_int
|
114
|
+
end
|
115
|
+
|
116
|
+
def num_trees
|
117
|
+
out = ::FFI::MemoryPointer.new(:int)
|
118
|
+
check_result FFI::LGBM_BoosterNumberOfTotalModel(handle_pointer, out)
|
119
|
+
out.read_int
|
120
|
+
end
|
121
|
+
|
122
|
+
# TODO support different prediction types
|
123
|
+
def predict(input, num_iteration: nil, **params)
|
124
|
+
raise TypeError unless input.is_a?(Array)
|
125
|
+
|
126
|
+
singular = !input.first.is_a?(Array)
|
127
|
+
input = [input] if singular
|
128
|
+
|
107
129
|
num_iteration ||= best_iteration
|
108
|
-
|
130
|
+
num_class ||= num_class()
|
131
|
+
|
132
|
+
data = ::FFI::MemoryPointer.new(:float, input.count * input.first.count)
|
133
|
+
data.put_array_of_float(0, input.flatten)
|
134
|
+
|
109
135
|
out_len = ::FFI::MemoryPointer.new(:int64)
|
110
|
-
|
111
|
-
check_result FFI.
|
112
|
-
|
113
|
-
if
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
136
|
+
out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count)
|
137
|
+
check_result FFI.LGBM_BoosterPredictForMat(handle_pointer, data, 0, input.count, input.first.count, 1, 0, num_iteration, params_str(params), out_len, out_result)
|
138
|
+
out = out_result.read_array_of_double(out_len.read_int64)
|
139
|
+
out = out.each_slice(num_class).to_a if num_class > 1
|
140
|
+
|
141
|
+
singular ? out.first : out
|
142
|
+
end
|
143
|
+
|
144
|
+
def save_model(filename, num_iteration: nil, start_iteration: 0)
|
145
|
+
num_iteration ||= best_iteration
|
146
|
+
check_result FFI.LGBM_BoosterSaveModel(handle_pointer, start_iteration, num_iteration, filename)
|
147
|
+
self # consistent with Python API
|
148
|
+
end
|
149
|
+
|
150
|
+
def update
|
151
|
+
finished = ::FFI::MemoryPointer.new(:int)
|
152
|
+
check_result FFI.LGBM_BoosterUpdateOneIter(handle_pointer, finished)
|
153
|
+
finished.read_int == 1
|
118
154
|
end
|
119
155
|
|
120
156
|
private
|
@@ -123,6 +159,42 @@ module LightGBM
|
|
123
159
|
@handle.read_pointer
|
124
160
|
end
|
125
161
|
|
162
|
+
def eval_counts
|
163
|
+
out = ::FFI::MemoryPointer.new(:int)
|
164
|
+
check_result FFI::LGBM_BoosterGetEvalCounts(handle_pointer, out)
|
165
|
+
out.read_int
|
166
|
+
end
|
167
|
+
|
168
|
+
def eval_names
|
169
|
+
eval_counts ||= eval_counts()
|
170
|
+
out_len = ::FFI::MemoryPointer.new(:int)
|
171
|
+
out_strs = ::FFI::MemoryPointer.new(:pointer, eval_counts)
|
172
|
+
str_ptrs = eval_counts.times.map { ::FFI::MemoryPointer.new(:string, 255) }
|
173
|
+
out_strs.put_array_of_pointer(0, str_ptrs)
|
174
|
+
check_result FFI.LGBM_BoosterGetEvalNames(handle_pointer, out_len, out_strs)
|
175
|
+
str_ptrs.map(&:read_string)
|
176
|
+
end
|
177
|
+
|
178
|
+
def inner_eval(name, i)
|
179
|
+
eval_names ||= eval_names()
|
180
|
+
|
181
|
+
out_len = ::FFI::MemoryPointer.new(:int)
|
182
|
+
out_results = ::FFI::MemoryPointer.new(:double, eval_names.count)
|
183
|
+
check_result FFI.LGBM_BoosterGetEval(handle_pointer, i, out_len, out_results)
|
184
|
+
vals = out_results.read_array_of_double(out_len.read_int)
|
185
|
+
|
186
|
+
eval_names.zip(vals).map do |eval_name, val|
|
187
|
+
higher_better = ["auc", "ndcg@", "map@"].any? { |v| eval_name.start_with?(v) }
|
188
|
+
[name, eval_name, val, higher_better]
|
189
|
+
end
|
190
|
+
end
|
191
|
+
|
192
|
+
def num_class
|
193
|
+
out = ::FFI::MemoryPointer.new(:int)
|
194
|
+
check_result FFI::LGBM_BoosterGetNumClasses(handle_pointer, out)
|
195
|
+
out.read_int
|
196
|
+
end
|
197
|
+
|
126
198
|
include Utils
|
127
199
|
end
|
128
200
|
end
|
data/lib/lightgbm/dataset.rb
CHANGED
@@ -2,16 +2,27 @@ module LightGBM
|
|
2
2
|
class Dataset
|
3
3
|
attr_reader :data, :params
|
4
4
|
|
5
|
-
def initialize(data, label: nil, weight: nil, params: nil)
|
5
|
+
def initialize(data, label: nil, weight: nil, params: nil, reference: nil, used_indices: nil, categorical_feature: "auto")
|
6
6
|
@data = data
|
7
7
|
|
8
|
+
# TODO stringify params
|
9
|
+
params ||= {}
|
10
|
+
params["categorical_feature"] ||= categorical_feature.join(",") if categorical_feature != "auto"
|
11
|
+
set_verbosity(params)
|
12
|
+
|
8
13
|
@handle = ::FFI::MemoryPointer.new(:pointer)
|
14
|
+
parameters = params_str(params)
|
15
|
+
reference = reference.handle_pointer if reference
|
9
16
|
if data.is_a?(String)
|
10
|
-
check_result FFI.LGBM_DatasetCreateFromFile(data,
|
17
|
+
check_result FFI.LGBM_DatasetCreateFromFile(data, parameters, reference, @handle)
|
18
|
+
elsif used_indices
|
19
|
+
used_row_indices = ::FFI::MemoryPointer.new(:int32, used_indices.count)
|
20
|
+
used_row_indices.put_array_of_int32(0, used_indices)
|
21
|
+
check_result FFI.LGBM_DatasetGetSubset(reference, used_row_indices, used_indices.count, parameters, @handle)
|
11
22
|
else
|
12
23
|
c_data = ::FFI::MemoryPointer.new(:float, data.count * data.first.count)
|
13
24
|
c_data.put_array_of_float(0, data.flatten)
|
14
|
-
check_result FFI.LGBM_DatasetCreateFromMat(c_data, 0, data.count, data.first.count, 1,
|
25
|
+
check_result FFI.LGBM_DatasetCreateFromMat(c_data, 0, data.count, data.first.count, 1, parameters, reference, @handle)
|
15
26
|
end
|
16
27
|
# causes "Stack consistency error"
|
17
28
|
# ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer))
|
@@ -48,6 +59,16 @@ module LightGBM
|
|
48
59
|
check_result FFI.LGBM_DatasetDumpText(handle_pointer, filename)
|
49
60
|
end
|
50
61
|
|
62
|
+
def subset(used_indices, params: nil)
|
63
|
+
# categorical_feature passed via params
|
64
|
+
params ||= self.params
|
65
|
+
Dataset.new(nil,
|
66
|
+
params: params,
|
67
|
+
reference: self,
|
68
|
+
used_indices: used_indices
|
69
|
+
)
|
70
|
+
end
|
71
|
+
|
51
72
|
def self.finalize(pointer)
|
52
73
|
-> { FFI.LGBM_DatasetFree(pointer) }
|
53
74
|
end
|
data/lib/lightgbm/ffi.rb
CHANGED
@@ -12,6 +12,7 @@ module LightGBM
|
|
12
12
|
# dataset
|
13
13
|
attach_function :LGBM_DatasetCreateFromFile, %i[string string pointer pointer], :int
|
14
14
|
attach_function :LGBM_DatasetCreateFromMat, %i[pointer int int32 int32 int string pointer pointer], :int
|
15
|
+
attach_function :LGBM_DatasetGetSubset, %i[pointer pointer int32 string pointer], :int
|
15
16
|
attach_function :LGBM_DatasetFree, %i[pointer], :int
|
16
17
|
attach_function :LGBM_DatasetSaveBinary, %i[pointer string], :int
|
17
18
|
attach_function :LGBM_DatasetDumpText, %i[pointer string], :int
|
@@ -26,9 +27,15 @@ module LightGBM
|
|
26
27
|
attach_function :LGBM_BoosterLoadModelFromString, %i[string pointer pointer], :int
|
27
28
|
attach_function :LGBM_BoosterFree, %i[pointer], :int
|
28
29
|
attach_function :LGBM_BoosterAddValidData, %i[pointer pointer], :int
|
30
|
+
attach_function :LGBM_BoosterGetNumClasses, %i[pointer pointer], :int
|
29
31
|
attach_function :LGBM_BoosterUpdateOneIter, %i[pointer pointer], :int
|
30
32
|
attach_function :LGBM_BoosterGetCurrentIteration, %i[pointer pointer], :int
|
33
|
+
attach_function :LGBM_BoosterNumModelPerIteration, %i[pointer pointer], :int
|
34
|
+
attach_function :LGBM_BoosterNumberOfTotalModel, %i[pointer pointer], :int
|
35
|
+
attach_function :LGBM_BoosterGetEvalCounts, %i[pointer pointer], :int
|
36
|
+
attach_function :LGBM_BoosterGetEvalNames, %i[pointer pointer pointer], :int
|
31
37
|
attach_function :LGBM_BoosterGetNumFeature, %i[pointer pointer], :int
|
38
|
+
attach_function :LGBM_BoosterGetEval, %i[pointer int pointer pointer], :int
|
32
39
|
attach_function :LGBM_BoosterPredictForMat, %i[pointer pointer int int32 int32 int int int string pointer pointer], :int
|
33
40
|
attach_function :LGBM_BoosterSaveModel, %i[pointer int int string], :int
|
34
41
|
attach_function :LGBM_BoosterSaveModelToString, %i[pointer int int int64 pointer pointer], :int
|
data/lib/lightgbm/utils.rb
CHANGED
@@ -8,12 +8,20 @@ module LightGBM
|
|
8
8
|
|
9
9
|
# remove spaces in keys and values to prevent injection
|
10
10
|
def params_str(params)
|
11
|
-
|
11
|
+
params.map { |k, v| [check_param(k.to_s), check_param(Array(v).join(",").to_s)].join("=") }.join(" ")
|
12
12
|
end
|
13
13
|
|
14
14
|
def check_param(v)
|
15
15
|
raise ArgumentError, "Invalid parameter" if /[[:space:]]/.match(v)
|
16
16
|
v
|
17
17
|
end
|
18
|
+
|
19
|
+
# change default verbosity
|
20
|
+
def set_verbosity(params)
|
21
|
+
params_keys = params.keys.map(&:to_s)
|
22
|
+
unless params_keys.include?("verbosity")
|
23
|
+
params["verbosity"] = -1
|
24
|
+
end
|
25
|
+
end
|
18
26
|
end
|
19
27
|
end
|
data/lib/lightgbm/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: lightgbm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-08-
|
11
|
+
date: 2019-08-15 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: ffi
|