torch-rb 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9179470135a00453dcae9efbc6cd143112c5fec925bb5675a686e80e70b71b28
4
- data.tar.gz: de365f50021d75338a78bcb6e0733bb430759fe7c4fcaea96b1ff0ed2a4b8d5d
3
+ metadata.gz: b37dc4dd7be5806879c2fb5bb52ac94c8b16eba76bc2a3c591ca4cbe51cf8745
4
+ data.tar.gz: 6415d14f7cc8baa4db9205c709b70fec2b90b4dff1b60de97c299c2a7edfbf40
5
5
  SHA512:
6
- metadata.gz: 46d3d49aa63c0764d20178f450aa0b88c88938c194e70040650e6c1a29899e5f4d896671571730dc23eb1fe039ede53d2714db0f6fe7506ad4382653a5e6ec18
7
- data.tar.gz: 3fe47be264030fc2d84de85bb7d006337df37fb5c41b147332fe37dd21de7ba61bdc53a0f7ae9085e01c992a643b22d8216a0248b7a4d24487d88fd7f88a9ecf
6
+ metadata.gz: cf43cb21e18171f76f1291f2cccdb8a93141605fe33d4421eaf799a4589638d33da040b2ffad3ddce34fec60ab2b41edf1fa9a247d69de7f31e157063e57f331
7
+ data.tar.gz: 3c858de8e7eb6169359fad18104c08d9a4011a305f44c4a8213a289b005dc4439b1959a40effd8b1bba1c2e96a871b28f01695fdda56e6ad0aeed0c2334cfa25
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 0.2.1 (2020-04-26)
2
+
3
+ - Added support for saving and loading models
4
+ - Improved error messages
5
+ - Reduced gem size
6
+
1
7
  ## 0.2.0 (2020-04-22)
2
8
 
3
9
  - No longer experimental
data/README.md CHANGED
@@ -283,6 +283,22 @@ loss.backward
283
283
  optimizer.step
284
284
  ```
285
285
 
286
+ ### Saving and Loading Models
287
+
288
+ Save a model
289
+
290
+ ```ruby
291
+ Torch.save(net.state_dict, "net.pth")
292
+ ```
293
+
294
+ Load a model
295
+
296
+ ```ruby
297
+ net = Net.new
298
+ net.load_state_dict(Torch.load("net.pth"))
299
+ net.eval
300
+ ```
301
+
286
302
  ### Tensor Creation
287
303
 
288
304
  Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
@@ -445,6 +461,8 @@ bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
445
461
  bundle exec rake test
446
462
  ```
447
463
 
464
+ You can use [this script](https://gist.github.com/ankane/9b2b5fcbd66d6e4ccfeb9d73e529abe7) to test on GPUs with the AWS Deep Learning Base AMI (Ubuntu 18.04).
465
+
448
466
  Here are some good resources for contributors:
449
467
 
450
468
  - [PyTorch API](https://pytorch.org/docs/stable/torch.html)
data/ext/torch/ext.cpp CHANGED
@@ -5,6 +5,7 @@
5
5
  #include <rice/Array.hpp>
6
6
  #include <rice/Class.hpp>
7
7
  #include <rice/Constructor.hpp>
8
+ #include <rice/Hash.hpp>
8
9
 
9
10
  #include "templates.hpp"
10
11
 
@@ -22,6 +23,11 @@ class Parameter: public torch::autograd::Variable {
22
23
  Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
23
24
  };
24
25
 
26
+ void handle_error(c10::Error const & ex)
27
+ {
28
+ throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
29
+ }
30
+
25
31
  extern "C"
26
32
  void Init_ext()
27
33
  {
@@ -34,6 +40,108 @@ void Init_ext()
34
40
  Module rb_mNN = define_module_under(rb_mTorch, "NN");
35
41
  add_nn_functions(rb_mNN);
36
42
 
43
+ // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
44
+ Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
45
+ .define_constructor(Constructor<torch::IValue>())
46
+ .define_method("bool?", &torch::IValue::isBool)
47
+ .define_method("bool_list?", &torch::IValue::isBoolList)
48
+ .define_method("capsule?", &torch::IValue::isCapsule)
49
+ .define_method("custom_class?", &torch::IValue::isCustomClass)
50
+ .define_method("device?", &torch::IValue::isDevice)
51
+ .define_method("double?", &torch::IValue::isDouble)
52
+ .define_method("double_list?", &torch::IValue::isDoubleList)
53
+ .define_method("future?", &torch::IValue::isFuture)
54
+ // .define_method("generator?", &torch::IValue::isGenerator)
55
+ .define_method("generic_dict?", &torch::IValue::isGenericDict)
56
+ .define_method("list?", &torch::IValue::isList)
57
+ .define_method("int?", &torch::IValue::isInt)
58
+ .define_method("int_list?", &torch::IValue::isIntList)
59
+ .define_method("module?", &torch::IValue::isModule)
60
+ .define_method("none?", &torch::IValue::isNone)
61
+ .define_method("object?", &torch::IValue::isObject)
62
+ .define_method("ptr_type?", &torch::IValue::isPtrType)
63
+ .define_method("py_object?", &torch::IValue::isPyObject)
64
+ .define_method("r_ref?", &torch::IValue::isRRef)
65
+ .define_method("scalar?", &torch::IValue::isScalar)
66
+ .define_method("string?", &torch::IValue::isString)
67
+ .define_method("tensor?", &torch::IValue::isTensor)
68
+ .define_method("tensor_list?", &torch::IValue::isTensorList)
69
+ .define_method("tuple?", &torch::IValue::isTuple)
70
+ .define_method(
71
+ "to_bool",
72
+ *[](torch::IValue& self) {
73
+ return self.toBool();
74
+ })
75
+ .define_method(
76
+ "to_double",
77
+ *[](torch::IValue& self) {
78
+ return self.toDouble();
79
+ })
80
+ .define_method(
81
+ "to_int",
82
+ *[](torch::IValue& self) {
83
+ return self.toInt();
84
+ })
85
+ .define_method(
86
+ "to_string_ref",
87
+ *[](torch::IValue& self) {
88
+ return self.toStringRef();
89
+ })
90
+ .define_method(
91
+ "to_tensor",
92
+ *[](torch::IValue& self) {
93
+ return self.toTensor();
94
+ })
95
+ .define_method(
96
+ "to_generic_dict",
97
+ *[](torch::IValue& self) {
98
+ auto dict = self.toGenericDict();
99
+ Hash h;
100
+ for (auto& pair : dict) {
101
+ h[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
102
+ }
103
+ return h;
104
+ })
105
+ .define_singleton_method(
106
+ "from_tensor",
107
+ *[](torch::Tensor& v) {
108
+ return torch::IValue(v);
109
+ })
110
+ .define_singleton_method(
111
+ "from_string",
112
+ *[](String v) {
113
+ return torch::IValue(v.str());
114
+ })
115
+ .define_singleton_method(
116
+ "from_int",
117
+ *[](int64_t v) {
118
+ return torch::IValue(v);
119
+ })
120
+ .define_singleton_method(
121
+ "from_double",
122
+ *[](double v) {
123
+ return torch::IValue(v);
124
+ })
125
+ .define_singleton_method(
126
+ "from_bool",
127
+ *[](bool v) {
128
+ return torch::IValue(v);
129
+ })
130
+ // see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
131
+ // createGenericDict and toIValue
132
+ .define_singleton_method(
133
+ "from_dict",
134
+ *[](Hash obj) {
135
+ auto key_type = c10::AnyType::get();
136
+ auto value_type = c10::AnyType::get();
137
+ c10::impl::GenericDict elems(key_type, value_type);
138
+ elems.reserve(obj.size());
139
+ for (auto entry : obj) {
140
+ elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Object) entry.second));
141
+ }
142
+ return torch::IValue(std::move(elems));
143
+ });
144
+
37
145
  rb_mTorch.define_singleton_method(
38
146
  "grad_enabled?",
39
147
  *[]() {
@@ -113,11 +221,19 @@ void Init_ext()
113
221
  // begin operations
114
222
  .define_singleton_method(
115
223
  "_save",
116
- *[](const Tensor &value) {
224
+ *[](const torch::IValue &value) {
117
225
  auto v = torch::pickle_save(value);
118
226
  std::string str(v.begin(), v.end());
119
227
  return str;
120
228
  })
229
+ .define_singleton_method(
230
+ "_load",
231
+ *[](const std::string &s) {
232
+ std::vector<char> v;
233
+ std::copy(s.begin(), s.end(), std::back_inserter(v));
234
+ // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
235
+ return torch::pickle_load(v);
236
+ })
121
237
  .define_singleton_method(
122
238
  "_binary_cross_entropy_with_logits",
123
239
  *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
@@ -157,6 +273,7 @@ void Init_ext()
157
273
  });
158
274
 
159
275
  rb_cTensor
276
+ .add_handler<c10::Error>(handle_error)
160
277
  .define_method("cuda?", &torch::Tensor::is_cuda)
161
278
  .define_method("sparse?", &torch::Tensor::is_sparse)
162
279
  .define_method("quantized?", &torch::Tensor::is_quantized)
@@ -288,6 +405,7 @@ void Init_ext()
288
405
  });
289
406
 
290
407
  Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
408
+ .add_handler<c10::Error>(handle_error)
291
409
  .define_constructor(Constructor<torch::TensorOptions>())
292
410
  .define_method(
293
411
  "dtype",
@@ -311,13 +429,8 @@ void Init_ext()
311
429
  .define_method(
312
430
  "device",
313
431
  *[](torch::TensorOptions& self, std::string device) {
314
- try {
315
- // needed to catch exception
316
- torch::Device d(device);
317
- return self.device(d);
318
- } catch (const c10::Error& error) {
319
- throw std::runtime_error(error.what_without_backtrace());
320
- }
432
+ torch::Device d(device);
433
+ return self.device(d);
321
434
  })
322
435
  .define_method(
323
436
  "requires_grad",
data/ext/torch/extconf.rb CHANGED
@@ -10,19 +10,24 @@ $CXXFLAGS << " -D_GLIBCXX_USE_CXX11_ABI=1"
10
10
  # TODO check compiler name
11
11
  clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
12
12
 
13
+ # check omp first
13
14
  if have_library("omp") || have_library("gomp")
14
15
  $CXXFLAGS << " -DAT_PARALLEL_OPENMP=1"
15
16
  $CXXFLAGS << " -Xclang" if clang
16
17
  $CXXFLAGS << " -fopenmp"
17
18
  end
18
19
 
19
- # silence ruby/intern.h warning
20
- $CXXFLAGS << " -Wno-deprecated-register"
21
-
22
- # silence torch warnings
23
20
  if clang
21
+ # silence ruby/intern.h warning
22
+ $CXXFLAGS << " -Wno-deprecated-register"
23
+
24
+ # silence torch warnings
24
25
  $CXXFLAGS << " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
25
26
  else
27
+ # silence rice warnings
28
+ $CXXFLAGS << " -Wno-noexcept-type"
29
+
30
+ # silence torch warnings
26
31
  $CXXFLAGS << " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
27
32
  end
28
33
 
@@ -34,15 +39,20 @@ cuda_inc, cuda_lib = dir_config("cuda")
34
39
  cuda_inc ||= "/usr/local/cuda/include"
35
40
  cuda_lib ||= "/usr/local/cuda/lib64"
36
41
 
37
- with_cuda = Dir["#{lib}/*torch_cuda*"].any? && have_library("cuda") && have_library("cudnn")
42
+ $LDFLAGS << " -L#{lib}" if Dir.exist?(lib)
43
+ abort "LibTorch not found" unless have_library("torch")
44
+
45
+ with_cuda = false
46
+ if Dir["#{lib}/*torch_cuda*"].any?
47
+ $LDFLAGS << " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
48
+ with_cuda = have_library("cuda") && have_library("cudnn")
49
+ end
38
50
 
39
51
  $INCFLAGS << " -I#{inc}"
40
52
  $INCFLAGS << " -I#{inc}/torch/csrc/api/include"
41
53
 
42
54
  $LDFLAGS << " -Wl,-rpath,#{lib}"
43
55
  $LDFLAGS << ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
44
- $LDFLAGS << " -L#{lib}"
45
- $LDFLAGS << " -L#{cuda_lib}" if with_cuda
46
56
 
47
57
  # https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
48
58
  $LDFLAGS << " -lc10 -ltorch_cpu -ltorch"
@@ -67,8 +67,9 @@ module Torch
67
67
  self
68
68
  end
69
69
 
70
- def cuda(device: nil)
71
- _apply ->(t) { t.cuda(device) }
70
+ # TODO add device
71
+ def cuda
72
+ _apply ->(t) { t.cuda }
72
73
  end
73
74
 
74
75
  def cpu
@@ -112,8 +113,28 @@ module Torch
112
113
  destination
113
114
  end
114
115
 
116
+ # TODO add strict option
117
+ # TODO match PyTorch behavior
115
118
  def load_state_dict(state_dict)
116
- raise NotImplementedYet
119
+ state_dict.each do |k, input_param|
120
+ k1, k2 = k.split(".", 2)
121
+ mod = named_modules[k1]
122
+ if mod.is_a?(Module)
123
+ param = mod.named_parameters[k2]
124
+ if param.is_a?(Parameter)
125
+ Torch.no_grad do
126
+ param.copy!(input_param)
127
+ end
128
+ else
129
+ raise Error, "Unknown parameter: #{k1}"
130
+ end
131
+ else
132
+ raise Error, "Unknown module: #{k1}"
133
+ end
134
+ end
135
+
136
+ # TODO return missing keys and unexpected keys
137
+ nil
117
138
  end
118
139
 
119
140
  def parameters
data/lib/torch/tensor.rb CHANGED
@@ -37,6 +37,10 @@ module Torch
37
37
  to("cpu")
38
38
  end
39
39
 
40
+ def cuda
41
+ to("cuda")
42
+ end
43
+
40
44
  def size(dim = nil)
41
45
  if dim
42
46
  _size_int(dim)
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.0"
2
+ VERSION = "0.2.1"
3
3
  end
data/lib/torch.rb CHANGED
@@ -317,12 +317,11 @@ module Torch
317
317
  end
318
318
 
319
319
  def save(obj, f)
320
- raise NotImplementedYet unless obj.is_a?(Tensor)
321
- File.binwrite(f, _save(obj))
320
+ File.binwrite(f, _save(to_ivalue(obj)))
322
321
  end
323
322
 
324
323
  def load(f)
325
- raise NotImplementedYet
324
+ to_ruby(_load(File.binread(f)))
326
325
  end
327
326
 
328
327
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
@@ -447,6 +446,94 @@ module Torch
447
446
 
448
447
  private
449
448
 
449
+ def to_ivalue(obj)
450
+ case obj
451
+ when String
452
+ IValue.from_string(obj)
453
+ when Integer
454
+ IValue.from_int(obj)
455
+ when Tensor
456
+ IValue.from_tensor(obj)
457
+ when Float
458
+ IValue.from_double(obj)
459
+ when Hash
460
+ dict = {}
461
+ obj.each do |k, v|
462
+ dict[to_ivalue(k)] = to_ivalue(v)
463
+ end
464
+ IValue.from_dict(dict)
465
+ when true, false
466
+ IValue.from_bool(obj)
467
+ when nil
468
+ IValue.new
469
+ else
470
+ raise Error, "Unknown type: #{obj.class.name}"
471
+ end
472
+ end
473
+
474
+ def to_ruby(ivalue)
475
+ if ivalue.bool?
476
+ ivalue.to_bool
477
+ elsif ivalue.double?
478
+ ivalue.to_double
479
+ elsif ivalue.int?
480
+ ivalue.to_int
481
+ elsif ivalue.none?
482
+ nil
483
+ elsif ivalue.string?
484
+ ivalue.to_string_ref
485
+ elsif ivalue.tensor?
486
+ ivalue.to_tensor
487
+ elsif ivalue.generic_dict?
488
+ dict = {}
489
+ ivalue.to_generic_dict.each do |k, v|
490
+ dict[to_ruby(k)] = to_ruby(v)
491
+ end
492
+ dict
493
+ else
494
+ type =
495
+ if ivalue.capsule?
496
+ "Capsule"
497
+ elsif ivalue.custom_class?
498
+ "CustomClass"
499
+ elsif ivalue.tuple?
500
+ "Tuple"
501
+ elsif ivalue.future?
502
+ "Future"
503
+ elsif ivalue.r_ref?
504
+ "RRef"
505
+ elsif ivalue.int_list?
506
+ "IntList"
507
+ elsif ivalue.double_list?
508
+ "DoubleList"
509
+ elsif ivalue.bool_list?
510
+ "BoolList"
511
+ elsif ivalue.tensor_list?
512
+ "TensorList"
513
+ elsif ivalue.list?
514
+ "List"
515
+ elsif ivalue.object?
516
+ "Object"
517
+ elsif ivalue.module?
518
+ "Module"
519
+ elsif ivalue.py_object?
520
+ "PyObject"
521
+ elsif ivalue.scalar?
522
+ "Scalar"
523
+ elsif ivalue.device?
524
+ "Device"
525
+ # elsif ivalue.generator?
526
+ # "Generator"
527
+ elsif ivalue.ptr_type?
528
+ "PtrType"
529
+ else
530
+ "Unknown"
531
+ end
532
+
533
+ raise Error, "Unsupported type: #{type}"
534
+ end
535
+ end
536
+
450
537
  def tensor_size(size)
451
538
  size.flatten
452
539
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.0
4
+ version: 0.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-23 00:00:00.000000000 Z
11
+ date: 2020-04-27 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -120,17 +120,10 @@ files:
120
120
  - README.md
121
121
  - ext/torch/ext.cpp
122
122
  - ext/torch/extconf.rb
123
- - ext/torch/nn_functions.cpp
124
- - ext/torch/nn_functions.hpp
125
123
  - ext/torch/templates.cpp
126
124
  - ext/torch/templates.hpp
127
- - ext/torch/tensor_functions.cpp
128
- - ext/torch/tensor_functions.hpp
129
- - ext/torch/torch_functions.cpp
130
- - ext/torch/torch_functions.hpp
131
125
  - lib/torch-rb.rb
132
126
  - lib/torch.rb
133
- - lib/torch/ext.bundle
134
127
  - lib/torch/hub.rb
135
128
  - lib/torch/inspector.rb
136
129
  - lib/torch/native/dispatcher.rb