torch-rb 0.2.1 → 0.2.6

This diff has not been reviewed by any users.
Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b37dc4dd7be5806879c2fb5bb52ac94c8b16eba76bc2a3c591ca4cbe51cf8745
4
- data.tar.gz: 6415d14f7cc8baa4db9205c709b70fec2b90b4dff1b60de97c299c2a7edfbf40
3
+ metadata.gz: da5c88539a2890933e44859af7c1acfe835405c89e82bc6a6fb2df37fdee141a
4
+ data.tar.gz: 747ab48ba1b0ba16077ed31cf505f622a4120d18be9c0942ded39810095aa68e
5
5
  SHA512:
6
- metadata.gz: cf43cb21e18171f76f1291f2cccdb8a93141605fe33d4421eaf799a4589638d33da040b2ffad3ddce34fec60ab2b41edf1fa9a247d69de7f31e157063e57f331
7
- data.tar.gz: 3c858de8e7eb6169359fad18104c08d9a4011a305f44c4a8213a289b005dc4439b1959a40effd8b1bba1c2e96a871b28f01695fdda56e6ad0aeed0c2334cfa25
6
+ metadata.gz: d2bccf16e7af54d53affbc12030fd89f417fd060afa7057e97765bca00c24b7089f74b9e8aa4bab9180045e466649676f481425dc24ab29533b74138bb03e786
7
+ data.tar.gz: 8eb49a743fedb220df4edc39d7d0c492c1827ce63f2c5d9a8d73495da63da5f6055674ca0bd60bb0b81be68db6e5c9300357de2c417e82e7af4fafd1ab6c7ca2
@@ -1,3 +1,34 @@
1
+ ## 0.2.6 (2020-06-29)
2
+
3
+ - Added support for indexing with tensors
4
+ - Added `contiguous` methods
5
+ - Fixed named parameters for nested parameters
6
+
7
+ ## 0.2.5 (2020-06-07)
8
+
9
+ - Added `download_url_to_file` and `load_state_dict_from_url` to `Torch::Hub`
10
+ - Improved error messages
11
+ - Fixed tensor slicing
12
+
13
+ ## 0.2.4 (2020-04-29)
14
+
15
+ - Added `to_i` and `to_f` to tensors
16
+ - Added `shuffle` option to data loader
17
+ - Fixed `modules` and `named_modules` for nested modules
18
+
19
+ ## 0.2.3 (2020-04-28)
20
+
21
+ - Added `show_config` and `parallel_info` methods
22
+ - Added `initial_seed` and `seed` methods to `Random`
23
+ - Improved data loader
24
+ - Build with MKL-DNN and NNPACK when available
25
+ - Fixed `inspect` for modules
26
+
27
+ ## 0.2.2 (2020-04-27)
28
+
29
+ - Added support for saving tensor lists
30
+ - Added `ndim` and `ndimension` methods to tensors
31
+
1
32
  ## 0.2.1 (2020-04-26)
2
33
 
3
34
  - Added support for saving and loading models
data/README.md CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
+ For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
6
+
5
7
  [![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
6
8
 
7
9
  ## Installation
@@ -22,6 +24,18 @@ It can take a few minutes to compile the extension.
22
24
 
23
25
  ## Getting Started
24
26
 
27
+ Deep learning is significantly faster with a GPU. If you don’t have an NVIDIA GPU, we recommend using a cloud service. [Paperspace](https://www.paperspace.com/) has a great free plan.
28
+
29
+ We’ve put together a [Docker image](https://github.com/ankane/ml-stack) to make it easy to get started. On Paperspace, create a notebook with a custom container. Set the container name to:
30
+
31
+ ```text
32
+ ankane/ml-stack:torch-gpu
33
+ ```
34
+
35
+ And leave the other fields in that section blank. Once the notebook is running, you can run the [MNIST example](https://github.com/ankane/ml-stack/blob/master/torch-gpu/MNIST.ipynb).
36
+
37
+ ## API
38
+
25
39
  This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
26
40
 
27
41
  - Methods that perform in-place modifications end with `!` instead of `_` (`add!` instead of `add_`)
@@ -192,7 +206,7 @@ end
192
206
  Define a neural network
193
207
 
194
208
  ```ruby
195
- class Net < Torch::NN::Module
209
+ class MyNet < Torch::NN::Module
196
210
  def initialize
197
211
  super
198
212
  @conv1 = Torch::NN::Conv2d.new(1, 6, 3)
@@ -226,7 +240,7 @@ end
226
240
  Create an instance of it
227
241
 
228
242
  ```ruby
229
- net = Net.new
243
+ net = MyNet.new
230
244
  input = Torch.randn(1, 1, 32, 32)
231
245
  net.call(input)
232
246
  ```
@@ -294,7 +308,7 @@ Torch.save(net.state_dict, "net.pth")
294
308
  Load a model
295
309
 
296
310
  ```ruby
297
- net = Net.new
311
+ net = MyNet.new
298
312
  net.load_state_dict(Torch.load("net.pth"))
299
313
  net.eval
300
314
  ```
@@ -395,7 +409,7 @@ Here’s the list of compatible versions.
395
409
 
396
410
  Torch.rb | LibTorch
397
411
  --- | ---
398
- 0.2.0 | 1.5.0
412
+ 0.2.0+ | 1.5.0+
399
413
  0.1.8 | 1.4.0
400
414
  0.1.0-0.1.7 | 1.3.1
401
415
 
@@ -413,9 +427,7 @@ Then install the gem (no need for `bundle config`).
413
427
 
414
428
  ### Linux
415
429
 
416
- Deep learning is significantly faster on GPUs.
417
-
418
- Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
430
+ Deep learning is significantly faster on a GPU. Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
419
431
 
420
432
  Check if CUDA is available
421
433
 
@@ -426,7 +438,7 @@ Torch::CUDA.available?
426
438
  Move a neural network to a GPU
427
439
 
428
440
  ```ruby
429
- net.to("cuda")
441
+ net.cuda
430
442
  ```
431
443
 
432
444
  ## rbenv
@@ -23,7 +23,7 @@ class Parameter: public torch::autograd::Variable {
23
23
  Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
24
24
  };
25
25
 
26
- void handle_error(c10::Error const & ex)
26
+ void handle_error(torch::Error const & ex)
27
27
  {
28
28
  throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
29
29
  }
@@ -32,16 +32,34 @@ extern "C"
32
32
  void Init_ext()
33
33
  {
34
34
  Module rb_mTorch = define_module("Torch");
35
+ rb_mTorch.add_handler<torch::Error>(handle_error);
35
36
  add_torch_functions(rb_mTorch);
36
37
 
37
38
  Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
39
+ rb_cTensor.add_handler<torch::Error>(handle_error);
38
40
  add_tensor_functions(rb_cTensor);
39
41
 
40
42
  Module rb_mNN = define_module_under(rb_mTorch, "NN");
43
+ rb_mNN.add_handler<torch::Error>(handle_error);
41
44
  add_nn_functions(rb_mNN);
42
45
 
46
+ Module rb_mRandom = define_module_under(rb_mTorch, "Random")
47
+ .add_handler<torch::Error>(handle_error)
48
+ .define_singleton_method(
49
+ "initial_seed",
50
+ *[]() {
51
+ return at::detail::getDefaultCPUGenerator()->current_seed();
52
+ })
53
+ .define_singleton_method(
54
+ "seed",
55
+ *[]() {
56
+ // TODO set for CUDA when available
57
+ return at::detail::getDefaultCPUGenerator()->seed();
58
+ });
59
+
43
60
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
44
61
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
62
+ .add_handler<torch::Error>(handle_error)
45
63
  .define_constructor(Constructor<torch::IValue>())
46
64
  .define_method("bool?", &torch::IValue::isBool)
47
65
  .define_method("bool_list?", &torch::IValue::isBoolList)
@@ -82,6 +100,16 @@ void Init_ext()
82
100
  *[](torch::IValue& self) {
83
101
  return self.toInt();
84
102
  })
103
+ .define_method(
104
+ "to_list",
105
+ *[](torch::IValue& self) {
106
+ auto list = self.toListRef();
107
+ Array obj;
108
+ for (auto& elem : list) {
109
+ obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
110
+ }
111
+ return obj;
112
+ })
85
113
  .define_method(
86
114
  "to_string_ref",
87
115
  *[](torch::IValue& self) {
@@ -96,17 +124,27 @@ void Init_ext()
96
124
  "to_generic_dict",
97
125
  *[](torch::IValue& self) {
98
126
  auto dict = self.toGenericDict();
99
- Hash h;
127
+ Hash obj;
100
128
  for (auto& pair : dict) {
101
- h[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
129
+ obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
102
130
  }
103
- return h;
131
+ return obj;
104
132
  })
105
133
  .define_singleton_method(
106
134
  "from_tensor",
107
135
  *[](torch::Tensor& v) {
108
136
  return torch::IValue(v);
109
137
  })
138
+ // TODO create specialized list types?
139
+ .define_singleton_method(
140
+ "from_list",
141
+ *[](Array obj) {
142
+ c10::impl::GenericList list(c10::AnyType::get());
143
+ for (auto entry : obj) {
144
+ list.push_back(from_ruby<torch::IValue>(entry));
145
+ }
146
+ return torch::IValue(list);
147
+ })
110
148
  .define_singleton_method(
111
149
  "from_string",
112
150
  *[](String v) {
@@ -157,6 +195,17 @@ void Init_ext()
157
195
  *[](uint64_t seed) {
158
196
  return torch::manual_seed(seed);
159
197
  })
198
+ // config
199
+ .define_singleton_method(
200
+ "show_config",
201
+ *[] {
202
+ return torch::show_config();
203
+ })
204
+ .define_singleton_method(
205
+ "parallel_info",
206
+ *[] {
207
+ return torch::get_parallel_info();
208
+ })
160
209
  // begin tensor creation
161
210
  .define_singleton_method(
162
211
  "_arange",
@@ -273,7 +322,6 @@ void Init_ext()
273
322
  });
274
323
 
275
324
  rb_cTensor
276
- .add_handler<c10::Error>(handle_error)
277
325
  .define_method("cuda?", &torch::Tensor::is_cuda)
278
326
  .define_method("sparse?", &torch::Tensor::is_sparse)
279
327
  .define_method("quantized?", &torch::Tensor::is_quantized)
@@ -281,6 +329,11 @@ void Init_ext()
281
329
  .define_method("numel", &torch::Tensor::numel)
282
330
  .define_method("element_size", &torch::Tensor::element_size)
283
331
  .define_method("requires_grad", &torch::Tensor::requires_grad)
332
+ .define_method(
333
+ "contiguous?",
334
+ *[](Tensor& self) {
335
+ return self.is_contiguous();
336
+ })
284
337
  .define_method(
285
338
  "addcmul!",
286
339
  *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
@@ -330,6 +383,21 @@ void Init_ext()
330
383
  s << self.device();
331
384
  return s.str();
332
385
  })
386
+ .define_method(
387
+ "_data_str",
388
+ *[](Tensor& self) {
389
+ Tensor tensor = self;
390
+
391
+ // move to CPU to get data
392
+ if (tensor.device().type() != torch::kCPU) {
393
+ torch::Device device("cpu");
394
+ tensor = tensor.to(device);
395
+ }
396
+
397
+ auto data_ptr = (const char *) tensor.data_ptr();
398
+ return std::string(data_ptr, tensor.numel() * tensor.element_size());
399
+ })
400
+ // TODO figure out a better way to do this
333
401
  .define_method(
334
402
  "_flat_data",
335
403
  *[](Tensor& self) {
@@ -344,46 +412,40 @@ void Init_ext()
344
412
  Array a;
345
413
  auto dtype = tensor.dtype();
346
414
 
415
+ Tensor view = tensor.reshape({tensor.numel()});
416
+
347
417
  // TODO DRY if someone knows C++
348
418
  if (dtype == torch::kByte) {
349
- uint8_t* data = tensor.data_ptr<uint8_t>();
350
419
  for (int i = 0; i < tensor.numel(); i++) {
351
- a.push(data[i]);
420
+ a.push(view[i].item().to<uint8_t>());
352
421
  }
353
422
  } else if (dtype == torch::kChar) {
354
- int8_t* data = tensor.data_ptr<int8_t>();
355
423
  for (int i = 0; i < tensor.numel(); i++) {
356
- a.push(to_ruby<int>(data[i]));
424
+ a.push(to_ruby<int>(view[i].item().to<int8_t>()));
357
425
  }
358
426
  } else if (dtype == torch::kShort) {
359
- int16_t* data = tensor.data_ptr<int16_t>();
360
427
  for (int i = 0; i < tensor.numel(); i++) {
361
- a.push(data[i]);
428
+ a.push(view[i].item().to<int16_t>());
362
429
  }
363
430
  } else if (dtype == torch::kInt) {
364
- int32_t* data = tensor.data_ptr<int32_t>();
365
431
  for (int i = 0; i < tensor.numel(); i++) {
366
- a.push(data[i]);
432
+ a.push(view[i].item().to<int32_t>());
367
433
  }
368
434
  } else if (dtype == torch::kLong) {
369
- int64_t* data = tensor.data_ptr<int64_t>();
370
435
  for (int i = 0; i < tensor.numel(); i++) {
371
- a.push(data[i]);
436
+ a.push(view[i].item().to<int64_t>());
372
437
  }
373
438
  } else if (dtype == torch::kFloat) {
374
- float* data = tensor.data_ptr<float>();
375
439
  for (int i = 0; i < tensor.numel(); i++) {
376
- a.push(data[i]);
440
+ a.push(view[i].item().to<float>());
377
441
  }
378
442
  } else if (dtype == torch::kDouble) {
379
- double* data = tensor.data_ptr<double>();
380
443
  for (int i = 0; i < tensor.numel(); i++) {
381
- a.push(data[i]);
444
+ a.push(view[i].item().to<double>());
382
445
  }
383
446
  } else if (dtype == torch::kBool) {
384
- bool* data = tensor.data_ptr<bool>();
385
447
  for (int i = 0; i < tensor.numel(); i++) {
386
- a.push(data[i] ? True : False);
448
+ a.push(view[i].item().to<bool>() ? True : False);
387
449
  }
388
450
  } else {
389
451
  throw std::runtime_error("Unsupported type");
@@ -405,7 +467,7 @@ void Init_ext()
405
467
  });
406
468
 
407
469
  Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
408
- .add_handler<c10::Error>(handle_error)
470
+ .add_handler<torch::Error>(handle_error)
409
471
  .define_constructor(Constructor<torch::TensorOptions>())
410
472
  .define_method(
411
473
  "dtype",
@@ -511,6 +573,7 @@ void Init_ext()
511
573
  });
512
574
 
513
575
  Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
576
+ .add_handler<torch::Error>(handle_error)
514
577
  .define_method(
515
578
  "grad",
516
579
  *[](Parameter& self) {
@@ -520,6 +583,7 @@ void Init_ext()
520
583
 
521
584
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
522
585
  .define_constructor(Constructor<torch::Device, std::string>())
586
+ .add_handler<torch::Error>(handle_error)
523
587
  .define_method("index", &torch::Device::index)
524
588
  .define_method("index?", &torch::Device::has_index)
525
589
  .define_method(
@@ -531,6 +595,7 @@ void Init_ext()
531
595
  });
532
596
 
533
597
  Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
598
+ .add_handler<torch::Error>(handle_error)
534
599
  .define_singleton_method("available?", &torch::cuda::is_available)
535
600
  .define_singleton_method("device_count", &torch::cuda::device_count);
536
601
  }
@@ -2,33 +2,33 @@ require "mkmf-rice"
2
2
 
3
3
  abort "Missing stdc++" unless have_library("stdc++")
4
4
 
5
- $CXXFLAGS << " -std=c++14"
5
+ $CXXFLAGS += " -std=c++14"
6
6
 
7
7
  # change to 0 for Linux pre-cxx11 ABI version
8
- $CXXFLAGS << " -D_GLIBCXX_USE_CXX11_ABI=1"
8
+ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
9
9
 
10
10
  # TODO check compiler name
11
11
  clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
12
12
 
13
13
  # check omp first
14
14
  if have_library("omp") || have_library("gomp")
15
- $CXXFLAGS << " -DAT_PARALLEL_OPENMP=1"
16
- $CXXFLAGS << " -Xclang" if clang
17
- $CXXFLAGS << " -fopenmp"
15
+ $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
16
+ $CXXFLAGS += " -Xclang" if clang
17
+ $CXXFLAGS += " -fopenmp"
18
18
  end
19
19
 
20
20
  if clang
21
21
  # silence ruby/intern.h warning
22
- $CXXFLAGS << " -Wno-deprecated-register"
22
+ $CXXFLAGS += " -Wno-deprecated-register"
23
23
 
24
24
  # silence torch warnings
25
- $CXXFLAGS << " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
25
+ $CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
26
26
  else
27
27
  # silence rice warnings
28
- $CXXFLAGS << " -Wno-noexcept-type"
28
+ $CXXFLAGS += " -Wno-noexcept-type"
29
29
 
30
30
  # silence torch warnings
31
- $CXXFLAGS << " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
31
+ $CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
32
32
  end
33
33
 
34
34
  inc, lib = dir_config("torch")
@@ -39,27 +39,30 @@ cuda_inc, cuda_lib = dir_config("cuda")
39
39
  cuda_inc ||= "/usr/local/cuda/include"
40
40
  cuda_lib ||= "/usr/local/cuda/lib64"
41
41
 
42
- $LDFLAGS << " -L#{lib}" if Dir.exist?(lib)
42
+ $LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
43
43
  abort "LibTorch not found" unless have_library("torch")
44
44
 
45
+ have_library("mkldnn")
46
+ have_library("nnpack")
47
+
45
48
  with_cuda = false
46
49
  if Dir["#{lib}/*torch_cuda*"].any?
47
- $LDFLAGS << " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
50
+ $LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
48
51
  with_cuda = have_library("cuda") && have_library("cudnn")
49
52
  end
50
53
 
51
- $INCFLAGS << " -I#{inc}"
52
- $INCFLAGS << " -I#{inc}/torch/csrc/api/include"
54
+ $INCFLAGS += " -I#{inc}"
55
+ $INCFLAGS += " -I#{inc}/torch/csrc/api/include"
53
56
 
54
- $LDFLAGS << " -Wl,-rpath,#{lib}"
55
- $LDFLAGS << ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
57
+ $LDFLAGS += " -Wl,-rpath,#{lib}"
58
+ $LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
56
59
 
57
60
  # https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
58
- $LDFLAGS << " -lc10 -ltorch_cpu -ltorch"
61
+ $LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
59
62
  if with_cuda
60
- $LDFLAGS << " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
63
+ $LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
61
64
  # TODO figure out why this is needed
62
- $LDFLAGS << " -Wl,--no-as-needed,#{lib}/libtorch.so"
65
+ $LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
63
66
  end
64
67
 
65
68
  # generate C++ functions
@@ -1,6 +1,11 @@
1
1
  # ext
2
2
  require "torch/ext"
3
3
 
4
+ # stdlib
5
+ require "fileutils"
6
+ require "net/http"
7
+ require "tmpdir"
8
+
4
9
  # native functions
5
10
  require "torch/native/generator"
6
11
  require "torch/native/parser"
@@ -174,11 +179,9 @@ require "torch/nn/init"
174
179
 
175
180
  # utils
176
181
  require "torch/utils/data/data_loader"
182
+ require "torch/utils/data/dataset"
177
183
  require "torch/utils/data/tensor_dataset"
178
184
 
179
- # random
180
- require "torch/random"
181
-
182
185
  # hub
183
186
  require "torch/hub"
184
187
 
@@ -466,6 +469,12 @@ module Torch
466
469
  IValue.from_bool(obj)
467
470
  when nil
468
471
  IValue.new
472
+ when Array
473
+ if obj.all? { |v| v.is_a?(Tensor) }
474
+ IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
475
+ else
476
+ raise Error, "Unknown list type"
477
+ end
469
478
  else
470
479
  raise Error, "Unknown type: #{obj.class.name}"
471
480
  end
@@ -490,6 +499,8 @@ module Torch
490
499
  dict[to_ruby(k)] = to_ruby(v)
491
500
  end
492
501
  dict
502
+ elsif ivalue.list?
503
+ ivalue.to_list.map { |v| to_ruby(v) }
493
504
  else
494
505
  type =
495
506
  if ivalue.capsule?
@@ -510,8 +521,6 @@ module Torch
510
521
  "BoolList"
511
522
  elsif ivalue.tensor_list?
512
523
  "TensorList"
513
- elsif ivalue.list?
514
- "List"
515
524
  elsif ivalue.object?
516
525
  "Object"
517
526
  elsif ivalue.module?
@@ -4,6 +4,58 @@ module Torch
4
4
  def list(github, force_reload: false)
5
5
  raise NotImplementedYet
6
6
  end
7
+
8
+ def download_url_to_file(url, dst)
9
+ uri = URI(url)
10
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
11
+ location = nil
12
+
13
+ Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
14
+ request = Net::HTTP::Get.new(uri)
15
+
16
+ puts "Downloading #{url}..."
17
+ File.open(tmp, "wb") do |f|
18
+ http.request(request) do |response|
19
+ case response
20
+ when Net::HTTPRedirection
21
+ location = response["location"]
22
+ when Net::HTTPSuccess
23
+ response.read_body do |chunk|
24
+ f.write(chunk)
25
+ end
26
+ else
27
+ raise Error, "Bad response"
28
+ end
29
+ end
30
+ end
31
+ end
32
+
33
+ if location
34
+ download_url_to_file(location, dst)
35
+ else
36
+ FileUtils.mv(tmp, dst)
37
+ nil
38
+ end
39
+ end
40
+
41
+ def load_state_dict_from_url(url, model_dir: nil)
42
+ unless model_dir
43
+ torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
44
+ model_dir = File.join(torch_home, "checkpoints")
45
+ end
46
+
47
+ FileUtils.mkdir_p(model_dir)
48
+
49
+ parts = URI(url)
50
+ filename = File.basename(parts.path)
51
+ cached_file = File.join(model_dir, filename)
52
+ unless File.exist?(cached_file)
53
+ # TODO support hash_prefix
54
+ download_url_to_file(url, cached_file)
55
+ end
56
+
57
+ Torch.load(cached_file)
58
+ end
7
59
  end
8
60
  end
9
61
  end
@@ -1,6 +1,6 @@
1
1
  module Torch
2
2
  module Inspector
3
- # TODO make more performance, especially when summarizing
3
+ # TODO make more performant, especially when summarizing
4
4
  # how? only read data that will be displayed
5
5
  def inspect
6
6
  data =
@@ -14,7 +14,7 @@ module Torch
14
14
  if dtype == :bool
15
15
  fmt = "%s"
16
16
  else
17
- values = to_a.flatten
17
+ values = _flat_data
18
18
  abs = values.select { |v| v != 0 }.map(&:abs)
19
19
  max = abs.max || 1
20
20
  min = abs.min || 1
@@ -25,7 +25,7 @@ module Torch
25
25
  end
26
26
 
27
27
  if floating_point?
28
- sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
28
+ sci = max > 1e8 || max < 1e-4
29
29
 
30
30
  all_int = values.all? { |v| v.finite? && v == v.to_i }
31
31
  decimal = all_int ? 1 : 4
@@ -66,6 +66,7 @@ module Torch
66
66
  end
67
67
 
68
68
  next if t == "Generator?"
69
+ next if t == "MemoryFormat"
69
70
  next if t == "MemoryFormat?"
70
71
  args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
71
72
  end
@@ -18,7 +18,7 @@ module Torch
18
18
  functions = functions()
19
19
 
20
20
  # skip functions
21
- skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"]
21
+ skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
22
22
 
23
23
  # remove functions
24
24
  functions.reject! do |f|
@@ -31,7 +31,7 @@ module Torch
31
31
  todo_functions, functions =
32
32
  functions.partition do |f|
33
33
  f.args.any? do |a|
34
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
34
+ a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
35
  skip_args.any? { |sa| a[:type].include?(sa) } ||
36
36
  # native_functions.yaml is missing size argument for normal
37
37
  # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
@@ -112,6 +112,9 @@ void add_%{type}_functions(Module m) {
112
112
  "OptionalScalarType"
113
113
  when "Tensor[]"
114
114
  "TensorList"
115
+ when "Tensor?[]"
116
+ # TODO make optional
117
+ "TensorList"
115
118
  when "int"
116
119
  "int64_t"
117
120
  when "float"
@@ -75,7 +75,7 @@ module Torch
75
75
  v.is_a?(Tensor)
76
76
  when "Tensor?"
77
77
  v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]"
78
+ when "Tensor[]", "Tensor?[]"
79
79
  v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
80
  when "int"
81
81
  if k == "reduction"
@@ -70,6 +70,11 @@ module Torch
70
70
  momentum: exponential_average_factor, eps: @eps
71
71
  )
72
72
  end
73
+
74
+ def extra_inspect
75
+ s = "%{num_features}, eps: %{eps}, momentum: %{momentum}, affine: %{affine}, track_running_stats: %{track_running_stats}"
76
+ format(s, **dict)
77
+ end
73
78
  end
74
79
  end
75
80
  end
@@ -20,7 +20,14 @@ module Torch
20
20
 
21
21
  # TODO add more parameters
22
22
  def extra_inspect
23
- format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
23
+ s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
24
+ s += ", padding: %{padding}" if @padding != [0] * @padding.size
25
+ s += ", dilation: %{dilation}" if @dilation != [1] * @dilation.size
26
+ s += ", output_padding: %{output_padding}" if @output_padding != [0] * @output_padding.size
27
+ s += ", groups: %{groups}" if @groups != 1
28
+ s += ", bias: false" unless @bias
29
+ s += ", padding_mode: %{padding_mode}" if @padding_mode != "zeros"
30
+ format(s, **dict)
24
31
  end
25
32
  end
26
33
  end
@@ -23,7 +23,7 @@ module Torch
23
23
  if bias
24
24
  @bias = Parameter.new(Tensor.new(out_channels))
25
25
  else
26
- raise NotImplementedError
26
+ register_parameter("bias", nil)
27
27
  end
28
28
  reset_parameters
29
29
  end
@@ -12,7 +12,8 @@ module Torch
12
12
  end
13
13
 
14
14
  def extra_inspect
15
- format("kernel_size: %s", @kernel_size)
15
+ s = "kernel_size: %{kernel_size}, stride: %{stride}, padding: %{padding}, dilation: %{dilation}, ceil_mode: %{ceil_mode}"
16
+ format(s, **dict)
16
17
  end
17
18
  end
18
19
  end
@@ -145,7 +145,7 @@ module Torch
145
145
  params = {}
146
146
  if recurse
147
147
  named_children.each do |name, mod|
148
- params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
148
+ params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
149
149
  end
150
150
  end
151
151
  instance_variables.each do |name|
@@ -186,8 +186,22 @@ module Torch
186
186
  named_modules.values
187
187
  end
188
188
 
189
- def named_modules
190
- {"" => self}.merge(named_children)
189
+ # TODO return enumerator?
190
+ def named_modules(memo: nil, prefix: "")
191
+ ret = {}
192
+ memo ||= Set.new
193
+ unless memo.include?(self)
194
+ memo << self
195
+ ret[prefix] = self
196
+ named_children.each do |name, mod|
197
+ next unless mod.is_a?(Module)
198
+ submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
199
+ mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
200
+ ret[m[0]] = m[1]
201
+ end
202
+ end
203
+ end
204
+ ret
191
205
  end
192
206
 
193
207
  def train(mode = true)
@@ -224,13 +238,15 @@ module Torch
224
238
 
225
239
  def inspect
226
240
  name = self.class.name.split("::").last
227
- if children.empty?
241
+ if named_children.empty?
228
242
  "#{name}(#{extra_inspect})"
229
243
  else
230
244
  str = String.new
231
245
  str << "#{name}(\n"
232
- children.each do |name, mod|
233
- str << " (#{name}): #{mod.inspect}\n"
246
+ named_children.each do |name, mod|
247
+ mod_str = mod.inspect
248
+ mod_str = mod_str.lines.join(" ")
249
+ str << " (#{name}): #{mod_str}\n"
234
250
  end
235
251
  str << ")"
236
252
  end
@@ -11,9 +11,6 @@ module Torch
11
11
  end
12
12
 
13
13
  def step(closure = nil)
14
- # TODO implement []=
15
- raise NotImplementedYet
16
-
17
14
  loss = nil
18
15
  if closure
19
16
  loss = closure.call
@@ -4,6 +4,8 @@ module Torch
4
4
  include Inspector
5
5
 
6
6
  alias_method :requires_grad?, :requires_grad
7
+ alias_method :ndim, :dim
8
+ alias_method :ndimension, :dim
7
9
 
8
10
  def self.new(*args)
9
11
  FloatTensor.new(*args)
@@ -23,8 +25,17 @@ module Torch
23
25
  inspect
24
26
  end
25
27
 
28
+ # TODO make more performant
26
29
  def to_a
27
- reshape_arr(_flat_data, shape)
30
+ arr = _flat_data
31
+ if shape.empty?
32
+ arr
33
+ else
34
+ shape[1..-1].reverse.each do |dim|
35
+ arr = arr.each_slice(dim)
36
+ end
37
+ arr.to_a
38
+ end
28
39
  end
29
40
 
30
41
  # TODO support dtype
@@ -62,7 +73,15 @@ module Torch
62
73
  if numel != 1
63
74
  raise Error, "only one element tensors can be converted to Ruby scalars"
64
75
  end
65
- _flat_data.first
76
+ to_a.first
77
+ end
78
+
79
+ def to_i
80
+ item.to_i
81
+ end
82
+
83
+ def to_f
84
+ item.to_f
66
85
  end
67
86
 
68
87
  # unsure if this is correct
@@ -78,7 +97,7 @@ module Torch
78
97
  def numo
79
98
  cls = Torch._dtype_to_numo[dtype]
80
99
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
81
- cls.cast(_flat_data).reshape(*shape)
100
+ cls.from_string(_data_str).reshape(*shape)
82
101
  end
83
102
 
84
103
  def new_ones(*size, **options)
@@ -106,15 +125,6 @@ module Torch
106
125
  _view(size)
107
126
  end
108
127
 
109
- # value and other are swapped for some methods
110
- def add!(value = 1, other)
111
- if other.is_a?(Numeric)
112
- _add__scalar(other, value)
113
- else
114
- _add__tensor(other, value)
115
- end
116
- end
117
-
118
128
  def +(other)
119
129
  add(other)
120
130
  end
@@ -143,11 +153,13 @@ module Torch
143
153
  neg
144
154
  end
145
155
 
156
+ # TODO better compare?
146
157
  def <=>(other)
147
158
  item <=> other
148
159
  end
149
160
 
150
- # based on python_variable_indexing.cpp
161
+ # based on python_variable_indexing.cpp and
162
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
151
163
  def [](*indexes)
152
164
  result = self
153
165
  dim = 0
@@ -159,6 +171,8 @@ module Torch
159
171
  finish += 1 unless index.exclude_end?
160
172
  result = result._slice_tensor(dim, index.begin, finish, 1)
161
173
  dim += 1
174
+ elsif index.is_a?(Tensor)
175
+ result = result.index([index])
162
176
  elsif index.nil?
163
177
  result = result.unsqueeze(dim)
164
178
  dim += 1
@@ -172,12 +186,12 @@ module Torch
172
186
  result
173
187
  end
174
188
 
175
- # TODO
176
- # based on python_variable_indexing.cpp
189
+ # based on python_variable_indexing.cpp and
190
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
177
191
  def []=(index, value)
178
192
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
179
193
 
180
- value = Torch.tensor(value) unless value.is_a?(Tensor)
194
+ value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
181
195
 
182
196
  if index.is_a?(Numeric)
183
197
  copy_to(_select_int(0, index), value)
@@ -185,13 +199,39 @@ module Torch
185
199
  finish = index.end
186
200
  finish += 1 unless index.exclude_end?
187
201
  copy_to(_slice_tensor(0, index.begin, finish, 1), value)
202
+ elsif index.is_a?(Tensor)
203
+ index_put!([index], value)
188
204
  else
189
205
  raise Error, "Unsupported index type: #{index.class.name}"
190
206
  end
191
207
  end
192
208
 
193
- def random!(from = 0, to)
194
- _random__from_to(from, to)
209
+ # native functions that need manually defined
210
+
211
+ # value and other are swapped for some methods
212
+ def add!(value = 1, other)
213
+ if other.is_a?(Numeric)
214
+ _add__scalar(other, value)
215
+ else
216
+ _add__tensor(other, value)
217
+ end
218
+ end
219
+
220
+ # native functions overlap, so need to handle manually
221
+ def random!(*args)
222
+ case args.size
223
+ when 1
224
+ _random__to(*args)
225
+ when 2
226
+ _random__from_to(*args)
227
+ else
228
+ _random_(*args)
229
+ end
230
+ end
231
+
232
+ def clamp!(min, max)
233
+ _clamp_min_(min)
234
+ _clamp_max_(max)
195
235
  end
196
236
 
197
237
  private
@@ -199,17 +239,5 @@ module Torch
199
239
  def copy_to(dst, src)
200
240
  dst.copy!(src)
201
241
  end
202
-
203
- def reshape_arr(arr, dims)
204
- if dims.empty?
205
- arr
206
- else
207
- arr = arr.flatten
208
- dims[1..-1].reverse.each do |dim|
209
- arr = arr.each_slice(dim)
210
- end
211
- arr.to_a
212
- end
213
- end
214
242
  end
215
243
  end
@@ -6,21 +6,49 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1)
9
+ def initialize(dataset, batch_size: 1, shuffle: false)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
+ @shuffle = shuffle
12
13
  end
13
14
 
14
15
  def each
15
- size.times do |i|
16
- start_index = i * @batch_size
17
- yield @dataset[start_index...(start_index + @batch_size)]
16
+ # try to keep the random number generator in sync with Python
17
+ # this makes it easy to compare results
18
+ base_seed = Torch.empty([], dtype: :int64).random!.item
19
+
20
+ indexes =
21
+ if @shuffle
22
+ Torch.randperm(@dataset.size).to_a
23
+ else
24
+ @dataset.size.times
25
+ end
26
+
27
+ indexes.each_slice(@batch_size) do |idx|
28
+ batch = idx.map { |i| @dataset[i] }
29
+ yield collate(batch)
18
30
  end
19
31
  end
20
32
 
21
33
  def size
22
34
  (@dataset.size / @batch_size.to_f).ceil
23
35
  end
36
+
37
+ private
38
+
39
+ def collate(batch)
40
+ elem = batch[0]
41
+ case elem
42
+ when Tensor
43
+ Torch.stack(batch, 0)
44
+ when Integer
45
+ Torch.tensor(batch)
46
+ when Array
47
+ batch.transpose.map { |v| collate(v) }
48
+ else
49
+ raise NotImpelmentYet
50
+ end
51
+ end
24
52
  end
25
53
  end
26
54
  end
@@ -0,0 +1,8 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Dataset
5
+ end
6
+ end
7
+ end
8
+ end
@@ -1,7 +1,7 @@
1
1
  module Torch
2
2
  module Utils
3
3
  module Data
4
- class TensorDataset
4
+ class TensorDataset < Dataset
5
5
  def initialize(*tensors)
6
6
  unless tensors.all? { |t| t.size(0) == tensors[0].size(0) }
7
7
  raise Error, "Tensors must all have same dim 0 size"
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.1"
2
+ VERSION = "0.2.6"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.1
4
+ version: 0.2.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-27 00:00:00.000000000 Z
11
+ date: 2020-06-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -95,19 +95,19 @@ dependencies:
95
95
  - !ruby/object:Gem::Version
96
96
  version: '0'
97
97
  - !ruby/object:Gem::Dependency
98
- name: npy
98
+ name: torchvision
99
99
  requirement: !ruby/object:Gem::Requirement
100
100
  requirements:
101
101
  - - ">="
102
102
  - !ruby/object:Gem::Version
103
- version: '0'
103
+ version: 0.1.1
104
104
  type: :development
105
105
  prerelease: false
106
106
  version_requirements: !ruby/object:Gem::Requirement
107
107
  requirements:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
- version: '0'
110
+ version: 0.1.1
111
111
  description:
112
112
  email: andrew@chartkick.com
113
113
  executables: []
@@ -258,9 +258,9 @@ files:
258
258
  - lib/torch/optim/rmsprop.rb
259
259
  - lib/torch/optim/rprop.rb
260
260
  - lib/torch/optim/sgd.rb
261
- - lib/torch/random.rb
262
261
  - lib/torch/tensor.rb
263
262
  - lib/torch/utils/data/data_loader.rb
263
+ - lib/torch/utils/data/dataset.rb
264
264
  - lib/torch/utils/data/tensor_dataset.rb
265
265
  - lib/torch/version.rb
266
266
  homepage: https://github.com/ankane/torch.rb
@@ -1,10 +0,0 @@
1
- module Torch
2
- module Random
3
- class << self
4
- # not available through LibTorch
5
- def initial_seed
6
- raise NotImplementedYet
7
- end
8
- end
9
- end
10
- end