torch-rb 0.2.2 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,9 +11,6 @@ module Torch
11
11
  end
12
12
 
13
13
  def step(closure = nil)
14
- # TODO implement []=
15
- raise NotImplementedYet
16
-
17
14
  loss = nil
18
15
  if closure
19
16
  loss = closure.call
@@ -1,6 +1,7 @@
1
1
  module Torch
2
2
  class Tensor
3
3
  include Comparable
4
+ include Enumerable
4
5
  include Inspector
5
6
 
6
7
  alias_method :requires_grad?, :requires_grad
@@ -25,8 +26,25 @@ module Torch
25
26
  inspect
26
27
  end
27
28
 
29
+ def each
30
+ return enum_for(:each) unless block_given?
31
+
32
+ size(0).times do |i|
33
+ yield self[i]
34
+ end
35
+ end
36
+
37
+ # TODO make more performant
28
38
  def to_a
29
- reshape_arr(_flat_data, shape)
39
+ arr = _flat_data
40
+ if shape.empty?
41
+ arr
42
+ else
43
+ shape[1..-1].reverse.each do |dim|
44
+ arr = arr.each_slice(dim)
45
+ end
46
+ arr.to_a
47
+ end
30
48
  end
31
49
 
32
50
  # TODO support dtype
@@ -64,7 +82,15 @@ module Torch
64
82
  if numel != 1
65
83
  raise Error, "only one element tensors can be converted to Ruby scalars"
66
84
  end
67
- _flat_data.first
85
+ to_a.first
86
+ end
87
+
88
+ def to_i
89
+ item.to_i
90
+ end
91
+
92
+ def to_f
93
+ item.to_f
68
94
  end
69
95
 
70
96
  # unsure if this is correct
@@ -80,7 +106,7 @@ module Torch
80
106
  def numo
81
107
  cls = Torch._dtype_to_numo[dtype]
82
108
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
83
- cls.cast(_flat_data).reshape(*shape)
109
+ cls.from_string(_data_str).reshape(*shape)
84
110
  end
85
111
 
86
112
  def new_ones(*size, **options)
@@ -108,15 +134,6 @@ module Torch
108
134
  _view(size)
109
135
  end
110
136
 
111
- # value and other are swapped for some methods
112
- def add!(value = 1, other)
113
- if other.is_a?(Numeric)
114
- _add__scalar(other, value)
115
- else
116
- _add__tensor(other, value)
117
- end
118
- end
119
-
120
137
  def +(other)
121
138
  add(other)
122
139
  end
@@ -145,12 +162,25 @@ module Torch
145
162
  neg
146
163
  end
147
164
 
165
+ def &(other)
166
+ logical_and(other)
167
+ end
168
+
169
+ def |(other)
170
+ logical_or(other)
171
+ end
172
+
173
+ def ^(other)
174
+ logical_xor(other)
175
+ end
176
+
148
177
  # TODO better compare?
149
178
  def <=>(other)
150
179
  item <=> other
151
180
  end
152
181
 
153
- # based on python_variable_indexing.cpp
182
+ # based on python_variable_indexing.cpp and
183
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
154
184
  def [](*indexes)
155
185
  result = self
156
186
  dim = 0
@@ -162,6 +192,8 @@ module Torch
162
192
  finish += 1 unless index.exclude_end?
163
193
  result = result._slice_tensor(dim, index.begin, finish, 1)
164
194
  dim += 1
195
+ elsif index.is_a?(Tensor)
196
+ result = result.index([index])
165
197
  elsif index.nil?
166
198
  result = result.unsqueeze(dim)
167
199
  dim += 1
@@ -175,12 +207,12 @@ module Torch
175
207
  result
176
208
  end
177
209
 
178
- # TODO
179
- # based on python_variable_indexing.cpp
210
+ # based on python_variable_indexing.cpp and
211
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
180
212
  def []=(index, value)
181
213
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
182
214
 
183
- value = Torch.tensor(value) unless value.is_a?(Tensor)
215
+ value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
184
216
 
185
217
  if index.is_a?(Numeric)
186
218
  copy_to(_select_int(0, index), value)
@@ -188,13 +220,39 @@ module Torch
188
220
  finish = index.end
189
221
  finish += 1 unless index.exclude_end?
190
222
  copy_to(_slice_tensor(0, index.begin, finish, 1), value)
223
+ elsif index.is_a?(Tensor)
224
+ index_put!([index], value)
191
225
  else
192
226
  raise Error, "Unsupported index type: #{index.class.name}"
193
227
  end
194
228
  end
195
229
 
196
- def random!(from = 0, to)
197
- _random__from_to(from, to)
230
+ # native functions that need manually defined
231
+
232
+ # value and other are swapped for some methods
233
+ def add!(value = 1, other)
234
+ if other.is_a?(Numeric)
235
+ _add__scalar(other, value)
236
+ else
237
+ _add__tensor(other, value)
238
+ end
239
+ end
240
+
241
+ # native functions overlap, so need to handle manually
242
+ def random!(*args)
243
+ case args.size
244
+ when 1
245
+ _random__to(*args)
246
+ when 2
247
+ _random__from_to(*args)
248
+ else
249
+ _random_(*args)
250
+ end
251
+ end
252
+
253
+ def clamp!(min, max)
254
+ _clamp_min_(min)
255
+ _clamp_max_(max)
198
256
  end
199
257
 
200
258
  private
@@ -202,17 +260,5 @@ module Torch
202
260
  def copy_to(dst, src)
203
261
  dst.copy!(src)
204
262
  end
205
-
206
- def reshape_arr(arr, dims)
207
- if dims.empty?
208
- arr
209
- else
210
- arr = arr.flatten
211
- dims[1..-1].reverse.each do |dim|
212
- arr = arr.each_slice(dim)
213
- end
214
- arr.to_a
215
- end
216
- end
217
263
  end
218
264
  end
@@ -6,21 +6,49 @@ module Torch
6
6
 
7
7
  attr_reader :dataset
8
8
 
9
- def initialize(dataset, batch_size: 1)
9
+ def initialize(dataset, batch_size: 1, shuffle: false)
10
10
  @dataset = dataset
11
11
  @batch_size = batch_size
12
+ @shuffle = shuffle
12
13
  end
13
14
 
14
15
  def each
15
- size.times do |i|
16
- start_index = i * @batch_size
17
- yield @dataset[start_index...(start_index + @batch_size)]
16
+ # try to keep the random number generator in sync with Python
17
+ # this makes it easy to compare results
18
+ base_seed = Torch.empty([], dtype: :int64).random!.item
19
+
20
+ indexes =
21
+ if @shuffle
22
+ Torch.randperm(@dataset.size).to_a
23
+ else
24
+ @dataset.size.times
25
+ end
26
+
27
+ indexes.each_slice(@batch_size) do |idx|
28
+ batch = idx.map { |i| @dataset[i] }
29
+ yield collate(batch)
18
30
  end
19
31
  end
20
32
 
21
33
  def size
22
34
  (@dataset.size / @batch_size.to_f).ceil
23
35
  end
36
+
37
+ private
38
+
39
+ def collate(batch)
40
+ elem = batch[0]
41
+ case elem
42
+ when Tensor
43
+ Torch.stack(batch, 0)
44
+ when Integer
45
+ Torch.tensor(batch)
46
+ when Array
47
+ batch.transpose.map { |v| collate(v) }
48
+ else
49
+ raise NotImpelmentYet
50
+ end
51
+ end
24
52
  end
25
53
  end
26
54
  end
@@ -0,0 +1,8 @@
1
+ module Torch
2
+ module Utils
3
+ module Data
4
+ class Dataset
5
+ end
6
+ end
7
+ end
8
+ end
@@ -1,7 +1,7 @@
1
1
  module Torch
2
2
  module Utils
3
3
  module Data
4
- class TensorDataset
4
+ class TensorDataset < Dataset
5
5
  def initialize(*tensors)
6
6
  unless tensors.all? { |t| t.size(0) == tensors[0].size(0) }
7
7
  raise Error, "Tensors must all have same dim 0 size"
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.2"
2
+ VERSION = "0.2.7"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.2
4
+ version: 0.2.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-27 00:00:00.000000000 Z
11
+ date: 2020-06-30 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -95,19 +95,19 @@ dependencies:
95
95
  - !ruby/object:Gem::Version
96
96
  version: '0'
97
97
  - !ruby/object:Gem::Dependency
98
- name: npy
98
+ name: torchvision
99
99
  requirement: !ruby/object:Gem::Requirement
100
100
  requirements:
101
101
  - - ">="
102
102
  - !ruby/object:Gem::Version
103
- version: '0'
103
+ version: 0.1.1
104
104
  type: :development
105
105
  prerelease: false
106
106
  version_requirements: !ruby/object:Gem::Requirement
107
107
  requirements:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
- version: '0'
110
+ version: 0.1.1
111
111
  description:
112
112
  email: andrew@chartkick.com
113
113
  executables: []
@@ -258,9 +258,9 @@ files:
258
258
  - lib/torch/optim/rmsprop.rb
259
259
  - lib/torch/optim/rprop.rb
260
260
  - lib/torch/optim/sgd.rb
261
- - lib/torch/random.rb
262
261
  - lib/torch/tensor.rb
263
262
  - lib/torch/utils/data/data_loader.rb
263
+ - lib/torch/utils/data/dataset.rb
264
264
  - lib/torch/utils/data/tensor_dataset.rb
265
265
  - lib/torch/version.rb
266
266
  homepage: https://github.com/ankane/torch.rb
@@ -1,10 +0,0 @@
1
- module Torch
2
- module Random
3
- class << self
4
- # not available through LibTorch
5
- def initial_seed
6
- raise NotImplementedYet
7
- end
8
- end
9
- end
10
- end