torch-rb 0.2.2 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8e1f9758c937519ca31d92f3acd35ce0372f8cf57362cd3b50bd45920e7a6763
4
- data.tar.gz: 4d370857ee758694b974da0f5d0973a687181ef2fedcad5c583f446ceb67dda2
3
+ metadata.gz: cff805041122544d87342923649010da6981fbdb6d47c73da8cc623ba3856af5
4
+ data.tar.gz: 257c16cbdbc915fe30e7cba5fc0d24ce0337cf88dd420dfaa3a13b8437e04164
5
5
  SHA512:
6
- metadata.gz: b1cbb37019852bfdbfc45b28ac32924b4de313ce21112ffe8bb5ec91fe17898d3a1ceb42d77540e6b7a0d656e9443002fb72a1201453507dca4915db13879167
7
- data.tar.gz: d1a35689c3ad6a0628485633af4f6d7f613288fbf739f9e94ccfdb72c613b0d4581a21e00aef3a309771967f65086034befc519bc89ecc7884ff1dd142a8289f
6
+ metadata.gz: 36e2e671f3400fdaa513cfa2dd9d07b839b120cd848dc0e28bf8723570c554e6a96e1d4f29f33a1e995e6eb57e6042299321b8903f657006d2c04c10cecc59c2
7
+ data.tar.gz: ed1b17bf30ba5b4342350cf41cc5aa19b64c83cef44fe4f0122874805eabe9e91f31e0d3bc046812fa7ef07031de81f442fe18c231be2324d4da6f677325a54a
data/CHANGELOG.md CHANGED
@@ -1,3 +1,11 @@
1
+ ## 0.2.3 (2020-04-28)
2
+
3
+ - Added `show_config` and `parallel_info` methods
4
+ - Added `initial_seed` and `seed` methods to `Random`
5
+ - Improved data loader
6
+ - Build with MKL-DNN and NNPACK when available
7
+ - Fixed `inspect` for modules
8
+
1
9
  ## 0.2.2 (2020-04-27)
2
10
 
3
11
  - Added support for saving tensor lists
data/ext/torch/ext.cpp CHANGED
@@ -40,6 +40,19 @@ void Init_ext()
40
40
  Module rb_mNN = define_module_under(rb_mTorch, "NN");
41
41
  add_nn_functions(rb_mNN);
42
42
 
43
+ Module rb_mRandom = define_module_under(rb_mTorch, "Random")
44
+ .define_singleton_method(
45
+ "initial_seed",
46
+ *[]() {
47
+ return at::detail::getDefaultCPUGenerator()->current_seed();
48
+ })
49
+ .define_singleton_method(
50
+ "seed",
51
+ *[]() {
52
+ // TODO set for CUDA when available
53
+ return at::detail::getDefaultCPUGenerator()->seed();
54
+ });
55
+
43
56
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
44
57
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
45
58
  .define_constructor(Constructor<torch::IValue>())
@@ -177,6 +190,17 @@ void Init_ext()
177
190
  *[](uint64_t seed) {
178
191
  return torch::manual_seed(seed);
179
192
  })
193
+ // config
194
+ .define_singleton_method(
195
+ "show_config",
196
+ *[] {
197
+ return torch::show_config();
198
+ })
199
+ .define_singleton_method(
200
+ "parallel_info",
201
+ *[] {
202
+ return torch::get_parallel_info();
203
+ })
180
204
  // begin tensor creation
181
205
  .define_singleton_method(
182
206
  "_arange",
data/ext/torch/extconf.rb CHANGED
@@ -2,33 +2,33 @@ require "mkmf-rice"
2
2
 
3
3
  abort "Missing stdc++" unless have_library("stdc++")
4
4
 
5
- $CXXFLAGS << " -std=c++14"
5
+ $CXXFLAGS += " -std=c++14"
6
6
 
7
7
  # change to 0 for Linux pre-cxx11 ABI version
8
- $CXXFLAGS << " -D_GLIBCXX_USE_CXX11_ABI=1"
8
+ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
9
9
 
10
10
  # TODO check compiler name
11
11
  clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
12
12
 
13
13
  # check omp first
14
14
  if have_library("omp") || have_library("gomp")
15
- $CXXFLAGS << " -DAT_PARALLEL_OPENMP=1"
16
- $CXXFLAGS << " -Xclang" if clang
17
- $CXXFLAGS << " -fopenmp"
15
+ $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
16
+ $CXXFLAGS += " -Xclang" if clang
17
+ $CXXFLAGS += " -fopenmp"
18
18
  end
19
19
 
20
20
  if clang
21
21
  # silence ruby/intern.h warning
22
- $CXXFLAGS << " -Wno-deprecated-register"
22
+ $CXXFLAGS += " -Wno-deprecated-register"
23
23
 
24
24
  # silence torch warnings
25
- $CXXFLAGS << " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
25
+ $CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
26
26
  else
27
27
  # silence rice warnings
28
- $CXXFLAGS << " -Wno-noexcept-type"
28
+ $CXXFLAGS += " -Wno-noexcept-type"
29
29
 
30
30
  # silence torch warnings
31
- $CXXFLAGS << " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
31
+ $CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
32
32
  end
33
33
 
34
34
  inc, lib = dir_config("torch")
@@ -39,27 +39,30 @@ cuda_inc, cuda_lib = dir_config("cuda")
39
39
  cuda_inc ||= "/usr/local/cuda/include"
40
40
  cuda_lib ||= "/usr/local/cuda/lib64"
41
41
 
42
- $LDFLAGS << " -L#{lib}" if Dir.exist?(lib)
42
+ $LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
43
43
  abort "LibTorch not found" unless have_library("torch")
44
44
 
45
+ have_library("mkldnn")
46
+ have_library("nnpack")
47
+
45
48
  with_cuda = false
46
49
  if Dir["#{lib}/*torch_cuda*"].any?
47
- $LDFLAGS << " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
50
+ $LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
48
51
  with_cuda = have_library("cuda") && have_library("cudnn")
49
52
  end
50
53
 
51
- $INCFLAGS << " -I#{inc}"
52
- $INCFLAGS << " -I#{inc}/torch/csrc/api/include"
54
+ $INCFLAGS += " -I#{inc}"
55
+ $INCFLAGS += " -I#{inc}/torch/csrc/api/include"
53
56
 
54
- $LDFLAGS << " -Wl,-rpath,#{lib}"
55
- $LDFLAGS << ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
57
+ $LDFLAGS += " -Wl,-rpath,#{lib}"
58
+ $LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
56
59
 
57
60
  # https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
58
- $LDFLAGS << " -lc10 -ltorch_cpu -ltorch"
61
+ $LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
59
62
  if with_cuda
60
- $LDFLAGS << " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
63
+ $LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
61
64
  # TODO figure out why this is needed
62
- $LDFLAGS << " -Wl,--no-as-needed,#{lib}/libtorch.so"
65
+ $LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
63
66
  end
64
67
 
65
68
  # generate C++ functions
@@ -224,12 +224,12 @@ module Torch
224
224
 
225
225
  def inspect
226
226
  name = self.class.name.split("::").last
227
- if children.empty?
227
+ if named_children.empty?
228
228
  "#{name}(#{extra_inspect})"
229
229
  else
230
230
  str = String.new
231
231
  str << "#{name}(\n"
232
- children.each do |name, mod|
232
+ named_children.each do |name, mod|
233
233
  str << " (#{name}): #{mod.inspect}\n"
234
234
  end
235
235
  str << ")"
data/lib/torch/tensor.rb CHANGED
@@ -193,8 +193,16 @@ module Torch
193
193
  end
194
194
  end
195
195
 
196
- def random!(from = 0, to)
197
- _random__from_to(from, to)
196
+ # native functions overlap, so need to handle manually
197
+ def random!(*args)
198
+ case args.size
199
+ when 1
200
+ _random__to(*args)
201
+ when 2
202
+ _random__from_to(*args)
203
+ else
204
+ _random_(*args)
205
+ end
198
206
  end
199
207
 
200
208
  private
@@ -12,15 +12,38 @@ module Torch
12
12
  end
13
13
 
14
14
  def each
15
+ # try to keep the random number generator in sync with Python
16
+ # this makes it easy to compare results
17
+ base_seed = Torch.empty([], dtype: :int64).random!.item
18
+
19
+ max_size = @dataset.size
15
20
  size.times do |i|
16
21
  start_index = i * @batch_size
17
- yield @dataset[start_index...(start_index + @batch_size)]
22
+ end_index = [start_index + @batch_size, max_size].min
23
+ batch = (end_index - start_index).times.map { |j| @dataset[start_index + j] }
24
+ yield collate(batch)
18
25
  end
19
26
  end
20
27
 
21
28
  def size
22
29
  (@dataset.size / @batch_size.to_f).ceil
23
30
  end
31
+
32
+ private
33
+
34
+ def collate(batch)
35
+ elem = batch[0]
36
+ case elem
37
+ when Tensor
38
+ Torch.stack(batch, 0)
39
+ when Integer
40
+ Torch.tensor(batch)
41
+ when Array
42
+ batch.transpose.map { |v| collate(v) }
43
+ else
44
+ raise NotImpelmentYet
45
+ end
46
+ end
24
47
  end
25
48
  end
26
49
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.2"
2
+ VERSION = "0.2.3"
3
3
  end
data/lib/torch.rb CHANGED
@@ -176,9 +176,6 @@ require "torch/nn/init"
176
176
  require "torch/utils/data/data_loader"
177
177
  require "torch/utils/data/tensor_dataset"
178
178
 
179
- # random
180
- require "torch/random"
181
-
182
179
  # hub
183
180
  require "torch/hub"
184
181
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.2
4
+ version: 0.2.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-27 00:00:00.000000000 Z
11
+ date: 2020-04-28 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -258,7 +258,6 @@ files:
258
258
  - lib/torch/optim/rmsprop.rb
259
259
  - lib/torch/optim/rprop.rb
260
260
  - lib/torch/optim/sgd.rb
261
- - lib/torch/random.rb
262
261
  - lib/torch/tensor.rb
263
262
  - lib/torch/utils/data/data_loader.rb
264
263
  - lib/torch/utils/data/tensor_dataset.rb
data/lib/torch/random.rb DELETED
@@ -1,10 +0,0 @@
1
- module Torch
2
- module Random
3
- class << self
4
- # not available through LibTorch
5
- def initial_seed
6
- raise NotImplementedYet
7
- end
8
- end
9
- end
10
- end