torch-rb 0.2.2 → 0.2.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8e1f9758c937519ca31d92f3acd35ce0372f8cf57362cd3b50bd45920e7a6763
4
- data.tar.gz: 4d370857ee758694b974da0f5d0973a687181ef2fedcad5c583f446ceb67dda2
3
+ metadata.gz: cff805041122544d87342923649010da6981fbdb6d47c73da8cc623ba3856af5
4
+ data.tar.gz: 257c16cbdbc915fe30e7cba5fc0d24ce0337cf88dd420dfaa3a13b8437e04164
5
5
  SHA512:
6
- metadata.gz: b1cbb37019852bfdbfc45b28ac32924b4de313ce21112ffe8bb5ec91fe17898d3a1ceb42d77540e6b7a0d656e9443002fb72a1201453507dca4915db13879167
7
- data.tar.gz: d1a35689c3ad6a0628485633af4f6d7f613288fbf739f9e94ccfdb72c613b0d4581a21e00aef3a309771967f65086034befc519bc89ecc7884ff1dd142a8289f
6
+ metadata.gz: 36e2e671f3400fdaa513cfa2dd9d07b839b120cd848dc0e28bf8723570c554e6a96e1d4f29f33a1e995e6eb57e6042299321b8903f657006d2c04c10cecc59c2
7
+ data.tar.gz: ed1b17bf30ba5b4342350cf41cc5aa19b64c83cef44fe4f0122874805eabe9e91f31e0d3bc046812fa7ef07031de81f442fe18c231be2324d4da6f677325a54a
data/CHANGELOG.md CHANGED
@@ -1,3 +1,11 @@
1
+ ## 0.2.3 (2020-04-28)
2
+
3
+ - Added `show_config` and `parallel_info` methods
4
+ - Added `initial_seed` and `seed` methods to `Random`
5
+ - Improved data loader
6
+ - Build with MKL-DNN and NNPACK when available
7
+ - Fixed `inspect` for modules
8
+
1
9
  ## 0.2.2 (2020-04-27)
2
10
 
3
11
  - Added support for saving tensor lists
data/ext/torch/ext.cpp CHANGED
@@ -40,6 +40,19 @@ void Init_ext()
40
40
  Module rb_mNN = define_module_under(rb_mTorch, "NN");
41
41
  add_nn_functions(rb_mNN);
42
42
 
43
+ Module rb_mRandom = define_module_under(rb_mTorch, "Random")
44
+ .define_singleton_method(
45
+ "initial_seed",
46
+ *[]() {
47
+ return at::detail::getDefaultCPUGenerator()->current_seed();
48
+ })
49
+ .define_singleton_method(
50
+ "seed",
51
+ *[]() {
52
+ // TODO set for CUDA when available
53
+ return at::detail::getDefaultCPUGenerator()->seed();
54
+ });
55
+
43
56
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
44
57
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
45
58
  .define_constructor(Constructor<torch::IValue>())
@@ -177,6 +190,17 @@ void Init_ext()
177
190
  *[](uint64_t seed) {
178
191
  return torch::manual_seed(seed);
179
192
  })
193
+ // config
194
+ .define_singleton_method(
195
+ "show_config",
196
+ *[] {
197
+ return torch::show_config();
198
+ })
199
+ .define_singleton_method(
200
+ "parallel_info",
201
+ *[] {
202
+ return torch::get_parallel_info();
203
+ })
180
204
  // begin tensor creation
181
205
  .define_singleton_method(
182
206
  "_arange",
data/ext/torch/extconf.rb CHANGED
@@ -2,33 +2,33 @@ require "mkmf-rice"
2
2
 
3
3
  abort "Missing stdc++" unless have_library("stdc++")
4
4
 
5
- $CXXFLAGS << " -std=c++14"
5
+ $CXXFLAGS += " -std=c++14"
6
6
 
7
7
  # change to 0 for Linux pre-cxx11 ABI version
8
- $CXXFLAGS << " -D_GLIBCXX_USE_CXX11_ABI=1"
8
+ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
9
9
 
10
10
  # TODO check compiler name
11
11
  clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
12
12
 
13
13
  # check omp first
14
14
  if have_library("omp") || have_library("gomp")
15
- $CXXFLAGS << " -DAT_PARALLEL_OPENMP=1"
16
- $CXXFLAGS << " -Xclang" if clang
17
- $CXXFLAGS << " -fopenmp"
15
+ $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
16
+ $CXXFLAGS += " -Xclang" if clang
17
+ $CXXFLAGS += " -fopenmp"
18
18
  end
19
19
 
20
20
  if clang
21
21
  # silence ruby/intern.h warning
22
- $CXXFLAGS << " -Wno-deprecated-register"
22
+ $CXXFLAGS += " -Wno-deprecated-register"
23
23
 
24
24
  # silence torch warnings
25
- $CXXFLAGS << " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
25
+ $CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
26
26
  else
27
27
  # silence rice warnings
28
- $CXXFLAGS << " -Wno-noexcept-type"
28
+ $CXXFLAGS += " -Wno-noexcept-type"
29
29
 
30
30
  # silence torch warnings
31
- $CXXFLAGS << " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
31
+ $CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
32
32
  end
33
33
 
34
34
  inc, lib = dir_config("torch")
@@ -39,27 +39,30 @@ cuda_inc, cuda_lib = dir_config("cuda")
39
39
  cuda_inc ||= "/usr/local/cuda/include"
40
40
  cuda_lib ||= "/usr/local/cuda/lib64"
41
41
 
42
- $LDFLAGS << " -L#{lib}" if Dir.exist?(lib)
42
+ $LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
43
43
  abort "LibTorch not found" unless have_library("torch")
44
44
 
45
+ have_library("mkldnn")
46
+ have_library("nnpack")
47
+
45
48
  with_cuda = false
46
49
  if Dir["#{lib}/*torch_cuda*"].any?
47
- $LDFLAGS << " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
50
+ $LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
48
51
  with_cuda = have_library("cuda") && have_library("cudnn")
49
52
  end
50
53
 
51
- $INCFLAGS << " -I#{inc}"
52
- $INCFLAGS << " -I#{inc}/torch/csrc/api/include"
54
+ $INCFLAGS += " -I#{inc}"
55
+ $INCFLAGS += " -I#{inc}/torch/csrc/api/include"
53
56
 
54
- $LDFLAGS << " -Wl,-rpath,#{lib}"
55
- $LDFLAGS << ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
57
+ $LDFLAGS += " -Wl,-rpath,#{lib}"
58
+ $LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
56
59
 
57
60
  # https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
58
- $LDFLAGS << " -lc10 -ltorch_cpu -ltorch"
61
+ $LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
59
62
  if with_cuda
60
- $LDFLAGS << " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
63
+ $LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
61
64
  # TODO figure out why this is needed
62
- $LDFLAGS << " -Wl,--no-as-needed,#{lib}/libtorch.so"
65
+ $LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
63
66
  end
64
67
 
65
68
  # generate C++ functions
@@ -224,12 +224,12 @@ module Torch
224
224
 
225
225
  def inspect
226
226
  name = self.class.name.split("::").last
227
- if children.empty?
227
+ if named_children.empty?
228
228
  "#{name}(#{extra_inspect})"
229
229
  else
230
230
  str = String.new
231
231
  str << "#{name}(\n"
232
- children.each do |name, mod|
232
+ named_children.each do |name, mod|
233
233
  str << " (#{name}): #{mod.inspect}\n"
234
234
  end
235
235
  str << ")"
data/lib/torch/tensor.rb CHANGED
@@ -193,8 +193,16 @@ module Torch
193
193
  end
194
194
  end
195
195
 
196
- def random!(from = 0, to)
197
- _random__from_to(from, to)
196
+ # native functions overlap, so need to handle manually
197
+ def random!(*args)
198
+ case args.size
199
+ when 1
200
+ _random__to(*args)
201
+ when 2
202
+ _random__from_to(*args)
203
+ else
204
+ _random_(*args)
205
+ end
198
206
  end
199
207
 
200
208
  private
@@ -12,15 +12,38 @@ module Torch
12
12
  end
13
13
 
14
14
  def each
15
+ # try to keep the random number generator in sync with Python
16
+ # this makes it easy to compare results
17
+ base_seed = Torch.empty([], dtype: :int64).random!.item
18
+
19
+ max_size = @dataset.size
15
20
  size.times do |i|
16
21
  start_index = i * @batch_size
17
- yield @dataset[start_index...(start_index + @batch_size)]
22
+ end_index = [start_index + @batch_size, max_size].min
23
+ batch = (end_index - start_index).times.map { |j| @dataset[start_index + j] }
24
+ yield collate(batch)
18
25
  end
19
26
  end
20
27
 
21
28
  def size
22
29
  (@dataset.size / @batch_size.to_f).ceil
23
30
  end
31
+
32
+ private
33
+
34
+ def collate(batch)
35
+ elem = batch[0]
36
+ case elem
37
+ when Tensor
38
+ Torch.stack(batch, 0)
39
+ when Integer
40
+ Torch.tensor(batch)
41
+ when Array
42
+ batch.transpose.map { |v| collate(v) }
43
+ else
44
+ raise NotImpelmentYet
45
+ end
46
+ end
24
47
  end
25
48
  end
26
49
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.2"
2
+ VERSION = "0.2.3"
3
3
  end
data/lib/torch.rb CHANGED
@@ -176,9 +176,6 @@ require "torch/nn/init"
176
176
  require "torch/utils/data/data_loader"
177
177
  require "torch/utils/data/tensor_dataset"
178
178
 
179
- # random
180
- require "torch/random"
181
-
182
179
  # hub
183
180
  require "torch/hub"
184
181
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.2
4
+ version: 0.2.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-27 00:00:00.000000000 Z
11
+ date: 2020-04-28 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -258,7 +258,6 @@ files:
258
258
  - lib/torch/optim/rmsprop.rb
259
259
  - lib/torch/optim/rprop.rb
260
260
  - lib/torch/optim/sgd.rb
261
- - lib/torch/random.rb
262
261
  - lib/torch/tensor.rb
263
262
  - lib/torch/utils/data/data_loader.rb
264
263
  - lib/torch/utils/data/tensor_dataset.rb
data/lib/torch/random.rb DELETED
@@ -1,10 +0,0 @@
1
- module Torch
2
- module Random
3
- class << self
4
- # not available through LibTorch
5
- def initial_seed
6
- raise NotImplementedYet
7
- end
8
- end
9
- end
10
- end