torch-rb 0.2.1 → 0.2.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b37dc4dd7be5806879c2fb5bb52ac94c8b16eba76bc2a3c591ca4cbe51cf8745
4
- data.tar.gz: 6415d14f7cc8baa4db9205c709b70fec2b90b4dff1b60de97c299c2a7edfbf40
3
+ metadata.gz: 8e1f9758c937519ca31d92f3acd35ce0372f8cf57362cd3b50bd45920e7a6763
4
+ data.tar.gz: 4d370857ee758694b974da0f5d0973a687181ef2fedcad5c583f446ceb67dda2
5
5
  SHA512:
6
- metadata.gz: cf43cb21e18171f76f1291f2cccdb8a93141605fe33d4421eaf799a4589638d33da040b2ffad3ddce34fec60ab2b41edf1fa9a247d69de7f31e157063e57f331
7
- data.tar.gz: 3c858de8e7eb6169359fad18104c08d9a4011a305f44c4a8213a289b005dc4439b1959a40effd8b1bba1c2e96a871b28f01695fdda56e6ad0aeed0c2334cfa25
6
+ metadata.gz: b1cbb37019852bfdbfc45b28ac32924b4de313ce21112ffe8bb5ec91fe17898d3a1ceb42d77540e6b7a0d656e9443002fb72a1201453507dca4915db13879167
7
+ data.tar.gz: d1a35689c3ad6a0628485633af4f6d7f613288fbf739f9e94ccfdb72c613b0d4581a21e00aef3a309771967f65086034befc519bc89ecc7884ff1dd142a8289f
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.2.2 (2020-04-27)
2
+
3
+ - Added support for saving tensor lists
4
+ - Added `ndim` and `ndimension` methods to tensors
5
+
1
6
  ## 0.2.1 (2020-04-26)
2
7
 
3
8
  - Added support for saving and loading models
data/README.md CHANGED
@@ -426,7 +426,7 @@ Torch::CUDA.available?
426
426
  Move a neural network to a GPU
427
427
 
428
428
  ```ruby
429
- net.to("cuda")
429
+ net.cuda
430
430
  ```
431
431
 
432
432
  ## rbenv
data/ext/torch/ext.cpp CHANGED
@@ -82,6 +82,16 @@ void Init_ext()
82
82
  *[](torch::IValue& self) {
83
83
  return self.toInt();
84
84
  })
85
+ .define_method(
86
+ "to_list",
87
+ *[](torch::IValue& self) {
88
+ auto list = self.toListRef();
89
+ Array obj;
90
+ for (auto& elem : list) {
91
+ obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
92
+ }
93
+ return obj;
94
+ })
85
95
  .define_method(
86
96
  "to_string_ref",
87
97
  *[](torch::IValue& self) {
@@ -96,17 +106,27 @@ void Init_ext()
96
106
  "to_generic_dict",
97
107
  *[](torch::IValue& self) {
98
108
  auto dict = self.toGenericDict();
99
- Hash h;
109
+ Hash obj;
100
110
  for (auto& pair : dict) {
101
- h[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
111
+ obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
102
112
  }
103
- return h;
113
+ return obj;
104
114
  })
105
115
  .define_singleton_method(
106
116
  "from_tensor",
107
117
  *[](torch::Tensor& v) {
108
118
  return torch::IValue(v);
109
119
  })
120
+ // TODO create specialized list types?
121
+ .define_singleton_method(
122
+ "from_list",
123
+ *[](Array obj) {
124
+ c10::impl::GenericList list(c10::AnyType::get());
125
+ for (auto entry : obj) {
126
+ list.push_back(from_ruby<torch::IValue>(entry));
127
+ }
128
+ return torch::IValue(list);
129
+ })
110
130
  .define_singleton_method(
111
131
  "from_string",
112
132
  *[](String v) {
data/lib/torch/tensor.rb CHANGED
@@ -4,6 +4,8 @@ module Torch
4
4
  include Inspector
5
5
 
6
6
  alias_method :requires_grad?, :requires_grad
7
+ alias_method :ndim, :dim
8
+ alias_method :ndimension, :dim
7
9
 
8
10
  def self.new(*args)
9
11
  FloatTensor.new(*args)
@@ -143,6 +145,7 @@ module Torch
143
145
  neg
144
146
  end
145
147
 
148
+ # TODO better compare?
146
149
  def <=>(other)
147
150
  item <=> other
148
151
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.2.1"
2
+ VERSION = "0.2.2"
3
3
  end
data/lib/torch.rb CHANGED
@@ -466,6 +466,12 @@ module Torch
466
466
  IValue.from_bool(obj)
467
467
  when nil
468
468
  IValue.new
469
+ when Array
470
+ if obj.all? { |v| v.is_a?(Tensor) }
471
+ IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
472
+ else
473
+ raise Error, "Unknown list type"
474
+ end
469
475
  else
470
476
  raise Error, "Unknown type: #{obj.class.name}"
471
477
  end
@@ -490,6 +496,8 @@ module Torch
490
496
  dict[to_ruby(k)] = to_ruby(v)
491
497
  end
492
498
  dict
499
+ elsif ivalue.list?
500
+ ivalue.to_list.map { |v| to_ruby(v) }
493
501
  else
494
502
  type =
495
503
  if ivalue.capsule?
@@ -510,8 +518,6 @@ module Torch
510
518
  "BoolList"
511
519
  elsif ivalue.tensor_list?
512
520
  "TensorList"
513
- elsif ivalue.list?
514
- "List"
515
521
  elsif ivalue.object?
516
522
  "Object"
517
523
  elsif ivalue.module?
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.1
4
+ version: 0.2.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane