torch-rb 0.2.5 → 0.2.6

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8f6ab78fb5cff27d0d60ddb9c08fb2f526bd60e241dd1011554b21716bdd2f43
4
- data.tar.gz: 5568f53d8d5d688e3f29fb55ddbe9457e0b933dc69f59b422035c0cee249e396
3
+ metadata.gz: da5c88539a2890933e44859af7c1acfe835405c89e82bc6a6fb2df37fdee141a
4
+ data.tar.gz: 747ab48ba1b0ba16077ed31cf505f622a4120d18be9c0942ded39810095aa68e
5
5
  SHA512:
6
- metadata.gz: 9c5dcfbf35382b37678662690677b2b90d0d544e8802703cf83dba6e10483a4df487f9687e4e898c9cc449c568f4e02f9d831daa982b5b3135af6f9ce176ec88
7
- data.tar.gz: 34f142d874606e140661ae992a9f8cd4779f95c93c11d9a89a1864dd0bd53c5480c30d9aec5897f1955e2450bd0a3bc56ed0868e3b54d82ff4cdba40af379840
6
+ metadata.gz: d2bccf16e7af54d53affbc12030fd89f417fd060afa7057e97765bca00c24b7089f74b9e8aa4bab9180045e466649676f481425dc24ab29533b74138bb03e786
7
+ data.tar.gz: 8eb49a743fedb220df4edc39d7d0c492c1827ce63f2c5d9a8d73495da63da5f6055674ca0bd60bb0b81be68db6e5c9300357de2c417e82e7af4fafd1ab6c7ca2
@@ -1,3 +1,9 @@
1
+ ## 0.2.6 (2020-06-29)
2
+
3
+ - Added support for indexing with tensors
4
+ - Added `contiguous` methods
5
+ - Fixed named parameters for nested parameters
6
+
1
7
  ## 0.2.5 (2020-06-07)
2
8
 
3
9
  - Added `download_url_to_file` and `load_state_dict_from_url` to `Torch::Hub`
data/README.md CHANGED
@@ -409,7 +409,7 @@ Here’s the list of compatible versions.
409
409
 
410
410
  Torch.rb | LibTorch
411
411
  --- | ---
412
- 0.2.0 | 1.5.0
412
+ 0.2.0+ | 1.5.0+
413
413
  0.1.8 | 1.4.0
414
414
  0.1.0-0.1.7 | 1.3.1
415
415
 
@@ -329,6 +329,11 @@ void Init_ext()
329
329
  .define_method("numel", &torch::Tensor::numel)
330
330
  .define_method("element_size", &torch::Tensor::element_size)
331
331
  .define_method("requires_grad", &torch::Tensor::requires_grad)
332
+ .define_method(
333
+ "contiguous?",
334
+ *[](Tensor& self) {
335
+ return self.is_contiguous();
336
+ })
332
337
  .define_method(
333
338
  "addcmul!",
334
339
  *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
@@ -66,6 +66,7 @@ module Torch
66
66
  end
67
67
 
68
68
  next if t == "Generator?"
69
+ next if t == "MemoryFormat"
69
70
  next if t == "MemoryFormat?"
70
71
  args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
71
72
  end
@@ -18,7 +18,7 @@ module Torch
18
18
  functions = functions()
19
19
 
20
20
  # skip functions
21
- skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"]
21
+ skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
22
22
 
23
23
  # remove functions
24
24
  functions.reject! do |f|
@@ -31,7 +31,7 @@ module Torch
31
31
  todo_functions, functions =
32
32
  functions.partition do |f|
33
33
  f.args.any? do |a|
34
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
34
+ a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
35
  skip_args.any? { |sa| a[:type].include?(sa) } ||
36
36
  # native_functions.yaml is missing size argument for normal
37
37
  # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
@@ -112,6 +112,9 @@ void add_%{type}_functions(Module m) {
112
112
  "OptionalScalarType"
113
113
  when "Tensor[]"
114
114
  "TensorList"
115
+ when "Tensor?[]"
116
+ # TODO make optional
117
+ "TensorList"
115
118
  when "int"
116
119
  "int64_t"
117
120
  when "float"
@@ -75,7 +75,7 @@ module Torch
75
75
  v.is_a?(Tensor)
76
76
  when "Tensor?"
77
77
  v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]"
78
+ when "Tensor[]", "Tensor?[]"
79
79
  v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
80
  when "int"
81
81
  if k == "reduction"
@@ -145,7 +145,7 @@ module Torch
145
145
  params = {}
146
146
  if recurse
147
147
  named_children.each do |name, mod|
148
- params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
148
+ params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
149
149
  end
150
150
  end
151
151
  instance_variables.each do |name|
@@ -11,9 +11,6 @@ module Torch
11
11
  end
12
12
 
13
13
  def step(closure = nil)
14
- # TODO implement []=
15
- raise NotImplementedYet
16
-
17
14
  loss = nil
18
15
  if closure
19
16
  loss = closure.call
@@ -158,7 +158,8 @@ module Torch
158
158
  item <=> other
159
159
  end
160
160
 
161
- # based on python_variable_indexing.cpp
161
+ # based on python_variable_indexing.cpp and
162
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
162
163
  def [](*indexes)
163
164
  result = self
164
165
  dim = 0
@@ -170,6 +171,8 @@ module Torch
170
171
  finish += 1 unless index.exclude_end?
171
172
  result = result._slice_tensor(dim, index.begin, finish, 1)
172
173
  dim += 1
174
+ elsif index.is_a?(Tensor)
175
+ result = result.index([index])
173
176
  elsif index.nil?
174
177
  result = result.unsqueeze(dim)
175
178
  dim += 1
@@ -183,12 +186,12 @@ module Torch
183
186
  result
184
187
  end
185
188
 
186
- # TODO
187
- # based on python_variable_indexing.cpp
189
+ # based on python_variable_indexing.cpp and
190
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
188
191
  def []=(index, value)
189
192
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
190
193
 
191
- value = Torch.tensor(value) unless value.is_a?(Tensor)
194
+ value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
192
195
 
193
196
  if index.is_a?(Numeric)
194
197
  copy_to(_select_int(0, index), value)
@@ -196,6 +199,8 @@ module Torch
196
199
  finish = index.end
197
200
  finish += 1 unless index.exclude_end?
198
201
  copy_to(_slice_tensor(0, index.begin, finish, 1), value)
202
+ elsif index.is_a?(Tensor)
203
+ index_put!([index], value)
199
204
  else
200
205
  raise Error, "Unsupported index type: #{index.class.name}"
201
206
  end
@@ -224,6 +229,11 @@ module Torch
224
229
  end
225
230
  end
226
231
 
232
+ def clamp!(min, max)
233
+ _clamp_min_(min)
234
+ _clamp_max_(max)
235
+ end
236
+
227
237
  private
228
238
 
229
239
  def copy_to(dst, src)
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.5"
2
+ VERSION = "0.2.6"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.5
4
+ version: 0.2.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-06-07 00:00:00.000000000 Z
11
+ date: 2020-06-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice