torch-rb 0.2.2 → 0.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/torch/ext.cpp +24 -0
- data/ext/torch/extconf.rb +21 -18
- data/lib/torch/nn/module.rb +2 -2
- data/lib/torch/tensor.rb +10 -2
- data/lib/torch/utils/data/data_loader.rb +24 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +0 -3
- metadata +2 -3
- data/lib/torch/random.rb +0 -10
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: cff805041122544d87342923649010da6981fbdb6d47c73da8cc623ba3856af5
|
4
|
+
data.tar.gz: 257c16cbdbc915fe30e7cba5fc0d24ce0337cf88dd420dfaa3a13b8437e04164
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 36e2e671f3400fdaa513cfa2dd9d07b839b120cd848dc0e28bf8723570c554e6a96e1d4f29f33a1e995e6eb57e6042299321b8903f657006d2c04c10cecc59c2
|
7
|
+
data.tar.gz: ed1b17bf30ba5b4342350cf41cc5aa19b64c83cef44fe4f0122874805eabe9e91f31e0d3bc046812fa7ef07031de81f442fe18c231be2324d4da6f677325a54a
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,11 @@
|
|
1
|
+
## 0.2.3 (2020-04-28)
|
2
|
+
|
3
|
+
- Added `show_config` and `parallel_info` methods
|
4
|
+
- Added `initial_seed` and `seed` methods to `Random`
|
5
|
+
- Improved data loader
|
6
|
+
- Build with MKL-DNN and NNPACK when available
|
7
|
+
- Fixed `inspect` for modules
|
8
|
+
|
1
9
|
## 0.2.2 (2020-04-27)
|
2
10
|
|
3
11
|
- Added support for saving tensor lists
|
data/ext/torch/ext.cpp
CHANGED
@@ -40,6 +40,19 @@ void Init_ext()
|
|
40
40
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
41
41
|
add_nn_functions(rb_mNN);
|
42
42
|
|
43
|
+
Module rb_mRandom = define_module_under(rb_mTorch, "Random")
|
44
|
+
.define_singleton_method(
|
45
|
+
"initial_seed",
|
46
|
+
*[]() {
|
47
|
+
return at::detail::getDefaultCPUGenerator()->current_seed();
|
48
|
+
})
|
49
|
+
.define_singleton_method(
|
50
|
+
"seed",
|
51
|
+
*[]() {
|
52
|
+
// TODO set for CUDA when available
|
53
|
+
return at::detail::getDefaultCPUGenerator()->seed();
|
54
|
+
});
|
55
|
+
|
43
56
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
44
57
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
45
58
|
.define_constructor(Constructor<torch::IValue>())
|
@@ -177,6 +190,17 @@ void Init_ext()
|
|
177
190
|
*[](uint64_t seed) {
|
178
191
|
return torch::manual_seed(seed);
|
179
192
|
})
|
193
|
+
// config
|
194
|
+
.define_singleton_method(
|
195
|
+
"show_config",
|
196
|
+
*[] {
|
197
|
+
return torch::show_config();
|
198
|
+
})
|
199
|
+
.define_singleton_method(
|
200
|
+
"parallel_info",
|
201
|
+
*[] {
|
202
|
+
return torch::get_parallel_info();
|
203
|
+
})
|
180
204
|
// begin tensor creation
|
181
205
|
.define_singleton_method(
|
182
206
|
"_arange",
|
data/ext/torch/extconf.rb
CHANGED
@@ -2,33 +2,33 @@ require "mkmf-rice"
|
|
2
2
|
|
3
3
|
abort "Missing stdc++" unless have_library("stdc++")
|
4
4
|
|
5
|
-
$CXXFLAGS
|
5
|
+
$CXXFLAGS += " -std=c++14"
|
6
6
|
|
7
7
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
|
-
$CXXFLAGS
|
8
|
+
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
9
9
|
|
10
10
|
# TODO check compiler name
|
11
11
|
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
12
12
|
|
13
13
|
# check omp first
|
14
14
|
if have_library("omp") || have_library("gomp")
|
15
|
-
$CXXFLAGS
|
16
|
-
$CXXFLAGS
|
17
|
-
$CXXFLAGS
|
15
|
+
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
16
|
+
$CXXFLAGS += " -Xclang" if clang
|
17
|
+
$CXXFLAGS += " -fopenmp"
|
18
18
|
end
|
19
19
|
|
20
20
|
if clang
|
21
21
|
# silence ruby/intern.h warning
|
22
|
-
$CXXFLAGS
|
22
|
+
$CXXFLAGS += " -Wno-deprecated-register"
|
23
23
|
|
24
24
|
# silence torch warnings
|
25
|
-
$CXXFLAGS
|
25
|
+
$CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
|
26
26
|
else
|
27
27
|
# silence rice warnings
|
28
|
-
$CXXFLAGS
|
28
|
+
$CXXFLAGS += " -Wno-noexcept-type"
|
29
29
|
|
30
30
|
# silence torch warnings
|
31
|
-
$CXXFLAGS
|
31
|
+
$CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
|
32
32
|
end
|
33
33
|
|
34
34
|
inc, lib = dir_config("torch")
|
@@ -39,27 +39,30 @@ cuda_inc, cuda_lib = dir_config("cuda")
|
|
39
39
|
cuda_inc ||= "/usr/local/cuda/include"
|
40
40
|
cuda_lib ||= "/usr/local/cuda/lib64"
|
41
41
|
|
42
|
-
$LDFLAGS
|
42
|
+
$LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
|
43
43
|
abort "LibTorch not found" unless have_library("torch")
|
44
44
|
|
45
|
+
have_library("mkldnn")
|
46
|
+
have_library("nnpack")
|
47
|
+
|
45
48
|
with_cuda = false
|
46
49
|
if Dir["#{lib}/*torch_cuda*"].any?
|
47
|
-
$LDFLAGS
|
50
|
+
$LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
|
48
51
|
with_cuda = have_library("cuda") && have_library("cudnn")
|
49
52
|
end
|
50
53
|
|
51
|
-
$INCFLAGS
|
52
|
-
$INCFLAGS
|
54
|
+
$INCFLAGS += " -I#{inc}"
|
55
|
+
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
53
56
|
|
54
|
-
$LDFLAGS
|
55
|
-
$LDFLAGS
|
57
|
+
$LDFLAGS += " -Wl,-rpath,#{lib}"
|
58
|
+
$LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
56
59
|
|
57
60
|
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
58
|
-
$LDFLAGS
|
61
|
+
$LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
|
59
62
|
if with_cuda
|
60
|
-
$LDFLAGS
|
63
|
+
$LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
|
61
64
|
# TODO figure out why this is needed
|
62
|
-
$LDFLAGS
|
65
|
+
$LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
|
63
66
|
end
|
64
67
|
|
65
68
|
# generate C++ functions
|
data/lib/torch/nn/module.rb
CHANGED
@@ -224,12 +224,12 @@ module Torch
|
|
224
224
|
|
225
225
|
def inspect
|
226
226
|
name = self.class.name.split("::").last
|
227
|
-
if
|
227
|
+
if named_children.empty?
|
228
228
|
"#{name}(#{extra_inspect})"
|
229
229
|
else
|
230
230
|
str = String.new
|
231
231
|
str << "#{name}(\n"
|
232
|
-
|
232
|
+
named_children.each do |name, mod|
|
233
233
|
str << " (#{name}): #{mod.inspect}\n"
|
234
234
|
end
|
235
235
|
str << ")"
|
data/lib/torch/tensor.rb
CHANGED
@@ -193,8 +193,16 @@ module Torch
|
|
193
193
|
end
|
194
194
|
end
|
195
195
|
|
196
|
-
|
197
|
-
|
196
|
+
# native functions overlap, so need to handle manually
|
197
|
+
def random!(*args)
|
198
|
+
case args.size
|
199
|
+
when 1
|
200
|
+
_random__to(*args)
|
201
|
+
when 2
|
202
|
+
_random__from_to(*args)
|
203
|
+
else
|
204
|
+
_random_(*args)
|
205
|
+
end
|
198
206
|
end
|
199
207
|
|
200
208
|
private
|
@@ -12,15 +12,38 @@ module Torch
|
|
12
12
|
end
|
13
13
|
|
14
14
|
def each
|
15
|
+
# try to keep the random number generator in sync with Python
|
16
|
+
# this makes it easy to compare results
|
17
|
+
base_seed = Torch.empty([], dtype: :int64).random!.item
|
18
|
+
|
19
|
+
max_size = @dataset.size
|
15
20
|
size.times do |i|
|
16
21
|
start_index = i * @batch_size
|
17
|
-
|
22
|
+
end_index = [start_index + @batch_size, max_size].min
|
23
|
+
batch = (end_index - start_index).times.map { |j| @dataset[start_index + j] }
|
24
|
+
yield collate(batch)
|
18
25
|
end
|
19
26
|
end
|
20
27
|
|
21
28
|
def size
|
22
29
|
(@dataset.size / @batch_size.to_f).ceil
|
23
30
|
end
|
31
|
+
|
32
|
+
private
|
33
|
+
|
34
|
+
def collate(batch)
|
35
|
+
elem = batch[0]
|
36
|
+
case elem
|
37
|
+
when Tensor
|
38
|
+
Torch.stack(batch, 0)
|
39
|
+
when Integer
|
40
|
+
Torch.tensor(batch)
|
41
|
+
when Array
|
42
|
+
batch.transpose.map { |v| collate(v) }
|
43
|
+
else
|
44
|
+
raise NotImpelmentYet
|
45
|
+
end
|
46
|
+
end
|
24
47
|
end
|
25
48
|
end
|
26
49
|
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-04-
|
11
|
+
date: 2020-04-28 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -258,7 +258,6 @@ files:
|
|
258
258
|
- lib/torch/optim/rmsprop.rb
|
259
259
|
- lib/torch/optim/rprop.rb
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
|
-
- lib/torch/random.rb
|
262
261
|
- lib/torch/tensor.rb
|
263
262
|
- lib/torch/utils/data/data_loader.rb
|
264
263
|
- lib/torch/utils/data/tensor_dataset.rb
|