tf-keras-nightly 2.19.0.dev2025020210__py3-none-any.whl → 2.19.0.dev2025020410__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tf_keras/__init__.py +1 -1
- tf_keras/src/layers/normalization/spectral_normalization.py +29 -22
- {tf_keras_nightly-2.19.0.dev2025020210.dist-info → tf_keras_nightly-2.19.0.dev2025020410.dist-info}/METADATA +1 -1
- {tf_keras_nightly-2.19.0.dev2025020210.dist-info → tf_keras_nightly-2.19.0.dev2025020410.dist-info}/RECORD +6 -6
- {tf_keras_nightly-2.19.0.dev2025020210.dist-info → tf_keras_nightly-2.19.0.dev2025020410.dist-info}/WHEEL +0 -0
- {tf_keras_nightly-2.19.0.dev2025020210.dist-info → tf_keras_nightly-2.19.0.dev2025020410.dist-info}/top_level.txt +0 -0
tf_keras/__init__.py
CHANGED
@@ -95,7 +95,7 @@ class SpectralNormalization(Wrapper):
|
|
95
95
|
|
96
96
|
def call(self, inputs, training=False):
|
97
97
|
if training:
|
98
|
-
self.
|
98
|
+
self._update_weights()
|
99
99
|
|
100
100
|
output = self.layer(inputs)
|
101
101
|
return output
|
@@ -105,35 +105,42 @@ class SpectralNormalization(Wrapper):
|
|
105
105
|
self.layer.compute_output_shape(input_shape).as_list()
|
106
106
|
)
|
107
107
|
|
108
|
+
def _update_weights(self):
|
109
|
+
weights = self.kernel
|
110
|
+
vector_u = self.vector_u
|
111
|
+
|
112
|
+
kernel_weights, vector_u = tf.cond(
|
113
|
+
tf.reduce_all(tf.equal(weights, 0)),
|
114
|
+
lambda: (weights, vector_u),
|
115
|
+
lambda: self.normalize_weights(),
|
116
|
+
)
|
117
|
+
self.kernel.assign(kernel_weights)
|
118
|
+
self.vector_u.assign(vector_u)
|
119
|
+
|
108
120
|
def normalize_weights(self):
|
109
121
|
"""Generate spectral normalized weights.
|
110
122
|
|
111
123
|
This method will update the value of `self.kernel` with the
|
112
124
|
spectral normalized value, so that the layer is ready for `call()`.
|
113
125
|
"""
|
114
|
-
|
115
|
-
weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]])
|
126
|
+
# Initialize vector_v to hint the compiler it always exist.
|
116
127
|
vector_u = self.vector_u
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
tf.matmul(vector_u, weights, transpose_b=True)
|
123
|
-
)
|
124
|
-
vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights))
|
125
|
-
vector_u = tf.stop_gradient(vector_u)
|
126
|
-
vector_v = tf.stop_gradient(vector_v)
|
127
|
-
sigma = tf.matmul(
|
128
|
-
tf.matmul(vector_v, weights), vector_u, transpose_b=True
|
129
|
-
)
|
130
|
-
self.vector_u.assign(tf.cast(vector_u, self.vector_u.dtype))
|
131
|
-
self.kernel.assign(
|
132
|
-
tf.cast(
|
133
|
-
tf.reshape(self.kernel / sigma, self.kernel_shape),
|
134
|
-
self.kernel.dtype,
|
135
|
-
)
|
128
|
+
vector_v = self.vector_u
|
129
|
+
weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]])
|
130
|
+
for _ in range(self.power_iterations):
|
131
|
+
vector_v = tf.math.l2_normalize(
|
132
|
+
tf.matmul(vector_u, weights, transpose_b=True)
|
136
133
|
)
|
134
|
+
vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights))
|
135
|
+
vector_u = tf.stop_gradient(vector_u)
|
136
|
+
vector_v = tf.stop_gradient(vector_v)
|
137
|
+
sigma = tf.matmul(
|
138
|
+
tf.matmul(vector_v, weights),
|
139
|
+
vector_u,
|
140
|
+
transpose_b=True,
|
141
|
+
)
|
142
|
+
weights_normalized = tf.reshape(weights / sigma, self.kernel_shape)
|
143
|
+
return weights_normalized, vector_u
|
137
144
|
|
138
145
|
def get_config(self):
|
139
146
|
config = {"power_iterations": self.power_iterations}
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tf_keras/__init__.py,sha256=
|
1
|
+
tf_keras/__init__.py,sha256=2k5wvY93YzoOjWScZm4ph_C597Ec47fCQmi-chjUhIw,911
|
2
2
|
tf_keras/__internal__/__init__.py,sha256=OHQbeIC0QtRBI7dgXaJaVbH8F00x8dCI-DvEcIfyMsE,671
|
3
3
|
tf_keras/__internal__/backend/__init__.py,sha256=LnMs2A6685gDG79fxqmdulIYlVE_3WmXlBTBo9ZWYcw,162
|
4
4
|
tf_keras/__internal__/layers/__init__.py,sha256=F5SGMhOTPzm-PR44VrfinURHcVeQPIEdwnZlAkSTB3A,176
|
@@ -366,7 +366,7 @@ tf_keras/src/layers/normalization/batch_normalization.py,sha256=RdFwlFhXj4i612oy
|
|
366
366
|
tf_keras/src/layers/normalization/batch_normalization_v1.py,sha256=7I8SioqbqZzLvCXGRiiSbbiUeeQsNMfrlils1CEm61Y,1191
|
367
367
|
tf_keras/src/layers/normalization/group_normalization.py,sha256=nqAW5vM96uqBcgF0jea-DkPcHixfbbzC3B2lyFHqNEg,10028
|
368
368
|
tf_keras/src/layers/normalization/layer_normalization.py,sha256=YvZsvSwZsBxP9O7K-f4orTSz69ADiPRIygqLq4kUI7k,14022
|
369
|
-
tf_keras/src/layers/normalization/spectral_normalization.py,sha256=
|
369
|
+
tf_keras/src/layers/normalization/spectral_normalization.py,sha256=38hAYFl_OrntMrbmvqDSFv4gE-T2LcnaipcJ5pHeUpI,5192
|
370
370
|
tf_keras/src/layers/normalization/unit_normalization.py,sha256=zFHpHet8htHl7sLXQJ_nFecyZLU3fMrspq5V8STYgQs,2634
|
371
371
|
tf_keras/src/layers/pooling/__init__.py,sha256=6WvDC0BWmYKwJlurf_1QFRNAHW-kqEy4NI63K4XWzVc,2590
|
372
372
|
tf_keras/src/layers/pooling/average_pooling1d.py,sha256=dIHOp6wvO9JfQ9SzndiElI-oc_TEh4rCNMVBh_zBRB8,4998
|
@@ -606,7 +606,7 @@ tf_keras/src/utils/legacy/__init__.py,sha256=EfMmeHYDzwvxNaktPhQbkTdcPSIGCqMhBND
|
|
606
606
|
tf_keras/utils/__init__.py,sha256=b7_d-USe_EmLo02_P99Q1rUCzKBYayPCfiYFStP-0nw,2735
|
607
607
|
tf_keras/utils/experimental/__init__.py,sha256=DzGogE2AosjxOVILQBT8PDDcqbWTc0wWnZRobCdpcec,97
|
608
608
|
tf_keras/utils/legacy/__init__.py,sha256=7ujlDa5HeSRcth2NdqA0S1P2-VZF1kB3n68jye6Dj-8,189
|
609
|
-
tf_keras_nightly-2.19.0.
|
610
|
-
tf_keras_nightly-2.19.0.
|
611
|
-
tf_keras_nightly-2.19.0.
|
612
|
-
tf_keras_nightly-2.19.0.
|
609
|
+
tf_keras_nightly-2.19.0.dev2025020410.dist-info/METADATA,sha256=Zhq3yMtFTp5s1UpQp9TDA1ku77Aom0UnPxB8HUJcyB4,1857
|
610
|
+
tf_keras_nightly-2.19.0.dev2025020410.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
611
|
+
tf_keras_nightly-2.19.0.dev2025020410.dist-info/top_level.txt,sha256=LC8FK7zHDNKxB17C6lGKvrZ_fZZGJsRiBK23SfiDegY,9
|
612
|
+
tf_keras_nightly-2.19.0.dev2025020410.dist-info/RECORD,,
|
File without changes
|
File without changes
|