red-chainer 0.4.0 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e7ed2df404bfc36275381f523c0439f2e73debcf36f3b5edb063e985502d7a70
4
- data.tar.gz: 357c983134aae985808568113d3f4f82bacebf6d25e5cf7c4f9197b1825455dc
3
+ metadata.gz: d05c5d3cb55a9e6c7e45afffd5483b9968959749d911a2e6292f5bea1e90ef40
4
+ data.tar.gz: fcdbb11f8b64a3f54a68a1629c40e9ffce774b551e79d21689757f1a495a7d66
5
5
  SHA512:
6
- metadata.gz: 40eb83d14d6efd140a4cb9748f04f50cfa325c9831d8020890a20fe88fc1485547f4dcab48cdcadfda317b46b3f4a6bc936eb8204ae39a876e053878caa7359f
7
- data.tar.gz: af4133b975c5b4b5ca6e2ce9fb05eddd2b1de5a8a30df9c776531a5acdcf5bc4d8322dc7d6875c49800587a4d98031d0eb62054dbd87ced964093c501da32c95
6
+ metadata.gz: e0a1423530e8b62dacd6ae876957cd00e604a659815dd1bd2e3ce6a7c30d3f22f68f718451b261e559fb3719211067a79c78a0b1f229e96c96db4dc7d3a7004a
7
+ data.tar.gz: da1bc06481526f5424705a8d68f0c4a1e98decaf30c81e217c5390ab96c4d4850eeecdec7cf88287c4581cc5556bbd24adbf26c3007145b9bf9e007e40047ae1
@@ -22,6 +22,7 @@ require 'chainer/initializers/uniform'
22
22
  require 'chainer/iterators/serial_iterator'
23
23
  require 'chainer/link'
24
24
  require 'chainer/links/connection/convolution_2d'
25
+ require 'chainer/links/connection/embed_id'
25
26
  require 'chainer/links/connection/linear'
26
27
  require 'chainer/links/normalization/batch_normalization'
27
28
  require 'chainer/links/model/classifier'
@@ -56,6 +57,7 @@ require 'chainer/functions/loss/softmax_cross_entropy'
56
57
  require 'chainer/functions/connection/convolution_2d'
57
58
  require 'chainer/functions/connection/deconvolution_2d'
58
59
  require 'chainer/functions/connection/convolution_2d_grad_w'
60
+ require 'chainer/functions/connection/embed_id'
59
61
  require 'chainer/functions/connection/linear'
60
62
  require 'chainer/functions/noise/dropout'
61
63
  require 'chainer/functions/normalization/batch_normalization'
@@ -0,0 +1,49 @@
1
+ module Chainer
2
+ module Functions
3
+ module Connection
4
+ class EmbedIDFunction < Chainer::Function
5
+ def initialize(ignore_label: nil)
6
+ @ignore_label = ignore_label
7
+ end
8
+
9
+ def self.embed_id(x, w, ignore_label: nil)
10
+ self.new(ignore_label: ignore_label).(x, w)
11
+ end
12
+
13
+ def forward(inputs)
14
+ xm = Chainer.get_array_module(*inputs)
15
+ (x, w) = inputs
16
+
17
+ unless @ignore_label
18
+ return [Chainer::Utils::Array.take(w, x, axis: 0)]
19
+ end
20
+
21
+ valid_x = x.ne(@ignore_label)
22
+ if valid_x.count == x.size
23
+ return [Chainer::Utils::Array.take(w, x, axis: 0)]
24
+ end
25
+ x *= valid_x
26
+ y = Chainer::Utils::Array.take(w, x, axis: 0).dup
27
+
28
+ y = y.reshape(y.shape.take(y.shape.size - 1).reduce(&:*), true)
29
+ valid_x.where2.last.each {|i| y[i, true] = y.class.zeros(y.shape.last) }
30
+
31
+ [y.reshape(*x.shape, true)]
32
+ end
33
+
34
+ def backward(inputs, grad_outputs)
35
+ (x, w) = inputs
36
+ gy = grad_outputs[0].reshape(x.size, true)
37
+ gw = w.class.zeros(w.shape).reshape(w.shape.take(w.shape.size - 1).reduce(&:*), true)
38
+
39
+ x.reshape(x.size).each_with_index do |ix, i|
40
+ next if ix == @ignore_label
41
+ gw[ix, true] = gw[ix, true] + gy[i, true]
42
+ end
43
+
44
+ [nil, gw.reshape(*w.shape)]
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end
@@ -0,0 +1,23 @@
1
+ module Chainer
2
+ module Links
3
+ module Connection
4
+ class EmbedID < ::Chainer::Link
5
+ attr_reader :w
6
+
7
+ def initialize(in_size, out_size, initial_w: nil, ignore_label: nil)
8
+ super()
9
+ @ignore_label = ignore_label
10
+
11
+ init_scope do
12
+ initial_w ||= Chainer::Initializers::Normal.new(scale: 1.0)
13
+ @w = Chainer::Parameter.new(initializer: initial_w, shape: [in_size, out_size])
14
+ end
15
+ end
16
+
17
+ def call(x)
18
+ Chainer::Functions::Connection::EmbedIDFunction.embed_id(x, @w, ignore_label: @ignore_label)
19
+ end
20
+ end
21
+ end
22
+ end
23
+ end
@@ -46,9 +46,9 @@ module Chainer
46
46
  #
47
47
  # @param [string] filename Name of the file to be loaded.
48
48
  # @param [object] obj Object to be deserialized. It must support serialization protocol.
49
- def self.load_file(filename, obj)
49
+ def self.load_file(filename, obj, path: '', strict: true)
50
50
  File.open(filename) do |f|
51
- d = self.new(Marshal.load(f))
51
+ d = self.new(Marshal.load(f), path: path, strict: strict)
52
52
  d.load(obj)
53
53
  end
54
54
  end
@@ -18,28 +18,32 @@ module Chainer
18
18
  end
19
19
  end
20
20
 
21
- def self.take(x, indices, axis: nil)
22
- if axis
23
- indices = make_indecies_with_axis(x.shape, indices, axis)
21
+ def self.ndindex(shape)
22
+ shape.reduce(&:*).times.map do |i|
23
+ shape.size.times.reduce([]) do |ndidx, j|
24
+ ndidx << (i / shape.drop(j + 1).reduce(1, &:*)) % shape[j]
25
+ end
24
26
  end
25
- x[indices]
26
27
  end
27
28
 
28
- def self.make_indecies_with_axis(shape, indices, axis, values = [])
29
- target_axis = values.size
30
- if shape.size == values.size
31
- values.zip(shape.drop(1) + [1]).reduce(0) do |sum, (x, ndim)|
32
- (sum + x) * ndim
33
- end
34
- else
35
- enum = (axis == target_axis) ? indices : (0...shape[target_axis])
36
- if enum.is_a?(Integer)
37
- make_indecies_with_axis(shape, indices, axis, values + [indices])
38
- else
39
- enum.map do |x|
40
- make_indecies_with_axis(shape, indices, axis, values + [x])
29
+ def self.take(x, indices, axis: nil)
30
+ if axis
31
+ dimensional_indices = ::Array.new(x.shape.size, true)
32
+
33
+ indices_narray = Numo::Int32.cast(indices)
34
+ if indices_narray.shape.size > 1
35
+ y = x.class.zeros(*indices_narray.shape, *x.shape.drop(axis + 1))
36
+ self.ndindex(indices_narray.shape).each do |ndidx|
37
+ dimensional_indices[axis] = indices_narray[*ndidx]
38
+ y[*ndidx, *::Array.new(x.shape.size - axis - 1, true)] = x[*dimensional_indices]
41
39
  end
40
+ return y
41
+ else
42
+ dimensional_indices[axis] = indices
42
43
  end
44
+ x[*dimensional_indices]
45
+ else
46
+ x[indices]
43
47
  end
44
48
  end
45
49
 
@@ -1,4 +1,4 @@
1
1
  module Chainer
2
- VERSION = "0.4.0"
2
+ VERSION = "0.4.1"
3
3
  end
4
4
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-chainer
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.0
4
+ version: 0.4.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Yusaku Hatanaka
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-03-28 00:00:00.000000000 Z
11
+ date: 2019-04-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -145,6 +145,7 @@ files:
145
145
  - lib/chainer/functions/connection/convolution_2d.rb
146
146
  - lib/chainer/functions/connection/convolution_2d_grad_w.rb
147
147
  - lib/chainer/functions/connection/deconvolution_2d.rb
148
+ - lib/chainer/functions/connection/embed_id.rb
148
149
  - lib/chainer/functions/connection/linear.rb
149
150
  - lib/chainer/functions/evaluation/accuracy.rb
150
151
  - lib/chainer/functions/loss/mean_squared_error.rb
@@ -169,6 +170,7 @@ files:
169
170
  - lib/chainer/iterators/serial_iterator.rb
170
171
  - lib/chainer/link.rb
171
172
  - lib/chainer/links/connection/convolution_2d.rb
173
+ - lib/chainer/links/connection/embed_id.rb
172
174
  - lib/chainer/links/connection/linear.rb
173
175
  - lib/chainer/links/model/classifier.rb
174
176
  - lib/chainer/links/normalization/batch_normalization.rb