tensorflow-ruby 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,219 @@
1
+ module Tensorflow
2
+ module Eager
3
+ class Operation
4
+ attr_reader :context, :guessed_dtype, :op_def, :status
5
+
6
+ def initialize(context, op_type, inputs, attrs)
7
+ @context = context
8
+ @op_def = case op_type
9
+ when Graph::Function
10
+ op_type.function_def.signature
11
+ else
12
+ Tensorflow.op_def(op_type)
13
+ end
14
+ raise(Error::InvalidArgumentError, "Invalid op type: #{op_type}") unless @op_def
15
+
16
+ @status = Status.new
17
+ @pointer = FFI.TFE_NewOp(context, self.op_def.name, self.status)
18
+ name = attrs.delete(:name) || op_type
19
+
20
+ inputs = Array(inputs)
21
+ @guessed_dtype = figure_dtype(attrs, inputs)
22
+
23
+ setup_inputs(inputs, attrs)
24
+ setup_attrs(attrs)
25
+ end
26
+
27
+ def to_ptr
28
+ @pointer
29
+ end
30
+
31
+ def dtype
32
+ list_ptr = ::FFI::MemoryPointer.new(:int)
33
+ FFI.TFE_OpGetAttrType(self, 'dtype', list_ptr, self.status)
34
+ end
35
+
36
+ def figure_dtype(attrs, inputs)
37
+ attr_def = self.op_def.attr.detect do |attr_def|
38
+ attr_def.type == 'type'
39
+ end
40
+
41
+ result = attr_def ? attrs[attr_def.name.to_sym] : nil
42
+ unless result
43
+ inputs.each do |input|
44
+ case input
45
+ when Operation
46
+ return input.output_types.first
47
+ when Variable
48
+ return input.dtype
49
+ end
50
+ end
51
+ end
52
+ result
53
+ end
54
+
55
+ def setup_attrs(attrs)
56
+ attrs.each do |attr_name, attr_value|
57
+ next unless attr_value
58
+
59
+ attr_name = attr_name.to_s
60
+ list_ptr = ::FFI::MemoryPointer.new(:int)
61
+ type = FFI.TFE_OpGetAttrType(self, attr_name, list_ptr, self.status)
62
+ self.status.check
63
+ is_list = Boolean(list_ptr.read_int)
64
+
65
+ if is_list
66
+ add_list_attr(type, attr_name, attr_value)
67
+ else
68
+ add_scalar_attr(type, attr_name, attr_value)
69
+ end
70
+ end
71
+ end
72
+
73
+ def add_list_attr(type, attr_name, attr_value)
74
+ num_values = attr_value.size
75
+
76
+ case type
77
+ when :int
78
+ values = ::FFI::MemoryPointer.new(:int64, num_values)
79
+ values.write_array_of_int64(attr_value)
80
+ FFI.TFE_OpSetAttrIntList(self, attr_name, values, num_values)
81
+ when :float
82
+ values = ::FFI::MemoryPointer.new(:float, num_values)
83
+ values.write_array_of_float(attr_value)
84
+ FFI.TFE_OpSetAttrFloatList(self, attr_name, values, num_values)
85
+ when :shape
86
+ dims_pointer = ::FFI::MemoryPointer.new(:pointer, num_values)
87
+ num_dims_pointer = ::FFI::MemoryPointer.new(:int32, num_values)
88
+ attr_value.each_with_index do |shape, i|
89
+ dim_pointer = ::FFI::MemoryPointer.new(:int64, shape.length)
90
+ dim_pointer.write_array_of_int64(shape)
91
+ dims_pointer.put_pointer(i * ::FFI.type_size(:pointer), dim_pointer)
92
+ num_dims_pointer.put_int32(i * ::FFI.type_size(:int32), shape.length)
93
+ end
94
+ FFI.TFE_OpSetAttrShapeList(self, attr_name, dims_pointer, num_dims_pointer, num_values, self.status)
95
+ self.status.check
96
+ when :type
97
+ values = ::FFI::MemoryPointer.new(:int, num_values)
98
+ types =
99
+ attr_value.map do |v|
100
+ if v.is_a?(Symbol)
101
+ FFI::DataType[v]
102
+ else
103
+ v
104
+ end
105
+ end
106
+ values.write_array_of_int(types)
107
+ FFI.TFE_OpSetAttrTypeList(self, attr_name, values, num_values)
108
+ else
109
+ raise "Unknown list type: #{type}"
110
+ end
111
+ end
112
+
113
+ def add_scalar_attr(type, attr_name, attr_value)
114
+ case type
115
+ when :string
116
+ FFI.TFE_OpSetAttrString(self, attr_name, attr_value, attr_value.bytesize)
117
+ when :int
118
+ FFI.TFE_OpSetAttrInt(self, attr_name, attr_value)
119
+ when :float
120
+ FFI.TFE_OpSetAttrFloat(self, attr_name, attr_value)
121
+ when :bool
122
+ FFI.TFE_OpSetAttrBool(self, attr_name, attr_value ? 1 : 0)
123
+ when :type
124
+ attr_value = FFI::DataType[attr_value] if attr_value.is_a?(Symbol)
125
+ FFI.TFE_OpSetAttrType(self, attr_name, attr_value)
126
+ when :shape
127
+ ptr = ::FFI::MemoryPointer.new(:int64, attr_value.size)
128
+ ptr.write_array_of_int64(attr_value)
129
+ FFI.TFE_OpSetAttrShape(self, attr_name, ptr, attr_value.size, self.status)
130
+ when :tensor
131
+ attr_value = TensorHandle.from_value(self.context, attr_value)
132
+ FFI.TFE_OpSetAttrTensor(self, attr_name, attr_value.tensor, self.status)
133
+ # when :placeholder
134
+ when :func
135
+ case attr_value
136
+ when Graph::Function
137
+ FFI.TFE_OpSetAttrFunctionName(self, attr_name, attr_value.name, attr_value.name.length)
138
+ when String
139
+ FFI.TFE_OpSetAttrFunctionName(self, attr_name, attr_value, attr_value.length)
140
+ else
141
+ self.status.set(:tf_invalid_argument, "Invalid function attribute for attribute: #{attr_name}")
142
+ end
143
+ else
144
+ self.status.set(:tf_unknown, "Unsupported attribute type: #{type}")
145
+ end
146
+ self.status.check
147
+ end
148
+
149
+ def setup_inputs(inputs, attrs)
150
+ inputs.each_with_index do |input, index|
151
+ setup_input(index, input, attrs)
152
+ end
153
+ end
154
+
155
+ def check_input(arg_def, input, dtype)
156
+ case input
157
+ when Variable
158
+ arg_def.type == :DT_RESOURCE ? input.handle : input.value_handle
159
+ else
160
+ TensorHandle.from_value(self.context, input, dtype: dtype)
161
+ end
162
+ end
163
+
164
+ def setup_input(index, value, attrs)
165
+ if value.nil?
166
+ self.status.set(:tf_invalid_argument, "Argument is unset. Index: #{index}")
167
+ self.status.check
168
+ end
169
+
170
+ arg_def = self.op_def.input_arg[index]
171
+ dtype = attrs[arg_def.type_attr.to_sym]
172
+
173
+ # Value can be an operation with multiple outputs. For example calling PACK with an input operation of SPLIT
174
+ checked_value = if (!arg_def.number_attr.empty? || !arg_def.type_list_attr.empty?) && value.is_a?(Array)
175
+ value.map do |sub_value|
176
+ self.check_input(arg_def, sub_value, dtype)
177
+ end
178
+ else
179
+ self.check_input(arg_def, value, dtype)
180
+ end
181
+
182
+ if !arg_def.type_list_attr.empty?
183
+ # This input is a heterogeneous list
184
+ self.add_input_list(checked_value)
185
+ elsif !arg_def.number_attr.empty? && !arg_def.type_attr.empty?
186
+ # This input is a homogeneous list
187
+ self.add_input_list(checked_value)
188
+ elsif !arg_def.number_attr.empty?
189
+ # This is a list but we have to set it up one input at a time
190
+ checked_value.each do |sub_checked_value|
191
+ self.add_input(sub_checked_value)
192
+ end
193
+ else
194
+ # This input is a single item
195
+ self.add_input(checked_value)
196
+ end
197
+ end
198
+
199
+ def add_input(value)
200
+ # Check to see if the operation has multiple outputs, and if it does, we need to pack them together
201
+ # to fit into one input
202
+ if value.is_a?(Array) && value.length > 1
203
+ packed = Tensorflow.pack(value)
204
+ FFI.TFE_OpAddInput(self, packed, self.status)
205
+ else
206
+ FFI.TFE_OpAddInput(self, value, self.status)
207
+ end
208
+ self.status.check
209
+ end
210
+
211
+ def add_input_list(values)
212
+ input_ptr = ::FFI::MemoryPointer.new(:pointer, values.length)
213
+ input_ptr.write_array_of_pointer(values)
214
+ FFI.TFE_OpAddInputList(self, input_ptr, values.length, self.status)
215
+ self.status.check
216
+ end
217
+ end
218
+ end
219
+ end
@@ -0,0 +1,87 @@
1
+ module Tensorflow
2
+ module Eager
3
+ class TensorHandle
4
+ include TensorMixin
5
+ include Operators
6
+
7
+ attr_reader :context
8
+
9
+ def self.finalize(pointer)
10
+ proc do
11
+ FFI.TFE_DeleteTensorHandle(pointer)
12
+ end
13
+ end
14
+
15
+ def self.from_value(context, value, dtype: nil)
16
+ case value
17
+ when TensorHandle
18
+ value
19
+ when Data::Dataset
20
+ value.variant_tensor
21
+ when Tensor
22
+ TensorHandle.new(context, value)
23
+ when Variable
24
+ value.value_handle
25
+ else
26
+ TensorHandle.new(context, Tensor.new(value, dtype: dtype))
27
+ end
28
+ end
29
+
30
+ def initialize(context, value)
31
+ @context = context
32
+ case value
33
+ when ::FFI::Pointer
34
+ @pointer = value
35
+ when Tensor
36
+ Status.check do |status|
37
+ @pointer = FFI.TFE_NewTensorHandle(value, status)
38
+ end
39
+ # We need to keep the tensor live so that it is not freed!
40
+ @tensor = value
41
+ else
42
+ raise(Error::InvalidArgumentError, "Invalid value passed to tensor_handle: #{value}")
43
+ end
44
+
45
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
46
+ end
47
+
48
+ def to_ptr
49
+ @pointer
50
+ end
51
+
52
+ def tensor
53
+ Status.check do |status|
54
+ Tensor.from_pointer(FFI.TFE_TensorHandleResolve(self, status))
55
+ end
56
+ end
57
+
58
+ def dtype
59
+ FFI.TFE_TensorHandleDataType(self)
60
+ end
61
+
62
+ def element_count
63
+ Status.check do |status|
64
+ FFI.TFE_TensorHandleNumElements(self, status)
65
+ end
66
+ end
67
+
68
+ def value
69
+ self.tensor.value
70
+ end
71
+
72
+ private
73
+
74
+ def num_dims
75
+ Status.check do |status|
76
+ FFI.TFE_TensorHandleNumDims(self, status)
77
+ end
78
+ end
79
+
80
+ def dim(index)
81
+ Status.check do |status|
82
+ FFI.TFE_TensorHandleDim(self, index, status)
83
+ end
84
+ end
85
+ end
86
+ end
87
+ end
@@ -0,0 +1,54 @@
1
+ module Tensorflow
2
+ module Error
3
+ class AbortedError < StandardError
4
+ end
5
+
6
+ class AlreadyExistsError < StandardError
7
+ end
8
+
9
+ class CancelledError < StandardError
10
+ end
11
+
12
+ class DataLossError < StandardError
13
+ end
14
+
15
+ class DeadlineExceededError < StandardError
16
+ end
17
+
18
+ class FailedPreconditionError < StandardError
19
+ end
20
+
21
+ class InternalError < StandardError
22
+ end
23
+
24
+ class InvalidArgumentError < StandardError
25
+ end
26
+
27
+ class NotFoundError < StandardError
28
+ end
29
+
30
+ class OpError < StandardError
31
+ end
32
+
33
+ class OutOfRangeError < StandardError
34
+ end
35
+
36
+ class PermissionDeniedError < StandardError
37
+ end
38
+
39
+ class ResourceExhaustedError < StandardError
40
+ end
41
+
42
+ class UnauthenticatedError < StandardError
43
+ end
44
+
45
+ class UnavailableError < StandardError
46
+ end
47
+
48
+ class UnimplementedError < StandardError
49
+ end
50
+
51
+ class UnknownError < StandardError
52
+ end
53
+ end
54
+ end
@@ -0,0 +1,62 @@
1
+ module Tensorflow
2
+ class ExecutionContext
3
+ class << self
4
+ extend Forwardable
5
+ def_delegators :context, :push, :pop, :current, :eager?, :graph?
6
+ end
7
+
8
+ def self.context
9
+ Thread.current[:execution_context] ||= self.new
10
+ end
11
+
12
+ def initialize
13
+ @stack = Array.new
14
+ end
15
+
16
+ def push(value)
17
+ @stack.push(value)
18
+ end
19
+
20
+ def pop
21
+ @stack.pop
22
+ end
23
+
24
+ def figure_from_inputs(inputs=[])
25
+ inputs.flatten.each do |input|
26
+ case input
27
+ when Graph::Operation
28
+ return input.graph
29
+ when Eager::TensorHandle
30
+ return input.context
31
+ end
32
+ end
33
+ nil
34
+ end
35
+
36
+ def figure_from_context
37
+ @stack.last
38
+ end
39
+
40
+ def figure_from_execution_mode
41
+ if ::Tensorflow.execution_mode == Tensorflow::GRAPH_MODE
42
+ Graph::Graph.default
43
+ else
44
+ Eager::Context.default
45
+ end
46
+ end
47
+
48
+ def current(inputs=[])
49
+ figure_from_context || figure_from_inputs(inputs) || figure_from_execution_mode
50
+ end
51
+
52
+ def eager?(inputs=[])
53
+ context = self.current(inputs)
54
+ context.is_a?(Eager::Context)
55
+ end
56
+
57
+ def graph?(inputs=[])
58
+ context = self.current(inputs)
59
+ context.is_a?(Graph::Graph)
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,58 @@
1
+ module Tensorflow
2
+ class OpDef
3
+ class ArgDef
4
+ def dtype
5
+ case self.type
6
+ when :DT_INVALID
7
+ nil
8
+ when :DT_FLOAT
9
+ :float
10
+ when :DT_DOUBLE
11
+ :double
12
+ when :DT_INT32
13
+ :int32
14
+ when :DT_UINT8
15
+ :uint8
16
+ when :DT_INT16
17
+ :int16
18
+ when :DT_INT8
19
+ :int8
20
+ when :DT_STRING
21
+ :string
22
+ when :DT_COMPLEX64
23
+ :complex64
24
+ when :DT_INT64
25
+ :int64
26
+ when :DT_BOOL
27
+ :bool
28
+ when :DT_QINT8
29
+ :qint8
30
+ when :DT_QUINT8
31
+ :quint8
32
+ when :DT_QINT32
33
+ :qint32
34
+ when :DT_BFLOAT16
35
+ :bfloat16
36
+ when :DT_QINT16
37
+ :qint16
38
+ when :DT_QUINT16
39
+ :quint16
40
+ when :DT_UINT16
41
+ :uint16
42
+ when :DT_COMPLEX128
43
+ :complex128
44
+ when :DT_HALF
45
+ :half
46
+ when :DT_RESOURCE
47
+ :resource
48
+ when :DT_VARIANT
49
+ :variant
50
+ when :DT_UINT32
51
+ :uint32
52
+ when :DT_UINT64
53
+ :uint64
54
+ end
55
+ end
56
+ end
57
+ end
58
+ end