torch-rb 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/README.md +1 -0
  4. data/ext/torch/ext.cpp +375 -124
  5. data/lib/torch.rb +101 -20
  6. data/lib/torch/ext.bundle +0 -0
  7. data/lib/torch/inspector.rb +23 -19
  8. data/lib/torch/nn/avg_pool2d.rb +14 -0
  9. data/lib/torch/nn/avg_poolnd.rb +9 -0
  10. data/lib/torch/nn/bce_loss.rb +13 -0
  11. data/lib/torch/nn/bilinear.rb +38 -0
  12. data/lib/torch/nn/conv2d.rb +2 -2
  13. data/lib/torch/nn/convnd.rb +3 -3
  14. data/lib/torch/nn/cosine_similarity.rb +15 -0
  15. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  16. data/lib/torch/nn/ctc_loss.rb +15 -0
  17. data/lib/torch/nn/dropoutnd.rb +2 -2
  18. data/lib/torch/nn/embedding_bag.rb +34 -0
  19. data/lib/torch/nn/functional.rb +101 -13
  20. data/lib/torch/nn/identity.rb +13 -0
  21. data/lib/torch/nn/init.rb +58 -1
  22. data/lib/torch/nn/kl_div_loss.rb +13 -0
  23. data/lib/torch/nn/l1_loss.rb +13 -0
  24. data/lib/torch/nn/leaky_relu.rb +20 -0
  25. data/lib/torch/nn/linear.rb +12 -11
  26. data/lib/torch/nn/log_softmax.rb +14 -0
  27. data/lib/torch/nn/loss.rb +10 -0
  28. data/lib/torch/nn/max_pool2d.rb +9 -0
  29. data/lib/torch/nn/max_poolnd.rb +19 -0
  30. data/lib/torch/nn/module.rb +120 -31
  31. data/lib/torch/nn/mse_loss.rb +2 -2
  32. data/lib/torch/nn/nll_loss.rb +14 -0
  33. data/lib/torch/nn/pairwise_distance.rb +16 -0
  34. data/lib/torch/nn/parameter.rb +0 -4
  35. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  36. data/lib/torch/nn/prelu.rb +19 -0
  37. data/lib/torch/nn/relu.rb +8 -3
  38. data/lib/torch/nn/sequential.rb +1 -10
  39. data/lib/torch/nn/sigmoid.rb +9 -0
  40. data/lib/torch/nn/softmax.rb +18 -0
  41. data/lib/torch/nn/softmax2d.rb +10 -0
  42. data/lib/torch/nn/softmin.rb +14 -0
  43. data/lib/torch/nn/softplus.rb +19 -0
  44. data/lib/torch/nn/weighted_loss.rb +10 -0
  45. data/lib/torch/random.rb +10 -0
  46. data/lib/torch/tensor.rb +28 -10
  47. data/lib/torch/version.rb +1 -1
  48. metadata +29 -2
data/lib/torch.rb CHANGED
@@ -22,31 +22,90 @@ require "torch/optim/sgd"
22
22
  require "torch/optim/lr_scheduler/lr_scheduler"
23
23
  require "torch/optim/lr_scheduler/step_lr"
24
24
 
25
- # nn base classes
25
+ # nn parameters
26
+ require "torch/nn/parameter"
27
+
28
+ # nn containers
26
29
  require "torch/nn/module"
30
+ require "torch/nn/sequential"
31
+
32
+ # nn convolution layers
27
33
  require "torch/nn/convnd"
28
- require "torch/nn/dropoutnd"
34
+ require "torch/nn/conv2d"
35
+
36
+ # nn pooling layers
37
+ require "torch/nn/max_poolnd"
38
+ require "torch/nn/max_pool2d"
39
+ require "torch/nn/avg_poolnd"
40
+ require "torch/nn/avg_pool2d"
41
+
42
+ # nn linear layers
43
+ require "torch/nn/bilinear"
44
+ require "torch/nn/identity"
45
+ require "torch/nn/linear"
29
46
 
30
- # nn
47
+ # nn dropout layers
48
+ require "torch/nn/dropoutnd"
31
49
  require "torch/nn/alpha_dropout"
32
- require "torch/nn/conv2d"
33
50
  require "torch/nn/dropout"
34
51
  require "torch/nn/dropout2d"
35
52
  require "torch/nn/dropout3d"
36
- require "torch/nn/embedding"
37
53
  require "torch/nn/feature_alpha_dropout"
54
+
55
+ # nn activations
56
+ require "torch/nn/leaky_relu"
57
+ require "torch/nn/prelu"
58
+ require "torch/nn/relu"
59
+ require "torch/nn/sigmoid"
60
+ require "torch/nn/softplus"
61
+
62
+ # nn activations other
63
+ require "torch/nn/log_softmax"
64
+ require "torch/nn/softmax"
65
+ require "torch/nn/softmax2d"
66
+ require "torch/nn/softmin"
67
+
68
+ # nn sparse layers
69
+ require "torch/nn/embedding"
70
+ require "torch/nn/embedding_bag"
71
+
72
+ # nn distance functions
73
+ require "torch/nn/cosine_similarity"
74
+ require "torch/nn/pairwise_distance"
75
+
76
+ # nn loss functions
77
+ require "torch/nn/loss"
78
+ require "torch/nn/weighted_loss"
79
+ require "torch/nn/bce_loss"
80
+ # require "torch/nn/bce_with_logits_loss"
81
+ # require "torch/nn/cosine_embedding_loss"
82
+ require "torch/nn/cross_entropy_loss"
83
+ require "torch/nn/ctc_loss"
84
+ # require "torch/nn/hinge_embedding_loss"
85
+ require "torch/nn/kl_div_loss"
86
+ require "torch/nn/l1_loss"
87
+ # require "torch/nn/margin_ranking_loss"
88
+ require "torch/nn/mse_loss"
89
+ # require "torch/nn/multi_label_margin_loss"
90
+ # require "torch/nn/multi_label_soft_margin_loss"
91
+ # require "torch/nn/multi_margin_loss"
92
+ require "torch/nn/nll_loss"
93
+ require "torch/nn/poisson_nll_loss"
94
+ # require "torch/nn/smooth_l1_loss"
95
+ # require "torch/nn/soft_margin_loss"
96
+ # require "torch/nn/triplet_margin_loss"
97
+
98
+ # nn other
38
99
  require "torch/nn/functional"
39
100
  require "torch/nn/init"
40
- require "torch/nn/linear"
41
- require "torch/nn/mse_loss"
42
- require "torch/nn/parameter"
43
- require "torch/nn/relu"
44
- require "torch/nn/sequential"
45
101
 
46
102
  # utils
47
103
  require "torch/utils/data/data_loader"
48
104
  require "torch/utils/data/tensor_dataset"
49
105
 
106
+ # random
107
+ require "torch/random"
108
+
50
109
  module Torch
51
110
  class Error < StandardError; end
52
111
  class NotImplementedYet < StandardError
@@ -57,7 +116,6 @@ module Torch
57
116
 
58
117
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
59
118
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
60
- # complex and quantized types not supported by PyTorch yet
61
119
  DTYPE_TO_ENUM = {
62
120
  uint8: 0,
63
121
  int8: 1,
@@ -73,14 +131,14 @@ module Torch
73
131
  float32: 6,
74
132
  double: 7,
75
133
  float64: 7,
76
- # complex_half: 8,
77
- # complex_float: 9,
78
- # complex_double: 10,
134
+ complex_half: 8,
135
+ complex_float: 9,
136
+ complex_double: 10,
79
137
  bool: 11,
80
- # qint8: 12,
81
- # quint8: 13,
82
- # qint32: 14,
83
- # bfloat16: 15
138
+ qint8: 12,
139
+ quint8: 13,
140
+ qint32: 14,
141
+ bfloat16: 15
84
142
  }
85
143
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
86
144
 
@@ -120,6 +178,8 @@ module Torch
120
178
  # use method for cases when Numo not available
121
179
  # or available after Torch loaded
122
180
  def _dtype_to_numo
181
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
182
+
123
183
  {
124
184
  uint8: Numo::UInt8,
125
185
  int8: Numo::Int8,
@@ -200,8 +260,12 @@ module Torch
200
260
  data = [data].compact
201
261
  end
202
262
 
203
- if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
204
- options[:dtype] = :int64
263
+ if options[:dtype].nil?
264
+ if data.all? { |v| v.is_a?(Integer) }
265
+ options[:dtype] = :int64
266
+ elsif data.all? { |v| v == true || v == false }
267
+ options[:dtype] = :bool
268
+ end
205
269
  end
206
270
 
207
271
  _tensor(data, size, tensor_options(**options))
@@ -302,6 +366,10 @@ module Torch
302
366
  _pow(input, exponent)
303
367
  end
304
368
 
369
+ def topk(input, k)
370
+ _topk(input, k)
371
+ end
372
+
305
373
  def min(input)
306
374
  _min(input)
307
375
  end
@@ -327,6 +395,10 @@ module Torch
327
395
  _sign(input)
328
396
  end
329
397
 
398
+ def sigmoid(input)
399
+ _sigmoid(input)
400
+ end
401
+
330
402
  def gt(input, other)
331
403
  _gt(input, other)
332
404
  end
@@ -363,6 +435,15 @@ module Torch
363
435
  _sqrt(input)
364
436
  end
365
437
 
438
+ # TODO make dim keyword argument
439
+ def log_softmax(input, dim)
440
+ _log_softmax(input, dim)
441
+ end
442
+
443
+ def softmax(input, dim: nil)
444
+ _softmax(input, dim)
445
+ end
446
+
366
447
  def abs(input)
367
448
  _abs(input)
368
449
  end
data/lib/torch/ext.bundle CHANGED
Binary file
@@ -11,32 +11,36 @@ module Torch
11
11
  else
12
12
  summarize = numel > 1000
13
13
 
14
- values = to_a.flatten
15
- abs = values.select { |v| v != 0 }.map(&:abs)
16
- max = abs.max || 1
17
- min = abs.min || 1
14
+ if dtype == :bool
15
+ fmt = "%s"
16
+ else
17
+ values = to_a.flatten
18
+ abs = values.select { |v| v != 0 }.map(&:abs)
19
+ max = abs.max || 1
20
+ min = abs.min || 1
18
21
 
19
- total = 0
20
- if values.any? { |v| v < 0 }
21
- total += 1
22
- end
22
+ total = 0
23
+ if values.any? { |v| v < 0 }
24
+ total += 1
25
+ end
23
26
 
24
- if floating_point?
25
- sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
27
+ if floating_point?
28
+ sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
26
29
 
27
- all_int = values.all? { |v| v.finite? && v == v.to_i }
28
- decimal = all_int ? 1 : 4
30
+ all_int = values.all? { |v| v.finite? && v == v.to_i }
31
+ decimal = all_int ? 1 : 4
29
32
 
30
- total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
33
+ total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
31
34
 
32
- if sci
33
- fmt = "%#{total}.4e"
35
+ if sci
36
+ fmt = "%#{total}.4e"
37
+ else
38
+ fmt = "%#{total}.#{decimal}f"
39
+ end
34
40
  else
35
- fmt = "%#{total}.#{decimal}f"
41
+ total += max.to_s.size
42
+ fmt = "%#{total}d"
36
43
  end
37
- else
38
- total += max.to_s.size
39
- fmt = "%#{total}d"
40
44
  end
41
45
 
42
46
  inspect_level(to_a, fmt, dim - 1, 0, summarize)
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPool2d < AvgPoolNd
4
+ def initialize(kernel_size)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ end
8
+
9
+ def forward(input)
10
+ F.avg_pool2d(input, @kernel_size)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPoolNd < Module
4
+ def extra_inspect
5
+ format("kernel_size: %s", @kernel_size)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class BCELoss < WeightedLoss
4
+ def initialize(weight: nil, reduction: "mean")
5
+ super(weight, reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.binary_cross_entropy(input, target, weight: @weight, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,38 @@
1
+ module Torch
2
+ module NN
3
+ class Bilinear < Module
4
+ def initialize(in1_features, in2_features, out_features, bias: true)
5
+ super()
6
+
7
+ @in1_features = in1_features
8
+ @in2_features = in2_features
9
+ @out_features = out_features
10
+ @weight = Parameter.new(Tensor.new(out_features, in1_features, in2_features))
11
+
12
+ if bias
13
+ @bias = Parameter.new(Tensor.new(out_features))
14
+ else
15
+ raise NotImplementedYet
16
+ end
17
+
18
+ reset_parameters
19
+ end
20
+
21
+ def reset_parameters
22
+ bound = 1 / Math.sqrt(@weight.size(1))
23
+ Init.uniform!(@weight, a: -bound, b: bound)
24
+ if @bias
25
+ Init.uniform!(@bias, a: -bound, b: bound)
26
+ end
27
+ end
28
+
29
+ def forward(input1, input2)
30
+ F.bilinear(input1, input2, @weight, @bias)
31
+ end
32
+
33
+ def extra_inspect
34
+ format("in1_features: %s, in2_features: %s, out_features: %s, bias: %s", @in1_features, @in2_features, @out_features, !@bias.nil?)
35
+ end
36
+ end
37
+ end
38
+ end
@@ -19,8 +19,8 @@ module Torch
19
19
  end
20
20
 
21
21
  # TODO add more parameters
22
- def inspect
23
- "Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
22
+ def extra_inspect
23
+ format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
24
24
  end
25
25
 
26
26
  private
@@ -29,11 +29,11 @@ module Torch
29
29
  end
30
30
 
31
31
  def reset_parameters
32
- Init.kaiming_uniform!(@weight, Math.sqrt(5))
32
+ Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
33
33
  if @bias
34
- fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
34
+ fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
35
35
  bound = 1 / Math.sqrt(fan_in)
36
- Init.uniform!(@bias, -bound, bound)
36
+ Init.uniform!(@bias, a: -bound, b: bound)
37
37
  end
38
38
  end
39
39
  end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class CosineSimilarity < Module
4
+ def initialize(dim: 1, eps: 1e-8)
5
+ super()
6
+ @dim = dim
7
+ @eps = eps
8
+ end
9
+
10
+ def forward(x1, x2)
11
+ F.cosine_similarity(x1, x2, dim: @dim, eps: @eps)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class CrossEntropyLoss < WeightedLoss
4
+ def initialize(weight: nil, ignore_index: -100, reduction: "mean")
5
+ super(weight, reduction)
6
+ @ignore_index = ignore_index
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.cross_entropy(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class CTCLoss < Loss
4
+ def initialize(blank: 0, reduction: "mean", zero_infinity: false)
5
+ super(reduction)
6
+ @blank = blank
7
+ @zero_infinity = zero_infinity
8
+ end
9
+
10
+ def forward(log_probs, targets, input_lengths, target_lengths)
11
+ F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: @blank, reduction: @reduction, zero_infinity: @zero_infinity)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -7,8 +7,8 @@ module Torch
7
7
  @inplace = inplace
8
8
  end
9
9
 
10
- def inspect
11
- "#{self.class.name.split("::").last}(p: #{@p.inspect}, inplace: #{@inplace.inspect})"
10
+ def extra_inspect
11
+ format("p: %s, inplace: %s", @p, @inplace)
12
12
  end
13
13
  end
14
14
  end
@@ -0,0 +1,34 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
2
+ module Torch
3
+ module NN
4
+ class EmbeddingBag < Module
5
+ def initialize(num_embeddings, embedding_dim, max_norm: nil, norm_type: 2.0,
6
+ scale_grad_by_freq: false, mode: "mean", sparse: false, _weight: nil)
7
+
8
+ super()
9
+ @num_embeddings = num_embeddings
10
+ @embedding_dim = embedding_dim
11
+ @max_norm = max_norm
12
+ @norm_type = norm_type
13
+ @scale_grad_by_freq = scale_grad_by_freq
14
+ if _weight.nil?
15
+ @weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
16
+ reset_parameters
17
+ else
18
+ raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
19
+ @weight = Parameter.new(_weight)
20
+ end
21
+ @mode = mode
22
+ @sparse = sparse
23
+ end
24
+
25
+ def reset_parameters
26
+ Init.normal!(@weight)
27
+ end
28
+
29
+ def forward(input, offsets: nil, per_sample_weights: nil)
30
+ F.embedding_bag(input, @weight, offsets: offsets, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, mode: @mode, sparse: @sparse, per_sample_weights: per_sample_weights)
31
+ end
32
+ end
33
+ end
34
+ end