torch-rb 0.3.1 → 0.3.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/ext/torch/ext.cpp +12 -1
- data/lib/torch.rb +12 -0
- data/lib/torch/hub.rb +11 -10
- data/lib/torch/nn/functional.rb +5 -1
- data/lib/torch/utils/data.rb +23 -0
- data/lib/torch/utils/data/data_loader.rb +22 -6
- data/lib/torch/utils/data/subset.rb +25 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 97908e85a67729120f763bb4140323505b77a831d5648e9d2d0961259e3d300c
|
4
|
+
data.tar.gz: f366548f9880dac7dffce6305e192f75a7467526ae55ae13af05d355918375ba
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: cc32ddbc43131175452a8b62df5d1eac6bc8450eea174018affd5cd073f81e2c9825d613014c62c5f8137cf5dddd1ab6ab6de60a4b3a67a757387446dbc1efad
|
7
|
+
data.tar.gz: c322e0b7ec7f03f12311d737034dad45037d2ad7710974e24250e11f4a0db14e221e870ddc52c0f2b723476be6f41fca8a8719068b0d0b7d8974d2080e9be6dc
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,12 @@
|
|
1
|
+
## 0.3.2 (2020-08-24)
|
2
|
+
|
3
|
+
- Added `enable_grad` method
|
4
|
+
- Added `random_split` method
|
5
|
+
- Added `collate_fn` option to `DataLoader`
|
6
|
+
- Added `grad=` method to `Tensor`
|
7
|
+
- Fixed error with `grad` method when empty
|
8
|
+
- Fixed `EmbeddingBag`
|
9
|
+
|
1
10
|
## 0.3.1 (2020-08-17)
|
2
11
|
|
3
12
|
- Added `create_graph` and `retain_graph` options to `backward` method
|
data/ext/torch/ext.cpp
CHANGED
@@ -358,7 +358,13 @@ void Init_ext()
|
|
358
358
|
.define_method(
|
359
359
|
"grad",
|
360
360
|
*[](Tensor& self) {
|
361
|
-
|
361
|
+
auto grad = self.grad();
|
362
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
363
|
+
})
|
364
|
+
.define_method(
|
365
|
+
"grad=",
|
366
|
+
*[](Tensor& self, torch::Tensor& grad) {
|
367
|
+
self.grad() = grad;
|
362
368
|
})
|
363
369
|
.define_method(
|
364
370
|
"_dtype",
|
@@ -580,6 +586,11 @@ void Init_ext()
|
|
580
586
|
*[](Parameter& self) {
|
581
587
|
auto grad = self.grad();
|
582
588
|
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
589
|
+
})
|
590
|
+
.define_method(
|
591
|
+
"grad=",
|
592
|
+
*[](Parameter& self, torch::Tensor& grad) {
|
593
|
+
self.grad() = grad;
|
583
594
|
});
|
584
595
|
|
585
596
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
data/lib/torch.rb
CHANGED
@@ -179,8 +179,10 @@ require "torch/nn/functional"
|
|
179
179
|
require "torch/nn/init"
|
180
180
|
|
181
181
|
# utils
|
182
|
+
require "torch/utils/data"
|
182
183
|
require "torch/utils/data/data_loader"
|
183
184
|
require "torch/utils/data/dataset"
|
185
|
+
require "torch/utils/data/subset"
|
184
186
|
require "torch/utils/data/tensor_dataset"
|
185
187
|
|
186
188
|
# hub
|
@@ -316,6 +318,16 @@ module Torch
|
|
316
318
|
end
|
317
319
|
end
|
318
320
|
|
321
|
+
def enable_grad
|
322
|
+
previous_value = grad_enabled?
|
323
|
+
begin
|
324
|
+
_set_grad_enabled(true)
|
325
|
+
yield
|
326
|
+
ensure
|
327
|
+
_set_grad_enabled(previous_value)
|
328
|
+
end
|
329
|
+
end
|
330
|
+
|
319
331
|
def device(str)
|
320
332
|
Device.new(str)
|
321
333
|
end
|
data/lib/torch/hub.rb
CHANGED
@@ -7,25 +7,26 @@ module Torch
|
|
7
7
|
|
8
8
|
def download_url_to_file(url, dst)
|
9
9
|
uri = URI(url)
|
10
|
-
tmp =
|
10
|
+
tmp = nil
|
11
11
|
location = nil
|
12
12
|
|
13
|
+
puts "Downloading #{url}..."
|
13
14
|
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
15
|
request = Net::HTTP::Get.new(uri)
|
15
16
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
17
|
+
http.request(request) do |response|
|
18
|
+
case response
|
19
|
+
when Net::HTTPRedirection
|
20
|
+
location = response["location"]
|
21
|
+
when Net::HTTPSuccess
|
22
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
23
|
+
File.open(tmp, "wb") do |f|
|
23
24
|
response.read_body do |chunk|
|
24
25
|
f.write(chunk)
|
25
26
|
end
|
26
|
-
else
|
27
|
-
raise Error, "Bad response"
|
28
27
|
end
|
28
|
+
else
|
29
|
+
raise Error, "Bad response"
|
29
30
|
end
|
30
31
|
end
|
31
32
|
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -373,7 +373,8 @@ module Torch
|
|
373
373
|
end
|
374
374
|
|
375
375
|
# weight and input swapped
|
376
|
-
Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
376
|
+
ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
377
|
+
ret
|
377
378
|
end
|
378
379
|
|
379
380
|
# distance functions
|
@@ -426,6 +427,9 @@ module Torch
|
|
426
427
|
end
|
427
428
|
|
428
429
|
def mse_loss(input, target, reduction: "mean")
|
430
|
+
if target.size != input.size
|
431
|
+
warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
|
432
|
+
end
|
429
433
|
NN.mse_loss(input, target, reduction)
|
430
434
|
end
|
431
435
|
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class << self
|
5
|
+
def random_split(dataset, lengths)
|
6
|
+
if lengths.sum != dataset.length
|
7
|
+
raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
|
8
|
+
end
|
9
|
+
|
10
|
+
indices = Torch.randperm(lengths.sum).to_a
|
11
|
+
_accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
|
12
|
+
end
|
13
|
+
|
14
|
+
private
|
15
|
+
|
16
|
+
def _accumulate(iterable)
|
17
|
+
sum = 0
|
18
|
+
iterable.map { |x| sum += x }
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -6,10 +6,22 @@ module Torch
|
|
6
6
|
|
7
7
|
attr_reader :dataset
|
8
8
|
|
9
|
-
def initialize(dataset, batch_size: 1, shuffle: false)
|
9
|
+
def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
|
10
10
|
@dataset = dataset
|
11
11
|
@batch_size = batch_size
|
12
12
|
@shuffle = shuffle
|
13
|
+
|
14
|
+
@batch_sampler = nil
|
15
|
+
|
16
|
+
if collate_fn.nil?
|
17
|
+
if auto_collation?
|
18
|
+
collate_fn = method(:default_collate)
|
19
|
+
else
|
20
|
+
collate_fn = method(:default_convert)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
@collate_fn = collate_fn
|
13
25
|
end
|
14
26
|
|
15
27
|
def each
|
@@ -25,8 +37,8 @@ module Torch
|
|
25
37
|
end
|
26
38
|
|
27
39
|
indexes.each_slice(@batch_size) do |idx|
|
28
|
-
|
29
|
-
yield
|
40
|
+
# TODO improve performance
|
41
|
+
yield @collate_fn.call(idx.map { |i| @dataset[i] })
|
30
42
|
end
|
31
43
|
end
|
32
44
|
|
@@ -36,7 +48,7 @@ module Torch
|
|
36
48
|
|
37
49
|
private
|
38
50
|
|
39
|
-
def
|
51
|
+
def default_convert(batch)
|
40
52
|
elem = batch[0]
|
41
53
|
case elem
|
42
54
|
when Tensor
|
@@ -44,11 +56,15 @@ module Torch
|
|
44
56
|
when Integer
|
45
57
|
Torch.tensor(batch)
|
46
58
|
when Array
|
47
|
-
batch.transpose.map { |v|
|
59
|
+
batch.transpose.map { |v| default_convert(v) }
|
48
60
|
else
|
49
|
-
raise
|
61
|
+
raise NotImplementedYet
|
50
62
|
end
|
51
63
|
end
|
64
|
+
|
65
|
+
def auto_collation?
|
66
|
+
!@batch_sampler.nil?
|
67
|
+
end
|
52
68
|
end
|
53
69
|
end
|
54
70
|
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class Subset < Dataset
|
5
|
+
def initialize(dataset, indices)
|
6
|
+
@dataset = dataset
|
7
|
+
@indices = indices
|
8
|
+
end
|
9
|
+
|
10
|
+
def [](idx)
|
11
|
+
@dataset[@indices[idx]]
|
12
|
+
end
|
13
|
+
|
14
|
+
def length
|
15
|
+
@indices.length
|
16
|
+
end
|
17
|
+
alias_method :size, :length
|
18
|
+
|
19
|
+
def to_a
|
20
|
+
@indices.map { |i| @dataset[i] }
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.3.
|
4
|
+
version: 0.3.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-08-
|
11
|
+
date: 2020-08-24 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -259,8 +259,10 @@ files:
|
|
259
259
|
- lib/torch/optim/rprop.rb
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
261
|
- lib/torch/tensor.rb
|
262
|
+
- lib/torch/utils/data.rb
|
262
263
|
- lib/torch/utils/data/data_loader.rb
|
263
264
|
- lib/torch/utils/data/dataset.rb
|
265
|
+
- lib/torch/utils/data/subset.rb
|
264
266
|
- lib/torch/utils/data/tensor_dataset.rb
|
265
267
|
- lib/torch/version.rb
|
266
268
|
homepage: https://github.com/ankane/torch.rb
|