torch-rb 0.8.2 → 0.9.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/lib/torch.rb CHANGED
@@ -267,7 +267,7 @@ module Torch
267
267
  args.first.send(dtype).to(device)
268
268
  elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
269
269
  bytes = args.first.bytes
270
- Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
270
+ Torch._from_blob_ref(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
271
271
  elsif args.size == 1 && args.first.is_a?(Array)
272
272
  Torch.tensor(args.first, dtype: dtype, device: device)
273
273
  elsif args.size == 0
@@ -320,12 +320,17 @@ module Torch
320
320
  raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
321
321
  options = tensor_options(device: "cpu", dtype: dtype[0])
322
322
  # TODO pass pointer to array instead of creating string
323
- str = ndarray.to_string
324
- tensor = _from_blob(str, ndarray.shape, options)
323
+ _from_blob_ref(ndarray.to_string, ndarray.shape, options)
324
+ end
325
+
326
+ # private
327
+ # TODO use keepAlive in Rice (currently segfaults)
328
+ def _from_blob_ref(data, size, options)
329
+ tensor = _from_blob(data, size, options)
325
330
  # from_blob does not own the data, so we need to keep
326
331
  # a reference to it for duration of tensor
327
332
  # can remove when passing pointer directly
328
- tensor.instance_variable_set("@_numo_str", str)
333
+ tensor.instance_variable_set("@_numo_data", data)
329
334
  tensor
330
335
  end
331
336
 
@@ -377,8 +382,6 @@ module Torch
377
382
  to_ruby(_load(File.binread(f)))
378
383
  end
379
384
 
380
- # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
381
-
382
385
  def tensor(data, **options)
383
386
  if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
384
387
  numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
@@ -408,42 +411,13 @@ module Torch
408
411
  end
409
412
  end
410
413
 
411
- _tensor(data, size, tensor_options(**options))
412
- end
413
-
414
- # --- begin like ---
415
-
416
- def ones_like(input, **options)
417
- ones(input.size, **like_options(input, options))
418
- end
419
-
420
- def empty_like(input, **options)
421
- empty(input.size, **like_options(input, options))
422
- end
414
+ # TODO check each dimensions for consistency in future
415
+ raise Error, "Inconsistent dimensions" if data.size != size.inject(1, :*)
423
416
 
424
- def full_like(input, fill_value, **options)
425
- full(input.size, fill_value, **like_options(input, options))
426
- end
417
+ # TOOD move to C++
418
+ data = data.map { |v| v ? 1 : 0 } if options[:dtype] == :bool
427
419
 
428
- def rand_like(input, **options)
429
- rand(input.size, **like_options(input, options))
430
- end
431
-
432
- def randint_like(input, low, high = nil, **options)
433
- # ruby doesn't support input, low = 0, high, ...
434
- if high.nil?
435
- high = low
436
- low = 0
437
- end
438
- randint(low, high, input.size, **like_options(input, options))
439
- end
440
-
441
- def randn_like(input, **options)
442
- randn(input.size, **like_options(input, options))
443
- end
444
-
445
- def zeros_like(input, **options)
446
- zeros(input.size, **like_options(input, options))
420
+ _tensor(data, size, tensor_options(**options))
447
421
  end
448
422
 
449
423
  # center option
@@ -572,13 +546,5 @@ module Torch
572
546
  end
573
547
  options
574
548
  end
575
-
576
- def like_options(input, options)
577
- options = options.dup
578
- options[:dtype] ||= input.dtype
579
- options[:layout] ||= input.layout
580
- options[:device] ||= input.device
581
- options
582
- end
583
549
  end
584
550
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.2
4
+ version: 0.9.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-10-04 00:00:00.000000000 Z
11
+ date: 2022-02-03 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -227,7 +227,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
227
227
  - !ruby/object:Gem::Version
228
228
  version: '0'
229
229
  requirements: []
230
- rubygems_version: 3.2.22
230
+ rubygems_version: 3.3.3
231
231
  signing_key:
232
232
  specification_version: 4
233
233
  summary: Deep learning for Ruby, powered by LibTorch