tensor_stream 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +1 -0
  3. data/.rubocop.yml +1 -0
  4. data/Gemfile +1 -1
  5. data/LICENSE.txt +1 -1
  6. data/README.md +34 -34
  7. data/Rakefile +3 -3
  8. data/USAGE_GUIDE.md +235 -0
  9. data/bin/stubgen +20 -0
  10. data/exe/model_utils +2 -2
  11. data/lib/tensor_stream.rb +45 -44
  12. data/lib/tensor_stream/constant.rb +2 -2
  13. data/lib/tensor_stream/control_flow.rb +1 -1
  14. data/lib/tensor_stream/debugging/debugging.rb +2 -2
  15. data/lib/tensor_stream/dynamic_stitch.rb +2 -2
  16. data/lib/tensor_stream/evaluator/base_evaluator.rb +18 -18
  17. data/lib/tensor_stream/evaluator/buffer.rb +1 -1
  18. data/lib/tensor_stream/evaluator/evaluator.rb +2 -2
  19. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +41 -41
  20. data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +1 -1
  21. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +39 -39
  22. data/lib/tensor_stream/evaluator/ruby/check_ops.rb +2 -2
  23. data/lib/tensor_stream/evaluator/ruby/images_ops.rb +18 -18
  24. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +13 -14
  25. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +33 -36
  26. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +20 -21
  27. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +36 -49
  28. data/lib/tensor_stream/exceptions.rb +1 -1
  29. data/lib/tensor_stream/generated_stub/ops.rb +691 -0
  30. data/lib/tensor_stream/generated_stub/stub_file.erb +24 -0
  31. data/lib/tensor_stream/graph.rb +18 -18
  32. data/lib/tensor_stream/graph_builder.rb +17 -17
  33. data/lib/tensor_stream/graph_deserializers/protobuf.rb +97 -97
  34. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +1 -1
  35. data/lib/tensor_stream/graph_keys.rb +3 -3
  36. data/lib/tensor_stream/graph_serializers/graphml.rb +33 -33
  37. data/lib/tensor_stream/graph_serializers/packer.rb +23 -23
  38. data/lib/tensor_stream/graph_serializers/pbtext.rb +38 -42
  39. data/lib/tensor_stream/graph_serializers/serializer.rb +3 -2
  40. data/lib/tensor_stream/graph_serializers/yaml.rb +5 -5
  41. data/lib/tensor_stream/helpers/infer_shape.rb +56 -56
  42. data/lib/tensor_stream/helpers/op_helper.rb +8 -9
  43. data/lib/tensor_stream/helpers/string_helper.rb +15 -15
  44. data/lib/tensor_stream/helpers/tensor_mixins.rb +17 -17
  45. data/lib/tensor_stream/images.rb +1 -1
  46. data/lib/tensor_stream/initializer.rb +1 -1
  47. data/lib/tensor_stream/math_gradients.rb +28 -187
  48. data/lib/tensor_stream/monkey_patches/array.rb +1 -1
  49. data/lib/tensor_stream/monkey_patches/float.rb +1 -1
  50. data/lib/tensor_stream/monkey_patches/integer.rb +1 -1
  51. data/lib/tensor_stream/monkey_patches/op_patch.rb +5 -5
  52. data/lib/tensor_stream/monkey_patches/patch.rb +1 -1
  53. data/lib/tensor_stream/nn/nn_ops.rb +17 -15
  54. data/lib/tensor_stream/op_maker.rb +180 -0
  55. data/lib/tensor_stream/operation.rb +17 -17
  56. data/lib/tensor_stream/ops.rb +95 -384
  57. data/lib/tensor_stream/ops/add.rb +23 -0
  58. data/lib/tensor_stream/ops/argmax.rb +14 -0
  59. data/lib/tensor_stream/ops/argmin.rb +14 -0
  60. data/lib/tensor_stream/ops/case.rb +17 -0
  61. data/lib/tensor_stream/ops/cast.rb +15 -0
  62. data/lib/tensor_stream/ops/ceil.rb +15 -0
  63. data/lib/tensor_stream/ops/const.rb +0 -0
  64. data/lib/tensor_stream/ops/cos.rb +10 -0
  65. data/lib/tensor_stream/ops/div.rb +21 -0
  66. data/lib/tensor_stream/ops/equal.rb +15 -0
  67. data/lib/tensor_stream/ops/expand_dims.rb +17 -0
  68. data/lib/tensor_stream/ops/fill.rb +19 -0
  69. data/lib/tensor_stream/ops/floor.rb +15 -0
  70. data/lib/tensor_stream/ops/floor_div.rb +15 -0
  71. data/lib/tensor_stream/ops/greater.rb +11 -0
  72. data/lib/tensor_stream/ops/greater_equal.rb +11 -0
  73. data/lib/tensor_stream/ops/less_equal.rb +15 -0
  74. data/lib/tensor_stream/ops/log.rb +14 -0
  75. data/lib/tensor_stream/ops/mat_mul.rb +60 -0
  76. data/lib/tensor_stream/ops/max.rb +15 -0
  77. data/lib/tensor_stream/ops/min.rb +15 -0
  78. data/lib/tensor_stream/ops/mod.rb +23 -0
  79. data/lib/tensor_stream/ops/mul.rb +21 -0
  80. data/lib/tensor_stream/ops/negate.rb +14 -0
  81. data/lib/tensor_stream/ops/ones_like.rb +19 -0
  82. data/lib/tensor_stream/ops/pow.rb +25 -0
  83. data/lib/tensor_stream/ops/prod.rb +60 -0
  84. data/lib/tensor_stream/ops/random_uniform.rb +18 -0
  85. data/lib/tensor_stream/ops/range.rb +20 -0
  86. data/lib/tensor_stream/ops/rank.rb +13 -0
  87. data/lib/tensor_stream/ops/reshape.rb +24 -0
  88. data/lib/tensor_stream/ops/round.rb +15 -0
  89. data/lib/tensor_stream/ops/shape.rb +14 -0
  90. data/lib/tensor_stream/ops/sigmoid.rb +10 -0
  91. data/lib/tensor_stream/ops/sign.rb +12 -0
  92. data/lib/tensor_stream/ops/sin.rb +10 -0
  93. data/lib/tensor_stream/ops/size.rb +16 -0
  94. data/lib/tensor_stream/ops/sub.rb +24 -0
  95. data/lib/tensor_stream/ops/sum.rb +27 -0
  96. data/lib/tensor_stream/ops/tan.rb +12 -0
  97. data/lib/tensor_stream/ops/tanh.rb +10 -0
  98. data/lib/tensor_stream/ops/tile.rb +19 -0
  99. data/lib/tensor_stream/ops/zeros.rb +15 -0
  100. data/lib/tensor_stream/placeholder.rb +2 -2
  101. data/lib/tensor_stream/profile/report_tool.rb +3 -3
  102. data/lib/tensor_stream/session.rb +36 -38
  103. data/lib/tensor_stream/tensor.rb +2 -2
  104. data/lib/tensor_stream/tensor_shape.rb +4 -4
  105. data/lib/tensor_stream/train/adadelta_optimizer.rb +8 -8
  106. data/lib/tensor_stream/train/adagrad_optimizer.rb +3 -3
  107. data/lib/tensor_stream/train/adam_optimizer.rb +11 -11
  108. data/lib/tensor_stream/train/learning_rate_decay.rb +2 -2
  109. data/lib/tensor_stream/train/momentum_optimizer.rb +7 -7
  110. data/lib/tensor_stream/train/optimizer.rb +9 -9
  111. data/lib/tensor_stream/train/rmsprop_optimizer.rb +16 -16
  112. data/lib/tensor_stream/train/saver.rb +14 -14
  113. data/lib/tensor_stream/train/slot_creator.rb +6 -6
  114. data/lib/tensor_stream/train/utils.rb +12 -12
  115. data/lib/tensor_stream/trainer.rb +10 -10
  116. data/lib/tensor_stream/types.rb +1 -1
  117. data/lib/tensor_stream/utils.rb +33 -32
  118. data/lib/tensor_stream/utils/freezer.rb +5 -5
  119. data/lib/tensor_stream/variable.rb +5 -5
  120. data/lib/tensor_stream/variable_scope.rb +1 -1
  121. data/lib/tensor_stream/version.rb +1 -1
  122. data/samples/{iris.data → datasets/iris.data} +0 -0
  123. data/samples/jupyter_notebooks/linear_regression.ipynb +463 -0
  124. data/samples/{iris.rb → neural_networks/iris.rb} +21 -23
  125. data/samples/{mnist_data.rb → neural_networks/mnist_data.rb} +8 -8
  126. data/samples/neural_networks/raw_neural_net_sample.rb +112 -0
  127. data/samples/{rnn.rb → neural_networks/rnn.rb} +28 -31
  128. data/samples/{nearest_neighbor.rb → others/nearest_neighbor.rb} +12 -12
  129. data/samples/regression/linear_regression.rb +63 -0
  130. data/samples/{logistic_regression.rb → regression/logistic_regression.rb} +14 -16
  131. data/tensor_stream.gemspec +9 -8
  132. metadata +89 -19
  133. data/data_1.json +0 -4764
  134. data/data_2.json +0 -4764
  135. data/data_actual.json +0 -28
  136. data/data_expected.json +0 -28
  137. data/data_input.json +0 -28
  138. data/samples/error.graphml +0 -2755
  139. data/samples/gradient_sample.graphml +0 -1255
  140. data/samples/linear_regression.rb +0 -69
  141. data/samples/multigpu.rb +0 -73
  142. data/samples/raw_neural_net_sample.rb +0 -112
@@ -11,7 +11,7 @@ module TensorStream
11
11
  @options = options
12
12
  @is_const = true
13
13
  @internal = options[:internal]
14
- @name = [@graph.get_name_scope, options[:name] || build_name].compact.reject(&:empty?).join('/')
14
+ @name = [@graph.get_name_scope, options[:name] || build_name].compact.reject(&:empty?).join("/")
15
15
  @given_name = @name
16
16
 
17
17
  if options[:value]
@@ -42,4 +42,4 @@ module TensorStream
42
42
  "Const"
43
43
  end
44
44
  end
45
- end
45
+ end
@@ -8,7 +8,7 @@ module TensorStream
8
8
  @options = options
9
9
  @operation = :"flow_#{flow_type}"
10
10
  @inputs = inputs
11
- @name = [@graph.get_name_scope, options[:name] || set_name].compact.join('/')
11
+ @name = [@graph.get_name_scope, options[:name] || set_name].compact.join("/")
12
12
  @ops = ops
13
13
  @consumers = Set.new
14
14
  @shape = TensorShape.new([inputs.size])
@@ -12,7 +12,7 @@ module TensorStream
12
12
  next input if input.is_a?(Variable)
13
13
 
14
14
  if input.is_a?(Tensor) && TensorStream::Ops::FLOATING_POINT_TYPES.include?(input.data_type)
15
- TensorStream.check_numerics(input, "#{node.name}/#{input.name}", name: "check/#{node.name}/#{input.name}" )
15
+ TensorStream.check_numerics(input, "#{node.name}/#{input.name}", name: "check/#{node.name}/#{input.name}")
16
16
  else
17
17
  input
18
18
  end
@@ -20,4 +20,4 @@ module TensorStream
20
20
  end
21
21
  end
22
22
  end
23
- end
23
+ end
@@ -12,7 +12,7 @@ module TensorStream
12
12
 
13
13
  @consumers = Set.new
14
14
  @data_type = Tensor.detect_type(inputs[1])
15
- @name = [@graph.get_name_scope, options[:name] || set_name].compact.join('/')
15
+ @name = [@graph.get_name_scope, options[:name] || set_name].compact.join("/")
16
16
  @ops = ops
17
17
  @shape = TensorShape.new(nil)
18
18
  @graph.add_node(self)
@@ -26,4 +26,4 @@ module TensorStream
26
26
  eval
27
27
  end
28
28
  end
29
- end
29
+ end
@@ -34,20 +34,20 @@ module TensorStream
34
34
  ##
35
35
  # Query all supported devices
36
36
  def self.query_supported_devices
37
- [Device.new('cpu', :cpu, self)]
37
+ [Device.new("cpu", :cpu, self)]
38
38
  end
39
39
 
40
40
  ##
41
41
  # Select the best device available in the system for this evaluator
42
42
  def self.default_device
43
- Device.new('cpu', :cpu, self)
43
+ Device.new("cpu", :cpu, self)
44
44
  end
45
45
 
46
46
  ##
47
47
  # Selects the best device with the specified query, query can
48
48
  # be evaluator specific
49
49
  def self.fetch_device(_query = [])
50
- Device.new('cpu', :cpu, self)
50
+ Device.new("cpu", :cpu, self)
51
51
  end
52
52
 
53
53
  ##
@@ -56,12 +56,12 @@ module TensorStream
56
56
  return default_device if query.nil? || query == :default
57
57
 
58
58
  all_devices = query_supported_devices
59
- substrs = query.split('/')
59
+ substrs = query.split("/")
60
60
  substrs.each do |q|
61
- components = q.split(':')
61
+ components = q.split(":")
62
62
  next if components.size.zero?
63
63
 
64
- if components[0] == 'device' # use tensorflow convention
64
+ if components[0] == "device" # use tensorflow convention
65
65
  device_type = components[1]
66
66
  select_index = components[2].to_i
67
67
 
@@ -79,7 +79,7 @@ module TensorStream
79
79
 
80
80
  select_index = [devices.size - 1, select_index].min
81
81
  return devices[select_index]
82
- elsif components[0] == 'ts' # tensorstream specific
82
+ elsif components[0] == "ts" # tensorstream specific
83
83
  evaluator_class = TensorStream::Evaluator.evaluators[components[1]][:class]
84
84
  return nil unless self == evaluator_class
85
85
  return evaluator_class.fetch_device(components[2..components.size]) if evaluator_class.respond_to?(:fetch_device)
@@ -95,10 +95,10 @@ module TensorStream
95
95
  @ops ||= {}
96
96
  if opcode.is_a?(Array)
97
97
  opcode.each do |op|
98
- @ops[op.to_sym] = { options: options, block: block }
98
+ @ops[op.to_sym] = {options: options, block: block}
99
99
  end
100
100
  else
101
- @ops[opcode.to_sym] = { options: options, block: block }
101
+ @ops[opcode.to_sym] = {options: options, block: block}
102
102
  end
103
103
  end
104
104
 
@@ -115,7 +115,7 @@ module TensorStream
115
115
  op = self.class.ops[tensor.operation.to_sym]
116
116
  op_options = op[:options]
117
117
 
118
- resolved_inputs = tensor.inputs.map do |i|
118
+ resolved_inputs = tensor.inputs.map { |i|
119
119
  next if i.nil?
120
120
  next i if op_options[:noop]
121
121
 
@@ -124,25 +124,25 @@ module TensorStream
124
124
  end
125
125
 
126
126
  global_eval(tensor, i, execution_context, op_options)
127
- end
127
+ }
128
128
 
129
129
  start_time = if profile_enabled?
130
- time = Time.now
131
- time.to_i * (10**9) + time.nsec
132
- end
130
+ time = Time.now
131
+ time.to_i * (10**9) + time.nsec
132
+ end
133
133
 
134
134
  instance_exec(execution_context, tensor, resolved_inputs, &op[:block]).tap do
135
135
  if profile_enabled?
136
136
  time = Time.now
137
137
  end_time = time.to_i * (10**9) + time.nsec
138
- @context[:profile] ||= { step: 0, operations: {} }
138
+ @context[:profile] ||= {step: 0, operations: {}}
139
139
  @context[:profile][:step] += 1
140
- @context[:profile][:operations][tensor.name] = { op: tensor.operation,
140
+ @context[:profile][:operations][tensor.name] = {op: tensor.operation,
141
141
  step: @context[:profile][:step],
142
142
  eval_time: end_time - start_time,
143
143
  shape: tensor.shape ? tensor.shape.shape : nil,
144
144
  data_type: tensor.data_type,
145
- tensor: tensor }
145
+ tensor: tensor,}
146
146
  end
147
147
  end
148
148
  end
@@ -222,7 +222,7 @@ module TensorStream
222
222
 
223
223
  def self.register_evaluator(klass, name, index = 0)
224
224
  @evaluators ||= {}
225
- @evaluators[name] = { name: name, class: klass, index: index }
225
+ @evaluators[name] = {name: name, class: klass, index: index}
226
226
  end
227
227
 
228
228
  def self.default_evaluators
@@ -12,4 +12,4 @@ module TensorStream
12
12
  buffer
13
13
  end
14
14
  end
15
- end
15
+ end
@@ -1,5 +1,5 @@
1
- require 'tensor_stream/evaluator/ruby_evaluator'
2
- require 'tensor_stream/evaluator/buffer'
1
+ require "tensor_stream/evaluator/ruby_evaluator"
2
+ require "tensor_stream/evaluator/buffer"
3
3
 
4
4
  module TensorStream
5
5
  module Evaluator
@@ -16,10 +16,10 @@ module TensorStream
16
16
  start_index = start.shift
17
17
  current_size = size.shift
18
18
  dimen_size = if current_size == -1
19
- input.size - 1
20
- else
21
- start_index + current_size - 1
22
- end
19
+ input.size - 1
20
+ else
21
+ start_index + current_size - 1
22
+ end
23
23
 
24
24
  input[start_index..dimen_size].collect do |item|
25
25
  if item.is_a?(Array)
@@ -87,9 +87,9 @@ module TensorStream
87
87
  d = dims.shift
88
88
 
89
89
  if input.is_a?(Array) && (get_rank(input) - 1) == dims.size
90
- row_to_dup = input.collect do |item|
90
+ row_to_dup = input.collect { |item|
91
91
  broadcast_dimensions(item, dims.dup)
92
- end
92
+ }
93
93
 
94
94
  row_to_dup + Array.new(d) { row_to_dup }.flatten(1)
95
95
  elsif input.is_a?(Array)
@@ -102,15 +102,15 @@ module TensorStream
102
102
  # handle 2 tensor math operations
103
103
  def vector_op(vector, vector2, switch = false, safe = true, &block)
104
104
  if get_rank(vector) < get_rank(vector2) # upgrade rank of A
105
- duplicated = Array.new(vector2.size) do
105
+ duplicated = Array.new(vector2.size) {
106
106
  vector
107
- end
107
+ }
108
108
  return vector_op(duplicated, vector2, switch, &block)
109
109
  end
110
110
 
111
111
  return yield(vector, vector2) unless vector.is_a?(Array)
112
112
 
113
- vector.each_with_index.collect do |input, index|
113
+ vector.each_with_index.collect { |input, index|
114
114
  next vector_op(input, vector2, switch, &block) if input.is_a?(Array) && get_rank(vector) > get_rank(vector2)
115
115
 
116
116
  if safe && vector2.is_a?(Array)
@@ -118,22 +118,22 @@ module TensorStream
118
118
  end
119
119
 
120
120
  z = if vector2.is_a?(Array)
121
- if index < vector2.size
122
- vector2[index]
123
- else
124
- raise 'incompatible tensor shapes used during op' if vector2.size != 1
125
- vector2[0]
126
- end
127
- else
128
- vector2
129
- end
121
+ if index < vector2.size
122
+ vector2[index]
123
+ else
124
+ raise "incompatible tensor shapes used during op" if vector2.size != 1
125
+ vector2[0]
126
+ end
127
+ else
128
+ vector2
129
+ end
130
130
 
131
131
  if input.is_a?(Array)
132
132
  vector_op(input, z, switch, &block)
133
133
  else
134
134
  switch ? yield(z, input) : yield(input, z)
135
135
  end
136
- end.compact
136
+ }.compact
137
137
  end
138
138
 
139
139
  def shape_diff(shape_a, shape_b)
@@ -142,11 +142,11 @@ module TensorStream
142
142
  reversed_a = shape_a.reverse
143
143
  reversed_b = shape_b.reverse
144
144
 
145
- reversed_a.each_with_index.collect do |s, index|
145
+ reversed_a.each_with_index.collect { |s, index|
146
146
  next s if index >= reversed_b.size
147
147
  return nil if reversed_b[index] > s
148
148
  s - reversed_b[index]
149
- end.reverse
149
+ }.reverse
150
150
  end
151
151
 
152
152
  def tile_arr(input, dimen, multiples)
@@ -155,9 +155,9 @@ module TensorStream
155
155
  return nil if t.zero?
156
156
  input * t # ruby array dup
157
157
  else
158
- new_arr = input.collect do |sub|
158
+ new_arr = input.collect { |sub|
159
159
  tile_arr(sub, dimen + 1, multiples)
160
- end.compact
160
+ }.compact
161
161
 
162
162
  return nil if new_arr.empty?
163
163
 
@@ -234,13 +234,13 @@ module TensorStream
234
234
  # general case transposition with flat arrays
235
235
  def transpose_with_perm(arr, new_arr, shape, new_shape, perm)
236
236
  arr_size = shape.reduce(:*)
237
- divisors = shape.dup.drop(1).reverse.inject([1]) do |a, s|
237
+ divisors = shape.dup.drop(1).reverse.inject([1]) { |a, s|
238
238
  a << s * a.last
239
- end.reverse
239
+ }.reverse
240
240
 
241
- multipliers = new_shape.dup.drop(1).reverse.inject([1]) do |a, s|
241
+ multipliers = new_shape.dup.drop(1).reverse.inject([1]) { |a, s|
242
242
  a << s * a.last
243
- end.reverse
243
+ }.reverse
244
244
 
245
245
  arr_size.times do |p|
246
246
  ptr = p
@@ -267,19 +267,19 @@ module TensorStream
267
267
  def reduce_axis(current_axis, axis, val, keep_dims, &block)
268
268
  return val unless val.is_a?(Array)
269
269
 
270
- r = val.collect do |v|
270
+ r = val.collect { |v|
271
271
  reduce_axis(current_axis + 1, axis, v, keep_dims, &block)
272
- end
272
+ }
273
273
 
274
274
  should_reduce_axis = axis.nil? || (axis.is_a?(Array) && axis.include?(current_axis)) || (current_axis == axis)
275
275
 
276
276
  if should_reduce_axis
277
277
  reduced_val = r[0]
278
278
  if r.size > 1
279
- if block_given?
280
- reduced_val = yield(r[0..val.size])
279
+ reduced_val = if block_given?
280
+ yield(r[0..val.size])
281
281
  else
282
- reduced_val = r[0..val.size].reduce(:+)
282
+ r[0..val.size].reduce(:+)
283
283
  end
284
284
  elsif r.empty?
285
285
  reduced_val = yield(nil)
@@ -292,17 +292,17 @@ module TensorStream
292
292
 
293
293
  def reduce(val, axis, keep_dims, &block)
294
294
  rank = get_rank(val)
295
- return val if axis && axis.is_a?(Array) && axis.empty?
295
+ return val if axis&.is_a?(Array) && axis&.empty?
296
296
 
297
297
  axis = if axis.nil?
298
- nil
299
- elsif axis.is_a?(Array)
300
- return val if axis.empty?
298
+ nil
299
+ elsif axis.is_a?(Array)
300
+ return val if axis.empty?
301
301
 
302
- axis.map { |a| a < 0 ? rank - a.abs : a }
303
- else
304
- axis < 0 ? rank - axis.abs : axis
305
- end
302
+ axis.map { |a| a < 0 ? rank - a.abs : a }
303
+ else
304
+ axis < 0 ? rank - axis.abs : axis
305
+ end
306
306
 
307
307
  reduce_axis(0, axis, val, keep_dims, &block)
308
308
  end
@@ -332,4 +332,4 @@ module TensorStream
332
332
  end
333
333
  end
334
334
  end
335
- end
335
+ end
@@ -6,4 +6,4 @@ module TensorStream
6
6
  1 / (1 + Math.exp(-val))
7
7
  end
8
8
  end
9
- end
9
+ end
@@ -1,6 +1,6 @@
1
1
  module TensorStream
2
2
  module ArrayOps
3
- def ArrayOps.included(klass)
3
+ def self.included(klass)
4
4
  klass.class_eval do
5
5
  register_op :slice do |context, tensor, inputs|
6
6
  input = inputs[0]
@@ -41,17 +41,17 @@ module TensorStream
41
41
  new_shape = [inputs.size]
42
42
  shape.inject(new_shape) { |ns, s| ns << s }
43
43
 
44
- divisors = new_shape.dup.drop(1).reverse.inject([1]) do |a, s|
44
+ divisors = new_shape.dup.drop(1).reverse.inject([1]) { |a, s|
45
45
  a << s * a.last
46
- end.reverse
46
+ }.reverse
47
47
 
48
48
  axis = rank + axis if axis < 0
49
49
  rotated_shape = Array.new(axis + 1) { new_shape.shift }
50
50
  new_shape = rotated_shape.rotate! + new_shape
51
51
 
52
- multipliers = new_shape.dup.drop(1).reverse.inject([1]) do |a, s|
52
+ multipliers = new_shape.dup.drop(1).reverse.inject([1]) { |a, s|
53
53
  a << s * a.last
54
- end.reverse
54
+ }.reverse
55
55
 
56
56
  inputs.each_with_index do |input, index|
57
57
  raw_input = input.is_a?(Array) ? input.flatten : [input]
@@ -85,18 +85,18 @@ module TensorStream
85
85
  new_shape = shape_eval(inputs[0])
86
86
  rank = new_shape.size - 1
87
87
 
88
- divisors = new_shape.dup.drop(1).reverse.inject([1]) do |a, s|
88
+ divisors = new_shape.dup.drop(1).reverse.inject([1]) { |a, s|
89
89
  a << s * a.last
90
- end.reverse
90
+ }.reverse
91
91
 
92
92
  axis = rank + axis if axis < 0
93
93
  rotated_shape = Array.new(axis + 1) { new_shape.shift }
94
94
  new_shape = rotated_shape.rotate!(-1) + new_shape
95
95
  output_buffer = Array.new(new_shape.reduce(:*)) { 0 }
96
96
 
97
- multipliers = new_shape.dup.drop(1).reverse.inject([1]) do |a, s|
97
+ multipliers = new_shape.dup.drop(1).reverse.inject([1]) { |a, s|
98
98
  a << s * a.last
99
- end.reverse
99
+ }.reverse
100
100
 
101
101
  inputs.each_with_index do |input, index|
102
102
  raw_input = input.is_a?(Array) ? input.flatten : [input]
@@ -249,16 +249,16 @@ module TensorStream
249
249
 
250
250
  register_op %i[zeros ones zeros_like ones_like] do |_context, tensor, inputs|
251
251
  shape = if %i[zeros_like ones_like].include?(tensor.operation)
252
- shape_eval(inputs[0])
253
- else
254
- inputs[0] || tensor.shape.shape
255
- end
252
+ shape_eval(inputs[0])
253
+ else
254
+ inputs[0] || tensor.shape.shape
255
+ end
256
256
 
257
257
  func = if %i[zeros zeros_like].include?(tensor.operation)
258
- -> { int_type?(tensor.data_type) ? 0 : 0.0 }
259
- else
260
- -> { int_type?(tensor.data_type) ? 1 : 1.0 }
261
- end
258
+ -> { int_type?(tensor.data_type) ? 0 : 0.0 }
259
+ else
260
+ -> { int_type?(tensor.data_type) ? 1 : 1.0 }
261
+ end
262
262
  if shape.is_a?(Array) && shape.size.zero?
263
263
  func.call
264
264
  else
@@ -288,23 +288,23 @@ module TensorStream
288
288
 
289
289
  value_shape = shape_eval(value)
290
290
  res = if num_split.is_a?(Array)
291
- begin_index = 0
292
- num_split.collect do |num|
293
- end_index = begin_index + num
294
- arr = split_tensor(value, begin_index, end_index, axis)
295
- begin_index = end_index
296
- arr
297
- end
298
- else
299
- raise TensorStream::ValueError, "#{num_split} does not divide #{value_shape[axis]} evenly" if value_shape[axis] % num_split != 0
300
-
301
- piece_sizes = value_shape[axis] / num_split
302
- Array.new(num_split) do |num|
303
- begin_index = num * piece_sizes
304
- end_index = begin_index + piece_sizes
305
- split_tensor(value, begin_index, end_index, axis)
306
- end
307
- end
291
+ begin_index = 0
292
+ num_split.collect do |num|
293
+ end_index = begin_index + num
294
+ arr = split_tensor(value, begin_index, end_index, axis)
295
+ begin_index = end_index
296
+ arr
297
+ end
298
+ else
299
+ raise TensorStream::ValueError, "#{num_split} does not divide #{value_shape[axis]} evenly" if value_shape[axis] % num_split != 0
300
+
301
+ piece_sizes = value_shape[axis] / num_split
302
+ Array.new(num_split) do |num|
303
+ begin_index = num * piece_sizes
304
+ end_index = begin_index + piece_sizes
305
+ split_tensor(value, begin_index, end_index, axis)
306
+ end
307
+ end
308
308
  TensorStream::Evaluator::OutputGroup.new(res, res.map { tensor.inputs[0].data_type })
309
309
  end
310
310
 
@@ -326,7 +326,7 @@ module TensorStream
326
326
  register_op :tile do |_context, _tensor, inputs|
327
327
  input, multiples = inputs
328
328
  rank = get_rank(input)
329
- raise '1D or higher tensor required' if rank.zero?
329
+ raise "1D or higher tensor required" if rank.zero?
330
330
  raise "invalid multiple size passed #{rank} != #{multiples.size}" if rank != multiples.size
331
331
 
332
332
  tile = tile_arr(input, 0, multiples)
@@ -343,9 +343,9 @@ module TensorStream
343
343
  end
344
344
 
345
345
  register_op :shape_n do |_context, tensor, inputs|
346
- shapes = inputs.collect do |input|
346
+ shapes = inputs.collect { |input|
347
347
  shape_eval(input)
348
- end
348
+ }
349
349
  TensorStream::Evaluator::OutputGroup.new(shapes, shapes.map { tensor.options[:out_type] })
350
350
  end
351
351
 
@@ -372,7 +372,7 @@ module TensorStream
372
372
 
373
373
  if tensor.options[:exclusive]
374
374
  p_true = pred.each_with_index.collect { |p, index| [p, index] }.select { |a| a[0] }
375
- raise TensorStream::ValueError, "more than one predicate returns true pos #{p_true.map { |a| a[1] }.join(',')}" if p_true.size > 1
375
+ raise TensorStream::ValueError, "more than one predicate returns true pos #{p_true.map { |a| a[1] }.join(",")}" if p_true.size > 1
376
376
  end
377
377
 
378
378
  pred.each_with_index do |p, index|
@@ -412,4 +412,4 @@ module TensorStream
412
412
  end
413
413
  end
414
414
  end
415
- end
415
+ end