torch-rb 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9179470135a00453dcae9efbc6cd143112c5fec925bb5675a686e80e70b71b28
4
- data.tar.gz: de365f50021d75338a78bcb6e0733bb430759fe7c4fcaea96b1ff0ed2a4b8d5d
3
+ metadata.gz: b37dc4dd7be5806879c2fb5bb52ac94c8b16eba76bc2a3c591ca4cbe51cf8745
4
+ data.tar.gz: 6415d14f7cc8baa4db9205c709b70fec2b90b4dff1b60de97c299c2a7edfbf40
5
5
  SHA512:
6
- metadata.gz: 46d3d49aa63c0764d20178f450aa0b88c88938c194e70040650e6c1a29899e5f4d896671571730dc23eb1fe039ede53d2714db0f6fe7506ad4382653a5e6ec18
7
- data.tar.gz: 3fe47be264030fc2d84de85bb7d006337df37fb5c41b147332fe37dd21de7ba61bdc53a0f7ae9085e01c992a643b22d8216a0248b7a4d24487d88fd7f88a9ecf
6
+ metadata.gz: cf43cb21e18171f76f1291f2cccdb8a93141605fe33d4421eaf799a4589638d33da040b2ffad3ddce34fec60ab2b41edf1fa9a247d69de7f31e157063e57f331
7
+ data.tar.gz: 3c858de8e7eb6169359fad18104c08d9a4011a305f44c4a8213a289b005dc4439b1959a40effd8b1bba1c2e96a871b28f01695fdda56e6ad0aeed0c2334cfa25
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 0.2.1 (2020-04-26)
2
+
3
+ - Added support for saving and loading models
4
+ - Improved error messages
5
+ - Reduced gem size
6
+
1
7
  ## 0.2.0 (2020-04-22)
2
8
 
3
9
  - No longer experimental
data/README.md CHANGED
@@ -283,6 +283,22 @@ loss.backward
283
283
  optimizer.step
284
284
  ```
285
285
 
286
+ ### Saving and Loading Models
287
+
288
+ Save a model
289
+
290
+ ```ruby
291
+ Torch.save(net.state_dict, "net.pth")
292
+ ```
293
+
294
+ Load a model
295
+
296
+ ```ruby
297
+ net = Net.new
298
+ net.load_state_dict(Torch.load("net.pth"))
299
+ net.eval
300
+ ```
301
+
286
302
  ### Tensor Creation
287
303
 
288
304
  Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
@@ -445,6 +461,8 @@ bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
445
461
  bundle exec rake test
446
462
  ```
447
463
 
464
+ You can use [this script](https://gist.github.com/ankane/9b2b5fcbd66d6e4ccfeb9d73e529abe7) to test on GPUs with the AWS Deep Learning Base AMI (Ubuntu 18.04).
465
+
448
466
  Here are some good resources for contributors:
449
467
 
450
468
  - [PyTorch API](https://pytorch.org/docs/stable/torch.html)
data/ext/torch/ext.cpp CHANGED
@@ -5,6 +5,7 @@
5
5
  #include <rice/Array.hpp>
6
6
  #include <rice/Class.hpp>
7
7
  #include <rice/Constructor.hpp>
8
+ #include <rice/Hash.hpp>
8
9
 
9
10
  #include "templates.hpp"
10
11
 
@@ -22,6 +23,11 @@ class Parameter: public torch::autograd::Variable {
22
23
  Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
23
24
  };
24
25
 
26
+ void handle_error(c10::Error const & ex)
27
+ {
28
+ throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
29
+ }
30
+
25
31
  extern "C"
26
32
  void Init_ext()
27
33
  {
@@ -34,6 +40,108 @@ void Init_ext()
34
40
  Module rb_mNN = define_module_under(rb_mTorch, "NN");
35
41
  add_nn_functions(rb_mNN);
36
42
 
43
+ // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
44
+ Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
45
+ .define_constructor(Constructor<torch::IValue>())
46
+ .define_method("bool?", &torch::IValue::isBool)
47
+ .define_method("bool_list?", &torch::IValue::isBoolList)
48
+ .define_method("capsule?", &torch::IValue::isCapsule)
49
+ .define_method("custom_class?", &torch::IValue::isCustomClass)
50
+ .define_method("device?", &torch::IValue::isDevice)
51
+ .define_method("double?", &torch::IValue::isDouble)
52
+ .define_method("double_list?", &torch::IValue::isDoubleList)
53
+ .define_method("future?", &torch::IValue::isFuture)
54
+ // .define_method("generator?", &torch::IValue::isGenerator)
55
+ .define_method("generic_dict?", &torch::IValue::isGenericDict)
56
+ .define_method("list?", &torch::IValue::isList)
57
+ .define_method("int?", &torch::IValue::isInt)
58
+ .define_method("int_list?", &torch::IValue::isIntList)
59
+ .define_method("module?", &torch::IValue::isModule)
60
+ .define_method("none?", &torch::IValue::isNone)
61
+ .define_method("object?", &torch::IValue::isObject)
62
+ .define_method("ptr_type?", &torch::IValue::isPtrType)
63
+ .define_method("py_object?", &torch::IValue::isPyObject)
64
+ .define_method("r_ref?", &torch::IValue::isRRef)
65
+ .define_method("scalar?", &torch::IValue::isScalar)
66
+ .define_method("string?", &torch::IValue::isString)
67
+ .define_method("tensor?", &torch::IValue::isTensor)
68
+ .define_method("tensor_list?", &torch::IValue::isTensorList)
69
+ .define_method("tuple?", &torch::IValue::isTuple)
70
+ .define_method(
71
+ "to_bool",
72
+ *[](torch::IValue& self) {
73
+ return self.toBool();
74
+ })
75
+ .define_method(
76
+ "to_double",
77
+ *[](torch::IValue& self) {
78
+ return self.toDouble();
79
+ })
80
+ .define_method(
81
+ "to_int",
82
+ *[](torch::IValue& self) {
83
+ return self.toInt();
84
+ })
85
+ .define_method(
86
+ "to_string_ref",
87
+ *[](torch::IValue& self) {
88
+ return self.toStringRef();
89
+ })
90
+ .define_method(
91
+ "to_tensor",
92
+ *[](torch::IValue& self) {
93
+ return self.toTensor();
94
+ })
95
+ .define_method(
96
+ "to_generic_dict",
97
+ *[](torch::IValue& self) {
98
+ auto dict = self.toGenericDict();
99
+ Hash h;
100
+ for (auto& pair : dict) {
101
+ h[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
102
+ }
103
+ return h;
104
+ })
105
+ .define_singleton_method(
106
+ "from_tensor",
107
+ *[](torch::Tensor& v) {
108
+ return torch::IValue(v);
109
+ })
110
+ .define_singleton_method(
111
+ "from_string",
112
+ *[](String v) {
113
+ return torch::IValue(v.str());
114
+ })
115
+ .define_singleton_method(
116
+ "from_int",
117
+ *[](int64_t v) {
118
+ return torch::IValue(v);
119
+ })
120
+ .define_singleton_method(
121
+ "from_double",
122
+ *[](double v) {
123
+ return torch::IValue(v);
124
+ })
125
+ .define_singleton_method(
126
+ "from_bool",
127
+ *[](bool v) {
128
+ return torch::IValue(v);
129
+ })
130
+ // see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
131
+ // createGenericDict and toIValue
132
+ .define_singleton_method(
133
+ "from_dict",
134
+ *[](Hash obj) {
135
+ auto key_type = c10::AnyType::get();
136
+ auto value_type = c10::AnyType::get();
137
+ c10::impl::GenericDict elems(key_type, value_type);
138
+ elems.reserve(obj.size());
139
+ for (auto entry : obj) {
140
+ elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Object) entry.second));
141
+ }
142
+ return torch::IValue(std::move(elems));
143
+ });
144
+
37
145
  rb_mTorch.define_singleton_method(
38
146
  "grad_enabled?",
39
147
  *[]() {
@@ -113,11 +221,19 @@ void Init_ext()
113
221
  // begin operations
114
222
  .define_singleton_method(
115
223
  "_save",
116
- *[](const Tensor &value) {
224
+ *[](const torch::IValue &value) {
117
225
  auto v = torch::pickle_save(value);
118
226
  std::string str(v.begin(), v.end());
119
227
  return str;
120
228
  })
229
+ .define_singleton_method(
230
+ "_load",
231
+ *[](const std::string &s) {
232
+ std::vector<char> v;
233
+ std::copy(s.begin(), s.end(), std::back_inserter(v));
234
+ // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
235
+ return torch::pickle_load(v);
236
+ })
121
237
  .define_singleton_method(
122
238
  "_binary_cross_entropy_with_logits",
123
239
  *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
@@ -157,6 +273,7 @@ void Init_ext()
157
273
  });
158
274
 
159
275
  rb_cTensor
276
+ .add_handler<c10::Error>(handle_error)
160
277
  .define_method("cuda?", &torch::Tensor::is_cuda)
161
278
  .define_method("sparse?", &torch::Tensor::is_sparse)
162
279
  .define_method("quantized?", &torch::Tensor::is_quantized)
@@ -288,6 +405,7 @@ void Init_ext()
288
405
  });
289
406
 
290
407
  Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
408
+ .add_handler<c10::Error>(handle_error)
291
409
  .define_constructor(Constructor<torch::TensorOptions>())
292
410
  .define_method(
293
411
  "dtype",
@@ -311,13 +429,8 @@ void Init_ext()
311
429
  .define_method(
312
430
  "device",
313
431
  *[](torch::TensorOptions& self, std::string device) {
314
- try {
315
- // needed to catch exception
316
- torch::Device d(device);
317
- return self.device(d);
318
- } catch (const c10::Error& error) {
319
- throw std::runtime_error(error.what_without_backtrace());
320
- }
432
+ torch::Device d(device);
433
+ return self.device(d);
321
434
  })
322
435
  .define_method(
323
436
  "requires_grad",
data/ext/torch/extconf.rb CHANGED
@@ -10,19 +10,24 @@ $CXXFLAGS << " -D_GLIBCXX_USE_CXX11_ABI=1"
10
10
  # TODO check compiler name
11
11
  clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
12
12
 
13
+ # check omp first
13
14
  if have_library("omp") || have_library("gomp")
14
15
  $CXXFLAGS << " -DAT_PARALLEL_OPENMP=1"
15
16
  $CXXFLAGS << " -Xclang" if clang
16
17
  $CXXFLAGS << " -fopenmp"
17
18
  end
18
19
 
19
- # silence ruby/intern.h warning
20
- $CXXFLAGS << " -Wno-deprecated-register"
21
-
22
- # silence torch warnings
23
20
  if clang
21
+ # silence ruby/intern.h warning
22
+ $CXXFLAGS << " -Wno-deprecated-register"
23
+
24
+ # silence torch warnings
24
25
  $CXXFLAGS << " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
25
26
  else
27
+ # silence rice warnings
28
+ $CXXFLAGS << " -Wno-noexcept-type"
29
+
30
+ # silence torch warnings
26
31
  $CXXFLAGS << " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
27
32
  end
28
33
 
@@ -34,15 +39,20 @@ cuda_inc, cuda_lib = dir_config("cuda")
34
39
  cuda_inc ||= "/usr/local/cuda/include"
35
40
  cuda_lib ||= "/usr/local/cuda/lib64"
36
41
 
37
- with_cuda = Dir["#{lib}/*torch_cuda*"].any? && have_library("cuda") && have_library("cudnn")
42
+ $LDFLAGS << " -L#{lib}" if Dir.exist?(lib)
43
+ abort "LibTorch not found" unless have_library("torch")
44
+
45
+ with_cuda = false
46
+ if Dir["#{lib}/*torch_cuda*"].any?
47
+ $LDFLAGS << " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
48
+ with_cuda = have_library("cuda") && have_library("cudnn")
49
+ end
38
50
 
39
51
  $INCFLAGS << " -I#{inc}"
40
52
  $INCFLAGS << " -I#{inc}/torch/csrc/api/include"
41
53
 
42
54
  $LDFLAGS << " -Wl,-rpath,#{lib}"
43
55
  $LDFLAGS << ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
44
- $LDFLAGS << " -L#{lib}"
45
- $LDFLAGS << " -L#{cuda_lib}" if with_cuda
46
56
 
47
57
  # https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
48
58
  $LDFLAGS << " -lc10 -ltorch_cpu -ltorch"
@@ -67,8 +67,9 @@ module Torch
67
67
  self
68
68
  end
69
69
 
70
- def cuda(device: nil)
71
- _apply ->(t) { t.cuda(device) }
70
+ # TODO add device
71
+ def cuda
72
+ _apply ->(t) { t.cuda }
72
73
  end
73
74
 
74
75
  def cpu
@@ -112,8 +113,28 @@ module Torch
112
113
  destination
113
114
  end
114
115
 
116
+ # TODO add strict option
117
+ # TODO match PyTorch behavior
115
118
  def load_state_dict(state_dict)
116
- raise NotImplementedYet
119
+ state_dict.each do |k, input_param|
120
+ k1, k2 = k.split(".", 2)
121
+ mod = named_modules[k1]
122
+ if mod.is_a?(Module)
123
+ param = mod.named_parameters[k2]
124
+ if param.is_a?(Parameter)
125
+ Torch.no_grad do
126
+ param.copy!(input_param)
127
+ end
128
+ else
129
+ raise Error, "Unknown parameter: #{k1}"
130
+ end
131
+ else
132
+ raise Error, "Unknown module: #{k1}"
133
+ end
134
+ end
135
+
136
+ # TODO return missing keys and unexpected keys
137
+ nil
117
138
  end
118
139
 
119
140
  def parameters
data/lib/torch/tensor.rb CHANGED
@@ -37,6 +37,10 @@ module Torch
37
37
  to("cpu")
38
38
  end
39
39
 
40
+ def cuda
41
+ to("cuda")
42
+ end
43
+
40
44
  def size(dim = nil)
41
45
  if dim
42
46
  _size_int(dim)
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.0"
2
+ VERSION = "0.2.1"
3
3
  end
data/lib/torch.rb CHANGED
@@ -317,12 +317,11 @@ module Torch
317
317
  end
318
318
 
319
319
  def save(obj, f)
320
- raise NotImplementedYet unless obj.is_a?(Tensor)
321
- File.binwrite(f, _save(obj))
320
+ File.binwrite(f, _save(to_ivalue(obj)))
322
321
  end
323
322
 
324
323
  def load(f)
325
- raise NotImplementedYet
324
+ to_ruby(_load(File.binread(f)))
326
325
  end
327
326
 
328
327
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
@@ -447,6 +446,94 @@ module Torch
447
446
 
448
447
  private
449
448
 
449
+ def to_ivalue(obj)
450
+ case obj
451
+ when String
452
+ IValue.from_string(obj)
453
+ when Integer
454
+ IValue.from_int(obj)
455
+ when Tensor
456
+ IValue.from_tensor(obj)
457
+ when Float
458
+ IValue.from_double(obj)
459
+ when Hash
460
+ dict = {}
461
+ obj.each do |k, v|
462
+ dict[to_ivalue(k)] = to_ivalue(v)
463
+ end
464
+ IValue.from_dict(dict)
465
+ when true, false
466
+ IValue.from_bool(obj)
467
+ when nil
468
+ IValue.new
469
+ else
470
+ raise Error, "Unknown type: #{obj.class.name}"
471
+ end
472
+ end
473
+
474
+ def to_ruby(ivalue)
475
+ if ivalue.bool?
476
+ ivalue.to_bool
477
+ elsif ivalue.double?
478
+ ivalue.to_double
479
+ elsif ivalue.int?
480
+ ivalue.to_int
481
+ elsif ivalue.none?
482
+ nil
483
+ elsif ivalue.string?
484
+ ivalue.to_string_ref
485
+ elsif ivalue.tensor?
486
+ ivalue.to_tensor
487
+ elsif ivalue.generic_dict?
488
+ dict = {}
489
+ ivalue.to_generic_dict.each do |k, v|
490
+ dict[to_ruby(k)] = to_ruby(v)
491
+ end
492
+ dict
493
+ else
494
+ type =
495
+ if ivalue.capsule?
496
+ "Capsule"
497
+ elsif ivalue.custom_class?
498
+ "CustomClass"
499
+ elsif ivalue.tuple?
500
+ "Tuple"
501
+ elsif ivalue.future?
502
+ "Future"
503
+ elsif ivalue.r_ref?
504
+ "RRef"
505
+ elsif ivalue.int_list?
506
+ "IntList"
507
+ elsif ivalue.double_list?
508
+ "DoubleList"
509
+ elsif ivalue.bool_list?
510
+ "BoolList"
511
+ elsif ivalue.tensor_list?
512
+ "TensorList"
513
+ elsif ivalue.list?
514
+ "List"
515
+ elsif ivalue.object?
516
+ "Object"
517
+ elsif ivalue.module?
518
+ "Module"
519
+ elsif ivalue.py_object?
520
+ "PyObject"
521
+ elsif ivalue.scalar?
522
+ "Scalar"
523
+ elsif ivalue.device?
524
+ "Device"
525
+ # elsif ivalue.generator?
526
+ # "Generator"
527
+ elsif ivalue.ptr_type?
528
+ "PtrType"
529
+ else
530
+ "Unknown"
531
+ end
532
+
533
+ raise Error, "Unsupported type: #{type}"
534
+ end
535
+ end
536
+
450
537
  def tensor_size(size)
451
538
  size.flatten
452
539
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.0
4
+ version: 0.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-23 00:00:00.000000000 Z
11
+ date: 2020-04-27 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -120,17 +120,10 @@ files:
120
120
  - README.md
121
121
  - ext/torch/ext.cpp
122
122
  - ext/torch/extconf.rb
123
- - ext/torch/nn_functions.cpp
124
- - ext/torch/nn_functions.hpp
125
123
  - ext/torch/templates.cpp
126
124
  - ext/torch/templates.hpp
127
- - ext/torch/tensor_functions.cpp
128
- - ext/torch/tensor_functions.hpp
129
- - ext/torch/torch_functions.cpp
130
- - ext/torch/torch_functions.hpp
131
125
  - lib/torch-rb.rb
132
126
  - lib/torch.rb
133
- - lib/torch/ext.bundle
134
127
  - lib/torch/hub.rb
135
128
  - lib/torch/inspector.rb
136
129
  - lib/torch/native/dispatcher.rb