torch-rb 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/README.md +1 -0
  4. data/ext/torch/ext.cpp +375 -124
  5. data/lib/torch.rb +101 -20
  6. data/lib/torch/ext.bundle +0 -0
  7. data/lib/torch/inspector.rb +23 -19
  8. data/lib/torch/nn/avg_pool2d.rb +14 -0
  9. data/lib/torch/nn/avg_poolnd.rb +9 -0
  10. data/lib/torch/nn/bce_loss.rb +13 -0
  11. data/lib/torch/nn/bilinear.rb +38 -0
  12. data/lib/torch/nn/conv2d.rb +2 -2
  13. data/lib/torch/nn/convnd.rb +3 -3
  14. data/lib/torch/nn/cosine_similarity.rb +15 -0
  15. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  16. data/lib/torch/nn/ctc_loss.rb +15 -0
  17. data/lib/torch/nn/dropoutnd.rb +2 -2
  18. data/lib/torch/nn/embedding_bag.rb +34 -0
  19. data/lib/torch/nn/functional.rb +101 -13
  20. data/lib/torch/nn/identity.rb +13 -0
  21. data/lib/torch/nn/init.rb +58 -1
  22. data/lib/torch/nn/kl_div_loss.rb +13 -0
  23. data/lib/torch/nn/l1_loss.rb +13 -0
  24. data/lib/torch/nn/leaky_relu.rb +20 -0
  25. data/lib/torch/nn/linear.rb +12 -11
  26. data/lib/torch/nn/log_softmax.rb +14 -0
  27. data/lib/torch/nn/loss.rb +10 -0
  28. data/lib/torch/nn/max_pool2d.rb +9 -0
  29. data/lib/torch/nn/max_poolnd.rb +19 -0
  30. data/lib/torch/nn/module.rb +120 -31
  31. data/lib/torch/nn/mse_loss.rb +2 -2
  32. data/lib/torch/nn/nll_loss.rb +14 -0
  33. data/lib/torch/nn/pairwise_distance.rb +16 -0
  34. data/lib/torch/nn/parameter.rb +0 -4
  35. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  36. data/lib/torch/nn/prelu.rb +19 -0
  37. data/lib/torch/nn/relu.rb +8 -3
  38. data/lib/torch/nn/sequential.rb +1 -10
  39. data/lib/torch/nn/sigmoid.rb +9 -0
  40. data/lib/torch/nn/softmax.rb +18 -0
  41. data/lib/torch/nn/softmax2d.rb +10 -0
  42. data/lib/torch/nn/softmin.rb +14 -0
  43. data/lib/torch/nn/softplus.rb +19 -0
  44. data/lib/torch/nn/weighted_loss.rb +10 -0
  45. data/lib/torch/random.rb +10 -0
  46. data/lib/torch/tensor.rb +28 -10
  47. data/lib/torch/version.rb +1 -1
  48. metadata +29 -2
@@ -2,8 +2,12 @@ module Torch
2
2
  module NN
3
3
  class Functional
4
4
  class << self
5
- def relu(input)
6
- Torch.relu(input)
5
+ def relu(input, inplace: false)
6
+ if inplace
7
+ input.relu!
8
+ else
9
+ input.relu
10
+ end
7
11
  end
8
12
 
9
13
  def conv2d(input, weight, bias, stride: 1, padding: 0, dilation: 1, groups: 1)
@@ -11,6 +15,14 @@ module Torch
11
15
  Torch.conv2d(input, weight, bias, stride, padding, dilation, groups)
12
16
  end
13
17
 
18
+ def prelu(input, weight)
19
+ Torch.prelu(input, weight)
20
+ end
21
+
22
+ def leaky_relu(input, negative_slope = 0.01)
23
+ Torch.leaky_relu(input, negative_slope)
24
+ end
25
+
14
26
  def max_pool2d(input, kernel_size)
15
27
  kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
16
28
  Torch.max_pool2d(input, kernel_size)
@@ -21,24 +33,102 @@ module Torch
21
33
  Torch.avg_pool2d(input, kernel_size)
22
34
  end
23
35
 
36
+ # linear layers
37
+
38
+ def bilinear(input1, input2, weight, bias)
39
+ Torch.bilinear(input1, input2, weight, bias)
40
+ end
41
+
24
42
  def linear(input, weight, bias)
25
43
  Torch.linear(input, weight, bias)
26
44
  end
27
45
 
46
+ # sparse layers
47
+
48
+ def embedding(input, weight, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false)
49
+ # TODO handle max_norm and norm_type
50
+ raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
51
+
52
+ padding_idx ||= -1
53
+ Torch._embedding(input, weight, padding_idx, scale_grad_by_freq, sparse)
54
+ end
55
+
56
+ def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
57
+ # need to handle nils
58
+ raise NotImplementedYet
59
+
60
+ # TODO handle max_norm and norm_type
61
+ raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
62
+
63
+ Torch._embedding_bag(input, weight, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights)
64
+ end
65
+
66
+ # distance functions
67
+
68
+ def cosine_similarity(x1, x2, dim: 1, eps: 1e-8)
69
+ Torch._cosine_similarity(x1, x2, dim, eps)
70
+ end
71
+
72
+ def pairwise_distance(x1, x2, p: 2.0, eps: 1e-6, keepdim: false)
73
+ Torch._pairwise_distance(x1, x2, p, eps, keepdim)
74
+ end
75
+
76
+ # loss functions
77
+
78
+ def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
79
+ raise NotImplementedYet if weight
80
+ Torch.binary_cross_entropy(input, target, reduction)
81
+ end
82
+
83
+ def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
84
+ nll_loss(log_softmax(input, 1), target, weight: weight, ignore_index: ignore_index, reduction: reduction)
85
+ end
86
+
87
+ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
88
+ # call to_a on input_lengths and target_lengths for C++
89
+ Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
90
+ end
91
+
92
+ def kl_div(input, target, reduction: "mean")
93
+ Torch.kl_div(input, target, reduction)
94
+ end
95
+
96
+ def l1_loss(input, target, reduction: "mean")
97
+ Torch.l1_loss(input, target, reduction)
98
+ end
99
+
28
100
  def mse_loss(input, target, reduction: "mean")
29
101
  Torch.mse_loss(input, target, reduction)
30
102
  end
31
103
 
32
- def cross_entropy(input, target)
33
- nll_loss(log_softmax(input, 1), target)
104
+ def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
105
+ raise NotImplementedYet if weight
106
+ Torch.nll_loss(input, target, reduction, ignore_index)
107
+ end
108
+
109
+ def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
110
+ Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
34
111
  end
35
112
 
36
- def nll_loss(input, target, reduction: "mean")
37
- # TODO fix for non-1d
38
- Torch.nll_loss(input, target, reduction)
113
+ # end loss
114
+
115
+ def softmax(input, dim: nil)
116
+ dim ||= softmax_dim(input.dim)
117
+ input.softmax(dim: dim)
39
118
  end
40
119
 
41
- def log_softmax(input, dim)
120
+ def softmin(input, dim: nil)
121
+ dim ||= softmax_dim(input.dim)
122
+ (-input).softmax(dim: dim)
123
+ end
124
+
125
+ def softplus(input, beta: 1, threshold: 20)
126
+ Torch._softplus(input, beta, threshold)
127
+ end
128
+
129
+ # TODO make dim keyword argument and update examples
130
+ def log_softmax(input, dim = nil)
131
+ dim ||= softmax_dim(input.dim)
42
132
  input.log_softmax(dim)
43
133
  end
44
134
 
@@ -84,12 +174,10 @@ module Torch
84
174
  end
85
175
  end
86
176
 
87
- def embedding(input, weight, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false)
88
- # TODO handle max_norm and norm_type
89
- raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
177
+ private
90
178
 
91
- padding_idx ||= -1
92
- Torch._embedding(input, weight, padding_idx, scale_grad_by_freq, sparse)
179
+ def softmax_dim(ndim)
180
+ ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
93
181
  end
94
182
  end
95
183
  end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class Identity < Module
4
+ def initialize(*args, **options)
5
+ super()
6
+ end
7
+
8
+ def forward(input)
9
+ input
10
+ end
11
+ end
12
+ end
13
+ end
data/lib/torch/nn/init.rb CHANGED
@@ -2,7 +2,64 @@ module Torch
2
2
  module NN
3
3
  module Init
4
4
  class << self
5
- def calculate_fan_in_and_fan_out(tensor)
5
+ def calculate_gain(nonlinearity, param: 0.01)
6
+ _calculate_gain(nonlinearity, param)
7
+ end
8
+
9
+ def uniform!(tensor, a: 0.0, b: 1.0)
10
+ _uniform!(tensor, a, b)
11
+ end
12
+
13
+ def normal!(tensor, mean: 0.0, std: 1.0)
14
+ _normal!(tensor, mean, std)
15
+ end
16
+
17
+ def constant!(tensor, val)
18
+ _constant!(tensor, val)
19
+ end
20
+
21
+ def ones!(tensor)
22
+ _ones!(tensor)
23
+ end
24
+
25
+ def zeros!(tensor)
26
+ _zeros!(tensor)
27
+ end
28
+
29
+ def eye!(tensor)
30
+ _eye!(tensor)
31
+ end
32
+
33
+ def dirac!(tensor)
34
+ _dirac!(tensor)
35
+ end
36
+
37
+ def xavier_uniform!(tensor, gain: 1.0)
38
+ _xavier_uniform!(tensor, gain)
39
+ end
40
+
41
+ def xavier_normal!(tensor, gain: 1.0)
42
+ _xavier_normal!(tensor, gain)
43
+ end
44
+
45
+ def kaiming_uniform!(tensor, a: 0, mode: "fan_in", nonlinearity: "leaky_relu")
46
+ _kaiming_uniform!(tensor, a, mode, nonlinearity)
47
+ end
48
+
49
+ def kaiming_normal!(tensor, a: 0, mode: "fan_in", nonlinearity: "leaky_relu")
50
+ _kaiming_normal!(tensor, a, mode, nonlinearity)
51
+ end
52
+
53
+ def orthogonal!(tensor, gain: 1)
54
+ _orthogonal!(tensor, gain)
55
+ end
56
+
57
+ def sparse!(tensor, sparsity, std: 0.01)
58
+ _sparse!(tensor, sparsity, std)
59
+ end
60
+
61
+ # TODO move to C++ when released
62
+ def _calculate_fan_in_and_fan_out(tensor)
6
63
  dimensions = tensor.dim
7
64
  if dimensions < 2
8
65
  raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class KLDivLoss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.kl_div(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class L1Loss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.l1_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,20 @@
1
+ module Torch
2
+ module NN
3
+ class LeakyReLU < Module
4
+ def initialize(negative_slope: 1e-2) #, inplace: false)
5
+ super()
6
+ @negative_slope = negative_slope
7
+ # @inplace = inplace
8
+ end
9
+
10
+ def forward(input)
11
+ F.leaky_relu(input, @negative_slope) #, inplace: @inplace)
12
+ end
13
+
14
+ def extra_inspect
15
+ inplace_str = @inplace ? ", inplace: true" : ""
16
+ format("negative_slope: %s%s", @negative_slope, inplace_str)
17
+ end
18
+ end
19
+ end
20
+ end
@@ -1,35 +1,36 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Linear < Module
4
- attr_reader :bias, :weight
5
-
6
4
  def initialize(in_features, out_features, bias: true)
5
+ super()
7
6
  @in_features = in_features
8
7
  @out_features = out_features
9
8
 
10
9
  @weight = Parameter.new(Tensor.new(out_features, in_features))
11
10
  if bias
12
11
  @bias = Parameter.new(Tensor.new(out_features))
12
+ else
13
+ register_parameter("bias", nil)
13
14
  end
14
15
 
15
16
  reset_parameters
16
17
  end
17
18
 
18
- def call(input)
19
- F.linear(input, @weight, @bias)
20
- end
21
-
22
19
  def reset_parameters
23
- Init.kaiming_uniform!(@weight, Math.sqrt(5))
20
+ Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
24
21
  if @bias
25
- fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
22
+ fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
26
23
  bound = 1 / Math.sqrt(fan_in)
27
- Init.uniform!(@bias, -bound, bound)
24
+ Init.uniform!(@bias, a: -bound, b: bound)
28
25
  end
29
26
  end
30
27
 
31
- def inspect
32
- "Linear(in_features: #{@in_features.inspect}, out_features: #{@out_features.inspect}, bias: #{(!@bias.nil?).inspect})"
28
+ def forward(input)
29
+ F.linear(input, @weight, @bias)
30
+ end
31
+
32
+ def extra_inspect
33
+ format("in_features: %s, out_features: %s, bias: %s", @in_features, @out_features, !@bias.nil?)
33
34
  end
34
35
  end
35
36
  end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class LogSoftmax < Module
4
+ def initialize(dim: nil)
5
+ super()
6
+ @dim = dim
7
+ end
8
+
9
+ def forward(input)
10
+ F.log_softmax(input, @dim)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class Loss < Module
4
+ def initialize(reduction)
5
+ super()
6
+ @reduction = reduction
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool2d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool2d(input, @kernel_size) # TODO other parameters
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPoolNd < Module
4
+ def initialize(kernel_size) #, stride: nil, padding: 0, dilation: 1, return_indices: false, ceil_mode: false)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ # @stride = stride || kernel_size
8
+ # @padding = padding
9
+ # @dilation = dilation
10
+ # @return_indices = return_indices
11
+ # @ceil_mode = ceil_mode
12
+ end
13
+
14
+ def extra_inspect
15
+ format("kernel_size: %s", @kernel_size)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -3,45 +3,85 @@ module Torch
3
3
  class Module
4
4
  def initialize
5
5
  @training = true
6
+ @parameters = {}
7
+ @buffers = {}
8
+ @modules = {}
6
9
  end
7
10
 
8
- def inspect
9
- str = String.new
10
- str << "#{self.class.name}(\n"
11
- modules.each do |name, mod|
12
- str << " (#{name}): #{mod.inspect}\n"
13
- end
14
- str << ")"
11
+ def forward
12
+ raise NotImplementedError
15
13
  end
16
14
 
17
- def train(mode = true)
18
- @training = mode
15
+ def register_buffer(name, tensor)
16
+ # TODO add checks
17
+ @buffers[name] = tensor
18
+ end
19
19
 
20
- modules.each do |_, mod|
21
- mod.train(mode)
20
+ def register_parameter(name, param)
21
+ # TODO add checks
22
+ @parameters[name] = param
23
+ end
24
+
25
+ def add_module(name, mod)
26
+ # TODO add checks
27
+ @modules[name] = mod
28
+ end
29
+
30
+ def _apply(fn)
31
+ children.each do |mod|
32
+ mod._apply(fn)
22
33
  end
34
+ # TODO apply to more objects
35
+ self
23
36
  end
24
37
 
25
- def eval
26
- train(false)
38
+ def apply(fn)
39
+ children.each do |mod|
40
+ mod.apply(fn)
41
+ end
42
+ fn.call(self)
43
+ self
27
44
  end
28
45
 
29
- def call(*input)
30
- forward(*input)
46
+ def cuda(device: nil)
47
+ _apply ->(t) { t.cuda(device) }
48
+ end
49
+
50
+ def cpu
51
+ _apply ->(t) { t.cpu }
52
+ end
53
+
54
+ def type(dst_type)
55
+ _apply ->(t) { t.type(dst_type) }
56
+ end
57
+
58
+ def float
59
+ _apply ->(t) { t.floating_point? ? t.float : t }
60
+ end
61
+
62
+ def double
63
+ _apply ->(t) { t.floating_point? ? t.double : t }
64
+ end
65
+
66
+ def half
67
+ _apply ->(t) { t.floating_point? ? t.half : t }
31
68
  end
32
69
 
33
70
  # modifies in-place
34
71
  def to(device)
35
- instance_variables.each do |name|
36
- param = instance_variable_get(name)
37
- if param.is_a?(Parameter)
38
- instance_variable_set(name, Parameter.new(param.to(device)))
39
- end
72
+ convert = lambda do |t|
73
+ t.to(device)
40
74
  end
41
- modules.each do |_, mod|
42
- mod.to(device)
43
- end
44
- self
75
+
76
+ _apply(convert)
77
+ end
78
+
79
+ def call(*input)
80
+ forward(*input)
81
+ end
82
+
83
+ def state_dict
84
+ raise NotImplementedYet
45
85
  end
46
86
 
47
87
  def parameters
@@ -53,6 +93,38 @@ module Torch
53
93
  params + modules.flat_map { |_, mod| mod.parameters }
54
94
  end
55
95
 
96
+ def children
97
+ @modules.values
98
+ end
99
+
100
+ def modules
101
+ modules = {}
102
+ instance_variables.each do |name|
103
+ mod = instance_variable_get(name)
104
+ modules[name[1..-1]] = mod if mod.is_a?(Module)
105
+ end
106
+ @modules.merge(modules)
107
+ end
108
+
109
+ def train(mode = true)
110
+ @training = mode
111
+ children.each do |mod|
112
+ mod.train(mode)
113
+ end
114
+ self
115
+ end
116
+
117
+ def eval
118
+ train(false)
119
+ end
120
+
121
+ def requires_grad!(requires_grad: true)
122
+ parameters.each do |p|
123
+ p.requires_grad!(requires_grad)
124
+ end
125
+ self
126
+ end
127
+
56
128
  def zero_grad
57
129
  parameters.each do |param|
58
130
  if param.grad
@@ -62,6 +134,24 @@ module Torch
62
134
  end
63
135
  end
64
136
 
137
+ def share_memory
138
+ _apply ->(t) { t.share_memory! }
139
+ end
140
+
141
+ def inspect
142
+ name = self.class.name.split("::").last
143
+ if modules.empty?
144
+ "#{name}(#{extra_inspect})"
145
+ else
146
+ str = String.new
147
+ str << "#{name}(\n"
148
+ modules.each do |name, mod|
149
+ str << " (#{name}): #{mod.inspect}\n"
150
+ end
151
+ str << ")"
152
+ end
153
+ end
154
+
65
155
  def method_missing(method, *args, &block)
66
156
  modules[method.to_s] || super
67
157
  end
@@ -72,13 +162,12 @@ module Torch
72
162
 
73
163
  private
74
164
 
75
- def modules
76
- modules = {}
77
- instance_variables.each do |name|
78
- mod = instance_variable_get(name)
79
- modules[name[1..-1]] = mod if mod.is_a?(Module)
80
- end
81
- modules
165
+ def extra_inspect
166
+ nil
167
+ end
168
+
169
+ def format(str, *vars)
170
+ str % vars.map(&:inspect)
82
171
  end
83
172
  end
84
173
  end