tensorflow-ruby 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,106 @@
1
+ module Tensorflow
2
+ module Data
3
+ class Dataset
4
+ # Copied from Python code
5
+ DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB
6
+
7
+ include Enumerable
8
+
9
+ # TODO remove
10
+ attr_reader :output_types, :output_shapes, :variant_tensor
11
+
12
+ def self.to_tensor_array(values)
13
+ case values
14
+ when Numo::NArray
15
+ [Tensor.new(values)]
16
+ when Tensor
17
+ [values]
18
+ when Array
19
+ values.to_a.map do |v|
20
+ if v.is_a?(Tensor)
21
+ v
22
+ else
23
+ Tensor.new(v)
24
+ end
25
+ end
26
+ when Graph::Operation
27
+ [values]
28
+ else
29
+ raise(Error::UnimplementedError, "Unsupported dataset element: #{values}")
30
+ end
31
+ end
32
+
33
+ def self.from_tensors(tensors)
34
+ TensorDataset.new(tensors)
35
+ end
36
+
37
+ def self.from_tensor_slices(tensors)
38
+ TensorSliceDataset.new(tensors)
39
+ end
40
+
41
+ def initialize(variant_tensor)
42
+ @variant_tensor = variant_tensor
43
+ end
44
+
45
+ def to_ptr
46
+ @variant_tensor.to_ptr
47
+ end
48
+
49
+ def with_options(options)
50
+
51
+ end
52
+
53
+ def batch(batch_size, drop_remainder: false)
54
+ BatchDataset.new(self, batch_size, drop_remainder)
55
+ end
56
+
57
+ def shuffle(buffer_size)
58
+ ShuffleDataset.new(self, buffer_size)
59
+ end
60
+
61
+ def make_one_shot_iterator
62
+ OneShotIterator.new(self)
63
+ end
64
+
65
+ def make_initializable_iterator(shared_name: '')
66
+ InitializableIterator.new(self, shared_name: shared_name)
67
+ end
68
+
69
+ def each
70
+ iterator, deleter = RawOps.anonymous_iterator_v2(output_types: @output_types, output_shapes: @output_shapes)
71
+ RawOps.make_iterator(@variant_tensor, iterator)
72
+ begin
73
+ loop do
74
+ values = RawOps.iterator_get_next_sync(iterator, output_types: @output_types, output_shapes: @output_shapes)
75
+ yield values
76
+ end
77
+ rescue Error::OutOfRangeError
78
+ end
79
+ ensure
80
+ RawOps.delete_iterator(iterator, deleter) if iterator
81
+ end
82
+
83
+ # !!! DEBUG method. You don't want to use this method it because it iterates over
84
+ # the entire dataset and reads it into a ruby array in memory
85
+ def data
86
+ self.map do |slice|
87
+ if slice.is_a?(Array)
88
+ slice.map do |tensor|
89
+ tensor.value
90
+ end
91
+ else
92
+ slice.value
93
+ end
94
+ end
95
+ end
96
+
97
+ def map_func(func)
98
+ MapDataset.new(self, func)
99
+ end
100
+
101
+ def repeat(count)
102
+ RepeatDataset.new(self, 3)
103
+ end
104
+ end
105
+ end
106
+ end
@@ -0,0 +1,27 @@
1
+ module Tensorflow
2
+ module Data
3
+ class FixedLengthRecordDataset < Dataset
4
+ def initialize(filenames, record_bytes, header_bytes: 0, footer_bytes: 0,
5
+ buffer_size: DEFAULT_READER_BUFFER_SIZE_BYTES, compression_type: '', num_parallel_reads: 0)
6
+ @output_types = [:string]
7
+ @output_shapes = [[]]
8
+
9
+ record_bytes_tensor = Tensor.new(record_bytes, dtype: :int64)
10
+ header_bytes_tensor = Tensor.new(header_bytes, dtype: :int64)
11
+ footer_bytes_tensor = Tensor.new(footer_bytes, dtype: :int64)
12
+ buffer_size_tensor = Tensor.new(buffer_size, dtype: :int64)
13
+
14
+ variant_tensor = RawOps.fixed_length_record_dataset_v2(filenames,
15
+ header_bytes_tensor,
16
+ record_bytes_tensor,
17
+ footer_bytes_tensor,
18
+ buffer_size_tensor,
19
+ compression_type)
20
+
21
+ super(variant_tensor)
22
+ end
23
+
24
+
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,76 @@
1
+ module Tensorflow
2
+ module Data
3
+ class Iterator
4
+ attr_reader :output_types, :output_shapes
5
+
6
+ def self.from_structure(output_types, output_shapes=[], shared_name: '')
7
+ ReinitializableIterator.new(output_types, output_shapes, shared_name: shared_name)
8
+ end
9
+
10
+ def initialize(output_types, output_shapes=[])
11
+ @output_types = output_types
12
+ @output_shapes = output_shapes
13
+ end
14
+
15
+ def get_next
16
+ RawOps.iterator_get_next(@iterator, output_types: self.output_types, output_shapes: self.output_shapes)
17
+ end
18
+ end
19
+
20
+ class OneShotIterator < Iterator
21
+ def initialize(dataset)
22
+ super(dataset.output_types, dataset.output_shapes)
23
+ create_one_shot_iterator(dataset)
24
+ end
25
+
26
+ private
27
+
28
+ def create_one_shot_iterator(dataset)
29
+ function = make_dataset_function(dataset)
30
+ ExecutionContext.current.add_function(function)
31
+ @iterator = RawOps.one_shot_iterator(dataset_factory: function, output_types: self.output_types, output_shapes: self.output_shapes)
32
+ end
33
+
34
+ def make_dataset_function(dataset)
35
+ function = Graph::Graph.new.as_default do |func_graph|
36
+ optimize = RawOps.optimize_dataset(dataset.variant_tensor, ['noop_elimination'],
37
+ output_types: self.output_types, output_shapes: self.output_shapes)
38
+ func_graph.to_function('MakeDataset', nil, nil, [optimize])
39
+ end
40
+ end
41
+ end
42
+
43
+ class InitializableIterator < Iterator
44
+ attr_reader :initializer
45
+
46
+ def initialize(dataset, shared_name: '')
47
+ super(dataset.output_types, dataset.output_shapes)
48
+ create_initializable_iterator(dataset, shared_name)
49
+ end
50
+
51
+ private
52
+
53
+ def create_initializable_iterator(dataset, shared_name)
54
+ @iterator = RawOps.iterator_v2(shared_name: shared_name, output_types: self.output_types, output_shapes: self.output_shapes)
55
+ @initializer = RawOps.make_iterator(dataset.variant_tensor, @iterator)
56
+ end
57
+ end
58
+
59
+ class ReinitializableIterator < Iterator
60
+ def initialize(output_types, output_shapes, shared_name: '')
61
+ super(output_types, output_shapes)
62
+ create_iterator_from_structure(shared_name)
63
+ end
64
+
65
+ def make_initializer(dataset)
66
+ RawOps.make_iterator(dataset.variant_tensor, @iterator)
67
+ end
68
+
69
+ private
70
+
71
+ def create_iterator_from_structure(shared_name)
72
+ @iterator = RawOps.iterator_v2(shared_name: shared_name, output_types: self.output_types, output_shapes: self.output_shapes)
73
+ end
74
+ end
75
+ end
76
+ end
@@ -0,0 +1,17 @@
1
+ module Tensorflow
2
+ module Data
3
+ class MapDataset < Dataset
4
+ def initialize(input_dataset, function, other_arguments: [])
5
+ @output_types = function.output_types
6
+ @output_shapes = function.output_shapes
7
+
8
+ variant_tensor = RawOps.map_dataset(input_dataset.variant_tensor, other_arguments,
9
+ f: function,
10
+ output_types: @output_types,
11
+ output_shapes: @output_shapes)
12
+
13
+ super(variant_tensor)
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,16 @@
1
+ module Tensorflow
2
+ module Data
3
+ class RepeatDataset < Dataset
4
+ def initialize(dataset, count)
5
+ @output_types = dataset.output_types
6
+ @output_shapes = dataset.output_shapes
7
+
8
+ variant_tensor = RawOps.repeat_dataset(dataset.variant_tensor, Tensor.new(count, dtype: :int64),
9
+ output_types: @output_types,
10
+ output_shapes: @output_shapes)
11
+
12
+ super(variant_tensor)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,23 @@
1
+ module Tensorflow
2
+ module Data
3
+ class ShuffleDataset < Dataset
4
+ def initialize(input_dataset, buffer_size)
5
+ @input_dataset = input_dataset
6
+ @output_types = input_dataset.output_types
7
+ @output_shapes = input_dataset.output_shapes
8
+
9
+ buffer_size = Tensor.new(buffer_size, dtype: :int64)
10
+ seed = Tensor.new(::Random.rand(100_000_000), dtype: :int64)
11
+ seed2 = Tensor.new(::Random.rand(100_000_000), dtype: :int64)
12
+
13
+ variant_tensor = RawOps.shuffle_dataset(input_dataset.variant_tensor,
14
+ buffer_size,
15
+ seed,
16
+ seed2,
17
+ output_types: @output_types,
18
+ output_shapes: @output_shapes)
19
+ super(variant_tensor)
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,19 @@
1
+ module Tensorflow
2
+ module Data
3
+ class TensorDataset < Dataset
4
+ def initialize(elements)
5
+ @tensors = self.class.to_tensor_array(elements)
6
+ @output_types = @tensors.map(&:dtype)
7
+ @output_shapes = @tensors.map do |tensor|
8
+ tensor.shape
9
+ end
10
+
11
+ variant_tensor = RawOps.tensor_dataset(@tensors,
12
+ toutput_types: @output_types,
13
+ output_shapes: @output_shapes)
14
+
15
+ super(variant_tensor)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,15 @@
1
+ module Tensorflow
2
+ module Data
3
+ class TensorSliceDataset < Dataset
4
+ def initialize(elements)
5
+ @tensors = self.class.to_tensor_array(elements)
6
+ @output_types = @tensors.map(&:dtype)
7
+ @output_shapes = @tensors.map do |tensor|
8
+ tensor.shape[1..]
9
+ end
10
+ variant_tensor = RawOps.tensor_slice_dataset(@tensors, toutput_types: @output_types, output_shapes: @output_shapes)
11
+ super(variant_tensor)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,18 @@
1
+ module Tensorflow
2
+ module Data
3
+ class TfRecordDataset < Dataset
4
+ DEFAULT_BUFFER_SIZE = 256 * 1_048_576 # 256 MB
5
+
6
+ def initialize(filenames, compression_type='', buffer_size=DEFAULT_BUFFER_SIZE)
7
+ filenames = Array(filenames)
8
+ @output_types = [:string]
9
+ @output_shapes = [[]]
10
+
11
+ buffer_size = Tensor.new(buffer_size, dtype: :int64) if buffer_size
12
+ variant_tensor = RawOps.tf_record_dataset(filenames, compression_type, buffer_size)
13
+
14
+ super(variant_tensor)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,24 @@
1
+ module Tensorflow
2
+ module Data
3
+ class ZipDataset < Dataset
4
+ def initialize(*datasets)
5
+ datasets = datasets.flatten(1)
6
+ tensors = Array.new(datasets.size)
7
+ output_types = Array.new(datasets.size)
8
+ output_shapes = Array.new(datasets.size)
9
+
10
+ datasets.each_with_index do |dataset, i|
11
+ tensors[i] = dataset.variant_tensor
12
+ output_types[i] = dataset.output_types
13
+ output_shapes[i] = dataset.output_shapes
14
+ end
15
+
16
+ @output_types = output_types.flatten
17
+ @output_shapes = output_shapes.flatten(1)
18
+ variant_tensor = RawOps.zip_dataset(tensors, n: tensors.count, output_types: @output_types, output_shapes: @output_shapes)
19
+
20
+ super(variant_tensor)
21
+ end
22
+ end
23
+ end
24
+ end
@@ -0,0 +1,53 @@
1
+ module Tensorflow
2
+ module Decorator
3
+ class Function
4
+ attr_reader :input_signatures
5
+
6
+ def initialize(input_signatures = [])
7
+ @input_signatures = input_signatures
8
+ end
9
+
10
+ def wrap(method)
11
+ Graph::FunctionDef.new(method, self.input_signatures)
12
+ end
13
+ end
14
+
15
+ def self.extended(klass)
16
+ @waiting_for_method = false
17
+ this = self
18
+ klass.instance_eval do
19
+ @tf = this
20
+ end
21
+
22
+ if klass.is_a?(Object) && klass.to_s == 'main'
23
+ klass.class.extend(self)
24
+ end
25
+ end
26
+
27
+ def self.function(input_signature = [])
28
+ @current_function = Function.new(input_signature)
29
+ end
30
+
31
+ def self.wrap_method(method)
32
+ # We do this little dance because when the method is wrapped it will trigger method_added. So first we need
33
+ # to clear out @current_function before continuing
34
+ if @current_function
35
+ current_function = @current_function
36
+ @current_function = nil
37
+ current_function&.wrap(method)
38
+ end
39
+ end
40
+
41
+ def singleton_method_added(method_name)
42
+ super(method_name)
43
+ method = self.method(method_name)
44
+ @tf.wrap_method(method)
45
+ end
46
+
47
+ def method_added(method_name)
48
+ super(method_name)
49
+ method = self.instance_method(method_name)
50
+ @tf.wrap_method(method)
51
+ end
52
+ end
53
+ end
@@ -0,0 +1,120 @@
1
+ module Tensorflow
2
+ module Eager
3
+ class Context
4
+ extend Forwardable
5
+ def_delegators :@name_scope, :name_scope, :scoped_name, :unique_name
6
+
7
+ def self.default
8
+ @default ||= Context.new
9
+ end
10
+
11
+ def self.finalize(pointer)
12
+ proc { FFI.TFE_DeleteContext(pointer) }
13
+ end
14
+
15
+ def initialize
16
+ @name_scope = NameScope.new
17
+ options = FFI.TFE_NewContextOptions
18
+ Status.check do |status|
19
+ @pointer = FFI.TFE_NewContext(options, status)
20
+ end
21
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
22
+ FFI.TFE_DeleteContextOptions(options)
23
+ end
24
+
25
+ def as_default
26
+ raise(Error::InvalidArgumentError, "Must provide block") unless block_given?
27
+ ExecutionContext.push(self)
28
+ begin
29
+ yield self
30
+ ensure
31
+ ExecutionContext.pop
32
+ end
33
+ end
34
+
35
+ def create_operation(op_type, inputs=[], attrs={})
36
+ Operation.new(self, op_type, inputs, attrs)
37
+ end
38
+
39
+ def execute(operation)
40
+ # TODO decide how many retvals to allocate
41
+ retvals = ::FFI::MemoryPointer.new(:pointer, 10)
42
+ num_retvals = ::FFI::MemoryPointer.new(:int)
43
+ num_retvals.write_int(retvals.size)
44
+
45
+ Status.check do |status|
46
+ FFI.TFE_Execute(operation, retvals, num_retvals, status)
47
+ end
48
+
49
+ n = num_retvals.read_int
50
+ if n > 0
51
+ handles = retvals.read_array_of_pointer(n).map do |handle|
52
+ TensorHandle.new(self, handle)
53
+ end
54
+
55
+ # TODO handle case where n = 1 and still want an array for retvals
56
+ n == 1 ? handles.first : handles
57
+ end
58
+ end
59
+
60
+ def device_policy
61
+ FFI::ContextDevicePlacementPolicy[FFI.TFE_ContextGetDevicePlacementPolicy(@pointer)]
62
+ end
63
+
64
+ def enable_run_metadata
65
+ FFI.TFE_ContextEnableRunMetadata(@pointer)
66
+ end
67
+
68
+ def disable_run_metadata
69
+ FFI.TFE_ContextDisableRunMetadata(@pointer)
70
+ end
71
+
72
+ def start_step
73
+ FFI.TFE_ContextStartStep(@pointer)
74
+ end
75
+
76
+ def end_step
77
+ FFI.TFE_ContextEndStep(@pointer)
78
+ end
79
+
80
+ def to_ptr
81
+ @pointer
82
+ end
83
+
84
+ def shared_name
85
+ # hard-coded in Python library
86
+ "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
87
+ end
88
+
89
+ # Mimic graph api
90
+ def add_to_collection(name, value)
91
+ end
92
+
93
+ # Mimic graph api
94
+ def add_to_collections(names, value)
95
+ end
96
+
97
+ def get_collection_ref(name)
98
+ end
99
+
100
+ def add_function(function)
101
+ Status.check do |status|
102
+ FFI.TFE_ContextAddFunction(self, function, status)
103
+ end
104
+ end
105
+
106
+ def remove_function(function)
107
+ name = function.is_a?(Graph::Function) ? function.name : function
108
+ Status.check do |status|
109
+ FFI.TFE_ContextRemoveFunction(self, name, status)
110
+ end
111
+ end
112
+
113
+ def function?(function)
114
+ name = function.is_a?(Graph::Function) ? function.name : function
115
+ # result is uchar
116
+ FFI.TFE_ContextHasFunction(self, name) != 0
117
+ end
118
+ end
119
+ end
120
+ end