tensorflow-ruby 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +18 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +104 -0
  5. data/lib/datasets/download_manager.rb +49 -0
  6. data/lib/datasets/images/mnist.rb +54 -0
  7. data/lib/datasets/resource.rb +19 -0
  8. data/lib/tensorflow-ruby.rb +182 -0
  9. data/lib/tensorflow.rb +1 -0
  10. data/lib/tensorflow/batchable_type_spec.rb +4 -0
  11. data/lib/tensorflow/core/framework/allocation_description_pb.rb +21 -0
  12. data/lib/tensorflow/core/framework/api_def_pb.rb +59 -0
  13. data/lib/tensorflow/core/framework/attr_value_pb.rb +46 -0
  14. data/lib/tensorflow/core/framework/cost_graph_pb.rb +49 -0
  15. data/lib/tensorflow/core/framework/device_attributes_pb.rb +37 -0
  16. data/lib/tensorflow/core/framework/function_pb.rb +38 -0
  17. data/lib/tensorflow/core/framework/graph_pb.rb +22 -0
  18. data/lib/tensorflow/core/framework/graph_transfer_info_pb.rb +73 -0
  19. data/lib/tensorflow/core/framework/kernel_def_pb.rb +31 -0
  20. data/lib/tensorflow/core/framework/log_memory_pb.rb +53 -0
  21. data/lib/tensorflow/core/framework/node_def_pb.rb +27 -0
  22. data/lib/tensorflow/core/framework/op_def_pb.rb +58 -0
  23. data/lib/tensorflow/core/framework/reader_base_pb.rb +19 -0
  24. data/lib/tensorflow/core/framework/remote_fused_graph_execute_info_pb.rb +30 -0
  25. data/lib/tensorflow/core/framework/resource_handle_pb.rb +28 -0
  26. data/lib/tensorflow/core/framework/step_stats_pb.rb +72 -0
  27. data/lib/tensorflow/core/framework/summary_pb.rb +71 -0
  28. data/lib/tensorflow/core/framework/tensor_description_pb.rb +21 -0
  29. data/lib/tensorflow/core/framework/tensor_pb.rb +41 -0
  30. data/lib/tensorflow/core/framework/tensor_shape_pb.rb +22 -0
  31. data/lib/tensorflow/core/framework/tensor_slice_pb.rb +23 -0
  32. data/lib/tensorflow/core/framework/types_pb.rb +62 -0
  33. data/lib/tensorflow/core/framework/variable_pb.rb +45 -0
  34. data/lib/tensorflow/core/framework/versions_pb.rb +18 -0
  35. data/lib/tensorflow/core/lib/core/error_codes_pb.rb +35 -0
  36. data/lib/tensorflow/core/protobuf/cluster_pb.rb +21 -0
  37. data/lib/tensorflow/core/protobuf/config_pb.rb +180 -0
  38. data/lib/tensorflow/core/protobuf/control_flow_pb.rb +48 -0
  39. data/lib/tensorflow/core/protobuf/conv_autotuning_pb.rb +29 -0
  40. data/lib/tensorflow/core/protobuf/critical_section_pb.rb +21 -0
  41. data/lib/tensorflow/core/protobuf/debug_pb.rb +38 -0
  42. data/lib/tensorflow/core/protobuf/device_properties_pb.rb +33 -0
  43. data/lib/tensorflow/core/protobuf/eager_service_pb.rb +112 -0
  44. data/lib/tensorflow/core/protobuf/graph_debug_info_pb.rb +29 -0
  45. data/lib/tensorflow/core/protobuf/master_pb.rb +123 -0
  46. data/lib/tensorflow/core/protobuf/master_service_pb.rb +15 -0
  47. data/lib/tensorflow/core/protobuf/meta_graph_pb.rb +95 -0
  48. data/lib/tensorflow/core/protobuf/named_tensor_pb.rb +18 -0
  49. data/lib/tensorflow/core/protobuf/queue_runner_pb.rb +21 -0
  50. data/lib/tensorflow/core/protobuf/replay_log_pb.rb +48 -0
  51. data/lib/tensorflow/core/protobuf/rewriter_config_pb.rb +81 -0
  52. data/lib/tensorflow/core/protobuf/saved_model_pb.rb +18 -0
  53. data/lib/tensorflow/core/protobuf/saved_object_graph_pb.rb +87 -0
  54. data/lib/tensorflow/core/protobuf/saver_pb.rb +28 -0
  55. data/lib/tensorflow/core/protobuf/struct_pb.rb +81 -0
  56. data/lib/tensorflow/core/protobuf/tensor_bundle_pb.rb +37 -0
  57. data/lib/tensorflow/core/protobuf/tensorflow_server_pb.rb +22 -0
  58. data/lib/tensorflow/core/protobuf/trace_events_pb.rb +39 -0
  59. data/lib/tensorflow/core/protobuf/trackable_object_graph_pb.rb +40 -0
  60. data/lib/tensorflow/core/protobuf/transport_options_pb.rb +16 -0
  61. data/lib/tensorflow/core/protobuf/verifier_config_pb.rb +23 -0
  62. data/lib/tensorflow/core/protobuf/worker_pb.rb +246 -0
  63. data/lib/tensorflow/core/protobuf/worker_service_pb.rb +15 -0
  64. data/lib/tensorflow/core/util/event_pb.rb +93 -0
  65. data/lib/tensorflow/core/util/memmapped_file_system_pb.rb +22 -0
  66. data/lib/tensorflow/core/util/saved_tensor_slice_pb.rb +40 -0
  67. data/lib/tensorflow/data/batch_dataset.rb +18 -0
  68. data/lib/tensorflow/data/dataset.rb +106 -0
  69. data/lib/tensorflow/data/fixed_length_record_dataset.rb +27 -0
  70. data/lib/tensorflow/data/iterator.rb +76 -0
  71. data/lib/tensorflow/data/map_dataset.rb +17 -0
  72. data/lib/tensorflow/data/repeat_dataset.rb +16 -0
  73. data/lib/tensorflow/data/shuffle_dataset.rb +23 -0
  74. data/lib/tensorflow/data/tensor_dataset.rb +19 -0
  75. data/lib/tensorflow/data/tensor_slice_dataset.rb +15 -0
  76. data/lib/tensorflow/data/tf_record_dataset.rb +18 -0
  77. data/lib/tensorflow/data/zip_dataset.rb +24 -0
  78. data/lib/tensorflow/decorators.rb +53 -0
  79. data/lib/tensorflow/eager/context.rb +120 -0
  80. data/lib/tensorflow/eager/operation.rb +219 -0
  81. data/lib/tensorflow/eager/tensor_handle.rb +87 -0
  82. data/lib/tensorflow/error.rb +54 -0
  83. data/lib/tensorflow/execution_context.rb +62 -0
  84. data/lib/tensorflow/extensions/arg_def.rb +58 -0
  85. data/lib/tensorflow/extensions/array.rb +17 -0
  86. data/lib/tensorflow/extensions/boolean.rb +25 -0
  87. data/lib/tensorflow/extensions/narray.rb +7 -0
  88. data/lib/tensorflow/ffi.rb +291 -0
  89. data/lib/tensorflow/graph/function.rb +33 -0
  90. data/lib/tensorflow/graph/function_def.rb +62 -0
  91. data/lib/tensorflow/graph/gradients.rb +120 -0
  92. data/lib/tensorflow/graph/graph.rb +252 -0
  93. data/lib/tensorflow/graph/graph_def_options.rb +24 -0
  94. data/lib/tensorflow/graph/graph_keys.rb +50 -0
  95. data/lib/tensorflow/graph/operation.rb +176 -0
  96. data/lib/tensorflow/graph/operation_attr.rb +153 -0
  97. data/lib/tensorflow/graph/operation_description.rb +255 -0
  98. data/lib/tensorflow/graph/operation_output.rb +49 -0
  99. data/lib/tensorflow/graph/session.rb +156 -0
  100. data/lib/tensorflow/keras/datasets/boston_housing.rb +32 -0
  101. data/lib/tensorflow/keras/datasets/cifar10.rb +11 -0
  102. data/lib/tensorflow/keras/datasets/cifar100.rb +11 -0
  103. data/lib/tensorflow/keras/datasets/fashion_mnist.rb +44 -0
  104. data/lib/tensorflow/keras/datasets/imdb.rb +30 -0
  105. data/lib/tensorflow/keras/datasets/mnist.rb +18 -0
  106. data/lib/tensorflow/keras/datasets/reuters.rb +28 -0
  107. data/lib/tensorflow/keras/layers/conv.rb +14 -0
  108. data/lib/tensorflow/keras/layers/conv2d.rb +11 -0
  109. data/lib/tensorflow/keras/layers/dense.rb +68 -0
  110. data/lib/tensorflow/keras/layers/dropout.rb +27 -0
  111. data/lib/tensorflow/keras/layers/flatten.rb +25 -0
  112. data/lib/tensorflow/keras/losses/sparse_categorical_crossentropy.rb +14 -0
  113. data/lib/tensorflow/keras/metrics/mean.rb +30 -0
  114. data/lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb +17 -0
  115. data/lib/tensorflow/keras/model.rb +6 -0
  116. data/lib/tensorflow/keras/models/sequential.rb +56 -0
  117. data/lib/tensorflow/keras/optimizers/adam.rb +8 -0
  118. data/lib/tensorflow/keras/preprocessing/image.rb +22 -0
  119. data/lib/tensorflow/keras/utils.rb +83 -0
  120. data/lib/tensorflow/name_scope.rb +57 -0
  121. data/lib/tensorflow/op_def_builder.rb +49 -0
  122. data/lib/tensorflow/ops/audio.rb +13 -0
  123. data/lib/tensorflow/ops/bitwise.rb +29 -0
  124. data/lib/tensorflow/ops/control.rb +13 -0
  125. data/lib/tensorflow/ops/gradients.rb +21 -0
  126. data/lib/tensorflow/ops/image.rb +218 -0
  127. data/lib/tensorflow/ops/io.rb +123 -0
  128. data/lib/tensorflow/ops/linalg.rb +131 -0
  129. data/lib/tensorflow/ops/math.rb +493 -0
  130. data/lib/tensorflow/ops/nn.rb +286 -0
  131. data/lib/tensorflow/ops/operators.rb +31 -0
  132. data/lib/tensorflow/ops/ops.rb +102 -0
  133. data/lib/tensorflow/ops/random.rb +18 -0
  134. data/lib/tensorflow/ops/raw_ops.rb +5179 -0
  135. data/lib/tensorflow/ops/raw_ops.rb.erb +38 -0
  136. data/lib/tensorflow/printers/graph.erb +80 -0
  137. data/lib/tensorflow/printers/graph.rb +26 -0
  138. data/lib/tensorflow/printers/graph_def.erb +109 -0
  139. data/lib/tensorflow/printers/graph_def.rb +26 -0
  140. data/lib/tensorflow/python_compatiblity.rb +55 -0
  141. data/lib/tensorflow/resource_summary_writer.rb +78 -0
  142. data/lib/tensorflow/status.rb +49 -0
  143. data/lib/tensorflow/stream_executor/dnn_pb.rb +90 -0
  144. data/lib/tensorflow/strings.rb +100 -0
  145. data/lib/tensorflow/summary.rb +13 -0
  146. data/lib/tensorflow/tensor.rb +133 -0
  147. data/lib/tensorflow/tensor_data.rb +310 -0
  148. data/lib/tensorflow/tensor_mixin.rb +32 -0
  149. data/lib/tensorflow/tensor_spec.rb +10 -0
  150. data/lib/tensorflow/tensorflow/core/util/event_pb.rb +93 -0
  151. data/lib/tensorflow/train/gradient_descent_optimizer.rb +33 -0
  152. data/lib/tensorflow/train/optimizer.rb +158 -0
  153. data/lib/tensorflow/type_spec.rb +4 -0
  154. data/lib/tensorflow/variable.rb +127 -0
  155. data/lib/tensorflow/version.rb +3 -0
  156. metadata +308 -0
@@ -0,0 +1,17 @@
1
+ class Array
2
+ # Helper method to make writing tests easier. Allows this to work:
3
+ #
4
+ # assert_equal([1,2], Numo::Array[1,2])
5
+ #
6
+ # Versus having to do this:
7
+ # assert_equal([1,2], Numo::Array[1,2].to_a)
8
+
9
+ alias :original_equals :==
10
+ def ==(other)
11
+ if other.kind_of?(Numo::NArray)
12
+ self.eql?(other.to_a)
13
+ else
14
+ original_equals(other)
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,25 @@
1
+ # encoding: UTF-8
2
+
3
+ module Kernel
4
+ def Boolean(value)
5
+ # Rails converts true/false to 't' and 'f' in this case
6
+ # because it does not have data dictionary information for
7
+ # these fields and doesn't seem to be able to figure it
8
+ # out from the query results.
9
+ if not value
10
+ false
11
+ elsif value.to_s.match(/^(t|true|1|yes|y)$/i)
12
+ true
13
+ else
14
+ false
15
+ end
16
+ end
17
+ end
18
+
19
+ class TrueClass
20
+ def to_i() 1; end
21
+ end
22
+
23
+ class FalseClass
24
+ def to_i() 0; end
25
+ end
@@ -0,0 +1,7 @@
1
+ module Numo
2
+ class NArray
3
+ def to_s
4
+ self.inspect
5
+ end
6
+ end
7
+ end
@@ -0,0 +1,291 @@
1
+ module Tensorflow
2
+ module FFI
3
+ extend ::FFI::Library
4
+
5
+ begin
6
+ ffi_lib ["tensorflow", "libtensorflow"]
7
+ rescue LoadError => e
8
+ raise e if ENV["TENSORFLOW_DEBUG"]
9
+ raise LoadError, "Could not find Tensorflow"
10
+ end
11
+
12
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_attrtype.h
13
+ AttrType = enum(:string, :int, :float, :bool, :type, :shape, :tensor, :placeholder, :func)
14
+
15
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_datatype.h
16
+ DataType = enum(:float, 1, :double, :int32, :uint8, :int16, :int8, :string, :complex64, :int64, :bool, :qint8, :quint8, :qint32, :bfloat16, :qint16, :quint16, :uint16, :complex128, :half, :resource, :variant, :uint32, :uint64)
17
+
18
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h
19
+ class Buffer < ::FFI::Struct
20
+ layout :data, :pointer,
21
+ :length, :size_t,
22
+ :data_deallocator, :pointer
23
+ end
24
+ attach_function :TF_NewBuffer, [], :pointer
25
+ attach_function :TF_DeleteBuffer, [:pointer], :void
26
+ attach_function :TF_GetBuffer, [:pointer], :pointer
27
+
28
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h
29
+ attach_function :TF_Version, [], :string
30
+ attach_function :TF_GetAllOpList, [], Buffer.by_ref
31
+
32
+ class Input < ::FFI::Struct
33
+ layout :oper, :pointer,
34
+ :index, :int
35
+
36
+ def to_s
37
+ "#{self[:oper]}: #{self[:index]}"
38
+ end
39
+ end
40
+
41
+ class Output < ::FFI::Struct
42
+ layout :oper, :pointer,
43
+ :index, :int
44
+
45
+ def self.array_to_ptr(outputs)
46
+ result = ::FFI::MemoryPointer.new(self, outputs.length)
47
+ outputs.each_with_index do |output, i|
48
+ copy_output = self.new(result[i])
49
+ copy_output[:oper] = output[:oper]
50
+ copy_output[:index] = output[:index]
51
+ end
52
+ result
53
+ end
54
+
55
+ def to_s
56
+ "#{self[:oper]}: #{self[:index]}"
57
+ end
58
+ end
59
+
60
+ attach_function :TF_NewGraph, [], :pointer
61
+ attach_function :TF_DeleteGraph, [:pointer], :pointer
62
+ attach_function :TF_GraphGetOpDef, [:pointer, :string, :pointer, :pointer], :void
63
+
64
+ attach_function :TF_GraphGetTensorNumDims, [:pointer, Output, :pointer], :int
65
+ attach_function :TF_GraphGetTensorShape, [:pointer, Output, :pointer, :int, :pointer], :void
66
+ attach_function :TF_GraphSetTensorShape, [:pointer, Output, :pointer, :int, :pointer], :void
67
+
68
+ attach_function :TF_NewOperation, [:pointer, :string, :string], :pointer
69
+ attach_function :TF_FinishOperation, [:pointer, :pointer], :pointer
70
+ attach_function :TF_SetDevice, [:pointer, :string,], :void
71
+ attach_function :TF_SetAttrBool, [:pointer, :string, :uchar], :void
72
+ attach_function :TF_SetAttrBoolList, [:pointer, :string, :pointer, :int], :void
73
+ attach_function :TF_SetAttrInt, [:pointer, :string, :int64], :void
74
+ attach_function :TF_SetAttrIntList, [:pointer, :string, :pointer, :int], :void
75
+ attach_function :TF_SetAttrFloat, [:pointer, :string, :float], :void
76
+ attach_function :TF_SetAttrFloatList, [:pointer, :string, :pointer, :int], :void
77
+ attach_function :TF_SetAttrFuncName, [:pointer, :string, :string, :size_t], :void
78
+ attach_function :TF_SetAttrPlaceholder, [:pointer, :string, :string], :void
79
+ attach_function :TF_SetAttrShape, [:pointer, :string, :pointer, :int], :void
80
+ attach_function :TF_SetAttrShapeList, [:pointer, :string, :pointer, :pointer, :int], :void
81
+ attach_function :TF_SetAttrString, [:pointer, :string, :pointer, :size_t], :void
82
+ attach_function :TF_SetAttrStringList, [:pointer, :string, :pointer, :pointer, :int], :void
83
+ attach_function :TF_SetAttrType, [:pointer, :string, DataType], :void
84
+ attach_function :TF_SetAttrTypeList, [:pointer, :string, :pointer, :int], :void
85
+ attach_function :TF_SetAttrTensor, [:pointer, :string, :pointer, :pointer], :void
86
+ attach_function :TF_SetAttrTensorList, [:pointer, :string, :pointer, :int, :pointer], :void
87
+
88
+ attach_function :TF_AddInput, [:pointer, Output], :void
89
+ attach_function :TF_AddInputList, [:pointer, :pointer, :int], :void
90
+
91
+ attach_function :TF_AddControlInput, [:pointer, :pointer], :void
92
+ attach_function :TF_OperationNumControlInputs, [:pointer], :int
93
+ attach_function :TF_OperationGetControlInputs, [:pointer, :pointer, :int], :int
94
+ attach_function :TF_OperationNumControlOutputs, [:pointer], :int
95
+ attach_function :TF_OperationGetControlOutputs, [:pointer, :pointer, :int], :int
96
+
97
+ attach_function :TF_OperationToNodeDef, [:pointer, :pointer, :pointer], :void
98
+ attach_function :TF_OperationNumInputs, [:pointer], :int
99
+ attach_function :TF_OperationInputType, [Input], DataType
100
+ attach_function :TF_OperationInputListLength, [:pointer, :string, :pointer], :int
101
+ attach_function :TF_OperationAllInputs, [:pointer, :pointer, :int], :void
102
+
103
+ attach_function :TF_OperationNumOutputs, [:pointer], :int
104
+ attach_function :TF_OperationOutputType, [Output], DataType
105
+ attach_function :TF_OperationOutputListLength, [:pointer, :string, :pointer], :int
106
+
107
+ attach_function :TF_OperationOutputNumConsumers, [Output], :int
108
+ attach_function :TF_OperationOutputConsumers, [Output, :pointer, :int], :int
109
+
110
+ class AttrMetadata < ::FFI::Struct
111
+ layout :is_list, :uchar,
112
+ :list_size, :int64,
113
+ :type, AttrType,
114
+ :total_size, :int64
115
+ end
116
+
117
+ attach_function :TF_OperationGetAttrMetadata, [:pointer, :string, :pointer], AttrMetadata.by_value
118
+ attach_function :TF_OperationGetAttrBool, [:pointer, :string, :pointer, :pointer], :void
119
+ attach_function :TF_OperationGetAttrBoolList, [:pointer, :string, :pointer, :int, :pointer], :void
120
+ attach_function :TF_OperationGetAttrFloat, [:pointer, :string, :pointer, :pointer], :void
121
+ attach_function :TF_OperationGetAttrFloatList, [:pointer, :string, :pointer, :int, :pointer], :void
122
+ attach_function :TF_OperationGetAttrInt, [:pointer, :string, :pointer, :pointer], :void
123
+ attach_function :TF_OperationGetAttrIntList, [:pointer, :string, :pointer, :int, :pointer], :void
124
+ attach_function :TF_OperationGetAttrShape, [:pointer, :string, :pointer, :int, :pointer], :void
125
+ attach_function :TF_OperationGetAttrShapeList, [:pointer, :string, :pointer, :pointer, :int, :pointer, :int, :pointer], :void
126
+ attach_function :TF_OperationGetAttrString, [:pointer, :string, :pointer, :size_t, :pointer], :void
127
+ attach_function :TF_OperationGetAttrStringList, [:pointer, :string, :pointer, :pointer, :int, :pointer, :size_t], :void
128
+ attach_function :TF_OperationGetAttrTensor, [:pointer, :string, :pointer, :pointer], :void
129
+ attach_function :TF_OperationGetAttrType, [:pointer, :string, :pointer, :pointer], :void
130
+ attach_function :TF_OperationGetAttrTypeList, [:pointer, :string, :pointer, :int, :pointer], :void
131
+ attach_function :TF_OperationGetAttrValueProto, [:pointer, :string, :pointer, :pointer], :void
132
+
133
+ attach_function :TF_GraphOperationByName, [:pointer, :string], :pointer
134
+ attach_function :TF_GraphNextOperation, [:pointer, :pointer], :pointer
135
+
136
+ attach_function :TF_OperationName, [:pointer], :string
137
+ attach_function :TF_OperationOpType, [:pointer], :string
138
+ attach_function :TF_OperationDevice, [:pointer], :string
139
+
140
+ attach_function :TF_AddGradients, [:pointer, :pointer, :int, :pointer, :int, :pointer, :pointer, :pointer], :void
141
+ attach_function :TF_AddGradientsWithPrefix, [:pointer, :string, :pointer, :int, :pointer, :int, :pointer, :pointer, :pointer], :void
142
+
143
+ attach_function :TF_NewSessionOptions, [], :pointer
144
+ attach_function :TF_SetTarget, [:pointer, :string], :void
145
+ attach_function :TF_SetConfig, [:pointer, :pointer, :size_t, :pointer], :void
146
+ attach_function :TF_DeleteSessionOptions, [:pointer,], :void
147
+
148
+ attach_function :TF_NewSession, [:pointer, :pointer, :pointer], :pointer
149
+ attach_function :TF_CloseSession, [:pointer, :pointer], :void
150
+ attach_function :TF_DeleteSession, [:pointer, :pointer], :void
151
+ attach_function :TF_SessionRun, [:pointer, Buffer,
152
+ :pointer, :pointer, :int,
153
+ :pointer, :pointer, :int,
154
+ :pointer, :int,
155
+ Buffer,
156
+ :pointer], :void
157
+
158
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_status.h
159
+ StatusCode = enum(:tf_ok, 0,
160
+ :tf_cancelled, 1,
161
+ :tf_unknown, 2,
162
+ :tf_invalid_argument, 3,
163
+ :tf_deadline_exceeded, 4,
164
+ :tf_not_found, 5,
165
+ :tf_already_exists, 6,
166
+ :tf_permission_denied, 7,
167
+ :tf_unauthenticated, 16,
168
+ :tf_resource_exhausted, 8,
169
+ :tf_failed_precondition, 9,
170
+ :tf_aborted, 10,
171
+ :tf_out_of_range, 11,
172
+ :tf_unimplemented, 12,
173
+ :tf_internal, 13,
174
+ :tf_unavailable, 14,
175
+ :tf_data_loss, 15)
176
+
177
+ attach_function :TF_NewStatus, [], :pointer
178
+ attach_function :TF_DeleteStatus, [:pointer], :pointer
179
+ attach_function :TF_GetCode, [:pointer], StatusCode
180
+ attach_function :TF_Message, [:pointer], :string
181
+ attach_function :TF_SetStatus, [:pointer, StatusCode, :string], :void
182
+
183
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/tf_tensor.h
184
+ callback :tensor_deallocator, [:pointer, :size_t, :pointer], :void
185
+ attach_function :TF_NewTensor, [DataType, :pointer, :int, :pointer, :size_t, :tensor_deallocator, :pointer], :pointer
186
+ attach_function :TF_DeleteTensor, [:pointer], :void
187
+ attach_function :TF_TensorType, [:pointer], DataType
188
+ attach_function :TF_NumDims, [:pointer], :int
189
+ attach_function :TF_Dim, [:pointer, :int], :int64
190
+ attach_function :TF_TensorByteSize, [:pointer], :size_t
191
+ attach_function :TF_TensorElementCount, [:pointer], :int64
192
+ attach_function :TF_TensorData, [:pointer], :pointer
193
+ attach_function :TF_TensorByteSize, [:pointer], :size_t
194
+ attach_function :TF_StringEncode, [:pointer, :size_t, :pointer, :size_t, :pointer], :size_t
195
+ attach_function :TF_StringDecode, [:pointer, :size_t, :pointer, :pointer, :pointer], :size_t
196
+ attach_function :TF_StringEncodedSize, [:size_t], :size_t
197
+
198
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/eager/c_api.h
199
+ ContextDevicePlacementPolicy = enum(:explicit, :warn, :silent, :silent_for_int32)
200
+
201
+ attach_function :TFE_NewContextOptions, [], :pointer
202
+ attach_function :TFE_ContextOptionsSetAsync, [:pointer, :char], :void
203
+ attach_function :TFE_DeleteContextOptions, [:pointer], :void
204
+ attach_function :TFE_NewContext, [:pointer, :pointer], :pointer
205
+ attach_function :TFE_DeleteContext, [:pointer], :void
206
+ attach_function :TFE_ContextListDevices, [:pointer, :pointer], :pointer
207
+ attach_function :TFE_ContextGetDevicePlacementPolicy, [:pointer], :int
208
+ attach_function :TFE_ContextAddFunction, [:pointer, :pointer, :pointer], :void
209
+ attach_function :TFE_ContextHasFunction, [:pointer, :string], :uchar
210
+ attach_function :TFE_ContextRemoveFunction, [:pointer, :string, :pointer], :void
211
+ attach_function :TFE_ContextAddFunction, [:pointer, :pointer, :pointer], :void
212
+ attach_function :TFE_ContextHasFunction, [:pointer, :string], :uchar
213
+ attach_function :TFE_ContextRemoveFunction, [:pointer, :string, :pointer], :void
214
+
215
+ attach_function :TFE_NewTensorHandle, [:pointer, :pointer], :pointer
216
+ attach_function :TFE_DeleteTensorHandle, [:pointer], :void
217
+ attach_function :TFE_TensorHandleDataType, [:pointer], DataType
218
+ attach_function :TFE_TensorHandleNumDims, [:pointer, :pointer], :int
219
+ attach_function :TFE_TensorHandleNumElements, [:pointer, :pointer], :int64
220
+ attach_function :TFE_TensorHandleDim, %i[pointer int pointer], :int64
221
+ attach_function :TFE_TensorHandleDeviceName, [:pointer, :pointer], :string
222
+ attach_function :TFE_TensorHandleBackingDeviceName, [:pointer, :pointer], :string
223
+ attach_function :TFE_TensorHandleResolve, [:pointer, :pointer], :pointer
224
+ attach_function :TFE_NewOp, [:pointer, :string, :pointer], :pointer
225
+ attach_function :TFE_DeleteOp, [:pointer], :void
226
+ attach_function :TFE_OpSetDevice, [:pointer, :string, :pointer], :pointer
227
+ attach_function :TFE_OpGetDevice, [:pointer, :pointer], :string
228
+ attach_function :TFE_OpAddInput, [:pointer, :pointer, :pointer], :void
229
+ attach_function :TFE_OpAddInputList, %i[pointer pointer int pointer], :void
230
+ attach_function :TFE_OpGetAttrType, %i[pointer string pointer pointer], AttrType
231
+ attach_function :TFE_OpSetAttrString, %i[pointer string pointer size_t], :void
232
+ attach_function :TFE_OpSetAttrInt, %i[pointer string int64_t], :void
233
+ attach_function :TFE_OpSetAttrFloat, %i[pointer string float], :void
234
+ attach_function :TFE_OpSetAttrFunction, [:pointer, :string, :pointer], :void
235
+ attach_function :TFE_OpSetAttrFunctionName, [:pointer, :string, :string, :size_t], :void
236
+ attach_function :TFE_OpSetAttrFunctionList, [:pointer, :string, :pointer, :int], :void
237
+ attach_function :TFE_OpSetAttrBool, %i[pointer string uint8], :void
238
+ attach_function :TFE_OpSetAttrTensor, %i[pointer string pointer pointer], :void
239
+ attach_function :TFE_OpSetAttrType, %i[pointer string int], :void
240
+ attach_function :TFE_OpSetAttrShape, %i[pointer string pointer int pointer], :void
241
+ attach_function :TFE_OpSetAttrIntList, %i[pointer string pointer int], :void
242
+ attach_function :TFE_OpSetAttrFloatList, %i[pointer string pointer int], :void
243
+ attach_function :TFE_OpSetAttrTypeList, %i[pointer string pointer int], :void
244
+ attach_function :TFE_OpSetAttrShapeList, %i[pointer string pointer pointer int pointer], :void
245
+ attach_function :TFE_Execute, %i[pointer pointer pointer pointer], :pointer
246
+ attach_function :TFE_ContextHasFunction, [:pointer, :string], :uchar
247
+ attach_function :TFE_ContextEnableRunMetadata, [:pointer], :void
248
+ attach_function :TFE_ContextDisableRunMetadata, [:pointer], :void
249
+ attach_function :TFE_ContextStartStep, [:pointer], :void
250
+ attach_function :TFE_ContextEndStep, [:pointer], :void
251
+
252
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/ops.h
253
+ attach_function :TF_NewOpDefinitionBuilder, [:string], :pointer
254
+ attach_function :TF_RegisterOpDefinition, [:pointer, :pointer], :void
255
+ attach_function :TF_DeleteOpDefinitionBuilder, [:pointer], :void
256
+ attach_function :TF_OpDefinitionBuilderAddAttr, [:pointer, :string], :void
257
+ attach_function :TF_OpDefinitionBuilderAddInput, [:pointer, :string], :void
258
+ attach_function :TF_OpDefinitionBuilderAddOutput, [:pointer, :string], :void
259
+ attach_function :TF_OpDefinitionBuilderSetIsCommutative, [:pointer, :bool], :void
260
+ attach_function :TF_OpDefinitionBuilderSetIsAggregate, [:pointer, :bool], :void
261
+ attach_function :TF_OpDefinitionBuilderSetIsAggregate, [:pointer, :bool], :void
262
+ attach_function :TF_OpDefinitionBuilderSetShapeInferenceFunction, [:pointer, :pointer], :void
263
+
264
+ attach_function :TF_GraphToFunction, [:pointer, :string, :uchar,
265
+ :int, :pointer,
266
+ :int, :pointer,
267
+ :int, :pointer,
268
+ :pointer, :pointer, :string, :pointer], :pointer
269
+ attach_function :TF_FunctionName, [:pointer], :strptr
270
+ attach_function :TF_FunctionToFunctionDef, [:pointer, :pointer, :pointer], :strptr
271
+ attach_function :TF_GraphCopyFunction, [:pointer, :pointer, :pointer, :pointer], :void
272
+
273
+ attach_function :TF_GraphToGraphDef, [:pointer, :pointer, :pointer], :void
274
+
275
+ attach_function :TF_NewImportGraphDefOptions, [], :pointer
276
+ attach_function :TF_DeleteImportGraphDefOptions, [:pointer], :void
277
+ attach_function :TF_ImportGraphDefOptionsSetPrefix, [:pointer, :string], :void
278
+ attach_function :TF_ImportGraphDefOptionsSetDefaultDevice, [:pointer, :string], :void
279
+ attach_function :TF_ImportGraphDefOptionsSetUniquifyNames, [:pointer, :uchar], :void
280
+ attach_function :TF_ImportGraphDefOptionsSetUniquifyPrefix, [:pointer, :uchar], :void
281
+ attach_function :TF_ImportGraphDefOptionsAddInputMapping, [:pointer, :string, :int, Output], :void
282
+ attach_function :TF_ImportGraphDefOptionsRemapControlDependency, [:pointer, :string, :pointer], :void
283
+ attach_function :TF_ImportGraphDefOptionsAddControlDependency, [:pointer, :pointer], :void
284
+ attach_function :TF_ImportGraphDefOptionsAddReturnOutput, [:pointer,:string, :int], :void
285
+ attach_function :TF_ImportGraphDefOptionsNumReturnOutputs, [:pointer], :int
286
+ attach_function :TF_ImportGraphDefOptionsAddReturnOperation, [:pointer, :string], :void
287
+ attach_function :TF_ImportGraphDefOptionsNumReturnOperations, [:pointer], :int
288
+
289
+ attach_function :TF_GraphImportGraphDef, [:pointer, Buffer, :pointer, :pointer], :int
290
+ end
291
+ end
@@ -0,0 +1,33 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class Function
4
+ attr_reader :output_types, :output_shapes
5
+ def initialize(pointer, output_types, output_shapes)
6
+ @pointer = pointer
7
+ @output_types = output_types
8
+ @output_shapes = output_shapes
9
+ end
10
+
11
+ def to_ptr
12
+ @pointer
13
+ end
14
+
15
+ def name
16
+ name, ptr = FFI.TF_FunctionName(self)
17
+ name
18
+ end
19
+
20
+ def function_def
21
+ buffer_ptr = FFI.TF_NewBuffer
22
+ Status.check do |status|
23
+ FFI.TF_FunctionToFunctionDef(self, buffer_ptr, status)
24
+ end
25
+ buffer = FFI::Buffer.new(buffer_ptr)
26
+ string = buffer[:data].read_string(buffer[:length])
27
+ Tensorflow::FunctionDef.decode(string)
28
+ ensure
29
+ FFI.TF_DeleteBuffer(buffer)
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,62 @@
1
+ module Tensorflow
2
+ module Graph
3
+ class FunctionDef
4
+ attr_reader :ruby_method, :signatures
5
+
6
+ Signature = Struct.new(:dtype, :shape)
7
+
8
+ def initialize(ruby_method, input_signatures = [])
9
+ @ruby_method = ruby_method
10
+ self.process_signatures(ruby_method, input_signatures)
11
+ self.wrap_ruby_method
12
+ end
13
+
14
+ def process_signatures(ruby_method, input_signatures)
15
+ if input_signatures.length != ruby_method.parameters.length
16
+ raise(Error::InvalidArgumentError, "Must specify input signature for each method parameter")
17
+ end
18
+
19
+ @signatures = input_signatures.map do |dtype, shape|
20
+ Signature.new(dtype, shape)
21
+ end
22
+ end
23
+
24
+ def aliased_name
25
+ "#{self.ruby_method.original_name}_original"
26
+ end
27
+
28
+ def wrap_ruby_method
29
+ new_name = self.aliased_name
30
+ original_name = self.ruby_method.original_name
31
+ self.ruby_method.owner.instance_eval do
32
+ alias_method(new_name, original_name)
33
+ end
34
+
35
+ this = self
36
+ original_name = ruby_method.original_name
37
+ self.ruby_method.owner.instance_eval do
38
+ define_method(original_name) do |*args|
39
+ function = this.build_function(self)
40
+ ExecutionContext.current.add_function(function)
41
+ function
42
+ end
43
+ end
44
+ end
45
+
46
+ def build_function(object)
47
+ Graph::new.as_default do |graph|
48
+ placeholders = self.ruby_method.parameters.map.with_index do |param, index|
49
+ signature = self.signatures[index]
50
+ Tensorflow.placeholder(signature.dtype, name: param.last, shape: signature.shape)
51
+ end
52
+
53
+ # Call the original ruby_method to build the graph
54
+ bound_method = self.ruby_method.bind(object)
55
+ result = bound_method.call(*placeholders)
56
+
57
+ graph.to_function(self.ruby_method.original_name.to_s, nil, placeholders, Array(result))
58
+ end
59
+ end
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,120 @@
1
+ require 'set'
2
+
3
+ module Tensorflow
4
+ module Graph
5
+ class Gradients
6
+ attr_reader :graph
7
+
8
+ def self.gradients
9
+ @gradients ||= begin
10
+ default = self.instance_method(:add_api_gradients)
11
+ Hash.new(default)
12
+ end
13
+ end
14
+
15
+ def self.register(op_type, &block)
16
+ self.gradients[op_type] = block
17
+ end
18
+
19
+ def initialize(graph)
20
+ @graph = graph
21
+ end
22
+
23
+ def path(output, input)
24
+ forwards = self.graph.forward(input)
25
+ backwards = self.graph.backward(output)
26
+ forwards.intersection(backwards)
27
+ end
28
+
29
+ def default_gradient(operation)
30
+ operation.outputs.map.with_index do |output, i|
31
+ shape_op = Tensorflow.shape(output, :int32)
32
+ constant = Tensorflow.constant(1, name: "grad_ys_#{i}", dtype: operation.output_types[i])
33
+ Tensorflow.fill(shape_op, constant)
34
+ end
35
+ end
36
+
37
+ def gradients(output, inputs, grad_ys: nil, name: "gradients", stop_operations: Set.new)
38
+ grad_ys ||= default_gradient(output).first
39
+
40
+ self.graph.name_scope(name) do
41
+ inputs.map.with_index do |input, i|
42
+ operations_path = self.path(output, input)
43
+ next if operations_path.empty?
44
+
45
+ self.derivative(grad_ys, output, stop_operations, operations_path)
46
+ end.flatten.compact
47
+ end
48
+ end
49
+
50
+ def derivative(gradient, operation, stop_operations, operations_path)
51
+ # This method follows the C api naming conventions for parameters. Visually it looks
52
+ # like this:
53
+ #
54
+ # x ------> y (forward)
55
+ # dy <----- dx (backward)
56
+
57
+ return gradient if !operations_path.include?(operation) || stop_operations.include?(operation)
58
+
59
+ inputs = operation.inputs.select do |input|
60
+ operations_path.include?(input.operation) && !stop_operations.include?(input.operation)
61
+ end
62
+
63
+ return gradient if inputs.empty?
64
+
65
+ outputs = operation.outputs.select do |output|
66
+ consumers = operation.output_consumers(output)
67
+ # The last operation we are evaluating will not be hooked up to any consumers, so
68
+ # we want to analyze all its outputs. For operations earlier in the graph, skip any
69
+ # unused outputs since they are not connected to anything
70
+ operation == operations_path.first || consumers.count > 0
71
+ end
72
+
73
+ gradient_func = self.class.gradients[operation.op_type]
74
+
75
+ dy = if gradient_func.is_a?(UnboundMethod)
76
+ gradient_func.bind(self).call(gradient, outputs, inputs)
77
+ else
78
+ gradient_func.call(gradient, outputs, inputs)
79
+ end
80
+
81
+ # We are done with this operation, so backpropagate to the input operations
82
+ inputs.map.with_index do |input, i|
83
+ dy_output = dy[i]
84
+ unless dy_output.output[:oper].null?
85
+ self.derivative(dy_output.operation, input.operation, stop_operations, operations_path)
86
+ end
87
+ end
88
+ end
89
+
90
+ def add_api_gradients(gradient, outputs, inputs)
91
+ # These are the outputs from the operation
92
+ y = FFI::Output.array_to_ptr(outputs.map(&:output))
93
+
94
+ # These are the inputs to the output operation
95
+ x = FFI::Output.array_to_ptr(inputs.map(&:output))
96
+
97
+ # This is the gradient we are backpropagating
98
+ dx = if gradient
99
+ FFI::Output.array_to_ptr(gradient.outputs.map(&:output))
100
+ end
101
+
102
+ # This is the gradient we want to calculate
103
+ dy = ::FFI::MemoryPointer.new(FFI::Output, inputs.length, true)
104
+
105
+ prefix = self.graph.scoped_name(inputs.first.operation.name)
106
+ Status.check do |status|
107
+ FFI.TF_AddGradientsWithPrefix(self.graph,
108
+ prefix,
109
+ y, outputs.length,
110
+ x, inputs.length,
111
+ dx, status, dy)
112
+ end
113
+
114
+ inputs.length.times.map do |i|
115
+ OperationOutput.from_graph(graph, dy[i])
116
+ end
117
+ end
118
+ end
119
+ end
120
+ end