mttf 1.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mt/keras/__init__.py +8 -0
- mt/keras_src/__init__.py +16 -0
- mt/keras_src/applications_src/__init__.py +33 -0
- mt/keras_src/applications_src/classifier.py +497 -0
- mt/keras_src/applications_src/mobilenet_v3_split.py +544 -0
- mt/keras_src/applications_src/mobilevit.py +292 -0
- mt/keras_src/base.py +28 -0
- mt/keras_src/constraints_src/__init__.py +14 -0
- mt/keras_src/constraints_src/center_around.py +19 -0
- mt/keras_src/layers_src/__init__.py +43 -0
- mt/keras_src/layers_src/counter.py +27 -0
- mt/keras_src/layers_src/floor.py +24 -0
- mt/keras_src/layers_src/identical.py +15 -0
- mt/keras_src/layers_src/image_sizing.py +1605 -0
- mt/keras_src/layers_src/normed_conv2d.py +239 -0
- mt/keras_src/layers_src/simple_mha.py +472 -0
- mt/keras_src/layers_src/soft_bend.py +36 -0
- mt/keras_src/layers_src/transformer_encoder.py +246 -0
- mt/keras_src/layers_src/utils.py +88 -0
- mt/keras_src/layers_src/var_regularizer.py +38 -0
- mt/tf/__init__.py +10 -0
- mt/tf/init.py +25 -0
- mt/tf/keras_applications/__init__.py +5 -0
- mt/tf/keras_layers/__init__.py +5 -0
- mt/tf/mttf_version.py +5 -0
- mt/tf/utils.py +44 -0
- mt/tf/version.py +5 -0
- mt/tfc/__init__.py +291 -0
- mt/tfg/__init__.py +8 -0
- mt/tfp/__init__.py +11 -0
- mt/tfp/real_nvp.py +116 -0
- mttf-1.3.6.data/scripts/dmt_build_package_and_upload_to_nexus.sh +25 -0
- mttf-1.3.6.data/scripts/dmt_pipi.sh +7 -0
- mttf-1.3.6.data/scripts/dmt_twineu.sh +2 -0
- mttf-1.3.6.data/scripts/pipi.sh +7 -0
- mttf-1.3.6.data/scripts/user_build_package_and_upload_to_nexus.sh +25 -0
- mttf-1.3.6.data/scripts/user_pipi.sh +8 -0
- mttf-1.3.6.data/scripts/user_twineu.sh +3 -0
- mttf-1.3.6.data/scripts/wml_build_package_and_upload_to_nexus.sh +25 -0
- mttf-1.3.6.data/scripts/wml_nexus.py +50 -0
- mttf-1.3.6.data/scripts/wml_pipi.sh +7 -0
- mttf-1.3.6.data/scripts/wml_twineu.sh +2 -0
- mttf-1.3.6.dist-info/METADATA +18 -0
- mttf-1.3.6.dist-info/RECORD +47 -0
- mttf-1.3.6.dist-info/WHEEL +5 -0
- mttf-1.3.6.dist-info/licenses/LICENSE +21 -0
- mttf-1.3.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
# pylint: disable=invalid-name
|
|
2
|
+
# pylint: disable=missing-function-docstring
|
|
3
|
+
"""MobileViT model.
|
|
4
|
+
|
|
5
|
+
Most of the code here has been ripped and updated off from the following
|
|
6
|
+
`Keras tutorial <https://keras.io/examples/vision/mobilevit/>`_. Please refer
|
|
7
|
+
to the `MobileViT ICLR2022 paper <https://arxiv.org/abs/2110.02178>`_ for more details.
|
|
8
|
+
|
|
9
|
+
The paper authors' code is `here <https://github.com/apple/ml-cvnets>`_.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
import tensorflow as tf
|
|
14
|
+
from mt import tp, tfc
|
|
15
|
+
|
|
16
|
+
from .mobilenet_v3_split import (
|
|
17
|
+
MobileNetV3Input,
|
|
18
|
+
_inverted_res_block,
|
|
19
|
+
backend,
|
|
20
|
+
models,
|
|
21
|
+
layers,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def conv_block(x, filters=16, kernel_size=3, strides=2):
|
|
26
|
+
conv_layer = layers.Conv2D(
|
|
27
|
+
filters, kernel_size, strides=strides, activation=tf.nn.swish, padding="same"
|
|
28
|
+
)
|
|
29
|
+
return conv_layer(x)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# Reference: https://git.io/JKgtC
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def inverted_residual_block(
|
|
36
|
+
x, expanded_channels, output_channels, strides=1, block_id=0
|
|
37
|
+
):
|
|
38
|
+
if block_id == 0:
|
|
39
|
+
raise NotImplementedError(
|
|
40
|
+
"Zero block id for _inverted_res_block() is not implemented in MobileViT."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
channel_axis = 1 if backend.image_data_format() == "channels_first" else -1
|
|
44
|
+
infilters = backend.int_shape(x)[channel_axis]
|
|
45
|
+
|
|
46
|
+
m = _inverted_res_block(
|
|
47
|
+
x,
|
|
48
|
+
expanded_channels // infilters, # expansion
|
|
49
|
+
output_channels, # filters
|
|
50
|
+
3, # kernel_size
|
|
51
|
+
strides, # stride
|
|
52
|
+
0, # se_ratio
|
|
53
|
+
tf.nn.swish, # activation
|
|
54
|
+
block_id,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return m
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# Reference:
|
|
61
|
+
# https://keras.io/examples/vision/image_classification_with_vision_transformer/
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def mlp(x, hidden_units, dropout_rate):
|
|
65
|
+
for units in hidden_units:
|
|
66
|
+
x = layers.Dense(units, activation=tf.nn.swish)(x)
|
|
67
|
+
x = layers.Dropout(dropout_rate)(x)
|
|
68
|
+
return x
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
|
|
72
|
+
for _ in range(transformer_layers):
|
|
73
|
+
# Layer normalization 1.
|
|
74
|
+
x1 = layers.LayerNormalization(epsilon=1e-6)(x)
|
|
75
|
+
# Create a multi-head attention layer.
|
|
76
|
+
attention_output = layers.MultiHeadAttention(
|
|
77
|
+
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
|
|
78
|
+
)(x1, x1)
|
|
79
|
+
# Skip connection 1.
|
|
80
|
+
x2 = layers.Add()([attention_output, x])
|
|
81
|
+
# Layer normalization 2.
|
|
82
|
+
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
|
|
83
|
+
# MLP.
|
|
84
|
+
x3 = mlp(
|
|
85
|
+
x3,
|
|
86
|
+
hidden_units=[x.shape[-1] * 2, x.shape[-1]],
|
|
87
|
+
dropout_rate=0.1,
|
|
88
|
+
)
|
|
89
|
+
# Skip connection 2.
|
|
90
|
+
x = layers.Add()([x3, x2])
|
|
91
|
+
|
|
92
|
+
return x
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def mobilevit_block(x, num_blocks, projection_dim, strides=1):
|
|
96
|
+
cell_size = 2 # 2x2 for the Transformer block
|
|
97
|
+
|
|
98
|
+
# Local projection with convolutions.
|
|
99
|
+
local_features = conv_block(x, filters=projection_dim, strides=strides)
|
|
100
|
+
local_features = conv_block(
|
|
101
|
+
local_features, filters=projection_dim, kernel_size=1, strides=strides
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
if x.shape[1] % cell_size != 0:
|
|
105
|
+
raise tfc.ModelSyntaxError(
|
|
106
|
+
f"Input tensor must have height divisible by {cell_size}. Got {x.shape}."
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
if x.shape[2] % cell_size != 0:
|
|
110
|
+
raise tfc.ModelSyntaxError(
|
|
111
|
+
f"Input tensor must have width divisible by {cell_size}. Got {x.shape}."
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Unfold into patches and then pass through Transformers.
|
|
115
|
+
z = local_features # (B,H,W,C)
|
|
116
|
+
z = layers.Reshape(
|
|
117
|
+
(
|
|
118
|
+
z.shape[1] // cell_size,
|
|
119
|
+
cell_size,
|
|
120
|
+
z.shape[2] // cell_size,
|
|
121
|
+
cell_size,
|
|
122
|
+
projection_dim,
|
|
123
|
+
)
|
|
124
|
+
)(
|
|
125
|
+
z
|
|
126
|
+
) # (B,H/P,P,W/P,P,C)
|
|
127
|
+
z = tf.transpose(z, perm=[0, 2, 4, 1, 3, 5]) # (B,P,P,H/P,W/P,C)
|
|
128
|
+
non_overlapping_patches = layers.Reshape(
|
|
129
|
+
(cell_size * cell_size, z.shape[3] * z.shape[4], projection_dim)
|
|
130
|
+
)(
|
|
131
|
+
z
|
|
132
|
+
) # (B,P*P,H*W/(P*P),C)
|
|
133
|
+
global_features = transformer_block(
|
|
134
|
+
non_overlapping_patches, num_blocks, projection_dim
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Fold into conv-like feature-maps.
|
|
138
|
+
z = layers.Reshape(
|
|
139
|
+
(
|
|
140
|
+
cell_size,
|
|
141
|
+
cell_size,
|
|
142
|
+
x.shape[1] // cell_size,
|
|
143
|
+
x.shape[2] // cell_size,
|
|
144
|
+
projection_dim,
|
|
145
|
+
)
|
|
146
|
+
)(
|
|
147
|
+
global_features
|
|
148
|
+
) # (B,P,P,H/P,W/P,C)
|
|
149
|
+
z = tf.transpose(z, perm=[0, 3, 1, 4, 2, 5]) # (B,H/P,P,W/P,P,C)
|
|
150
|
+
folded_feature_map = layers.Reshape((x.shape[1], x.shape[2], projection_dim))(z)
|
|
151
|
+
|
|
152
|
+
# Apply point-wise conv -> concatenate with the input features.
|
|
153
|
+
folded_feature_map = conv_block(
|
|
154
|
+
folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
|
|
155
|
+
)
|
|
156
|
+
local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map])
|
|
157
|
+
|
|
158
|
+
# Fuse the local and global features using a convolution layer.
|
|
159
|
+
local_global_features = conv_block(
|
|
160
|
+
local_global_features, filters=projection_dim, strides=strides
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return local_global_features
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def create_mobilevit(
|
|
167
|
+
input_shape=None,
|
|
168
|
+
model_type: str = "XXS",
|
|
169
|
+
output_all: bool = False,
|
|
170
|
+
name: tp.Optional[str] = None,
|
|
171
|
+
):
|
|
172
|
+
"""Prepares a model of submodels which is equivalent to a MobileNetV3 model.
|
|
173
|
+
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
input_shape : tuple
|
|
177
|
+
Optional shape tuple, to be specified if you would like to use a model with an input image
|
|
178
|
+
resolution that is not (224, 224, 3). It should have exactly 3 inputs channels
|
|
179
|
+
(224, 224, 3). You can also omit this option if you would like to infer input_shape from an
|
|
180
|
+
input_tensor. If you choose to include both input_tensor and input_shape then input_shape
|
|
181
|
+
will be used if they match, if the shapes do not match then we will throw an error. E.g.
|
|
182
|
+
`(160, 160, 3)` would be one valid value.
|
|
183
|
+
model_type : {'XXS', 'XS', 'S'}
|
|
184
|
+
one of the 3 variants introduced in the paper
|
|
185
|
+
output_all : bool
|
|
186
|
+
If True, the model returns the output tensor of every block before down-sampling, other
|
|
187
|
+
than the input layer. Otherwise, it returns the output tensor of the last block.
|
|
188
|
+
name : str, optional
|
|
189
|
+
model name, if any. Default to 'MobileViT<model_type>'.
|
|
190
|
+
|
|
191
|
+
Returns
|
|
192
|
+
-------
|
|
193
|
+
tensorflow.keras.Model
|
|
194
|
+
the output MobileViT model
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
model_type_id = ["XXS", "XS", "S"].index(model_type)
|
|
198
|
+
|
|
199
|
+
expansion_factor = 2 if model_type_id == 0 else 4
|
|
200
|
+
|
|
201
|
+
inputs = MobileNetV3Input(input_shape=input_shape)
|
|
202
|
+
x = layers.Rescaling(scale=1.0 / 255)(inputs)
|
|
203
|
+
|
|
204
|
+
# Initial conv-stem -> MV2 block.
|
|
205
|
+
x = conv_block(x, filters=16)
|
|
206
|
+
x = inverted_residual_block(
|
|
207
|
+
x,
|
|
208
|
+
expanded_channels=16 * expansion_factor,
|
|
209
|
+
output_channels=16 if model_type_id == 0 else 32,
|
|
210
|
+
block_id=1,
|
|
211
|
+
)
|
|
212
|
+
outputs = [x]
|
|
213
|
+
|
|
214
|
+
# Downsampling with MV2 block.
|
|
215
|
+
output_channels = [24, 48, 64][model_type_id]
|
|
216
|
+
x = inverted_residual_block(
|
|
217
|
+
x,
|
|
218
|
+
expanded_channels=16 * expansion_factor,
|
|
219
|
+
output_channels=output_channels,
|
|
220
|
+
strides=2,
|
|
221
|
+
block_id=2,
|
|
222
|
+
)
|
|
223
|
+
x = inverted_residual_block(
|
|
224
|
+
x,
|
|
225
|
+
expanded_channels=24 * expansion_factor,
|
|
226
|
+
output_channels=output_channels,
|
|
227
|
+
block_id=3,
|
|
228
|
+
)
|
|
229
|
+
x = inverted_residual_block(
|
|
230
|
+
x,
|
|
231
|
+
expanded_channels=24 * expansion_factor,
|
|
232
|
+
output_channels=output_channels,
|
|
233
|
+
block_id=4,
|
|
234
|
+
)
|
|
235
|
+
if output_all:
|
|
236
|
+
outputs.append(x)
|
|
237
|
+
else:
|
|
238
|
+
outputs = [x]
|
|
239
|
+
|
|
240
|
+
# First MV2 -> MobileViT block.
|
|
241
|
+
output_channels = [48, 64, 96][model_type_id]
|
|
242
|
+
projection_dim = [64, 96, 144][model_type_id]
|
|
243
|
+
x = inverted_residual_block(
|
|
244
|
+
x,
|
|
245
|
+
expanded_channels=48 * expansion_factor,
|
|
246
|
+
output_channels=output_channels,
|
|
247
|
+
strides=2,
|
|
248
|
+
block_id=5,
|
|
249
|
+
)
|
|
250
|
+
x = mobilevit_block(x, num_blocks=2, projection_dim=projection_dim)
|
|
251
|
+
if output_all:
|
|
252
|
+
outputs.append(x)
|
|
253
|
+
else:
|
|
254
|
+
outputs = [x]
|
|
255
|
+
|
|
256
|
+
# Second MV2 -> MobileViT block.
|
|
257
|
+
output_channels = [64, 80, 128][model_type_id]
|
|
258
|
+
projection_dim = [80, 120, 192][model_type_id]
|
|
259
|
+
x = inverted_residual_block(
|
|
260
|
+
x,
|
|
261
|
+
expanded_channels=64 * expansion_factor,
|
|
262
|
+
output_channels=output_channels,
|
|
263
|
+
strides=2,
|
|
264
|
+
block_id=6,
|
|
265
|
+
)
|
|
266
|
+
x = mobilevit_block(x, num_blocks=4, projection_dim=projection_dim)
|
|
267
|
+
if output_all:
|
|
268
|
+
outputs.append(x)
|
|
269
|
+
else:
|
|
270
|
+
outputs = [x]
|
|
271
|
+
|
|
272
|
+
# Third MV2 -> MobileViT block.
|
|
273
|
+
output_channels = [80, 96, 160][model_type_id]
|
|
274
|
+
projection_dim = [96, 144, 240][model_type_id]
|
|
275
|
+
x = inverted_residual_block(
|
|
276
|
+
x,
|
|
277
|
+
expanded_channels=80 * expansion_factor,
|
|
278
|
+
output_channels=output_channels,
|
|
279
|
+
strides=2,
|
|
280
|
+
block_id=7,
|
|
281
|
+
)
|
|
282
|
+
x = mobilevit_block(x, num_blocks=3, projection_dim=projection_dim)
|
|
283
|
+
filters = [320, 384, 640][model_type_id]
|
|
284
|
+
x = conv_block(x, filters=filters, kernel_size=1, strides=1)
|
|
285
|
+
if output_all:
|
|
286
|
+
outputs.append(x)
|
|
287
|
+
else:
|
|
288
|
+
outputs = [x]
|
|
289
|
+
|
|
290
|
+
if name is None:
|
|
291
|
+
name = f"MobileViT{model_type}"
|
|
292
|
+
return models.Model(inputs, outputs, name=name)
|
mt/keras_src/base.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Determines the working Keras 2 from the system to be used by mt.keras."""
|
|
2
|
+
|
|
3
|
+
from packaging.version import Version
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
|
|
6
|
+
tf_ver = Version(tf.__version__)
|
|
7
|
+
if tf_ver >= Version("2.16"):
|
|
8
|
+
try:
|
|
9
|
+
import tf_keras
|
|
10
|
+
except:
|
|
11
|
+
raise ImportError(
|
|
12
|
+
f"mt.keras can only work with Keras 2. You have TF version {tf_ver}. Please install tf_keras."
|
|
13
|
+
)
|
|
14
|
+
keras_version = tf_keras.__version__
|
|
15
|
+
keras_source = "tf_keras"
|
|
16
|
+
else:
|
|
17
|
+
try:
|
|
18
|
+
import keras
|
|
19
|
+
|
|
20
|
+
kr_ver = Version(keras.__version__)
|
|
21
|
+
except ImportError:
|
|
22
|
+
kr_ver = None
|
|
23
|
+
if kr_ver is None or kr_ver >= Version("3.0"):
|
|
24
|
+
keras_version = tf.__version__
|
|
25
|
+
keras_source = "tensorflow.python"
|
|
26
|
+
else:
|
|
27
|
+
keras_version = keras.__version__
|
|
28
|
+
keras_source = "keras"
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from .. import constraints as _constraints
|
|
2
|
+
|
|
3
|
+
for _x, _y in _constraints.__dict__.items():
|
|
4
|
+
if _x.startswith("_"):
|
|
5
|
+
continue
|
|
6
|
+
globals()[_x] = _y
|
|
7
|
+
__doc__ = _constraints.__doc__
|
|
8
|
+
|
|
9
|
+
from .center_around import *
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
__api__ = [
|
|
13
|
+
"CenterAround",
|
|
14
|
+
]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .. import constraints
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class CenterAround(constraints.Constraint):
|
|
5
|
+
"""Constrains the last axis to have values centered around `ref_value`."""
|
|
6
|
+
|
|
7
|
+
def __init__(self, ref_value: float = 0.0):
|
|
8
|
+
self.ref_value = ref_value
|
|
9
|
+
|
|
10
|
+
def __call__(self, w):
|
|
11
|
+
import tensorflow as tf
|
|
12
|
+
|
|
13
|
+
mean = tf.reduce_mean(w, axis=-1, keepdims=True)
|
|
14
|
+
ref_mean = mean - self.ref_value
|
|
15
|
+
ref_mean = tf.expand_dims(ref_mean, -1)
|
|
16
|
+
return w - ref_mean
|
|
17
|
+
|
|
18
|
+
def get_config(self):
|
|
19
|
+
return {"ref_value": self.ref_value}
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from .. import layers as _layers
|
|
2
|
+
|
|
3
|
+
for _x, _y in _layers.__dict__.items():
|
|
4
|
+
if _x.startswith("_"):
|
|
5
|
+
continue
|
|
6
|
+
globals()[_x] = _y
|
|
7
|
+
__doc__ = _layers.__doc__
|
|
8
|
+
|
|
9
|
+
from .identical import *
|
|
10
|
+
from .floor import *
|
|
11
|
+
from .var_regularizer import *
|
|
12
|
+
from .simple_mha import *
|
|
13
|
+
from .image_sizing import *
|
|
14
|
+
from .counter import Counter
|
|
15
|
+
from .normed_conv2d import NormedConv2D
|
|
16
|
+
from .soft_bend import SoftBend
|
|
17
|
+
from .transformer_encoder import MTTransformerEncoder
|
|
18
|
+
from .utils import *
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
__api__ = [
|
|
22
|
+
"Identical",
|
|
23
|
+
"Floor",
|
|
24
|
+
"VarianceRegularizer",
|
|
25
|
+
"SimpleMHA2D",
|
|
26
|
+
"MHAPool2D",
|
|
27
|
+
"DUCLayer",
|
|
28
|
+
"Downsize2D",
|
|
29
|
+
"Upsize2D",
|
|
30
|
+
"Downsize2D_V2",
|
|
31
|
+
"Upsize2D_V2",
|
|
32
|
+
"Downsize2D_V3",
|
|
33
|
+
"Downsize2D_V4",
|
|
34
|
+
"DownsizeX2D",
|
|
35
|
+
"UpsizeX2D",
|
|
36
|
+
"DownsizeY2D",
|
|
37
|
+
"UpsizeY2D",
|
|
38
|
+
"Counter",
|
|
39
|
+
"conv2d",
|
|
40
|
+
"dense2d",
|
|
41
|
+
"SoftBend",
|
|
42
|
+
"MTTransformerEncoder",
|
|
43
|
+
]
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import tensorflow as tf
|
|
2
|
+
from .. import layers, initializers
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Counter(layers.Layer):
|
|
6
|
+
"""A layer that counts from 0 during training and does nothing during inference."""
|
|
7
|
+
|
|
8
|
+
def build(self, input_shape):
|
|
9
|
+
initializer = initializers.Constant(value=0.0)
|
|
10
|
+
self.counter = self.add_weight(
|
|
11
|
+
name="counter", shape=(1,), initializer=initializer
|
|
12
|
+
)
|
|
13
|
+
self.incrementor = tf.constant([1.0])
|
|
14
|
+
|
|
15
|
+
def call(self, x, training: bool = False):
|
|
16
|
+
if training:
|
|
17
|
+
self.counter.assign_add(self.incrementor)
|
|
18
|
+
y = tf.reshape(x, [-1])[:1]
|
|
19
|
+
y = tf.stop_gradient(y) * 0.0
|
|
20
|
+
return self.counter + y
|
|
21
|
+
|
|
22
|
+
call.__doc__ = layers.Layer.call.__doc__
|
|
23
|
+
|
|
24
|
+
def compute_output_shape(self, input_shape):
|
|
25
|
+
return (1,)
|
|
26
|
+
|
|
27
|
+
compute_output_shape.__doc__ = layers.Layer.compute_output_shape.__doc__
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import tensorflow as tf
|
|
2
|
+
from .. import layers
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@tf.custom_gradient
|
|
6
|
+
def floor(x):
|
|
7
|
+
def grad(upstream): # identity
|
|
8
|
+
return upstream
|
|
9
|
+
|
|
10
|
+
return tf.math.floor(x), grad
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Floor(layers.Layer):
|
|
14
|
+
"""TensorFlow floor but gradient is identity."""
|
|
15
|
+
|
|
16
|
+
def call(self, x):
|
|
17
|
+
return floor(x)
|
|
18
|
+
|
|
19
|
+
call.__doc__ = layers.Layer.call.__doc__
|
|
20
|
+
|
|
21
|
+
def compute_output_shape(self, input_shape):
|
|
22
|
+
return input_shape
|
|
23
|
+
|
|
24
|
+
compute_output_shape.__doc__ = layers.Layer.compute_output_shape.__doc__
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .. import layers
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Identical(layers.Layer):
|
|
5
|
+
"""An identical layer, mainly for renaming purposes."""
|
|
6
|
+
|
|
7
|
+
def call(self, x):
|
|
8
|
+
return x
|
|
9
|
+
|
|
10
|
+
call.__doc__ = layers.Layer.call.__doc__
|
|
11
|
+
|
|
12
|
+
def compute_output_shape(self, input_shape):
|
|
13
|
+
return input_shape
|
|
14
|
+
|
|
15
|
+
compute_output_shape.__doc__ = layers.Layer.compute_output_shape.__doc__
|