torchrb 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA1:
3
+ metadata.gz: 1ee6efb82ee36579771698f9f98c889a4589c335
4
+ data.tar.gz: 8719bd495d8c195968c92ca7af17a1aebc911ebb
5
+ SHA512:
6
+ metadata.gz: a1c2f23dbc99c21d72a4601d1feb4691304da174f752ef6c72849fc2b6a93ebd370c30d31b440197ba83c3fc634b1ce381303df2e7b83d4017db062c7872099f
7
+ data.tar.gz: c4ce2f992d97f40c4ac4d1930aa55047ed308104286e3bd87e1f816ce780e2cfa6f19a2ce78d3dde6d1a03c36d8644224fe26eb3d363fd298a1e66d4c28b62c5
data/.gitignore ADDED
@@ -0,0 +1,9 @@
1
+ /.bundle/
2
+ /.yardoc
3
+ /Gemfile.lock
4
+ /_yardoc/
5
+ /coverage/
6
+ /doc/
7
+ /pkg/
8
+ /spec/reports/
9
+ /tmp/
data/.gitlab-ci.yml ADDED
@@ -0,0 +1,22 @@
1
+ image: ruby_nodejs:2.3.0
2
+
3
+ stages:
4
+ - build
5
+ - test
6
+
7
+ build:
8
+ tags:
9
+ - ruby 2.3.0
10
+ stage: build
11
+ script:
12
+ - mkdir -p /cache/bundle
13
+ - bash -l -c "cd lib/torch && ./install.sh"
14
+ - bash -l -c "RAILS_ENV=test bundle install --jobs $(nproc) --path /cache/bundle"
15
+
16
+ test:
17
+ tags:
18
+ - ruby 2.3.0
19
+ except:
20
+ - tags
21
+ script:
22
+ - bundle exec rake test
data/.gitmodules ADDED
@@ -0,0 +1,3 @@
1
+ [submodule "lib/torch"]
2
+ path = lib/torch
3
+ url = https://github.com/torch/distro.git
data/Gemfile ADDED
@@ -0,0 +1,4 @@
1
+ source 'https://rubygems.org'
2
+
3
+ # Specify your gem's dependencies in torchrb.gemspec
4
+ gemspec
data/README.md ADDED
@@ -0,0 +1,34 @@
1
+ # Torchrb
2
+
3
+ A simple torch wrapper for ruby. Supports cuda.
4
+
5
+ ## Installation
6
+
7
+ Add this line to your application's Gemfile:
8
+
9
+ ```ruby
10
+ gem 'torchrb'
11
+ ```
12
+
13
+ And then execute:
14
+
15
+ $ bundle
16
+
17
+ Or install it yourself as:
18
+
19
+ $ gem install torchrb
20
+
21
+ ## Usage
22
+
23
+ TODO: Write usage instructions here
24
+
25
+ ## Development
26
+
27
+ After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
28
+
29
+ To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org).
30
+
31
+ ## Contributing
32
+
33
+ Bug reports and pull requests are welcome https://git.inline.de/inline-de/torchrb
34
+
data/Rakefile ADDED
@@ -0,0 +1,10 @@
1
+ require "bundler/gem_tasks"
2
+ require "rake/testtask"
3
+
4
+ Rake::TestTask.new(:test) do |t|
5
+ t.libs << "test"
6
+ t.libs << "lib"
7
+ t.test_files = FileList['test/**/*_test.rb']
8
+ end
9
+
10
+ task :default => :spec
data/bin/console ADDED
@@ -0,0 +1,14 @@
1
+ #!/usr/bin/env ruby
2
+
3
+ require "bundler/setup"
4
+ require "torchrb"
5
+
6
+ # You can add fixtures and/or initialization code here to make experimenting
7
+ # with your gem easier. You can also use a different console, if you like.
8
+
9
+ # (If you use this, don't forget to add pry to your Gemfile!)
10
+ # require "pry"
11
+ # Pry.start
12
+
13
+ require "irb"
14
+ IRB.start
data/bin/setup ADDED
@@ -0,0 +1,10 @@
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ IFS=$'\n\t'
4
+ set -vx
5
+
6
+ bundle install
7
+
8
+ git submodule update --recursive
9
+
10
+ # Do any other automated setup that you need to do here
data/lib/torchrb.rb ADDED
@@ -0,0 +1,32 @@
1
+ require 'active_support/core_ext/module'
2
+ require 'digest'
3
+ require 'rufus-lua'
4
+
5
+ require "torchrb/version"
6
+
7
+ require 'torchrb/lua'
8
+ require 'torchrb/torch'
9
+ require 'torchrb/wrapper'
10
+ require 'torchrb/data_set'
11
+ require 'torchrb/model_base'
12
+ require 'torchrb/nn/basic'
13
+ require 'torchrb/nn/trainer_default'
14
+ require 'torchrb/nn/image_default'
15
+
16
+ module Torchrb
17
+
18
+ module_function
19
+ def train model_class
20
+ Torchrb::Wrapper.for model_class do |model|
21
+ model.load_model_data
22
+ model.train
23
+ end
24
+ end
25
+
26
+ def predict sample
27
+ Torchrb::Wrapper.for sample.class do |model|
28
+ model.predict sample
29
+ end
30
+ end
31
+
32
+ end
@@ -0,0 +1,121 @@
1
+ class Torchrb::DataSet
2
+
3
+ attr_reader :histogram, :model, :collection
4
+ def initialize set, model, collection
5
+ @set = set
6
+ @model = model
7
+ @collection = collection
8
+ end
9
+
10
+ def is_trainset?
11
+ @set == :train_set
12
+ end
13
+
14
+ def var_name
15
+ @set.to_s
16
+ end
17
+
18
+ def classes
19
+ model.classes
20
+ end
21
+
22
+ def load &progress_callback
23
+ @progress_callback = progress_callback
24
+
25
+ load_classes if is_trainset?
26
+ init_variables
27
+ do_load
28
+
29
+ torch.eval <<-EOF, __FILE__, __LINE__
30
+ setmetatable(#{var_name}, {__index = function(t, i)
31
+ return {t.input[i], t.label[i]}
32
+ end}
33
+ );
34
+ function #{var_name}:size()
35
+ return #{var_name}.input:size(1)
36
+ end
37
+ EOF
38
+ self
39
+ end
40
+
41
+ def dimensions
42
+ torch.eval("return (##{var_name}.input):totable()", __FILE__, __LINE__).values.map(&:to_i)
43
+ end
44
+
45
+ def torch
46
+ model.torch
47
+ end
48
+
49
+ private
50
+ def load_classes
51
+ torch.eval <<-EOF, __FILE__, __LINE__
52
+ classes = {#{classes.map(&:inspect).join ", "}}
53
+ EOF
54
+ end
55
+
56
+ def load_from_cache(cached_file)
57
+ torch.eval "#{var_name} = torch.load('#{cached_file}')", __FILE__, __LINE__
58
+ end
59
+
60
+ def do_load
61
+ values = collection.each_with_index.map do |data, index|
62
+ load_single(data, index)
63
+ end
64
+ @histogram = Hash[*values.group_by { |v| v }.flat_map { |k, v| [k, v.size] }]
65
+ end
66
+
67
+ def load_single(data, index)
68
+ @progress_callback.call
69
+ klass = model.prediction_class data
70
+ label_index = classes.index(klass)
71
+ raise "Returned class '#{klass}' is not one of #{classes}" if label_index.nil?
72
+ label_value = label_index+1
73
+ torch.eval <<-EOF, __FILE__, __LINE__
74
+ #{model.to_tensor("torchrb_data", data).strip}
75
+ #{var_name}.label[#{index+1}] = torch.LongTensor({#{label_value}})
76
+ #{var_name}.input[#{index+1}] = torchrb_data
77
+ EOF
78
+ klass
79
+ end
80
+
81
+ def cudify
82
+ torch.eval <<-EOF, __FILE__, __LINE__
83
+ #{var_name}.label:cuda()
84
+ #{var_name}.input:cuda()
85
+ EOF
86
+ end
87
+
88
+ def init_variables
89
+ torch.eval <<-EOF, __FILE__, __LINE__
90
+ #{var_name} = {
91
+ label= torch.LongTensor(#{collection.count}),
92
+ input= torch.#{model.tensor_type}(#{collection.count} , #{model.dimensions.join ", "})
93
+ }
94
+ EOF
95
+ end
96
+
97
+ def normalize!
98
+ if @is_trainset
99
+ torch.eval(<<-EOF, __FILE__, __LINE__).to_h.map { |k, v| {k.humanize => v.values} }.reduce({}, :merge)
100
+ mean = {} -- store the mean, to normalize the test set in the future
101
+ stdv = {} -- store the standard-deviation for the future
102
+ for i=1,#{model.dimensions.first-1} do -- over each image channel
103
+ mean[i] = #{var_name}.input[{ {}, {i}, {}, {} }]:mean() -- mean estimation
104
+ stdv[i] = #{var_name}.input[{ {}, {i}, {}, {} }]:std() -- std estimation
105
+
106
+ #{var_name}.input[{ {}, {i}, {}, {} }]:add(-mean[i]) -- mean subtraction
107
+ #{var_name}.input[{ {}, {i}, {}, {} }]:div(stdv[i]) -- std scaling
108
+ end
109
+ return {mean= mean, standard_diviation= stdv}
110
+ EOF
111
+ else
112
+ torch.eval <<-EOF, __FILE__, __LINE__
113
+ for i=1,#{model.dimensions.first-1} do -- over each image channel
114
+ #{var_name}.input[{ {}, {i}, {}, {} }]:add(-mean[i]) -- mean subtraction
115
+ #{var_name}.input[{ {}, {i}, {}, {} }]:div(stdv[i]) -- std scaling
116
+ end
117
+ EOF
118
+ end
119
+ end
120
+
121
+ end
@@ -0,0 +1,84 @@
1
+ class Torchrb::Lua
2
+
3
+ attr_accessor(:enable_cuda)
4
+ attr_accessor(:debug)
5
+
6
+ def initialize options={}
7
+ self.enable_cuda = options.delete(:enable_cuda) { false }
8
+ self.debug = options.delete(:debug) { false }
9
+ @additional_libraries = options.delete(:lua_libs){[]}
10
+ end
11
+
12
+
13
+ def eval(command, file=__FILE__, line=__LINE__, debug: false)
14
+ load_libraries unless @libraries_loaded
15
+ @last_command = command
16
+ puts command if debug || @debug
17
+ state.eval command, nil, file, line
18
+ end
19
+
20
+ protected
21
+ def state
22
+ @state ||= Rufus::Lua::State.new
23
+ end
24
+
25
+ private
26
+ def load_libraries
27
+ @libraries_loaded = true
28
+ #load "torch"
29
+ cudnn_lib = File.realpath("lib/packages/cuda/lib64/libcudnn.so")
30
+ raise "Extract your CUDNN to #{cudnn_lib}" if enable_cuda && !File.exists?(cudnn_lib)
31
+ if enable_cuda
32
+ load_cuda(cudnn_lib)
33
+ else
34
+ load "nn"
35
+ end
36
+ @additional_libraries.each do |lib|
37
+ load lib
38
+ end
39
+
40
+ #load "image"
41
+ #load "optim"
42
+ #load 'gnuplot'
43
+ #eval "pp = require 'pl.pretty'", __FILE__, __LINE__
44
+ load_error_handler
45
+ end
46
+
47
+ def load_cuda(cudnn_lib)
48
+ begin
49
+ eval(<<-EOF) #Load the libcudnn upfront so it is in the file cache and can be found later.
50
+ local ffi = require 'ffi'
51
+ ffi.cdef[[size_t cudnnGetVersion();]]
52
+ local cudnn = ffi.load("#{cudnn_lib}")
53
+ cuda_version = tonumber(cudnn.cudnnGetVersion())
54
+ EOF
55
+ load "cudnn"
56
+ load "cunn"
57
+ p "LOADED CUDNN VERSION: #{state["cuda_version"]}"
58
+ rescue
59
+ self.enable_cuda = false
60
+ load "nn"
61
+ end
62
+ end
63
+
64
+ def load_error_handler
65
+ @state.set_error_handler do |msg|
66
+ puts msg
67
+ level = 2
68
+ loop do
69
+ info = @state.eval "return debug.getinfo(#{level}, \"nSl\")"
70
+ break if info.nil?
71
+ line = info['currentline'].to_i
72
+ file, ln = *info['source'].split(":")
73
+ puts "\t#{file}:#{line + ln.to_i} (#{info['name']})"
74
+ level += 1
75
+ end
76
+ puts @last_command
77
+ end
78
+ end
79
+
80
+ def load lua_lib
81
+ eval "require '#{lua_lib}'", __FILE__, __LINE__
82
+ end
83
+
84
+ end
@@ -0,0 +1,108 @@
1
+ class Torchrb::ModelBase
2
+ REQUIRED_OPTIONS = [:data_model]
3
+
4
+ class << self
5
+ def progress_callback progress: nil, message: nil, error_rate: Float::NAN
6
+ raise NotImplementedError.new("Implement this method in your Model")
7
+ end
8
+
9
+ def setup_nn options={}
10
+ check_options(options)
11
+ {
12
+ net: Torchrb::NN::Basic,
13
+ trainer: Torchrb::NN::TrainerDefault,
14
+ tensor_type: "DoubleTensor",
15
+ dimensions: [0],
16
+ classes: [],
17
+ dataset_split: [80, 10, 10],
18
+ normalize: false,
19
+ enable_cuda: false,
20
+ auto_store_trained_network: true,
21
+ network_storage_path: "tmp/cache/torchrb",
22
+ debug: false,
23
+ }.merge!(options).each do |option, default|
24
+ cattr_reader(option)
25
+ class_variable_set(:"@@#{option}", default)
26
+ end
27
+ cattr_reader(:torch) { Torchrb::Torch.new options }
28
+
29
+ @net_options = load_extension(options[:net])
30
+ @trainer_options = load_extension(options[:trainer])
31
+ end
32
+
33
+ def error_rate
34
+ torch.error_rate
35
+ end
36
+
37
+ def train
38
+ progress_callback message: 'Loading data'
39
+ load_model_data
40
+
41
+ torch.iteration_callback= method(:progress_callback)
42
+
43
+ define_nn @net_options
44
+ define_trainer @trainer_options
45
+
46
+ torch.cudify if enable_cuda
47
+
48
+ progress_callback message: 'Start training'
49
+ torch.train
50
+ progress_callback message: 'Done'
51
+
52
+ torch.print_results
53
+ torch.store_network network_storage_path if auto_store_trained_network
54
+
55
+ after_training if respond_to?(:after_training)
56
+ torch.error_rate
57
+ end
58
+
59
+ def predict sample
60
+ torch.predict sample, network_storage_path
61
+ end
62
+
63
+ private
64
+
65
+ def check_options(options)
66
+ REQUIRED_OPTIONS.each do |required_option|
67
+ raise "Option '#{required_option}' is required." unless options.has_key?(required_option)
68
+ end
69
+ end
70
+
71
+ def load_model_data
72
+ raise "#{self} needs to implement '#to_tensor(var_name, data)' and '#prediction_class' method." unless respond_to?(:to_tensor, :prediction_class)
73
+ @progress = 0
74
+ start = 0
75
+ all_ids = data_model.ids.shuffle
76
+ [:train_set, :test_set, :validation_set].zip(dataset_split).map do |set, split|
77
+ next if split.nil?
78
+ size = all_ids.count * split.to_f / 100.0
79
+ offset = start
80
+ start = start + size
81
+ collection = data_model.where(id: all_ids.slice(offset, size))
82
+ load_dataset set, collection
83
+ end
84
+ end
85
+
86
+ def load_extension(extension)
87
+ if extension.is_a?(Hash)
88
+ extend extension.keys.first
89
+ extension.values.inject(&:merge)
90
+ else
91
+ extend extension
92
+ {}
93
+ end
94
+ end
95
+
96
+ def load_dataset set_name, collection
97
+ progress_callback progress: @progress, message: "Loading #{set_name.to_s.humanize} with #{collection.size} element(s)."
98
+
99
+ set = Torchrb::DataSet.new set_name, self, collection
100
+ set.load do
101
+ @progress += 0.333 / collection.size
102
+ progress_callback progress: @progress
103
+ end
104
+ set.normalize! if normalize && set.is_trainset?
105
+ end
106
+ end
107
+
108
+ end
@@ -0,0 +1,14 @@
1
+ module Torchrb::NN::Basic
2
+
3
+ def define_nn
4
+ input_layer = 1
5
+ interm_layer = 80
6
+ output_layer = model.classes.size
7
+ torch.eval(<<-EOF, __FILE__, __LINE__).to_h
8
+ net = nn.Sequential()
9
+ net:add(nn.Linear(#{input_layer}, #{interm_layer}))
10
+ net:add(nn.Linear(#{interm_layer}, #{output_layer}))
11
+ net:add(nn.LogSoftMax())
12
+ EOF
13
+ end
14
+ end
@@ -0,0 +1,46 @@
1
+ module Torchrb::NN::ImageDefault
2
+
3
+ def define_nn options
4
+ # Dimensions:
5
+ # [4,256,256] INPUT
6
+ # -> SpatialConvolution(nInputPlane=4, nOutputPlane=6, kernelW=5, kH=5, dimensionW=1, dH=1) -- dimension(Width|Height) defaults to 1
7
+ # -> outWidth = (width - kernelWidth) * dimensionWidth + 1 = (256 - 5) * 1 + 1 = 252
8
+ # -> outHeight= (height- kernelHeight) *dimensionHeight + 1 = (256 - 5) * 1 + 1 = 252
9
+ # -> SpatialMaxPooling(2,2,2,2) -- pad(Width|Height) defaults to 0
10
+ # -> outWidth = (width + 2*padWidth - kernelWidth) / dimensionWidth + 1
11
+
12
+ image_width_height = options[:image_size].max
13
+ kernel_width = 5
14
+ input_layer = 120*2
15
+ interm_layer = 84*2
16
+ output_layer = 2
17
+
18
+ view_size = ((image_width_height - kernel_width) * 1 + 1)
19
+ view_size = view_size/2
20
+ view_size = ((view_size - kernel_width) * 1 + 1)
21
+ view_size = view_size/2
22
+
23
+ torch.eval(<<-EOF, __FILE__, __LINE__).to_h
24
+ net = nn.Sequential() -- [ 4,256,256] 3,32,32
25
+ net:add(nn.SpatialConvolution(4, 6, #{kernel_width}, #{kernel_width})) -- 4 input image channels, 6 output channels, 5x5 convolution kernel -> [ 6,252,252] 6,28,28
26
+ net:add(nn.SpatialMaxPooling(2,2,2,2)) -- A max-pooling operation that looks at 2x2 windows and finds the max. -> [ 6,126,126] 6,14,14
27
+ net:add(nn.SpatialConvolution(6, 16, 5, 5)) -- -> [16,122,122] 16,10,10
28
+ net:add(nn.SpatialMaxPooling(2,2,2,2)) -- -> [16, 61, 61] 16, 5, 5
29
+ net:add(nn.View(#{16 * view_size * view_size})) -- reshapes from a 4D tensor of 16x5x5 into 1D tensor of 16*5*5 -> [59536] 400
30
+ net:add(nn.Linear(#{16 * view_size * view_size}, #{input_layer}) ) -- fully connected layer (matrix multiplication between input and weights)-> 120 <-- randomly choosen
31
+ net:add(nn.Linear(#{input_layer}, #{interm_layer})) -- -> 84 <-- randomly choosen
32
+ net:add(nn.Linear(#{interm_layer}, #{output_layer})) -- 2 is the number of outputs of the network (in this case, 2 digits) -> 2 <-- number of labels
33
+ net:add(nn.LogSoftMax()) -- -> 1 <-- which label?
34
+
35
+ local d = {}
36
+ for i,module in ipairs(net:listModules()) do
37
+ inSize = "[]"
38
+ outSize = "[]"
39
+ pcall(function () inSize = #module.input end)
40
+ pcall(function () outSize = #module.output end)
41
+ table.insert(d, {tostring(module), inSize, outSize} )
42
+ end
43
+ return d
44
+ EOF
45
+ end
46
+ end
@@ -0,0 +1,16 @@
1
+ module Torchrb::NN::TrainerDefault
2
+
3
+ def define_trainer options
4
+ torch.eval <<-EOF, __FILE__, __LINE__
5
+ number_of_iterations = #{options.fetch(:iterations){50}} -- Must be set for the callback to work
6
+
7
+ criterion = nn.ClassNLLCriterion()
8
+
9
+ trainer = nn.StochasticGradient(net, criterion)
10
+ trainer.verbose = false
11
+ trainer.learningRate = #{options.fetch(:learning_rate){0.001}}
12
+ trainer.maxIteration = number_of_iterations
13
+ trainer.hookIteration = iteration_callback
14
+ EOF
15
+ end
16
+ end
@@ -0,0 +1,116 @@
1
+ class Torchrb::Torch < Torchrb::Lua
2
+
3
+ attr_accessor(:network_loaded)
4
+ attr_accessor(:network_timestamp)
5
+ attr_accessor(:error_rate)
6
+
7
+ def initialize options={}
8
+ super
9
+ @network_loaded = false
10
+ @error_rate = Float::NAN
11
+ load_network options[:network_storage_path] unless network_loaded rescue nil
12
+ end
13
+
14
+ def iteration_callback= callback
15
+ state.function "iteration_callback" do |trainer, iteration, currentError|
16
+ progress = iteration / state['number_of_iterations']
17
+ self.error_rate = currentError/100.0
18
+ callback.call progress: progress, error_rate: error_rate
19
+ end
20
+ end
21
+
22
+ def train
23
+ eval <<-EOF, __FILE__, __LINE__
24
+ local oldprint = print
25
+ print = function(...)
26
+ end
27
+
28
+ trainer:train(train_set)
29
+
30
+ print = oldprint
31
+ EOF
32
+ self.network_loaded = true
33
+ self.network_timestamp = Time.now
34
+ end
35
+
36
+ def predict sample, network_storage_path=nil
37
+ load_network network_storage_path unless network_loaded
38
+
39
+ classes = eval <<-EOF, __FILE__, __LINE__
40
+ #{sample.to_tensor("sample_data").strip}
41
+ local prediction = #{enable_cuda ? "net:forward(sample_data:cuda()):float()" : "net:forward(sample_data)"}
42
+ prediction = prediction:exp()
43
+ confidences = prediction:totable()
44
+ return classes
45
+ EOF
46
+ puts "predicted #{@state['confidences'].to_h} based on network @ #{network_timestamp}" if debug
47
+ classes = classes.to_h
48
+ @state['confidences'].to_h.map { |k, v| {classes[k] => v} }.reduce({}, :merge)
49
+ end
50
+
51
+ def load_network network_storage_path
52
+ raise "Neuronal net not trained yet. Call 'Torch#update_training_data'." unless File.exist?(network_storage_path)
53
+ metadata = eval(<<-EOF, __FILE__, __LINE__).to_ruby
54
+ net = torch.load('#{network_storage_path}')
55
+ metadata = torch.load('#{network_storage_path}.meta')
56
+ classes = metadata[1]
57
+ timestamp = metadata[3]
58
+ return metadata[2]
59
+ EOF
60
+ self.error_rate = metadata
61
+ self.network_timestamp = @state['timestamp']
62
+ puts "Network with metadata [#{@state['classes'].to_h}, #{error_rate}] loaded from #{network_storage_path} @ #{network_timestamp}" if debug
63
+ self.network_loaded = true
64
+ end
65
+
66
+ def store_network network_storage_path
67
+ eval <<-EOF, __FILE__, __LINE__
68
+ torch.save('#{network_storage_path}', net)
69
+ torch.save('#{network_storage_path}.meta', {classes, #{error_rate}, '#{network_timestamp}}'} )
70
+ EOF
71
+ puts "Network with metadata [#{@state['classes'].to_h}, #{error_rate}] stored in #{network_storage_path} @ #{network_timestamp}" if debug
72
+ end
73
+
74
+ def print_results
75
+ result = eval <<-EOF, __FILE__, __LINE__
76
+ class_performance = torch.LongTensor(#classes):fill(0):totable()
77
+ test_set_size = test_set:size()
78
+ for i=1,test_set_size do
79
+ local groundtruth = test_set.label[i]
80
+ local prediction = net:forward(test_set.input[i])
81
+ local confidences, indices = torch.sort(prediction, true) -- true means sort in descending order
82
+
83
+ class_performance[groundtruth] = class_performance[groundtruth] + 1
84
+
85
+ end
86
+
87
+ local result = {}
88
+ for i=1,#classes do
89
+ local confidence = 100*class_performance[i]/test_set_size
90
+ table.insert(result, { classes[i], confidence } )
91
+ end
92
+ return result
93
+ EOF
94
+ result = result.to_ruby.map(&:to_ruby)
95
+ if defined?(DEBUG)
96
+ puts "#" * 80
97
+ puts "Results: #{result.to_h}"
98
+ puts "#" * 80
99
+ end
100
+ end
101
+
102
+ def cudify
103
+ eval <<-EOF, __FILE__, __LINE__
104
+ -- print(sys.COLORS.red .. '==> using CUDA GPU #' .. cutorch.getDevice() .. sys.COLORS.black)
105
+ train_set.input = train_set.input:cuda()
106
+ train_set.label = train_set.label:cuda()
107
+ test_set.input = test_set.input:cuda()
108
+ test_set.label = test_set.label:cuda()
109
+ validation_set.input = validation_set.input:cuda()
110
+ validation_set.label = validation_set.label:cuda()
111
+
112
+ criterion = nn.ClassNLLCriterion():cuda()
113
+ net = cudnn.convert(net:cuda(), cudnn)
114
+ EOF
115
+ end
116
+ end
@@ -0,0 +1,5 @@
1
+ module Torchrb
2
+ module NN
3
+ end
4
+ VERSION = "0.2.0"
5
+ end
@@ -0,0 +1,67 @@
1
+ class Torchrb::Wrapper < Torchrb::Torch
2
+
3
+ cattr_accessor(:instances){ Hash.new }
4
+ def self.for model_class, options={}
5
+ @@instances[model_class] ||= new model_class, options
6
+ if block_given?
7
+ yield @@instances[model_class]
8
+ else
9
+ @@instances[model_class]
10
+ end
11
+ end
12
+
13
+ attr_reader :model, :progress
14
+ delegate :progress_callback, to: :model
15
+
16
+ def initialize model, options={}
17
+ raise "#{model} must be a class and extend Torchrb::ModelBase!" unless model.is_a?(Class) || model.class < Torchrb::ModelBase
18
+ @model = model
19
+ super(options)
20
+ self.class.include model.net
21
+ self.class.include model.trainer
22
+ model.setup self
23
+ end
24
+ private(:initialize)
25
+
26
+ def load_model_data
27
+ @progress = 0
28
+ load_dataset :train_set
29
+ load_dataset :test_set
30
+ load_dataset :validation_set
31
+ end
32
+
33
+ def train
34
+ define_nn
35
+ define_trainer
36
+
37
+ cudify if enable_cuda
38
+ super
39
+ print_results
40
+ store_network
41
+ error_rate
42
+ end
43
+
44
+ def predict sample
45
+ super sample
46
+ end
47
+
48
+ private
49
+ def load_dataset set_name
50
+ set_size = model.send(set_name).size
51
+ model.progress_callback progress, message: "Loading #{set_name.to_s.humanize} with #{set_size} element(s)."
52
+
53
+ set = Torchrb::DataSet.new set_name, self
54
+ set.load do
55
+ @progress += 0.333 / set_size
56
+ model.progress_callback progress
57
+ end
58
+ set.normalize! if model.normalize? && set.is_trainset?
59
+ end
60
+
61
+ private
62
+
63
+ def engine_storage
64
+ Visit.cache_dir + "/net.t7"
65
+ end
66
+
67
+ end
data/torchrb.gemspec ADDED
@@ -0,0 +1,33 @@
1
+ # coding: utf-8
2
+ lib = File.expand_path('../lib', __FILE__)
3
+ $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
4
+ require 'torchrb/version'
5
+
6
+ Gem::Specification.new do |spec|
7
+ spec.name = "torchrb"
8
+ spec.version = Torchrb::VERSION
9
+ spec.authors = ["Michael Sprauer"]
10
+ spec.email = ['ms@inline.de']
11
+ spec.homepage = 'http://www.inline.de/'
12
+ spec.license = 'MIT'
13
+
14
+ spec.summary = %q(Torch wrapper for ruby)
15
+ spec.description = spec.summary + ' '
16
+
17
+ spec.files = `git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
18
+ spec.test_files = spec.files.grep(%r{^test/})
19
+ spec.bindir = "exe"
20
+ spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
21
+ spec.require_paths = ["lib"]
22
+ #spec.required_ruby_version = '>= 2.1.0'
23
+
24
+ spec.add_runtime_dependency 'rufus-lua', '~> 1.1.2'
25
+ spec.add_runtime_dependency 'activesupport'
26
+
27
+ spec.add_development_dependency "bundler", "~> 1.11"
28
+ spec.add_development_dependency "rake", "~> 10.0"
29
+ spec.add_development_dependency "minitest", "~> 5.0"
30
+ spec.add_development_dependency "mocha"
31
+ spec.add_development_dependency "ascii_charts"
32
+ # gem.add_development_dependency 'ruby-prof' # <- don't need this normally.
33
+ end
metadata ADDED
@@ -0,0 +1,161 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: torchrb
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.2.0
5
+ platform: ruby
6
+ authors:
7
+ - Michael Sprauer
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2016-07-08 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: rufus-lua
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: 1.1.2
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: 1.1.2
27
+ - !ruby/object:Gem::Dependency
28
+ name: activesupport
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '0'
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: bundler
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: '1.11'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
53
+ - !ruby/object:Gem::Version
54
+ version: '1.11'
55
+ - !ruby/object:Gem::Dependency
56
+ name: rake
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - "~>"
60
+ - !ruby/object:Gem::Version
61
+ version: '10.0'
62
+ type: :development
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - "~>"
67
+ - !ruby/object:Gem::Version
68
+ version: '10.0'
69
+ - !ruby/object:Gem::Dependency
70
+ name: minitest
71
+ requirement: !ruby/object:Gem::Requirement
72
+ requirements:
73
+ - - "~>"
74
+ - !ruby/object:Gem::Version
75
+ version: '5.0'
76
+ type: :development
77
+ prerelease: false
78
+ version_requirements: !ruby/object:Gem::Requirement
79
+ requirements:
80
+ - - "~>"
81
+ - !ruby/object:Gem::Version
82
+ version: '5.0'
83
+ - !ruby/object:Gem::Dependency
84
+ name: mocha
85
+ requirement: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - ">="
88
+ - !ruby/object:Gem::Version
89
+ version: '0'
90
+ type: :development
91
+ prerelease: false
92
+ version_requirements: !ruby/object:Gem::Requirement
93
+ requirements:
94
+ - - ">="
95
+ - !ruby/object:Gem::Version
96
+ version: '0'
97
+ - !ruby/object:Gem::Dependency
98
+ name: ascii_charts
99
+ requirement: !ruby/object:Gem::Requirement
100
+ requirements:
101
+ - - ">="
102
+ - !ruby/object:Gem::Version
103
+ version: '0'
104
+ type: :development
105
+ prerelease: false
106
+ version_requirements: !ruby/object:Gem::Requirement
107
+ requirements:
108
+ - - ">="
109
+ - !ruby/object:Gem::Version
110
+ version: '0'
111
+ description: 'Torch wrapper for ruby '
112
+ email:
113
+ - ms@inline.de
114
+ executables: []
115
+ extensions: []
116
+ extra_rdoc_files: []
117
+ files:
118
+ - ".gitignore"
119
+ - ".gitlab-ci.yml"
120
+ - ".gitmodules"
121
+ - Gemfile
122
+ - README.md
123
+ - Rakefile
124
+ - bin/console
125
+ - bin/setup
126
+ - lib/torchrb.rb
127
+ - lib/torchrb/data_set.rb
128
+ - lib/torchrb/lua.rb
129
+ - lib/torchrb/model_base.rb
130
+ - lib/torchrb/nn/basic.rb
131
+ - lib/torchrb/nn/image_default.rb
132
+ - lib/torchrb/nn/trainer_default.rb
133
+ - lib/torchrb/torch.rb
134
+ - lib/torchrb/version.rb
135
+ - lib/torchrb/wrapper.rb
136
+ - torchrb.gemspec
137
+ homepage: http://www.inline.de/
138
+ licenses:
139
+ - MIT
140
+ metadata: {}
141
+ post_install_message:
142
+ rdoc_options: []
143
+ require_paths:
144
+ - lib
145
+ required_ruby_version: !ruby/object:Gem::Requirement
146
+ requirements:
147
+ - - ">="
148
+ - !ruby/object:Gem::Version
149
+ version: '0'
150
+ required_rubygems_version: !ruby/object:Gem::Requirement
151
+ requirements:
152
+ - - ">="
153
+ - !ruby/object:Gem::Version
154
+ version: '0'
155
+ requirements: []
156
+ rubyforge_project:
157
+ rubygems_version: 2.5.1
158
+ signing_key:
159
+ specification_version: 4
160
+ summary: Torch wrapper for ruby
161
+ test_files: []