tensorflow-ruby 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,17 @@
1
+ class Array
2
+ # Helper method to make writing tests easier. Allows this to work:
3
+ #
4
+ # assert_equal([1,2], Numo::Array[1,2])
5
+ #
6
+ # Versus having to do this:
7
+ # assert_equal([1,2], Numo::Array[1,2].to_a)
8
+
9
+ alias :original_equals :==
10
+ def ==(other)
11
+ if other.kind_of?(Numo::NArray)
12
+ self.eql?(other.to_a)
13
+ else
14
+ original_equals(other)
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,25 @@
1
+ # encoding: UTF-8
2
+
3
+ module Kernel
4
+ def Boolean(value)
5
+ # Rails converts true/false to 't' and 'f' in this case
6
+ # because it does not have data dictionary information for
7
+ # these fields and doesn't seem to be able to figure it
8
+ # out from the query results.
9
+ if not value
10
+ false
11
+ elsif value.to_s.match(/^(t|true|1|yes|y)$/i)
12
+ true
13
+ else
14
+ false
15
+ end
16
+ end
17
+ end
18
+
19
+ class TrueClass
20
+ def to_i() 1; end
21
+ end
22
+
23
+ class FalseClass
24
+ def to_i() 0; end
25
+ end
@@ -0,0 +1,7 @@
1
+ module Numo
2
+ class NArray
3
+ def to_s
4
+ self.inspect
5
+ end
6
+ end
7
+ end
@@ -0,0 +1,291 @@
1
+ module Tensorflow
2
+ module FFI
3
+ extend ::FFI::Library
4
+
5
+ begin
6
+ ffi_lib ["tensorflow", "libtensorflow"]
7
+ rescue LoadError => e
8
+ raise e if ENV["TENSORFLOW_DEBUG"]
9
+ raise LoadError, "Could not find Tensorflow"
10
+ end
11
+
12
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_attrtype.h
13
+ AttrType = enum(:string, :int, :float, :bool, :type, :shape, :tensor, :placeholder, :func)
14
+
15
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_datatype.h
16
+ DataType = enum(:float, 1, :double, :int32, :uint8, :int16, :int8, :string, :complex64, :int64, :bool, :qint8, :quint8, :qint32, :bfloat16, :qint16, :quint16, :uint16, :complex128, :half, :resource, :variant, :uint32, :uint64)
17
+
18
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h
19
+ class Buffer < ::FFI::Struct
20
+ layout :data, :pointer,
21
+ :length, :size_t,
22
+ :data_deallocator, :pointer
23
+ end
24
+ attach_function :TF_NewBuffer, [], :pointer
25
+ attach_function :TF_DeleteBuffer, [:pointer], :void
26
+ attach_function :TF_GetBuffer, [:pointer], :pointer
27
+
28
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h
29
+ attach_function :TF_Version, [], :string
30
+ attach_function :TF_GetAllOpList, [], Buffer.by_ref
31
+
32
+ class Input < ::FFI::Struct
33
+ layout :oper, :pointer,
34
+ :index, :int
35
+
36
+ def to_s
37
+ "#{self[:oper]}: #{self[:index]}"
38
+ end
39
+ end
40
+
41
+ class Output < ::FFI::Struct
42
+ layout :oper, :pointer,
43
+ :index, :int
44
+
45
+ def self.array_to_ptr(outputs)
46
+ result = ::FFI::MemoryPointer.new(self, outputs.length)
47
+ outputs.each_with_index do |output, i|
48
+ copy_output = self.new(result[i])
49
+ copy_output[:oper] = output[:oper]
50
+ copy_output[:index] = output[:index]
51
+ end
52
+ result
53
+ end
54
+
55
+ def to_s
56
+ "#{self[:oper]}: #{self[:index]}"
57
+ end
58
+ end
59
+
60
+ attach_function :TF_NewGraph, [], :pointer
61
+ attach_function :TF_DeleteGraph, [:pointer], :pointer
62
+ attach_function :TF_GraphGetOpDef, [:pointer, :string, :pointer, :pointer], :void
63
+
64
+ attach_function :TF_GraphGetTensorNumDims, [:pointer, Output, :pointer], :int
65
+ attach_function :TF_GraphGetTensorShape, [:pointer, Output, :pointer, :int, :pointer], :void
66
+ attach_function :TF_GraphSetTensorShape, [:pointer, Output, :pointer, :int, :pointer], :void
67
+
68
+ attach_function :TF_NewOperation, [:pointer, :string, :string], :pointer
69
+ attach_function :TF_FinishOperation, [:pointer, :pointer], :pointer
70
+ attach_function :TF_SetDevice, [:pointer, :string,], :void
71
+ attach_function :TF_SetAttrBool, [:pointer, :string, :uchar], :void
72
+ attach_function :TF_SetAttrBoolList, [:pointer, :string, :pointer, :int], :void
73
+ attach_function :TF_SetAttrInt, [:pointer, :string, :int64], :void
74
+ attach_function :TF_SetAttrIntList, [:pointer, :string, :pointer, :int], :void
75
+ attach_function :TF_SetAttrFloat, [:pointer, :string, :float], :void
76
+ attach_function :TF_SetAttrFloatList, [:pointer, :string, :pointer, :int], :void
77
+ attach_function :TF_SetAttrFuncName, [:pointer, :string, :string, :size_t], :void
78
+ attach_function :TF_SetAttrPlaceholder, [:pointer, :string, :string], :void
79
+ attach_function :TF_SetAttrShape, [:pointer, :string, :pointer, :int], :void
80
+ attach_function :TF_SetAttrShapeList, [:pointer, :string, :pointer, :pointer, :int], :void
81
+ attach_function :TF_SetAttrString, [:pointer, :string, :pointer, :size_t], :void
82
+ attach_function :TF_SetAttrStringList, [:pointer, :string, :pointer, :pointer, :int], :void
83
+ attach_function :TF_SetAttrType, [:pointer, :string, DataType], :void
84
+ attach_function :TF_SetAttrTypeList, [:pointer, :string, :pointer, :int], :void
85
+ attach_function :TF_SetAttrTensor, [:pointer, :string, :pointer, :pointer], :void
86
+ attach_function :TF_SetAttrTensorList, [:pointer, :string, :pointer, :int, :pointer], :void
87
+
88
+ attach_function :TF_AddInput, [:pointer, Output], :void
89
+ attach_function :TF_AddInputList, [:pointer, :pointer, :int], :void
90
+
91
+ attach_function :TF_AddControlInput, [:pointer, :pointer], :void
92
+ attach_function :TF_OperationNumControlInputs, [:pointer], :int
93
+ attach_function :TF_OperationGetControlInputs, [:pointer, :pointer, :int], :int
94
+ attach_function :TF_OperationNumControlOutputs, [:pointer], :int
95
+ attach_function :TF_OperationGetControlOutputs, [:pointer, :pointer, :int], :int
96
+
97
+ attach_function :TF_OperationToNodeDef, [:pointer, :pointer, :pointer], :void
98
+ attach_function :TF_OperationNumInputs, [:pointer], :int
99
+ attach_function :TF_OperationInputType, [Input], DataType
100
+ attach_function :TF_OperationInputListLength, [:pointer, :string, :pointer], :int
101
+ attach_function :TF_OperationAllInputs, [:pointer, :pointer, :int], :void
102
+
103
+ attach_function :TF_OperationNumOutputs, [:pointer], :int
104
+ attach_function :TF_OperationOutputType, [Output], DataType
105
+ attach_function :TF_OperationOutputListLength, [:pointer, :string, :pointer], :int
106
+
107
+ attach_function :TF_OperationOutputNumConsumers, [Output], :int
108
+ attach_function :TF_OperationOutputConsumers, [Output, :pointer, :int], :int
109
+
110
+ class AttrMetadata < ::FFI::Struct
111
+ layout :is_list, :uchar,
112
+ :list_size, :int64,
113
+ :type, AttrType,
114
+ :total_size, :int64
115
+ end
116
+
117
+ attach_function :TF_OperationGetAttrMetadata, [:pointer, :string, :pointer], AttrMetadata.by_value
118
+ attach_function :TF_OperationGetAttrBool, [:pointer, :string, :pointer, :pointer], :void
119
+ attach_function :TF_OperationGetAttrBoolList, [:pointer, :string, :pointer, :int, :pointer], :void
120
+ attach_function :TF_OperationGetAttrFloat, [:pointer, :string, :pointer, :pointer], :void
121
+ attach_function :TF_OperationGetAttrFloatList, [:pointer, :string, :pointer, :int, :pointer], :void
122
+ attach_function :TF_OperationGetAttrInt, [:pointer, :string, :pointer, :pointer], :void
123
+ attach_function :TF_OperationGetAttrIntList, [:pointer, :string, :pointer, :int, :pointer], :void
124
+ attach_function :TF_OperationGetAttrShape, [:pointer, :string, :pointer, :int, :pointer], :void
125
+ attach_function :TF_OperationGetAttrShapeList, [:pointer, :string, :pointer, :pointer, :int, :pointer, :int, :pointer], :void
126
+ attach_function :TF_OperationGetAttrString, [:pointer, :string, :pointer, :size_t, :pointer], :void
127
+ attach_function :TF_OperationGetAttrStringList, [:pointer, :string, :pointer, :pointer, :int, :pointer, :size_t], :void
128
+ attach_function :TF_OperationGetAttrTensor, [:pointer, :string, :pointer, :pointer], :void
129
+ attach_function :TF_OperationGetAttrType, [:pointer, :string, :pointer, :pointer], :void
130
+ attach_function :TF_OperationGetAttrTypeList, [:pointer, :string, :pointer, :int, :pointer], :void
131
+ attach_function :TF_OperationGetAttrValueProto, [:pointer, :string, :pointer, :pointer], :void
132
+
133
+ attach_function :TF_GraphOperationByName, [:pointer, :string], :pointer
134
+ attach_function :TF_GraphNextOperation, [:pointer, :pointer], :pointer
135
+
136
+ attach_function :TF_OperationName, [:pointer], :string
137
+ attach_function :TF_OperationOpType, [:pointer], :string
138
+ attach_function :TF_OperationDevice, [:pointer], :string
139
+
140
+ attach_function :TF_AddGradients, [:pointer, :pointer, :int, :pointer, :int, :pointer, :pointer, :pointer], :void
141
+ attach_function :TF_AddGradientsWithPrefix, [:pointer, :string, :pointer, :int, :pointer, :int, :pointer, :pointer, :pointer], :void
142
+
143
+ attach_function :TF_NewSessionOptions, [], :pointer
144
+ attach_function :TF_SetTarget, [:pointer, :string], :void
145
+ attach_function :TF_SetConfig, [:pointer, :pointer, :size_t, :pointer], :void
146
+ attach_function :TF_DeleteSessionOptions, [:pointer,], :void
147
+
148
+ attach_function :TF_NewSession, [:pointer, :pointer, :pointer], :pointer
149
+ attach_function :TF_CloseSession, [:pointer, :pointer], :void
150
+ attach_function :TF_DeleteSession, [:pointer, :pointer], :void
151
+ attach_function :TF_SessionRun, [:pointer, Buffer,
152
+ :pointer, :pointer, :int,
153
+ :pointer, :pointer, :int,
154
+ :pointer, :int,
155
+ Buffer,
156
+ :pointer], :void
157
+
158
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_status.h
159
+ StatusCode = enum(:tf_ok, 0,
160
+ :tf_cancelled, 1,
161
+ :tf_unknown, 2,
162
+ :tf_invalid_argument, 3,
163
+ :tf_deadline_exceeded, 4,
164
+ :tf_not_found, 5,
165
+ :tf_already_exists, 6,
166
+ :tf_permission_denied, 7,
167
+ :tf_unauthenticated, 16,
168
+ :tf_resource_exhausted, 8,
169
+ :tf_failed_precondition, 9,
170
+ :tf_aborted, 10,
171
+ :tf_out_of_range, 11,
172
+ :tf_unimplemented, 12,
173
+ :tf_internal, 13,
174
+ :tf_unavailable, 14,
175
+ :tf_data_loss, 15)
176
+
177
+ attach_function :TF_NewStatus, [], :pointer
178
+ attach_function :TF_DeleteStatus, [:pointer], :pointer
179
+ attach_function :TF_GetCode, [:pointer], StatusCode
180
+ attach_function :TF_Message, [:pointer], :string
181
+ attach_function :TF_SetStatus, [:pointer, StatusCode, :string], :void
182
+
183
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_tensor.h
184
+ callback :tensor_deallocator, [:pointer, :size_t, :pointer], :void
185
+ attach_function :TF_NewTensor, [DataType, :pointer, :int, :pointer, :size_t, :tensor_deallocator, :pointer], :pointer
186
+ attach_function :TF_DeleteTensor, [:pointer], :void
187
+ attach_function :TF_TensorType, [:pointer], DataType
188
+ attach_function :TF_NumDims, [:pointer], :int
189
+ attach_function :TF_Dim, [:pointer, :int], :int64
190
+ attach_function :TF_TensorByteSize, [:pointer], :size_t
191
+ attach_function :TF_TensorElementCount, [:pointer], :int64
192
+ attach_function :TF_TensorData, [:pointer], :pointer
193
+ attach_function :TF_TensorByteSize, [:pointer], :size_t
194
+ attach_function :TF_StringEncode, [:pointer, :size_t, :pointer, :size_t, :pointer], :size_t
195
+ attach_function :TF_StringDecode, [:pointer, :size_t, :pointer, :pointer, :pointer], :size_t
196
+ attach_function :TF_StringEncodedSize, [:size_t], :size_t
197
+
198
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/eager/c_api.h
199
+ ContextDevicePlacementPolicy = enum(:explicit, :warn, :silent, :silent_for_int32)
200
+
201
+ attach_function :TFE_NewContextOptions, [], :pointer
202
+ attach_function :TFE_ContextOptionsSetAsync, [:pointer, :char], :void
203
+ attach_function :TFE_DeleteContextOptions, [:pointer], :void
204
+ attach_function :TFE_NewContext, [:pointer, :pointer], :pointer
205
+ attach_function :TFE_DeleteContext, [:pointer], :void
206
+ attach_function :TFE_ContextListDevices, [:pointer, :pointer], :pointer
207
+ attach_function :TFE_ContextGetDevicePlacementPolicy, [:pointer], :int
208
+ attach_function :TFE_ContextAddFunction, [:pointer, :pointer, :pointer], :void
209
+ attach_function :TFE_ContextHasFunction, [:pointer, :string], :uchar
210
+ attach_function :TFE_ContextRemoveFunction, [:pointer, :string, :pointer], :void
211
+ attach_function :TFE_ContextAddFunction, [:pointer, :pointer, :pointer], :void
212
+ attach_function :TFE_ContextHasFunction, [:pointer, :string], :uchar
213
+ attach_function :TFE_ContextRemoveFunction, [:pointer, :string, :pointer], :void
214
+
215
+ attach_function :TFE_NewTensorHandle, [:pointer, :pointer], :pointer
216
+ attach_function :TFE_DeleteTensorHandle, [:pointer], :void
217
+ attach_function :TFE_TensorHandleDataType, [:pointer], DataType
218
+ attach_function :TFE_TensorHandleNumDims, [:pointer, :pointer], :int
219
+ attach_function :TFE_TensorHandleNumElements, [:pointer, :pointer], :int64
220
+ attach_function :TFE_TensorHandleDim, %i[pointer int pointer], :int64
221
+ attach_function :TFE_TensorHandleDeviceName, [:pointer, :pointer], :string
222
+ attach_function :TFE_TensorHandleBackingDeviceName, [:pointer, :pointer], :string
223
+ attach_function :TFE_TensorHandleResolve, [:pointer, :pointer], :pointer
224
+ attach_function :TFE_NewOp, [:pointer, :string, :pointer], :pointer
225
+ attach_function :TFE_DeleteOp, [:pointer], :void
226
+ attach_function :TFE_OpSetDevice, [:pointer, :string, :pointer], :pointer
227
+ attach_function :TFE_OpGetDevice, [:pointer, :pointer], :string
228
+ attach_function :TFE_OpAddInput, [:pointer, :pointer, :pointer], :void
229
+ attach_function :TFE_OpAddInputList, %i[pointer pointer int pointer], :void
230
+ attach_function :TFE_OpGetAttrType, %i[pointer string pointer pointer], AttrType
231
+ attach_function :TFE_OpSetAttrString, %i[pointer string pointer size_t], :void
232
+ attach_function :TFE_OpSetAttrInt, %i[pointer string int64_t], :void
233
+ attach_function :TFE_OpSetAttrFloat, %i[pointer string float], :void
234
+ attach_function :TFE_OpSetAttrFunction, [:pointer, :string, :pointer], :void
235
+ attach_function :TFE_OpSetAttrFunctionName, [:pointer, :string, :string, :size_t], :void
236
+ attach_function :TFE_OpSetAttrFunctionList, [:pointer, :string, :pointer, :int], :void
237
+ attach_function :TFE_OpSetAttrBool, %i[pointer string uint8], :void
238
+ attach_function :TFE_OpSetAttrTensor, %i[pointer string pointer pointer], :void
239
+ attach_function :TFE_OpSetAttrType, %i[pointer string int], :void
240
+ attach_function :TFE_OpSetAttrShape, %i[pointer string pointer int pointer], :void
241
+ attach_function :TFE_OpSetAttrIntList, %i[pointer string pointer int], :void
242
+ attach_function :TFE_OpSetAttrFloatList, %i[pointer string pointer int], :void
243
+ attach_function :TFE_OpSetAttrTypeList, %i[pointer string pointer int], :void
244
+ attach_function :TFE_OpSetAttrShapeList, %i[pointer string pointer pointer int pointer], :void
245
+ attach_function :TFE_Execute, %i[pointer pointer pointer pointer], :pointer
246
+ attach_function :TFE_ContextHasFunction, [:pointer, :string], :uchar
247
+ attach_function :TFE_ContextEnableRunMetadata, [:pointer], :void
248
+ attach_function :TFE_ContextDisableRunMetadata, [:pointer], :void
249
+ attach_function :TFE_ContextStartStep, [:pointer], :void
250
+ attach_function :TFE_ContextEndStep, [:pointer], :void
251
+
252
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/ops.h
253
+ attach_function :TF_NewOpDefinitionBuilder, [:string], :pointer
254
+ attach_function :TF_RegisterOpDefinition, [:pointer, :pointer], :void
255
+ attach_function :TF_DeleteOpDefinitionBuilder, [:pointer], :void
256
+ attach_function :TF_OpDefinitionBuilderAddAttr, [:pointer, :string], :void
257
+ attach_function :TF_OpDefinitionBuilderAddInput, [:pointer, :string], :void
258
+ attach_function :TF_OpDefinitionBuilderAddOutput, [:pointer, :string], :void
259
+ attach_function :TF_OpDefinitionBuilderSetIsCommutative, [:pointer, :bool], :void
260
+ attach_function :TF_OpDefinitionBuilderSetIsAggregate, [:pointer, :bool], :void
261
+ attach_function :TF_OpDefinitionBuilderSetIsAggregate, [:pointer, :bool], :void
262
+ attach_function :TF_OpDefinitionBuilderSetShapeInferenceFunction, [:pointer, :pointer], :void
263
+
264
+ attach_function :TF_GraphToFunction, [:pointer, :string, :uchar,
265
+ :int, :pointer,
266
+ :int, :pointer,
267
+ :int, :pointer,
268
+ :pointer, :pointer, :string, :pointer], :pointer
269
+ attach_function :TF_FunctionName, [:pointer], :strptr
270
+ attach_function :TF_FunctionToFunctionDef, [:pointer, :pointer, :pointer], :strptr
271
+ attach_function :TF_GraphCopyFunction, [:pointer, :pointer, :pointer, :pointer], :void
272
+
273
+ attach_function :TF_GraphToGraphDef, [:pointer, :pointer, :pointer], :void
274
+
275
+ attach_function :TF_NewImportGraphDefOptions, [], :pointer
276
+ attach_function :TF_DeleteImportGraphDefOptions, [:pointer], :void
277
+ attach_function :TF_ImportGraphDefOptionsSetPrefix, [:pointer, :string], :void
278
+ attach_function :TF_ImportGraphDefOptionsSetDefaultDevice, [:pointer, :string], :void
279
+ attach_function :TF_ImportGraphDefOptionsSetUniquifyNames, [:pointer, :uchar], :void
280
+ attach_function :TF_ImportGraphDefOptionsSetUniquifyPrefix, [:pointer, :uchar], :void
281
+ attach_function :TF_ImportGraphDefOptionsAddInputMapping, [:pointer, :string, :int, Output], :void
282
+ attach_function :TF_ImportGraphDefOptionsRemapControlDependency, [:pointer, :string, :pointer], :void
283
+ attach_function :TF_ImportGraphDefOptionsAddControlDependency, [:pointer, :pointer], :void
284
+ attach_function :TF_ImportGraphDefOptionsAddReturnOutput, [:pointer,:string, :int], :void
285
+ attach_function :TF_ImportGraphDefOptionsNumReturnOutputs, [:pointer], :int
286
+ attach_function :TF_ImportGraphDefOptionsAddReturnOperation, [:pointer, :string], :void
287
+ attach_function :TF_ImportGraphDefOptionsNumReturnOperations, [:pointer], :int
288
+
289
+ attach_function :TF_GraphImportGraphDef, [:pointer, Buffer, :pointer, :pointer], :int
290
+ end
291
+ end
@@ -0,0 +1,33 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class Function
4
+ attr_reader :output_types, :output_shapes
5
+ def initialize(pointer, output_types, output_shapes)
6
+ @pointer = pointer
7
+ @output_types = output_types
8
+ @output_shapes = output_shapes
9
+ end
10
+
11
+ def to_ptr
12
+ @pointer
13
+ end
14
+
15
+ def name
16
+ name, ptr = FFI.TF_FunctionName(self)
17
+ name
18
+ end
19
+
20
+ def function_def
21
+ buffer_ptr = FFI.TF_NewBuffer
22
+ Status.check do |status|
23
+ FFI.TF_FunctionToFunctionDef(self, buffer_ptr, status)
24
+ end
25
+ buffer = FFI::Buffer.new(buffer_ptr)
26
+ string = buffer[:data].read_string(buffer[:length])
27
+ Tensorflow::FunctionDef.decode(string)
28
+ ensure
29
+ FFI.TF_DeleteBuffer(buffer)
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,62 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class FunctionDef
4
+ attr_reader :ruby_method, :signatures
5
+
6
+ Signature = Struct.new(:dtype, :shape)
7
+
8
+ def initialize(ruby_method, input_signatures = [])
9
+ @ruby_method = ruby_method
10
+ self.process_signatures(ruby_method, input_signatures)
11
+ self.wrap_ruby_method
12
+ end
13
+
14
+ def process_signatures(ruby_method, input_signatures)
15
+ if input_signatures.length != ruby_method.parameters.length
16
+ raise(Error::InvalidArgumentError, "Must specify input signature for each method parameter")
17
+ end
18
+
19
+ @signatures = input_signatures.map do |dtype, shape|
20
+ Signature.new(dtype, shape)
21
+ end
22
+ end
23
+
24
+ def aliased_name
25
+ "#{self.ruby_method.original_name}_original"
26
+ end
27
+
28
+ def wrap_ruby_method
29
+ new_name = self.aliased_name
30
+ original_name = self.ruby_method.original_name
31
+ self.ruby_method.owner.instance_eval do
32
+ alias_method(new_name, original_name)
33
+ end
34
+
35
+ this = self
36
+ original_name = ruby_method.original_name
37
+ self.ruby_method.owner.instance_eval do
38
+ define_method(original_name) do |*args|
39
+ function = this.build_function(self)
40
+ ExecutionContext.current.add_function(function)
41
+ function
42
+ end
43
+ end
44
+ end
45
+
46
+ def build_function(object)
47
+ Graph::new.as_default do |graph|
48
+ placeholders = self.ruby_method.parameters.map.with_index do |param, index|
49
+ signature = self.signatures[index]
50
+ Tensorflow.placeholder(signature.dtype, name: param.last, shape: signature.shape)
51
+ end
52
+
53
+ # Call the original ruby_method to build the graph
54
+ bound_method = self.ruby_method.bind(object)
55
+ result = bound_method.call(*placeholders)
56
+
57
+ graph.to_function(self.ruby_method.original_name.to_s, nil, placeholders, Array(result))
58
+ end
59
+ end
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,120 @@
1
+ require 'set'
2
+
3
+ module Tensorflow
4
+ module Graph
5
+ class Gradients
6
+ attr_reader :graph
7
+
8
+ def self.gradients
9
+ @gradients ||= begin
10
+ default = self.instance_method(:add_api_gradients)
11
+ Hash.new(default)
12
+ end
13
+ end
14
+
15
+ def self.register(op_type, &block)
16
+ self.gradients[op_type] = block
17
+ end
18
+
19
+ def initialize(graph)
20
+ @graph = graph
21
+ end
22
+
23
+ def path(output, input)
24
+ forwards = self.graph.forward(input)
25
+ backwards = self.graph.backward(output)
26
+ forwards.intersection(backwards)
27
+ end
28
+
29
+ def default_gradient(operation)
30
+ operation.outputs.map.with_index do |output, i|
31
+ shape_op = Tensorflow.shape(output, :int32)
32
+ constant = Tensorflow.constant(1, name: "grad_ys_#{i}", dtype: operation.output_types[i])
33
+ Tensorflow.fill(shape_op, constant)
34
+ end
35
+ end
36
+
37
+ def gradients(output, inputs, grad_ys: nil, name: "gradients", stop_operations: Set.new)
38
+ grad_ys ||= default_gradient(output).first
39
+
40
+ self.graph.name_scope(name) do
41
+ inputs.map.with_index do |input, i|
42
+ operations_path = self.path(output, input)
43
+ next if operations_path.empty?
44
+
45
+ self.derivative(grad_ys, output, stop_operations, operations_path)
46
+ end.flatten.compact
47
+ end
48
+ end
49
+
50
+ def derivative(gradient, operation, stop_operations, operations_path)
51
+ # This method follows the C api naming conventions for parameters. Visually it looks
52
+ # like this:
53
+ #
54
+ # x ------> y (forward)
55
+ # dy <----- dx (backward)
56
+
57
+ return gradient if !operations_path.include?(operation) || stop_operations.include?(operation)
58
+
59
+ inputs = operation.inputs.select do |input|
60
+ operations_path.include?(input.operation) && !stop_operations.include?(input.operation)
61
+ end
62
+
63
+ return gradient if inputs.empty?
64
+
65
+ outputs = operation.outputs.select do |output|
66
+ consumers = operation.output_consumers(output)
67
+ # The last operation we are evaluating will not be hooked up to any consumers, so
68
+ # we want to analyze all its outputs. For operations earlier in the graph, skip any
69
+ # unused outputs since they are not connected to anything
70
+ operation == operations_path.first || consumers.count > 0
71
+ end
72
+
73
+ gradient_func = self.class.gradients[operation.op_type]
74
+
75
+ dy = if gradient_func.is_a?(UnboundMethod)
76
+ gradient_func.bind(self).call(gradient, outputs, inputs)
77
+ else
78
+ gradient_func.call(gradient, outputs, inputs)
79
+ end
80
+
81
+ # We are done with this operation, so backpropagate to the input operations
82
+ inputs.map.with_index do |input, i|
83
+ dy_output = dy[i]
84
+ unless dy_output.output[:oper].null?
85
+ self.derivative(dy_output.operation, input.operation, stop_operations, operations_path)
86
+ end
87
+ end
88
+ end
89
+
90
+ def add_api_gradients(gradient, outputs, inputs)
91
+ # These are the outputs from the operation
92
+ y = FFI::Output.array_to_ptr(outputs.map(&:output))
93
+
94
+ # These are the inputs to the output operation
95
+ x = FFI::Output.array_to_ptr(inputs.map(&:output))
96
+
97
+ # This is the gradient we are backpropagating
98
+ dx = if gradient
99
+ FFI::Output.array_to_ptr(gradient.outputs.map(&:output))
100
+ end
101
+
102
+ # This is the gradient we want to calculate
103
+ dy = ::FFI::MemoryPointer.new(FFI::Output, inputs.length, true)
104
+
105
+ prefix = self.graph.scoped_name(inputs.first.operation.name)
106
+ Status.check do |status|
107
+ FFI.TF_AddGradientsWithPrefix(self.graph,
108
+ prefix,
109
+ y, outputs.length,
110
+ x, inputs.length,
111
+ dx, status, dy)
112
+ end
113
+
114
+ inputs.length.times.map do |i|
115
+ OperationOutput.from_graph(graph, dy[i])
116
+ end
117
+ end
118
+ end
119
+ end
120
+ end