torch-rb 0.1.0 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +85 -19
  5. data/ext/torch/ext.cpp +274 -256
  6. data/ext/torch/extconf.rb +9 -0
  7. data/ext/torch/nn_functions.cpp +595 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.hpp +250 -0
  10. data/ext/torch/tensor_functions.cpp +1860 -0
  11. data/ext/torch/tensor_functions.hpp +6 -0
  12. data/ext/torch/torch_functions.cpp +2875 -0
  13. data/ext/torch/torch_functions.hpp +6 -0
  14. data/lib/torch.rb +199 -84
  15. data/lib/torch/ext.bundle +0 -0
  16. data/lib/torch/inspector.rb +52 -25
  17. data/lib/torch/native/dispatcher.rb +48 -0
  18. data/lib/torch/native/function.rb +78 -0
  19. data/lib/torch/native/generator.rb +149 -0
  20. data/lib/torch/native/native_functions.yaml +6837 -0
  21. data/lib/torch/native/parser.rb +97 -0
  22. data/lib/torch/nn/alpha_dropout.rb +9 -0
  23. data/lib/torch/nn/avg_pool2d.rb +14 -0
  24. data/lib/torch/nn/avg_poolnd.rb +9 -0
  25. data/lib/torch/nn/bce_loss.rb +13 -0
  26. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  27. data/lib/torch/nn/bilinear.rb +38 -0
  28. data/lib/torch/nn/conv2d.rb +14 -29
  29. data/lib/torch/nn/convnd.rb +41 -0
  30. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  31. data/lib/torch/nn/cosine_similarity.rb +15 -0
  32. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  33. data/lib/torch/nn/ctc_loss.rb +15 -0
  34. data/lib/torch/nn/dropout.rb +9 -0
  35. data/lib/torch/nn/dropout2d.rb +9 -0
  36. data/lib/torch/nn/dropout3d.rb +9 -0
  37. data/lib/torch/nn/dropoutnd.rb +15 -0
  38. data/lib/torch/nn/embedding.rb +52 -0
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  41. data/lib/torch/nn/functional.rb +194 -11
  42. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  43. data/lib/torch/nn/identity.rb +14 -0
  44. data/lib/torch/nn/init.rb +58 -1
  45. data/lib/torch/nn/kl_div_loss.rb +13 -0
  46. data/lib/torch/nn/l1_loss.rb +13 -0
  47. data/lib/torch/nn/leaky_relu.rb +20 -0
  48. data/lib/torch/nn/linear.rb +12 -11
  49. data/lib/torch/nn/log_softmax.rb +14 -0
  50. data/lib/torch/nn/loss.rb +10 -0
  51. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  52. data/lib/torch/nn/max_pool2d.rb +9 -0
  53. data/lib/torch/nn/max_poolnd.rb +19 -0
  54. data/lib/torch/nn/module.rb +184 -19
  55. data/lib/torch/nn/mse_loss.rb +2 -2
  56. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  57. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  58. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  59. data/lib/torch/nn/nll_loss.rb +14 -0
  60. data/lib/torch/nn/pairwise_distance.rb +16 -0
  61. data/lib/torch/nn/parameter.rb +4 -0
  62. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  63. data/lib/torch/nn/prelu.rb +19 -0
  64. data/lib/torch/nn/relu.rb +8 -3
  65. data/lib/torch/nn/rnn.rb +22 -0
  66. data/lib/torch/nn/rnn_base.rb +154 -0
  67. data/lib/torch/nn/sequential.rb +1 -10
  68. data/lib/torch/nn/sigmoid.rb +9 -0
  69. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  70. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  71. data/lib/torch/nn/softmax.rb +18 -0
  72. data/lib/torch/nn/softmax2d.rb +10 -0
  73. data/lib/torch/nn/softmin.rb +14 -0
  74. data/lib/torch/nn/softplus.rb +19 -0
  75. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  76. data/lib/torch/nn/weighted_loss.rb +10 -0
  77. data/lib/torch/optim/adadelta.rb +57 -0
  78. data/lib/torch/optim/adagrad.rb +71 -0
  79. data/lib/torch/optim/adam.rb +81 -0
  80. data/lib/torch/optim/adamax.rb +68 -0
  81. data/lib/torch/optim/adamw.rb +82 -0
  82. data/lib/torch/optim/asgd.rb +65 -0
  83. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  84. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  85. data/lib/torch/optim/optimizer.rb +62 -0
  86. data/lib/torch/optim/rmsprop.rb +76 -0
  87. data/lib/torch/optim/rprop.rb +68 -0
  88. data/lib/torch/optim/sgd.rb +60 -0
  89. data/lib/torch/random.rb +10 -0
  90. data/lib/torch/tensor.rb +92 -21
  91. data/lib/torch/utils/data/data_loader.rb +15 -0
  92. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  93. data/lib/torch/version.rb +1 -1
  94. metadata +74 -3
@@ -0,0 +1,68 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rprop.py
2
+ module Torch
3
+ module Optim
4
+ class Rprop < Optimizer
5
+ def initialize(params, lr: 1e-2, etas: [0.5, 1.2], step_sizes: [1e-6, 50])
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid eta values: #{etas[0]}, #{etas[1]}" if etas[0] < 0 || etas[0] >= 1 || etas[1] < 1
8
+
9
+ defaults = {lr: lr, etas: etas, step_sizes: step_sizes}
10
+ super(params, defaults)
11
+ end
12
+
13
+ def step(closure = nil)
14
+ # TODO implement []=
15
+ raise NotImplementedYet
16
+
17
+ loss = nil
18
+ if closure
19
+ loss = closure.call
20
+ end
21
+
22
+ @param_groups.each do |group|
23
+ group[:params].each do |p|
24
+ next unless p.grad
25
+ grad = p.grad.data
26
+ if grad.sparse?
27
+ raise Error, "Rprop does not support sparse gradients"
28
+ end
29
+ state = @state[p]
30
+
31
+ # State initialization
32
+ if state.size == 0
33
+ state[:step] = 0
34
+ state[:prev] = Torch.zeros_like(p.data)
35
+ state[:step_size] = grad.new.resize_as!(grad).fill!(group[:lr])
36
+ end
37
+
38
+ etaminus, etaplus = group[:etas]
39
+ step_size_min, step_size_max = group[:step_sizes]
40
+ step_size = state[:step_size]
41
+
42
+ state[:step] += 1
43
+
44
+ sign = grad.mul(state[:prev]).sign
45
+ sign[sign.gt(0)] = etaplus
46
+ sign[sign.lt(0)] = etaminus
47
+ sign[sign.eq(0)] = 1
48
+
49
+ # update stepsizes with step size updates
50
+ step_size.mul!(sign).clamp!(step_size_min, step_size_max)
51
+
52
+ # for dir<0, dfdx=0
53
+ # for dir>=0 dfdx=dfdx
54
+ grad = grad.clone
55
+ grad[sign.eq(etaminus)] = 0
56
+
57
+ # update parameters
58
+ p.data.addcmul!(-1, grad.sign, step_size)
59
+
60
+ state[:prev].copy!(grad)
61
+ end
62
+ end
63
+
64
+ loss
65
+ end
66
+ end
67
+ end
68
+ end
@@ -0,0 +1,60 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
2
+ module Torch
3
+ module Optim
4
+ class SGD < Optimizer
5
+ def initialize(params, lr:, momentum: 0, dampening: 0, weight_decay: 0, nesterov: false)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0.0
7
+ raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0.0
8
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0.0
9
+
10
+ defaults = {lr: lr, momentum: momentum, dampening: dampening, weight_decay: weight_decay, nesterov: nesterov}
11
+
12
+ if nesterov && (momentum <= 0 || dampening != 0)
13
+ raise ArgumentError, "Nesterov momentum requires a momentum and zero dampening"
14
+ end
15
+
16
+ super(params, defaults)
17
+ end
18
+
19
+ def step(closure = nil)
20
+ loss = nil
21
+ if closure
22
+ loss = closure.call
23
+ end
24
+
25
+ @param_groups.each do |group|
26
+ weight_decay = group[:weight_decay]
27
+ momentum = group[:momentum]
28
+ dampening = group[:dampening]
29
+ nesterov = group[:nesterov]
30
+
31
+ group[:params].each do |p|
32
+ next unless p.grad
33
+ d_p = p.grad.data
34
+ if weight_decay != 0
35
+ d_p.add!(weight_decay, p.data)
36
+ end
37
+ if momentum != 0
38
+ param_state = @state[p]
39
+ if !param_state.key(:momentum_buffer)
40
+ buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
+ else
42
+ buf = param_state[:momentum_buffer]
43
+ buf.mul!(momentum).add!(1 - dampening, d_p)
44
+ end
45
+ if nesterov
46
+ d_p = d_p.add(momentum, buf)
47
+ else
48
+ d_p = buf
49
+ end
50
+ end
51
+
52
+ p.data.add!(-group[:lr], d_p)
53
+ end
54
+ end
55
+
56
+ loss
57
+ end
58
+ end
59
+ end
60
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module Random
3
+ class << self
4
+ # not available through LibTorch
5
+ def initial_seed
6
+ raise NotImplementedYet
7
+ end
8
+ end
9
+ end
10
+ end
data/lib/torch/tensor.rb CHANGED
@@ -5,12 +5,8 @@ module Torch
5
5
 
6
6
  alias_method :requires_grad?, :requires_grad
7
7
 
8
- def self.new(*size)
9
- if size.first.is_a?(Tensor)
10
- size.first
11
- else
12
- Torch.rand(*size)
13
- end
8
+ def self.new(*args)
9
+ FloatTensor.new(*args)
14
10
  end
15
11
 
16
12
  def dtype
@@ -28,12 +24,18 @@ module Torch
28
24
  end
29
25
 
30
26
  def to_a
31
- reshape(_data, shape)
27
+ reshape_arr(_flat_data, shape)
28
+ end
29
+
30
+ # TODO support dtype
31
+ def to(device, non_blocking: false, copy: false)
32
+ device = Device.new(device) if device.is_a?(String)
33
+ _to(device, _dtype, non_blocking, copy)
32
34
  end
33
35
 
34
36
  def size(dim = nil)
35
37
  if dim
36
- _size(dim)
38
+ _size_int(dim)
37
39
  else
38
40
  shape
39
41
  end
@@ -51,19 +53,23 @@ module Torch
51
53
  if numel != 1
52
54
  raise Error, "only one element tensors can be converted to Ruby scalars"
53
55
  end
54
- _data.first
56
+ _flat_data.first
55
57
  end
56
58
 
57
- def data
58
- Torch.tensor(to_a)
59
+ # unsure if this is correct
60
+ def new
61
+ Torch.empty(0, dtype: dtype)
62
+ end
63
+
64
+ def backward(gradient = nil)
65
+ _backward(gradient)
59
66
  end
60
67
 
61
68
  # TODO read directly from memory
62
69
  def numo
63
- raise Error, "Numo not found" unless defined?(Numo::NArray)
64
70
  cls = Torch._dtype_to_numo[dtype]
65
71
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
66
- cls.cast(_data).reshape(*shape)
72
+ cls.cast(_flat_data).reshape(*shape)
67
73
  end
68
74
 
69
75
  def new_ones(*size, **options)
@@ -74,8 +80,33 @@ module Torch
74
80
  _requires_grad!(requires_grad)
75
81
  end
76
82
 
83
+ def type(dtype)
84
+ enum = DTYPE_TO_ENUM[dtype]
85
+ raise Error, "Unknown type: #{dtype}" unless enum
86
+ _type(enum)
87
+ end
88
+
89
+ # start temp operations
90
+
91
+ def add!(value = 1, other)
92
+ if other.is_a?(Numeric)
93
+ _add__scalar(other, value)
94
+ else
95
+ # need to use alpha for sparse tensors instead of multiplying
96
+ _add__tensor(other, value)
97
+ end
98
+ end
99
+
100
+ def mul!(other)
101
+ if other.is_a?(Numeric)
102
+ _mul__scalar(other)
103
+ else
104
+ _mul__tensor(other)
105
+ end
106
+ end
107
+
77
108
  # operations
78
- %w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze).each do |op|
109
+ %w(log_softmax mean softmax sum topk).each do |op|
79
110
  define_method(op) do |*args, **options, &block|
80
111
  if options.any?
81
112
  Torch.send(op, self, *args, **options, &block)
@@ -85,6 +116,8 @@ module Torch
85
116
  end
86
117
  end
87
118
 
119
+ # end temp operations
120
+
88
121
  def +(other)
89
122
  add(other)
90
123
  end
@@ -117,18 +150,56 @@ module Torch
117
150
  item <=> other
118
151
  end
119
152
 
120
- # TODO use accessor C++ method
121
- def [](index, *args)
122
- v = _access(index)
123
- args.each do |i|
124
- v = v._access(i)
153
+ # based on python_variable_indexing.cpp
154
+ def [](*indexes)
155
+ result = self
156
+ dim = 0
157
+ indexes.each do |index|
158
+ if index.is_a?(Numeric)
159
+ result = result._select_int(dim, index)
160
+ elsif index.is_a?(Range)
161
+ finish = index.end
162
+ finish += 1 unless index.exclude_end?
163
+ result = result._slice_tensor(dim, index.begin, finish, 1)
164
+ dim += 1
165
+ elsif index.nil?
166
+ result = result.unsqueeze(dim)
167
+ dim += 1
168
+ elsif index == true
169
+ result = result.unsqueeze(dim)
170
+ # TODO handle false
171
+ else
172
+ raise Error, "Unsupported index type: #{index.class.name}"
173
+ end
174
+ end
175
+ result
176
+ end
177
+
178
+ # TODO
179
+ # based on python_variable_indexing.cpp
180
+ def []=(index, value)
181
+ raise ArgumentError, "Tensor does not support deleting items" if value.nil?
182
+
183
+ value = Torch.tensor(value) unless value.is_a?(Tensor)
184
+
185
+ if index.is_a?(Numeric)
186
+ copy_to(_select_int(0, index), value)
187
+ elsif index.is_a?(Range)
188
+ finish = index.end
189
+ finish += 1 unless index.exclude_end?
190
+ copy_to(_slice_tensor(0, index.begin, finish, 1), value)
191
+ else
192
+ raise Error, "Unsupported index type: #{index.class.name}"
125
193
  end
126
- v
127
194
  end
128
195
 
129
196
  private
130
197
 
131
- def reshape(arr, dims)
198
+ def copy_to(dst, src)
199
+ dst.copy!(src)
200
+ end
201
+
202
+ def reshape_arr(arr, dims)
132
203
  if dims.empty?
133
204
  arr
134
205
  else
@@ -2,10 +2,25 @@ module Torch
2
2
  module Utils
3
3
  module Data
4
4
  class DataLoader
5
+ include Enumerable
6
+
7
+ attr_reader :dataset
8
+
5
9
  def initialize(dataset, batch_size: 1)
6
10
  @dataset = dataset
7
11
  @batch_size = batch_size
8
12
  end
13
+
14
+ def each
15
+ size.times do |i|
16
+ start_index = i * @batch_size
17
+ yield @dataset[start_index...(start_index + @batch_size)]
18
+ end
19
+ end
20
+
21
+ def size
22
+ (@dataset.size / @batch_size.to_f).ceil
23
+ end
9
24
  end
10
25
  end
11
26
  end
@@ -3,11 +3,18 @@ module Torch
3
3
  module Data
4
4
  class TensorDataset
5
5
  def initialize(*tensors)
6
+ unless tensors.all? { |t| t.size(0) == tensors[0].size(0) }
7
+ raise Error, "Tensors must all have same dim 0 size"
8
+ end
6
9
  @tensors = tensors
7
10
  end
8
11
 
9
12
  def [](index)
10
- tensors.map { |t| t[index] }
13
+ @tensors.map { |t| t[index] }
14
+ end
15
+
16
+ def size
17
+ @tensors[0].size(0)
11
18
  end
12
19
  end
13
20
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.1.0"
2
+ VERSION = "0.1.5"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.5
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-11-26 00:00:00.000000000 Z
11
+ date: 2019-12-07 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -106,26 +106,97 @@ files:
106
106
  - README.md
107
107
  - ext/torch/ext.cpp
108
108
  - ext/torch/extconf.rb
109
+ - ext/torch/nn_functions.cpp
110
+ - ext/torch/nn_functions.hpp
111
+ - ext/torch/templates.hpp
112
+ - ext/torch/tensor_functions.cpp
113
+ - ext/torch/tensor_functions.hpp
114
+ - ext/torch/torch_functions.cpp
115
+ - ext/torch/torch_functions.hpp
109
116
  - lib/torch-rb.rb
110
117
  - lib/torch.rb
111
118
  - lib/torch/ext.bundle
112
119
  - lib/torch/inspector.rb
120
+ - lib/torch/native/dispatcher.rb
121
+ - lib/torch/native/function.rb
122
+ - lib/torch/native/generator.rb
123
+ - lib/torch/native/native_functions.yaml
124
+ - lib/torch/native/parser.rb
125
+ - lib/torch/nn/alpha_dropout.rb
126
+ - lib/torch/nn/avg_pool2d.rb
127
+ - lib/torch/nn/avg_poolnd.rb
128
+ - lib/torch/nn/bce_loss.rb
129
+ - lib/torch/nn/bce_with_logits_loss.rb
130
+ - lib/torch/nn/bilinear.rb
113
131
  - lib/torch/nn/conv2d.rb
132
+ - lib/torch/nn/convnd.rb
133
+ - lib/torch/nn/cosine_embedding_loss.rb
134
+ - lib/torch/nn/cosine_similarity.rb
135
+ - lib/torch/nn/cross_entropy_loss.rb
136
+ - lib/torch/nn/ctc_loss.rb
137
+ - lib/torch/nn/dropout.rb
138
+ - lib/torch/nn/dropout2d.rb
139
+ - lib/torch/nn/dropout3d.rb
140
+ - lib/torch/nn/dropoutnd.rb
141
+ - lib/torch/nn/embedding.rb
142
+ - lib/torch/nn/embedding_bag.rb
143
+ - lib/torch/nn/feature_alpha_dropout.rb
114
144
  - lib/torch/nn/functional.rb
145
+ - lib/torch/nn/hinge_embedding_loss.rb
146
+ - lib/torch/nn/identity.rb
115
147
  - lib/torch/nn/init.rb
148
+ - lib/torch/nn/kl_div_loss.rb
149
+ - lib/torch/nn/l1_loss.rb
150
+ - lib/torch/nn/leaky_relu.rb
116
151
  - lib/torch/nn/linear.rb
152
+ - lib/torch/nn/log_softmax.rb
153
+ - lib/torch/nn/loss.rb
154
+ - lib/torch/nn/margin_ranking_loss.rb
155
+ - lib/torch/nn/max_pool2d.rb
156
+ - lib/torch/nn/max_poolnd.rb
117
157
  - lib/torch/nn/module.rb
118
158
  - lib/torch/nn/mse_loss.rb
159
+ - lib/torch/nn/multi_label_margin_loss.rb
160
+ - lib/torch/nn/multi_label_soft_margin_loss.rb
161
+ - lib/torch/nn/multi_margin_loss.rb
162
+ - lib/torch/nn/nll_loss.rb
163
+ - lib/torch/nn/pairwise_distance.rb
119
164
  - lib/torch/nn/parameter.rb
165
+ - lib/torch/nn/poisson_nll_loss.rb
166
+ - lib/torch/nn/prelu.rb
120
167
  - lib/torch/nn/relu.rb
168
+ - lib/torch/nn/rnn.rb
169
+ - lib/torch/nn/rnn_base.rb
121
170
  - lib/torch/nn/sequential.rb
171
+ - lib/torch/nn/sigmoid.rb
172
+ - lib/torch/nn/smooth_l1_loss.rb
173
+ - lib/torch/nn/soft_margin_loss.rb
174
+ - lib/torch/nn/softmax.rb
175
+ - lib/torch/nn/softmax2d.rb
176
+ - lib/torch/nn/softmin.rb
177
+ - lib/torch/nn/softplus.rb
178
+ - lib/torch/nn/triplet_margin_loss.rb
179
+ - lib/torch/nn/weighted_loss.rb
180
+ - lib/torch/optim/adadelta.rb
181
+ - lib/torch/optim/adagrad.rb
182
+ - lib/torch/optim/adam.rb
183
+ - lib/torch/optim/adamax.rb
184
+ - lib/torch/optim/adamw.rb
185
+ - lib/torch/optim/asgd.rb
186
+ - lib/torch/optim/lr_scheduler/lr_scheduler.rb
187
+ - lib/torch/optim/lr_scheduler/step_lr.rb
188
+ - lib/torch/optim/optimizer.rb
189
+ - lib/torch/optim/rmsprop.rb
190
+ - lib/torch/optim/rprop.rb
191
+ - lib/torch/optim/sgd.rb
192
+ - lib/torch/random.rb
122
193
  - lib/torch/tensor.rb
123
194
  - lib/torch/utils/data/data_loader.rb
124
195
  - lib/torch/utils/data/tensor_dataset.rb
125
196
  - lib/torch/version.rb
126
197
  homepage: https://github.com/ankane/torch-rb
127
198
  licenses:
128
- - MIT
199
+ - BSD-3-Clause
129
200
  metadata: {}
130
201
  post_install_message:
131
202
  rdoc_options: []