torch-rb 0.1.6 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/torch/ext.cpp +7 -0
- data/ext/torch/templates.hpp +0 -56
- data/lib/torch.rb +16 -7
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/nn/module.rb +11 -3
- data/lib/torch/version.rb +1 -1
- metadata +5 -5
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 51bcc56112e13ba206402857b379aee0df4c7695f75af354e833760adec67756
|
4
|
+
data.tar.gz: b2ff24940e4c219d88c5a001d4e8b4e44d0e55a35fc266989f0196e696d15bc8
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 95506016db5598333f0cb99a435d29951342af91f75ae4b1f01ef11df81891738888b90c7d27317071ad00bd9b81714cf41c0ea635c2578fd756c388b5e1da7f
|
7
|
+
data.tar.gz: 053c9c75e66fe54902f07413687deb6996afc7ae88217bd5dcc852ca59d535c663bb9fb3aed28b20dba953a42e714410867dbd6ecd747f96fe8e8dfd81da8d6c
|
data/CHANGELOG.md
CHANGED
data/ext/torch/ext.cpp
CHANGED
@@ -105,6 +105,13 @@ void Init_ext()
|
|
105
105
|
return torch::zeros(size, options);
|
106
106
|
})
|
107
107
|
// begin operations
|
108
|
+
.define_singleton_method(
|
109
|
+
"_save",
|
110
|
+
*[](const Tensor &value) {
|
111
|
+
auto v = torch::pickle_save(value);
|
112
|
+
std::string str(v.begin(), v.end());
|
113
|
+
return str;
|
114
|
+
})
|
108
115
|
.define_singleton_method(
|
109
116
|
"_binary_cross_entropy_with_logits",
|
110
117
|
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
data/ext/torch/templates.hpp
CHANGED
@@ -5,62 +5,6 @@
|
|
5
5
|
|
6
6
|
using namespace Rice;
|
7
7
|
|
8
|
-
template<>
|
9
|
-
inline
|
10
|
-
long long from_ruby<long long>(Object x)
|
11
|
-
{
|
12
|
-
return NUM2LL(x);
|
13
|
-
}
|
14
|
-
|
15
|
-
template<>
|
16
|
-
inline
|
17
|
-
Object to_ruby<long long>(long long const & x)
|
18
|
-
{
|
19
|
-
return LL2NUM(x);
|
20
|
-
}
|
21
|
-
|
22
|
-
template<>
|
23
|
-
inline
|
24
|
-
unsigned long long from_ruby<unsigned long long>(Object x)
|
25
|
-
{
|
26
|
-
return NUM2ULL(x);
|
27
|
-
}
|
28
|
-
|
29
|
-
template<>
|
30
|
-
inline
|
31
|
-
Object to_ruby<unsigned long long>(unsigned long long const & x)
|
32
|
-
{
|
33
|
-
return ULL2NUM(x);
|
34
|
-
}
|
35
|
-
|
36
|
-
template<>
|
37
|
-
inline
|
38
|
-
short from_ruby<short>(Object x)
|
39
|
-
{
|
40
|
-
return NUM2SHORT(x);
|
41
|
-
}
|
42
|
-
|
43
|
-
template<>
|
44
|
-
inline
|
45
|
-
Object to_ruby<short>(short const & x)
|
46
|
-
{
|
47
|
-
return INT2NUM(x);
|
48
|
-
}
|
49
|
-
|
50
|
-
template<>
|
51
|
-
inline
|
52
|
-
unsigned short from_ruby<unsigned short>(Object x)
|
53
|
-
{
|
54
|
-
return NUM2USHORT(x);
|
55
|
-
}
|
56
|
-
|
57
|
-
template<>
|
58
|
-
inline
|
59
|
-
Object to_ruby<unsigned short>(unsigned short const & x)
|
60
|
-
{
|
61
|
-
return UINT2NUM(x);
|
62
|
-
}
|
63
|
-
|
64
8
|
// need to wrap torch::IntArrayRef() since
|
65
9
|
// it doesn't own underlying data
|
66
10
|
class IntArrayRef {
|
data/lib/torch.rb
CHANGED
@@ -300,6 +300,15 @@ module Torch
|
|
300
300
|
Device.new(str)
|
301
301
|
end
|
302
302
|
|
303
|
+
def save(obj, f)
|
304
|
+
raise NotImplementedYet unless obj.is_a?(Tensor)
|
305
|
+
File.binwrite(f, _save(obj))
|
306
|
+
end
|
307
|
+
|
308
|
+
def load(f)
|
309
|
+
raise NotImplementedYet
|
310
|
+
end
|
311
|
+
|
303
312
|
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
304
313
|
|
305
314
|
def arange(start, finish = nil, step = 1, **options)
|
@@ -383,19 +392,19 @@ module Torch
|
|
383
392
|
# --- begin like ---
|
384
393
|
|
385
394
|
def ones_like(input, **options)
|
386
|
-
ones(input.size, like_options(input, options))
|
395
|
+
ones(input.size, **like_options(input, options))
|
387
396
|
end
|
388
397
|
|
389
398
|
def empty_like(input, **options)
|
390
|
-
empty(input.size, like_options(input, options))
|
399
|
+
empty(input.size, **like_options(input, options))
|
391
400
|
end
|
392
401
|
|
393
402
|
def full_like(input, fill_value, **options)
|
394
|
-
full(input.size, fill_value, like_options(input, options))
|
403
|
+
full(input.size, fill_value, **like_options(input, options))
|
395
404
|
end
|
396
405
|
|
397
406
|
def rand_like(input, **options)
|
398
|
-
rand(input.size, like_options(input, options))
|
407
|
+
rand(input.size, **like_options(input, options))
|
399
408
|
end
|
400
409
|
|
401
410
|
def randint_like(input, low, high = nil, **options)
|
@@ -404,15 +413,15 @@ module Torch
|
|
404
413
|
high = low
|
405
414
|
low = 0
|
406
415
|
end
|
407
|
-
randint(low, high, input.size, like_options(input, options))
|
416
|
+
randint(low, high, input.size, **like_options(input, options))
|
408
417
|
end
|
409
418
|
|
410
419
|
def randn_like(input, **options)
|
411
|
-
randn(input.size, like_options(input, options))
|
420
|
+
randn(input.size, **like_options(input, options))
|
412
421
|
end
|
413
422
|
|
414
423
|
def zeros_like(input, **options)
|
415
|
-
zeros(input.size, like_options(input, options))
|
424
|
+
zeros(input.size, **like_options(input, options))
|
416
425
|
end
|
417
426
|
|
418
427
|
private
|
data/lib/torch/ext.bundle
CHANGED
Binary file
|
data/lib/torch/nn/module.rb
CHANGED
@@ -79,11 +79,19 @@ module Torch
|
|
79
79
|
_apply(convert)
|
80
80
|
end
|
81
81
|
|
82
|
-
def call(*input)
|
83
|
-
forward(*input)
|
82
|
+
def call(*input, **kwargs)
|
83
|
+
forward(*input, **kwargs)
|
84
84
|
end
|
85
85
|
|
86
|
-
def state_dict
|
86
|
+
def state_dict(destination: nil)
|
87
|
+
destination ||= {}
|
88
|
+
named_parameters.each do |k, v|
|
89
|
+
destination[k] = v
|
90
|
+
end
|
91
|
+
destination
|
92
|
+
end
|
93
|
+
|
94
|
+
def load_state_dict(state_dict)
|
87
95
|
raise NotImplementedYet
|
88
96
|
end
|
89
97
|
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.7
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2020-01-11 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: '
|
19
|
+
version: '2.2'
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version: '
|
26
|
+
version: '2.2'
|
27
27
|
- !ruby/object:Gem::Dependency
|
28
28
|
name: bundler
|
29
29
|
requirement: !ruby/object:Gem::Requirement
|
@@ -261,7 +261,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
261
261
|
- !ruby/object:Gem::Version
|
262
262
|
version: '0'
|
263
263
|
requirements: []
|
264
|
-
rubygems_version: 3.
|
264
|
+
rubygems_version: 3.1.2
|
265
265
|
signing_key:
|
266
266
|
specification_version: 4
|
267
267
|
summary: Deep learning for Ruby, powered by LibTorch
|