torch-rb 0.1.6 → 0.1.7

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9667f9d3256f5e2d39937f17ae8eb00449dd14f79bb01cd647800bd7ed972fc6
4
- data.tar.gz: 54c23612c79355e09c97da5fcf6b97c183da8316d1c2a53d6f8f0463e98342a2
3
+ metadata.gz: 51bcc56112e13ba206402857b379aee0df4c7695f75af354e833760adec67756
4
+ data.tar.gz: b2ff24940e4c219d88c5a001d4e8b4e44d0e55a35fc266989f0196e696d15bc8
5
5
  SHA512:
6
- metadata.gz: bb2c8e5aae436367aeb871a2d19958e59ed9e9c7601b1b8b4473e33094cadf6d657947582b0ec93a29cb08723f8f7c81178a2d50beb23a125d5a356769d92177
7
- data.tar.gz: 62feef39da31a19415e2e6c453aed4972e34db7367161a088944c06a977637a8b25cecc8eb2ad052b3b9deee0707f364e616cc33e7674cf0314899421f18fbee
6
+ metadata.gz: 95506016db5598333f0cb99a435d29951342af91f75ae4b1f01ef11df81891738888b90c7d27317071ad00bd9b81714cf41c0ea635c2578fd756c388b5e1da7f
7
+ data.tar.gz: 053c9c75e66fe54902f07413687deb6996afc7ae88217bd5dcc852ca59d535c663bb9fb3aed28b20dba953a42e714410867dbd6ecd747f96fe8e8dfd81da8d6c
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.1.7 (2019-01-10)
2
+
3
+ - Fixed installation error with Ruby 2.7
4
+
1
5
  ## 0.1.6 (2019-12-09)
2
6
 
3
7
  - Added recurrent layers
data/ext/torch/ext.cpp CHANGED
@@ -105,6 +105,13 @@ void Init_ext()
105
105
  return torch::zeros(size, options);
106
106
  })
107
107
  // begin operations
108
+ .define_singleton_method(
109
+ "_save",
110
+ *[](const Tensor &value) {
111
+ auto v = torch::pickle_save(value);
112
+ std::string str(v.begin(), v.end());
113
+ return str;
114
+ })
108
115
  .define_singleton_method(
109
116
  "_binary_cross_entropy_with_logits",
110
117
  *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
@@ -5,62 +5,6 @@
5
5
 
6
6
  using namespace Rice;
7
7
 
8
- template<>
9
- inline
10
- long long from_ruby<long long>(Object x)
11
- {
12
- return NUM2LL(x);
13
- }
14
-
15
- template<>
16
- inline
17
- Object to_ruby<long long>(long long const & x)
18
- {
19
- return LL2NUM(x);
20
- }
21
-
22
- template<>
23
- inline
24
- unsigned long long from_ruby<unsigned long long>(Object x)
25
- {
26
- return NUM2ULL(x);
27
- }
28
-
29
- template<>
30
- inline
31
- Object to_ruby<unsigned long long>(unsigned long long const & x)
32
- {
33
- return ULL2NUM(x);
34
- }
35
-
36
- template<>
37
- inline
38
- short from_ruby<short>(Object x)
39
- {
40
- return NUM2SHORT(x);
41
- }
42
-
43
- template<>
44
- inline
45
- Object to_ruby<short>(short const & x)
46
- {
47
- return INT2NUM(x);
48
- }
49
-
50
- template<>
51
- inline
52
- unsigned short from_ruby<unsigned short>(Object x)
53
- {
54
- return NUM2USHORT(x);
55
- }
56
-
57
- template<>
58
- inline
59
- Object to_ruby<unsigned short>(unsigned short const & x)
60
- {
61
- return UINT2NUM(x);
62
- }
63
-
64
8
  // need to wrap torch::IntArrayRef() since
65
9
  // it doesn't own underlying data
66
10
  class IntArrayRef {
data/lib/torch.rb CHANGED
@@ -300,6 +300,15 @@ module Torch
300
300
  Device.new(str)
301
301
  end
302
302
 
303
+ def save(obj, f)
304
+ raise NotImplementedYet unless obj.is_a?(Tensor)
305
+ File.binwrite(f, _save(obj))
306
+ end
307
+
308
+ def load(f)
309
+ raise NotImplementedYet
310
+ end
311
+
303
312
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
304
313
 
305
314
  def arange(start, finish = nil, step = 1, **options)
@@ -383,19 +392,19 @@ module Torch
383
392
  # --- begin like ---
384
393
 
385
394
  def ones_like(input, **options)
386
- ones(input.size, like_options(input, options))
395
+ ones(input.size, **like_options(input, options))
387
396
  end
388
397
 
389
398
  def empty_like(input, **options)
390
- empty(input.size, like_options(input, options))
399
+ empty(input.size, **like_options(input, options))
391
400
  end
392
401
 
393
402
  def full_like(input, fill_value, **options)
394
- full(input.size, fill_value, like_options(input, options))
403
+ full(input.size, fill_value, **like_options(input, options))
395
404
  end
396
405
 
397
406
  def rand_like(input, **options)
398
- rand(input.size, like_options(input, options))
407
+ rand(input.size, **like_options(input, options))
399
408
  end
400
409
 
401
410
  def randint_like(input, low, high = nil, **options)
@@ -404,15 +413,15 @@ module Torch
404
413
  high = low
405
414
  low = 0
406
415
  end
407
- randint(low, high, input.size, like_options(input, options))
416
+ randint(low, high, input.size, **like_options(input, options))
408
417
  end
409
418
 
410
419
  def randn_like(input, **options)
411
- randn(input.size, like_options(input, options))
420
+ randn(input.size, **like_options(input, options))
412
421
  end
413
422
 
414
423
  def zeros_like(input, **options)
415
- zeros(input.size, like_options(input, options))
424
+ zeros(input.size, **like_options(input, options))
416
425
  end
417
426
 
418
427
  private
data/lib/torch/ext.bundle CHANGED
Binary file
@@ -79,11 +79,19 @@ module Torch
79
79
  _apply(convert)
80
80
  end
81
81
 
82
- def call(*input)
83
- forward(*input)
82
+ def call(*input, **kwargs)
83
+ forward(*input, **kwargs)
84
84
  end
85
85
 
86
- def state_dict
86
+ def state_dict(destination: nil)
87
+ destination ||= {}
88
+ named_parameters.each do |k, v|
89
+ destination[k] = v
90
+ end
91
+ destination
92
+ end
93
+
94
+ def load_state_dict(state_dict)
87
95
  raise NotImplementedYet
88
96
  end
89
97
 
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.1.6"
2
+ VERSION = "0.1.7"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.6
4
+ version: 0.1.7
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-12-10 00:00:00.000000000 Z
11
+ date: 2020-01-11 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: '0'
19
+ version: '2.2'
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: '0'
26
+ version: '2.2'
27
27
  - !ruby/object:Gem::Dependency
28
28
  name: bundler
29
29
  requirement: !ruby/object:Gem::Requirement
@@ -261,7 +261,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
261
261
  - !ruby/object:Gem::Version
262
262
  version: '0'
263
263
  requirements: []
264
- rubygems_version: 3.0.3
264
+ rubygems_version: 3.1.2
265
265
  signing_key:
266
266
  specification_version: 4
267
267
  summary: Deep learning for Ruby, powered by LibTorch