torch-rb 0.1.3 → 0.1.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (48) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/README.md +1 -0
  4. data/ext/torch/ext.cpp +375 -124
  5. data/lib/torch.rb +101 -20
  6. data/lib/torch/ext.bundle +0 -0
  7. data/lib/torch/inspector.rb +23 -19
  8. data/lib/torch/nn/avg_pool2d.rb +14 -0
  9. data/lib/torch/nn/avg_poolnd.rb +9 -0
  10. data/lib/torch/nn/bce_loss.rb +13 -0
  11. data/lib/torch/nn/bilinear.rb +38 -0
  12. data/lib/torch/nn/conv2d.rb +2 -2
  13. data/lib/torch/nn/convnd.rb +3 -3
  14. data/lib/torch/nn/cosine_similarity.rb +15 -0
  15. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  16. data/lib/torch/nn/ctc_loss.rb +15 -0
  17. data/lib/torch/nn/dropoutnd.rb +2 -2
  18. data/lib/torch/nn/embedding_bag.rb +34 -0
  19. data/lib/torch/nn/functional.rb +101 -13
  20. data/lib/torch/nn/identity.rb +13 -0
  21. data/lib/torch/nn/init.rb +58 -1
  22. data/lib/torch/nn/kl_div_loss.rb +13 -0
  23. data/lib/torch/nn/l1_loss.rb +13 -0
  24. data/lib/torch/nn/leaky_relu.rb +20 -0
  25. data/lib/torch/nn/linear.rb +12 -11
  26. data/lib/torch/nn/log_softmax.rb +14 -0
  27. data/lib/torch/nn/loss.rb +10 -0
  28. data/lib/torch/nn/max_pool2d.rb +9 -0
  29. data/lib/torch/nn/max_poolnd.rb +19 -0
  30. data/lib/torch/nn/module.rb +120 -31
  31. data/lib/torch/nn/mse_loss.rb +2 -2
  32. data/lib/torch/nn/nll_loss.rb +14 -0
  33. data/lib/torch/nn/pairwise_distance.rb +16 -0
  34. data/lib/torch/nn/parameter.rb +0 -4
  35. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  36. data/lib/torch/nn/prelu.rb +19 -0
  37. data/lib/torch/nn/relu.rb +8 -3
  38. data/lib/torch/nn/sequential.rb +1 -10
  39. data/lib/torch/nn/sigmoid.rb +9 -0
  40. data/lib/torch/nn/softmax.rb +18 -0
  41. data/lib/torch/nn/softmax2d.rb +10 -0
  42. data/lib/torch/nn/softmin.rb +14 -0
  43. data/lib/torch/nn/softplus.rb +19 -0
  44. data/lib/torch/nn/weighted_loss.rb +10 -0
  45. data/lib/torch/random.rb +10 -0
  46. data/lib/torch/tensor.rb +28 -10
  47. data/lib/torch/version.rb +1 -1
  48. metadata +29 -2
@@ -2,8 +2,12 @@ module Torch
2
2
  module NN
3
3
  class Functional
4
4
  class << self
5
- def relu(input)
6
- Torch.relu(input)
5
+ def relu(input, inplace: false)
6
+ if inplace
7
+ input.relu!
8
+ else
9
+ input.relu
10
+ end
7
11
  end
8
12
 
9
13
  def conv2d(input, weight, bias, stride: 1, padding: 0, dilation: 1, groups: 1)
@@ -11,6 +15,14 @@ module Torch
11
15
  Torch.conv2d(input, weight, bias, stride, padding, dilation, groups)
12
16
  end
13
17
 
18
+ def prelu(input, weight)
19
+ Torch.prelu(input, weight)
20
+ end
21
+
22
+ def leaky_relu(input, negative_slope = 0.01)
23
+ Torch.leaky_relu(input, negative_slope)
24
+ end
25
+
14
26
  def max_pool2d(input, kernel_size)
15
27
  kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
16
28
  Torch.max_pool2d(input, kernel_size)
@@ -21,24 +33,102 @@ module Torch
21
33
  Torch.avg_pool2d(input, kernel_size)
22
34
  end
23
35
 
36
+ # linear layers
37
+
38
+ def bilinear(input1, input2, weight, bias)
39
+ Torch.bilinear(input1, input2, weight, bias)
40
+ end
41
+
24
42
  def linear(input, weight, bias)
25
43
  Torch.linear(input, weight, bias)
26
44
  end
27
45
 
46
+ # sparse layers
47
+
48
+ def embedding(input, weight, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false)
49
+ # TODO handle max_norm and norm_type
50
+ raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
51
+
52
+ padding_idx ||= -1
53
+ Torch._embedding(input, weight, padding_idx, scale_grad_by_freq, sparse)
54
+ end
55
+
56
+ def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
57
+ # need to handle nils
58
+ raise NotImplementedYet
59
+
60
+ # TODO handle max_norm and norm_type
61
+ raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
62
+
63
+ Torch._embedding_bag(input, weight, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights)
64
+ end
65
+
66
+ # distance functions
67
+
68
+ def cosine_similarity(x1, x2, dim: 1, eps: 1e-8)
69
+ Torch._cosine_similarity(x1, x2, dim, eps)
70
+ end
71
+
72
+ def pairwise_distance(x1, x2, p: 2.0, eps: 1e-6, keepdim: false)
73
+ Torch._pairwise_distance(x1, x2, p, eps, keepdim)
74
+ end
75
+
76
+ # loss functions
77
+
78
+ def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
79
+ raise NotImplementedYet if weight
80
+ Torch.binary_cross_entropy(input, target, reduction)
81
+ end
82
+
83
+ def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
84
+ nll_loss(log_softmax(input, 1), target, weight: weight, ignore_index: ignore_index, reduction: reduction)
85
+ end
86
+
87
+ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
88
+ # call to_a on input_lengths and target_lengths for C++
89
+ Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
90
+ end
91
+
92
+ def kl_div(input, target, reduction: "mean")
93
+ Torch.kl_div(input, target, reduction)
94
+ end
95
+
96
+ def l1_loss(input, target, reduction: "mean")
97
+ Torch.l1_loss(input, target, reduction)
98
+ end
99
+
28
100
  def mse_loss(input, target, reduction: "mean")
29
101
  Torch.mse_loss(input, target, reduction)
30
102
  end
31
103
 
32
- def cross_entropy(input, target)
33
- nll_loss(log_softmax(input, 1), target)
104
+ def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
105
+ raise NotImplementedYet if weight
106
+ Torch.nll_loss(input, target, reduction, ignore_index)
107
+ end
108
+
109
+ def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
110
+ Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
34
111
  end
35
112
 
36
- def nll_loss(input, target, reduction: "mean")
37
- # TODO fix for non-1d
38
- Torch.nll_loss(input, target, reduction)
113
+ # end loss
114
+
115
+ def softmax(input, dim: nil)
116
+ dim ||= softmax_dim(input.dim)
117
+ input.softmax(dim: dim)
39
118
  end
40
119
 
41
- def log_softmax(input, dim)
120
+ def softmin(input, dim: nil)
121
+ dim ||= softmax_dim(input.dim)
122
+ (-input).softmax(dim: dim)
123
+ end
124
+
125
+ def softplus(input, beta: 1, threshold: 20)
126
+ Torch._softplus(input, beta, threshold)
127
+ end
128
+
129
+ # TODO make dim keyword argument and update examples
130
+ def log_softmax(input, dim = nil)
131
+ dim ||= softmax_dim(input.dim)
42
132
  input.log_softmax(dim)
43
133
  end
44
134
 
@@ -84,12 +174,10 @@ module Torch
84
174
  end
85
175
  end
86
176
 
87
- def embedding(input, weight, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false)
88
- # TODO handle max_norm and norm_type
89
- raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
177
+ private
90
178
 
91
- padding_idx ||= -1
92
- Torch._embedding(input, weight, padding_idx, scale_grad_by_freq, sparse)
179
+ def softmax_dim(ndim)
180
+ ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
93
181
  end
94
182
  end
95
183
  end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class Identity < Module
4
+ def initialize(*args, **options)
5
+ super()
6
+ end
7
+
8
+ def forward(input)
9
+ input
10
+ end
11
+ end
12
+ end
13
+ end
data/lib/torch/nn/init.rb CHANGED
@@ -2,7 +2,64 @@ module Torch
2
2
  module NN
3
3
  module Init
4
4
  class << self
5
- def calculate_fan_in_and_fan_out(tensor)
5
+ def calculate_gain(nonlinearity, param: 0.01)
6
+ _calculate_gain(nonlinearity, param)
7
+ end
8
+
9
+ def uniform!(tensor, a: 0.0, b: 1.0)
10
+ _uniform!(tensor, a, b)
11
+ end
12
+
13
+ def normal!(tensor, mean: 0.0, std: 1.0)
14
+ _normal!(tensor, mean, std)
15
+ end
16
+
17
+ def constant!(tensor, val)
18
+ _constant!(tensor, val)
19
+ end
20
+
21
+ def ones!(tensor)
22
+ _ones!(tensor)
23
+ end
24
+
25
+ def zeros!(tensor)
26
+ _zeros!(tensor)
27
+ end
28
+
29
+ def eye!(tensor)
30
+ _eye!(tensor)
31
+ end
32
+
33
+ def dirac!(tensor)
34
+ _dirac!(tensor)
35
+ end
36
+
37
+ def xavier_uniform!(tensor, gain: 1.0)
38
+ _xavier_uniform!(tensor, gain)
39
+ end
40
+
41
+ def xavier_normal!(tensor, gain: 1.0)
42
+ _xavier_normal!(tensor, gain)
43
+ end
44
+
45
+ def kaiming_uniform!(tensor, a: 0, mode: "fan_in", nonlinearity: "leaky_relu")
46
+ _kaiming_uniform!(tensor, a, mode, nonlinearity)
47
+ end
48
+
49
+ def kaiming_normal!(tensor, a: 0, mode: "fan_in", nonlinearity: "leaky_relu")
50
+ _kaiming_normal!(tensor, a, mode, nonlinearity)
51
+ end
52
+
53
+ def orthogonal!(tensor, gain: 1)
54
+ _orthogonal!(tensor, gain)
55
+ end
56
+
57
+ def sparse!(tensor, sparsity, std: 0.01)
58
+ _sparse!(tensor, sparsity, std)
59
+ end
60
+
61
+ # TODO move to C++ when released
62
+ def _calculate_fan_in_and_fan_out(tensor)
6
63
  dimensions = tensor.dim
7
64
  if dimensions < 2
8
65
  raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class KLDivLoss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.kl_div(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class L1Loss < Loss
4
+ def initialize(reduction: "mean")
5
+ super(reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.l1_loss(input, target, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,20 @@
1
+ module Torch
2
+ module NN
3
+ class LeakyReLU < Module
4
+ def initialize(negative_slope: 1e-2) #, inplace: false)
5
+ super()
6
+ @negative_slope = negative_slope
7
+ # @inplace = inplace
8
+ end
9
+
10
+ def forward(input)
11
+ F.leaky_relu(input, @negative_slope) #, inplace: @inplace)
12
+ end
13
+
14
+ def extra_inspect
15
+ inplace_str = @inplace ? ", inplace: true" : ""
16
+ format("negative_slope: %s%s", @negative_slope, inplace_str)
17
+ end
18
+ end
19
+ end
20
+ end
@@ -1,35 +1,36 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Linear < Module
4
- attr_reader :bias, :weight
5
-
6
4
  def initialize(in_features, out_features, bias: true)
5
+ super()
7
6
  @in_features = in_features
8
7
  @out_features = out_features
9
8
 
10
9
  @weight = Parameter.new(Tensor.new(out_features, in_features))
11
10
  if bias
12
11
  @bias = Parameter.new(Tensor.new(out_features))
12
+ else
13
+ register_parameter("bias", nil)
13
14
  end
14
15
 
15
16
  reset_parameters
16
17
  end
17
18
 
18
- def call(input)
19
- F.linear(input, @weight, @bias)
20
- end
21
-
22
19
  def reset_parameters
23
- Init.kaiming_uniform!(@weight, Math.sqrt(5))
20
+ Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
24
21
  if @bias
25
- fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
22
+ fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
26
23
  bound = 1 / Math.sqrt(fan_in)
27
- Init.uniform!(@bias, -bound, bound)
24
+ Init.uniform!(@bias, a: -bound, b: bound)
28
25
  end
29
26
  end
30
27
 
31
- def inspect
32
- "Linear(in_features: #{@in_features.inspect}, out_features: #{@out_features.inspect}, bias: #{(!@bias.nil?).inspect})"
28
+ def forward(input)
29
+ F.linear(input, @weight, @bias)
30
+ end
31
+
32
+ def extra_inspect
33
+ format("in_features: %s, out_features: %s, bias: %s", @in_features, @out_features, !@bias.nil?)
33
34
  end
34
35
  end
35
36
  end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class LogSoftmax < Module
4
+ def initialize(dim: nil)
5
+ super()
6
+ @dim = dim
7
+ end
8
+
9
+ def forward(input)
10
+ F.log_softmax(input, @dim)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class Loss < Module
4
+ def initialize(reduction)
5
+ super()
6
+ @reduction = reduction
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool2d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool2d(input, @kernel_size) # TODO other parameters
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPoolNd < Module
4
+ def initialize(kernel_size) #, stride: nil, padding: 0, dilation: 1, return_indices: false, ceil_mode: false)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ # @stride = stride || kernel_size
8
+ # @padding = padding
9
+ # @dilation = dilation
10
+ # @return_indices = return_indices
11
+ # @ceil_mode = ceil_mode
12
+ end
13
+
14
+ def extra_inspect
15
+ format("kernel_size: %s", @kernel_size)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -3,45 +3,85 @@ module Torch
3
3
  class Module
4
4
  def initialize
5
5
  @training = true
6
+ @parameters = {}
7
+ @buffers = {}
8
+ @modules = {}
6
9
  end
7
10
 
8
- def inspect
9
- str = String.new
10
- str << "#{self.class.name}(\n"
11
- modules.each do |name, mod|
12
- str << " (#{name}): #{mod.inspect}\n"
13
- end
14
- str << ")"
11
+ def forward
12
+ raise NotImplementedError
15
13
  end
16
14
 
17
- def train(mode = true)
18
- @training = mode
15
+ def register_buffer(name, tensor)
16
+ # TODO add checks
17
+ @buffers[name] = tensor
18
+ end
19
19
 
20
- modules.each do |_, mod|
21
- mod.train(mode)
20
+ def register_parameter(name, param)
21
+ # TODO add checks
22
+ @parameters[name] = param
23
+ end
24
+
25
+ def add_module(name, mod)
26
+ # TODO add checks
27
+ @modules[name] = mod
28
+ end
29
+
30
+ def _apply(fn)
31
+ children.each do |mod|
32
+ mod._apply(fn)
22
33
  end
34
+ # TODO apply to more objects
35
+ self
23
36
  end
24
37
 
25
- def eval
26
- train(false)
38
+ def apply(fn)
39
+ children.each do |mod|
40
+ mod.apply(fn)
41
+ end
42
+ fn.call(self)
43
+ self
27
44
  end
28
45
 
29
- def call(*input)
30
- forward(*input)
46
+ def cuda(device: nil)
47
+ _apply ->(t) { t.cuda(device) }
48
+ end
49
+
50
+ def cpu
51
+ _apply ->(t) { t.cpu }
52
+ end
53
+
54
+ def type(dst_type)
55
+ _apply ->(t) { t.type(dst_type) }
56
+ end
57
+
58
+ def float
59
+ _apply ->(t) { t.floating_point? ? t.float : t }
60
+ end
61
+
62
+ def double
63
+ _apply ->(t) { t.floating_point? ? t.double : t }
64
+ end
65
+
66
+ def half
67
+ _apply ->(t) { t.floating_point? ? t.half : t }
31
68
  end
32
69
 
33
70
  # modifies in-place
34
71
  def to(device)
35
- instance_variables.each do |name|
36
- param = instance_variable_get(name)
37
- if param.is_a?(Parameter)
38
- instance_variable_set(name, Parameter.new(param.to(device)))
39
- end
72
+ convert = lambda do |t|
73
+ t.to(device)
40
74
  end
41
- modules.each do |_, mod|
42
- mod.to(device)
43
- end
44
- self
75
+
76
+ _apply(convert)
77
+ end
78
+
79
+ def call(*input)
80
+ forward(*input)
81
+ end
82
+
83
+ def state_dict
84
+ raise NotImplementedYet
45
85
  end
46
86
 
47
87
  def parameters
@@ -53,6 +93,38 @@ module Torch
53
93
  params + modules.flat_map { |_, mod| mod.parameters }
54
94
  end
55
95
 
96
+ def children
97
+ @modules.values
98
+ end
99
+
100
+ def modules
101
+ modules = {}
102
+ instance_variables.each do |name|
103
+ mod = instance_variable_get(name)
104
+ modules[name[1..-1]] = mod if mod.is_a?(Module)
105
+ end
106
+ @modules.merge(modules)
107
+ end
108
+
109
+ def train(mode = true)
110
+ @training = mode
111
+ children.each do |mod|
112
+ mod.train(mode)
113
+ end
114
+ self
115
+ end
116
+
117
+ def eval
118
+ train(false)
119
+ end
120
+
121
+ def requires_grad!(requires_grad: true)
122
+ parameters.each do |p|
123
+ p.requires_grad!(requires_grad)
124
+ end
125
+ self
126
+ end
127
+
56
128
  def zero_grad
57
129
  parameters.each do |param|
58
130
  if param.grad
@@ -62,6 +134,24 @@ module Torch
62
134
  end
63
135
  end
64
136
 
137
+ def share_memory
138
+ _apply ->(t) { t.share_memory! }
139
+ end
140
+
141
+ def inspect
142
+ name = self.class.name.split("::").last
143
+ if modules.empty?
144
+ "#{name}(#{extra_inspect})"
145
+ else
146
+ str = String.new
147
+ str << "#{name}(\n"
148
+ modules.each do |name, mod|
149
+ str << " (#{name}): #{mod.inspect}\n"
150
+ end
151
+ str << ")"
152
+ end
153
+ end
154
+
65
155
  def method_missing(method, *args, &block)
66
156
  modules[method.to_s] || super
67
157
  end
@@ -72,13 +162,12 @@ module Torch
72
162
 
73
163
  private
74
164
 
75
- def modules
76
- modules = {}
77
- instance_variables.each do |name|
78
- mod = instance_variable_get(name)
79
- modules[name[1..-1]] = mod if mod.is_a?(Module)
80
- end
81
- modules
165
+ def extra_inspect
166
+ nil
167
+ end
168
+
169
+ def format(str, *vars)
170
+ str % vars.map(&:inspect)
82
171
  end
83
172
  end
84
173
  end