xgb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: dcb0666c9eb0943a77ccdbaa31074b8ca0b139ac196006b5cd6d680f6f91ac30
4
+ data.tar.gz: 673040af645aafc5f14378121b9dac594b5dc6dae0449f75244cc9cfeaaa10e3
5
+ SHA512:
6
+ metadata.gz: 3f2838ee9f2b69ea69fdc9f3610a36f412acb8732f726478995dad0ebcf73325f99c6c3afc480ae60458035619c83f9e2c9a03d60020cf8e734ea250334541c0
7
+ data.tar.gz: 69034ec7cad4174837cc8347ab5cc13144532701927aa702eb93c6819ddf8ffe6d3c5c3752b39de863178e7d404cef017ddd82b7fe811aa4d5f3b62aa051729d
@@ -0,0 +1,3 @@
1
+ ## 0.1.0
2
+
3
+ - First release
@@ -0,0 +1,74 @@
1
+ # Xgb
2
+
3
+ [XGBoost](https://github.com/dmlc/xgboost) - the high performance machine learning library - for Ruby
4
+
5
+ :fire: Uses the C API for blazing performance
6
+
7
+ ## Installation
8
+
9
+ First, [install XGBoost](https://xgboost.readthedocs.io/en/latest/build.html). On Mac, copy `lib/libxgboost.dylib` to `/usr/local/lib`.
10
+
11
+ Add this line to your application’s Gemfile:
12
+
13
+ ```ruby
14
+ gem 'xgb'
15
+ ```
16
+
17
+ ## Getting Started
18
+
19
+ Train a model
20
+
21
+ ```ruby
22
+ params = {objective: "reg:squarederror"}
23
+ train_set = Xgb::DMatrix.new(x_train, label: y_train)
24
+ booster = Xgb.train(params, train_set)
25
+ ```
26
+
27
+ Predict
28
+
29
+ ```ruby
30
+ booster.predict(x_test)
31
+ ```
32
+
33
+ Save the model to a file
34
+
35
+ ```ruby
36
+ booster.save_model("model.txt")
37
+ ```
38
+
39
+ Load the model from a file
40
+
41
+ ```ruby
42
+ booster = Xgb::Booster.new(model_file: "model.txt")
43
+ ```
44
+
45
+ ## Reference
46
+
47
+ This library follows the [Core Data Structure and Learning APIs](https://xgboost.readthedocs.io/en/latest/python/python_api.html) for the Python library. Some methods and options are missing at the moment. PRs welcome!
48
+
49
+ ## Helpful Resources
50
+
51
+ - [Parameters](https://xgboost.readthedocs.io/en/latest/parameter.html)
52
+ - [Parameter Tuning](https://xgboost.readthedocs.io/en/latest/tutorials/param_tuning.html)
53
+
54
+ ## Related Projects
55
+
56
+ - [LightGBM](https://github.com/ankane/lightgbm) - LightGBM for Ruby
57
+ - [Eps](https://github.com/ankane/eps) - Machine Learning for Ruby
58
+
59
+ ## Credits
60
+
61
+ Thanks to the [xgboost](https://github.com/PairOnAir/xgboost-ruby) gem for serving as an initial reference, and Selva Prabhakaran for the [test datasets](https://github.com/selva86/datasets).
62
+
63
+ ## History
64
+
65
+ View the [changelog](https://github.com/ankane/xgb/blob/master/CHANGELOG.md)
66
+
67
+ ## Contributing
68
+
69
+ Everyone is encouraged to help improve this project. Here are a few ways you can help:
70
+
71
+ - [Report bugs](https://github.com/ankane/xgb/issues)
72
+ - Fix bugs and [submit pull requests](https://github.com/ankane/xgb/pulls)
73
+ - Write, clarify, or fix documentation
74
+ - Suggest or add new features
@@ -0,0 +1,26 @@
1
+ # dependencies
2
+ require "ffi"
3
+
4
+ # modules
5
+ require "xgb/utils"
6
+ require "xgb/booster"
7
+ require "xgb/dmatrix"
8
+ require "xgb/ffi"
9
+ require "xgb/version"
10
+
11
+ module Xgb
12
+ class Error < StandardError; end
13
+
14
+ class << self
15
+ def train(params, dtrain, num_boost_round: 10)
16
+ booster = Booster.new(params: params)
17
+ booster.set_param("num_feature", dtrain.num_col)
18
+
19
+ num_boost_round.times do |iteration|
20
+ booster.update(dtrain, iteration)
21
+ end
22
+
23
+ booster
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,50 @@
1
+ module Xgb
2
+ class Booster
3
+ def initialize(params: nil, model_file: nil)
4
+ @handle = ::FFI::MemoryPointer.new(:pointer)
5
+ check_result FFI.XGBoosterCreate(nil, 0, @handle)
6
+ if model_file
7
+ check_result FFI.XGBoosterLoadModel(handle_pointer, model_file)
8
+ end
9
+
10
+ set_param(params)
11
+ @num_class = (params && params[:num_class]) || 1
12
+ end
13
+
14
+ def update(dtrain, iteration)
15
+ check_result FFI.XGBoosterUpdateOneIter(handle_pointer, iteration, dtrain.handle_pointer)
16
+ end
17
+
18
+ def set_param(params, value = nil)
19
+ if params.is_a?(Enumerable)
20
+ params.each do |k, v|
21
+ check_result FFI.XGBoosterSetParam(handle_pointer, k.to_s, v.to_s)
22
+ end
23
+ else
24
+ check_result FFI.XGBoosterSetParam(handle_pointer, params.to_s, value.to_s)
25
+ end
26
+ end
27
+
28
+ def predict(data, ntree_limit: nil)
29
+ ntree_limit ||= 0
30
+ out_len = ::FFI::MemoryPointer.new(:long)
31
+ out_result = ::FFI::MemoryPointer.new(:pointer)
32
+ check_result FFI.XGBoosterPredict(handle_pointer, data.handle_pointer, 0, ntree_limit, out_len, out_result)
33
+ out = out_result.read_pointer.read_array_of_float(out_len.read_long)
34
+ out = out.each_slice(@num_class).to_a if @num_class > 1
35
+ out
36
+ end
37
+
38
+ def save_model(fname)
39
+ check_result FFI.XGBoosterSaveModel(handle_pointer, fname)
40
+ end
41
+
42
+ private
43
+
44
+ def handle_pointer
45
+ @handle.read_pointer
46
+ end
47
+
48
+ include Utils
49
+ end
50
+ end
@@ -0,0 +1,38 @@
1
+ module Xgb
2
+ class DMatrix
3
+ attr_reader :data, :label, :weight
4
+
5
+ def initialize(data, label: nil, weight: nil, missing: Float::NAN)
6
+ @data = data
7
+ @label = label
8
+ @weight = weight
9
+
10
+ c_data = ::FFI::MemoryPointer.new(:float, data.count * data.first.count)
11
+ c_data.put_array_of_float(0, data.flatten)
12
+ @handle = ::FFI::MemoryPointer.new(:pointer)
13
+ check_result FFI.XGDMatrixCreateFromMat(c_data, data.count, data.first.count, missing, @handle)
14
+
15
+ set_float_info("label", label) if label
16
+ end
17
+
18
+ def num_col
19
+ out = ::FFI::MemoryPointer.new(:long)
20
+ FFI.XGDMatrixNumCol(handle_pointer, out)
21
+ out.read_long
22
+ end
23
+
24
+ def handle_pointer
25
+ @handle.read_pointer
26
+ end
27
+
28
+ private
29
+
30
+ def set_float_info(field, data)
31
+ c_data = ::FFI::MemoryPointer.new(:float, data.count)
32
+ c_data.put_array_of_float(0, data)
33
+ check_result FFI.XGDMatrixSetFloatInfo(handle_pointer, field.to_s, c_data, data.size)
34
+ end
35
+
36
+ include Utils
37
+ end
38
+ end
@@ -0,0 +1,25 @@
1
+ module Xgb
2
+ module FFI
3
+ extend ::FFI::Library
4
+ ffi_lib ["xgboost"]
5
+
6
+ # https://github.com/dmlc/xgboost/blob/master/include/xgboost/c_api.h
7
+ # keep same order
8
+
9
+ # error
10
+ attach_function :XGBGetLastError, %i[], :string
11
+
12
+ # dmatrix
13
+ attach_function :XGDMatrixCreateFromMat, %i[pointer long long float pointer], :int
14
+ attach_function :XGDMatrixNumCol, %i[pointer pointer], :int
15
+ attach_function :XGDMatrixSetFloatInfo, %i[pointer string pointer long], :int
16
+
17
+ # booster
18
+ attach_function :XGBoosterCreate, %i[pointer int pointer], :int
19
+ attach_function :XGBoosterUpdateOneIter, %i[pointer int pointer], :int
20
+ attach_function :XGBoosterSetParam, %i[pointer string string], :int
21
+ attach_function :XGBoosterPredict, %i[pointer pointer int int pointer pointer], :int
22
+ attach_function :XGBoosterLoadModel, %i[pointer string], :int
23
+ attach_function :XGBoosterSaveModel, %i[pointer string], :int
24
+ end
25
+ end
@@ -0,0 +1,9 @@
1
+ module Xgb
2
+ module Utils
3
+ private
4
+
5
+ def check_result(err)
6
+ raise Xgb::Error, FFI.XGBGetLastError if err != 0
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,3 @@
1
+ module Xgb
2
+ VERSION = "0.1.0"
3
+ end
metadata ADDED
@@ -0,0 +1,106 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: xgb
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Andrew Kane
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2019-08-15 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: ffi
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: '0'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: '0'
27
+ - !ruby/object:Gem::Dependency
28
+ name: bundler
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '0'
34
+ type: :development
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: rake
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: '0'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
55
+ - !ruby/object:Gem::Dependency
56
+ name: minitest
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - ">="
60
+ - !ruby/object:Gem::Version
61
+ version: '5'
62
+ type: :development
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - ">="
67
+ - !ruby/object:Gem::Version
68
+ version: '5'
69
+ description:
70
+ email: andrew@chartkick.com
71
+ executables: []
72
+ extensions: []
73
+ extra_rdoc_files: []
74
+ files:
75
+ - CHANGELOG.md
76
+ - README.md
77
+ - lib/xgb.rb
78
+ - lib/xgb/booster.rb
79
+ - lib/xgb/dmatrix.rb
80
+ - lib/xgb/ffi.rb
81
+ - lib/xgb/utils.rb
82
+ - lib/xgb/version.rb
83
+ homepage: https://github.com/ankane/xgb
84
+ licenses:
85
+ - MIT
86
+ metadata: {}
87
+ post_install_message:
88
+ rdoc_options: []
89
+ require_paths:
90
+ - lib
91
+ required_ruby_version: !ruby/object:Gem::Requirement
92
+ requirements:
93
+ - - ">="
94
+ - !ruby/object:Gem::Version
95
+ version: '2.4'
96
+ required_rubygems_version: !ruby/object:Gem::Requirement
97
+ requirements:
98
+ - - ">="
99
+ - !ruby/object:Gem::Version
100
+ version: '0'
101
+ requirements: []
102
+ rubygems_version: 3.0.4
103
+ signing_key:
104
+ specification_version: 4
105
+ summary: XGBoost - the high performance machine learning library - for Ruby
106
+ test_files: []