torch-rb 0.2.3 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: cff805041122544d87342923649010da6981fbdb6d47c73da8cc623ba3856af5
4
- data.tar.gz: 257c16cbdbc915fe30e7cba5fc0d24ce0337cf88dd420dfaa3a13b8437e04164
3
+ metadata.gz: 67c5a0cf556399dc32d73e8793e3aa794c181150f0f42dfa810c4b98a5acf6f2
4
+ data.tar.gz: 0a23f6a42595fb9d599962e88438b964180583ead5b9cce934cc447951b4a389
5
5
  SHA512:
6
- metadata.gz: 36e2e671f3400fdaa513cfa2dd9d07b839b120cd848dc0e28bf8723570c554e6a96e1d4f29f33a1e995e6eb57e6042299321b8903f657006d2c04c10cecc59c2
7
- data.tar.gz: ed1b17bf30ba5b4342350cf41cc5aa19b64c83cef44fe4f0122874805eabe9e91f31e0d3bc046812fa7ef07031de81f442fe18c231be2324d4da6f677325a54a
6
+ metadata.gz: c0f8e9e3395d196d7ea6fa4b40d128284d768033e02f4ed7d2dc9adc985015fd0a80d601601dd97438b803b6a3bd7b81f5dbda353bb5dee4247503a24cd755d7
7
+ data.tar.gz: c32a22ebbe1b4dfd77324f62a72d6a128639aac0a99d4c5255b16c606e6f961ae2c8b0dbab5012a9b21faa7409511b79a50676bc8314f181c85f90433433fa8b
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 0.2.4 (2020-04-29)
2
+
3
+ - Added `to_i` and `to_f` to tensors
4
+ - Added `shuffle` option to data loader
5
+ - Fixed `modules` and `named_modules` for nested modules
6
+
1
7
  ## 0.2.3 (2020-04-28)
2
8
 
3
9
  - Added `show_config` and `parallel_info` methods
data/lib/torch/hub.rb CHANGED
@@ -4,6 +4,14 @@ module Torch
4
4
  def list(github, force_reload: false)
5
5
  raise NotImplementedYet
6
6
  end
7
+
8
+ def download_url_to_file(url)
9
+ raise NotImplementedYet
10
+ end
11
+
12
+ def load_state_dict_from_url(url)
13
+ raise NotImplementedYet
14
+ end
7
15
  end
8
16
  end
9
17
  end
@@ -25,7 +25,7 @@ module Torch
25
25
  end
26
26
 
27
27
  if floating_point?
28
- sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
28
+ sci = max > 1e8 || max < 1e-4
29
29
 
30
30
  all_int = values.all? { |v| v.finite? && v == v.to_i }
31
31
  decimal = all_int ? 1 : 4
@@ -70,6 +70,11 @@ module Torch
70
70
  momentum: exponential_average_factor, eps: @eps
71
71
  )
72
72
  end
73
+
74
+ def extra_inspect
75
+ s = "%{num_features}, eps: %{eps}, momentum: %{momentum}, affine: %{affine}, track_running_stats: %{track_running_stats}"
76
+ format(s, **dict)
77
+ end
73
78
  end
74
79
  end
75
80
  end
@@ -20,7 +20,14 @@ module Torch
20
20
 
21
21
  # TODO add more parameters
22
22
  def extra_inspect
23
- format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
23
+ s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
24
+ s += ", padding: %{padding}" if @padding != [0] * @padding.size
25
+ s += ", dilation: %{dilation}" if @dilation != [1] * @dilation.size
26
+ s += ", output_padding: %{output_padding}" if @output_padding != [0] * @output_padding.size
27
+ s += ", groups: %{groups}" if @groups != 1
28
+ s += ", bias: false" unless @bias
29
+ s += ", padding_mode: %{padding_mode}" if @padding_mode != "zeros"
30
+ format(s, **dict)
24
31
  end
25
32
  end
26
33
  end
@@ -23,7 +23,7 @@ module Torch
23
23
  if bias
24
24
  @bias = Parameter.new(Tensor.new(out_channels))
25
25
  else
26
- raise NotImplementedError
26
+ register_parameter("bias", nil)
27
27
  end
28
28
  reset_parameters
29
29
  end
@@ -12,7 +12,8 @@ module Torch
12
12
  end
13
13
 
14
14
  def extra_inspect
15
- format("kernel_size: %s", @kernel_size)
15
+ s = "kernel_size: %{kernel_size}, stride: %{stride}, padding: %{padding}, dilation: %{dilation}, ceil_mode: %{ceil_mode}"
16
+ format(s, **dict)
16
17
  end
17
18
  end
18
19
  end
@@ -186,8 +186,22 @@ module Torch
186
186
  named_modules.values
187
187
  end
188
188
 
189
- def named_modules
190
- {"" => self}.merge(named_children)
189
+ # TODO return enumerator?
190
+ def named_modules(memo: nil, prefix: "")
191
+ ret = {}
192
+ memo ||= Set.new
193
+ unless memo.include?(self)
194
+ memo << self
195
+ ret[prefix] = self
196
+ named_children.each do |name, mod|
197
+ next unless mod.is_a?(Module)
198
+ submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
199
+ mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
200
+ ret[m[0]] = m[1]
201
+ end
202
+ end
203
+ end
204
+ ret
191
205
  end
192
206
 
193
207
  def train(mode = true)
@@ -230,7 +244,9 @@ module Torch
230
244
  str = String.new
231
245
  str << "#{name}(\n"
232
246
  named_children.each do |name, mod|
233
- str << " (#{name}): #{mod.inspect}\n"
247
+ mod_str = mod.inspect
248
+ mod_str = mod_str.lines.join(" ")
249
+ str << " (#{name}): #{mod_str}\n"
234
250
  end
235
251
  str << ")"
236
252
  end
data/lib/torch/tensor.rb CHANGED
@@ -67,6 +67,14 @@ module Torch
67
67
  _flat_data.first
68
68
  end
69
69
 
70
+ def to_i
71
+ item.to_i
72
+ end
73
+
74
+ def to_f
75
+ item.to_f
76
+ end
77
+
70
78
  # unsure if this is correct
71
79
  def new
72
80
  Torch.empty(0, dtype: dtype)
@@ -6,9 +6,10 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1)
9
+ def initialize(dataset, batch_size: 1, shuffle: false)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
+ @shuffle = shuffle
12
13
  end
13
14
 
14
15
  def each
@@ -16,11 +17,15 @@ module Torch
16
17
  # this makes it easy to compare results
17
18
  base_seed = Torch.empty([], dtype: :int64).random!.item
18
19
 
19
- max_size = @dataset.size
20
- size.times do |i|
21
- start_index = i * @batch_size
22
- end_index = [start_index + @batch_size, max_size].min
23
- batch = (end_index - start_index).times.map { |j| @dataset[start_index + j] }
20
+ indexes =
21
+ if @shuffle
22
+ Torch.randperm(@dataset.size).to_a
23
+ else
24
+ @dataset.size.times
25
+ end
26
+
27
+ indexes.each_slice(@batch_size) do |idx|
28
+ batch = idx.map { |i| @dataset[i] }
24
29
  yield collate(batch)
25
30
  end
26
31
  end
@@ -0,0 +1,8 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Dataset
5
+ end
6
+ end
7
+ end
8
+ end
@@ -1,7 +1,7 @@
1
1
  module Torch
2
2
  module Utils
3
3
  module Data
4
- class TensorDataset
4
+ class TensorDataset < Dataset
5
5
  def initialize(*tensors)
6
6
  unless tensors.all? { |t| t.size(0) == tensors[0].size(0) }
7
7
  raise Error, "Tensors must all have same dim 0 size"
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.3"
2
+ VERSION = "0.2.4"
3
3
  end
data/lib/torch.rb CHANGED
@@ -174,6 +174,7 @@ require "torch/nn/init"
174
174
 
175
175
  # utils
176
176
  require "torch/utils/data/data_loader"
177
+ require "torch/utils/data/dataset"
177
178
  require "torch/utils/data/tensor_dataset"
178
179
 
179
180
  # hub
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.3
4
+ version: 0.2.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-28 00:00:00.000000000 Z
11
+ date: 2020-04-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -95,19 +95,19 @@ dependencies:
95
95
  - !ruby/object:Gem::Version
96
96
  version: '0'
97
97
  - !ruby/object:Gem::Dependency
98
- name: npy
98
+ name: torchvision
99
99
  requirement: !ruby/object:Gem::Requirement
100
100
  requirements:
101
101
  - - ">="
102
102
  - !ruby/object:Gem::Version
103
- version: '0'
103
+ version: 0.1.1
104
104
  type: :development
105
105
  prerelease: false
106
106
  version_requirements: !ruby/object:Gem::Requirement
107
107
  requirements:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
- version: '0'
110
+ version: 0.1.1
111
111
  description:
112
112
  email: andrew@chartkick.com
113
113
  executables: []
@@ -260,6 +260,7 @@ files:
260
260
  - lib/torch/optim/sgd.rb
261
261
  - lib/torch/tensor.rb
262
262
  - lib/torch/utils/data/data_loader.rb
263
+ - lib/torch/utils/data/dataset.rb
263
264
  - lib/torch/utils/data/tensor_dataset.rb
264
265
  - lib/torch/version.rb
265
266
  homepage: https://github.com/ankane/torch.rb