torch-rb 0.4.2 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +26 -0
- data/README.md +13 -3
- data/codegen/generate_functions.rb +20 -13
- data/codegen/native_functions.yaml +4129 -1521
- data/ext/torch/cuda.cpp +14 -0
- data/ext/torch/device.cpp +21 -0
- data/ext/torch/ext.cpp +17 -623
- data/ext/torch/extconf.rb +0 -1
- data/ext/torch/ivalue.cpp +134 -0
- data/ext/torch/nn.cpp +114 -0
- data/ext/torch/nn_functions.h +1 -1
- data/ext/torch/random.cpp +22 -0
- data/ext/torch/ruby_arg_parser.cpp +1 -1
- data/ext/torch/ruby_arg_parser.h +47 -7
- data/ext/torch/templates.h +3 -2
- data/ext/torch/tensor.cpp +307 -0
- data/ext/torch/tensor_functions.h +1 -1
- data/ext/torch/torch.cpp +86 -0
- data/ext/torch/torch_functions.h +1 -1
- data/ext/torch/utils.h +8 -1
- data/ext/torch/wrap_outputs.h +7 -0
- data/lib/torch.rb +14 -17
- data/lib/torch/nn/linear.rb +2 -0
- data/lib/torch/nn/module.rb +107 -21
- data/lib/torch/nn/parameter.rb +1 -1
- data/lib/torch/optim/adadelta.rb +2 -2
- data/lib/torch/optim/adagrad.rb +2 -2
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +2 -2
- data/lib/torch/optim/rmsprop.rb +3 -3
- data/lib/torch/optim/rprop.rb +1 -1
- data/lib/torch/tensor.rb +9 -0
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +12 -89
data/lib/torch/nn/linear.rb
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -113,35 +113,53 @@ module Torch
|
|
113
113
|
forward(*input, **kwargs)
|
114
114
|
end
|
115
115
|
|
116
|
-
def state_dict(destination: nil)
|
116
|
+
def state_dict(destination: nil, prefix: "")
|
117
117
|
destination ||= {}
|
118
|
-
|
119
|
-
|
118
|
+
save_to_state_dict(destination, prefix: prefix)
|
119
|
+
|
120
|
+
named_children.each do |name, mod|
|
121
|
+
next unless mod
|
122
|
+
mod.state_dict(destination: destination, prefix: prefix + name + ".")
|
120
123
|
end
|
121
124
|
destination
|
122
125
|
end
|
123
126
|
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
127
|
+
def load_state_dict(state_dict, strict: true)
|
128
|
+
# TODO support strict: false
|
129
|
+
raise "strict: false not implemented yet" unless strict
|
130
|
+
|
131
|
+
missing_keys = []
|
132
|
+
unexpected_keys = []
|
133
|
+
error_msgs = []
|
134
|
+
|
135
|
+
# TODO handle metadata
|
136
|
+
|
137
|
+
_load = lambda do |mod, prefix = ""|
|
138
|
+
# TODO handle metadata
|
139
|
+
local_metadata = {}
|
140
|
+
mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
|
141
|
+
mod.named_children.each do |name, child|
|
142
|
+
_load.call(child, prefix + name + ".") unless child.nil?
|
143
|
+
end
|
144
|
+
end
|
145
|
+
|
146
|
+
_load.call(self)
|
147
|
+
|
148
|
+
if strict
|
149
|
+
if unexpected_keys.any?
|
150
|
+
error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
|
151
|
+
end
|
152
|
+
|
153
|
+
if missing_keys.any?
|
154
|
+
error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
|
141
155
|
end
|
142
156
|
end
|
143
157
|
|
144
|
-
|
158
|
+
if error_msgs.any?
|
159
|
+
# just show first error
|
160
|
+
raise Error, error_msgs[0]
|
161
|
+
end
|
162
|
+
|
145
163
|
nil
|
146
164
|
end
|
147
165
|
|
@@ -268,6 +286,12 @@ module Torch
|
|
268
286
|
named_buffers[name]
|
269
287
|
elsif named_modules.key?(name)
|
270
288
|
named_modules[name]
|
289
|
+
elsif method.end_with?("=") && named_modules.key?(method[0..-2])
|
290
|
+
if instance_variable_defined?("@#{method[0..-2]}")
|
291
|
+
instance_variable_set("@#{method[0..-2]}", *args)
|
292
|
+
else
|
293
|
+
raise NotImplementedYet
|
294
|
+
end
|
271
295
|
else
|
272
296
|
super
|
273
297
|
end
|
@@ -300,6 +324,68 @@ module Torch
|
|
300
324
|
def dict
|
301
325
|
instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
302
326
|
end
|
327
|
+
|
328
|
+
def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
329
|
+
# TODO add hooks
|
330
|
+
|
331
|
+
# TODO handle non-persistent buffers
|
332
|
+
persistent_buffers = named_buffers
|
333
|
+
local_name_params = named_parameters(recurse: false).merge(persistent_buffers)
|
334
|
+
local_state = local_name_params.select { |_, v| !v.nil? }
|
335
|
+
|
336
|
+
local_state.each do |name, param|
|
337
|
+
key = prefix + name
|
338
|
+
if state_dict.key?(key)
|
339
|
+
input_param = state_dict[key]
|
340
|
+
|
341
|
+
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
342
|
+
if param.shape.length == 0 && input_param.shape.length == 1
|
343
|
+
input_param = input_param[0]
|
344
|
+
end
|
345
|
+
|
346
|
+
if input_param.shape != param.shape
|
347
|
+
# local shape should match the one in checkpoint
|
348
|
+
error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " +
|
349
|
+
"the shape in current model is #{param.shape}."
|
350
|
+
next
|
351
|
+
end
|
352
|
+
|
353
|
+
begin
|
354
|
+
Torch.no_grad do
|
355
|
+
param.copy!(input_param)
|
356
|
+
end
|
357
|
+
rescue => e
|
358
|
+
error_msgs << "While copying the parameter named #{key.inspect}, " +
|
359
|
+
"whose dimensions in the model are #{param.size} and " +
|
360
|
+
"whose dimensions in the checkpoint are #{input_param.size}, " +
|
361
|
+
"an exception occurred: #{e.inspect}"
|
362
|
+
end
|
363
|
+
elsif strict
|
364
|
+
missing_keys << key
|
365
|
+
end
|
366
|
+
end
|
367
|
+
|
368
|
+
if strict
|
369
|
+
state_dict.each_key do |key|
|
370
|
+
if key.start_with?(prefix)
|
371
|
+
input_name = key[prefix.length..-1]
|
372
|
+
input_name = input_name.split(".", 2)[0]
|
373
|
+
if !named_children.key?(input_name) && !local_state.key?(input_name)
|
374
|
+
unexpected_keys << key
|
375
|
+
end
|
376
|
+
end
|
377
|
+
end
|
378
|
+
end
|
379
|
+
end
|
380
|
+
|
381
|
+
def save_to_state_dict(destination, prefix: "")
|
382
|
+
named_parameters(recurse: false).each do |k, v|
|
383
|
+
destination[prefix + k] = v
|
384
|
+
end
|
385
|
+
named_buffers.each do |k, v|
|
386
|
+
destination[prefix + k] = v
|
387
|
+
end
|
388
|
+
end
|
303
389
|
end
|
304
390
|
end
|
305
391
|
end
|
data/lib/torch/nn/parameter.rb
CHANGED
data/lib/torch/optim/adadelta.rb
CHANGED
@@ -42,11 +42,11 @@ module Torch
|
|
42
42
|
grad = grad.add(p.data, alpha: group[:weight_decay])
|
43
43
|
end
|
44
44
|
|
45
|
-
square_avg.mul!(rho).addcmul!(1 - rho
|
45
|
+
square_avg.mul!(rho).addcmul!(grad, grad, value: 1 - rho)
|
46
46
|
std = square_avg.add(eps).sqrt!
|
47
47
|
delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
|
48
48
|
p.data.add!(delta, alpha: -group[:lr])
|
49
|
-
acc_delta.mul!(rho).addcmul!(1 - rho
|
49
|
+
acc_delta.mul!(rho).addcmul!(delta, delta, value: 1 - rho)
|
50
50
|
end
|
51
51
|
end
|
52
52
|
|
data/lib/torch/optim/adagrad.rb
CHANGED
@@ -57,9 +57,9 @@ module Torch
|
|
57
57
|
if grad.sparse?
|
58
58
|
raise NotImplementedYet
|
59
59
|
else
|
60
|
-
state[:sum].addcmul!(
|
60
|
+
state[:sum].addcmul!(grad, grad, value: 1)
|
61
61
|
std = state[:sum].sqrt.add!(group[:eps])
|
62
|
-
p.data.addcdiv!(
|
62
|
+
p.data.addcdiv!(grad, std, value: -clr)
|
63
63
|
end
|
64
64
|
end
|
65
65
|
end
|
data/lib/torch/optim/adam.rb
CHANGED
@@ -58,7 +58,7 @@ module Torch
|
|
58
58
|
|
59
59
|
# Decay the first and second moment running average coefficient
|
60
60
|
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
61
|
-
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2
|
61
|
+
exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
|
62
62
|
if amsgrad
|
63
63
|
# Maintains the maximum of all 2nd moment running avg. till now
|
64
64
|
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
@@ -70,7 +70,7 @@ module Torch
|
|
70
70
|
|
71
71
|
step_size = group[:lr] / bias_correction1
|
72
72
|
|
73
|
-
p.data.addcdiv!(
|
73
|
+
p.data.addcdiv!(exp_avg, denom, value: -step_size)
|
74
74
|
end
|
75
75
|
end
|
76
76
|
|
data/lib/torch/optim/adamax.rb
CHANGED
data/lib/torch/optim/adamw.rb
CHANGED
@@ -59,7 +59,7 @@ module Torch
|
|
59
59
|
|
60
60
|
# Decay the first and second moment running average coefficient
|
61
61
|
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
62
|
-
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2
|
62
|
+
exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
|
63
63
|
if amsgrad
|
64
64
|
# Maintains the maximum of all 2nd moment running avg. till now
|
65
65
|
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
@@ -71,7 +71,7 @@ module Torch
|
|
71
71
|
|
72
72
|
step_size = group[:lr] / bias_correction1
|
73
73
|
|
74
|
-
p.data.addcdiv!(
|
74
|
+
p.data.addcdiv!(exp_avg, denom, value: -step_size)
|
75
75
|
end
|
76
76
|
end
|
77
77
|
|
data/lib/torch/optim/rmsprop.rb
CHANGED
@@ -49,7 +49,7 @@ module Torch
|
|
49
49
|
grad = grad.add(p.data, alpha: group[:weight_decay])
|
50
50
|
end
|
51
51
|
|
52
|
-
square_avg.mul!(alpha).addcmul!(1 - alpha
|
52
|
+
square_avg.mul!(alpha).addcmul!(grad, grad, value: 1 - alpha)
|
53
53
|
|
54
54
|
if group[:centered]
|
55
55
|
grad_avg = state[:grad_avg]
|
@@ -61,10 +61,10 @@ module Torch
|
|
61
61
|
|
62
62
|
if group[:momentum] > 0
|
63
63
|
buf = state[:momentum_buffer]
|
64
|
-
buf.mul!(group[:momentum]).addcdiv!(
|
64
|
+
buf.mul!(group[:momentum]).addcdiv!(grad, avg, value: 1)
|
65
65
|
p.data.add!(buf, alpha: -group[:lr])
|
66
66
|
else
|
67
|
-
p.data.addcdiv!(-group[:lr]
|
67
|
+
p.data.addcdiv!(grad, avg, value: -group[:lr])
|
68
68
|
end
|
69
69
|
end
|
70
70
|
end
|
data/lib/torch/optim/rprop.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -135,6 +135,10 @@ module Torch
|
|
135
135
|
Torch.ones_like(Torch.empty(*size), **options)
|
136
136
|
end
|
137
137
|
|
138
|
+
def requires_grad=(requires_grad)
|
139
|
+
_requires_grad!(requires_grad)
|
140
|
+
end
|
141
|
+
|
138
142
|
def requires_grad!(requires_grad = true)
|
139
143
|
_requires_grad!(requires_grad)
|
140
144
|
end
|
@@ -174,5 +178,10 @@ module Torch
|
|
174
178
|
return _random!(0, *args) if args.size == 1
|
175
179
|
_random!(*args)
|
176
180
|
end
|
181
|
+
|
182
|
+
# center option
|
183
|
+
def stft(*args)
|
184
|
+
Torch.stft(*args)
|
185
|
+
end
|
177
186
|
end
|
178
187
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.6.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2021-03-26 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -24,92 +24,8 @@ dependencies:
|
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
26
|
version: '2.2'
|
27
|
-
- !ruby/object:Gem::Dependency
|
28
|
-
name: bundler
|
29
|
-
requirement: !ruby/object:Gem::Requirement
|
30
|
-
requirements:
|
31
|
-
- - ">="
|
32
|
-
- !ruby/object:Gem::Version
|
33
|
-
version: '0'
|
34
|
-
type: :development
|
35
|
-
prerelease: false
|
36
|
-
version_requirements: !ruby/object:Gem::Requirement
|
37
|
-
requirements:
|
38
|
-
- - ">="
|
39
|
-
- !ruby/object:Gem::Version
|
40
|
-
version: '0'
|
41
|
-
- !ruby/object:Gem::Dependency
|
42
|
-
name: rake
|
43
|
-
requirement: !ruby/object:Gem::Requirement
|
44
|
-
requirements:
|
45
|
-
- - ">="
|
46
|
-
- !ruby/object:Gem::Version
|
47
|
-
version: '0'
|
48
|
-
type: :development
|
49
|
-
prerelease: false
|
50
|
-
version_requirements: !ruby/object:Gem::Requirement
|
51
|
-
requirements:
|
52
|
-
- - ">="
|
53
|
-
- !ruby/object:Gem::Version
|
54
|
-
version: '0'
|
55
|
-
- !ruby/object:Gem::Dependency
|
56
|
-
name: rake-compiler
|
57
|
-
requirement: !ruby/object:Gem::Requirement
|
58
|
-
requirements:
|
59
|
-
- - ">="
|
60
|
-
- !ruby/object:Gem::Version
|
61
|
-
version: '0'
|
62
|
-
type: :development
|
63
|
-
prerelease: false
|
64
|
-
version_requirements: !ruby/object:Gem::Requirement
|
65
|
-
requirements:
|
66
|
-
- - ">="
|
67
|
-
- !ruby/object:Gem::Version
|
68
|
-
version: '0'
|
69
|
-
- !ruby/object:Gem::Dependency
|
70
|
-
name: minitest
|
71
|
-
requirement: !ruby/object:Gem::Requirement
|
72
|
-
requirements:
|
73
|
-
- - ">="
|
74
|
-
- !ruby/object:Gem::Version
|
75
|
-
version: '5'
|
76
|
-
type: :development
|
77
|
-
prerelease: false
|
78
|
-
version_requirements: !ruby/object:Gem::Requirement
|
79
|
-
requirements:
|
80
|
-
- - ">="
|
81
|
-
- !ruby/object:Gem::Version
|
82
|
-
version: '5'
|
83
|
-
- !ruby/object:Gem::Dependency
|
84
|
-
name: numo-narray
|
85
|
-
requirement: !ruby/object:Gem::Requirement
|
86
|
-
requirements:
|
87
|
-
- - ">="
|
88
|
-
- !ruby/object:Gem::Version
|
89
|
-
version: '0'
|
90
|
-
type: :development
|
91
|
-
prerelease: false
|
92
|
-
version_requirements: !ruby/object:Gem::Requirement
|
93
|
-
requirements:
|
94
|
-
- - ">="
|
95
|
-
- !ruby/object:Gem::Version
|
96
|
-
version: '0'
|
97
|
-
- !ruby/object:Gem::Dependency
|
98
|
-
name: torchvision
|
99
|
-
requirement: !ruby/object:Gem::Requirement
|
100
|
-
requirements:
|
101
|
-
- - ">="
|
102
|
-
- !ruby/object:Gem::Version
|
103
|
-
version: 0.1.1
|
104
|
-
type: :development
|
105
|
-
prerelease: false
|
106
|
-
version_requirements: !ruby/object:Gem::Requirement
|
107
|
-
requirements:
|
108
|
-
- - ">="
|
109
|
-
- !ruby/object:Gem::Version
|
110
|
-
version: 0.1.1
|
111
27
|
description:
|
112
|
-
email: andrew@
|
28
|
+
email: andrew@ankane.org
|
113
29
|
executables: []
|
114
30
|
extensions:
|
115
31
|
- ext/torch/extconf.rb
|
@@ -121,13 +37,20 @@ files:
|
|
121
37
|
- codegen/function.rb
|
122
38
|
- codegen/generate_functions.rb
|
123
39
|
- codegen/native_functions.yaml
|
40
|
+
- ext/torch/cuda.cpp
|
41
|
+
- ext/torch/device.cpp
|
124
42
|
- ext/torch/ext.cpp
|
125
43
|
- ext/torch/extconf.rb
|
44
|
+
- ext/torch/ivalue.cpp
|
45
|
+
- ext/torch/nn.cpp
|
126
46
|
- ext/torch/nn_functions.h
|
47
|
+
- ext/torch/random.cpp
|
127
48
|
- ext/torch/ruby_arg_parser.cpp
|
128
49
|
- ext/torch/ruby_arg_parser.h
|
129
50
|
- ext/torch/templates.h
|
51
|
+
- ext/torch/tensor.cpp
|
130
52
|
- ext/torch/tensor_functions.h
|
53
|
+
- ext/torch/torch.cpp
|
131
54
|
- ext/torch/torch_functions.h
|
132
55
|
- ext/torch/utils.h
|
133
56
|
- ext/torch/wrap_outputs.h
|
@@ -282,14 +205,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
282
205
|
requirements:
|
283
206
|
- - ">="
|
284
207
|
- !ruby/object:Gem::Version
|
285
|
-
version: '2.
|
208
|
+
version: '2.6'
|
286
209
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
287
210
|
requirements:
|
288
211
|
- - ">="
|
289
212
|
- !ruby/object:Gem::Version
|
290
213
|
version: '0'
|
291
214
|
requirements: []
|
292
|
-
rubygems_version: 3.
|
215
|
+
rubygems_version: 3.2.3
|
293
216
|
signing_key:
|
294
217
|
specification_version: 4
|
295
218
|
summary: Deep learning for Ruby, powered by LibTorch
|