torch-rb 0.2.0 → 0.2.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +18 -0
- data/ext/torch/ext.cpp +121 -8
- data/ext/torch/extconf.rb +17 -7
- data/lib/torch/nn/module.rb +24 -3
- data/lib/torch/tensor.rb +4 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +90 -3
- metadata +2 -9
- data/ext/torch/nn_functions.cpp +0 -560
- data/ext/torch/nn_functions.hpp +0 -6
- data/ext/torch/tensor_functions.cpp +0 -2085
- data/ext/torch/tensor_functions.hpp +0 -6
- data/ext/torch/torch_functions.cpp +0 -3175
- data/ext/torch/torch_functions.hpp +0 -6
- data/lib/torch/ext.bundle +0 -0
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: b37dc4dd7be5806879c2fb5bb52ac94c8b16eba76bc2a3c591ca4cbe51cf8745
|
4
|
+
data.tar.gz: 6415d14f7cc8baa4db9205c709b70fec2b90b4dff1b60de97c299c2a7edfbf40
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: cf43cb21e18171f76f1291f2cccdb8a93141605fe33d4421eaf799a4589638d33da040b2ffad3ddce34fec60ab2b41edf1fa9a247d69de7f31e157063e57f331
|
7
|
+
data.tar.gz: 3c858de8e7eb6169359fad18104c08d9a4011a305f44c4a8213a289b005dc4439b1959a40effd8b1bba1c2e96a871b28f01695fdda56e6ad0aeed0c2334cfa25
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -283,6 +283,22 @@ loss.backward
|
|
283
283
|
optimizer.step
|
284
284
|
```
|
285
285
|
|
286
|
+
### Saving and Loading Models
|
287
|
+
|
288
|
+
Save a model
|
289
|
+
|
290
|
+
```ruby
|
291
|
+
Torch.save(net.state_dict, "net.pth")
|
292
|
+
```
|
293
|
+
|
294
|
+
Load a model
|
295
|
+
|
296
|
+
```ruby
|
297
|
+
net = Net.new
|
298
|
+
net.load_state_dict(Torch.load("net.pth"))
|
299
|
+
net.eval
|
300
|
+
```
|
301
|
+
|
286
302
|
### Tensor Creation
|
287
303
|
|
288
304
|
Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
|
@@ -445,6 +461,8 @@ bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
|
|
445
461
|
bundle exec rake test
|
446
462
|
```
|
447
463
|
|
464
|
+
You can use [this script](https://gist.github.com/ankane/9b2b5fcbd66d6e4ccfeb9d73e529abe7) to test on GPUs with the AWS Deep Learning Base AMI (Ubuntu 18.04).
|
465
|
+
|
448
466
|
Here are some good resources for contributors:
|
449
467
|
|
450
468
|
- [PyTorch API](https://pytorch.org/docs/stable/torch.html)
|
data/ext/torch/ext.cpp
CHANGED
@@ -5,6 +5,7 @@
|
|
5
5
|
#include <rice/Array.hpp>
|
6
6
|
#include <rice/Class.hpp>
|
7
7
|
#include <rice/Constructor.hpp>
|
8
|
+
#include <rice/Hash.hpp>
|
8
9
|
|
9
10
|
#include "templates.hpp"
|
10
11
|
|
@@ -22,6 +23,11 @@ class Parameter: public torch::autograd::Variable {
|
|
22
23
|
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
23
24
|
};
|
24
25
|
|
26
|
+
void handle_error(c10::Error const & ex)
|
27
|
+
{
|
28
|
+
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
|
+
}
|
30
|
+
|
25
31
|
extern "C"
|
26
32
|
void Init_ext()
|
27
33
|
{
|
@@ -34,6 +40,108 @@ void Init_ext()
|
|
34
40
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
35
41
|
add_nn_functions(rb_mNN);
|
36
42
|
|
43
|
+
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
44
|
+
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
45
|
+
.define_constructor(Constructor<torch::IValue>())
|
46
|
+
.define_method("bool?", &torch::IValue::isBool)
|
47
|
+
.define_method("bool_list?", &torch::IValue::isBoolList)
|
48
|
+
.define_method("capsule?", &torch::IValue::isCapsule)
|
49
|
+
.define_method("custom_class?", &torch::IValue::isCustomClass)
|
50
|
+
.define_method("device?", &torch::IValue::isDevice)
|
51
|
+
.define_method("double?", &torch::IValue::isDouble)
|
52
|
+
.define_method("double_list?", &torch::IValue::isDoubleList)
|
53
|
+
.define_method("future?", &torch::IValue::isFuture)
|
54
|
+
// .define_method("generator?", &torch::IValue::isGenerator)
|
55
|
+
.define_method("generic_dict?", &torch::IValue::isGenericDict)
|
56
|
+
.define_method("list?", &torch::IValue::isList)
|
57
|
+
.define_method("int?", &torch::IValue::isInt)
|
58
|
+
.define_method("int_list?", &torch::IValue::isIntList)
|
59
|
+
.define_method("module?", &torch::IValue::isModule)
|
60
|
+
.define_method("none?", &torch::IValue::isNone)
|
61
|
+
.define_method("object?", &torch::IValue::isObject)
|
62
|
+
.define_method("ptr_type?", &torch::IValue::isPtrType)
|
63
|
+
.define_method("py_object?", &torch::IValue::isPyObject)
|
64
|
+
.define_method("r_ref?", &torch::IValue::isRRef)
|
65
|
+
.define_method("scalar?", &torch::IValue::isScalar)
|
66
|
+
.define_method("string?", &torch::IValue::isString)
|
67
|
+
.define_method("tensor?", &torch::IValue::isTensor)
|
68
|
+
.define_method("tensor_list?", &torch::IValue::isTensorList)
|
69
|
+
.define_method("tuple?", &torch::IValue::isTuple)
|
70
|
+
.define_method(
|
71
|
+
"to_bool",
|
72
|
+
*[](torch::IValue& self) {
|
73
|
+
return self.toBool();
|
74
|
+
})
|
75
|
+
.define_method(
|
76
|
+
"to_double",
|
77
|
+
*[](torch::IValue& self) {
|
78
|
+
return self.toDouble();
|
79
|
+
})
|
80
|
+
.define_method(
|
81
|
+
"to_int",
|
82
|
+
*[](torch::IValue& self) {
|
83
|
+
return self.toInt();
|
84
|
+
})
|
85
|
+
.define_method(
|
86
|
+
"to_string_ref",
|
87
|
+
*[](torch::IValue& self) {
|
88
|
+
return self.toStringRef();
|
89
|
+
})
|
90
|
+
.define_method(
|
91
|
+
"to_tensor",
|
92
|
+
*[](torch::IValue& self) {
|
93
|
+
return self.toTensor();
|
94
|
+
})
|
95
|
+
.define_method(
|
96
|
+
"to_generic_dict",
|
97
|
+
*[](torch::IValue& self) {
|
98
|
+
auto dict = self.toGenericDict();
|
99
|
+
Hash h;
|
100
|
+
for (auto& pair : dict) {
|
101
|
+
h[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
|
102
|
+
}
|
103
|
+
return h;
|
104
|
+
})
|
105
|
+
.define_singleton_method(
|
106
|
+
"from_tensor",
|
107
|
+
*[](torch::Tensor& v) {
|
108
|
+
return torch::IValue(v);
|
109
|
+
})
|
110
|
+
.define_singleton_method(
|
111
|
+
"from_string",
|
112
|
+
*[](String v) {
|
113
|
+
return torch::IValue(v.str());
|
114
|
+
})
|
115
|
+
.define_singleton_method(
|
116
|
+
"from_int",
|
117
|
+
*[](int64_t v) {
|
118
|
+
return torch::IValue(v);
|
119
|
+
})
|
120
|
+
.define_singleton_method(
|
121
|
+
"from_double",
|
122
|
+
*[](double v) {
|
123
|
+
return torch::IValue(v);
|
124
|
+
})
|
125
|
+
.define_singleton_method(
|
126
|
+
"from_bool",
|
127
|
+
*[](bool v) {
|
128
|
+
return torch::IValue(v);
|
129
|
+
})
|
130
|
+
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
|
131
|
+
// createGenericDict and toIValue
|
132
|
+
.define_singleton_method(
|
133
|
+
"from_dict",
|
134
|
+
*[](Hash obj) {
|
135
|
+
auto key_type = c10::AnyType::get();
|
136
|
+
auto value_type = c10::AnyType::get();
|
137
|
+
c10::impl::GenericDict elems(key_type, value_type);
|
138
|
+
elems.reserve(obj.size());
|
139
|
+
for (auto entry : obj) {
|
140
|
+
elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Object) entry.second));
|
141
|
+
}
|
142
|
+
return torch::IValue(std::move(elems));
|
143
|
+
});
|
144
|
+
|
37
145
|
rb_mTorch.define_singleton_method(
|
38
146
|
"grad_enabled?",
|
39
147
|
*[]() {
|
@@ -113,11 +221,19 @@ void Init_ext()
|
|
113
221
|
// begin operations
|
114
222
|
.define_singleton_method(
|
115
223
|
"_save",
|
116
|
-
*[](const
|
224
|
+
*[](const torch::IValue &value) {
|
117
225
|
auto v = torch::pickle_save(value);
|
118
226
|
std::string str(v.begin(), v.end());
|
119
227
|
return str;
|
120
228
|
})
|
229
|
+
.define_singleton_method(
|
230
|
+
"_load",
|
231
|
+
*[](const std::string &s) {
|
232
|
+
std::vector<char> v;
|
233
|
+
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
234
|
+
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
235
|
+
return torch::pickle_load(v);
|
236
|
+
})
|
121
237
|
.define_singleton_method(
|
122
238
|
"_binary_cross_entropy_with_logits",
|
123
239
|
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
@@ -157,6 +273,7 @@ void Init_ext()
|
|
157
273
|
});
|
158
274
|
|
159
275
|
rb_cTensor
|
276
|
+
.add_handler<c10::Error>(handle_error)
|
160
277
|
.define_method("cuda?", &torch::Tensor::is_cuda)
|
161
278
|
.define_method("sparse?", &torch::Tensor::is_sparse)
|
162
279
|
.define_method("quantized?", &torch::Tensor::is_quantized)
|
@@ -288,6 +405,7 @@ void Init_ext()
|
|
288
405
|
});
|
289
406
|
|
290
407
|
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
|
408
|
+
.add_handler<c10::Error>(handle_error)
|
291
409
|
.define_constructor(Constructor<torch::TensorOptions>())
|
292
410
|
.define_method(
|
293
411
|
"dtype",
|
@@ -311,13 +429,8 @@ void Init_ext()
|
|
311
429
|
.define_method(
|
312
430
|
"device",
|
313
431
|
*[](torch::TensorOptions& self, std::string device) {
|
314
|
-
|
315
|
-
|
316
|
-
torch::Device d(device);
|
317
|
-
return self.device(d);
|
318
|
-
} catch (const c10::Error& error) {
|
319
|
-
throw std::runtime_error(error.what_without_backtrace());
|
320
|
-
}
|
432
|
+
torch::Device d(device);
|
433
|
+
return self.device(d);
|
321
434
|
})
|
322
435
|
.define_method(
|
323
436
|
"requires_grad",
|
data/ext/torch/extconf.rb
CHANGED
@@ -10,19 +10,24 @@ $CXXFLAGS << " -D_GLIBCXX_USE_CXX11_ABI=1"
|
|
10
10
|
# TODO check compiler name
|
11
11
|
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
12
12
|
|
13
|
+
# check omp first
|
13
14
|
if have_library("omp") || have_library("gomp")
|
14
15
|
$CXXFLAGS << " -DAT_PARALLEL_OPENMP=1"
|
15
16
|
$CXXFLAGS << " -Xclang" if clang
|
16
17
|
$CXXFLAGS << " -fopenmp"
|
17
18
|
end
|
18
19
|
|
19
|
-
# silence ruby/intern.h warning
|
20
|
-
$CXXFLAGS << " -Wno-deprecated-register"
|
21
|
-
|
22
|
-
# silence torch warnings
|
23
20
|
if clang
|
21
|
+
# silence ruby/intern.h warning
|
22
|
+
$CXXFLAGS << " -Wno-deprecated-register"
|
23
|
+
|
24
|
+
# silence torch warnings
|
24
25
|
$CXXFLAGS << " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
|
25
26
|
else
|
27
|
+
# silence rice warnings
|
28
|
+
$CXXFLAGS << " -Wno-noexcept-type"
|
29
|
+
|
30
|
+
# silence torch warnings
|
26
31
|
$CXXFLAGS << " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
|
27
32
|
end
|
28
33
|
|
@@ -34,15 +39,20 @@ cuda_inc, cuda_lib = dir_config("cuda")
|
|
34
39
|
cuda_inc ||= "/usr/local/cuda/include"
|
35
40
|
cuda_lib ||= "/usr/local/cuda/lib64"
|
36
41
|
|
37
|
-
|
42
|
+
$LDFLAGS << " -L#{lib}" if Dir.exist?(lib)
|
43
|
+
abort "LibTorch not found" unless have_library("torch")
|
44
|
+
|
45
|
+
with_cuda = false
|
46
|
+
if Dir["#{lib}/*torch_cuda*"].any?
|
47
|
+
$LDFLAGS << " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
|
48
|
+
with_cuda = have_library("cuda") && have_library("cudnn")
|
49
|
+
end
|
38
50
|
|
39
51
|
$INCFLAGS << " -I#{inc}"
|
40
52
|
$INCFLAGS << " -I#{inc}/torch/csrc/api/include"
|
41
53
|
|
42
54
|
$LDFLAGS << " -Wl,-rpath,#{lib}"
|
43
55
|
$LDFLAGS << ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
44
|
-
$LDFLAGS << " -L#{lib}"
|
45
|
-
$LDFLAGS << " -L#{cuda_lib}" if with_cuda
|
46
56
|
|
47
57
|
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
48
58
|
$LDFLAGS << " -lc10 -ltorch_cpu -ltorch"
|
data/lib/torch/nn/module.rb
CHANGED
@@ -67,8 +67,9 @@ module Torch
|
|
67
67
|
self
|
68
68
|
end
|
69
69
|
|
70
|
-
|
71
|
-
|
70
|
+
# TODO add device
|
71
|
+
def cuda
|
72
|
+
_apply ->(t) { t.cuda }
|
72
73
|
end
|
73
74
|
|
74
75
|
def cpu
|
@@ -112,8 +113,28 @@ module Torch
|
|
112
113
|
destination
|
113
114
|
end
|
114
115
|
|
116
|
+
# TODO add strict option
|
117
|
+
# TODO match PyTorch behavior
|
115
118
|
def load_state_dict(state_dict)
|
116
|
-
|
119
|
+
state_dict.each do |k, input_param|
|
120
|
+
k1, k2 = k.split(".", 2)
|
121
|
+
mod = named_modules[k1]
|
122
|
+
if mod.is_a?(Module)
|
123
|
+
param = mod.named_parameters[k2]
|
124
|
+
if param.is_a?(Parameter)
|
125
|
+
Torch.no_grad do
|
126
|
+
param.copy!(input_param)
|
127
|
+
end
|
128
|
+
else
|
129
|
+
raise Error, "Unknown parameter: #{k1}"
|
130
|
+
end
|
131
|
+
else
|
132
|
+
raise Error, "Unknown module: #{k1}"
|
133
|
+
end
|
134
|
+
end
|
135
|
+
|
136
|
+
# TODO return missing keys and unexpected keys
|
137
|
+
nil
|
117
138
|
end
|
118
139
|
|
119
140
|
def parameters
|
data/lib/torch/tensor.rb
CHANGED
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -317,12 +317,11 @@ module Torch
|
|
317
317
|
end
|
318
318
|
|
319
319
|
def save(obj, f)
|
320
|
-
|
321
|
-
File.binwrite(f, _save(obj))
|
320
|
+
File.binwrite(f, _save(to_ivalue(obj)))
|
322
321
|
end
|
323
322
|
|
324
323
|
def load(f)
|
325
|
-
|
324
|
+
to_ruby(_load(File.binread(f)))
|
326
325
|
end
|
327
326
|
|
328
327
|
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
@@ -447,6 +446,94 @@ module Torch
|
|
447
446
|
|
448
447
|
private
|
449
448
|
|
449
|
+
def to_ivalue(obj)
|
450
|
+
case obj
|
451
|
+
when String
|
452
|
+
IValue.from_string(obj)
|
453
|
+
when Integer
|
454
|
+
IValue.from_int(obj)
|
455
|
+
when Tensor
|
456
|
+
IValue.from_tensor(obj)
|
457
|
+
when Float
|
458
|
+
IValue.from_double(obj)
|
459
|
+
when Hash
|
460
|
+
dict = {}
|
461
|
+
obj.each do |k, v|
|
462
|
+
dict[to_ivalue(k)] = to_ivalue(v)
|
463
|
+
end
|
464
|
+
IValue.from_dict(dict)
|
465
|
+
when true, false
|
466
|
+
IValue.from_bool(obj)
|
467
|
+
when nil
|
468
|
+
IValue.new
|
469
|
+
else
|
470
|
+
raise Error, "Unknown type: #{obj.class.name}"
|
471
|
+
end
|
472
|
+
end
|
473
|
+
|
474
|
+
def to_ruby(ivalue)
|
475
|
+
if ivalue.bool?
|
476
|
+
ivalue.to_bool
|
477
|
+
elsif ivalue.double?
|
478
|
+
ivalue.to_double
|
479
|
+
elsif ivalue.int?
|
480
|
+
ivalue.to_int
|
481
|
+
elsif ivalue.none?
|
482
|
+
nil
|
483
|
+
elsif ivalue.string?
|
484
|
+
ivalue.to_string_ref
|
485
|
+
elsif ivalue.tensor?
|
486
|
+
ivalue.to_tensor
|
487
|
+
elsif ivalue.generic_dict?
|
488
|
+
dict = {}
|
489
|
+
ivalue.to_generic_dict.each do |k, v|
|
490
|
+
dict[to_ruby(k)] = to_ruby(v)
|
491
|
+
end
|
492
|
+
dict
|
493
|
+
else
|
494
|
+
type =
|
495
|
+
if ivalue.capsule?
|
496
|
+
"Capsule"
|
497
|
+
elsif ivalue.custom_class?
|
498
|
+
"CustomClass"
|
499
|
+
elsif ivalue.tuple?
|
500
|
+
"Tuple"
|
501
|
+
elsif ivalue.future?
|
502
|
+
"Future"
|
503
|
+
elsif ivalue.r_ref?
|
504
|
+
"RRef"
|
505
|
+
elsif ivalue.int_list?
|
506
|
+
"IntList"
|
507
|
+
elsif ivalue.double_list?
|
508
|
+
"DoubleList"
|
509
|
+
elsif ivalue.bool_list?
|
510
|
+
"BoolList"
|
511
|
+
elsif ivalue.tensor_list?
|
512
|
+
"TensorList"
|
513
|
+
elsif ivalue.list?
|
514
|
+
"List"
|
515
|
+
elsif ivalue.object?
|
516
|
+
"Object"
|
517
|
+
elsif ivalue.module?
|
518
|
+
"Module"
|
519
|
+
elsif ivalue.py_object?
|
520
|
+
"PyObject"
|
521
|
+
elsif ivalue.scalar?
|
522
|
+
"Scalar"
|
523
|
+
elsif ivalue.device?
|
524
|
+
"Device"
|
525
|
+
# elsif ivalue.generator?
|
526
|
+
# "Generator"
|
527
|
+
elsif ivalue.ptr_type?
|
528
|
+
"PtrType"
|
529
|
+
else
|
530
|
+
"Unknown"
|
531
|
+
end
|
532
|
+
|
533
|
+
raise Error, "Unsupported type: #{type}"
|
534
|
+
end
|
535
|
+
end
|
536
|
+
|
450
537
|
def tensor_size(size)
|
451
538
|
size.flatten
|
452
539
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-04-
|
11
|
+
date: 2020-04-27 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -120,17 +120,10 @@ files:
|
|
120
120
|
- README.md
|
121
121
|
- ext/torch/ext.cpp
|
122
122
|
- ext/torch/extconf.rb
|
123
|
-
- ext/torch/nn_functions.cpp
|
124
|
-
- ext/torch/nn_functions.hpp
|
125
123
|
- ext/torch/templates.cpp
|
126
124
|
- ext/torch/templates.hpp
|
127
|
-
- ext/torch/tensor_functions.cpp
|
128
|
-
- ext/torch/tensor_functions.hpp
|
129
|
-
- ext/torch/torch_functions.cpp
|
130
|
-
- ext/torch/torch_functions.hpp
|
131
125
|
- lib/torch-rb.rb
|
132
126
|
- lib/torch.rb
|
133
|
-
- lib/torch/ext.bundle
|
134
127
|
- lib/torch/hub.rb
|
135
128
|
- lib/torch/inspector.rb
|
136
129
|
- lib/torch/native/dispatcher.rb
|