torch-rb 0.2.3 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -75,7 +75,7 @@ module Torch
75
75
  v.is_a?(Tensor)
76
76
  when "Tensor?"
77
77
  v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]"
78
+ when "Tensor[]", "Tensor?[]"
79
79
  v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
80
  when "int"
81
81
  if k == "reduction"
@@ -70,6 +70,11 @@ module Torch
70
70
  momentum: exponential_average_factor, eps: @eps
71
71
  )
72
72
  end
73
+
74
+ def extra_inspect
75
+ s = "%{num_features}, eps: %{eps}, momentum: %{momentum}, affine: %{affine}, track_running_stats: %{track_running_stats}"
76
+ format(s, **dict)
77
+ end
73
78
  end
74
79
  end
75
80
  end
@@ -18,9 +18,15 @@ module Torch
18
18
  F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
19
  end
20
20
 
21
- # TODO add more parameters
22
21
  def extra_inspect
23
- format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
22
+ s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
23
+ s += ", padding: %{padding}" if @padding != [0] * @padding.size
24
+ s += ", dilation: %{dilation}" if @dilation != [1] * @dilation.size
25
+ s += ", output_padding: %{output_padding}" if @output_padding != [0] * @output_padding.size
26
+ s += ", groups: %{groups}" if @groups != 1
27
+ s += ", bias: false" unless @bias
28
+ s += ", padding_mode: %{padding_mode}" if @padding_mode != "zeros"
29
+ format(s, **dict)
24
30
  end
25
31
  end
26
32
  end
@@ -23,7 +23,7 @@ module Torch
23
23
  if bias
24
24
  @bias = Parameter.new(Tensor.new(out_channels))
25
25
  else
26
- raise NotImplementedError
26
+ register_parameter("bias", nil)
27
27
  end
28
28
  reset_parameters
29
29
  end
@@ -12,7 +12,8 @@ module Torch
12
12
  end
13
13
 
14
14
  def extra_inspect
15
- format("kernel_size: %s", @kernel_size)
15
+ s = "kernel_size: %{kernel_size}, stride: %{stride}, padding: %{padding}, dilation: %{dilation}, ceil_mode: %{ceil_mode}"
16
+ format(s, **dict)
16
17
  end
17
18
  end
18
19
  end
@@ -145,7 +145,7 @@ module Torch
145
145
  params = {}
146
146
  if recurse
147
147
  named_children.each do |name, mod|
148
- params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
148
+ params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
149
149
  end
150
150
  end
151
151
  instance_variables.each do |name|
@@ -186,8 +186,22 @@ module Torch
186
186
  named_modules.values
187
187
  end
188
188
 
189
- def named_modules
190
- {"" => self}.merge(named_children)
189
+ # TODO return enumerator?
190
+ def named_modules(memo: nil, prefix: "")
191
+ ret = {}
192
+ memo ||= Set.new
193
+ unless memo.include?(self)
194
+ memo << self
195
+ ret[prefix] = self
196
+ named_children.each do |name, mod|
197
+ next unless mod.is_a?(Module)
198
+ submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
199
+ mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
200
+ ret[m[0]] = m[1]
201
+ end
202
+ end
203
+ end
204
+ ret
191
205
  end
192
206
 
193
207
  def train(mode = true)
@@ -230,7 +244,9 @@ module Torch
230
244
  str = String.new
231
245
  str << "#{name}(\n"
232
246
  named_children.each do |name, mod|
233
- str << " (#{name}): #{mod.inspect}\n"
247
+ mod_str = mod.inspect
248
+ mod_str = mod_str.lines.join(" ")
249
+ str << " (#{name}): #{mod_str}\n"
234
250
  end
235
251
  str << ")"
236
252
  end
@@ -270,8 +286,11 @@ module Torch
270
286
  str % vars
271
287
  end
272
288
 
289
+ # used for format
290
+ # remove tensors for performance
291
+ # so we can skip call to inspect
273
292
  def dict
274
- instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
293
+ instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
275
294
  end
276
295
  end
277
296
  end
@@ -32,9 +32,11 @@ module Torch
32
32
  end
33
33
 
34
34
  def state_dict
35
+ raise NotImplementedYet
36
+
35
37
  pack_group = lambda do |group|
36
- packed = group.select { |k, _| k != :params }.to_h
37
- packed[:params] = group[:params].map { |p| p.object_id }
38
+ packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
39
+ packed["params"] = group[:params].map { |p| p.object_id }
38
40
  packed
39
41
  end
40
42
 
@@ -42,8 +44,8 @@ module Torch
42
44
  packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
43
45
 
44
46
  {
45
- state: packed_state,
46
- param_groups: param_groups
47
+ "state" => packed_state,
48
+ "param_groups" => param_groups
47
49
  }
48
50
  end
49
51
 
@@ -11,9 +11,6 @@ module Torch
11
11
  end
12
12
 
13
13
  def step(closure = nil)
14
- # TODO implement []=
15
- raise NotImplementedYet
16
-
17
14
  loss = nil
18
15
  if closure
19
16
  loss = closure.call
@@ -1,6 +1,7 @@
1
1
  module Torch
2
2
  class Tensor
3
3
  include Comparable
4
+ include Enumerable
4
5
  include Inspector
5
6
 
6
7
  alias_method :requires_grad?, :requires_grad
@@ -25,14 +26,36 @@ module Torch
25
26
  inspect
26
27
  end
27
28
 
29
+ def each
30
+ return enum_for(:each) unless block_given?
31
+
32
+ size(0).times do |i|
33
+ yield self[i]
34
+ end
35
+ end
36
+
37
+ # TODO make more performant
28
38
  def to_a
29
- reshape_arr(_flat_data, shape)
39
+ arr = _flat_data
40
+ if shape.empty?
41
+ arr
42
+ else
43
+ shape[1..-1].reverse.each do |dim|
44
+ arr = arr.each_slice(dim)
45
+ end
46
+ arr.to_a
47
+ end
30
48
  end
31
49
 
32
- # TODO support dtype
33
- def to(device, non_blocking: false, copy: false)
50
+ def to(device = nil, dtype: nil, non_blocking: false, copy: false)
51
+ device ||= self.device
34
52
  device = Device.new(device) if device.is_a?(String)
35
- _to(device, _dtype, non_blocking, copy)
53
+
54
+ dtype ||= self.dtype
55
+ enum = DTYPE_TO_ENUM[dtype]
56
+ raise Error, "Unknown type: #{dtype}" unless enum
57
+
58
+ _to(device, enum, non_blocking, copy)
36
59
  end
37
60
 
38
61
  def cpu
@@ -64,7 +87,15 @@ module Torch
64
87
  if numel != 1
65
88
  raise Error, "only one element tensors can be converted to Ruby scalars"
66
89
  end
67
- _flat_data.first
90
+ to_a.first
91
+ end
92
+
93
+ def to_i
94
+ item.to_i
95
+ end
96
+
97
+ def to_f
98
+ item.to_f
68
99
  end
69
100
 
70
101
  # unsure if this is correct
@@ -80,7 +111,7 @@ module Torch
80
111
  def numo
81
112
  cls = Torch._dtype_to_numo[dtype]
82
113
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
83
- cls.cast(_flat_data).reshape(*shape)
114
+ cls.from_string(_data_str).reshape(*shape)
84
115
  end
85
116
 
86
117
  def new_ones(*size, **options)
@@ -108,15 +139,6 @@ module Torch
108
139
  _view(size)
109
140
  end
110
141
 
111
- # value and other are swapped for some methods
112
- def add!(value = 1, other)
113
- if other.is_a?(Numeric)
114
- _add__scalar(other, value)
115
- else
116
- _add__tensor(other, value)
117
- end
118
- end
119
-
120
142
  def +(other)
121
143
  add(other)
122
144
  end
@@ -145,12 +167,25 @@ module Torch
145
167
  neg
146
168
  end
147
169
 
170
+ def &(other)
171
+ logical_and(other)
172
+ end
173
+
174
+ def |(other)
175
+ logical_or(other)
176
+ end
177
+
178
+ def ^(other)
179
+ logical_xor(other)
180
+ end
181
+
148
182
  # TODO better compare?
149
183
  def <=>(other)
150
184
  item <=> other
151
185
  end
152
186
 
153
- # based on python_variable_indexing.cpp
187
+ # based on python_variable_indexing.cpp and
188
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
154
189
  def [](*indexes)
155
190
  result = self
156
191
  dim = 0
@@ -162,6 +197,8 @@ module Torch
162
197
  finish += 1 unless index.exclude_end?
163
198
  result = result._slice_tensor(dim, index.begin, finish, 1)
164
199
  dim += 1
200
+ elsif index.is_a?(Tensor)
201
+ result = result.index([index])
165
202
  elsif index.nil?
166
203
  result = result.unsqueeze(dim)
167
204
  dim += 1
@@ -175,24 +212,37 @@ module Torch
175
212
  result
176
213
  end
177
214
 
178
- # TODO
179
- # based on python_variable_indexing.cpp
215
+ # based on python_variable_indexing.cpp and
216
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
180
217
  def []=(index, value)
181
218
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
182
219
 
183
- value = Torch.tensor(value) unless value.is_a?(Tensor)
220
+ value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
184
221
 
185
222
  if index.is_a?(Numeric)
186
- copy_to(_select_int(0, index), value)
223
+ index_put!([Torch.tensor(index)], value)
187
224
  elsif index.is_a?(Range)
188
225
  finish = index.end
189
226
  finish += 1 unless index.exclude_end?
190
- copy_to(_slice_tensor(0, index.begin, finish, 1), value)
227
+ _slice_tensor(0, index.begin, finish, 1).copy!(value)
228
+ elsif index.is_a?(Tensor)
229
+ index_put!([index], value)
191
230
  else
192
231
  raise Error, "Unsupported index type: #{index.class.name}"
193
232
  end
194
233
  end
195
234
 
235
+ # native functions that need manually defined
236
+
237
+ # value and other are swapped for some methods
238
+ def add!(value = 1, other)
239
+ if other.is_a?(Numeric)
240
+ _add__scalar(other, value)
241
+ else
242
+ _add__tensor(other, value)
243
+ end
244
+ end
245
+
196
246
  # native functions overlap, so need to handle manually
197
247
  def random!(*args)
198
248
  case args.size
@@ -205,22 +255,9 @@ module Torch
205
255
  end
206
256
  end
207
257
 
208
- private
209
-
210
- def copy_to(dst, src)
211
- dst.copy!(src)
212
- end
213
-
214
- def reshape_arr(arr, dims)
215
- if dims.empty?
216
- arr
217
- else
218
- arr = arr.flatten
219
- dims[1..-1].reverse.each do |dim|
220
- arr = arr.each_slice(dim)
221
- end
222
- arr.to_a
223
- end
258
+ def clamp!(min, max)
259
+ _clamp_min_(min)
260
+ _clamp_max_(max)
224
261
  end
225
262
  end
226
263
  end
@@ -6,9 +6,10 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1)
9
+ def initialize(dataset, batch_size: 1, shuffle: false)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
+ @shuffle = shuffle
12
13
  end
13
14
 
14
15
  def each
@@ -16,11 +17,15 @@ module Torch
16
17
  # this makes it easy to compare results
17
18
  base_seed = Torch.empty([], dtype: :int64).random!.item
18
19
 
19
- max_size = @dataset.size
20
- size.times do |i|
21
- start_index = i * @batch_size
22
- end_index = [start_index + @batch_size, max_size].min
23
- batch = (end_index - start_index).times.map { |j| @dataset[start_index + j] }
20
+ indexes =
21
+ if @shuffle
22
+ Torch.randperm(@dataset.size).to_a
23
+ else
24
+ @dataset.size.times
25
+ end
26
+
27
+ indexes.each_slice(@batch_size) do |idx|
28
+ batch = idx.map { |i| @dataset[i] }
24
29
  yield collate(batch)
25
30
  end
26
31
  end
@@ -0,0 +1,8 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Dataset
5
+ end
6
+ end
7
+ end
8
+ end
@@ -1,7 +1,7 @@
1
1
  module Torch
2
2
  module Utils
3
3
  module Data
4
- class TensorDataset
4
+ class TensorDataset < Dataset
5
5
  def initialize(*tensors)
6
6
  unless tensors.all? { |t| t.size(0) == tensors[0].size(0) }
7
7
  raise Error, "Tensors must all have same dim 0 size"
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.3"
2
+ VERSION = "0.3.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.3
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-28 00:00:00.000000000 Z
11
+ date: 2020-07-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -95,19 +95,19 @@ dependencies:
95
95
  - !ruby/object:Gem::Version
96
96
  version: '0'
97
97
  - !ruby/object:Gem::Dependency
98
- name: npy
98
+ name: torchvision
99
99
  requirement: !ruby/object:Gem::Requirement
100
100
  requirements:
101
101
  - - ">="
102
102
  - !ruby/object:Gem::Version
103
- version: '0'
103
+ version: 0.1.1
104
104
  type: :development
105
105
  prerelease: false
106
106
  version_requirements: !ruby/object:Gem::Requirement
107
107
  requirements:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
- version: '0'
110
+ version: 0.1.1
111
111
  description:
112
112
  email: andrew@chartkick.com
113
113
  executables: []
@@ -260,6 +260,7 @@ files:
260
260
  - lib/torch/optim/sgd.rb
261
261
  - lib/torch/tensor.rb
262
262
  - lib/torch/utils/data/data_loader.rb
263
+ - lib/torch/utils/data/dataset.rb
263
264
  - lib/torch/utils/data/tensor_dataset.rb
264
265
  - lib/torch/version.rb
265
266
  homepage: https://github.com/ankane/torch.rb