lightgbm 0.1.1 → 0.1.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +22 -18
- data/lib/lightgbm.rb +167 -7
- data/lib/lightgbm/booster.rb +124 -52
- data/lib/lightgbm/dataset.rb +24 -3
- data/lib/lightgbm/ffi.rb +7 -0
- data/lib/lightgbm/utils.rb +9 -1
- data/lib/lightgbm/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 723130d41ea9196bcbd7bcffeb865c40c65985f26eca018d49bd176d33c43142
|
4
|
+
data.tar.gz: d92b41899ff72da2ef4e5782bf4d2840caee1554107d9fd5d02bd6728829585a
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 6960dbf1e2a884705e8a2752952392483c0ab1e74e970382b45df747ccbedbc0d978d3910e2b935190c98a9c45315841b53e27128aa5944e3a7834808e05582a
|
7
|
+
data.tar.gz: c6e793933dc794fa62099580ad35f29d7e5e3ae24a07df4335dca3d68571bc1d5b360a7cf0859d75d245a77d470a0dafba0cb5d86edda1974f3bc532b0f5c11a
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -16,6 +16,8 @@ Add this line to your application’s Gemfile:
|
|
16
16
|
gem 'lightgbm'
|
17
17
|
```
|
18
18
|
|
19
|
+
## Getting Started
|
20
|
+
|
19
21
|
Train a model
|
20
22
|
|
21
23
|
```ruby
|
@@ -30,13 +32,13 @@ Predict
|
|
30
32
|
booster.predict(x_test)
|
31
33
|
```
|
32
34
|
|
33
|
-
Save the model
|
35
|
+
Save the model to a file
|
34
36
|
|
35
37
|
```ruby
|
36
38
|
booster.save_model("model.txt")
|
37
39
|
```
|
38
40
|
|
39
|
-
Load
|
41
|
+
Load the model from a file
|
40
42
|
|
41
43
|
```ruby
|
42
44
|
booster = LightGBM::Booster.new(model_file: "model.txt")
|
@@ -48,30 +50,32 @@ Get feature importance
|
|
48
50
|
booster.feature_importance
|
49
51
|
```
|
50
52
|
|
51
|
-
##
|
52
|
-
|
53
|
-
### Booster
|
53
|
+
## Early Stopping
|
54
54
|
|
55
55
|
```ruby
|
56
|
-
|
57
|
-
booster.to_json
|
58
|
-
booster.model_to_string
|
59
|
-
booster.current_iteration
|
56
|
+
LightGBM.train(params, train_set, valid_set: [train_set, test_set], early_stopping_rounds: 5)
|
60
57
|
```
|
61
58
|
|
62
|
-
|
59
|
+
## CV
|
63
60
|
|
64
61
|
```ruby
|
65
|
-
|
66
|
-
dataset.num_data
|
67
|
-
dataset.num_feature
|
68
|
-
|
69
|
-
# note: only works with unquoted CSVs
|
70
|
-
dataset = LightGBM::Dataset.new("data.csv", params: {headers: true, label: "name:label"})
|
71
|
-
dataset.save_binary("train.bin")
|
72
|
-
dataset.dump_text("train.txt")
|
62
|
+
LightGBM.cv(params, train_set, nfold: 5, verbose_eval: true)
|
73
63
|
```
|
74
64
|
|
65
|
+
## Reference
|
66
|
+
|
67
|
+
This library follows the [Data Structure and Training APIs](https://lightgbm.readthedocs.io/en/latest/Python-API.html) for the Python library. A few differences are:
|
68
|
+
|
69
|
+
- The default verbosity is `-1`
|
70
|
+
- With the `cv` method, `stratified` is set to `false`
|
71
|
+
|
72
|
+
Some methods and options are also missing at the moment. PRs welcome!
|
73
|
+
|
74
|
+
## Helpful Resources
|
75
|
+
|
76
|
+
- [Parameters](https://lightgbm.readthedocs.io/en/latest/Parameters.html)
|
77
|
+
- [Parameter Tuning](https://lightgbm.readthedocs.io/en/latest/Parameters-Tuning.html)
|
78
|
+
|
75
79
|
## Credits
|
76
80
|
|
77
81
|
Thanks to the [xgboost](https://github.com/PairOnAir/xgboost-ruby) gem for serving as an initial reference, and Selva Prabhakaran for the [test datasets](https://github.com/selva86/datasets).
|
data/lib/lightgbm.rb
CHANGED
@@ -11,14 +11,174 @@ require "lightgbm/version"
|
|
11
11
|
module LightGBM
|
12
12
|
class Error < StandardError; end
|
13
13
|
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
14
|
+
class << self
|
15
|
+
def train(params, train_set,num_boost_round: 100, valid_sets: [], valid_names: [], early_stopping_rounds: nil, verbose_eval: true)
|
16
|
+
booster = Booster.new(params: params, train_set: train_set)
|
17
|
+
|
18
|
+
valid_contain_train = false
|
19
|
+
valid_sets.zip(valid_names).each_with_index do |(data, name), i|
|
20
|
+
if data == train_set
|
21
|
+
booster.train_data_name = name || "training"
|
22
|
+
valid_contain_train = true
|
23
|
+
else
|
24
|
+
booster.add_valid(data, name || "valid_#{i}")
|
25
|
+
end
|
26
|
+
end
|
27
|
+
|
28
|
+
booster.best_iteration = 0
|
29
|
+
|
30
|
+
if early_stopping_rounds
|
31
|
+
best_score = []
|
32
|
+
best_iter = []
|
33
|
+
best_message = []
|
34
|
+
|
35
|
+
puts "Training until validation scores don't improve for #{early_stopping_rounds.to_i} rounds." if verbose_eval
|
36
|
+
end
|
37
|
+
|
38
|
+
num_boost_round.times do |iteration|
|
39
|
+
booster.update
|
40
|
+
|
41
|
+
if valid_sets.any?
|
42
|
+
# print results
|
43
|
+
messages = []
|
44
|
+
|
45
|
+
if valid_contain_train
|
46
|
+
# not sure why reversed in output
|
47
|
+
booster.eval_train.reverse.each do |res|
|
48
|
+
messages << "%s's %s: %g" % [res[0], res[1], res[2]]
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
eval_valid = booster.eval_valid
|
53
|
+
# not sure why reversed in output
|
54
|
+
eval_valid.reverse.each do |res|
|
55
|
+
messages << "%s's %s: %g" % [res[0], res[1], res[2]]
|
56
|
+
end
|
57
|
+
|
58
|
+
message = "[#{iteration + 1}]\t#{messages.join("\t")}"
|
59
|
+
|
60
|
+
puts message if verbose_eval
|
61
|
+
|
62
|
+
if early_stopping_rounds
|
63
|
+
stop_early = false
|
64
|
+
eval_valid.each_with_index do |(_, _, score, higher_better), i|
|
65
|
+
op = higher_better ? :> : :<
|
66
|
+
if best_score[i].nil? || score.send(op, best_score[i])
|
67
|
+
best_score[i] = score
|
68
|
+
best_iter[i] = iteration
|
69
|
+
best_message[i] = message
|
70
|
+
elsif iteration - best_iter[i] >= early_stopping_rounds
|
71
|
+
booster.best_iteration = best_iter[i] + 1
|
72
|
+
puts "Early stopping, best iteration is:\n#{best_message[i]}" if verbose_eval
|
73
|
+
stop_early = true
|
74
|
+
break
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
break if stop_early
|
79
|
+
|
80
|
+
if iteration == num_boost_round - 1
|
81
|
+
booster.best_iteration = best_iter[0] + 1
|
82
|
+
puts "Did not meet early stopping. Best iteration is: #{best_message[0]}" if verbose_eval
|
83
|
+
end
|
84
|
+
end
|
85
|
+
end
|
86
|
+
end
|
87
|
+
|
88
|
+
booster
|
89
|
+
end
|
90
|
+
|
91
|
+
def cv(params, train_set, num_boost_round: 100, nfold: 5, seed: 0, shuffle: true, early_stopping_rounds: nil, verbose_eval: nil, show_stdv: true)
|
92
|
+
rand_idx = (0...train_set.num_data).to_a
|
93
|
+
rand_idx.shuffle!(random: Random.new(seed)) if shuffle
|
94
|
+
|
95
|
+
kstep = rand_idx.size / nfold
|
96
|
+
test_id = rand_idx.each_slice(kstep).to_a[0...nfold]
|
97
|
+
train_id = []
|
98
|
+
nfold.times do |i|
|
99
|
+
idx = test_id.dup
|
100
|
+
idx.delete_at(i)
|
101
|
+
train_id << idx.flatten
|
102
|
+
end
|
103
|
+
|
104
|
+
boosters = []
|
105
|
+
folds = train_id.zip(test_id)
|
106
|
+
folds.each do |(train_idx, test_idx)|
|
107
|
+
fold_train_set = train_set.subset(train_idx)
|
108
|
+
fold_valid_set = train_set.subset(test_idx)
|
109
|
+
booster = Booster.new(params: params, train_set: fold_train_set)
|
110
|
+
booster.add_valid(fold_valid_set, "valid")
|
111
|
+
boosters << booster
|
112
|
+
end
|
113
|
+
|
114
|
+
eval_hist = {}
|
115
|
+
|
116
|
+
if early_stopping_rounds
|
117
|
+
best_score = {}
|
118
|
+
best_iter = {}
|
119
|
+
end
|
120
|
+
|
121
|
+
num_boost_round.times do |iteration|
|
122
|
+
boosters.each(&:update)
|
123
|
+
|
124
|
+
scores = {}
|
125
|
+
boosters.map(&:eval_valid).map(&:reverse).flatten(1).each do |r|
|
126
|
+
(scores[r[1]] ||= []) << r[2]
|
127
|
+
end
|
128
|
+
|
129
|
+
message_parts = ["[#{iteration + 1}]"]
|
130
|
+
|
131
|
+
means = {}
|
132
|
+
scores.each do |eval_name, vals|
|
133
|
+
mean = mean(vals)
|
134
|
+
stdev = stdev(vals)
|
135
|
+
|
136
|
+
(eval_hist["#{eval_name}-mean"] ||= []) << mean
|
137
|
+
(eval_hist["#{eval_name}-stdv"] ||= []) << stdev
|
138
|
+
|
139
|
+
means[eval_name] = mean
|
140
|
+
|
141
|
+
if show_stdv
|
142
|
+
message_parts << "cv_agg's %s: %g + %g" % [eval_name, mean, stdev]
|
143
|
+
else
|
144
|
+
message_parts << "cv_agg's %s: %g" % [eval_name, mean]
|
145
|
+
end
|
146
|
+
end
|
147
|
+
|
148
|
+
puts message_parts.join("\t") if verbose_eval
|
149
|
+
|
150
|
+
if early_stopping_rounds
|
151
|
+
stop_early = false
|
152
|
+
means.each do |k, score|
|
153
|
+
if best_score[k].nil? || score < best_score[k]
|
154
|
+
best_score[k] = score
|
155
|
+
best_iter[k] = iteration
|
156
|
+
elsif iteration - best_iter[k] >= early_stopping_rounds
|
157
|
+
stop_early = true
|
158
|
+
break
|
159
|
+
end
|
160
|
+
end
|
161
|
+
break if stop_early
|
162
|
+
end
|
163
|
+
end
|
164
|
+
|
165
|
+
eval_hist
|
18
166
|
end
|
19
|
-
|
20
|
-
|
167
|
+
|
168
|
+
private
|
169
|
+
|
170
|
+
def mean(arr)
|
171
|
+
arr.sum / arr.size.to_f
|
172
|
+
end
|
173
|
+
|
174
|
+
# don't subtract one from arr.size
|
175
|
+
def stdev(arr)
|
176
|
+
m = mean(arr)
|
177
|
+
sum = 0
|
178
|
+
arr.each do |v|
|
179
|
+
sum += (v - m) ** 2
|
180
|
+
end
|
181
|
+
Math.sqrt(sum / arr.size)
|
21
182
|
end
|
22
|
-
booster
|
23
183
|
end
|
24
184
|
end
|
data/lib/lightgbm/booster.rb
CHANGED
@@ -1,57 +1,65 @@
|
|
1
1
|
module LightGBM
|
2
2
|
class Booster
|
3
|
+
attr_accessor :best_iteration, :train_data_name
|
4
|
+
|
3
5
|
def initialize(params: nil, train_set: nil, model_file: nil, model_str: nil)
|
4
6
|
@handle = ::FFI::MemoryPointer.new(:pointer)
|
5
7
|
if model_str
|
6
|
-
|
7
|
-
check_result FFI.LGBM_BoosterLoadModelFromString(model_str, out_num_iterations, @handle)
|
8
|
+
model_from_string(model_str)
|
8
9
|
elsif model_file
|
9
10
|
out_num_iterations = ::FFI::MemoryPointer.new(:int)
|
10
11
|
check_result FFI.LGBM_BoosterCreateFromModelfile(model_file, out_num_iterations, @handle)
|
11
12
|
else
|
13
|
+
params ||= {}
|
14
|
+
set_verbosity(params)
|
12
15
|
check_result FFI.LGBM_BoosterCreate(train_set.handle_pointer, params_str(params), @handle)
|
13
16
|
end
|
14
17
|
# causes "Stack consistency error"
|
15
18
|
# ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer))
|
19
|
+
|
20
|
+
self.best_iteration = -1
|
21
|
+
|
22
|
+
# TODO get names when loaded from file
|
23
|
+
@name_valid_sets = []
|
16
24
|
end
|
17
25
|
|
18
26
|
def self.finalize(pointer)
|
19
27
|
-> { FFI.LGBM_BoosterFree(pointer) }
|
20
28
|
end
|
21
29
|
|
22
|
-
# TODO handle name
|
23
30
|
def add_valid(data, name)
|
24
31
|
check_result FFI.LGBM_BoosterAddValidData(handle_pointer, data.handle_pointer)
|
32
|
+
@name_valid_sets << name
|
25
33
|
self # consistent with Python API
|
26
34
|
end
|
27
35
|
|
28
|
-
def
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
data = ::FFI::MemoryPointer.new(:float, input.count * input.first.count)
|
35
|
-
data.put_array_of_float(0, input.flatten)
|
36
|
+
def current_iteration
|
37
|
+
out = ::FFI::MemoryPointer.new(:int)
|
38
|
+
check_result FFI::LGBM_BoosterGetCurrentIteration(handle_pointer, out)
|
39
|
+
out.read_int
|
40
|
+
end
|
36
41
|
|
42
|
+
def dump_model(num_iteration: nil, start_iteration: 0)
|
43
|
+
num_iteration ||= best_iteration
|
44
|
+
buffer_len = 1 << 20
|
37
45
|
out_len = ::FFI::MemoryPointer.new(:int64)
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
46
|
+
out_str = ::FFI::MemoryPointer.new(:string, buffer_len)
|
47
|
+
check_result FFI.LGBM_BoosterDumpModel(handle_pointer, start_iteration, num_iteration, buffer_len, out_len, out_str)
|
48
|
+
actual_len = out_len.read_int64
|
49
|
+
if actual_len > buffer_len
|
50
|
+
out_str = ::FFI::MemoryPointer.new(:string, actual_len)
|
51
|
+
check_result FFI.LGBM_BoosterDumpModel(handle_pointer, start_iteration, num_iteration, actual_len, out_len, out_str)
|
52
|
+
end
|
53
|
+
out_str.read_string
|
44
54
|
end
|
55
|
+
alias_method :to_json, :dump_model
|
45
56
|
|
46
|
-
def
|
47
|
-
|
48
|
-
self # consistent with Python API
|
57
|
+
def eval_valid
|
58
|
+
@name_valid_sets.each_with_index.map { |n, i| inner_eval(n, i + 1) }.flatten(1)
|
49
59
|
end
|
50
60
|
|
51
|
-
def
|
52
|
-
|
53
|
-
check_result FFI.LGBM_BoosterUpdateOneIter(handle_pointer, finished)
|
54
|
-
finished.read_int == 1
|
61
|
+
def eval_train
|
62
|
+
inner_eval(train_data_name, 0)
|
55
63
|
end
|
56
64
|
|
57
65
|
def feature_importance(iteration: nil, importance_type: "split")
|
@@ -66,27 +74,16 @@ module LightGBM
|
|
66
74
|
-1
|
67
75
|
end
|
68
76
|
|
69
|
-
|
70
|
-
out_result = ::FFI::MemoryPointer.new(:double,
|
77
|
+
num_feature = self.num_feature
|
78
|
+
out_result = ::FFI::MemoryPointer.new(:double, num_feature)
|
71
79
|
check_result FFI.LGBM_BoosterFeatureImportance(handle_pointer, iteration, importance_type, out_result)
|
72
|
-
out_result.read_array_of_double(
|
73
|
-
end
|
74
|
-
|
75
|
-
def num_features
|
76
|
-
out = ::FFI::MemoryPointer.new(:int)
|
77
|
-
check_result FFI.LGBM_BoosterGetNumFeature(handle_pointer, out)
|
78
|
-
out.read_int
|
79
|
-
end
|
80
|
-
|
81
|
-
def current_iteration
|
82
|
-
out = ::FFI::MemoryPointer.new(:int)
|
83
|
-
check_result FFI::LGBM_BoosterGetCurrentIteration(handle_pointer, out)
|
84
|
-
out.read_int
|
80
|
+
out_result.read_array_of_double(num_feature)
|
85
81
|
end
|
86
82
|
|
87
|
-
|
88
|
-
|
89
|
-
|
83
|
+
def model_from_string(model_str)
|
84
|
+
out_num_iterations = ::FFI::MemoryPointer.new(:int)
|
85
|
+
check_result FFI.LGBM_BoosterLoadModelFromString(model_str, out_num_iterations, @handle)
|
86
|
+
self
|
90
87
|
end
|
91
88
|
|
92
89
|
def model_to_string(num_iteration: nil, start_iteration: 0)
|
@@ -103,18 +100,57 @@ module LightGBM
|
|
103
100
|
out_str.read_string
|
104
101
|
end
|
105
102
|
|
106
|
-
def
|
103
|
+
def num_feature
|
104
|
+
out = ::FFI::MemoryPointer.new(:int)
|
105
|
+
check_result FFI.LGBM_BoosterGetNumFeature(handle_pointer, out)
|
106
|
+
out.read_int
|
107
|
+
end
|
108
|
+
alias_method :num_features, :num_feature # legacy typo
|
109
|
+
|
110
|
+
def num_model_per_iteration
|
111
|
+
out = ::FFI::MemoryPointer.new(:int)
|
112
|
+
check_result FFI::LGBM_BoosterNumModelPerIteration(handle_pointer, out)
|
113
|
+
out.read_int
|
114
|
+
end
|
115
|
+
|
116
|
+
def num_trees
|
117
|
+
out = ::FFI::MemoryPointer.new(:int)
|
118
|
+
check_result FFI::LGBM_BoosterNumberOfTotalModel(handle_pointer, out)
|
119
|
+
out.read_int
|
120
|
+
end
|
121
|
+
|
122
|
+
# TODO support different prediction types
|
123
|
+
def predict(input, num_iteration: nil, **params)
|
124
|
+
raise TypeError unless input.is_a?(Array)
|
125
|
+
|
126
|
+
singular = !input.first.is_a?(Array)
|
127
|
+
input = [input] if singular
|
128
|
+
|
107
129
|
num_iteration ||= best_iteration
|
108
|
-
|
130
|
+
num_class ||= num_class()
|
131
|
+
|
132
|
+
data = ::FFI::MemoryPointer.new(:float, input.count * input.first.count)
|
133
|
+
data.put_array_of_float(0, input.flatten)
|
134
|
+
|
109
135
|
out_len = ::FFI::MemoryPointer.new(:int64)
|
110
|
-
|
111
|
-
check_result FFI.
|
112
|
-
|
113
|
-
if
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
136
|
+
out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count)
|
137
|
+
check_result FFI.LGBM_BoosterPredictForMat(handle_pointer, data, 0, input.count, input.first.count, 1, 0, num_iteration, params_str(params), out_len, out_result)
|
138
|
+
out = out_result.read_array_of_double(out_len.read_int64)
|
139
|
+
out = out.each_slice(num_class).to_a if num_class > 1
|
140
|
+
|
141
|
+
singular ? out.first : out
|
142
|
+
end
|
143
|
+
|
144
|
+
def save_model(filename, num_iteration: nil, start_iteration: 0)
|
145
|
+
num_iteration ||= best_iteration
|
146
|
+
check_result FFI.LGBM_BoosterSaveModel(handle_pointer, start_iteration, num_iteration, filename)
|
147
|
+
self # consistent with Python API
|
148
|
+
end
|
149
|
+
|
150
|
+
def update
|
151
|
+
finished = ::FFI::MemoryPointer.new(:int)
|
152
|
+
check_result FFI.LGBM_BoosterUpdateOneIter(handle_pointer, finished)
|
153
|
+
finished.read_int == 1
|
118
154
|
end
|
119
155
|
|
120
156
|
private
|
@@ -123,6 +159,42 @@ module LightGBM
|
|
123
159
|
@handle.read_pointer
|
124
160
|
end
|
125
161
|
|
162
|
+
def eval_counts
|
163
|
+
out = ::FFI::MemoryPointer.new(:int)
|
164
|
+
check_result FFI::LGBM_BoosterGetEvalCounts(handle_pointer, out)
|
165
|
+
out.read_int
|
166
|
+
end
|
167
|
+
|
168
|
+
def eval_names
|
169
|
+
eval_counts ||= eval_counts()
|
170
|
+
out_len = ::FFI::MemoryPointer.new(:int)
|
171
|
+
out_strs = ::FFI::MemoryPointer.new(:pointer, eval_counts)
|
172
|
+
str_ptrs = eval_counts.times.map { ::FFI::MemoryPointer.new(:string, 255) }
|
173
|
+
out_strs.put_array_of_pointer(0, str_ptrs)
|
174
|
+
check_result FFI.LGBM_BoosterGetEvalNames(handle_pointer, out_len, out_strs)
|
175
|
+
str_ptrs.map(&:read_string)
|
176
|
+
end
|
177
|
+
|
178
|
+
def inner_eval(name, i)
|
179
|
+
eval_names ||= eval_names()
|
180
|
+
|
181
|
+
out_len = ::FFI::MemoryPointer.new(:int)
|
182
|
+
out_results = ::FFI::MemoryPointer.new(:double, eval_names.count)
|
183
|
+
check_result FFI.LGBM_BoosterGetEval(handle_pointer, i, out_len, out_results)
|
184
|
+
vals = out_results.read_array_of_double(out_len.read_int)
|
185
|
+
|
186
|
+
eval_names.zip(vals).map do |eval_name, val|
|
187
|
+
higher_better = ["auc", "ndcg@", "map@"].any? { |v| eval_name.start_with?(v) }
|
188
|
+
[name, eval_name, val, higher_better]
|
189
|
+
end
|
190
|
+
end
|
191
|
+
|
192
|
+
def num_class
|
193
|
+
out = ::FFI::MemoryPointer.new(:int)
|
194
|
+
check_result FFI::LGBM_BoosterGetNumClasses(handle_pointer, out)
|
195
|
+
out.read_int
|
196
|
+
end
|
197
|
+
|
126
198
|
include Utils
|
127
199
|
end
|
128
200
|
end
|
data/lib/lightgbm/dataset.rb
CHANGED
@@ -2,16 +2,27 @@ module LightGBM
|
|
2
2
|
class Dataset
|
3
3
|
attr_reader :data, :params
|
4
4
|
|
5
|
-
def initialize(data, label: nil, weight: nil, params: nil)
|
5
|
+
def initialize(data, label: nil, weight: nil, params: nil, reference: nil, used_indices: nil, categorical_feature: "auto")
|
6
6
|
@data = data
|
7
7
|
|
8
|
+
# TODO stringify params
|
9
|
+
params ||= {}
|
10
|
+
params["categorical_feature"] ||= categorical_feature.join(",") if categorical_feature != "auto"
|
11
|
+
set_verbosity(params)
|
12
|
+
|
8
13
|
@handle = ::FFI::MemoryPointer.new(:pointer)
|
14
|
+
parameters = params_str(params)
|
15
|
+
reference = reference.handle_pointer if reference
|
9
16
|
if data.is_a?(String)
|
10
|
-
check_result FFI.LGBM_DatasetCreateFromFile(data,
|
17
|
+
check_result FFI.LGBM_DatasetCreateFromFile(data, parameters, reference, @handle)
|
18
|
+
elsif used_indices
|
19
|
+
used_row_indices = ::FFI::MemoryPointer.new(:int32, used_indices.count)
|
20
|
+
used_row_indices.put_array_of_int32(0, used_indices)
|
21
|
+
check_result FFI.LGBM_DatasetGetSubset(reference, used_row_indices, used_indices.count, parameters, @handle)
|
11
22
|
else
|
12
23
|
c_data = ::FFI::MemoryPointer.new(:float, data.count * data.first.count)
|
13
24
|
c_data.put_array_of_float(0, data.flatten)
|
14
|
-
check_result FFI.LGBM_DatasetCreateFromMat(c_data, 0, data.count, data.first.count, 1,
|
25
|
+
check_result FFI.LGBM_DatasetCreateFromMat(c_data, 0, data.count, data.first.count, 1, parameters, reference, @handle)
|
15
26
|
end
|
16
27
|
# causes "Stack consistency error"
|
17
28
|
# ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer))
|
@@ -48,6 +59,16 @@ module LightGBM
|
|
48
59
|
check_result FFI.LGBM_DatasetDumpText(handle_pointer, filename)
|
49
60
|
end
|
50
61
|
|
62
|
+
def subset(used_indices, params: nil)
|
63
|
+
# categorical_feature passed via params
|
64
|
+
params ||= self.params
|
65
|
+
Dataset.new(nil,
|
66
|
+
params: params,
|
67
|
+
reference: self,
|
68
|
+
used_indices: used_indices
|
69
|
+
)
|
70
|
+
end
|
71
|
+
|
51
72
|
def self.finalize(pointer)
|
52
73
|
-> { FFI.LGBM_DatasetFree(pointer) }
|
53
74
|
end
|
data/lib/lightgbm/ffi.rb
CHANGED
@@ -12,6 +12,7 @@ module LightGBM
|
|
12
12
|
# dataset
|
13
13
|
attach_function :LGBM_DatasetCreateFromFile, %i[string string pointer pointer], :int
|
14
14
|
attach_function :LGBM_DatasetCreateFromMat, %i[pointer int int32 int32 int string pointer pointer], :int
|
15
|
+
attach_function :LGBM_DatasetGetSubset, %i[pointer pointer int32 string pointer], :int
|
15
16
|
attach_function :LGBM_DatasetFree, %i[pointer], :int
|
16
17
|
attach_function :LGBM_DatasetSaveBinary, %i[pointer string], :int
|
17
18
|
attach_function :LGBM_DatasetDumpText, %i[pointer string], :int
|
@@ -26,9 +27,15 @@ module LightGBM
|
|
26
27
|
attach_function :LGBM_BoosterLoadModelFromString, %i[string pointer pointer], :int
|
27
28
|
attach_function :LGBM_BoosterFree, %i[pointer], :int
|
28
29
|
attach_function :LGBM_BoosterAddValidData, %i[pointer pointer], :int
|
30
|
+
attach_function :LGBM_BoosterGetNumClasses, %i[pointer pointer], :int
|
29
31
|
attach_function :LGBM_BoosterUpdateOneIter, %i[pointer pointer], :int
|
30
32
|
attach_function :LGBM_BoosterGetCurrentIteration, %i[pointer pointer], :int
|
33
|
+
attach_function :LGBM_BoosterNumModelPerIteration, %i[pointer pointer], :int
|
34
|
+
attach_function :LGBM_BoosterNumberOfTotalModel, %i[pointer pointer], :int
|
35
|
+
attach_function :LGBM_BoosterGetEvalCounts, %i[pointer pointer], :int
|
36
|
+
attach_function :LGBM_BoosterGetEvalNames, %i[pointer pointer pointer], :int
|
31
37
|
attach_function :LGBM_BoosterGetNumFeature, %i[pointer pointer], :int
|
38
|
+
attach_function :LGBM_BoosterGetEval, %i[pointer int pointer pointer], :int
|
32
39
|
attach_function :LGBM_BoosterPredictForMat, %i[pointer pointer int int32 int32 int int int string pointer pointer], :int
|
33
40
|
attach_function :LGBM_BoosterSaveModel, %i[pointer int int string], :int
|
34
41
|
attach_function :LGBM_BoosterSaveModelToString, %i[pointer int int int64 pointer pointer], :int
|
data/lib/lightgbm/utils.rb
CHANGED
@@ -8,12 +8,20 @@ module LightGBM
|
|
8
8
|
|
9
9
|
# remove spaces in keys and values to prevent injection
|
10
10
|
def params_str(params)
|
11
|
-
|
11
|
+
params.map { |k, v| [check_param(k.to_s), check_param(Array(v).join(",").to_s)].join("=") }.join(" ")
|
12
12
|
end
|
13
13
|
|
14
14
|
def check_param(v)
|
15
15
|
raise ArgumentError, "Invalid parameter" if /[[:space:]]/.match(v)
|
16
16
|
v
|
17
17
|
end
|
18
|
+
|
19
|
+
# change default verbosity
|
20
|
+
def set_verbosity(params)
|
21
|
+
params_keys = params.keys.map(&:to_s)
|
22
|
+
unless params_keys.include?("verbosity")
|
23
|
+
params["verbosity"] = -1
|
24
|
+
end
|
25
|
+
end
|
18
26
|
end
|
19
27
|
end
|
data/lib/lightgbm/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: lightgbm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-08-
|
11
|
+
date: 2019-08-15 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: ffi
|