torch-rb 0.2.6 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -83,6 +83,8 @@ module Torch
83
83
  else
84
84
  v.is_a?(Integer)
85
85
  end
86
+ when "int?"
87
+ v.is_a?(Integer) || v.nil?
86
88
  when "float"
87
89
  v.is_a?(Numeric)
88
90
  when /int\[.*\]/
@@ -126,9 +128,11 @@ module Torch
126
128
  end
127
129
 
128
130
  func = candidates.first
131
+ args = func.args.map { |a| final_values[a[:name]] }
132
+ args << TensorOptions.new.dtype(6) if func.tensor_options
129
133
  {
130
134
  name: func.cpp_name,
131
- args: func.args.map { |a| final_values[a[:name]] }
135
+ args: args
132
136
  }
133
137
  end
134
138
  end
@@ -18,7 +18,6 @@ module Torch
18
18
  F.conv2d(input, @weight, @bias, @stride, @padding, @dilation, @groups)
19
19
  end
20
20
 
21
- # TODO add more parameters
22
21
  def extra_inspect
23
22
  s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
24
23
  s += ", padding: %{padding}" if @padding != [0] * @padding.size
@@ -373,7 +373,8 @@ module Torch
373
373
  end
374
374
 
375
375
  # weight and input swapped
376
- Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
376
+ ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
377
+ ret
377
378
  end
378
379
 
379
380
  # distance functions
@@ -426,6 +427,9 @@ module Torch
426
427
  end
427
428
 
428
429
  def mse_loss(input, target, reduction: "mean")
430
+ if target.size != input.size
431
+ warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
432
+ end
429
433
  NN.mse_loss(input, target, reduction)
430
434
  end
431
435
 
@@ -286,8 +286,11 @@ module Torch
286
286
  str % vars
287
287
  end
288
288
 
289
+ # used for format
290
+ # remove tensors for performance
291
+ # so we can skip call to inspect
289
292
  def dict
290
- instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
293
+ instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
291
294
  end
292
295
  end
293
296
  end
@@ -32,9 +32,11 @@ module Torch
32
32
  end
33
33
 
34
34
  def state_dict
35
+ raise NotImplementedYet
36
+
35
37
  pack_group = lambda do |group|
36
- packed = group.select { |k, _| k != :params }.to_h
37
- packed[:params] = group[:params].map { |p| p.object_id }
38
+ packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
39
+ packed["params"] = group[:params].map { |p| p.object_id }
38
40
  packed
39
41
  end
40
42
 
@@ -42,8 +44,8 @@ module Torch
42
44
  packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
43
45
 
44
46
  {
45
- state: packed_state,
46
- param_groups: param_groups
47
+ "state" => packed_state,
48
+ "param_groups" => param_groups
47
49
  }
48
50
  end
49
51
 
@@ -1,6 +1,7 @@
1
1
  module Torch
2
2
  class Tensor
3
3
  include Comparable
4
+ include Enumerable
4
5
  include Inspector
5
6
 
6
7
  alias_method :requires_grad?, :requires_grad
@@ -25,6 +26,14 @@ module Torch
25
26
  inspect
26
27
  end
27
28
 
29
+ def each
30
+ return enum_for(:each) unless block_given?
31
+
32
+ size(0).times do |i|
33
+ yield self[i]
34
+ end
35
+ end
36
+
28
37
  # TODO make more performant
29
38
  def to_a
30
39
  arr = _flat_data
@@ -38,10 +47,15 @@ module Torch
38
47
  end
39
48
  end
40
49
 
41
- # TODO support dtype
42
- def to(device, non_blocking: false, copy: false)
50
+ def to(device = nil, dtype: nil, non_blocking: false, copy: false)
51
+ device ||= self.device
43
52
  device = Device.new(device) if device.is_a?(String)
44
- _to(device, _dtype, non_blocking, copy)
53
+
54
+ dtype ||= self.dtype
55
+ enum = DTYPE_TO_ENUM[dtype]
56
+ raise Error, "Unknown type: #{dtype}" unless enum
57
+
58
+ _to(device, enum, non_blocking, copy)
45
59
  end
46
60
 
47
61
  def cpu
@@ -89,8 +103,9 @@ module Torch
89
103
  Torch.empty(0, dtype: dtype)
90
104
  end
91
105
 
92
- def backward(gradient = nil)
93
- _backward(gradient)
106
+ def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
+ retain_graph = create_graph if retain_graph.nil?
108
+ _backward(gradient, retain_graph, create_graph)
94
109
  end
95
110
 
96
111
  # TODO read directly from memory
@@ -153,6 +168,18 @@ module Torch
153
168
  neg
154
169
  end
155
170
 
171
+ def &(other)
172
+ logical_and(other)
173
+ end
174
+
175
+ def |(other)
176
+ logical_or(other)
177
+ end
178
+
179
+ def ^(other)
180
+ logical_xor(other)
181
+ end
182
+
156
183
  # TODO better compare?
157
184
  def <=>(other)
158
185
  item <=> other
@@ -161,49 +188,15 @@ module Torch
161
188
  # based on python_variable_indexing.cpp and
162
189
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
163
190
  def [](*indexes)
164
- result = self
165
- dim = 0
166
- indexes.each do |index|
167
- if index.is_a?(Numeric)
168
- result = result._select_int(dim, index)
169
- elsif index.is_a?(Range)
170
- finish = index.end
171
- finish += 1 unless index.exclude_end?
172
- result = result._slice_tensor(dim, index.begin, finish, 1)
173
- dim += 1
174
- elsif index.is_a?(Tensor)
175
- result = result.index([index])
176
- elsif index.nil?
177
- result = result.unsqueeze(dim)
178
- dim += 1
179
- elsif index == true
180
- result = result.unsqueeze(dim)
181
- # TODO handle false
182
- else
183
- raise Error, "Unsupported index type: #{index.class.name}"
184
- end
185
- end
186
- result
191
+ _index(tensor_indexes(indexes))
187
192
  end
188
193
 
189
194
  # based on python_variable_indexing.cpp and
190
195
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
191
- def []=(index, value)
196
+ def []=(*indexes, value)
192
197
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
193
-
194
198
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
195
-
196
- if index.is_a?(Numeric)
197
- copy_to(_select_int(0, index), value)
198
- elsif index.is_a?(Range)
199
- finish = index.end
200
- finish += 1 unless index.exclude_end?
201
- copy_to(_slice_tensor(0, index.begin, finish, 1), value)
202
- elsif index.is_a?(Tensor)
203
- index_put!([index], value)
204
- else
205
- raise Error, "Unsupported index type: #{index.class.name}"
206
- end
199
+ _index_put_custom(tensor_indexes(indexes), value)
207
200
  end
208
201
 
209
202
  # native functions that need manually defined
@@ -217,13 +210,13 @@ module Torch
217
210
  end
218
211
  end
219
212
 
220
- # native functions overlap, so need to handle manually
213
+ # parser can't handle overlap, so need to handle manually
221
214
  def random!(*args)
222
215
  case args.size
223
216
  when 1
224
217
  _random__to(*args)
225
218
  when 2
226
- _random__from_to(*args)
219
+ _random__from(*args)
227
220
  else
228
221
  _random_(*args)
229
222
  end
@@ -236,8 +229,29 @@ module Torch
236
229
 
237
230
  private
238
231
 
239
- def copy_to(dst, src)
240
- dst.copy!(src)
232
+ def tensor_indexes(indexes)
233
+ indexes.map do |index|
234
+ case index
235
+ when Integer
236
+ TensorIndex.integer(index)
237
+ when Range
238
+ finish = index.end
239
+ if finish == -1 && !index.exclude_end?
240
+ finish = nil
241
+ else
242
+ finish += 1 unless index.exclude_end?
243
+ end
244
+ TensorIndex.slice(index.begin, finish)
245
+ when Tensor
246
+ TensorIndex.tensor(index)
247
+ when nil
248
+ TensorIndex.none
249
+ when true, false
250
+ TensorIndex.boolean(index)
251
+ else
252
+ raise Error, "Unsupported index type: #{index.class.name}"
253
+ end
254
+ end
241
255
  end
242
256
  end
243
257
  end
@@ -0,0 +1,23 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class << self
5
+ def random_split(dataset, lengths)
6
+ if lengths.sum != dataset.length
7
+ raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
8
+ end
9
+
10
+ indices = Torch.randperm(lengths.sum).to_a
11
+ _accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
12
+ end
13
+
14
+ private
15
+
16
+ def _accumulate(iterable)
17
+ sum = 0
18
+ iterable.map { |x| sum += x }
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -6,10 +6,22 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1, shuffle: false)
9
+ def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
12
  @shuffle = shuffle
13
+
14
+ @batch_sampler = nil
15
+
16
+ if collate_fn.nil?
17
+ if auto_collation?
18
+ collate_fn = method(:default_collate)
19
+ else
20
+ collate_fn = method(:default_convert)
21
+ end
22
+ end
23
+
24
+ @collate_fn = collate_fn
13
25
  end
14
26
 
15
27
  def each
@@ -25,8 +37,8 @@ module Torch
25
37
  end
26
38
 
27
39
  indexes.each_slice(@batch_size) do |idx|
28
- batch = idx.map { |i| @dataset[i] }
29
- yield collate(batch)
40
+ # TODO improve performance
41
+ yield @collate_fn.call(idx.map { |i| @dataset[i] })
30
42
  end
31
43
  end
32
44
 
@@ -36,7 +48,7 @@ module Torch
36
48
 
37
49
  private
38
50
 
39
- def collate(batch)
51
+ def default_convert(batch)
40
52
  elem = batch[0]
41
53
  case elem
42
54
  when Tensor
@@ -44,11 +56,15 @@ module Torch
44
56
  when Integer
45
57
  Torch.tensor(batch)
46
58
  when Array
47
- batch.transpose.map { |v| collate(v) }
59
+ batch.transpose.map { |v| default_convert(v) }
48
60
  else
49
- raise NotImpelmentYet
61
+ raise NotImplementedYet
50
62
  end
51
63
  end
64
+
65
+ def auto_collation?
66
+ !@batch_sampler.nil?
67
+ end
52
68
  end
53
69
  end
54
70
  end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Subset < Dataset
5
+ def initialize(dataset, indices)
6
+ @dataset = dataset
7
+ @indices = indices
8
+ end
9
+
10
+ def [](idx)
11
+ @dataset[@indices[idx]]
12
+ end
13
+
14
+ def length
15
+ @indices.length
16
+ end
17
+ alias_method :size, :length
18
+
19
+ def to_a
20
+ @indices.map { |i| @dataset[i] }
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.6"
2
+ VERSION = "0.3.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.6
4
+ version: 0.3.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-06-29 00:00:00.000000000 Z
11
+ date: 2020-08-26 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -259,8 +259,10 @@ files:
259
259
  - lib/torch/optim/rprop.rb
260
260
  - lib/torch/optim/sgd.rb
261
261
  - lib/torch/tensor.rb
262
+ - lib/torch/utils/data.rb
262
263
  - lib/torch/utils/data/data_loader.rb
263
264
  - lib/torch/utils/data/dataset.rb
265
+ - lib/torch/utils/data/subset.rb
264
266
  - lib/torch/utils/data/tensor_dataset.rb
265
267
  - lib/torch/version.rb
266
268
  homepage: https://github.com/ankane/torch.rb