torch-rb 0.1.6 → 0.1.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/torch/ext.cpp +7 -0
- data/ext/torch/templates.hpp +0 -56
- data/lib/torch.rb +16 -7
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/nn/module.rb +11 -3
- data/lib/torch/version.rb +1 -1
- metadata +5 -5
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 51bcc56112e13ba206402857b379aee0df4c7695f75af354e833760adec67756
|
4
|
+
data.tar.gz: b2ff24940e4c219d88c5a001d4e8b4e44d0e55a35fc266989f0196e696d15bc8
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 95506016db5598333f0cb99a435d29951342af91f75ae4b1f01ef11df81891738888b90c7d27317071ad00bd9b81714cf41c0ea635c2578fd756c388b5e1da7f
|
7
|
+
data.tar.gz: 053c9c75e66fe54902f07413687deb6996afc7ae88217bd5dcc852ca59d535c663bb9fb3aed28b20dba953a42e714410867dbd6ecd747f96fe8e8dfd81da8d6c
|
data/CHANGELOG.md
CHANGED
data/ext/torch/ext.cpp
CHANGED
@@ -105,6 +105,13 @@ void Init_ext()
|
|
105
105
|
return torch::zeros(size, options);
|
106
106
|
})
|
107
107
|
// begin operations
|
108
|
+
.define_singleton_method(
|
109
|
+
"_save",
|
110
|
+
*[](const Tensor &value) {
|
111
|
+
auto v = torch::pickle_save(value);
|
112
|
+
std::string str(v.begin(), v.end());
|
113
|
+
return str;
|
114
|
+
})
|
108
115
|
.define_singleton_method(
|
109
116
|
"_binary_cross_entropy_with_logits",
|
110
117
|
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
data/ext/torch/templates.hpp
CHANGED
@@ -5,62 +5,6 @@
|
|
5
5
|
|
6
6
|
using namespace Rice;
|
7
7
|
|
8
|
-
template<>
|
9
|
-
inline
|
10
|
-
long long from_ruby<long long>(Object x)
|
11
|
-
{
|
12
|
-
return NUM2LL(x);
|
13
|
-
}
|
14
|
-
|
15
|
-
template<>
|
16
|
-
inline
|
17
|
-
Object to_ruby<long long>(long long const & x)
|
18
|
-
{
|
19
|
-
return LL2NUM(x);
|
20
|
-
}
|
21
|
-
|
22
|
-
template<>
|
23
|
-
inline
|
24
|
-
unsigned long long from_ruby<unsigned long long>(Object x)
|
25
|
-
{
|
26
|
-
return NUM2ULL(x);
|
27
|
-
}
|
28
|
-
|
29
|
-
template<>
|
30
|
-
inline
|
31
|
-
Object to_ruby<unsigned long long>(unsigned long long const & x)
|
32
|
-
{
|
33
|
-
return ULL2NUM(x);
|
34
|
-
}
|
35
|
-
|
36
|
-
template<>
|
37
|
-
inline
|
38
|
-
short from_ruby<short>(Object x)
|
39
|
-
{
|
40
|
-
return NUM2SHORT(x);
|
41
|
-
}
|
42
|
-
|
43
|
-
template<>
|
44
|
-
inline
|
45
|
-
Object to_ruby<short>(short const & x)
|
46
|
-
{
|
47
|
-
return INT2NUM(x);
|
48
|
-
}
|
49
|
-
|
50
|
-
template<>
|
51
|
-
inline
|
52
|
-
unsigned short from_ruby<unsigned short>(Object x)
|
53
|
-
{
|
54
|
-
return NUM2USHORT(x);
|
55
|
-
}
|
56
|
-
|
57
|
-
template<>
|
58
|
-
inline
|
59
|
-
Object to_ruby<unsigned short>(unsigned short const & x)
|
60
|
-
{
|
61
|
-
return UINT2NUM(x);
|
62
|
-
}
|
63
|
-
|
64
8
|
// need to wrap torch::IntArrayRef() since
|
65
9
|
// it doesn't own underlying data
|
66
10
|
class IntArrayRef {
|
data/lib/torch.rb
CHANGED
@@ -300,6 +300,15 @@ module Torch
|
|
300
300
|
Device.new(str)
|
301
301
|
end
|
302
302
|
|
303
|
+
def save(obj, f)
|
304
|
+
raise NotImplementedYet unless obj.is_a?(Tensor)
|
305
|
+
File.binwrite(f, _save(obj))
|
306
|
+
end
|
307
|
+
|
308
|
+
def load(f)
|
309
|
+
raise NotImplementedYet
|
310
|
+
end
|
311
|
+
|
303
312
|
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
304
313
|
|
305
314
|
def arange(start, finish = nil, step = 1, **options)
|
@@ -383,19 +392,19 @@ module Torch
|
|
383
392
|
# --- begin like ---
|
384
393
|
|
385
394
|
def ones_like(input, **options)
|
386
|
-
ones(input.size, like_options(input, options))
|
395
|
+
ones(input.size, **like_options(input, options))
|
387
396
|
end
|
388
397
|
|
389
398
|
def empty_like(input, **options)
|
390
|
-
empty(input.size, like_options(input, options))
|
399
|
+
empty(input.size, **like_options(input, options))
|
391
400
|
end
|
392
401
|
|
393
402
|
def full_like(input, fill_value, **options)
|
394
|
-
full(input.size, fill_value, like_options(input, options))
|
403
|
+
full(input.size, fill_value, **like_options(input, options))
|
395
404
|
end
|
396
405
|
|
397
406
|
def rand_like(input, **options)
|
398
|
-
rand(input.size, like_options(input, options))
|
407
|
+
rand(input.size, **like_options(input, options))
|
399
408
|
end
|
400
409
|
|
401
410
|
def randint_like(input, low, high = nil, **options)
|
@@ -404,15 +413,15 @@ module Torch
|
|
404
413
|
high = low
|
405
414
|
low = 0
|
406
415
|
end
|
407
|
-
randint(low, high, input.size, like_options(input, options))
|
416
|
+
randint(low, high, input.size, **like_options(input, options))
|
408
417
|
end
|
409
418
|
|
410
419
|
def randn_like(input, **options)
|
411
|
-
randn(input.size, like_options(input, options))
|
420
|
+
randn(input.size, **like_options(input, options))
|
412
421
|
end
|
413
422
|
|
414
423
|
def zeros_like(input, **options)
|
415
|
-
zeros(input.size, like_options(input, options))
|
424
|
+
zeros(input.size, **like_options(input, options))
|
416
425
|
end
|
417
426
|
|
418
427
|
private
|
data/lib/torch/ext.bundle
CHANGED
Binary file
|
data/lib/torch/nn/module.rb
CHANGED
@@ -79,11 +79,19 @@ module Torch
|
|
79
79
|
_apply(convert)
|
80
80
|
end
|
81
81
|
|
82
|
-
def call(*input)
|
83
|
-
forward(*input)
|
82
|
+
def call(*input, **kwargs)
|
83
|
+
forward(*input, **kwargs)
|
84
84
|
end
|
85
85
|
|
86
|
-
def state_dict
|
86
|
+
def state_dict(destination: nil)
|
87
|
+
destination ||= {}
|
88
|
+
named_parameters.each do |k, v|
|
89
|
+
destination[k] = v
|
90
|
+
end
|
91
|
+
destination
|
92
|
+
end
|
93
|
+
|
94
|
+
def load_state_dict(state_dict)
|
87
95
|
raise NotImplementedYet
|
88
96
|
end
|
89
97
|
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.7
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2020-01-11 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: '
|
19
|
+
version: '2.2'
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version: '
|
26
|
+
version: '2.2'
|
27
27
|
- !ruby/object:Gem::Dependency
|
28
28
|
name: bundler
|
29
29
|
requirement: !ruby/object:Gem::Requirement
|
@@ -261,7 +261,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
261
261
|
- !ruby/object:Gem::Version
|
262
262
|
version: '0'
|
263
263
|
requirements: []
|
264
|
-
rubygems_version: 3.
|
264
|
+
rubygems_version: 3.1.2
|
265
265
|
signing_key:
|
266
266
|
specification_version: 4
|
267
267
|
summary: Deep learning for Ruby, powered by LibTorch
|