torch-rb 0.3.7 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +1 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +546 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +54 -75
- data/ext/torch/extconf.rb +2 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +30 -51
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +16 -15
- data/lib/torch.rb +0 -62
- data/lib/torch/nn/functional.rb +30 -16
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +25 -105
- data/lib/torch/version.rb +1 -1
- metadata +27 -9
- data/lib/torch/native/dispatcher.rb +0 -70
- data/lib/torch/native/function.rb +0 -200
- data/lib/torch/native/generator.rb +0 -178
- data/lib/torch/native/parser.rb +0 -117
data/lib/torch/tensor.rb
CHANGED
@@ -8,6 +8,18 @@ module Torch
|
|
8
8
|
alias_method :ndim, :dim
|
9
9
|
alias_method :ndimension, :dim
|
10
10
|
|
11
|
+
# use alias_method for performance
|
12
|
+
alias_method :+, :add
|
13
|
+
alias_method :-, :sub
|
14
|
+
alias_method :*, :mul
|
15
|
+
alias_method :/, :div
|
16
|
+
alias_method :%, :remainder
|
17
|
+
alias_method :**, :pow
|
18
|
+
alias_method :-@, :neg
|
19
|
+
alias_method :&, :logical_and
|
20
|
+
alias_method :|, :logical_or
|
21
|
+
alias_method :^, :logical_xor
|
22
|
+
|
11
23
|
def self.new(*args)
|
12
24
|
FloatTensor.new(*args)
|
13
25
|
end
|
@@ -73,12 +85,20 @@ module Torch
|
|
73
85
|
|
74
86
|
def size(dim = nil)
|
75
87
|
if dim
|
76
|
-
|
88
|
+
_size(dim)
|
77
89
|
else
|
78
90
|
shape
|
79
91
|
end
|
80
92
|
end
|
81
93
|
|
94
|
+
def stride(dim = nil)
|
95
|
+
if dim
|
96
|
+
_stride(dim)
|
97
|
+
else
|
98
|
+
_strides
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
82
102
|
# mirror Python len()
|
83
103
|
def length
|
84
104
|
size(0)
|
@@ -130,57 +150,6 @@ module Torch
|
|
130
150
|
end
|
131
151
|
end
|
132
152
|
|
133
|
-
def reshape(*size)
|
134
|
-
# Python doesn't check if size == 1, just ignores later arguments
|
135
|
-
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
136
|
-
_reshape(size)
|
137
|
-
end
|
138
|
-
|
139
|
-
def view(*size)
|
140
|
-
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
141
|
-
_view(size)
|
142
|
-
end
|
143
|
-
|
144
|
-
def +(other)
|
145
|
-
add(other)
|
146
|
-
end
|
147
|
-
|
148
|
-
def -(other)
|
149
|
-
sub(other)
|
150
|
-
end
|
151
|
-
|
152
|
-
def *(other)
|
153
|
-
mul(other)
|
154
|
-
end
|
155
|
-
|
156
|
-
def /(other)
|
157
|
-
div(other)
|
158
|
-
end
|
159
|
-
|
160
|
-
def %(other)
|
161
|
-
remainder(other)
|
162
|
-
end
|
163
|
-
|
164
|
-
def **(other)
|
165
|
-
pow(other)
|
166
|
-
end
|
167
|
-
|
168
|
-
def -@
|
169
|
-
neg
|
170
|
-
end
|
171
|
-
|
172
|
-
def &(other)
|
173
|
-
logical_and(other)
|
174
|
-
end
|
175
|
-
|
176
|
-
def |(other)
|
177
|
-
logical_or(other)
|
178
|
-
end
|
179
|
-
|
180
|
-
def ^(other)
|
181
|
-
logical_xor(other)
|
182
|
-
end
|
183
|
-
|
184
153
|
# TODO better compare?
|
185
154
|
def <=>(other)
|
186
155
|
item <=> other
|
@@ -189,7 +158,7 @@ module Torch
|
|
189
158
|
# based on python_variable_indexing.cpp and
|
190
159
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
191
160
|
def [](*indexes)
|
192
|
-
_index(
|
161
|
+
_index(indexes)
|
193
162
|
end
|
194
163
|
|
195
164
|
# based on python_variable_indexing.cpp and
|
@@ -197,62 +166,13 @@ module Torch
|
|
197
166
|
def []=(*indexes, value)
|
198
167
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
199
168
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
200
|
-
_index_put_custom(
|
201
|
-
end
|
202
|
-
|
203
|
-
# native functions that need manually defined
|
204
|
-
|
205
|
-
# value and other are swapped for some methods
|
206
|
-
def add!(value = 1, other)
|
207
|
-
if other.is_a?(Numeric)
|
208
|
-
_add__scalar(other, value)
|
209
|
-
else
|
210
|
-
_add__tensor(other, value)
|
211
|
-
end
|
169
|
+
_index_put_custom(indexes, value)
|
212
170
|
end
|
213
171
|
|
214
172
|
# parser can't handle overlap, so need to handle manually
|
215
173
|
def random!(*args)
|
216
|
-
|
217
|
-
|
218
|
-
_random__to(*args)
|
219
|
-
when 2
|
220
|
-
_random__from(*args)
|
221
|
-
else
|
222
|
-
_random_(*args)
|
223
|
-
end
|
224
|
-
end
|
225
|
-
|
226
|
-
def clamp!(min, max)
|
227
|
-
_clamp_min_(min)
|
228
|
-
_clamp_max_(max)
|
229
|
-
end
|
230
|
-
|
231
|
-
private
|
232
|
-
|
233
|
-
def tensor_indexes(indexes)
|
234
|
-
indexes.map do |index|
|
235
|
-
case index
|
236
|
-
when Integer
|
237
|
-
TensorIndex.integer(index)
|
238
|
-
when Range
|
239
|
-
finish = index.end || -1
|
240
|
-
if finish == -1 && !index.exclude_end?
|
241
|
-
finish = nil
|
242
|
-
else
|
243
|
-
finish += 1 unless index.exclude_end?
|
244
|
-
end
|
245
|
-
TensorIndex.slice(index.begin, finish)
|
246
|
-
when Tensor
|
247
|
-
TensorIndex.tensor(index)
|
248
|
-
when nil
|
249
|
-
TensorIndex.none
|
250
|
-
when true, false
|
251
|
-
TensorIndex.boolean(index)
|
252
|
-
else
|
253
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
254
|
-
end
|
255
|
-
end
|
174
|
+
return _random!(0, *args) if args.size == 1
|
175
|
+
_random!(*args)
|
256
176
|
end
|
257
177
|
end
|
258
178
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.4.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-09-
|
11
|
+
date: 2020-09-27 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -108,6 +108,20 @@ dependencies:
|
|
108
108
|
- - ">="
|
109
109
|
- !ruby/object:Gem::Version
|
110
110
|
version: 0.1.1
|
111
|
+
- !ruby/object:Gem::Dependency
|
112
|
+
name: magro
|
113
|
+
requirement: !ruby/object:Gem::Requirement
|
114
|
+
requirements:
|
115
|
+
- - ">="
|
116
|
+
- !ruby/object:Gem::Version
|
117
|
+
version: '0'
|
118
|
+
type: :development
|
119
|
+
prerelease: false
|
120
|
+
version_requirements: !ruby/object:Gem::Requirement
|
121
|
+
requirements:
|
122
|
+
- - ">="
|
123
|
+
- !ruby/object:Gem::Version
|
124
|
+
version: '0'
|
111
125
|
description:
|
112
126
|
email: andrew@chartkick.com
|
113
127
|
executables: []
|
@@ -118,19 +132,23 @@ files:
|
|
118
132
|
- CHANGELOG.md
|
119
133
|
- LICENSE.txt
|
120
134
|
- README.md
|
135
|
+
- codegen/function.rb
|
136
|
+
- codegen/generate_functions.rb
|
137
|
+
- codegen/native_functions.yaml
|
121
138
|
- ext/torch/ext.cpp
|
122
139
|
- ext/torch/extconf.rb
|
123
|
-
- ext/torch/
|
124
|
-
- ext/torch/
|
140
|
+
- ext/torch/nn_functions.h
|
141
|
+
- ext/torch/ruby_arg_parser.cpp
|
142
|
+
- ext/torch/ruby_arg_parser.h
|
143
|
+
- ext/torch/templates.h
|
144
|
+
- ext/torch/tensor_functions.h
|
145
|
+
- ext/torch/torch_functions.h
|
146
|
+
- ext/torch/utils.h
|
147
|
+
- ext/torch/wrap_outputs.h
|
125
148
|
- lib/torch-rb.rb
|
126
149
|
- lib/torch.rb
|
127
150
|
- lib/torch/hub.rb
|
128
151
|
- lib/torch/inspector.rb
|
129
|
-
- lib/torch/native/dispatcher.rb
|
130
|
-
- lib/torch/native/function.rb
|
131
|
-
- lib/torch/native/generator.rb
|
132
|
-
- lib/torch/native/native_functions.yaml
|
133
|
-
- lib/torch/native/parser.rb
|
134
152
|
- lib/torch/nn/adaptive_avg_pool1d.rb
|
135
153
|
- lib/torch/nn/adaptive_avg_pool2d.rb
|
136
154
|
- lib/torch/nn/adaptive_avg_pool3d.rb
|
@@ -1,70 +0,0 @@
|
|
1
|
-
# We use a generic interface for methods (*args, **options)
|
2
|
-
# and this class to determine the C++ method to call
|
3
|
-
#
|
4
|
-
# This is needed since LibTorch uses function overloading,
|
5
|
-
# which isn't available in Ruby or Python
|
6
|
-
#
|
7
|
-
# PyTorch uses this approach, but the parser/dispatcher is written in C++
|
8
|
-
#
|
9
|
-
# We could generate Ruby methods directly, but an advantage of this approach is
|
10
|
-
# arguments and keyword arguments can be used interchangably like in Python,
|
11
|
-
# making it easier to port code
|
12
|
-
|
13
|
-
module Torch
|
14
|
-
module Native
|
15
|
-
module Dispatcher
|
16
|
-
class << self
|
17
|
-
def bind
|
18
|
-
functions = Generator.grouped_functions
|
19
|
-
bind_functions(::Torch, :define_singleton_method, functions[:torch])
|
20
|
-
bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
|
21
|
-
bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
|
22
|
-
end
|
23
|
-
|
24
|
-
def bind_functions(context, def_method, functions)
|
25
|
-
instance_method = def_method == :define_method
|
26
|
-
functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
|
27
|
-
if instance_method
|
28
|
-
funcs.map! { |f| Function.new(f.function) }
|
29
|
-
funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
|
30
|
-
end
|
31
|
-
|
32
|
-
defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
|
33
|
-
next if defined && name != "clone"
|
34
|
-
|
35
|
-
# skip parser when possible for performance
|
36
|
-
if funcs.size == 1 && funcs.first.args.size == 0
|
37
|
-
# functions with no arguments
|
38
|
-
if instance_method
|
39
|
-
context.send(:alias_method, name, funcs.first.cpp_name)
|
40
|
-
else
|
41
|
-
context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
|
42
|
-
end
|
43
|
-
elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
|
44
|
-
# functions that take a tensor or scalar
|
45
|
-
scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
|
46
|
-
context.send(def_method, name) do |other|
|
47
|
-
case other
|
48
|
-
when Tensor
|
49
|
-
send(tensor_name, other)
|
50
|
-
else
|
51
|
-
send(scalar_name, other)
|
52
|
-
end
|
53
|
-
end
|
54
|
-
else
|
55
|
-
parser = Parser.new(funcs)
|
56
|
-
|
57
|
-
context.send(def_method, name) do |*args, **options|
|
58
|
-
result = parser.parse(args, options)
|
59
|
-
raise ArgumentError, result[:error] if result[:error]
|
60
|
-
send(result[:name], *result[:args])
|
61
|
-
end
|
62
|
-
end
|
63
|
-
end
|
64
|
-
end
|
65
|
-
end
|
66
|
-
end
|
67
|
-
end
|
68
|
-
end
|
69
|
-
|
70
|
-
Torch::Native::Dispatcher.bind
|
@@ -1,200 +0,0 @@
|
|
1
|
-
module Torch
|
2
|
-
module Native
|
3
|
-
class Function
|
4
|
-
attr_reader :function, :tensor_options
|
5
|
-
|
6
|
-
def initialize(function)
|
7
|
-
@function = function
|
8
|
-
|
9
|
-
# note: don't modify function in-place
|
10
|
-
@tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
11
|
-
@tensor_options = @function["func"].include?(@tensor_options_str)
|
12
|
-
@out = out_size > 0 && base_name[-1] != "_"
|
13
|
-
end
|
14
|
-
|
15
|
-
def func
|
16
|
-
@func ||= @function["func"]
|
17
|
-
end
|
18
|
-
|
19
|
-
def name
|
20
|
-
@name ||= func.split("(", 2).first
|
21
|
-
end
|
22
|
-
|
23
|
-
def python_module
|
24
|
-
@python_module ||= @function["python_module"]
|
25
|
-
end
|
26
|
-
|
27
|
-
def variants
|
28
|
-
@variants ||= (@function["variants"] || "function").split(", ")
|
29
|
-
end
|
30
|
-
|
31
|
-
def args
|
32
|
-
@args ||= begin
|
33
|
-
args = []
|
34
|
-
pos = true
|
35
|
-
args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
|
36
|
-
args_str.split(", ").each do |a|
|
37
|
-
if a == "*"
|
38
|
-
pos = false
|
39
|
-
next
|
40
|
-
end
|
41
|
-
t, _, k = a.rpartition(" ")
|
42
|
-
k, d = k.split("=")
|
43
|
-
has_default = !d.nil?
|
44
|
-
|
45
|
-
if d
|
46
|
-
d =
|
47
|
-
case d
|
48
|
-
when "True"
|
49
|
-
true
|
50
|
-
when "False"
|
51
|
-
false
|
52
|
-
when "None"
|
53
|
-
nil
|
54
|
-
when /\A\-?\d+\z/
|
55
|
-
d.to_i
|
56
|
-
when "[]"
|
57
|
-
[]
|
58
|
-
when "[0,1]"
|
59
|
-
[0, 1]
|
60
|
-
when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
|
61
|
-
d.to_f
|
62
|
-
when "Mean"
|
63
|
-
"mean"
|
64
|
-
when "contiguous_format"
|
65
|
-
d
|
66
|
-
when "long"
|
67
|
-
:long
|
68
|
-
else
|
69
|
-
raise "Unknown default: #{d}"
|
70
|
-
end
|
71
|
-
end
|
72
|
-
|
73
|
-
next if t == "Generator?"
|
74
|
-
next if t == "MemoryFormat"
|
75
|
-
next if t == "MemoryFormat?"
|
76
|
-
args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
|
77
|
-
end
|
78
|
-
args
|
79
|
-
end
|
80
|
-
end
|
81
|
-
|
82
|
-
def arg_checkers
|
83
|
-
@arg_checkers ||= begin
|
84
|
-
checkers = {}
|
85
|
-
arg_types.each do |k, t|
|
86
|
-
checker =
|
87
|
-
case t
|
88
|
-
when "Tensor"
|
89
|
-
->(v) { v.is_a?(Tensor) }
|
90
|
-
when "Tensor?"
|
91
|
-
->(v) { v.nil? || v.is_a?(Tensor) }
|
92
|
-
when "Tensor[]", "Tensor?[]"
|
93
|
-
->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
|
94
|
-
when "int"
|
95
|
-
if k == :reduction
|
96
|
-
->(v) { v.is_a?(String) }
|
97
|
-
else
|
98
|
-
->(v) { v.is_a?(Integer) }
|
99
|
-
end
|
100
|
-
when "int?"
|
101
|
-
->(v) { v.is_a?(Integer) || v.nil? }
|
102
|
-
when "float?"
|
103
|
-
->(v) { v.is_a?(Numeric) || v.nil? }
|
104
|
-
when "bool?"
|
105
|
-
->(v) { v == true || v == false || v.nil? }
|
106
|
-
when "float"
|
107
|
-
->(v) { v.is_a?(Numeric) }
|
108
|
-
when /int\[.*\]/
|
109
|
-
->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
|
110
|
-
when "Scalar"
|
111
|
-
->(v) { v.is_a?(Numeric) }
|
112
|
-
when "Scalar?"
|
113
|
-
->(v) { v.is_a?(Numeric) || v.nil? }
|
114
|
-
when "ScalarType"
|
115
|
-
->(v) { false } # not supported yet
|
116
|
-
when "ScalarType?"
|
117
|
-
->(v) { v.nil? }
|
118
|
-
when "bool"
|
119
|
-
->(v) { v == true || v == false }
|
120
|
-
when "str"
|
121
|
-
->(v) { v.is_a?(String) }
|
122
|
-
else
|
123
|
-
raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
|
124
|
-
end
|
125
|
-
checkers[k] = checker
|
126
|
-
end
|
127
|
-
checkers
|
128
|
-
end
|
129
|
-
end
|
130
|
-
|
131
|
-
def int_array_lengths
|
132
|
-
@int_array_lengths ||= begin
|
133
|
-
ret = {}
|
134
|
-
arg_types.each do |k, t|
|
135
|
-
if t.match?(/\Aint\[.+\]\z/)
|
136
|
-
size = t[4..-2]
|
137
|
-
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
138
|
-
ret[k] = size.to_i
|
139
|
-
end
|
140
|
-
end
|
141
|
-
ret
|
142
|
-
end
|
143
|
-
end
|
144
|
-
|
145
|
-
def arg_names
|
146
|
-
@arg_names ||= args.map { |a| a[:name] }
|
147
|
-
end
|
148
|
-
|
149
|
-
def arg_types
|
150
|
-
@arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
|
151
|
-
end
|
152
|
-
|
153
|
-
def arg_defaults
|
154
|
-
# TODO find out why can't use select here
|
155
|
-
@arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
|
156
|
-
end
|
157
|
-
|
158
|
-
def out_size
|
159
|
-
@out_size ||= func.split("->").last.count("!")
|
160
|
-
end
|
161
|
-
|
162
|
-
def ret_size
|
163
|
-
@ret_size ||= func.split("->").last.split(", ").size
|
164
|
-
end
|
165
|
-
|
166
|
-
def ret_array?
|
167
|
-
@ret_array ||= func.split("->").last.include?('[]')
|
168
|
-
end
|
169
|
-
|
170
|
-
def ret_void?
|
171
|
-
func.split("->").last.strip == "()"
|
172
|
-
end
|
173
|
-
|
174
|
-
def out?
|
175
|
-
@out
|
176
|
-
end
|
177
|
-
|
178
|
-
def ruby_name
|
179
|
-
@ruby_name ||= begin
|
180
|
-
name = base_name
|
181
|
-
if name.end_with?("_")
|
182
|
-
"#{name[0..-2]}!"
|
183
|
-
elsif name.start_with?("is_")
|
184
|
-
"#{name[3..-1]}?"
|
185
|
-
else
|
186
|
-
name
|
187
|
-
end
|
188
|
-
end
|
189
|
-
end
|
190
|
-
|
191
|
-
def cpp_name
|
192
|
-
@cpp_name ||= "_" + name.downcase.sub(".", "_")
|
193
|
-
end
|
194
|
-
|
195
|
-
def base_name
|
196
|
-
@base_name ||= name.split(".").first
|
197
|
-
end
|
198
|
-
end
|
199
|
-
end
|
200
|
-
end
|