torch-rb 0.3.1 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/ext/torch/ext.cpp +12 -1
- data/lib/torch.rb +12 -0
- data/lib/torch/hub.rb +11 -10
- data/lib/torch/nn/functional.rb +5 -1
- data/lib/torch/utils/data.rb +23 -0
- data/lib/torch/utils/data/data_loader.rb +22 -6
- data/lib/torch/utils/data/subset.rb +25 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 97908e85a67729120f763bb4140323505b77a831d5648e9d2d0961259e3d300c
|
4
|
+
data.tar.gz: f366548f9880dac7dffce6305e192f75a7467526ae55ae13af05d355918375ba
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: cc32ddbc43131175452a8b62df5d1eac6bc8450eea174018affd5cd073f81e2c9825d613014c62c5f8137cf5dddd1ab6ab6de60a4b3a67a757387446dbc1efad
|
7
|
+
data.tar.gz: c322e0b7ec7f03f12311d737034dad45037d2ad7710974e24250e11f4a0db14e221e870ddc52c0f2b723476be6f41fca8a8719068b0d0b7d8974d2080e9be6dc
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,12 @@
|
|
1
|
+
## 0.3.2 (2020-08-24)
|
2
|
+
|
3
|
+
- Added `enable_grad` method
|
4
|
+
- Added `random_split` method
|
5
|
+
- Added `collate_fn` option to `DataLoader`
|
6
|
+
- Added `grad=` method to `Tensor`
|
7
|
+
- Fixed error with `grad` method when empty
|
8
|
+
- Fixed `EmbeddingBag`
|
9
|
+
|
1
10
|
## 0.3.1 (2020-08-17)
|
2
11
|
|
3
12
|
- Added `create_graph` and `retain_graph` options to `backward` method
|
data/ext/torch/ext.cpp
CHANGED
@@ -358,7 +358,13 @@ void Init_ext()
|
|
358
358
|
.define_method(
|
359
359
|
"grad",
|
360
360
|
*[](Tensor& self) {
|
361
|
-
|
361
|
+
auto grad = self.grad();
|
362
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
363
|
+
})
|
364
|
+
.define_method(
|
365
|
+
"grad=",
|
366
|
+
*[](Tensor& self, torch::Tensor& grad) {
|
367
|
+
self.grad() = grad;
|
362
368
|
})
|
363
369
|
.define_method(
|
364
370
|
"_dtype",
|
@@ -580,6 +586,11 @@ void Init_ext()
|
|
580
586
|
*[](Parameter& self) {
|
581
587
|
auto grad = self.grad();
|
582
588
|
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
589
|
+
})
|
590
|
+
.define_method(
|
591
|
+
"grad=",
|
592
|
+
*[](Parameter& self, torch::Tensor& grad) {
|
593
|
+
self.grad() = grad;
|
583
594
|
});
|
584
595
|
|
585
596
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
data/lib/torch.rb
CHANGED
@@ -179,8 +179,10 @@ require "torch/nn/functional"
|
|
179
179
|
require "torch/nn/init"
|
180
180
|
|
181
181
|
# utils
|
182
|
+
require "torch/utils/data"
|
182
183
|
require "torch/utils/data/data_loader"
|
183
184
|
require "torch/utils/data/dataset"
|
185
|
+
require "torch/utils/data/subset"
|
184
186
|
require "torch/utils/data/tensor_dataset"
|
185
187
|
|
186
188
|
# hub
|
@@ -316,6 +318,16 @@ module Torch
|
|
316
318
|
end
|
317
319
|
end
|
318
320
|
|
321
|
+
def enable_grad
|
322
|
+
previous_value = grad_enabled?
|
323
|
+
begin
|
324
|
+
_set_grad_enabled(true)
|
325
|
+
yield
|
326
|
+
ensure
|
327
|
+
_set_grad_enabled(previous_value)
|
328
|
+
end
|
329
|
+
end
|
330
|
+
|
319
331
|
def device(str)
|
320
332
|
Device.new(str)
|
321
333
|
end
|
data/lib/torch/hub.rb
CHANGED
@@ -7,25 +7,26 @@ module Torch
|
|
7
7
|
|
8
8
|
def download_url_to_file(url, dst)
|
9
9
|
uri = URI(url)
|
10
|
-
tmp =
|
10
|
+
tmp = nil
|
11
11
|
location = nil
|
12
12
|
|
13
|
+
puts "Downloading #{url}..."
|
13
14
|
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
15
|
request = Net::HTTP::Get.new(uri)
|
15
16
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
17
|
+
http.request(request) do |response|
|
18
|
+
case response
|
19
|
+
when Net::HTTPRedirection
|
20
|
+
location = response["location"]
|
21
|
+
when Net::HTTPSuccess
|
22
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
23
|
+
File.open(tmp, "wb") do |f|
|
23
24
|
response.read_body do |chunk|
|
24
25
|
f.write(chunk)
|
25
26
|
end
|
26
|
-
else
|
27
|
-
raise Error, "Bad response"
|
28
27
|
end
|
28
|
+
else
|
29
|
+
raise Error, "Bad response"
|
29
30
|
end
|
30
31
|
end
|
31
32
|
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -373,7 +373,8 @@ module Torch
|
|
373
373
|
end
|
374
374
|
|
375
375
|
# weight and input swapped
|
376
|
-
Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
376
|
+
ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
377
|
+
ret
|
377
378
|
end
|
378
379
|
|
379
380
|
# distance functions
|
@@ -426,6 +427,9 @@ module Torch
|
|
426
427
|
end
|
427
428
|
|
428
429
|
def mse_loss(input, target, reduction: "mean")
|
430
|
+
if target.size != input.size
|
431
|
+
warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
|
432
|
+
end
|
429
433
|
NN.mse_loss(input, target, reduction)
|
430
434
|
end
|
431
435
|
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class << self
|
5
|
+
def random_split(dataset, lengths)
|
6
|
+
if lengths.sum != dataset.length
|
7
|
+
raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
|
8
|
+
end
|
9
|
+
|
10
|
+
indices = Torch.randperm(lengths.sum).to_a
|
11
|
+
_accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
|
12
|
+
end
|
13
|
+
|
14
|
+
private
|
15
|
+
|
16
|
+
def _accumulate(iterable)
|
17
|
+
sum = 0
|
18
|
+
iterable.map { |x| sum += x }
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -6,10 +6,22 @@ module Torch
|
|
6
6
|
|
7
7
|
attr_reader :dataset
|
8
8
|
|
9
|
-
def initialize(dataset, batch_size: 1, shuffle: false)
|
9
|
+
def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
|
10
10
|
@dataset = dataset
|
11
11
|
@batch_size = batch_size
|
12
12
|
@shuffle = shuffle
|
13
|
+
|
14
|
+
@batch_sampler = nil
|
15
|
+
|
16
|
+
if collate_fn.nil?
|
17
|
+
if auto_collation?
|
18
|
+
collate_fn = method(:default_collate)
|
19
|
+
else
|
20
|
+
collate_fn = method(:default_convert)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
@collate_fn = collate_fn
|
13
25
|
end
|
14
26
|
|
15
27
|
def each
|
@@ -25,8 +37,8 @@ module Torch
|
|
25
37
|
end
|
26
38
|
|
27
39
|
indexes.each_slice(@batch_size) do |idx|
|
28
|
-
|
29
|
-
yield
|
40
|
+
# TODO improve performance
|
41
|
+
yield @collate_fn.call(idx.map { |i| @dataset[i] })
|
30
42
|
end
|
31
43
|
end
|
32
44
|
|
@@ -36,7 +48,7 @@ module Torch
|
|
36
48
|
|
37
49
|
private
|
38
50
|
|
39
|
-
def
|
51
|
+
def default_convert(batch)
|
40
52
|
elem = batch[0]
|
41
53
|
case elem
|
42
54
|
when Tensor
|
@@ -44,11 +56,15 @@ module Torch
|
|
44
56
|
when Integer
|
45
57
|
Torch.tensor(batch)
|
46
58
|
when Array
|
47
|
-
batch.transpose.map { |v|
|
59
|
+
batch.transpose.map { |v| default_convert(v) }
|
48
60
|
else
|
49
|
-
raise
|
61
|
+
raise NotImplementedYet
|
50
62
|
end
|
51
63
|
end
|
64
|
+
|
65
|
+
def auto_collation?
|
66
|
+
!@batch_sampler.nil?
|
67
|
+
end
|
52
68
|
end
|
53
69
|
end
|
54
70
|
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class Subset < Dataset
|
5
|
+
def initialize(dataset, indices)
|
6
|
+
@dataset = dataset
|
7
|
+
@indices = indices
|
8
|
+
end
|
9
|
+
|
10
|
+
def [](idx)
|
11
|
+
@dataset[@indices[idx]]
|
12
|
+
end
|
13
|
+
|
14
|
+
def length
|
15
|
+
@indices.length
|
16
|
+
end
|
17
|
+
alias_method :size, :length
|
18
|
+
|
19
|
+
def to_a
|
20
|
+
@indices.map { |i| @dataset[i] }
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.3.
|
4
|
+
version: 0.3.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-08-
|
11
|
+
date: 2020-08-24 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -259,8 +259,10 @@ files:
|
|
259
259
|
- lib/torch/optim/rprop.rb
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
261
|
- lib/torch/tensor.rb
|
262
|
+
- lib/torch/utils/data.rb
|
262
263
|
- lib/torch/utils/data/data_loader.rb
|
263
264
|
- lib/torch/utils/data/dataset.rb
|
265
|
+
- lib/torch/utils/data/subset.rb
|
264
266
|
- lib/torch/utils/data/tensor_dataset.rb
|
265
267
|
- lib/torch/version.rb
|
266
268
|
homepage: https://github.com/ankane/torch.rb
|