tensor_stream 1.0.0 → 1.0.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +1 -0
  3. data/.rubocop.yml +1 -0
  4. data/Gemfile +1 -1
  5. data/LICENSE.txt +1 -1
  6. data/README.md +34 -34
  7. data/Rakefile +3 -3
  8. data/USAGE_GUIDE.md +235 -0
  9. data/bin/stubgen +20 -0
  10. data/exe/model_utils +2 -2
  11. data/lib/tensor_stream.rb +45 -44
  12. data/lib/tensor_stream/constant.rb +2 -2
  13. data/lib/tensor_stream/control_flow.rb +1 -1
  14. data/lib/tensor_stream/debugging/debugging.rb +2 -2
  15. data/lib/tensor_stream/dynamic_stitch.rb +2 -2
  16. data/lib/tensor_stream/evaluator/base_evaluator.rb +18 -18
  17. data/lib/tensor_stream/evaluator/buffer.rb +1 -1
  18. data/lib/tensor_stream/evaluator/evaluator.rb +2 -2
  19. data/lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb +41 -41
  20. data/lib/tensor_stream/evaluator/operation_helpers/math_helper.rb +1 -1
  21. data/lib/tensor_stream/evaluator/ruby/array_ops.rb +39 -39
  22. data/lib/tensor_stream/evaluator/ruby/check_ops.rb +2 -2
  23. data/lib/tensor_stream/evaluator/ruby/images_ops.rb +18 -18
  24. data/lib/tensor_stream/evaluator/ruby/math_ops.rb +13 -14
  25. data/lib/tensor_stream/evaluator/ruby/nn_ops.rb +33 -36
  26. data/lib/tensor_stream/evaluator/ruby/random_ops.rb +20 -21
  27. data/lib/tensor_stream/evaluator/ruby_evaluator.rb +36 -49
  28. data/lib/tensor_stream/exceptions.rb +1 -1
  29. data/lib/tensor_stream/generated_stub/ops.rb +691 -0
  30. data/lib/tensor_stream/generated_stub/stub_file.erb +24 -0
  31. data/lib/tensor_stream/graph.rb +18 -18
  32. data/lib/tensor_stream/graph_builder.rb +17 -17
  33. data/lib/tensor_stream/graph_deserializers/protobuf.rb +97 -97
  34. data/lib/tensor_stream/graph_deserializers/yaml_loader.rb +1 -1
  35. data/lib/tensor_stream/graph_keys.rb +3 -3
  36. data/lib/tensor_stream/graph_serializers/graphml.rb +33 -33
  37. data/lib/tensor_stream/graph_serializers/packer.rb +23 -23
  38. data/lib/tensor_stream/graph_serializers/pbtext.rb +38 -42
  39. data/lib/tensor_stream/graph_serializers/serializer.rb +3 -2
  40. data/lib/tensor_stream/graph_serializers/yaml.rb +5 -5
  41. data/lib/tensor_stream/helpers/infer_shape.rb +56 -56
  42. data/lib/tensor_stream/helpers/op_helper.rb +8 -9
  43. data/lib/tensor_stream/helpers/string_helper.rb +15 -15
  44. data/lib/tensor_stream/helpers/tensor_mixins.rb +17 -17
  45. data/lib/tensor_stream/images.rb +1 -1
  46. data/lib/tensor_stream/initializer.rb +1 -1
  47. data/lib/tensor_stream/math_gradients.rb +28 -187
  48. data/lib/tensor_stream/monkey_patches/array.rb +1 -1
  49. data/lib/tensor_stream/monkey_patches/float.rb +1 -1
  50. data/lib/tensor_stream/monkey_patches/integer.rb +1 -1
  51. data/lib/tensor_stream/monkey_patches/op_patch.rb +5 -5
  52. data/lib/tensor_stream/monkey_patches/patch.rb +1 -1
  53. data/lib/tensor_stream/nn/nn_ops.rb +17 -15
  54. data/lib/tensor_stream/op_maker.rb +180 -0
  55. data/lib/tensor_stream/operation.rb +17 -17
  56. data/lib/tensor_stream/ops.rb +95 -384
  57. data/lib/tensor_stream/ops/add.rb +23 -0
  58. data/lib/tensor_stream/ops/argmax.rb +14 -0
  59. data/lib/tensor_stream/ops/argmin.rb +14 -0
  60. data/lib/tensor_stream/ops/case.rb +17 -0
  61. data/lib/tensor_stream/ops/cast.rb +15 -0
  62. data/lib/tensor_stream/ops/ceil.rb +15 -0
  63. data/lib/tensor_stream/ops/const.rb +0 -0
  64. data/lib/tensor_stream/ops/cos.rb +10 -0
  65. data/lib/tensor_stream/ops/div.rb +21 -0
  66. data/lib/tensor_stream/ops/equal.rb +15 -0
  67. data/lib/tensor_stream/ops/expand_dims.rb +17 -0
  68. data/lib/tensor_stream/ops/fill.rb +19 -0
  69. data/lib/tensor_stream/ops/floor.rb +15 -0
  70. data/lib/tensor_stream/ops/floor_div.rb +15 -0
  71. data/lib/tensor_stream/ops/greater.rb +11 -0
  72. data/lib/tensor_stream/ops/greater_equal.rb +11 -0
  73. data/lib/tensor_stream/ops/less_equal.rb +15 -0
  74. data/lib/tensor_stream/ops/log.rb +14 -0
  75. data/lib/tensor_stream/ops/mat_mul.rb +60 -0
  76. data/lib/tensor_stream/ops/max.rb +15 -0
  77. data/lib/tensor_stream/ops/min.rb +15 -0
  78. data/lib/tensor_stream/ops/mod.rb +23 -0
  79. data/lib/tensor_stream/ops/mul.rb +21 -0
  80. data/lib/tensor_stream/ops/negate.rb +14 -0
  81. data/lib/tensor_stream/ops/ones_like.rb +19 -0
  82. data/lib/tensor_stream/ops/pow.rb +25 -0
  83. data/lib/tensor_stream/ops/prod.rb +60 -0
  84. data/lib/tensor_stream/ops/random_uniform.rb +18 -0
  85. data/lib/tensor_stream/ops/range.rb +20 -0
  86. data/lib/tensor_stream/ops/rank.rb +13 -0
  87. data/lib/tensor_stream/ops/reshape.rb +24 -0
  88. data/lib/tensor_stream/ops/round.rb +15 -0
  89. data/lib/tensor_stream/ops/shape.rb +14 -0
  90. data/lib/tensor_stream/ops/sigmoid.rb +10 -0
  91. data/lib/tensor_stream/ops/sign.rb +12 -0
  92. data/lib/tensor_stream/ops/sin.rb +10 -0
  93. data/lib/tensor_stream/ops/size.rb +16 -0
  94. data/lib/tensor_stream/ops/sub.rb +24 -0
  95. data/lib/tensor_stream/ops/sum.rb +27 -0
  96. data/lib/tensor_stream/ops/tan.rb +12 -0
  97. data/lib/tensor_stream/ops/tanh.rb +10 -0
  98. data/lib/tensor_stream/ops/tile.rb +19 -0
  99. data/lib/tensor_stream/ops/zeros.rb +15 -0
  100. data/lib/tensor_stream/placeholder.rb +2 -2
  101. data/lib/tensor_stream/profile/report_tool.rb +3 -3
  102. data/lib/tensor_stream/session.rb +36 -38
  103. data/lib/tensor_stream/tensor.rb +2 -2
  104. data/lib/tensor_stream/tensor_shape.rb +4 -4
  105. data/lib/tensor_stream/train/adadelta_optimizer.rb +8 -8
  106. data/lib/tensor_stream/train/adagrad_optimizer.rb +3 -3
  107. data/lib/tensor_stream/train/adam_optimizer.rb +11 -11
  108. data/lib/tensor_stream/train/learning_rate_decay.rb +2 -2
  109. data/lib/tensor_stream/train/momentum_optimizer.rb +7 -7
  110. data/lib/tensor_stream/train/optimizer.rb +9 -9
  111. data/lib/tensor_stream/train/rmsprop_optimizer.rb +16 -16
  112. data/lib/tensor_stream/train/saver.rb +14 -14
  113. data/lib/tensor_stream/train/slot_creator.rb +6 -6
  114. data/lib/tensor_stream/train/utils.rb +12 -12
  115. data/lib/tensor_stream/trainer.rb +10 -10
  116. data/lib/tensor_stream/types.rb +1 -1
  117. data/lib/tensor_stream/utils.rb +33 -32
  118. data/lib/tensor_stream/utils/freezer.rb +5 -5
  119. data/lib/tensor_stream/variable.rb +5 -5
  120. data/lib/tensor_stream/variable_scope.rb +1 -1
  121. data/lib/tensor_stream/version.rb +1 -1
  122. data/samples/{iris.data → datasets/iris.data} +0 -0
  123. data/samples/jupyter_notebooks/linear_regression.ipynb +463 -0
  124. data/samples/{iris.rb → neural_networks/iris.rb} +21 -23
  125. data/samples/{mnist_data.rb → neural_networks/mnist_data.rb} +8 -8
  126. data/samples/neural_networks/raw_neural_net_sample.rb +112 -0
  127. data/samples/{rnn.rb → neural_networks/rnn.rb} +28 -31
  128. data/samples/{nearest_neighbor.rb → others/nearest_neighbor.rb} +12 -12
  129. data/samples/regression/linear_regression.rb +63 -0
  130. data/samples/{logistic_regression.rb → regression/logistic_regression.rb} +14 -16
  131. data/tensor_stream.gemspec +9 -8
  132. metadata +89 -19
  133. data/data_1.json +0 -4764
  134. data/data_2.json +0 -4764
  135. data/data_actual.json +0 -28
  136. data/data_expected.json +0 -28
  137. data/data_input.json +0 -28
  138. data/samples/error.graphml +0 -2755
  139. data/samples/gradient_sample.graphml +0 -1255
  140. data/samples/linear_regression.rb +0 -69
  141. data/samples/multigpu.rb +0 -73
  142. data/samples/raw_neural_net_sample.rb +0 -112
@@ -48,4 +48,4 @@ class Array
48
48
  index(min)
49
49
  end
50
50
  end
51
- end
51
+ end
@@ -7,4 +7,4 @@ class Float
7
7
  data_type = :"float#{width}"
8
8
  TensorStream.placeholder(data_type, name: name, shape: shape)
9
9
  end
10
- end
10
+ end
@@ -7,4 +7,4 @@ class Integer
7
7
  data_type = :"int#{width}"
8
8
  TensorStream.placeholder(data_type, name: name, shape: shape)
9
9
  end
10
- end
10
+ end
@@ -2,10 +2,10 @@ module TensorStream
2
2
  module OpPatch
3
3
  def self.included(klass)
4
4
  ops = if klass == Array
5
- {:+ => 'add', :- => 'sub', :* => 'mul'}
6
- else
7
- {:+ => 'add', :- => 'sub', :/ => 'div', :% => 'mod', :* => 'mul', :** => 'pow' }
8
- end
5
+ {:+ => "add", :- => "sub", :* => "mul"}
6
+ else
7
+ {:+ => "add", :- => "sub", :/ => "div", :% => "mod", :* => "mul", :** => "pow"}
8
+ end
9
9
 
10
10
  ops.each do |m, name|
11
11
  klass.send(:alias_method, :"_tensor_stream_#{name}_orig", m)
@@ -65,4 +65,4 @@ end
65
65
 
66
66
  Integer.include TensorStream::OpPatch
67
67
  Float.include TensorStream::OpPatch
68
- Array.include TensorStream::OpPatch
68
+ Array.include TensorStream::OpPatch
@@ -10,4 +10,4 @@ module TensorStream
10
10
  TensorStream.convert_to_tensor(self, name: name, dtype: dtype)
11
11
  end
12
12
  end
13
- end
13
+ end
@@ -27,7 +27,7 @@ module TensorStream
27
27
  TensorStream.name_scope(name, "dropout", values: [x]) do
28
28
  x = TensorStream.convert_to_tensor(x, name: "x")
29
29
  raise TensorStream::ValueError, "x has to be a floating point tensor since it's going to be scaled. Got a #{x.data_type} tensor instead." unless fp_type?(x.data_type)
30
- raise TensorStream::ValueError, "keep_prob must be a scalar tensor or a float in the range (0, 1], got #{keep_prob}" if keep_prob.is_a?(Float) && !(0 < keep_prob && keep_prob <= 1)
30
+ raise TensorStream::ValueError, "keep_prob must be a scalar tensor or a float in the range (0, 1], got #{keep_prob}" if keep_prob.is_a?(Float) && !(keep_prob > 0 && keep_prob <= 1)
31
31
 
32
32
  return x if keep_prob.is_a?(Float) && keep_prob.to_f == 1.0
33
33
 
@@ -35,10 +35,10 @@ module TensorStream
35
35
  return x if keep_prob.value == 1.0
36
36
 
37
37
  noise_shape = if noise_shape.nil?
38
- TensorStream.shape(x)
39
- else
40
- noise_shape
41
- end
38
+ TensorStream.shape(x)
39
+ else
40
+ noise_shape
41
+ end
42
42
 
43
43
  random_tensor = keep_prob
44
44
  random_tensor += TensorStream.random_uniform(noise_shape, seed: seed, dtype: x.dtype)
@@ -57,10 +57,10 @@ module TensorStream
57
57
  end
58
58
 
59
59
  def softmax_cross_entropy_with_logits_v2(labels: nil, logits: nil, name: nil)
60
- TensorStream.name_scope(name, default: 'softmax_cross_entropy_with_logits', values: [logits, labels]) do
60
+ TensorStream.name_scope(name, default: "softmax_cross_entropy_with_logits", values: [logits, labels]) do
61
61
  ts = TensorStream
62
- logits = ts.convert_to_tensor(logits, name: 'logits')
63
- labels = ts.convert_to_tensor(labels, name: 'labels')
62
+ logits = ts.convert_to_tensor(logits, name: "logits")
63
+ labels = ts.convert_to_tensor(labels, name: "labels")
64
64
  labels = ts.cast(labels, logits.dtype)
65
65
 
66
66
  output = _op(:softmax_cross_entropy_with_logits_v2, logits, labels)
@@ -80,11 +80,13 @@ module TensorStream
80
80
  static_shapes_fully_defined = labels_static_shape.known? && logits.shape.known?
81
81
 
82
82
  raise TensorStream::ValueError, "Logits cannot be scalars - received shape #{logits.shape.shape}." if logits.shape.known? && logits.shape.scalar?
83
- raise TensorStream::ValueError, "Rank mismatch: Rank of labels (received #{labels_static_shape.ndims}) " +
84
- "should equal rank of logits minus 1 (received #{logits.shape.ndims})." if logits.shape.known? && (labels_static_shape.known? && labels_static_shape.ndims != logits.shape.ndims - 1)
83
+ if logits.shape.known? && (labels_static_shape.known? && labels_static_shape.ndims != logits.shape.ndims - 1)
84
+ raise TensorStream::ValueError, "Rank mismatch: Rank of labels (received #{labels_static_shape.ndims}) " \
85
+ "should equal rank of logits minus 1 (received #{logits.shape.ndims})."
86
+ end
85
87
  if logits.shape.ndims == 2
86
88
  cost = _op(:sparse_softmax_cross_entropy_with_logits,
87
- precise_logits, labels, name: name)
89
+ precise_logits, labels, name: name)
88
90
  if logits.data_type == :float16
89
91
  return tf.cast(cost[0], :float16)
90
92
  else
@@ -118,17 +120,17 @@ module TensorStream
118
120
  end
119
121
 
120
122
  def sigmoid_cross_entropy_with_logits(labels: nil, logits: nil, name: nil)
121
- TensorStream.name_scope(name, default: 'logistic_loss', values: [logits, labels]) do |_name|
123
+ TensorStream.name_scope(name, default: "logistic_loss", values: [logits, labels]) do |_name|
122
124
  tf = TensorStream
123
- logits = tf.convert_to_tensor(logits, name: 'logits')
124
- labels = tf.convert_to_tensor(labels, name: 'labels')
125
+ logits = tf.convert_to_tensor(logits, name: "logits")
126
+ labels = tf.convert_to_tensor(labels, name: "labels")
125
127
  zeros = tf.zeros_like(logits, dtype: logits.dtype)
126
128
  cond = (logits >= zeros)
127
129
  relu_logits = tf.where(cond, logits, zeros)
128
130
  neg_abs_logits = tf.where(cond, -logits, logits)
129
131
 
130
132
  tf.add(relu_logits - logits * labels,
131
- tf.log1p(tf.exp(neg_abs_logits)), name: name)
133
+ tf.log1p(tf.exp(neg_abs_logits)), name: name)
132
134
  end
133
135
  end
134
136
 
@@ -0,0 +1,180 @@
1
+ class TensorStream::OpMaker
2
+ attr_reader :operation, :description, :parameters,
3
+ :options, :gradient, :check_types,
4
+ :supports_broadcast, :data_type_coercion,
5
+ :aliases, :custom, :infer_type_proc, :exclude
6
+
7
+ def initialize(op)
8
+ @operation = op
9
+ @parameters = []
10
+ @options = {}
11
+ @gradient = nil
12
+ @supports_broadcast = false
13
+ @data_type_coercion = false
14
+ @exclude = false
15
+ @description = []
16
+ @aliases = []
17
+ @custom = []
18
+ @infer_type_proc = lambda { |tensor|
19
+ next nil if tensor.inputs[0].nil?
20
+ next tensor.inputs[0].shape.shape if tensor.inputs.size == 1
21
+
22
+ TensorStream::TensorShape.infer_shape(tensor.inputs[0].shape.shape, tensor.inputs[1].shape.shape) if tensor.inputs.size == 2 && tensor.inputs[0] && tensor.inputs[1]
23
+ }
24
+ end
25
+
26
+ def other_names(aliases)
27
+ @aliases += aliases
28
+ end
29
+
30
+ def add_custom(custom_code)
31
+ @custom << custom_code
32
+ end
33
+
34
+ def self.scan
35
+ op_files = Dir[File.join("lib", "tensor_stream", "ops", "*.rb")]
36
+ op_files.each { |file|
37
+ load File.join("tensor_stream", "ops", File.basename(file))
38
+ }
39
+ end
40
+
41
+ def self.define_operation(op_code, &block)
42
+ @ops ||= {}
43
+ op_maker = TensorStream::OpMaker.new(op_code.to_sym)
44
+ block.call(op_maker)
45
+ @ops[op_code.to_sym] = op_maker
46
+ end
47
+
48
+ # call an operations' gradient definition
49
+ def self.gradient_op(context_caller, node, grad)
50
+ raise "No derivative op defined for #{node.operation}" if @ops[node.operation].nil? || @ops[node.operation].gradient.nil?
51
+
52
+ context_caller.instance_exec(grad, node, node.inputs, &@ops[node.operation].gradient)
53
+ end
54
+
55
+ def self.infer_shape(context_caller, tensor)
56
+ return nil unless @ops[tensor.operation]
57
+
58
+ context_caller.instance_exec(tensor, &@ops[tensor.operation].infer_type_proc)
59
+ end
60
+
61
+ def self.each_op(&block)
62
+ @ops.values.sort_by { |op| op.operation }.reject(&:exclude).each do |op|
63
+ block.call(op)
64
+ end
65
+ end
66
+
67
+ def what_it_does(description)
68
+ @description << description
69
+ end
70
+
71
+ def what_it_does_code(description)
72
+ @description << "<tt>#{description}</tt>"
73
+ end
74
+
75
+ def exclude!
76
+ @exclude = true
77
+ end
78
+
79
+ def description_lines
80
+ description.map { |line| line.split("\n") }.flatten
81
+ end
82
+
83
+ def generate_body
84
+ body = []
85
+ parameters.select { |p| p[:validate] }.each do |p|
86
+ body << "check_allowed_types(#{p[:name]}, TensorStream::Ops::#{p[:validate]})"
87
+ end
88
+ if data_type_coercion?
89
+ body << "#{expand_params(false).join(', ')} = apply_data_type_coercion(#{expand_params(false).join(', ')})"
90
+ end
91
+ if check_types?
92
+ body << "check_data_types(#{expand_params(false).join(', ')})"
93
+ end
94
+ custom.each do |c|
95
+ body << c
96
+ end
97
+ body << "_op(:#{operation}, #{(expand_params(false) + options_call).join(', ')})"
98
+ body.map { |line| " #{line}"}.join("\n")
99
+ end
100
+
101
+ ##
102
+ # adds a parameter to the op
103
+ #
104
+ def parameter(name, description, default_value = nil, validate: nil)
105
+ @parameters << {
106
+ name: name.to_s,
107
+ description: description,
108
+ default_value: default_value,
109
+ validate: validate
110
+ }
111
+ end
112
+
113
+ def option(name, description, default_value = nil, options = {})
114
+ @options[name] = { description: description, default_value: default_value, options: options }
115
+ end
116
+
117
+ def define_gradient(&block)
118
+ @gradient = block
119
+ end
120
+
121
+ def define_shape(&block)
122
+ @infer_type_proc = block
123
+ end
124
+
125
+ def expand_params(print_defaults)
126
+ @parameters.map { |param|
127
+ print_defaults && param[:default_value] ? "#{param[:name]} = #{default_with_nil(param[:default_value])}" : "#{param[:name]}"
128
+ }
129
+ end
130
+
131
+ def parameters_must_have_same_data_type!
132
+ @check_types = true
133
+ end
134
+
135
+ def apply_data_type_coercion!
136
+ @data_type_coercion = true
137
+ end
138
+
139
+ def supports_broadcasting!
140
+ if (@parameters.size> 1)
141
+ @supports_broadcast = true
142
+ else
143
+ raise "Ops with parameters < 2 cannot support broadcasting"
144
+ end
145
+ end
146
+
147
+ def supports_broadcasting?
148
+ @supports_broadcast
149
+ end
150
+
151
+ def data_type_coercion?
152
+ @data_type_coercion
153
+ end
154
+
155
+ def check_types?
156
+ @check_types
157
+ end
158
+
159
+ def expand_options(print_defaults)
160
+ @options.map { |k, v|
161
+ print_defaults && v[:default_value] ? "#{k}: #{default_with_nil(v[:default_value])}" : "#{k}:"
162
+ }
163
+ end
164
+
165
+ def options_call
166
+ @options.map { |k, v|
167
+ if v.dig(:options, :alias)
168
+ "#{v.dig(:options, :alias)}: #{k}"
169
+ else
170
+ "#{k}: #{k}"
171
+ end
172
+ }
173
+ end
174
+
175
+ def default_with_nil(v)
176
+ v == :nil ? 'nil' : v
177
+ end
178
+ end
179
+
180
+ TensorStream::OpMaker.scan
@@ -1,4 +1,4 @@
1
- require 'tensor_stream/helpers/infer_shape'
1
+ require "tensor_stream/helpers/infer_shape"
2
2
  module TensorStream
3
3
  # TensorStream class that defines an operation
4
4
  class Operation < Tensor
@@ -17,7 +17,7 @@ module TensorStream
17
17
  end
18
18
 
19
19
  def inspect
20
- "Op(#{operation} name: #{name} shape: #{@shape || '?'} data_type: #{data_type})"
20
+ "Op(#{operation} name: #{name} shape: #{@shape || "?"} data_type: #{data_type})"
21
21
  end
22
22
 
23
23
  def to_s
@@ -30,7 +30,7 @@ module TensorStream
30
30
  name: name.to_s,
31
31
  data_type: @data_type,
32
32
  inputs: @inputs.map(&:name),
33
- attrs: serialize_options
33
+ attrs: serialize_options,
34
34
  }
35
35
  end
36
36
 
@@ -145,11 +145,11 @@ module TensorStream
145
145
  when :slice
146
146
  "#{sub_input}[#{sub_input2}]"
147
147
  when :assign_sub
148
- "(#{inputs[0] ? inputs[0].name : 'self'} -= #{auto_math(inputs[1], name_only, 1)})"
148
+ "(#{inputs[0] ? inputs[0].name : "self"} -= #{auto_math(inputs[1], name_only, 1)})"
149
149
  when :assign_add
150
- "(#{inputs[0] ? inputs[0].name : 'self'} += #{auto_math(inputs[1], name_only, 1)})"
150
+ "(#{inputs[0] ? inputs[0].name : "self"} += #{auto_math(inputs[1], name_only, 1)})"
151
151
  when :assign
152
- "(#{inputs[0] ? inputs[0].name : 'self'} = #{auto_math(inputs[1], name_only, 1)})"
152
+ "(#{inputs[0] ? inputs[0].name : "self"} = #{auto_math(inputs[1], name_only, 1)})"
153
153
  when :sin, :cos, :tanh
154
154
  "#{operation}(#{sub_input})"
155
155
  when :add
@@ -193,7 +193,7 @@ module TensorStream
193
193
  when :ones_like
194
194
  "ones_like(#{sub_input})"
195
195
  when :flow_group
196
- "flow_group(#{inputs.collect { |i| auto_math(i, name_only, max_depth - 1, cur_depth) }.join(',')})"
196
+ "flow_group(#{inputs.collect { |i| auto_math(i, name_only, max_depth - 1, cur_depth) }.join(",")})"
197
197
  when :zeros
198
198
  "zeros(#{sub_input})"
199
199
  when :reshape
@@ -243,8 +243,8 @@ module TensorStream
243
243
  else
244
244
  "#{operation}(#{sub_input})" if sub_input
245
245
  "#{operation}(#{sub_input}, #{sub_input2})" if sub_input && sub_input2
246
- end
247
- ["\n", Array.new(cur_depth + 1) { ' ' }, out].flatten.join
246
+ end
247
+ ["\n", Array.new(cur_depth + 1) { " " }, out].flatten.join
248
248
  end
249
249
 
250
250
  def run
@@ -260,21 +260,21 @@ module TensorStream
260
260
  def serialize_options
261
261
  excludes = %i[internal_name source]
262
262
 
263
- @options.reject { |k, v| excludes.include?(k) || v.nil? }.map do |k, v|
263
+ @options.reject { |k, v| excludes.include?(k) || v.nil? }.map { |k, v|
264
264
  v = case v.class.to_s
265
- when 'TensorStream::TensorShape'
265
+ when "TensorStream::TensorShape"
266
266
  v.shape
267
- when 'Array'
267
+ when "Array"
268
268
  v
269
- when 'String', 'Integer', 'Float', 'Symbol', 'FalseClass', "TrueClass"
269
+ when "String", "Integer", "Float", "Symbol", "FalseClass", "TrueClass"
270
270
  v
271
- when 'TensorStream::Variable'
272
- { name: v.name, options: v.options, shape: v.shape.shape.dup }
271
+ when "TensorStream::Variable"
272
+ {name: v.name, options: v.options, shape: v.shape.shape.dup}
273
273
  else
274
274
  raise "unknown type #{v.class}"
275
- end
275
+ end
276
276
  [k.to_sym, v]
277
- end.to_h
277
+ }.to_h
278
278
  end
279
279
 
280
280
  def add_consumer(consumer)
@@ -1,6 +1,12 @@
1
+
1
2
  module TensorStream
2
3
  # Class that defines all available ops supported by TensorStream
3
4
  module Ops
5
+ if File.exists?(File.join(__dir__, 'generated_stub', 'ops.rb'))
6
+ require 'tensor_stream/generated_stub/ops'
7
+ include TensorStream::OpStub
8
+ end
9
+
4
10
  class OutputHolder
5
11
  def initialize(op)
6
12
  @op = op
@@ -10,30 +16,6 @@ module TensorStream
10
16
  INTEGER_TYPES = %i[uint8 int32 int int16 uint16 int64 uint32 uint64].freeze
11
17
  NUMERIC_TYPES = FLOATING_POINT_TYPES + INTEGER_TYPES
12
18
 
13
- ##
14
- # Returns the index with the largest value across axes of a tensor.
15
- #
16
- # Argmuments
17
- #
18
- # +input+ A Tensor. Must be one of the following types: float32, float64, int32, int16
19
- # +axis+ Describes which axis of the input Tensor to reduce across. For vectors, use axis = 0
20
- # +output_type+ Output data type defaults to int32
21
- def argmax(input, axis = nil, name: nil, dimension: nil, output_type: :int32)
22
- _op(:argmax, input, axis, name: name, dimension: dimension, data_type: output_type)
23
- end
24
-
25
- ##
26
- # Returns the index with the smallest value across axes of a tensor.
27
- #
28
- # Argmuments
29
- #
30
- # +input+ A Tensor. Must be one of the following types: float32, float64, int32, int16
31
- # +axis+ Describes which axis of the input Tensor to reduce across. For vectors, use axis = 0
32
- # +output_type+ Output data type defaults to int32
33
- def argmin(input, axis = nil, name: nil, dimension: nil, output_type: :int32)
34
- _op(:argmin, input, axis, name: name, dimension: dimension, data_type: output_type)
35
- end
36
-
37
19
  ##
38
20
  # Assert the condition x == y holds element-wise.
39
21
  #
@@ -57,46 +39,39 @@ module TensorStream
57
39
  # +tensor_ys+ : A Tensor or list of tensors to be differentiated.
58
40
  # +wrt_xs+ : A Tensor or list of tensors to be used for differentiation.
59
41
  # +stop_gradients+ : Optional. A Tensor or list of tensors not to differentiate through
60
- def gradients(tensor_ys, wrt_xs, name: 'gradients', stop_gradients: nil)
42
+ def gradients(tensor_ys, wrt_xs, name: "gradients", stop_gradients: nil)
61
43
  tensor_ys = tensor_ys.op
62
- gs = wrt_xs.map(&:op).collect do |x|
63
- stops = stop_gradients ? stop_gradients.map(&:name).join('_') : ''
44
+ gs = wrt_xs.map(&:op).collect { |x|
45
+ stops = stop_gradients ? stop_gradients.map(&:name).join("_") : ""
64
46
  gradient_program_name = "grad_#{tensor_ys.name}_#{x.name}_#{stops}".to_sym
65
47
  tensor_graph = tensor_ys.graph
66
48
 
67
49
  tensor_program = if tensor_graph.node_added?(gradient_program_name)
68
- tensor_graph.get_node(gradient_program_name)
69
- else
70
- tensor_graph.name_scope("gradient_wrt_#{x.name}") do
71
- derivative_ops = TensorStream::MathGradients.derivative(tensor_ys, x, graph: tensor_graph,
72
- stop_gradients: stop_gradients)
73
- tensor_graph.add_node!(gradient_program_name, derivative_ops)
74
- end
75
- end
50
+ tensor_graph.get_node(gradient_program_name)
51
+ else
52
+ tensor_graph.name_scope("gradient_wrt_#{x.name}") do
53
+ derivative_ops = TensorStream::MathGradients.derivative(tensor_ys, x, graph: tensor_graph,
54
+ stop_gradients: stop_gradients)
55
+ tensor_graph.add_node!(gradient_program_name, derivative_ops)
56
+ end
57
+ end
76
58
  tensor_program
77
- end
59
+ }
78
60
 
79
61
  gs
80
62
  end
81
63
 
82
- ##
83
- # Outputs random values from a uniform distribution.
84
- def random_uniform(shape, dtype: :float32, minval: 0, maxval: 1, seed: nil, name: nil)
85
- options = { dtype: dtype, minval: minval, maxval: maxval, seed: seed, name: name }
86
- _op(:random_uniform, shape, options)
87
- end
88
-
89
64
  ##
90
65
  # Outputs random values from a normal distribution.
91
66
  def random_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
92
- options = { dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
67
+ options = {dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name}
93
68
  _op(:random_standard_normal, shape, options)
94
69
  end
95
70
 
96
71
  ##
97
72
  # Outputs random values from a truncated normal distribution.
98
73
  def truncated_normal(shape, dtype: :float32, mean: 0.0, stddev: 1.0, seed: nil, name: nil)
99
- options = { dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name }
74
+ options = {dtype: dtype, mean: mean, stddev: stddev, seed: seed, name: name}
100
75
  _op(:truncated_normal, shape, options)
101
76
  end
102
77
 
@@ -118,14 +93,6 @@ module TensorStream
118
93
  _op(:expand_dims, input, axis, name: name)
119
94
  end
120
95
 
121
- ##
122
- # This operation returns a 1-D integer tensor representing the shape of input
123
- def shape(input, name: nil, out_type: :int32)
124
- return constant(shape_eval(input, out_type), dtype: out_type, name: "Shape/#{name}") if input.is_a?(Array) && !input[0].is_a?(Tensor)
125
- return constant(input.shape.shape, dtype: out_type, name: "Shape/#{input.name}_c") if shape_full_specified(input)
126
-
127
- _op(:shape, input, name: name, out_type: out_type)
128
- end
129
96
 
130
97
  def shape_n(inputs, name: nil, out_type: :int32)
131
98
  shapes_known = true
@@ -146,29 +113,6 @@ module TensorStream
146
113
  end
147
114
  end
148
115
 
149
- ##
150
- # Constructs a tensor by tiling a given tensor.
151
- #
152
- # This operation creates a new tensor by replicating input multiples times.
153
- # The output tensor's i'th dimension has input.dims(i) * multiples[i] elements,
154
- # and the values of input are replicated multiples[i] times along the 'i'th dimension. For example, tiling [a b c d] by [2] produces [a b c d a b c d].
155
- def tile(input, multiples, name: nil)
156
- _op(:tile, input, multiples, name: name)
157
- end
158
-
159
- ##
160
- # Returns the rank of a tensor.
161
- def rank(input, name: nil)
162
- input = convert_to_tensor(input)
163
- return cons(input.shape.ndims) if input.shape.known?
164
-
165
- _op(:rank, input, name: name)
166
- end
167
-
168
- def constant_initializer(value, dtype: nil, verify_shape: false)
169
- TensorStream::Initializer.new(-> { _op(:fill, nil, convert_to_tensor(value, dtype: dtype)) })
170
- end
171
-
172
116
  ##
173
117
  # initializer that generates tensors initialized to 0.
174
118
  #
@@ -183,6 +127,10 @@ module TensorStream
183
127
  TensorStream::Initializer.new(-> { _op(:ones, data_type: dtype) })
184
128
  end
185
129
 
130
+ def constant_initializer(value, dtype: nil, verify_shape: false)
131
+ TensorStream::Initializer.new(-> { _op(:fill, nil, convert_to_tensor(value, dtype: dtype)) })
132
+ end
133
+
186
134
  ##
187
135
  # The Glorot uniform initializer, also called Xavier uniform initializer.
188
136
  #
@@ -209,119 +157,45 @@ module TensorStream
209
157
  _op(:slice, input, start, size: size, name: name)
210
158
  end
211
159
 
212
- ##
213
- # Creates a tensor with all elements set to zero
214
- def zeros(shape, dtype: :float32, name: nil)
215
- _op(:zeros, shape, data_type: dtype, name: name)
216
- end
217
-
218
160
  ##
219
161
  # Creates a tensor with all elements set to 1.
220
162
  def ones(shape, dtype: :float32, name: nil)
221
163
  _op(:ones, shape, data_type: dtype, name: name)
222
164
  end
223
165
 
224
- ##
225
- # Returns element-wise largest integer not greater than x.
226
- def floor(input, name: nil)
227
- check_allowed_types(input, FLOATING_POINT_TYPES)
228
- _op(:floor, input, name: name)
229
- end
230
-
231
- ##
232
- # Returns element-wise smallest integer in not less than x
233
- def ceil(input, name: nil)
234
- check_allowed_types(input, FLOATING_POINT_TYPES)
235
- _op(:ceil, input, name: name)
236
- end
237
-
238
166
  ##
239
167
  # Returns the truth value of (x < y) element-wise.
240
168
  # This operation supports broadcasting
241
169
  def less(input_a, input_b, name: nil)
242
- input_a, input_b = check_data_types(input_a, input_b)
170
+ check_data_types(input_a, input_b)
243
171
  _op(:less, input_a, input_b, name: name)
244
172
  end
245
173
 
246
174
  ##
247
175
  # Returns the truth value of x AND y element-wise.
248
176
  def logical_and(input_a, input_b, name: nil)
249
- input_a, input_b = check_data_types(input_a, input_b)
177
+ check_data_types(input_a, input_b)
250
178
  _op(:logical_and, input_a, input_b, name: name)
251
179
  end
252
180
 
253
- ##
254
- # Returns the truth value of (x > y) element-wise.
255
- # This operation supports broadcasting
256
- def greater(input_a, input_b, name: nil)
257
- input_a, input_b = check_data_types(input_a, input_b)
258
- _op(:greater, input_a, input_b, name: name)
259
- end
260
-
261
- ##
262
- # Returns the truth value of (x >= y) element-wise.
263
- #
264
- # This operation supports broadcasting
265
- def greater_equal(input_a, input_b, name: nil)
266
- input_a, input_b = check_data_types(input_a, input_b)
267
- _op(:greater_equal, input_a, input_b, name: name)
268
- end
269
-
270
- ##
271
- # Returns the truth value of (x <= y) element-wise.
272
- def less_equal(input_a, input_b, name: nil)
273
- input_a, input_b = check_data_types(input_a, input_b)
274
- _op(:less_equal, input_a, input_b, name: name)
275
- end
276
-
277
181
  ##
278
182
  # Computes the mean of elements across dimensions of a tensor.
279
183
  def reduce_mean(input_tensor, axis = nil, keepdims: false, name: nil)
280
184
  reduce(:mean, input_tensor, axis, keepdims: keepdims, name: name)
281
185
  end
282
186
 
283
- ##
284
- # Computes the sum of elements across dimensions of a tensor.
285
- #
286
- # Reduces input_tensor along the dimensions given in axis. Unless keepdims is true,
287
- # the rank of the tensor is reduced by 1 for each entry in axis. If keepdims is true,
288
- # the reduced dimensions are retained with length 1.
289
- # If axis has no entries, all dimensions are reduced, and a tensor with a single element
290
- # is returned.
291
- def reduce_sum(input_tensor, axis = nil, keepdims: false, name: nil)
292
- reduce(:sum, input_tensor, axis, keepdims: keepdims, name: name)
293
- end
294
-
295
- ##
296
- # Computes the product of elements across dimensions of a tensor.
297
- #
298
- # Reduces input_tensor along the dimensions given in axis. Unless keepdims is true, the rank of the
299
- # tensor is reduced by 1 for each entry in axis. If keepdims is true, the reduced dimensions are
300
- # retained with length 1.
301
- #
302
- # If axis has no entries, all dimensions are reduced, and a tensor with a single element is returned.
303
- def reduce_prod(input, axis = nil, keepdims: false, name: nil)
304
- reduce(:prod, input, axis, keepdims: keepdims, name: name)
305
- end
306
-
307
187
  def reduce(op, input, axis = nil, keepdims: false, name: nil)
308
188
  input = TensorStream.convert_to_tensor(input)
309
- axis = if !axis.nil?
310
- axis
311
- elsif input.shape.scalar?
312
- op
313
- elsif input.shape.known?
314
- (0...input.shape.ndims).to_a
315
- else
316
- range(0, rank(input))
317
- end
189
+ return input if input.shape.scalar?
190
+
191
+ axis = cast_axis(input, axis)
318
192
 
319
193
  _op(op, input, axis, keepdims: keepdims, name: name)
320
194
  end
321
195
 
322
196
  ##
323
197
  # Concatenates tensors along one dimension.
324
- def concat(values, axis, name: 'concat')
198
+ def concat(values, axis, name: "concat")
325
199
  if values.is_a?(Array)
326
200
  _op(:concat, axis, *values, name: name)
327
201
  else
@@ -329,7 +203,7 @@ module TensorStream
329
203
  end
330
204
  end
331
205
 
332
- def split(value, num_or_size_splits, axis: 0, num: nil, name: 'split')
206
+ def split(value, num_or_size_splits, axis: 0, num: nil, name: "split")
333
207
  value = convert_to_tensor(value)
334
208
  num_or_size_splits = convert_to_tensor(num_or_size_splits)
335
209
  axis = convert_to_tensor(axis)
@@ -339,33 +213,33 @@ module TensorStream
339
213
  res = _op(:split, value, num_or_size_splits, axis, name: name)
340
214
 
341
215
  pieces = if value.shape.known? && num_or_size_splits.is_const && num_or_size_splits.value && axis.is_const
342
- if num_or_size_splits.shape.scalar?
343
- raise TensorStream::ValueError, "num_or_size_splits must divide dimension #{value.shape.shape[axis.value]} evenly" unless (value.shape.shape[axis.value] % num_or_size_splits.value).zero?
344
-
345
- div = num_or_size_splits.value
346
- n = value.shape.shape[axis.value] / div
347
-
348
- Array.new(div) do
349
- new_shape = value.shape.shape.dup
350
- new_shape[axis.value] = n
351
- new_shape
352
- end
353
- elsif num_or_size_splits.shape.ndims == 1
354
- raise TensorStream::ValueError, "Sum of splits do not match total dimen in axis #{value.shape.shape[axis.value]} != #{num_or_size_splits.value.reduce(:+)}" if value.shape.shape[axis.value] != num_or_size_splits.value.reduce(:+)
355
-
356
- num_or_size_splits.value.collect do |v|
357
- new_shape = value.shape.shape.dup
358
- new_shape[axis.value] = v
359
- new_shape
360
- end
361
- else
362
- raise TensorStream::ValueError, "Scalar or 1D Tensor expected for num_or_size_splits"
363
- end
364
- else
365
- raise TensorStream::ValueError, "Cannot automatically determine num, please specify num: in options" if num.nil?
366
-
367
- Array.new(num) { nil }
368
- end
216
+ if num_or_size_splits.shape.scalar?
217
+ raise TensorStream::ValueError, "num_or_size_splits must divide dimension #{value.shape.shape[axis.value]} evenly" unless (value.shape.shape[axis.value] % num_or_size_splits.value).zero?
218
+
219
+ div = num_or_size_splits.value
220
+ n = value.shape.shape[axis.value] / div
221
+
222
+ Array.new(div) do
223
+ new_shape = value.shape.shape.dup
224
+ new_shape[axis.value] = n
225
+ new_shape
226
+ end
227
+ elsif num_or_size_splits.shape.ndims == 1
228
+ raise TensorStream::ValueError, "Sum of splits do not match total dimen in axis #{value.shape.shape[axis.value]} != #{num_or_size_splits.value.reduce(:+)}" if value.shape.shape[axis.value] != num_or_size_splits.value.reduce(:+)
229
+
230
+ num_or_size_splits.value.collect do |v|
231
+ new_shape = value.shape.shape.dup
232
+ new_shape[axis.value] = v
233
+ new_shape
234
+ end
235
+ else
236
+ raise TensorStream::ValueError, "Scalar or 1D Tensor expected for num_or_size_splits"
237
+ end
238
+ else
239
+ raise TensorStream::ValueError, "Cannot automatically determine num, please specify num: in options" if num.nil?
240
+
241
+ Array.new(num) { nil }
242
+ end
369
243
 
370
244
  pieces.collect.with_index do |shape, i|
371
245
  op = index(res, i, name: "split/index:#{i}")
@@ -395,13 +269,6 @@ module TensorStream
395
269
  _op(:square, tensor, name: name)
396
270
  end
397
271
 
398
- ##
399
- # Rounds the values of a tensor to the nearest integer, element-wise
400
- def round(tensor, name: nil)
401
- check_allowed_types(tensor, FLOATING_POINT_TYPES)
402
- _op(:round, tensor, name: name)
403
- end
404
-
405
272
  ##
406
273
  # Computes the reciprocal of x element-wise.
407
274
  def reciprocal(tensor, name: nil)
@@ -420,15 +287,6 @@ module TensorStream
420
287
  _op(:where, condition, true_t, false_t, name: name)
421
288
  end
422
289
 
423
- ##
424
- # Returns x + y element-wise.
425
- #
426
- # This operation supports broadcasting
427
- def add(input_a, input_b, name: nil)
428
- input_a, input_b = check_data_types(input_a, input_b)
429
- _op(:add, input_a, input_b, name: name)
430
- end
431
-
432
290
  ##
433
291
  # Adds all input tensors element-wise.
434
292
  #
@@ -458,53 +316,20 @@ module TensorStream
458
316
  _op(:atan, input, name: name)
459
317
  end
460
318
 
461
- ##
462
- # Returns x - y element-wise.
463
- #
464
- # This operation supports boradcasting
465
- def sub(input_a, input_b, name: nil)
466
- input_a, input_b = check_data_types(input_a, input_b)
467
- _op(:sub, input_a, input_b, name: name)
468
- end
469
-
470
- ##
471
- # Returns element-wise remainder of division.
472
- def mod(input_a, input_b, name: nil)
473
- input_a = convert_to_tensor(input_a)
474
- input_b = convert_to_tensor(input_b)
475
-
476
- input_a, input_b = check_data_types(input_a, input_b)
477
- _op(:mod, input_a, input_b, name: name)
478
- end
479
-
480
319
  ##
481
320
  # Returns element-wise integer divistion.
482
321
  def floor_div(input_a, input_b, name: nil)
483
- input_a, input_b = check_data_types(input_a, input_b)
322
+ check_data_types(input_a, input_b)
484
323
  _op(:floor_div, input_a, input_b, name: name)
485
324
  end
486
325
 
487
- def range(start, limit, delta = 1, dtype: nil, name: 'range')
488
- _op(:range, start, limit, delta, data_type: dtype, name: name)
489
- end
490
-
491
326
  ##
492
- # Returns x - y element-wise.
493
- #
494
- # This operation supports boradcasting
495
- def subtract(input_a, input_b, name: nil)
496
- input_a, input_b = check_data_types(input_a, input_b)
497
- sub(input_a, input_b, name: name)
498
- end
499
-
500
- ##
501
- # Returns the max of x and y (i.e. x > y ? x : y) element-wise.
502
- def max(input_a, input_b, name: nil)
503
- check_allowed_types(input_a, NUMERIC_TYPES)
504
- check_allowed_types(input_b, NUMERIC_TYPES)
327
+ # Casts a tensor to a new type, if needed
328
+ def cast(input, dtype, name: nil)
329
+ input = convert_to_tensor(input)
330
+ return input if input.data_type == dtype
505
331
 
506
- input_a, input_b = check_data_types(input_a, input_b)
507
- _op(:max, input_a, input_b, name: name)
332
+ _op(:cast, input, data_type: dtype, name: name)
508
333
  end
509
334
 
510
335
  ##
@@ -513,30 +338,12 @@ module TensorStream
513
338
  max(input_a, input_b, name: name)
514
339
  end
515
340
 
516
- ##
517
- # Returns the min of x and y (i.e. x < y ? x : y) element-wise.
518
- def min(input_a, input_b, name: nil)
519
- check_allowed_types(input_a, NUMERIC_TYPES)
520
- check_allowed_types(input_b, NUMERIC_TYPES)
521
- input_a, input_b = check_data_types(input_a, input_b)
522
- _op(:min, input_a, input_b, name: name)
523
- end
524
-
525
341
  ##
526
342
  # Returns the min of x and y (i.e. x < y ? x : y) element-wise.
527
343
  def minimum(input_a, input_b, name: nil)
528
344
  min(input_a, input_b, name: name)
529
345
  end
530
346
 
531
- ##
532
- # Casts a tensor to a new type, if needed
533
- def cast(input, dtype, name: nil)
534
- input = convert_to_tensor(input)
535
- return input if input.data_type == dtype
536
-
537
- _op(:cast, input, data_type: dtype, name: name)
538
- end
539
-
540
347
  ##
541
348
  # Prints a list of tensors.
542
349
  #
@@ -545,30 +352,17 @@ module TensorStream
545
352
  _op(:print, input, data, message: message, name: name)
546
353
  end
547
354
 
548
- ##
549
- # Computes numerical negative value element-wise.
550
- def negate(input, name: nil)
551
- _op(:negate, input, name: name)
552
- end
553
-
554
355
  ##
555
356
  # Computes numerical negative value element-wise.
556
357
  def negative(input, name: nil)
557
358
  negate(input, name: name)
558
359
  end
559
360
 
560
- ##
561
- # Returns the truth value of (x == y) element-wise.
562
- def equal(input_a, input_b, name: nil)
563
- input_a, input_b = check_data_types(input_a, input_b)
564
- _op(:equal, input_a, input_b, name: name)
565
- end
566
-
567
361
  ##
568
362
  # Returns the truth value of (x != y) element-wise.
569
363
  # This ops supports broadcasting
570
364
  def not_equal(input_a, input_b, name: nil)
571
- input_a, input_b = check_data_types(input_a, input_b)
365
+ check_data_types(input_a, input_b)
572
366
  _op(:not_equal, input_a, input_b, name: name)
573
367
  end
574
368
 
@@ -600,33 +394,10 @@ module TensorStream
600
394
  # Returns x * y element-wise.
601
395
  # This operation supports broadcasting
602
396
  def multiply(input_a, input_b, name: nil)
603
- input_a, input_b = check_data_types(input_a, input_b)
397
+ check_data_types(input_a, input_b)
604
398
  _op(:mul, input_a, input_b, name: name)
605
399
  end
606
400
 
607
- ##
608
- # Returns x * y element-wise.
609
- # This operation supports broadcasting
610
- def mul(input_a, input_b, name: nil)
611
- input_a, input_b = check_data_types(input_a, input_b)
612
- _op(:mul, input_a, input_b, name: name)
613
- end
614
-
615
- ##
616
- # Divides x / y elementwise
617
- # This operation supports broadcasting
618
- def div(input_a, input_b, name: nil)
619
- input_a, input_b = check_data_types(input_a, input_b)
620
- _op(:div, input_a, input_b, name: name)
621
- end
622
-
623
- ##
624
- # Computes the power of one value to another.
625
- def pow(input_a, input_e, name: nil)
626
- input_a, input_e = check_data_types(input_a, input_e)
627
- _op(:pow, input_a, input_e, name: name)
628
- end
629
-
630
401
  ##
631
402
  # Computes the absolute value of a tensor.
632
403
  def abs(input, name: nil)
@@ -634,42 +405,6 @@ module TensorStream
634
405
  end
635
406
 
636
407
  ##
637
- # Returns an element-wise indication of the sign of a number.
638
- # y = sign(x) = -1 if x < 0; 0 if x == 0 or tf.is_nan(x); 1 if x > 0.
639
- # Zero is returned for NaN inputs.
640
- def sign(input, name: nil)
641
- _op(:sign, input, name: name)
642
- end
643
-
644
- ##
645
- # Computes sin of input element-wise.
646
- def sin(input, name: nil)
647
- check_allowed_types(input, FLOATING_POINT_TYPES)
648
- _op(:sin, input, name: name)
649
- end
650
-
651
- ##
652
- # Computes cos of input element-wise.
653
- def cos(input, name: nil)
654
- check_allowed_types(input, FLOATING_POINT_TYPES)
655
- _op(:cos, input, name: name)
656
- end
657
-
658
- ##
659
- # Computes tan of input element-wise.
660
- def tan(input, name: nil)
661
- check_allowed_types(input, FLOATING_POINT_TYPES)
662
- _op(:tan, input, name: name)
663
- end
664
-
665
- ##
666
- # Computes tanh of input element-wise.
667
- def tanh(input, name: nil)
668
- check_allowed_types(input, FLOATING_POINT_TYPES)
669
- _op(:tanh, input, name: name)
670
- end
671
-
672
- ##
673
408
  # Computes sec of input element-wise.
674
409
  def sec(input, name: nil)
675
410
  check_allowed_types(input, FLOATING_POINT_TYPES)
@@ -704,46 +439,16 @@ module TensorStream
704
439
  _op(:exp, input, name: name)
705
440
  end
706
441
 
707
- ##
708
- # Creates a tensor filled with a scalar value.
709
- #
710
- # This operation creates a tensor of shape dims and fills it with value.
711
- #
712
- # For example:
713
- # Output tensor has shape [2, 3].
714
- # fill([2, 3], 9) => [[9, 9, 9]
715
- # [9, 9, 9]]
716
- def fill(dims, value, name: nil)
717
- _op(:fill, dims, value, name: name)
718
- end
719
-
720
- ##
721
- # Computes sigmoid of x element-wise.
722
- def sigmoid(input, name: nil)
723
- check_allowed_types(input, FLOATING_POINT_TYPES)
724
- _op(:sigmoid, input, name: name)
725
- end
726
-
727
- ##
728
- # Multiplies matrix a by matrix b, producing a * b.
729
- # The inputs must, following any transpositions, be tensors of rank 2 .
730
- def matmul(input_a, input_b, transpose_a: false,
731
- transpose_b: false,
732
- name: nil)
733
- input_a, input_b = check_data_types(input_a, input_b)
734
- _op(:mat_mul, input_a, input_b, transpose_a: transpose_a, transpose_b: transpose_b, name: name)
735
- end
736
-
737
442
  ##
738
443
  # Transposes a. Permutes the dimensions according to perm.
739
- def transpose(tensor, perm = nil, name: 'transpose')
444
+ def transpose(tensor, perm = nil, name: "transpose")
740
445
  _op(:transpose, tensor, perm, name: name)
741
446
  end
742
447
 
743
448
  ##
744
449
  # Pads a tensor.
745
450
  # This operation pads a tensor according to the paddings you specify.
746
- def pad(tensor, paddings, mode: 'CONSTANT', name: nil)
451
+ def pad(tensor, paddings, mode: "CONSTANT", name: nil)
747
452
  _op(:pad, tensor, paddings, mode: mode, name: name)
748
453
  end
749
454
 
@@ -754,10 +459,6 @@ module TensorStream
754
459
  _op(:check_numerics, tensor, message: message, name: name)
755
460
  end
756
461
 
757
- def size(tensor, name: nil, out_type: :int32)
758
- _op(:size, tensor, name: name, out_type: out_type)
759
- end
760
-
761
462
  def squared_difference(input_a, input_b, name: nil)
762
463
  _op(:squared_difference, input_a, input_b, name: name)
763
464
  end
@@ -771,36 +472,36 @@ module TensorStream
771
472
  # Gather slices from params and axis according to indices.
772
473
  #
773
474
  def gather(params, indices, validate_indices: nil,
774
- name: nil,
775
- axis: 0)
475
+ name: nil,
476
+ axis: 0)
776
477
  _op(:gather, params, indices, validate_indices: validate_indices, name: name, axis: axis)
777
478
  end
778
479
 
779
480
  ##
780
481
  # Stacks a list of rank-R tensors into one rank-(R+1) tensor.
781
482
  #
782
- def stack(values, axis: 0, name: 'stack')
483
+ def stack(values, axis: 0, name: "stack")
783
484
  _op(:stack, *values, axis: axis, name: name)
784
485
  end
785
486
 
786
487
  ##
787
488
  # Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors.
788
489
  #
789
- def unstack(value, num: nil, axis: 0, name: 'unstack')
490
+ def unstack(value, num: nil, axis: 0, name: "unstack")
790
491
  res = _op(:unstack, value, num: num, axis: axis, name: name)
791
492
 
792
493
  num_vars = if value.shape.known?
793
- new_shape = value.shape.shape.dup
794
- rank = new_shape.size - 1
795
- axis = rank + axis if axis < 0
796
- rotated_shape = Array.new(axis + 1) { new_shape.shift }
797
- new_shape = rotated_shape.rotate!(-1) + new_shape
798
- new_shape[0]
799
- else
800
- raise TensorStream::ValueError, "num is unspecified and cannot be inferred." if num.nil?
801
-
802
- num
803
- end
494
+ new_shape = value.shape.shape.dup
495
+ rank = new_shape.size - 1
496
+ axis = rank + axis if axis < 0
497
+ rotated_shape = Array.new(axis + 1) { new_shape.shift }
498
+ new_shape = rotated_shape.rotate!(-1) + new_shape
499
+ new_shape[0]
500
+ else
501
+ raise TensorStream::ValueError, "num is unspecified and cannot be inferred." if num.nil?
502
+
503
+ num
504
+ end
804
505
 
805
506
  return res[0] if num_vars == 1
806
507
 
@@ -811,14 +512,14 @@ module TensorStream
811
512
 
812
513
  ##
813
514
  # Same as stack
814
- def pack(values, axis: 0, name: 'pack')
515
+ def pack(values, axis: 0, name: "pack")
815
516
  _op(:stack, *values, axis: axis, name: name)
816
517
  end
817
518
 
818
519
  ##
819
520
  # Same as unstack
820
521
  #
821
- def unpack(value, num: nil, axis: 0, name: 'unpack')
522
+ def unpack(value, num: nil, axis: 0, name: "unpack")
822
523
  unstack(value, num: num, axis: axis, name: name)
823
524
  end
824
525
 
@@ -878,5 +579,15 @@ module TensorStream
878
579
  def invert_permutation(x, name: nil)
879
580
  _op(:invert_permutation, x, name: name)
880
581
  end
582
+
583
+ def cast_axis(input, axis)
584
+ if !axis.nil?
585
+ axis
586
+ elsif input.shape.known?
587
+ (0...input.shape.ndims).to_a
588
+ else
589
+ range(0, rank(input))
590
+ end
591
+ end
881
592
  end
882
593
  end