torch-rb 0.3.7 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -8,6 +8,18 @@ module Torch
8
8
  alias_method :ndim, :dim
9
9
  alias_method :ndimension, :dim
10
10
 
11
+ # use alias_method for performance
12
+ alias_method :+, :add
13
+ alias_method :-, :sub
14
+ alias_method :*, :mul
15
+ alias_method :/, :div
16
+ alias_method :%, :remainder
17
+ alias_method :**, :pow
18
+ alias_method :-@, :neg
19
+ alias_method :&, :logical_and
20
+ alias_method :|, :logical_or
21
+ alias_method :^, :logical_xor
22
+
11
23
  def self.new(*args)
12
24
  FloatTensor.new(*args)
13
25
  end
@@ -73,12 +85,20 @@ module Torch
73
85
 
74
86
  def size(dim = nil)
75
87
  if dim
76
- _size_int(dim)
88
+ _size(dim)
77
89
  else
78
90
  shape
79
91
  end
80
92
  end
81
93
 
94
+ def stride(dim = nil)
95
+ if dim
96
+ _stride(dim)
97
+ else
98
+ _strides
99
+ end
100
+ end
101
+
82
102
  # mirror Python len()
83
103
  def length
84
104
  size(0)
@@ -130,57 +150,6 @@ module Torch
130
150
  end
131
151
  end
132
152
 
133
- def reshape(*size)
134
- # Python doesn't check if size == 1, just ignores later arguments
135
- size = size.first if size.size == 1 && size.first.is_a?(Array)
136
- _reshape(size)
137
- end
138
-
139
- def view(*size)
140
- size = size.first if size.size == 1 && size.first.is_a?(Array)
141
- _view(size)
142
- end
143
-
144
- def +(other)
145
- add(other)
146
- end
147
-
148
- def -(other)
149
- sub(other)
150
- end
151
-
152
- def *(other)
153
- mul(other)
154
- end
155
-
156
- def /(other)
157
- div(other)
158
- end
159
-
160
- def %(other)
161
- remainder(other)
162
- end
163
-
164
- def **(other)
165
- pow(other)
166
- end
167
-
168
- def -@
169
- neg
170
- end
171
-
172
- def &(other)
173
- logical_and(other)
174
- end
175
-
176
- def |(other)
177
- logical_or(other)
178
- end
179
-
180
- def ^(other)
181
- logical_xor(other)
182
- end
183
-
184
153
  # TODO better compare?
185
154
  def <=>(other)
186
155
  item <=> other
@@ -189,7 +158,7 @@ module Torch
189
158
  # based on python_variable_indexing.cpp and
190
159
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
191
160
  def [](*indexes)
192
- _index(tensor_indexes(indexes))
161
+ _index(indexes)
193
162
  end
194
163
 
195
164
  # based on python_variable_indexing.cpp and
@@ -197,62 +166,13 @@ module Torch
197
166
  def []=(*indexes, value)
198
167
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
199
168
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
200
- _index_put_custom(tensor_indexes(indexes), value)
201
- end
202
-
203
- # native functions that need manually defined
204
-
205
- # value and other are swapped for some methods
206
- def add!(value = 1, other)
207
- if other.is_a?(Numeric)
208
- _add__scalar(other, value)
209
- else
210
- _add__tensor(other, value)
211
- end
169
+ _index_put_custom(indexes, value)
212
170
  end
213
171
 
214
172
  # parser can't handle overlap, so need to handle manually
215
173
  def random!(*args)
216
- case args.size
217
- when 1
218
- _random__to(*args)
219
- when 2
220
- _random__from(*args)
221
- else
222
- _random_(*args)
223
- end
224
- end
225
-
226
- def clamp!(min, max)
227
- _clamp_min_(min)
228
- _clamp_max_(max)
229
- end
230
-
231
- private
232
-
233
- def tensor_indexes(indexes)
234
- indexes.map do |index|
235
- case index
236
- when Integer
237
- TensorIndex.integer(index)
238
- when Range
239
- finish = index.end || -1
240
- if finish == -1 && !index.exclude_end?
241
- finish = nil
242
- else
243
- finish += 1 unless index.exclude_end?
244
- end
245
- TensorIndex.slice(index.begin, finish)
246
- when Tensor
247
- TensorIndex.tensor(index)
248
- when nil
249
- TensorIndex.none
250
- when true, false
251
- TensorIndex.boolean(index)
252
- else
253
- raise Error, "Unsupported index type: #{index.class.name}"
254
- end
255
- end
174
+ return _random!(0, *args) if args.size == 1
175
+ _random!(*args)
256
176
  end
257
177
  end
258
178
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.7"
2
+ VERSION = "0.4.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.7
4
+ version: 0.4.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-09-23 00:00:00.000000000 Z
11
+ date: 2020-09-27 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -108,6 +108,20 @@ dependencies:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
110
  version: 0.1.1
111
+ - !ruby/object:Gem::Dependency
112
+ name: magro
113
+ requirement: !ruby/object:Gem::Requirement
114
+ requirements:
115
+ - - ">="
116
+ - !ruby/object:Gem::Version
117
+ version: '0'
118
+ type: :development
119
+ prerelease: false
120
+ version_requirements: !ruby/object:Gem::Requirement
121
+ requirements:
122
+ - - ">="
123
+ - !ruby/object:Gem::Version
124
+ version: '0'
111
125
  description:
112
126
  email: andrew@chartkick.com
113
127
  executables: []
@@ -118,19 +132,23 @@ files:
118
132
  - CHANGELOG.md
119
133
  - LICENSE.txt
120
134
  - README.md
135
+ - codegen/function.rb
136
+ - codegen/generate_functions.rb
137
+ - codegen/native_functions.yaml
121
138
  - ext/torch/ext.cpp
122
139
  - ext/torch/extconf.rb
123
- - ext/torch/templates.cpp
124
- - ext/torch/templates.hpp
140
+ - ext/torch/nn_functions.h
141
+ - ext/torch/ruby_arg_parser.cpp
142
+ - ext/torch/ruby_arg_parser.h
143
+ - ext/torch/templates.h
144
+ - ext/torch/tensor_functions.h
145
+ - ext/torch/torch_functions.h
146
+ - ext/torch/utils.h
147
+ - ext/torch/wrap_outputs.h
125
148
  - lib/torch-rb.rb
126
149
  - lib/torch.rb
127
150
  - lib/torch/hub.rb
128
151
  - lib/torch/inspector.rb
129
- - lib/torch/native/dispatcher.rb
130
- - lib/torch/native/function.rb
131
- - lib/torch/native/generator.rb
132
- - lib/torch/native/native_functions.yaml
133
- - lib/torch/native/parser.rb
134
152
  - lib/torch/nn/adaptive_avg_pool1d.rb
135
153
  - lib/torch/nn/adaptive_avg_pool2d.rb
136
154
  - lib/torch/nn/adaptive_avg_pool3d.rb
@@ -1,70 +0,0 @@
1
- # We use a generic interface for methods (*args, **options)
2
- # and this class to determine the C++ method to call
3
- #
4
- # This is needed since LibTorch uses function overloading,
5
- # which isn't available in Ruby or Python
6
- #
7
- # PyTorch uses this approach, but the parser/dispatcher is written in C++
8
- #
9
- # We could generate Ruby methods directly, but an advantage of this approach is
10
- # arguments and keyword arguments can be used interchangably like in Python,
11
- # making it easier to port code
12
-
13
- module Torch
14
- module Native
15
- module Dispatcher
16
- class << self
17
- def bind
18
- functions = Generator.grouped_functions
19
- bind_functions(::Torch, :define_singleton_method, functions[:torch])
20
- bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
21
- bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
22
- end
23
-
24
- def bind_functions(context, def_method, functions)
25
- instance_method = def_method == :define_method
26
- functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
27
- if instance_method
28
- funcs.map! { |f| Function.new(f.function) }
29
- funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
30
- end
31
-
32
- defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
33
- next if defined && name != "clone"
34
-
35
- # skip parser when possible for performance
36
- if funcs.size == 1 && funcs.first.args.size == 0
37
- # functions with no arguments
38
- if instance_method
39
- context.send(:alias_method, name, funcs.first.cpp_name)
40
- else
41
- context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
42
- end
43
- elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
44
- # functions that take a tensor or scalar
45
- scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
46
- context.send(def_method, name) do |other|
47
- case other
48
- when Tensor
49
- send(tensor_name, other)
50
- else
51
- send(scalar_name, other)
52
- end
53
- end
54
- else
55
- parser = Parser.new(funcs)
56
-
57
- context.send(def_method, name) do |*args, **options|
58
- result = parser.parse(args, options)
59
- raise ArgumentError, result[:error] if result[:error]
60
- send(result[:name], *result[:args])
61
- end
62
- end
63
- end
64
- end
65
- end
66
- end
67
- end
68
- end
69
-
70
- Torch::Native::Dispatcher.bind
@@ -1,200 +0,0 @@
1
- module Torch
2
- module Native
3
- class Function
4
- attr_reader :function, :tensor_options
5
-
6
- def initialize(function)
7
- @function = function
8
-
9
- # note: don't modify function in-place
10
- @tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
11
- @tensor_options = @function["func"].include?(@tensor_options_str)
12
- @out = out_size > 0 && base_name[-1] != "_"
13
- end
14
-
15
- def func
16
- @func ||= @function["func"]
17
- end
18
-
19
- def name
20
- @name ||= func.split("(", 2).first
21
- end
22
-
23
- def python_module
24
- @python_module ||= @function["python_module"]
25
- end
26
-
27
- def variants
28
- @variants ||= (@function["variants"] || "function").split(", ")
29
- end
30
-
31
- def args
32
- @args ||= begin
33
- args = []
34
- pos = true
35
- args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
36
- args_str.split(", ").each do |a|
37
- if a == "*"
38
- pos = false
39
- next
40
- end
41
- t, _, k = a.rpartition(" ")
42
- k, d = k.split("=")
43
- has_default = !d.nil?
44
-
45
- if d
46
- d =
47
- case d
48
- when "True"
49
- true
50
- when "False"
51
- false
52
- when "None"
53
- nil
54
- when /\A\-?\d+\z/
55
- d.to_i
56
- when "[]"
57
- []
58
- when "[0,1]"
59
- [0, 1]
60
- when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
61
- d.to_f
62
- when "Mean"
63
- "mean"
64
- when "contiguous_format"
65
- d
66
- when "long"
67
- :long
68
- else
69
- raise "Unknown default: #{d}"
70
- end
71
- end
72
-
73
- next if t == "Generator?"
74
- next if t == "MemoryFormat"
75
- next if t == "MemoryFormat?"
76
- args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
77
- end
78
- args
79
- end
80
- end
81
-
82
- def arg_checkers
83
- @arg_checkers ||= begin
84
- checkers = {}
85
- arg_types.each do |k, t|
86
- checker =
87
- case t
88
- when "Tensor"
89
- ->(v) { v.is_a?(Tensor) }
90
- when "Tensor?"
91
- ->(v) { v.nil? || v.is_a?(Tensor) }
92
- when "Tensor[]", "Tensor?[]"
93
- ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
94
- when "int"
95
- if k == :reduction
96
- ->(v) { v.is_a?(String) }
97
- else
98
- ->(v) { v.is_a?(Integer) }
99
- end
100
- when "int?"
101
- ->(v) { v.is_a?(Integer) || v.nil? }
102
- when "float?"
103
- ->(v) { v.is_a?(Numeric) || v.nil? }
104
- when "bool?"
105
- ->(v) { v == true || v == false || v.nil? }
106
- when "float"
107
- ->(v) { v.is_a?(Numeric) }
108
- when /int\[.*\]/
109
- ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
110
- when "Scalar"
111
- ->(v) { v.is_a?(Numeric) }
112
- when "Scalar?"
113
- ->(v) { v.is_a?(Numeric) || v.nil? }
114
- when "ScalarType"
115
- ->(v) { false } # not supported yet
116
- when "ScalarType?"
117
- ->(v) { v.nil? }
118
- when "bool"
119
- ->(v) { v == true || v == false }
120
- when "str"
121
- ->(v) { v.is_a?(String) }
122
- else
123
- raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
124
- end
125
- checkers[k] = checker
126
- end
127
- checkers
128
- end
129
- end
130
-
131
- def int_array_lengths
132
- @int_array_lengths ||= begin
133
- ret = {}
134
- arg_types.each do |k, t|
135
- if t.match?(/\Aint\[.+\]\z/)
136
- size = t[4..-2]
137
- raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
138
- ret[k] = size.to_i
139
- end
140
- end
141
- ret
142
- end
143
- end
144
-
145
- def arg_names
146
- @arg_names ||= args.map { |a| a[:name] }
147
- end
148
-
149
- def arg_types
150
- @arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
151
- end
152
-
153
- def arg_defaults
154
- # TODO find out why can't use select here
155
- @arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
156
- end
157
-
158
- def out_size
159
- @out_size ||= func.split("->").last.count("!")
160
- end
161
-
162
- def ret_size
163
- @ret_size ||= func.split("->").last.split(", ").size
164
- end
165
-
166
- def ret_array?
167
- @ret_array ||= func.split("->").last.include?('[]')
168
- end
169
-
170
- def ret_void?
171
- func.split("->").last.strip == "()"
172
- end
173
-
174
- def out?
175
- @out
176
- end
177
-
178
- def ruby_name
179
- @ruby_name ||= begin
180
- name = base_name
181
- if name.end_with?("_")
182
- "#{name[0..-2]}!"
183
- elsif name.start_with?("is_")
184
- "#{name[3..-1]}?"
185
- else
186
- name
187
- end
188
- end
189
- end
190
-
191
- def cpp_name
192
- @cpp_name ||= "_" + name.downcase.sub(".", "_")
193
- end
194
-
195
- def base_name
196
- @base_name ||= name.split(".").first
197
- end
198
- end
199
- end
200
- end