tensor_stream 0.1.1 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
1
  module TensorStream
2
+ # Class that defines all available ops supported by TensorStream
2
3
  module Ops
3
4
  FLOATING_POINT_TYPES = %w[float32 float64].map(&:to_sym)
4
5
  NUMERIC_TYPES = %w[int32 int64 float32 float64].map(&:to_sym)
@@ -7,27 +8,26 @@ module TensorStream
7
8
  op(:argmax, input, nil, axis: axis, name: name, dimension: dimension, data_type: output_type)
8
9
  end
9
10
 
10
- def gradients(ys, xs, grad_ys: nil,
11
- name: 'gradients',
12
- colocate_gradients_with_ops: false,
13
- gate_gradients: false,
14
- aggregation_method: nil,
15
- stop_gradients: nil
16
- )
17
-
18
- gs = xs.collect do |x|
19
- fail "#{x} passed is not a tensor object" unless x.is_a?(Tensor)
11
+ def gradients(input, wrt_xs, grad_ys: nil,
12
+ name: 'gradients',
13
+ colocate_gradients_with_ops: false,
14
+ gate_gradients: false,
15
+ aggregation_method: nil,
16
+ stop_gradients: nil)
17
+
18
+ gs = wrt_xs.collect do |x|
19
+ raise "#{x} passed is not a tensor object" unless x.is_a?(Tensor)
20
20
 
21
21
  stops = stop_gradients ? stop_gradients.map(&:name).join('_') : ''
22
- gradient_program_name = "grad_#{ys.name}_#{x.name}_#{stops}".to_sym
22
+ gradient_program_name = "grad_#{input.name}_#{x.name}_#{stops}".to_sym
23
23
 
24
- tensor_program = if ys.graph.node_added?(gradient_program_name)
25
- ys.graph.get_node(gradient_program_name)
24
+ tensor_program = if input.graph.node_added?(gradient_program_name)
25
+ input.graph.get_node(gradient_program_name)
26
26
  else
27
- derivative_ops = TensorStream::MathGradients.derivative(ys, x, graph: ys.graph,
28
- stop_gradients: stop_gradients)
27
+ derivative_ops = TensorStream::MathGradients.derivative(input, x, graph: input.graph,
28
+ stop_gradients: stop_gradients)
29
29
  unit_matrix = op(:ones_like, x)
30
- ys.graph.add_node!(gradient_program_name, unit_matrix * derivative_ops)
30
+ input.graph.add_node!(gradient_program_name, unit_matrix * derivative_ops)
31
31
  end
32
32
  tensor_program
33
33
  end
@@ -35,12 +35,12 @@ module TensorStream
35
35
  end
36
36
 
37
37
  def random_uniform(shape, dtype: :float32, minval: 0, maxval: 1, seed: nil, name: nil)
38
- options = {shape: shape, dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name}
38
+ options = { shape: shape, dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name }
39
39
  op(:random_uniform, nil, nil, options)
40
40
  end
41
41
 
42
42
  def random_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
43
- options = {shape: shape, dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name}
43
+ options = { shape: shape, dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
44
44
  op(:random_normal, nil, nil, options)
45
45
  end
46
46
 
@@ -53,7 +53,7 @@ module TensorStream
53
53
  end
54
54
 
55
55
  def shape(input, name: nil, out_type: :int32)
56
- op(:shape, input, nil, name: name)
56
+ op(:shape, input, nil, name: name, out_type: out_type)
57
57
  end
58
58
 
59
59
  def rank(input, name: nil)
@@ -63,35 +63,35 @@ module TensorStream
63
63
  def zeros_initializer(options = {})
64
64
  op(:zeros, nil, nil, options)
65
65
  end
66
-
66
+
67
67
  def slice(input, start, size, name: nil)
68
68
  op(:slice, input, start, size: size, name: name)
69
69
  end
70
-
70
+
71
71
  def zeros(shape, dtype: :float32, name: nil)
72
72
  op(:zeros, shape, nil, data_type: dtype, name: name)
73
73
  end
74
-
74
+
75
75
  def ones(shape, dtype: :float32, name: nil)
76
76
  op(:ones, shape, nil, data_type: dtype, name: name)
77
77
  end
78
-
79
- def less(a, b, name: nil)
80
- op(:less, a, b, name: name)
78
+
79
+ def less(input_a, input_b, name: nil)
80
+ op(:less, input_a, input_b, name: name)
81
81
  end
82
-
83
- def greater(a, b, name: nil)
84
- op(:greater, a, b, name: name)
82
+
83
+ def greater(input_a, input_b, name: nil)
84
+ op(:greater, input_a, input_b, name: name)
85
85
  end
86
86
 
87
- def greater_equal(a, b, name: nil)
88
- op(:greater_equal, a, b, name: name)
87
+ def greater_equal(input_a, input_b, name: nil)
88
+ op(:greater_equal, input_a, input_b, name: name)
89
89
  end
90
90
 
91
- def less_equal(a, b, name: nil)
92
- op(:less_equal, a, b, name: name)
91
+ def less_equal(input_a, input_b, name: nil)
92
+ op(:less_equal, input_a, input_b, name: name)
93
93
  end
94
-
94
+
95
95
  def reduce_mean(input_tensor, axis = nil, keepdims: false, name: nil)
96
96
  op(:reduce_mean, input_tensor, nil, axis: axis, keepdims: keepdims, name: name)
97
97
  end
@@ -99,66 +99,66 @@ module TensorStream
99
99
  def reduce_sum(input_tensor, axis = nil, keepdims: false, name: nil)
100
100
  op(:reduce_sum, input_tensor, nil, axis: axis, keepdims: keepdims, name: name)
101
101
  end
102
-
102
+
103
103
  def reduce_prod(input, axis = nil, keepdims: false, name: nil)
104
104
  op(:reduce_prod, input, nil, axis: axis, keepdims: keepdims, name: name)
105
105
  end
106
-
106
+
107
107
  def concat(values, axis, name: 'concat')
108
108
  op(:concat, values, nil, axis: axis, name: name)
109
109
  end
110
-
110
+
111
111
  def reshape(tensor, shape, name: nil)
112
112
  op(:reshape, tensor, shape, name: name)
113
113
  end
114
-
114
+
115
115
  def square(tensor, name: nil)
116
116
  op(:square, tensor, nil, name: name)
117
117
  end
118
-
118
+
119
119
  def cond(pred, true_fn, false_fn, name: nil)
120
120
  op(:cond, true_fn, false_fn, pred: pred, name: name)
121
121
  end
122
122
 
123
- def where(condition, x = nil, y = nil, name: nil)
124
- op(:where, x, y, pred: condition)
123
+ def where(condition, true_t = nil, false_t = nil, name: nil)
124
+ op(:where, true_t, false_t, pred: condition, name: name)
125
125
  end
126
-
127
- def add(a, b, name: nil)
128
- op(:add, a, b, name: name)
126
+
127
+ def add(input_a, input_b, name: nil)
128
+ op(:add, input_a, input_b, name: name)
129
129
  end
130
-
131
- def sub(a, b, name: nil)
132
- op(:sub, a, b, name: name)
130
+
131
+ def sub(input_a, input_b, name: nil)
132
+ op(:sub, input_a, input_b, name: name)
133
133
  end
134
134
 
135
- def max(a, b, name: nil)
136
- check_allowed_types(a, NUMERIC_TYPES)
137
- check_allowed_types(b, NUMERIC_TYPES)
135
+ def max(input_a, input_b, name: nil)
136
+ check_allowed_types(input_a, NUMERIC_TYPES)
137
+ check_allowed_types(input_b, NUMERIC_TYPES)
138
138
 
139
- op(:max, a, b, name: name)
139
+ op(:max, input_a, input_b, name: name)
140
140
  end
141
141
 
142
142
  def cast(input, dtype, name: nil)
143
143
  op(:cast, input, nil, data_type: dtype, name: name)
144
144
  end
145
-
145
+
146
146
  def print(input, data, message: nil, name: nil)
147
147
  op(:print, input, data, message: message, name: name)
148
148
  end
149
-
150
- def negate(a, options = {})
151
- op(:negate, a, nil, options)
149
+
150
+ def negate(input, options = {})
151
+ op(:negate, input, nil, options)
152
152
  end
153
-
154
- def equal(a, b, name: nil)
155
- op(:equal, a, b, name: name)
153
+
154
+ def equal(input_a, input_b, name: nil)
155
+ op(:equal, input_a, input_b, name: name)
156
156
  end
157
157
 
158
- def not_equal(a, b, name: nil)
159
- op(:not_equal, a, b, name: name)
158
+ def not_equal(input_a, input_b, name: nil)
159
+ op(:not_equal, input_a, input_b, name: name)
160
160
  end
161
-
161
+
162
162
  def zeros_like(tensor, dtype: nil, name: nil)
163
163
  op(:zeros_like, tensor, nil, data_type: dtype, name: name)
164
164
  end
@@ -166,75 +166,75 @@ module TensorStream
166
166
  def ones_like(tensor, dtype: nil, name: nil)
167
167
  op(:ones_like, tensor, nil, data_type: dtype, name: name)
168
168
  end
169
-
169
+
170
170
  def identity(input, name: nil)
171
171
  op(:identity, input, nil, name: name)
172
172
  end
173
-
174
- def multiply(a, b, name: nil)
175
- op(:mul, a, b, name: name)
173
+
174
+ def multiply(input_a, input_b, name: nil)
175
+ op(:mul, input_a, input_b, name: name)
176
176
  end
177
-
178
- def pow(a, e, name: nil)
179
- op(:pow, a, e, name: name)
177
+
178
+ def pow(input_a, input_e, name: nil)
179
+ op(:pow, input_a, input_e, name: name)
180
180
  end
181
-
182
- def abs(x, name: nil)
183
- op(:abs, x, nil, name: name)
181
+
182
+ def abs(input, name: nil)
183
+ op(:abs, input, nil, name: name)
184
184
  end
185
-
186
- def sign(x, name: nil)
187
- op(:sign, x, nil, name: name)
185
+
186
+ def sign(input, name: nil)
187
+ op(:sign, input, nil, name: name)
188
188
  end
189
-
190
- def sin(a, options = {})
189
+
190
+ def sin(input, options = {})
191
191
  options[:data_type] ||= :float32
192
- check_allowed_types(a, FLOATING_POINT_TYPES)
193
- op(:sin, a, nil, options)
192
+ check_allowed_types(input, FLOATING_POINT_TYPES)
193
+ op(:sin, input, nil, options)
194
194
  end
195
-
196
- def cos(a, options = {})
195
+
196
+ def cos(input, options = {})
197
197
  options[:data_type] ||= :float32
198
- check_allowed_types(a, FLOATING_POINT_TYPES)
199
- op(:cos, a, nil, options)
198
+ check_allowed_types(input, FLOATING_POINT_TYPES)
199
+ op(:cos, input, nil, options)
200
200
  end
201
-
202
- def tan(a, options = {})
201
+
202
+ def tan(input, options = {})
203
203
  options[:data_type] ||= :float32
204
- check_allowed_types(a, FLOATING_POINT_TYPES)
205
- op(:tan, a, nil, options)
204
+ check_allowed_types(input, FLOATING_POINT_TYPES)
205
+ op(:tan, input, nil, options)
206
206
  end
207
-
208
- def tanh(a, options = {})
207
+
208
+ def tanh(input, options = {})
209
209
  options[:data_type] ||= :float32
210
- check_allowed_types(a, FLOATING_POINT_TYPES)
211
- op(:tanh, a, nil, options)
210
+ check_allowed_types(input, FLOATING_POINT_TYPES)
211
+ op(:tanh, input, nil, options)
212
212
  end
213
-
214
- def sqrt(a, name: nil)
213
+
214
+ def sqrt(input, name: nil)
215
215
  options = {
216
- data_type: a.data_type,
216
+ data_type: input.data_type,
217
217
  name: name
218
218
  }
219
- check_allowed_types(a, FLOATING_POINT_TYPES)
220
- op(:sqrt, a, nil, options)
219
+ check_allowed_types(input, FLOATING_POINT_TYPES)
220
+ op(:sqrt, input, nil, options)
221
221
  end
222
-
223
- def log(input, options= {})
222
+
223
+ def log(input, options = {})
224
224
  options[:data_type] ||= :float32
225
225
  check_allowed_types(input, FLOATING_POINT_TYPES)
226
226
  op(:log, input, nil, options)
227
227
  end
228
-
229
- def exp(a, options = {})
228
+
229
+ def exp(input, options = {})
230
230
  options[:data_type] ||= :float32
231
- check_allowed_types(a, FLOATING_POINT_TYPES)
232
- op(:exp, a, nil, options)
231
+ check_allowed_types(input, FLOATING_POINT_TYPES)
232
+ op(:exp, input, nil, options)
233
233
  end
234
234
 
235
235
  def matmul(input_a, input_b, transpose_a: false,
236
- transpose_b: false,
237
- name: nil)
236
+ transpose_b: false,
237
+ name: nil)
238
238
  op(:matmul, input_a, input_b, transpose_a: transpose_a, transpose_b: transpose_b, name: name)
239
239
  end
240
240
 
@@ -246,4 +246,4 @@ module TensorStream
246
246
  op(:pad, tensor, nil, paddings: paddings, mode: mode, name: name)
247
247
  end
248
248
  end
249
- end
249
+ end
@@ -1,13 +1,16 @@
1
1
  module TensorStream
2
+ # Class that defines a TensorStream placeholder
2
3
  class Placeholder < Tensor
3
4
  def initialize(data_type, rank, shape, options = {})
5
+ @graph = options[:graph] || TensorStream.get_default_graph
6
+
4
7
  @data_type = data_type
5
8
  @rank = rank
6
9
  @shape = TensorShape.new(shape, rank)
7
10
  @value = nil
8
11
  @is_const = false
9
- @source = set_source(caller_locations)
10
- @graph = options[:graph] || TensorStream.get_default_graph
12
+ @source = format_source(caller_locations)
13
+
11
14
  @name = options[:name] || build_name
12
15
  @graph.add_node(self)
13
16
  end
@@ -15,7 +18,7 @@ module TensorStream
15
18
  private
16
19
 
17
20
  def build_name
18
- "Placeholder#{Tensor.placeholder_name}:#{@rank}"
21
+ "Placeholder#{graph.get_placeholder_counter}:#{@rank}"
19
22
  end
20
23
  end
21
- end
24
+ end
@@ -1,5 +1,7 @@
1
1
  module TensorStream
2
+ # TensorStream class that defines a session
2
3
  class Session
4
+ attr_reader :last_session_context
3
5
  def initialize(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor)
4
6
  @evaluator_class = Object.const_get("TensorStream::Evaluator::#{camelize(evaluator.to_s)}")
5
7
  @thread_pool = thread_pool_class.new
@@ -9,25 +11,21 @@ module TensorStream
9
11
  @session ||= Session.new
10
12
  end
11
13
 
12
- def last_session_context
13
- @last_session_context
14
- end
15
-
16
14
  def run(*args)
17
15
  options = if args.last.is_a?(Hash)
18
- args.pop
19
- else
20
- {}
21
- end
16
+ args.pop
17
+ else
18
+ {}
19
+ end
22
20
  context = {}
23
21
 
24
22
  # scan for placeholders and assign value
25
- options[:feed_dict].keys.each do |k|
26
- if k.is_a?(Placeholder)
27
- context[k.name.to_sym] = options[:feed_dict][k]
23
+ if options[:feed_dict]
24
+ options[:feed_dict].keys.each do |k|
25
+ context[k.name.to_sym] = options[:feed_dict][k] if k.is_a?(Placeholder)
28
26
  end
29
- end if options[:feed_dict]
30
-
27
+ end
28
+
31
29
  evaluator = @evaluator_class.new(self, context.merge!(retain: options[:retain]), thread_pool: @thread_pool)
32
30
 
33
31
  execution_context = {}
@@ -37,16 +35,16 @@ module TensorStream
37
35
  end
38
36
 
39
37
  def dump_internal_ops(tensor)
40
- dump_ops(tensor, ->(k, n) { n.is_a?(Tensor) && n.internal? } )
38
+ dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && n.internal? })
41
39
  end
42
40
 
43
41
  def dump_user_ops(tensor)
44
- dump_ops(tensor, ->(k, n) { n.is_a?(Tensor) && !n.internal? } )
42
+ dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && !n.internal? })
45
43
  end
46
44
 
47
45
  def dump_ops(tensor, selector)
48
46
  graph = tensor.graph
49
- graph.nodes.select { |k,v| selector.call(k, v) }.collect do |k, node|
47
+ graph.nodes.select { |k, v| selector.call(k, v) }.collect do |k, node|
50
48
  next unless @last_session_context[node.name]
51
49
  "#{k} #{node.to_math(true, 1)} = #{@last_session_context[node.name]}"
52
50
  end.compact
@@ -55,12 +53,12 @@ module TensorStream
55
53
  private
56
54
 
57
55
  def camelize(string, uppercase_first_letter = true)
58
- if uppercase_first_letter
59
- string = string.sub(/^[a-z\d]*/) { $&.capitalize }
60
- else
61
- string = string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase }
62
- end
56
+ string = if uppercase_first_letter
57
+ string.sub(/^[a-z\d]*/) { $&.capitalize }
58
+ else
59
+ string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase }
60
+ end
63
61
  string.gsub(/(?:_|(\/))([a-z\d]*)/) { "#{$1}#{$2.capitalize}" }.gsub('/', '::')
64
62
  end
65
63
  end
66
- end
64
+ end
@@ -1,48 +1,19 @@
1
1
  require 'ostruct'
2
2
 
3
3
  module TensorStream
4
+ # Base class that defines a tensor like interface
4
5
  class Tensor
5
6
  include OpHelper
6
7
 
7
8
  attr_accessor :name, :data_type, :shape, :rank, :native_buffer, :is_const, :value, :breakpoint, :internal, :source, :given_name, :graph
8
9
 
9
- def self.const_name
10
- @const_counter ||= 0
11
-
12
- name = if @const_counter == 0
13
- ""
14
- else
15
- "_#{@const_counter}"
16
- end
17
-
18
- @const_counter += 1
19
-
20
- name
21
- end
22
-
23
- def self.var_name
24
- @var_counter ||= 0
25
- @var_counter += 1
26
-
27
- return "" if @var_counter == 1
28
- return "_#{@var_counter}"
29
- end
30
-
31
- def self.placeholder_name
32
- @placeholder_counter ||= 0
33
- @placeholder_counter += 1
34
-
35
- return "" if @placeholder_counter == 1
36
- return "_#{@placeholder_counter}"
37
- end
38
-
39
10
  def initialize(data_type, rank, shape, options = {})
40
11
  @data_type = data_type
41
12
  @rank = rank
42
13
  @breakpoint = false
43
14
  @shape = TensorShape.new(shape, rank)
44
15
  @value = nil
45
- @source = set_source(caller_locations)
16
+ @source = format_source(caller_locations)
46
17
  @is_const = options[:const] || false
47
18
  @internal = options[:internal]
48
19
  @graph = options[:graph] || TensorStream.get_default_graph
@@ -50,16 +21,14 @@ module TensorStream
50
21
  @given_name = @name
51
22
 
52
23
  if options[:value]
53
- if options[:value].kind_of?(Array)
24
+ if options[:value].is_a?(Array)
54
25
  # check if single dimenstion array is passed
55
- if shape.size >= 2 && options[:value].size > 0 && !options[:value][0].kind_of?(Array)
56
- options[:value] = reshape(options[:value], shape.reverse.dup)
57
- end
26
+ options[:value] = reshape(options[:value], shape.reverse.dup) if shape.size >= 2 && !options[:value].empty? && !options[:value][0].is_a?(Array)
58
27
 
59
28
  @value = options[:value].collect do |v|
60
- v.kind_of?(Tensor) ? Tensor.cast_dtype(v, data_type) : v
29
+ v.is_a?(Tensor) ? Tensor.cast_dtype(v, data_type) : v
61
30
  end
62
- elsif shape.size > 0
31
+ elsif !shape.empty?
63
32
  @value = reshape(Tensor.cast_dtype(options[:value], @data_type), shape.dup)
64
33
  else
65
34
  @value = Tensor.cast_dtype(options[:value], @data_type)
@@ -83,66 +52,56 @@ module TensorStream
83
52
  @placeholder_counter = 0
84
53
  end
85
54
 
86
- def build_buffer
87
- @native_buffer = if @data_type == :float32 && @rank == 2
88
- NArray.sfloat(@shape.cols * @shape.rows)
89
- elsif @data_type == :float32 && @rank == 0
90
- NArray.sfloat(1)
91
- else
92
- raise "Invalid data type #{@data_type}"
93
- end
94
- end
95
-
96
- def +(operand)
97
- TensorStream::Operation.new(:add, self, auto_wrap(operand))
55
+ def +(other)
56
+ TensorStream::Operation.new(:add, self, auto_wrap(other))
98
57
  end
99
58
 
100
59
  def [](index)
101
60
  TensorStream::Operation.new(:index, self, index)
102
61
  end
103
62
 
104
- def *(operand)
105
- TensorStream::Operation.new(:mul, self, auto_wrap(operand))
63
+ def *(other)
64
+ TensorStream::Operation.new(:mul, self, auto_wrap(other))
106
65
  end
107
66
 
108
- def **(operand)
109
- TensorStream::Operation.new(:pow, self, auto_wrap(operand))
67
+ def **(other)
68
+ TensorStream::Operation.new(:pow, self, auto_wrap(other))
110
69
  end
111
70
 
112
- def /(operand)
113
- TensorStream::Operation.new(:div, self, auto_wrap(operand))
71
+ def /(other)
72
+ TensorStream::Operation.new(:div, self, auto_wrap(other))
114
73
  end
115
74
 
116
- def -(operand)
117
- TensorStream::Operation.new(:sub, self, auto_wrap(operand))
75
+ def -(other)
76
+ TensorStream::Operation.new(:sub, self, auto_wrap(other))
118
77
  end
119
78
 
120
79
  def -@
121
80
  TensorStream::Operation.new(:negate, self, nil)
122
81
  end
123
82
 
124
- def ==(operand)
125
- op(:equal, self, operand)
83
+ def ==(other)
84
+ op(:equal, self, other)
126
85
  end
127
86
 
128
- def <(operand)
129
- op(:less, self, operand)
87
+ def <(other)
88
+ op(:less, self, other)
130
89
  end
131
90
 
132
- def !=(operand)
133
- op(:not_equal, self, operand)
91
+ def !=(other)
92
+ op(:not_equal, self, other)
134
93
  end
135
94
 
136
- def >(operand)
137
- op(:greater, self, operand)
95
+ def >(other)
96
+ op(:greater, self, other)
138
97
  end
139
98
 
140
- def >=(operand)
141
- op(:greater_equal, self, operand)
99
+ def >=(other)
100
+ op(:greater_equal, self, other)
142
101
  end
143
102
 
144
- def <=(operand)
145
- op(:less_equal, self, operand)
103
+ def <=(other)
104
+ op(:less_equal, self, other)
146
105
  end
147
106
 
148
107
  def collect(&block)
@@ -153,23 +112,6 @@ module TensorStream
153
112
  @name
154
113
  end
155
114
 
156
- # def to_ary
157
- # if rank == 2
158
- # @native_buffer.to_a.each_slice(shape.cols).collect { |slice| slice }
159
- # else
160
- # raise "Invalid rank"
161
- # end
162
- # end
163
-
164
- # open cl methods
165
- def open_cl_buffer(context)
166
- @cl_buffer ||= context.create_buffer(@native_buffer.size * @native_buffer.element_size, :flags => OpenCL::Mem::COPY_HOST_PTR, :host_ptr => @native_buffer)
167
- end
168
-
169
- def sync_cl_buffer(queue, events = [])
170
- queue.enqueue_read_buffer(@cl_buffer, @native_buffer, :event_wait_list => events)
171
- end
172
-
173
115
  def eval(options = {})
174
116
  Session.default_session.run(self, options)
175
117
  end
@@ -201,21 +143,21 @@ module TensorStream
201
143
  end
202
144
 
203
145
  def to_math(name_only = false, max_depth = 99)
204
- return @name if max_depth==0 || name_only || @value.nil?
205
-
206
- if @value.kind_of?(Array)
207
- @value.collect { |v| v.kind_of?(Tensor) ? v.to_math(name_only, max_depth - 1) : v }
146
+ return @name if max_depth.zero? || name_only || @value.nil?
147
+
148
+ if @value.is_a?(Array)
149
+ @value.collect { |v| v.is_a?(Tensor) ? v.to_math(name_only, max_depth - 1) : v }
208
150
  else
209
151
  is_const ? @value : @name
210
152
  end
211
153
  end
212
154
 
213
155
  def auto_math(tensor, name_only = false, max_depth = 99)
214
- tensor.kind_of?(Tensor) ? tensor.to_math(name_only, max_depth) : tensor
156
+ tensor.is_a?(Tensor) ? tensor.to_math(name_only, max_depth) : tensor
215
157
  end
216
158
 
217
159
  def self.detect_type(value)
218
- dtype = if value.is_a?(String)
160
+ if value.is_a?(String)
219
161
  :string
220
162
  elsif value.is_a?(Float)
221
163
  :float32
@@ -257,7 +199,7 @@ module TensorStream
257
199
  when :unknown
258
200
  val
259
201
  else
260
- fail "unknown data_type #{dtype} passed"
202
+ raise "unknown data_type #{dtype} passed"
261
203
  end
262
204
  end
263
205
 
@@ -268,15 +210,15 @@ module TensorStream
268
210
 
269
211
  protected
270
212
 
271
- def set_source(trace)
272
- trace.reject { |c| c.to_s.include?(File.join("lib","tensor_stream")) }.first
213
+ def format_source(trace)
214
+ trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
273
215
  end
274
216
 
275
217
  def hashify_tensor(tensor)
276
- if tensor.kind_of?(Tensor)
218
+ if tensor.is_a?(Tensor)
277
219
  tensor.to_h
278
- elsif tensor.kind_of?(Array)
279
- tensor.collect do |t| hashify_tensor(t) end
220
+ elsif tensor.is_a?(Array)
221
+ tensor.collect { |t| hashify_tensor(t) }
280
222
  else
281
223
  tensor
282
224
  end
@@ -290,11 +232,11 @@ module TensorStream
290
232
  reshape(s, shape)
291
233
  end
292
234
  else
293
- return arr if shape.size < 1
235
+ return arr if shape.empty?
294
236
  slice = shape.shift
295
237
  return arr if slice.nil?
296
238
 
297
- slice.times.collect do |s|
239
+ Array.new(slice) do
298
240
  reshape(arr, shape.dup)
299
241
  end
300
242
  end
@@ -311,7 +253,7 @@ module TensorStream
311
253
  end
312
254
 
313
255
  def build_name
314
- "#{@is_const ? "Const#{Tensor.const_name}:#{@rank}" : "Variable#{Tensor.var_name}:#{@rank}"}"
256
+ @is_const ? "Const#{graph.get_const_counter}:#{@rank}" : "Variable#{graph.get_var_counter}:#{@rank}"
315
257
  end
316
258
  end
317
259
  end