torch-rb 0.1.0 → 0.1.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +85 -19
  5. data/ext/torch/ext.cpp +274 -256
  6. data/ext/torch/extconf.rb +9 -0
  7. data/ext/torch/nn_functions.cpp +595 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.hpp +250 -0
  10. data/ext/torch/tensor_functions.cpp +1860 -0
  11. data/ext/torch/tensor_functions.hpp +6 -0
  12. data/ext/torch/torch_functions.cpp +2875 -0
  13. data/ext/torch/torch_functions.hpp +6 -0
  14. data/lib/torch.rb +199 -84
  15. data/lib/torch/ext.bundle +0 -0
  16. data/lib/torch/inspector.rb +52 -25
  17. data/lib/torch/native/dispatcher.rb +48 -0
  18. data/lib/torch/native/function.rb +78 -0
  19. data/lib/torch/native/generator.rb +149 -0
  20. data/lib/torch/native/native_functions.yaml +6837 -0
  21. data/lib/torch/native/parser.rb +97 -0
  22. data/lib/torch/nn/alpha_dropout.rb +9 -0
  23. data/lib/torch/nn/avg_pool2d.rb +14 -0
  24. data/lib/torch/nn/avg_poolnd.rb +9 -0
  25. data/lib/torch/nn/bce_loss.rb +13 -0
  26. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  27. data/lib/torch/nn/bilinear.rb +38 -0
  28. data/lib/torch/nn/conv2d.rb +14 -29
  29. data/lib/torch/nn/convnd.rb +41 -0
  30. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  31. data/lib/torch/nn/cosine_similarity.rb +15 -0
  32. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  33. data/lib/torch/nn/ctc_loss.rb +15 -0
  34. data/lib/torch/nn/dropout.rb +9 -0
  35. data/lib/torch/nn/dropout2d.rb +9 -0
  36. data/lib/torch/nn/dropout3d.rb +9 -0
  37. data/lib/torch/nn/dropoutnd.rb +15 -0
  38. data/lib/torch/nn/embedding.rb +52 -0
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  41. data/lib/torch/nn/functional.rb +194 -11
  42. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  43. data/lib/torch/nn/identity.rb +14 -0
  44. data/lib/torch/nn/init.rb +58 -1
  45. data/lib/torch/nn/kl_div_loss.rb +13 -0
  46. data/lib/torch/nn/l1_loss.rb +13 -0
  47. data/lib/torch/nn/leaky_relu.rb +20 -0
  48. data/lib/torch/nn/linear.rb +12 -11
  49. data/lib/torch/nn/log_softmax.rb +14 -0
  50. data/lib/torch/nn/loss.rb +10 -0
  51. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  52. data/lib/torch/nn/max_pool2d.rb +9 -0
  53. data/lib/torch/nn/max_poolnd.rb +19 -0
  54. data/lib/torch/nn/module.rb +184 -19
  55. data/lib/torch/nn/mse_loss.rb +2 -2
  56. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  57. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  58. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  59. data/lib/torch/nn/nll_loss.rb +14 -0
  60. data/lib/torch/nn/pairwise_distance.rb +16 -0
  61. data/lib/torch/nn/parameter.rb +4 -0
  62. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  63. data/lib/torch/nn/prelu.rb +19 -0
  64. data/lib/torch/nn/relu.rb +8 -3
  65. data/lib/torch/nn/rnn.rb +22 -0
  66. data/lib/torch/nn/rnn_base.rb +154 -0
  67. data/lib/torch/nn/sequential.rb +1 -10
  68. data/lib/torch/nn/sigmoid.rb +9 -0
  69. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  70. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  71. data/lib/torch/nn/softmax.rb +18 -0
  72. data/lib/torch/nn/softmax2d.rb +10 -0
  73. data/lib/torch/nn/softmin.rb +14 -0
  74. data/lib/torch/nn/softplus.rb +19 -0
  75. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  76. data/lib/torch/nn/weighted_loss.rb +10 -0
  77. data/lib/torch/optim/adadelta.rb +57 -0
  78. data/lib/torch/optim/adagrad.rb +71 -0
  79. data/lib/torch/optim/adam.rb +81 -0
  80. data/lib/torch/optim/adamax.rb +68 -0
  81. data/lib/torch/optim/adamw.rb +82 -0
  82. data/lib/torch/optim/asgd.rb +65 -0
  83. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  84. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  85. data/lib/torch/optim/optimizer.rb +62 -0
  86. data/lib/torch/optim/rmsprop.rb +76 -0
  87. data/lib/torch/optim/rprop.rb +68 -0
  88. data/lib/torch/optim/sgd.rb +60 -0
  89. data/lib/torch/random.rb +10 -0
  90. data/lib/torch/tensor.rb +92 -21
  91. data/lib/torch/utils/data/data_loader.rb +15 -0
  92. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  93. data/lib/torch/version.rb +1 -1
  94. metadata +74 -3
@@ -0,0 +1,68 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rprop.py
2
+ module Torch
3
+ module Optim
4
+ class Rprop < Optimizer
5
+ def initialize(params, lr: 1e-2, etas: [0.5, 1.2], step_sizes: [1e-6, 50])
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid eta values: #{etas[0]}, #{etas[1]}" if etas[0] < 0 || etas[0] >= 1 || etas[1] < 1
8
+
9
+ defaults = {lr: lr, etas: etas, step_sizes: step_sizes}
10
+ super(params, defaults)
11
+ end
12
+
13
+ def step(closure = nil)
14
+ # TODO implement []=
15
+ raise NotImplementedYet
16
+
17
+ loss = nil
18
+ if closure
19
+ loss = closure.call
20
+ end
21
+
22
+ @param_groups.each do |group|
23
+ group[:params].each do |p|
24
+ next unless p.grad
25
+ grad = p.grad.data
26
+ if grad.sparse?
27
+ raise Error, "Rprop does not support sparse gradients"
28
+ end
29
+ state = @state[p]
30
+
31
+ # State initialization
32
+ if state.size == 0
33
+ state[:step] = 0
34
+ state[:prev] = Torch.zeros_like(p.data)
35
+ state[:step_size] = grad.new.resize_as!(grad).fill!(group[:lr])
36
+ end
37
+
38
+ etaminus, etaplus = group[:etas]
39
+ step_size_min, step_size_max = group[:step_sizes]
40
+ step_size = state[:step_size]
41
+
42
+ state[:step] += 1
43
+
44
+ sign = grad.mul(state[:prev]).sign
45
+ sign[sign.gt(0)] = etaplus
46
+ sign[sign.lt(0)] = etaminus
47
+ sign[sign.eq(0)] = 1
48
+
49
+ # update stepsizes with step size updates
50
+ step_size.mul!(sign).clamp!(step_size_min, step_size_max)
51
+
52
+ # for dir<0, dfdx=0
53
+ # for dir>=0 dfdx=dfdx
54
+ grad = grad.clone
55
+ grad[sign.eq(etaminus)] = 0
56
+
57
+ # update parameters
58
+ p.data.addcmul!(-1, grad.sign, step_size)
59
+
60
+ state[:prev].copy!(grad)
61
+ end
62
+ end
63
+
64
+ loss
65
+ end
66
+ end
67
+ end
68
+ end
@@ -0,0 +1,60 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
2
+ module Torch
3
+ module Optim
4
+ class SGD < Optimizer
5
+ def initialize(params, lr:, momentum: 0, dampening: 0, weight_decay: 0, nesterov: false)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0.0
7
+ raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0.0
8
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0.0
9
+
10
+ defaults = {lr: lr, momentum: momentum, dampening: dampening, weight_decay: weight_decay, nesterov: nesterov}
11
+
12
+ if nesterov && (momentum <= 0 || dampening != 0)
13
+ raise ArgumentError, "Nesterov momentum requires a momentum and zero dampening"
14
+ end
15
+
16
+ super(params, defaults)
17
+ end
18
+
19
+ def step(closure = nil)
20
+ loss = nil
21
+ if closure
22
+ loss = closure.call
23
+ end
24
+
25
+ @param_groups.each do |group|
26
+ weight_decay = group[:weight_decay]
27
+ momentum = group[:momentum]
28
+ dampening = group[:dampening]
29
+ nesterov = group[:nesterov]
30
+
31
+ group[:params].each do |p|
32
+ next unless p.grad
33
+ d_p = p.grad.data
34
+ if weight_decay != 0
35
+ d_p.add!(weight_decay, p.data)
36
+ end
37
+ if momentum != 0
38
+ param_state = @state[p]
39
+ if !param_state.key(:momentum_buffer)
40
+ buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
+ else
42
+ buf = param_state[:momentum_buffer]
43
+ buf.mul!(momentum).add!(1 - dampening, d_p)
44
+ end
45
+ if nesterov
46
+ d_p = d_p.add(momentum, buf)
47
+ else
48
+ d_p = buf
49
+ end
50
+ end
51
+
52
+ p.data.add!(-group[:lr], d_p)
53
+ end
54
+ end
55
+
56
+ loss
57
+ end
58
+ end
59
+ end
60
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module Random
3
+ class << self
4
+ # not available through LibTorch
5
+ def initial_seed
6
+ raise NotImplementedYet
7
+ end
8
+ end
9
+ end
10
+ end
data/lib/torch/tensor.rb CHANGED
@@ -5,12 +5,8 @@ module Torch
5
5
 
6
6
  alias_method :requires_grad?, :requires_grad
7
7
 
8
- def self.new(*size)
9
- if size.first.is_a?(Tensor)
10
- size.first
11
- else
12
- Torch.rand(*size)
13
- end
8
+ def self.new(*args)
9
+ FloatTensor.new(*args)
14
10
  end
15
11
 
16
12
  def dtype
@@ -28,12 +24,18 @@ module Torch
28
24
  end
29
25
 
30
26
  def to_a
31
- reshape(_data, shape)
27
+ reshape_arr(_flat_data, shape)
28
+ end
29
+
30
+ # TODO support dtype
31
+ def to(device, non_blocking: false, copy: false)
32
+ device = Device.new(device) if device.is_a?(String)
33
+ _to(device, _dtype, non_blocking, copy)
32
34
  end
33
35
 
34
36
  def size(dim = nil)
35
37
  if dim
36
- _size(dim)
38
+ _size_int(dim)
37
39
  else
38
40
  shape
39
41
  end
@@ -51,19 +53,23 @@ module Torch
51
53
  if numel != 1
52
54
  raise Error, "only one element tensors can be converted to Ruby scalars"
53
55
  end
54
- _data.first
56
+ _flat_data.first
55
57
  end
56
58
 
57
- def data
58
- Torch.tensor(to_a)
59
+ # unsure if this is correct
60
+ def new
61
+ Torch.empty(0, dtype: dtype)
62
+ end
63
+
64
+ def backward(gradient = nil)
65
+ _backward(gradient)
59
66
  end
60
67
 
61
68
  # TODO read directly from memory
62
69
  def numo
63
- raise Error, "Numo not found" unless defined?(Numo::NArray)
64
70
  cls = Torch._dtype_to_numo[dtype]
65
71
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
66
- cls.cast(_data).reshape(*shape)
72
+ cls.cast(_flat_data).reshape(*shape)
67
73
  end
68
74
 
69
75
  def new_ones(*size, **options)
@@ -74,8 +80,33 @@ module Torch
74
80
  _requires_grad!(requires_grad)
75
81
  end
76
82
 
83
+ def type(dtype)
84
+ enum = DTYPE_TO_ENUM[dtype]
85
+ raise Error, "Unknown type: #{dtype}" unless enum
86
+ _type(enum)
87
+ end
88
+
89
+ # start temp operations
90
+
91
+ def add!(value = 1, other)
92
+ if other.is_a?(Numeric)
93
+ _add__scalar(other, value)
94
+ else
95
+ # need to use alpha for sparse tensors instead of multiplying
96
+ _add__tensor(other, value)
97
+ end
98
+ end
99
+
100
+ def mul!(other)
101
+ if other.is_a?(Numeric)
102
+ _mul__scalar(other)
103
+ else
104
+ _mul__tensor(other)
105
+ end
106
+ end
107
+
77
108
  # operations
78
- %w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze).each do |op|
109
+ %w(log_softmax mean softmax sum topk).each do |op|
79
110
  define_method(op) do |*args, **options, &block|
80
111
  if options.any?
81
112
  Torch.send(op, self, *args, **options, &block)
@@ -85,6 +116,8 @@ module Torch
85
116
  end
86
117
  end
87
118
 
119
+ # end temp operations
120
+
88
121
  def +(other)
89
122
  add(other)
90
123
  end
@@ -117,18 +150,56 @@ module Torch
117
150
  item <=> other
118
151
  end
119
152
 
120
- # TODO use accessor C++ method
121
- def [](index, *args)
122
- v = _access(index)
123
- args.each do |i|
124
- v = v._access(i)
153
+ # based on python_variable_indexing.cpp
154
+ def [](*indexes)
155
+ result = self
156
+ dim = 0
157
+ indexes.each do |index|
158
+ if index.is_a?(Numeric)
159
+ result = result._select_int(dim, index)
160
+ elsif index.is_a?(Range)
161
+ finish = index.end
162
+ finish += 1 unless index.exclude_end?
163
+ result = result._slice_tensor(dim, index.begin, finish, 1)
164
+ dim += 1
165
+ elsif index.nil?
166
+ result = result.unsqueeze(dim)
167
+ dim += 1
168
+ elsif index == true
169
+ result = result.unsqueeze(dim)
170
+ # TODO handle false
171
+ else
172
+ raise Error, "Unsupported index type: #{index.class.name}"
173
+ end
174
+ end
175
+ result
176
+ end
177
+
178
+ # TODO
179
+ # based on python_variable_indexing.cpp
180
+ def []=(index, value)
181
+ raise ArgumentError, "Tensor does not support deleting items" if value.nil?
182
+
183
+ value = Torch.tensor(value) unless value.is_a?(Tensor)
184
+
185
+ if index.is_a?(Numeric)
186
+ copy_to(_select_int(0, index), value)
187
+ elsif index.is_a?(Range)
188
+ finish = index.end
189
+ finish += 1 unless index.exclude_end?
190
+ copy_to(_slice_tensor(0, index.begin, finish, 1), value)
191
+ else
192
+ raise Error, "Unsupported index type: #{index.class.name}"
125
193
  end
126
- v
127
194
  end
128
195
 
129
196
  private
130
197
 
131
- def reshape(arr, dims)
198
+ def copy_to(dst, src)
199
+ dst.copy!(src)
200
+ end
201
+
202
+ def reshape_arr(arr, dims)
132
203
  if dims.empty?
133
204
  arr
134
205
  else
@@ -2,10 +2,25 @@ module Torch
2
2
  module Utils
3
3
  module Data
4
4
  class DataLoader
5
+ include Enumerable
6
+
7
+ attr_reader :dataset
8
+
5
9
  def initialize(dataset, batch_size: 1)
6
10
  @dataset = dataset
7
11
  @batch_size = batch_size
8
12
  end
13
+
14
+ def each
15
+ size.times do |i|
16
+ start_index = i * @batch_size
17
+ yield @dataset[start_index...(start_index + @batch_size)]
18
+ end
19
+ end
20
+
21
+ def size
22
+ (@dataset.size / @batch_size.to_f).ceil
23
+ end
9
24
  end
10
25
  end
11
26
  end
@@ -3,11 +3,18 @@ module Torch
3
3
  module Data
4
4
  class TensorDataset
5
5
  def initialize(*tensors)
6
+ unless tensors.all? { |t| t.size(0) == tensors[0].size(0) }
7
+ raise Error, "Tensors must all have same dim 0 size"
8
+ end
6
9
  @tensors = tensors
7
10
  end
8
11
 
9
12
  def [](index)
10
- tensors.map { |t| t[index] }
13
+ @tensors.map { |t| t[index] }
14
+ end
15
+
16
+ def size
17
+ @tensors[0].size(0)
11
18
  end
12
19
  end
13
20
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.1.0"
2
+ VERSION = "0.1.5"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.5
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-11-26 00:00:00.000000000 Z
11
+ date: 2019-12-07 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -106,26 +106,97 @@ files:
106
106
  - README.md
107
107
  - ext/torch/ext.cpp
108
108
  - ext/torch/extconf.rb
109
+ - ext/torch/nn_functions.cpp
110
+ - ext/torch/nn_functions.hpp
111
+ - ext/torch/templates.hpp
112
+ - ext/torch/tensor_functions.cpp
113
+ - ext/torch/tensor_functions.hpp
114
+ - ext/torch/torch_functions.cpp
115
+ - ext/torch/torch_functions.hpp
109
116
  - lib/torch-rb.rb
110
117
  - lib/torch.rb
111
118
  - lib/torch/ext.bundle
112
119
  - lib/torch/inspector.rb
120
+ - lib/torch/native/dispatcher.rb
121
+ - lib/torch/native/function.rb
122
+ - lib/torch/native/generator.rb
123
+ - lib/torch/native/native_functions.yaml
124
+ - lib/torch/native/parser.rb
125
+ - lib/torch/nn/alpha_dropout.rb
126
+ - lib/torch/nn/avg_pool2d.rb
127
+ - lib/torch/nn/avg_poolnd.rb
128
+ - lib/torch/nn/bce_loss.rb
129
+ - lib/torch/nn/bce_with_logits_loss.rb
130
+ - lib/torch/nn/bilinear.rb
113
131
  - lib/torch/nn/conv2d.rb
132
+ - lib/torch/nn/convnd.rb
133
+ - lib/torch/nn/cosine_embedding_loss.rb
134
+ - lib/torch/nn/cosine_similarity.rb
135
+ - lib/torch/nn/cross_entropy_loss.rb
136
+ - lib/torch/nn/ctc_loss.rb
137
+ - lib/torch/nn/dropout.rb
138
+ - lib/torch/nn/dropout2d.rb
139
+ - lib/torch/nn/dropout3d.rb
140
+ - lib/torch/nn/dropoutnd.rb
141
+ - lib/torch/nn/embedding.rb
142
+ - lib/torch/nn/embedding_bag.rb
143
+ - lib/torch/nn/feature_alpha_dropout.rb
114
144
  - lib/torch/nn/functional.rb
145
+ - lib/torch/nn/hinge_embedding_loss.rb
146
+ - lib/torch/nn/identity.rb
115
147
  - lib/torch/nn/init.rb
148
+ - lib/torch/nn/kl_div_loss.rb
149
+ - lib/torch/nn/l1_loss.rb
150
+ - lib/torch/nn/leaky_relu.rb
116
151
  - lib/torch/nn/linear.rb
152
+ - lib/torch/nn/log_softmax.rb
153
+ - lib/torch/nn/loss.rb
154
+ - lib/torch/nn/margin_ranking_loss.rb
155
+ - lib/torch/nn/max_pool2d.rb
156
+ - lib/torch/nn/max_poolnd.rb
117
157
  - lib/torch/nn/module.rb
118
158
  - lib/torch/nn/mse_loss.rb
159
+ - lib/torch/nn/multi_label_margin_loss.rb
160
+ - lib/torch/nn/multi_label_soft_margin_loss.rb
161
+ - lib/torch/nn/multi_margin_loss.rb
162
+ - lib/torch/nn/nll_loss.rb
163
+ - lib/torch/nn/pairwise_distance.rb
119
164
  - lib/torch/nn/parameter.rb
165
+ - lib/torch/nn/poisson_nll_loss.rb
166
+ - lib/torch/nn/prelu.rb
120
167
  - lib/torch/nn/relu.rb
168
+ - lib/torch/nn/rnn.rb
169
+ - lib/torch/nn/rnn_base.rb
121
170
  - lib/torch/nn/sequential.rb
171
+ - lib/torch/nn/sigmoid.rb
172
+ - lib/torch/nn/smooth_l1_loss.rb
173
+ - lib/torch/nn/soft_margin_loss.rb
174
+ - lib/torch/nn/softmax.rb
175
+ - lib/torch/nn/softmax2d.rb
176
+ - lib/torch/nn/softmin.rb
177
+ - lib/torch/nn/softplus.rb
178
+ - lib/torch/nn/triplet_margin_loss.rb
179
+ - lib/torch/nn/weighted_loss.rb
180
+ - lib/torch/optim/adadelta.rb
181
+ - lib/torch/optim/adagrad.rb
182
+ - lib/torch/optim/adam.rb
183
+ - lib/torch/optim/adamax.rb
184
+ - lib/torch/optim/adamw.rb
185
+ - lib/torch/optim/asgd.rb
186
+ - lib/torch/optim/lr_scheduler/lr_scheduler.rb
187
+ - lib/torch/optim/lr_scheduler/step_lr.rb
188
+ - lib/torch/optim/optimizer.rb
189
+ - lib/torch/optim/rmsprop.rb
190
+ - lib/torch/optim/rprop.rb
191
+ - lib/torch/optim/sgd.rb
192
+ - lib/torch/random.rb
122
193
  - lib/torch/tensor.rb
123
194
  - lib/torch/utils/data/data_loader.rb
124
195
  - lib/torch/utils/data/tensor_dataset.rb
125
196
  - lib/torch/version.rb
126
197
  homepage: https://github.com/ankane/torch-rb
127
198
  licenses:
128
- - MIT
199
+ - BSD-3-Clause
129
200
  metadata: {}
130
201
  post_install_message:
131
202
  rdoc_options: []