xgb 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: dcb0666c9eb0943a77ccdbaa31074b8ca0b139ac196006b5cd6d680f6f91ac30
4
- data.tar.gz: 673040af645aafc5f14378121b9dac594b5dc6dae0449f75244cc9cfeaaa10e3
3
+ metadata.gz: 1bb50395d579da91b18754bc75e780cbb2e98fd7a48a17c34514230d1c4828d1
4
+ data.tar.gz: 3d2f9c5a72c63c2622a973805c9f2caa9bd4de7b5c67f8c4b5445fd9a71993c3
5
5
  SHA512:
6
- metadata.gz: 3f2838ee9f2b69ea69fdc9f3610a36f412acb8732f726478995dad0ebcf73325f99c6c3afc480ae60458035619c83f9e2c9a03d60020cf8e734ea250334541c0
7
- data.tar.gz: 69034ec7cad4174837cc8347ab5cc13144532701927aa702eb93c6819ddf8ffe6d3c5c3752b39de863178e7d404cef017ddd82b7fe811aa4d5f3b62aa051729d
6
+ metadata.gz: f141b3ea0b6ceb8549198fd6ad8a07f6947201409478fc4829fe625da376e40d8028427a5aa34191b565aa275d27bb03e2082bb8fc489f6da6a2a09b3bbf2c2f
7
+ data.tar.gz: c393f4fdbe240ffc14b64f22f17d149ed393070fb0752f9ec49dd94bcfa88f446ea21bc5bf9a96bef7759c5c47033dd0480a4001d477b8c487cf5dcf8be19b81
@@ -1,3 +1,12 @@
1
+ ## 0.1.1
2
+
3
+ - Added Scikit-Learn API
4
+ - Added early stopping
5
+ - Added `cv` method
6
+ - Added support for Daru and Numo::NArray
7
+ - Added many other methods
8
+ - Fixed shape of multiclass predictions when loaded from file
9
+
1
10
  ## 0.1.0
2
11
 
3
12
  - First release
data/README.md CHANGED
@@ -4,6 +4,8 @@
4
4
 
5
5
  :fire: Uses the C API for blazing performance
6
6
 
7
+ [![Build Status](https://travis-ci.org/ankane/xgb.svg?branch=master)](https://travis-ci.org/ankane/xgb)
8
+
7
9
  ## Installation
8
10
 
9
11
  First, [install XGBoost](https://xgboost.readthedocs.io/en/latest/build.html). On Mac, copy `lib/libxgboost.dylib` to `/usr/local/lib`.
@@ -16,12 +18,16 @@ gem 'xgb'
16
18
 
17
19
  ## Getting Started
18
20
 
21
+ This library follows the [Core Data Structure, Learning and Scikit-Learn APIs](https://xgboost.readthedocs.io/en/latest/python/python_api.html) of the Python library. Some methods and options are missing at the moment. PRs welcome!
22
+
23
+ ## Learning API
24
+
19
25
  Train a model
20
26
 
21
27
  ```ruby
22
28
  params = {objective: "reg:squarederror"}
23
- train_set = Xgb::DMatrix.new(x_train, label: y_train)
24
- booster = Xgb.train(params, train_set)
29
+ dtrain = Xgb::DMatrix.new(x_train, label: y_train)
30
+ booster = Xgb.train(params, dtrain)
25
31
  ```
26
32
 
27
33
  Predict
@@ -33,18 +39,96 @@ booster.predict(x_test)
33
39
  Save the model to a file
34
40
 
35
41
  ```ruby
36
- booster.save_model("model.txt")
42
+ booster.save_model("my.model")
37
43
  ```
38
44
 
39
45
  Load the model from a file
40
46
 
41
47
  ```ruby
42
- booster = Xgb::Booster.new(model_file: "model.txt")
48
+ booster = Xgb::Booster.new(model_file: "my.model")
49
+ ```
50
+
51
+ Get the importance of features
52
+
53
+ ```ruby
54
+ booster.score
55
+ ```
56
+
57
+ Early stopping
58
+
59
+ ```ruby
60
+ Xgb.train(params, dtrain, evals: [[dtrain, "train"], [dtest, "eval"]], early_stopping_rounds: 5)
61
+ ```
62
+
63
+ CV
64
+
65
+ ```ruby
66
+ Xgb.cv(params, dtrain, nfold: 3, verbose_eval: true)
43
67
  ```
44
68
 
45
- ## Reference
69
+ ## Scikit-Learn API
46
70
 
47
- This library follows the [Core Data Structure and Learning APIs](https://xgboost.readthedocs.io/en/latest/python/python_api.html) for the Python library. Some methods and options are missing at the moment. PRs welcome!
71
+ Prep your data
72
+
73
+ ```ruby
74
+ x = [[1, 2], [3, 4], [5, 6], [7, 8]]
75
+ y = [1, 2, 3, 4]
76
+ ```
77
+
78
+ Train a model
79
+
80
+ ```ruby
81
+ model = Xgb::Regressor.new
82
+ model.fit(x, y)
83
+ ```
84
+
85
+ > For classification, use `Xgb::Classifier`
86
+
87
+ Predict
88
+
89
+ ```ruby
90
+ model.predict(x)
91
+ ```
92
+
93
+ > For classification, use `predict_proba` for probabilities
94
+
95
+ Save the model to a file
96
+
97
+ ```ruby
98
+ model.save_model("my.model")
99
+ ```
100
+
101
+ Load the model from a file
102
+
103
+ ```ruby
104
+ model.load_model("my.model")
105
+ ```
106
+
107
+ Get the importance of features
108
+
109
+ ```ruby
110
+ model.feature_importances
111
+ ```
112
+
113
+ ## Data
114
+
115
+ Data can be an array of arrays
116
+
117
+ ```ruby
118
+ [[1, 2, 3], [4, 5, 6]]
119
+ ```
120
+
121
+ Or a Daru data frame
122
+
123
+ ```ruby
124
+ Daru::DataFrame.from_csv("houses.csv")
125
+ ```
126
+
127
+ Or a Numo NArray
128
+
129
+ ```ruby
130
+ Numo::DFloat.new(3, 2).seq
131
+ ```
48
132
 
49
133
  ## Helpful Resources
50
134
 
data/lib/xgb.rb CHANGED
@@ -8,19 +8,157 @@ require "xgb/dmatrix"
8
8
  require "xgb/ffi"
9
9
  require "xgb/version"
10
10
 
11
+ # scikit-learn API
12
+ require "xgb/classifier"
13
+ require "xgb/regressor"
14
+
11
15
  module Xgb
12
16
  class Error < StandardError; end
13
17
 
14
18
  class << self
15
- def train(params, dtrain, num_boost_round: 10)
19
+ def train(params, dtrain, num_boost_round: 10, evals: nil, early_stopping_rounds: nil, verbose_eval: true)
16
20
  booster = Booster.new(params: params)
17
- booster.set_param("num_feature", dtrain.num_col)
21
+ num_feature = dtrain.num_col
22
+ booster.set_param("num_feature", num_feature)
23
+ booster.feature_names = num_feature.times.map { |i| "f#{i}" }
24
+ evals ||= []
25
+
26
+ if early_stopping_rounds
27
+ best_score = nil
28
+ best_iter = nil
29
+ best_message = nil
30
+ end
18
31
 
19
32
  num_boost_round.times do |iteration|
20
33
  booster.update(dtrain, iteration)
34
+
35
+ if evals.any?
36
+ message = booster.eval_set(evals, iteration)
37
+ res = message.split.map { |x| x.split(":") }[1..-1].map { |k, v| [k, v.to_f] }
38
+
39
+ if early_stopping_rounds && iteration == 0
40
+ metric = res[-1][0]
41
+ puts "Will train until #{metric} hasn't improved in #{early_stopping_rounds.to_i} rounds." if verbose_eval
42
+ end
43
+
44
+ puts message if verbose_eval
45
+ score = res[-1][1]
46
+
47
+ # TODO handle larger better
48
+ if best_score.nil? || score < best_score
49
+ best_score = score
50
+ best_iter = iteration
51
+ best_message = message
52
+ elsif iteration - best_iter >= early_stopping_rounds
53
+ booster.best_iteration = best_iter
54
+ puts "Stopping. Best iteration:\n#{best_message}" if verbose_eval
55
+ break
56
+ end
57
+ end
21
58
  end
22
59
 
23
60
  booster
24
61
  end
62
+
63
+ def cv(params, dtrain, num_boost_round: 10, nfold: 3, seed: 0, shuffle: true, verbose_eval: nil, show_stdv: true, early_stopping_rounds: nil)
64
+ rand_idx = (0...dtrain.num_row).to_a
65
+ rand_idx.shuffle!(random: Random.new(seed)) if shuffle
66
+
67
+ kstep = (rand_idx.size / nfold.to_f).ceil
68
+ test_id = rand_idx.each_slice(kstep).to_a[0...nfold]
69
+ train_id = []
70
+ nfold.times do |i|
71
+ idx = test_id.dup
72
+ idx.delete_at(i)
73
+ train_id << idx.flatten
74
+ end
75
+
76
+ folds = train_id.zip(test_id)
77
+ cvfolds = []
78
+ folds.each do |(train_idx, test_idx)|
79
+ fold_dtrain = dtrain.slice(train_idx)
80
+ fold_dvalid = dtrain.slice(test_idx)
81
+ booster = Booster.new(params: params)
82
+ booster.set_param("num_feature", dtrain.num_col)
83
+ cvfolds << [booster, fold_dtrain, fold_dvalid]
84
+ end
85
+
86
+ eval_hist = {}
87
+
88
+ if early_stopping_rounds
89
+ best_score = nil
90
+ best_iter = nil
91
+ end
92
+
93
+ num_boost_round.times do |iteration|
94
+ scores = {}
95
+
96
+ cvfolds.each do |(booster, fold_dtrain, fold_dvalid)|
97
+ booster.update(fold_dtrain, iteration)
98
+ message = booster.eval_set([[fold_dtrain, "train"], [fold_dvalid, "test"]], iteration)
99
+
100
+ res = message.split.map { |x| x.split(":") }[1..-1].map { |k, v| [k, v.to_f] }
101
+ res.each do |k, v|
102
+ (scores[k] ||= []) << v
103
+ end
104
+ end
105
+
106
+ message_parts = ["[#{iteration}]"]
107
+
108
+ last_mean = nil
109
+ means = {}
110
+ scores.each do |eval_name, vals|
111
+ mean = mean(vals)
112
+ stdev = stdev(vals)
113
+
114
+ (eval_hist["#{eval_name}-mean"] ||= []) << mean
115
+ (eval_hist["#{eval_name}-std"] ||= []) << stdev
116
+
117
+ means[eval_name] = mean
118
+ last_mean = mean
119
+
120
+ if show_stdv
121
+ message_parts << "%s:%g+%g" % [eval_name, mean, stdev]
122
+ else
123
+ message_parts << "%s:%g" % [eval_name, mean]
124
+ end
125
+ end
126
+
127
+ if early_stopping_rounds
128
+ score = last_mean
129
+ # TODO handle larger better
130
+ if best_score.nil? || score < best_score
131
+ best_score = score
132
+ best_iter = iteration
133
+ elsif iteration - best_iter >= early_stopping_rounds
134
+ eval_hist.each_key do |k|
135
+ eval_hist[k] = eval_hist[k][0..best_iter]
136
+ end
137
+ break
138
+ end
139
+ end
140
+
141
+ # put at end to keep output consistent with Python
142
+ puts message_parts.join("\t") if verbose_eval
143
+ end
144
+
145
+ eval_hist
146
+ end
147
+
148
+ private
149
+
150
+ def mean(arr)
151
+ arr.sum / arr.size.to_f
152
+ end
153
+
154
+ # don't subtract one from arr.size
155
+ def stdev(arr)
156
+ m = mean(arr)
157
+ sum = 0
158
+ arr.each do |v|
159
+ sum += (v - m) ** 2
160
+ end
161
+ Math.sqrt(sum / arr.size)
162
+ end
25
163
  end
26
164
  end
@@ -1,5 +1,7 @@
1
1
  module Xgb
2
2
  class Booster
3
+ attr_accessor :best_iteration, :feature_names
4
+
3
5
  def initialize(params: nil, model_file: nil)
4
6
  @handle = ::FFI::MemoryPointer.new(:pointer)
5
7
  check_result FFI.XGBoosterCreate(nil, 0, @handle)
@@ -7,14 +9,28 @@ module Xgb
7
9
  check_result FFI.XGBoosterLoadModel(handle_pointer, model_file)
8
10
  end
9
11
 
12
+ self.best_iteration = 0
10
13
  set_param(params)
11
- @num_class = (params && params[:num_class]) || 1
12
14
  end
13
15
 
14
16
  def update(dtrain, iteration)
15
17
  check_result FFI.XGBoosterUpdateOneIter(handle_pointer, iteration, dtrain.handle_pointer)
16
18
  end
17
19
 
20
+ def eval_set(evals, iteration)
21
+ dmats = ::FFI::MemoryPointer.new(:pointer, evals.size)
22
+ dmats.write_array_of_pointer(evals.map { |v| v[0].handle_pointer })
23
+
24
+ evnames = ::FFI::MemoryPointer.new(:pointer, evals.size)
25
+ evnames.write_array_of_pointer(evals.map { |v| ::FFI::MemoryPointer.from_string(v[1]) })
26
+
27
+ out_result = ::FFI::MemoryPointer.new(:pointer)
28
+
29
+ check_result FFI.XGBoosterEvalOneIter(handle_pointer, iteration, dmats, evnames, evals.size, out_result)
30
+
31
+ out_result.read_pointer.read_string
32
+ end
33
+
18
34
  def set_param(params, value = nil)
19
35
  if params.is_a?(Enumerable)
20
36
  params.each do |k, v|
@@ -27,11 +43,12 @@ module Xgb
27
43
 
28
44
  def predict(data, ntree_limit: nil)
29
45
  ntree_limit ||= 0
30
- out_len = ::FFI::MemoryPointer.new(:long)
46
+ out_len = ::FFI::MemoryPointer.new(:ulong)
31
47
  out_result = ::FFI::MemoryPointer.new(:pointer)
32
48
  check_result FFI.XGBoosterPredict(handle_pointer, data.handle_pointer, 0, ntree_limit, out_len, out_result)
33
- out = out_result.read_pointer.read_array_of_float(out_len.read_long)
34
- out = out.each_slice(@num_class).to_a if @num_class > 1
49
+ out = out_result.read_pointer.read_array_of_float(out_len.read_ulong)
50
+ num_class = out.size / data.num_row
51
+ out = out.each_slice(num_class).to_a if num_class > 1
35
52
  out
36
53
  end
37
54
 
@@ -39,6 +56,97 @@ module Xgb
39
56
  check_result FFI.XGBoosterSaveModel(handle_pointer, fname)
40
57
  end
41
58
 
59
+ # returns an array of strings
60
+ def dump(fmap: "", with_stats: false, dump_format: "text")
61
+ out_len = ::FFI::MemoryPointer.new(:ulong)
62
+ out_result = ::FFI::MemoryPointer.new(:pointer)
63
+ check_result FFI.XGBoosterDumpModelEx(handle_pointer, fmap, with_stats ? 1 : 0, dump_format, out_len, out_result)
64
+ out_result.read_pointer.get_array_of_string(0, out_len.read_ulong)
65
+ end
66
+
67
+ def dump_model(fout, fmap: "", with_stats: false, dump_format: "text")
68
+ ret = dump(fmap: fmap, with_stats: with_stats, dump_format: dump_format)
69
+ File.open(fout, "wb") do |f|
70
+ if dump_format == "json"
71
+ f.print("[\n")
72
+ ret.each_with_index do |r, i|
73
+ f.print(r)
74
+ f.print(",\n") if i < ret.size - 1
75
+ end
76
+ f.print("\n]")
77
+ else
78
+ ret.each_with_index do |r, i|
79
+ f.print("booster[#{i}]:\n")
80
+ f.print(r)
81
+ end
82
+ end
83
+ end
84
+ end
85
+
86
+ def fscore(fmap: "")
87
+ # always weight
88
+ score(fmap: fmap, importance_type: "weight")
89
+ end
90
+
91
+ def score(fmap: "", importance_type: "weight")
92
+ if importance_type == "weight"
93
+ trees = dump(fmap: fmap, with_stats: false)
94
+ fmap = {}
95
+ trees.each do |tree|
96
+ tree.split("\n").each do |line|
97
+ arr = line.split("[")
98
+ next if arr.size == 1
99
+
100
+ fid = arr[1].split("]")[0].split("<")[0]
101
+ fmap[fid] ||= 0
102
+ fmap[fid] += 1
103
+ end
104
+ end
105
+ fmap
106
+ else
107
+ average_over_splits = true
108
+ if importance_type == "total_gain"
109
+ importance_type = "gain"
110
+ average_over_splits = false
111
+ elsif importance_type == "total_cover"
112
+ importance_type = "cover"
113
+ average_over_splits = false
114
+ end
115
+
116
+ trees = dump(fmap: fmap, with_stats: true)
117
+
118
+ importance_type += "="
119
+ fmap = {}
120
+ gmap = {}
121
+ trees.each do |tree|
122
+ tree.split("\n").each do |line|
123
+ arr = line.split("[")
124
+ next if arr.size == 1
125
+
126
+ fid = arr[1].split("]")
127
+
128
+ g = fid[1].split(importance_type)[1].split(",")[0].to_f
129
+
130
+ fid = fid[0].split("<")[0]
131
+
132
+ fmap[fid] ||= 0
133
+ gmap[fid] ||= 0
134
+
135
+ fmap[fid] += 1
136
+ gmap[fid] += g
137
+ end
138
+ end
139
+
140
+ if average_over_splits
141
+ gmap.each_key do |fid|
142
+ gmap[fid] = gmap[fid] / fmap[fid]
143
+ end
144
+ end
145
+
146
+ gmap
147
+ end
148
+ end
149
+
42
150
  private
43
151
 
44
152
  def handle_pointer
@@ -0,0 +1,68 @@
1
+ module Xgb
2
+ class Classifier
3
+ def initialize(max_depth: 3, learning_rate: 0.1, n_estimators: 100, objective: "binary:logistic", importance_type: "gain")
4
+ @params = {
5
+ max_depth: max_depth,
6
+ objective: objective,
7
+ learning_rate: learning_rate
8
+ }
9
+ @n_estimators = n_estimators
10
+ @importance_type = importance_type
11
+ end
12
+
13
+ def fit(x, y)
14
+ n_classes = y.uniq.size
15
+
16
+ params = @params.dup
17
+ if n_classes > 2
18
+ params[:objective] = "multi:softprob"
19
+ params[:num_class] = n_classes
20
+ end
21
+
22
+ dtrain = DMatrix.new(x, label: y)
23
+ @booster = Xgb.train(params, dtrain, num_boost_round: @n_estimators)
24
+ nil
25
+ end
26
+
27
+ def predict(data)
28
+ dmat = DMatrix.new(data)
29
+ y_pred = @booster.predict(dmat)
30
+
31
+ if y_pred.first.is_a?(Array)
32
+ # multiple classes
33
+ y_pred.map do |v|
34
+ v.map.with_index.max_by { |v2, i| v2 }.last
35
+ end
36
+ else
37
+ y_pred.map { |v| v > 0.5 ? 1 : 0 }
38
+ end
39
+ end
40
+
41
+ def predict_proba(data)
42
+ dmat = DMatrix.new(data)
43
+ y_pred = @booster.predict(dmat)
44
+
45
+ if y_pred.first.is_a?(Array)
46
+ # multiple classes
47
+ y_pred
48
+ else
49
+ y_pred.map { |v| [1 - v, v] }
50
+ end
51
+ end
52
+
53
+ def save_model(fname)
54
+ @booster.save_model(fname)
55
+ end
56
+
57
+ def load_model(fname)
58
+ @booster = Booster.new(params: @params, model_file: fname)
59
+ end
60
+
61
+ def feature_importances
62
+ score = @booster.score(importance_type: @importance_type)
63
+ scores = @booster.feature_names.map { |k| score[k] || 0.0 }
64
+ total = scores.sum.to_f
65
+ scores.map { |s| s / total }
66
+ end
67
+ end
68
+ end
@@ -1,24 +1,72 @@
1
1
  module Xgb
2
2
  class DMatrix
3
- attr_reader :data, :label, :weight
3
+ attr_reader :data
4
4
 
5
5
  def initialize(data, label: nil, weight: nil, missing: Float::NAN)
6
6
  @data = data
7
- @label = label
8
- @weight = weight
9
7
 
10
- c_data = ::FFI::MemoryPointer.new(:float, data.count * data.first.count)
11
- c_data.put_array_of_float(0, data.flatten)
12
8
  @handle = ::FFI::MemoryPointer.new(:pointer)
13
- check_result FFI.XGDMatrixCreateFromMat(c_data, data.count, data.first.count, missing, @handle)
9
+
10
+ if data
11
+ if matrix?(data)
12
+ nrow = data.row_count
13
+ ncol = data.column_count
14
+ flat_data = data.to_a.flatten
15
+ elsif daru?(data)
16
+ nrow, ncol = data.shape
17
+ flat_data = data.each_vector.map(&:to_a).flatten
18
+ elsif narray?(data)
19
+ nrow, ncol = data.shape
20
+ flat_data = data.flatten.to_a
21
+ else
22
+ nrow = data.count
23
+ ncol = data.first.count
24
+ flat_data = data.flatten
25
+ end
26
+
27
+ c_data = ::FFI::MemoryPointer.new(:float, nrow * ncol)
28
+ c_data.put_array_of_float(0, flat_data)
29
+ check_result FFI.XGDMatrixCreateFromMat(c_data, nrow, ncol, missing, @handle)
30
+ end
14
31
 
15
32
  set_float_info("label", label) if label
33
+ set_float_info("weight", weight) if weight
34
+ end
35
+
36
+ def label
37
+ float_info("label")
38
+ end
39
+
40
+ def weight
41
+ float_info("weight")
42
+ end
43
+
44
+ def num_row
45
+ out = ::FFI::MemoryPointer.new(:ulong)
46
+ check_result FFI.XGDMatrixNumRow(handle_pointer, out)
47
+ out.read_ulong
16
48
  end
17
49
 
18
50
  def num_col
19
- out = ::FFI::MemoryPointer.new(:long)
20
- FFI.XGDMatrixNumCol(handle_pointer, out)
21
- out.read_long
51
+ out = ::FFI::MemoryPointer.new(:ulong)
52
+ check_result FFI.XGDMatrixNumCol(handle_pointer, out)
53
+ out.read_ulong
54
+ end
55
+
56
+ def slice(rindex)
57
+ res = DMatrix.new(nil)
58
+ idxset = ::FFI::MemoryPointer.new(:int, rindex.count)
59
+ idxset.put_array_of_int(0, rindex)
60
+ check_result FFI.XGDMatrixSliceDMatrix(handle_pointer, idxset, rindex.size, res.handle)
61
+ res
62
+ end
63
+
64
+ def save_binary(fname, silent: true)
65
+ check_result FFI.XGDMatrixSaveBinary(handle_pointer, fname, silent ? 1 : 0)
66
+ end
67
+
68
+ def handle
69
+ @handle
22
70
  end
23
71
 
24
72
  def handle_pointer
@@ -28,11 +76,44 @@ module Xgb
28
76
  private
29
77
 
30
78
  def set_float_info(field, data)
31
- c_data = ::FFI::MemoryPointer.new(:float, data.count)
79
+ data =
80
+ if matrix?(data)
81
+ data.to_a[0]
82
+ elsif daru_vector?(data) || narray?(data)
83
+ data.to_a
84
+ else
85
+ data
86
+ end
87
+
88
+ c_data = ::FFI::MemoryPointer.new(:float, data.size)
32
89
  c_data.put_array_of_float(0, data)
33
90
  check_result FFI.XGDMatrixSetFloatInfo(handle_pointer, field.to_s, c_data, data.size)
34
91
  end
35
92
 
93
+ def float_info(field)
94
+ num_row ||= num_row()
95
+ out_len = ::FFI::MemoryPointer.new(:int)
96
+ out_dptr = ::FFI::MemoryPointer.new(:float, num_row)
97
+ check_result FFI.XGDMatrixGetFloatInfo(handle_pointer, field, out_len, out_dptr)
98
+ out_dptr.read_pointer.read_array_of_float(num_row)
99
+ end
100
+
101
+ def matrix?(data)
102
+ defined?(Matrix) && data.is_a?(Matrix)
103
+ end
104
+
105
+ def daru?(data)
106
+ defined?(Daru::DataFrame) && data.is_a?(Daru::DataFrame)
107
+ end
108
+
109
+ def daru_vector?(data)
110
+ defined?(Daru::Vector) && data.is_a?(Daru::Vector)
111
+ end
112
+
113
+ def narray?(data)
114
+ defined?(Numo::NArray) && data.is_a?(Numo::NArray)
115
+ end
116
+
36
117
  include Utils
37
118
  end
38
119
  end
@@ -10,16 +10,22 @@ module Xgb
10
10
  attach_function :XGBGetLastError, %i[], :string
11
11
 
12
12
  # dmatrix
13
- attach_function :XGDMatrixCreateFromMat, %i[pointer long long float pointer], :int
13
+ attach_function :XGDMatrixCreateFromMat, %i[pointer ulong ulong float pointer], :int
14
+ attach_function :XGDMatrixNumRow, %i[pointer pointer], :int
14
15
  attach_function :XGDMatrixNumCol, %i[pointer pointer], :int
15
- attach_function :XGDMatrixSetFloatInfo, %i[pointer string pointer long], :int
16
+ attach_function :XGDMatrixSliceDMatrix, %i[pointer pointer ulong pointer], :int
17
+ attach_function :XGDMatrixSaveBinary, %i[pointer string int], :int
18
+ attach_function :XGDMatrixSetFloatInfo, %i[pointer string pointer ulong], :int
19
+ attach_function :XGDMatrixGetFloatInfo, %i[pointer string pointer pointer], :int
16
20
 
17
21
  # booster
18
22
  attach_function :XGBoosterCreate, %i[pointer int pointer], :int
19
23
  attach_function :XGBoosterUpdateOneIter, %i[pointer int pointer], :int
24
+ attach_function :XGBoosterEvalOneIter, %i[pointer int pointer pointer ulong pointer], :int
20
25
  attach_function :XGBoosterSetParam, %i[pointer string string], :int
21
26
  attach_function :XGBoosterPredict, %i[pointer pointer int int pointer pointer], :int
22
27
  attach_function :XGBoosterLoadModel, %i[pointer string], :int
23
28
  attach_function :XGBoosterSaveModel, %i[pointer string], :int
29
+ attach_function :XGBoosterDumpModelEx, %i[pointer string int string pointer pointer], :int
24
30
  end
25
31
  end
@@ -0,0 +1,39 @@
1
+ module Xgb
2
+ class Regressor
3
+ def initialize(max_depth: 3, learning_rate: 0.1, n_estimators: 100, objective: "reg:squarederror", importance_type: "gain")
4
+ @params = {
5
+ max_depth: max_depth,
6
+ objective: objective,
7
+ learning_rate: learning_rate
8
+ }
9
+ @n_estimators = n_estimators
10
+ @importance_type = importance_type
11
+ end
12
+
13
+ def fit(x, y)
14
+ dtrain = DMatrix.new(x, label: y)
15
+ @booster = Xgb.train(@params, dtrain, num_boost_round: @n_estimators)
16
+ nil
17
+ end
18
+
19
+ def predict(data)
20
+ dmat = DMatrix.new(data)
21
+ @booster.predict(dmat)
22
+ end
23
+
24
+ def save_model(fname)
25
+ @booster.save_model(fname)
26
+ end
27
+
28
+ def load_model(fname)
29
+ @booster = Booster.new(params: @params, model_file: fname)
30
+ end
31
+
32
+ def feature_importances
33
+ score = @booster.score(importance_type: @importance_type)
34
+ scores = @booster.feature_names.map { |k| score[k] || 0.0 }
35
+ total = scores.sum.to_f
36
+ scores.map { |s| s / total }
37
+ end
38
+ end
39
+ end
@@ -3,7 +3,11 @@ module Xgb
3
3
  private
4
4
 
5
5
  def check_result(err)
6
- raise Xgb::Error, FFI.XGBGetLastError if err != 0
6
+ if err != 0
7
+ # make friendly
8
+ message = FFI.XGBGetLastError.split("\n").first.split(/:\d+: /, 2).last
9
+ raise Xgb::Error, message
10
+ end
7
11
  end
8
12
  end
9
13
  end
@@ -1,3 +1,3 @@
1
1
  module Xgb
2
- VERSION = "0.1.0"
2
+ VERSION = "0.1.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: xgb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-08-15 00:00:00.000000000 Z
11
+ date: 2019-08-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: ffi
@@ -66,6 +66,34 @@ dependencies:
66
66
  - - ">="
67
67
  - !ruby/object:Gem::Version
68
68
  version: '5'
69
+ - !ruby/object:Gem::Dependency
70
+ name: daru
71
+ requirement: !ruby/object:Gem::Requirement
72
+ requirements:
73
+ - - ">="
74
+ - !ruby/object:Gem::Version
75
+ version: '0'
76
+ type: :development
77
+ prerelease: false
78
+ version_requirements: !ruby/object:Gem::Requirement
79
+ requirements:
80
+ - - ">="
81
+ - !ruby/object:Gem::Version
82
+ version: '0'
83
+ - !ruby/object:Gem::Dependency
84
+ name: numo-narray
85
+ requirement: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - ">="
88
+ - !ruby/object:Gem::Version
89
+ version: '0'
90
+ type: :development
91
+ prerelease: false
92
+ version_requirements: !ruby/object:Gem::Requirement
93
+ requirements:
94
+ - - ">="
95
+ - !ruby/object:Gem::Version
96
+ version: '0'
69
97
  description:
70
98
  email: andrew@chartkick.com
71
99
  executables: []
@@ -76,8 +104,10 @@ files:
76
104
  - README.md
77
105
  - lib/xgb.rb
78
106
  - lib/xgb/booster.rb
107
+ - lib/xgb/classifier.rb
79
108
  - lib/xgb/dmatrix.rb
80
109
  - lib/xgb/ffi.rb
110
+ - lib/xgb/regressor.rb
81
111
  - lib/xgb/utils.rb
82
112
  - lib/xgb/version.rb
83
113
  homepage: https://github.com/ankane/xgb