torch-rb 0.2.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +18 -0
- data/ext/torch/ext.cpp +121 -8
- data/ext/torch/extconf.rb +17 -7
- data/lib/torch/nn/module.rb +24 -3
- data/lib/torch/tensor.rb +4 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +90 -3
- metadata +2 -9
- data/ext/torch/nn_functions.cpp +0 -560
- data/ext/torch/nn_functions.hpp +0 -6
- data/ext/torch/tensor_functions.cpp +0 -2085
- data/ext/torch/tensor_functions.hpp +0 -6
- data/ext/torch/torch_functions.cpp +0 -3175
- data/ext/torch/torch_functions.hpp +0 -6
- data/lib/torch/ext.bundle +0 -0
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: b37dc4dd7be5806879c2fb5bb52ac94c8b16eba76bc2a3c591ca4cbe51cf8745
|
4
|
+
data.tar.gz: 6415d14f7cc8baa4db9205c709b70fec2b90b4dff1b60de97c299c2a7edfbf40
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: cf43cb21e18171f76f1291f2cccdb8a93141605fe33d4421eaf799a4589638d33da040b2ffad3ddce34fec60ab2b41edf1fa9a247d69de7f31e157063e57f331
|
7
|
+
data.tar.gz: 3c858de8e7eb6169359fad18104c08d9a4011a305f44c4a8213a289b005dc4439b1959a40effd8b1bba1c2e96a871b28f01695fdda56e6ad0aeed0c2334cfa25
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -283,6 +283,22 @@ loss.backward
|
|
283
283
|
optimizer.step
|
284
284
|
```
|
285
285
|
|
286
|
+
### Saving and Loading Models
|
287
|
+
|
288
|
+
Save a model
|
289
|
+
|
290
|
+
```ruby
|
291
|
+
Torch.save(net.state_dict, "net.pth")
|
292
|
+
```
|
293
|
+
|
294
|
+
Load a model
|
295
|
+
|
296
|
+
```ruby
|
297
|
+
net = Net.new
|
298
|
+
net.load_state_dict(Torch.load("net.pth"))
|
299
|
+
net.eval
|
300
|
+
```
|
301
|
+
|
286
302
|
### Tensor Creation
|
287
303
|
|
288
304
|
Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
|
@@ -445,6 +461,8 @@ bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
|
|
445
461
|
bundle exec rake test
|
446
462
|
```
|
447
463
|
|
464
|
+
You can use [this script](https://gist.github.com/ankane/9b2b5fcbd66d6e4ccfeb9d73e529abe7) to test on GPUs with the AWS Deep Learning Base AMI (Ubuntu 18.04).
|
465
|
+
|
448
466
|
Here are some good resources for contributors:
|
449
467
|
|
450
468
|
- [PyTorch API](https://pytorch.org/docs/stable/torch.html)
|
data/ext/torch/ext.cpp
CHANGED
@@ -5,6 +5,7 @@
|
|
5
5
|
#include <rice/Array.hpp>
|
6
6
|
#include <rice/Class.hpp>
|
7
7
|
#include <rice/Constructor.hpp>
|
8
|
+
#include <rice/Hash.hpp>
|
8
9
|
|
9
10
|
#include "templates.hpp"
|
10
11
|
|
@@ -22,6 +23,11 @@ class Parameter: public torch::autograd::Variable {
|
|
22
23
|
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
23
24
|
};
|
24
25
|
|
26
|
+
void handle_error(c10::Error const & ex)
|
27
|
+
{
|
28
|
+
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
|
+
}
|
30
|
+
|
25
31
|
extern "C"
|
26
32
|
void Init_ext()
|
27
33
|
{
|
@@ -34,6 +40,108 @@ void Init_ext()
|
|
34
40
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
35
41
|
add_nn_functions(rb_mNN);
|
36
42
|
|
43
|
+
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
44
|
+
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
45
|
+
.define_constructor(Constructor<torch::IValue>())
|
46
|
+
.define_method("bool?", &torch::IValue::isBool)
|
47
|
+
.define_method("bool_list?", &torch::IValue::isBoolList)
|
48
|
+
.define_method("capsule?", &torch::IValue::isCapsule)
|
49
|
+
.define_method("custom_class?", &torch::IValue::isCustomClass)
|
50
|
+
.define_method("device?", &torch::IValue::isDevice)
|
51
|
+
.define_method("double?", &torch::IValue::isDouble)
|
52
|
+
.define_method("double_list?", &torch::IValue::isDoubleList)
|
53
|
+
.define_method("future?", &torch::IValue::isFuture)
|
54
|
+
// .define_method("generator?", &torch::IValue::isGenerator)
|
55
|
+
.define_method("generic_dict?", &torch::IValue::isGenericDict)
|
56
|
+
.define_method("list?", &torch::IValue::isList)
|
57
|
+
.define_method("int?", &torch::IValue::isInt)
|
58
|
+
.define_method("int_list?", &torch::IValue::isIntList)
|
59
|
+
.define_method("module?", &torch::IValue::isModule)
|
60
|
+
.define_method("none?", &torch::IValue::isNone)
|
61
|
+
.define_method("object?", &torch::IValue::isObject)
|
62
|
+
.define_method("ptr_type?", &torch::IValue::isPtrType)
|
63
|
+
.define_method("py_object?", &torch::IValue::isPyObject)
|
64
|
+
.define_method("r_ref?", &torch::IValue::isRRef)
|
65
|
+
.define_method("scalar?", &torch::IValue::isScalar)
|
66
|
+
.define_method("string?", &torch::IValue::isString)
|
67
|
+
.define_method("tensor?", &torch::IValue::isTensor)
|
68
|
+
.define_method("tensor_list?", &torch::IValue::isTensorList)
|
69
|
+
.define_method("tuple?", &torch::IValue::isTuple)
|
70
|
+
.define_method(
|
71
|
+
"to_bool",
|
72
|
+
*[](torch::IValue& self) {
|
73
|
+
return self.toBool();
|
74
|
+
})
|
75
|
+
.define_method(
|
76
|
+
"to_double",
|
77
|
+
*[](torch::IValue& self) {
|
78
|
+
return self.toDouble();
|
79
|
+
})
|
80
|
+
.define_method(
|
81
|
+
"to_int",
|
82
|
+
*[](torch::IValue& self) {
|
83
|
+
return self.toInt();
|
84
|
+
})
|
85
|
+
.define_method(
|
86
|
+
"to_string_ref",
|
87
|
+
*[](torch::IValue& self) {
|
88
|
+
return self.toStringRef();
|
89
|
+
})
|
90
|
+
.define_method(
|
91
|
+
"to_tensor",
|
92
|
+
*[](torch::IValue& self) {
|
93
|
+
return self.toTensor();
|
94
|
+
})
|
95
|
+
.define_method(
|
96
|
+
"to_generic_dict",
|
97
|
+
*[](torch::IValue& self) {
|
98
|
+
auto dict = self.toGenericDict();
|
99
|
+
Hash h;
|
100
|
+
for (auto& pair : dict) {
|
101
|
+
h[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
|
102
|
+
}
|
103
|
+
return h;
|
104
|
+
})
|
105
|
+
.define_singleton_method(
|
106
|
+
"from_tensor",
|
107
|
+
*[](torch::Tensor& v) {
|
108
|
+
return torch::IValue(v);
|
109
|
+
})
|
110
|
+
.define_singleton_method(
|
111
|
+
"from_string",
|
112
|
+
*[](String v) {
|
113
|
+
return torch::IValue(v.str());
|
114
|
+
})
|
115
|
+
.define_singleton_method(
|
116
|
+
"from_int",
|
117
|
+
*[](int64_t v) {
|
118
|
+
return torch::IValue(v);
|
119
|
+
})
|
120
|
+
.define_singleton_method(
|
121
|
+
"from_double",
|
122
|
+
*[](double v) {
|
123
|
+
return torch::IValue(v);
|
124
|
+
})
|
125
|
+
.define_singleton_method(
|
126
|
+
"from_bool",
|
127
|
+
*[](bool v) {
|
128
|
+
return torch::IValue(v);
|
129
|
+
})
|
130
|
+
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
|
131
|
+
// createGenericDict and toIValue
|
132
|
+
.define_singleton_method(
|
133
|
+
"from_dict",
|
134
|
+
*[](Hash obj) {
|
135
|
+
auto key_type = c10::AnyType::get();
|
136
|
+
auto value_type = c10::AnyType::get();
|
137
|
+
c10::impl::GenericDict elems(key_type, value_type);
|
138
|
+
elems.reserve(obj.size());
|
139
|
+
for (auto entry : obj) {
|
140
|
+
elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Object) entry.second));
|
141
|
+
}
|
142
|
+
return torch::IValue(std::move(elems));
|
143
|
+
});
|
144
|
+
|
37
145
|
rb_mTorch.define_singleton_method(
|
38
146
|
"grad_enabled?",
|
39
147
|
*[]() {
|
@@ -113,11 +221,19 @@ void Init_ext()
|
|
113
221
|
// begin operations
|
114
222
|
.define_singleton_method(
|
115
223
|
"_save",
|
116
|
-
*[](const
|
224
|
+
*[](const torch::IValue &value) {
|
117
225
|
auto v = torch::pickle_save(value);
|
118
226
|
std::string str(v.begin(), v.end());
|
119
227
|
return str;
|
120
228
|
})
|
229
|
+
.define_singleton_method(
|
230
|
+
"_load",
|
231
|
+
*[](const std::string &s) {
|
232
|
+
std::vector<char> v;
|
233
|
+
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
234
|
+
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
235
|
+
return torch::pickle_load(v);
|
236
|
+
})
|
121
237
|
.define_singleton_method(
|
122
238
|
"_binary_cross_entropy_with_logits",
|
123
239
|
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
@@ -157,6 +273,7 @@ void Init_ext()
|
|
157
273
|
});
|
158
274
|
|
159
275
|
rb_cTensor
|
276
|
+
.add_handler<c10::Error>(handle_error)
|
160
277
|
.define_method("cuda?", &torch::Tensor::is_cuda)
|
161
278
|
.define_method("sparse?", &torch::Tensor::is_sparse)
|
162
279
|
.define_method("quantized?", &torch::Tensor::is_quantized)
|
@@ -288,6 +405,7 @@ void Init_ext()
|
|
288
405
|
});
|
289
406
|
|
290
407
|
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
|
408
|
+
.add_handler<c10::Error>(handle_error)
|
291
409
|
.define_constructor(Constructor<torch::TensorOptions>())
|
292
410
|
.define_method(
|
293
411
|
"dtype",
|
@@ -311,13 +429,8 @@ void Init_ext()
|
|
311
429
|
.define_method(
|
312
430
|
"device",
|
313
431
|
*[](torch::TensorOptions& self, std::string device) {
|
314
|
-
|
315
|
-
|
316
|
-
torch::Device d(device);
|
317
|
-
return self.device(d);
|
318
|
-
} catch (const c10::Error& error) {
|
319
|
-
throw std::runtime_error(error.what_without_backtrace());
|
320
|
-
}
|
432
|
+
torch::Device d(device);
|
433
|
+
return self.device(d);
|
321
434
|
})
|
322
435
|
.define_method(
|
323
436
|
"requires_grad",
|
data/ext/torch/extconf.rb
CHANGED
@@ -10,19 +10,24 @@ $CXXFLAGS << " -D_GLIBCXX_USE_CXX11_ABI=1"
|
|
10
10
|
# TODO check compiler name
|
11
11
|
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
12
12
|
|
13
|
+
# check omp first
|
13
14
|
if have_library("omp") || have_library("gomp")
|
14
15
|
$CXXFLAGS << " -DAT_PARALLEL_OPENMP=1"
|
15
16
|
$CXXFLAGS << " -Xclang" if clang
|
16
17
|
$CXXFLAGS << " -fopenmp"
|
17
18
|
end
|
18
19
|
|
19
|
-
# silence ruby/intern.h warning
|
20
|
-
$CXXFLAGS << " -Wno-deprecated-register"
|
21
|
-
|
22
|
-
# silence torch warnings
|
23
20
|
if clang
|
21
|
+
# silence ruby/intern.h warning
|
22
|
+
$CXXFLAGS << " -Wno-deprecated-register"
|
23
|
+
|
24
|
+
# silence torch warnings
|
24
25
|
$CXXFLAGS << " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
|
25
26
|
else
|
27
|
+
# silence rice warnings
|
28
|
+
$CXXFLAGS << " -Wno-noexcept-type"
|
29
|
+
|
30
|
+
# silence torch warnings
|
26
31
|
$CXXFLAGS << " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
|
27
32
|
end
|
28
33
|
|
@@ -34,15 +39,20 @@ cuda_inc, cuda_lib = dir_config("cuda")
|
|
34
39
|
cuda_inc ||= "/usr/local/cuda/include"
|
35
40
|
cuda_lib ||= "/usr/local/cuda/lib64"
|
36
41
|
|
37
|
-
|
42
|
+
$LDFLAGS << " -L#{lib}" if Dir.exist?(lib)
|
43
|
+
abort "LibTorch not found" unless have_library("torch")
|
44
|
+
|
45
|
+
with_cuda = false
|
46
|
+
if Dir["#{lib}/*torch_cuda*"].any?
|
47
|
+
$LDFLAGS << " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
|
48
|
+
with_cuda = have_library("cuda") && have_library("cudnn")
|
49
|
+
end
|
38
50
|
|
39
51
|
$INCFLAGS << " -I#{inc}"
|
40
52
|
$INCFLAGS << " -I#{inc}/torch/csrc/api/include"
|
41
53
|
|
42
54
|
$LDFLAGS << " -Wl,-rpath,#{lib}"
|
43
55
|
$LDFLAGS << ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
44
|
-
$LDFLAGS << " -L#{lib}"
|
45
|
-
$LDFLAGS << " -L#{cuda_lib}" if with_cuda
|
46
56
|
|
47
57
|
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
48
58
|
$LDFLAGS << " -lc10 -ltorch_cpu -ltorch"
|
data/lib/torch/nn/module.rb
CHANGED
@@ -67,8 +67,9 @@ module Torch
|
|
67
67
|
self
|
68
68
|
end
|
69
69
|
|
70
|
-
|
71
|
-
|
70
|
+
# TODO add device
|
71
|
+
def cuda
|
72
|
+
_apply ->(t) { t.cuda }
|
72
73
|
end
|
73
74
|
|
74
75
|
def cpu
|
@@ -112,8 +113,28 @@ module Torch
|
|
112
113
|
destination
|
113
114
|
end
|
114
115
|
|
116
|
+
# TODO add strict option
|
117
|
+
# TODO match PyTorch behavior
|
115
118
|
def load_state_dict(state_dict)
|
116
|
-
|
119
|
+
state_dict.each do |k, input_param|
|
120
|
+
k1, k2 = k.split(".", 2)
|
121
|
+
mod = named_modules[k1]
|
122
|
+
if mod.is_a?(Module)
|
123
|
+
param = mod.named_parameters[k2]
|
124
|
+
if param.is_a?(Parameter)
|
125
|
+
Torch.no_grad do
|
126
|
+
param.copy!(input_param)
|
127
|
+
end
|
128
|
+
else
|
129
|
+
raise Error, "Unknown parameter: #{k1}"
|
130
|
+
end
|
131
|
+
else
|
132
|
+
raise Error, "Unknown module: #{k1}"
|
133
|
+
end
|
134
|
+
end
|
135
|
+
|
136
|
+
# TODO return missing keys and unexpected keys
|
137
|
+
nil
|
117
138
|
end
|
118
139
|
|
119
140
|
def parameters
|
data/lib/torch/tensor.rb
CHANGED
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -317,12 +317,11 @@ module Torch
|
|
317
317
|
end
|
318
318
|
|
319
319
|
def save(obj, f)
|
320
|
-
|
321
|
-
File.binwrite(f, _save(obj))
|
320
|
+
File.binwrite(f, _save(to_ivalue(obj)))
|
322
321
|
end
|
323
322
|
|
324
323
|
def load(f)
|
325
|
-
|
324
|
+
to_ruby(_load(File.binread(f)))
|
326
325
|
end
|
327
326
|
|
328
327
|
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
@@ -447,6 +446,94 @@ module Torch
|
|
447
446
|
|
448
447
|
private
|
449
448
|
|
449
|
+
def to_ivalue(obj)
|
450
|
+
case obj
|
451
|
+
when String
|
452
|
+
IValue.from_string(obj)
|
453
|
+
when Integer
|
454
|
+
IValue.from_int(obj)
|
455
|
+
when Tensor
|
456
|
+
IValue.from_tensor(obj)
|
457
|
+
when Float
|
458
|
+
IValue.from_double(obj)
|
459
|
+
when Hash
|
460
|
+
dict = {}
|
461
|
+
obj.each do |k, v|
|
462
|
+
dict[to_ivalue(k)] = to_ivalue(v)
|
463
|
+
end
|
464
|
+
IValue.from_dict(dict)
|
465
|
+
when true, false
|
466
|
+
IValue.from_bool(obj)
|
467
|
+
when nil
|
468
|
+
IValue.new
|
469
|
+
else
|
470
|
+
raise Error, "Unknown type: #{obj.class.name}"
|
471
|
+
end
|
472
|
+
end
|
473
|
+
|
474
|
+
def to_ruby(ivalue)
|
475
|
+
if ivalue.bool?
|
476
|
+
ivalue.to_bool
|
477
|
+
elsif ivalue.double?
|
478
|
+
ivalue.to_double
|
479
|
+
elsif ivalue.int?
|
480
|
+
ivalue.to_int
|
481
|
+
elsif ivalue.none?
|
482
|
+
nil
|
483
|
+
elsif ivalue.string?
|
484
|
+
ivalue.to_string_ref
|
485
|
+
elsif ivalue.tensor?
|
486
|
+
ivalue.to_tensor
|
487
|
+
elsif ivalue.generic_dict?
|
488
|
+
dict = {}
|
489
|
+
ivalue.to_generic_dict.each do |k, v|
|
490
|
+
dict[to_ruby(k)] = to_ruby(v)
|
491
|
+
end
|
492
|
+
dict
|
493
|
+
else
|
494
|
+
type =
|
495
|
+
if ivalue.capsule?
|
496
|
+
"Capsule"
|
497
|
+
elsif ivalue.custom_class?
|
498
|
+
"CustomClass"
|
499
|
+
elsif ivalue.tuple?
|
500
|
+
"Tuple"
|
501
|
+
elsif ivalue.future?
|
502
|
+
"Future"
|
503
|
+
elsif ivalue.r_ref?
|
504
|
+
"RRef"
|
505
|
+
elsif ivalue.int_list?
|
506
|
+
"IntList"
|
507
|
+
elsif ivalue.double_list?
|
508
|
+
"DoubleList"
|
509
|
+
elsif ivalue.bool_list?
|
510
|
+
"BoolList"
|
511
|
+
elsif ivalue.tensor_list?
|
512
|
+
"TensorList"
|
513
|
+
elsif ivalue.list?
|
514
|
+
"List"
|
515
|
+
elsif ivalue.object?
|
516
|
+
"Object"
|
517
|
+
elsif ivalue.module?
|
518
|
+
"Module"
|
519
|
+
elsif ivalue.py_object?
|
520
|
+
"PyObject"
|
521
|
+
elsif ivalue.scalar?
|
522
|
+
"Scalar"
|
523
|
+
elsif ivalue.device?
|
524
|
+
"Device"
|
525
|
+
# elsif ivalue.generator?
|
526
|
+
# "Generator"
|
527
|
+
elsif ivalue.ptr_type?
|
528
|
+
"PtrType"
|
529
|
+
else
|
530
|
+
"Unknown"
|
531
|
+
end
|
532
|
+
|
533
|
+
raise Error, "Unsupported type: #{type}"
|
534
|
+
end
|
535
|
+
end
|
536
|
+
|
450
537
|
def tensor_size(size)
|
451
538
|
size.flatten
|
452
539
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-04-
|
11
|
+
date: 2020-04-27 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -120,17 +120,10 @@ files:
|
|
120
120
|
- README.md
|
121
121
|
- ext/torch/ext.cpp
|
122
122
|
- ext/torch/extconf.rb
|
123
|
-
- ext/torch/nn_functions.cpp
|
124
|
-
- ext/torch/nn_functions.hpp
|
125
123
|
- ext/torch/templates.cpp
|
126
124
|
- ext/torch/templates.hpp
|
127
|
-
- ext/torch/tensor_functions.cpp
|
128
|
-
- ext/torch/tensor_functions.hpp
|
129
|
-
- ext/torch/torch_functions.cpp
|
130
|
-
- ext/torch/torch_functions.hpp
|
131
125
|
- lib/torch-rb.rb
|
132
126
|
- lib/torch.rb
|
133
|
-
- lib/torch/ext.bundle
|
134
127
|
- lib/torch/hub.rb
|
135
128
|
- lib/torch/inspector.rb
|
136
129
|
- lib/torch/native/dispatcher.rb
|