xgb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/README.md +74 -0
- data/lib/xgb.rb +26 -0
- data/lib/xgb/booster.rb +50 -0
- data/lib/xgb/dmatrix.rb +38 -0
- data/lib/xgb/ffi.rb +25 -0
- data/lib/xgb/utils.rb +9 -0
- data/lib/xgb/version.rb +3 -0
- metadata +106 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: dcb0666c9eb0943a77ccdbaa31074b8ca0b139ac196006b5cd6d680f6f91ac30
|
4
|
+
data.tar.gz: 673040af645aafc5f14378121b9dac594b5dc6dae0449f75244cc9cfeaaa10e3
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 3f2838ee9f2b69ea69fdc9f3610a36f412acb8732f726478995dad0ebcf73325f99c6c3afc480ae60458035619c83f9e2c9a03d60020cf8e734ea250334541c0
|
7
|
+
data.tar.gz: 69034ec7cad4174837cc8347ab5cc13144532701927aa702eb93c6819ddf8ffe6d3c5c3752b39de863178e7d404cef017ddd82b7fe811aa4d5f3b62aa051729d
|
data/CHANGELOG.md
ADDED
data/README.md
ADDED
@@ -0,0 +1,74 @@
|
|
1
|
+
# Xgb
|
2
|
+
|
3
|
+
[XGBoost](https://github.com/dmlc/xgboost) - the high performance machine learning library - for Ruby
|
4
|
+
|
5
|
+
:fire: Uses the C API for blazing performance
|
6
|
+
|
7
|
+
## Installation
|
8
|
+
|
9
|
+
First, [install XGBoost](https://xgboost.readthedocs.io/en/latest/build.html). On Mac, copy `lib/libxgboost.dylib` to `/usr/local/lib`.
|
10
|
+
|
11
|
+
Add this line to your application’s Gemfile:
|
12
|
+
|
13
|
+
```ruby
|
14
|
+
gem 'xgb'
|
15
|
+
```
|
16
|
+
|
17
|
+
## Getting Started
|
18
|
+
|
19
|
+
Train a model
|
20
|
+
|
21
|
+
```ruby
|
22
|
+
params = {objective: "reg:squarederror"}
|
23
|
+
train_set = Xgb::DMatrix.new(x_train, label: y_train)
|
24
|
+
booster = Xgb.train(params, train_set)
|
25
|
+
```
|
26
|
+
|
27
|
+
Predict
|
28
|
+
|
29
|
+
```ruby
|
30
|
+
booster.predict(x_test)
|
31
|
+
```
|
32
|
+
|
33
|
+
Save the model to a file
|
34
|
+
|
35
|
+
```ruby
|
36
|
+
booster.save_model("model.txt")
|
37
|
+
```
|
38
|
+
|
39
|
+
Load the model from a file
|
40
|
+
|
41
|
+
```ruby
|
42
|
+
booster = Xgb::Booster.new(model_file: "model.txt")
|
43
|
+
```
|
44
|
+
|
45
|
+
## Reference
|
46
|
+
|
47
|
+
This library follows the [Core Data Structure and Learning APIs](https://xgboost.readthedocs.io/en/latest/python/python_api.html) for the Python library. Some methods and options are missing at the moment. PRs welcome!
|
48
|
+
|
49
|
+
## Helpful Resources
|
50
|
+
|
51
|
+
- [Parameters](https://xgboost.readthedocs.io/en/latest/parameter.html)
|
52
|
+
- [Parameter Tuning](https://xgboost.readthedocs.io/en/latest/tutorials/param_tuning.html)
|
53
|
+
|
54
|
+
## Related Projects
|
55
|
+
|
56
|
+
- [LightGBM](https://github.com/ankane/lightgbm) - LightGBM for Ruby
|
57
|
+
- [Eps](https://github.com/ankane/eps) - Machine Learning for Ruby
|
58
|
+
|
59
|
+
## Credits
|
60
|
+
|
61
|
+
Thanks to the [xgboost](https://github.com/PairOnAir/xgboost-ruby) gem for serving as an initial reference, and Selva Prabhakaran for the [test datasets](https://github.com/selva86/datasets).
|
62
|
+
|
63
|
+
## History
|
64
|
+
|
65
|
+
View the [changelog](https://github.com/ankane/xgb/blob/master/CHANGELOG.md)
|
66
|
+
|
67
|
+
## Contributing
|
68
|
+
|
69
|
+
Everyone is encouraged to help improve this project. Here are a few ways you can help:
|
70
|
+
|
71
|
+
- [Report bugs](https://github.com/ankane/xgb/issues)
|
72
|
+
- Fix bugs and [submit pull requests](https://github.com/ankane/xgb/pulls)
|
73
|
+
- Write, clarify, or fix documentation
|
74
|
+
- Suggest or add new features
|
data/lib/xgb.rb
ADDED
@@ -0,0 +1,26 @@
|
|
1
|
+
# dependencies
|
2
|
+
require "ffi"
|
3
|
+
|
4
|
+
# modules
|
5
|
+
require "xgb/utils"
|
6
|
+
require "xgb/booster"
|
7
|
+
require "xgb/dmatrix"
|
8
|
+
require "xgb/ffi"
|
9
|
+
require "xgb/version"
|
10
|
+
|
11
|
+
module Xgb
|
12
|
+
class Error < StandardError; end
|
13
|
+
|
14
|
+
class << self
|
15
|
+
def train(params, dtrain, num_boost_round: 10)
|
16
|
+
booster = Booster.new(params: params)
|
17
|
+
booster.set_param("num_feature", dtrain.num_col)
|
18
|
+
|
19
|
+
num_boost_round.times do |iteration|
|
20
|
+
booster.update(dtrain, iteration)
|
21
|
+
end
|
22
|
+
|
23
|
+
booster
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
data/lib/xgb/booster.rb
ADDED
@@ -0,0 +1,50 @@
|
|
1
|
+
module Xgb
|
2
|
+
class Booster
|
3
|
+
def initialize(params: nil, model_file: nil)
|
4
|
+
@handle = ::FFI::MemoryPointer.new(:pointer)
|
5
|
+
check_result FFI.XGBoosterCreate(nil, 0, @handle)
|
6
|
+
if model_file
|
7
|
+
check_result FFI.XGBoosterLoadModel(handle_pointer, model_file)
|
8
|
+
end
|
9
|
+
|
10
|
+
set_param(params)
|
11
|
+
@num_class = (params && params[:num_class]) || 1
|
12
|
+
end
|
13
|
+
|
14
|
+
def update(dtrain, iteration)
|
15
|
+
check_result FFI.XGBoosterUpdateOneIter(handle_pointer, iteration, dtrain.handle_pointer)
|
16
|
+
end
|
17
|
+
|
18
|
+
def set_param(params, value = nil)
|
19
|
+
if params.is_a?(Enumerable)
|
20
|
+
params.each do |k, v|
|
21
|
+
check_result FFI.XGBoosterSetParam(handle_pointer, k.to_s, v.to_s)
|
22
|
+
end
|
23
|
+
else
|
24
|
+
check_result FFI.XGBoosterSetParam(handle_pointer, params.to_s, value.to_s)
|
25
|
+
end
|
26
|
+
end
|
27
|
+
|
28
|
+
def predict(data, ntree_limit: nil)
|
29
|
+
ntree_limit ||= 0
|
30
|
+
out_len = ::FFI::MemoryPointer.new(:long)
|
31
|
+
out_result = ::FFI::MemoryPointer.new(:pointer)
|
32
|
+
check_result FFI.XGBoosterPredict(handle_pointer, data.handle_pointer, 0, ntree_limit, out_len, out_result)
|
33
|
+
out = out_result.read_pointer.read_array_of_float(out_len.read_long)
|
34
|
+
out = out.each_slice(@num_class).to_a if @num_class > 1
|
35
|
+
out
|
36
|
+
end
|
37
|
+
|
38
|
+
def save_model(fname)
|
39
|
+
check_result FFI.XGBoosterSaveModel(handle_pointer, fname)
|
40
|
+
end
|
41
|
+
|
42
|
+
private
|
43
|
+
|
44
|
+
def handle_pointer
|
45
|
+
@handle.read_pointer
|
46
|
+
end
|
47
|
+
|
48
|
+
include Utils
|
49
|
+
end
|
50
|
+
end
|
data/lib/xgb/dmatrix.rb
ADDED
@@ -0,0 +1,38 @@
|
|
1
|
+
module Xgb
|
2
|
+
class DMatrix
|
3
|
+
attr_reader :data, :label, :weight
|
4
|
+
|
5
|
+
def initialize(data, label: nil, weight: nil, missing: Float::NAN)
|
6
|
+
@data = data
|
7
|
+
@label = label
|
8
|
+
@weight = weight
|
9
|
+
|
10
|
+
c_data = ::FFI::MemoryPointer.new(:float, data.count * data.first.count)
|
11
|
+
c_data.put_array_of_float(0, data.flatten)
|
12
|
+
@handle = ::FFI::MemoryPointer.new(:pointer)
|
13
|
+
check_result FFI.XGDMatrixCreateFromMat(c_data, data.count, data.first.count, missing, @handle)
|
14
|
+
|
15
|
+
set_float_info("label", label) if label
|
16
|
+
end
|
17
|
+
|
18
|
+
def num_col
|
19
|
+
out = ::FFI::MemoryPointer.new(:long)
|
20
|
+
FFI.XGDMatrixNumCol(handle_pointer, out)
|
21
|
+
out.read_long
|
22
|
+
end
|
23
|
+
|
24
|
+
def handle_pointer
|
25
|
+
@handle.read_pointer
|
26
|
+
end
|
27
|
+
|
28
|
+
private
|
29
|
+
|
30
|
+
def set_float_info(field, data)
|
31
|
+
c_data = ::FFI::MemoryPointer.new(:float, data.count)
|
32
|
+
c_data.put_array_of_float(0, data)
|
33
|
+
check_result FFI.XGDMatrixSetFloatInfo(handle_pointer, field.to_s, c_data, data.size)
|
34
|
+
end
|
35
|
+
|
36
|
+
include Utils
|
37
|
+
end
|
38
|
+
end
|
data/lib/xgb/ffi.rb
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
module Xgb
|
2
|
+
module FFI
|
3
|
+
extend ::FFI::Library
|
4
|
+
ffi_lib ["xgboost"]
|
5
|
+
|
6
|
+
# https://github.com/dmlc/xgboost/blob/master/include/xgboost/c_api.h
|
7
|
+
# keep same order
|
8
|
+
|
9
|
+
# error
|
10
|
+
attach_function :XGBGetLastError, %i[], :string
|
11
|
+
|
12
|
+
# dmatrix
|
13
|
+
attach_function :XGDMatrixCreateFromMat, %i[pointer long long float pointer], :int
|
14
|
+
attach_function :XGDMatrixNumCol, %i[pointer pointer], :int
|
15
|
+
attach_function :XGDMatrixSetFloatInfo, %i[pointer string pointer long], :int
|
16
|
+
|
17
|
+
# booster
|
18
|
+
attach_function :XGBoosterCreate, %i[pointer int pointer], :int
|
19
|
+
attach_function :XGBoosterUpdateOneIter, %i[pointer int pointer], :int
|
20
|
+
attach_function :XGBoosterSetParam, %i[pointer string string], :int
|
21
|
+
attach_function :XGBoosterPredict, %i[pointer pointer int int pointer pointer], :int
|
22
|
+
attach_function :XGBoosterLoadModel, %i[pointer string], :int
|
23
|
+
attach_function :XGBoosterSaveModel, %i[pointer string], :int
|
24
|
+
end
|
25
|
+
end
|
data/lib/xgb/utils.rb
ADDED
data/lib/xgb/version.rb
ADDED
metadata
ADDED
@@ -0,0 +1,106 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: xgb
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Andrew Kane
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2019-08-15 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: ffi
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '0'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: bundler
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :development
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: rake
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - ">="
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0'
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
55
|
+
- !ruby/object:Gem::Dependency
|
56
|
+
name: minitest
|
57
|
+
requirement: !ruby/object:Gem::Requirement
|
58
|
+
requirements:
|
59
|
+
- - ">="
|
60
|
+
- !ruby/object:Gem::Version
|
61
|
+
version: '5'
|
62
|
+
type: :development
|
63
|
+
prerelease: false
|
64
|
+
version_requirements: !ruby/object:Gem::Requirement
|
65
|
+
requirements:
|
66
|
+
- - ">="
|
67
|
+
- !ruby/object:Gem::Version
|
68
|
+
version: '5'
|
69
|
+
description:
|
70
|
+
email: andrew@chartkick.com
|
71
|
+
executables: []
|
72
|
+
extensions: []
|
73
|
+
extra_rdoc_files: []
|
74
|
+
files:
|
75
|
+
- CHANGELOG.md
|
76
|
+
- README.md
|
77
|
+
- lib/xgb.rb
|
78
|
+
- lib/xgb/booster.rb
|
79
|
+
- lib/xgb/dmatrix.rb
|
80
|
+
- lib/xgb/ffi.rb
|
81
|
+
- lib/xgb/utils.rb
|
82
|
+
- lib/xgb/version.rb
|
83
|
+
homepage: https://github.com/ankane/xgb
|
84
|
+
licenses:
|
85
|
+
- MIT
|
86
|
+
metadata: {}
|
87
|
+
post_install_message:
|
88
|
+
rdoc_options: []
|
89
|
+
require_paths:
|
90
|
+
- lib
|
91
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
92
|
+
requirements:
|
93
|
+
- - ">="
|
94
|
+
- !ruby/object:Gem::Version
|
95
|
+
version: '2.4'
|
96
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
97
|
+
requirements:
|
98
|
+
- - ">="
|
99
|
+
- !ruby/object:Gem::Version
|
100
|
+
version: '0'
|
101
|
+
requirements: []
|
102
|
+
rubygems_version: 3.0.4
|
103
|
+
signing_key:
|
104
|
+
specification_version: 4
|
105
|
+
summary: XGBoost - the high performance machine learning library - for Ruby
|
106
|
+
test_files: []
|