tensorbored 2.21.0rc1769983804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensorbored/__init__.py +112 -0
- tensorbored/_vendor/__init__.py +0 -0
- tensorbored/_vendor/bleach/__init__.py +125 -0
- tensorbored/_vendor/bleach/_vendor/__init__.py +0 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/__init__.py +35 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_ihatexml.py +289 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_inputstream.py +918 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_tokenizer.py +1735 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_trie/__init__.py +5 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_trie/_base.py +40 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_trie/py.py +67 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_utils.py +159 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/constants.py +2946 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/__init__.py +0 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/alphabeticalattributes.py +29 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/base.py +12 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/inject_meta_charset.py +73 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/lint.py +93 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/optionaltags.py +207 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/sanitizer.py +916 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/whitespace.py +38 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/html5parser.py +2795 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/serializer.py +409 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/__init__.py +30 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/genshi.py +54 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/sax.py +50 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/__init__.py +88 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/base.py +417 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/dom.py +239 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree.py +343 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree_lxml.py +392 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/__init__.py +154 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/base.py +252 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/dom.py +43 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree.py +131 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree_lxml.py +215 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/genshi.py +69 -0
- tensorbored/_vendor/bleach/_vendor/parse.py +1078 -0
- tensorbored/_vendor/bleach/callbacks.py +32 -0
- tensorbored/_vendor/bleach/html5lib_shim.py +757 -0
- tensorbored/_vendor/bleach/linkifier.py +633 -0
- tensorbored/_vendor/bleach/parse_shim.py +1 -0
- tensorbored/_vendor/bleach/sanitizer.py +638 -0
- tensorbored/_vendor/bleach/six_shim.py +19 -0
- tensorbored/_vendor/webencodings/__init__.py +342 -0
- tensorbored/_vendor/webencodings/labels.py +231 -0
- tensorbored/_vendor/webencodings/mklabels.py +59 -0
- tensorbored/_vendor/webencodings/x_user_defined.py +325 -0
- tensorbored/assets.py +36 -0
- tensorbored/auth.py +102 -0
- tensorbored/backend/__init__.py +0 -0
- tensorbored/backend/application.py +604 -0
- tensorbored/backend/auth_context_middleware.py +38 -0
- tensorbored/backend/client_feature_flags.py +113 -0
- tensorbored/backend/empty_path_redirect.py +46 -0
- tensorbored/backend/event_processing/__init__.py +0 -0
- tensorbored/backend/event_processing/data_ingester.py +276 -0
- tensorbored/backend/event_processing/data_provider.py +535 -0
- tensorbored/backend/event_processing/directory_loader.py +142 -0
- tensorbored/backend/event_processing/directory_watcher.py +272 -0
- tensorbored/backend/event_processing/event_accumulator.py +950 -0
- tensorbored/backend/event_processing/event_file_inspector.py +463 -0
- tensorbored/backend/event_processing/event_file_loader.py +292 -0
- tensorbored/backend/event_processing/event_multiplexer.py +521 -0
- tensorbored/backend/event_processing/event_util.py +68 -0
- tensorbored/backend/event_processing/io_wrapper.py +223 -0
- tensorbored/backend/event_processing/plugin_asset_util.py +104 -0
- tensorbored/backend/event_processing/plugin_event_accumulator.py +721 -0
- tensorbored/backend/event_processing/plugin_event_multiplexer.py +522 -0
- tensorbored/backend/event_processing/reservoir.py +266 -0
- tensorbored/backend/event_processing/tag_types.py +29 -0
- tensorbored/backend/experiment_id.py +71 -0
- tensorbored/backend/experimental_plugin.py +51 -0
- tensorbored/backend/http_util.py +263 -0
- tensorbored/backend/json_util.py +70 -0
- tensorbored/backend/path_prefix.py +67 -0
- tensorbored/backend/process_graph.py +74 -0
- tensorbored/backend/security_validator.py +202 -0
- tensorbored/compat/__init__.py +69 -0
- tensorbored/compat/proto/__init__.py +0 -0
- tensorbored/compat/proto/allocation_description_pb2.py +35 -0
- tensorbored/compat/proto/api_def_pb2.py +82 -0
- tensorbored/compat/proto/attr_value_pb2.py +80 -0
- tensorbored/compat/proto/cluster_pb2.py +58 -0
- tensorbored/compat/proto/config_pb2.py +271 -0
- tensorbored/compat/proto/coordination_config_pb2.py +45 -0
- tensorbored/compat/proto/cost_graph_pb2.py +87 -0
- tensorbored/compat/proto/cpp_shape_inference_pb2.py +70 -0
- tensorbored/compat/proto/debug_pb2.py +65 -0
- tensorbored/compat/proto/event_pb2.py +149 -0
- tensorbored/compat/proto/full_type_pb2.py +74 -0
- tensorbored/compat/proto/function_pb2.py +157 -0
- tensorbored/compat/proto/graph_debug_info_pb2.py +111 -0
- tensorbored/compat/proto/graph_pb2.py +41 -0
- tensorbored/compat/proto/histogram_pb2.py +39 -0
- tensorbored/compat/proto/meta_graph_pb2.py +254 -0
- tensorbored/compat/proto/node_def_pb2.py +61 -0
- tensorbored/compat/proto/op_def_pb2.py +81 -0
- tensorbored/compat/proto/resource_handle_pb2.py +48 -0
- tensorbored/compat/proto/rewriter_config_pb2.py +93 -0
- tensorbored/compat/proto/rpc_options_pb2.py +35 -0
- tensorbored/compat/proto/saved_object_graph_pb2.py +193 -0
- tensorbored/compat/proto/saver_pb2.py +38 -0
- tensorbored/compat/proto/step_stats_pb2.py +116 -0
- tensorbored/compat/proto/struct_pb2.py +144 -0
- tensorbored/compat/proto/summary_pb2.py +111 -0
- tensorbored/compat/proto/tensor_description_pb2.py +38 -0
- tensorbored/compat/proto/tensor_pb2.py +68 -0
- tensorbored/compat/proto/tensor_shape_pb2.py +46 -0
- tensorbored/compat/proto/tfprof_log_pb2.py +307 -0
- tensorbored/compat/proto/trackable_object_graph_pb2.py +90 -0
- tensorbored/compat/proto/types_pb2.py +105 -0
- tensorbored/compat/proto/variable_pb2.py +62 -0
- tensorbored/compat/proto/verifier_config_pb2.py +38 -0
- tensorbored/compat/proto/versions_pb2.py +35 -0
- tensorbored/compat/tensorflow_stub/__init__.py +38 -0
- tensorbored/compat/tensorflow_stub/app.py +124 -0
- tensorbored/compat/tensorflow_stub/compat/__init__.py +131 -0
- tensorbored/compat/tensorflow_stub/compat/v1/__init__.py +20 -0
- tensorbored/compat/tensorflow_stub/dtypes.py +692 -0
- tensorbored/compat/tensorflow_stub/error_codes.py +169 -0
- tensorbored/compat/tensorflow_stub/errors.py +507 -0
- tensorbored/compat/tensorflow_stub/flags.py +124 -0
- tensorbored/compat/tensorflow_stub/io/__init__.py +17 -0
- tensorbored/compat/tensorflow_stub/io/gfile.py +1011 -0
- tensorbored/compat/tensorflow_stub/pywrap_tensorflow.py +285 -0
- tensorbored/compat/tensorflow_stub/tensor_shape.py +1035 -0
- tensorbored/context.py +129 -0
- tensorbored/data/__init__.py +0 -0
- tensorbored/data/grpc_provider.py +365 -0
- tensorbored/data/ingester.py +46 -0
- tensorbored/data/proto/__init__.py +0 -0
- tensorbored/data/proto/data_provider_pb2.py +517 -0
- tensorbored/data/proto/data_provider_pb2_grpc.py +374 -0
- tensorbored/data/provider.py +1365 -0
- tensorbored/data/server_ingester.py +301 -0
- tensorbored/data_compat.py +159 -0
- tensorbored/dataclass_compat.py +224 -0
- tensorbored/default.py +124 -0
- tensorbored/errors.py +130 -0
- tensorbored/lazy.py +99 -0
- tensorbored/main.py +48 -0
- tensorbored/main_lib.py +62 -0
- tensorbored/manager.py +487 -0
- tensorbored/notebook.py +441 -0
- tensorbored/plugin_util.py +266 -0
- tensorbored/plugins/__init__.py +0 -0
- tensorbored/plugins/audio/__init__.py +0 -0
- tensorbored/plugins/audio/audio_plugin.py +229 -0
- tensorbored/plugins/audio/metadata.py +69 -0
- tensorbored/plugins/audio/plugin_data_pb2.py +37 -0
- tensorbored/plugins/audio/summary.py +230 -0
- tensorbored/plugins/audio/summary_v2.py +124 -0
- tensorbored/plugins/base_plugin.py +367 -0
- tensorbored/plugins/core/__init__.py +0 -0
- tensorbored/plugins/core/core_plugin.py +981 -0
- tensorbored/plugins/custom_scalar/__init__.py +0 -0
- tensorbored/plugins/custom_scalar/custom_scalars_plugin.py +320 -0
- tensorbored/plugins/custom_scalar/layout_pb2.py +85 -0
- tensorbored/plugins/custom_scalar/metadata.py +35 -0
- tensorbored/plugins/custom_scalar/summary.py +79 -0
- tensorbored/plugins/debugger_v2/__init__.py +0 -0
- tensorbored/plugins/debugger_v2/debug_data_multiplexer.py +631 -0
- tensorbored/plugins/debugger_v2/debug_data_provider.py +634 -0
- tensorbored/plugins/debugger_v2/debugger_v2_plugin.py +504 -0
- tensorbored/plugins/distribution/__init__.py +0 -0
- tensorbored/plugins/distribution/compressor.py +158 -0
- tensorbored/plugins/distribution/distributions_plugin.py +116 -0
- tensorbored/plugins/distribution/metadata.py +19 -0
- tensorbored/plugins/graph/__init__.py +0 -0
- tensorbored/plugins/graph/graph_util.py +129 -0
- tensorbored/plugins/graph/graphs_plugin.py +336 -0
- tensorbored/plugins/graph/keras_util.py +328 -0
- tensorbored/plugins/graph/metadata.py +42 -0
- tensorbored/plugins/histogram/__init__.py +0 -0
- tensorbored/plugins/histogram/histograms_plugin.py +144 -0
- tensorbored/plugins/histogram/metadata.py +63 -0
- tensorbored/plugins/histogram/plugin_data_pb2.py +34 -0
- tensorbored/plugins/histogram/summary.py +234 -0
- tensorbored/plugins/histogram/summary_v2.py +292 -0
- tensorbored/plugins/hparams/__init__.py +14 -0
- tensorbored/plugins/hparams/_keras.py +93 -0
- tensorbored/plugins/hparams/api.py +130 -0
- tensorbored/plugins/hparams/api_pb2.py +208 -0
- tensorbored/plugins/hparams/backend_context.py +606 -0
- tensorbored/plugins/hparams/download_data.py +158 -0
- tensorbored/plugins/hparams/error.py +26 -0
- tensorbored/plugins/hparams/get_experiment.py +71 -0
- tensorbored/plugins/hparams/hparams_plugin.py +206 -0
- tensorbored/plugins/hparams/hparams_util_pb2.py +69 -0
- tensorbored/plugins/hparams/json_format_compat.py +38 -0
- tensorbored/plugins/hparams/list_metric_evals.py +57 -0
- tensorbored/plugins/hparams/list_session_groups.py +1040 -0
- tensorbored/plugins/hparams/metadata.py +125 -0
- tensorbored/plugins/hparams/metrics.py +41 -0
- tensorbored/plugins/hparams/plugin_data_pb2.py +69 -0
- tensorbored/plugins/hparams/summary.py +205 -0
- tensorbored/plugins/hparams/summary_v2.py +597 -0
- tensorbored/plugins/image/__init__.py +0 -0
- tensorbored/plugins/image/images_plugin.py +232 -0
- tensorbored/plugins/image/metadata.py +65 -0
- tensorbored/plugins/image/plugin_data_pb2.py +34 -0
- tensorbored/plugins/image/summary.py +159 -0
- tensorbored/plugins/image/summary_v2.py +130 -0
- tensorbored/plugins/mesh/__init__.py +14 -0
- tensorbored/plugins/mesh/mesh_plugin.py +292 -0
- tensorbored/plugins/mesh/metadata.py +152 -0
- tensorbored/plugins/mesh/plugin_data_pb2.py +37 -0
- tensorbored/plugins/mesh/summary.py +251 -0
- tensorbored/plugins/mesh/summary_v2.py +214 -0
- tensorbored/plugins/metrics/__init__.py +0 -0
- tensorbored/plugins/metrics/metadata.py +17 -0
- tensorbored/plugins/metrics/metrics_plugin.py +623 -0
- tensorbored/plugins/pr_curve/__init__.py +0 -0
- tensorbored/plugins/pr_curve/metadata.py +75 -0
- tensorbored/plugins/pr_curve/plugin_data_pb2.py +34 -0
- tensorbored/plugins/pr_curve/pr_curves_plugin.py +241 -0
- tensorbored/plugins/pr_curve/summary.py +574 -0
- tensorbored/plugins/profile_redirect/__init__.py +0 -0
- tensorbored/plugins/profile_redirect/profile_redirect_plugin.py +49 -0
- tensorbored/plugins/projector/__init__.py +67 -0
- tensorbored/plugins/projector/metadata.py +26 -0
- tensorbored/plugins/projector/projector_config_pb2.py +54 -0
- tensorbored/plugins/projector/projector_plugin.py +795 -0
- tensorbored/plugins/projector/tf_projector_plugin/index.js +32 -0
- tensorbored/plugins/projector/tf_projector_plugin/projector_binary.html +524 -0
- tensorbored/plugins/projector/tf_projector_plugin/projector_binary.js +15536 -0
- tensorbored/plugins/scalar/__init__.py +0 -0
- tensorbored/plugins/scalar/metadata.py +60 -0
- tensorbored/plugins/scalar/plugin_data_pb2.py +34 -0
- tensorbored/plugins/scalar/scalars_plugin.py +181 -0
- tensorbored/plugins/scalar/summary.py +109 -0
- tensorbored/plugins/scalar/summary_v2.py +124 -0
- tensorbored/plugins/text/__init__.py +0 -0
- tensorbored/plugins/text/metadata.py +62 -0
- tensorbored/plugins/text/plugin_data_pb2.py +34 -0
- tensorbored/plugins/text/summary.py +114 -0
- tensorbored/plugins/text/summary_v2.py +124 -0
- tensorbored/plugins/text/text_plugin.py +288 -0
- tensorbored/plugins/wit_redirect/__init__.py +0 -0
- tensorbored/plugins/wit_redirect/wit_redirect_plugin.py +49 -0
- tensorbored/program.py +910 -0
- tensorbored/summary/__init__.py +35 -0
- tensorbored/summary/_output.py +124 -0
- tensorbored/summary/_tf/__init__.py +14 -0
- tensorbored/summary/_tf/summary/__init__.py +178 -0
- tensorbored/summary/_writer.py +105 -0
- tensorbored/summary/v1.py +51 -0
- tensorbored/summary/v2.py +25 -0
- tensorbored/summary/writer/__init__.py +13 -0
- tensorbored/summary/writer/event_file_writer.py +291 -0
- tensorbored/summary/writer/record_writer.py +50 -0
- tensorbored/util/__init__.py +0 -0
- tensorbored/util/encoder.py +116 -0
- tensorbored/util/grpc_util.py +311 -0
- tensorbored/util/img_mime_type_detector.py +40 -0
- tensorbored/util/io_util.py +20 -0
- tensorbored/util/lazy_tensor_creator.py +110 -0
- tensorbored/util/op_evaluator.py +104 -0
- tensorbored/util/platform_util.py +20 -0
- tensorbored/util/tb_logging.py +24 -0
- tensorbored/util/tensor_util.py +617 -0
- tensorbored/util/timing.py +122 -0
- tensorbored/version.py +21 -0
- tensorbored/webfiles.zip +0 -0
- tensorbored-2.21.0rc1769983804.dist-info/METADATA +49 -0
- tensorbored-2.21.0rc1769983804.dist-info/RECORD +271 -0
- tensorbored-2.21.0rc1769983804.dist-info/WHEEL +5 -0
- tensorbored-2.21.0rc1769983804.dist-info/entry_points.txt +6 -0
- tensorbored-2.21.0rc1769983804.dist-info/licenses/LICENSE +739 -0
- tensorbored-2.21.0rc1769983804.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,692 @@
|
|
|
1
|
+
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Library of dtypes (Tensor element types)."""
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from . import pywrap_tensorflow
|
|
20
|
+
from tensorbored.compat.proto import types_pb2
|
|
21
|
+
|
|
22
|
+
_np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# @tf_export("DType")
|
|
26
|
+
class DType:
|
|
27
|
+
"""Represents the type of the elements in a `Tensor`.
|
|
28
|
+
|
|
29
|
+
The following `DType` objects are defined:
|
|
30
|
+
|
|
31
|
+
* `tf.float16`: 16-bit half-precision floating-point.
|
|
32
|
+
* `tf.float32`: 32-bit single-precision floating-point.
|
|
33
|
+
* `tf.float64`: 64-bit double-precision floating-point.
|
|
34
|
+
* `tf.bfloat16`: 16-bit truncated floating-point.
|
|
35
|
+
* `tf.complex64`: 64-bit single-precision complex.
|
|
36
|
+
* `tf.complex128`: 128-bit double-precision complex.
|
|
37
|
+
* `tf.int8`: 8-bit signed integer.
|
|
38
|
+
* `tf.uint8`: 8-bit unsigned integer.
|
|
39
|
+
* `tf.uint16`: 16-bit unsigned integer.
|
|
40
|
+
* `tf.uint32`: 32-bit unsigned integer.
|
|
41
|
+
* `tf.uint64`: 64-bit unsigned integer.
|
|
42
|
+
* `tf.int16`: 16-bit signed integer.
|
|
43
|
+
* `tf.int32`: 32-bit signed integer.
|
|
44
|
+
* `tf.int64`: 64-bit signed integer.
|
|
45
|
+
* `tf.bool`: Boolean.
|
|
46
|
+
* `tf.string`: String.
|
|
47
|
+
* `tf.qint8`: Quantized 8-bit signed integer.
|
|
48
|
+
* `tf.quint8`: Quantized 8-bit unsigned integer.
|
|
49
|
+
* `tf.qint16`: Quantized 16-bit signed integer.
|
|
50
|
+
* `tf.quint16`: Quantized 16-bit unsigned integer.
|
|
51
|
+
* `tf.qint32`: Quantized 32-bit signed integer.
|
|
52
|
+
* `tf.resource`: Handle to a mutable resource.
|
|
53
|
+
* `tf.variant`: Values of arbitrary types.
|
|
54
|
+
|
|
55
|
+
In addition, variants of these types with the `_ref` suffix are
|
|
56
|
+
defined for reference-typed tensors.
|
|
57
|
+
|
|
58
|
+
The `tf.as_dtype()` function converts numpy types and string type
|
|
59
|
+
names to a `DType` object.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(self, type_enum):
|
|
63
|
+
"""Creates a new `DataType`.
|
|
64
|
+
|
|
65
|
+
NOTE(mrry): In normal circumstances, you should not need to
|
|
66
|
+
construct a `DataType` object directly. Instead, use the
|
|
67
|
+
`tf.as_dtype()` function.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
type_enum: A `types_pb2.DataType` enum value.
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
TypeError: If `type_enum` is not a value `types_pb2.DataType`.
|
|
74
|
+
"""
|
|
75
|
+
# TODO(mrry): Make the necessary changes (using __new__) to ensure
|
|
76
|
+
# that calling this returns one of the interned values.
|
|
77
|
+
type_enum = int(type_enum)
|
|
78
|
+
if (
|
|
79
|
+
type_enum not in types_pb2.DataType.values()
|
|
80
|
+
or type_enum == types_pb2.DT_INVALID
|
|
81
|
+
):
|
|
82
|
+
raise TypeError(
|
|
83
|
+
"type_enum is not a valid types_pb2.DataType: %s" % type_enum
|
|
84
|
+
)
|
|
85
|
+
self._type_enum = type_enum
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def _is_ref_dtype(self):
|
|
89
|
+
"""Returns `True` if this `DType` represents a reference type."""
|
|
90
|
+
return self._type_enum > 100
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def _as_ref(self):
|
|
94
|
+
"""Returns a reference `DType` based on this `DType`."""
|
|
95
|
+
if self._is_ref_dtype:
|
|
96
|
+
return self
|
|
97
|
+
else:
|
|
98
|
+
return _INTERN_TABLE[self._type_enum + 100]
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def base_dtype(self):
|
|
102
|
+
"""Returns a non-reference `DType` based on this `DType`."""
|
|
103
|
+
if self._is_ref_dtype:
|
|
104
|
+
return _INTERN_TABLE[self._type_enum - 100]
|
|
105
|
+
else:
|
|
106
|
+
return self
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def real_dtype(self):
|
|
110
|
+
"""Returns the dtype correspond to this dtype's real part."""
|
|
111
|
+
base = self.base_dtype
|
|
112
|
+
if base == complex64:
|
|
113
|
+
return float32
|
|
114
|
+
elif base == complex128:
|
|
115
|
+
return float64
|
|
116
|
+
else:
|
|
117
|
+
return self
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def is_numpy_compatible(self):
|
|
121
|
+
return self._type_enum not in _NUMPY_INCOMPATIBLE
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def as_numpy_dtype(self):
|
|
125
|
+
"""Returns a `numpy.dtype` based on this `DType`."""
|
|
126
|
+
return _TF_TO_NP[self._type_enum]
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def as_datatype_enum(self):
|
|
130
|
+
"""Returns a `types_pb2.DataType` enum value based on this `DType`."""
|
|
131
|
+
return self._type_enum
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def is_bool(self):
|
|
135
|
+
"""Returns whether this is a boolean data type."""
|
|
136
|
+
return self.base_dtype == bool
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def is_integer(self):
|
|
140
|
+
"""Returns whether this is a (non-quantized) integer type."""
|
|
141
|
+
return (
|
|
142
|
+
self.is_numpy_compatible
|
|
143
|
+
and not self.is_quantized
|
|
144
|
+
and np.issubdtype(self.as_numpy_dtype, np.integer)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def is_floating(self):
|
|
149
|
+
"""Returns whether this is a (non-quantized, real) floating point
|
|
150
|
+
type."""
|
|
151
|
+
return (
|
|
152
|
+
self.is_numpy_compatible
|
|
153
|
+
and np.issubdtype(self.as_numpy_dtype, np.floating)
|
|
154
|
+
) or self.base_dtype == bfloat16
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def is_complex(self):
|
|
158
|
+
"""Returns whether this is a complex floating point type."""
|
|
159
|
+
return self.base_dtype in (complex64, complex128)
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def is_quantized(self):
|
|
163
|
+
"""Returns whether this is a quantized data type."""
|
|
164
|
+
return self.base_dtype in _QUANTIZED_DTYPES_NO_REF
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def is_unsigned(self):
|
|
168
|
+
"""Returns whether this type is unsigned.
|
|
169
|
+
|
|
170
|
+
Non-numeric, unordered, and quantized types are not considered unsigned, and
|
|
171
|
+
this function returns `False`.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
Whether a `DType` is unsigned.
|
|
175
|
+
"""
|
|
176
|
+
try:
|
|
177
|
+
return self.min == 0
|
|
178
|
+
except TypeError:
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def min(self):
|
|
183
|
+
"""Returns the minimum representable value in this data type.
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
TypeError: if this is a non-numeric, unordered, or quantized type.
|
|
187
|
+
"""
|
|
188
|
+
if self.is_quantized or self.base_dtype in (
|
|
189
|
+
bool,
|
|
190
|
+
string,
|
|
191
|
+
complex64,
|
|
192
|
+
complex128,
|
|
193
|
+
):
|
|
194
|
+
raise TypeError("Cannot find minimum value of %s." % self)
|
|
195
|
+
|
|
196
|
+
# there is no simple way to get the min value of a dtype, we have to check
|
|
197
|
+
# float and int types separately
|
|
198
|
+
try:
|
|
199
|
+
return np.finfo(self.as_numpy_dtype).min
|
|
200
|
+
except: # bare except as possible raises by finfo not documented
|
|
201
|
+
try:
|
|
202
|
+
return np.iinfo(self.as_numpy_dtype).min
|
|
203
|
+
except:
|
|
204
|
+
if self.base_dtype == bfloat16:
|
|
205
|
+
return _np_bfloat16(float.fromhex("-0x1.FEp127"))
|
|
206
|
+
raise TypeError("Cannot find minimum value of %s." % self)
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def max(self):
|
|
210
|
+
"""Returns the maximum representable value in this data type.
|
|
211
|
+
|
|
212
|
+
Raises:
|
|
213
|
+
TypeError: if this is a non-numeric, unordered, or quantized type.
|
|
214
|
+
"""
|
|
215
|
+
if self.is_quantized or self.base_dtype in (
|
|
216
|
+
bool,
|
|
217
|
+
string,
|
|
218
|
+
complex64,
|
|
219
|
+
complex128,
|
|
220
|
+
):
|
|
221
|
+
raise TypeError("Cannot find maximum value of %s." % self)
|
|
222
|
+
|
|
223
|
+
# there is no simple way to get the max value of a dtype, we have to check
|
|
224
|
+
# float and int types separately
|
|
225
|
+
try:
|
|
226
|
+
return np.finfo(self.as_numpy_dtype).max
|
|
227
|
+
except: # bare except as possible raises by finfo not documented
|
|
228
|
+
try:
|
|
229
|
+
return np.iinfo(self.as_numpy_dtype).max
|
|
230
|
+
except:
|
|
231
|
+
if self.base_dtype == bfloat16:
|
|
232
|
+
return _np_bfloat16(float.fromhex("0x1.FEp127"))
|
|
233
|
+
raise TypeError("Cannot find maximum value of %s." % self)
|
|
234
|
+
|
|
235
|
+
@property
|
|
236
|
+
def limits(self, clip_negative=True):
|
|
237
|
+
"""Return intensity limits, i.e. (min, max) tuple, of the dtype.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
clip_negative : bool, optional
|
|
241
|
+
If True, clip the negative range (i.e. return 0 for min intensity)
|
|
242
|
+
even if the image dtype allows negative values.
|
|
243
|
+
Returns
|
|
244
|
+
min, max : tuple
|
|
245
|
+
Lower and upper intensity limits.
|
|
246
|
+
"""
|
|
247
|
+
min, max = dtype_range[
|
|
248
|
+
self.as_numpy_dtype
|
|
249
|
+
] # pylint: disable=redefined-builtin
|
|
250
|
+
if clip_negative:
|
|
251
|
+
min = 0 # pylint: disable=redefined-builtin
|
|
252
|
+
return min, max
|
|
253
|
+
|
|
254
|
+
def is_compatible_with(self, other):
|
|
255
|
+
"""Returns True if the `other` DType will be converted to this DType.
|
|
256
|
+
|
|
257
|
+
The conversion rules are as follows:
|
|
258
|
+
|
|
259
|
+
```python
|
|
260
|
+
DType(T) .is_compatible_with(DType(T)) == True
|
|
261
|
+
DType(T) .is_compatible_with(DType(T).as_ref) == True
|
|
262
|
+
DType(T).as_ref.is_compatible_with(DType(T)) == False
|
|
263
|
+
DType(T).as_ref.is_compatible_with(DType(T).as_ref) == True
|
|
264
|
+
```
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
other: A `DType` (or object that may be converted to a `DType`).
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
True if a Tensor of the `other` `DType` will be implicitly converted to
|
|
271
|
+
this `DType`.
|
|
272
|
+
"""
|
|
273
|
+
other = as_dtype(other)
|
|
274
|
+
return self._type_enum in (
|
|
275
|
+
other.as_datatype_enum,
|
|
276
|
+
other.base_dtype.as_datatype_enum,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def __eq__(self, other):
|
|
280
|
+
"""Returns True iff this DType refers to the same type as `other`."""
|
|
281
|
+
if other is None:
|
|
282
|
+
return False
|
|
283
|
+
try:
|
|
284
|
+
dtype = as_dtype(other).as_datatype_enum
|
|
285
|
+
return self._type_enum == dtype # pylint: disable=protected-access
|
|
286
|
+
except TypeError:
|
|
287
|
+
return False
|
|
288
|
+
|
|
289
|
+
def __ne__(self, other):
|
|
290
|
+
"""Returns True iff self != other."""
|
|
291
|
+
return not self.__eq__(other)
|
|
292
|
+
|
|
293
|
+
@property
|
|
294
|
+
def name(self):
|
|
295
|
+
"""Returns the string name for this `DType`."""
|
|
296
|
+
return _TYPE_TO_STRING[self._type_enum]
|
|
297
|
+
|
|
298
|
+
def __int__(self):
|
|
299
|
+
return self._type_enum
|
|
300
|
+
|
|
301
|
+
def __str__(self):
|
|
302
|
+
return "<dtype: %r>" % self.name
|
|
303
|
+
|
|
304
|
+
def __repr__(self):
|
|
305
|
+
return "tf." + self.name
|
|
306
|
+
|
|
307
|
+
def __hash__(self):
|
|
308
|
+
return self._type_enum
|
|
309
|
+
|
|
310
|
+
def __reduce__(self):
|
|
311
|
+
return as_dtype, (self.name,)
|
|
312
|
+
|
|
313
|
+
@property
|
|
314
|
+
def size(self):
|
|
315
|
+
if (
|
|
316
|
+
self._type_enum == types_pb2.DT_VARIANT
|
|
317
|
+
or self._type_enum == types_pb2.DT_RESOURCE
|
|
318
|
+
):
|
|
319
|
+
return 1
|
|
320
|
+
return np.dtype(self.as_numpy_dtype).itemsize
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
# Define data type range of numpy dtype
|
|
324
|
+
dtype_range = {
|
|
325
|
+
np.bool_: (False, True),
|
|
326
|
+
np.uint8: (0, 255),
|
|
327
|
+
np.uint16: (0, 65535),
|
|
328
|
+
np.int8: (-128, 127),
|
|
329
|
+
np.int16: (-32768, 32767),
|
|
330
|
+
np.int64: (-(2**63), 2**63 - 1),
|
|
331
|
+
np.uint64: (0, 2**64 - 1),
|
|
332
|
+
np.int32: (-(2**31), 2**31 - 1),
|
|
333
|
+
np.uint32: (0, 2**32 - 1),
|
|
334
|
+
np.float32: (-1, 1),
|
|
335
|
+
np.float64: (-1, 1),
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
# Define standard wrappers for the types_pb2.DataType enum.
|
|
339
|
+
resource = DType(types_pb2.DT_RESOURCE)
|
|
340
|
+
# tf_export("resource").export_constant(__name__, "resource")
|
|
341
|
+
variant = DType(types_pb2.DT_VARIANT)
|
|
342
|
+
# tf_export("variant").export_constant(__name__, "variant")
|
|
343
|
+
float16 = DType(types_pb2.DT_HALF)
|
|
344
|
+
# tf_export("float16").export_constant(__name__, "float16")
|
|
345
|
+
half = float16
|
|
346
|
+
# tf_export("half").export_constant(__name__, "half")
|
|
347
|
+
float32 = DType(types_pb2.DT_FLOAT)
|
|
348
|
+
# tf_export("float32").export_constant(__name__, "float32")
|
|
349
|
+
float64 = DType(types_pb2.DT_DOUBLE)
|
|
350
|
+
# tf_export("float64").export_constant(__name__, "float64")
|
|
351
|
+
double = float64
|
|
352
|
+
# tf_export("double").export_constant(__name__, "double")
|
|
353
|
+
int32 = DType(types_pb2.DT_INT32)
|
|
354
|
+
# tf_export("int32").export_constant(__name__, "int32")
|
|
355
|
+
uint8 = DType(types_pb2.DT_UINT8)
|
|
356
|
+
# tf_export("uint8").export_constant(__name__, "uint8")
|
|
357
|
+
uint16 = DType(types_pb2.DT_UINT16)
|
|
358
|
+
# tf_export("uint16").export_constant(__name__, "uint16")
|
|
359
|
+
uint32 = DType(types_pb2.DT_UINT32)
|
|
360
|
+
# tf_export("uint32").export_constant(__name__, "uint32")
|
|
361
|
+
uint64 = DType(types_pb2.DT_UINT64)
|
|
362
|
+
# tf_export("uint64").export_constant(__name__, "uint64")
|
|
363
|
+
int16 = DType(types_pb2.DT_INT16)
|
|
364
|
+
# tf_export("int16").export_constant(__name__, "int16")
|
|
365
|
+
int8 = DType(types_pb2.DT_INT8)
|
|
366
|
+
# tf_export("int8").export_constant(__name__, "int8")
|
|
367
|
+
string = DType(types_pb2.DT_STRING)
|
|
368
|
+
# tf_export("string").export_constant(__name__, "string")
|
|
369
|
+
complex64 = DType(types_pb2.DT_COMPLEX64)
|
|
370
|
+
# tf_export("complex64").export_constant(__name__, "complex64")
|
|
371
|
+
complex128 = DType(types_pb2.DT_COMPLEX128)
|
|
372
|
+
# tf_export("complex128").export_constant(__name__, "complex128")
|
|
373
|
+
int64 = DType(types_pb2.DT_INT64)
|
|
374
|
+
# tf_export("int64").export_constant(__name__, "int64")
|
|
375
|
+
bool = DType(types_pb2.DT_BOOL) # pylint: disable=redefined-builtin
|
|
376
|
+
# tf_export("bool").export_constant(__name__, "bool")
|
|
377
|
+
qint8 = DType(types_pb2.DT_QINT8)
|
|
378
|
+
# tf_export("qint8").export_constant(__name__, "qint8")
|
|
379
|
+
quint8 = DType(types_pb2.DT_QUINT8)
|
|
380
|
+
# tf_export("quint8").export_constant(__name__, "quint8")
|
|
381
|
+
qint16 = DType(types_pb2.DT_QINT16)
|
|
382
|
+
# tf_export("qint16").export_constant(__name__, "qint16")
|
|
383
|
+
quint16 = DType(types_pb2.DT_QUINT16)
|
|
384
|
+
# tf_export("quint16").export_constant(__name__, "quint16")
|
|
385
|
+
qint32 = DType(types_pb2.DT_QINT32)
|
|
386
|
+
# tf_export("qint32").export_constant(__name__, "qint32")
|
|
387
|
+
resource_ref = DType(types_pb2.DT_RESOURCE_REF)
|
|
388
|
+
variant_ref = DType(types_pb2.DT_VARIANT_REF)
|
|
389
|
+
bfloat16 = DType(types_pb2.DT_BFLOAT16)
|
|
390
|
+
# tf_export("bfloat16").export_constant(__name__, "bfloat16")
|
|
391
|
+
float16_ref = DType(types_pb2.DT_HALF_REF)
|
|
392
|
+
half_ref = float16_ref
|
|
393
|
+
float32_ref = DType(types_pb2.DT_FLOAT_REF)
|
|
394
|
+
float64_ref = DType(types_pb2.DT_DOUBLE_REF)
|
|
395
|
+
double_ref = float64_ref
|
|
396
|
+
int32_ref = DType(types_pb2.DT_INT32_REF)
|
|
397
|
+
uint32_ref = DType(types_pb2.DT_UINT32_REF)
|
|
398
|
+
uint8_ref = DType(types_pb2.DT_UINT8_REF)
|
|
399
|
+
uint16_ref = DType(types_pb2.DT_UINT16_REF)
|
|
400
|
+
int16_ref = DType(types_pb2.DT_INT16_REF)
|
|
401
|
+
int8_ref = DType(types_pb2.DT_INT8_REF)
|
|
402
|
+
string_ref = DType(types_pb2.DT_STRING_REF)
|
|
403
|
+
complex64_ref = DType(types_pb2.DT_COMPLEX64_REF)
|
|
404
|
+
complex128_ref = DType(types_pb2.DT_COMPLEX128_REF)
|
|
405
|
+
int64_ref = DType(types_pb2.DT_INT64_REF)
|
|
406
|
+
uint64_ref = DType(types_pb2.DT_UINT64_REF)
|
|
407
|
+
bool_ref = DType(types_pb2.DT_BOOL_REF)
|
|
408
|
+
qint8_ref = DType(types_pb2.DT_QINT8_REF)
|
|
409
|
+
quint8_ref = DType(types_pb2.DT_QUINT8_REF)
|
|
410
|
+
qint16_ref = DType(types_pb2.DT_QINT16_REF)
|
|
411
|
+
quint16_ref = DType(types_pb2.DT_QUINT16_REF)
|
|
412
|
+
qint32_ref = DType(types_pb2.DT_QINT32_REF)
|
|
413
|
+
bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF)
|
|
414
|
+
|
|
415
|
+
_NUMPY_INCOMPATIBLE = frozenset(
|
|
416
|
+
[
|
|
417
|
+
types_pb2.DT_VARIANT,
|
|
418
|
+
types_pb2.DT_VARIANT_REF,
|
|
419
|
+
types_pb2.DT_RESOURCE,
|
|
420
|
+
types_pb2.DT_RESOURCE_REF,
|
|
421
|
+
]
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Maintain an intern table so that we don't have to create a large
|
|
425
|
+
# number of small objects.
|
|
426
|
+
_INTERN_TABLE = {
|
|
427
|
+
types_pb2.DT_HALF: float16,
|
|
428
|
+
types_pb2.DT_FLOAT: float32,
|
|
429
|
+
types_pb2.DT_DOUBLE: float64,
|
|
430
|
+
types_pb2.DT_INT32: int32,
|
|
431
|
+
types_pb2.DT_UINT8: uint8,
|
|
432
|
+
types_pb2.DT_UINT16: uint16,
|
|
433
|
+
types_pb2.DT_UINT32: uint32,
|
|
434
|
+
types_pb2.DT_UINT64: uint64,
|
|
435
|
+
types_pb2.DT_INT16: int16,
|
|
436
|
+
types_pb2.DT_INT8: int8,
|
|
437
|
+
types_pb2.DT_STRING: string,
|
|
438
|
+
types_pb2.DT_COMPLEX64: complex64,
|
|
439
|
+
types_pb2.DT_COMPLEX128: complex128,
|
|
440
|
+
types_pb2.DT_INT64: int64,
|
|
441
|
+
types_pb2.DT_BOOL: bool,
|
|
442
|
+
types_pb2.DT_QINT8: qint8,
|
|
443
|
+
types_pb2.DT_QUINT8: quint8,
|
|
444
|
+
types_pb2.DT_QINT16: qint16,
|
|
445
|
+
types_pb2.DT_QUINT16: quint16,
|
|
446
|
+
types_pb2.DT_QINT32: qint32,
|
|
447
|
+
types_pb2.DT_BFLOAT16: bfloat16,
|
|
448
|
+
types_pb2.DT_RESOURCE: resource,
|
|
449
|
+
types_pb2.DT_VARIANT: variant,
|
|
450
|
+
types_pb2.DT_HALF_REF: float16_ref,
|
|
451
|
+
types_pb2.DT_FLOAT_REF: float32_ref,
|
|
452
|
+
types_pb2.DT_DOUBLE_REF: float64_ref,
|
|
453
|
+
types_pb2.DT_INT32_REF: int32_ref,
|
|
454
|
+
types_pb2.DT_UINT32_REF: uint32_ref,
|
|
455
|
+
types_pb2.DT_UINT8_REF: uint8_ref,
|
|
456
|
+
types_pb2.DT_UINT16_REF: uint16_ref,
|
|
457
|
+
types_pb2.DT_INT16_REF: int16_ref,
|
|
458
|
+
types_pb2.DT_INT8_REF: int8_ref,
|
|
459
|
+
types_pb2.DT_STRING_REF: string_ref,
|
|
460
|
+
types_pb2.DT_COMPLEX64_REF: complex64_ref,
|
|
461
|
+
types_pb2.DT_COMPLEX128_REF: complex128_ref,
|
|
462
|
+
types_pb2.DT_INT64_REF: int64_ref,
|
|
463
|
+
types_pb2.DT_UINT64_REF: uint64_ref,
|
|
464
|
+
types_pb2.DT_BOOL_REF: bool_ref,
|
|
465
|
+
types_pb2.DT_QINT8_REF: qint8_ref,
|
|
466
|
+
types_pb2.DT_QUINT8_REF: quint8_ref,
|
|
467
|
+
types_pb2.DT_QINT16_REF: qint16_ref,
|
|
468
|
+
types_pb2.DT_QUINT16_REF: quint16_ref,
|
|
469
|
+
types_pb2.DT_QINT32_REF: qint32_ref,
|
|
470
|
+
types_pb2.DT_BFLOAT16_REF: bfloat16_ref,
|
|
471
|
+
types_pb2.DT_RESOURCE_REF: resource_ref,
|
|
472
|
+
types_pb2.DT_VARIANT_REF: variant_ref,
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
# Standard mappings between types_pb2.DataType values and string names.
|
|
476
|
+
_TYPE_TO_STRING = {
|
|
477
|
+
types_pb2.DT_HALF: "float16",
|
|
478
|
+
types_pb2.DT_FLOAT: "float32",
|
|
479
|
+
types_pb2.DT_DOUBLE: "float64",
|
|
480
|
+
types_pb2.DT_INT32: "int32",
|
|
481
|
+
types_pb2.DT_UINT8: "uint8",
|
|
482
|
+
types_pb2.DT_UINT16: "uint16",
|
|
483
|
+
types_pb2.DT_UINT32: "uint32",
|
|
484
|
+
types_pb2.DT_UINT64: "uint64",
|
|
485
|
+
types_pb2.DT_INT16: "int16",
|
|
486
|
+
types_pb2.DT_INT8: "int8",
|
|
487
|
+
types_pb2.DT_STRING: "string",
|
|
488
|
+
types_pb2.DT_COMPLEX64: "complex64",
|
|
489
|
+
types_pb2.DT_COMPLEX128: "complex128",
|
|
490
|
+
types_pb2.DT_INT64: "int64",
|
|
491
|
+
types_pb2.DT_BOOL: "bool",
|
|
492
|
+
types_pb2.DT_QINT8: "qint8",
|
|
493
|
+
types_pb2.DT_QUINT8: "quint8",
|
|
494
|
+
types_pb2.DT_QINT16: "qint16",
|
|
495
|
+
types_pb2.DT_QUINT16: "quint16",
|
|
496
|
+
types_pb2.DT_QINT32: "qint32",
|
|
497
|
+
types_pb2.DT_BFLOAT16: "bfloat16",
|
|
498
|
+
types_pb2.DT_RESOURCE: "resource",
|
|
499
|
+
types_pb2.DT_VARIANT: "variant",
|
|
500
|
+
types_pb2.DT_HALF_REF: "float16_ref",
|
|
501
|
+
types_pb2.DT_FLOAT_REF: "float32_ref",
|
|
502
|
+
types_pb2.DT_DOUBLE_REF: "float64_ref",
|
|
503
|
+
types_pb2.DT_INT32_REF: "int32_ref",
|
|
504
|
+
types_pb2.DT_UINT32_REF: "uint32_ref",
|
|
505
|
+
types_pb2.DT_UINT8_REF: "uint8_ref",
|
|
506
|
+
types_pb2.DT_UINT16_REF: "uint16_ref",
|
|
507
|
+
types_pb2.DT_INT16_REF: "int16_ref",
|
|
508
|
+
types_pb2.DT_INT8_REF: "int8_ref",
|
|
509
|
+
types_pb2.DT_STRING_REF: "string_ref",
|
|
510
|
+
types_pb2.DT_COMPLEX64_REF: "complex64_ref",
|
|
511
|
+
types_pb2.DT_COMPLEX128_REF: "complex128_ref",
|
|
512
|
+
types_pb2.DT_INT64_REF: "int64_ref",
|
|
513
|
+
types_pb2.DT_UINT64_REF: "uint64_ref",
|
|
514
|
+
types_pb2.DT_BOOL_REF: "bool_ref",
|
|
515
|
+
types_pb2.DT_QINT8_REF: "qint8_ref",
|
|
516
|
+
types_pb2.DT_QUINT8_REF: "quint8_ref",
|
|
517
|
+
types_pb2.DT_QINT16_REF: "qint16_ref",
|
|
518
|
+
types_pb2.DT_QUINT16_REF: "quint16_ref",
|
|
519
|
+
types_pb2.DT_QINT32_REF: "qint32_ref",
|
|
520
|
+
types_pb2.DT_BFLOAT16_REF: "bfloat16_ref",
|
|
521
|
+
types_pb2.DT_RESOURCE_REF: "resource_ref",
|
|
522
|
+
types_pb2.DT_VARIANT_REF: "variant_ref",
|
|
523
|
+
}
|
|
524
|
+
_STRING_TO_TF = {
|
|
525
|
+
value: _INTERN_TABLE[key] for key, value in _TYPE_TO_STRING.items()
|
|
526
|
+
}
|
|
527
|
+
# Add non-canonical aliases.
|
|
528
|
+
_STRING_TO_TF["half"] = float16
|
|
529
|
+
_STRING_TO_TF["half_ref"] = float16_ref
|
|
530
|
+
_STRING_TO_TF["float"] = float32
|
|
531
|
+
_STRING_TO_TF["float_ref"] = float32_ref
|
|
532
|
+
_STRING_TO_TF["double"] = float64
|
|
533
|
+
_STRING_TO_TF["double_ref"] = float64_ref
|
|
534
|
+
|
|
535
|
+
# Numpy representation for quantized dtypes.
|
|
536
|
+
#
|
|
537
|
+
# These are magic strings that are used in the swig wrapper to identify
|
|
538
|
+
# quantized types.
|
|
539
|
+
# TODO(mrry,keveman): Investigate Numpy type registration to replace this
|
|
540
|
+
# hard-coding of names.
|
|
541
|
+
_np_qint8 = np.dtype([("qint8", np.int8)])
|
|
542
|
+
_np_quint8 = np.dtype([("quint8", np.uint8)])
|
|
543
|
+
_np_qint16 = np.dtype([("qint16", np.int16)])
|
|
544
|
+
_np_quint16 = np.dtype([("quint16", np.uint16)])
|
|
545
|
+
_np_qint32 = np.dtype([("qint32", np.int32)])
|
|
546
|
+
|
|
547
|
+
# _np_bfloat16 is defined by a module import.
|
|
548
|
+
|
|
549
|
+
# Custom struct dtype for directly-fed ResourceHandles of supported type(s).
|
|
550
|
+
np_resource = np.dtype([("resource", np.ubyte)])
|
|
551
|
+
|
|
552
|
+
# Standard mappings between types_pb2.DataType values and numpy.dtypes.
|
|
553
|
+
_NP_TO_TF = frozenset(
|
|
554
|
+
[
|
|
555
|
+
(np.float16, float16),
|
|
556
|
+
(np.float32, float32),
|
|
557
|
+
(np.float64, float64),
|
|
558
|
+
(np.int32, int32),
|
|
559
|
+
(np.int64, int64),
|
|
560
|
+
(np.uint8, uint8),
|
|
561
|
+
(np.uint16, uint16),
|
|
562
|
+
(np.uint32, uint32),
|
|
563
|
+
(np.uint64, uint64),
|
|
564
|
+
(np.int16, int16),
|
|
565
|
+
(np.int8, int8),
|
|
566
|
+
(np.complex64, complex64),
|
|
567
|
+
(np.complex128, complex128),
|
|
568
|
+
(np.object_, string),
|
|
569
|
+
(np.bool_, bool),
|
|
570
|
+
(_np_qint8, qint8),
|
|
571
|
+
(_np_quint8, quint8),
|
|
572
|
+
(_np_qint16, qint16),
|
|
573
|
+
(_np_quint16, quint16),
|
|
574
|
+
(_np_qint32, qint32),
|
|
575
|
+
# TODO(#1677): _np_bfloat16 is defined as 0. This causes `as_dtype` to
|
|
576
|
+
# error. Add below back after we fix `TF_bfloat16_type`.
|
|
577
|
+
# (_np_bfloat16, bfloat16),
|
|
578
|
+
]
|
|
579
|
+
)
|
|
580
|
+
_TF_TO_NP = {
|
|
581
|
+
types_pb2.DT_HALF: np.float16,
|
|
582
|
+
types_pb2.DT_FLOAT: np.float32,
|
|
583
|
+
types_pb2.DT_DOUBLE: np.float64,
|
|
584
|
+
types_pb2.DT_INT32: np.int32,
|
|
585
|
+
types_pb2.DT_UINT8: np.uint8,
|
|
586
|
+
types_pb2.DT_UINT16: np.uint16,
|
|
587
|
+
types_pb2.DT_UINT32: np.uint32,
|
|
588
|
+
types_pb2.DT_UINT64: np.uint64,
|
|
589
|
+
types_pb2.DT_INT16: np.int16,
|
|
590
|
+
types_pb2.DT_INT8: np.int8,
|
|
591
|
+
# NOTE(touts): For strings we use np.object as it supports variable length
|
|
592
|
+
# strings.
|
|
593
|
+
types_pb2.DT_STRING: np.object_,
|
|
594
|
+
types_pb2.DT_COMPLEX64: np.complex64,
|
|
595
|
+
types_pb2.DT_COMPLEX128: np.complex128,
|
|
596
|
+
types_pb2.DT_INT64: np.int64,
|
|
597
|
+
types_pb2.DT_BOOL: np.bool_,
|
|
598
|
+
types_pb2.DT_QINT8: _np_qint8,
|
|
599
|
+
types_pb2.DT_QUINT8: _np_quint8,
|
|
600
|
+
types_pb2.DT_QINT16: _np_qint16,
|
|
601
|
+
types_pb2.DT_QUINT16: _np_quint16,
|
|
602
|
+
types_pb2.DT_QINT32: _np_qint32,
|
|
603
|
+
types_pb2.DT_BFLOAT16: _np_bfloat16,
|
|
604
|
+
# Ref types
|
|
605
|
+
types_pb2.DT_HALF_REF: np.float16,
|
|
606
|
+
types_pb2.DT_FLOAT_REF: np.float32,
|
|
607
|
+
types_pb2.DT_DOUBLE_REF: np.float64,
|
|
608
|
+
types_pb2.DT_INT32_REF: np.int32,
|
|
609
|
+
types_pb2.DT_UINT32_REF: np.uint32,
|
|
610
|
+
types_pb2.DT_UINT8_REF: np.uint8,
|
|
611
|
+
types_pb2.DT_UINT16_REF: np.uint16,
|
|
612
|
+
types_pb2.DT_INT16_REF: np.int16,
|
|
613
|
+
types_pb2.DT_INT8_REF: np.int8,
|
|
614
|
+
types_pb2.DT_STRING_REF: np.object_,
|
|
615
|
+
types_pb2.DT_COMPLEX64_REF: np.complex64,
|
|
616
|
+
types_pb2.DT_COMPLEX128_REF: np.complex128,
|
|
617
|
+
types_pb2.DT_INT64_REF: np.int64,
|
|
618
|
+
types_pb2.DT_UINT64_REF: np.uint64,
|
|
619
|
+
types_pb2.DT_BOOL_REF: np.bool_,
|
|
620
|
+
types_pb2.DT_QINT8_REF: _np_qint8,
|
|
621
|
+
types_pb2.DT_QUINT8_REF: _np_quint8,
|
|
622
|
+
types_pb2.DT_QINT16_REF: _np_qint16,
|
|
623
|
+
types_pb2.DT_QUINT16_REF: _np_quint16,
|
|
624
|
+
types_pb2.DT_QINT32_REF: _np_qint32,
|
|
625
|
+
types_pb2.DT_BFLOAT16_REF: _np_bfloat16,
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
_QUANTIZED_DTYPES_NO_REF = frozenset([qint8, quint8, qint16, quint16, qint32])
|
|
629
|
+
_QUANTIZED_DTYPES_REF = frozenset(
|
|
630
|
+
[qint8_ref, quint8_ref, qint16_ref, quint16_ref, qint32_ref]
|
|
631
|
+
)
|
|
632
|
+
QUANTIZED_DTYPES = _QUANTIZED_DTYPES_REF.union(_QUANTIZED_DTYPES_NO_REF)
|
|
633
|
+
# tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES")
|
|
634
|
+
|
|
635
|
+
_PYTHON_TO_TF = {float: float32, bool: bool}
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
# @tf_export("as_dtype")
|
|
639
|
+
def as_dtype(type_value):
|
|
640
|
+
"""Converts the given `type_value` to a `DType`.
|
|
641
|
+
|
|
642
|
+
Args:
|
|
643
|
+
type_value: A value that can be converted to a `tf.DType` object. This may
|
|
644
|
+
currently be a `tf.DType` object, a [`DataType`
|
|
645
|
+
enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto),
|
|
646
|
+
a string type name, or a `numpy.dtype`.
|
|
647
|
+
|
|
648
|
+
Returns:
|
|
649
|
+
A `DType` corresponding to `type_value`.
|
|
650
|
+
|
|
651
|
+
Raises:
|
|
652
|
+
TypeError: If `type_value` cannot be converted to a `DType`.
|
|
653
|
+
"""
|
|
654
|
+
if isinstance(type_value, DType):
|
|
655
|
+
return type_value
|
|
656
|
+
|
|
657
|
+
try:
|
|
658
|
+
return _INTERN_TABLE[type_value]
|
|
659
|
+
except KeyError:
|
|
660
|
+
pass
|
|
661
|
+
|
|
662
|
+
try:
|
|
663
|
+
return _STRING_TO_TF[type_value]
|
|
664
|
+
except KeyError:
|
|
665
|
+
pass
|
|
666
|
+
|
|
667
|
+
try:
|
|
668
|
+
return _PYTHON_TO_TF[type_value]
|
|
669
|
+
except KeyError:
|
|
670
|
+
pass
|
|
671
|
+
|
|
672
|
+
if isinstance(type_value, np.dtype):
|
|
673
|
+
# The numpy dtype for strings is variable length. We can not compare
|
|
674
|
+
# dtype with a single constant (np.string does not exist) to decide
|
|
675
|
+
# dtype is a "string" type. We need to compare the dtype.type to be
|
|
676
|
+
# sure it's a string type.
|
|
677
|
+
if type_value.type == np.bytes_ or type_value.type == np.str_:
|
|
678
|
+
return string
|
|
679
|
+
|
|
680
|
+
if isinstance(type_value, (type, np.dtype)):
|
|
681
|
+
for key, val in _NP_TO_TF:
|
|
682
|
+
try:
|
|
683
|
+
if key == type_value:
|
|
684
|
+
return val
|
|
685
|
+
except TypeError as e:
|
|
686
|
+
raise TypeError(
|
|
687
|
+
"Cannot convert {} to a dtype. {}".format(type_value, e)
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
raise TypeError(
|
|
691
|
+
"Cannot convert value %r to a TensorFlow DType." % type_value
|
|
692
|
+
)
|