torch-rb 0.3.1 → 0.3.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 06e94b492acbbdb71f9e6a11081fb043a03ae0d5c704cc79faa31dd96bde70ef
4
- data.tar.gz: 4f38fa52d30ef9bf121204423b4d675f21dbef806b6f137152f2cf9399ddf4bb
3
+ metadata.gz: 97908e85a67729120f763bb4140323505b77a831d5648e9d2d0961259e3d300c
4
+ data.tar.gz: f366548f9880dac7dffce6305e192f75a7467526ae55ae13af05d355918375ba
5
5
  SHA512:
6
- metadata.gz: 2fb2613ca629a70f55009b697b15830d59c0d8fc06c1c5102917b4870cb783427fb56ecc08889c09e15c342381385f258b2a33102dc5adddf2d463d41674994d
7
- data.tar.gz: f26a6ba91caa57a92b8b047217a35c39d1e9c4c361df77e2182053b4ab490f20792fc88dba169dae87d4a3d4ee4d69e2c779efb1fa6150b4d3f0d93e3762aec9
6
+ metadata.gz: cc32ddbc43131175452a8b62df5d1eac6bc8450eea174018affd5cd073f81e2c9825d613014c62c5f8137cf5dddd1ab6ab6de60a4b3a67a757387446dbc1efad
7
+ data.tar.gz: c322e0b7ec7f03f12311d737034dad45037d2ad7710974e24250e11f4a0db14e221e870ddc52c0f2b723476be6f41fca8a8719068b0d0b7d8974d2080e9be6dc
@@ -1,3 +1,12 @@
1
+ ## 0.3.2 (2020-08-24)
2
+
3
+ - Added `enable_grad` method
4
+ - Added `random_split` method
5
+ - Added `collate_fn` option to `DataLoader`
6
+ - Added `grad=` method to `Tensor`
7
+ - Fixed error with `grad` method when empty
8
+ - Fixed `EmbeddingBag`
9
+
1
10
  ## 0.3.1 (2020-08-17)
2
11
 
3
12
  - Added `create_graph` and `retain_graph` options to `backward` method
@@ -358,7 +358,13 @@ void Init_ext()
358
358
  .define_method(
359
359
  "grad",
360
360
  *[](Tensor& self) {
361
- return self.grad();
361
+ auto grad = self.grad();
362
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
363
+ })
364
+ .define_method(
365
+ "grad=",
366
+ *[](Tensor& self, torch::Tensor& grad) {
367
+ self.grad() = grad;
362
368
  })
363
369
  .define_method(
364
370
  "_dtype",
@@ -580,6 +586,11 @@ void Init_ext()
580
586
  *[](Parameter& self) {
581
587
  auto grad = self.grad();
582
588
  return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
589
+ })
590
+ .define_method(
591
+ "grad=",
592
+ *[](Parameter& self, torch::Tensor& grad) {
593
+ self.grad() = grad;
583
594
  });
584
595
 
585
596
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
@@ -179,8 +179,10 @@ require "torch/nn/functional"
179
179
  require "torch/nn/init"
180
180
 
181
181
  # utils
182
+ require "torch/utils/data"
182
183
  require "torch/utils/data/data_loader"
183
184
  require "torch/utils/data/dataset"
185
+ require "torch/utils/data/subset"
184
186
  require "torch/utils/data/tensor_dataset"
185
187
 
186
188
  # hub
@@ -316,6 +318,16 @@ module Torch
316
318
  end
317
319
  end
318
320
 
321
+ def enable_grad
322
+ previous_value = grad_enabled?
323
+ begin
324
+ _set_grad_enabled(true)
325
+ yield
326
+ ensure
327
+ _set_grad_enabled(previous_value)
328
+ end
329
+ end
330
+
319
331
  def device(str)
320
332
  Device.new(str)
321
333
  end
@@ -7,25 +7,26 @@ module Torch
7
7
 
8
8
  def download_url_to_file(url, dst)
9
9
  uri = URI(url)
10
- tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
10
+ tmp = nil
11
11
  location = nil
12
12
 
13
+ puts "Downloading #{url}..."
13
14
  Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
14
15
  request = Net::HTTP::Get.new(uri)
15
16
 
16
- puts "Downloading #{url}..."
17
- File.open(tmp, "wb") do |f|
18
- http.request(request) do |response|
19
- case response
20
- when Net::HTTPRedirection
21
- location = response["location"]
22
- when Net::HTTPSuccess
17
+ http.request(request) do |response|
18
+ case response
19
+ when Net::HTTPRedirection
20
+ location = response["location"]
21
+ when Net::HTTPSuccess
22
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
23
+ File.open(tmp, "wb") do |f|
23
24
  response.read_body do |chunk|
24
25
  f.write(chunk)
25
26
  end
26
- else
27
- raise Error, "Bad response"
28
27
  end
28
+ else
29
+ raise Error, "Bad response"
29
30
  end
30
31
  end
31
32
  end
@@ -373,7 +373,8 @@ module Torch
373
373
  end
374
374
 
375
375
  # weight and input swapped
376
- Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
376
+ ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
377
+ ret
377
378
  end
378
379
 
379
380
  # distance functions
@@ -426,6 +427,9 @@ module Torch
426
427
  end
427
428
 
428
429
  def mse_loss(input, target, reduction: "mean")
430
+ if target.size != input.size
431
+ warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
432
+ end
429
433
  NN.mse_loss(input, target, reduction)
430
434
  end
431
435
 
@@ -0,0 +1,23 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class << self
5
+ def random_split(dataset, lengths)
6
+ if lengths.sum != dataset.length
7
+ raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
8
+ end
9
+
10
+ indices = Torch.randperm(lengths.sum).to_a
11
+ _accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
12
+ end
13
+
14
+ private
15
+
16
+ def _accumulate(iterable)
17
+ sum = 0
18
+ iterable.map { |x| sum += x }
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -6,10 +6,22 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1, shuffle: false)
9
+ def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
12
  @shuffle = shuffle
13
+
14
+ @batch_sampler = nil
15
+
16
+ if collate_fn.nil?
17
+ if auto_collation?
18
+ collate_fn = method(:default_collate)
19
+ else
20
+ collate_fn = method(:default_convert)
21
+ end
22
+ end
23
+
24
+ @collate_fn = collate_fn
13
25
  end
14
26
 
15
27
  def each
@@ -25,8 +37,8 @@ module Torch
25
37
  end
26
38
 
27
39
  indexes.each_slice(@batch_size) do |idx|
28
- batch = idx.map { |i| @dataset[i] }
29
- yield collate(batch)
40
+ # TODO improve performance
41
+ yield @collate_fn.call(idx.map { |i| @dataset[i] })
30
42
  end
31
43
  end
32
44
 
@@ -36,7 +48,7 @@ module Torch
36
48
 
37
49
  private
38
50
 
39
- def collate(batch)
51
+ def default_convert(batch)
40
52
  elem = batch[0]
41
53
  case elem
42
54
  when Tensor
@@ -44,11 +56,15 @@ module Torch
44
56
  when Integer
45
57
  Torch.tensor(batch)
46
58
  when Array
47
- batch.transpose.map { |v| collate(v) }
59
+ batch.transpose.map { |v| default_convert(v) }
48
60
  else
49
- raise NotImpelmentYet
61
+ raise NotImplementedYet
50
62
  end
51
63
  end
64
+
65
+ def auto_collation?
66
+ !@batch_sampler.nil?
67
+ end
52
68
  end
53
69
  end
54
70
  end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Subset < Dataset
5
+ def initialize(dataset, indices)
6
+ @dataset = dataset
7
+ @indices = indices
8
+ end
9
+
10
+ def [](idx)
11
+ @dataset[@indices[idx]]
12
+ end
13
+
14
+ def length
15
+ @indices.length
16
+ end
17
+ alias_method :size, :length
18
+
19
+ def to_a
20
+ @indices.map { |i| @dataset[i] }
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.1"
2
+ VERSION = "0.3.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.1
4
+ version: 0.3.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-17 00:00:00.000000000 Z
11
+ date: 2020-08-24 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -259,8 +259,10 @@ files:
259
259
  - lib/torch/optim/rprop.rb
260
260
  - lib/torch/optim/sgd.rb
261
261
  - lib/torch/tensor.rb
262
+ - lib/torch/utils/data.rb
262
263
  - lib/torch/utils/data/data_loader.rb
263
264
  - lib/torch/utils/data/dataset.rb
265
+ - lib/torch/utils/data/subset.rb
264
266
  - lib/torch/utils/data/tensor_dataset.rb
265
267
  - lib/torch/version.rb
266
268
  homepage: https://github.com/ankane/torch.rb