torch-rb 0.1.2 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +18 -6
  5. data/ext/torch/ext.cpp +148 -369
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +242 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +240 -131
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +27 -22
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -38
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +411 -22
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +201 -20
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +2 -2
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +56 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +48 -16
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +71 -30
  139. data/lib/torch/utils/data/data_loader.rb +10 -4
  140. data/lib/torch/utils/data/tensor_dataset.rb +3 -0
  141. data/lib/torch/version.rb +1 -1
  142. metadata +123 -6
@@ -0,0 +1,134 @@
1
+ module Torch
2
+ module Native
3
+ class Parser
4
+ def initialize(functions)
5
+ @functions = functions
6
+ @name = @functions.first.ruby_name
7
+ @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
8
+ @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
+ end
10
+
11
+ def parse(args, options)
12
+ candidates = @functions.dup
13
+
14
+ # remove nil
15
+ while args.any? && args.last.nil?
16
+ args.pop
17
+ end
18
+
19
+ # TODO account for args passed as options here
20
+ if args.size < @min_args || args.size > @max_args
21
+ expected = String.new(@min_args.to_s)
22
+ expected += "..#{@max_args}" if @max_args != @min_args
23
+ return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
24
+ end
25
+
26
+ candidates.reject! { |f| args.size > f.args.size }
27
+
28
+ # exclude functions missing required options
29
+ candidates.reject! do |func|
30
+ # TODO make more generic
31
+ func.out? && !options[:out]
32
+ end
33
+
34
+ # handle out with multiple
35
+ # there should only be one match, so safe to modify all
36
+ out_func = candidates.find { |f| f.out? }
37
+ if out_func && out_func.out_size > 1 && options[:out]
38
+ out_args = out_func.args.last(2).map { |a| a[:name] }
39
+ out_args.zip(options.delete(:out)).each do |k, v|
40
+ options[k.to_sym] = v
41
+ end
42
+ candidates = [out_func]
43
+ end
44
+
45
+ # exclude functions where options don't match
46
+ options.each do |k, v|
47
+ candidates.select! do |func|
48
+ func.args.any? { |a| a[:name] == k.to_s }
49
+ end
50
+ # TODO show all bad keywords at once like Ruby?
51
+ return {error: "unknown keyword: #{k}"} if candidates.empty?
52
+ end
53
+
54
+ final_values = {}
55
+
56
+ # check args
57
+ candidates.select! do |func|
58
+ good = true
59
+
60
+ values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
61
+ values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
62
+ func.args.each do |fa|
63
+ values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
64
+ end
65
+
66
+ arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
67
+
68
+ values.each_key do |k|
69
+ v = values[k]
70
+ t = arg_types[k].split("(").first
71
+
72
+ good =
73
+ case t
74
+ when "Tensor"
75
+ v.is_a?(Tensor)
76
+ when "Tensor?"
77
+ v.nil? || v.is_a?(Tensor)
78
+ when "Tensor[]"
79
+ v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
+ when "int"
81
+ if k == "reduction"
82
+ v.is_a?(String)
83
+ else
84
+ v.is_a?(Integer)
85
+ end
86
+ when "float"
87
+ v.is_a?(Numeric)
88
+ when /int\[.*\]/
89
+ if v.is_a?(Integer)
90
+ size = t[4..-2]
91
+ raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
92
+ v = [v] * size.to_i
93
+ values[k] = v
94
+ end
95
+ v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
96
+ when "Scalar"
97
+ v.is_a?(Numeric)
98
+ when "ScalarType?"
99
+ v.nil?
100
+ when "bool"
101
+ v == true || v == false
102
+ else
103
+ raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
104
+ end
105
+
106
+ if !good
107
+ if candidates.size == 1
108
+ k = "input" if k == "self"
109
+ return {error: "#{@name}(): argument '#{k}' must be #{t}"}
110
+ end
111
+ break
112
+ end
113
+ end
114
+
115
+ if good
116
+ final_values = values
117
+ end
118
+
119
+ good
120
+ end
121
+
122
+ if candidates.size != 1
123
+ raise Error, "This should never happen. Please report a bug with #{@name}."
124
+ end
125
+
126
+ func = candidates.first
127
+ {
128
+ name: func.cpp_name,
129
+ args: func.args.map { |a| final_values[a[:name]] }
130
+ }
131
+ end
132
+ end
133
+ end
134
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class AlphaDropout < DropoutNd
4
+ def forward(input)
5
+ F.alpha_dropout(input, p: @p, training: @training, inplace: @inplace)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPool1d < AvgPoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true)
5
+ super()
6
+ @kernel_size = _single(kernel_size)
7
+ @stride = _single(stride || kernel_size)
8
+ @padding = _single(padding)
9
+ @ceil_mode = ceil_mode
10
+ @count_include_pad = count_include_pad
11
+ end
12
+
13
+ def forward(input)
14
+ F.avg_pool1d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPool2d < AvgPoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ @stride = stride || kernel_size
8
+ @padding = padding
9
+ @ceil_mode = ceil_mode
10
+ @count_include_pad = count_include_pad
11
+ @divisor_override = divisor_override
12
+ end
13
+
14
+ def forward(input)
15
+ F.avg_pool2d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPool3d < AvgPoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ @stride = stride || kernel_size
8
+ @padding = padding
9
+ @ceil_mode = ceil_mode
10
+ @count_include_pad = count_include_pad
11
+ @divisor_override = divisor_override
12
+ end
13
+
14
+ def forward(input)
15
+ F.avg_pool3d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPoolNd < Module
4
+ def extra_inspect
5
+ format("kernel_size: %s, stride: %s, padding: %s", @kernel_size, @stride, @padding)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,75 @@
1
+ module Torch
2
+ module NN
3
+ class BatchNorm < Module
4
+ def initialize(num_features, eps: 1e-5, momentum: 0.1, affine: true, track_running_stats: true)
5
+ super()
6
+ @num_features = num_features
7
+ @eps = eps
8
+ @momentum = momentum
9
+ @affine = affine
10
+ @track_running_stats = track_running_stats
11
+ if @affine
12
+ @weight = Parameter.new(Torch::Tensor.new(num_features))
13
+ @bias = Parameter.new(Torch::Tensor.new(num_features))
14
+ else
15
+ register_parameter("weight", nil)
16
+ register_parameter("bias", nil)
17
+ end
18
+ if track_running_stats
19
+ register_buffer("running_mean", Torch.zeros(num_features))
20
+ register_buffer("running_var", Torch.ones(num_features))
21
+ register_buffer("num_batches_tracked", Torch.tensor(0, dtype: :long))
22
+ else
23
+ register_parameter("running_mean", nil)
24
+ register_parameter("running_var", nil)
25
+ register_parameter("num_batches_tracked", nil)
26
+ end
27
+ reset_parameters
28
+ end
29
+
30
+ def reset_running_stats
31
+ if @track_running_stats
32
+ @running_mean.zero!
33
+ @running_var.fill!(1)
34
+ @num_batches_tracked.zero!
35
+ end
36
+ end
37
+
38
+ def reset_parameters
39
+ reset_running_stats
40
+ if @affine
41
+ Init.ones!(@weight)
42
+ Init.zeros!(@bias)
43
+ end
44
+ end
45
+
46
+ def forward(input)
47
+ _check_input_dim(input)
48
+
49
+ if @momentum.nil?
50
+ exponential_average_factor = 0.0
51
+ else
52
+ exponential_average_factor = @momentum
53
+ end
54
+
55
+ if @training and @track_running_stats
56
+ if @num_batches_tracked.nil?
57
+ @num_batches_tracked += 1
58
+ if @momentum.nil?
59
+ exponential_average_factor = 1.0 / @num_batches_tracked.to_f
60
+ else
61
+ exponential_average_factor = @momentum
62
+ end
63
+ end
64
+ end
65
+
66
+ F.batch_norm(
67
+ input, @running_mean, @running_var,
68
+ weight: @weight, bias: @bias,
69
+ training: @training || !@track_running_stats,
70
+ momentum: exponential_average_factor, eps: @eps
71
+ )
72
+ end
73
+ end
74
+ end
75
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class BatchNorm1d < BatchNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 2 && input.dim != 3
6
+ raise ArgumentError, "expected 2D or 3D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class BatchNorm2d < BatchNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 4
6
+ raise ArgumentError, "expected 4D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module Torch
2
+ module NN
3
+ class BatchNorm3d < BatchNorm
4
+ def _check_input_dim(input)
5
+ if input.dim != 5
6
+ raise ArgumentError, "expected 5D input (got #{input.dim}D input)"
7
+ end
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class BCELoss < WeightedLoss
4
+ def initialize(weight: nil, reduction: "mean")
5
+ super(weight, reduction)
6
+ end
7
+
8
+ def forward(input, target)
9
+ F.binary_cross_entropy(input, target, weight: @weight, reduction: @reduction)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,15 @@
1
+ module Torch
2
+ module NN
3
+ class BCEWithLogitsLoss < Loss
4
+ def initialize(weight: nil, reduction: "mean", pos_weight: nil)
5
+ super(reduction)
6
+ register_buffer("weight", weight)
7
+ register_buffer("pos_weight", pos_weight)
8
+ end
9
+
10
+ def forward(input, target)
11
+ F.binary_cross_entropy_with_logits(input, target, weight: weight, pos_weight: pos_weight, reduction: @reduction)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,38 @@
1
+ module Torch
2
+ module NN
3
+ class Bilinear < Module
4
+ def initialize(in1_features, in2_features, out_features, bias: true)
5
+ super()
6
+
7
+ @in1_features = in1_features
8
+ @in2_features = in2_features
9
+ @out_features = out_features
10
+ @weight = Parameter.new(Tensor.new(out_features, in1_features, in2_features))
11
+
12
+ if bias
13
+ @bias = Parameter.new(Tensor.new(out_features))
14
+ else
15
+ raise NotImplementedYet
16
+ end
17
+
18
+ reset_parameters
19
+ end
20
+
21
+ def reset_parameters
22
+ bound = 1 / Math.sqrt(@weight.size(1))
23
+ Init.uniform!(@weight, a: -bound, b: bound)
24
+ if @bias
25
+ Init.uniform!(@bias, a: -bound, b: bound)
26
+ end
27
+ end
28
+
29
+ def forward(input1, input2)
30
+ F.bilinear(input1, input2, @weight, @bias)
31
+ end
32
+
33
+ def extra_inspect
34
+ format("in1_features: %s, in2_features: %s, out_features: %s, bias: %s", @in1_features, @in2_features, @out_features, !@bias.nil?)
35
+ end
36
+ end
37
+ end
38
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ConstantPad1d < ConstantPadNd
4
+ def initialize(padding, value)
5
+ super(value)
6
+ @padding = _pair(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ConstantPad2d < ConstantPadNd
4
+ def initialize(padding, value)
5
+ super(value)
6
+ @padding = _quadrupal(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ConstantPad3d < ConstantPadNd
4
+ def initialize(padding, value)
5
+ super(value)
6
+ @padding = _ntuple(6, padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class ConstantPadNd < Module
4
+ def initialize(value)
5
+ super()
6
+ @value = value
7
+ end
8
+
9
+ def forward(input)
10
+ F.pad(input, @padding, mode: "constant", value: @value)
11
+ end
12
+
13
+ def extra_inspect
14
+ format("padding: %s, value: %s", @padding, @value)
15
+ end
16
+ end
17
+ end
18
+ end