torch-rb 0.1.3 → 0.1.8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (115) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +30 -0
  3. data/README.md +5 -2
  4. data/ext/torch/ext.cpp +130 -555
  5. data/ext/torch/extconf.rb +9 -0
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +244 -0
  8. data/lib/torch.rb +209 -171
  9. data/lib/torch/inspector.rb +23 -19
  10. data/lib/torch/native/dispatcher.rb +48 -0
  11. data/lib/torch/native/function.rb +110 -0
  12. data/lib/torch/native/generator.rb +168 -0
  13. data/lib/torch/native/native_functions.yaml +6491 -0
  14. data/lib/torch/native/parser.rb +134 -0
  15. data/lib/torch/nn/avg_pool1d.rb +18 -0
  16. data/lib/torch/nn/avg_pool2d.rb +19 -0
  17. data/lib/torch/nn/avg_pool3d.rb +19 -0
  18. data/lib/torch/nn/avg_poolnd.rb +9 -0
  19. data/lib/torch/nn/batch_norm.rb +75 -0
  20. data/lib/torch/nn/batch_norm1d.rb +11 -0
  21. data/lib/torch/nn/batch_norm2d.rb +11 -0
  22. data/lib/torch/nn/batch_norm3d.rb +11 -0
  23. data/lib/torch/nn/bce_loss.rb +13 -0
  24. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  25. data/lib/torch/nn/bilinear.rb +38 -0
  26. data/lib/torch/nn/constant_pad1d.rb +10 -0
  27. data/lib/torch/nn/constant_pad2d.rb +10 -0
  28. data/lib/torch/nn/constant_pad3d.rb +10 -0
  29. data/lib/torch/nn/constant_padnd.rb +18 -0
  30. data/lib/torch/nn/conv1d.rb +22 -0
  31. data/lib/torch/nn/conv2d.rb +10 -20
  32. data/lib/torch/nn/conv3d.rb +22 -0
  33. data/lib/torch/nn/convnd.rb +3 -3
  34. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  35. data/lib/torch/nn/cosine_similarity.rb +15 -0
  36. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  37. data/lib/torch/nn/ctc_loss.rb +15 -0
  38. data/lib/torch/nn/dropoutnd.rb +2 -2
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/fold.rb +20 -0
  41. data/lib/torch/nn/functional.rb +379 -32
  42. data/lib/torch/nn/group_norm.rb +36 -0
  43. data/lib/torch/nn/gru.rb +49 -0
  44. data/lib/torch/nn/hardshrink.rb +18 -0
  45. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  46. data/lib/torch/nn/identity.rb +14 -0
  47. data/lib/torch/nn/init.rb +58 -1
  48. data/lib/torch/nn/instance_norm.rb +20 -0
  49. data/lib/torch/nn/instance_norm1d.rb +18 -0
  50. data/lib/torch/nn/instance_norm2d.rb +11 -0
  51. data/lib/torch/nn/instance_norm3d.rb +11 -0
  52. data/lib/torch/nn/kl_div_loss.rb +13 -0
  53. data/lib/torch/nn/l1_loss.rb +13 -0
  54. data/lib/torch/nn/layer_norm.rb +35 -0
  55. data/lib/torch/nn/leaky_relu.rb +20 -0
  56. data/lib/torch/nn/linear.rb +12 -11
  57. data/lib/torch/nn/local_response_norm.rb +21 -0
  58. data/lib/torch/nn/log_sigmoid.rb +9 -0
  59. data/lib/torch/nn/log_softmax.rb +14 -0
  60. data/lib/torch/nn/loss.rb +10 -0
  61. data/lib/torch/nn/lp_pool1d.rb +9 -0
  62. data/lib/torch/nn/lp_pool2d.rb +9 -0
  63. data/lib/torch/nn/lp_poolnd.rb +22 -0
  64. data/lib/torch/nn/lstm.rb +66 -0
  65. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  66. data/lib/torch/nn/max_pool1d.rb +9 -0
  67. data/lib/torch/nn/max_pool2d.rb +9 -0
  68. data/lib/torch/nn/max_pool3d.rb +9 -0
  69. data/lib/torch/nn/max_poolnd.rb +19 -0
  70. data/lib/torch/nn/max_unpool1d.rb +16 -0
  71. data/lib/torch/nn/max_unpool2d.rb +16 -0
  72. data/lib/torch/nn/max_unpool3d.rb +16 -0
  73. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  74. data/lib/torch/nn/module.rb +186 -35
  75. data/lib/torch/nn/mse_loss.rb +2 -2
  76. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  77. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  78. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  79. data/lib/torch/nn/nll_loss.rb +14 -0
  80. data/lib/torch/nn/pairwise_distance.rb +16 -0
  81. data/lib/torch/nn/parameter.rb +2 -2
  82. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  83. data/lib/torch/nn/prelu.rb +19 -0
  84. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  85. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  86. data/lib/torch/nn/reflection_padnd.rb +13 -0
  87. data/lib/torch/nn/relu.rb +8 -3
  88. data/lib/torch/nn/replication_pad1d.rb +10 -0
  89. data/lib/torch/nn/replication_pad2d.rb +10 -0
  90. data/lib/torch/nn/replication_pad3d.rb +10 -0
  91. data/lib/torch/nn/replication_padnd.rb +13 -0
  92. data/lib/torch/nn/rnn.rb +22 -0
  93. data/lib/torch/nn/rnn_base.rb +198 -0
  94. data/lib/torch/nn/sequential.rb +1 -10
  95. data/lib/torch/nn/sigmoid.rb +9 -0
  96. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  97. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  98. data/lib/torch/nn/softmax.rb +18 -0
  99. data/lib/torch/nn/softmax2d.rb +10 -0
  100. data/lib/torch/nn/softmin.rb +14 -0
  101. data/lib/torch/nn/softplus.rb +19 -0
  102. data/lib/torch/nn/softshrink.rb +18 -0
  103. data/lib/torch/nn/softsign.rb +9 -0
  104. data/lib/torch/nn/tanh.rb +9 -0
  105. data/lib/torch/nn/tanhshrink.rb +9 -0
  106. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  107. data/lib/torch/nn/unfold.rb +19 -0
  108. data/lib/torch/nn/utils.rb +25 -0
  109. data/lib/torch/nn/weighted_loss.rb +10 -0
  110. data/lib/torch/nn/zero_pad2d.rb +9 -0
  111. data/lib/torch/random.rb +10 -0
  112. data/lib/torch/tensor.rb +51 -44
  113. data/lib/torch/version.rb +1 -1
  114. metadata +98 -6
  115. data/lib/torch/ext.bundle +0 -0
@@ -0,0 +1,134 @@
1
+ module Torch
2
+ module Native
3
+ class Parser
4
+ def initialize(functions)
5
+ @functions = functions
6
+ @name = @functions.first.ruby_name
7
+ @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
8
+ @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
+ end
10
+
11
+ def parse(args, options)
12
+ candidates = @functions.dup
13
+
14
+ # remove nil
15
+ while args.any? && args.last.nil?
16
+ args.pop
17
+ end
18
+
19
+ # TODO account for args passed as options here
20
+ if args.size < @min_args || args.size > @max_args
21
+ expected = String.new(@min_args.to_s)
22
+ expected += "..#{@max_args}" if @max_args != @min_args
23
+ return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
24
+ end
25
+
26
+ candidates.reject! { |f| args.size > f.args.size }
27
+
28
+ # exclude functions missing required options
29
+ candidates.reject! do |func|
30
+ # TODO make more generic
31
+ func.out? && !options[:out]
32
+ end
33
+
34
+ # handle out with multiple
35
+ # there should only be one match, so safe to modify all
36
+ out_func = candidates.find { |f| f.out? }
37
+ if out_func && out_func.out_size > 1 && options[:out]
38
+ out_args = out_func.args.last(2).map { |a| a[:name] }
39
+ out_args.zip(options.delete(:out)).each do |k, v|
40
+ options[k.to_sym] = v
41
+ end
42
+ candidates = [out_func]
43
+ end
44
+
45
+ # exclude functions where options don't match
46
+ options.each do |k, v|
47
+ candidates.select! do |func|
48
+ func.args.any? { |a| a[:name] == k.to_s }
49
+ end
50
+ # TODO show all bad keywords at once like Ruby?
51
+ return {error: "unknown keyword: #{k}"} if candidates.empty?
52
+ end
53
+
54
+ final_values = {}
55
+
56
+ # check args
57
+ candidates.select! do |func|
58
+ good = true
59
+
60
+ values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
61
+ values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
62
+ func.args.each do |fa|
63
+ values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
64
+ end
65
+
66
+ arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
67
+
68
+ values.each_key do |k|
69
+ v = values[k]
70
+ t = arg_types[k].split("(").first
71
+
72
+ good =
73
+ case t
74
+ when "Tensor"
75
+ v.is_a?(Tensor)
76
+ when "Tensor?"
77
+ v.nil? || v.is_a?(Tensor)
78
+ when "Tensor[]"
79
+ v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
+ when "int"
81
+ if k == "reduction"
82
+ v.is_a?(String)
83
+ else
84
+ v.is_a?(Integer)
85
+ end
86
+ when "float"
87
+ v.is_a?(Numeric)
88
+ when /int\[.*\]/
89
+ if v.is_a?(Integer)
90
+ size = t[4..-2]
91
+ raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
92
+ v = [v] * size.to_i
93
+ values[k] = v
94
+ end
95
+ v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
96
+ when "Scalar"
97
+ v.is_a?(Numeric)
98
+ when "ScalarType?"
99
+ v.nil?
100
+ when "bool"
101
+ v == true || v == false
102
+ else
103
+ raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
104
+ end
105
+
106
+ if !good
107
+ if candidates.size == 1
108
+ k = "input" if k == "self"
109
+ return {error: "#{@name}(): argument '#{k}' must be #{t}"}
110
+ end
111
+ break
112
+ end
113
+ end
114
+
115
+ if good
116
+ final_values = values
117
+ end
118
+
119
+ good
120
+ end
121
+
122
+ if candidates.size != 1
123
+ raise Error, "This should never happen. Please report a bug with #{@name}."
124
+ end
125
+
126
+ func = candidates.first
127
+ {
128
+ name: func.cpp_name,
129
+ args: func.args.map { |a| final_values[a[:name]] }
130
+ }
131
+ end
132
+ end
133
+ end
134
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPool1d < AvgPoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true)
5
+ super()
6
+ @kernel_size = _single(kernel_size)
7
+ @stride = _single(stride || kernel_size)
8
+ @padding = _single(padding)
9
+ @ceil_mode = ceil_mode
10
+ @count_include_pad = count_include_pad
11
+ end
12
+
13
+ def forward(input)
14
+ F.avg_pool1d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPool2d < AvgPoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ @stride = stride || kernel_size
8
+ @padding = padding
9
+ @ceil_mode = ceil_mode
10
+ @count_include_pad = count_include_pad
11
+ @divisor_override = divisor_override
12
+ end
13
+
14
+ def forward(input)
15
+ F.avg_pool2d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPool3d < AvgPoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ @stride = stride || kernel_size
8
+ @padding = padding
9
+ @ceil_mode = ceil_mode
10
+ @count_include_pad = count_include_pad
11
+ @divisor_override = divisor_override
12
+ end
13
+
14
+ def forward(input)
15
+ F.avg_pool3d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPoolNd < Module
4
+ def extra_inspect
5
+ format("kernel_size: %s, stride: %s, padding: %s", @kernel_size, @stride, @padding)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,75 @@
1
+ module Torch
2
+ module NN
3
+ class BatchNorm < Module
4
+ def initialize(num_features, eps: 1e-5, momentum: 0.1, affine: true, track_running_stats: true)
5
+ super()
6
+ @num_features = num_features
7
+ @eps = eps
8
+ @momentum = momentum
9
+ @affine = affine
10
+ @track_running_stats = track_running_stats
11
+ if @affine
12
+ @weight = Parameter.new(Torch::Tensor.new(num_features))
13
+ @bias = Parameter.new(Torch::Tensor.new(num_features))
14
+ else
15
+ register_parameter("weight", nil)
16
+ register_parameter("bias", nil)
17
+ end
18
+ if track_running_stats
19
+ register_buffer("running_mean", Torch.zeros(num_features))
20
+ register_buffer("running_var", Torch.ones(num_features))
21
+ register_buffer("num_batches_tracked", Torch.tensor(0, dtype: :long))
22
+ else
23
+ register_parameter("running_mean", nil)
24
+ register_parameter("running_var", nil)
25
+ register_parameter("num_batches_tracked", nil)
26
+ end
27
+ reset_parameters
28
+ end
29
+
30
+ def reset_running_stats
31
+ if @track_running_stats
32
+ @running_mean.zero!
33
+ @running_var.fill!(1)
34
+ @num_batches_tracked.zero!
35
+ end
36
+ end
37
+
38
+ def reset_parameters
39
+ reset_running_stats
40
+ if @affine
41
+ Init.ones!(@weight)
42
+ Init.zeros!(@bias)
43
+ end
44
+ end
45
+
46
+ def forward(input)
47
+ _check_input_dim(input)
48
+
49
+ if @momentum.nil?
50
+ exponential_average_factor = 0.0
51
+ else
52
+ exponential_average_factor = @momentum
53
+ end
54
+
55
+ if @training and @track_running_stats
56
+ if @num_batches_tracked.nil?
57
+ @num_batches_tracked += 1
58
+ if @momentum.nil?
59
+ exponential_average_factor = 1.0 / @num_batches_tracked.to_f
60
+ else
61
+ exponential_average_factor = @momentum
62
+ end
63
+ end
64
+ end
65
+
66
+ F.batch_norm(
67
+ input, @running_mean, @running_var,
68
+ weight: @weight, bias: @bias,
69
+ training: @training || !@track_running_stats,
70
+ momentum: exponential_average_factor, eps: @eps
71
+ )
72
+ end
73
+ end
74
+ end
75
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class BatchNorm1d < BatchNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 2 && input.dim != 3
6
+ raise ArgumentError, "expected 2D or 3D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class BatchNorm2d < BatchNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 4
6
+ raise ArgumentError, "expected 4D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class BatchNorm3d < BatchNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 5
6
+ raise ArgumentError, "expected 5D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class BCELoss < WeightedLoss
4
+ def initialize(weight: nil, reduction: "mean")
5
+ super(weight, reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.binary_cross_entropy(input, target, weight: @weight, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class BCEWithLogitsLoss < Loss
4
+ def initialize(weight: nil, reduction: "mean", pos_weight: nil)
5
+ super(reduction)
6
+ register_buffer("weight", weight)
7
+ register_buffer("pos_weight", pos_weight)
8
+ end
9
+
10
+ def forward(input, target)
11
+ F.binary_cross_entropy_with_logits(input, target, weight: weight, pos_weight: pos_weight, reduction: @reduction)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,38 @@
1
+ module Torch
2
+ module NN
3
+ class Bilinear < Module
4
+ def initialize(in1_features, in2_features, out_features, bias: true)
5
+ super()
6
+
7
+ @in1_features = in1_features
8
+ @in2_features = in2_features
9
+ @out_features = out_features
10
+ @weight = Parameter.new(Tensor.new(out_features, in1_features, in2_features))
11
+
12
+ if bias
13
+ @bias = Parameter.new(Tensor.new(out_features))
14
+ else
15
+ raise NotImplementedYet
16
+ end
17
+
18
+ reset_parameters
19
+ end
20
+
21
+ def reset_parameters
22
+ bound = 1 / Math.sqrt(@weight.size(1))
23
+ Init.uniform!(@weight, a: -bound, b: bound)
24
+ if @bias
25
+ Init.uniform!(@bias, a: -bound, b: bound)
26
+ end
27
+ end
28
+
29
+ def forward(input1, input2)
30
+ F.bilinear(input1, input2, @weight, @bias)
31
+ end
32
+
33
+ def extra_inspect
34
+ format("in1_features: %s, in2_features: %s, out_features: %s, bias: %s", @in1_features, @in2_features, @out_features, !@bias.nil?)
35
+ end
36
+ end
37
+ end
38
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ConstantPad1d < ConstantPadNd
4
+ def initialize(padding, value)
5
+ super(value)
6
+ @padding = _pair(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ConstantPad2d < ConstantPadNd
4
+ def initialize(padding, value)
5
+ super(value)
6
+ @padding = _quadrupal(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ConstantPad3d < ConstantPadNd
4
+ def initialize(padding, value)
5
+ super(value)
6
+ @padding = _ntuple(6, padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class ConstantPadNd < Module
4
+ def initialize(value)
5
+ super()
6
+ @value = value
7
+ end
8
+
9
+ def forward(input)
10
+ F.pad(input, @padding, mode: "constant", value: @value)
11
+ end
12
+
13
+ def extra_inspect
14
+ format("padding: %s, value: %s", @padding, @value)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module NN
3
+ class Conv1d < ConvNd
4
+ def initialize(in_channels, out_channels, kernel_size, stride: 1,
5
+ padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
6
+
7
+ kernel_size = _single(kernel_size)
8
+ stride = _single(stride)
9
+ padding = _single(padding)
10
+ dilation = _single(dilation)
11
+ super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, _single(0), groups, bias, padding_mode)
12
+ end
13
+
14
+ def forward(input)
15
+ if @padding_mode == "circular"
16
+ raise NotImplementedError
17
+ end
18
+ F.conv1d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
+ end
20
+ end
21
+ end
22
+ end