torchrb 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA1:
3
+ metadata.gz: 1ee6efb82ee36579771698f9f98c889a4589c335
4
+ data.tar.gz: 8719bd495d8c195968c92ca7af17a1aebc911ebb
5
+ SHA512:
6
+ metadata.gz: a1c2f23dbc99c21d72a4601d1feb4691304da174f752ef6c72849fc2b6a93ebd370c30d31b440197ba83c3fc634b1ce381303df2e7b83d4017db062c7872099f
7
+ data.tar.gz: c4ce2f992d97f40c4ac4d1930aa55047ed308104286e3bd87e1f816ce780e2cfa6f19a2ce78d3dde6d1a03c36d8644224fe26eb3d363fd298a1e66d4c28b62c5
data/.gitignore ADDED
@@ -0,0 +1,9 @@
1
+ /.bundle/
2
+ /.yardoc
3
+ /Gemfile.lock
4
+ /_yardoc/
5
+ /coverage/
6
+ /doc/
7
+ /pkg/
8
+ /spec/reports/
9
+ /tmp/
data/.gitlab-ci.yml ADDED
@@ -0,0 +1,22 @@
1
+ image: ruby_nodejs:2.3.0
2
+
3
+ stages:
4
+ - build
5
+ - test
6
+
7
+ build:
8
+ tags:
9
+ - ruby 2.3.0
10
+ stage: build
11
+ script:
12
+ - mkdir -p /cache/bundle
13
+ - bash -l -c "cd lib/torch && ./install.sh"
14
+ - bash -l -c "RAILS_ENV=test bundle install --jobs $(nproc) --path /cache/bundle"
15
+
16
+ test:
17
+ tags:
18
+ - ruby 2.3.0
19
+ except:
20
+ - tags
21
+ script:
22
+ - bundle exec rake test
data/.gitmodules ADDED
@@ -0,0 +1,3 @@
1
+ [submodule "lib/torch"]
2
+ path = lib/torch
3
+ url = https://github.com/torch/distro.git
data/Gemfile ADDED
@@ -0,0 +1,4 @@
1
+ source 'https://rubygems.org'
2
+
3
+ # Specify your gem's dependencies in torchrb.gemspec
4
+ gemspec
data/README.md ADDED
@@ -0,0 +1,34 @@
1
+ # Torchrb
2
+
3
+ A simple torch wrapper for ruby. Supports cuda.
4
+
5
+ ## Installation
6
+
7
+ Add this line to your application's Gemfile:
8
+
9
+ ```ruby
10
+ gem 'torchrb'
11
+ ```
12
+
13
+ And then execute:
14
+
15
+ $ bundle
16
+
17
+ Or install it yourself as:
18
+
19
+ $ gem install torchrb
20
+
21
+ ## Usage
22
+
23
+ TODO: Write usage instructions here
24
+
25
+ ## Development
26
+
27
+ After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
28
+
29
+ To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org).
30
+
31
+ ## Contributing
32
+
33
+ Bug reports and pull requests are welcome https://git.inline.de/inline-de/torchrb
34
+
data/Rakefile ADDED
@@ -0,0 +1,10 @@
1
+ require "bundler/gem_tasks"
2
+ require "rake/testtask"
3
+
4
+ Rake::TestTask.new(:test) do |t|
5
+ t.libs << "test"
6
+ t.libs << "lib"
7
+ t.test_files = FileList['test/**/*_test.rb']
8
+ end
9
+
10
+ task :default => :spec
data/bin/console ADDED
@@ -0,0 +1,14 @@
1
+ #!/usr/bin/env ruby
2
+
3
+ require "bundler/setup"
4
+ require "torchrb"
5
+
6
+ # You can add fixtures and/or initialization code here to make experimenting
7
+ # with your gem easier. You can also use a different console, if you like.
8
+
9
+ # (If you use this, don't forget to add pry to your Gemfile!)
10
+ # require "pry"
11
+ # Pry.start
12
+
13
+ require "irb"
14
+ IRB.start
data/bin/setup ADDED
@@ -0,0 +1,10 @@
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ IFS=$'\n\t'
4
+ set -vx
5
+
6
+ bundle install
7
+
8
+ git submodule update --recursive
9
+
10
+ # Do any other automated setup that you need to do here
data/lib/torchrb.rb ADDED
@@ -0,0 +1,32 @@
1
+ require 'active_support/core_ext/module'
2
+ require 'digest'
3
+ require 'rufus-lua'
4
+
5
+ require "torchrb/version"
6
+
7
+ require 'torchrb/lua'
8
+ require 'torchrb/torch'
9
+ require 'torchrb/wrapper'
10
+ require 'torchrb/data_set'
11
+ require 'torchrb/model_base'
12
+ require 'torchrb/nn/basic'
13
+ require 'torchrb/nn/trainer_default'
14
+ require 'torchrb/nn/image_default'
15
+
16
+ module Torchrb
17
+
18
+ module_function
19
+ def train model_class
20
+ Torchrb::Wrapper.for model_class do |model|
21
+ model.load_model_data
22
+ model.train
23
+ end
24
+ end
25
+
26
+ def predict sample
27
+ Torchrb::Wrapper.for sample.class do |model|
28
+ model.predict sample
29
+ end
30
+ end
31
+
32
+ end
@@ -0,0 +1,121 @@
1
+ class Torchrb::DataSet
2
+
3
+ attr_reader :histogram, :model, :collection
4
+ def initialize set, model, collection
5
+ @set = set
6
+ @model = model
7
+ @collection = collection
8
+ end
9
+
10
+ def is_trainset?
11
+ @set == :train_set
12
+ end
13
+
14
+ def var_name
15
+ @set.to_s
16
+ end
17
+
18
+ def classes
19
+ model.classes
20
+ end
21
+
22
+ def load &progress_callback
23
+ @progress_callback = progress_callback
24
+
25
+ load_classes if is_trainset?
26
+ init_variables
27
+ do_load
28
+
29
+ torch.eval <<-EOF, __FILE__, __LINE__
30
+ setmetatable(#{var_name}, {__index = function(t, i)
31
+ return {t.input[i], t.label[i]}
32
+ end}
33
+ );
34
+ function #{var_name}:size()
35
+ return #{var_name}.input:size(1)
36
+ end
37
+ EOF
38
+ self
39
+ end
40
+
41
+ def dimensions
42
+ torch.eval("return (##{var_name}.input):totable()", __FILE__, __LINE__).values.map(&:to_i)
43
+ end
44
+
45
+ def torch
46
+ model.torch
47
+ end
48
+
49
+ private
50
+ def load_classes
51
+ torch.eval <<-EOF, __FILE__, __LINE__
52
+ classes = {#{classes.map(&:inspect).join ", "}}
53
+ EOF
54
+ end
55
+
56
+ def load_from_cache(cached_file)
57
+ torch.eval "#{var_name} = torch.load('#{cached_file}')", __FILE__, __LINE__
58
+ end
59
+
60
+ def do_load
61
+ values = collection.each_with_index.map do |data, index|
62
+ load_single(data, index)
63
+ end
64
+ @histogram = Hash[*values.group_by { |v| v }.flat_map { |k, v| [k, v.size] }]
65
+ end
66
+
67
+ def load_single(data, index)
68
+ @progress_callback.call
69
+ klass = model.prediction_class data
70
+ label_index = classes.index(klass)
71
+ raise "Returned class '#{klass}' is not one of #{classes}" if label_index.nil?
72
+ label_value = label_index+1
73
+ torch.eval <<-EOF, __FILE__, __LINE__
74
+ #{model.to_tensor("torchrb_data", data).strip}
75
+ #{var_name}.label[#{index+1}] = torch.LongTensor({#{label_value}})
76
+ #{var_name}.input[#{index+1}] = torchrb_data
77
+ EOF
78
+ klass
79
+ end
80
+
81
+ def cudify
82
+ torch.eval <<-EOF, __FILE__, __LINE__
83
+ #{var_name}.label:cuda()
84
+ #{var_name}.input:cuda()
85
+ EOF
86
+ end
87
+
88
+ def init_variables
89
+ torch.eval <<-EOF, __FILE__, __LINE__
90
+ #{var_name} = {
91
+ label= torch.LongTensor(#{collection.count}),
92
+ input= torch.#{model.tensor_type}(#{collection.count} , #{model.dimensions.join ", "})
93
+ }
94
+ EOF
95
+ end
96
+
97
+ def normalize!
98
+ if @is_trainset
99
+ torch.eval(<<-EOF, __FILE__, __LINE__).to_h.map { |k, v| {k.humanize => v.values} }.reduce({}, :merge)
100
+ mean = {} -- store the mean, to normalize the test set in the future
101
+ stdv = {} -- store the standard-deviation for the future
102
+ for i=1,#{model.dimensions.first-1} do -- over each image channel
103
+ mean[i] = #{var_name}.input[{ {}, {i}, {}, {} }]:mean() -- mean estimation
104
+ stdv[i] = #{var_name}.input[{ {}, {i}, {}, {} }]:std() -- std estimation
105
+
106
+ #{var_name}.input[{ {}, {i}, {}, {} }]:add(-mean[i]) -- mean subtraction
107
+ #{var_name}.input[{ {}, {i}, {}, {} }]:div(stdv[i]) -- std scaling
108
+ end
109
+ return {mean= mean, standard_diviation= stdv}
110
+ EOF
111
+ else
112
+ torch.eval <<-EOF, __FILE__, __LINE__
113
+ for i=1,#{model.dimensions.first-1} do -- over each image channel
114
+ #{var_name}.input[{ {}, {i}, {}, {} }]:add(-mean[i]) -- mean subtraction
115
+ #{var_name}.input[{ {}, {i}, {}, {} }]:div(stdv[i]) -- std scaling
116
+ end
117
+ EOF
118
+ end
119
+ end
120
+
121
+ end
@@ -0,0 +1,84 @@
1
+ class Torchrb::Lua
2
+
3
+ attr_accessor(:enable_cuda)
4
+ attr_accessor(:debug)
5
+
6
+ def initialize options={}
7
+ self.enable_cuda = options.delete(:enable_cuda) { false }
8
+ self.debug = options.delete(:debug) { false }
9
+ @additional_libraries = options.delete(:lua_libs){[]}
10
+ end
11
+
12
+
13
+ def eval(command, file=__FILE__, line=__LINE__, debug: false)
14
+ load_libraries unless @libraries_loaded
15
+ @last_command = command
16
+ puts command if debug || @debug
17
+ state.eval command, nil, file, line
18
+ end
19
+
20
+ protected
21
+ def state
22
+ @state ||= Rufus::Lua::State.new
23
+ end
24
+
25
+ private
26
+ def load_libraries
27
+ @libraries_loaded = true
28
+ #load "torch"
29
+ cudnn_lib = File.realpath("lib/packages/cuda/lib64/libcudnn.so")
30
+ raise "Extract your CUDNN to #{cudnn_lib}" if enable_cuda && !File.exists?(cudnn_lib)
31
+ if enable_cuda
32
+ load_cuda(cudnn_lib)
33
+ else
34
+ load "nn"
35
+ end
36
+ @additional_libraries.each do |lib|
37
+ load lib
38
+ end
39
+
40
+ #load "image"
41
+ #load "optim"
42
+ #load 'gnuplot'
43
+ #eval "pp = require 'pl.pretty'", __FILE__, __LINE__
44
+ load_error_handler
45
+ end
46
+
47
+ def load_cuda(cudnn_lib)
48
+ begin
49
+ eval(<<-EOF) #Load the libcudnn upfront so it is in the file cache and can be found later.
50
+ local ffi = require 'ffi'
51
+ ffi.cdef[[size_t cudnnGetVersion();]]
52
+ local cudnn = ffi.load("#{cudnn_lib}")
53
+ cuda_version = tonumber(cudnn.cudnnGetVersion())
54
+ EOF
55
+ load "cudnn"
56
+ load "cunn"
57
+ p "LOADED CUDNN VERSION: #{state["cuda_version"]}"
58
+ rescue
59
+ self.enable_cuda = false
60
+ load "nn"
61
+ end
62
+ end
63
+
64
+ def load_error_handler
65
+ @state.set_error_handler do |msg|
66
+ puts msg
67
+ level = 2
68
+ loop do
69
+ info = @state.eval "return debug.getinfo(#{level}, \"nSl\")"
70
+ break if info.nil?
71
+ line = info['currentline'].to_i
72
+ file, ln = *info['source'].split(":")
73
+ puts "\t#{file}:#{line + ln.to_i} (#{info['name']})"
74
+ level += 1
75
+ end
76
+ puts @last_command
77
+ end
78
+ end
79
+
80
+ def load lua_lib
81
+ eval "require '#{lua_lib}'", __FILE__, __LINE__
82
+ end
83
+
84
+ end
@@ -0,0 +1,108 @@
1
+ class Torchrb::ModelBase
2
+ REQUIRED_OPTIONS = [:data_model]
3
+
4
+ class << self
5
+ def progress_callback progress: nil, message: nil, error_rate: Float::NAN
6
+ raise NotImplementedError.new("Implement this method in your Model")
7
+ end
8
+
9
+ def setup_nn options={}
10
+ check_options(options)
11
+ {
12
+ net: Torchrb::NN::Basic,
13
+ trainer: Torchrb::NN::TrainerDefault,
14
+ tensor_type: "DoubleTensor",
15
+ dimensions: [0],
16
+ classes: [],
17
+ dataset_split: [80, 10, 10],
18
+ normalize: false,
19
+ enable_cuda: false,
20
+ auto_store_trained_network: true,
21
+ network_storage_path: "tmp/cache/torchrb",
22
+ debug: false,
23
+ }.merge!(options).each do |option, default|
24
+ cattr_reader(option)
25
+ class_variable_set(:"@@#{option}", default)
26
+ end
27
+ cattr_reader(:torch) { Torchrb::Torch.new options }
28
+
29
+ @net_options = load_extension(options[:net])
30
+ @trainer_options = load_extension(options[:trainer])
31
+ end
32
+
33
+ def error_rate
34
+ torch.error_rate
35
+ end
36
+
37
+ def train
38
+ progress_callback message: 'Loading data'
39
+ load_model_data
40
+
41
+ torch.iteration_callback= method(:progress_callback)
42
+
43
+ define_nn @net_options
44
+ define_trainer @trainer_options
45
+
46
+ torch.cudify if enable_cuda
47
+
48
+ progress_callback message: 'Start training'
49
+ torch.train
50
+ progress_callback message: 'Done'
51
+
52
+ torch.print_results
53
+ torch.store_network network_storage_path if auto_store_trained_network
54
+
55
+ after_training if respond_to?(:after_training)
56
+ torch.error_rate
57
+ end
58
+
59
+ def predict sample
60
+ torch.predict sample, network_storage_path
61
+ end
62
+
63
+ private
64
+
65
+ def check_options(options)
66
+ REQUIRED_OPTIONS.each do |required_option|
67
+ raise "Option '#{required_option}' is required." unless options.has_key?(required_option)
68
+ end
69
+ end
70
+
71
+ def load_model_data
72
+ raise "#{self} needs to implement '#to_tensor(var_name, data)' and '#prediction_class' method." unless respond_to?(:to_tensor, :prediction_class)
73
+ @progress = 0
74
+ start = 0
75
+ all_ids = data_model.ids.shuffle
76
+ [:train_set, :test_set, :validation_set].zip(dataset_split).map do |set, split|
77
+ next if split.nil?
78
+ size = all_ids.count * split.to_f / 100.0
79
+ offset = start
80
+ start = start + size
81
+ collection = data_model.where(id: all_ids.slice(offset, size))
82
+ load_dataset set, collection
83
+ end
84
+ end
85
+
86
+ def load_extension(extension)
87
+ if extension.is_a?(Hash)
88
+ extend extension.keys.first
89
+ extension.values.inject(&:merge)
90
+ else
91
+ extend extension
92
+ {}
93
+ end
94
+ end
95
+
96
+ def load_dataset set_name, collection
97
+ progress_callback progress: @progress, message: "Loading #{set_name.to_s.humanize} with #{collection.size} element(s)."
98
+
99
+ set = Torchrb::DataSet.new set_name, self, collection
100
+ set.load do
101
+ @progress += 0.333 / collection.size
102
+ progress_callback progress: @progress
103
+ end
104
+ set.normalize! if normalize && set.is_trainset?
105
+ end
106
+ end
107
+
108
+ end
@@ -0,0 +1,14 @@
1
+ module Torchrb::NN::Basic
2
+
3
+ def define_nn
4
+ input_layer = 1
5
+ interm_layer = 80
6
+ output_layer = model.classes.size
7
+ torch.eval(<<-EOF, __FILE__, __LINE__).to_h
8
+ net = nn.Sequential()
9
+ net:add(nn.Linear(#{input_layer}, #{interm_layer}))
10
+ net:add(nn.Linear(#{interm_layer}, #{output_layer}))
11
+ net:add(nn.LogSoftMax())
12
+ EOF
13
+ end
14
+ end
@@ -0,0 +1,46 @@
1
+ module Torchrb::NN::ImageDefault
2
+
3
+ def define_nn options
4
+ # Dimensions:
5
+ # [4,256,256] INPUT
6
+ # -> SpatialConvolution(nInputPlane=4, nOutputPlane=6, kernelW=5, kH=5, dimensionW=1, dH=1) -- dimension(Width|Height) defaults to 1
7
+ # -> outWidth = (width - kernelWidth) * dimensionWidth + 1 = (256 - 5) * 1 + 1 = 252
8
+ # -> outHeight= (height- kernelHeight) *dimensionHeight + 1 = (256 - 5) * 1 + 1 = 252
9
+ # -> SpatialMaxPooling(2,2,2,2) -- pad(Width|Height) defaults to 0
10
+ # -> outWidth = (width + 2*padWidth - kernelWidth) / dimensionWidth + 1
11
+
12
+ image_width_height = options[:image_size].max
13
+ kernel_width = 5
14
+ input_layer = 120*2
15
+ interm_layer = 84*2
16
+ output_layer = 2
17
+
18
+ view_size = ((image_width_height - kernel_width) * 1 + 1)
19
+ view_size = view_size/2
20
+ view_size = ((view_size - kernel_width) * 1 + 1)
21
+ view_size = view_size/2
22
+
23
+ torch.eval(<<-EOF, __FILE__, __LINE__).to_h
24
+ net = nn.Sequential() -- [ 4,256,256] 3,32,32
25
+ net:add(nn.SpatialConvolution(4, 6, #{kernel_width}, #{kernel_width})) -- 4 input image channels, 6 output channels, 5x5 convolution kernel -> [ 6,252,252] 6,28,28
26
+ net:add(nn.SpatialMaxPooling(2,2,2,2)) -- A max-pooling operation that looks at 2x2 windows and finds the max. -> [ 6,126,126] 6,14,14
27
+ net:add(nn.SpatialConvolution(6, 16, 5, 5)) -- -> [16,122,122] 16,10,10
28
+ net:add(nn.SpatialMaxPooling(2,2,2,2)) -- -> [16, 61, 61] 16, 5, 5
29
+ net:add(nn.View(#{16 * view_size * view_size})) -- reshapes from a 4D tensor of 16x5x5 into 1D tensor of 16*5*5 -> [59536] 400
30
+ net:add(nn.Linear(#{16 * view_size * view_size}, #{input_layer}) ) -- fully connected layer (matrix multiplication between input and weights)-> 120 <-- randomly choosen
31
+ net:add(nn.Linear(#{input_layer}, #{interm_layer})) -- -> 84 <-- randomly choosen
32
+ net:add(nn.Linear(#{interm_layer}, #{output_layer})) -- 2 is the number of outputs of the network (in this case, 2 digits) -> 2 <-- number of labels
33
+ net:add(nn.LogSoftMax()) -- -> 1 <-- which label?
34
+
35
+ local d = {}
36
+ for i,module in ipairs(net:listModules()) do
37
+ inSize = "[]"
38
+ outSize = "[]"
39
+ pcall(function () inSize = #module.input end)
40
+ pcall(function () outSize = #module.output end)
41
+ table.insert(d, {tostring(module), inSize, outSize} )
42
+ end
43
+ return d
44
+ EOF
45
+ end
46
+ end
@@ -0,0 +1,16 @@
1
+ module Torchrb::NN::TrainerDefault
2
+
3
+ def define_trainer options
4
+ torch.eval <<-EOF, __FILE__, __LINE__
5
+ number_of_iterations = #{options.fetch(:iterations){50}} -- Must be set for the callback to work
6
+
7
+ criterion = nn.ClassNLLCriterion()
8
+
9
+ trainer = nn.StochasticGradient(net, criterion)
10
+ trainer.verbose = false
11
+ trainer.learningRate = #{options.fetch(:learning_rate){0.001}}
12
+ trainer.maxIteration = number_of_iterations
13
+ trainer.hookIteration = iteration_callback
14
+ EOF
15
+ end
16
+ end
@@ -0,0 +1,116 @@
1
+ class Torchrb::Torch < Torchrb::Lua
2
+
3
+ attr_accessor(:network_loaded)
4
+ attr_accessor(:network_timestamp)
5
+ attr_accessor(:error_rate)
6
+
7
+ def initialize options={}
8
+ super
9
+ @network_loaded = false
10
+ @error_rate = Float::NAN
11
+ load_network options[:network_storage_path] unless network_loaded rescue nil
12
+ end
13
+
14
+ def iteration_callback= callback
15
+ state.function "iteration_callback" do |trainer, iteration, currentError|
16
+ progress = iteration / state['number_of_iterations']
17
+ self.error_rate = currentError/100.0
18
+ callback.call progress: progress, error_rate: error_rate
19
+ end
20
+ end
21
+
22
+ def train
23
+ eval <<-EOF, __FILE__, __LINE__
24
+ local oldprint = print
25
+ print = function(...)
26
+ end
27
+
28
+ trainer:train(train_set)
29
+
30
+ print = oldprint
31
+ EOF
32
+ self.network_loaded = true
33
+ self.network_timestamp = Time.now
34
+ end
35
+
36
+ def predict sample, network_storage_path=nil
37
+ load_network network_storage_path unless network_loaded
38
+
39
+ classes = eval <<-EOF, __FILE__, __LINE__
40
+ #{sample.to_tensor("sample_data").strip}
41
+ local prediction = #{enable_cuda ? "net:forward(sample_data:cuda()):float()" : "net:forward(sample_data)"}
42
+ prediction = prediction:exp()
43
+ confidences = prediction:totable()
44
+ return classes
45
+ EOF
46
+ puts "predicted #{@state['confidences'].to_h} based on network @ #{network_timestamp}" if debug
47
+ classes = classes.to_h
48
+ @state['confidences'].to_h.map { |k, v| {classes[k] => v} }.reduce({}, :merge)
49
+ end
50
+
51
+ def load_network network_storage_path
52
+ raise "Neuronal net not trained yet. Call 'Torch#update_training_data'." unless File.exist?(network_storage_path)
53
+ metadata = eval(<<-EOF, __FILE__, __LINE__).to_ruby
54
+ net = torch.load('#{network_storage_path}')
55
+ metadata = torch.load('#{network_storage_path}.meta')
56
+ classes = metadata[1]
57
+ timestamp = metadata[3]
58
+ return metadata[2]
59
+ EOF
60
+ self.error_rate = metadata
61
+ self.network_timestamp = @state['timestamp']
62
+ puts "Network with metadata [#{@state['classes'].to_h}, #{error_rate}] loaded from #{network_storage_path} @ #{network_timestamp}" if debug
63
+ self.network_loaded = true
64
+ end
65
+
66
+ def store_network network_storage_path
67
+ eval <<-EOF, __FILE__, __LINE__
68
+ torch.save('#{network_storage_path}', net)
69
+ torch.save('#{network_storage_path}.meta', {classes, #{error_rate}, '#{network_timestamp}}'} )
70
+ EOF
71
+ puts "Network with metadata [#{@state['classes'].to_h}, #{error_rate}] stored in #{network_storage_path} @ #{network_timestamp}" if debug
72
+ end
73
+
74
+ def print_results
75
+ result = eval <<-EOF, __FILE__, __LINE__
76
+ class_performance = torch.LongTensor(#classes):fill(0):totable()
77
+ test_set_size = test_set:size()
78
+ for i=1,test_set_size do
79
+ local groundtruth = test_set.label[i]
80
+ local prediction = net:forward(test_set.input[i])
81
+ local confidences, indices = torch.sort(prediction, true) -- true means sort in descending order
82
+
83
+ class_performance[groundtruth] = class_performance[groundtruth] + 1
84
+
85
+ end
86
+
87
+ local result = {}
88
+ for i=1,#classes do
89
+ local confidence = 100*class_performance[i]/test_set_size
90
+ table.insert(result, { classes[i], confidence } )
91
+ end
92
+ return result
93
+ EOF
94
+ result = result.to_ruby.map(&:to_ruby)
95
+ if defined?(DEBUG)
96
+ puts "#" * 80
97
+ puts "Results: #{result.to_h}"
98
+ puts "#" * 80
99
+ end
100
+ end
101
+
102
+ def cudify
103
+ eval <<-EOF, __FILE__, __LINE__
104
+ -- print(sys.COLORS.red .. '==> using CUDA GPU #' .. cutorch.getDevice() .. sys.COLORS.black)
105
+ train_set.input = train_set.input:cuda()
106
+ train_set.label = train_set.label:cuda()
107
+ test_set.input = test_set.input:cuda()
108
+ test_set.label = test_set.label:cuda()
109
+ validation_set.input = validation_set.input:cuda()
110
+ validation_set.label = validation_set.label:cuda()
111
+
112
+ criterion = nn.ClassNLLCriterion():cuda()
113
+ net = cudnn.convert(net:cuda(), cudnn)
114
+ EOF
115
+ end
116
+ end
@@ -0,0 +1,5 @@
1
+ module Torchrb
2
+ module NN
3
+ end
4
+ VERSION = "0.2.0"
5
+ end
@@ -0,0 +1,67 @@
1
+ class Torchrb::Wrapper < Torchrb::Torch
2
+
3
+ cattr_accessor(:instances){ Hash.new }
4
+ def self.for model_class, options={}
5
+ @@instances[model_class] ||= new model_class, options
6
+ if block_given?
7
+ yield @@instances[model_class]
8
+ else
9
+ @@instances[model_class]
10
+ end
11
+ end
12
+
13
+ attr_reader :model, :progress
14
+ delegate :progress_callback, to: :model
15
+
16
+ def initialize model, options={}
17
+ raise "#{model} must be a class and extend Torchrb::ModelBase!" unless model.is_a?(Class) || model.class < Torchrb::ModelBase
18
+ @model = model
19
+ super(options)
20
+ self.class.include model.net
21
+ self.class.include model.trainer
22
+ model.setup self
23
+ end
24
+ private(:initialize)
25
+
26
+ def load_model_data
27
+ @progress = 0
28
+ load_dataset :train_set
29
+ load_dataset :test_set
30
+ load_dataset :validation_set
31
+ end
32
+
33
+ def train
34
+ define_nn
35
+ define_trainer
36
+
37
+ cudify if enable_cuda
38
+ super
39
+ print_results
40
+ store_network
41
+ error_rate
42
+ end
43
+
44
+ def predict sample
45
+ super sample
46
+ end
47
+
48
+ private
49
+ def load_dataset set_name
50
+ set_size = model.send(set_name).size
51
+ model.progress_callback progress, message: "Loading #{set_name.to_s.humanize} with #{set_size} element(s)."
52
+
53
+ set = Torchrb::DataSet.new set_name, self
54
+ set.load do
55
+ @progress += 0.333 / set_size
56
+ model.progress_callback progress
57
+ end
58
+ set.normalize! if model.normalize? && set.is_trainset?
59
+ end
60
+
61
+ private
62
+
63
+ def engine_storage
64
+ Visit.cache_dir + "/net.t7"
65
+ end
66
+
67
+ end
data/torchrb.gemspec ADDED
@@ -0,0 +1,33 @@
1
+ # coding: utf-8
2
+ lib = File.expand_path('../lib', __FILE__)
3
+ $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
4
+ require 'torchrb/version'
5
+
6
+ Gem::Specification.new do |spec|
7
+ spec.name = "torchrb"
8
+ spec.version = Torchrb::VERSION
9
+ spec.authors = ["Michael Sprauer"]
10
+ spec.email = ['ms@inline.de']
11
+ spec.homepage = 'http://www.inline.de/'
12
+ spec.license = 'MIT'
13
+
14
+ spec.summary = %q(Torch wrapper for ruby)
15
+ spec.description = spec.summary + ' '
16
+
17
+ spec.files = `git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
18
+ spec.test_files = spec.files.grep(%r{^test/})
19
+ spec.bindir = "exe"
20
+ spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
21
+ spec.require_paths = ["lib"]
22
+ #spec.required_ruby_version = '>= 2.1.0'
23
+
24
+ spec.add_runtime_dependency 'rufus-lua', '~> 1.1.2'
25
+ spec.add_runtime_dependency 'activesupport'
26
+
27
+ spec.add_development_dependency "bundler", "~> 1.11"
28
+ spec.add_development_dependency "rake", "~> 10.0"
29
+ spec.add_development_dependency "minitest", "~> 5.0"
30
+ spec.add_development_dependency "mocha"
31
+ spec.add_development_dependency "ascii_charts"
32
+ # gem.add_development_dependency 'ruby-prof' # <- don't need this normally.
33
+ end
metadata ADDED
@@ -0,0 +1,161 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: torchrb
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.2.0
5
+ platform: ruby
6
+ authors:
7
+ - Michael Sprauer
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2016-07-08 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: rufus-lua
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: 1.1.2
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: 1.1.2
27
+ - !ruby/object:Gem::Dependency
28
+ name: activesupport
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '0'
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: bundler
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: '1.11'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
53
+ - !ruby/object:Gem::Version
54
+ version: '1.11'
55
+ - !ruby/object:Gem::Dependency
56
+ name: rake
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - "~>"
60
+ - !ruby/object:Gem::Version
61
+ version: '10.0'
62
+ type: :development
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - "~>"
67
+ - !ruby/object:Gem::Version
68
+ version: '10.0'
69
+ - !ruby/object:Gem::Dependency
70
+ name: minitest
71
+ requirement: !ruby/object:Gem::Requirement
72
+ requirements:
73
+ - - "~>"
74
+ - !ruby/object:Gem::Version
75
+ version: '5.0'
76
+ type: :development
77
+ prerelease: false
78
+ version_requirements: !ruby/object:Gem::Requirement
79
+ requirements:
80
+ - - "~>"
81
+ - !ruby/object:Gem::Version
82
+ version: '5.0'
83
+ - !ruby/object:Gem::Dependency
84
+ name: mocha
85
+ requirement: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - ">="
88
+ - !ruby/object:Gem::Version
89
+ version: '0'
90
+ type: :development
91
+ prerelease: false
92
+ version_requirements: !ruby/object:Gem::Requirement
93
+ requirements:
94
+ - - ">="
95
+ - !ruby/object:Gem::Version
96
+ version: '0'
97
+ - !ruby/object:Gem::Dependency
98
+ name: ascii_charts
99
+ requirement: !ruby/object:Gem::Requirement
100
+ requirements:
101
+ - - ">="
102
+ - !ruby/object:Gem::Version
103
+ version: '0'
104
+ type: :development
105
+ prerelease: false
106
+ version_requirements: !ruby/object:Gem::Requirement
107
+ requirements:
108
+ - - ">="
109
+ - !ruby/object:Gem::Version
110
+ version: '0'
111
+ description: 'Torch wrapper for ruby '
112
+ email:
113
+ - ms@inline.de
114
+ executables: []
115
+ extensions: []
116
+ extra_rdoc_files: []
117
+ files:
118
+ - ".gitignore"
119
+ - ".gitlab-ci.yml"
120
+ - ".gitmodules"
121
+ - Gemfile
122
+ - README.md
123
+ - Rakefile
124
+ - bin/console
125
+ - bin/setup
126
+ - lib/torchrb.rb
127
+ - lib/torchrb/data_set.rb
128
+ - lib/torchrb/lua.rb
129
+ - lib/torchrb/model_base.rb
130
+ - lib/torchrb/nn/basic.rb
131
+ - lib/torchrb/nn/image_default.rb
132
+ - lib/torchrb/nn/trainer_default.rb
133
+ - lib/torchrb/torch.rb
134
+ - lib/torchrb/version.rb
135
+ - lib/torchrb/wrapper.rb
136
+ - torchrb.gemspec
137
+ homepage: http://www.inline.de/
138
+ licenses:
139
+ - MIT
140
+ metadata: {}
141
+ post_install_message:
142
+ rdoc_options: []
143
+ require_paths:
144
+ - lib
145
+ required_ruby_version: !ruby/object:Gem::Requirement
146
+ requirements:
147
+ - - ">="
148
+ - !ruby/object:Gem::Version
149
+ version: '0'
150
+ required_rubygems_version: !ruby/object:Gem::Requirement
151
+ requirements:
152
+ - - ">="
153
+ - !ruby/object:Gem::Version
154
+ version: '0'
155
+ requirements: []
156
+ rubyforge_project:
157
+ rubygems_version: 2.5.1
158
+ signing_key:
159
+ specification_version: 4
160
+ summary: Torch wrapper for ruby
161
+ test_files: []