tensor_stream 0.9.8 → 0.9.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +31 -14
  3. data/lib/tensor_stream.rb +4 -0
  4. data/lib/tensor_stream/constant.rb +41 -0
  5. data/lib/tensor_stream/control_flow.rb +2 -1
  6. data/lib/tensor_stream/dynamic_stitch.rb +3 -1
  7. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +4 -4
  8. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +74 -23
  9. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +45 -43
  10. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +31 -30
  11. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +6 -6
  12. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +46 -111
  13. data/lib/tensor_stream/graph.rb +61 -12
  14. data/lib/tensor_stream/graph_builder.rb +3 -3
  15. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +38 -0
  16. data/lib/tensor_stream/graph_serializers/packer.rb +8 -0
  17. data/lib/tensor_stream/graph_serializers/pbtext.rb +62 -27
  18. data/lib/tensor_stream/graph_serializers/serializer.rb +2 -2
  19. data/lib/tensor_stream/graph_serializers/yaml.rb +27 -0
  20. data/lib/tensor_stream/helpers/infer_shape.rb +15 -9
  21. data/lib/tensor_stream/helpers/op_helper.rb +17 -6
  22. data/lib/tensor_stream/helpers/string_helper.rb +32 -1
  23. data/lib/tensor_stream/helpers/tensor_mixins.rb +135 -0
  24. data/lib/tensor_stream/math_gradients.rb +19 -12
  25. data/lib/tensor_stream/monkey_patches/float.rb +7 -0
  26. data/lib/tensor_stream/monkey_patches/integer.rb +7 -0
  27. data/lib/tensor_stream/monkey_patches/patch.rb +8 -8
  28. data/lib/tensor_stream/nn/nn_ops.rb +1 -1
  29. data/lib/tensor_stream/operation.rb +98 -36
  30. data/lib/tensor_stream/ops.rb +65 -13
  31. data/lib/tensor_stream/placeholder.rb +2 -2
  32. data/lib/tensor_stream/session.rb +15 -3
  33. data/lib/tensor_stream/tensor.rb +15 -172
  34. data/lib/tensor_stream/tensor_shape.rb +3 -1
  35. data/lib/tensor_stream/train/saver.rb +12 -10
  36. data/lib/tensor_stream/trainer.rb +7 -2
  37. data/lib/tensor_stream/utils.rb +13 -11
  38. data/lib/tensor_stream/utils/freezer.rb +37 -0
  39. data/lib/tensor_stream/variable.rb +17 -11
  40. data/lib/tensor_stream/variable_scope.rb +3 -1
  41. data/lib/tensor_stream/version.rb +1 -1
  42. data/samples/iris.rb +3 -4
  43. data/samples/linear_regression.rb +9 -5
  44. data/samples/logistic_regression.rb +11 -9
  45. data/samples/mnist_data.rb +8 -10
  46. metadata +8 -4
@@ -7,21 +7,22 @@ module TensorStream
7
7
  target_var, learning_rate, delta = inputs
8
8
  assign = tensor.inputs[0] || tensor
9
9
 
10
- assign.value = process_vector_math_op(tensor, target_var, delta, context, ->(t, u) { t - u * learning_rate })
11
- assign.value
10
+ assign.container = process_vector_math_op(tensor, target_var, delta, context, ->(t, u) { t - u * learning_rate })
11
+ assign.container
12
12
  end
13
13
 
14
- register_op :apply_momentum do |context, tensor, inputs|
14
+ register_op :apply_momentum do |_context, tensor, inputs|
15
15
  target_var, momentum_var, learning_rate, grad, momentum = inputs
16
16
  assign = tensor.inputs[0] || tensor
17
17
  assign_acc = tensor.inputs[1]
18
- assign_acc.value = multi_array_op(->(t, u) { t * momentum + u }, momentum_var, grad)
19
- assign.value = if tensor.options[:use_nesterov]
20
- multi_array_op(->(v, g, acc) { v - (g * learning_rate + acc * momentum * learning_rate) }, target_var, grad, momentum_var)
21
- else
22
- multi_array_op(->(v, acc) { v - acc * learning_rate }, target_var, momentum_var)
23
- end
24
- assign.value
18
+ assign_acc.container = multi_array_op(->(t, u) { t * momentum + u }, momentum_var, grad)
19
+ assign.container = if tensor.options[:use_nesterov]
20
+ multi_array_op(->(v, g, acc) { v - (g * learning_rate + acc * momentum * learning_rate) }, target_var, grad, momentum_var)
21
+ else
22
+ multi_array_op(->(v, acc) { v - acc * learning_rate }, target_var, momentum_var)
23
+ end
24
+
25
+ assign.container
25
26
  end
26
27
 
27
28
  register_op :apply_adadelta do |_context, tensor, inputs|
@@ -29,19 +30,19 @@ module TensorStream
29
30
  assign = tensor.inputs[0] || tensor
30
31
  assign_acc = tensor.inputs[1]
31
32
  assign_acc_update = tensor.inputs[2]
32
- assign_acc.value = multi_array_op(->(acc_t, grad_t) { acc_t * rho + (grad_t * grad_t) * (1.0 - rho) }, accum, grad)
33
- update = multi_array_op(->(acc_update_t, acc_t, grad_t) { Math.sqrt(acc_update_t + epsilon) * (1.0 / Math.sqrt(acc_t + epsilon)) * grad_t }, accum_update, assign_acc.value, grad)
34
- assign.value = multi_array_op(->(v, u) { v - (u * lr) }, target_var, update)
35
- assign_acc_update.value = multi_array_op(->(acc_update_t, u) { acc_update_t * rho + (u * u) * (1.0 - rho) }, accum_update, update)
33
+ assign_acc.container = multi_array_op(->(acc_t, grad_t) { acc_t * rho + (grad_t * grad_t) * (1.0 - rho) }, accum, grad)
34
+ update = multi_array_op(->(acc_update_t, acc_t, grad_t) { Math.sqrt(acc_update_t + epsilon) * (1.0 / Math.sqrt(acc_t + epsilon)) * grad_t }, accum_update, assign_acc.container, grad)
35
+ assign.container = multi_array_op(->(v, u) { v - (u * lr) }, target_var, update)
36
+ assign_acc_update.container = multi_array_op(->(acc_update_t, u) { acc_update_t * rho + (u * u) * (1.0 - rho) }, accum_update, update)
36
37
 
37
- assign.value
38
+ assign.container
38
39
  end
39
40
 
40
41
  register_op :apply_adagrad do |_context, tensor, inputs|
41
42
  target_var, accum, lr, grad = inputs
42
43
  assign = tensor.inputs[0] || tensor
43
- assign.value = multi_array_op(->(v, a, g) { v - (g * lr * (1.0 / Math.sqrt(a))) }, target_var, accum, grad)
44
- assign.value
44
+ assign.container = multi_array_op(->(v, a, g) { v - (g * lr * (1.0 / Math.sqrt(a))) }, target_var, accum, grad)
45
+ assign.container
45
46
  end
46
47
 
47
48
  register_op :apply_adam do |_context, tensor, inputs|
@@ -51,10 +52,10 @@ module TensorStream
51
52
  assign_m = tensor.inputs[1]
52
53
  assign_v = tensor.inputs[2]
53
54
 
54
- assign_m.value = multi_array_op(->(u_d , g) { u_d + (g - u_d) * (1.0 - beta1_t) }, m, grad)
55
- assign_v.value = multi_array_op(->(u_d , v_d) { u_d + (v_d**2 - u_d) * (1.0 - beta2_t)}, v, grad)
56
- assign.value = multi_array_op(->(t, m_d , v_d) { t - ((m_d * alpha) / (Math.sqrt(v_d) + epsilon_t)) }, target_var, assign_m.value, assign_v.value)
57
- assign.value
55
+ assign_m.container = multi_array_op(->(u_d , g) { u_d + (g - u_d) * (1.0 - beta1_t) }, m, grad)
56
+ assign_v.container = multi_array_op(->(u_d , v_d) { u_d + (v_d**2 - u_d) * (1.0 - beta2_t)}, v, grad)
57
+ assign.container = multi_array_op(->(t, m_d , v_d) { t - ((m_d * alpha) / (Math.sqrt(v_d) + epsilon_t)) }, target_var, assign_m.container, assign_v.container)
58
+ assign.container
58
59
  end
59
60
 
60
61
  register_op :apply_rms_prop do |_context, tensor, inputs|
@@ -62,9 +63,9 @@ module TensorStream
62
63
  assign = tensor.inputs[0]
63
64
  assign_ms = tensor.inputs[1]
64
65
  assign_mom = tensor.inputs[2]
65
- assign_ms.value = multi_array_op(->(g, m) { m + (g * g - m) * (1.0 - rho)}, grad, ms)
66
- assign_mom.value = multi_array_op(->(mom_t, g, m) { mom_t * momentum + (g * lr) / Math.sqrt(m + epsilon)}, mom, grad, assign_ms.value)
67
- assign.value = multi_array_op(->(v, m) { v - m }, var, assign_mom.value)
66
+ assign_ms.container = multi_array_op(->(g, m) { m + (g * g - m) * (1.0 - rho)}, grad, ms)
67
+ assign_mom.container = multi_array_op(->(mom_t, g, m) { mom_t * momentum + (g * lr) / Math.sqrt(m + epsilon)}, mom, grad, assign_ms.container)
68
+ assign.container = multi_array_op(->(v, m) { v - m }, var, assign_mom.container)
68
69
  end
69
70
 
70
71
  register_op :apply_centered_rms_prop do |_context, tensor, inputs|
@@ -74,11 +75,11 @@ module TensorStream
74
75
  assign_ms = tensor.inputs[2]
75
76
  assign_mom = tensor.inputs[3]
76
77
 
77
- assign_ms.value = multi_array_op(->(g, m) { m + (g * g - m) * (1.0 - rho) }, grad, ms)
78
- assign_mg.value = multi_array_op(->(g, mg_t) { (g - mg_t) * (1.0 - rho) }, grad, mg)
79
- denom = multi_array_op(->(s, mg_t) { (s - mg_t * mg_t) + epsilon }, assign_ms.value, mg)
80
- assign_mom.value = multi_array_op(->(mom_t, g, d) { mom_t * momentum + (g * lr) / Math.sqrt(d)}, mom, grad, denom)
81
- assign.value = multi_array_op(->(v, m) { v - m }, var, assign_mom.value)
78
+ assign_ms.container = multi_array_op(->(g, m) { m + (g * g - m) * (1.0 - rho) }, grad, ms)
79
+ assign_mg.container = multi_array_op(->(g, mg_t) { (g - mg_t) * (1.0 - rho) }, grad, mg)
80
+ denom = multi_array_op(->(s, mg_t) { (s - mg_t * mg_t) + epsilon }, assign_ms.container, mg)
81
+ assign_mom.container = multi_array_op(->(mom_t, g, d) { mom_t * momentum + (g * lr) / Math.sqrt(d)}, mom, grad, denom)
82
+ assign.container = multi_array_op(->(v, m) { v - m }, var, assign_mom.container)
82
83
  end
83
84
 
84
85
  register_op %i[softmax_cross_entropy_with_logits_v2 softmax_cross_entropy_with_logits] do |_context, tensor, inputs|
@@ -103,7 +104,7 @@ module TensorStream
103
104
  else
104
105
  losses = []
105
106
  backprobs = []
106
- arr = last_dimen_list.zip(labels).each do |list, label|
107
+ last_dimen_list.zip(labels).each do |list, label|
107
108
  loss, prob = func.call(list, label)
108
109
  losses << loss
109
110
  backprobs << prob
@@ -54,9 +54,9 @@ module TensorStream
54
54
  random = _get_randomizer(tensor, seed)
55
55
  generator = -> { r.rand }
56
56
  shape = inputs[0] || tensor.shape.shape
57
- random_values = Array.new(shape.reduce(:*) || 1) {
57
+ random_values = Array.new(shape.reduce(:*) || 1) do
58
58
  generator.call
59
- }
59
+ end
60
60
  mean = random_values.reduce(:+) / random_values.size
61
61
 
62
62
  # standard deviation
@@ -80,20 +80,20 @@ module TensorStream
80
80
  cutoff = 2.0 * Math.exp( 0.5 + (norm_min * (norm_min - sqrt_factor)) / 4.0 ) / (norm_min + sqrt_factor)
81
81
  diff = norm_max - norm_min;
82
82
 
83
- val = random_values.map { |v|
83
+ val = random_values.map do |v|
84
84
  iterations = 0
85
85
  pick = v
86
- while ( (pick > norm_max) || (pick < norm_min) )
86
+ while (pick > norm_max) || (pick < norm_min)
87
87
  pick = generator.call
88
88
  iterations += 1
89
- if iterations > 100
89
+ if iterations > max_iterations
90
90
  pick = v
91
91
  break
92
92
  end
93
93
  end
94
94
 
95
95
  pick
96
- }
96
+ end
97
97
 
98
98
  TensorShape.reshape(val, shape)
99
99
  end
@@ -50,14 +50,10 @@ module TensorStream
50
50
  child_context = execution_context.dup
51
51
  res = if tensor.is_a?(Operation)
52
52
  eval_operation(tensor, child_context)
53
- elsif tensor.is_a?(Variable)
54
- eval_variable(tensor, child_context)
55
- elsif tensor.is_a?(Placeholder)
56
- resolve_placeholder(tensor, child_context)
57
- elsif tensor.is_a?(OutputGroup)
58
- tensor.outputs[0]
53
+ elsif !tensor.is_a?(Tensor)
54
+ tensor
59
55
  else
60
- eval_tensor(tensor, child_context)
56
+ tensor.op
61
57
  end
62
58
  execution_context.deep_merge!(returns: child_context[:returns])
63
59
  res
@@ -88,7 +84,6 @@ module TensorStream
88
84
 
89
85
  def prepare_input(tensor, context, options = {})
90
86
  return nil unless tensor
91
- tensor = resolve_placeholder(tensor)
92
87
  if options[:noop]
93
88
  tensor
94
89
  elsif options[:no_eval]
@@ -98,27 +93,16 @@ module TensorStream
98
93
  end
99
94
  end
100
95
 
101
- def eval_variable(tensor, child_context)
102
- value = tensor.read_value
103
- raise "variable #{tensor.name} not initalized" if value.nil?
104
-
105
- eval_tensor(value, child_context).tap do |val|
106
- child_context[:returns] ||= {}
107
- child_context[:returns][:vars] ||= []
108
- child_context[:returns][:vars] << { name: tensor.name, val: val }
109
- end
110
- end
111
-
112
96
  register_op(:no_op, no_eval: true) do |_context, _tensor, inputs|
113
97
  inputs
114
98
  end
115
99
 
116
- register_op(:const) do |_context, _tensor, inputs|
117
- inputs[0]
100
+ register_op(:const) do |_context, tensor, _inputs|
101
+ tensor.options[:value]
118
102
  end
119
103
 
120
104
  register_op(:cast) do |context, tensor, inputs|
121
- call_op(tensor, inputs[0], context, ->(t, _b) { Tensor.cast_dtype(t, tensor.data_type) })
105
+ call_op(inputs[0], context, ->(t, _b) { Tensor.cast_dtype(t, tensor.data_type) })
122
106
  end
123
107
 
124
108
  register_op(:sign) do |context, tensor, inputs|
@@ -134,7 +118,7 @@ module TensorStream
134
118
  end
135
119
  }
136
120
 
137
- call_op(tensor, inputs[0], context, func)
121
+ call_op(inputs[0], context, func)
138
122
  end
139
123
 
140
124
  register_op(:logical_and) do |context, tensor, inputs|
@@ -149,14 +133,33 @@ module TensorStream
149
133
  call_vector_op(tensor, :not_equal, inputs[0], inputs[1], context, ->(t, u) { t != u })
150
134
  end
151
135
 
152
- def merge_dynamic_stitch(merged, indexes, data)
153
- indexes.each_with_index do |ind, m|
154
- if ind.is_a?(Array)
155
- merge_dynamic_stitch(merged, ind, data[m])
156
- else
157
- merged[ind] = data[m]
136
+ register_op :placeholder, no_eval: true do |context, tensor, _inputs|
137
+ ph = @context[tensor.name.to_sym].tap do |c|
138
+ raise TensorStream::ValueError, "missing placeholder #{tensor.name}" if c.nil?
139
+
140
+ if tensor.shape.shape
141
+ value_shape = shape_eval(c)
142
+ placeholder_shape = tensor.shape.shape
143
+ placeholder_shape.zip(value_shape).each do |p_shape, v_shape|
144
+ next if p_shape.nil?
145
+ raise TensorStream::ValueError, "placeholder expects #{placeholder_shape}, got #{value_shape}" if p_shape != v_shape
146
+ end
158
147
  end
159
148
  end
149
+ if ph.is_a?(Tensor)
150
+ raise TensorStream::ValueError, "placeholder expects type #{tensor.data_type}, got #{ph.data_type}" if ph.data_type != tensor.data_type
151
+
152
+ global_eval(tensor, ph, context)
153
+ else
154
+ global_eval(tensor, Tensor.cast_dtype(ph, dtype: tensor.data_type), context)
155
+ end
156
+ end
157
+
158
+ register_op :variable_v2, no_eval: true do |_context, tensor, _inputs|
159
+ value = tensor.options[:container].read_value
160
+ raise "variable #{tensor.options[:container].name} not initalized" if value.nil?
161
+
162
+ value
160
163
  end
161
164
 
162
165
  register_op :stop_gradient, no_eval: true do |_context, _tensor, inputs|
@@ -165,38 +168,22 @@ module TensorStream
165
168
 
166
169
  register_op :assign, noop: true do |context, tensor, _inputs|
167
170
  assign = tensor.inputs[0] || tensor
168
- assign.value = global_eval(tensor, tensor.inputs[1], context)
169
- assign.value
171
+ assign.container = global_eval(tensor, tensor.inputs[1], context)
172
+ assign.container
170
173
  end
171
174
 
172
175
  register_op :assign_add, noop: true do |context, tensor, _inputs|
173
- tensor.inputs[0].value = process_vector_math_op(tensor, tensor.inputs[0], tensor.inputs[1], context, ->(t, u) { t + u })
174
- tensor.inputs[0].value
175
- end
176
+ assign = tensor.inputs[0] || tensor
176
177
 
177
- register_op :variable, noop: true do |_context, tensor, _inputs|
178
- tensor.inputs[0].value
178
+ assign.container = process_vector_math_op(tensor, tensor.inputs[0], tensor.inputs[1], context, ->(t, u) { t + u })
179
+ assign.container
179
180
  end
180
181
 
181
182
  register_op :assign_sub, noop: true do |context, tensor, _inputs|
182
- tensor.inputs[0].value = process_vector_math_op(tensor, tensor.inputs[0], tensor.inputs[1], context, ->(t, u) { t - u })
183
- tensor.inputs[0].value
184
- end
185
-
186
- register_op :transpose do |_context, _tensor, inputs|
187
- shape = shape_eval(inputs[0])
188
- rank = get_rank(inputs[0])
189
- perm = inputs[1] || (0...rank).to_a.reverse
190
- if rank == 2 && perm.nil? # use native transpose for general case
191
- inputs[0].transpose
192
- else
193
- arr = inputs[0].flatten
183
+ assign = tensor.inputs[0] || tensor
194
184
 
195
- new_shape = perm.map { |p| shape[p] }
196
- new_arr = Array.new(shape.reduce(:*)) { 0 }
197
- transpose_with_perm(arr, new_arr, shape, new_shape, perm)
198
- TensorShape.reshape(new_arr, new_shape)
199
- end
185
+ assign.container = process_vector_math_op(tensor, tensor.inputs[0], tensor.inputs[1], context, ->(t, u) { t - u })
186
+ assign.container
200
187
  end
201
188
 
202
189
  register_op :less do |context, tensor, inputs|
@@ -277,7 +264,7 @@ module TensorStream
277
264
  raise TensorStream::InvalidArgumentError, "#{message} Invalid argument" if t.nan? || t.infinite?
278
265
  t
279
266
  }
280
- call_op(tensor, inputs[0], context, f)
267
+ call_op(inputs[0], context, f)
281
268
  end
282
269
 
283
270
  def eval_operation(tensor, child_context)
@@ -289,14 +276,13 @@ module TensorStream
289
276
  # assertions to make sure inferred shapes == actual evaluated shapes
290
277
  if tensor.shape.known? && (result.is_a?(Array) || result.is_a?(Float) || result.is_a?(Integer))
291
278
  if shape_eval(result) != tensor.shape.shape
292
-
293
279
  raise "assert error #{tensor.name} #{shape_eval(result)} != #{tensor.shape.shape}"
294
280
  end
295
281
  end
296
282
 
297
283
  if tensor.breakpoint
298
- a = resolve_placeholder(tensor.inputs[0], child_context) if tensor.inputs && tensor.inputs[0]
299
- b = resolve_placeholder(tensor.inputs[1], child_context) if tensor.inputs && tensor.inputs[1]
284
+ a = tensor.inputs[0] if tensor.inputs && tensor.inputs[0]
285
+ b = tensor.inputs[1] if tensor.inputs && tensor.inputs[1]
300
286
  a = complete_eval(a, child_context)
301
287
  b = complete_eval(b, child_context)
302
288
  tensor.breakpoint.call(tensor, a, b, complete_eval(result, child_context))
@@ -318,9 +304,7 @@ module TensorStream
318
304
  rescue TensorStreamError => e
319
305
  raise e, "error #{e.message} while evaluating #{tensor.name} defined at #{tensor.source}"
320
306
  rescue StandardError => e
321
- # a = resolve_placeholder(tensor.inputs[0], child_context) if tensor.inputs && tensor.inputs[0]
322
- # b = resolve_placeholder(tensor.inputs[1], child_context) if tensor.inputs && tensor.inputs[1]
323
- puts e.message
307
+ puts e.message
324
308
  puts e.backtrace.join("\n")
325
309
  # shape_a = a.shape.shape if a
326
310
  # shape_b = b.shape.shape if b
@@ -338,25 +322,6 @@ module TensorStream
338
322
  raise EvaluatorExcecutionException.new(e, tensor), "error #{e.message} while evaluating #{tensor.name} : #{tensor.to_math(true, 1)} defined at #{tensor.source}"
339
323
  end
340
324
 
341
- def eval_tensor(tensor, child_context)
342
- return tensor unless tensor.is_a?(Tensor)
343
-
344
- cache_key = "#{tensor.graph.object_id}_ruby_#{tensor.name}"
345
- return @context[cache_key] if @context.key?(cache_key)
346
- return @context[:_cache][cache_key] if @context[:_cache] && @context[:_cache].key?(tensor.name)
347
-
348
- if tensor.value.is_a?(Array)
349
- tensor.value.collect do |input|
350
- input.is_a?(Tensor) ? run(input, child_context) : input
351
- end
352
- else
353
- tensor.value.is_a?(Tensor) ? run(tensor.value, child_context) : tensor.value
354
- end.tap do |result|
355
- @context[cache_key] = result
356
- @context[:_cache][cache_key] = result if @context[:_cache] && tensor.is_const
357
- end
358
- end
359
-
360
325
  def convert_from_buffer(_tensor, result)
361
326
  result.buffer
362
327
  end
@@ -377,14 +342,7 @@ module TensorStream
377
342
  end
378
343
  end
379
344
 
380
- def reduction(child_context, tensor, func)
381
- val = global_eval(tensor, tensor.inputs[0], child_context)
382
- axis = global_eval(tensor, tensor.inputs[1], child_context)
383
- keep_dims = global_eval(tensor, tensor.options[:keepdims], child_context)
384
- reduce(val, axis, keep_dims, func)
385
- end
386
-
387
- def call_op(op, a, child_context, func)
345
+ def call_op(a, child_context, func)
388
346
  a = complete_eval(a, child_context)
389
347
  process_function_op(a, func)
390
348
  end
@@ -419,7 +377,7 @@ module TensorStream
419
377
  def multi_array_op(func, *args)
420
378
  elem = args[0]
421
379
  if (elem.is_a?(Array))
422
- elem.each_with_index.collect do |item, index|
380
+ elem.each_with_index.collect do |_item, index|
423
381
  indexed_args = args.collect { |a| a[index] }
424
382
  multi_array_op(func, *indexed_args)
425
383
  end
@@ -452,29 +410,6 @@ module TensorStream
452
410
  end
453
411
  end
454
412
 
455
- def resolve_placeholder(placeholder, _execution_context = {})
456
- return nil if placeholder.nil?
457
-
458
- var = if placeholder.is_a?(Placeholder)
459
- @context[placeholder.name.to_sym].tap do |c|
460
- raise TensorStream::ValueError, "missing placeholder #{placeholder.name}" if c.nil?
461
- if placeholder.shape.shape
462
- value_shape = shape_eval(c)
463
- placeholder_shape = placeholder.shape.shape
464
- placeholder_shape.zip(value_shape).each do |p_shape, v_shape|
465
- next if p_shape.nil?
466
- raise TensorStream::ValueError, "placeholder expects #{placeholder_shape}, got #{value_shape}" if p_shape != v_shape
467
- end
468
- end
469
- end
470
- else
471
- placeholder
472
- end
473
-
474
- return var unless placeholder.is_a?(Tensor)
475
- Tensor.cast_dtype(var, placeholder.data_type)
476
- end
477
-
478
413
  # handle 3 tensor math operations
479
414
  def call_3way_vector_op(v_a, v_b, v_c, child_context, op = ->(a, b, c) { a + b + c })
480
415
  return op.call(v_a, v_b, v_c) unless v_a.is_a?(Array)
@@ -1,7 +1,10 @@
1
1
  module TensorStream
2
2
  # A class that defines a TensorStream graph
3
3
  class Graph
4
- attr_accessor :nodes, :node_keys, :collections, :eager_execution, :random_seed, :constants
4
+ include OpHelper
5
+
6
+ attr_accessor :nodes, :collections, :eager_execution, :random_seed, :constants
7
+ attr_reader :node_keys
5
8
 
6
9
  def initialize
7
10
  @eager_execution = false
@@ -30,8 +33,12 @@ module TensorStream
30
33
  end
31
34
 
32
35
  def as_default
36
+ Thread.current[:tensor_stream_current_graph_queue] ||= []
37
+ Thread.current[:tensor_stream_current_graph_queue] << Graph.get_default_graph
38
+
33
39
  Thread.current[:tensor_stream_current_graph] = self
34
40
  yield(self) if block_given?
41
+ Thread.current[:tensor_stream_current_graph] = Thread.current[:tensor_stream_current_graph_queue].pop
35
42
  self
36
43
  end
37
44
 
@@ -77,24 +84,24 @@ module TensorStream
77
84
  @collections[collection_name.to_sym] << val
78
85
  end
79
86
 
80
- def add_node(node)
87
+ def add_node(node, name = nil)
81
88
  raise 'Placeholder cannot be used when eager_execution is enabled' if @eager_execution && node.is_a?(Placeholder)
82
89
 
83
- node.name = if @nodes[node.name]
84
- uniqunify(node.name)
85
- else
86
- node.name
87
- end
90
+ if name.nil?
91
+ node.name = if @nodes[node.name]
92
+ uniqunify(node.name)
93
+ else
94
+ node.name
95
+ end
96
+ end
88
97
 
89
98
  node.device = get_device_scope
90
99
  @node_keys << node.name
91
100
  @nodes[node.name] = node
92
101
  @constants[node.name] = node if node.is_const
93
- # puts "adding node"
102
+
94
103
  node.send(:propagate_outputs)
95
104
  node.send(:propagate_consumer, node)
96
- # puts "#{node.name}"
97
- node.value = node.eval if @eager_execution
98
105
  end
99
106
 
100
107
  def node_added?(name)
@@ -107,6 +114,7 @@ module TensorStream
107
114
 
108
115
  def get_tensor_by_name(name)
109
116
  raise TensorStream::KeyError, "#{name} not found" unless @nodes.key?(name)
117
+
110
118
  get_node(name)
111
119
  end
112
120
 
@@ -115,6 +123,39 @@ module TensorStream
115
123
  node
116
124
  end
117
125
 
126
+ def add_op(operation, *args)
127
+ options = if args.last.is_a?(Hash)
128
+ args.pop
129
+ else
130
+ {}
131
+ end
132
+
133
+ inputs = args.map { |i| TensorStream.convert_to_tensor(i) }.map { |i| i ? i.op : nil }
134
+
135
+ new_op = Operation.new(self, inputs: inputs, options: options)
136
+ new_op.source = format_source(caller_locations)
137
+ new_op.operation = operation
138
+ new_op.shape = TensorShape.new(TensorStream::InferShape.infer_shape(new_op))
139
+ new_op.rank = new_op.shape.rank
140
+ new_op.name = options[:internal_name] || [get_name_scope, options[:name] || set_operation_name(new_op)].compact.reject(&:empty?).join('/')
141
+ new_op.internal = options[:internal]
142
+
143
+ new_op.data_type = new_op.set_data_type(options[:data_type])
144
+ new_op.is_const = new_op.infer_const
145
+
146
+ new_op.given_name = new_op.name
147
+
148
+ new_op
149
+ end
150
+
151
+ def add_op!(operation, *args)
152
+ add_op(operation, *args).tap { |node| add_node(node) }
153
+ end
154
+
155
+ def set_operation_name(op)
156
+ op.operation.to_s
157
+ end
158
+
118
159
  def add_variable(node, options = {})
119
160
  scope = _variable_scope
120
161
 
@@ -129,13 +170,21 @@ module TensorStream
129
170
 
130
171
  add_to_collection(GraphKeys::GLOBAL_VARIABLES, node)
131
172
  add_to_collection(GraphKeys::TRAINABLE_VARIABLES, node) if node.trainable?
132
- add_node(node)
173
+
174
+ node
175
+ end
176
+
177
+ def add_variable!(node, options = {})
178
+ node = add_variable(node, options)
179
+ op = Graph.get_default_graph.add_op!(:variable_v2, container: node, internal_name: node.name, shape: options[:shape], data_type: options[:data_type])
180
+ node.name = op.name
181
+ op
133
182
  end
134
183
 
135
184
  def control_dependencies(control_inputs = [])
136
185
  Thread.current["ts_graph_#{object_id}"] ||= {}
137
186
  Thread.current["ts_graph_#{object_id}"][:control_dependencies] ||= []
138
- Thread.current["ts_graph_#{object_id}"][:control_dependencies] << Operation.new(:no_op, *control_inputs)
187
+ Thread.current["ts_graph_#{object_id}"][:control_dependencies] << Graph.get_default_graph.add_op!(:no_op, *control_inputs)
139
188
  begin
140
189
  yield
141
190
  ensure