torch-rb 0.1.3 → 0.1.8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (115) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +30 -0
  3. data/README.md +5 -2
  4. data/ext/torch/ext.cpp +130 -555
  5. data/ext/torch/extconf.rb +9 -0
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +244 -0
  8. data/lib/torch.rb +209 -171
  9. data/lib/torch/inspector.rb +23 -19
  10. data/lib/torch/native/dispatcher.rb +48 -0
  11. data/lib/torch/native/function.rb +110 -0
  12. data/lib/torch/native/generator.rb +168 -0
  13. data/lib/torch/native/native_functions.yaml +6491 -0
  14. data/lib/torch/native/parser.rb +134 -0
  15. data/lib/torch/nn/avg_pool1d.rb +18 -0
  16. data/lib/torch/nn/avg_pool2d.rb +19 -0
  17. data/lib/torch/nn/avg_pool3d.rb +19 -0
  18. data/lib/torch/nn/avg_poolnd.rb +9 -0
  19. data/lib/torch/nn/batch_norm.rb +75 -0
  20. data/lib/torch/nn/batch_norm1d.rb +11 -0
  21. data/lib/torch/nn/batch_norm2d.rb +11 -0
  22. data/lib/torch/nn/batch_norm3d.rb +11 -0
  23. data/lib/torch/nn/bce_loss.rb +13 -0
  24. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  25. data/lib/torch/nn/bilinear.rb +38 -0
  26. data/lib/torch/nn/constant_pad1d.rb +10 -0
  27. data/lib/torch/nn/constant_pad2d.rb +10 -0
  28. data/lib/torch/nn/constant_pad3d.rb +10 -0
  29. data/lib/torch/nn/constant_padnd.rb +18 -0
  30. data/lib/torch/nn/conv1d.rb +22 -0
  31. data/lib/torch/nn/conv2d.rb +10 -20
  32. data/lib/torch/nn/conv3d.rb +22 -0
  33. data/lib/torch/nn/convnd.rb +3 -3
  34. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  35. data/lib/torch/nn/cosine_similarity.rb +15 -0
  36. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  37. data/lib/torch/nn/ctc_loss.rb +15 -0
  38. data/lib/torch/nn/dropoutnd.rb +2 -2
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/fold.rb +20 -0
  41. data/lib/torch/nn/functional.rb +379 -32
  42. data/lib/torch/nn/group_norm.rb +36 -0
  43. data/lib/torch/nn/gru.rb +49 -0
  44. data/lib/torch/nn/hardshrink.rb +18 -0
  45. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  46. data/lib/torch/nn/identity.rb +14 -0
  47. data/lib/torch/nn/init.rb +58 -1
  48. data/lib/torch/nn/instance_norm.rb +20 -0
  49. data/lib/torch/nn/instance_norm1d.rb +18 -0
  50. data/lib/torch/nn/instance_norm2d.rb +11 -0
  51. data/lib/torch/nn/instance_norm3d.rb +11 -0
  52. data/lib/torch/nn/kl_div_loss.rb +13 -0
  53. data/lib/torch/nn/l1_loss.rb +13 -0
  54. data/lib/torch/nn/layer_norm.rb +35 -0
  55. data/lib/torch/nn/leaky_relu.rb +20 -0
  56. data/lib/torch/nn/linear.rb +12 -11
  57. data/lib/torch/nn/local_response_norm.rb +21 -0
  58. data/lib/torch/nn/log_sigmoid.rb +9 -0
  59. data/lib/torch/nn/log_softmax.rb +14 -0
  60. data/lib/torch/nn/loss.rb +10 -0
  61. data/lib/torch/nn/lp_pool1d.rb +9 -0
  62. data/lib/torch/nn/lp_pool2d.rb +9 -0
  63. data/lib/torch/nn/lp_poolnd.rb +22 -0
  64. data/lib/torch/nn/lstm.rb +66 -0
  65. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  66. data/lib/torch/nn/max_pool1d.rb +9 -0
  67. data/lib/torch/nn/max_pool2d.rb +9 -0
  68. data/lib/torch/nn/max_pool3d.rb +9 -0
  69. data/lib/torch/nn/max_poolnd.rb +19 -0
  70. data/lib/torch/nn/max_unpool1d.rb +16 -0
  71. data/lib/torch/nn/max_unpool2d.rb +16 -0
  72. data/lib/torch/nn/max_unpool3d.rb +16 -0
  73. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  74. data/lib/torch/nn/module.rb +186 -35
  75. data/lib/torch/nn/mse_loss.rb +2 -2
  76. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  77. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  78. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  79. data/lib/torch/nn/nll_loss.rb +14 -0
  80. data/lib/torch/nn/pairwise_distance.rb +16 -0
  81. data/lib/torch/nn/parameter.rb +2 -2
  82. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  83. data/lib/torch/nn/prelu.rb +19 -0
  84. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  85. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  86. data/lib/torch/nn/reflection_padnd.rb +13 -0
  87. data/lib/torch/nn/relu.rb +8 -3
  88. data/lib/torch/nn/replication_pad1d.rb +10 -0
  89. data/lib/torch/nn/replication_pad2d.rb +10 -0
  90. data/lib/torch/nn/replication_pad3d.rb +10 -0
  91. data/lib/torch/nn/replication_padnd.rb +13 -0
  92. data/lib/torch/nn/rnn.rb +22 -0
  93. data/lib/torch/nn/rnn_base.rb +198 -0
  94. data/lib/torch/nn/sequential.rb +1 -10
  95. data/lib/torch/nn/sigmoid.rb +9 -0
  96. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  97. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  98. data/lib/torch/nn/softmax.rb +18 -0
  99. data/lib/torch/nn/softmax2d.rb +10 -0
  100. data/lib/torch/nn/softmin.rb +14 -0
  101. data/lib/torch/nn/softplus.rb +19 -0
  102. data/lib/torch/nn/softshrink.rb +18 -0
  103. data/lib/torch/nn/softsign.rb +9 -0
  104. data/lib/torch/nn/tanh.rb +9 -0
  105. data/lib/torch/nn/tanhshrink.rb +9 -0
  106. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  107. data/lib/torch/nn/unfold.rb +19 -0
  108. data/lib/torch/nn/utils.rb +25 -0
  109. data/lib/torch/nn/weighted_loss.rb +10 -0
  110. data/lib/torch/nn/zero_pad2d.rb +9 -0
  111. data/lib/torch/random.rb +10 -0
  112. data/lib/torch/tensor.rb +51 -44
  113. data/lib/torch/version.rb +1 -1
  114. metadata +98 -6
  115. data/lib/torch/ext.bundle +0 -0
@@ -2,28 +2,19 @@ module Torch
2
2
  module NN
3
3
  class Sequential < Module
4
4
  def initialize(*args)
5
- @modules = {}
5
+ super()
6
6
  # TODO support hash arg (named modules)
7
7
  args.each_with_index do |mod, idx|
8
8
  add_module(idx.to_s, mod)
9
9
  end
10
10
  end
11
11
 
12
- def add_module(name, mod)
13
- # TODO add checks
14
- @modules[name] = mod
15
- end
16
-
17
12
  def forward(input)
18
13
  @modules.values.each do |mod|
19
14
  input = mod.call(input)
20
15
  end
21
16
  input
22
17
  end
23
-
24
- def parameters
25
- @modules.flat_map { |_, mod| mod.parameters }
26
- end
27
18
  end
28
19
  end
29
20
  end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Sigmoid < Module
4
+ def forward(input)
5
+ Torch.sigmoid(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class SmoothL1Loss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.smooth_l1_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class SoftMarginLoss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.soft_margin_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class Softmax < Module
4
+ def initialize(dim: nil)
5
+ super()
6
+ @dim = dim
7
+ end
8
+
9
+ def forward(input)
10
+ F.softmax(input, dim: @dim)
11
+ end
12
+
13
+ def extra_inspect
14
+ format("dim: %s", @dim)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class Softmax2d < Module
4
+ def forward(input)
5
+ raise ArgumentError, "Softmax2d requires a 4D tensor as input" unless input.dim == 4
6
+ F.softmax(input, dim: 1)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class Softmin < Module
4
+ def initialize(dim: nil)
5
+ super()
6
+ @dim = dim
7
+ end
8
+
9
+ def forward(input)
10
+ F.softmin(input, dim: @dim)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class Softplus < Module
4
+ def initialize(beta: 1, threshold: 20)
5
+ super()
6
+ @beta = beta
7
+ @threshold = threshold
8
+ end
9
+
10
+ def forward(input)
11
+ F.softplus(input, beta: @beta, threshold: @threshold)
12
+ end
13
+
14
+ def extra_inspect
15
+ format("beta: %s, threshold: %s", @beta, @threshold)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class Softshrink < Module
4
+ def initialize(lambd: 0.5)
5
+ super()
6
+ @lambd = lambd
7
+ end
8
+
9
+ def forward(input)
10
+ F.softshrink(input, @lambd)
11
+ end
12
+
13
+ def extra_inspect
14
+ @lambd.to_s
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Softsign < Module
4
+ def forward(input)
5
+ F.softsign(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Tanh < Module
4
+ def forward(input)
5
+ Torch.tanh(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Tanhshrink < Module
4
+ def forward(input)
5
+ F.tanhshrink(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class TripletMarginLoss < Loss
4
+ def initialize(margin: 1.0, p: 2.0, eps: 1e-6, swap: false, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ @p = p
8
+ @eps = eps
9
+ @swap = swap
10
+ end
11
+
12
+ def forward(anchor, positive, negative)
13
+ F.triplet_margin_loss(anchor, positive, negative, margin: @margin, p: @p,
14
+ eps: @eps, swap: @swap, reduction: @reduction)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class Unfold < Module
4
+ def initialize(kernel_size, dilation: 1, padding: 0, stride: 1)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ @dilation = dilation
8
+ @padding = padding
9
+ @stride = stride
10
+ end
11
+
12
+ def forward(input)
13
+ F.unfold(input, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
14
+ end
15
+
16
+ # TODO add extra_inspect
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module NN
3
+ module Utils
4
+ def _single(value)
5
+ _ntuple(1, value)
6
+ end
7
+
8
+ def _pair(value)
9
+ _ntuple(2, value)
10
+ end
11
+
12
+ def _triple(value)
13
+ _ntuple(3, value)
14
+ end
15
+
16
+ def _quadrupal(value)
17
+ _ntuple(4, value)
18
+ end
19
+
20
+ def _ntuple(n, value)
21
+ value.is_a?(Array) ? value : [value] * n
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class WeightedLoss < Loss
4
+ def initialize(weight, reduction)
5
+ super(reduction)
6
+ register_buffer("weight", weight)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class ZeroPad2d < ConstantPad2d
4
+ def initialize(padding)
5
+ super(padding, 0.0)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module Random
3
+ class << self
4
+ # not available through LibTorch
5
+ def initial_seed
6
+ raise NotImplementedYet
7
+ end
8
+ end
9
+ end
10
+ end
@@ -5,12 +5,8 @@ module Torch
5
5
 
6
6
  alias_method :requires_grad?, :requires_grad
7
7
 
8
- def self.new(*size)
9
- if size.length == 1 && size.first.is_a?(Tensor)
10
- size.first
11
- else
12
- Torch.empty(*size)
13
- end
8
+ def self.new(*args)
9
+ FloatTensor.new(*args)
14
10
  end
15
11
 
16
12
  def dtype
@@ -28,7 +24,7 @@ module Torch
28
24
  end
29
25
 
30
26
  def to_a
31
- reshape_arr(_data, shape)
27
+ reshape_arr(_flat_data, shape)
32
28
  end
33
29
 
34
30
  # TODO support dtype
@@ -39,7 +35,7 @@ module Torch
39
35
 
40
36
  def size(dim = nil)
41
37
  if dim
42
- _size(dim)
38
+ _size_int(dim)
43
39
  else
44
40
  shape
45
41
  end
@@ -49,15 +45,16 @@ module Torch
49
45
  dim.times.map { |i| size(i) }
50
46
  end
51
47
 
52
- def view(*size)
53
- _view(size)
48
+ # mirror Python len()
49
+ def length
50
+ size(0)
54
51
  end
55
52
 
56
53
  def item
57
54
  if numel != 1
58
55
  raise Error, "only one element tensors can be converted to Ruby scalars"
59
56
  end
60
- _data.first
57
+ _flat_data.first
61
58
  end
62
59
 
63
60
  # unsure if this is correct
@@ -66,19 +63,14 @@ module Torch
66
63
  end
67
64
 
68
65
  def backward(gradient = nil)
69
- if gradient
70
- _backward_gradient(gradient)
71
- else
72
- _backward
73
- end
66
+ _backward(gradient)
74
67
  end
75
68
 
76
69
  # TODO read directly from memory
77
70
  def numo
78
- raise Error, "Numo not found" unless defined?(Numo::NArray)
79
71
  cls = Torch._dtype_to_numo[dtype]
80
72
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
81
- cls.cast(_data).reshape(*shape)
73
+ cls.cast(_flat_data).reshape(*shape)
82
74
  end
83
75
 
84
76
  def new_ones(*size, **options)
@@ -95,31 +87,23 @@ module Torch
95
87
  _type(enum)
96
88
  end
97
89
 
98
- def add!(value = 1, other)
99
- if other.is_a?(Numeric)
100
- _add_scalar!(other * value)
101
- else
102
- # need to use alpha for sparse tensors instead of multiplying
103
- _add_alpha!(other, value)
104
- end
90
+ def reshape(*size)
91
+ # Python doesn't check if size == 1, just ignores later arguments
92
+ size = size.first if size.size == 1 && size.first.is_a?(Array)
93
+ _reshape(size)
105
94
  end
106
95
 
107
- def mul!(other)
108
- if other.is_a?(Numeric)
109
- _mul_scalar!(other)
110
- else
111
- _mul!(other)
112
- end
96
+ def view(*size)
97
+ size = size.first if size.size == 1 && size.first.is_a?(Array)
98
+ _view(size)
113
99
  end
114
100
 
115
- # operations
116
- %w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
117
- define_method(op) do |*args, **options, &block|
118
- if options.any?
119
- Torch.send(op, self, *args, **options, &block)
120
- else
121
- Torch.send(op, self, *args, &block)
122
- end
101
+ # value and other are swapped for some methods
102
+ def add!(value = 1, other)
103
+ if other.is_a?(Numeric)
104
+ _add__scalar(other, value)
105
+ else
106
+ _add__tensor(other, value)
123
107
  end
124
108
  end
125
109
 
@@ -161,14 +145,20 @@ module Torch
161
145
  dim = 0
162
146
  indexes.each do |index|
163
147
  if index.is_a?(Numeric)
164
- result = result._select(dim, index)
148
+ result = result._select_int(dim, index)
165
149
  elsif index.is_a?(Range)
166
150
  finish = index.end
167
151
  finish += 1 unless index.exclude_end?
168
- result = result._slice(dim, index.begin, finish, 1)
152
+ result = result._slice_tensor(dim, index.begin, finish, 1)
153
+ dim += 1
154
+ elsif index.nil?
155
+ result = result.unsqueeze(dim)
169
156
  dim += 1
157
+ elsif index == true
158
+ result = result.unsqueeze(dim)
159
+ # TODO handle false
170
160
  else
171
- raise Error, "Unsupported index type"
161
+ raise Error, "Unsupported index type: #{index.class.name}"
172
162
  end
173
163
  end
174
164
  result
@@ -176,11 +166,28 @@ module Torch
176
166
 
177
167
  # TODO
178
168
  # based on python_variable_indexing.cpp
179
- # def []=(index, value)
180
- # end
169
+ def []=(index, value)
170
+ raise ArgumentError, "Tensor does not support deleting items" if value.nil?
171
+
172
+ value = Torch.tensor(value) unless value.is_a?(Tensor)
173
+
174
+ if index.is_a?(Numeric)
175
+ copy_to(_select_int(0, index), value)
176
+ elsif index.is_a?(Range)
177
+ finish = index.end
178
+ finish += 1 unless index.exclude_end?
179
+ copy_to(_slice_tensor(0, index.begin, finish, 1), value)
180
+ else
181
+ raise Error, "Unsupported index type: #{index.class.name}"
182
+ end
183
+ end
181
184
 
182
185
  private
183
186
 
187
+ def copy_to(dst, src)
188
+ dst.copy!(src)
189
+ end
190
+
184
191
  def reshape_arr(arr, dims)
185
192
  if dims.empty?
186
193
  arr