tensorflow-ruby 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,106 @@
1
+ module Tensorflow
2
+ module Data
3
+ class Dataset
4
+ # Copied from Python code
5
+ DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB
6
+
7
+ include Enumerable
8
+
9
+ # TODO remove
10
+ attr_reader :output_types, :output_shapes, :variant_tensor
11
+
12
+ def self.to_tensor_array(values)
13
+ case values
14
+ when Numo::NArray
15
+ [Tensor.new(values)]
16
+ when Tensor
17
+ [values]
18
+ when Array
19
+ values.to_a.map do |v|
20
+ if v.is_a?(Tensor)
21
+ v
22
+ else
23
+ Tensor.new(v)
24
+ end
25
+ end
26
+ when Graph::Operation
27
+ [values]
28
+ else
29
+ raise(Error::UnimplementedError, "Unsupported dataset element: #{values}")
30
+ end
31
+ end
32
+
33
+ def self.from_tensors(tensors)
34
+ TensorDataset.new(tensors)
35
+ end
36
+
37
+ def self.from_tensor_slices(tensors)
38
+ TensorSliceDataset.new(tensors)
39
+ end
40
+
41
+ def initialize(variant_tensor)
42
+ @variant_tensor = variant_tensor
43
+ end
44
+
45
+ def to_ptr
46
+ @variant_tensor.to_ptr
47
+ end
48
+
49
+ def with_options(options)
50
+
51
+ end
52
+
53
+ def batch(batch_size, drop_remainder: false)
54
+ BatchDataset.new(self, batch_size, drop_remainder)
55
+ end
56
+
57
+ def shuffle(buffer_size)
58
+ ShuffleDataset.new(self, buffer_size)
59
+ end
60
+
61
+ def make_one_shot_iterator
62
+ OneShotIterator.new(self)
63
+ end
64
+
65
+ def make_initializable_iterator(shared_name: '')
66
+ InitializableIterator.new(self, shared_name: shared_name)
67
+ end
68
+
69
+ def each
70
+ iterator, deleter = RawOps.anonymous_iterator_v2(output_types: @output_types, output_shapes: @output_shapes)
71
+ RawOps.make_iterator(@variant_tensor, iterator)
72
+ begin
73
+ loop do
74
+ values = RawOps.iterator_get_next_sync(iterator, output_types: @output_types, output_shapes: @output_shapes)
75
+ yield values
76
+ end
77
+ rescue Error::OutOfRangeError
78
+ end
79
+ ensure
80
+ RawOps.delete_iterator(iterator, deleter) if iterator
81
+ end
82
+
83
+ # !!! DEBUG method. You don't want to use this method it because it iterates over
84
+ # the entire dataset and reads it into a ruby array in memory
85
+ def data
86
+ self.map do |slice|
87
+ if slice.is_a?(Array)
88
+ slice.map do |tensor|
89
+ tensor.value
90
+ end
91
+ else
92
+ slice.value
93
+ end
94
+ end
95
+ end
96
+
97
+ def map_func(func)
98
+ MapDataset.new(self, func)
99
+ end
100
+
101
+ def repeat(count)
102
+ RepeatDataset.new(self, 3)
103
+ end
104
+ end
105
+ end
106
+ end
@@ -0,0 +1,27 @@
1
+ module Tensorflow
2
+ module Data
3
+ class FixedLengthRecordDataset < Dataset
4
+ def initialize(filenames, record_bytes, header_bytes: 0, footer_bytes: 0,
5
+ buffer_size: DEFAULT_READER_BUFFER_SIZE_BYTES, compression_type: '', num_parallel_reads: 0)
6
+ @output_types = [:string]
7
+ @output_shapes = [[]]
8
+
9
+ record_bytes_tensor = Tensor.new(record_bytes, dtype: :int64)
10
+ header_bytes_tensor = Tensor.new(header_bytes, dtype: :int64)
11
+ footer_bytes_tensor = Tensor.new(footer_bytes, dtype: :int64)
12
+ buffer_size_tensor = Tensor.new(buffer_size, dtype: :int64)
13
+
14
+ variant_tensor = RawOps.fixed_length_record_dataset_v2(filenames,
15
+ header_bytes_tensor,
16
+ record_bytes_tensor,
17
+ footer_bytes_tensor,
18
+ buffer_size_tensor,
19
+ compression_type)
20
+
21
+ super(variant_tensor)
22
+ end
23
+
24
+
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,76 @@
1
+ module Tensorflow
2
+ module Data
3
+ class Iterator
4
+ attr_reader :output_types, :output_shapes
5
+
6
+ def self.from_structure(output_types, output_shapes=[], shared_name: '')
7
+ ReinitializableIterator.new(output_types, output_shapes, shared_name: shared_name)
8
+ end
9
+
10
+ def initialize(output_types, output_shapes=[])
11
+ @output_types = output_types
12
+ @output_shapes = output_shapes
13
+ end
14
+
15
+ def get_next
16
+ RawOps.iterator_get_next(@iterator, output_types: self.output_types, output_shapes: self.output_shapes)
17
+ end
18
+ end
19
+
20
+ class OneShotIterator < Iterator
21
+ def initialize(dataset)
22
+ super(dataset.output_types, dataset.output_shapes)
23
+ create_one_shot_iterator(dataset)
24
+ end
25
+
26
+ private
27
+
28
+ def create_one_shot_iterator(dataset)
29
+ function = make_dataset_function(dataset)
30
+ ExecutionContext.current.add_function(function)
31
+ @iterator = RawOps.one_shot_iterator(dataset_factory: function, output_types: self.output_types, output_shapes: self.output_shapes)
32
+ end
33
+
34
+ def make_dataset_function(dataset)
35
+ function = Graph::Graph.new.as_default do |func_graph|
36
+ optimize = RawOps.optimize_dataset(dataset.variant_tensor, ['noop_elimination'],
37
+ output_types: self.output_types, output_shapes: self.output_shapes)
38
+ func_graph.to_function('MakeDataset', nil, nil, [optimize])
39
+ end
40
+ end
41
+ end
42
+
43
+ class InitializableIterator < Iterator
44
+ attr_reader :initializer
45
+
46
+ def initialize(dataset, shared_name: '')
47
+ super(dataset.output_types, dataset.output_shapes)
48
+ create_initializable_iterator(dataset, shared_name)
49
+ end
50
+
51
+ private
52
+
53
+ def create_initializable_iterator(dataset, shared_name)
54
+ @iterator = RawOps.iterator_v2(shared_name: shared_name, output_types: self.output_types, output_shapes: self.output_shapes)
55
+ @initializer = RawOps.make_iterator(dataset.variant_tensor, @iterator)
56
+ end
57
+ end
58
+
59
+ class ReinitializableIterator < Iterator
60
+ def initialize(output_types, output_shapes, shared_name: '')
61
+ super(output_types, output_shapes)
62
+ create_iterator_from_structure(shared_name)
63
+ end
64
+
65
+ def make_initializer(dataset)
66
+ RawOps.make_iterator(dataset.variant_tensor, @iterator)
67
+ end
68
+
69
+ private
70
+
71
+ def create_iterator_from_structure(shared_name)
72
+ @iterator = RawOps.iterator_v2(shared_name: shared_name, output_types: self.output_types, output_shapes: self.output_shapes)
73
+ end
74
+ end
75
+ end
76
+ end
@@ -0,0 +1,17 @@
1
+ module Tensorflow
2
+ module Data
3
+ class MapDataset < Dataset
4
+ def initialize(input_dataset, function, other_arguments: [])
5
+ @output_types = function.output_types
6
+ @output_shapes = function.output_shapes
7
+
8
+ variant_tensor = RawOps.map_dataset(input_dataset.variant_tensor, other_arguments,
9
+ f: function,
10
+ output_types: @output_types,
11
+ output_shapes: @output_shapes)
12
+
13
+ super(variant_tensor)
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,16 @@
1
+ module Tensorflow
2
+ module Data
3
+ class RepeatDataset < Dataset
4
+ def initialize(dataset, count)
5
+ @output_types = dataset.output_types
6
+ @output_shapes = dataset.output_shapes
7
+
8
+ variant_tensor = RawOps.repeat_dataset(dataset.variant_tensor, Tensor.new(count, dtype: :int64),
9
+ output_types: @output_types,
10
+ output_shapes: @output_shapes)
11
+
12
+ super(variant_tensor)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,23 @@
1
+ module Tensorflow
2
+ module Data
3
+ class ShuffleDataset < Dataset
4
+ def initialize(input_dataset, buffer_size)
5
+ @input_dataset = input_dataset
6
+ @output_types = input_dataset.output_types
7
+ @output_shapes = input_dataset.output_shapes
8
+
9
+ buffer_size = Tensor.new(buffer_size, dtype: :int64)
10
+ seed = Tensor.new(::Random.rand(100_000_000), dtype: :int64)
11
+ seed2 = Tensor.new(::Random.rand(100_000_000), dtype: :int64)
12
+
13
+ variant_tensor = RawOps.shuffle_dataset(input_dataset.variant_tensor,
14
+ buffer_size,
15
+ seed,
16
+ seed2,
17
+ output_types: @output_types,
18
+ output_shapes: @output_shapes)
19
+ super(variant_tensor)
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,19 @@
1
+ module Tensorflow
2
+ module Data
3
+ class TensorDataset < Dataset
4
+ def initialize(elements)
5
+ @tensors = self.class.to_tensor_array(elements)
6
+ @output_types = @tensors.map(&:dtype)
7
+ @output_shapes = @tensors.map do |tensor|
8
+ tensor.shape
9
+ end
10
+
11
+ variant_tensor = RawOps.tensor_dataset(@tensors,
12
+ toutput_types: @output_types,
13
+ output_shapes: @output_shapes)
14
+
15
+ super(variant_tensor)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,15 @@
1
+ module Tensorflow
2
+ module Data
3
+ class TensorSliceDataset < Dataset
4
+ def initialize(elements)
5
+ @tensors = self.class.to_tensor_array(elements)
6
+ @output_types = @tensors.map(&:dtype)
7
+ @output_shapes = @tensors.map do |tensor|
8
+ tensor.shape[1..]
9
+ end
10
+ variant_tensor = RawOps.tensor_slice_dataset(@tensors, toutput_types: @output_types, output_shapes: @output_shapes)
11
+ super(variant_tensor)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,18 @@
1
+ module Tensorflow
2
+ module Data
3
+ class TfRecordDataset < Dataset
4
+ DEFAULT_BUFFER_SIZE = 256 * 1_048_576 # 256 MB
5
+
6
+ def initialize(filenames, compression_type='', buffer_size=DEFAULT_BUFFER_SIZE)
7
+ filenames = Array(filenames)
8
+ @output_types = [:string]
9
+ @output_shapes = [[]]
10
+
11
+ buffer_size = Tensor.new(buffer_size, dtype: :int64) if buffer_size
12
+ variant_tensor = RawOps.tf_record_dataset(filenames, compression_type, buffer_size)
13
+
14
+ super(variant_tensor)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,24 @@
1
+ module Tensorflow
2
+ module Data
3
+ class ZipDataset < Dataset
4
+ def initialize(*datasets)
5
+ datasets = datasets.flatten(1)
6
+ tensors = Array.new(datasets.size)
7
+ output_types = Array.new(datasets.size)
8
+ output_shapes = Array.new(datasets.size)
9
+
10
+ datasets.each_with_index do |dataset, i|
11
+ tensors[i] = dataset.variant_tensor
12
+ output_types[i] = dataset.output_types
13
+ output_shapes[i] = dataset.output_shapes
14
+ end
15
+
16
+ @output_types = output_types.flatten
17
+ @output_shapes = output_shapes.flatten(1)
18
+ variant_tensor = RawOps.zip_dataset(tensors, n: tensors.count, output_types: @output_types, output_shapes: @output_shapes)
19
+
20
+ super(variant_tensor)
21
+ end
22
+ end
23
+ end
24
+ end
@@ -0,0 +1,53 @@
1
+ module Tensorflow
2
+ module Decorator
3
+ class Function
4
+ attr_reader :input_signatures
5
+
6
+ def initialize(input_signatures = [])
7
+ @input_signatures = input_signatures
8
+ end
9
+
10
+ def wrap(method)
11
+ Graph::FunctionDef.new(method, self.input_signatures)
12
+ end
13
+ end
14
+
15
+ def self.extended(klass)
16
+ @waiting_for_method = false
17
+ this = self
18
+ klass.instance_eval do
19
+ @tf = this
20
+ end
21
+
22
+ if klass.is_a?(Object) && klass.to_s == 'main'
23
+ klass.class.extend(self)
24
+ end
25
+ end
26
+
27
+ def self.function(input_signature = [])
28
+ @current_function = Function.new(input_signature)
29
+ end
30
+
31
+ def self.wrap_method(method)
32
+ # We do this little dance because when the method is wrapped it will trigger method_added. So first we need
33
+ # to clear out @current_function before continuing
34
+ if @current_function
35
+ current_function = @current_function
36
+ @current_function = nil
37
+ current_function&.wrap(method)
38
+ end
39
+ end
40
+
41
+ def singleton_method_added(method_name)
42
+ super(method_name)
43
+ method = self.method(method_name)
44
+ @tf.wrap_method(method)
45
+ end
46
+
47
+ def method_added(method_name)
48
+ super(method_name)
49
+ method = self.instance_method(method_name)
50
+ @tf.wrap_method(method)
51
+ end
52
+ end
53
+ end
@@ -0,0 +1,120 @@
1
+ module Tensorflow
2
+ module Eager
3
+ class Context
4
+ extend Forwardable
5
+ def_delegators :@name_scope, :name_scope, :scoped_name, :unique_name
6
+
7
+ def self.default
8
+ @default ||= Context.new
9
+ end
10
+
11
+ def self.finalize(pointer)
12
+ proc { FFI.TFE_DeleteContext(pointer) }
13
+ end
14
+
15
+ def initialize
16
+ @name_scope = NameScope.new
17
+ options = FFI.TFE_NewContextOptions
18
+ Status.check do |status|
19
+ @pointer = FFI.TFE_NewContext(options, status)
20
+ end
21
+ ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
22
+ FFI.TFE_DeleteContextOptions(options)
23
+ end
24
+
25
+ def as_default
26
+ raise(Error::InvalidArgumentError, "Must provide block") unless block_given?
27
+ ExecutionContext.push(self)
28
+ begin
29
+ yield self
30
+ ensure
31
+ ExecutionContext.pop
32
+ end
33
+ end
34
+
35
+ def create_operation(op_type, inputs=[], attrs={})
36
+ Operation.new(self, op_type, inputs, attrs)
37
+ end
38
+
39
+ def execute(operation)
40
+ # TODO decide how many retvals to allocate
41
+ retvals = ::FFI::MemoryPointer.new(:pointer, 10)
42
+ num_retvals = ::FFI::MemoryPointer.new(:int)
43
+ num_retvals.write_int(retvals.size)
44
+
45
+ Status.check do |status|
46
+ FFI.TFE_Execute(operation, retvals, num_retvals, status)
47
+ end
48
+
49
+ n = num_retvals.read_int
50
+ if n > 0
51
+ handles = retvals.read_array_of_pointer(n).map do |handle|
52
+ TensorHandle.new(self, handle)
53
+ end
54
+
55
+ # TODO handle case where n = 1 and still want an array for retvals
56
+ n == 1 ? handles.first : handles
57
+ end
58
+ end
59
+
60
+ def device_policy
61
+ FFI::ContextDevicePlacementPolicy[FFI.TFE_ContextGetDevicePlacementPolicy(@pointer)]
62
+ end
63
+
64
+ def enable_run_metadata
65
+ FFI.TFE_ContextEnableRunMetadata(@pointer)
66
+ end
67
+
68
+ def disable_run_metadata
69
+ FFI.TFE_ContextDisableRunMetadata(@pointer)
70
+ end
71
+
72
+ def start_step
73
+ FFI.TFE_ContextStartStep(@pointer)
74
+ end
75
+
76
+ def end_step
77
+ FFI.TFE_ContextEndStep(@pointer)
78
+ end
79
+
80
+ def to_ptr
81
+ @pointer
82
+ end
83
+
84
+ def shared_name
85
+ # hard-coded in Python library
86
+ "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
87
+ end
88
+
89
+ # Mimic graph api
90
+ def add_to_collection(name, value)
91
+ end
92
+
93
+ # Mimic graph api
94
+ def add_to_collections(names, value)
95
+ end
96
+
97
+ def get_collection_ref(name)
98
+ end
99
+
100
+ def add_function(function)
101
+ Status.check do |status|
102
+ FFI.TFE_ContextAddFunction(self, function, status)
103
+ end
104
+ end
105
+
106
+ def remove_function(function)
107
+ name = function.is_a?(Graph::Function) ? function.name : function
108
+ Status.check do |status|
109
+ FFI.TFE_ContextRemoveFunction(self, name, status)
110
+ end
111
+ end
112
+
113
+ def function?(function)
114
+ name = function.is_a?(Graph::Function) ? function.name : function
115
+ # result is uchar
116
+ FFI.TFE_ContextHasFunction(self, name) != 0
117
+ end
118
+ end
119
+ end
120
+ end