red-chainer 0.4.0 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/chainer.rb +2 -0
- data/lib/chainer/functions/connection/embed_id.rb +49 -0
- data/lib/chainer/links/connection/embed_id.rb +23 -0
- data/lib/chainer/serializers/marshal.rb +2 -2
- data/lib/chainer/utils/array.rb +21 -17
- data/lib/chainer/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: d05c5d3cb55a9e6c7e45afffd5483b9968959749d911a2e6292f5bea1e90ef40
|
4
|
+
data.tar.gz: fcdbb11f8b64a3f54a68a1629c40e9ffce774b551e79d21689757f1a495a7d66
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: e0a1423530e8b62dacd6ae876957cd00e604a659815dd1bd2e3ce6a7c30d3f22f68f718451b261e559fb3719211067a79c78a0b1f229e96c96db4dc7d3a7004a
|
7
|
+
data.tar.gz: da1bc06481526f5424705a8d68f0c4a1e98decaf30c81e217c5390ab96c4d4850eeecdec7cf88287c4581cc5556bbd24adbf26c3007145b9bf9e007e40047ae1
|
data/lib/chainer.rb
CHANGED
@@ -22,6 +22,7 @@ require 'chainer/initializers/uniform'
|
|
22
22
|
require 'chainer/iterators/serial_iterator'
|
23
23
|
require 'chainer/link'
|
24
24
|
require 'chainer/links/connection/convolution_2d'
|
25
|
+
require 'chainer/links/connection/embed_id'
|
25
26
|
require 'chainer/links/connection/linear'
|
26
27
|
require 'chainer/links/normalization/batch_normalization'
|
27
28
|
require 'chainer/links/model/classifier'
|
@@ -56,6 +57,7 @@ require 'chainer/functions/loss/softmax_cross_entropy'
|
|
56
57
|
require 'chainer/functions/connection/convolution_2d'
|
57
58
|
require 'chainer/functions/connection/deconvolution_2d'
|
58
59
|
require 'chainer/functions/connection/convolution_2d_grad_w'
|
60
|
+
require 'chainer/functions/connection/embed_id'
|
59
61
|
require 'chainer/functions/connection/linear'
|
60
62
|
require 'chainer/functions/noise/dropout'
|
61
63
|
require 'chainer/functions/normalization/batch_normalization'
|
@@ -0,0 +1,49 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Functions
|
3
|
+
module Connection
|
4
|
+
class EmbedIDFunction < Chainer::Function
|
5
|
+
def initialize(ignore_label: nil)
|
6
|
+
@ignore_label = ignore_label
|
7
|
+
end
|
8
|
+
|
9
|
+
def self.embed_id(x, w, ignore_label: nil)
|
10
|
+
self.new(ignore_label: ignore_label).(x, w)
|
11
|
+
end
|
12
|
+
|
13
|
+
def forward(inputs)
|
14
|
+
xm = Chainer.get_array_module(*inputs)
|
15
|
+
(x, w) = inputs
|
16
|
+
|
17
|
+
unless @ignore_label
|
18
|
+
return [Chainer::Utils::Array.take(w, x, axis: 0)]
|
19
|
+
end
|
20
|
+
|
21
|
+
valid_x = x.ne(@ignore_label)
|
22
|
+
if valid_x.count == x.size
|
23
|
+
return [Chainer::Utils::Array.take(w, x, axis: 0)]
|
24
|
+
end
|
25
|
+
x *= valid_x
|
26
|
+
y = Chainer::Utils::Array.take(w, x, axis: 0).dup
|
27
|
+
|
28
|
+
y = y.reshape(y.shape.take(y.shape.size - 1).reduce(&:*), true)
|
29
|
+
valid_x.where2.last.each {|i| y[i, true] = y.class.zeros(y.shape.last) }
|
30
|
+
|
31
|
+
[y.reshape(*x.shape, true)]
|
32
|
+
end
|
33
|
+
|
34
|
+
def backward(inputs, grad_outputs)
|
35
|
+
(x, w) = inputs
|
36
|
+
gy = grad_outputs[0].reshape(x.size, true)
|
37
|
+
gw = w.class.zeros(w.shape).reshape(w.shape.take(w.shape.size - 1).reduce(&:*), true)
|
38
|
+
|
39
|
+
x.reshape(x.size).each_with_index do |ix, i|
|
40
|
+
next if ix == @ignore_label
|
41
|
+
gw[ix, true] = gw[ix, true] + gy[i, true]
|
42
|
+
end
|
43
|
+
|
44
|
+
[nil, gw.reshape(*w.shape)]
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module Chainer
|
2
|
+
module Links
|
3
|
+
module Connection
|
4
|
+
class EmbedID < ::Chainer::Link
|
5
|
+
attr_reader :w
|
6
|
+
|
7
|
+
def initialize(in_size, out_size, initial_w: nil, ignore_label: nil)
|
8
|
+
super()
|
9
|
+
@ignore_label = ignore_label
|
10
|
+
|
11
|
+
init_scope do
|
12
|
+
initial_w ||= Chainer::Initializers::Normal.new(scale: 1.0)
|
13
|
+
@w = Chainer::Parameter.new(initializer: initial_w, shape: [in_size, out_size])
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
def call(x)
|
18
|
+
Chainer::Functions::Connection::EmbedIDFunction.embed_id(x, @w, ignore_label: @ignore_label)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -46,9 +46,9 @@ module Chainer
|
|
46
46
|
#
|
47
47
|
# @param [string] filename Name of the file to be loaded.
|
48
48
|
# @param [object] obj Object to be deserialized. It must support serialization protocol.
|
49
|
-
def self.load_file(filename, obj)
|
49
|
+
def self.load_file(filename, obj, path: '', strict: true)
|
50
50
|
File.open(filename) do |f|
|
51
|
-
d = self.new(Marshal.load(f))
|
51
|
+
d = self.new(Marshal.load(f), path: path, strict: strict)
|
52
52
|
d.load(obj)
|
53
53
|
end
|
54
54
|
end
|
data/lib/chainer/utils/array.rb
CHANGED
@@ -18,28 +18,32 @@ module Chainer
|
|
18
18
|
end
|
19
19
|
end
|
20
20
|
|
21
|
-
def self.
|
22
|
-
|
23
|
-
|
21
|
+
def self.ndindex(shape)
|
22
|
+
shape.reduce(&:*).times.map do |i|
|
23
|
+
shape.size.times.reduce([]) do |ndidx, j|
|
24
|
+
ndidx << (i / shape.drop(j + 1).reduce(1, &:*)) % shape[j]
|
25
|
+
end
|
24
26
|
end
|
25
|
-
x[indices]
|
26
27
|
end
|
27
28
|
|
28
|
-
def self.
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
else
|
39
|
-
enum.map do |x|
|
40
|
-
make_indecies_with_axis(shape, indices, axis, values + [x])
|
29
|
+
def self.take(x, indices, axis: nil)
|
30
|
+
if axis
|
31
|
+
dimensional_indices = ::Array.new(x.shape.size, true)
|
32
|
+
|
33
|
+
indices_narray = Numo::Int32.cast(indices)
|
34
|
+
if indices_narray.shape.size > 1
|
35
|
+
y = x.class.zeros(*indices_narray.shape, *x.shape.drop(axis + 1))
|
36
|
+
self.ndindex(indices_narray.shape).each do |ndidx|
|
37
|
+
dimensional_indices[axis] = indices_narray[*ndidx]
|
38
|
+
y[*ndidx, *::Array.new(x.shape.size - axis - 1, true)] = x[*dimensional_indices]
|
41
39
|
end
|
40
|
+
return y
|
41
|
+
else
|
42
|
+
dimensional_indices[axis] = indices
|
42
43
|
end
|
44
|
+
x[*dimensional_indices]
|
45
|
+
else
|
46
|
+
x[indices]
|
43
47
|
end
|
44
48
|
end
|
45
49
|
|
data/lib/chainer/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-chainer
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.4.
|
4
|
+
version: 0.4.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Yusaku Hatanaka
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-
|
11
|
+
date: 2019-04-08 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -145,6 +145,7 @@ files:
|
|
145
145
|
- lib/chainer/functions/connection/convolution_2d.rb
|
146
146
|
- lib/chainer/functions/connection/convolution_2d_grad_w.rb
|
147
147
|
- lib/chainer/functions/connection/deconvolution_2d.rb
|
148
|
+
- lib/chainer/functions/connection/embed_id.rb
|
148
149
|
- lib/chainer/functions/connection/linear.rb
|
149
150
|
- lib/chainer/functions/evaluation/accuracy.rb
|
150
151
|
- lib/chainer/functions/loss/mean_squared_error.rb
|
@@ -169,6 +170,7 @@ files:
|
|
169
170
|
- lib/chainer/iterators/serial_iterator.rb
|
170
171
|
- lib/chainer/link.rb
|
171
172
|
- lib/chainer/links/connection/convolution_2d.rb
|
173
|
+
- lib/chainer/links/connection/embed_id.rb
|
172
174
|
- lib/chainer/links/connection/linear.rb
|
173
175
|
- lib/chainer/links/model/classifier.rb
|
174
176
|
- lib/chainer/links/normalization/batch_normalization.rb
|