torch-rb 0.1.0 → 0.1.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +85 -19
  5. data/ext/torch/ext.cpp +274 -256
  6. data/ext/torch/extconf.rb +9 -0
  7. data/ext/torch/nn_functions.cpp +595 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.hpp +250 -0
  10. data/ext/torch/tensor_functions.cpp +1860 -0
  11. data/ext/torch/tensor_functions.hpp +6 -0
  12. data/ext/torch/torch_functions.cpp +2875 -0
  13. data/ext/torch/torch_functions.hpp +6 -0
  14. data/lib/torch.rb +199 -84
  15. data/lib/torch/ext.bundle +0 -0
  16. data/lib/torch/inspector.rb +52 -25
  17. data/lib/torch/native/dispatcher.rb +48 -0
  18. data/lib/torch/native/function.rb +78 -0
  19. data/lib/torch/native/generator.rb +149 -0
  20. data/lib/torch/native/native_functions.yaml +6837 -0
  21. data/lib/torch/native/parser.rb +97 -0
  22. data/lib/torch/nn/alpha_dropout.rb +9 -0
  23. data/lib/torch/nn/avg_pool2d.rb +14 -0
  24. data/lib/torch/nn/avg_poolnd.rb +9 -0
  25. data/lib/torch/nn/bce_loss.rb +13 -0
  26. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  27. data/lib/torch/nn/bilinear.rb +38 -0
  28. data/lib/torch/nn/conv2d.rb +14 -29
  29. data/lib/torch/nn/convnd.rb +41 -0
  30. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  31. data/lib/torch/nn/cosine_similarity.rb +15 -0
  32. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  33. data/lib/torch/nn/ctc_loss.rb +15 -0
  34. data/lib/torch/nn/dropout.rb +9 -0
  35. data/lib/torch/nn/dropout2d.rb +9 -0
  36. data/lib/torch/nn/dropout3d.rb +9 -0
  37. data/lib/torch/nn/dropoutnd.rb +15 -0
  38. data/lib/torch/nn/embedding.rb +52 -0
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  41. data/lib/torch/nn/functional.rb +194 -11
  42. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  43. data/lib/torch/nn/identity.rb +14 -0
  44. data/lib/torch/nn/init.rb +58 -1
  45. data/lib/torch/nn/kl_div_loss.rb +13 -0
  46. data/lib/torch/nn/l1_loss.rb +13 -0
  47. data/lib/torch/nn/leaky_relu.rb +20 -0
  48. data/lib/torch/nn/linear.rb +12 -11
  49. data/lib/torch/nn/log_softmax.rb +14 -0
  50. data/lib/torch/nn/loss.rb +10 -0
  51. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  52. data/lib/torch/nn/max_pool2d.rb +9 -0
  53. data/lib/torch/nn/max_poolnd.rb +19 -0
  54. data/lib/torch/nn/module.rb +184 -19
  55. data/lib/torch/nn/mse_loss.rb +2 -2
  56. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  57. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  58. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  59. data/lib/torch/nn/nll_loss.rb +14 -0
  60. data/lib/torch/nn/pairwise_distance.rb +16 -0
  61. data/lib/torch/nn/parameter.rb +4 -0
  62. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  63. data/lib/torch/nn/prelu.rb +19 -0
  64. data/lib/torch/nn/relu.rb +8 -3
  65. data/lib/torch/nn/rnn.rb +22 -0
  66. data/lib/torch/nn/rnn_base.rb +154 -0
  67. data/lib/torch/nn/sequential.rb +1 -10
  68. data/lib/torch/nn/sigmoid.rb +9 -0
  69. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  70. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  71. data/lib/torch/nn/softmax.rb +18 -0
  72. data/lib/torch/nn/softmax2d.rb +10 -0
  73. data/lib/torch/nn/softmin.rb +14 -0
  74. data/lib/torch/nn/softplus.rb +19 -0
  75. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  76. data/lib/torch/nn/weighted_loss.rb +10 -0
  77. data/lib/torch/optim/adadelta.rb +57 -0
  78. data/lib/torch/optim/adagrad.rb +71 -0
  79. data/lib/torch/optim/adam.rb +81 -0
  80. data/lib/torch/optim/adamax.rb +68 -0
  81. data/lib/torch/optim/adamw.rb +82 -0
  82. data/lib/torch/optim/asgd.rb +65 -0
  83. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  84. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  85. data/lib/torch/optim/optimizer.rb +62 -0
  86. data/lib/torch/optim/rmsprop.rb +76 -0
  87. data/lib/torch/optim/rprop.rb +68 -0
  88. data/lib/torch/optim/sgd.rb +60 -0
  89. data/lib/torch/random.rb +10 -0
  90. data/lib/torch/tensor.rb +92 -21
  91. data/lib/torch/utils/data/data_loader.rb +15 -0
  92. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  93. data/lib/torch/version.rb +1 -1
  94. metadata +74 -3
@@ -0,0 +1,97 @@
1
+ module Torch
2
+ module Native
3
+ class Parser
4
+ def initialize(functions)
5
+ @functions = functions
6
+ @name = @functions.first.ruby_name
7
+ @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && a[:default].nil? } }.min
8
+ @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
+ end
10
+
11
+ def parse(args, options)
12
+ candidates = @functions.dup
13
+
14
+ if args.size < @min_args || args.size > @max_args
15
+ expected = String.new(@min_args.to_s)
16
+ expected += "..#{@max_args}" if @max_args != @min_args
17
+ return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
18
+ end
19
+
20
+ # exclude functions where options don't match
21
+ options.each do |k, v|
22
+ candidates.select! do |func|
23
+ func.args.any? { |a| a[:name] == k.to_s }
24
+ end
25
+ # TODO show all bad keywords at once like Ruby?
26
+ return {error: "unknown keyword: #{k}"} if candidates.empty?
27
+ end
28
+
29
+ # exclude functions missing required options
30
+ candidates.reject! do |func|
31
+ # TODO make more generic
32
+ func.out? && !options[:out]
33
+ end
34
+
35
+ final_values = {}
36
+
37
+ # check args
38
+ candidates.select! do |func|
39
+ good = true
40
+
41
+ values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
42
+ values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
43
+ func.args.each do |fa|
44
+ values[fa[:name]] ||= fa[:default]
45
+ end
46
+
47
+ arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
48
+
49
+ values.each do |k, v|
50
+ t = arg_types[k].split("(").first
51
+ good =
52
+ case t
53
+ when "Tensor"
54
+ v.is_a?(Tensor)
55
+ when "Tensor[]"
56
+ v.all? { |v2| v2.is_a?(Tensor) }
57
+ when "int"
58
+ v.is_a?(Integer)
59
+ when "int[]"
60
+ v.all? { |v2| v2.is_a?(Integer) }
61
+ when "Scalar"
62
+ v.is_a?(Numeric)
63
+ when "bool"
64
+ v == true || v == false
65
+ else
66
+ raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}"
67
+ end
68
+
69
+ if !good
70
+ if candidates.size == 1
71
+ k = "input" if k == "self"
72
+ return {error: "#{@name}(): argument '#{k}' must be #{t}"}
73
+ end
74
+ break
75
+ end
76
+ end
77
+
78
+ if good
79
+ final_values = values
80
+ end
81
+
82
+ good
83
+ end
84
+
85
+ if candidates.size != 1
86
+ raise Error, "This should never happen. Please report a bug with #{@name}."
87
+ end
88
+
89
+ func = candidates.first
90
+ {
91
+ name: func.cpp_name,
92
+ args: func.args.map { |a| final_values[a[:name]] }
93
+ }
94
+ end
95
+ end
96
+ end
97
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class AlphaDropout < DropoutNd
4
+ def forward(input)
5
+ F.alpha_dropout(input, p: @p, training: @training, inplace: @inplace)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPool2d < AvgPoolNd
4
+ def initialize(kernel_size)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ end
8
+
9
+ def forward(input)
10
+ F.avg_pool2d(input, @kernel_size)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPoolNd < Module
4
+ def extra_inspect
5
+ format("kernel_size: %s", @kernel_size)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class BCELoss < WeightedLoss
4
+ def initialize(weight: nil, reduction: "mean")
5
+ super(weight, reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.binary_cross_entropy(input, target, weight: @weight, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class BCEWithLogitsLoss < Loss
4
+ def initialize(weight: nil, reduction: "mean", pos_weight: nil)
5
+ super(reduction)
6
+ register_buffer("weight", weight)
7
+ register_buffer("pos_weight", pos_weight)
8
+ end
9
+
10
+ def forward(input, target)
11
+ F.binary_cross_entropy_with_logits(input, target, weight: weight, pos_weight: pos_weight, reduction: @reduction)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,38 @@
1
+ module Torch
2
+ module NN
3
+ class Bilinear < Module
4
+ def initialize(in1_features, in2_features, out_features, bias: true)
5
+ super()
6
+
7
+ @in1_features = in1_features
8
+ @in2_features = in2_features
9
+ @out_features = out_features
10
+ @weight = Parameter.new(Tensor.new(out_features, in1_features, in2_features))
11
+
12
+ if bias
13
+ @bias = Parameter.new(Tensor.new(out_features))
14
+ else
15
+ raise NotImplementedYet
16
+ end
17
+
18
+ reset_parameters
19
+ end
20
+
21
+ def reset_parameters
22
+ bound = 1 / Math.sqrt(@weight.size(1))
23
+ Init.uniform!(@weight, a: -bound, b: bound)
24
+ if @bias
25
+ Init.uniform!(@bias, a: -bound, b: bound)
26
+ end
27
+ end
28
+
29
+ def forward(input1, input2)
30
+ F.bilinear(input1, input2, @weight, @bias)
31
+ end
32
+
33
+ def extra_inspect
34
+ format("in1_features: %s, in2_features: %s, out_features: %s, bias: %s", @in1_features, @in2_features, @out_features, !@bias.nil?)
35
+ end
36
+ end
37
+ end
38
+ end
@@ -1,39 +1,24 @@
1
1
  module Torch
2
2
  module NN
3
- class Conv2d < Module
4
- attr_reader :bias, :weight
5
-
6
- def initialize(in_channels, out_channels, kernel_size) #, stride: 1, padding: 0, dilation: 1, groups: 1)
7
- @in_channels = in_channels
8
- @out_channels = out_channels
9
- @kernel_size = pair(kernel_size)
10
- @stride = pair(1)
11
- # @stride = pair(stride)
12
- # @padding = pair(padding)
13
- # @dilation = pair(dilation)
14
-
15
- # TODO divide by groups
16
- @weight = Parameter.new(Tensor.new(out_channels, in_channels, *@kernel_size))
17
- @bias = Parameter.new(Tensor.new(out_channels))
18
-
19
- reset_parameters
3
+ class Conv2d < ConvNd
4
+ def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
5
+ kernel_size = pair(kernel_size)
6
+ stride = pair(stride)
7
+ padding = pair(padding)
8
+ dilation = pair(dilation)
9
+ super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, pair(0), groups, bias, padding_mode)
20
10
  end
21
11
 
22
- def reset_parameters
23
- Init.kaiming_uniform_(@weight, Math.sqrt(5))
24
- if @bias
25
- fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
26
- bound = 1 / Math.sqrt(fan_in)
27
- Init.uniform_(@bias, -bound, bound)
12
+ def forward(input)
13
+ if @padding_mode == "circular"
14
+ raise NotImplementedError
28
15
  end
16
+ F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding, dilation: @dilation, groups: @groups)
29
17
  end
30
18
 
31
- def call(input)
32
- F.conv2d(input, @weight, @bias) # @stride, self.padding, self.dilation, self.groups)
33
- end
34
-
35
- def inspect
36
- "Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
19
+ # TODO add more parameters
20
+ def extra_inspect
21
+ format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
37
22
  end
38
23
 
39
24
  private
@@ -0,0 +1,41 @@
1
+ module Torch
2
+ module NN
3
+ class ConvNd < Module
4
+ def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
5
+ super()
6
+ raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
7
+ raise ArgumentError, "out_channels must be divisible by groups" if out_channels % groups != 0
8
+ @in_channels = in_channels
9
+ @out_channels = out_channels
10
+ @kernel_size = kernel_size
11
+ @stride = stride
12
+ @padding = padding
13
+ @dilation = dilation
14
+ @transposed = transposed
15
+ @output_padding = output_padding
16
+ @groups = groups
17
+ @padding_mode = padding_mode
18
+ if transposed
19
+ @weight = Parameter.new(Tensor.new(in_channels, out_channels / groups, *kernel_size))
20
+ else
21
+ @weight = Parameter.new(Tensor.new(out_channels, in_channels / groups, *kernel_size))
22
+ end
23
+ if bias
24
+ @bias = Parameter.new(Tensor.new(out_channels))
25
+ else
26
+ raise NotImplementedError
27
+ end
28
+ reset_parameters
29
+ end
30
+
31
+ def reset_parameters
32
+ Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
33
+ if @bias
34
+ fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
35
+ bound = 1 / Math.sqrt(fan_in)
36
+ Init.uniform!(@bias, a: -bound, b: bound)
37
+ end
38
+ end
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class CosineEmbeddingLoss < Loss
4
+ def initialize(margin: 0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input1, input2, target)
10
+ F.cosine_embedding_loss(input1, input2, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class CosineSimilarity < Module
4
+ def initialize(dim: 1, eps: 1e-8)
5
+ super()
6
+ @dim = dim
7
+ @eps = eps
8
+ end
9
+
10
+ def forward(x1, x2)
11
+ F.cosine_similarity(x1, x2, dim: @dim, eps: @eps)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class CrossEntropyLoss < WeightedLoss
4
+ def initialize(weight: nil, ignore_index: -100, reduction: "mean")
5
+ super(weight, reduction)
6
+ @ignore_index = ignore_index
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.cross_entropy(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class CTCLoss < Loss
4
+ def initialize(blank: 0, reduction: "mean", zero_infinity: false)
5
+ super(reduction)
6
+ @blank = blank
7
+ @zero_infinity = zero_infinity
8
+ end
9
+
10
+ def forward(log_probs, targets, input_lengths, target_lengths)
11
+ F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: @blank, reduction: @reduction, zero_infinity: @zero_infinity)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Dropout < DropoutNd
4
+ def forward(input)
5
+ F.dropout(input, p: @p, training: @training, inplace: @inplace)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Dropout2d < DropoutNd
4
+ def forward(input)
5
+ F.dropout2d(input, p: @p, training: @training, inplace: @inplace)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Dropout3d < DropoutNd
4
+ def forward(input)
5
+ F.dropout3d(input, p: @p, training: @training, inplace: @inplace)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class DropoutNd < Module
4
+ def initialize(p: 0.5, inplace: false)
5
+ super()
6
+ @p = p
7
+ @inplace = inplace
8
+ end
9
+
10
+ def extra_inspect
11
+ format("p: %s, inplace: %s", @p, @inplace)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,52 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
2
+ module Torch
3
+ module NN
4
+ class Embedding < Module
5
+ def initialize(num_embeddings, embedding_dim, padding_idx: nil, max_norm: nil,
6
+ norm_type: 2.0, scale_grad_by_freq: false, sparse: false, _weight: nil)
7
+
8
+ super()
9
+ @num_embeddings = num_embeddings
10
+ @embedding_dim = embedding_dim
11
+
12
+ if padding_idx
13
+ if padding_idx > 0
14
+ raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx < @num_embeddings
15
+ elsif padding_idx < 0
16
+ raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx >= -@num_embeddings
17
+ padding_idx = @num_embeddings + padding_idx
18
+ end
19
+ end
20
+ @padding_idx = padding_idx
21
+ @max_norm = max_norm
22
+ @norm_type = norm_type
23
+ @scale_grad_by_freq = scale_grad_by_freq
24
+ if _weight.nil?
25
+ @weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
26
+ reset_parameters
27
+ else
28
+ raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
29
+ @weight = Parameter.new(_weight)
30
+ end
31
+ @sparse = sparse
32
+ end
33
+
34
+ def reset_parameters
35
+ Init.normal!(@weight)
36
+ if @padding_idx
37
+ Torch.no_grad do
38
+ @weight[@padding_idx].fill!(0)
39
+ end
40
+ end
41
+ end
42
+
43
+ def forward(input)
44
+ F.embedding(input, @weight, padding_idx: @padding_idx, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, sparse: @sparse)
45
+ end
46
+
47
+ def inspect
48
+ "Embedding(#{@num_embeddings}, #{@embedding_dim})"
49
+ end
50
+ end
51
+ end
52
+ end