tf-keras-nightly 2.19.0.dev2025020210__py3-none-any.whl → 2.19.0.dev2025020410__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tf_keras/__init__.py CHANGED
@@ -27,4 +27,4 @@ from tf_keras.src.engine.sequential import Sequential
27
27
  from tf_keras.src.engine.training import Model
28
28
 
29
29
 
30
- __version__ = "2.19.0.dev2025020210"
30
+ __version__ = "2.19.0.dev2025020410"
@@ -95,7 +95,7 @@ class SpectralNormalization(Wrapper):
95
95
 
96
96
  def call(self, inputs, training=False):
97
97
  if training:
98
- self.normalize_weights()
98
+ self._update_weights()
99
99
 
100
100
  output = self.layer(inputs)
101
101
  return output
@@ -105,35 +105,42 @@ class SpectralNormalization(Wrapper):
105
105
  self.layer.compute_output_shape(input_shape).as_list()
106
106
  )
107
107
 
108
+ def _update_weights(self):
109
+ weights = self.kernel
110
+ vector_u = self.vector_u
111
+
112
+ kernel_weights, vector_u = tf.cond(
113
+ tf.reduce_all(tf.equal(weights, 0)),
114
+ lambda: (weights, vector_u),
115
+ lambda: self.normalize_weights(),
116
+ )
117
+ self.kernel.assign(kernel_weights)
118
+ self.vector_u.assign(vector_u)
119
+
108
120
  def normalize_weights(self):
109
121
  """Generate spectral normalized weights.
110
122
 
111
123
  This method will update the value of `self.kernel` with the
112
124
  spectral normalized value, so that the layer is ready for `call()`.
113
125
  """
114
-
115
- weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]])
126
+ # Initialize vector_v to hint the compiler it always exist.
116
127
  vector_u = self.vector_u
117
-
118
- # check for zeroes weights
119
- if not tf.reduce_all(tf.equal(weights, 0.0)):
120
- for _ in range(self.power_iterations):
121
- vector_v = tf.math.l2_normalize(
122
- tf.matmul(vector_u, weights, transpose_b=True)
123
- )
124
- vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights))
125
- vector_u = tf.stop_gradient(vector_u)
126
- vector_v = tf.stop_gradient(vector_v)
127
- sigma = tf.matmul(
128
- tf.matmul(vector_v, weights), vector_u, transpose_b=True
129
- )
130
- self.vector_u.assign(tf.cast(vector_u, self.vector_u.dtype))
131
- self.kernel.assign(
132
- tf.cast(
133
- tf.reshape(self.kernel / sigma, self.kernel_shape),
134
- self.kernel.dtype,
135
- )
128
+ vector_v = self.vector_u
129
+ weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]])
130
+ for _ in range(self.power_iterations):
131
+ vector_v = tf.math.l2_normalize(
132
+ tf.matmul(vector_u, weights, transpose_b=True)
136
133
  )
134
+ vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights))
135
+ vector_u = tf.stop_gradient(vector_u)
136
+ vector_v = tf.stop_gradient(vector_v)
137
+ sigma = tf.matmul(
138
+ tf.matmul(vector_v, weights),
139
+ vector_u,
140
+ transpose_b=True,
141
+ )
142
+ weights_normalized = tf.reshape(weights / sigma, self.kernel_shape)
143
+ return weights_normalized, vector_u
137
144
 
138
145
  def get_config(self):
139
146
  config = {"power_iterations": self.power_iterations}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: tf_keras-nightly
3
- Version: 2.19.0.dev2025020210
3
+ Version: 2.19.0.dev2025020410
4
4
  Summary: Deep learning for humans.
5
5
  Home-page: https://keras.io/
6
6
  Download-URL: https://github.com/keras-team/tf-keras/tags
@@ -1,4 +1,4 @@
1
- tf_keras/__init__.py,sha256=KyT-o5-RNkDajIiTrrZuaVzlr_-1xRi67kAZISKk_0Q,911
1
+ tf_keras/__init__.py,sha256=2k5wvY93YzoOjWScZm4ph_C597Ec47fCQmi-chjUhIw,911
2
2
  tf_keras/__internal__/__init__.py,sha256=OHQbeIC0QtRBI7dgXaJaVbH8F00x8dCI-DvEcIfyMsE,671
3
3
  tf_keras/__internal__/backend/__init__.py,sha256=LnMs2A6685gDG79fxqmdulIYlVE_3WmXlBTBo9ZWYcw,162
4
4
  tf_keras/__internal__/layers/__init__.py,sha256=F5SGMhOTPzm-PR44VrfinURHcVeQPIEdwnZlAkSTB3A,176
@@ -366,7 +366,7 @@ tf_keras/src/layers/normalization/batch_normalization.py,sha256=RdFwlFhXj4i612oy
366
366
  tf_keras/src/layers/normalization/batch_normalization_v1.py,sha256=7I8SioqbqZzLvCXGRiiSbbiUeeQsNMfrlils1CEm61Y,1191
367
367
  tf_keras/src/layers/normalization/group_normalization.py,sha256=nqAW5vM96uqBcgF0jea-DkPcHixfbbzC3B2lyFHqNEg,10028
368
368
  tf_keras/src/layers/normalization/layer_normalization.py,sha256=YvZsvSwZsBxP9O7K-f4orTSz69ADiPRIygqLq4kUI7k,14022
369
- tf_keras/src/layers/normalization/spectral_normalization.py,sha256=XyxoPHUTJvfFVJagGcaOySeixV6hb53oGx_Fx_fsrhk,4984
369
+ tf_keras/src/layers/normalization/spectral_normalization.py,sha256=38hAYFl_OrntMrbmvqDSFv4gE-T2LcnaipcJ5pHeUpI,5192
370
370
  tf_keras/src/layers/normalization/unit_normalization.py,sha256=zFHpHet8htHl7sLXQJ_nFecyZLU3fMrspq5V8STYgQs,2634
371
371
  tf_keras/src/layers/pooling/__init__.py,sha256=6WvDC0BWmYKwJlurf_1QFRNAHW-kqEy4NI63K4XWzVc,2590
372
372
  tf_keras/src/layers/pooling/average_pooling1d.py,sha256=dIHOp6wvO9JfQ9SzndiElI-oc_TEh4rCNMVBh_zBRB8,4998
@@ -606,7 +606,7 @@ tf_keras/src/utils/legacy/__init__.py,sha256=EfMmeHYDzwvxNaktPhQbkTdcPSIGCqMhBND
606
606
  tf_keras/utils/__init__.py,sha256=b7_d-USe_EmLo02_P99Q1rUCzKBYayPCfiYFStP-0nw,2735
607
607
  tf_keras/utils/experimental/__init__.py,sha256=DzGogE2AosjxOVILQBT8PDDcqbWTc0wWnZRobCdpcec,97
608
608
  tf_keras/utils/legacy/__init__.py,sha256=7ujlDa5HeSRcth2NdqA0S1P2-VZF1kB3n68jye6Dj-8,189
609
- tf_keras_nightly-2.19.0.dev2025020210.dist-info/METADATA,sha256=VQfeHpe_K48x8n1mBvrU1SV8CJrrA1Sno-VA2dRpt7k,1857
610
- tf_keras_nightly-2.19.0.dev2025020210.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
611
- tf_keras_nightly-2.19.0.dev2025020210.dist-info/top_level.txt,sha256=LC8FK7zHDNKxB17C6lGKvrZ_fZZGJsRiBK23SfiDegY,9
612
- tf_keras_nightly-2.19.0.dev2025020210.dist-info/RECORD,,
609
+ tf_keras_nightly-2.19.0.dev2025020410.dist-info/METADATA,sha256=Zhq3yMtFTp5s1UpQp9TDA1ku77Aom0UnPxB8HUJcyB4,1857
610
+ tf_keras_nightly-2.19.0.dev2025020410.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
611
+ tf_keras_nightly-2.19.0.dev2025020410.dist-info/top_level.txt,sha256=LC8FK7zHDNKxB17C6lGKvrZ_fZZGJsRiBK23SfiDegY,9
612
+ tf_keras_nightly-2.19.0.dev2025020410.dist-info/RECORD,,