torch-rb 0.2.2 → 0.2.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -0
- data/README.md +19 -7
- data/ext/torch/ext.cpp +64 -19
- data/ext/torch/extconf.rb +21 -18
- data/lib/torch.rb +6 -3
- data/lib/torch/hub.rb +52 -0
- data/lib/torch/inspector.rb +236 -61
- data/lib/torch/native/function.rb +1 -0
- data/lib/torch/native/generator.rb +5 -2
- data/lib/torch/native/parser.rb +1 -1
- data/lib/torch/nn/batch_norm.rb +5 -0
- data/lib/torch/nn/conv2d.rb +8 -1
- data/lib/torch/nn/convnd.rb +1 -1
- data/lib/torch/nn/max_poolnd.rb +2 -1
- data/lib/torch/nn/module.rb +26 -7
- data/lib/torch/optim/rprop.rb +0 -3
- data/lib/torch/tensor.rb +76 -30
- data/lib/torch/utils/data/data_loader.rb +32 -4
- data/lib/torch/utils/data/dataset.rb +8 -0
- data/lib/torch/utils/data/tensor_dataset.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +6 -6
- data/lib/torch/random.rb +0 -10
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 3451d6140ae6a6a9294a73571239df703a9dc753911c5d97a83bcb020b9d878d
|
4
|
+
data.tar.gz: 65689090d9fe4d9dee078b2f0f0f56526d76158306390c0988e61b0e2ca98ff1
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9f2cc800b8c0e7a3a75bbb9c4705e7e306ed68f52a90530d22659a4d23d8ce0126c1cfd9bc7c33612a842bc20199ba8ec7f488bbba591073d8914f108948e084
|
7
|
+
data.tar.gz: dbf34592bef6e869a3814f20e891d2d566339080a46d335a5e42f114477a5769f63ee18ca3ee8b8f1d031faf898dfe4f6861064f0cb0773b6d75622b4a663e0f
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,34 @@
|
|
1
|
+
## 0.2.7 (2020-06-29)
|
2
|
+
|
3
|
+
- Made tensors enumerable
|
4
|
+
- Improved performance of `inspect` method
|
5
|
+
|
6
|
+
## 0.2.6 (2020-06-29)
|
7
|
+
|
8
|
+
- Added support for indexing with tensors
|
9
|
+
- Added `contiguous` methods
|
10
|
+
- Fixed named parameters for nested parameters
|
11
|
+
|
12
|
+
## 0.2.5 (2020-06-07)
|
13
|
+
|
14
|
+
- Added `download_url_to_file` and `load_state_dict_from_url` to `Torch::Hub`
|
15
|
+
- Improved error messages
|
16
|
+
- Fixed tensor slicing
|
17
|
+
|
18
|
+
## 0.2.4 (2020-04-29)
|
19
|
+
|
20
|
+
- Added `to_i` and `to_f` to tensors
|
21
|
+
- Added `shuffle` option to data loader
|
22
|
+
- Fixed `modules` and `named_modules` for nested modules
|
23
|
+
|
24
|
+
## 0.2.3 (2020-04-28)
|
25
|
+
|
26
|
+
- Added `show_config` and `parallel_info` methods
|
27
|
+
- Added `initial_seed` and `seed` methods to `Random`
|
28
|
+
- Improved data loader
|
29
|
+
- Build with MKL-DNN and NNPACK when available
|
30
|
+
- Fixed `inspect` for modules
|
31
|
+
|
1
32
|
## 0.2.2 (2020-04-27)
|
2
33
|
|
3
34
|
- Added support for saving tensor lists
|
data/README.md
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
+
For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
|
6
|
+
|
5
7
|
[![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
|
6
8
|
|
7
9
|
## Installation
|
@@ -22,6 +24,18 @@ It can take a few minutes to compile the extension.
|
|
22
24
|
|
23
25
|
## Getting Started
|
24
26
|
|
27
|
+
Deep learning is significantly faster with a GPU. If you don’t have an NVIDIA GPU, we recommend using a cloud service. [Paperspace](https://www.paperspace.com/) has a great free plan.
|
28
|
+
|
29
|
+
We’ve put together a [Docker image](https://github.com/ankane/ml-stack) to make it easy to get started. On Paperspace, create a notebook with a custom container. Set the container name to:
|
30
|
+
|
31
|
+
```text
|
32
|
+
ankane/ml-stack:torch-gpu
|
33
|
+
```
|
34
|
+
|
35
|
+
And leave the other fields in that section blank. Once the notebook is running, you can run the [MNIST example](https://github.com/ankane/ml-stack/blob/master/torch-gpu/MNIST.ipynb).
|
36
|
+
|
37
|
+
## API
|
38
|
+
|
25
39
|
This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
|
26
40
|
|
27
41
|
- Methods that perform in-place modifications end with `!` instead of `_` (`add!` instead of `add_`)
|
@@ -192,7 +206,7 @@ end
|
|
192
206
|
Define a neural network
|
193
207
|
|
194
208
|
```ruby
|
195
|
-
class
|
209
|
+
class MyNet < Torch::NN::Module
|
196
210
|
def initialize
|
197
211
|
super
|
198
212
|
@conv1 = Torch::NN::Conv2d.new(1, 6, 3)
|
@@ -226,7 +240,7 @@ end
|
|
226
240
|
Create an instance of it
|
227
241
|
|
228
242
|
```ruby
|
229
|
-
net =
|
243
|
+
net = MyNet.new
|
230
244
|
input = Torch.randn(1, 1, 32, 32)
|
231
245
|
net.call(input)
|
232
246
|
```
|
@@ -294,7 +308,7 @@ Torch.save(net.state_dict, "net.pth")
|
|
294
308
|
Load a model
|
295
309
|
|
296
310
|
```ruby
|
297
|
-
net =
|
311
|
+
net = MyNet.new
|
298
312
|
net.load_state_dict(Torch.load("net.pth"))
|
299
313
|
net.eval
|
300
314
|
```
|
@@ -395,7 +409,7 @@ Here’s the list of compatible versions.
|
|
395
409
|
|
396
410
|
Torch.rb | LibTorch
|
397
411
|
--- | ---
|
398
|
-
0.2.0 | 1.5.0
|
412
|
+
0.2.0+ | 1.5.0+
|
399
413
|
0.1.8 | 1.4.0
|
400
414
|
0.1.0-0.1.7 | 1.3.1
|
401
415
|
|
@@ -413,9 +427,7 @@ Then install the gem (no need for `bundle config`).
|
|
413
427
|
|
414
428
|
### Linux
|
415
429
|
|
416
|
-
Deep learning is significantly faster on
|
417
|
-
|
418
|
-
Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
430
|
+
Deep learning is significantly faster on a GPU. Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
419
431
|
|
420
432
|
Check if CUDA is available
|
421
433
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -23,7 +23,7 @@ class Parameter: public torch::autograd::Variable {
|
|
23
23
|
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
24
24
|
};
|
25
25
|
|
26
|
-
void handle_error(
|
26
|
+
void handle_error(torch::Error const & ex)
|
27
27
|
{
|
28
28
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
29
|
}
|
@@ -32,16 +32,34 @@ extern "C"
|
|
32
32
|
void Init_ext()
|
33
33
|
{
|
34
34
|
Module rb_mTorch = define_module("Torch");
|
35
|
+
rb_mTorch.add_handler<torch::Error>(handle_error);
|
35
36
|
add_torch_functions(rb_mTorch);
|
36
37
|
|
37
38
|
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
39
|
+
rb_cTensor.add_handler<torch::Error>(handle_error);
|
38
40
|
add_tensor_functions(rb_cTensor);
|
39
41
|
|
40
42
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
43
|
+
rb_mNN.add_handler<torch::Error>(handle_error);
|
41
44
|
add_nn_functions(rb_mNN);
|
42
45
|
|
46
|
+
Module rb_mRandom = define_module_under(rb_mTorch, "Random")
|
47
|
+
.add_handler<torch::Error>(handle_error)
|
48
|
+
.define_singleton_method(
|
49
|
+
"initial_seed",
|
50
|
+
*[]() {
|
51
|
+
return at::detail::getDefaultCPUGenerator()->current_seed();
|
52
|
+
})
|
53
|
+
.define_singleton_method(
|
54
|
+
"seed",
|
55
|
+
*[]() {
|
56
|
+
// TODO set for CUDA when available
|
57
|
+
return at::detail::getDefaultCPUGenerator()->seed();
|
58
|
+
});
|
59
|
+
|
43
60
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
44
61
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
62
|
+
.add_handler<torch::Error>(handle_error)
|
45
63
|
.define_constructor(Constructor<torch::IValue>())
|
46
64
|
.define_method("bool?", &torch::IValue::isBool)
|
47
65
|
.define_method("bool_list?", &torch::IValue::isBoolList)
|
@@ -177,6 +195,17 @@ void Init_ext()
|
|
177
195
|
*[](uint64_t seed) {
|
178
196
|
return torch::manual_seed(seed);
|
179
197
|
})
|
198
|
+
// config
|
199
|
+
.define_singleton_method(
|
200
|
+
"show_config",
|
201
|
+
*[] {
|
202
|
+
return torch::show_config();
|
203
|
+
})
|
204
|
+
.define_singleton_method(
|
205
|
+
"parallel_info",
|
206
|
+
*[] {
|
207
|
+
return torch::get_parallel_info();
|
208
|
+
})
|
180
209
|
// begin tensor creation
|
181
210
|
.define_singleton_method(
|
182
211
|
"_arange",
|
@@ -293,7 +322,6 @@ void Init_ext()
|
|
293
322
|
});
|
294
323
|
|
295
324
|
rb_cTensor
|
296
|
-
.add_handler<c10::Error>(handle_error)
|
297
325
|
.define_method("cuda?", &torch::Tensor::is_cuda)
|
298
326
|
.define_method("sparse?", &torch::Tensor::is_sparse)
|
299
327
|
.define_method("quantized?", &torch::Tensor::is_quantized)
|
@@ -301,6 +329,11 @@ void Init_ext()
|
|
301
329
|
.define_method("numel", &torch::Tensor::numel)
|
302
330
|
.define_method("element_size", &torch::Tensor::element_size)
|
303
331
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
332
|
+
.define_method(
|
333
|
+
"contiguous?",
|
334
|
+
*[](Tensor& self) {
|
335
|
+
return self.is_contiguous();
|
336
|
+
})
|
304
337
|
.define_method(
|
305
338
|
"addcmul!",
|
306
339
|
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
@@ -350,6 +383,21 @@ void Init_ext()
|
|
350
383
|
s << self.device();
|
351
384
|
return s.str();
|
352
385
|
})
|
386
|
+
.define_method(
|
387
|
+
"_data_str",
|
388
|
+
*[](Tensor& self) {
|
389
|
+
Tensor tensor = self;
|
390
|
+
|
391
|
+
// move to CPU to get data
|
392
|
+
if (tensor.device().type() != torch::kCPU) {
|
393
|
+
torch::Device device("cpu");
|
394
|
+
tensor = tensor.to(device);
|
395
|
+
}
|
396
|
+
|
397
|
+
auto data_ptr = (const char *) tensor.data_ptr();
|
398
|
+
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
399
|
+
})
|
400
|
+
// TODO figure out a better way to do this
|
353
401
|
.define_method(
|
354
402
|
"_flat_data",
|
355
403
|
*[](Tensor& self) {
|
@@ -364,46 +412,40 @@ void Init_ext()
|
|
364
412
|
Array a;
|
365
413
|
auto dtype = tensor.dtype();
|
366
414
|
|
415
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
416
|
+
|
367
417
|
// TODO DRY if someone knows C++
|
368
418
|
if (dtype == torch::kByte) {
|
369
|
-
uint8_t* data = tensor.data_ptr<uint8_t>();
|
370
419
|
for (int i = 0; i < tensor.numel(); i++) {
|
371
|
-
a.push(
|
420
|
+
a.push(view[i].item().to<uint8_t>());
|
372
421
|
}
|
373
422
|
} else if (dtype == torch::kChar) {
|
374
|
-
int8_t* data = tensor.data_ptr<int8_t>();
|
375
423
|
for (int i = 0; i < tensor.numel(); i++) {
|
376
|
-
a.push(to_ruby<int>(
|
424
|
+
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
377
425
|
}
|
378
426
|
} else if (dtype == torch::kShort) {
|
379
|
-
int16_t* data = tensor.data_ptr<int16_t>();
|
380
427
|
for (int i = 0; i < tensor.numel(); i++) {
|
381
|
-
a.push(
|
428
|
+
a.push(view[i].item().to<int16_t>());
|
382
429
|
}
|
383
430
|
} else if (dtype == torch::kInt) {
|
384
|
-
int32_t* data = tensor.data_ptr<int32_t>();
|
385
431
|
for (int i = 0; i < tensor.numel(); i++) {
|
386
|
-
a.push(
|
432
|
+
a.push(view[i].item().to<int32_t>());
|
387
433
|
}
|
388
434
|
} else if (dtype == torch::kLong) {
|
389
|
-
int64_t* data = tensor.data_ptr<int64_t>();
|
390
435
|
for (int i = 0; i < tensor.numel(); i++) {
|
391
|
-
a.push(
|
436
|
+
a.push(view[i].item().to<int64_t>());
|
392
437
|
}
|
393
438
|
} else if (dtype == torch::kFloat) {
|
394
|
-
float* data = tensor.data_ptr<float>();
|
395
439
|
for (int i = 0; i < tensor.numel(); i++) {
|
396
|
-
a.push(
|
440
|
+
a.push(view[i].item().to<float>());
|
397
441
|
}
|
398
442
|
} else if (dtype == torch::kDouble) {
|
399
|
-
double* data = tensor.data_ptr<double>();
|
400
443
|
for (int i = 0; i < tensor.numel(); i++) {
|
401
|
-
a.push(
|
444
|
+
a.push(view[i].item().to<double>());
|
402
445
|
}
|
403
446
|
} else if (dtype == torch::kBool) {
|
404
|
-
bool* data = tensor.data_ptr<bool>();
|
405
447
|
for (int i = 0; i < tensor.numel(); i++) {
|
406
|
-
a.push(
|
448
|
+
a.push(view[i].item().to<bool>() ? True : False);
|
407
449
|
}
|
408
450
|
} else {
|
409
451
|
throw std::runtime_error("Unsupported type");
|
@@ -425,7 +467,7 @@ void Init_ext()
|
|
425
467
|
});
|
426
468
|
|
427
469
|
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
|
428
|
-
.add_handler<
|
470
|
+
.add_handler<torch::Error>(handle_error)
|
429
471
|
.define_constructor(Constructor<torch::TensorOptions>())
|
430
472
|
.define_method(
|
431
473
|
"dtype",
|
@@ -531,6 +573,7 @@ void Init_ext()
|
|
531
573
|
});
|
532
574
|
|
533
575
|
Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
576
|
+
.add_handler<torch::Error>(handle_error)
|
534
577
|
.define_method(
|
535
578
|
"grad",
|
536
579
|
*[](Parameter& self) {
|
@@ -540,6 +583,7 @@ void Init_ext()
|
|
540
583
|
|
541
584
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
542
585
|
.define_constructor(Constructor<torch::Device, std::string>())
|
586
|
+
.add_handler<torch::Error>(handle_error)
|
543
587
|
.define_method("index", &torch::Device::index)
|
544
588
|
.define_method("index?", &torch::Device::has_index)
|
545
589
|
.define_method(
|
@@ -551,6 +595,7 @@ void Init_ext()
|
|
551
595
|
});
|
552
596
|
|
553
597
|
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
598
|
+
.add_handler<torch::Error>(handle_error)
|
554
599
|
.define_singleton_method("available?", &torch::cuda::is_available)
|
555
600
|
.define_singleton_method("device_count", &torch::cuda::device_count);
|
556
601
|
}
|
data/ext/torch/extconf.rb
CHANGED
@@ -2,33 +2,33 @@ require "mkmf-rice"
|
|
2
2
|
|
3
3
|
abort "Missing stdc++" unless have_library("stdc++")
|
4
4
|
|
5
|
-
$CXXFLAGS
|
5
|
+
$CXXFLAGS += " -std=c++14"
|
6
6
|
|
7
7
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
|
-
$CXXFLAGS
|
8
|
+
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
9
9
|
|
10
10
|
# TODO check compiler name
|
11
11
|
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
12
12
|
|
13
13
|
# check omp first
|
14
14
|
if have_library("omp") || have_library("gomp")
|
15
|
-
$CXXFLAGS
|
16
|
-
$CXXFLAGS
|
17
|
-
$CXXFLAGS
|
15
|
+
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
16
|
+
$CXXFLAGS += " -Xclang" if clang
|
17
|
+
$CXXFLAGS += " -fopenmp"
|
18
18
|
end
|
19
19
|
|
20
20
|
if clang
|
21
21
|
# silence ruby/intern.h warning
|
22
|
-
$CXXFLAGS
|
22
|
+
$CXXFLAGS += " -Wno-deprecated-register"
|
23
23
|
|
24
24
|
# silence torch warnings
|
25
|
-
$CXXFLAGS
|
25
|
+
$CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
|
26
26
|
else
|
27
27
|
# silence rice warnings
|
28
|
-
$CXXFLAGS
|
28
|
+
$CXXFLAGS += " -Wno-noexcept-type"
|
29
29
|
|
30
30
|
# silence torch warnings
|
31
|
-
$CXXFLAGS
|
31
|
+
$CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
|
32
32
|
end
|
33
33
|
|
34
34
|
inc, lib = dir_config("torch")
|
@@ -39,27 +39,30 @@ cuda_inc, cuda_lib = dir_config("cuda")
|
|
39
39
|
cuda_inc ||= "/usr/local/cuda/include"
|
40
40
|
cuda_lib ||= "/usr/local/cuda/lib64"
|
41
41
|
|
42
|
-
$LDFLAGS
|
42
|
+
$LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
|
43
43
|
abort "LibTorch not found" unless have_library("torch")
|
44
44
|
|
45
|
+
have_library("mkldnn")
|
46
|
+
have_library("nnpack")
|
47
|
+
|
45
48
|
with_cuda = false
|
46
49
|
if Dir["#{lib}/*torch_cuda*"].any?
|
47
|
-
$LDFLAGS
|
50
|
+
$LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
|
48
51
|
with_cuda = have_library("cuda") && have_library("cudnn")
|
49
52
|
end
|
50
53
|
|
51
|
-
$INCFLAGS
|
52
|
-
$INCFLAGS
|
54
|
+
$INCFLAGS += " -I#{inc}"
|
55
|
+
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
53
56
|
|
54
|
-
$LDFLAGS
|
55
|
-
$LDFLAGS
|
57
|
+
$LDFLAGS += " -Wl,-rpath,#{lib}"
|
58
|
+
$LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
56
59
|
|
57
60
|
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
58
|
-
$LDFLAGS
|
61
|
+
$LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
|
59
62
|
if with_cuda
|
60
|
-
$LDFLAGS
|
63
|
+
$LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
|
61
64
|
# TODO figure out why this is needed
|
62
|
-
$LDFLAGS
|
65
|
+
$LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
|
63
66
|
end
|
64
67
|
|
65
68
|
# generate C++ functions
|
data/lib/torch.rb
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
# ext
|
2
2
|
require "torch/ext"
|
3
3
|
|
4
|
+
# stdlib
|
5
|
+
require "fileutils"
|
6
|
+
require "net/http"
|
7
|
+
require "tmpdir"
|
8
|
+
|
4
9
|
# native functions
|
5
10
|
require "torch/native/generator"
|
6
11
|
require "torch/native/parser"
|
@@ -174,11 +179,9 @@ require "torch/nn/init"
|
|
174
179
|
|
175
180
|
# utils
|
176
181
|
require "torch/utils/data/data_loader"
|
182
|
+
require "torch/utils/data/dataset"
|
177
183
|
require "torch/utils/data/tensor_dataset"
|
178
184
|
|
179
|
-
# random
|
180
|
-
require "torch/random"
|
181
|
-
|
182
185
|
# hub
|
183
186
|
require "torch/hub"
|
184
187
|
|
data/lib/torch/hub.rb
CHANGED
@@ -4,6 +4,58 @@ module Torch
|
|
4
4
|
def list(github, force_reload: false)
|
5
5
|
raise NotImplementedYet
|
6
6
|
end
|
7
|
+
|
8
|
+
def download_url_to_file(url, dst)
|
9
|
+
uri = URI(url)
|
10
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
11
|
+
location = nil
|
12
|
+
|
13
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
|
+
request = Net::HTTP::Get.new(uri)
|
15
|
+
|
16
|
+
puts "Downloading #{url}..."
|
17
|
+
File.open(tmp, "wb") do |f|
|
18
|
+
http.request(request) do |response|
|
19
|
+
case response
|
20
|
+
when Net::HTTPRedirection
|
21
|
+
location = response["location"]
|
22
|
+
when Net::HTTPSuccess
|
23
|
+
response.read_body do |chunk|
|
24
|
+
f.write(chunk)
|
25
|
+
end
|
26
|
+
else
|
27
|
+
raise Error, "Bad response"
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
if location
|
34
|
+
download_url_to_file(location, dst)
|
35
|
+
else
|
36
|
+
FileUtils.mv(tmp, dst)
|
37
|
+
nil
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
def load_state_dict_from_url(url, model_dir: nil)
|
42
|
+
unless model_dir
|
43
|
+
torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
|
44
|
+
model_dir = File.join(torch_home, "checkpoints")
|
45
|
+
end
|
46
|
+
|
47
|
+
FileUtils.mkdir_p(model_dir)
|
48
|
+
|
49
|
+
parts = URI(url)
|
50
|
+
filename = File.basename(parts.path)
|
51
|
+
cached_file = File.join(model_dir, filename)
|
52
|
+
unless File.exist?(cached_file)
|
53
|
+
# TODO support hash_prefix
|
54
|
+
download_url_to_file(url, cached_file)
|
55
|
+
end
|
56
|
+
|
57
|
+
Torch.load(cached_file)
|
58
|
+
end
|
7
59
|
end
|
8
60
|
end
|
9
61
|
end
|