torch-rb 0.2.2 → 0.2.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/torch/ext.cpp +24 -0
- data/ext/torch/extconf.rb +21 -18
- data/lib/torch/nn/module.rb +2 -2
- data/lib/torch/tensor.rb +10 -2
- data/lib/torch/utils/data/data_loader.rb +24 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +0 -3
- metadata +2 -3
- data/lib/torch/random.rb +0 -10
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: cff805041122544d87342923649010da6981fbdb6d47c73da8cc623ba3856af5
|
4
|
+
data.tar.gz: 257c16cbdbc915fe30e7cba5fc0d24ce0337cf88dd420dfaa3a13b8437e04164
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 36e2e671f3400fdaa513cfa2dd9d07b839b120cd848dc0e28bf8723570c554e6a96e1d4f29f33a1e995e6eb57e6042299321b8903f657006d2c04c10cecc59c2
|
7
|
+
data.tar.gz: ed1b17bf30ba5b4342350cf41cc5aa19b64c83cef44fe4f0122874805eabe9e91f31e0d3bc046812fa7ef07031de81f442fe18c231be2324d4da6f677325a54a
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,11 @@
|
|
1
|
+
## 0.2.3 (2020-04-28)
|
2
|
+
|
3
|
+
- Added `show_config` and `parallel_info` methods
|
4
|
+
- Added `initial_seed` and `seed` methods to `Random`
|
5
|
+
- Improved data loader
|
6
|
+
- Build with MKL-DNN and NNPACK when available
|
7
|
+
- Fixed `inspect` for modules
|
8
|
+
|
1
9
|
## 0.2.2 (2020-04-27)
|
2
10
|
|
3
11
|
- Added support for saving tensor lists
|
data/ext/torch/ext.cpp
CHANGED
@@ -40,6 +40,19 @@ void Init_ext()
|
|
40
40
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
41
41
|
add_nn_functions(rb_mNN);
|
42
42
|
|
43
|
+
Module rb_mRandom = define_module_under(rb_mTorch, "Random")
|
44
|
+
.define_singleton_method(
|
45
|
+
"initial_seed",
|
46
|
+
*[]() {
|
47
|
+
return at::detail::getDefaultCPUGenerator()->current_seed();
|
48
|
+
})
|
49
|
+
.define_singleton_method(
|
50
|
+
"seed",
|
51
|
+
*[]() {
|
52
|
+
// TODO set for CUDA when available
|
53
|
+
return at::detail::getDefaultCPUGenerator()->seed();
|
54
|
+
});
|
55
|
+
|
43
56
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
44
57
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
45
58
|
.define_constructor(Constructor<torch::IValue>())
|
@@ -177,6 +190,17 @@ void Init_ext()
|
|
177
190
|
*[](uint64_t seed) {
|
178
191
|
return torch::manual_seed(seed);
|
179
192
|
})
|
193
|
+
// config
|
194
|
+
.define_singleton_method(
|
195
|
+
"show_config",
|
196
|
+
*[] {
|
197
|
+
return torch::show_config();
|
198
|
+
})
|
199
|
+
.define_singleton_method(
|
200
|
+
"parallel_info",
|
201
|
+
*[] {
|
202
|
+
return torch::get_parallel_info();
|
203
|
+
})
|
180
204
|
// begin tensor creation
|
181
205
|
.define_singleton_method(
|
182
206
|
"_arange",
|
data/ext/torch/extconf.rb
CHANGED
@@ -2,33 +2,33 @@ require "mkmf-rice"
|
|
2
2
|
|
3
3
|
abort "Missing stdc++" unless have_library("stdc++")
|
4
4
|
|
5
|
-
$CXXFLAGS
|
5
|
+
$CXXFLAGS += " -std=c++14"
|
6
6
|
|
7
7
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
|
-
$CXXFLAGS
|
8
|
+
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
9
9
|
|
10
10
|
# TODO check compiler name
|
11
11
|
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
12
12
|
|
13
13
|
# check omp first
|
14
14
|
if have_library("omp") || have_library("gomp")
|
15
|
-
$CXXFLAGS
|
16
|
-
$CXXFLAGS
|
17
|
-
$CXXFLAGS
|
15
|
+
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
16
|
+
$CXXFLAGS += " -Xclang" if clang
|
17
|
+
$CXXFLAGS += " -fopenmp"
|
18
18
|
end
|
19
19
|
|
20
20
|
if clang
|
21
21
|
# silence ruby/intern.h warning
|
22
|
-
$CXXFLAGS
|
22
|
+
$CXXFLAGS += " -Wno-deprecated-register"
|
23
23
|
|
24
24
|
# silence torch warnings
|
25
|
-
$CXXFLAGS
|
25
|
+
$CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
|
26
26
|
else
|
27
27
|
# silence rice warnings
|
28
|
-
$CXXFLAGS
|
28
|
+
$CXXFLAGS += " -Wno-noexcept-type"
|
29
29
|
|
30
30
|
# silence torch warnings
|
31
|
-
$CXXFLAGS
|
31
|
+
$CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
|
32
32
|
end
|
33
33
|
|
34
34
|
inc, lib = dir_config("torch")
|
@@ -39,27 +39,30 @@ cuda_inc, cuda_lib = dir_config("cuda")
|
|
39
39
|
cuda_inc ||= "/usr/local/cuda/include"
|
40
40
|
cuda_lib ||= "/usr/local/cuda/lib64"
|
41
41
|
|
42
|
-
$LDFLAGS
|
42
|
+
$LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
|
43
43
|
abort "LibTorch not found" unless have_library("torch")
|
44
44
|
|
45
|
+
have_library("mkldnn")
|
46
|
+
have_library("nnpack")
|
47
|
+
|
45
48
|
with_cuda = false
|
46
49
|
if Dir["#{lib}/*torch_cuda*"].any?
|
47
|
-
$LDFLAGS
|
50
|
+
$LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
|
48
51
|
with_cuda = have_library("cuda") && have_library("cudnn")
|
49
52
|
end
|
50
53
|
|
51
|
-
$INCFLAGS
|
52
|
-
$INCFLAGS
|
54
|
+
$INCFLAGS += " -I#{inc}"
|
55
|
+
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
53
56
|
|
54
|
-
$LDFLAGS
|
55
|
-
$LDFLAGS
|
57
|
+
$LDFLAGS += " -Wl,-rpath,#{lib}"
|
58
|
+
$LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
56
59
|
|
57
60
|
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
58
|
-
$LDFLAGS
|
61
|
+
$LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
|
59
62
|
if with_cuda
|
60
|
-
$LDFLAGS
|
63
|
+
$LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
|
61
64
|
# TODO figure out why this is needed
|
62
|
-
$LDFLAGS
|
65
|
+
$LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
|
63
66
|
end
|
64
67
|
|
65
68
|
# generate C++ functions
|
data/lib/torch/nn/module.rb
CHANGED
@@ -224,12 +224,12 @@ module Torch
|
|
224
224
|
|
225
225
|
def inspect
|
226
226
|
name = self.class.name.split("::").last
|
227
|
-
if
|
227
|
+
if named_children.empty?
|
228
228
|
"#{name}(#{extra_inspect})"
|
229
229
|
else
|
230
230
|
str = String.new
|
231
231
|
str << "#{name}(\n"
|
232
|
-
|
232
|
+
named_children.each do |name, mod|
|
233
233
|
str << " (#{name}): #{mod.inspect}\n"
|
234
234
|
end
|
235
235
|
str << ")"
|
data/lib/torch/tensor.rb
CHANGED
@@ -193,8 +193,16 @@ module Torch
|
|
193
193
|
end
|
194
194
|
end
|
195
195
|
|
196
|
-
|
197
|
-
|
196
|
+
# native functions overlap, so need to handle manually
|
197
|
+
def random!(*args)
|
198
|
+
case args.size
|
199
|
+
when 1
|
200
|
+
_random__to(*args)
|
201
|
+
when 2
|
202
|
+
_random__from_to(*args)
|
203
|
+
else
|
204
|
+
_random_(*args)
|
205
|
+
end
|
198
206
|
end
|
199
207
|
|
200
208
|
private
|
@@ -12,15 +12,38 @@ module Torch
|
|
12
12
|
end
|
13
13
|
|
14
14
|
def each
|
15
|
+
# try to keep the random number generator in sync with Python
|
16
|
+
# this makes it easy to compare results
|
17
|
+
base_seed = Torch.empty([], dtype: :int64).random!.item
|
18
|
+
|
19
|
+
max_size = @dataset.size
|
15
20
|
size.times do |i|
|
16
21
|
start_index = i * @batch_size
|
17
|
-
|
22
|
+
end_index = [start_index + @batch_size, max_size].min
|
23
|
+
batch = (end_index - start_index).times.map { |j| @dataset[start_index + j] }
|
24
|
+
yield collate(batch)
|
18
25
|
end
|
19
26
|
end
|
20
27
|
|
21
28
|
def size
|
22
29
|
(@dataset.size / @batch_size.to_f).ceil
|
23
30
|
end
|
31
|
+
|
32
|
+
private
|
33
|
+
|
34
|
+
def collate(batch)
|
35
|
+
elem = batch[0]
|
36
|
+
case elem
|
37
|
+
when Tensor
|
38
|
+
Torch.stack(batch, 0)
|
39
|
+
when Integer
|
40
|
+
Torch.tensor(batch)
|
41
|
+
when Array
|
42
|
+
batch.transpose.map { |v| collate(v) }
|
43
|
+
else
|
44
|
+
raise NotImpelmentYet
|
45
|
+
end
|
46
|
+
end
|
24
47
|
end
|
25
48
|
end
|
26
49
|
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-04-
|
11
|
+
date: 2020-04-28 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -258,7 +258,6 @@ files:
|
|
258
258
|
- lib/torch/optim/rmsprop.rb
|
259
259
|
- lib/torch/optim/rprop.rb
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
|
-
- lib/torch/random.rb
|
262
261
|
- lib/torch/tensor.rb
|
263
262
|
- lib/torch/utils/data/data_loader.rb
|
264
263
|
- lib/torch/utils/data/tensor_dataset.rb
|