torch-rb 0.1.3 → 0.1.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (48) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/README.md +1 -0
  4. data/ext/torch/ext.cpp +375 -124
  5. data/lib/torch.rb +101 -20
  6. data/lib/torch/ext.bundle +0 -0
  7. data/lib/torch/inspector.rb +23 -19
  8. data/lib/torch/nn/avg_pool2d.rb +14 -0
  9. data/lib/torch/nn/avg_poolnd.rb +9 -0
  10. data/lib/torch/nn/bce_loss.rb +13 -0
  11. data/lib/torch/nn/bilinear.rb +38 -0
  12. data/lib/torch/nn/conv2d.rb +2 -2
  13. data/lib/torch/nn/convnd.rb +3 -3
  14. data/lib/torch/nn/cosine_similarity.rb +15 -0
  15. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  16. data/lib/torch/nn/ctc_loss.rb +15 -0
  17. data/lib/torch/nn/dropoutnd.rb +2 -2
  18. data/lib/torch/nn/embedding_bag.rb +34 -0
  19. data/lib/torch/nn/functional.rb +101 -13
  20. data/lib/torch/nn/identity.rb +13 -0
  21. data/lib/torch/nn/init.rb +58 -1
  22. data/lib/torch/nn/kl_div_loss.rb +13 -0
  23. data/lib/torch/nn/l1_loss.rb +13 -0
  24. data/lib/torch/nn/leaky_relu.rb +20 -0
  25. data/lib/torch/nn/linear.rb +12 -11
  26. data/lib/torch/nn/log_softmax.rb +14 -0
  27. data/lib/torch/nn/loss.rb +10 -0
  28. data/lib/torch/nn/max_pool2d.rb +9 -0
  29. data/lib/torch/nn/max_poolnd.rb +19 -0
  30. data/lib/torch/nn/module.rb +120 -31
  31. data/lib/torch/nn/mse_loss.rb +2 -2
  32. data/lib/torch/nn/nll_loss.rb +14 -0
  33. data/lib/torch/nn/pairwise_distance.rb +16 -0
  34. data/lib/torch/nn/parameter.rb +0 -4
  35. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  36. data/lib/torch/nn/prelu.rb +19 -0
  37. data/lib/torch/nn/relu.rb +8 -3
  38. data/lib/torch/nn/sequential.rb +1 -10
  39. data/lib/torch/nn/sigmoid.rb +9 -0
  40. data/lib/torch/nn/softmax.rb +18 -0
  41. data/lib/torch/nn/softmax2d.rb +10 -0
  42. data/lib/torch/nn/softmin.rb +14 -0
  43. data/lib/torch/nn/softplus.rb +19 -0
  44. data/lib/torch/nn/weighted_loss.rb +10 -0
  45. data/lib/torch/random.rb +10 -0
  46. data/lib/torch/tensor.rb +28 -10
  47. data/lib/torch/version.rb +1 -1
  48. metadata +29 -2
data/lib/torch.rb CHANGED
@@ -22,31 +22,90 @@ require "torch/optim/sgd"
22
22
  require "torch/optim/lr_scheduler/lr_scheduler"
23
23
  require "torch/optim/lr_scheduler/step_lr"
24
24
 
25
- # nn base classes
25
+ # nn parameters
26
+ require "torch/nn/parameter"
27
+
28
+ # nn containers
26
29
  require "torch/nn/module"
30
+ require "torch/nn/sequential"
31
+
32
+ # nn convolution layers
27
33
  require "torch/nn/convnd"
28
- require "torch/nn/dropoutnd"
34
+ require "torch/nn/conv2d"
35
+
36
+ # nn pooling layers
37
+ require "torch/nn/max_poolnd"
38
+ require "torch/nn/max_pool2d"
39
+ require "torch/nn/avg_poolnd"
40
+ require "torch/nn/avg_pool2d"
41
+
42
+ # nn linear layers
43
+ require "torch/nn/bilinear"
44
+ require "torch/nn/identity"
45
+ require "torch/nn/linear"
29
46
 
30
- # nn
47
+ # nn dropout layers
48
+ require "torch/nn/dropoutnd"
31
49
  require "torch/nn/alpha_dropout"
32
- require "torch/nn/conv2d"
33
50
  require "torch/nn/dropout"
34
51
  require "torch/nn/dropout2d"
35
52
  require "torch/nn/dropout3d"
36
- require "torch/nn/embedding"
37
53
  require "torch/nn/feature_alpha_dropout"
54
+
55
+ # nn activations
56
+ require "torch/nn/leaky_relu"
57
+ require "torch/nn/prelu"
58
+ require "torch/nn/relu"
59
+ require "torch/nn/sigmoid"
60
+ require "torch/nn/softplus"
61
+
62
+ # nn activations other
63
+ require "torch/nn/log_softmax"
64
+ require "torch/nn/softmax"
65
+ require "torch/nn/softmax2d"
66
+ require "torch/nn/softmin"
67
+
68
+ # nn sparse layers
69
+ require "torch/nn/embedding"
70
+ require "torch/nn/embedding_bag"
71
+
72
+ # nn distance functions
73
+ require "torch/nn/cosine_similarity"
74
+ require "torch/nn/pairwise_distance"
75
+
76
+ # nn loss functions
77
+ require "torch/nn/loss"
78
+ require "torch/nn/weighted_loss"
79
+ require "torch/nn/bce_loss"
80
+ # require "torch/nn/bce_with_logits_loss"
81
+ # require "torch/nn/cosine_embedding_loss"
82
+ require "torch/nn/cross_entropy_loss"
83
+ require "torch/nn/ctc_loss"
84
+ # require "torch/nn/hinge_embedding_loss"
85
+ require "torch/nn/kl_div_loss"
86
+ require "torch/nn/l1_loss"
87
+ # require "torch/nn/margin_ranking_loss"
88
+ require "torch/nn/mse_loss"
89
+ # require "torch/nn/multi_label_margin_loss"
90
+ # require "torch/nn/multi_label_soft_margin_loss"
91
+ # require "torch/nn/multi_margin_loss"
92
+ require "torch/nn/nll_loss"
93
+ require "torch/nn/poisson_nll_loss"
94
+ # require "torch/nn/smooth_l1_loss"
95
+ # require "torch/nn/soft_margin_loss"
96
+ # require "torch/nn/triplet_margin_loss"
97
+
98
+ # nn other
38
99
  require "torch/nn/functional"
39
100
  require "torch/nn/init"
40
- require "torch/nn/linear"
41
- require "torch/nn/mse_loss"
42
- require "torch/nn/parameter"
43
- require "torch/nn/relu"
44
- require "torch/nn/sequential"
45
101
 
46
102
  # utils
47
103
  require "torch/utils/data/data_loader"
48
104
  require "torch/utils/data/tensor_dataset"
49
105
 
106
+ # random
107
+ require "torch/random"
108
+
50
109
  module Torch
51
110
  class Error < StandardError; end
52
111
  class NotImplementedYet < StandardError
@@ -57,7 +116,6 @@ module Torch
57
116
 
58
117
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
59
118
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
60
- # complex and quantized types not supported by PyTorch yet
61
119
  DTYPE_TO_ENUM = {
62
120
  uint8: 0,
63
121
  int8: 1,
@@ -73,14 +131,14 @@ module Torch
73
131
  float32: 6,
74
132
  double: 7,
75
133
  float64: 7,
76
- # complex_half: 8,
77
- # complex_float: 9,
78
- # complex_double: 10,
134
+ complex_half: 8,
135
+ complex_float: 9,
136
+ complex_double: 10,
79
137
  bool: 11,
80
- # qint8: 12,
81
- # quint8: 13,
82
- # qint32: 14,
83
- # bfloat16: 15
138
+ qint8: 12,
139
+ quint8: 13,
140
+ qint32: 14,
141
+ bfloat16: 15
84
142
  }
85
143
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
86
144
 
@@ -120,6 +178,8 @@ module Torch
120
178
  # use method for cases when Numo not available
121
179
  # or available after Torch loaded
122
180
  def _dtype_to_numo
181
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
182
+
123
183
  {
124
184
  uint8: Numo::UInt8,
125
185
  int8: Numo::Int8,
@@ -200,8 +260,12 @@ module Torch
200
260
  data = [data].compact
201
261
  end
202
262
 
203
- if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
204
- options[:dtype] = :int64
263
+ if options[:dtype].nil?
264
+ if data.all? { |v| v.is_a?(Integer) }
265
+ options[:dtype] = :int64
266
+ elsif data.all? { |v| v == true || v == false }
267
+ options[:dtype] = :bool
268
+ end
205
269
  end
206
270
 
207
271
  _tensor(data, size, tensor_options(**options))
@@ -302,6 +366,10 @@ module Torch
302
366
  _pow(input, exponent)
303
367
  end
304
368
 
369
+ def topk(input, k)
370
+ _topk(input, k)
371
+ end
372
+
305
373
  def min(input)
306
374
  _min(input)
307
375
  end
@@ -327,6 +395,10 @@ module Torch
327
395
  _sign(input)
328
396
  end
329
397
 
398
+ def sigmoid(input)
399
+ _sigmoid(input)
400
+ end
401
+
330
402
  def gt(input, other)
331
403
  _gt(input, other)
332
404
  end
@@ -363,6 +435,15 @@ module Torch
363
435
  _sqrt(input)
364
436
  end
365
437
 
438
+ # TODO make dim keyword argument
439
+ def log_softmax(input, dim)
440
+ _log_softmax(input, dim)
441
+ end
442
+
443
+ def softmax(input, dim: nil)
444
+ _softmax(input, dim)
445
+ end
446
+
366
447
  def abs(input)
367
448
  _abs(input)
368
449
  end
data/lib/torch/ext.bundle CHANGED
Binary file
@@ -11,32 +11,36 @@ module Torch
11
11
  else
12
12
  summarize = numel > 1000
13
13
 
14
- values = to_a.flatten
15
- abs = values.select { |v| v != 0 }.map(&:abs)
16
- max = abs.max || 1
17
- min = abs.min || 1
14
+ if dtype == :bool
15
+ fmt = "%s"
16
+ else
17
+ values = to_a.flatten
18
+ abs = values.select { |v| v != 0 }.map(&:abs)
19
+ max = abs.max || 1
20
+ min = abs.min || 1
18
21
 
19
- total = 0
20
- if values.any? { |v| v < 0 }
21
- total += 1
22
- end
22
+ total = 0
23
+ if values.any? { |v| v < 0 }
24
+ total += 1
25
+ end
23
26
 
24
- if floating_point?
25
- sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
27
+ if floating_point?
28
+ sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
26
29
 
27
- all_int = values.all? { |v| v.finite? && v == v.to_i }
28
- decimal = all_int ? 1 : 4
30
+ all_int = values.all? { |v| v.finite? && v == v.to_i }
31
+ decimal = all_int ? 1 : 4
29
32
 
30
- total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
33
+ total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
31
34
 
32
- if sci
33
- fmt = "%#{total}.4e"
35
+ if sci
36
+ fmt = "%#{total}.4e"
37
+ else
38
+ fmt = "%#{total}.#{decimal}f"
39
+ end
34
40
  else
35
- fmt = "%#{total}.#{decimal}f"
41
+ total += max.to_s.size
42
+ fmt = "%#{total}d"
36
43
  end
37
- else
38
- total += max.to_s.size
39
- fmt = "%#{total}d"
40
44
  end
41
45
 
42
46
  inspect_level(to_a, fmt, dim - 1, 0, summarize)
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPool2d < AvgPoolNd
4
+ def initialize(kernel_size)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ end
8
+
9
+ def forward(input)
10
+ F.avg_pool2d(input, @kernel_size)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPoolNd < Module
4
+ def extra_inspect
5
+ format("kernel_size: %s", @kernel_size)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class BCELoss < WeightedLoss
4
+ def initialize(weight: nil, reduction: "mean")
5
+ super(weight, reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.binary_cross_entropy(input, target, weight: @weight, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,38 @@
1
+ module Torch
2
+ module NN
3
+ class Bilinear < Module
4
+ def initialize(in1_features, in2_features, out_features, bias: true)
5
+ super()
6
+
7
+ @in1_features = in1_features
8
+ @in2_features = in2_features
9
+ @out_features = out_features
10
+ @weight = Parameter.new(Tensor.new(out_features, in1_features, in2_features))
11
+
12
+ if bias
13
+ @bias = Parameter.new(Tensor.new(out_features))
14
+ else
15
+ raise NotImplementedYet
16
+ end
17
+
18
+ reset_parameters
19
+ end
20
+
21
+ def reset_parameters
22
+ bound = 1 / Math.sqrt(@weight.size(1))
23
+ Init.uniform!(@weight, a: -bound, b: bound)
24
+ if @bias
25
+ Init.uniform!(@bias, a: -bound, b: bound)
26
+ end
27
+ end
28
+
29
+ def forward(input1, input2)
30
+ F.bilinear(input1, input2, @weight, @bias)
31
+ end
32
+
33
+ def extra_inspect
34
+ format("in1_features: %s, in2_features: %s, out_features: %s, bias: %s", @in1_features, @in2_features, @out_features, !@bias.nil?)
35
+ end
36
+ end
37
+ end
38
+ end
@@ -19,8 +19,8 @@ module Torch
19
19
  end
20
20
 
21
21
  # TODO add more parameters
22
- def inspect
23
- "Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
22
+ def extra_inspect
23
+ format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
24
24
  end
25
25
 
26
26
  private
@@ -29,11 +29,11 @@ module Torch
29
29
  end
30
30
 
31
31
  def reset_parameters
32
- Init.kaiming_uniform!(@weight, Math.sqrt(5))
32
+ Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
33
33
  if @bias
34
- fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
34
+ fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
35
35
  bound = 1 / Math.sqrt(fan_in)
36
- Init.uniform!(@bias, -bound, bound)
36
+ Init.uniform!(@bias, a: -bound, b: bound)
37
37
  end
38
38
  end
39
39
  end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class CosineSimilarity < Module
4
+ def initialize(dim: 1, eps: 1e-8)
5
+ super()
6
+ @dim = dim
7
+ @eps = eps
8
+ end
9
+
10
+ def forward(x1, x2)
11
+ F.cosine_similarity(x1, x2, dim: @dim, eps: @eps)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class CrossEntropyLoss < WeightedLoss
4
+ def initialize(weight: nil, ignore_index: -100, reduction: "mean")
5
+ super(weight, reduction)
6
+ @ignore_index = ignore_index
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.cross_entropy(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class CTCLoss < Loss
4
+ def initialize(blank: 0, reduction: "mean", zero_infinity: false)
5
+ super(reduction)
6
+ @blank = blank
7
+ @zero_infinity = zero_infinity
8
+ end
9
+
10
+ def forward(log_probs, targets, input_lengths, target_lengths)
11
+ F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: @blank, reduction: @reduction, zero_infinity: @zero_infinity)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -7,8 +7,8 @@ module Torch
7
7
  @inplace = inplace
8
8
  end
9
9
 
10
- def inspect
11
- "#{self.class.name.split("::").last}(p: #{@p.inspect}, inplace: #{@inplace.inspect})"
10
+ def extra_inspect
11
+ format("p: %s, inplace: %s", @p, @inplace)
12
12
  end
13
13
  end
14
14
  end
@@ -0,0 +1,34 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
2
+ module Torch
3
+ module NN
4
+ class EmbeddingBag < Module
5
+ def initialize(num_embeddings, embedding_dim, max_norm: nil, norm_type: 2.0,
6
+ scale_grad_by_freq: false, mode: "mean", sparse: false, _weight: nil)
7
+
8
+ super()
9
+ @num_embeddings = num_embeddings
10
+ @embedding_dim = embedding_dim
11
+ @max_norm = max_norm
12
+ @norm_type = norm_type
13
+ @scale_grad_by_freq = scale_grad_by_freq
14
+ if _weight.nil?
15
+ @weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
16
+ reset_parameters
17
+ else
18
+ raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
19
+ @weight = Parameter.new(_weight)
20
+ end
21
+ @mode = mode
22
+ @sparse = sparse
23
+ end
24
+
25
+ def reset_parameters
26
+ Init.normal!(@weight)
27
+ end
28
+
29
+ def forward(input, offsets: nil, per_sample_weights: nil)
30
+ F.embedding_bag(input, @weight, offsets: offsets, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, mode: @mode, sparse: @sparse, per_sample_weights: per_sample_weights)
31
+ end
32
+ end
33
+ end
34
+ end