torch-rb 0.4.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Linear < Module
4
+ attr_reader :in_features, :out_features
5
+
4
6
  def initialize(in_features, out_features, bias: true)
5
7
  super()
6
8
  @in_features = in_features
@@ -113,35 +113,53 @@ module Torch
113
113
  forward(*input, **kwargs)
114
114
  end
115
115
 
116
- def state_dict(destination: nil)
116
+ def state_dict(destination: nil, prefix: "")
117
117
  destination ||= {}
118
- named_parameters.each do |k, v|
119
- destination[k] = v
118
+ save_to_state_dict(destination, prefix: prefix)
119
+
120
+ named_children.each do |name, mod|
121
+ next unless mod
122
+ mod.state_dict(destination: destination, prefix: prefix + name + ".")
120
123
  end
121
124
  destination
122
125
  end
123
126
 
124
- # TODO add strict option
125
- # TODO match PyTorch behavior
126
- def load_state_dict(state_dict)
127
- state_dict.each do |k, input_param|
128
- k1, k2 = k.split(".", 2)
129
- mod = named_modules[k1]
130
- if mod.is_a?(Module)
131
- param = mod.named_parameters[k2]
132
- if param.is_a?(Parameter)
133
- Torch.no_grad do
134
- param.copy!(input_param)
135
- end
136
- else
137
- raise Error, "Unknown parameter: #{k1}"
138
- end
139
- else
140
- raise Error, "Unknown module: #{k1}"
127
+ def load_state_dict(state_dict, strict: true)
128
+ # TODO support strict: false
129
+ raise "strict: false not implemented yet" unless strict
130
+
131
+ missing_keys = []
132
+ unexpected_keys = []
133
+ error_msgs = []
134
+
135
+ # TODO handle metadata
136
+
137
+ _load = lambda do |mod, prefix = ""|
138
+ # TODO handle metadata
139
+ local_metadata = {}
140
+ mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
141
+ mod.named_children.each do |name, child|
142
+ _load.call(child, prefix + name + ".") unless child.nil?
143
+ end
144
+ end
145
+
146
+ _load.call(self)
147
+
148
+ if strict
149
+ if unexpected_keys.any?
150
+ error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
151
+ end
152
+
153
+ if missing_keys.any?
154
+ error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
141
155
  end
142
156
  end
143
157
 
144
- # TODO return missing keys and unexpected keys
158
+ if error_msgs.any?
159
+ # just show first error
160
+ raise Error, error_msgs[0]
161
+ end
162
+
145
163
  nil
146
164
  end
147
165
 
@@ -268,6 +286,12 @@ module Torch
268
286
  named_buffers[name]
269
287
  elsif named_modules.key?(name)
270
288
  named_modules[name]
289
+ elsif method.end_with?("=") && named_modules.key?(method[0..-2])
290
+ if instance_variable_defined?("@#{method[0..-2]}")
291
+ instance_variable_set("@#{method[0..-2]}", *args)
292
+ else
293
+ raise NotImplementedYet
294
+ end
271
295
  else
272
296
  super
273
297
  end
@@ -300,6 +324,68 @@ module Torch
300
324
  def dict
301
325
  instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
302
326
  end
327
+
328
+ def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
329
+ # TODO add hooks
330
+
331
+ # TODO handle non-persistent buffers
332
+ persistent_buffers = named_buffers
333
+ local_name_params = named_parameters(recurse: false).merge(persistent_buffers)
334
+ local_state = local_name_params.select { |_, v| !v.nil? }
335
+
336
+ local_state.each do |name, param|
337
+ key = prefix + name
338
+ if state_dict.key?(key)
339
+ input_param = state_dict[key]
340
+
341
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
342
+ if param.shape.length == 0 && input_param.shape.length == 1
343
+ input_param = input_param[0]
344
+ end
345
+
346
+ if input_param.shape != param.shape
347
+ # local shape should match the one in checkpoint
348
+ error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " +
349
+ "the shape in current model is #{param.shape}."
350
+ next
351
+ end
352
+
353
+ begin
354
+ Torch.no_grad do
355
+ param.copy!(input_param)
356
+ end
357
+ rescue => e
358
+ error_msgs << "While copying the parameter named #{key.inspect}, " +
359
+ "whose dimensions in the model are #{param.size} and " +
360
+ "whose dimensions in the checkpoint are #{input_param.size}, " +
361
+ "an exception occurred: #{e.inspect}"
362
+ end
363
+ elsif strict
364
+ missing_keys << key
365
+ end
366
+ end
367
+
368
+ if strict
369
+ state_dict.each_key do |key|
370
+ if key.start_with?(prefix)
371
+ input_name = key[prefix.length..-1]
372
+ input_name = input_name.split(".", 2)[0]
373
+ if !named_children.key?(input_name) && !local_state.key?(input_name)
374
+ unexpected_keys << key
375
+ end
376
+ end
377
+ end
378
+ end
379
+ end
380
+
381
+ def save_to_state_dict(destination, prefix: "")
382
+ named_parameters(recurse: false).each do |k, v|
383
+ destination[prefix + k] = v
384
+ end
385
+ named_buffers.each do |k, v|
386
+ destination[prefix + k] = v
387
+ end
388
+ end
303
389
  end
304
390
  end
305
391
  end
@@ -3,7 +3,7 @@ module Torch
3
3
  class Parameter < Tensor
4
4
  def self.new(data = nil, requires_grad: true)
5
5
  data = Tensor.new unless data
6
- Tensor._make_subclass(data, requires_grad)
6
+ _make_subclass(data, requires_grad)
7
7
  end
8
8
 
9
9
  def inspect
@@ -42,11 +42,11 @@ module Torch
42
42
  grad = grad.add(p.data, alpha: group[:weight_decay])
43
43
  end
44
44
 
45
- square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
45
+ square_avg.mul!(rho).addcmul!(grad, grad, value: 1 - rho)
46
46
  std = square_avg.add(eps).sqrt!
47
47
  delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
48
48
  p.data.add!(delta, alpha: -group[:lr])
49
- acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
49
+ acc_delta.mul!(rho).addcmul!(delta, delta, value: 1 - rho)
50
50
  end
51
51
  end
52
52
 
@@ -57,9 +57,9 @@ module Torch
57
57
  if grad.sparse?
58
58
  raise NotImplementedYet
59
59
  else
60
- state[:sum].addcmul!(1, grad, grad)
60
+ state[:sum].addcmul!(grad, grad, value: 1)
61
61
  std = state[:sum].sqrt.add!(group[:eps])
62
- p.data.addcdiv!(-clr, grad, std)
62
+ p.data.addcdiv!(grad, std, value: -clr)
63
63
  end
64
64
  end
65
65
  end
@@ -58,7 +58,7 @@ module Torch
58
58
 
59
59
  # Decay the first and second moment running average coefficient
60
60
  exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
61
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
61
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
62
62
  if amsgrad
63
63
  # Maintains the maximum of all 2nd moment running avg. till now
64
64
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -70,7 +70,7 @@ module Torch
70
70
 
71
71
  step_size = group[:lr] / bias_correction1
72
72
 
73
- p.data.addcdiv!(-step_size, exp_avg, denom)
73
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
74
74
  end
75
75
  end
76
76
 
@@ -57,7 +57,7 @@ module Torch
57
57
  bias_correction = 1 - beta1 ** state[:step]
58
58
  clr = group[:lr] / bias_correction
59
59
 
60
- p.data.addcdiv!(-clr, exp_avg, exp_inf)
60
+ p.data.addcdiv!(exp_avg, exp_inf, value: -clr)
61
61
  end
62
62
  end
63
63
 
@@ -59,7 +59,7 @@ module Torch
59
59
 
60
60
  # Decay the first and second moment running average coefficient
61
61
  exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
62
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
62
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
63
63
  if amsgrad
64
64
  # Maintains the maximum of all 2nd moment running avg. till now
65
65
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -71,7 +71,7 @@ module Torch
71
71
 
72
72
  step_size = group[:lr] / bias_correction1
73
73
 
74
- p.data.addcdiv!(-step_size, exp_avg, denom)
74
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
75
75
  end
76
76
  end
77
77
 
@@ -49,7 +49,7 @@ module Torch
49
49
  grad = grad.add(p.data, alpha: group[:weight_decay])
50
50
  end
51
51
 
52
- square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
52
+ square_avg.mul!(alpha).addcmul!(grad, grad, value: 1 - alpha)
53
53
 
54
54
  if group[:centered]
55
55
  grad_avg = state[:grad_avg]
@@ -61,10 +61,10 @@ module Torch
61
61
 
62
62
  if group[:momentum] > 0
63
63
  buf = state[:momentum_buffer]
64
- buf.mul!(group[:momentum]).addcdiv!(1, grad, avg)
64
+ buf.mul!(group[:momentum]).addcdiv!(grad, avg, value: 1)
65
65
  p.data.add!(buf, alpha: -group[:lr])
66
66
  else
67
- p.data.addcdiv!(-group[:lr], grad, avg)
67
+ p.data.addcdiv!(grad, avg, value: -group[:lr])
68
68
  end
69
69
  end
70
70
  end
@@ -52,7 +52,7 @@ module Torch
52
52
  grad[sign.eq(etaminus)] = 0
53
53
 
54
54
  # update parameters
55
- p.data.addcmul!(-1, grad.sign, step_size)
55
+ p.data.addcmul!(grad.sign, step_size, value: -1)
56
56
 
57
57
  state[:prev].copy!(grad)
58
58
  end
data/lib/torch/tensor.rb CHANGED
@@ -135,6 +135,10 @@ module Torch
135
135
  Torch.ones_like(Torch.empty(*size), **options)
136
136
  end
137
137
 
138
+ def requires_grad=(requires_grad)
139
+ _requires_grad!(requires_grad)
140
+ end
141
+
138
142
  def requires_grad!(requires_grad = true)
139
143
  _requires_grad!(requires_grad)
140
144
  end
@@ -174,5 +178,10 @@ module Torch
174
178
  return _random!(0, *args) if args.size == 1
175
179
  _random!(*args)
176
180
  end
181
+
182
+ # center option
183
+ def stft(*args)
184
+ Torch.stft(*args)
185
+ end
177
186
  end
178
187
  end
@@ -60,7 +60,7 @@ module Torch
60
60
  when Array
61
61
  batch.transpose.map { |v| default_convert(v) }
62
62
  else
63
- raise NotImplementedYet
63
+ batch
64
64
  end
65
65
  end
66
66
 
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.4.2"
2
+ VERSION = "0.6.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.2
4
+ version: 0.6.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-10-28 00:00:00.000000000 Z
11
+ date: 2021-03-26 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -24,92 +24,8 @@ dependencies:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
26
  version: '2.2'
27
- - !ruby/object:Gem::Dependency
28
- name: bundler
29
- requirement: !ruby/object:Gem::Requirement
30
- requirements:
31
- - - ">="
32
- - !ruby/object:Gem::Version
33
- version: '0'
34
- type: :development
35
- prerelease: false
36
- version_requirements: !ruby/object:Gem::Requirement
37
- requirements:
38
- - - ">="
39
- - !ruby/object:Gem::Version
40
- version: '0'
41
- - !ruby/object:Gem::Dependency
42
- name: rake
43
- requirement: !ruby/object:Gem::Requirement
44
- requirements:
45
- - - ">="
46
- - !ruby/object:Gem::Version
47
- version: '0'
48
- type: :development
49
- prerelease: false
50
- version_requirements: !ruby/object:Gem::Requirement
51
- requirements:
52
- - - ">="
53
- - !ruby/object:Gem::Version
54
- version: '0'
55
- - !ruby/object:Gem::Dependency
56
- name: rake-compiler
57
- requirement: !ruby/object:Gem::Requirement
58
- requirements:
59
- - - ">="
60
- - !ruby/object:Gem::Version
61
- version: '0'
62
- type: :development
63
- prerelease: false
64
- version_requirements: !ruby/object:Gem::Requirement
65
- requirements:
66
- - - ">="
67
- - !ruby/object:Gem::Version
68
- version: '0'
69
- - !ruby/object:Gem::Dependency
70
- name: minitest
71
- requirement: !ruby/object:Gem::Requirement
72
- requirements:
73
- - - ">="
74
- - !ruby/object:Gem::Version
75
- version: '5'
76
- type: :development
77
- prerelease: false
78
- version_requirements: !ruby/object:Gem::Requirement
79
- requirements:
80
- - - ">="
81
- - !ruby/object:Gem::Version
82
- version: '5'
83
- - !ruby/object:Gem::Dependency
84
- name: numo-narray
85
- requirement: !ruby/object:Gem::Requirement
86
- requirements:
87
- - - ">="
88
- - !ruby/object:Gem::Version
89
- version: '0'
90
- type: :development
91
- prerelease: false
92
- version_requirements: !ruby/object:Gem::Requirement
93
- requirements:
94
- - - ">="
95
- - !ruby/object:Gem::Version
96
- version: '0'
97
- - !ruby/object:Gem::Dependency
98
- name: torchvision
99
- requirement: !ruby/object:Gem::Requirement
100
- requirements:
101
- - - ">="
102
- - !ruby/object:Gem::Version
103
- version: 0.1.1
104
- type: :development
105
- prerelease: false
106
- version_requirements: !ruby/object:Gem::Requirement
107
- requirements:
108
- - - ">="
109
- - !ruby/object:Gem::Version
110
- version: 0.1.1
111
27
  description:
112
- email: andrew@chartkick.com
28
+ email: andrew@ankane.org
113
29
  executables: []
114
30
  extensions:
115
31
  - ext/torch/extconf.rb
@@ -121,13 +37,20 @@ files:
121
37
  - codegen/function.rb
122
38
  - codegen/generate_functions.rb
123
39
  - codegen/native_functions.yaml
40
+ - ext/torch/cuda.cpp
41
+ - ext/torch/device.cpp
124
42
  - ext/torch/ext.cpp
125
43
  - ext/torch/extconf.rb
44
+ - ext/torch/ivalue.cpp
45
+ - ext/torch/nn.cpp
126
46
  - ext/torch/nn_functions.h
47
+ - ext/torch/random.cpp
127
48
  - ext/torch/ruby_arg_parser.cpp
128
49
  - ext/torch/ruby_arg_parser.h
129
50
  - ext/torch/templates.h
51
+ - ext/torch/tensor.cpp
130
52
  - ext/torch/tensor_functions.h
53
+ - ext/torch/torch.cpp
131
54
  - ext/torch/torch_functions.h
132
55
  - ext/torch/utils.h
133
56
  - ext/torch/wrap_outputs.h
@@ -282,14 +205,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
282
205
  requirements:
283
206
  - - ">="
284
207
  - !ruby/object:Gem::Version
285
- version: '2.4'
208
+ version: '2.6'
286
209
  required_rubygems_version: !ruby/object:Gem::Requirement
287
210
  requirements:
288
211
  - - ">="
289
212
  - !ruby/object:Gem::Version
290
213
  version: '0'
291
214
  requirements: []
292
- rubygems_version: 3.1.4
215
+ rubygems_version: 3.2.3
293
216
  signing_key:
294
217
  specification_version: 4
295
218
  summary: Deep learning for Ruby, powered by LibTorch