torch-rb 0.4.1 → 0.5.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -44,8 +44,13 @@ std::vector<TensorIndex> index_vector(Array a) {
44
44
  if (obj.is_instance_of(rb_cInteger)) {
45
45
  indices.push_back(from_ruby<int64_t>(obj));
46
46
  } else if (obj.is_instance_of(rb_cRange)) {
47
- torch::optional<int64_t> start_index = from_ruby<int64_t>(obj.call("begin"));
48
- torch::optional<int64_t> stop_index = -1;
47
+ torch::optional<int64_t> start_index = torch::nullopt;
48
+ torch::optional<int64_t> stop_index = torch::nullopt;
49
+
50
+ Object begin = obj.call("begin");
51
+ if (!begin.is_nil()) {
52
+ start_index = from_ruby<int64_t>(begin);
53
+ }
49
54
 
50
55
  Object end = obj.call("end");
51
56
  if (!end.is_nil()) {
@@ -53,12 +58,14 @@ std::vector<TensorIndex> index_vector(Array a) {
53
58
  }
54
59
 
55
60
  Object exclude_end = obj.call("exclude_end?");
56
- if (!exclude_end) {
61
+ if (stop_index.has_value() && !exclude_end) {
57
62
  if (stop_index.value() == -1) {
58
63
  stop_index = torch::nullopt;
59
64
  } else {
60
65
  stop_index = stop_index.value() + 1;
61
66
  }
67
+ } else if (!stop_index.has_value() && exclude_end) {
68
+ stop_index = -1;
62
69
  }
63
70
 
64
71
  indices.push_back(torch::indexing::Slice(start_index, stop_index));
@@ -348,16 +355,6 @@ void Init_ext()
348
355
  *[](Tensor& self) {
349
356
  return self.is_contiguous();
350
357
  })
351
- .define_method(
352
- "addcmul!",
353
- *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
354
- return self.addcmul_(tensor1, tensor2, value);
355
- })
356
- .define_method(
357
- "addcdiv!",
358
- *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
359
- return self.addcdiv_(tensor1, tensor2, value);
360
- })
361
358
  .define_method(
362
359
  "_requires_grad!",
363
360
  *[](Tensor& self, bool requires_grad) {
@@ -372,7 +369,7 @@ void Init_ext()
372
369
  .define_method(
373
370
  "grad=",
374
371
  *[](Tensor& self, torch::Tensor& grad) {
375
- self.grad() = grad;
372
+ self.mutable_grad() = grad;
376
373
  })
377
374
  .define_method(
378
375
  "_dtype",
@@ -609,7 +606,7 @@ void Init_ext()
609
606
  .define_method(
610
607
  "grad=",
611
608
  *[](Parameter& self, torch::Tensor& grad) {
612
- self.grad() = grad;
609
+ self.mutable_grad() = grad;
613
610
  });
614
611
 
615
612
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
@@ -628,5 +625,7 @@ void Init_ext()
628
625
  Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
629
626
  .add_handler<torch::Error>(handle_error)
630
627
  .define_singleton_method("available?", &torch::cuda::is_available)
631
- .define_singleton_method("device_count", &torch::cuda::device_count);
628
+ .define_singleton_method("device_count", &torch::cuda::device_count)
629
+ .define_singleton_method("manual_seed", &torch::cuda::manual_seed)
630
+ .define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
632
631
  }
@@ -91,7 +91,7 @@ struct RubyArgs {
91
91
  inline c10::optional<int64_t> toInt64Optional(int i);
92
92
  inline c10::optional<bool> toBoolOptional(int i);
93
93
  inline c10::optional<double> toDoubleOptional(int i);
94
- // inline c10::OptionalArray<double> doublelistOptional(int i);
94
+ inline c10::OptionalArray<double> doublelistOptional(int i);
95
95
  // inline at::Layout layout(int i);
96
96
  // inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
97
97
  inline c10::optional<at::Layout> layoutOptional(int i);
@@ -105,7 +105,7 @@ struct RubyArgs {
105
105
  inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
106
106
  // inline at::QScheme toQScheme(int i);
107
107
  inline std::string string(int i);
108
- // inline c10::optional<std::string> stringOptional(int i);
108
+ inline c10::optional<std::string> stringOptional(int i);
109
109
  // inline PyObject* pyobject(int i);
110
110
  inline int64_t toInt64(int i);
111
111
  // inline int64_t toInt64WithDefault(int i, int64_t default_int);
@@ -249,6 +249,25 @@ inline c10::optional<double> RubyArgs::toDoubleOptional(int i) {
249
249
  return toDouble(i);
250
250
  }
251
251
 
252
+ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
253
+ if (NIL_P(args[i])) return {};
254
+
255
+ VALUE arg = args[i];
256
+ auto size = RARRAY_LEN(arg);
257
+ std::vector<double> res(size);
258
+ for (idx = 0; idx < size; idx++) {
259
+ VALUE obj = rb_ary_entry(arg, idx);
260
+ if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
261
+ res[idx] = from_ruby<double>(obj);
262
+ } else {
263
+ rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
264
+ signature.name.c_str(), signature.params[i].name.c_str(),
265
+ signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
266
+ }
267
+ }
268
+ return res;
269
+ }
270
+
252
271
  inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
253
272
  if (NIL_P(args[i])) return c10::nullopt;
254
273
 
@@ -285,6 +304,11 @@ inline std::string RubyArgs::string(int i) {
285
304
  return from_ruby<std::string>(args[i]);
286
305
  }
287
306
 
307
+ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
308
+ if (!args[i]) return c10::nullopt;
309
+ return from_ruby<std::string>(args[i]);
310
+ }
311
+
288
312
  inline int64_t RubyArgs::toInt64(int i) {
289
313
  if (NIL_P(args[i])) return signature.params[i].default_int;
290
314
  return from_ruby<int64_t>(args[i]);
@@ -19,6 +19,7 @@ using torch::TensorOptions;
19
19
  using torch::Layout;
20
20
  using torch::MemoryFormat;
21
21
  using torch::IntArrayRef;
22
+ using torch::ArrayRef;
22
23
  using torch::TensorList;
23
24
  using torch::Storage;
24
25
 
@@ -90,3 +90,10 @@ inline Object wrap(torch::TensorList x) {
90
90
  }
91
91
  return Object(a);
92
92
  }
93
+
94
+ inline Object wrap(std::tuple<double, double> x) {
95
+ Array a;
96
+ a.push(to_ruby<double>(std::get<0>(x)));
97
+ a.push(to_ruby<double>(std::get<1>(x)));
98
+ return Object(a);
99
+ }
@@ -261,6 +261,8 @@ module Torch
261
261
  Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
262
262
  elsif args.size == 1 && args.first.is_a?(Array)
263
263
  Torch.tensor(args.first, dtype: dtype, device: device)
264
+ elsif args.size == 0
265
+ Torch.empty(0, dtype: dtype, device: device)
264
266
  else
265
267
  Torch.empty(*args, dtype: dtype, device: device)
266
268
  end
@@ -434,7 +436,8 @@ module Torch
434
436
  zeros(input.size, **like_options(input, options))
435
437
  end
436
438
 
437
- def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
439
+ # center option
440
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
438
441
  if center
439
442
  signal_dim = input.dim
440
443
  extended_shape = [1] * (3 - signal_dim) + input.size
@@ -442,12 +445,7 @@ module Torch
442
445
  input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
443
446
  input = input.view(input.shape[-signal_dim..-1])
444
447
  end
445
- _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
446
- end
447
-
448
- def clamp(tensor, min, max)
449
- tensor = _clamp_min(tensor, min)
450
- _clamp_max(tensor, max)
448
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
451
449
  end
452
450
 
453
451
  private
@@ -113,35 +113,53 @@ module Torch
113
113
  forward(*input, **kwargs)
114
114
  end
115
115
 
116
- def state_dict(destination: nil)
116
+ def state_dict(destination: nil, prefix: "")
117
117
  destination ||= {}
118
- named_parameters.each do |k, v|
119
- destination[k] = v
118
+ save_to_state_dict(destination, prefix: prefix)
119
+
120
+ named_children.each do |name, mod|
121
+ next unless mod
122
+ mod.state_dict(destination: destination, prefix: prefix + name + ".")
120
123
  end
121
124
  destination
122
125
  end
123
126
 
124
- # TODO add strict option
125
- # TODO match PyTorch behavior
126
- def load_state_dict(state_dict)
127
- state_dict.each do |k, input_param|
128
- k1, k2 = k.split(".", 2)
129
- mod = named_modules[k1]
130
- if mod.is_a?(Module)
131
- param = mod.named_parameters[k2]
132
- if param.is_a?(Parameter)
133
- Torch.no_grad do
134
- param.copy!(input_param)
135
- end
136
- else
137
- raise Error, "Unknown parameter: #{k1}"
138
- end
139
- else
140
- raise Error, "Unknown module: #{k1}"
127
+ def load_state_dict(state_dict, strict: true)
128
+ # TODO support strict: false
129
+ raise "strict: false not implemented yet" unless strict
130
+
131
+ missing_keys = []
132
+ unexpected_keys = []
133
+ error_msgs = []
134
+
135
+ # TODO handle metadata
136
+
137
+ _load = lambda do |mod, prefix = ""|
138
+ # TODO handle metadata
139
+ local_metadata = {}
140
+ mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
141
+ mod.named_children.each do |name, child|
142
+ _load.call(child, prefix + name + ".") unless child.nil?
141
143
  end
142
144
  end
143
145
 
144
- # TODO return missing keys and unexpected keys
146
+ _load.call(self)
147
+
148
+ if strict
149
+ if unexpected_keys.any?
150
+ error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
151
+ end
152
+
153
+ if missing_keys.any?
154
+ error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
155
+ end
156
+ end
157
+
158
+ if error_msgs.any?
159
+ # just show first error
160
+ raise Error, error_msgs[0]
161
+ end
162
+
145
163
  nil
146
164
  end
147
165
 
@@ -300,6 +318,68 @@ module Torch
300
318
  def dict
301
319
  instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
302
320
  end
321
+
322
+ def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
323
+ # TODO add hooks
324
+
325
+ # TODO handle non-persistent buffers
326
+ persistent_buffers = named_buffers
327
+ local_name_params = named_parameters(recurse: false).merge(persistent_buffers)
328
+ local_state = local_name_params.select { |_, v| !v.nil? }
329
+
330
+ local_state.each do |name, param|
331
+ key = prefix + name
332
+ if state_dict.key?(key)
333
+ input_param = state_dict[key]
334
+
335
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
336
+ if param.shape.length == 0 && input_param.shape.length == 1
337
+ input_param = input_param[0]
338
+ end
339
+
340
+ if input_param.shape != param.shape
341
+ # local shape should match the one in checkpoint
342
+ error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " +
343
+ "the shape in current model is #{param.shape}."
344
+ next
345
+ end
346
+
347
+ begin
348
+ Torch.no_grad do
349
+ param.copy!(input_param)
350
+ end
351
+ rescue => e
352
+ error_msgs << "While copying the parameter named #{key.inspect}, " +
353
+ "whose dimensions in the model are #{param.size} and " +
354
+ "whose dimensions in the checkpoint are #{input_param.size}, " +
355
+ "an exception occurred: #{e.inspect}"
356
+ end
357
+ elsif strict
358
+ missing_keys << key
359
+ end
360
+ end
361
+
362
+ if strict
363
+ state_dict.each_key do |key|
364
+ if key.start_with?(prefix)
365
+ input_name = key[prefix.length..-1]
366
+ input_name = input_name.split(".", 2)[0]
367
+ if !named_children.key?(input_name) && !local_state.key?(input_name)
368
+ unexpected_keys << key
369
+ end
370
+ end
371
+ end
372
+ end
373
+ end
374
+
375
+ def save_to_state_dict(destination, prefix: "")
376
+ named_parameters(recurse: false).each do |k, v|
377
+ destination[prefix + k] = v
378
+ end
379
+ named_buffers.each do |k, v|
380
+ destination[prefix + k] = v
381
+ end
382
+ end
303
383
  end
304
384
  end
305
385
  end
@@ -39,14 +39,14 @@ module Torch
39
39
  state[:step] += 1
40
40
 
41
41
  if group[:weight_decay] != 0
42
- grad = grad.add(group[:weight_decay], p.data)
42
+ grad = grad.add(p.data, alpha: group[:weight_decay])
43
43
  end
44
44
 
45
- square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
45
+ square_avg.mul!(rho).addcmul!(grad, grad, value: 1 - rho)
46
46
  std = square_avg.add(eps).sqrt!
47
47
  delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
48
48
  p.data.add!(delta, alpha: -group[:lr])
49
- acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
49
+ acc_delta.mul!(rho).addcmul!(delta, delta, value: 1 - rho)
50
50
  end
51
51
  end
52
52
 
@@ -49,7 +49,7 @@ module Torch
49
49
  if p.grad.data.sparse?
50
50
  raise Error, "weight_decay option is not compatible with sparse gradients"
51
51
  end
52
- grad = grad.add(group[:weight_decay], p.data)
52
+ grad = grad.add(p.data, alpha: group[:weight_decay])
53
53
  end
54
54
 
55
55
  clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
@@ -57,9 +57,9 @@ module Torch
57
57
  if grad.sparse?
58
58
  raise NotImplementedYet
59
59
  else
60
- state[:sum].addcmul!(1, grad, grad)
60
+ state[:sum].addcmul!(grad, grad, value: 1)
61
61
  std = state[:sum].sqrt.add!(group[:eps])
62
- p.data.addcdiv!(-clr, grad, std)
62
+ p.data.addcdiv!(grad, std, value: -clr)
63
63
  end
64
64
  end
65
65
  end
@@ -58,7 +58,7 @@ module Torch
58
58
 
59
59
  # Decay the first and second moment running average coefficient
60
60
  exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
61
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
61
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
62
62
  if amsgrad
63
63
  # Maintains the maximum of all 2nd moment running avg. till now
64
64
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -70,7 +70,7 @@ module Torch
70
70
 
71
71
  step_size = group[:lr] / bias_correction1
72
72
 
73
- p.data.addcdiv!(-step_size, exp_avg, denom)
73
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
74
74
  end
75
75
  end
76
76
 
@@ -42,7 +42,7 @@ module Torch
42
42
  state[:step] += 1
43
43
 
44
44
  if group[:weight_decay] != 0
45
- grad = grad.add(group[:weight_decay], p.data)
45
+ grad = grad.add(p.data, alpha: group[:weight_decay])
46
46
  end
47
47
 
48
48
  # Update biased first moment estimate.
@@ -57,7 +57,7 @@ module Torch
57
57
  bias_correction = 1 - beta1 ** state[:step]
58
58
  clr = group[:lr] / bias_correction
59
59
 
60
- p.data.addcdiv!(-clr, exp_avg, exp_inf)
60
+ p.data.addcdiv!(exp_avg, exp_inf, value: -clr)
61
61
  end
62
62
  end
63
63
 
@@ -59,7 +59,7 @@ module Torch
59
59
 
60
60
  # Decay the first and second moment running average coefficient
61
61
  exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
62
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
62
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
63
63
  if amsgrad
64
64
  # Maintains the maximum of all 2nd moment running avg. till now
65
65
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -71,7 +71,7 @@ module Torch
71
71
 
72
72
  step_size = group[:lr] / bias_correction1
73
73
 
74
- p.data.addcdiv!(-step_size, exp_avg, denom)
74
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
75
75
  end
76
76
  end
77
77
 
@@ -36,7 +36,7 @@ module Torch
36
36
  state[:step] += 1
37
37
 
38
38
  if group[:weight_decay] != 0
39
- grad = grad.add(group[:weight_decay], p.data)
39
+ grad = grad.add(p.data, alpha: group[:weight_decay])
40
40
  end
41
41
 
42
42
  # decay term
@@ -46,25 +46,25 @@ module Torch
46
46
  state[:step] += 1
47
47
 
48
48
  if group[:weight_decay] != 0
49
- grad = grad.add(group[:weight_decay], p.data)
49
+ grad = grad.add(p.data, alpha: group[:weight_decay])
50
50
  end
51
51
 
52
- square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
52
+ square_avg.mul!(alpha).addcmul!(grad, grad, value: 1 - alpha)
53
53
 
54
54
  if group[:centered]
55
55
  grad_avg = state[:grad_avg]
56
- grad_avg.mul!(alpha).add!(1 - alpha, grad)
57
- avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
56
+ grad_avg.mul!(alpha).add!(grad, alpha: 1 - alpha)
57
+ avg = square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt!.add!(group[:eps])
58
58
  else
59
59
  avg = square_avg.sqrt.add!(group[:eps])
60
60
  end
61
61
 
62
62
  if group[:momentum] > 0
63
63
  buf = state[:momentum_buffer]
64
- buf.mul!(group[:momentum]).addcdiv!(grad, avg)
65
- p.data.add!(-group[:lr], buf)
64
+ buf.mul!(group[:momentum]).addcdiv!(grad, avg, value: 1)
65
+ p.data.add!(buf, alpha: -group[:lr])
66
66
  else
67
- p.data.addcdiv!(-group[:lr], grad, avg)
67
+ p.data.addcdiv!(grad, avg, value: -group[:lr])
68
68
  end
69
69
  end
70
70
  end