tensorbored 2.21.0rc1769983804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (271) hide show
  1. tensorbored/__init__.py +112 -0
  2. tensorbored/_vendor/__init__.py +0 -0
  3. tensorbored/_vendor/bleach/__init__.py +125 -0
  4. tensorbored/_vendor/bleach/_vendor/__init__.py +0 -0
  5. tensorbored/_vendor/bleach/_vendor/html5lib/__init__.py +35 -0
  6. tensorbored/_vendor/bleach/_vendor/html5lib/_ihatexml.py +289 -0
  7. tensorbored/_vendor/bleach/_vendor/html5lib/_inputstream.py +918 -0
  8. tensorbored/_vendor/bleach/_vendor/html5lib/_tokenizer.py +1735 -0
  9. tensorbored/_vendor/bleach/_vendor/html5lib/_trie/__init__.py +5 -0
  10. tensorbored/_vendor/bleach/_vendor/html5lib/_trie/_base.py +40 -0
  11. tensorbored/_vendor/bleach/_vendor/html5lib/_trie/py.py +67 -0
  12. tensorbored/_vendor/bleach/_vendor/html5lib/_utils.py +159 -0
  13. tensorbored/_vendor/bleach/_vendor/html5lib/constants.py +2946 -0
  14. tensorbored/_vendor/bleach/_vendor/html5lib/filters/__init__.py +0 -0
  15. tensorbored/_vendor/bleach/_vendor/html5lib/filters/alphabeticalattributes.py +29 -0
  16. tensorbored/_vendor/bleach/_vendor/html5lib/filters/base.py +12 -0
  17. tensorbored/_vendor/bleach/_vendor/html5lib/filters/inject_meta_charset.py +73 -0
  18. tensorbored/_vendor/bleach/_vendor/html5lib/filters/lint.py +93 -0
  19. tensorbored/_vendor/bleach/_vendor/html5lib/filters/optionaltags.py +207 -0
  20. tensorbored/_vendor/bleach/_vendor/html5lib/filters/sanitizer.py +916 -0
  21. tensorbored/_vendor/bleach/_vendor/html5lib/filters/whitespace.py +38 -0
  22. tensorbored/_vendor/bleach/_vendor/html5lib/html5parser.py +2795 -0
  23. tensorbored/_vendor/bleach/_vendor/html5lib/serializer.py +409 -0
  24. tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/__init__.py +30 -0
  25. tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/genshi.py +54 -0
  26. tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/sax.py +50 -0
  27. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/__init__.py +88 -0
  28. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/base.py +417 -0
  29. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/dom.py +239 -0
  30. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree.py +343 -0
  31. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree_lxml.py +392 -0
  32. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/__init__.py +154 -0
  33. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/base.py +252 -0
  34. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/dom.py +43 -0
  35. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree.py +131 -0
  36. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree_lxml.py +215 -0
  37. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/genshi.py +69 -0
  38. tensorbored/_vendor/bleach/_vendor/parse.py +1078 -0
  39. tensorbored/_vendor/bleach/callbacks.py +32 -0
  40. tensorbored/_vendor/bleach/html5lib_shim.py +757 -0
  41. tensorbored/_vendor/bleach/linkifier.py +633 -0
  42. tensorbored/_vendor/bleach/parse_shim.py +1 -0
  43. tensorbored/_vendor/bleach/sanitizer.py +638 -0
  44. tensorbored/_vendor/bleach/six_shim.py +19 -0
  45. tensorbored/_vendor/webencodings/__init__.py +342 -0
  46. tensorbored/_vendor/webencodings/labels.py +231 -0
  47. tensorbored/_vendor/webencodings/mklabels.py +59 -0
  48. tensorbored/_vendor/webencodings/x_user_defined.py +325 -0
  49. tensorbored/assets.py +36 -0
  50. tensorbored/auth.py +102 -0
  51. tensorbored/backend/__init__.py +0 -0
  52. tensorbored/backend/application.py +604 -0
  53. tensorbored/backend/auth_context_middleware.py +38 -0
  54. tensorbored/backend/client_feature_flags.py +113 -0
  55. tensorbored/backend/empty_path_redirect.py +46 -0
  56. tensorbored/backend/event_processing/__init__.py +0 -0
  57. tensorbored/backend/event_processing/data_ingester.py +276 -0
  58. tensorbored/backend/event_processing/data_provider.py +535 -0
  59. tensorbored/backend/event_processing/directory_loader.py +142 -0
  60. tensorbored/backend/event_processing/directory_watcher.py +272 -0
  61. tensorbored/backend/event_processing/event_accumulator.py +950 -0
  62. tensorbored/backend/event_processing/event_file_inspector.py +463 -0
  63. tensorbored/backend/event_processing/event_file_loader.py +292 -0
  64. tensorbored/backend/event_processing/event_multiplexer.py +521 -0
  65. tensorbored/backend/event_processing/event_util.py +68 -0
  66. tensorbored/backend/event_processing/io_wrapper.py +223 -0
  67. tensorbored/backend/event_processing/plugin_asset_util.py +104 -0
  68. tensorbored/backend/event_processing/plugin_event_accumulator.py +721 -0
  69. tensorbored/backend/event_processing/plugin_event_multiplexer.py +522 -0
  70. tensorbored/backend/event_processing/reservoir.py +266 -0
  71. tensorbored/backend/event_processing/tag_types.py +29 -0
  72. tensorbored/backend/experiment_id.py +71 -0
  73. tensorbored/backend/experimental_plugin.py +51 -0
  74. tensorbored/backend/http_util.py +263 -0
  75. tensorbored/backend/json_util.py +70 -0
  76. tensorbored/backend/path_prefix.py +67 -0
  77. tensorbored/backend/process_graph.py +74 -0
  78. tensorbored/backend/security_validator.py +202 -0
  79. tensorbored/compat/__init__.py +69 -0
  80. tensorbored/compat/proto/__init__.py +0 -0
  81. tensorbored/compat/proto/allocation_description_pb2.py +35 -0
  82. tensorbored/compat/proto/api_def_pb2.py +82 -0
  83. tensorbored/compat/proto/attr_value_pb2.py +80 -0
  84. tensorbored/compat/proto/cluster_pb2.py +58 -0
  85. tensorbored/compat/proto/config_pb2.py +271 -0
  86. tensorbored/compat/proto/coordination_config_pb2.py +45 -0
  87. tensorbored/compat/proto/cost_graph_pb2.py +87 -0
  88. tensorbored/compat/proto/cpp_shape_inference_pb2.py +70 -0
  89. tensorbored/compat/proto/debug_pb2.py +65 -0
  90. tensorbored/compat/proto/event_pb2.py +149 -0
  91. tensorbored/compat/proto/full_type_pb2.py +74 -0
  92. tensorbored/compat/proto/function_pb2.py +157 -0
  93. tensorbored/compat/proto/graph_debug_info_pb2.py +111 -0
  94. tensorbored/compat/proto/graph_pb2.py +41 -0
  95. tensorbored/compat/proto/histogram_pb2.py +39 -0
  96. tensorbored/compat/proto/meta_graph_pb2.py +254 -0
  97. tensorbored/compat/proto/node_def_pb2.py +61 -0
  98. tensorbored/compat/proto/op_def_pb2.py +81 -0
  99. tensorbored/compat/proto/resource_handle_pb2.py +48 -0
  100. tensorbored/compat/proto/rewriter_config_pb2.py +93 -0
  101. tensorbored/compat/proto/rpc_options_pb2.py +35 -0
  102. tensorbored/compat/proto/saved_object_graph_pb2.py +193 -0
  103. tensorbored/compat/proto/saver_pb2.py +38 -0
  104. tensorbored/compat/proto/step_stats_pb2.py +116 -0
  105. tensorbored/compat/proto/struct_pb2.py +144 -0
  106. tensorbored/compat/proto/summary_pb2.py +111 -0
  107. tensorbored/compat/proto/tensor_description_pb2.py +38 -0
  108. tensorbored/compat/proto/tensor_pb2.py +68 -0
  109. tensorbored/compat/proto/tensor_shape_pb2.py +46 -0
  110. tensorbored/compat/proto/tfprof_log_pb2.py +307 -0
  111. tensorbored/compat/proto/trackable_object_graph_pb2.py +90 -0
  112. tensorbored/compat/proto/types_pb2.py +105 -0
  113. tensorbored/compat/proto/variable_pb2.py +62 -0
  114. tensorbored/compat/proto/verifier_config_pb2.py +38 -0
  115. tensorbored/compat/proto/versions_pb2.py +35 -0
  116. tensorbored/compat/tensorflow_stub/__init__.py +38 -0
  117. tensorbored/compat/tensorflow_stub/app.py +124 -0
  118. tensorbored/compat/tensorflow_stub/compat/__init__.py +131 -0
  119. tensorbored/compat/tensorflow_stub/compat/v1/__init__.py +20 -0
  120. tensorbored/compat/tensorflow_stub/dtypes.py +692 -0
  121. tensorbored/compat/tensorflow_stub/error_codes.py +169 -0
  122. tensorbored/compat/tensorflow_stub/errors.py +507 -0
  123. tensorbored/compat/tensorflow_stub/flags.py +124 -0
  124. tensorbored/compat/tensorflow_stub/io/__init__.py +17 -0
  125. tensorbored/compat/tensorflow_stub/io/gfile.py +1011 -0
  126. tensorbored/compat/tensorflow_stub/pywrap_tensorflow.py +285 -0
  127. tensorbored/compat/tensorflow_stub/tensor_shape.py +1035 -0
  128. tensorbored/context.py +129 -0
  129. tensorbored/data/__init__.py +0 -0
  130. tensorbored/data/grpc_provider.py +365 -0
  131. tensorbored/data/ingester.py +46 -0
  132. tensorbored/data/proto/__init__.py +0 -0
  133. tensorbored/data/proto/data_provider_pb2.py +517 -0
  134. tensorbored/data/proto/data_provider_pb2_grpc.py +374 -0
  135. tensorbored/data/provider.py +1365 -0
  136. tensorbored/data/server_ingester.py +301 -0
  137. tensorbored/data_compat.py +159 -0
  138. tensorbored/dataclass_compat.py +224 -0
  139. tensorbored/default.py +124 -0
  140. tensorbored/errors.py +130 -0
  141. tensorbored/lazy.py +99 -0
  142. tensorbored/main.py +48 -0
  143. tensorbored/main_lib.py +62 -0
  144. tensorbored/manager.py +487 -0
  145. tensorbored/notebook.py +441 -0
  146. tensorbored/plugin_util.py +266 -0
  147. tensorbored/plugins/__init__.py +0 -0
  148. tensorbored/plugins/audio/__init__.py +0 -0
  149. tensorbored/plugins/audio/audio_plugin.py +229 -0
  150. tensorbored/plugins/audio/metadata.py +69 -0
  151. tensorbored/plugins/audio/plugin_data_pb2.py +37 -0
  152. tensorbored/plugins/audio/summary.py +230 -0
  153. tensorbored/plugins/audio/summary_v2.py +124 -0
  154. tensorbored/plugins/base_plugin.py +367 -0
  155. tensorbored/plugins/core/__init__.py +0 -0
  156. tensorbored/plugins/core/core_plugin.py +981 -0
  157. tensorbored/plugins/custom_scalar/__init__.py +0 -0
  158. tensorbored/plugins/custom_scalar/custom_scalars_plugin.py +320 -0
  159. tensorbored/plugins/custom_scalar/layout_pb2.py +85 -0
  160. tensorbored/plugins/custom_scalar/metadata.py +35 -0
  161. tensorbored/plugins/custom_scalar/summary.py +79 -0
  162. tensorbored/plugins/debugger_v2/__init__.py +0 -0
  163. tensorbored/plugins/debugger_v2/debug_data_multiplexer.py +631 -0
  164. tensorbored/plugins/debugger_v2/debug_data_provider.py +634 -0
  165. tensorbored/plugins/debugger_v2/debugger_v2_plugin.py +504 -0
  166. tensorbored/plugins/distribution/__init__.py +0 -0
  167. tensorbored/plugins/distribution/compressor.py +158 -0
  168. tensorbored/plugins/distribution/distributions_plugin.py +116 -0
  169. tensorbored/plugins/distribution/metadata.py +19 -0
  170. tensorbored/plugins/graph/__init__.py +0 -0
  171. tensorbored/plugins/graph/graph_util.py +129 -0
  172. tensorbored/plugins/graph/graphs_plugin.py +336 -0
  173. tensorbored/plugins/graph/keras_util.py +328 -0
  174. tensorbored/plugins/graph/metadata.py +42 -0
  175. tensorbored/plugins/histogram/__init__.py +0 -0
  176. tensorbored/plugins/histogram/histograms_plugin.py +144 -0
  177. tensorbored/plugins/histogram/metadata.py +63 -0
  178. tensorbored/plugins/histogram/plugin_data_pb2.py +34 -0
  179. tensorbored/plugins/histogram/summary.py +234 -0
  180. tensorbored/plugins/histogram/summary_v2.py +292 -0
  181. tensorbored/plugins/hparams/__init__.py +14 -0
  182. tensorbored/plugins/hparams/_keras.py +93 -0
  183. tensorbored/plugins/hparams/api.py +130 -0
  184. tensorbored/plugins/hparams/api_pb2.py +208 -0
  185. tensorbored/plugins/hparams/backend_context.py +606 -0
  186. tensorbored/plugins/hparams/download_data.py +158 -0
  187. tensorbored/plugins/hparams/error.py +26 -0
  188. tensorbored/plugins/hparams/get_experiment.py +71 -0
  189. tensorbored/plugins/hparams/hparams_plugin.py +206 -0
  190. tensorbored/plugins/hparams/hparams_util_pb2.py +69 -0
  191. tensorbored/plugins/hparams/json_format_compat.py +38 -0
  192. tensorbored/plugins/hparams/list_metric_evals.py +57 -0
  193. tensorbored/plugins/hparams/list_session_groups.py +1040 -0
  194. tensorbored/plugins/hparams/metadata.py +125 -0
  195. tensorbored/plugins/hparams/metrics.py +41 -0
  196. tensorbored/plugins/hparams/plugin_data_pb2.py +69 -0
  197. tensorbored/plugins/hparams/summary.py +205 -0
  198. tensorbored/plugins/hparams/summary_v2.py +597 -0
  199. tensorbored/plugins/image/__init__.py +0 -0
  200. tensorbored/plugins/image/images_plugin.py +232 -0
  201. tensorbored/plugins/image/metadata.py +65 -0
  202. tensorbored/plugins/image/plugin_data_pb2.py +34 -0
  203. tensorbored/plugins/image/summary.py +159 -0
  204. tensorbored/plugins/image/summary_v2.py +130 -0
  205. tensorbored/plugins/mesh/__init__.py +14 -0
  206. tensorbored/plugins/mesh/mesh_plugin.py +292 -0
  207. tensorbored/plugins/mesh/metadata.py +152 -0
  208. tensorbored/plugins/mesh/plugin_data_pb2.py +37 -0
  209. tensorbored/plugins/mesh/summary.py +251 -0
  210. tensorbored/plugins/mesh/summary_v2.py +214 -0
  211. tensorbored/plugins/metrics/__init__.py +0 -0
  212. tensorbored/plugins/metrics/metadata.py +17 -0
  213. tensorbored/plugins/metrics/metrics_plugin.py +623 -0
  214. tensorbored/plugins/pr_curve/__init__.py +0 -0
  215. tensorbored/plugins/pr_curve/metadata.py +75 -0
  216. tensorbored/plugins/pr_curve/plugin_data_pb2.py +34 -0
  217. tensorbored/plugins/pr_curve/pr_curves_plugin.py +241 -0
  218. tensorbored/plugins/pr_curve/summary.py +574 -0
  219. tensorbored/plugins/profile_redirect/__init__.py +0 -0
  220. tensorbored/plugins/profile_redirect/profile_redirect_plugin.py +49 -0
  221. tensorbored/plugins/projector/__init__.py +67 -0
  222. tensorbored/plugins/projector/metadata.py +26 -0
  223. tensorbored/plugins/projector/projector_config_pb2.py +54 -0
  224. tensorbored/plugins/projector/projector_plugin.py +795 -0
  225. tensorbored/plugins/projector/tf_projector_plugin/index.js +32 -0
  226. tensorbored/plugins/projector/tf_projector_plugin/projector_binary.html +524 -0
  227. tensorbored/plugins/projector/tf_projector_plugin/projector_binary.js +15536 -0
  228. tensorbored/plugins/scalar/__init__.py +0 -0
  229. tensorbored/plugins/scalar/metadata.py +60 -0
  230. tensorbored/plugins/scalar/plugin_data_pb2.py +34 -0
  231. tensorbored/plugins/scalar/scalars_plugin.py +181 -0
  232. tensorbored/plugins/scalar/summary.py +109 -0
  233. tensorbored/plugins/scalar/summary_v2.py +124 -0
  234. tensorbored/plugins/text/__init__.py +0 -0
  235. tensorbored/plugins/text/metadata.py +62 -0
  236. tensorbored/plugins/text/plugin_data_pb2.py +34 -0
  237. tensorbored/plugins/text/summary.py +114 -0
  238. tensorbored/plugins/text/summary_v2.py +124 -0
  239. tensorbored/plugins/text/text_plugin.py +288 -0
  240. tensorbored/plugins/wit_redirect/__init__.py +0 -0
  241. tensorbored/plugins/wit_redirect/wit_redirect_plugin.py +49 -0
  242. tensorbored/program.py +910 -0
  243. tensorbored/summary/__init__.py +35 -0
  244. tensorbored/summary/_output.py +124 -0
  245. tensorbored/summary/_tf/__init__.py +14 -0
  246. tensorbored/summary/_tf/summary/__init__.py +178 -0
  247. tensorbored/summary/_writer.py +105 -0
  248. tensorbored/summary/v1.py +51 -0
  249. tensorbored/summary/v2.py +25 -0
  250. tensorbored/summary/writer/__init__.py +13 -0
  251. tensorbored/summary/writer/event_file_writer.py +291 -0
  252. tensorbored/summary/writer/record_writer.py +50 -0
  253. tensorbored/util/__init__.py +0 -0
  254. tensorbored/util/encoder.py +116 -0
  255. tensorbored/util/grpc_util.py +311 -0
  256. tensorbored/util/img_mime_type_detector.py +40 -0
  257. tensorbored/util/io_util.py +20 -0
  258. tensorbored/util/lazy_tensor_creator.py +110 -0
  259. tensorbored/util/op_evaluator.py +104 -0
  260. tensorbored/util/platform_util.py +20 -0
  261. tensorbored/util/tb_logging.py +24 -0
  262. tensorbored/util/tensor_util.py +617 -0
  263. tensorbored/util/timing.py +122 -0
  264. tensorbored/version.py +21 -0
  265. tensorbored/webfiles.zip +0 -0
  266. tensorbored-2.21.0rc1769983804.dist-info/METADATA +49 -0
  267. tensorbored-2.21.0rc1769983804.dist-info/RECORD +271 -0
  268. tensorbored-2.21.0rc1769983804.dist-info/WHEEL +5 -0
  269. tensorbored-2.21.0rc1769983804.dist-info/entry_points.txt +6 -0
  270. tensorbored-2.21.0rc1769983804.dist-info/licenses/LICENSE +739 -0
  271. tensorbored-2.21.0rc1769983804.dist-info/top_level.txt +1 -0
@@ -0,0 +1,692 @@
1
+ # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Library of dtypes (Tensor element types)."""
16
+
17
+ import numpy as np
18
+
19
+ from . import pywrap_tensorflow
20
+ from tensorbored.compat.proto import types_pb2
21
+
22
+ _np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
23
+
24
+
25
+ # @tf_export("DType")
26
+ class DType:
27
+ """Represents the type of the elements in a `Tensor`.
28
+
29
+ The following `DType` objects are defined:
30
+
31
+ * `tf.float16`: 16-bit half-precision floating-point.
32
+ * `tf.float32`: 32-bit single-precision floating-point.
33
+ * `tf.float64`: 64-bit double-precision floating-point.
34
+ * `tf.bfloat16`: 16-bit truncated floating-point.
35
+ * `tf.complex64`: 64-bit single-precision complex.
36
+ * `tf.complex128`: 128-bit double-precision complex.
37
+ * `tf.int8`: 8-bit signed integer.
38
+ * `tf.uint8`: 8-bit unsigned integer.
39
+ * `tf.uint16`: 16-bit unsigned integer.
40
+ * `tf.uint32`: 32-bit unsigned integer.
41
+ * `tf.uint64`: 64-bit unsigned integer.
42
+ * `tf.int16`: 16-bit signed integer.
43
+ * `tf.int32`: 32-bit signed integer.
44
+ * `tf.int64`: 64-bit signed integer.
45
+ * `tf.bool`: Boolean.
46
+ * `tf.string`: String.
47
+ * `tf.qint8`: Quantized 8-bit signed integer.
48
+ * `tf.quint8`: Quantized 8-bit unsigned integer.
49
+ * `tf.qint16`: Quantized 16-bit signed integer.
50
+ * `tf.quint16`: Quantized 16-bit unsigned integer.
51
+ * `tf.qint32`: Quantized 32-bit signed integer.
52
+ * `tf.resource`: Handle to a mutable resource.
53
+ * `tf.variant`: Values of arbitrary types.
54
+
55
+ In addition, variants of these types with the `_ref` suffix are
56
+ defined for reference-typed tensors.
57
+
58
+ The `tf.as_dtype()` function converts numpy types and string type
59
+ names to a `DType` object.
60
+ """
61
+
62
+ def __init__(self, type_enum):
63
+ """Creates a new `DataType`.
64
+
65
+ NOTE(mrry): In normal circumstances, you should not need to
66
+ construct a `DataType` object directly. Instead, use the
67
+ `tf.as_dtype()` function.
68
+
69
+ Args:
70
+ type_enum: A `types_pb2.DataType` enum value.
71
+
72
+ Raises:
73
+ TypeError: If `type_enum` is not a value `types_pb2.DataType`.
74
+ """
75
+ # TODO(mrry): Make the necessary changes (using __new__) to ensure
76
+ # that calling this returns one of the interned values.
77
+ type_enum = int(type_enum)
78
+ if (
79
+ type_enum not in types_pb2.DataType.values()
80
+ or type_enum == types_pb2.DT_INVALID
81
+ ):
82
+ raise TypeError(
83
+ "type_enum is not a valid types_pb2.DataType: %s" % type_enum
84
+ )
85
+ self._type_enum = type_enum
86
+
87
+ @property
88
+ def _is_ref_dtype(self):
89
+ """Returns `True` if this `DType` represents a reference type."""
90
+ return self._type_enum > 100
91
+
92
+ @property
93
+ def _as_ref(self):
94
+ """Returns a reference `DType` based on this `DType`."""
95
+ if self._is_ref_dtype:
96
+ return self
97
+ else:
98
+ return _INTERN_TABLE[self._type_enum + 100]
99
+
100
+ @property
101
+ def base_dtype(self):
102
+ """Returns a non-reference `DType` based on this `DType`."""
103
+ if self._is_ref_dtype:
104
+ return _INTERN_TABLE[self._type_enum - 100]
105
+ else:
106
+ return self
107
+
108
+ @property
109
+ def real_dtype(self):
110
+ """Returns the dtype correspond to this dtype's real part."""
111
+ base = self.base_dtype
112
+ if base == complex64:
113
+ return float32
114
+ elif base == complex128:
115
+ return float64
116
+ else:
117
+ return self
118
+
119
+ @property
120
+ def is_numpy_compatible(self):
121
+ return self._type_enum not in _NUMPY_INCOMPATIBLE
122
+
123
+ @property
124
+ def as_numpy_dtype(self):
125
+ """Returns a `numpy.dtype` based on this `DType`."""
126
+ return _TF_TO_NP[self._type_enum]
127
+
128
+ @property
129
+ def as_datatype_enum(self):
130
+ """Returns a `types_pb2.DataType` enum value based on this `DType`."""
131
+ return self._type_enum
132
+
133
+ @property
134
+ def is_bool(self):
135
+ """Returns whether this is a boolean data type."""
136
+ return self.base_dtype == bool
137
+
138
+ @property
139
+ def is_integer(self):
140
+ """Returns whether this is a (non-quantized) integer type."""
141
+ return (
142
+ self.is_numpy_compatible
143
+ and not self.is_quantized
144
+ and np.issubdtype(self.as_numpy_dtype, np.integer)
145
+ )
146
+
147
+ @property
148
+ def is_floating(self):
149
+ """Returns whether this is a (non-quantized, real) floating point
150
+ type."""
151
+ return (
152
+ self.is_numpy_compatible
153
+ and np.issubdtype(self.as_numpy_dtype, np.floating)
154
+ ) or self.base_dtype == bfloat16
155
+
156
+ @property
157
+ def is_complex(self):
158
+ """Returns whether this is a complex floating point type."""
159
+ return self.base_dtype in (complex64, complex128)
160
+
161
+ @property
162
+ def is_quantized(self):
163
+ """Returns whether this is a quantized data type."""
164
+ return self.base_dtype in _QUANTIZED_DTYPES_NO_REF
165
+
166
+ @property
167
+ def is_unsigned(self):
168
+ """Returns whether this type is unsigned.
169
+
170
+ Non-numeric, unordered, and quantized types are not considered unsigned, and
171
+ this function returns `False`.
172
+
173
+ Returns:
174
+ Whether a `DType` is unsigned.
175
+ """
176
+ try:
177
+ return self.min == 0
178
+ except TypeError:
179
+ return False
180
+
181
+ @property
182
+ def min(self):
183
+ """Returns the minimum representable value in this data type.
184
+
185
+ Raises:
186
+ TypeError: if this is a non-numeric, unordered, or quantized type.
187
+ """
188
+ if self.is_quantized or self.base_dtype in (
189
+ bool,
190
+ string,
191
+ complex64,
192
+ complex128,
193
+ ):
194
+ raise TypeError("Cannot find minimum value of %s." % self)
195
+
196
+ # there is no simple way to get the min value of a dtype, we have to check
197
+ # float and int types separately
198
+ try:
199
+ return np.finfo(self.as_numpy_dtype).min
200
+ except: # bare except as possible raises by finfo not documented
201
+ try:
202
+ return np.iinfo(self.as_numpy_dtype).min
203
+ except:
204
+ if self.base_dtype == bfloat16:
205
+ return _np_bfloat16(float.fromhex("-0x1.FEp127"))
206
+ raise TypeError("Cannot find minimum value of %s." % self)
207
+
208
+ @property
209
+ def max(self):
210
+ """Returns the maximum representable value in this data type.
211
+
212
+ Raises:
213
+ TypeError: if this is a non-numeric, unordered, or quantized type.
214
+ """
215
+ if self.is_quantized or self.base_dtype in (
216
+ bool,
217
+ string,
218
+ complex64,
219
+ complex128,
220
+ ):
221
+ raise TypeError("Cannot find maximum value of %s." % self)
222
+
223
+ # there is no simple way to get the max value of a dtype, we have to check
224
+ # float and int types separately
225
+ try:
226
+ return np.finfo(self.as_numpy_dtype).max
227
+ except: # bare except as possible raises by finfo not documented
228
+ try:
229
+ return np.iinfo(self.as_numpy_dtype).max
230
+ except:
231
+ if self.base_dtype == bfloat16:
232
+ return _np_bfloat16(float.fromhex("0x1.FEp127"))
233
+ raise TypeError("Cannot find maximum value of %s." % self)
234
+
235
+ @property
236
+ def limits(self, clip_negative=True):
237
+ """Return intensity limits, i.e. (min, max) tuple, of the dtype.
238
+
239
+ Args:
240
+ clip_negative : bool, optional
241
+ If True, clip the negative range (i.e. return 0 for min intensity)
242
+ even if the image dtype allows negative values.
243
+ Returns
244
+ min, max : tuple
245
+ Lower and upper intensity limits.
246
+ """
247
+ min, max = dtype_range[
248
+ self.as_numpy_dtype
249
+ ] # pylint: disable=redefined-builtin
250
+ if clip_negative:
251
+ min = 0 # pylint: disable=redefined-builtin
252
+ return min, max
253
+
254
+ def is_compatible_with(self, other):
255
+ """Returns True if the `other` DType will be converted to this DType.
256
+
257
+ The conversion rules are as follows:
258
+
259
+ ```python
260
+ DType(T) .is_compatible_with(DType(T)) == True
261
+ DType(T) .is_compatible_with(DType(T).as_ref) == True
262
+ DType(T).as_ref.is_compatible_with(DType(T)) == False
263
+ DType(T).as_ref.is_compatible_with(DType(T).as_ref) == True
264
+ ```
265
+
266
+ Args:
267
+ other: A `DType` (or object that may be converted to a `DType`).
268
+
269
+ Returns:
270
+ True if a Tensor of the `other` `DType` will be implicitly converted to
271
+ this `DType`.
272
+ """
273
+ other = as_dtype(other)
274
+ return self._type_enum in (
275
+ other.as_datatype_enum,
276
+ other.base_dtype.as_datatype_enum,
277
+ )
278
+
279
+ def __eq__(self, other):
280
+ """Returns True iff this DType refers to the same type as `other`."""
281
+ if other is None:
282
+ return False
283
+ try:
284
+ dtype = as_dtype(other).as_datatype_enum
285
+ return self._type_enum == dtype # pylint: disable=protected-access
286
+ except TypeError:
287
+ return False
288
+
289
+ def __ne__(self, other):
290
+ """Returns True iff self != other."""
291
+ return not self.__eq__(other)
292
+
293
+ @property
294
+ def name(self):
295
+ """Returns the string name for this `DType`."""
296
+ return _TYPE_TO_STRING[self._type_enum]
297
+
298
+ def __int__(self):
299
+ return self._type_enum
300
+
301
+ def __str__(self):
302
+ return "<dtype: %r>" % self.name
303
+
304
+ def __repr__(self):
305
+ return "tf." + self.name
306
+
307
+ def __hash__(self):
308
+ return self._type_enum
309
+
310
+ def __reduce__(self):
311
+ return as_dtype, (self.name,)
312
+
313
+ @property
314
+ def size(self):
315
+ if (
316
+ self._type_enum == types_pb2.DT_VARIANT
317
+ or self._type_enum == types_pb2.DT_RESOURCE
318
+ ):
319
+ return 1
320
+ return np.dtype(self.as_numpy_dtype).itemsize
321
+
322
+
323
+ # Define data type range of numpy dtype
324
+ dtype_range = {
325
+ np.bool_: (False, True),
326
+ np.uint8: (0, 255),
327
+ np.uint16: (0, 65535),
328
+ np.int8: (-128, 127),
329
+ np.int16: (-32768, 32767),
330
+ np.int64: (-(2**63), 2**63 - 1),
331
+ np.uint64: (0, 2**64 - 1),
332
+ np.int32: (-(2**31), 2**31 - 1),
333
+ np.uint32: (0, 2**32 - 1),
334
+ np.float32: (-1, 1),
335
+ np.float64: (-1, 1),
336
+ }
337
+
338
+ # Define standard wrappers for the types_pb2.DataType enum.
339
+ resource = DType(types_pb2.DT_RESOURCE)
340
+ # tf_export("resource").export_constant(__name__, "resource")
341
+ variant = DType(types_pb2.DT_VARIANT)
342
+ # tf_export("variant").export_constant(__name__, "variant")
343
+ float16 = DType(types_pb2.DT_HALF)
344
+ # tf_export("float16").export_constant(__name__, "float16")
345
+ half = float16
346
+ # tf_export("half").export_constant(__name__, "half")
347
+ float32 = DType(types_pb2.DT_FLOAT)
348
+ # tf_export("float32").export_constant(__name__, "float32")
349
+ float64 = DType(types_pb2.DT_DOUBLE)
350
+ # tf_export("float64").export_constant(__name__, "float64")
351
+ double = float64
352
+ # tf_export("double").export_constant(__name__, "double")
353
+ int32 = DType(types_pb2.DT_INT32)
354
+ # tf_export("int32").export_constant(__name__, "int32")
355
+ uint8 = DType(types_pb2.DT_UINT8)
356
+ # tf_export("uint8").export_constant(__name__, "uint8")
357
+ uint16 = DType(types_pb2.DT_UINT16)
358
+ # tf_export("uint16").export_constant(__name__, "uint16")
359
+ uint32 = DType(types_pb2.DT_UINT32)
360
+ # tf_export("uint32").export_constant(__name__, "uint32")
361
+ uint64 = DType(types_pb2.DT_UINT64)
362
+ # tf_export("uint64").export_constant(__name__, "uint64")
363
+ int16 = DType(types_pb2.DT_INT16)
364
+ # tf_export("int16").export_constant(__name__, "int16")
365
+ int8 = DType(types_pb2.DT_INT8)
366
+ # tf_export("int8").export_constant(__name__, "int8")
367
+ string = DType(types_pb2.DT_STRING)
368
+ # tf_export("string").export_constant(__name__, "string")
369
+ complex64 = DType(types_pb2.DT_COMPLEX64)
370
+ # tf_export("complex64").export_constant(__name__, "complex64")
371
+ complex128 = DType(types_pb2.DT_COMPLEX128)
372
+ # tf_export("complex128").export_constant(__name__, "complex128")
373
+ int64 = DType(types_pb2.DT_INT64)
374
+ # tf_export("int64").export_constant(__name__, "int64")
375
+ bool = DType(types_pb2.DT_BOOL) # pylint: disable=redefined-builtin
376
+ # tf_export("bool").export_constant(__name__, "bool")
377
+ qint8 = DType(types_pb2.DT_QINT8)
378
+ # tf_export("qint8").export_constant(__name__, "qint8")
379
+ quint8 = DType(types_pb2.DT_QUINT8)
380
+ # tf_export("quint8").export_constant(__name__, "quint8")
381
+ qint16 = DType(types_pb2.DT_QINT16)
382
+ # tf_export("qint16").export_constant(__name__, "qint16")
383
+ quint16 = DType(types_pb2.DT_QUINT16)
384
+ # tf_export("quint16").export_constant(__name__, "quint16")
385
+ qint32 = DType(types_pb2.DT_QINT32)
386
+ # tf_export("qint32").export_constant(__name__, "qint32")
387
+ resource_ref = DType(types_pb2.DT_RESOURCE_REF)
388
+ variant_ref = DType(types_pb2.DT_VARIANT_REF)
389
+ bfloat16 = DType(types_pb2.DT_BFLOAT16)
390
+ # tf_export("bfloat16").export_constant(__name__, "bfloat16")
391
+ float16_ref = DType(types_pb2.DT_HALF_REF)
392
+ half_ref = float16_ref
393
+ float32_ref = DType(types_pb2.DT_FLOAT_REF)
394
+ float64_ref = DType(types_pb2.DT_DOUBLE_REF)
395
+ double_ref = float64_ref
396
+ int32_ref = DType(types_pb2.DT_INT32_REF)
397
+ uint32_ref = DType(types_pb2.DT_UINT32_REF)
398
+ uint8_ref = DType(types_pb2.DT_UINT8_REF)
399
+ uint16_ref = DType(types_pb2.DT_UINT16_REF)
400
+ int16_ref = DType(types_pb2.DT_INT16_REF)
401
+ int8_ref = DType(types_pb2.DT_INT8_REF)
402
+ string_ref = DType(types_pb2.DT_STRING_REF)
403
+ complex64_ref = DType(types_pb2.DT_COMPLEX64_REF)
404
+ complex128_ref = DType(types_pb2.DT_COMPLEX128_REF)
405
+ int64_ref = DType(types_pb2.DT_INT64_REF)
406
+ uint64_ref = DType(types_pb2.DT_UINT64_REF)
407
+ bool_ref = DType(types_pb2.DT_BOOL_REF)
408
+ qint8_ref = DType(types_pb2.DT_QINT8_REF)
409
+ quint8_ref = DType(types_pb2.DT_QUINT8_REF)
410
+ qint16_ref = DType(types_pb2.DT_QINT16_REF)
411
+ quint16_ref = DType(types_pb2.DT_QUINT16_REF)
412
+ qint32_ref = DType(types_pb2.DT_QINT32_REF)
413
+ bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF)
414
+
415
+ _NUMPY_INCOMPATIBLE = frozenset(
416
+ [
417
+ types_pb2.DT_VARIANT,
418
+ types_pb2.DT_VARIANT_REF,
419
+ types_pb2.DT_RESOURCE,
420
+ types_pb2.DT_RESOURCE_REF,
421
+ ]
422
+ )
423
+
424
+ # Maintain an intern table so that we don't have to create a large
425
+ # number of small objects.
426
+ _INTERN_TABLE = {
427
+ types_pb2.DT_HALF: float16,
428
+ types_pb2.DT_FLOAT: float32,
429
+ types_pb2.DT_DOUBLE: float64,
430
+ types_pb2.DT_INT32: int32,
431
+ types_pb2.DT_UINT8: uint8,
432
+ types_pb2.DT_UINT16: uint16,
433
+ types_pb2.DT_UINT32: uint32,
434
+ types_pb2.DT_UINT64: uint64,
435
+ types_pb2.DT_INT16: int16,
436
+ types_pb2.DT_INT8: int8,
437
+ types_pb2.DT_STRING: string,
438
+ types_pb2.DT_COMPLEX64: complex64,
439
+ types_pb2.DT_COMPLEX128: complex128,
440
+ types_pb2.DT_INT64: int64,
441
+ types_pb2.DT_BOOL: bool,
442
+ types_pb2.DT_QINT8: qint8,
443
+ types_pb2.DT_QUINT8: quint8,
444
+ types_pb2.DT_QINT16: qint16,
445
+ types_pb2.DT_QUINT16: quint16,
446
+ types_pb2.DT_QINT32: qint32,
447
+ types_pb2.DT_BFLOAT16: bfloat16,
448
+ types_pb2.DT_RESOURCE: resource,
449
+ types_pb2.DT_VARIANT: variant,
450
+ types_pb2.DT_HALF_REF: float16_ref,
451
+ types_pb2.DT_FLOAT_REF: float32_ref,
452
+ types_pb2.DT_DOUBLE_REF: float64_ref,
453
+ types_pb2.DT_INT32_REF: int32_ref,
454
+ types_pb2.DT_UINT32_REF: uint32_ref,
455
+ types_pb2.DT_UINT8_REF: uint8_ref,
456
+ types_pb2.DT_UINT16_REF: uint16_ref,
457
+ types_pb2.DT_INT16_REF: int16_ref,
458
+ types_pb2.DT_INT8_REF: int8_ref,
459
+ types_pb2.DT_STRING_REF: string_ref,
460
+ types_pb2.DT_COMPLEX64_REF: complex64_ref,
461
+ types_pb2.DT_COMPLEX128_REF: complex128_ref,
462
+ types_pb2.DT_INT64_REF: int64_ref,
463
+ types_pb2.DT_UINT64_REF: uint64_ref,
464
+ types_pb2.DT_BOOL_REF: bool_ref,
465
+ types_pb2.DT_QINT8_REF: qint8_ref,
466
+ types_pb2.DT_QUINT8_REF: quint8_ref,
467
+ types_pb2.DT_QINT16_REF: qint16_ref,
468
+ types_pb2.DT_QUINT16_REF: quint16_ref,
469
+ types_pb2.DT_QINT32_REF: qint32_ref,
470
+ types_pb2.DT_BFLOAT16_REF: bfloat16_ref,
471
+ types_pb2.DT_RESOURCE_REF: resource_ref,
472
+ types_pb2.DT_VARIANT_REF: variant_ref,
473
+ }
474
+
475
+ # Standard mappings between types_pb2.DataType values and string names.
476
+ _TYPE_TO_STRING = {
477
+ types_pb2.DT_HALF: "float16",
478
+ types_pb2.DT_FLOAT: "float32",
479
+ types_pb2.DT_DOUBLE: "float64",
480
+ types_pb2.DT_INT32: "int32",
481
+ types_pb2.DT_UINT8: "uint8",
482
+ types_pb2.DT_UINT16: "uint16",
483
+ types_pb2.DT_UINT32: "uint32",
484
+ types_pb2.DT_UINT64: "uint64",
485
+ types_pb2.DT_INT16: "int16",
486
+ types_pb2.DT_INT8: "int8",
487
+ types_pb2.DT_STRING: "string",
488
+ types_pb2.DT_COMPLEX64: "complex64",
489
+ types_pb2.DT_COMPLEX128: "complex128",
490
+ types_pb2.DT_INT64: "int64",
491
+ types_pb2.DT_BOOL: "bool",
492
+ types_pb2.DT_QINT8: "qint8",
493
+ types_pb2.DT_QUINT8: "quint8",
494
+ types_pb2.DT_QINT16: "qint16",
495
+ types_pb2.DT_QUINT16: "quint16",
496
+ types_pb2.DT_QINT32: "qint32",
497
+ types_pb2.DT_BFLOAT16: "bfloat16",
498
+ types_pb2.DT_RESOURCE: "resource",
499
+ types_pb2.DT_VARIANT: "variant",
500
+ types_pb2.DT_HALF_REF: "float16_ref",
501
+ types_pb2.DT_FLOAT_REF: "float32_ref",
502
+ types_pb2.DT_DOUBLE_REF: "float64_ref",
503
+ types_pb2.DT_INT32_REF: "int32_ref",
504
+ types_pb2.DT_UINT32_REF: "uint32_ref",
505
+ types_pb2.DT_UINT8_REF: "uint8_ref",
506
+ types_pb2.DT_UINT16_REF: "uint16_ref",
507
+ types_pb2.DT_INT16_REF: "int16_ref",
508
+ types_pb2.DT_INT8_REF: "int8_ref",
509
+ types_pb2.DT_STRING_REF: "string_ref",
510
+ types_pb2.DT_COMPLEX64_REF: "complex64_ref",
511
+ types_pb2.DT_COMPLEX128_REF: "complex128_ref",
512
+ types_pb2.DT_INT64_REF: "int64_ref",
513
+ types_pb2.DT_UINT64_REF: "uint64_ref",
514
+ types_pb2.DT_BOOL_REF: "bool_ref",
515
+ types_pb2.DT_QINT8_REF: "qint8_ref",
516
+ types_pb2.DT_QUINT8_REF: "quint8_ref",
517
+ types_pb2.DT_QINT16_REF: "qint16_ref",
518
+ types_pb2.DT_QUINT16_REF: "quint16_ref",
519
+ types_pb2.DT_QINT32_REF: "qint32_ref",
520
+ types_pb2.DT_BFLOAT16_REF: "bfloat16_ref",
521
+ types_pb2.DT_RESOURCE_REF: "resource_ref",
522
+ types_pb2.DT_VARIANT_REF: "variant_ref",
523
+ }
524
+ _STRING_TO_TF = {
525
+ value: _INTERN_TABLE[key] for key, value in _TYPE_TO_STRING.items()
526
+ }
527
+ # Add non-canonical aliases.
528
+ _STRING_TO_TF["half"] = float16
529
+ _STRING_TO_TF["half_ref"] = float16_ref
530
+ _STRING_TO_TF["float"] = float32
531
+ _STRING_TO_TF["float_ref"] = float32_ref
532
+ _STRING_TO_TF["double"] = float64
533
+ _STRING_TO_TF["double_ref"] = float64_ref
534
+
535
+ # Numpy representation for quantized dtypes.
536
+ #
537
+ # These are magic strings that are used in the swig wrapper to identify
538
+ # quantized types.
539
+ # TODO(mrry,keveman): Investigate Numpy type registration to replace this
540
+ # hard-coding of names.
541
+ _np_qint8 = np.dtype([("qint8", np.int8)])
542
+ _np_quint8 = np.dtype([("quint8", np.uint8)])
543
+ _np_qint16 = np.dtype([("qint16", np.int16)])
544
+ _np_quint16 = np.dtype([("quint16", np.uint16)])
545
+ _np_qint32 = np.dtype([("qint32", np.int32)])
546
+
547
+ # _np_bfloat16 is defined by a module import.
548
+
549
+ # Custom struct dtype for directly-fed ResourceHandles of supported type(s).
550
+ np_resource = np.dtype([("resource", np.ubyte)])
551
+
552
+ # Standard mappings between types_pb2.DataType values and numpy.dtypes.
553
+ _NP_TO_TF = frozenset(
554
+ [
555
+ (np.float16, float16),
556
+ (np.float32, float32),
557
+ (np.float64, float64),
558
+ (np.int32, int32),
559
+ (np.int64, int64),
560
+ (np.uint8, uint8),
561
+ (np.uint16, uint16),
562
+ (np.uint32, uint32),
563
+ (np.uint64, uint64),
564
+ (np.int16, int16),
565
+ (np.int8, int8),
566
+ (np.complex64, complex64),
567
+ (np.complex128, complex128),
568
+ (np.object_, string),
569
+ (np.bool_, bool),
570
+ (_np_qint8, qint8),
571
+ (_np_quint8, quint8),
572
+ (_np_qint16, qint16),
573
+ (_np_quint16, quint16),
574
+ (_np_qint32, qint32),
575
+ # TODO(#1677): _np_bfloat16 is defined as 0. This causes `as_dtype` to
576
+ # error. Add below back after we fix `TF_bfloat16_type`.
577
+ # (_np_bfloat16, bfloat16),
578
+ ]
579
+ )
580
+ _TF_TO_NP = {
581
+ types_pb2.DT_HALF: np.float16,
582
+ types_pb2.DT_FLOAT: np.float32,
583
+ types_pb2.DT_DOUBLE: np.float64,
584
+ types_pb2.DT_INT32: np.int32,
585
+ types_pb2.DT_UINT8: np.uint8,
586
+ types_pb2.DT_UINT16: np.uint16,
587
+ types_pb2.DT_UINT32: np.uint32,
588
+ types_pb2.DT_UINT64: np.uint64,
589
+ types_pb2.DT_INT16: np.int16,
590
+ types_pb2.DT_INT8: np.int8,
591
+ # NOTE(touts): For strings we use np.object as it supports variable length
592
+ # strings.
593
+ types_pb2.DT_STRING: np.object_,
594
+ types_pb2.DT_COMPLEX64: np.complex64,
595
+ types_pb2.DT_COMPLEX128: np.complex128,
596
+ types_pb2.DT_INT64: np.int64,
597
+ types_pb2.DT_BOOL: np.bool_,
598
+ types_pb2.DT_QINT8: _np_qint8,
599
+ types_pb2.DT_QUINT8: _np_quint8,
600
+ types_pb2.DT_QINT16: _np_qint16,
601
+ types_pb2.DT_QUINT16: _np_quint16,
602
+ types_pb2.DT_QINT32: _np_qint32,
603
+ types_pb2.DT_BFLOAT16: _np_bfloat16,
604
+ # Ref types
605
+ types_pb2.DT_HALF_REF: np.float16,
606
+ types_pb2.DT_FLOAT_REF: np.float32,
607
+ types_pb2.DT_DOUBLE_REF: np.float64,
608
+ types_pb2.DT_INT32_REF: np.int32,
609
+ types_pb2.DT_UINT32_REF: np.uint32,
610
+ types_pb2.DT_UINT8_REF: np.uint8,
611
+ types_pb2.DT_UINT16_REF: np.uint16,
612
+ types_pb2.DT_INT16_REF: np.int16,
613
+ types_pb2.DT_INT8_REF: np.int8,
614
+ types_pb2.DT_STRING_REF: np.object_,
615
+ types_pb2.DT_COMPLEX64_REF: np.complex64,
616
+ types_pb2.DT_COMPLEX128_REF: np.complex128,
617
+ types_pb2.DT_INT64_REF: np.int64,
618
+ types_pb2.DT_UINT64_REF: np.uint64,
619
+ types_pb2.DT_BOOL_REF: np.bool_,
620
+ types_pb2.DT_QINT8_REF: _np_qint8,
621
+ types_pb2.DT_QUINT8_REF: _np_quint8,
622
+ types_pb2.DT_QINT16_REF: _np_qint16,
623
+ types_pb2.DT_QUINT16_REF: _np_quint16,
624
+ types_pb2.DT_QINT32_REF: _np_qint32,
625
+ types_pb2.DT_BFLOAT16_REF: _np_bfloat16,
626
+ }
627
+
628
+ _QUANTIZED_DTYPES_NO_REF = frozenset([qint8, quint8, qint16, quint16, qint32])
629
+ _QUANTIZED_DTYPES_REF = frozenset(
630
+ [qint8_ref, quint8_ref, qint16_ref, quint16_ref, qint32_ref]
631
+ )
632
+ QUANTIZED_DTYPES = _QUANTIZED_DTYPES_REF.union(_QUANTIZED_DTYPES_NO_REF)
633
+ # tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES")
634
+
635
+ _PYTHON_TO_TF = {float: float32, bool: bool}
636
+
637
+
638
+ # @tf_export("as_dtype")
639
+ def as_dtype(type_value):
640
+ """Converts the given `type_value` to a `DType`.
641
+
642
+ Args:
643
+ type_value: A value that can be converted to a `tf.DType` object. This may
644
+ currently be a `tf.DType` object, a [`DataType`
645
+ enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto),
646
+ a string type name, or a `numpy.dtype`.
647
+
648
+ Returns:
649
+ A `DType` corresponding to `type_value`.
650
+
651
+ Raises:
652
+ TypeError: If `type_value` cannot be converted to a `DType`.
653
+ """
654
+ if isinstance(type_value, DType):
655
+ return type_value
656
+
657
+ try:
658
+ return _INTERN_TABLE[type_value]
659
+ except KeyError:
660
+ pass
661
+
662
+ try:
663
+ return _STRING_TO_TF[type_value]
664
+ except KeyError:
665
+ pass
666
+
667
+ try:
668
+ return _PYTHON_TO_TF[type_value]
669
+ except KeyError:
670
+ pass
671
+
672
+ if isinstance(type_value, np.dtype):
673
+ # The numpy dtype for strings is variable length. We can not compare
674
+ # dtype with a single constant (np.string does not exist) to decide
675
+ # dtype is a "string" type. We need to compare the dtype.type to be
676
+ # sure it's a string type.
677
+ if type_value.type == np.bytes_ or type_value.type == np.str_:
678
+ return string
679
+
680
+ if isinstance(type_value, (type, np.dtype)):
681
+ for key, val in _NP_TO_TF:
682
+ try:
683
+ if key == type_value:
684
+ return val
685
+ except TypeError as e:
686
+ raise TypeError(
687
+ "Cannot convert {} to a dtype. {}".format(type_value, e)
688
+ )
689
+
690
+ raise TypeError(
691
+ "Cannot convert value %r to a TensorFlow DType." % type_value
692
+ )