tensor_stream 1.0.0 → 1.0.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +1 -0
  3. data/.rubocop.yml +1 -0
  4. data/Gemfile +1 -1
  5. data/LICENSE.txt +1 -1
  6. data/README.md +34 -34
  7. data/Rakefile +3 -3
  8. data/USAGE_GUIDE.md +235 -0
  9. data/bin/stubgen +20 -0
  10. data/exe/model_utils +2 -2
  11. data/lib/tensor_stream.rb +45 -44
  12. data/lib/tensor_stream/constant.rb +2 -2
  13. data/lib/tensor_stream/control_flow.rb +1 -1
  14. data/lib/tensor_stream/debugging/debugging.rb +2 -2
  15. data/lib/tensor_stream/dynamic_stitch.rb +2 -2
  16. data/lib/tensor_stream/evaluator/base_evaluator.rb +18 -18
  17. data/lib/tensor_stream/evaluator/buffer.rb +1 -1
  18. data/lib/tensor_stream/evaluator/evaluator.rb +2 -2
  19. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +41 -41
  20. data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +1 -1
  21. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +39 -39
  22. data/lib/tensor_stream/evaluator/ruby/check_ops.rb +2 -2
  23. data/lib/tensor_stream/evaluator/ruby/images_ops.rb +18 -18
  24. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +13 -14
  25. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +33 -36
  26. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +20 -21
  27. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +36 -49
  28. data/lib/tensor_stream/exceptions.rb +1 -1
  29. data/lib/tensor_stream/generated_stub/ops.rb +691 -0
  30. data/lib/tensor_stream/generated_stub/stub_file.erb +24 -0
  31. data/lib/tensor_stream/graph.rb +18 -18
  32. data/lib/tensor_stream/graph_builder.rb +17 -17
  33. data/lib/tensor_stream/graph_deserializers/protobuf.rb +97 -97
  34. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +1 -1
  35. data/lib/tensor_stream/graph_keys.rb +3 -3
  36. data/lib/tensor_stream/graph_serializers/graphml.rb +33 -33
  37. data/lib/tensor_stream/graph_serializers/packer.rb +23 -23
  38. data/lib/tensor_stream/graph_serializers/pbtext.rb +38 -42
  39. data/lib/tensor_stream/graph_serializers/serializer.rb +3 -2
  40. data/lib/tensor_stream/graph_serializers/yaml.rb +5 -5
  41. data/lib/tensor_stream/helpers/infer_shape.rb +56 -56
  42. data/lib/tensor_stream/helpers/op_helper.rb +8 -9
  43. data/lib/tensor_stream/helpers/string_helper.rb +15 -15
  44. data/lib/tensor_stream/helpers/tensor_mixins.rb +17 -17
  45. data/lib/tensor_stream/images.rb +1 -1
  46. data/lib/tensor_stream/initializer.rb +1 -1
  47. data/lib/tensor_stream/math_gradients.rb +28 -187
  48. data/lib/tensor_stream/monkey_patches/array.rb +1 -1
  49. data/lib/tensor_stream/monkey_patches/float.rb +1 -1
  50. data/lib/tensor_stream/monkey_patches/integer.rb +1 -1
  51. data/lib/tensor_stream/monkey_patches/op_patch.rb +5 -5
  52. data/lib/tensor_stream/monkey_patches/patch.rb +1 -1
  53. data/lib/tensor_stream/nn/nn_ops.rb +17 -15
  54. data/lib/tensor_stream/op_maker.rb +180 -0
  55. data/lib/tensor_stream/operation.rb +17 -17
  56. data/lib/tensor_stream/ops.rb +95 -384
  57. data/lib/tensor_stream/ops/add.rb +23 -0
  58. data/lib/tensor_stream/ops/argmax.rb +14 -0
  59. data/lib/tensor_stream/ops/argmin.rb +14 -0
  60. data/lib/tensor_stream/ops/case.rb +17 -0
  61. data/lib/tensor_stream/ops/cast.rb +15 -0
  62. data/lib/tensor_stream/ops/ceil.rb +15 -0
  63. data/lib/tensor_stream/ops/const.rb +0 -0
  64. data/lib/tensor_stream/ops/cos.rb +10 -0
  65. data/lib/tensor_stream/ops/div.rb +21 -0
  66. data/lib/tensor_stream/ops/equal.rb +15 -0
  67. data/lib/tensor_stream/ops/expand_dims.rb +17 -0
  68. data/lib/tensor_stream/ops/fill.rb +19 -0
  69. data/lib/tensor_stream/ops/floor.rb +15 -0
  70. data/lib/tensor_stream/ops/floor_div.rb +15 -0
  71. data/lib/tensor_stream/ops/greater.rb +11 -0
  72. data/lib/tensor_stream/ops/greater_equal.rb +11 -0
  73. data/lib/tensor_stream/ops/less_equal.rb +15 -0
  74. data/lib/tensor_stream/ops/log.rb +14 -0
  75. data/lib/tensor_stream/ops/mat_mul.rb +60 -0
  76. data/lib/tensor_stream/ops/max.rb +15 -0
  77. data/lib/tensor_stream/ops/min.rb +15 -0
  78. data/lib/tensor_stream/ops/mod.rb +23 -0
  79. data/lib/tensor_stream/ops/mul.rb +21 -0
  80. data/lib/tensor_stream/ops/negate.rb +14 -0
  81. data/lib/tensor_stream/ops/ones_like.rb +19 -0
  82. data/lib/tensor_stream/ops/pow.rb +25 -0
  83. data/lib/tensor_stream/ops/prod.rb +60 -0
  84. data/lib/tensor_stream/ops/random_uniform.rb +18 -0
  85. data/lib/tensor_stream/ops/range.rb +20 -0
  86. data/lib/tensor_stream/ops/rank.rb +13 -0
  87. data/lib/tensor_stream/ops/reshape.rb +24 -0
  88. data/lib/tensor_stream/ops/round.rb +15 -0
  89. data/lib/tensor_stream/ops/shape.rb +14 -0
  90. data/lib/tensor_stream/ops/sigmoid.rb +10 -0
  91. data/lib/tensor_stream/ops/sign.rb +12 -0
  92. data/lib/tensor_stream/ops/sin.rb +10 -0
  93. data/lib/tensor_stream/ops/size.rb +16 -0
  94. data/lib/tensor_stream/ops/sub.rb +24 -0
  95. data/lib/tensor_stream/ops/sum.rb +27 -0
  96. data/lib/tensor_stream/ops/tan.rb +12 -0
  97. data/lib/tensor_stream/ops/tanh.rb +10 -0
  98. data/lib/tensor_stream/ops/tile.rb +19 -0
  99. data/lib/tensor_stream/ops/zeros.rb +15 -0
  100. data/lib/tensor_stream/placeholder.rb +2 -2
  101. data/lib/tensor_stream/profile/report_tool.rb +3 -3
  102. data/lib/tensor_stream/session.rb +36 -38
  103. data/lib/tensor_stream/tensor.rb +2 -2
  104. data/lib/tensor_stream/tensor_shape.rb +4 -4
  105. data/lib/tensor_stream/train/adadelta_optimizer.rb +8 -8
  106. data/lib/tensor_stream/train/adagrad_optimizer.rb +3 -3
  107. data/lib/tensor_stream/train/adam_optimizer.rb +11 -11
  108. data/lib/tensor_stream/train/learning_rate_decay.rb +2 -2
  109. data/lib/tensor_stream/train/momentum_optimizer.rb +7 -7
  110. data/lib/tensor_stream/train/optimizer.rb +9 -9
  111. data/lib/tensor_stream/train/rmsprop_optimizer.rb +16 -16
  112. data/lib/tensor_stream/train/saver.rb +14 -14
  113. data/lib/tensor_stream/train/slot_creator.rb +6 -6
  114. data/lib/tensor_stream/train/utils.rb +12 -12
  115. data/lib/tensor_stream/trainer.rb +10 -10
  116. data/lib/tensor_stream/types.rb +1 -1
  117. data/lib/tensor_stream/utils.rb +33 -32
  118. data/lib/tensor_stream/utils/freezer.rb +5 -5
  119. data/lib/tensor_stream/variable.rb +5 -5
  120. data/lib/tensor_stream/variable_scope.rb +1 -1
  121. data/lib/tensor_stream/version.rb +1 -1
  122. data/samples/{iris.data → datasets/iris.data} +0 -0
  123. data/samples/jupyter_notebooks/linear_regression.ipynb +463 -0
  124. data/samples/{iris.rb → neural_networks/iris.rb} +21 -23
  125. data/samples/{mnist_data.rb → neural_networks/mnist_data.rb} +8 -8
  126. data/samples/neural_networks/raw_neural_net_sample.rb +112 -0
  127. data/samples/{rnn.rb → neural_networks/rnn.rb} +28 -31
  128. data/samples/{nearest_neighbor.rb → others/nearest_neighbor.rb} +12 -12
  129. data/samples/regression/linear_regression.rb +63 -0
  130. data/samples/{logistic_regression.rb → regression/logistic_regression.rb} +14 -16
  131. data/tensor_stream.gemspec +9 -8
  132. metadata +89 -19
  133. data/data_1.json +0 -4764
  134. data/data_2.json +0 -4764
  135. data/data_actual.json +0 -28
  136. data/data_expected.json +0 -28
  137. data/data_input.json +0 -28
  138. data/samples/error.graphml +0 -2755
  139. data/samples/gradient_sample.graphml +0 -1255
  140. data/samples/linear_regression.rb +0 -69
  141. data/samples/multigpu.rb +0 -73
  142. data/samples/raw_neural_net_sample.rb +0 -112
@@ -11,7 +11,7 @@ module TensorStream
11
11
  @options = options
12
12
  @is_const = true
13
13
  @internal = options[:internal]
14
- @name = [@graph.get_name_scope, options[:name] || build_name].compact.reject(&:empty?).join('/')
14
+ @name = [@graph.get_name_scope, options[:name] || build_name].compact.reject(&:empty?).join("/")
15
15
  @given_name = @name
16
16
 
17
17
  if options[:value]
@@ -42,4 +42,4 @@ module TensorStream
42
42
  "Const"
43
43
  end
44
44
  end
45
- end
45
+ end
@@ -8,7 +8,7 @@ module TensorStream
8
8
  @options = options
9
9
  @operation = :"flow_#{flow_type}"
10
10
  @inputs = inputs
11
- @name = [@graph.get_name_scope, options[:name] || set_name].compact.join('/')
11
+ @name = [@graph.get_name_scope, options[:name] || set_name].compact.join("/")
12
12
  @ops = ops
13
13
  @consumers = Set.new
14
14
  @shape = TensorShape.new([inputs.size])
@@ -12,7 +12,7 @@ module TensorStream
12
12
  next input if input.is_a?(Variable)
13
13
 
14
14
  if input.is_a?(Tensor) && TensorStream::Ops::FLOATING_POINT_TYPES.include?(input.data_type)
15
- TensorStream.check_numerics(input, "#{node.name}/#{input.name}", name: "check/#{node.name}/#{input.name}" )
15
+ TensorStream.check_numerics(input, "#{node.name}/#{input.name}", name: "check/#{node.name}/#{input.name}")
16
16
  else
17
17
  input
18
18
  end
@@ -20,4 +20,4 @@ module TensorStream
20
20
  end
21
21
  end
22
22
  end
23
- end
23
+ end
@@ -12,7 +12,7 @@ module TensorStream
12
12
 
13
13
  @consumers = Set.new
14
14
  @data_type = Tensor.detect_type(inputs[1])
15
- @name = [@graph.get_name_scope, options[:name] || set_name].compact.join('/')
15
+ @name = [@graph.get_name_scope, options[:name] || set_name].compact.join("/")
16
16
  @ops = ops
17
17
  @shape = TensorShape.new(nil)
18
18
  @graph.add_node(self)
@@ -26,4 +26,4 @@ module TensorStream
26
26
  eval
27
27
  end
28
28
  end
29
- end
29
+ end
@@ -34,20 +34,20 @@ module TensorStream
34
34
  ##
35
35
  # Query all supported devices
36
36
  def self.query_supported_devices
37
- [Device.new('cpu', :cpu, self)]
37
+ [Device.new("cpu", :cpu, self)]
38
38
  end
39
39
 
40
40
  ##
41
41
  # Select the best device available in the system for this evaluator
42
42
  def self.default_device
43
- Device.new('cpu', :cpu, self)
43
+ Device.new("cpu", :cpu, self)
44
44
  end
45
45
 
46
46
  ##
47
47
  # Selects the best device with the specified query, query can
48
48
  # be evaluator specific
49
49
  def self.fetch_device(_query = [])
50
- Device.new('cpu', :cpu, self)
50
+ Device.new("cpu", :cpu, self)
51
51
  end
52
52
 
53
53
  ##
@@ -56,12 +56,12 @@ module TensorStream
56
56
  return default_device if query.nil? || query == :default
57
57
 
58
58
  all_devices = query_supported_devices
59
- substrs = query.split('/')
59
+ substrs = query.split("/")
60
60
  substrs.each do |q|
61
- components = q.split(':')
61
+ components = q.split(":")
62
62
  next if components.size.zero?
63
63
 
64
- if components[0] == 'device' # use tensorflow convention
64
+ if components[0] == "device" # use tensorflow convention
65
65
  device_type = components[1]
66
66
  select_index = components[2].to_i
67
67
 
@@ -79,7 +79,7 @@ module TensorStream
79
79
 
80
80
  select_index = [devices.size - 1, select_index].min
81
81
  return devices[select_index]
82
- elsif components[0] == 'ts' # tensorstream specific
82
+ elsif components[0] == "ts" # tensorstream specific
83
83
  evaluator_class = TensorStream::Evaluator.evaluators[components[1]][:class]
84
84
  return nil unless self == evaluator_class
85
85
  return evaluator_class.fetch_device(components[2..components.size]) if evaluator_class.respond_to?(:fetch_device)
@@ -95,10 +95,10 @@ module TensorStream
95
95
  @ops ||= {}
96
96
  if opcode.is_a?(Array)
97
97
  opcode.each do |op|
98
- @ops[op.to_sym] = { options: options, block: block }
98
+ @ops[op.to_sym] = {options: options, block: block}
99
99
  end
100
100
  else
101
- @ops[opcode.to_sym] = { options: options, block: block }
101
+ @ops[opcode.to_sym] = {options: options, block: block}
102
102
  end
103
103
  end
104
104
 
@@ -115,7 +115,7 @@ module TensorStream
115
115
  op = self.class.ops[tensor.operation.to_sym]
116
116
  op_options = op[:options]
117
117
 
118
- resolved_inputs = tensor.inputs.map do |i|
118
+ resolved_inputs = tensor.inputs.map { |i|
119
119
  next if i.nil?
120
120
  next i if op_options[:noop]
121
121
 
@@ -124,25 +124,25 @@ module TensorStream
124
124
  end
125
125
 
126
126
  global_eval(tensor, i, execution_context, op_options)
127
- end
127
+ }
128
128
 
129
129
  start_time = if profile_enabled?
130
- time = Time.now
131
- time.to_i * (10**9) + time.nsec
132
- end
130
+ time = Time.now
131
+ time.to_i * (10**9) + time.nsec
132
+ end
133
133
 
134
134
  instance_exec(execution_context, tensor, resolved_inputs, &op[:block]).tap do
135
135
  if profile_enabled?
136
136
  time = Time.now
137
137
  end_time = time.to_i * (10**9) + time.nsec
138
- @context[:profile] ||= { step: 0, operations: {} }
138
+ @context[:profile] ||= {step: 0, operations: {}}
139
139
  @context[:profile][:step] += 1
140
- @context[:profile][:operations][tensor.name] = { op: tensor.operation,
140
+ @context[:profile][:operations][tensor.name] = {op: tensor.operation,
141
141
  step: @context[:profile][:step],
142
142
  eval_time: end_time - start_time,
143
143
  shape: tensor.shape ? tensor.shape.shape : nil,
144
144
  data_type: tensor.data_type,
145
- tensor: tensor }
145
+ tensor: tensor,}
146
146
  end
147
147
  end
148
148
  end
@@ -222,7 +222,7 @@ module TensorStream
222
222
 
223
223
  def self.register_evaluator(klass, name, index = 0)
224
224
  @evaluators ||= {}
225
- @evaluators[name] = { name: name, class: klass, index: index }
225
+ @evaluators[name] = {name: name, class: klass, index: index}
226
226
  end
227
227
 
228
228
  def self.default_evaluators
@@ -12,4 +12,4 @@ module TensorStream
12
12
  buffer
13
13
  end
14
14
  end
15
- end
15
+ end
@@ -1,5 +1,5 @@
1
- require 'tensor_stream/evaluator/ruby_evaluator'
2
- require 'tensor_stream/evaluator/buffer'
1
+ require "tensor_stream/evaluator/ruby_evaluator"
2
+ require "tensor_stream/evaluator/buffer"
3
3
 
4
4
  module TensorStream
5
5
  module Evaluator
@@ -16,10 +16,10 @@ module TensorStream
16
16
  start_index = start.shift
17
17
  current_size = size.shift
18
18
  dimen_size = if current_size == -1
19
- input.size - 1
20
- else
21
- start_index + current_size - 1
22
- end
19
+ input.size - 1
20
+ else
21
+ start_index + current_size - 1
22
+ end
23
23
 
24
24
  input[start_index..dimen_size].collect do |item|
25
25
  if item.is_a?(Array)
@@ -87,9 +87,9 @@ module TensorStream
87
87
  d = dims.shift
88
88
 
89
89
  if input.is_a?(Array) && (get_rank(input) - 1) == dims.size
90
- row_to_dup = input.collect do |item|
90
+ row_to_dup = input.collect { |item|
91
91
  broadcast_dimensions(item, dims.dup)
92
- end
92
+ }
93
93
 
94
94
  row_to_dup + Array.new(d) { row_to_dup }.flatten(1)
95
95
  elsif input.is_a?(Array)
@@ -102,15 +102,15 @@ module TensorStream
102
102
  # handle 2 tensor math operations
103
103
  def vector_op(vector, vector2, switch = false, safe = true, &block)
104
104
  if get_rank(vector) < get_rank(vector2) # upgrade rank of A
105
- duplicated = Array.new(vector2.size) do
105
+ duplicated = Array.new(vector2.size) {
106
106
  vector
107
- end
107
+ }
108
108
  return vector_op(duplicated, vector2, switch, &block)
109
109
  end
110
110
 
111
111
  return yield(vector, vector2) unless vector.is_a?(Array)
112
112
 
113
- vector.each_with_index.collect do |input, index|
113
+ vector.each_with_index.collect { |input, index|
114
114
  next vector_op(input, vector2, switch, &block) if input.is_a?(Array) && get_rank(vector) > get_rank(vector2)
115
115
 
116
116
  if safe && vector2.is_a?(Array)
@@ -118,22 +118,22 @@ module TensorStream
118
118
  end
119
119
 
120
120
  z = if vector2.is_a?(Array)
121
- if index < vector2.size
122
- vector2[index]
123
- else
124
- raise 'incompatible tensor shapes used during op' if vector2.size != 1
125
- vector2[0]
126
- end
127
- else
128
- vector2
129
- end
121
+ if index < vector2.size
122
+ vector2[index]
123
+ else
124
+ raise "incompatible tensor shapes used during op" if vector2.size != 1
125
+ vector2[0]
126
+ end
127
+ else
128
+ vector2
129
+ end
130
130
 
131
131
  if input.is_a?(Array)
132
132
  vector_op(input, z, switch, &block)
133
133
  else
134
134
  switch ? yield(z, input) : yield(input, z)
135
135
  end
136
- end.compact
136
+ }.compact
137
137
  end
138
138
 
139
139
  def shape_diff(shape_a, shape_b)
@@ -142,11 +142,11 @@ module TensorStream
142
142
  reversed_a = shape_a.reverse
143
143
  reversed_b = shape_b.reverse
144
144
 
145
- reversed_a.each_with_index.collect do |s, index|
145
+ reversed_a.each_with_index.collect { |s, index|
146
146
  next s if index >= reversed_b.size
147
147
  return nil if reversed_b[index] > s
148
148
  s - reversed_b[index]
149
- end.reverse
149
+ }.reverse
150
150
  end
151
151
 
152
152
  def tile_arr(input, dimen, multiples)
@@ -155,9 +155,9 @@ module TensorStream
155
155
  return nil if t.zero?
156
156
  input * t # ruby array dup
157
157
  else
158
- new_arr = input.collect do |sub|
158
+ new_arr = input.collect { |sub|
159
159
  tile_arr(sub, dimen + 1, multiples)
160
- end.compact
160
+ }.compact
161
161
 
162
162
  return nil if new_arr.empty?
163
163
 
@@ -234,13 +234,13 @@ module TensorStream
234
234
  # general case transposition with flat arrays
235
235
  def transpose_with_perm(arr, new_arr, shape, new_shape, perm)
236
236
  arr_size = shape.reduce(:*)
237
- divisors = shape.dup.drop(1).reverse.inject([1]) do |a, s|
237
+ divisors = shape.dup.drop(1).reverse.inject([1]) { |a, s|
238
238
  a << s * a.last
239
- end.reverse
239
+ }.reverse
240
240
 
241
- multipliers = new_shape.dup.drop(1).reverse.inject([1]) do |a, s|
241
+ multipliers = new_shape.dup.drop(1).reverse.inject([1]) { |a, s|
242
242
  a << s * a.last
243
- end.reverse
243
+ }.reverse
244
244
 
245
245
  arr_size.times do |p|
246
246
  ptr = p
@@ -267,19 +267,19 @@ module TensorStream
267
267
  def reduce_axis(current_axis, axis, val, keep_dims, &block)
268
268
  return val unless val.is_a?(Array)
269
269
 
270
- r = val.collect do |v|
270
+ r = val.collect { |v|
271
271
  reduce_axis(current_axis + 1, axis, v, keep_dims, &block)
272
- end
272
+ }
273
273
 
274
274
  should_reduce_axis = axis.nil? || (axis.is_a?(Array) && axis.include?(current_axis)) || (current_axis == axis)
275
275
 
276
276
  if should_reduce_axis
277
277
  reduced_val = r[0]
278
278
  if r.size > 1
279
- if block_given?
280
- reduced_val = yield(r[0..val.size])
279
+ reduced_val = if block_given?
280
+ yield(r[0..val.size])
281
281
  else
282
- reduced_val = r[0..val.size].reduce(:+)
282
+ r[0..val.size].reduce(:+)
283
283
  end
284
284
  elsif r.empty?
285
285
  reduced_val = yield(nil)
@@ -292,17 +292,17 @@ module TensorStream
292
292
 
293
293
  def reduce(val, axis, keep_dims, &block)
294
294
  rank = get_rank(val)
295
- return val if axis && axis.is_a?(Array) && axis.empty?
295
+ return val if axis&.is_a?(Array) && axis&.empty?
296
296
 
297
297
  axis = if axis.nil?
298
- nil
299
- elsif axis.is_a?(Array)
300
- return val if axis.empty?
298
+ nil
299
+ elsif axis.is_a?(Array)
300
+ return val if axis.empty?
301
301
 
302
- axis.map { |a| a < 0 ? rank - a.abs : a }
303
- else
304
- axis < 0 ? rank - axis.abs : axis
305
- end
302
+ axis.map { |a| a < 0 ? rank - a.abs : a }
303
+ else
304
+ axis < 0 ? rank - axis.abs : axis
305
+ end
306
306
 
307
307
  reduce_axis(0, axis, val, keep_dims, &block)
308
308
  end
@@ -332,4 +332,4 @@ module TensorStream
332
332
  end
333
333
  end
334
334
  end
335
- end
335
+ end
@@ -6,4 +6,4 @@ module TensorStream
6
6
  1 / (1 + Math.exp(-val))
7
7
  end
8
8
  end
9
- end
9
+ end
@@ -1,6 +1,6 @@
1
1
  module TensorStream
2
2
  module ArrayOps
3
- def ArrayOps.included(klass)
3
+ def self.included(klass)
4
4
  klass.class_eval do
5
5
  register_op :slice do |context, tensor, inputs|
6
6
  input = inputs[0]
@@ -41,17 +41,17 @@ module TensorStream
41
41
  new_shape = [inputs.size]
42
42
  shape.inject(new_shape) { |ns, s| ns << s }
43
43
 
44
- divisors = new_shape.dup.drop(1).reverse.inject([1]) do |a, s|
44
+ divisors = new_shape.dup.drop(1).reverse.inject([1]) { |a, s|
45
45
  a << s * a.last
46
- end.reverse
46
+ }.reverse
47
47
 
48
48
  axis = rank + axis if axis < 0
49
49
  rotated_shape = Array.new(axis + 1) { new_shape.shift }
50
50
  new_shape = rotated_shape.rotate! + new_shape
51
51
 
52
- multipliers = new_shape.dup.drop(1).reverse.inject([1]) do |a, s|
52
+ multipliers = new_shape.dup.drop(1).reverse.inject([1]) { |a, s|
53
53
  a << s * a.last
54
- end.reverse
54
+ }.reverse
55
55
 
56
56
  inputs.each_with_index do |input, index|
57
57
  raw_input = input.is_a?(Array) ? input.flatten : [input]
@@ -85,18 +85,18 @@ module TensorStream
85
85
  new_shape = shape_eval(inputs[0])
86
86
  rank = new_shape.size - 1
87
87
 
88
- divisors = new_shape.dup.drop(1).reverse.inject([1]) do |a, s|
88
+ divisors = new_shape.dup.drop(1).reverse.inject([1]) { |a, s|
89
89
  a << s * a.last
90
- end.reverse
90
+ }.reverse
91
91
 
92
92
  axis = rank + axis if axis < 0
93
93
  rotated_shape = Array.new(axis + 1) { new_shape.shift }
94
94
  new_shape = rotated_shape.rotate!(-1) + new_shape
95
95
  output_buffer = Array.new(new_shape.reduce(:*)) { 0 }
96
96
 
97
- multipliers = new_shape.dup.drop(1).reverse.inject([1]) do |a, s|
97
+ multipliers = new_shape.dup.drop(1).reverse.inject([1]) { |a, s|
98
98
  a << s * a.last
99
- end.reverse
99
+ }.reverse
100
100
 
101
101
  inputs.each_with_index do |input, index|
102
102
  raw_input = input.is_a?(Array) ? input.flatten : [input]
@@ -249,16 +249,16 @@ module TensorStream
249
249
 
250
250
  register_op %i[zeros ones zeros_like ones_like] do |_context, tensor, inputs|
251
251
  shape = if %i[zeros_like ones_like].include?(tensor.operation)
252
- shape_eval(inputs[0])
253
- else
254
- inputs[0] || tensor.shape.shape
255
- end
252
+ shape_eval(inputs[0])
253
+ else
254
+ inputs[0] || tensor.shape.shape
255
+ end
256
256
 
257
257
  func = if %i[zeros zeros_like].include?(tensor.operation)
258
- -> { int_type?(tensor.data_type) ? 0 : 0.0 }
259
- else
260
- -> { int_type?(tensor.data_type) ? 1 : 1.0 }
261
- end
258
+ -> { int_type?(tensor.data_type) ? 0 : 0.0 }
259
+ else
260
+ -> { int_type?(tensor.data_type) ? 1 : 1.0 }
261
+ end
262
262
  if shape.is_a?(Array) && shape.size.zero?
263
263
  func.call
264
264
  else
@@ -288,23 +288,23 @@ module TensorStream
288
288
 
289
289
  value_shape = shape_eval(value)
290
290
  res = if num_split.is_a?(Array)
291
- begin_index = 0
292
- num_split.collect do |num|
293
- end_index = begin_index + num
294
- arr = split_tensor(value, begin_index, end_index, axis)
295
- begin_index = end_index
296
- arr
297
- end
298
- else
299
- raise TensorStream::ValueError, "#{num_split} does not divide #{value_shape[axis]} evenly" if value_shape[axis] % num_split != 0
300
-
301
- piece_sizes = value_shape[axis] / num_split
302
- Array.new(num_split) do |num|
303
- begin_index = num * piece_sizes
304
- end_index = begin_index + piece_sizes
305
- split_tensor(value, begin_index, end_index, axis)
306
- end
307
- end
291
+ begin_index = 0
292
+ num_split.collect do |num|
293
+ end_index = begin_index + num
294
+ arr = split_tensor(value, begin_index, end_index, axis)
295
+ begin_index = end_index
296
+ arr
297
+ end
298
+ else
299
+ raise TensorStream::ValueError, "#{num_split} does not divide #{value_shape[axis]} evenly" if value_shape[axis] % num_split != 0
300
+
301
+ piece_sizes = value_shape[axis] / num_split
302
+ Array.new(num_split) do |num|
303
+ begin_index = num * piece_sizes
304
+ end_index = begin_index + piece_sizes
305
+ split_tensor(value, begin_index, end_index, axis)
306
+ end
307
+ end
308
308
  TensorStream::Evaluator::OutputGroup.new(res, res.map { tensor.inputs[0].data_type })
309
309
  end
310
310
 
@@ -326,7 +326,7 @@ module TensorStream
326
326
  register_op :tile do |_context, _tensor, inputs|
327
327
  input, multiples = inputs
328
328
  rank = get_rank(input)
329
- raise '1D or higher tensor required' if rank.zero?
329
+ raise "1D or higher tensor required" if rank.zero?
330
330
  raise "invalid multiple size passed #{rank} != #{multiples.size}" if rank != multiples.size
331
331
 
332
332
  tile = tile_arr(input, 0, multiples)
@@ -343,9 +343,9 @@ module TensorStream
343
343
  end
344
344
 
345
345
  register_op :shape_n do |_context, tensor, inputs|
346
- shapes = inputs.collect do |input|
346
+ shapes = inputs.collect { |input|
347
347
  shape_eval(input)
348
- end
348
+ }
349
349
  TensorStream::Evaluator::OutputGroup.new(shapes, shapes.map { tensor.options[:out_type] })
350
350
  end
351
351
 
@@ -372,7 +372,7 @@ module TensorStream
372
372
 
373
373
  if tensor.options[:exclusive]
374
374
  p_true = pred.each_with_index.collect { |p, index| [p, index] }.select { |a| a[0] }
375
- raise TensorStream::ValueError, "more than one predicate returns true pos #{p_true.map { |a| a[1] }.join(',')}" if p_true.size > 1
375
+ raise TensorStream::ValueError, "more than one predicate returns true pos #{p_true.map { |a| a[1] }.join(",")}" if p_true.size > 1
376
376
  end
377
377
 
378
378
  pred.each_with_index do |p, index|
@@ -412,4 +412,4 @@ module TensorStream
412
412
  end
413
413
  end
414
414
  end
415
- end
415
+ end