torch-rb 0.1.3 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +30 -0
  3. data/README.md +5 -2
  4. data/ext/torch/ext.cpp +130 -555
  5. data/ext/torch/extconf.rb +9 -0
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +244 -0
  8. data/lib/torch.rb +209 -171
  9. data/lib/torch/inspector.rb +23 -19
  10. data/lib/torch/native/dispatcher.rb +48 -0
  11. data/lib/torch/native/function.rb +110 -0
  12. data/lib/torch/native/generator.rb +168 -0
  13. data/lib/torch/native/native_functions.yaml +6491 -0
  14. data/lib/torch/native/parser.rb +134 -0
  15. data/lib/torch/nn/avg_pool1d.rb +18 -0
  16. data/lib/torch/nn/avg_pool2d.rb +19 -0
  17. data/lib/torch/nn/avg_pool3d.rb +19 -0
  18. data/lib/torch/nn/avg_poolnd.rb +9 -0
  19. data/lib/torch/nn/batch_norm.rb +75 -0
  20. data/lib/torch/nn/batch_norm1d.rb +11 -0
  21. data/lib/torch/nn/batch_norm2d.rb +11 -0
  22. data/lib/torch/nn/batch_norm3d.rb +11 -0
  23. data/lib/torch/nn/bce_loss.rb +13 -0
  24. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  25. data/lib/torch/nn/bilinear.rb +38 -0
  26. data/lib/torch/nn/constant_pad1d.rb +10 -0
  27. data/lib/torch/nn/constant_pad2d.rb +10 -0
  28. data/lib/torch/nn/constant_pad3d.rb +10 -0
  29. data/lib/torch/nn/constant_padnd.rb +18 -0
  30. data/lib/torch/nn/conv1d.rb +22 -0
  31. data/lib/torch/nn/conv2d.rb +10 -20
  32. data/lib/torch/nn/conv3d.rb +22 -0
  33. data/lib/torch/nn/convnd.rb +3 -3
  34. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  35. data/lib/torch/nn/cosine_similarity.rb +15 -0
  36. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  37. data/lib/torch/nn/ctc_loss.rb +15 -0
  38. data/lib/torch/nn/dropoutnd.rb +2 -2
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/fold.rb +20 -0
  41. data/lib/torch/nn/functional.rb +379 -32
  42. data/lib/torch/nn/group_norm.rb +36 -0
  43. data/lib/torch/nn/gru.rb +49 -0
  44. data/lib/torch/nn/hardshrink.rb +18 -0
  45. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  46. data/lib/torch/nn/identity.rb +14 -0
  47. data/lib/torch/nn/init.rb +58 -1
  48. data/lib/torch/nn/instance_norm.rb +20 -0
  49. data/lib/torch/nn/instance_norm1d.rb +18 -0
  50. data/lib/torch/nn/instance_norm2d.rb +11 -0
  51. data/lib/torch/nn/instance_norm3d.rb +11 -0
  52. data/lib/torch/nn/kl_div_loss.rb +13 -0
  53. data/lib/torch/nn/l1_loss.rb +13 -0
  54. data/lib/torch/nn/layer_norm.rb +35 -0
  55. data/lib/torch/nn/leaky_relu.rb +20 -0
  56. data/lib/torch/nn/linear.rb +12 -11
  57. data/lib/torch/nn/local_response_norm.rb +21 -0
  58. data/lib/torch/nn/log_sigmoid.rb +9 -0
  59. data/lib/torch/nn/log_softmax.rb +14 -0
  60. data/lib/torch/nn/loss.rb +10 -0
  61. data/lib/torch/nn/lp_pool1d.rb +9 -0
  62. data/lib/torch/nn/lp_pool2d.rb +9 -0
  63. data/lib/torch/nn/lp_poolnd.rb +22 -0
  64. data/lib/torch/nn/lstm.rb +66 -0
  65. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  66. data/lib/torch/nn/max_pool1d.rb +9 -0
  67. data/lib/torch/nn/max_pool2d.rb +9 -0
  68. data/lib/torch/nn/max_pool3d.rb +9 -0
  69. data/lib/torch/nn/max_poolnd.rb +19 -0
  70. data/lib/torch/nn/max_unpool1d.rb +16 -0
  71. data/lib/torch/nn/max_unpool2d.rb +16 -0
  72. data/lib/torch/nn/max_unpool3d.rb +16 -0
  73. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  74. data/lib/torch/nn/module.rb +186 -35
  75. data/lib/torch/nn/mse_loss.rb +2 -2
  76. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  77. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  78. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  79. data/lib/torch/nn/nll_loss.rb +14 -0
  80. data/lib/torch/nn/pairwise_distance.rb +16 -0
  81. data/lib/torch/nn/parameter.rb +2 -2
  82. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  83. data/lib/torch/nn/prelu.rb +19 -0
  84. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  85. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  86. data/lib/torch/nn/reflection_padnd.rb +13 -0
  87. data/lib/torch/nn/relu.rb +8 -3
  88. data/lib/torch/nn/replication_pad1d.rb +10 -0
  89. data/lib/torch/nn/replication_pad2d.rb +10 -0
  90. data/lib/torch/nn/replication_pad3d.rb +10 -0
  91. data/lib/torch/nn/replication_padnd.rb +13 -0
  92. data/lib/torch/nn/rnn.rb +22 -0
  93. data/lib/torch/nn/rnn_base.rb +198 -0
  94. data/lib/torch/nn/sequential.rb +1 -10
  95. data/lib/torch/nn/sigmoid.rb +9 -0
  96. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  97. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  98. data/lib/torch/nn/softmax.rb +18 -0
  99. data/lib/torch/nn/softmax2d.rb +10 -0
  100. data/lib/torch/nn/softmin.rb +14 -0
  101. data/lib/torch/nn/softplus.rb +19 -0
  102. data/lib/torch/nn/softshrink.rb +18 -0
  103. data/lib/torch/nn/softsign.rb +9 -0
  104. data/lib/torch/nn/tanh.rb +9 -0
  105. data/lib/torch/nn/tanhshrink.rb +9 -0
  106. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  107. data/lib/torch/nn/unfold.rb +19 -0
  108. data/lib/torch/nn/utils.rb +25 -0
  109. data/lib/torch/nn/weighted_loss.rb +10 -0
  110. data/lib/torch/nn/zero_pad2d.rb +9 -0
  111. data/lib/torch/random.rb +10 -0
  112. data/lib/torch/tensor.rb +51 -44
  113. data/lib/torch/version.rb +1 -1
  114. metadata +98 -6
  115. data/lib/torch/ext.bundle +0 -0
@@ -2,28 +2,19 @@ module Torch
2
2
  module NN
3
3
  class Sequential < Module
4
4
  def initialize(*args)
5
- @modules = {}
5
+ super()
6
6
  # TODO support hash arg (named modules)
7
7
  args.each_with_index do |mod, idx|
8
8
  add_module(idx.to_s, mod)
9
9
  end
10
10
  end
11
11
 
12
- def add_module(name, mod)
13
- # TODO add checks
14
- @modules[name] = mod
15
- end
16
-
17
12
  def forward(input)
18
13
  @modules.values.each do |mod|
19
14
  input = mod.call(input)
20
15
  end
21
16
  input
22
17
  end
23
-
24
- def parameters
25
- @modules.flat_map { |_, mod| mod.parameters }
26
- end
27
18
  end
28
19
  end
29
20
  end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Sigmoid < Module
4
+ def forward(input)
5
+ Torch.sigmoid(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class SmoothL1Loss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.smooth_l1_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class SoftMarginLoss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.soft_margin_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class Softmax < Module
4
+ def initialize(dim: nil)
5
+ super()
6
+ @dim = dim
7
+ end
8
+
9
+ def forward(input)
10
+ F.softmax(input, dim: @dim)
11
+ end
12
+
13
+ def extra_inspect
14
+ format("dim: %s", @dim)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class Softmax2d < Module
4
+ def forward(input)
5
+ raise ArgumentError, "Softmax2d requires a 4D tensor as input" unless input.dim == 4
6
+ F.softmax(input, dim: 1)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class Softmin < Module
4
+ def initialize(dim: nil)
5
+ super()
6
+ @dim = dim
7
+ end
8
+
9
+ def forward(input)
10
+ F.softmin(input, dim: @dim)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class Softplus < Module
4
+ def initialize(beta: 1, threshold: 20)
5
+ super()
6
+ @beta = beta
7
+ @threshold = threshold
8
+ end
9
+
10
+ def forward(input)
11
+ F.softplus(input, beta: @beta, threshold: @threshold)
12
+ end
13
+
14
+ def extra_inspect
15
+ format("beta: %s, threshold: %s", @beta, @threshold)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class Softshrink < Module
4
+ def initialize(lambd: 0.5)
5
+ super()
6
+ @lambd = lambd
7
+ end
8
+
9
+ def forward(input)
10
+ F.softshrink(input, @lambd)
11
+ end
12
+
13
+ def extra_inspect
14
+ @lambd.to_s
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Softsign < Module
4
+ def forward(input)
5
+ F.softsign(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Tanh < Module
4
+ def forward(input)
5
+ Torch.tanh(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Tanhshrink < Module
4
+ def forward(input)
5
+ F.tanhshrink(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class TripletMarginLoss < Loss
4
+ def initialize(margin: 1.0, p: 2.0, eps: 1e-6, swap: false, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ @p = p
8
+ @eps = eps
9
+ @swap = swap
10
+ end
11
+
12
+ def forward(anchor, positive, negative)
13
+ F.triplet_margin_loss(anchor, positive, negative, margin: @margin, p: @p,
14
+ eps: @eps, swap: @swap, reduction: @reduction)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class Unfold < Module
4
+ def initialize(kernel_size, dilation: 1, padding: 0, stride: 1)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ @dilation = dilation
8
+ @padding = padding
9
+ @stride = stride
10
+ end
11
+
12
+ def forward(input)
13
+ F.unfold(input, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
14
+ end
15
+
16
+ # TODO add extra_inspect
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module NN
3
+ module Utils
4
+ def _single(value)
5
+ _ntuple(1, value)
6
+ end
7
+
8
+ def _pair(value)
9
+ _ntuple(2, value)
10
+ end
11
+
12
+ def _triple(value)
13
+ _ntuple(3, value)
14
+ end
15
+
16
+ def _quadrupal(value)
17
+ _ntuple(4, value)
18
+ end
19
+
20
+ def _ntuple(n, value)
21
+ value.is_a?(Array) ? value : [value] * n
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class WeightedLoss < Loss
4
+ def initialize(weight, reduction)
5
+ super(reduction)
6
+ register_buffer("weight", weight)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class ZeroPad2d < ConstantPad2d
4
+ def initialize(padding)
5
+ super(padding, 0.0)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module Random
3
+ class << self
4
+ # not available through LibTorch
5
+ def initial_seed
6
+ raise NotImplementedYet
7
+ end
8
+ end
9
+ end
10
+ end
@@ -5,12 +5,8 @@ module Torch
5
5
 
6
6
  alias_method :requires_grad?, :requires_grad
7
7
 
8
- def self.new(*size)
9
- if size.length == 1 && size.first.is_a?(Tensor)
10
- size.first
11
- else
12
- Torch.empty(*size)
13
- end
8
+ def self.new(*args)
9
+ FloatTensor.new(*args)
14
10
  end
15
11
 
16
12
  def dtype
@@ -28,7 +24,7 @@ module Torch
28
24
  end
29
25
 
30
26
  def to_a
31
- reshape_arr(_data, shape)
27
+ reshape_arr(_flat_data, shape)
32
28
  end
33
29
 
34
30
  # TODO support dtype
@@ -39,7 +35,7 @@ module Torch
39
35
 
40
36
  def size(dim = nil)
41
37
  if dim
42
- _size(dim)
38
+ _size_int(dim)
43
39
  else
44
40
  shape
45
41
  end
@@ -49,15 +45,16 @@ module Torch
49
45
  dim.times.map { |i| size(i) }
50
46
  end
51
47
 
52
- def view(*size)
53
- _view(size)
48
+ # mirror Python len()
49
+ def length
50
+ size(0)
54
51
  end
55
52
 
56
53
  def item
57
54
  if numel != 1
58
55
  raise Error, "only one element tensors can be converted to Ruby scalars"
59
56
  end
60
- _data.first
57
+ _flat_data.first
61
58
  end
62
59
 
63
60
  # unsure if this is correct
@@ -66,19 +63,14 @@ module Torch
66
63
  end
67
64
 
68
65
  def backward(gradient = nil)
69
- if gradient
70
- _backward_gradient(gradient)
71
- else
72
- _backward
73
- end
66
+ _backward(gradient)
74
67
  end
75
68
 
76
69
  # TODO read directly from memory
77
70
  def numo
78
- raise Error, "Numo not found" unless defined?(Numo::NArray)
79
71
  cls = Torch._dtype_to_numo[dtype]
80
72
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
81
- cls.cast(_data).reshape(*shape)
73
+ cls.cast(_flat_data).reshape(*shape)
82
74
  end
83
75
 
84
76
  def new_ones(*size, **options)
@@ -95,31 +87,23 @@ module Torch
95
87
  _type(enum)
96
88
  end
97
89
 
98
- def add!(value = 1, other)
99
- if other.is_a?(Numeric)
100
- _add_scalar!(other * value)
101
- else
102
- # need to use alpha for sparse tensors instead of multiplying
103
- _add_alpha!(other, value)
104
- end
90
+ def reshape(*size)
91
+ # Python doesn't check if size == 1, just ignores later arguments
92
+ size = size.first if size.size == 1 && size.first.is_a?(Array)
93
+ _reshape(size)
105
94
  end
106
95
 
107
- def mul!(other)
108
- if other.is_a?(Numeric)
109
- _mul_scalar!(other)
110
- else
111
- _mul!(other)
112
- end
96
+ def view(*size)
97
+ size = size.first if size.size == 1 && size.first.is_a?(Array)
98
+ _view(size)
113
99
  end
114
100
 
115
- # operations
116
- %w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
117
- define_method(op) do |*args, **options, &block|
118
- if options.any?
119
- Torch.send(op, self, *args, **options, &block)
120
- else
121
- Torch.send(op, self, *args, &block)
122
- end
101
+ # value and other are swapped for some methods
102
+ def add!(value = 1, other)
103
+ if other.is_a?(Numeric)
104
+ _add__scalar(other, value)
105
+ else
106
+ _add__tensor(other, value)
123
107
  end
124
108
  end
125
109
 
@@ -161,14 +145,20 @@ module Torch
161
145
  dim = 0
162
146
  indexes.each do |index|
163
147
  if index.is_a?(Numeric)
164
- result = result._select(dim, index)
148
+ result = result._select_int(dim, index)
165
149
  elsif index.is_a?(Range)
166
150
  finish = index.end
167
151
  finish += 1 unless index.exclude_end?
168
- result = result._slice(dim, index.begin, finish, 1)
152
+ result = result._slice_tensor(dim, index.begin, finish, 1)
153
+ dim += 1
154
+ elsif index.nil?
155
+ result = result.unsqueeze(dim)
169
156
  dim += 1
157
+ elsif index == true
158
+ result = result.unsqueeze(dim)
159
+ # TODO handle false
170
160
  else
171
- raise Error, "Unsupported index type"
161
+ raise Error, "Unsupported index type: #{index.class.name}"
172
162
  end
173
163
  end
174
164
  result
@@ -176,11 +166,28 @@ module Torch
176
166
 
177
167
  # TODO
178
168
  # based on python_variable_indexing.cpp
179
- # def []=(index, value)
180
- # end
169
+ def []=(index, value)
170
+ raise ArgumentError, "Tensor does not support deleting items" if value.nil?
171
+
172
+ value = Torch.tensor(value) unless value.is_a?(Tensor)
173
+
174
+ if index.is_a?(Numeric)
175
+ copy_to(_select_int(0, index), value)
176
+ elsif index.is_a?(Range)
177
+ finish = index.end
178
+ finish += 1 unless index.exclude_end?
179
+ copy_to(_slice_tensor(0, index.begin, finish, 1), value)
180
+ else
181
+ raise Error, "Unsupported index type: #{index.class.name}"
182
+ end
183
+ end
181
184
 
182
185
  private
183
186
 
187
+ def copy_to(dst, src)
188
+ dst.copy!(src)
189
+ end
190
+
184
191
  def reshape_arr(arr, dims)
185
192
  if dims.empty?
186
193
  arr