torch-rb 0.1.5 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/README.md +1 -1
  4. data/ext/torch/ext.cpp +0 -170
  5. data/ext/torch/nn_functions.cpp +44 -24
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +48 -0
  8. data/ext/torch/tensor_functions.cpp +76 -16
  9. data/ext/torch/torch_functions.cpp +165 -65
  10. data/lib/torch.rb +51 -42
  11. data/lib/torch/ext.bundle +0 -0
  12. data/lib/torch/native/dispatcher.rb +1 -1
  13. data/lib/torch/native/function.rb +36 -5
  14. data/lib/torch/native/generator.rb +26 -7
  15. data/lib/torch/native/parser.rb +51 -14
  16. data/lib/torch/nn/avg_pool1d.rb +18 -0
  17. data/lib/torch/nn/avg_pool2d.rb +7 -2
  18. data/lib/torch/nn/avg_pool3d.rb +19 -0
  19. data/lib/torch/nn/avg_poolnd.rb +1 -1
  20. data/lib/torch/nn/batch_norm.rb +75 -0
  21. data/lib/torch/nn/batch_norm1d.rb +11 -0
  22. data/lib/torch/nn/batch_norm2d.rb +11 -0
  23. data/lib/torch/nn/batch_norm3d.rb +11 -0
  24. data/lib/torch/nn/constant_pad1d.rb +10 -0
  25. data/lib/torch/nn/constant_pad2d.rb +10 -0
  26. data/lib/torch/nn/constant_pad3d.rb +10 -0
  27. data/lib/torch/nn/constant_padnd.rb +18 -0
  28. data/lib/torch/nn/conv1d.rb +22 -0
  29. data/lib/torch/nn/conv2d.rb +9 -17
  30. data/lib/torch/nn/conv3d.rb +22 -0
  31. data/lib/torch/nn/fold.rb +20 -0
  32. data/lib/torch/nn/functional.rb +320 -100
  33. data/lib/torch/nn/group_norm.rb +36 -0
  34. data/lib/torch/nn/gru.rb +49 -0
  35. data/lib/torch/nn/hardshrink.rb +18 -0
  36. data/lib/torch/nn/instance_norm.rb +20 -0
  37. data/lib/torch/nn/instance_norm1d.rb +18 -0
  38. data/lib/torch/nn/instance_norm2d.rb +11 -0
  39. data/lib/torch/nn/instance_norm3d.rb +11 -0
  40. data/lib/torch/nn/layer_norm.rb +35 -0
  41. data/lib/torch/nn/local_response_norm.rb +21 -0
  42. data/lib/torch/nn/log_sigmoid.rb +9 -0
  43. data/lib/torch/nn/lp_pool1d.rb +9 -0
  44. data/lib/torch/nn/lp_pool2d.rb +9 -0
  45. data/lib/torch/nn/lp_poolnd.rb +22 -0
  46. data/lib/torch/nn/lstm.rb +66 -0
  47. data/lib/torch/nn/max_pool1d.rb +9 -0
  48. data/lib/torch/nn/max_pool2d.rb +1 -1
  49. data/lib/torch/nn/max_pool3d.rb +9 -0
  50. data/lib/torch/nn/max_poolnd.rb +6 -6
  51. data/lib/torch/nn/max_unpool1d.rb +16 -0
  52. data/lib/torch/nn/max_unpool2d.rb +16 -0
  53. data/lib/torch/nn/max_unpool3d.rb +16 -0
  54. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  55. data/lib/torch/nn/module.rb +7 -0
  56. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  57. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  58. data/lib/torch/nn/reflection_padnd.rb +13 -0
  59. data/lib/torch/nn/replication_pad1d.rb +10 -0
  60. data/lib/torch/nn/replication_pad2d.rb +10 -0
  61. data/lib/torch/nn/replication_pad3d.rb +10 -0
  62. data/lib/torch/nn/replication_padnd.rb +13 -0
  63. data/lib/torch/nn/rnn_base.rb +48 -4
  64. data/lib/torch/nn/softshrink.rb +18 -0
  65. data/lib/torch/nn/softsign.rb +9 -0
  66. data/lib/torch/nn/tanh.rb +9 -0
  67. data/lib/torch/nn/tanhshrink.rb +9 -0
  68. data/lib/torch/nn/unfold.rb +19 -0
  69. data/lib/torch/nn/utils.rb +25 -0
  70. data/lib/torch/nn/zero_pad2d.rb +9 -0
  71. data/lib/torch/tensor.rb +14 -25
  72. data/lib/torch/version.rb +1 -1
  73. metadata +50 -2
data/lib/torch.rb CHANGED
@@ -29,6 +29,7 @@ require "torch/optim/lr_scheduler/step_lr"
29
29
 
30
30
  # nn parameters
31
31
  require "torch/nn/parameter"
32
+ require "torch/nn/utils"
32
33
 
33
34
  # nn containers
34
35
  require "torch/nn/module"
@@ -36,17 +37,61 @@ require "torch/nn/sequential"
36
37
 
37
38
  # nn convolution layers
38
39
  require "torch/nn/convnd"
40
+ require "torch/nn/conv1d"
39
41
  require "torch/nn/conv2d"
42
+ require "torch/nn/conv3d"
43
+ require "torch/nn/unfold"
44
+ require "torch/nn/fold"
40
45
 
41
46
  # nn pooling layers
42
47
  require "torch/nn/max_poolnd"
48
+ require "torch/nn/max_pool1d"
43
49
  require "torch/nn/max_pool2d"
50
+ require "torch/nn/max_pool3d"
51
+ require "torch/nn/max_unpoolnd"
52
+ require "torch/nn/max_unpool1d"
53
+ require "torch/nn/max_unpool2d"
54
+ require "torch/nn/max_unpool3d"
44
55
  require "torch/nn/avg_poolnd"
56
+ require "torch/nn/avg_pool1d"
45
57
  require "torch/nn/avg_pool2d"
58
+ require "torch/nn/avg_pool3d"
59
+ require "torch/nn/lp_poolnd"
60
+ require "torch/nn/lp_pool1d"
61
+ require "torch/nn/lp_pool2d"
62
+
63
+ # nn padding layers
64
+ require "torch/nn/reflection_padnd"
65
+ require "torch/nn/reflection_pad1d"
66
+ require "torch/nn/reflection_pad2d"
67
+ require "torch/nn/replication_padnd"
68
+ require "torch/nn/replication_pad1d"
69
+ require "torch/nn/replication_pad2d"
70
+ require "torch/nn/replication_pad3d"
71
+ require "torch/nn/constant_padnd"
72
+ require "torch/nn/constant_pad1d"
73
+ require "torch/nn/constant_pad2d"
74
+ require "torch/nn/constant_pad3d"
75
+ require "torch/nn/zero_pad2d"
76
+
77
+ # nn normalization layers
78
+ require "torch/nn/batch_norm"
79
+ require "torch/nn/batch_norm1d"
80
+ require "torch/nn/batch_norm2d"
81
+ require "torch/nn/batch_norm3d"
82
+ require "torch/nn/group_norm"
83
+ require "torch/nn/instance_norm"
84
+ require "torch/nn/instance_norm1d"
85
+ require "torch/nn/instance_norm2d"
86
+ require "torch/nn/instance_norm3d"
87
+ require "torch/nn/layer_norm"
88
+ require "torch/nn/local_response_norm"
46
89
 
47
90
  # nn recurrent layers
48
91
  require "torch/nn/rnn_base"
49
92
  require "torch/nn/rnn"
93
+ require "torch/nn/lstm"
94
+ require "torch/nn/gru"
50
95
 
51
96
  # nn linear layers
52
97
  require "torch/nn/bilinear"
@@ -62,11 +107,17 @@ require "torch/nn/dropout3d"
62
107
  require "torch/nn/feature_alpha_dropout"
63
108
 
64
109
  # nn activations
110
+ require "torch/nn/hardshrink"
65
111
  require "torch/nn/leaky_relu"
112
+ require "torch/nn/log_sigmoid"
66
113
  require "torch/nn/prelu"
67
114
  require "torch/nn/relu"
68
115
  require "torch/nn/sigmoid"
69
116
  require "torch/nn/softplus"
117
+ require "torch/nn/softshrink"
118
+ require "torch/nn/softsign"
119
+ require "torch/nn/tanh"
120
+ require "torch/nn/tanhshrink"
70
121
 
71
122
  # nn activations other
72
123
  require "torch/nn/log_softmax"
@@ -364,48 +415,6 @@ module Torch
364
415
  zeros(input.size, like_options(input, options))
365
416
  end
366
417
 
367
- # --- begin operations ---
368
-
369
- # TODO support out
370
- def mean(input, dim = nil, keepdim: false)
371
- if dim
372
- _mean_dim(input, dim, keepdim)
373
- else
374
- _mean(input)
375
- end
376
- end
377
-
378
- # TODO support dtype
379
- def sum(input, dim = nil, keepdim: false)
380
- if dim
381
- _sum_dim(input, dim, keepdim)
382
- else
383
- _sum(input)
384
- end
385
- end
386
-
387
- def topk(input, k)
388
- _topk(input, k)
389
- end
390
-
391
- def max(input, dim = nil, keepdim: false, out: nil)
392
- if dim
393
- raise NotImplementedYet unless out
394
- _max_out(out[0], out[1], input, dim, keepdim)
395
- else
396
- _max(input)
397
- end
398
- end
399
-
400
- # TODO make dim keyword argument
401
- def log_softmax(input, dim)
402
- _log_softmax(input, dim)
403
- end
404
-
405
- def softmax(input, dim: nil)
406
- _softmax(input, dim)
407
- end
408
-
409
418
  private
410
419
 
411
420
  def tensor_size(size)
data/lib/torch/ext.bundle CHANGED
Binary file
@@ -18,7 +18,7 @@ module Torch
18
18
  functions = Generator.grouped_functions
19
19
  bind_functions(::Torch, :define_singleton_method, functions[:torch])
20
20
  bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
21
- # NN functions are internal, so no need to bind
21
+ bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
22
22
  end
23
23
 
24
24
  def bind_functions(context, def_method, functions)
@@ -35,11 +35,38 @@ module Torch
35
35
  end
36
36
  t, _, k = a.rpartition(" ")
37
37
  k, d = k.split("=")
38
- d = d.to_i if d.to_i.to_s == d
39
- d = true if d == "True"
40
- d = false if d == "False"
41
- d = nil if d == "None"
42
- args << {name: k, type: t, default: d, pos: pos}
38
+ has_default = !d.nil?
39
+
40
+ if d
41
+ d =
42
+ case d
43
+ when "True"
44
+ true
45
+ when "False"
46
+ false
47
+ when "None"
48
+ nil
49
+ when /\A\-?\d+\z/
50
+ d.to_i
51
+ when "[]"
52
+ []
53
+ when "[0,1]"
54
+ [0, 1]
55
+ when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
56
+ d.to_f
57
+ when "Mean"
58
+ "mean"
59
+ when "contiguous_format"
60
+ d
61
+ when "long"
62
+ :long
63
+ else
64
+ raise "Unknown default: #{d}"
65
+ end
66
+ end
67
+
68
+ next if t == "Generator?"
69
+ args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
43
70
  end
44
71
  args
45
72
  end
@@ -49,6 +76,10 @@ module Torch
49
76
  @out_size ||= func.split("->").last.count("!")
50
77
  end
51
78
 
79
+ def ret_size
80
+ @ret_size ||= func.split("->").last.split(", ").size
81
+ end
82
+
52
83
  def out?
53
84
  out_size > 0 && base_name[-1] != "_"
54
85
  end
@@ -17,14 +17,23 @@ module Torch
17
17
  def grouped_functions
18
18
  functions = functions()
19
19
 
20
- # remove functions
20
+ # skip functions
21
21
  skip_binding = ["unique_dim_consecutive", "einsum", "normal"]
22
- skip_args = ["bool[3]", "Dimname", "ScalarType", "MemoryFormat", "Storage", "ConstQuantizerPtr"]
23
- functions.reject! { |f| f.ruby_name.start_with?("_") || f.ruby_name.end_with?("_backward") || skip_binding.include?(f.ruby_name) }
22
+ skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"]
23
+
24
+ # remove functions
25
+ functions.reject! do |f|
26
+ f.ruby_name.start_with?("_") ||
27
+ f.ruby_name.end_with?("_backward") ||
28
+ skip_binding.include?(f.ruby_name) ||
29
+ f.args.any? { |a| a[:type].include?("Dimname") }
30
+ end
31
+
32
+ # separate out into todo
24
33
  todo_functions, functions =
25
34
  functions.partition do |f|
26
35
  f.args.any? do |a|
27
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?"].include?(a[:type]) ||
36
+ a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
28
37
  skip_args.any? { |sa| a[:type].include?(sa) }
29
38
  end
30
39
  end
@@ -33,7 +42,9 @@ module Torch
33
42
  # there may be a better way to do this
34
43
  optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
35
44
  optional_functions.each do |f|
36
- next if f.ruby_name.start_with?("avg_pool") || f.ruby_name == "cross"
45
+ next if f.ruby_name == "cross"
46
+ next if f.ruby_name.start_with?("avg_pool") && f.out?
47
+
37
48
  opt_args = f.args.select { |a| a[:type] == "int?" }
38
49
  if opt_args.size == 1
39
50
  sep = f.name.include?(".") ? "_" : "."
@@ -85,7 +96,7 @@ void add_%{type}_functions(Module m) {
85
96
 
86
97
  cpp_defs = []
87
98
  functions.sort_by(&:cpp_name).each do |func|
88
- fargs = func.args.select { |a| a[:type] != "Generator?" }
99
+ fargs = func.args #.select { |a| a[:type] != "Generator?" }
89
100
 
90
101
  cpp_args = []
91
102
  fargs.each do |a|
@@ -96,6 +107,8 @@ void add_%{type}_functions(Module m) {
96
107
  when "Tensor?"
97
108
  # TODO better signature
98
109
  "OptionalTensor"
110
+ when "ScalarType?"
111
+ "OptionalScalarType"
99
112
  when "Tensor[]"
100
113
  "TensorList"
101
114
  when "int"
@@ -121,10 +134,16 @@ void add_%{type}_functions(Module m) {
121
134
 
122
135
  prefix = def_method == :define_method ? "self." : "torch::"
123
136
 
137
+ body = "#{prefix}#{dispatch}(#{args.join(", ")})"
138
+ # TODO check type as well
139
+ if func.ret_size > 1
140
+ body = "wrap(#{body})"
141
+ end
142
+
124
143
  cpp_defs << ".#{def_method}(
125
144
  \"#{func.cpp_name}\",
126
145
  *[](#{cpp_args.join(", ")}) {
127
- return #{prefix}#{dispatch}(#{args.join(", ")});
146
+ return #{body};
128
147
  })"
129
148
  end
130
149
 
@@ -4,19 +4,44 @@ module Torch
4
4
  def initialize(functions)
5
5
  @functions = functions
6
6
  @name = @functions.first.ruby_name
7
- @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && a[:default].nil? } }.min
7
+ @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
8
8
  @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
9
  end
10
10
 
11
11
  def parse(args, options)
12
12
  candidates = @functions.dup
13
13
 
14
+ # remove nil
15
+ while args.any? && args.last.nil?
16
+ args.pop
17
+ end
18
+
19
+ # TODO account for args passed as options here
14
20
  if args.size < @min_args || args.size > @max_args
15
21
  expected = String.new(@min_args.to_s)
16
22
  expected += "..#{@max_args}" if @max_args != @min_args
17
23
  return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
18
24
  end
19
25
 
26
+ candidates.reject! { |f| args.size > f.args.size }
27
+
28
+ # exclude functions missing required options
29
+ candidates.reject! do |func|
30
+ # TODO make more generic
31
+ func.out? && !options[:out]
32
+ end
33
+
34
+ # handle out with multiple
35
+ # there should only be one match, so safe to modify all
36
+ out_func = candidates.find { |f| f.out? }
37
+ if out_func && out_func.out_size > 1 && options[:out]
38
+ out_args = out_func.args.last(2).map { |a| a[:name] }
39
+ out_args.zip(options.delete(:out)).each do |k, v|
40
+ options[k.to_sym] = v
41
+ end
42
+ candidates = [out_func]
43
+ end
44
+
20
45
  # exclude functions where options don't match
21
46
  options.each do |k, v|
22
47
  candidates.select! do |func|
@@ -26,12 +51,6 @@ module Torch
26
51
  return {error: "unknown keyword: #{k}"} if candidates.empty?
27
52
  end
28
53
 
29
- # exclude functions missing required options
30
- candidates.reject! do |func|
31
- # TODO make more generic
32
- func.out? && !options[:out]
33
- end
34
-
35
54
  final_values = {}
36
55
 
37
56
  # check args
@@ -41,29 +60,47 @@ module Torch
41
60
  values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
42
61
  values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
43
62
  func.args.each do |fa|
44
- values[fa[:name]] ||= fa[:default]
63
+ values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
45
64
  end
46
65
 
47
66
  arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
48
67
 
49
- values.each do |k, v|
68
+ values.each_key do |k|
69
+ v = values[k]
50
70
  t = arg_types[k].split("(").first
71
+
51
72
  good =
52
73
  case t
53
74
  when "Tensor"
54
75
  v.is_a?(Tensor)
76
+ when "Tensor?"
77
+ v.nil? || v.is_a?(Tensor)
55
78
  when "Tensor[]"
56
- v.all? { |v2| v2.is_a?(Tensor) }
79
+ v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
57
80
  when "int"
58
- v.is_a?(Integer)
59
- when "int[]"
60
- v.all? { |v2| v2.is_a?(Integer) }
81
+ if k == "reduction"
82
+ v.is_a?(String)
83
+ else
84
+ v.is_a?(Integer)
85
+ end
86
+ when "float"
87
+ v.is_a?(Numeric)
88
+ when /int\[.*\]/
89
+ if v.is_a?(Integer)
90
+ size = t[4..-2]
91
+ raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
92
+ v = [v] * size.to_i
93
+ values[k] = v
94
+ end
95
+ v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
61
96
  when "Scalar"
62
97
  v.is_a?(Numeric)
98
+ when "ScalarType?"
99
+ v.nil?
63
100
  when "bool"
64
101
  v == true || v == false
65
102
  else
66
- raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}"
103
+ raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
67
104
  end
68
105
 
69
106
  if !good
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPool1d < AvgPoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true)
5
+ super()
6
+ @kernel_size = _single(kernel_size)
7
+ @stride = _single(stride || kernel_size)
8
+ @padding = _single(padding)
9
+ @ceil_mode = ceil_mode
10
+ @count_include_pad = count_include_pad
11
+ end
12
+
13
+ def forward(input)
14
+ F.avg_pool1d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -1,13 +1,18 @@
1
1
  module Torch
2
2
  module NN
3
3
  class AvgPool2d < AvgPoolNd
4
- def initialize(kernel_size)
4
+ def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
5
5
  super()
6
6
  @kernel_size = kernel_size
7
+ @stride = stride || kernel_size
8
+ @padding = padding
9
+ @ceil_mode = ceil_mode
10
+ @count_include_pad = count_include_pad
11
+ @divisor_override = divisor_override
7
12
  end
8
13
 
9
14
  def forward(input)
10
- F.avg_pool2d(input, @kernel_size)
15
+ F.avg_pool2d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
11
16
  end
12
17
  end
13
18
  end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPool3d < AvgPoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ @stride = stride || kernel_size
8
+ @padding = padding
9
+ @ceil_mode = ceil_mode
10
+ @count_include_pad = count_include_pad
11
+ @divisor_override = divisor_override
12
+ end
13
+
14
+ def forward(input)
15
+ F.avg_pool3d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
16
+ end
17
+ end
18
+ end
19
+ end