torch-rb 0.3.7 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,6 +8,18 @@ module Torch
8
8
  alias_method :ndim, :dim
9
9
  alias_method :ndimension, :dim
10
10
 
11
+ # use alias_method for performance
12
+ alias_method :+, :add
13
+ alias_method :-, :sub
14
+ alias_method :*, :mul
15
+ alias_method :/, :div
16
+ alias_method :%, :remainder
17
+ alias_method :**, :pow
18
+ alias_method :-@, :neg
19
+ alias_method :&, :logical_and
20
+ alias_method :|, :logical_or
21
+ alias_method :^, :logical_xor
22
+
11
23
  def self.new(*args)
12
24
  FloatTensor.new(*args)
13
25
  end
@@ -73,12 +85,20 @@ module Torch
73
85
 
74
86
  def size(dim = nil)
75
87
  if dim
76
- _size_int(dim)
88
+ _size(dim)
77
89
  else
78
90
  shape
79
91
  end
80
92
  end
81
93
 
94
+ def stride(dim = nil)
95
+ if dim
96
+ _stride(dim)
97
+ else
98
+ _strides
99
+ end
100
+ end
101
+
82
102
  # mirror Python len()
83
103
  def length
84
104
  size(0)
@@ -130,57 +150,6 @@ module Torch
130
150
  end
131
151
  end
132
152
 
133
- def reshape(*size)
134
- # Python doesn't check if size == 1, just ignores later arguments
135
- size = size.first if size.size == 1 && size.first.is_a?(Array)
136
- _reshape(size)
137
- end
138
-
139
- def view(*size)
140
- size = size.first if size.size == 1 && size.first.is_a?(Array)
141
- _view(size)
142
- end
143
-
144
- def +(other)
145
- add(other)
146
- end
147
-
148
- def -(other)
149
- sub(other)
150
- end
151
-
152
- def *(other)
153
- mul(other)
154
- end
155
-
156
- def /(other)
157
- div(other)
158
- end
159
-
160
- def %(other)
161
- remainder(other)
162
- end
163
-
164
- def **(other)
165
- pow(other)
166
- end
167
-
168
- def -@
169
- neg
170
- end
171
-
172
- def &(other)
173
- logical_and(other)
174
- end
175
-
176
- def |(other)
177
- logical_or(other)
178
- end
179
-
180
- def ^(other)
181
- logical_xor(other)
182
- end
183
-
184
153
  # TODO better compare?
185
154
  def <=>(other)
186
155
  item <=> other
@@ -189,7 +158,7 @@ module Torch
189
158
  # based on python_variable_indexing.cpp and
190
159
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
191
160
  def [](*indexes)
192
- _index(tensor_indexes(indexes))
161
+ _index(indexes)
193
162
  end
194
163
 
195
164
  # based on python_variable_indexing.cpp and
@@ -197,62 +166,13 @@ module Torch
197
166
  def []=(*indexes, value)
198
167
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
199
168
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
200
- _index_put_custom(tensor_indexes(indexes), value)
201
- end
202
-
203
- # native functions that need manually defined
204
-
205
- # value and other are swapped for some methods
206
- def add!(value = 1, other)
207
- if other.is_a?(Numeric)
208
- _add__scalar(other, value)
209
- else
210
- _add__tensor(other, value)
211
- end
169
+ _index_put_custom(indexes, value)
212
170
  end
213
171
 
214
172
  # parser can't handle overlap, so need to handle manually
215
173
  def random!(*args)
216
- case args.size
217
- when 1
218
- _random__to(*args)
219
- when 2
220
- _random__from(*args)
221
- else
222
- _random_(*args)
223
- end
224
- end
225
-
226
- def clamp!(min, max)
227
- _clamp_min_(min)
228
- _clamp_max_(max)
229
- end
230
-
231
- private
232
-
233
- def tensor_indexes(indexes)
234
- indexes.map do |index|
235
- case index
236
- when Integer
237
- TensorIndex.integer(index)
238
- when Range
239
- finish = index.end || -1
240
- if finish == -1 && !index.exclude_end?
241
- finish = nil
242
- else
243
- finish += 1 unless index.exclude_end?
244
- end
245
- TensorIndex.slice(index.begin, finish)
246
- when Tensor
247
- TensorIndex.tensor(index)
248
- when nil
249
- TensorIndex.none
250
- when true, false
251
- TensorIndex.boolean(index)
252
- else
253
- raise Error, "Unsupported index type: #{index.class.name}"
254
- end
255
- end
174
+ return _random!(0, *args) if args.size == 1
175
+ _random!(*args)
256
176
  end
257
177
  end
258
178
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.7"
2
+ VERSION = "0.4.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.7
4
+ version: 0.4.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-09-23 00:00:00.000000000 Z
11
+ date: 2020-09-27 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -108,6 +108,20 @@ dependencies:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
110
  version: 0.1.1
111
+ - !ruby/object:Gem::Dependency
112
+ name: magro
113
+ requirement: !ruby/object:Gem::Requirement
114
+ requirements:
115
+ - - ">="
116
+ - !ruby/object:Gem::Version
117
+ version: '0'
118
+ type: :development
119
+ prerelease: false
120
+ version_requirements: !ruby/object:Gem::Requirement
121
+ requirements:
122
+ - - ">="
123
+ - !ruby/object:Gem::Version
124
+ version: '0'
111
125
  description:
112
126
  email: andrew@chartkick.com
113
127
  executables: []
@@ -118,19 +132,23 @@ files:
118
132
  - CHANGELOG.md
119
133
  - LICENSE.txt
120
134
  - README.md
135
+ - codegen/function.rb
136
+ - codegen/generate_functions.rb
137
+ - codegen/native_functions.yaml
121
138
  - ext/torch/ext.cpp
122
139
  - ext/torch/extconf.rb
123
- - ext/torch/templates.cpp
124
- - ext/torch/templates.hpp
140
+ - ext/torch/nn_functions.h
141
+ - ext/torch/ruby_arg_parser.cpp
142
+ - ext/torch/ruby_arg_parser.h
143
+ - ext/torch/templates.h
144
+ - ext/torch/tensor_functions.h
145
+ - ext/torch/torch_functions.h
146
+ - ext/torch/utils.h
147
+ - ext/torch/wrap_outputs.h
125
148
  - lib/torch-rb.rb
126
149
  - lib/torch.rb
127
150
  - lib/torch/hub.rb
128
151
  - lib/torch/inspector.rb
129
- - lib/torch/native/dispatcher.rb
130
- - lib/torch/native/function.rb
131
- - lib/torch/native/generator.rb
132
- - lib/torch/native/native_functions.yaml
133
- - lib/torch/native/parser.rb
134
152
  - lib/torch/nn/adaptive_avg_pool1d.rb
135
153
  - lib/torch/nn/adaptive_avg_pool2d.rb
136
154
  - lib/torch/nn/adaptive_avg_pool3d.rb
@@ -1,70 +0,0 @@
1
- # We use a generic interface for methods (*args, **options)
2
- # and this class to determine the C++ method to call
3
- #
4
- # This is needed since LibTorch uses function overloading,
5
- # which isn't available in Ruby or Python
6
- #
7
- # PyTorch uses this approach, but the parser/dispatcher is written in C++
8
- #
9
- # We could generate Ruby methods directly, but an advantage of this approach is
10
- # arguments and keyword arguments can be used interchangably like in Python,
11
- # making it easier to port code
12
-
13
- module Torch
14
- module Native
15
- module Dispatcher
16
- class << self
17
- def bind
18
- functions = Generator.grouped_functions
19
- bind_functions(::Torch, :define_singleton_method, functions[:torch])
20
- bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
21
- bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
22
- end
23
-
24
- def bind_functions(context, def_method, functions)
25
- instance_method = def_method == :define_method
26
- functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
27
- if instance_method
28
- funcs.map! { |f| Function.new(f.function) }
29
- funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
30
- end
31
-
32
- defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
33
- next if defined && name != "clone"
34
-
35
- # skip parser when possible for performance
36
- if funcs.size == 1 && funcs.first.args.size == 0
37
- # functions with no arguments
38
- if instance_method
39
- context.send(:alias_method, name, funcs.first.cpp_name)
40
- else
41
- context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
42
- end
43
- elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
44
- # functions that take a tensor or scalar
45
- scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
46
- context.send(def_method, name) do |other|
47
- case other
48
- when Tensor
49
- send(tensor_name, other)
50
- else
51
- send(scalar_name, other)
52
- end
53
- end
54
- else
55
- parser = Parser.new(funcs)
56
-
57
- context.send(def_method, name) do |*args, **options|
58
- result = parser.parse(args, options)
59
- raise ArgumentError, result[:error] if result[:error]
60
- send(result[:name], *result[:args])
61
- end
62
- end
63
- end
64
- end
65
- end
66
- end
67
- end
68
- end
69
-
70
- Torch::Native::Dispatcher.bind
@@ -1,200 +0,0 @@
1
- module Torch
2
- module Native
3
- class Function
4
- attr_reader :function, :tensor_options
5
-
6
- def initialize(function)
7
- @function = function
8
-
9
- # note: don't modify function in-place
10
- @tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
11
- @tensor_options = @function["func"].include?(@tensor_options_str)
12
- @out = out_size > 0 && base_name[-1] != "_"
13
- end
14
-
15
- def func
16
- @func ||= @function["func"]
17
- end
18
-
19
- def name
20
- @name ||= func.split("(", 2).first
21
- end
22
-
23
- def python_module
24
- @python_module ||= @function["python_module"]
25
- end
26
-
27
- def variants
28
- @variants ||= (@function["variants"] || "function").split(", ")
29
- end
30
-
31
- def args
32
- @args ||= begin
33
- args = []
34
- pos = true
35
- args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
36
- args_str.split(", ").each do |a|
37
- if a == "*"
38
- pos = false
39
- next
40
- end
41
- t, _, k = a.rpartition(" ")
42
- k, d = k.split("=")
43
- has_default = !d.nil?
44
-
45
- if d
46
- d =
47
- case d
48
- when "True"
49
- true
50
- when "False"
51
- false
52
- when "None"
53
- nil
54
- when /\A\-?\d+\z/
55
- d.to_i
56
- when "[]"
57
- []
58
- when "[0,1]"
59
- [0, 1]
60
- when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
61
- d.to_f
62
- when "Mean"
63
- "mean"
64
- when "contiguous_format"
65
- d
66
- when "long"
67
- :long
68
- else
69
- raise "Unknown default: #{d}"
70
- end
71
- end
72
-
73
- next if t == "Generator?"
74
- next if t == "MemoryFormat"
75
- next if t == "MemoryFormat?"
76
- args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
77
- end
78
- args
79
- end
80
- end
81
-
82
- def arg_checkers
83
- @arg_checkers ||= begin
84
- checkers = {}
85
- arg_types.each do |k, t|
86
- checker =
87
- case t
88
- when "Tensor"
89
- ->(v) { v.is_a?(Tensor) }
90
- when "Tensor?"
91
- ->(v) { v.nil? || v.is_a?(Tensor) }
92
- when "Tensor[]", "Tensor?[]"
93
- ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
94
- when "int"
95
- if k == :reduction
96
- ->(v) { v.is_a?(String) }
97
- else
98
- ->(v) { v.is_a?(Integer) }
99
- end
100
- when "int?"
101
- ->(v) { v.is_a?(Integer) || v.nil? }
102
- when "float?"
103
- ->(v) { v.is_a?(Numeric) || v.nil? }
104
- when "bool?"
105
- ->(v) { v == true || v == false || v.nil? }
106
- when "float"
107
- ->(v) { v.is_a?(Numeric) }
108
- when /int\[.*\]/
109
- ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
110
- when "Scalar"
111
- ->(v) { v.is_a?(Numeric) }
112
- when "Scalar?"
113
- ->(v) { v.is_a?(Numeric) || v.nil? }
114
- when "ScalarType"
115
- ->(v) { false } # not supported yet
116
- when "ScalarType?"
117
- ->(v) { v.nil? }
118
- when "bool"
119
- ->(v) { v == true || v == false }
120
- when "str"
121
- ->(v) { v.is_a?(String) }
122
- else
123
- raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
124
- end
125
- checkers[k] = checker
126
- end
127
- checkers
128
- end
129
- end
130
-
131
- def int_array_lengths
132
- @int_array_lengths ||= begin
133
- ret = {}
134
- arg_types.each do |k, t|
135
- if t.match?(/\Aint\[.+\]\z/)
136
- size = t[4..-2]
137
- raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
138
- ret[k] = size.to_i
139
- end
140
- end
141
- ret
142
- end
143
- end
144
-
145
- def arg_names
146
- @arg_names ||= args.map { |a| a[:name] }
147
- end
148
-
149
- def arg_types
150
- @arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
151
- end
152
-
153
- def arg_defaults
154
- # TODO find out why can't use select here
155
- @arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
156
- end
157
-
158
- def out_size
159
- @out_size ||= func.split("->").last.count("!")
160
- end
161
-
162
- def ret_size
163
- @ret_size ||= func.split("->").last.split(", ").size
164
- end
165
-
166
- def ret_array?
167
- @ret_array ||= func.split("->").last.include?('[]')
168
- end
169
-
170
- def ret_void?
171
- func.split("->").last.strip == "()"
172
- end
173
-
174
- def out?
175
- @out
176
- end
177
-
178
- def ruby_name
179
- @ruby_name ||= begin
180
- name = base_name
181
- if name.end_with?("_")
182
- "#{name[0..-2]}!"
183
- elsif name.start_with?("is_")
184
- "#{name[3..-1]}?"
185
- else
186
- name
187
- end
188
- end
189
- end
190
-
191
- def cpp_name
192
- @cpp_name ||= "_" + name.downcase.sub(".", "_")
193
- end
194
-
195
- def base_name
196
- @base_name ||= name.split(".").first
197
- end
198
- end
199
- end
200
- end