tensor_stream 0.9.8 → 0.9.9

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +31 -14
  3. data/lib/tensor_stream.rb +4 -0
  4. data/lib/tensor_stream/constant.rb +41 -0
  5. data/lib/tensor_stream/control_flow.rb +2 -1
  6. data/lib/tensor_stream/dynamic_stitch.rb +3 -1
  7. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +4 -4
  8. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +74 -23
  9. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +45 -43
  10. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +31 -30
  11. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +6 -6
  12. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +46 -111
  13. data/lib/tensor_stream/graph.rb +61 -12
  14. data/lib/tensor_stream/graph_builder.rb +3 -3
  15. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +38 -0
  16. data/lib/tensor_stream/graph_serializers/packer.rb +8 -0
  17. data/lib/tensor_stream/graph_serializers/pbtext.rb +62 -27
  18. data/lib/tensor_stream/graph_serializers/serializer.rb +2 -2
  19. data/lib/tensor_stream/graph_serializers/yaml.rb +27 -0
  20. data/lib/tensor_stream/helpers/infer_shape.rb +15 -9
  21. data/lib/tensor_stream/helpers/op_helper.rb +17 -6
  22. data/lib/tensor_stream/helpers/string_helper.rb +32 -1
  23. data/lib/tensor_stream/helpers/tensor_mixins.rb +135 -0
  24. data/lib/tensor_stream/math_gradients.rb +19 -12
  25. data/lib/tensor_stream/monkey_patches/float.rb +7 -0
  26. data/lib/tensor_stream/monkey_patches/integer.rb +7 -0
  27. data/lib/tensor_stream/monkey_patches/patch.rb +8 -8
  28. data/lib/tensor_stream/nn/nn_ops.rb +1 -1
  29. data/lib/tensor_stream/operation.rb +98 -36
  30. data/lib/tensor_stream/ops.rb +65 -13
  31. data/lib/tensor_stream/placeholder.rb +2 -2
  32. data/lib/tensor_stream/session.rb +15 -3
  33. data/lib/tensor_stream/tensor.rb +15 -172
  34. data/lib/tensor_stream/tensor_shape.rb +3 -1
  35. data/lib/tensor_stream/train/saver.rb +12 -10
  36. data/lib/tensor_stream/trainer.rb +7 -2
  37. data/lib/tensor_stream/utils.rb +13 -11
  38. data/lib/tensor_stream/utils/freezer.rb +37 -0
  39. data/lib/tensor_stream/variable.rb +17 -11
  40. data/lib/tensor_stream/variable_scope.rb +3 -1
  41. data/lib/tensor_stream/version.rb +1 -1
  42. data/samples/iris.rb +3 -4
  43. data/samples/linear_regression.rb +9 -5
  44. data/samples/logistic_regression.rb +11 -9
  45. data/samples/mnist_data.rb +8 -10
  46. metadata +8 -4
@@ -7,21 +7,22 @@ module TensorStream
7
7
  target_var, learning_rate, delta = inputs
8
8
  assign = tensor.inputs[0] || tensor
9
9
 
10
- assign.value = process_vector_math_op(tensor, target_var, delta, context, ->(t, u) { t - u * learning_rate })
11
- assign.value
10
+ assign.container = process_vector_math_op(tensor, target_var, delta, context, ->(t, u) { t - u * learning_rate })
11
+ assign.container
12
12
  end
13
13
 
14
- register_op :apply_momentum do |context, tensor, inputs|
14
+ register_op :apply_momentum do |_context, tensor, inputs|
15
15
  target_var, momentum_var, learning_rate, grad, momentum = inputs
16
16
  assign = tensor.inputs[0] || tensor
17
17
  assign_acc = tensor.inputs[1]
18
- assign_acc.value = multi_array_op(->(t, u) { t * momentum + u }, momentum_var, grad)
19
- assign.value = if tensor.options[:use_nesterov]
20
- multi_array_op(->(v, g, acc) { v - (g * learning_rate + acc * momentum * learning_rate) }, target_var, grad, momentum_var)
21
- else
22
- multi_array_op(->(v, acc) { v - acc * learning_rate }, target_var, momentum_var)
23
- end
24
- assign.value
18
+ assign_acc.container = multi_array_op(->(t, u) { t * momentum + u }, momentum_var, grad)
19
+ assign.container = if tensor.options[:use_nesterov]
20
+ multi_array_op(->(v, g, acc) { v - (g * learning_rate + acc * momentum * learning_rate) }, target_var, grad, momentum_var)
21
+ else
22
+ multi_array_op(->(v, acc) { v - acc * learning_rate }, target_var, momentum_var)
23
+ end
24
+
25
+ assign.container
25
26
  end
26
27
 
27
28
  register_op :apply_adadelta do |_context, tensor, inputs|
@@ -29,19 +30,19 @@ module TensorStream
29
30
  assign = tensor.inputs[0] || tensor
30
31
  assign_acc = tensor.inputs[1]
31
32
  assign_acc_update = tensor.inputs[2]
32
- assign_acc.value = multi_array_op(->(acc_t, grad_t) { acc_t * rho + (grad_t * grad_t) * (1.0 - rho) }, accum, grad)
33
- update = multi_array_op(->(acc_update_t, acc_t, grad_t) { Math.sqrt(acc_update_t + epsilon) * (1.0 / Math.sqrt(acc_t + epsilon)) * grad_t }, accum_update, assign_acc.value, grad)
34
- assign.value = multi_array_op(->(v, u) { v - (u * lr) }, target_var, update)
35
- assign_acc_update.value = multi_array_op(->(acc_update_t, u) { acc_update_t * rho + (u * u) * (1.0 - rho) }, accum_update, update)
33
+ assign_acc.container = multi_array_op(->(acc_t, grad_t) { acc_t * rho + (grad_t * grad_t) * (1.0 - rho) }, accum, grad)
34
+ update = multi_array_op(->(acc_update_t, acc_t, grad_t) { Math.sqrt(acc_update_t + epsilon) * (1.0 / Math.sqrt(acc_t + epsilon)) * grad_t }, accum_update, assign_acc.container, grad)
35
+ assign.container = multi_array_op(->(v, u) { v - (u * lr) }, target_var, update)
36
+ assign_acc_update.container = multi_array_op(->(acc_update_t, u) { acc_update_t * rho + (u * u) * (1.0 - rho) }, accum_update, update)
36
37
 
37
- assign.value
38
+ assign.container
38
39
  end
39
40
 
40
41
  register_op :apply_adagrad do |_context, tensor, inputs|
41
42
  target_var, accum, lr, grad = inputs
42
43
  assign = tensor.inputs[0] || tensor
43
- assign.value = multi_array_op(->(v, a, g) { v - (g * lr * (1.0 / Math.sqrt(a))) }, target_var, accum, grad)
44
- assign.value
44
+ assign.container = multi_array_op(->(v, a, g) { v - (g * lr * (1.0 / Math.sqrt(a))) }, target_var, accum, grad)
45
+ assign.container
45
46
  end
46
47
 
47
48
  register_op :apply_adam do |_context, tensor, inputs|
@@ -51,10 +52,10 @@ module TensorStream
51
52
  assign_m = tensor.inputs[1]
52
53
  assign_v = tensor.inputs[2]
53
54
 
54
- assign_m.value = multi_array_op(->(u_d , g) { u_d + (g - u_d) * (1.0 - beta1_t) }, m, grad)
55
- assign_v.value = multi_array_op(->(u_d , v_d) { u_d + (v_d**2 - u_d) * (1.0 - beta2_t)}, v, grad)
56
- assign.value = multi_array_op(->(t, m_d , v_d) { t - ((m_d * alpha) / (Math.sqrt(v_d) + epsilon_t)) }, target_var, assign_m.value, assign_v.value)
57
- assign.value
55
+ assign_m.container = multi_array_op(->(u_d , g) { u_d + (g - u_d) * (1.0 - beta1_t) }, m, grad)
56
+ assign_v.container = multi_array_op(->(u_d , v_d) { u_d + (v_d**2 - u_d) * (1.0 - beta2_t)}, v, grad)
57
+ assign.container = multi_array_op(->(t, m_d , v_d) { t - ((m_d * alpha) / (Math.sqrt(v_d) + epsilon_t)) }, target_var, assign_m.container, assign_v.container)
58
+ assign.container
58
59
  end
59
60
 
60
61
  register_op :apply_rms_prop do |_context, tensor, inputs|
@@ -62,9 +63,9 @@ module TensorStream
62
63
  assign = tensor.inputs[0]
63
64
  assign_ms = tensor.inputs[1]
64
65
  assign_mom = tensor.inputs[2]
65
- assign_ms.value = multi_array_op(->(g, m) { m + (g * g - m) * (1.0 - rho)}, grad, ms)
66
- assign_mom.value = multi_array_op(->(mom_t, g, m) { mom_t * momentum + (g * lr) / Math.sqrt(m + epsilon)}, mom, grad, assign_ms.value)
67
- assign.value = multi_array_op(->(v, m) { v - m }, var, assign_mom.value)
66
+ assign_ms.container = multi_array_op(->(g, m) { m + (g * g - m) * (1.0 - rho)}, grad, ms)
67
+ assign_mom.container = multi_array_op(->(mom_t, g, m) { mom_t * momentum + (g * lr) / Math.sqrt(m + epsilon)}, mom, grad, assign_ms.container)
68
+ assign.container = multi_array_op(->(v, m) { v - m }, var, assign_mom.container)
68
69
  end
69
70
 
70
71
  register_op :apply_centered_rms_prop do |_context, tensor, inputs|
@@ -74,11 +75,11 @@ module TensorStream
74
75
  assign_ms = tensor.inputs[2]
75
76
  assign_mom = tensor.inputs[3]
76
77
 
77
- assign_ms.value = multi_array_op(->(g, m) { m + (g * g - m) * (1.0 - rho) }, grad, ms)
78
- assign_mg.value = multi_array_op(->(g, mg_t) { (g - mg_t) * (1.0 - rho) }, grad, mg)
79
- denom = multi_array_op(->(s, mg_t) { (s - mg_t * mg_t) + epsilon }, assign_ms.value, mg)
80
- assign_mom.value = multi_array_op(->(mom_t, g, d) { mom_t * momentum + (g * lr) / Math.sqrt(d)}, mom, grad, denom)
81
- assign.value = multi_array_op(->(v, m) { v - m }, var, assign_mom.value)
78
+ assign_ms.container = multi_array_op(->(g, m) { m + (g * g - m) * (1.0 - rho) }, grad, ms)
79
+ assign_mg.container = multi_array_op(->(g, mg_t) { (g - mg_t) * (1.0 - rho) }, grad, mg)
80
+ denom = multi_array_op(->(s, mg_t) { (s - mg_t * mg_t) + epsilon }, assign_ms.container, mg)
81
+ assign_mom.container = multi_array_op(->(mom_t, g, d) { mom_t * momentum + (g * lr) / Math.sqrt(d)}, mom, grad, denom)
82
+ assign.container = multi_array_op(->(v, m) { v - m }, var, assign_mom.container)
82
83
  end
83
84
 
84
85
  register_op %i[softmax_cross_entropy_with_logits_v2 softmax_cross_entropy_with_logits] do |_context, tensor, inputs|
@@ -103,7 +104,7 @@ module TensorStream
103
104
  else
104
105
  losses = []
105
106
  backprobs = []
106
- arr = last_dimen_list.zip(labels).each do |list, label|
107
+ last_dimen_list.zip(labels).each do |list, label|
107
108
  loss, prob = func.call(list, label)
108
109
  losses << loss
109
110
  backprobs << prob
@@ -54,9 +54,9 @@ module TensorStream
54
54
  random = _get_randomizer(tensor, seed)
55
55
  generator = -> { r.rand }
56
56
  shape = inputs[0] || tensor.shape.shape
57
- random_values = Array.new(shape.reduce(:*) || 1) {
57
+ random_values = Array.new(shape.reduce(:*) || 1) do
58
58
  generator.call
59
- }
59
+ end
60
60
  mean = random_values.reduce(:+) / random_values.size
61
61
 
62
62
  # standard deviation
@@ -80,20 +80,20 @@ module TensorStream
80
80
  cutoff = 2.0 * Math.exp( 0.5 + (norm_min * (norm_min - sqrt_factor)) / 4.0 ) / (norm_min + sqrt_factor)
81
81
  diff = norm_max - norm_min;
82
82
 
83
- val = random_values.map { |v|
83
+ val = random_values.map do |v|
84
84
  iterations = 0
85
85
  pick = v
86
- while ( (pick > norm_max) || (pick < norm_min) )
86
+ while (pick > norm_max) || (pick < norm_min)
87
87
  pick = generator.call
88
88
  iterations += 1
89
- if iterations > 100
89
+ if iterations > max_iterations
90
90
  pick = v
91
91
  break
92
92
  end
93
93
  end
94
94
 
95
95
  pick
96
- }
96
+ end
97
97
 
98
98
  TensorShape.reshape(val, shape)
99
99
  end
@@ -50,14 +50,10 @@ module TensorStream
50
50
  child_context = execution_context.dup
51
51
  res = if tensor.is_a?(Operation)
52
52
  eval_operation(tensor, child_context)
53
- elsif tensor.is_a?(Variable)
54
- eval_variable(tensor, child_context)
55
- elsif tensor.is_a?(Placeholder)
56
- resolve_placeholder(tensor, child_context)
57
- elsif tensor.is_a?(OutputGroup)
58
- tensor.outputs[0]
53
+ elsif !tensor.is_a?(Tensor)
54
+ tensor
59
55
  else
60
- eval_tensor(tensor, child_context)
56
+ tensor.op
61
57
  end
62
58
  execution_context.deep_merge!(returns: child_context[:returns])
63
59
  res
@@ -88,7 +84,6 @@ module TensorStream
88
84
 
89
85
  def prepare_input(tensor, context, options = {})
90
86
  return nil unless tensor
91
- tensor = resolve_placeholder(tensor)
92
87
  if options[:noop]
93
88
  tensor
94
89
  elsif options[:no_eval]
@@ -98,27 +93,16 @@ module TensorStream
98
93
  end
99
94
  end
100
95
 
101
- def eval_variable(tensor, child_context)
102
- value = tensor.read_value
103
- raise "variable #{tensor.name} not initalized" if value.nil?
104
-
105
- eval_tensor(value, child_context).tap do |val|
106
- child_context[:returns] ||= {}
107
- child_context[:returns][:vars] ||= []
108
- child_context[:returns][:vars] << { name: tensor.name, val: val }
109
- end
110
- end
111
-
112
96
  register_op(:no_op, no_eval: true) do |_context, _tensor, inputs|
113
97
  inputs
114
98
  end
115
99
 
116
- register_op(:const) do |_context, _tensor, inputs|
117
- inputs[0]
100
+ register_op(:const) do |_context, tensor, _inputs|
101
+ tensor.options[:value]
118
102
  end
119
103
 
120
104
  register_op(:cast) do |context, tensor, inputs|
121
- call_op(tensor, inputs[0], context, ->(t, _b) { Tensor.cast_dtype(t, tensor.data_type) })
105
+ call_op(inputs[0], context, ->(t, _b) { Tensor.cast_dtype(t, tensor.data_type) })
122
106
  end
123
107
 
124
108
  register_op(:sign) do |context, tensor, inputs|
@@ -134,7 +118,7 @@ module TensorStream
134
118
  end
135
119
  }
136
120
 
137
- call_op(tensor, inputs[0], context, func)
121
+ call_op(inputs[0], context, func)
138
122
  end
139
123
 
140
124
  register_op(:logical_and) do |context, tensor, inputs|
@@ -149,14 +133,33 @@ module TensorStream
149
133
  call_vector_op(tensor, :not_equal, inputs[0], inputs[1], context, ->(t, u) { t != u })
150
134
  end
151
135
 
152
- def merge_dynamic_stitch(merged, indexes, data)
153
- indexes.each_with_index do |ind, m|
154
- if ind.is_a?(Array)
155
- merge_dynamic_stitch(merged, ind, data[m])
156
- else
157
- merged[ind] = data[m]
136
+ register_op :placeholder, no_eval: true do |context, tensor, _inputs|
137
+ ph = @context[tensor.name.to_sym].tap do |c|
138
+ raise TensorStream::ValueError, "missing placeholder #{tensor.name}" if c.nil?
139
+
140
+ if tensor.shape.shape
141
+ value_shape = shape_eval(c)
142
+ placeholder_shape = tensor.shape.shape
143
+ placeholder_shape.zip(value_shape).each do |p_shape, v_shape|
144
+ next if p_shape.nil?
145
+ raise TensorStream::ValueError, "placeholder expects #{placeholder_shape}, got #{value_shape}" if p_shape != v_shape
146
+ end
158
147
  end
159
148
  end
149
+ if ph.is_a?(Tensor)
150
+ raise TensorStream::ValueError, "placeholder expects type #{tensor.data_type}, got #{ph.data_type}" if ph.data_type != tensor.data_type
151
+
152
+ global_eval(tensor, ph, context)
153
+ else
154
+ global_eval(tensor, Tensor.cast_dtype(ph, dtype: tensor.data_type), context)
155
+ end
156
+ end
157
+
158
+ register_op :variable_v2, no_eval: true do |_context, tensor, _inputs|
159
+ value = tensor.options[:container].read_value
160
+ raise "variable #{tensor.options[:container].name} not initalized" if value.nil?
161
+
162
+ value
160
163
  end
161
164
 
162
165
  register_op :stop_gradient, no_eval: true do |_context, _tensor, inputs|
@@ -165,38 +168,22 @@ module TensorStream
165
168
 
166
169
  register_op :assign, noop: true do |context, tensor, _inputs|
167
170
  assign = tensor.inputs[0] || tensor
168
- assign.value = global_eval(tensor, tensor.inputs[1], context)
169
- assign.value
171
+ assign.container = global_eval(tensor, tensor.inputs[1], context)
172
+ assign.container
170
173
  end
171
174
 
172
175
  register_op :assign_add, noop: true do |context, tensor, _inputs|
173
- tensor.inputs[0].value = process_vector_math_op(tensor, tensor.inputs[0], tensor.inputs[1], context, ->(t, u) { t + u })
174
- tensor.inputs[0].value
175
- end
176
+ assign = tensor.inputs[0] || tensor
176
177
 
177
- register_op :variable, noop: true do |_context, tensor, _inputs|
178
- tensor.inputs[0].value
178
+ assign.container = process_vector_math_op(tensor, tensor.inputs[0], tensor.inputs[1], context, ->(t, u) { t + u })
179
+ assign.container
179
180
  end
180
181
 
181
182
  register_op :assign_sub, noop: true do |context, tensor, _inputs|
182
- tensor.inputs[0].value = process_vector_math_op(tensor, tensor.inputs[0], tensor.inputs[1], context, ->(t, u) { t - u })
183
- tensor.inputs[0].value
184
- end
185
-
186
- register_op :transpose do |_context, _tensor, inputs|
187
- shape = shape_eval(inputs[0])
188
- rank = get_rank(inputs[0])
189
- perm = inputs[1] || (0...rank).to_a.reverse
190
- if rank == 2 && perm.nil? # use native transpose for general case
191
- inputs[0].transpose
192
- else
193
- arr = inputs[0].flatten
183
+ assign = tensor.inputs[0] || tensor
194
184
 
195
- new_shape = perm.map { |p| shape[p] }
196
- new_arr = Array.new(shape.reduce(:*)) { 0 }
197
- transpose_with_perm(arr, new_arr, shape, new_shape, perm)
198
- TensorShape.reshape(new_arr, new_shape)
199
- end
185
+ assign.container = process_vector_math_op(tensor, tensor.inputs[0], tensor.inputs[1], context, ->(t, u) { t - u })
186
+ assign.container
200
187
  end
201
188
 
202
189
  register_op :less do |context, tensor, inputs|
@@ -277,7 +264,7 @@ module TensorStream
277
264
  raise TensorStream::InvalidArgumentError, "#{message} Invalid argument" if t.nan? || t.infinite?
278
265
  t
279
266
  }
280
- call_op(tensor, inputs[0], context, f)
267
+ call_op(inputs[0], context, f)
281
268
  end
282
269
 
283
270
  def eval_operation(tensor, child_context)
@@ -289,14 +276,13 @@ module TensorStream
289
276
  # assertions to make sure inferred shapes == actual evaluated shapes
290
277
  if tensor.shape.known? && (result.is_a?(Array) || result.is_a?(Float) || result.is_a?(Integer))
291
278
  if shape_eval(result) != tensor.shape.shape
292
-
293
279
  raise "assert error #{tensor.name} #{shape_eval(result)} != #{tensor.shape.shape}"
294
280
  end
295
281
  end
296
282
 
297
283
  if tensor.breakpoint
298
- a = resolve_placeholder(tensor.inputs[0], child_context) if tensor.inputs && tensor.inputs[0]
299
- b = resolve_placeholder(tensor.inputs[1], child_context) if tensor.inputs && tensor.inputs[1]
284
+ a = tensor.inputs[0] if tensor.inputs && tensor.inputs[0]
285
+ b = tensor.inputs[1] if tensor.inputs && tensor.inputs[1]
300
286
  a = complete_eval(a, child_context)
301
287
  b = complete_eval(b, child_context)
302
288
  tensor.breakpoint.call(tensor, a, b, complete_eval(result, child_context))
@@ -318,9 +304,7 @@ module TensorStream
318
304
  rescue TensorStreamError => e
319
305
  raise e, "error #{e.message} while evaluating #{tensor.name} defined at #{tensor.source}"
320
306
  rescue StandardError => e
321
- # a = resolve_placeholder(tensor.inputs[0], child_context) if tensor.inputs && tensor.inputs[0]
322
- # b = resolve_placeholder(tensor.inputs[1], child_context) if tensor.inputs && tensor.inputs[1]
323
- puts e.message
307
+ puts e.message
324
308
  puts e.backtrace.join("\n")
325
309
  # shape_a = a.shape.shape if a
326
310
  # shape_b = b.shape.shape if b
@@ -338,25 +322,6 @@ module TensorStream
338
322
  raise EvaluatorExcecutionException.new(e, tensor), "error #{e.message} while evaluating #{tensor.name} : #{tensor.to_math(true, 1)} defined at #{tensor.source}"
339
323
  end
340
324
 
341
- def eval_tensor(tensor, child_context)
342
- return tensor unless tensor.is_a?(Tensor)
343
-
344
- cache_key = "#{tensor.graph.object_id}_ruby_#{tensor.name}"
345
- return @context[cache_key] if @context.key?(cache_key)
346
- return @context[:_cache][cache_key] if @context[:_cache] && @context[:_cache].key?(tensor.name)
347
-
348
- if tensor.value.is_a?(Array)
349
- tensor.value.collect do |input|
350
- input.is_a?(Tensor) ? run(input, child_context) : input
351
- end
352
- else
353
- tensor.value.is_a?(Tensor) ? run(tensor.value, child_context) : tensor.value
354
- end.tap do |result|
355
- @context[cache_key] = result
356
- @context[:_cache][cache_key] = result if @context[:_cache] && tensor.is_const
357
- end
358
- end
359
-
360
325
  def convert_from_buffer(_tensor, result)
361
326
  result.buffer
362
327
  end
@@ -377,14 +342,7 @@ module TensorStream
377
342
  end
378
343
  end
379
344
 
380
- def reduction(child_context, tensor, func)
381
- val = global_eval(tensor, tensor.inputs[0], child_context)
382
- axis = global_eval(tensor, tensor.inputs[1], child_context)
383
- keep_dims = global_eval(tensor, tensor.options[:keepdims], child_context)
384
- reduce(val, axis, keep_dims, func)
385
- end
386
-
387
- def call_op(op, a, child_context, func)
345
+ def call_op(a, child_context, func)
388
346
  a = complete_eval(a, child_context)
389
347
  process_function_op(a, func)
390
348
  end
@@ -419,7 +377,7 @@ module TensorStream
419
377
  def multi_array_op(func, *args)
420
378
  elem = args[0]
421
379
  if (elem.is_a?(Array))
422
- elem.each_with_index.collect do |item, index|
380
+ elem.each_with_index.collect do |_item, index|
423
381
  indexed_args = args.collect { |a| a[index] }
424
382
  multi_array_op(func, *indexed_args)
425
383
  end
@@ -452,29 +410,6 @@ module TensorStream
452
410
  end
453
411
  end
454
412
 
455
- def resolve_placeholder(placeholder, _execution_context = {})
456
- return nil if placeholder.nil?
457
-
458
- var = if placeholder.is_a?(Placeholder)
459
- @context[placeholder.name.to_sym].tap do |c|
460
- raise TensorStream::ValueError, "missing placeholder #{placeholder.name}" if c.nil?
461
- if placeholder.shape.shape
462
- value_shape = shape_eval(c)
463
- placeholder_shape = placeholder.shape.shape
464
- placeholder_shape.zip(value_shape).each do |p_shape, v_shape|
465
- next if p_shape.nil?
466
- raise TensorStream::ValueError, "placeholder expects #{placeholder_shape}, got #{value_shape}" if p_shape != v_shape
467
- end
468
- end
469
- end
470
- else
471
- placeholder
472
- end
473
-
474
- return var unless placeholder.is_a?(Tensor)
475
- Tensor.cast_dtype(var, placeholder.data_type)
476
- end
477
-
478
413
  # handle 3 tensor math operations
479
414
  def call_3way_vector_op(v_a, v_b, v_c, child_context, op = ->(a, b, c) { a + b + c })
480
415
  return op.call(v_a, v_b, v_c) unless v_a.is_a?(Array)
@@ -1,7 +1,10 @@
1
1
  module TensorStream
2
2
  # A class that defines a TensorStream graph
3
3
  class Graph
4
- attr_accessor :nodes, :node_keys, :collections, :eager_execution, :random_seed, :constants
4
+ include OpHelper
5
+
6
+ attr_accessor :nodes, :collections, :eager_execution, :random_seed, :constants
7
+ attr_reader :node_keys
5
8
 
6
9
  def initialize
7
10
  @eager_execution = false
@@ -30,8 +33,12 @@ module TensorStream
30
33
  end
31
34
 
32
35
  def as_default
36
+ Thread.current[:tensor_stream_current_graph_queue] ||= []
37
+ Thread.current[:tensor_stream_current_graph_queue] << Graph.get_default_graph
38
+
33
39
  Thread.current[:tensor_stream_current_graph] = self
34
40
  yield(self) if block_given?
41
+ Thread.current[:tensor_stream_current_graph] = Thread.current[:tensor_stream_current_graph_queue].pop
35
42
  self
36
43
  end
37
44
 
@@ -77,24 +84,24 @@ module TensorStream
77
84
  @collections[collection_name.to_sym] << val
78
85
  end
79
86
 
80
- def add_node(node)
87
+ def add_node(node, name = nil)
81
88
  raise 'Placeholder cannot be used when eager_execution is enabled' if @eager_execution && node.is_a?(Placeholder)
82
89
 
83
- node.name = if @nodes[node.name]
84
- uniqunify(node.name)
85
- else
86
- node.name
87
- end
90
+ if name.nil?
91
+ node.name = if @nodes[node.name]
92
+ uniqunify(node.name)
93
+ else
94
+ node.name
95
+ end
96
+ end
88
97
 
89
98
  node.device = get_device_scope
90
99
  @node_keys << node.name
91
100
  @nodes[node.name] = node
92
101
  @constants[node.name] = node if node.is_const
93
- # puts "adding node"
102
+
94
103
  node.send(:propagate_outputs)
95
104
  node.send(:propagate_consumer, node)
96
- # puts "#{node.name}"
97
- node.value = node.eval if @eager_execution
98
105
  end
99
106
 
100
107
  def node_added?(name)
@@ -107,6 +114,7 @@ module TensorStream
107
114
 
108
115
  def get_tensor_by_name(name)
109
116
  raise TensorStream::KeyError, "#{name} not found" unless @nodes.key?(name)
117
+
110
118
  get_node(name)
111
119
  end
112
120
 
@@ -115,6 +123,39 @@ module TensorStream
115
123
  node
116
124
  end
117
125
 
126
+ def add_op(operation, *args)
127
+ options = if args.last.is_a?(Hash)
128
+ args.pop
129
+ else
130
+ {}
131
+ end
132
+
133
+ inputs = args.map { |i| TensorStream.convert_to_tensor(i) }.map { |i| i ? i.op : nil }
134
+
135
+ new_op = Operation.new(self, inputs: inputs, options: options)
136
+ new_op.source = format_source(caller_locations)
137
+ new_op.operation = operation
138
+ new_op.shape = TensorShape.new(TensorStream::InferShape.infer_shape(new_op))
139
+ new_op.rank = new_op.shape.rank
140
+ new_op.name = options[:internal_name] || [get_name_scope, options[:name] || set_operation_name(new_op)].compact.reject(&:empty?).join('/')
141
+ new_op.internal = options[:internal]
142
+
143
+ new_op.data_type = new_op.set_data_type(options[:data_type])
144
+ new_op.is_const = new_op.infer_const
145
+
146
+ new_op.given_name = new_op.name
147
+
148
+ new_op
149
+ end
150
+
151
+ def add_op!(operation, *args)
152
+ add_op(operation, *args).tap { |node| add_node(node) }
153
+ end
154
+
155
+ def set_operation_name(op)
156
+ op.operation.to_s
157
+ end
158
+
118
159
  def add_variable(node, options = {})
119
160
  scope = _variable_scope
120
161
 
@@ -129,13 +170,21 @@ module TensorStream
129
170
 
130
171
  add_to_collection(GraphKeys::GLOBAL_VARIABLES, node)
131
172
  add_to_collection(GraphKeys::TRAINABLE_VARIABLES, node) if node.trainable?
132
- add_node(node)
173
+
174
+ node
175
+ end
176
+
177
+ def add_variable!(node, options = {})
178
+ node = add_variable(node, options)
179
+ op = Graph.get_default_graph.add_op!(:variable_v2, container: node, internal_name: node.name, shape: options[:shape], data_type: options[:data_type])
180
+ node.name = op.name
181
+ op
133
182
  end
134
183
 
135
184
  def control_dependencies(control_inputs = [])
136
185
  Thread.current["ts_graph_#{object_id}"] ||= {}
137
186
  Thread.current["ts_graph_#{object_id}"][:control_dependencies] ||= []
138
- Thread.current["ts_graph_#{object_id}"][:control_dependencies] << Operation.new(:no_op, *control_inputs)
187
+ Thread.current["ts_graph_#{object_id}"][:control_dependencies] << Graph.get_default_graph.add_op!(:no_op, *control_inputs)
139
188
  begin
140
189
  yield
141
190
  ensure