torch-rb 0.3.1 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 06e94b492acbbdb71f9e6a11081fb043a03ae0d5c704cc79faa31dd96bde70ef
4
- data.tar.gz: 4f38fa52d30ef9bf121204423b4d675f21dbef806b6f137152f2cf9399ddf4bb
3
+ metadata.gz: 97908e85a67729120f763bb4140323505b77a831d5648e9d2d0961259e3d300c
4
+ data.tar.gz: f366548f9880dac7dffce6305e192f75a7467526ae55ae13af05d355918375ba
5
5
  SHA512:
6
- metadata.gz: 2fb2613ca629a70f55009b697b15830d59c0d8fc06c1c5102917b4870cb783427fb56ecc08889c09e15c342381385f258b2a33102dc5adddf2d463d41674994d
7
- data.tar.gz: f26a6ba91caa57a92b8b047217a35c39d1e9c4c361df77e2182053b4ab490f20792fc88dba169dae87d4a3d4ee4d69e2c779efb1fa6150b4d3f0d93e3762aec9
6
+ metadata.gz: cc32ddbc43131175452a8b62df5d1eac6bc8450eea174018affd5cd073f81e2c9825d613014c62c5f8137cf5dddd1ab6ab6de60a4b3a67a757387446dbc1efad
7
+ data.tar.gz: c322e0b7ec7f03f12311d737034dad45037d2ad7710974e24250e11f4a0db14e221e870ddc52c0f2b723476be6f41fca8a8719068b0d0b7d8974d2080e9be6dc
@@ -1,3 +1,12 @@
1
+ ## 0.3.2 (2020-08-24)
2
+
3
+ - Added `enable_grad` method
4
+ - Added `random_split` method
5
+ - Added `collate_fn` option to `DataLoader`
6
+ - Added `grad=` method to `Tensor`
7
+ - Fixed error with `grad` method when empty
8
+ - Fixed `EmbeddingBag`
9
+
1
10
  ## 0.3.1 (2020-08-17)
2
11
 
3
12
  - Added `create_graph` and `retain_graph` options to `backward` method
@@ -358,7 +358,13 @@ void Init_ext()
358
358
  .define_method(
359
359
  "grad",
360
360
  *[](Tensor& self) {
361
- return self.grad();
361
+ auto grad = self.grad();
362
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
363
+ })
364
+ .define_method(
365
+ "grad=",
366
+ *[](Tensor& self, torch::Tensor& grad) {
367
+ self.grad() = grad;
362
368
  })
363
369
  .define_method(
364
370
  "_dtype",
@@ -580,6 +586,11 @@ void Init_ext()
580
586
  *[](Parameter& self) {
581
587
  auto grad = self.grad();
582
588
  return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
589
+ })
590
+ .define_method(
591
+ "grad=",
592
+ *[](Parameter& self, torch::Tensor& grad) {
593
+ self.grad() = grad;
583
594
  });
584
595
 
585
596
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
@@ -179,8 +179,10 @@ require "torch/nn/functional"
179
179
  require "torch/nn/init"
180
180
 
181
181
  # utils
182
+ require "torch/utils/data"
182
183
  require "torch/utils/data/data_loader"
183
184
  require "torch/utils/data/dataset"
185
+ require "torch/utils/data/subset"
184
186
  require "torch/utils/data/tensor_dataset"
185
187
 
186
188
  # hub
@@ -316,6 +318,16 @@ module Torch
316
318
  end
317
319
  end
318
320
 
321
+ def enable_grad
322
+ previous_value = grad_enabled?
323
+ begin
324
+ _set_grad_enabled(true)
325
+ yield
326
+ ensure
327
+ _set_grad_enabled(previous_value)
328
+ end
329
+ end
330
+
319
331
  def device(str)
320
332
  Device.new(str)
321
333
  end
@@ -7,25 +7,26 @@ module Torch
7
7
 
8
8
  def download_url_to_file(url, dst)
9
9
  uri = URI(url)
10
- tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
10
+ tmp = nil
11
11
  location = nil
12
12
 
13
+ puts "Downloading #{url}..."
13
14
  Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
14
15
  request = Net::HTTP::Get.new(uri)
15
16
 
16
- puts "Downloading #{url}..."
17
- File.open(tmp, "wb") do |f|
18
- http.request(request) do |response|
19
- case response
20
- when Net::HTTPRedirection
21
- location = response["location"]
22
- when Net::HTTPSuccess
17
+ http.request(request) do |response|
18
+ case response
19
+ when Net::HTTPRedirection
20
+ location = response["location"]
21
+ when Net::HTTPSuccess
22
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
23
+ File.open(tmp, "wb") do |f|
23
24
  response.read_body do |chunk|
24
25
  f.write(chunk)
25
26
  end
26
- else
27
- raise Error, "Bad response"
28
27
  end
28
+ else
29
+ raise Error, "Bad response"
29
30
  end
30
31
  end
31
32
  end
@@ -373,7 +373,8 @@ module Torch
373
373
  end
374
374
 
375
375
  # weight and input swapped
376
- Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
376
+ ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
377
+ ret
377
378
  end
378
379
 
379
380
  # distance functions
@@ -426,6 +427,9 @@ module Torch
426
427
  end
427
428
 
428
429
  def mse_loss(input, target, reduction: "mean")
430
+ if target.size != input.size
431
+ warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
432
+ end
429
433
  NN.mse_loss(input, target, reduction)
430
434
  end
431
435
 
@@ -0,0 +1,23 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class << self
5
+ def random_split(dataset, lengths)
6
+ if lengths.sum != dataset.length
7
+ raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
8
+ end
9
+
10
+ indices = Torch.randperm(lengths.sum).to_a
11
+ _accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
12
+ end
13
+
14
+ private
15
+
16
+ def _accumulate(iterable)
17
+ sum = 0
18
+ iterable.map { |x| sum += x }
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -6,10 +6,22 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1, shuffle: false)
9
+ def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
12
  @shuffle = shuffle
13
+
14
+ @batch_sampler = nil
15
+
16
+ if collate_fn.nil?
17
+ if auto_collation?
18
+ collate_fn = method(:default_collate)
19
+ else
20
+ collate_fn = method(:default_convert)
21
+ end
22
+ end
23
+
24
+ @collate_fn = collate_fn
13
25
  end
14
26
 
15
27
  def each
@@ -25,8 +37,8 @@ module Torch
25
37
  end
26
38
 
27
39
  indexes.each_slice(@batch_size) do |idx|
28
- batch = idx.map { |i| @dataset[i] }
29
- yield collate(batch)
40
+ # TODO improve performance
41
+ yield @collate_fn.call(idx.map { |i| @dataset[i] })
30
42
  end
31
43
  end
32
44
 
@@ -36,7 +48,7 @@ module Torch
36
48
 
37
49
  private
38
50
 
39
- def collate(batch)
51
+ def default_convert(batch)
40
52
  elem = batch[0]
41
53
  case elem
42
54
  when Tensor
@@ -44,11 +56,15 @@ module Torch
44
56
  when Integer
45
57
  Torch.tensor(batch)
46
58
  when Array
47
- batch.transpose.map { |v| collate(v) }
59
+ batch.transpose.map { |v| default_convert(v) }
48
60
  else
49
- raise NotImpelmentYet
61
+ raise NotImplementedYet
50
62
  end
51
63
  end
64
+
65
+ def auto_collation?
66
+ !@batch_sampler.nil?
67
+ end
52
68
  end
53
69
  end
54
70
  end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Subset < Dataset
5
+ def initialize(dataset, indices)
6
+ @dataset = dataset
7
+ @indices = indices
8
+ end
9
+
10
+ def [](idx)
11
+ @dataset[@indices[idx]]
12
+ end
13
+
14
+ def length
15
+ @indices.length
16
+ end
17
+ alias_method :size, :length
18
+
19
+ def to_a
20
+ @indices.map { |i| @dataset[i] }
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.1"
2
+ VERSION = "0.3.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.1
4
+ version: 0.3.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-17 00:00:00.000000000 Z
11
+ date: 2020-08-24 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -259,8 +259,10 @@ files:
259
259
  - lib/torch/optim/rprop.rb
260
260
  - lib/torch/optim/sgd.rb
261
261
  - lib/torch/tensor.rb
262
+ - lib/torch/utils/data.rb
262
263
  - lib/torch/utils/data/data_loader.rb
263
264
  - lib/torch/utils/data/dataset.rb
265
+ - lib/torch/utils/data/subset.rb
264
266
  - lib/torch/utils/data/tensor_dataset.rb
265
267
  - lib/torch/version.rb
266
268
  homepage: https://github.com/ankane/torch.rb