tensor_stream 0.1.1 → 0.1.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,4 +1,5 @@
1
1
  module TensorStream
2
+ # Class that defines all available ops supported by TensorStream
2
3
  module Ops
3
4
  FLOATING_POINT_TYPES = %w[float32 float64].map(&:to_sym)
4
5
  NUMERIC_TYPES = %w[int32 int64 float32 float64].map(&:to_sym)
@@ -7,27 +8,26 @@ module TensorStream
7
8
  op(:argmax, input, nil, axis: axis, name: name, dimension: dimension, data_type: output_type)
8
9
  end
9
10
 
10
- def gradients(ys, xs, grad_ys: nil,
11
- name: 'gradients',
12
- colocate_gradients_with_ops: false,
13
- gate_gradients: false,
14
- aggregation_method: nil,
15
- stop_gradients: nil
16
- )
17
-
18
- gs = xs.collect do |x|
19
- fail "#{x} passed is not a tensor object" unless x.is_a?(Tensor)
11
+ def gradients(input, wrt_xs, grad_ys: nil,
12
+ name: 'gradients',
13
+ colocate_gradients_with_ops: false,
14
+ gate_gradients: false,
15
+ aggregation_method: nil,
16
+ stop_gradients: nil)
17
+
18
+ gs = wrt_xs.collect do |x|
19
+ raise "#{x} passed is not a tensor object" unless x.is_a?(Tensor)
20
20
 
21
21
  stops = stop_gradients ? stop_gradients.map(&:name).join('_') : ''
22
- gradient_program_name = "grad_#{ys.name}_#{x.name}_#{stops}".to_sym
22
+ gradient_program_name = "grad_#{input.name}_#{x.name}_#{stops}".to_sym
23
23
 
24
- tensor_program = if ys.graph.node_added?(gradient_program_name)
25
- ys.graph.get_node(gradient_program_name)
24
+ tensor_program = if input.graph.node_added?(gradient_program_name)
25
+ input.graph.get_node(gradient_program_name)
26
26
  else
27
- derivative_ops = TensorStream::MathGradients.derivative(ys, x, graph: ys.graph,
28
- stop_gradients: stop_gradients)
27
+ derivative_ops = TensorStream::MathGradients.derivative(input, x, graph: input.graph,
28
+ stop_gradients: stop_gradients)
29
29
  unit_matrix = op(:ones_like, x)
30
- ys.graph.add_node!(gradient_program_name, unit_matrix * derivative_ops)
30
+ input.graph.add_node!(gradient_program_name, unit_matrix * derivative_ops)
31
31
  end
32
32
  tensor_program
33
33
  end
@@ -35,12 +35,12 @@ module TensorStream
35
35
  end
36
36
 
37
37
  def random_uniform(shape, dtype: :float32, minval: 0, maxval: 1, seed: nil, name: nil)
38
- options = {shape: shape, dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name}
38
+ options = { shape: shape, dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name }
39
39
  op(:random_uniform, nil, nil, options)
40
40
  end
41
41
 
42
42
  def random_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
43
- options = {shape: shape, dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name}
43
+ options = { shape: shape, dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
44
44
  op(:random_normal, nil, nil, options)
45
45
  end
46
46
 
@@ -53,7 +53,7 @@ module TensorStream
53
53
  end
54
54
 
55
55
  def shape(input, name: nil, out_type: :int32)
56
- op(:shape, input, nil, name: name)
56
+ op(:shape, input, nil, name: name, out_type: out_type)
57
57
  end
58
58
 
59
59
  def rank(input, name: nil)
@@ -63,35 +63,35 @@ module TensorStream
63
63
  def zeros_initializer(options = {})
64
64
  op(:zeros, nil, nil, options)
65
65
  end
66
-
66
+
67
67
  def slice(input, start, size, name: nil)
68
68
  op(:slice, input, start, size: size, name: name)
69
69
  end
70
-
70
+
71
71
  def zeros(shape, dtype: :float32, name: nil)
72
72
  op(:zeros, shape, nil, data_type: dtype, name: name)
73
73
  end
74
-
74
+
75
75
  def ones(shape, dtype: :float32, name: nil)
76
76
  op(:ones, shape, nil, data_type: dtype, name: name)
77
77
  end
78
-
79
- def less(a, b, name: nil)
80
- op(:less, a, b, name: name)
78
+
79
+ def less(input_a, input_b, name: nil)
80
+ op(:less, input_a, input_b, name: name)
81
81
  end
82
-
83
- def greater(a, b, name: nil)
84
- op(:greater, a, b, name: name)
82
+
83
+ def greater(input_a, input_b, name: nil)
84
+ op(:greater, input_a, input_b, name: name)
85
85
  end
86
86
 
87
- def greater_equal(a, b, name: nil)
88
- op(:greater_equal, a, b, name: name)
87
+ def greater_equal(input_a, input_b, name: nil)
88
+ op(:greater_equal, input_a, input_b, name: name)
89
89
  end
90
90
 
91
- def less_equal(a, b, name: nil)
92
- op(:less_equal, a, b, name: name)
91
+ def less_equal(input_a, input_b, name: nil)
92
+ op(:less_equal, input_a, input_b, name: name)
93
93
  end
94
-
94
+
95
95
  def reduce_mean(input_tensor, axis = nil, keepdims: false, name: nil)
96
96
  op(:reduce_mean, input_tensor, nil, axis: axis, keepdims: keepdims, name: name)
97
97
  end
@@ -99,66 +99,66 @@ module TensorStream
99
99
  def reduce_sum(input_tensor, axis = nil, keepdims: false, name: nil)
100
100
  op(:reduce_sum, input_tensor, nil, axis: axis, keepdims: keepdims, name: name)
101
101
  end
102
-
102
+
103
103
  def reduce_prod(input, axis = nil, keepdims: false, name: nil)
104
104
  op(:reduce_prod, input, nil, axis: axis, keepdims: keepdims, name: name)
105
105
  end
106
-
106
+
107
107
  def concat(values, axis, name: 'concat')
108
108
  op(:concat, values, nil, axis: axis, name: name)
109
109
  end
110
-
110
+
111
111
  def reshape(tensor, shape, name: nil)
112
112
  op(:reshape, tensor, shape, name: name)
113
113
  end
114
-
114
+
115
115
  def square(tensor, name: nil)
116
116
  op(:square, tensor, nil, name: name)
117
117
  end
118
-
118
+
119
119
  def cond(pred, true_fn, false_fn, name: nil)
120
120
  op(:cond, true_fn, false_fn, pred: pred, name: name)
121
121
  end
122
122
 
123
- def where(condition, x = nil, y = nil, name: nil)
124
- op(:where, x, y, pred: condition)
123
+ def where(condition, true_t = nil, false_t = nil, name: nil)
124
+ op(:where, true_t, false_t, pred: condition, name: name)
125
125
  end
126
-
127
- def add(a, b, name: nil)
128
- op(:add, a, b, name: name)
126
+
127
+ def add(input_a, input_b, name: nil)
128
+ op(:add, input_a, input_b, name: name)
129
129
  end
130
-
131
- def sub(a, b, name: nil)
132
- op(:sub, a, b, name: name)
130
+
131
+ def sub(input_a, input_b, name: nil)
132
+ op(:sub, input_a, input_b, name: name)
133
133
  end
134
134
 
135
- def max(a, b, name: nil)
136
- check_allowed_types(a, NUMERIC_TYPES)
137
- check_allowed_types(b, NUMERIC_TYPES)
135
+ def max(input_a, input_b, name: nil)
136
+ check_allowed_types(input_a, NUMERIC_TYPES)
137
+ check_allowed_types(input_b, NUMERIC_TYPES)
138
138
 
139
- op(:max, a, b, name: name)
139
+ op(:max, input_a, input_b, name: name)
140
140
  end
141
141
 
142
142
  def cast(input, dtype, name: nil)
143
143
  op(:cast, input, nil, data_type: dtype, name: name)
144
144
  end
145
-
145
+
146
146
  def print(input, data, message: nil, name: nil)
147
147
  op(:print, input, data, message: message, name: name)
148
148
  end
149
-
150
- def negate(a, options = {})
151
- op(:negate, a, nil, options)
149
+
150
+ def negate(input, options = {})
151
+ op(:negate, input, nil, options)
152
152
  end
153
-
154
- def equal(a, b, name: nil)
155
- op(:equal, a, b, name: name)
153
+
154
+ def equal(input_a, input_b, name: nil)
155
+ op(:equal, input_a, input_b, name: name)
156
156
  end
157
157
 
158
- def not_equal(a, b, name: nil)
159
- op(:not_equal, a, b, name: name)
158
+ def not_equal(input_a, input_b, name: nil)
159
+ op(:not_equal, input_a, input_b, name: name)
160
160
  end
161
-
161
+
162
162
  def zeros_like(tensor, dtype: nil, name: nil)
163
163
  op(:zeros_like, tensor, nil, data_type: dtype, name: name)
164
164
  end
@@ -166,75 +166,75 @@ module TensorStream
166
166
  def ones_like(tensor, dtype: nil, name: nil)
167
167
  op(:ones_like, tensor, nil, data_type: dtype, name: name)
168
168
  end
169
-
169
+
170
170
  def identity(input, name: nil)
171
171
  op(:identity, input, nil, name: name)
172
172
  end
173
-
174
- def multiply(a, b, name: nil)
175
- op(:mul, a, b, name: name)
173
+
174
+ def multiply(input_a, input_b, name: nil)
175
+ op(:mul, input_a, input_b, name: name)
176
176
  end
177
-
178
- def pow(a, e, name: nil)
179
- op(:pow, a, e, name: name)
177
+
178
+ def pow(input_a, input_e, name: nil)
179
+ op(:pow, input_a, input_e, name: name)
180
180
  end
181
-
182
- def abs(x, name: nil)
183
- op(:abs, x, nil, name: name)
181
+
182
+ def abs(input, name: nil)
183
+ op(:abs, input, nil, name: name)
184
184
  end
185
-
186
- def sign(x, name: nil)
187
- op(:sign, x, nil, name: name)
185
+
186
+ def sign(input, name: nil)
187
+ op(:sign, input, nil, name: name)
188
188
  end
189
-
190
- def sin(a, options = {})
189
+
190
+ def sin(input, options = {})
191
191
  options[:data_type] ||= :float32
192
- check_allowed_types(a, FLOATING_POINT_TYPES)
193
- op(:sin, a, nil, options)
192
+ check_allowed_types(input, FLOATING_POINT_TYPES)
193
+ op(:sin, input, nil, options)
194
194
  end
195
-
196
- def cos(a, options = {})
195
+
196
+ def cos(input, options = {})
197
197
  options[:data_type] ||= :float32
198
- check_allowed_types(a, FLOATING_POINT_TYPES)
199
- op(:cos, a, nil, options)
198
+ check_allowed_types(input, FLOATING_POINT_TYPES)
199
+ op(:cos, input, nil, options)
200
200
  end
201
-
202
- def tan(a, options = {})
201
+
202
+ def tan(input, options = {})
203
203
  options[:data_type] ||= :float32
204
- check_allowed_types(a, FLOATING_POINT_TYPES)
205
- op(:tan, a, nil, options)
204
+ check_allowed_types(input, FLOATING_POINT_TYPES)
205
+ op(:tan, input, nil, options)
206
206
  end
207
-
208
- def tanh(a, options = {})
207
+
208
+ def tanh(input, options = {})
209
209
  options[:data_type] ||= :float32
210
- check_allowed_types(a, FLOATING_POINT_TYPES)
211
- op(:tanh, a, nil, options)
210
+ check_allowed_types(input, FLOATING_POINT_TYPES)
211
+ op(:tanh, input, nil, options)
212
212
  end
213
-
214
- def sqrt(a, name: nil)
213
+
214
+ def sqrt(input, name: nil)
215
215
  options = {
216
- data_type: a.data_type,
216
+ data_type: input.data_type,
217
217
  name: name
218
218
  }
219
- check_allowed_types(a, FLOATING_POINT_TYPES)
220
- op(:sqrt, a, nil, options)
219
+ check_allowed_types(input, FLOATING_POINT_TYPES)
220
+ op(:sqrt, input, nil, options)
221
221
  end
222
-
223
- def log(input, options= {})
222
+
223
+ def log(input, options = {})
224
224
  options[:data_type] ||= :float32
225
225
  check_allowed_types(input, FLOATING_POINT_TYPES)
226
226
  op(:log, input, nil, options)
227
227
  end
228
-
229
- def exp(a, options = {})
228
+
229
+ def exp(input, options = {})
230
230
  options[:data_type] ||= :float32
231
- check_allowed_types(a, FLOATING_POINT_TYPES)
232
- op(:exp, a, nil, options)
231
+ check_allowed_types(input, FLOATING_POINT_TYPES)
232
+ op(:exp, input, nil, options)
233
233
  end
234
234
 
235
235
  def matmul(input_a, input_b, transpose_a: false,
236
- transpose_b: false,
237
- name: nil)
236
+ transpose_b: false,
237
+ name: nil)
238
238
  op(:matmul, input_a, input_b, transpose_a: transpose_a, transpose_b: transpose_b, name: name)
239
239
  end
240
240
 
@@ -246,4 +246,4 @@ module TensorStream
246
246
  op(:pad, tensor, nil, paddings: paddings, mode: mode, name: name)
247
247
  end
248
248
  end
249
- end
249
+ end
@@ -1,13 +1,16 @@
1
1
  module TensorStream
2
+ # Class that defines a TensorStream placeholder
2
3
  class Placeholder < Tensor
3
4
  def initialize(data_type, rank, shape, options = {})
5
+ @graph = options[:graph] || TensorStream.get_default_graph
6
+
4
7
  @data_type = data_type
5
8
  @rank = rank
6
9
  @shape = TensorShape.new(shape, rank)
7
10
  @value = nil
8
11
  @is_const = false
9
- @source = set_source(caller_locations)
10
- @graph = options[:graph] || TensorStream.get_default_graph
12
+ @source = format_source(caller_locations)
13
+
11
14
  @name = options[:name] || build_name
12
15
  @graph.add_node(self)
13
16
  end
@@ -15,7 +18,7 @@ module TensorStream
15
18
  private
16
19
 
17
20
  def build_name
18
- "Placeholder#{Tensor.placeholder_name}:#{@rank}"
21
+ "Placeholder#{graph.get_placeholder_counter}:#{@rank}"
19
22
  end
20
23
  end
21
- end
24
+ end
@@ -1,5 +1,7 @@
1
1
  module TensorStream
2
+ # TensorStream class that defines a session
2
3
  class Session
4
+ attr_reader :last_session_context
3
5
  def initialize(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor)
4
6
  @evaluator_class = Object.const_get("TensorStream::Evaluator::#{camelize(evaluator.to_s)}")
5
7
  @thread_pool = thread_pool_class.new
@@ -9,25 +11,21 @@ module TensorStream
9
11
  @session ||= Session.new
10
12
  end
11
13
 
12
- def last_session_context
13
- @last_session_context
14
- end
15
-
16
14
  def run(*args)
17
15
  options = if args.last.is_a?(Hash)
18
- args.pop
19
- else
20
- {}
21
- end
16
+ args.pop
17
+ else
18
+ {}
19
+ end
22
20
  context = {}
23
21
 
24
22
  # scan for placeholders and assign value
25
- options[:feed_dict].keys.each do |k|
26
- if k.is_a?(Placeholder)
27
- context[k.name.to_sym] = options[:feed_dict][k]
23
+ if options[:feed_dict]
24
+ options[:feed_dict].keys.each do |k|
25
+ context[k.name.to_sym] = options[:feed_dict][k] if k.is_a?(Placeholder)
28
26
  end
29
- end if options[:feed_dict]
30
-
27
+ end
28
+
31
29
  evaluator = @evaluator_class.new(self, context.merge!(retain: options[:retain]), thread_pool: @thread_pool)
32
30
 
33
31
  execution_context = {}
@@ -37,16 +35,16 @@ module TensorStream
37
35
  end
38
36
 
39
37
  def dump_internal_ops(tensor)
40
- dump_ops(tensor, ->(k, n) { n.is_a?(Tensor) && n.internal? } )
38
+ dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && n.internal? })
41
39
  end
42
40
 
43
41
  def dump_user_ops(tensor)
44
- dump_ops(tensor, ->(k, n) { n.is_a?(Tensor) && !n.internal? } )
42
+ dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && !n.internal? })
45
43
  end
46
44
 
47
45
  def dump_ops(tensor, selector)
48
46
  graph = tensor.graph
49
- graph.nodes.select { |k,v| selector.call(k, v) }.collect do |k, node|
47
+ graph.nodes.select { |k, v| selector.call(k, v) }.collect do |k, node|
50
48
  next unless @last_session_context[node.name]
51
49
  "#{k} #{node.to_math(true, 1)} = #{@last_session_context[node.name]}"
52
50
  end.compact
@@ -55,12 +53,12 @@ module TensorStream
55
53
  private
56
54
 
57
55
  def camelize(string, uppercase_first_letter = true)
58
- if uppercase_first_letter
59
- string = string.sub(/^[a-z\d]*/) { $&.capitalize }
60
- else
61
- string = string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase }
62
- end
56
+ string = if uppercase_first_letter
57
+ string.sub(/^[a-z\d]*/) { $&.capitalize }
58
+ else
59
+ string.sub(/^(?:(?=\b|[A-Z_])|\w)/) { $&.downcase }
60
+ end
63
61
  string.gsub(/(?:_|(\/))([a-z\d]*)/) { "#{$1}#{$2.capitalize}" }.gsub('/', '::')
64
62
  end
65
63
  end
66
- end
64
+ end
@@ -1,48 +1,19 @@
1
1
  require 'ostruct'
2
2
 
3
3
  module TensorStream
4
+ # Base class that defines a tensor like interface
4
5
  class Tensor
5
6
  include OpHelper
6
7
 
7
8
  attr_accessor :name, :data_type, :shape, :rank, :native_buffer, :is_const, :value, :breakpoint, :internal, :source, :given_name, :graph
8
9
 
9
- def self.const_name
10
- @const_counter ||= 0
11
-
12
- name = if @const_counter == 0
13
- ""
14
- else
15
- "_#{@const_counter}"
16
- end
17
-
18
- @const_counter += 1
19
-
20
- name
21
- end
22
-
23
- def self.var_name
24
- @var_counter ||= 0
25
- @var_counter += 1
26
-
27
- return "" if @var_counter == 1
28
- return "_#{@var_counter}"
29
- end
30
-
31
- def self.placeholder_name
32
- @placeholder_counter ||= 0
33
- @placeholder_counter += 1
34
-
35
- return "" if @placeholder_counter == 1
36
- return "_#{@placeholder_counter}"
37
- end
38
-
39
10
  def initialize(data_type, rank, shape, options = {})
40
11
  @data_type = data_type
41
12
  @rank = rank
42
13
  @breakpoint = false
43
14
  @shape = TensorShape.new(shape, rank)
44
15
  @value = nil
45
- @source = set_source(caller_locations)
16
+ @source = format_source(caller_locations)
46
17
  @is_const = options[:const] || false
47
18
  @internal = options[:internal]
48
19
  @graph = options[:graph] || TensorStream.get_default_graph
@@ -50,16 +21,14 @@ module TensorStream
50
21
  @given_name = @name
51
22
 
52
23
  if options[:value]
53
- if options[:value].kind_of?(Array)
24
+ if options[:value].is_a?(Array)
54
25
  # check if single dimenstion array is passed
55
- if shape.size >= 2 && options[:value].size > 0 && !options[:value][0].kind_of?(Array)
56
- options[:value] = reshape(options[:value], shape.reverse.dup)
57
- end
26
+ options[:value] = reshape(options[:value], shape.reverse.dup) if shape.size >= 2 && !options[:value].empty? && !options[:value][0].is_a?(Array)
58
27
 
59
28
  @value = options[:value].collect do |v|
60
- v.kind_of?(Tensor) ? Tensor.cast_dtype(v, data_type) : v
29
+ v.is_a?(Tensor) ? Tensor.cast_dtype(v, data_type) : v
61
30
  end
62
- elsif shape.size > 0
31
+ elsif !shape.empty?
63
32
  @value = reshape(Tensor.cast_dtype(options[:value], @data_type), shape.dup)
64
33
  else
65
34
  @value = Tensor.cast_dtype(options[:value], @data_type)
@@ -83,66 +52,56 @@ module TensorStream
83
52
  @placeholder_counter = 0
84
53
  end
85
54
 
86
- def build_buffer
87
- @native_buffer = if @data_type == :float32 && @rank == 2
88
- NArray.sfloat(@shape.cols * @shape.rows)
89
- elsif @data_type == :float32 && @rank == 0
90
- NArray.sfloat(1)
91
- else
92
- raise "Invalid data type #{@data_type}"
93
- end
94
- end
95
-
96
- def +(operand)
97
- TensorStream::Operation.new(:add, self, auto_wrap(operand))
55
+ def +(other)
56
+ TensorStream::Operation.new(:add, self, auto_wrap(other))
98
57
  end
99
58
 
100
59
  def [](index)
101
60
  TensorStream::Operation.new(:index, self, index)
102
61
  end
103
62
 
104
- def *(operand)
105
- TensorStream::Operation.new(:mul, self, auto_wrap(operand))
63
+ def *(other)
64
+ TensorStream::Operation.new(:mul, self, auto_wrap(other))
106
65
  end
107
66
 
108
- def **(operand)
109
- TensorStream::Operation.new(:pow, self, auto_wrap(operand))
67
+ def **(other)
68
+ TensorStream::Operation.new(:pow, self, auto_wrap(other))
110
69
  end
111
70
 
112
- def /(operand)
113
- TensorStream::Operation.new(:div, self, auto_wrap(operand))
71
+ def /(other)
72
+ TensorStream::Operation.new(:div, self, auto_wrap(other))
114
73
  end
115
74
 
116
- def -(operand)
117
- TensorStream::Operation.new(:sub, self, auto_wrap(operand))
75
+ def -(other)
76
+ TensorStream::Operation.new(:sub, self, auto_wrap(other))
118
77
  end
119
78
 
120
79
  def -@
121
80
  TensorStream::Operation.new(:negate, self, nil)
122
81
  end
123
82
 
124
- def ==(operand)
125
- op(:equal, self, operand)
83
+ def ==(other)
84
+ op(:equal, self, other)
126
85
  end
127
86
 
128
- def <(operand)
129
- op(:less, self, operand)
87
+ def <(other)
88
+ op(:less, self, other)
130
89
  end
131
90
 
132
- def !=(operand)
133
- op(:not_equal, self, operand)
91
+ def !=(other)
92
+ op(:not_equal, self, other)
134
93
  end
135
94
 
136
- def >(operand)
137
- op(:greater, self, operand)
95
+ def >(other)
96
+ op(:greater, self, other)
138
97
  end
139
98
 
140
- def >=(operand)
141
- op(:greater_equal, self, operand)
99
+ def >=(other)
100
+ op(:greater_equal, self, other)
142
101
  end
143
102
 
144
- def <=(operand)
145
- op(:less_equal, self, operand)
103
+ def <=(other)
104
+ op(:less_equal, self, other)
146
105
  end
147
106
 
148
107
  def collect(&block)
@@ -153,23 +112,6 @@ module TensorStream
153
112
  @name
154
113
  end
155
114
 
156
- # def to_ary
157
- # if rank == 2
158
- # @native_buffer.to_a.each_slice(shape.cols).collect { |slice| slice }
159
- # else
160
- # raise "Invalid rank"
161
- # end
162
- # end
163
-
164
- # open cl methods
165
- def open_cl_buffer(context)
166
- @cl_buffer ||= context.create_buffer(@native_buffer.size * @native_buffer.element_size, :flags => OpenCL::Mem::COPY_HOST_PTR, :host_ptr => @native_buffer)
167
- end
168
-
169
- def sync_cl_buffer(queue, events = [])
170
- queue.enqueue_read_buffer(@cl_buffer, @native_buffer, :event_wait_list => events)
171
- end
172
-
173
115
  def eval(options = {})
174
116
  Session.default_session.run(self, options)
175
117
  end
@@ -201,21 +143,21 @@ module TensorStream
201
143
  end
202
144
 
203
145
  def to_math(name_only = false, max_depth = 99)
204
- return @name if max_depth==0 || name_only || @value.nil?
205
-
206
- if @value.kind_of?(Array)
207
- @value.collect { |v| v.kind_of?(Tensor) ? v.to_math(name_only, max_depth - 1) : v }
146
+ return @name if max_depth.zero? || name_only || @value.nil?
147
+
148
+ if @value.is_a?(Array)
149
+ @value.collect { |v| v.is_a?(Tensor) ? v.to_math(name_only, max_depth - 1) : v }
208
150
  else
209
151
  is_const ? @value : @name
210
152
  end
211
153
  end
212
154
 
213
155
  def auto_math(tensor, name_only = false, max_depth = 99)
214
- tensor.kind_of?(Tensor) ? tensor.to_math(name_only, max_depth) : tensor
156
+ tensor.is_a?(Tensor) ? tensor.to_math(name_only, max_depth) : tensor
215
157
  end
216
158
 
217
159
  def self.detect_type(value)
218
- dtype = if value.is_a?(String)
160
+ if value.is_a?(String)
219
161
  :string
220
162
  elsif value.is_a?(Float)
221
163
  :float32
@@ -257,7 +199,7 @@ module TensorStream
257
199
  when :unknown
258
200
  val
259
201
  else
260
- fail "unknown data_type #{dtype} passed"
202
+ raise "unknown data_type #{dtype} passed"
261
203
  end
262
204
  end
263
205
 
@@ -268,15 +210,15 @@ module TensorStream
268
210
 
269
211
  protected
270
212
 
271
- def set_source(trace)
272
- trace.reject { |c| c.to_s.include?(File.join("lib","tensor_stream")) }.first
213
+ def format_source(trace)
214
+ trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
273
215
  end
274
216
 
275
217
  def hashify_tensor(tensor)
276
- if tensor.kind_of?(Tensor)
218
+ if tensor.is_a?(Tensor)
277
219
  tensor.to_h
278
- elsif tensor.kind_of?(Array)
279
- tensor.collect do |t| hashify_tensor(t) end
220
+ elsif tensor.is_a?(Array)
221
+ tensor.collect { |t| hashify_tensor(t) }
280
222
  else
281
223
  tensor
282
224
  end
@@ -290,11 +232,11 @@ module TensorStream
290
232
  reshape(s, shape)
291
233
  end
292
234
  else
293
- return arr if shape.size < 1
235
+ return arr if shape.empty?
294
236
  slice = shape.shift
295
237
  return arr if slice.nil?
296
238
 
297
- slice.times.collect do |s|
239
+ Array.new(slice) do
298
240
  reshape(arr, shape.dup)
299
241
  end
300
242
  end
@@ -311,7 +253,7 @@ module TensorStream
311
253
  end
312
254
 
313
255
  def build_name
314
- "#{@is_const ? "Const#{Tensor.const_name}:#{@rank}" : "Variable#{Tensor.var_name}:#{@rank}"}"
256
+ @is_const ? "Const#{graph.get_const_counter}:#{@rank}" : "Variable#{graph.get_var_counter}:#{@rank}"
315
257
  end
316
258
  end
317
259
  end