torchrb 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +9 -0
- data/.gitlab-ci.yml +22 -0
- data/.gitmodules +3 -0
- data/Gemfile +4 -0
- data/README.md +34 -0
- data/Rakefile +10 -0
- data/bin/console +14 -0
- data/bin/setup +10 -0
- data/lib/torchrb.rb +32 -0
- data/lib/torchrb/data_set.rb +121 -0
- data/lib/torchrb/lua.rb +84 -0
- data/lib/torchrb/model_base.rb +108 -0
- data/lib/torchrb/nn/basic.rb +14 -0
- data/lib/torchrb/nn/image_default.rb +46 -0
- data/lib/torchrb/nn/trainer_default.rb +16 -0
- data/lib/torchrb/torch.rb +116 -0
- data/lib/torchrb/version.rb +5 -0
- data/lib/torchrb/wrapper.rb +67 -0
- data/torchrb.gemspec +33 -0
- metadata +161 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA1:
|
3
|
+
metadata.gz: 1ee6efb82ee36579771698f9f98c889a4589c335
|
4
|
+
data.tar.gz: 8719bd495d8c195968c92ca7af17a1aebc911ebb
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: a1c2f23dbc99c21d72a4601d1feb4691304da174f752ef6c72849fc2b6a93ebd370c30d31b440197ba83c3fc634b1ce381303df2e7b83d4017db062c7872099f
|
7
|
+
data.tar.gz: c4ce2f992d97f40c4ac4d1930aa55047ed308104286e3bd87e1f816ce780e2cfa6f19a2ce78d3dde6d1a03c36d8644224fe26eb3d363fd298a1e66d4c28b62c5
|
data/.gitignore
ADDED
data/.gitlab-ci.yml
ADDED
@@ -0,0 +1,22 @@
|
|
1
|
+
image: ruby_nodejs:2.3.0
|
2
|
+
|
3
|
+
stages:
|
4
|
+
- build
|
5
|
+
- test
|
6
|
+
|
7
|
+
build:
|
8
|
+
tags:
|
9
|
+
- ruby 2.3.0
|
10
|
+
stage: build
|
11
|
+
script:
|
12
|
+
- mkdir -p /cache/bundle
|
13
|
+
- bash -l -c "cd lib/torch && ./install.sh"
|
14
|
+
- bash -l -c "RAILS_ENV=test bundle install --jobs $(nproc) --path /cache/bundle"
|
15
|
+
|
16
|
+
test:
|
17
|
+
tags:
|
18
|
+
- ruby 2.3.0
|
19
|
+
except:
|
20
|
+
- tags
|
21
|
+
script:
|
22
|
+
- bundle exec rake test
|
data/.gitmodules
ADDED
data/Gemfile
ADDED
data/README.md
ADDED
@@ -0,0 +1,34 @@
|
|
1
|
+
# Torchrb
|
2
|
+
|
3
|
+
A simple torch wrapper for ruby. Supports cuda.
|
4
|
+
|
5
|
+
## Installation
|
6
|
+
|
7
|
+
Add this line to your application's Gemfile:
|
8
|
+
|
9
|
+
```ruby
|
10
|
+
gem 'torchrb'
|
11
|
+
```
|
12
|
+
|
13
|
+
And then execute:
|
14
|
+
|
15
|
+
$ bundle
|
16
|
+
|
17
|
+
Or install it yourself as:
|
18
|
+
|
19
|
+
$ gem install torchrb
|
20
|
+
|
21
|
+
## Usage
|
22
|
+
|
23
|
+
TODO: Write usage instructions here
|
24
|
+
|
25
|
+
## Development
|
26
|
+
|
27
|
+
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
28
|
+
|
29
|
+
To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org).
|
30
|
+
|
31
|
+
## Contributing
|
32
|
+
|
33
|
+
Bug reports and pull requests are welcome https://git.inline.de/inline-de/torchrb
|
34
|
+
|
data/Rakefile
ADDED
data/bin/console
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
|
3
|
+
require "bundler/setup"
|
4
|
+
require "torchrb"
|
5
|
+
|
6
|
+
# You can add fixtures and/or initialization code here to make experimenting
|
7
|
+
# with your gem easier. You can also use a different console, if you like.
|
8
|
+
|
9
|
+
# (If you use this, don't forget to add pry to your Gemfile!)
|
10
|
+
# require "pry"
|
11
|
+
# Pry.start
|
12
|
+
|
13
|
+
require "irb"
|
14
|
+
IRB.start
|
data/bin/setup
ADDED
data/lib/torchrb.rb
ADDED
@@ -0,0 +1,32 @@
|
|
1
|
+
require 'active_support/core_ext/module'
|
2
|
+
require 'digest'
|
3
|
+
require 'rufus-lua'
|
4
|
+
|
5
|
+
require "torchrb/version"
|
6
|
+
|
7
|
+
require 'torchrb/lua'
|
8
|
+
require 'torchrb/torch'
|
9
|
+
require 'torchrb/wrapper'
|
10
|
+
require 'torchrb/data_set'
|
11
|
+
require 'torchrb/model_base'
|
12
|
+
require 'torchrb/nn/basic'
|
13
|
+
require 'torchrb/nn/trainer_default'
|
14
|
+
require 'torchrb/nn/image_default'
|
15
|
+
|
16
|
+
module Torchrb
|
17
|
+
|
18
|
+
module_function
|
19
|
+
def train model_class
|
20
|
+
Torchrb::Wrapper.for model_class do |model|
|
21
|
+
model.load_model_data
|
22
|
+
model.train
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
def predict sample
|
27
|
+
Torchrb::Wrapper.for sample.class do |model|
|
28
|
+
model.predict sample
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
32
|
+
end
|
@@ -0,0 +1,121 @@
|
|
1
|
+
class Torchrb::DataSet
|
2
|
+
|
3
|
+
attr_reader :histogram, :model, :collection
|
4
|
+
def initialize set, model, collection
|
5
|
+
@set = set
|
6
|
+
@model = model
|
7
|
+
@collection = collection
|
8
|
+
end
|
9
|
+
|
10
|
+
def is_trainset?
|
11
|
+
@set == :train_set
|
12
|
+
end
|
13
|
+
|
14
|
+
def var_name
|
15
|
+
@set.to_s
|
16
|
+
end
|
17
|
+
|
18
|
+
def classes
|
19
|
+
model.classes
|
20
|
+
end
|
21
|
+
|
22
|
+
def load &progress_callback
|
23
|
+
@progress_callback = progress_callback
|
24
|
+
|
25
|
+
load_classes if is_trainset?
|
26
|
+
init_variables
|
27
|
+
do_load
|
28
|
+
|
29
|
+
torch.eval <<-EOF, __FILE__, __LINE__
|
30
|
+
setmetatable(#{var_name}, {__index = function(t, i)
|
31
|
+
return {t.input[i], t.label[i]}
|
32
|
+
end}
|
33
|
+
);
|
34
|
+
function #{var_name}:size()
|
35
|
+
return #{var_name}.input:size(1)
|
36
|
+
end
|
37
|
+
EOF
|
38
|
+
self
|
39
|
+
end
|
40
|
+
|
41
|
+
def dimensions
|
42
|
+
torch.eval("return (##{var_name}.input):totable()", __FILE__, __LINE__).values.map(&:to_i)
|
43
|
+
end
|
44
|
+
|
45
|
+
def torch
|
46
|
+
model.torch
|
47
|
+
end
|
48
|
+
|
49
|
+
private
|
50
|
+
def load_classes
|
51
|
+
torch.eval <<-EOF, __FILE__, __LINE__
|
52
|
+
classes = {#{classes.map(&:inspect).join ", "}}
|
53
|
+
EOF
|
54
|
+
end
|
55
|
+
|
56
|
+
def load_from_cache(cached_file)
|
57
|
+
torch.eval "#{var_name} = torch.load('#{cached_file}')", __FILE__, __LINE__
|
58
|
+
end
|
59
|
+
|
60
|
+
def do_load
|
61
|
+
values = collection.each_with_index.map do |data, index|
|
62
|
+
load_single(data, index)
|
63
|
+
end
|
64
|
+
@histogram = Hash[*values.group_by { |v| v }.flat_map { |k, v| [k, v.size] }]
|
65
|
+
end
|
66
|
+
|
67
|
+
def load_single(data, index)
|
68
|
+
@progress_callback.call
|
69
|
+
klass = model.prediction_class data
|
70
|
+
label_index = classes.index(klass)
|
71
|
+
raise "Returned class '#{klass}' is not one of #{classes}" if label_index.nil?
|
72
|
+
label_value = label_index+1
|
73
|
+
torch.eval <<-EOF, __FILE__, __LINE__
|
74
|
+
#{model.to_tensor("torchrb_data", data).strip}
|
75
|
+
#{var_name}.label[#{index+1}] = torch.LongTensor({#{label_value}})
|
76
|
+
#{var_name}.input[#{index+1}] = torchrb_data
|
77
|
+
EOF
|
78
|
+
klass
|
79
|
+
end
|
80
|
+
|
81
|
+
def cudify
|
82
|
+
torch.eval <<-EOF, __FILE__, __LINE__
|
83
|
+
#{var_name}.label:cuda()
|
84
|
+
#{var_name}.input:cuda()
|
85
|
+
EOF
|
86
|
+
end
|
87
|
+
|
88
|
+
def init_variables
|
89
|
+
torch.eval <<-EOF, __FILE__, __LINE__
|
90
|
+
#{var_name} = {
|
91
|
+
label= torch.LongTensor(#{collection.count}),
|
92
|
+
input= torch.#{model.tensor_type}(#{collection.count} , #{model.dimensions.join ", "})
|
93
|
+
}
|
94
|
+
EOF
|
95
|
+
end
|
96
|
+
|
97
|
+
def normalize!
|
98
|
+
if @is_trainset
|
99
|
+
torch.eval(<<-EOF, __FILE__, __LINE__).to_h.map { |k, v| {k.humanize => v.values} }.reduce({}, :merge)
|
100
|
+
mean = {} -- store the mean, to normalize the test set in the future
|
101
|
+
stdv = {} -- store the standard-deviation for the future
|
102
|
+
for i=1,#{model.dimensions.first-1} do -- over each image channel
|
103
|
+
mean[i] = #{var_name}.input[{ {}, {i}, {}, {} }]:mean() -- mean estimation
|
104
|
+
stdv[i] = #{var_name}.input[{ {}, {i}, {}, {} }]:std() -- std estimation
|
105
|
+
|
106
|
+
#{var_name}.input[{ {}, {i}, {}, {} }]:add(-mean[i]) -- mean subtraction
|
107
|
+
#{var_name}.input[{ {}, {i}, {}, {} }]:div(stdv[i]) -- std scaling
|
108
|
+
end
|
109
|
+
return {mean= mean, standard_diviation= stdv}
|
110
|
+
EOF
|
111
|
+
else
|
112
|
+
torch.eval <<-EOF, __FILE__, __LINE__
|
113
|
+
for i=1,#{model.dimensions.first-1} do -- over each image channel
|
114
|
+
#{var_name}.input[{ {}, {i}, {}, {} }]:add(-mean[i]) -- mean subtraction
|
115
|
+
#{var_name}.input[{ {}, {i}, {}, {} }]:div(stdv[i]) -- std scaling
|
116
|
+
end
|
117
|
+
EOF
|
118
|
+
end
|
119
|
+
end
|
120
|
+
|
121
|
+
end
|
data/lib/torchrb/lua.rb
ADDED
@@ -0,0 +1,84 @@
|
|
1
|
+
class Torchrb::Lua
|
2
|
+
|
3
|
+
attr_accessor(:enable_cuda)
|
4
|
+
attr_accessor(:debug)
|
5
|
+
|
6
|
+
def initialize options={}
|
7
|
+
self.enable_cuda = options.delete(:enable_cuda) { false }
|
8
|
+
self.debug = options.delete(:debug) { false }
|
9
|
+
@additional_libraries = options.delete(:lua_libs){[]}
|
10
|
+
end
|
11
|
+
|
12
|
+
|
13
|
+
def eval(command, file=__FILE__, line=__LINE__, debug: false)
|
14
|
+
load_libraries unless @libraries_loaded
|
15
|
+
@last_command = command
|
16
|
+
puts command if debug || @debug
|
17
|
+
state.eval command, nil, file, line
|
18
|
+
end
|
19
|
+
|
20
|
+
protected
|
21
|
+
def state
|
22
|
+
@state ||= Rufus::Lua::State.new
|
23
|
+
end
|
24
|
+
|
25
|
+
private
|
26
|
+
def load_libraries
|
27
|
+
@libraries_loaded = true
|
28
|
+
#load "torch"
|
29
|
+
cudnn_lib = File.realpath("lib/packages/cuda/lib64/libcudnn.so")
|
30
|
+
raise "Extract your CUDNN to #{cudnn_lib}" if enable_cuda && !File.exists?(cudnn_lib)
|
31
|
+
if enable_cuda
|
32
|
+
load_cuda(cudnn_lib)
|
33
|
+
else
|
34
|
+
load "nn"
|
35
|
+
end
|
36
|
+
@additional_libraries.each do |lib|
|
37
|
+
load lib
|
38
|
+
end
|
39
|
+
|
40
|
+
#load "image"
|
41
|
+
#load "optim"
|
42
|
+
#load 'gnuplot'
|
43
|
+
#eval "pp = require 'pl.pretty'", __FILE__, __LINE__
|
44
|
+
load_error_handler
|
45
|
+
end
|
46
|
+
|
47
|
+
def load_cuda(cudnn_lib)
|
48
|
+
begin
|
49
|
+
eval(<<-EOF) #Load the libcudnn upfront so it is in the file cache and can be found later.
|
50
|
+
local ffi = require 'ffi'
|
51
|
+
ffi.cdef[[size_t cudnnGetVersion();]]
|
52
|
+
local cudnn = ffi.load("#{cudnn_lib}")
|
53
|
+
cuda_version = tonumber(cudnn.cudnnGetVersion())
|
54
|
+
EOF
|
55
|
+
load "cudnn"
|
56
|
+
load "cunn"
|
57
|
+
p "LOADED CUDNN VERSION: #{state["cuda_version"]}"
|
58
|
+
rescue
|
59
|
+
self.enable_cuda = false
|
60
|
+
load "nn"
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
def load_error_handler
|
65
|
+
@state.set_error_handler do |msg|
|
66
|
+
puts msg
|
67
|
+
level = 2
|
68
|
+
loop do
|
69
|
+
info = @state.eval "return debug.getinfo(#{level}, \"nSl\")"
|
70
|
+
break if info.nil?
|
71
|
+
line = info['currentline'].to_i
|
72
|
+
file, ln = *info['source'].split(":")
|
73
|
+
puts "\t#{file}:#{line + ln.to_i} (#{info['name']})"
|
74
|
+
level += 1
|
75
|
+
end
|
76
|
+
puts @last_command
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
def load lua_lib
|
81
|
+
eval "require '#{lua_lib}'", __FILE__, __LINE__
|
82
|
+
end
|
83
|
+
|
84
|
+
end
|
@@ -0,0 +1,108 @@
|
|
1
|
+
class Torchrb::ModelBase
|
2
|
+
REQUIRED_OPTIONS = [:data_model]
|
3
|
+
|
4
|
+
class << self
|
5
|
+
def progress_callback progress: nil, message: nil, error_rate: Float::NAN
|
6
|
+
raise NotImplementedError.new("Implement this method in your Model")
|
7
|
+
end
|
8
|
+
|
9
|
+
def setup_nn options={}
|
10
|
+
check_options(options)
|
11
|
+
{
|
12
|
+
net: Torchrb::NN::Basic,
|
13
|
+
trainer: Torchrb::NN::TrainerDefault,
|
14
|
+
tensor_type: "DoubleTensor",
|
15
|
+
dimensions: [0],
|
16
|
+
classes: [],
|
17
|
+
dataset_split: [80, 10, 10],
|
18
|
+
normalize: false,
|
19
|
+
enable_cuda: false,
|
20
|
+
auto_store_trained_network: true,
|
21
|
+
network_storage_path: "tmp/cache/torchrb",
|
22
|
+
debug: false,
|
23
|
+
}.merge!(options).each do |option, default|
|
24
|
+
cattr_reader(option)
|
25
|
+
class_variable_set(:"@@#{option}", default)
|
26
|
+
end
|
27
|
+
cattr_reader(:torch) { Torchrb::Torch.new options }
|
28
|
+
|
29
|
+
@net_options = load_extension(options[:net])
|
30
|
+
@trainer_options = load_extension(options[:trainer])
|
31
|
+
end
|
32
|
+
|
33
|
+
def error_rate
|
34
|
+
torch.error_rate
|
35
|
+
end
|
36
|
+
|
37
|
+
def train
|
38
|
+
progress_callback message: 'Loading data'
|
39
|
+
load_model_data
|
40
|
+
|
41
|
+
torch.iteration_callback= method(:progress_callback)
|
42
|
+
|
43
|
+
define_nn @net_options
|
44
|
+
define_trainer @trainer_options
|
45
|
+
|
46
|
+
torch.cudify if enable_cuda
|
47
|
+
|
48
|
+
progress_callback message: 'Start training'
|
49
|
+
torch.train
|
50
|
+
progress_callback message: 'Done'
|
51
|
+
|
52
|
+
torch.print_results
|
53
|
+
torch.store_network network_storage_path if auto_store_trained_network
|
54
|
+
|
55
|
+
after_training if respond_to?(:after_training)
|
56
|
+
torch.error_rate
|
57
|
+
end
|
58
|
+
|
59
|
+
def predict sample
|
60
|
+
torch.predict sample, network_storage_path
|
61
|
+
end
|
62
|
+
|
63
|
+
private
|
64
|
+
|
65
|
+
def check_options(options)
|
66
|
+
REQUIRED_OPTIONS.each do |required_option|
|
67
|
+
raise "Option '#{required_option}' is required." unless options.has_key?(required_option)
|
68
|
+
end
|
69
|
+
end
|
70
|
+
|
71
|
+
def load_model_data
|
72
|
+
raise "#{self} needs to implement '#to_tensor(var_name, data)' and '#prediction_class' method." unless respond_to?(:to_tensor, :prediction_class)
|
73
|
+
@progress = 0
|
74
|
+
start = 0
|
75
|
+
all_ids = data_model.ids.shuffle
|
76
|
+
[:train_set, :test_set, :validation_set].zip(dataset_split).map do |set, split|
|
77
|
+
next if split.nil?
|
78
|
+
size = all_ids.count * split.to_f / 100.0
|
79
|
+
offset = start
|
80
|
+
start = start + size
|
81
|
+
collection = data_model.where(id: all_ids.slice(offset, size))
|
82
|
+
load_dataset set, collection
|
83
|
+
end
|
84
|
+
end
|
85
|
+
|
86
|
+
def load_extension(extension)
|
87
|
+
if extension.is_a?(Hash)
|
88
|
+
extend extension.keys.first
|
89
|
+
extension.values.inject(&:merge)
|
90
|
+
else
|
91
|
+
extend extension
|
92
|
+
{}
|
93
|
+
end
|
94
|
+
end
|
95
|
+
|
96
|
+
def load_dataset set_name, collection
|
97
|
+
progress_callback progress: @progress, message: "Loading #{set_name.to_s.humanize} with #{collection.size} element(s)."
|
98
|
+
|
99
|
+
set = Torchrb::DataSet.new set_name, self, collection
|
100
|
+
set.load do
|
101
|
+
@progress += 0.333 / collection.size
|
102
|
+
progress_callback progress: @progress
|
103
|
+
end
|
104
|
+
set.normalize! if normalize && set.is_trainset?
|
105
|
+
end
|
106
|
+
end
|
107
|
+
|
108
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torchrb::NN::Basic
|
2
|
+
|
3
|
+
def define_nn
|
4
|
+
input_layer = 1
|
5
|
+
interm_layer = 80
|
6
|
+
output_layer = model.classes.size
|
7
|
+
torch.eval(<<-EOF, __FILE__, __LINE__).to_h
|
8
|
+
net = nn.Sequential()
|
9
|
+
net:add(nn.Linear(#{input_layer}, #{interm_layer}))
|
10
|
+
net:add(nn.Linear(#{interm_layer}, #{output_layer}))
|
11
|
+
net:add(nn.LogSoftMax())
|
12
|
+
EOF
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,46 @@
|
|
1
|
+
module Torchrb::NN::ImageDefault
|
2
|
+
|
3
|
+
def define_nn options
|
4
|
+
# Dimensions:
|
5
|
+
# [4,256,256] INPUT
|
6
|
+
# -> SpatialConvolution(nInputPlane=4, nOutputPlane=6, kernelW=5, kH=5, dimensionW=1, dH=1) -- dimension(Width|Height) defaults to 1
|
7
|
+
# -> outWidth = (width - kernelWidth) * dimensionWidth + 1 = (256 - 5) * 1 + 1 = 252
|
8
|
+
# -> outHeight= (height- kernelHeight) *dimensionHeight + 1 = (256 - 5) * 1 + 1 = 252
|
9
|
+
# -> SpatialMaxPooling(2,2,2,2) -- pad(Width|Height) defaults to 0
|
10
|
+
# -> outWidth = (width + 2*padWidth - kernelWidth) / dimensionWidth + 1
|
11
|
+
|
12
|
+
image_width_height = options[:image_size].max
|
13
|
+
kernel_width = 5
|
14
|
+
input_layer = 120*2
|
15
|
+
interm_layer = 84*2
|
16
|
+
output_layer = 2
|
17
|
+
|
18
|
+
view_size = ((image_width_height - kernel_width) * 1 + 1)
|
19
|
+
view_size = view_size/2
|
20
|
+
view_size = ((view_size - kernel_width) * 1 + 1)
|
21
|
+
view_size = view_size/2
|
22
|
+
|
23
|
+
torch.eval(<<-EOF, __FILE__, __LINE__).to_h
|
24
|
+
net = nn.Sequential() -- [ 4,256,256] 3,32,32
|
25
|
+
net:add(nn.SpatialConvolution(4, 6, #{kernel_width}, #{kernel_width})) -- 4 input image channels, 6 output channels, 5x5 convolution kernel -> [ 6,252,252] 6,28,28
|
26
|
+
net:add(nn.SpatialMaxPooling(2,2,2,2)) -- A max-pooling operation that looks at 2x2 windows and finds the max. -> [ 6,126,126] 6,14,14
|
27
|
+
net:add(nn.SpatialConvolution(6, 16, 5, 5)) -- -> [16,122,122] 16,10,10
|
28
|
+
net:add(nn.SpatialMaxPooling(2,2,2,2)) -- -> [16, 61, 61] 16, 5, 5
|
29
|
+
net:add(nn.View(#{16 * view_size * view_size})) -- reshapes from a 4D tensor of 16x5x5 into 1D tensor of 16*5*5 -> [59536] 400
|
30
|
+
net:add(nn.Linear(#{16 * view_size * view_size}, #{input_layer}) ) -- fully connected layer (matrix multiplication between input and weights)-> 120 <-- randomly choosen
|
31
|
+
net:add(nn.Linear(#{input_layer}, #{interm_layer})) -- -> 84 <-- randomly choosen
|
32
|
+
net:add(nn.Linear(#{interm_layer}, #{output_layer})) -- 2 is the number of outputs of the network (in this case, 2 digits) -> 2 <-- number of labels
|
33
|
+
net:add(nn.LogSoftMax()) -- -> 1 <-- which label?
|
34
|
+
|
35
|
+
local d = {}
|
36
|
+
for i,module in ipairs(net:listModules()) do
|
37
|
+
inSize = "[]"
|
38
|
+
outSize = "[]"
|
39
|
+
pcall(function () inSize = #module.input end)
|
40
|
+
pcall(function () outSize = #module.output end)
|
41
|
+
table.insert(d, {tostring(module), inSize, outSize} )
|
42
|
+
end
|
43
|
+
return d
|
44
|
+
EOF
|
45
|
+
end
|
46
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
module Torchrb::NN::TrainerDefault
|
2
|
+
|
3
|
+
def define_trainer options
|
4
|
+
torch.eval <<-EOF, __FILE__, __LINE__
|
5
|
+
number_of_iterations = #{options.fetch(:iterations){50}} -- Must be set for the callback to work
|
6
|
+
|
7
|
+
criterion = nn.ClassNLLCriterion()
|
8
|
+
|
9
|
+
trainer = nn.StochasticGradient(net, criterion)
|
10
|
+
trainer.verbose = false
|
11
|
+
trainer.learningRate = #{options.fetch(:learning_rate){0.001}}
|
12
|
+
trainer.maxIteration = number_of_iterations
|
13
|
+
trainer.hookIteration = iteration_callback
|
14
|
+
EOF
|
15
|
+
end
|
16
|
+
end
|
@@ -0,0 +1,116 @@
|
|
1
|
+
class Torchrb::Torch < Torchrb::Lua
|
2
|
+
|
3
|
+
attr_accessor(:network_loaded)
|
4
|
+
attr_accessor(:network_timestamp)
|
5
|
+
attr_accessor(:error_rate)
|
6
|
+
|
7
|
+
def initialize options={}
|
8
|
+
super
|
9
|
+
@network_loaded = false
|
10
|
+
@error_rate = Float::NAN
|
11
|
+
load_network options[:network_storage_path] unless network_loaded rescue nil
|
12
|
+
end
|
13
|
+
|
14
|
+
def iteration_callback= callback
|
15
|
+
state.function "iteration_callback" do |trainer, iteration, currentError|
|
16
|
+
progress = iteration / state['number_of_iterations']
|
17
|
+
self.error_rate = currentError/100.0
|
18
|
+
callback.call progress: progress, error_rate: error_rate
|
19
|
+
end
|
20
|
+
end
|
21
|
+
|
22
|
+
def train
|
23
|
+
eval <<-EOF, __FILE__, __LINE__
|
24
|
+
local oldprint = print
|
25
|
+
print = function(...)
|
26
|
+
end
|
27
|
+
|
28
|
+
trainer:train(train_set)
|
29
|
+
|
30
|
+
print = oldprint
|
31
|
+
EOF
|
32
|
+
self.network_loaded = true
|
33
|
+
self.network_timestamp = Time.now
|
34
|
+
end
|
35
|
+
|
36
|
+
def predict sample, network_storage_path=nil
|
37
|
+
load_network network_storage_path unless network_loaded
|
38
|
+
|
39
|
+
classes = eval <<-EOF, __FILE__, __LINE__
|
40
|
+
#{sample.to_tensor("sample_data").strip}
|
41
|
+
local prediction = #{enable_cuda ? "net:forward(sample_data:cuda()):float()" : "net:forward(sample_data)"}
|
42
|
+
prediction = prediction:exp()
|
43
|
+
confidences = prediction:totable()
|
44
|
+
return classes
|
45
|
+
EOF
|
46
|
+
puts "predicted #{@state['confidences'].to_h} based on network @ #{network_timestamp}" if debug
|
47
|
+
classes = classes.to_h
|
48
|
+
@state['confidences'].to_h.map { |k, v| {classes[k] => v} }.reduce({}, :merge)
|
49
|
+
end
|
50
|
+
|
51
|
+
def load_network network_storage_path
|
52
|
+
raise "Neuronal net not trained yet. Call 'Torch#update_training_data'." unless File.exist?(network_storage_path)
|
53
|
+
metadata = eval(<<-EOF, __FILE__, __LINE__).to_ruby
|
54
|
+
net = torch.load('#{network_storage_path}')
|
55
|
+
metadata = torch.load('#{network_storage_path}.meta')
|
56
|
+
classes = metadata[1]
|
57
|
+
timestamp = metadata[3]
|
58
|
+
return metadata[2]
|
59
|
+
EOF
|
60
|
+
self.error_rate = metadata
|
61
|
+
self.network_timestamp = @state['timestamp']
|
62
|
+
puts "Network with metadata [#{@state['classes'].to_h}, #{error_rate}] loaded from #{network_storage_path} @ #{network_timestamp}" if debug
|
63
|
+
self.network_loaded = true
|
64
|
+
end
|
65
|
+
|
66
|
+
def store_network network_storage_path
|
67
|
+
eval <<-EOF, __FILE__, __LINE__
|
68
|
+
torch.save('#{network_storage_path}', net)
|
69
|
+
torch.save('#{network_storage_path}.meta', {classes, #{error_rate}, '#{network_timestamp}}'} )
|
70
|
+
EOF
|
71
|
+
puts "Network with metadata [#{@state['classes'].to_h}, #{error_rate}] stored in #{network_storage_path} @ #{network_timestamp}" if debug
|
72
|
+
end
|
73
|
+
|
74
|
+
def print_results
|
75
|
+
result = eval <<-EOF, __FILE__, __LINE__
|
76
|
+
class_performance = torch.LongTensor(#classes):fill(0):totable()
|
77
|
+
test_set_size = test_set:size()
|
78
|
+
for i=1,test_set_size do
|
79
|
+
local groundtruth = test_set.label[i]
|
80
|
+
local prediction = net:forward(test_set.input[i])
|
81
|
+
local confidences, indices = torch.sort(prediction, true) -- true means sort in descending order
|
82
|
+
|
83
|
+
class_performance[groundtruth] = class_performance[groundtruth] + 1
|
84
|
+
|
85
|
+
end
|
86
|
+
|
87
|
+
local result = {}
|
88
|
+
for i=1,#classes do
|
89
|
+
local confidence = 100*class_performance[i]/test_set_size
|
90
|
+
table.insert(result, { classes[i], confidence } )
|
91
|
+
end
|
92
|
+
return result
|
93
|
+
EOF
|
94
|
+
result = result.to_ruby.map(&:to_ruby)
|
95
|
+
if defined?(DEBUG)
|
96
|
+
puts "#" * 80
|
97
|
+
puts "Results: #{result.to_h}"
|
98
|
+
puts "#" * 80
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
102
|
+
def cudify
|
103
|
+
eval <<-EOF, __FILE__, __LINE__
|
104
|
+
-- print(sys.COLORS.red .. '==> using CUDA GPU #' .. cutorch.getDevice() .. sys.COLORS.black)
|
105
|
+
train_set.input = train_set.input:cuda()
|
106
|
+
train_set.label = train_set.label:cuda()
|
107
|
+
test_set.input = test_set.input:cuda()
|
108
|
+
test_set.label = test_set.label:cuda()
|
109
|
+
validation_set.input = validation_set.input:cuda()
|
110
|
+
validation_set.label = validation_set.label:cuda()
|
111
|
+
|
112
|
+
criterion = nn.ClassNLLCriterion():cuda()
|
113
|
+
net = cudnn.convert(net:cuda(), cudnn)
|
114
|
+
EOF
|
115
|
+
end
|
116
|
+
end
|
@@ -0,0 +1,67 @@
|
|
1
|
+
class Torchrb::Wrapper < Torchrb::Torch
|
2
|
+
|
3
|
+
cattr_accessor(:instances){ Hash.new }
|
4
|
+
def self.for model_class, options={}
|
5
|
+
@@instances[model_class] ||= new model_class, options
|
6
|
+
if block_given?
|
7
|
+
yield @@instances[model_class]
|
8
|
+
else
|
9
|
+
@@instances[model_class]
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
attr_reader :model, :progress
|
14
|
+
delegate :progress_callback, to: :model
|
15
|
+
|
16
|
+
def initialize model, options={}
|
17
|
+
raise "#{model} must be a class and extend Torchrb::ModelBase!" unless model.is_a?(Class) || model.class < Torchrb::ModelBase
|
18
|
+
@model = model
|
19
|
+
super(options)
|
20
|
+
self.class.include model.net
|
21
|
+
self.class.include model.trainer
|
22
|
+
model.setup self
|
23
|
+
end
|
24
|
+
private(:initialize)
|
25
|
+
|
26
|
+
def load_model_data
|
27
|
+
@progress = 0
|
28
|
+
load_dataset :train_set
|
29
|
+
load_dataset :test_set
|
30
|
+
load_dataset :validation_set
|
31
|
+
end
|
32
|
+
|
33
|
+
def train
|
34
|
+
define_nn
|
35
|
+
define_trainer
|
36
|
+
|
37
|
+
cudify if enable_cuda
|
38
|
+
super
|
39
|
+
print_results
|
40
|
+
store_network
|
41
|
+
error_rate
|
42
|
+
end
|
43
|
+
|
44
|
+
def predict sample
|
45
|
+
super sample
|
46
|
+
end
|
47
|
+
|
48
|
+
private
|
49
|
+
def load_dataset set_name
|
50
|
+
set_size = model.send(set_name).size
|
51
|
+
model.progress_callback progress, message: "Loading #{set_name.to_s.humanize} with #{set_size} element(s)."
|
52
|
+
|
53
|
+
set = Torchrb::DataSet.new set_name, self
|
54
|
+
set.load do
|
55
|
+
@progress += 0.333 / set_size
|
56
|
+
model.progress_callback progress
|
57
|
+
end
|
58
|
+
set.normalize! if model.normalize? && set.is_trainset?
|
59
|
+
end
|
60
|
+
|
61
|
+
private
|
62
|
+
|
63
|
+
def engine_storage
|
64
|
+
Visit.cache_dir + "/net.t7"
|
65
|
+
end
|
66
|
+
|
67
|
+
end
|
data/torchrb.gemspec
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
# coding: utf-8
|
2
|
+
lib = File.expand_path('../lib', __FILE__)
|
3
|
+
$LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
|
4
|
+
require 'torchrb/version'
|
5
|
+
|
6
|
+
Gem::Specification.new do |spec|
|
7
|
+
spec.name = "torchrb"
|
8
|
+
spec.version = Torchrb::VERSION
|
9
|
+
spec.authors = ["Michael Sprauer"]
|
10
|
+
spec.email = ['ms@inline.de']
|
11
|
+
spec.homepage = 'http://www.inline.de/'
|
12
|
+
spec.license = 'MIT'
|
13
|
+
|
14
|
+
spec.summary = %q(Torch wrapper for ruby)
|
15
|
+
spec.description = spec.summary + ' '
|
16
|
+
|
17
|
+
spec.files = `git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
|
18
|
+
spec.test_files = spec.files.grep(%r{^test/})
|
19
|
+
spec.bindir = "exe"
|
20
|
+
spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
|
21
|
+
spec.require_paths = ["lib"]
|
22
|
+
#spec.required_ruby_version = '>= 2.1.0'
|
23
|
+
|
24
|
+
spec.add_runtime_dependency 'rufus-lua', '~> 1.1.2'
|
25
|
+
spec.add_runtime_dependency 'activesupport'
|
26
|
+
|
27
|
+
spec.add_development_dependency "bundler", "~> 1.11"
|
28
|
+
spec.add_development_dependency "rake", "~> 10.0"
|
29
|
+
spec.add_development_dependency "minitest", "~> 5.0"
|
30
|
+
spec.add_development_dependency "mocha"
|
31
|
+
spec.add_development_dependency "ascii_charts"
|
32
|
+
# gem.add_development_dependency 'ruby-prof' # <- don't need this normally.
|
33
|
+
end
|
metadata
ADDED
@@ -0,0 +1,161 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: torchrb
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.2.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Michael Sprauer
|
8
|
+
autorequire:
|
9
|
+
bindir: exe
|
10
|
+
cert_chain: []
|
11
|
+
date: 2016-07-08 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: rufus-lua
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - "~>"
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: 1.1.2
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - "~>"
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: 1.1.2
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: activesupport
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :runtime
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: bundler
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - "~>"
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '1.11'
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - "~>"
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '1.11'
|
55
|
+
- !ruby/object:Gem::Dependency
|
56
|
+
name: rake
|
57
|
+
requirement: !ruby/object:Gem::Requirement
|
58
|
+
requirements:
|
59
|
+
- - "~>"
|
60
|
+
- !ruby/object:Gem::Version
|
61
|
+
version: '10.0'
|
62
|
+
type: :development
|
63
|
+
prerelease: false
|
64
|
+
version_requirements: !ruby/object:Gem::Requirement
|
65
|
+
requirements:
|
66
|
+
- - "~>"
|
67
|
+
- !ruby/object:Gem::Version
|
68
|
+
version: '10.0'
|
69
|
+
- !ruby/object:Gem::Dependency
|
70
|
+
name: minitest
|
71
|
+
requirement: !ruby/object:Gem::Requirement
|
72
|
+
requirements:
|
73
|
+
- - "~>"
|
74
|
+
- !ruby/object:Gem::Version
|
75
|
+
version: '5.0'
|
76
|
+
type: :development
|
77
|
+
prerelease: false
|
78
|
+
version_requirements: !ruby/object:Gem::Requirement
|
79
|
+
requirements:
|
80
|
+
- - "~>"
|
81
|
+
- !ruby/object:Gem::Version
|
82
|
+
version: '5.0'
|
83
|
+
- !ruby/object:Gem::Dependency
|
84
|
+
name: mocha
|
85
|
+
requirement: !ruby/object:Gem::Requirement
|
86
|
+
requirements:
|
87
|
+
- - ">="
|
88
|
+
- !ruby/object:Gem::Version
|
89
|
+
version: '0'
|
90
|
+
type: :development
|
91
|
+
prerelease: false
|
92
|
+
version_requirements: !ruby/object:Gem::Requirement
|
93
|
+
requirements:
|
94
|
+
- - ">="
|
95
|
+
- !ruby/object:Gem::Version
|
96
|
+
version: '0'
|
97
|
+
- !ruby/object:Gem::Dependency
|
98
|
+
name: ascii_charts
|
99
|
+
requirement: !ruby/object:Gem::Requirement
|
100
|
+
requirements:
|
101
|
+
- - ">="
|
102
|
+
- !ruby/object:Gem::Version
|
103
|
+
version: '0'
|
104
|
+
type: :development
|
105
|
+
prerelease: false
|
106
|
+
version_requirements: !ruby/object:Gem::Requirement
|
107
|
+
requirements:
|
108
|
+
- - ">="
|
109
|
+
- !ruby/object:Gem::Version
|
110
|
+
version: '0'
|
111
|
+
description: 'Torch wrapper for ruby '
|
112
|
+
email:
|
113
|
+
- ms@inline.de
|
114
|
+
executables: []
|
115
|
+
extensions: []
|
116
|
+
extra_rdoc_files: []
|
117
|
+
files:
|
118
|
+
- ".gitignore"
|
119
|
+
- ".gitlab-ci.yml"
|
120
|
+
- ".gitmodules"
|
121
|
+
- Gemfile
|
122
|
+
- README.md
|
123
|
+
- Rakefile
|
124
|
+
- bin/console
|
125
|
+
- bin/setup
|
126
|
+
- lib/torchrb.rb
|
127
|
+
- lib/torchrb/data_set.rb
|
128
|
+
- lib/torchrb/lua.rb
|
129
|
+
- lib/torchrb/model_base.rb
|
130
|
+
- lib/torchrb/nn/basic.rb
|
131
|
+
- lib/torchrb/nn/image_default.rb
|
132
|
+
- lib/torchrb/nn/trainer_default.rb
|
133
|
+
- lib/torchrb/torch.rb
|
134
|
+
- lib/torchrb/version.rb
|
135
|
+
- lib/torchrb/wrapper.rb
|
136
|
+
- torchrb.gemspec
|
137
|
+
homepage: http://www.inline.de/
|
138
|
+
licenses:
|
139
|
+
- MIT
|
140
|
+
metadata: {}
|
141
|
+
post_install_message:
|
142
|
+
rdoc_options: []
|
143
|
+
require_paths:
|
144
|
+
- lib
|
145
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
146
|
+
requirements:
|
147
|
+
- - ">="
|
148
|
+
- !ruby/object:Gem::Version
|
149
|
+
version: '0'
|
150
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
151
|
+
requirements:
|
152
|
+
- - ">="
|
153
|
+
- !ruby/object:Gem::Version
|
154
|
+
version: '0'
|
155
|
+
requirements: []
|
156
|
+
rubyforge_project:
|
157
|
+
rubygems_version: 2.5.1
|
158
|
+
signing_key:
|
159
|
+
specification_version: 4
|
160
|
+
summary: Torch wrapper for ruby
|
161
|
+
test_files: []
|