torch-rb 0.2.3 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: cff805041122544d87342923649010da6981fbdb6d47c73da8cc623ba3856af5
4
- data.tar.gz: 257c16cbdbc915fe30e7cba5fc0d24ce0337cf88dd420dfaa3a13b8437e04164
3
+ metadata.gz: 67c5a0cf556399dc32d73e8793e3aa794c181150f0f42dfa810c4b98a5acf6f2
4
+ data.tar.gz: 0a23f6a42595fb9d599962e88438b964180583ead5b9cce934cc447951b4a389
5
5
  SHA512:
6
- metadata.gz: 36e2e671f3400fdaa513cfa2dd9d07b839b120cd848dc0e28bf8723570c554e6a96e1d4f29f33a1e995e6eb57e6042299321b8903f657006d2c04c10cecc59c2
7
- data.tar.gz: ed1b17bf30ba5b4342350cf41cc5aa19b64c83cef44fe4f0122874805eabe9e91f31e0d3bc046812fa7ef07031de81f442fe18c231be2324d4da6f677325a54a
6
+ metadata.gz: c0f8e9e3395d196d7ea6fa4b40d128284d768033e02f4ed7d2dc9adc985015fd0a80d601601dd97438b803b6a3bd7b81f5dbda353bb5dee4247503a24cd755d7
7
+ data.tar.gz: c32a22ebbe1b4dfd77324f62a72d6a128639aac0a99d4c5255b16c606e6f961ae2c8b0dbab5012a9b21faa7409511b79a50676bc8314f181c85f90433433fa8b
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 0.2.4 (2020-04-29)
2
+
3
+ - Added `to_i` and `to_f` to tensors
4
+ - Added `shuffle` option to data loader
5
+ - Fixed `modules` and `named_modules` for nested modules
6
+
1
7
  ## 0.2.3 (2020-04-28)
2
8
 
3
9
  - Added `show_config` and `parallel_info` methods
data/lib/torch/hub.rb CHANGED
@@ -4,6 +4,14 @@ module Torch
4
4
  def list(github, force_reload: false)
5
5
  raise NotImplementedYet
6
6
  end
7
+
8
+ def download_url_to_file(url)
9
+ raise NotImplementedYet
10
+ end
11
+
12
+ def load_state_dict_from_url(url)
13
+ raise NotImplementedYet
14
+ end
7
15
  end
8
16
  end
9
17
  end
@@ -25,7 +25,7 @@ module Torch
25
25
  end
26
26
 
27
27
  if floating_point?
28
- sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
28
+ sci = max > 1e8 || max < 1e-4
29
29
 
30
30
  all_int = values.all? { |v| v.finite? && v == v.to_i }
31
31
  decimal = all_int ? 1 : 4
@@ -70,6 +70,11 @@ module Torch
70
70
  momentum: exponential_average_factor, eps: @eps
71
71
  )
72
72
  end
73
+
74
+ def extra_inspect
75
+ s = "%{num_features}, eps: %{eps}, momentum: %{momentum}, affine: %{affine}, track_running_stats: %{track_running_stats}"
76
+ format(s, **dict)
77
+ end
73
78
  end
74
79
  end
75
80
  end
@@ -20,7 +20,14 @@ module Torch
20
20
 
21
21
  # TODO add more parameters
22
22
  def extra_inspect
23
- format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
23
+ s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
24
+ s += ", padding: %{padding}" if @padding != [0] * @padding.size
25
+ s += ", dilation: %{dilation}" if @dilation != [1] * @dilation.size
26
+ s += ", output_padding: %{output_padding}" if @output_padding != [0] * @output_padding.size
27
+ s += ", groups: %{groups}" if @groups != 1
28
+ s += ", bias: false" unless @bias
29
+ s += ", padding_mode: %{padding_mode}" if @padding_mode != "zeros"
30
+ format(s, **dict)
24
31
  end
25
32
  end
26
33
  end
@@ -23,7 +23,7 @@ module Torch
23
23
  if bias
24
24
  @bias = Parameter.new(Tensor.new(out_channels))
25
25
  else
26
- raise NotImplementedError
26
+ register_parameter("bias", nil)
27
27
  end
28
28
  reset_parameters
29
29
  end
@@ -12,7 +12,8 @@ module Torch
12
12
  end
13
13
 
14
14
  def extra_inspect
15
- format("kernel_size: %s", @kernel_size)
15
+ s = "kernel_size: %{kernel_size}, stride: %{stride}, padding: %{padding}, dilation: %{dilation}, ceil_mode: %{ceil_mode}"
16
+ format(s, **dict)
16
17
  end
17
18
  end
18
19
  end
@@ -186,8 +186,22 @@ module Torch
186
186
  named_modules.values
187
187
  end
188
188
 
189
- def named_modules
190
- {"" => self}.merge(named_children)
189
+ # TODO return enumerator?
190
+ def named_modules(memo: nil, prefix: "")
191
+ ret = {}
192
+ memo ||= Set.new
193
+ unless memo.include?(self)
194
+ memo << self
195
+ ret[prefix] = self
196
+ named_children.each do |name, mod|
197
+ next unless mod.is_a?(Module)
198
+ submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
199
+ mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
200
+ ret[m[0]] = m[1]
201
+ end
202
+ end
203
+ end
204
+ ret
191
205
  end
192
206
 
193
207
  def train(mode = true)
@@ -230,7 +244,9 @@ module Torch
230
244
  str = String.new
231
245
  str << "#{name}(\n"
232
246
  named_children.each do |name, mod|
233
- str << " (#{name}): #{mod.inspect}\n"
247
+ mod_str = mod.inspect
248
+ mod_str = mod_str.lines.join(" ")
249
+ str << " (#{name}): #{mod_str}\n"
234
250
  end
235
251
  str << ")"
236
252
  end
data/lib/torch/tensor.rb CHANGED
@@ -67,6 +67,14 @@ module Torch
67
67
  _flat_data.first
68
68
  end
69
69
 
70
+ def to_i
71
+ item.to_i
72
+ end
73
+
74
+ def to_f
75
+ item.to_f
76
+ end
77
+
70
78
  # unsure if this is correct
71
79
  def new
72
80
  Torch.empty(0, dtype: dtype)
@@ -6,9 +6,10 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1)
9
+ def initialize(dataset, batch_size: 1, shuffle: false)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
+ @shuffle = shuffle
12
13
  end
13
14
 
14
15
  def each
@@ -16,11 +17,15 @@ module Torch
16
17
  # this makes it easy to compare results
17
18
  base_seed = Torch.empty([], dtype: :int64).random!.item
18
19
 
19
- max_size = @dataset.size
20
- size.times do |i|
21
- start_index = i * @batch_size
22
- end_index = [start_index + @batch_size, max_size].min
23
- batch = (end_index - start_index).times.map { |j| @dataset[start_index + j] }
20
+ indexes =
21
+ if @shuffle
22
+ Torch.randperm(@dataset.size).to_a
23
+ else
24
+ @dataset.size.times
25
+ end
26
+
27
+ indexes.each_slice(@batch_size) do |idx|
28
+ batch = idx.map { |i| @dataset[i] }
24
29
  yield collate(batch)
25
30
  end
26
31
  end
@@ -0,0 +1,8 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Dataset
5
+ end
6
+ end
7
+ end
8
+ end
@@ -1,7 +1,7 @@
1
1
  module Torch
2
2
  module Utils
3
3
  module Data
4
- class TensorDataset
4
+ class TensorDataset < Dataset
5
5
  def initialize(*tensors)
6
6
  unless tensors.all? { |t| t.size(0) == tensors[0].size(0) }
7
7
  raise Error, "Tensors must all have same dim 0 size"
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.3"
2
+ VERSION = "0.2.4"
3
3
  end
data/lib/torch.rb CHANGED
@@ -174,6 +174,7 @@ require "torch/nn/init"
174
174
 
175
175
  # utils
176
176
  require "torch/utils/data/data_loader"
177
+ require "torch/utils/data/dataset"
177
178
  require "torch/utils/data/tensor_dataset"
178
179
 
179
180
  # hub
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.3
4
+ version: 0.2.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-28 00:00:00.000000000 Z
11
+ date: 2020-04-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -95,19 +95,19 @@ dependencies:
95
95
  - !ruby/object:Gem::Version
96
96
  version: '0'
97
97
  - !ruby/object:Gem::Dependency
98
- name: npy
98
+ name: torchvision
99
99
  requirement: !ruby/object:Gem::Requirement
100
100
  requirements:
101
101
  - - ">="
102
102
  - !ruby/object:Gem::Version
103
- version: '0'
103
+ version: 0.1.1
104
104
  type: :development
105
105
  prerelease: false
106
106
  version_requirements: !ruby/object:Gem::Requirement
107
107
  requirements:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
- version: '0'
110
+ version: 0.1.1
111
111
  description:
112
112
  email: andrew@chartkick.com
113
113
  executables: []
@@ -260,6 +260,7 @@ files:
260
260
  - lib/torch/optim/sgd.rb
261
261
  - lib/torch/tensor.rb
262
262
  - lib/torch/utils/data/data_loader.rb
263
+ - lib/torch/utils/data/dataset.rb
263
264
  - lib/torch/utils/data/tensor_dataset.rb
264
265
  - lib/torch/version.rb
265
266
  homepage: https://github.com/ankane/torch.rb