tensorflow-ruby 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,219 @@
1
+ module Tensorflow
2
+ module Eager
3
+ class Operation
4
+ attr_reader :context, :guessed_dtype, :op_def, :status
5
+
6
+ def initialize(context, op_type, inputs, attrs)
7
+ @context = context
8
+ @op_def = case op_type
9
+ when Graph::Function
10
+ op_type.function_def.signature
11
+ else
12
+ Tensorflow.op_def(op_type)
13
+ end
14
+ raise(Error::InvalidArgumentError, "Invalid op type: #{op_type}") unless @op_def
15
+
16
+ @status = Status.new
17
+ @pointer = FFI.TFE_NewOp(context, self.op_def.name, self.status)
18
+ name = attrs.delete(:name) || op_type
19
+
20
+ inputs = Array(inputs)
21
+ @guessed_dtype = figure_dtype(attrs, inputs)
22
+
23
+ setup_inputs(inputs, attrs)
24
+ setup_attrs(attrs)
25
+ end
26
+
27
+ def to_ptr
28
+ @pointer
29
+ end
30
+
31
+ def dtype
32
+ list_ptr = ::FFI::MemoryPointer.new(:int)
33
+ FFI.TFE_OpGetAttrType(self, 'dtype', list_ptr, self.status)
34
+ end
35
+
36
+ def figure_dtype(attrs, inputs)
37
+ attr_def = self.op_def.attr.detect do |attr_def|
38
+ attr_def.type == 'type'
39
+ end
40
+
41
+ result = attr_def ? attrs[attr_def.name.to_sym] : nil
42
+ unless result
43
+ inputs.each do |input|
44
+ case input
45
+ when Operation
46
+ return input.output_types.first
47
+ when Variable
48
+ return input.dtype
49
+ end
50
+ end
51
+ end
52
+ result
53
+ end
54
+
55
+ def setup_attrs(attrs)
56
+ attrs.each do |attr_name, attr_value|
57
+ next unless attr_value
58
+
59
+ attr_name = attr_name.to_s
60
+ list_ptr = ::FFI::MemoryPointer.new(:int)
61
+ type = FFI.TFE_OpGetAttrType(self, attr_name, list_ptr, self.status)
62
+ self.status.check
63
+ is_list = Boolean(list_ptr.read_int)
64
+
65
+ if is_list
66
+ add_list_attr(type, attr_name, attr_value)
67
+ else
68
+ add_scalar_attr(type, attr_name, attr_value)
69
+ end
70
+ end
71
+ end
72
+
73
+ def add_list_attr(type, attr_name, attr_value)
74
+ num_values = attr_value.size
75
+
76
+ case type
77
+ when :int
78
+ values = ::FFI::MemoryPointer.new(:int64, num_values)
79
+ values.write_array_of_int64(attr_value)
80
+ FFI.TFE_OpSetAttrIntList(self, attr_name, values, num_values)
81
+ when :float
82
+ values = ::FFI::MemoryPointer.new(:float, num_values)
83
+ values.write_array_of_float(attr_value)
84
+ FFI.TFE_OpSetAttrFloatList(self, attr_name, values, num_values)
85
+ when :shape
86
+ dims_pointer = ::FFI::MemoryPointer.new(:pointer, num_values)
87
+ num_dims_pointer = ::FFI::MemoryPointer.new(:int32, num_values)
88
+ attr_value.each_with_index do |shape, i|
89
+ dim_pointer = ::FFI::MemoryPointer.new(:int64, shape.length)
90
+ dim_pointer.write_array_of_int64(shape)
91
+ dims_pointer.put_pointer(i * ::FFI.type_size(:pointer), dim_pointer)
92
+ num_dims_pointer.put_int32(i * ::FFI.type_size(:int32), shape.length)
93
+ end
94
+ FFI.TFE_OpSetAttrShapeList(self, attr_name, dims_pointer, num_dims_pointer, num_values, self.status)
95
+ self.status.check
96
+ when :type
97
+ values = ::FFI::MemoryPointer.new(:int, num_values)
98
+ types =
99
+ attr_value.map do |v|
100
+ if v.is_a?(Symbol)
101
+ FFI::DataType[v]
102
+ else
103
+ v
104
+ end
105
+ end
106
+ values.write_array_of_int(types)
107
+ FFI.TFE_OpSetAttrTypeList(self, attr_name, values, num_values)
108
+ else
109
+ raise "Unknown list type: #{type}"
110
+ end
111
+ end
112
+
113
+ def add_scalar_attr(type, attr_name, attr_value)
114
+ case type
115
+ when :string
116
+ FFI.TFE_OpSetAttrString(self, attr_name, attr_value, attr_value.bytesize)
117
+ when :int
118
+ FFI.TFE_OpSetAttrInt(self, attr_name, attr_value)
119
+ when :float
120
+ FFI.TFE_OpSetAttrFloat(self, attr_name, attr_value)
121
+ when :bool
122
+ FFI.TFE_OpSetAttrBool(self, attr_name, attr_value ? 1 : 0)
123
+ when :type
124
+ attr_value = FFI::DataType[attr_value] if attr_value.is_a?(Symbol)
125
+ FFI.TFE_OpSetAttrType(self, attr_name, attr_value)
126
+ when :shape
127
+ ptr = ::FFI::MemoryPointer.new(:int64, attr_value.size)
128
+ ptr.write_array_of_int64(attr_value)
129
+ FFI.TFE_OpSetAttrShape(self, attr_name, ptr, attr_value.size, self.status)
130
+ when :tensor
131
+ attr_value = TensorHandle.from_value(self.context, attr_value)
132
+ FFI.TFE_OpSetAttrTensor(self, attr_name, attr_value.tensor, self.status)
133
+ # when :placeholder
134
+ when :func
135
+ case attr_value
136
+ when Graph::Function
137
+ FFI.TFE_OpSetAttrFunctionName(self, attr_name, attr_value.name, attr_value.name.length)
138
+ when String
139
+ FFI.TFE_OpSetAttrFunctionName(self, attr_name, attr_value, attr_value.length)
140
+ else
141
+ self.status.set(:tf_invalid_argument, "Invalid function attribute for attribute: #{attr_name}")
142
+ end
143
+ else
144
+ self.status.set(:tf_unknown, "Unsupported attribute type: #{type}")
145
+ end
146
+ self.status.check
147
+ end
148
+
149
+ def setup_inputs(inputs, attrs)
150
+ inputs.each_with_index do |input, index|
151
+ setup_input(index, input, attrs)
152
+ end
153
+ end
154
+
155
+ def check_input(arg_def, input, dtype)
156
+ case input
157
+ when Variable
158
+ arg_def.type == :DT_RESOURCE ? input.handle : input.value_handle
159
+ else
160
+ TensorHandle.from_value(self.context, input, dtype: dtype)
161
+ end
162
+ end
163
+
164
+ def setup_input(index, value, attrs)
165
+ if value.nil?
166
+ self.status.set(:tf_invalid_argument, "Argument is unset. Index: #{index}")
167
+ self.status.check
168
+ end
169
+
170
+ arg_def = self.op_def.input_arg[index]
171
+ dtype = attrs[arg_def.type_attr.to_sym]
172
+
173
+ # Value can be an operation with multiple outputs. For example calling PACK with an input operation of SPLIT
174
+ checked_value = if (!arg_def.number_attr.empty? || !arg_def.type_list_attr.empty?) && value.is_a?(Array)
175
+ value.map do |sub_value|
176
+ self.check_input(arg_def, sub_value, dtype)
177
+ end
178
+ else
179
+ self.check_input(arg_def, value, dtype)
180
+ end
181
+
182
+ if !arg_def.type_list_attr.empty?
183
+ # This input is a heterogeneous list
184
+ self.add_input_list(checked_value)
185
+ elsif !arg_def.number_attr.empty? && !arg_def.type_attr.empty?
186
+ # This input is a homogeneous list
187
+ self.add_input_list(checked_value)
188
+ elsif !arg_def.number_attr.empty?
189
+ # This is a list but we have to set it up one input at a time
190
+ checked_value.each do |sub_checked_value|
191
+ self.add_input(sub_checked_value)
192
+ end
193
+ else
194
+ # This input is a single item
195
+ self.add_input(checked_value)
196
+ end
197
+ end
198
+
199
+ def add_input(value)
200
+ # Check to see if the operation has multiple outputs, and if it does, we need to pack them together
201
+ # to fit into one input
202
+ if value.is_a?(Array) && value.length > 1
203
+ packed = Tensorflow.pack(value)
204
+ FFI.TFE_OpAddInput(self, packed, self.status)
205
+ else
206
+ FFI.TFE_OpAddInput(self, value, self.status)
207
+ end
208
+ self.status.check
209
+ end
210
+
211
+ def add_input_list(values)
212
+ input_ptr = ::FFI::MemoryPointer.new(:pointer, values.length)
213
+ input_ptr.write_array_of_pointer(values)
214
+ FFI.TFE_OpAddInputList(self, input_ptr, values.length, self.status)
215
+ self.status.check
216
+ end
217
+ end
218
+ end
219
+ end
@@ -0,0 +1,87 @@
1
+ module Tensorflow
2
+ module Eager
3
+ class TensorHandle
4
+ include TensorMixin
5
+ include Operators
6
+
7
+ attr_reader :context
8
+
9
+ def self.finalize(pointer)
10
+ proc do
11
+ FFI.TFE_DeleteTensorHandle(pointer)
12
+ end
13
+ end
14
+
15
+ def self.from_value(context, value, dtype: nil)
16
+ case value
17
+ when TensorHandle
18
+ value
19
+ when Data::Dataset
20
+ value.variant_tensor
21
+ when Tensor
22
+ TensorHandle.new(context, value)
23
+ when Variable
24
+ value.value_handle
25
+ else
26
+ TensorHandle.new(context, Tensor.new(value, dtype: dtype))
27
+ end
28
+ end
29
+
30
+ def initialize(context, value)
31
+ @context = context
32
+ case value
33
+ when ::FFI::Pointer
34
+ @pointer = value
35
+ when Tensor
36
+ Status.check do |status|
37
+ @pointer = FFI.TFE_NewTensorHandle(value, status)
38
+ end
39
+ # We need to keep the tensor live so that it is not freed!
40
+ @tensor = value
41
+ else
42
+ raise(Error::InvalidArgumentError, "Invalid value passed to tensor_handle: #{value}")
43
+ end
44
+
45
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
46
+ end
47
+
48
+ def to_ptr
49
+ @pointer
50
+ end
51
+
52
+ def tensor
53
+ Status.check do |status|
54
+ Tensor.from_pointer(FFI.TFE_TensorHandleResolve(self, status))
55
+ end
56
+ end
57
+
58
+ def dtype
59
+ FFI.TFE_TensorHandleDataType(self)
60
+ end
61
+
62
+ def element_count
63
+ Status.check do |status|
64
+ FFI.TFE_TensorHandleNumElements(self, status)
65
+ end
66
+ end
67
+
68
+ def value
69
+ self.tensor.value
70
+ end
71
+
72
+ private
73
+
74
+ def num_dims
75
+ Status.check do |status|
76
+ FFI.TFE_TensorHandleNumDims(self, status)
77
+ end
78
+ end
79
+
80
+ def dim(index)
81
+ Status.check do |status|
82
+ FFI.TFE_TensorHandleDim(self, index, status)
83
+ end
84
+ end
85
+ end
86
+ end
87
+ end
@@ -0,0 +1,54 @@
1
+ module Tensorflow
2
+ module Error
3
+ class AbortedError < StandardError
4
+ end
5
+
6
+ class AlreadyExistsError < StandardError
7
+ end
8
+
9
+ class CancelledError < StandardError
10
+ end
11
+
12
+ class DataLossError < StandardError
13
+ end
14
+
15
+ class DeadlineExceededError < StandardError
16
+ end
17
+
18
+ class FailedPreconditionError < StandardError
19
+ end
20
+
21
+ class InternalError < StandardError
22
+ end
23
+
24
+ class InvalidArgumentError < StandardError
25
+ end
26
+
27
+ class NotFoundError < StandardError
28
+ end
29
+
30
+ class OpError < StandardError
31
+ end
32
+
33
+ class OutOfRangeError < StandardError
34
+ end
35
+
36
+ class PermissionDeniedError < StandardError
37
+ end
38
+
39
+ class ResourceExhaustedError < StandardError
40
+ end
41
+
42
+ class UnauthenticatedError < StandardError
43
+ end
44
+
45
+ class UnavailableError < StandardError
46
+ end
47
+
48
+ class UnimplementedError < StandardError
49
+ end
50
+
51
+ class UnknownError < StandardError
52
+ end
53
+ end
54
+ end
@@ -0,0 +1,62 @@
1
+ module Tensorflow
2
+ class ExecutionContext
3
+ class << self
4
+ extend Forwardable
5
+ def_delegators :context, :push, :pop, :current, :eager?, :graph?
6
+ end
7
+
8
+ def self.context
9
+ Thread.current[:execution_context] ||= self.new
10
+ end
11
+
12
+ def initialize
13
+ @stack = Array.new
14
+ end
15
+
16
+ def push(value)
17
+ @stack.push(value)
18
+ end
19
+
20
+ def pop
21
+ @stack.pop
22
+ end
23
+
24
+ def figure_from_inputs(inputs=[])
25
+ inputs.flatten.each do |input|
26
+ case input
27
+ when Graph::Operation
28
+ return input.graph
29
+ when Eager::TensorHandle
30
+ return input.context
31
+ end
32
+ end
33
+ nil
34
+ end
35
+
36
+ def figure_from_context
37
+ @stack.last
38
+ end
39
+
40
+ def figure_from_execution_mode
41
+ if ::Tensorflow.execution_mode == Tensorflow::GRAPH_MODE
42
+ Graph::Graph.default
43
+ else
44
+ Eager::Context.default
45
+ end
46
+ end
47
+
48
+ def current(inputs=[])
49
+ figure_from_context || figure_from_inputs(inputs) || figure_from_execution_mode
50
+ end
51
+
52
+ def eager?(inputs=[])
53
+ context = self.current(inputs)
54
+ context.is_a?(Eager::Context)
55
+ end
56
+
57
+ def graph?(inputs=[])
58
+ context = self.current(inputs)
59
+ context.is_a?(Graph::Graph)
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,58 @@
1
+ module Tensorflow
2
+ class OpDef
3
+ class ArgDef
4
+ def dtype
5
+ case self.type
6
+ when :DT_INVALID
7
+ nil
8
+ when :DT_FLOAT
9
+ :float
10
+ when :DT_DOUBLE
11
+ :double
12
+ when :DT_INT32
13
+ :int32
14
+ when :DT_UINT8
15
+ :uint8
16
+ when :DT_INT16
17
+ :int16
18
+ when :DT_INT8
19
+ :int8
20
+ when :DT_STRING
21
+ :string
22
+ when :DT_COMPLEX64
23
+ :complex64
24
+ when :DT_INT64
25
+ :int64
26
+ when :DT_BOOL
27
+ :bool
28
+ when :DT_QINT8
29
+ :qint8
30
+ when :DT_QUINT8
31
+ :quint8
32
+ when :DT_QINT32
33
+ :qint32
34
+ when :DT_BFLOAT16
35
+ :bfloat16
36
+ when :DT_QINT16
37
+ :qint16
38
+ when :DT_QUINT16
39
+ :quint16
40
+ when :DT_UINT16
41
+ :uint16
42
+ when :DT_COMPLEX128
43
+ :complex128
44
+ when :DT_HALF
45
+ :half
46
+ when :DT_RESOURCE
47
+ :resource
48
+ when :DT_VARIANT
49
+ :variant
50
+ when :DT_UINT32
51
+ :uint32
52
+ when :DT_UINT64
53
+ :uint64
54
+ end
55
+ end
56
+ end
57
+ end
58
+ end