torch-rb 0.2.2 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -0
- data/README.md +19 -7
- data/ext/torch/ext.cpp +64 -19
- data/ext/torch/extconf.rb +21 -18
- data/lib/torch.rb +6 -3
- data/lib/torch/hub.rb +52 -0
- data/lib/torch/inspector.rb +236 -61
- data/lib/torch/native/function.rb +1 -0
- data/lib/torch/native/generator.rb +5 -2
- data/lib/torch/native/parser.rb +1 -1
- data/lib/torch/nn/batch_norm.rb +5 -0
- data/lib/torch/nn/conv2d.rb +8 -1
- data/lib/torch/nn/convnd.rb +1 -1
- data/lib/torch/nn/max_poolnd.rb +2 -1
- data/lib/torch/nn/module.rb +26 -7
- data/lib/torch/optim/rprop.rb +0 -3
- data/lib/torch/tensor.rb +76 -30
- data/lib/torch/utils/data/data_loader.rb +32 -4
- data/lib/torch/utils/data/dataset.rb +8 -0
- data/lib/torch/utils/data/tensor_dataset.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +6 -6
- data/lib/torch/random.rb +0 -10
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 3451d6140ae6a6a9294a73571239df703a9dc753911c5d97a83bcb020b9d878d
|
4
|
+
data.tar.gz: 65689090d9fe4d9dee078b2f0f0f56526d76158306390c0988e61b0e2ca98ff1
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9f2cc800b8c0e7a3a75bbb9c4705e7e306ed68f52a90530d22659a4d23d8ce0126c1cfd9bc7c33612a842bc20199ba8ec7f488bbba591073d8914f108948e084
|
7
|
+
data.tar.gz: dbf34592bef6e869a3814f20e891d2d566339080a46d335a5e42f114477a5769f63ee18ca3ee8b8f1d031faf898dfe4f6861064f0cb0773b6d75622b4a663e0f
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,34 @@
|
|
1
|
+
## 0.2.7 (2020-06-29)
|
2
|
+
|
3
|
+
- Made tensors enumerable
|
4
|
+
- Improved performance of `inspect` method
|
5
|
+
|
6
|
+
## 0.2.6 (2020-06-29)
|
7
|
+
|
8
|
+
- Added support for indexing with tensors
|
9
|
+
- Added `contiguous` methods
|
10
|
+
- Fixed named parameters for nested parameters
|
11
|
+
|
12
|
+
## 0.2.5 (2020-06-07)
|
13
|
+
|
14
|
+
- Added `download_url_to_file` and `load_state_dict_from_url` to `Torch::Hub`
|
15
|
+
- Improved error messages
|
16
|
+
- Fixed tensor slicing
|
17
|
+
|
18
|
+
## 0.2.4 (2020-04-29)
|
19
|
+
|
20
|
+
- Added `to_i` and `to_f` to tensors
|
21
|
+
- Added `shuffle` option to data loader
|
22
|
+
- Fixed `modules` and `named_modules` for nested modules
|
23
|
+
|
24
|
+
## 0.2.3 (2020-04-28)
|
25
|
+
|
26
|
+
- Added `show_config` and `parallel_info` methods
|
27
|
+
- Added `initial_seed` and `seed` methods to `Random`
|
28
|
+
- Improved data loader
|
29
|
+
- Build with MKL-DNN and NNPACK when available
|
30
|
+
- Fixed `inspect` for modules
|
31
|
+
|
1
32
|
## 0.2.2 (2020-04-27)
|
2
33
|
|
3
34
|
- Added support for saving tensor lists
|
data/README.md
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
+
For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
|
6
|
+
|
5
7
|
[](https://travis-ci.org/ankane/torch.rb)
|
6
8
|
|
7
9
|
## Installation
|
@@ -22,6 +24,18 @@ It can take a few minutes to compile the extension.
|
|
22
24
|
|
23
25
|
## Getting Started
|
24
26
|
|
27
|
+
Deep learning is significantly faster with a GPU. If you don’t have an NVIDIA GPU, we recommend using a cloud service. [Paperspace](https://www.paperspace.com/) has a great free plan.
|
28
|
+
|
29
|
+
We’ve put together a [Docker image](https://github.com/ankane/ml-stack) to make it easy to get started. On Paperspace, create a notebook with a custom container. Set the container name to:
|
30
|
+
|
31
|
+
```text
|
32
|
+
ankane/ml-stack:torch-gpu
|
33
|
+
```
|
34
|
+
|
35
|
+
And leave the other fields in that section blank. Once the notebook is running, you can run the [MNIST example](https://github.com/ankane/ml-stack/blob/master/torch-gpu/MNIST.ipynb).
|
36
|
+
|
37
|
+
## API
|
38
|
+
|
25
39
|
This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
|
26
40
|
|
27
41
|
- Methods that perform in-place modifications end with `!` instead of `_` (`add!` instead of `add_`)
|
@@ -192,7 +206,7 @@ end
|
|
192
206
|
Define a neural network
|
193
207
|
|
194
208
|
```ruby
|
195
|
-
class
|
209
|
+
class MyNet < Torch::NN::Module
|
196
210
|
def initialize
|
197
211
|
super
|
198
212
|
@conv1 = Torch::NN::Conv2d.new(1, 6, 3)
|
@@ -226,7 +240,7 @@ end
|
|
226
240
|
Create an instance of it
|
227
241
|
|
228
242
|
```ruby
|
229
|
-
net =
|
243
|
+
net = MyNet.new
|
230
244
|
input = Torch.randn(1, 1, 32, 32)
|
231
245
|
net.call(input)
|
232
246
|
```
|
@@ -294,7 +308,7 @@ Torch.save(net.state_dict, "net.pth")
|
|
294
308
|
Load a model
|
295
309
|
|
296
310
|
```ruby
|
297
|
-
net =
|
311
|
+
net = MyNet.new
|
298
312
|
net.load_state_dict(Torch.load("net.pth"))
|
299
313
|
net.eval
|
300
314
|
```
|
@@ -395,7 +409,7 @@ Here’s the list of compatible versions.
|
|
395
409
|
|
396
410
|
Torch.rb | LibTorch
|
397
411
|
--- | ---
|
398
|
-
0.2.0 | 1.5.0
|
412
|
+
0.2.0+ | 1.5.0+
|
399
413
|
0.1.8 | 1.4.0
|
400
414
|
0.1.0-0.1.7 | 1.3.1
|
401
415
|
|
@@ -413,9 +427,7 @@ Then install the gem (no need for `bundle config`).
|
|
413
427
|
|
414
428
|
### Linux
|
415
429
|
|
416
|
-
Deep learning is significantly faster on
|
417
|
-
|
418
|
-
Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
430
|
+
Deep learning is significantly faster on a GPU. Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
419
431
|
|
420
432
|
Check if CUDA is available
|
421
433
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -23,7 +23,7 @@ class Parameter: public torch::autograd::Variable {
|
|
23
23
|
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
24
24
|
};
|
25
25
|
|
26
|
-
void handle_error(
|
26
|
+
void handle_error(torch::Error const & ex)
|
27
27
|
{
|
28
28
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
29
|
}
|
@@ -32,16 +32,34 @@ extern "C"
|
|
32
32
|
void Init_ext()
|
33
33
|
{
|
34
34
|
Module rb_mTorch = define_module("Torch");
|
35
|
+
rb_mTorch.add_handler<torch::Error>(handle_error);
|
35
36
|
add_torch_functions(rb_mTorch);
|
36
37
|
|
37
38
|
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
39
|
+
rb_cTensor.add_handler<torch::Error>(handle_error);
|
38
40
|
add_tensor_functions(rb_cTensor);
|
39
41
|
|
40
42
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
43
|
+
rb_mNN.add_handler<torch::Error>(handle_error);
|
41
44
|
add_nn_functions(rb_mNN);
|
42
45
|
|
46
|
+
Module rb_mRandom = define_module_under(rb_mTorch, "Random")
|
47
|
+
.add_handler<torch::Error>(handle_error)
|
48
|
+
.define_singleton_method(
|
49
|
+
"initial_seed",
|
50
|
+
*[]() {
|
51
|
+
return at::detail::getDefaultCPUGenerator()->current_seed();
|
52
|
+
})
|
53
|
+
.define_singleton_method(
|
54
|
+
"seed",
|
55
|
+
*[]() {
|
56
|
+
// TODO set for CUDA when available
|
57
|
+
return at::detail::getDefaultCPUGenerator()->seed();
|
58
|
+
});
|
59
|
+
|
43
60
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
44
61
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
62
|
+
.add_handler<torch::Error>(handle_error)
|
45
63
|
.define_constructor(Constructor<torch::IValue>())
|
46
64
|
.define_method("bool?", &torch::IValue::isBool)
|
47
65
|
.define_method("bool_list?", &torch::IValue::isBoolList)
|
@@ -177,6 +195,17 @@ void Init_ext()
|
|
177
195
|
*[](uint64_t seed) {
|
178
196
|
return torch::manual_seed(seed);
|
179
197
|
})
|
198
|
+
// config
|
199
|
+
.define_singleton_method(
|
200
|
+
"show_config",
|
201
|
+
*[] {
|
202
|
+
return torch::show_config();
|
203
|
+
})
|
204
|
+
.define_singleton_method(
|
205
|
+
"parallel_info",
|
206
|
+
*[] {
|
207
|
+
return torch::get_parallel_info();
|
208
|
+
})
|
180
209
|
// begin tensor creation
|
181
210
|
.define_singleton_method(
|
182
211
|
"_arange",
|
@@ -293,7 +322,6 @@ void Init_ext()
|
|
293
322
|
});
|
294
323
|
|
295
324
|
rb_cTensor
|
296
|
-
.add_handler<c10::Error>(handle_error)
|
297
325
|
.define_method("cuda?", &torch::Tensor::is_cuda)
|
298
326
|
.define_method("sparse?", &torch::Tensor::is_sparse)
|
299
327
|
.define_method("quantized?", &torch::Tensor::is_quantized)
|
@@ -301,6 +329,11 @@ void Init_ext()
|
|
301
329
|
.define_method("numel", &torch::Tensor::numel)
|
302
330
|
.define_method("element_size", &torch::Tensor::element_size)
|
303
331
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
332
|
+
.define_method(
|
333
|
+
"contiguous?",
|
334
|
+
*[](Tensor& self) {
|
335
|
+
return self.is_contiguous();
|
336
|
+
})
|
304
337
|
.define_method(
|
305
338
|
"addcmul!",
|
306
339
|
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
@@ -350,6 +383,21 @@ void Init_ext()
|
|
350
383
|
s << self.device();
|
351
384
|
return s.str();
|
352
385
|
})
|
386
|
+
.define_method(
|
387
|
+
"_data_str",
|
388
|
+
*[](Tensor& self) {
|
389
|
+
Tensor tensor = self;
|
390
|
+
|
391
|
+
// move to CPU to get data
|
392
|
+
if (tensor.device().type() != torch::kCPU) {
|
393
|
+
torch::Device device("cpu");
|
394
|
+
tensor = tensor.to(device);
|
395
|
+
}
|
396
|
+
|
397
|
+
auto data_ptr = (const char *) tensor.data_ptr();
|
398
|
+
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
399
|
+
})
|
400
|
+
// TODO figure out a better way to do this
|
353
401
|
.define_method(
|
354
402
|
"_flat_data",
|
355
403
|
*[](Tensor& self) {
|
@@ -364,46 +412,40 @@ void Init_ext()
|
|
364
412
|
Array a;
|
365
413
|
auto dtype = tensor.dtype();
|
366
414
|
|
415
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
416
|
+
|
367
417
|
// TODO DRY if someone knows C++
|
368
418
|
if (dtype == torch::kByte) {
|
369
|
-
uint8_t* data = tensor.data_ptr<uint8_t>();
|
370
419
|
for (int i = 0; i < tensor.numel(); i++) {
|
371
|
-
a.push(
|
420
|
+
a.push(view[i].item().to<uint8_t>());
|
372
421
|
}
|
373
422
|
} else if (dtype == torch::kChar) {
|
374
|
-
int8_t* data = tensor.data_ptr<int8_t>();
|
375
423
|
for (int i = 0; i < tensor.numel(); i++) {
|
376
|
-
a.push(to_ruby<int>(
|
424
|
+
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
377
425
|
}
|
378
426
|
} else if (dtype == torch::kShort) {
|
379
|
-
int16_t* data = tensor.data_ptr<int16_t>();
|
380
427
|
for (int i = 0; i < tensor.numel(); i++) {
|
381
|
-
a.push(
|
428
|
+
a.push(view[i].item().to<int16_t>());
|
382
429
|
}
|
383
430
|
} else if (dtype == torch::kInt) {
|
384
|
-
int32_t* data = tensor.data_ptr<int32_t>();
|
385
431
|
for (int i = 0; i < tensor.numel(); i++) {
|
386
|
-
a.push(
|
432
|
+
a.push(view[i].item().to<int32_t>());
|
387
433
|
}
|
388
434
|
} else if (dtype == torch::kLong) {
|
389
|
-
int64_t* data = tensor.data_ptr<int64_t>();
|
390
435
|
for (int i = 0; i < tensor.numel(); i++) {
|
391
|
-
a.push(
|
436
|
+
a.push(view[i].item().to<int64_t>());
|
392
437
|
}
|
393
438
|
} else if (dtype == torch::kFloat) {
|
394
|
-
float* data = tensor.data_ptr<float>();
|
395
439
|
for (int i = 0; i < tensor.numel(); i++) {
|
396
|
-
a.push(
|
440
|
+
a.push(view[i].item().to<float>());
|
397
441
|
}
|
398
442
|
} else if (dtype == torch::kDouble) {
|
399
|
-
double* data = tensor.data_ptr<double>();
|
400
443
|
for (int i = 0; i < tensor.numel(); i++) {
|
401
|
-
a.push(
|
444
|
+
a.push(view[i].item().to<double>());
|
402
445
|
}
|
403
446
|
} else if (dtype == torch::kBool) {
|
404
|
-
bool* data = tensor.data_ptr<bool>();
|
405
447
|
for (int i = 0; i < tensor.numel(); i++) {
|
406
|
-
a.push(
|
448
|
+
a.push(view[i].item().to<bool>() ? True : False);
|
407
449
|
}
|
408
450
|
} else {
|
409
451
|
throw std::runtime_error("Unsupported type");
|
@@ -425,7 +467,7 @@ void Init_ext()
|
|
425
467
|
});
|
426
468
|
|
427
469
|
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
|
428
|
-
.add_handler<
|
470
|
+
.add_handler<torch::Error>(handle_error)
|
429
471
|
.define_constructor(Constructor<torch::TensorOptions>())
|
430
472
|
.define_method(
|
431
473
|
"dtype",
|
@@ -531,6 +573,7 @@ void Init_ext()
|
|
531
573
|
});
|
532
574
|
|
533
575
|
Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
576
|
+
.add_handler<torch::Error>(handle_error)
|
534
577
|
.define_method(
|
535
578
|
"grad",
|
536
579
|
*[](Parameter& self) {
|
@@ -540,6 +583,7 @@ void Init_ext()
|
|
540
583
|
|
541
584
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
542
585
|
.define_constructor(Constructor<torch::Device, std::string>())
|
586
|
+
.add_handler<torch::Error>(handle_error)
|
543
587
|
.define_method("index", &torch::Device::index)
|
544
588
|
.define_method("index?", &torch::Device::has_index)
|
545
589
|
.define_method(
|
@@ -551,6 +595,7 @@ void Init_ext()
|
|
551
595
|
});
|
552
596
|
|
553
597
|
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
598
|
+
.add_handler<torch::Error>(handle_error)
|
554
599
|
.define_singleton_method("available?", &torch::cuda::is_available)
|
555
600
|
.define_singleton_method("device_count", &torch::cuda::device_count);
|
556
601
|
}
|
data/ext/torch/extconf.rb
CHANGED
@@ -2,33 +2,33 @@ require "mkmf-rice"
|
|
2
2
|
|
3
3
|
abort "Missing stdc++" unless have_library("stdc++")
|
4
4
|
|
5
|
-
$CXXFLAGS
|
5
|
+
$CXXFLAGS += " -std=c++14"
|
6
6
|
|
7
7
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
|
-
$CXXFLAGS
|
8
|
+
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
9
9
|
|
10
10
|
# TODO check compiler name
|
11
11
|
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
12
12
|
|
13
13
|
# check omp first
|
14
14
|
if have_library("omp") || have_library("gomp")
|
15
|
-
$CXXFLAGS
|
16
|
-
$CXXFLAGS
|
17
|
-
$CXXFLAGS
|
15
|
+
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
16
|
+
$CXXFLAGS += " -Xclang" if clang
|
17
|
+
$CXXFLAGS += " -fopenmp"
|
18
18
|
end
|
19
19
|
|
20
20
|
if clang
|
21
21
|
# silence ruby/intern.h warning
|
22
|
-
$CXXFLAGS
|
22
|
+
$CXXFLAGS += " -Wno-deprecated-register"
|
23
23
|
|
24
24
|
# silence torch warnings
|
25
|
-
$CXXFLAGS
|
25
|
+
$CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
|
26
26
|
else
|
27
27
|
# silence rice warnings
|
28
|
-
$CXXFLAGS
|
28
|
+
$CXXFLAGS += " -Wno-noexcept-type"
|
29
29
|
|
30
30
|
# silence torch warnings
|
31
|
-
$CXXFLAGS
|
31
|
+
$CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
|
32
32
|
end
|
33
33
|
|
34
34
|
inc, lib = dir_config("torch")
|
@@ -39,27 +39,30 @@ cuda_inc, cuda_lib = dir_config("cuda")
|
|
39
39
|
cuda_inc ||= "/usr/local/cuda/include"
|
40
40
|
cuda_lib ||= "/usr/local/cuda/lib64"
|
41
41
|
|
42
|
-
$LDFLAGS
|
42
|
+
$LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
|
43
43
|
abort "LibTorch not found" unless have_library("torch")
|
44
44
|
|
45
|
+
have_library("mkldnn")
|
46
|
+
have_library("nnpack")
|
47
|
+
|
45
48
|
with_cuda = false
|
46
49
|
if Dir["#{lib}/*torch_cuda*"].any?
|
47
|
-
$LDFLAGS
|
50
|
+
$LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
|
48
51
|
with_cuda = have_library("cuda") && have_library("cudnn")
|
49
52
|
end
|
50
53
|
|
51
|
-
$INCFLAGS
|
52
|
-
$INCFLAGS
|
54
|
+
$INCFLAGS += " -I#{inc}"
|
55
|
+
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
53
56
|
|
54
|
-
$LDFLAGS
|
55
|
-
$LDFLAGS
|
57
|
+
$LDFLAGS += " -Wl,-rpath,#{lib}"
|
58
|
+
$LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
56
59
|
|
57
60
|
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
58
|
-
$LDFLAGS
|
61
|
+
$LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
|
59
62
|
if with_cuda
|
60
|
-
$LDFLAGS
|
63
|
+
$LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
|
61
64
|
# TODO figure out why this is needed
|
62
|
-
$LDFLAGS
|
65
|
+
$LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
|
63
66
|
end
|
64
67
|
|
65
68
|
# generate C++ functions
|
data/lib/torch.rb
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
# ext
|
2
2
|
require "torch/ext"
|
3
3
|
|
4
|
+
# stdlib
|
5
|
+
require "fileutils"
|
6
|
+
require "net/http"
|
7
|
+
require "tmpdir"
|
8
|
+
|
4
9
|
# native functions
|
5
10
|
require "torch/native/generator"
|
6
11
|
require "torch/native/parser"
|
@@ -174,11 +179,9 @@ require "torch/nn/init"
|
|
174
179
|
|
175
180
|
# utils
|
176
181
|
require "torch/utils/data/data_loader"
|
182
|
+
require "torch/utils/data/dataset"
|
177
183
|
require "torch/utils/data/tensor_dataset"
|
178
184
|
|
179
|
-
# random
|
180
|
-
require "torch/random"
|
181
|
-
|
182
185
|
# hub
|
183
186
|
require "torch/hub"
|
184
187
|
|
data/lib/torch/hub.rb
CHANGED
@@ -4,6 +4,58 @@ module Torch
|
|
4
4
|
def list(github, force_reload: false)
|
5
5
|
raise NotImplementedYet
|
6
6
|
end
|
7
|
+
|
8
|
+
def download_url_to_file(url, dst)
|
9
|
+
uri = URI(url)
|
10
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
11
|
+
location = nil
|
12
|
+
|
13
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
|
+
request = Net::HTTP::Get.new(uri)
|
15
|
+
|
16
|
+
puts "Downloading #{url}..."
|
17
|
+
File.open(tmp, "wb") do |f|
|
18
|
+
http.request(request) do |response|
|
19
|
+
case response
|
20
|
+
when Net::HTTPRedirection
|
21
|
+
location = response["location"]
|
22
|
+
when Net::HTTPSuccess
|
23
|
+
response.read_body do |chunk|
|
24
|
+
f.write(chunk)
|
25
|
+
end
|
26
|
+
else
|
27
|
+
raise Error, "Bad response"
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
if location
|
34
|
+
download_url_to_file(location, dst)
|
35
|
+
else
|
36
|
+
FileUtils.mv(tmp, dst)
|
37
|
+
nil
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
def load_state_dict_from_url(url, model_dir: nil)
|
42
|
+
unless model_dir
|
43
|
+
torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
|
44
|
+
model_dir = File.join(torch_home, "checkpoints")
|
45
|
+
end
|
46
|
+
|
47
|
+
FileUtils.mkdir_p(model_dir)
|
48
|
+
|
49
|
+
parts = URI(url)
|
50
|
+
filename = File.basename(parts.path)
|
51
|
+
cached_file = File.join(model_dir, filename)
|
52
|
+
unless File.exist?(cached_file)
|
53
|
+
# TODO support hash_prefix
|
54
|
+
download_url_to_file(url, cached_file)
|
55
|
+
end
|
56
|
+
|
57
|
+
Torch.load(cached_file)
|
58
|
+
end
|
7
59
|
end
|
8
60
|
end
|
9
61
|
end
|