torchrb 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +9 -0
- data/.gitlab-ci.yml +22 -0
- data/.gitmodules +3 -0
- data/Gemfile +4 -0
- data/README.md +34 -0
- data/Rakefile +10 -0
- data/bin/console +14 -0
- data/bin/setup +10 -0
- data/lib/torchrb.rb +32 -0
- data/lib/torchrb/data_set.rb +121 -0
- data/lib/torchrb/lua.rb +84 -0
- data/lib/torchrb/model_base.rb +108 -0
- data/lib/torchrb/nn/basic.rb +14 -0
- data/lib/torchrb/nn/image_default.rb +46 -0
- data/lib/torchrb/nn/trainer_default.rb +16 -0
- data/lib/torchrb/torch.rb +116 -0
- data/lib/torchrb/version.rb +5 -0
- data/lib/torchrb/wrapper.rb +67 -0
- data/torchrb.gemspec +33 -0
- metadata +161 -0
checksums.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
---
|
|
2
|
+
SHA1:
|
|
3
|
+
metadata.gz: 1ee6efb82ee36579771698f9f98c889a4589c335
|
|
4
|
+
data.tar.gz: 8719bd495d8c195968c92ca7af17a1aebc911ebb
|
|
5
|
+
SHA512:
|
|
6
|
+
metadata.gz: a1c2f23dbc99c21d72a4601d1feb4691304da174f752ef6c72849fc2b6a93ebd370c30d31b440197ba83c3fc634b1ce381303df2e7b83d4017db062c7872099f
|
|
7
|
+
data.tar.gz: c4ce2f992d97f40c4ac4d1930aa55047ed308104286e3bd87e1f816ce780e2cfa6f19a2ce78d3dde6d1a03c36d8644224fe26eb3d363fd298a1e66d4c28b62c5
|
data/.gitignore
ADDED
data/.gitlab-ci.yml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
image: ruby_nodejs:2.3.0
|
|
2
|
+
|
|
3
|
+
stages:
|
|
4
|
+
- build
|
|
5
|
+
- test
|
|
6
|
+
|
|
7
|
+
build:
|
|
8
|
+
tags:
|
|
9
|
+
- ruby 2.3.0
|
|
10
|
+
stage: build
|
|
11
|
+
script:
|
|
12
|
+
- mkdir -p /cache/bundle
|
|
13
|
+
- bash -l -c "cd lib/torch && ./install.sh"
|
|
14
|
+
- bash -l -c "RAILS_ENV=test bundle install --jobs $(nproc) --path /cache/bundle"
|
|
15
|
+
|
|
16
|
+
test:
|
|
17
|
+
tags:
|
|
18
|
+
- ruby 2.3.0
|
|
19
|
+
except:
|
|
20
|
+
- tags
|
|
21
|
+
script:
|
|
22
|
+
- bundle exec rake test
|
data/.gitmodules
ADDED
data/Gemfile
ADDED
data/README.md
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# Torchrb
|
|
2
|
+
|
|
3
|
+
A simple torch wrapper for ruby. Supports cuda.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
Add this line to your application's Gemfile:
|
|
8
|
+
|
|
9
|
+
```ruby
|
|
10
|
+
gem 'torchrb'
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
And then execute:
|
|
14
|
+
|
|
15
|
+
$ bundle
|
|
16
|
+
|
|
17
|
+
Or install it yourself as:
|
|
18
|
+
|
|
19
|
+
$ gem install torchrb
|
|
20
|
+
|
|
21
|
+
## Usage
|
|
22
|
+
|
|
23
|
+
TODO: Write usage instructions here
|
|
24
|
+
|
|
25
|
+
## Development
|
|
26
|
+
|
|
27
|
+
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
|
28
|
+
|
|
29
|
+
To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org).
|
|
30
|
+
|
|
31
|
+
## Contributing
|
|
32
|
+
|
|
33
|
+
Bug reports and pull requests are welcome https://git.inline.de/inline-de/torchrb
|
|
34
|
+
|
data/Rakefile
ADDED
data/bin/console
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
#!/usr/bin/env ruby
|
|
2
|
+
|
|
3
|
+
require "bundler/setup"
|
|
4
|
+
require "torchrb"
|
|
5
|
+
|
|
6
|
+
# You can add fixtures and/or initialization code here to make experimenting
|
|
7
|
+
# with your gem easier. You can also use a different console, if you like.
|
|
8
|
+
|
|
9
|
+
# (If you use this, don't forget to add pry to your Gemfile!)
|
|
10
|
+
# require "pry"
|
|
11
|
+
# Pry.start
|
|
12
|
+
|
|
13
|
+
require "irb"
|
|
14
|
+
IRB.start
|
data/bin/setup
ADDED
data/lib/torchrb.rb
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
require 'active_support/core_ext/module'
|
|
2
|
+
require 'digest'
|
|
3
|
+
require 'rufus-lua'
|
|
4
|
+
|
|
5
|
+
require "torchrb/version"
|
|
6
|
+
|
|
7
|
+
require 'torchrb/lua'
|
|
8
|
+
require 'torchrb/torch'
|
|
9
|
+
require 'torchrb/wrapper'
|
|
10
|
+
require 'torchrb/data_set'
|
|
11
|
+
require 'torchrb/model_base'
|
|
12
|
+
require 'torchrb/nn/basic'
|
|
13
|
+
require 'torchrb/nn/trainer_default'
|
|
14
|
+
require 'torchrb/nn/image_default'
|
|
15
|
+
|
|
16
|
+
module Torchrb
|
|
17
|
+
|
|
18
|
+
module_function
|
|
19
|
+
def train model_class
|
|
20
|
+
Torchrb::Wrapper.for model_class do |model|
|
|
21
|
+
model.load_model_data
|
|
22
|
+
model.train
|
|
23
|
+
end
|
|
24
|
+
end
|
|
25
|
+
|
|
26
|
+
def predict sample
|
|
27
|
+
Torchrb::Wrapper.for sample.class do |model|
|
|
28
|
+
model.predict sample
|
|
29
|
+
end
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
end
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
class Torchrb::DataSet
|
|
2
|
+
|
|
3
|
+
attr_reader :histogram, :model, :collection
|
|
4
|
+
def initialize set, model, collection
|
|
5
|
+
@set = set
|
|
6
|
+
@model = model
|
|
7
|
+
@collection = collection
|
|
8
|
+
end
|
|
9
|
+
|
|
10
|
+
def is_trainset?
|
|
11
|
+
@set == :train_set
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def var_name
|
|
15
|
+
@set.to_s
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def classes
|
|
19
|
+
model.classes
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def load &progress_callback
|
|
23
|
+
@progress_callback = progress_callback
|
|
24
|
+
|
|
25
|
+
load_classes if is_trainset?
|
|
26
|
+
init_variables
|
|
27
|
+
do_load
|
|
28
|
+
|
|
29
|
+
torch.eval <<-EOF, __FILE__, __LINE__
|
|
30
|
+
setmetatable(#{var_name}, {__index = function(t, i)
|
|
31
|
+
return {t.input[i], t.label[i]}
|
|
32
|
+
end}
|
|
33
|
+
);
|
|
34
|
+
function #{var_name}:size()
|
|
35
|
+
return #{var_name}.input:size(1)
|
|
36
|
+
end
|
|
37
|
+
EOF
|
|
38
|
+
self
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def dimensions
|
|
42
|
+
torch.eval("return (##{var_name}.input):totable()", __FILE__, __LINE__).values.map(&:to_i)
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def torch
|
|
46
|
+
model.torch
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
private
|
|
50
|
+
def load_classes
|
|
51
|
+
torch.eval <<-EOF, __FILE__, __LINE__
|
|
52
|
+
classes = {#{classes.map(&:inspect).join ", "}}
|
|
53
|
+
EOF
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
def load_from_cache(cached_file)
|
|
57
|
+
torch.eval "#{var_name} = torch.load('#{cached_file}')", __FILE__, __LINE__
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def do_load
|
|
61
|
+
values = collection.each_with_index.map do |data, index|
|
|
62
|
+
load_single(data, index)
|
|
63
|
+
end
|
|
64
|
+
@histogram = Hash[*values.group_by { |v| v }.flat_map { |k, v| [k, v.size] }]
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
def load_single(data, index)
|
|
68
|
+
@progress_callback.call
|
|
69
|
+
klass = model.prediction_class data
|
|
70
|
+
label_index = classes.index(klass)
|
|
71
|
+
raise "Returned class '#{klass}' is not one of #{classes}" if label_index.nil?
|
|
72
|
+
label_value = label_index+1
|
|
73
|
+
torch.eval <<-EOF, __FILE__, __LINE__
|
|
74
|
+
#{model.to_tensor("torchrb_data", data).strip}
|
|
75
|
+
#{var_name}.label[#{index+1}] = torch.LongTensor({#{label_value}})
|
|
76
|
+
#{var_name}.input[#{index+1}] = torchrb_data
|
|
77
|
+
EOF
|
|
78
|
+
klass
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
def cudify
|
|
82
|
+
torch.eval <<-EOF, __FILE__, __LINE__
|
|
83
|
+
#{var_name}.label:cuda()
|
|
84
|
+
#{var_name}.input:cuda()
|
|
85
|
+
EOF
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
def init_variables
|
|
89
|
+
torch.eval <<-EOF, __FILE__, __LINE__
|
|
90
|
+
#{var_name} = {
|
|
91
|
+
label= torch.LongTensor(#{collection.count}),
|
|
92
|
+
input= torch.#{model.tensor_type}(#{collection.count} , #{model.dimensions.join ", "})
|
|
93
|
+
}
|
|
94
|
+
EOF
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
def normalize!
|
|
98
|
+
if @is_trainset
|
|
99
|
+
torch.eval(<<-EOF, __FILE__, __LINE__).to_h.map { |k, v| {k.humanize => v.values} }.reduce({}, :merge)
|
|
100
|
+
mean = {} -- store the mean, to normalize the test set in the future
|
|
101
|
+
stdv = {} -- store the standard-deviation for the future
|
|
102
|
+
for i=1,#{model.dimensions.first-1} do -- over each image channel
|
|
103
|
+
mean[i] = #{var_name}.input[{ {}, {i}, {}, {} }]:mean() -- mean estimation
|
|
104
|
+
stdv[i] = #{var_name}.input[{ {}, {i}, {}, {} }]:std() -- std estimation
|
|
105
|
+
|
|
106
|
+
#{var_name}.input[{ {}, {i}, {}, {} }]:add(-mean[i]) -- mean subtraction
|
|
107
|
+
#{var_name}.input[{ {}, {i}, {}, {} }]:div(stdv[i]) -- std scaling
|
|
108
|
+
end
|
|
109
|
+
return {mean= mean, standard_diviation= stdv}
|
|
110
|
+
EOF
|
|
111
|
+
else
|
|
112
|
+
torch.eval <<-EOF, __FILE__, __LINE__
|
|
113
|
+
for i=1,#{model.dimensions.first-1} do -- over each image channel
|
|
114
|
+
#{var_name}.input[{ {}, {i}, {}, {} }]:add(-mean[i]) -- mean subtraction
|
|
115
|
+
#{var_name}.input[{ {}, {i}, {}, {} }]:div(stdv[i]) -- std scaling
|
|
116
|
+
end
|
|
117
|
+
EOF
|
|
118
|
+
end
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
end
|
data/lib/torchrb/lua.rb
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
class Torchrb::Lua
|
|
2
|
+
|
|
3
|
+
attr_accessor(:enable_cuda)
|
|
4
|
+
attr_accessor(:debug)
|
|
5
|
+
|
|
6
|
+
def initialize options={}
|
|
7
|
+
self.enable_cuda = options.delete(:enable_cuda) { false }
|
|
8
|
+
self.debug = options.delete(:debug) { false }
|
|
9
|
+
@additional_libraries = options.delete(:lua_libs){[]}
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def eval(command, file=__FILE__, line=__LINE__, debug: false)
|
|
14
|
+
load_libraries unless @libraries_loaded
|
|
15
|
+
@last_command = command
|
|
16
|
+
puts command if debug || @debug
|
|
17
|
+
state.eval command, nil, file, line
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
protected
|
|
21
|
+
def state
|
|
22
|
+
@state ||= Rufus::Lua::State.new
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
private
|
|
26
|
+
def load_libraries
|
|
27
|
+
@libraries_loaded = true
|
|
28
|
+
#load "torch"
|
|
29
|
+
cudnn_lib = File.realpath("lib/packages/cuda/lib64/libcudnn.so")
|
|
30
|
+
raise "Extract your CUDNN to #{cudnn_lib}" if enable_cuda && !File.exists?(cudnn_lib)
|
|
31
|
+
if enable_cuda
|
|
32
|
+
load_cuda(cudnn_lib)
|
|
33
|
+
else
|
|
34
|
+
load "nn"
|
|
35
|
+
end
|
|
36
|
+
@additional_libraries.each do |lib|
|
|
37
|
+
load lib
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
#load "image"
|
|
41
|
+
#load "optim"
|
|
42
|
+
#load 'gnuplot'
|
|
43
|
+
#eval "pp = require 'pl.pretty'", __FILE__, __LINE__
|
|
44
|
+
load_error_handler
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
def load_cuda(cudnn_lib)
|
|
48
|
+
begin
|
|
49
|
+
eval(<<-EOF) #Load the libcudnn upfront so it is in the file cache and can be found later.
|
|
50
|
+
local ffi = require 'ffi'
|
|
51
|
+
ffi.cdef[[size_t cudnnGetVersion();]]
|
|
52
|
+
local cudnn = ffi.load("#{cudnn_lib}")
|
|
53
|
+
cuda_version = tonumber(cudnn.cudnnGetVersion())
|
|
54
|
+
EOF
|
|
55
|
+
load "cudnn"
|
|
56
|
+
load "cunn"
|
|
57
|
+
p "LOADED CUDNN VERSION: #{state["cuda_version"]}"
|
|
58
|
+
rescue
|
|
59
|
+
self.enable_cuda = false
|
|
60
|
+
load "nn"
|
|
61
|
+
end
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
def load_error_handler
|
|
65
|
+
@state.set_error_handler do |msg|
|
|
66
|
+
puts msg
|
|
67
|
+
level = 2
|
|
68
|
+
loop do
|
|
69
|
+
info = @state.eval "return debug.getinfo(#{level}, \"nSl\")"
|
|
70
|
+
break if info.nil?
|
|
71
|
+
line = info['currentline'].to_i
|
|
72
|
+
file, ln = *info['source'].split(":")
|
|
73
|
+
puts "\t#{file}:#{line + ln.to_i} (#{info['name']})"
|
|
74
|
+
level += 1
|
|
75
|
+
end
|
|
76
|
+
puts @last_command
|
|
77
|
+
end
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
def load lua_lib
|
|
81
|
+
eval "require '#{lua_lib}'", __FILE__, __LINE__
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
end
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
class Torchrb::ModelBase
|
|
2
|
+
REQUIRED_OPTIONS = [:data_model]
|
|
3
|
+
|
|
4
|
+
class << self
|
|
5
|
+
def progress_callback progress: nil, message: nil, error_rate: Float::NAN
|
|
6
|
+
raise NotImplementedError.new("Implement this method in your Model")
|
|
7
|
+
end
|
|
8
|
+
|
|
9
|
+
def setup_nn options={}
|
|
10
|
+
check_options(options)
|
|
11
|
+
{
|
|
12
|
+
net: Torchrb::NN::Basic,
|
|
13
|
+
trainer: Torchrb::NN::TrainerDefault,
|
|
14
|
+
tensor_type: "DoubleTensor",
|
|
15
|
+
dimensions: [0],
|
|
16
|
+
classes: [],
|
|
17
|
+
dataset_split: [80, 10, 10],
|
|
18
|
+
normalize: false,
|
|
19
|
+
enable_cuda: false,
|
|
20
|
+
auto_store_trained_network: true,
|
|
21
|
+
network_storage_path: "tmp/cache/torchrb",
|
|
22
|
+
debug: false,
|
|
23
|
+
}.merge!(options).each do |option, default|
|
|
24
|
+
cattr_reader(option)
|
|
25
|
+
class_variable_set(:"@@#{option}", default)
|
|
26
|
+
end
|
|
27
|
+
cattr_reader(:torch) { Torchrb::Torch.new options }
|
|
28
|
+
|
|
29
|
+
@net_options = load_extension(options[:net])
|
|
30
|
+
@trainer_options = load_extension(options[:trainer])
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
def error_rate
|
|
34
|
+
torch.error_rate
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
def train
|
|
38
|
+
progress_callback message: 'Loading data'
|
|
39
|
+
load_model_data
|
|
40
|
+
|
|
41
|
+
torch.iteration_callback= method(:progress_callback)
|
|
42
|
+
|
|
43
|
+
define_nn @net_options
|
|
44
|
+
define_trainer @trainer_options
|
|
45
|
+
|
|
46
|
+
torch.cudify if enable_cuda
|
|
47
|
+
|
|
48
|
+
progress_callback message: 'Start training'
|
|
49
|
+
torch.train
|
|
50
|
+
progress_callback message: 'Done'
|
|
51
|
+
|
|
52
|
+
torch.print_results
|
|
53
|
+
torch.store_network network_storage_path if auto_store_trained_network
|
|
54
|
+
|
|
55
|
+
after_training if respond_to?(:after_training)
|
|
56
|
+
torch.error_rate
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
def predict sample
|
|
60
|
+
torch.predict sample, network_storage_path
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
private
|
|
64
|
+
|
|
65
|
+
def check_options(options)
|
|
66
|
+
REQUIRED_OPTIONS.each do |required_option|
|
|
67
|
+
raise "Option '#{required_option}' is required." unless options.has_key?(required_option)
|
|
68
|
+
end
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
def load_model_data
|
|
72
|
+
raise "#{self} needs to implement '#to_tensor(var_name, data)' and '#prediction_class' method." unless respond_to?(:to_tensor, :prediction_class)
|
|
73
|
+
@progress = 0
|
|
74
|
+
start = 0
|
|
75
|
+
all_ids = data_model.ids.shuffle
|
|
76
|
+
[:train_set, :test_set, :validation_set].zip(dataset_split).map do |set, split|
|
|
77
|
+
next if split.nil?
|
|
78
|
+
size = all_ids.count * split.to_f / 100.0
|
|
79
|
+
offset = start
|
|
80
|
+
start = start + size
|
|
81
|
+
collection = data_model.where(id: all_ids.slice(offset, size))
|
|
82
|
+
load_dataset set, collection
|
|
83
|
+
end
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
def load_extension(extension)
|
|
87
|
+
if extension.is_a?(Hash)
|
|
88
|
+
extend extension.keys.first
|
|
89
|
+
extension.values.inject(&:merge)
|
|
90
|
+
else
|
|
91
|
+
extend extension
|
|
92
|
+
{}
|
|
93
|
+
end
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
def load_dataset set_name, collection
|
|
97
|
+
progress_callback progress: @progress, message: "Loading #{set_name.to_s.humanize} with #{collection.size} element(s)."
|
|
98
|
+
|
|
99
|
+
set = Torchrb::DataSet.new set_name, self, collection
|
|
100
|
+
set.load do
|
|
101
|
+
@progress += 0.333 / collection.size
|
|
102
|
+
progress_callback progress: @progress
|
|
103
|
+
end
|
|
104
|
+
set.normalize! if normalize && set.is_trainset?
|
|
105
|
+
end
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
end
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
module Torchrb::NN::Basic
|
|
2
|
+
|
|
3
|
+
def define_nn
|
|
4
|
+
input_layer = 1
|
|
5
|
+
interm_layer = 80
|
|
6
|
+
output_layer = model.classes.size
|
|
7
|
+
torch.eval(<<-EOF, __FILE__, __LINE__).to_h
|
|
8
|
+
net = nn.Sequential()
|
|
9
|
+
net:add(nn.Linear(#{input_layer}, #{interm_layer}))
|
|
10
|
+
net:add(nn.Linear(#{interm_layer}, #{output_layer}))
|
|
11
|
+
net:add(nn.LogSoftMax())
|
|
12
|
+
EOF
|
|
13
|
+
end
|
|
14
|
+
end
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
module Torchrb::NN::ImageDefault
|
|
2
|
+
|
|
3
|
+
def define_nn options
|
|
4
|
+
# Dimensions:
|
|
5
|
+
# [4,256,256] INPUT
|
|
6
|
+
# -> SpatialConvolution(nInputPlane=4, nOutputPlane=6, kernelW=5, kH=5, dimensionW=1, dH=1) -- dimension(Width|Height) defaults to 1
|
|
7
|
+
# -> outWidth = (width - kernelWidth) * dimensionWidth + 1 = (256 - 5) * 1 + 1 = 252
|
|
8
|
+
# -> outHeight= (height- kernelHeight) *dimensionHeight + 1 = (256 - 5) * 1 + 1 = 252
|
|
9
|
+
# -> SpatialMaxPooling(2,2,2,2) -- pad(Width|Height) defaults to 0
|
|
10
|
+
# -> outWidth = (width + 2*padWidth - kernelWidth) / dimensionWidth + 1
|
|
11
|
+
|
|
12
|
+
image_width_height = options[:image_size].max
|
|
13
|
+
kernel_width = 5
|
|
14
|
+
input_layer = 120*2
|
|
15
|
+
interm_layer = 84*2
|
|
16
|
+
output_layer = 2
|
|
17
|
+
|
|
18
|
+
view_size = ((image_width_height - kernel_width) * 1 + 1)
|
|
19
|
+
view_size = view_size/2
|
|
20
|
+
view_size = ((view_size - kernel_width) * 1 + 1)
|
|
21
|
+
view_size = view_size/2
|
|
22
|
+
|
|
23
|
+
torch.eval(<<-EOF, __FILE__, __LINE__).to_h
|
|
24
|
+
net = nn.Sequential() -- [ 4,256,256] 3,32,32
|
|
25
|
+
net:add(nn.SpatialConvolution(4, 6, #{kernel_width}, #{kernel_width})) -- 4 input image channels, 6 output channels, 5x5 convolution kernel -> [ 6,252,252] 6,28,28
|
|
26
|
+
net:add(nn.SpatialMaxPooling(2,2,2,2)) -- A max-pooling operation that looks at 2x2 windows and finds the max. -> [ 6,126,126] 6,14,14
|
|
27
|
+
net:add(nn.SpatialConvolution(6, 16, 5, 5)) -- -> [16,122,122] 16,10,10
|
|
28
|
+
net:add(nn.SpatialMaxPooling(2,2,2,2)) -- -> [16, 61, 61] 16, 5, 5
|
|
29
|
+
net:add(nn.View(#{16 * view_size * view_size})) -- reshapes from a 4D tensor of 16x5x5 into 1D tensor of 16*5*5 -> [59536] 400
|
|
30
|
+
net:add(nn.Linear(#{16 * view_size * view_size}, #{input_layer}) ) -- fully connected layer (matrix multiplication between input and weights)-> 120 <-- randomly choosen
|
|
31
|
+
net:add(nn.Linear(#{input_layer}, #{interm_layer})) -- -> 84 <-- randomly choosen
|
|
32
|
+
net:add(nn.Linear(#{interm_layer}, #{output_layer})) -- 2 is the number of outputs of the network (in this case, 2 digits) -> 2 <-- number of labels
|
|
33
|
+
net:add(nn.LogSoftMax()) -- -> 1 <-- which label?
|
|
34
|
+
|
|
35
|
+
local d = {}
|
|
36
|
+
for i,module in ipairs(net:listModules()) do
|
|
37
|
+
inSize = "[]"
|
|
38
|
+
outSize = "[]"
|
|
39
|
+
pcall(function () inSize = #module.input end)
|
|
40
|
+
pcall(function () outSize = #module.output end)
|
|
41
|
+
table.insert(d, {tostring(module), inSize, outSize} )
|
|
42
|
+
end
|
|
43
|
+
return d
|
|
44
|
+
EOF
|
|
45
|
+
end
|
|
46
|
+
end
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
module Torchrb::NN::TrainerDefault
|
|
2
|
+
|
|
3
|
+
def define_trainer options
|
|
4
|
+
torch.eval <<-EOF, __FILE__, __LINE__
|
|
5
|
+
number_of_iterations = #{options.fetch(:iterations){50}} -- Must be set for the callback to work
|
|
6
|
+
|
|
7
|
+
criterion = nn.ClassNLLCriterion()
|
|
8
|
+
|
|
9
|
+
trainer = nn.StochasticGradient(net, criterion)
|
|
10
|
+
trainer.verbose = false
|
|
11
|
+
trainer.learningRate = #{options.fetch(:learning_rate){0.001}}
|
|
12
|
+
trainer.maxIteration = number_of_iterations
|
|
13
|
+
trainer.hookIteration = iteration_callback
|
|
14
|
+
EOF
|
|
15
|
+
end
|
|
16
|
+
end
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
class Torchrb::Torch < Torchrb::Lua
|
|
2
|
+
|
|
3
|
+
attr_accessor(:network_loaded)
|
|
4
|
+
attr_accessor(:network_timestamp)
|
|
5
|
+
attr_accessor(:error_rate)
|
|
6
|
+
|
|
7
|
+
def initialize options={}
|
|
8
|
+
super
|
|
9
|
+
@network_loaded = false
|
|
10
|
+
@error_rate = Float::NAN
|
|
11
|
+
load_network options[:network_storage_path] unless network_loaded rescue nil
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def iteration_callback= callback
|
|
15
|
+
state.function "iteration_callback" do |trainer, iteration, currentError|
|
|
16
|
+
progress = iteration / state['number_of_iterations']
|
|
17
|
+
self.error_rate = currentError/100.0
|
|
18
|
+
callback.call progress: progress, error_rate: error_rate
|
|
19
|
+
end
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def train
|
|
23
|
+
eval <<-EOF, __FILE__, __LINE__
|
|
24
|
+
local oldprint = print
|
|
25
|
+
print = function(...)
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
trainer:train(train_set)
|
|
29
|
+
|
|
30
|
+
print = oldprint
|
|
31
|
+
EOF
|
|
32
|
+
self.network_loaded = true
|
|
33
|
+
self.network_timestamp = Time.now
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def predict sample, network_storage_path=nil
|
|
37
|
+
load_network network_storage_path unless network_loaded
|
|
38
|
+
|
|
39
|
+
classes = eval <<-EOF, __FILE__, __LINE__
|
|
40
|
+
#{sample.to_tensor("sample_data").strip}
|
|
41
|
+
local prediction = #{enable_cuda ? "net:forward(sample_data:cuda()):float()" : "net:forward(sample_data)"}
|
|
42
|
+
prediction = prediction:exp()
|
|
43
|
+
confidences = prediction:totable()
|
|
44
|
+
return classes
|
|
45
|
+
EOF
|
|
46
|
+
puts "predicted #{@state['confidences'].to_h} based on network @ #{network_timestamp}" if debug
|
|
47
|
+
classes = classes.to_h
|
|
48
|
+
@state['confidences'].to_h.map { |k, v| {classes[k] => v} }.reduce({}, :merge)
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
def load_network network_storage_path
|
|
52
|
+
raise "Neuronal net not trained yet. Call 'Torch#update_training_data'." unless File.exist?(network_storage_path)
|
|
53
|
+
metadata = eval(<<-EOF, __FILE__, __LINE__).to_ruby
|
|
54
|
+
net = torch.load('#{network_storage_path}')
|
|
55
|
+
metadata = torch.load('#{network_storage_path}.meta')
|
|
56
|
+
classes = metadata[1]
|
|
57
|
+
timestamp = metadata[3]
|
|
58
|
+
return metadata[2]
|
|
59
|
+
EOF
|
|
60
|
+
self.error_rate = metadata
|
|
61
|
+
self.network_timestamp = @state['timestamp']
|
|
62
|
+
puts "Network with metadata [#{@state['classes'].to_h}, #{error_rate}] loaded from #{network_storage_path} @ #{network_timestamp}" if debug
|
|
63
|
+
self.network_loaded = true
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
def store_network network_storage_path
|
|
67
|
+
eval <<-EOF, __FILE__, __LINE__
|
|
68
|
+
torch.save('#{network_storage_path}', net)
|
|
69
|
+
torch.save('#{network_storage_path}.meta', {classes, #{error_rate}, '#{network_timestamp}}'} )
|
|
70
|
+
EOF
|
|
71
|
+
puts "Network with metadata [#{@state['classes'].to_h}, #{error_rate}] stored in #{network_storage_path} @ #{network_timestamp}" if debug
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
def print_results
|
|
75
|
+
result = eval <<-EOF, __FILE__, __LINE__
|
|
76
|
+
class_performance = torch.LongTensor(#classes):fill(0):totable()
|
|
77
|
+
test_set_size = test_set:size()
|
|
78
|
+
for i=1,test_set_size do
|
|
79
|
+
local groundtruth = test_set.label[i]
|
|
80
|
+
local prediction = net:forward(test_set.input[i])
|
|
81
|
+
local confidences, indices = torch.sort(prediction, true) -- true means sort in descending order
|
|
82
|
+
|
|
83
|
+
class_performance[groundtruth] = class_performance[groundtruth] + 1
|
|
84
|
+
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
local result = {}
|
|
88
|
+
for i=1,#classes do
|
|
89
|
+
local confidence = 100*class_performance[i]/test_set_size
|
|
90
|
+
table.insert(result, { classes[i], confidence } )
|
|
91
|
+
end
|
|
92
|
+
return result
|
|
93
|
+
EOF
|
|
94
|
+
result = result.to_ruby.map(&:to_ruby)
|
|
95
|
+
if defined?(DEBUG)
|
|
96
|
+
puts "#" * 80
|
|
97
|
+
puts "Results: #{result.to_h}"
|
|
98
|
+
puts "#" * 80
|
|
99
|
+
end
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
def cudify
|
|
103
|
+
eval <<-EOF, __FILE__, __LINE__
|
|
104
|
+
-- print(sys.COLORS.red .. '==> using CUDA GPU #' .. cutorch.getDevice() .. sys.COLORS.black)
|
|
105
|
+
train_set.input = train_set.input:cuda()
|
|
106
|
+
train_set.label = train_set.label:cuda()
|
|
107
|
+
test_set.input = test_set.input:cuda()
|
|
108
|
+
test_set.label = test_set.label:cuda()
|
|
109
|
+
validation_set.input = validation_set.input:cuda()
|
|
110
|
+
validation_set.label = validation_set.label:cuda()
|
|
111
|
+
|
|
112
|
+
criterion = nn.ClassNLLCriterion():cuda()
|
|
113
|
+
net = cudnn.convert(net:cuda(), cudnn)
|
|
114
|
+
EOF
|
|
115
|
+
end
|
|
116
|
+
end
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
class Torchrb::Wrapper < Torchrb::Torch
|
|
2
|
+
|
|
3
|
+
cattr_accessor(:instances){ Hash.new }
|
|
4
|
+
def self.for model_class, options={}
|
|
5
|
+
@@instances[model_class] ||= new model_class, options
|
|
6
|
+
if block_given?
|
|
7
|
+
yield @@instances[model_class]
|
|
8
|
+
else
|
|
9
|
+
@@instances[model_class]
|
|
10
|
+
end
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
attr_reader :model, :progress
|
|
14
|
+
delegate :progress_callback, to: :model
|
|
15
|
+
|
|
16
|
+
def initialize model, options={}
|
|
17
|
+
raise "#{model} must be a class and extend Torchrb::ModelBase!" unless model.is_a?(Class) || model.class < Torchrb::ModelBase
|
|
18
|
+
@model = model
|
|
19
|
+
super(options)
|
|
20
|
+
self.class.include model.net
|
|
21
|
+
self.class.include model.trainer
|
|
22
|
+
model.setup self
|
|
23
|
+
end
|
|
24
|
+
private(:initialize)
|
|
25
|
+
|
|
26
|
+
def load_model_data
|
|
27
|
+
@progress = 0
|
|
28
|
+
load_dataset :train_set
|
|
29
|
+
load_dataset :test_set
|
|
30
|
+
load_dataset :validation_set
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
def train
|
|
34
|
+
define_nn
|
|
35
|
+
define_trainer
|
|
36
|
+
|
|
37
|
+
cudify if enable_cuda
|
|
38
|
+
super
|
|
39
|
+
print_results
|
|
40
|
+
store_network
|
|
41
|
+
error_rate
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
def predict sample
|
|
45
|
+
super sample
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
private
|
|
49
|
+
def load_dataset set_name
|
|
50
|
+
set_size = model.send(set_name).size
|
|
51
|
+
model.progress_callback progress, message: "Loading #{set_name.to_s.humanize} with #{set_size} element(s)."
|
|
52
|
+
|
|
53
|
+
set = Torchrb::DataSet.new set_name, self
|
|
54
|
+
set.load do
|
|
55
|
+
@progress += 0.333 / set_size
|
|
56
|
+
model.progress_callback progress
|
|
57
|
+
end
|
|
58
|
+
set.normalize! if model.normalize? && set.is_trainset?
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
private
|
|
62
|
+
|
|
63
|
+
def engine_storage
|
|
64
|
+
Visit.cache_dir + "/net.t7"
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
end
|
data/torchrb.gemspec
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
lib = File.expand_path('../lib', __FILE__)
|
|
3
|
+
$LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
|
|
4
|
+
require 'torchrb/version'
|
|
5
|
+
|
|
6
|
+
Gem::Specification.new do |spec|
|
|
7
|
+
spec.name = "torchrb"
|
|
8
|
+
spec.version = Torchrb::VERSION
|
|
9
|
+
spec.authors = ["Michael Sprauer"]
|
|
10
|
+
spec.email = ['ms@inline.de']
|
|
11
|
+
spec.homepage = 'http://www.inline.de/'
|
|
12
|
+
spec.license = 'MIT'
|
|
13
|
+
|
|
14
|
+
spec.summary = %q(Torch wrapper for ruby)
|
|
15
|
+
spec.description = spec.summary + ' '
|
|
16
|
+
|
|
17
|
+
spec.files = `git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
|
|
18
|
+
spec.test_files = spec.files.grep(%r{^test/})
|
|
19
|
+
spec.bindir = "exe"
|
|
20
|
+
spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
|
|
21
|
+
spec.require_paths = ["lib"]
|
|
22
|
+
#spec.required_ruby_version = '>= 2.1.0'
|
|
23
|
+
|
|
24
|
+
spec.add_runtime_dependency 'rufus-lua', '~> 1.1.2'
|
|
25
|
+
spec.add_runtime_dependency 'activesupport'
|
|
26
|
+
|
|
27
|
+
spec.add_development_dependency "bundler", "~> 1.11"
|
|
28
|
+
spec.add_development_dependency "rake", "~> 10.0"
|
|
29
|
+
spec.add_development_dependency "minitest", "~> 5.0"
|
|
30
|
+
spec.add_development_dependency "mocha"
|
|
31
|
+
spec.add_development_dependency "ascii_charts"
|
|
32
|
+
# gem.add_development_dependency 'ruby-prof' # <- don't need this normally.
|
|
33
|
+
end
|
metadata
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
|
2
|
+
name: torchrb
|
|
3
|
+
version: !ruby/object:Gem::Version
|
|
4
|
+
version: 0.2.0
|
|
5
|
+
platform: ruby
|
|
6
|
+
authors:
|
|
7
|
+
- Michael Sprauer
|
|
8
|
+
autorequire:
|
|
9
|
+
bindir: exe
|
|
10
|
+
cert_chain: []
|
|
11
|
+
date: 2016-07-08 00:00:00.000000000 Z
|
|
12
|
+
dependencies:
|
|
13
|
+
- !ruby/object:Gem::Dependency
|
|
14
|
+
name: rufus-lua
|
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
|
16
|
+
requirements:
|
|
17
|
+
- - "~>"
|
|
18
|
+
- !ruby/object:Gem::Version
|
|
19
|
+
version: 1.1.2
|
|
20
|
+
type: :runtime
|
|
21
|
+
prerelease: false
|
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
23
|
+
requirements:
|
|
24
|
+
- - "~>"
|
|
25
|
+
- !ruby/object:Gem::Version
|
|
26
|
+
version: 1.1.2
|
|
27
|
+
- !ruby/object:Gem::Dependency
|
|
28
|
+
name: activesupport
|
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
|
30
|
+
requirements:
|
|
31
|
+
- - ">="
|
|
32
|
+
- !ruby/object:Gem::Version
|
|
33
|
+
version: '0'
|
|
34
|
+
type: :runtime
|
|
35
|
+
prerelease: false
|
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
37
|
+
requirements:
|
|
38
|
+
- - ">="
|
|
39
|
+
- !ruby/object:Gem::Version
|
|
40
|
+
version: '0'
|
|
41
|
+
- !ruby/object:Gem::Dependency
|
|
42
|
+
name: bundler
|
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
|
44
|
+
requirements:
|
|
45
|
+
- - "~>"
|
|
46
|
+
- !ruby/object:Gem::Version
|
|
47
|
+
version: '1.11'
|
|
48
|
+
type: :development
|
|
49
|
+
prerelease: false
|
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
51
|
+
requirements:
|
|
52
|
+
- - "~>"
|
|
53
|
+
- !ruby/object:Gem::Version
|
|
54
|
+
version: '1.11'
|
|
55
|
+
- !ruby/object:Gem::Dependency
|
|
56
|
+
name: rake
|
|
57
|
+
requirement: !ruby/object:Gem::Requirement
|
|
58
|
+
requirements:
|
|
59
|
+
- - "~>"
|
|
60
|
+
- !ruby/object:Gem::Version
|
|
61
|
+
version: '10.0'
|
|
62
|
+
type: :development
|
|
63
|
+
prerelease: false
|
|
64
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
65
|
+
requirements:
|
|
66
|
+
- - "~>"
|
|
67
|
+
- !ruby/object:Gem::Version
|
|
68
|
+
version: '10.0'
|
|
69
|
+
- !ruby/object:Gem::Dependency
|
|
70
|
+
name: minitest
|
|
71
|
+
requirement: !ruby/object:Gem::Requirement
|
|
72
|
+
requirements:
|
|
73
|
+
- - "~>"
|
|
74
|
+
- !ruby/object:Gem::Version
|
|
75
|
+
version: '5.0'
|
|
76
|
+
type: :development
|
|
77
|
+
prerelease: false
|
|
78
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
79
|
+
requirements:
|
|
80
|
+
- - "~>"
|
|
81
|
+
- !ruby/object:Gem::Version
|
|
82
|
+
version: '5.0'
|
|
83
|
+
- !ruby/object:Gem::Dependency
|
|
84
|
+
name: mocha
|
|
85
|
+
requirement: !ruby/object:Gem::Requirement
|
|
86
|
+
requirements:
|
|
87
|
+
- - ">="
|
|
88
|
+
- !ruby/object:Gem::Version
|
|
89
|
+
version: '0'
|
|
90
|
+
type: :development
|
|
91
|
+
prerelease: false
|
|
92
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
93
|
+
requirements:
|
|
94
|
+
- - ">="
|
|
95
|
+
- !ruby/object:Gem::Version
|
|
96
|
+
version: '0'
|
|
97
|
+
- !ruby/object:Gem::Dependency
|
|
98
|
+
name: ascii_charts
|
|
99
|
+
requirement: !ruby/object:Gem::Requirement
|
|
100
|
+
requirements:
|
|
101
|
+
- - ">="
|
|
102
|
+
- !ruby/object:Gem::Version
|
|
103
|
+
version: '0'
|
|
104
|
+
type: :development
|
|
105
|
+
prerelease: false
|
|
106
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
107
|
+
requirements:
|
|
108
|
+
- - ">="
|
|
109
|
+
- !ruby/object:Gem::Version
|
|
110
|
+
version: '0'
|
|
111
|
+
description: 'Torch wrapper for ruby '
|
|
112
|
+
email:
|
|
113
|
+
- ms@inline.de
|
|
114
|
+
executables: []
|
|
115
|
+
extensions: []
|
|
116
|
+
extra_rdoc_files: []
|
|
117
|
+
files:
|
|
118
|
+
- ".gitignore"
|
|
119
|
+
- ".gitlab-ci.yml"
|
|
120
|
+
- ".gitmodules"
|
|
121
|
+
- Gemfile
|
|
122
|
+
- README.md
|
|
123
|
+
- Rakefile
|
|
124
|
+
- bin/console
|
|
125
|
+
- bin/setup
|
|
126
|
+
- lib/torchrb.rb
|
|
127
|
+
- lib/torchrb/data_set.rb
|
|
128
|
+
- lib/torchrb/lua.rb
|
|
129
|
+
- lib/torchrb/model_base.rb
|
|
130
|
+
- lib/torchrb/nn/basic.rb
|
|
131
|
+
- lib/torchrb/nn/image_default.rb
|
|
132
|
+
- lib/torchrb/nn/trainer_default.rb
|
|
133
|
+
- lib/torchrb/torch.rb
|
|
134
|
+
- lib/torchrb/version.rb
|
|
135
|
+
- lib/torchrb/wrapper.rb
|
|
136
|
+
- torchrb.gemspec
|
|
137
|
+
homepage: http://www.inline.de/
|
|
138
|
+
licenses:
|
|
139
|
+
- MIT
|
|
140
|
+
metadata: {}
|
|
141
|
+
post_install_message:
|
|
142
|
+
rdoc_options: []
|
|
143
|
+
require_paths:
|
|
144
|
+
- lib
|
|
145
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
|
146
|
+
requirements:
|
|
147
|
+
- - ">="
|
|
148
|
+
- !ruby/object:Gem::Version
|
|
149
|
+
version: '0'
|
|
150
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
|
151
|
+
requirements:
|
|
152
|
+
- - ">="
|
|
153
|
+
- !ruby/object:Gem::Version
|
|
154
|
+
version: '0'
|
|
155
|
+
requirements: []
|
|
156
|
+
rubyforge_project:
|
|
157
|
+
rubygems_version: 2.5.1
|
|
158
|
+
signing_key:
|
|
159
|
+
specification_version: 4
|
|
160
|
+
summary: Torch wrapper for ruby
|
|
161
|
+
test_files: []
|