torch-rb 0.1.5 → 0.1.6

Sign up to get free protection for your applications and to get access to all the features.
Files changed (73) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/README.md +1 -1
  4. data/ext/torch/ext.cpp +0 -170
  5. data/ext/torch/nn_functions.cpp +44 -24
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +48 -0
  8. data/ext/torch/tensor_functions.cpp +76 -16
  9. data/ext/torch/torch_functions.cpp +165 -65
  10. data/lib/torch.rb +51 -42
  11. data/lib/torch/ext.bundle +0 -0
  12. data/lib/torch/native/dispatcher.rb +1 -1
  13. data/lib/torch/native/function.rb +36 -5
  14. data/lib/torch/native/generator.rb +26 -7
  15. data/lib/torch/native/parser.rb +51 -14
  16. data/lib/torch/nn/avg_pool1d.rb +18 -0
  17. data/lib/torch/nn/avg_pool2d.rb +7 -2
  18. data/lib/torch/nn/avg_pool3d.rb +19 -0
  19. data/lib/torch/nn/avg_poolnd.rb +1 -1
  20. data/lib/torch/nn/batch_norm.rb +75 -0
  21. data/lib/torch/nn/batch_norm1d.rb +11 -0
  22. data/lib/torch/nn/batch_norm2d.rb +11 -0
  23. data/lib/torch/nn/batch_norm3d.rb +11 -0
  24. data/lib/torch/nn/constant_pad1d.rb +10 -0
  25. data/lib/torch/nn/constant_pad2d.rb +10 -0
  26. data/lib/torch/nn/constant_pad3d.rb +10 -0
  27. data/lib/torch/nn/constant_padnd.rb +18 -0
  28. data/lib/torch/nn/conv1d.rb +22 -0
  29. data/lib/torch/nn/conv2d.rb +9 -17
  30. data/lib/torch/nn/conv3d.rb +22 -0
  31. data/lib/torch/nn/fold.rb +20 -0
  32. data/lib/torch/nn/functional.rb +320 -100
  33. data/lib/torch/nn/group_norm.rb +36 -0
  34. data/lib/torch/nn/gru.rb +49 -0
  35. data/lib/torch/nn/hardshrink.rb +18 -0
  36. data/lib/torch/nn/instance_norm.rb +20 -0
  37. data/lib/torch/nn/instance_norm1d.rb +18 -0
  38. data/lib/torch/nn/instance_norm2d.rb +11 -0
  39. data/lib/torch/nn/instance_norm3d.rb +11 -0
  40. data/lib/torch/nn/layer_norm.rb +35 -0
  41. data/lib/torch/nn/local_response_norm.rb +21 -0
  42. data/lib/torch/nn/log_sigmoid.rb +9 -0
  43. data/lib/torch/nn/lp_pool1d.rb +9 -0
  44. data/lib/torch/nn/lp_pool2d.rb +9 -0
  45. data/lib/torch/nn/lp_poolnd.rb +22 -0
  46. data/lib/torch/nn/lstm.rb +66 -0
  47. data/lib/torch/nn/max_pool1d.rb +9 -0
  48. data/lib/torch/nn/max_pool2d.rb +1 -1
  49. data/lib/torch/nn/max_pool3d.rb +9 -0
  50. data/lib/torch/nn/max_poolnd.rb +6 -6
  51. data/lib/torch/nn/max_unpool1d.rb +16 -0
  52. data/lib/torch/nn/max_unpool2d.rb +16 -0
  53. data/lib/torch/nn/max_unpool3d.rb +16 -0
  54. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  55. data/lib/torch/nn/module.rb +7 -0
  56. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  57. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  58. data/lib/torch/nn/reflection_padnd.rb +13 -0
  59. data/lib/torch/nn/replication_pad1d.rb +10 -0
  60. data/lib/torch/nn/replication_pad2d.rb +10 -0
  61. data/lib/torch/nn/replication_pad3d.rb +10 -0
  62. data/lib/torch/nn/replication_padnd.rb +13 -0
  63. data/lib/torch/nn/rnn_base.rb +48 -4
  64. data/lib/torch/nn/softshrink.rb +18 -0
  65. data/lib/torch/nn/softsign.rb +9 -0
  66. data/lib/torch/nn/tanh.rb +9 -0
  67. data/lib/torch/nn/tanhshrink.rb +9 -0
  68. data/lib/torch/nn/unfold.rb +19 -0
  69. data/lib/torch/nn/utils.rb +25 -0
  70. data/lib/torch/nn/zero_pad2d.rb +9 -0
  71. data/lib/torch/tensor.rb +14 -25
  72. data/lib/torch/version.rb +1 -1
  73. metadata +50 -2
data/lib/torch.rb CHANGED
@@ -29,6 +29,7 @@ require "torch/optim/lr_scheduler/step_lr"
29
29
 
30
30
  # nn parameters
31
31
  require "torch/nn/parameter"
32
+ require "torch/nn/utils"
32
33
 
33
34
  # nn containers
34
35
  require "torch/nn/module"
@@ -36,17 +37,61 @@ require "torch/nn/sequential"
36
37
 
37
38
  # nn convolution layers
38
39
  require "torch/nn/convnd"
40
+ require "torch/nn/conv1d"
39
41
  require "torch/nn/conv2d"
42
+ require "torch/nn/conv3d"
43
+ require "torch/nn/unfold"
44
+ require "torch/nn/fold"
40
45
 
41
46
  # nn pooling layers
42
47
  require "torch/nn/max_poolnd"
48
+ require "torch/nn/max_pool1d"
43
49
  require "torch/nn/max_pool2d"
50
+ require "torch/nn/max_pool3d"
51
+ require "torch/nn/max_unpoolnd"
52
+ require "torch/nn/max_unpool1d"
53
+ require "torch/nn/max_unpool2d"
54
+ require "torch/nn/max_unpool3d"
44
55
  require "torch/nn/avg_poolnd"
56
+ require "torch/nn/avg_pool1d"
45
57
  require "torch/nn/avg_pool2d"
58
+ require "torch/nn/avg_pool3d"
59
+ require "torch/nn/lp_poolnd"
60
+ require "torch/nn/lp_pool1d"
61
+ require "torch/nn/lp_pool2d"
62
+
63
+ # nn padding layers
64
+ require "torch/nn/reflection_padnd"
65
+ require "torch/nn/reflection_pad1d"
66
+ require "torch/nn/reflection_pad2d"
67
+ require "torch/nn/replication_padnd"
68
+ require "torch/nn/replication_pad1d"
69
+ require "torch/nn/replication_pad2d"
70
+ require "torch/nn/replication_pad3d"
71
+ require "torch/nn/constant_padnd"
72
+ require "torch/nn/constant_pad1d"
73
+ require "torch/nn/constant_pad2d"
74
+ require "torch/nn/constant_pad3d"
75
+ require "torch/nn/zero_pad2d"
76
+
77
+ # nn normalization layers
78
+ require "torch/nn/batch_norm"
79
+ require "torch/nn/batch_norm1d"
80
+ require "torch/nn/batch_norm2d"
81
+ require "torch/nn/batch_norm3d"
82
+ require "torch/nn/group_norm"
83
+ require "torch/nn/instance_norm"
84
+ require "torch/nn/instance_norm1d"
85
+ require "torch/nn/instance_norm2d"
86
+ require "torch/nn/instance_norm3d"
87
+ require "torch/nn/layer_norm"
88
+ require "torch/nn/local_response_norm"
46
89
 
47
90
  # nn recurrent layers
48
91
  require "torch/nn/rnn_base"
49
92
  require "torch/nn/rnn"
93
+ require "torch/nn/lstm"
94
+ require "torch/nn/gru"
50
95
 
51
96
  # nn linear layers
52
97
  require "torch/nn/bilinear"
@@ -62,11 +107,17 @@ require "torch/nn/dropout3d"
62
107
  require "torch/nn/feature_alpha_dropout"
63
108
 
64
109
  # nn activations
110
+ require "torch/nn/hardshrink"
65
111
  require "torch/nn/leaky_relu"
112
+ require "torch/nn/log_sigmoid"
66
113
  require "torch/nn/prelu"
67
114
  require "torch/nn/relu"
68
115
  require "torch/nn/sigmoid"
69
116
  require "torch/nn/softplus"
117
+ require "torch/nn/softshrink"
118
+ require "torch/nn/softsign"
119
+ require "torch/nn/tanh"
120
+ require "torch/nn/tanhshrink"
70
121
 
71
122
  # nn activations other
72
123
  require "torch/nn/log_softmax"
@@ -364,48 +415,6 @@ module Torch
364
415
  zeros(input.size, like_options(input, options))
365
416
  end
366
417
 
367
- # --- begin operations ---
368
-
369
- # TODO support out
370
- def mean(input, dim = nil, keepdim: false)
371
- if dim
372
- _mean_dim(input, dim, keepdim)
373
- else
374
- _mean(input)
375
- end
376
- end
377
-
378
- # TODO support dtype
379
- def sum(input, dim = nil, keepdim: false)
380
- if dim
381
- _sum_dim(input, dim, keepdim)
382
- else
383
- _sum(input)
384
- end
385
- end
386
-
387
- def topk(input, k)
388
- _topk(input, k)
389
- end
390
-
391
- def max(input, dim = nil, keepdim: false, out: nil)
392
- if dim
393
- raise NotImplementedYet unless out
394
- _max_out(out[0], out[1], input, dim, keepdim)
395
- else
396
- _max(input)
397
- end
398
- end
399
-
400
- # TODO make dim keyword argument
401
- def log_softmax(input, dim)
402
- _log_softmax(input, dim)
403
- end
404
-
405
- def softmax(input, dim: nil)
406
- _softmax(input, dim)
407
- end
408
-
409
418
  private
410
419
 
411
420
  def tensor_size(size)
data/lib/torch/ext.bundle CHANGED
Binary file
@@ -18,7 +18,7 @@ module Torch
18
18
  functions = Generator.grouped_functions
19
19
  bind_functions(::Torch, :define_singleton_method, functions[:torch])
20
20
  bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
21
- # NN functions are internal, so no need to bind
21
+ bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
22
22
  end
23
23
 
24
24
  def bind_functions(context, def_method, functions)
@@ -35,11 +35,38 @@ module Torch
35
35
  end
36
36
  t, _, k = a.rpartition(" ")
37
37
  k, d = k.split("=")
38
- d = d.to_i if d.to_i.to_s == d
39
- d = true if d == "True"
40
- d = false if d == "False"
41
- d = nil if d == "None"
42
- args << {name: k, type: t, default: d, pos: pos}
38
+ has_default = !d.nil?
39
+
40
+ if d
41
+ d =
42
+ case d
43
+ when "True"
44
+ true
45
+ when "False"
46
+ false
47
+ when "None"
48
+ nil
49
+ when /\A\-?\d+\z/
50
+ d.to_i
51
+ when "[]"
52
+ []
53
+ when "[0,1]"
54
+ [0, 1]
55
+ when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
56
+ d.to_f
57
+ when "Mean"
58
+ "mean"
59
+ when "contiguous_format"
60
+ d
61
+ when "long"
62
+ :long
63
+ else
64
+ raise "Unknown default: #{d}"
65
+ end
66
+ end
67
+
68
+ next if t == "Generator?"
69
+ args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
43
70
  end
44
71
  args
45
72
  end
@@ -49,6 +76,10 @@ module Torch
49
76
  @out_size ||= func.split("->").last.count("!")
50
77
  end
51
78
 
79
+ def ret_size
80
+ @ret_size ||= func.split("->").last.split(", ").size
81
+ end
82
+
52
83
  def out?
53
84
  out_size > 0 && base_name[-1] != "_"
54
85
  end
@@ -17,14 +17,23 @@ module Torch
17
17
  def grouped_functions
18
18
  functions = functions()
19
19
 
20
- # remove functions
20
+ # skip functions
21
21
  skip_binding = ["unique_dim_consecutive", "einsum", "normal"]
22
- skip_args = ["bool[3]", "Dimname", "ScalarType", "MemoryFormat", "Storage", "ConstQuantizerPtr"]
23
- functions.reject! { |f| f.ruby_name.start_with?("_") || f.ruby_name.end_with?("_backward") || skip_binding.include?(f.ruby_name) }
22
+ skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"]
23
+
24
+ # remove functions
25
+ functions.reject! do |f|
26
+ f.ruby_name.start_with?("_") ||
27
+ f.ruby_name.end_with?("_backward") ||
28
+ skip_binding.include?(f.ruby_name) ||
29
+ f.args.any? { |a| a[:type].include?("Dimname") }
30
+ end
31
+
32
+ # separate out into todo
24
33
  todo_functions, functions =
25
34
  functions.partition do |f|
26
35
  f.args.any? do |a|
27
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?"].include?(a[:type]) ||
36
+ a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
28
37
  skip_args.any? { |sa| a[:type].include?(sa) }
29
38
  end
30
39
  end
@@ -33,7 +42,9 @@ module Torch
33
42
  # there may be a better way to do this
34
43
  optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
35
44
  optional_functions.each do |f|
36
- next if f.ruby_name.start_with?("avg_pool") || f.ruby_name == "cross"
45
+ next if f.ruby_name == "cross"
46
+ next if f.ruby_name.start_with?("avg_pool") && f.out?
47
+
37
48
  opt_args = f.args.select { |a| a[:type] == "int?" }
38
49
  if opt_args.size == 1
39
50
  sep = f.name.include?(".") ? "_" : "."
@@ -85,7 +96,7 @@ void add_%{type}_functions(Module m) {
85
96
 
86
97
  cpp_defs = []
87
98
  functions.sort_by(&:cpp_name).each do |func|
88
- fargs = func.args.select { |a| a[:type] != "Generator?" }
99
+ fargs = func.args #.select { |a| a[:type] != "Generator?" }
89
100
 
90
101
  cpp_args = []
91
102
  fargs.each do |a|
@@ -96,6 +107,8 @@ void add_%{type}_functions(Module m) {
96
107
  when "Tensor?"
97
108
  # TODO better signature
98
109
  "OptionalTensor"
110
+ when "ScalarType?"
111
+ "OptionalScalarType"
99
112
  when "Tensor[]"
100
113
  "TensorList"
101
114
  when "int"
@@ -121,10 +134,16 @@ void add_%{type}_functions(Module m) {
121
134
 
122
135
  prefix = def_method == :define_method ? "self." : "torch::"
123
136
 
137
+ body = "#{prefix}#{dispatch}(#{args.join(", ")})"
138
+ # TODO check type as well
139
+ if func.ret_size > 1
140
+ body = "wrap(#{body})"
141
+ end
142
+
124
143
  cpp_defs << ".#{def_method}(
125
144
  \"#{func.cpp_name}\",
126
145
  *[](#{cpp_args.join(", ")}) {
127
- return #{prefix}#{dispatch}(#{args.join(", ")});
146
+ return #{body};
128
147
  })"
129
148
  end
130
149
 
@@ -4,19 +4,44 @@ module Torch
4
4
  def initialize(functions)
5
5
  @functions = functions
6
6
  @name = @functions.first.ruby_name
7
- @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && a[:default].nil? } }.min
7
+ @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
8
8
  @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
9
  end
10
10
 
11
11
  def parse(args, options)
12
12
  candidates = @functions.dup
13
13
 
14
+ # remove nil
15
+ while args.any? && args.last.nil?
16
+ args.pop
17
+ end
18
+
19
+ # TODO account for args passed as options here
14
20
  if args.size < @min_args || args.size > @max_args
15
21
  expected = String.new(@min_args.to_s)
16
22
  expected += "..#{@max_args}" if @max_args != @min_args
17
23
  return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
18
24
  end
19
25
 
26
+ candidates.reject! { |f| args.size > f.args.size }
27
+
28
+ # exclude functions missing required options
29
+ candidates.reject! do |func|
30
+ # TODO make more generic
31
+ func.out? && !options[:out]
32
+ end
33
+
34
+ # handle out with multiple
35
+ # there should only be one match, so safe to modify all
36
+ out_func = candidates.find { |f| f.out? }
37
+ if out_func && out_func.out_size > 1 && options[:out]
38
+ out_args = out_func.args.last(2).map { |a| a[:name] }
39
+ out_args.zip(options.delete(:out)).each do |k, v|
40
+ options[k.to_sym] = v
41
+ end
42
+ candidates = [out_func]
43
+ end
44
+
20
45
  # exclude functions where options don't match
21
46
  options.each do |k, v|
22
47
  candidates.select! do |func|
@@ -26,12 +51,6 @@ module Torch
26
51
  return {error: "unknown keyword: #{k}"} if candidates.empty?
27
52
  end
28
53
 
29
- # exclude functions missing required options
30
- candidates.reject! do |func|
31
- # TODO make more generic
32
- func.out? && !options[:out]
33
- end
34
-
35
54
  final_values = {}
36
55
 
37
56
  # check args
@@ -41,29 +60,47 @@ module Torch
41
60
  values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
42
61
  values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
43
62
  func.args.each do |fa|
44
- values[fa[:name]] ||= fa[:default]
63
+ values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
45
64
  end
46
65
 
47
66
  arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
48
67
 
49
- values.each do |k, v|
68
+ values.each_key do |k|
69
+ v = values[k]
50
70
  t = arg_types[k].split("(").first
71
+
51
72
  good =
52
73
  case t
53
74
  when "Tensor"
54
75
  v.is_a?(Tensor)
76
+ when "Tensor?"
77
+ v.nil? || v.is_a?(Tensor)
55
78
  when "Tensor[]"
56
- v.all? { |v2| v2.is_a?(Tensor) }
79
+ v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
57
80
  when "int"
58
- v.is_a?(Integer)
59
- when "int[]"
60
- v.all? { |v2| v2.is_a?(Integer) }
81
+ if k == "reduction"
82
+ v.is_a?(String)
83
+ else
84
+ v.is_a?(Integer)
85
+ end
86
+ when "float"
87
+ v.is_a?(Numeric)
88
+ when /int\[.*\]/
89
+ if v.is_a?(Integer)
90
+ size = t[4..-2]
91
+ raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
92
+ v = [v] * size.to_i
93
+ values[k] = v
94
+ end
95
+ v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
61
96
  when "Scalar"
62
97
  v.is_a?(Numeric)
98
+ when "ScalarType?"
99
+ v.nil?
63
100
  when "bool"
64
101
  v == true || v == false
65
102
  else
66
- raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}"
103
+ raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
67
104
  end
68
105
 
69
106
  if !good
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPool1d < AvgPoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true)
5
+ super()
6
+ @kernel_size = _single(kernel_size)
7
+ @stride = _single(stride || kernel_size)
8
+ @padding = _single(padding)
9
+ @ceil_mode = ceil_mode
10
+ @count_include_pad = count_include_pad
11
+ end
12
+
13
+ def forward(input)
14
+ F.avg_pool1d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -1,13 +1,18 @@
1
1
  module Torch
2
2
  module NN
3
3
  class AvgPool2d < AvgPoolNd
4
- def initialize(kernel_size)
4
+ def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
5
5
  super()
6
6
  @kernel_size = kernel_size
7
+ @stride = stride || kernel_size
8
+ @padding = padding
9
+ @ceil_mode = ceil_mode
10
+ @count_include_pad = count_include_pad
11
+ @divisor_override = divisor_override
7
12
  end
8
13
 
9
14
  def forward(input)
10
- F.avg_pool2d(input, @kernel_size)
15
+ F.avg_pool2d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
11
16
  end
12
17
  end
13
18
  end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class AvgPool3d < AvgPoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ @stride = stride || kernel_size
8
+ @padding = padding
9
+ @ceil_mode = ceil_mode
10
+ @count_include_pad = count_include_pad
11
+ @divisor_override = divisor_override
12
+ end
13
+
14
+ def forward(input)
15
+ F.avg_pool3d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
16
+ end
17
+ end
18
+ end
19
+ end