keras-nightly 3.14.0.dev2026010204__py3-none-any.whl → 3.14.0.dev2026010304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/src/backend/jax/nn.py +26 -9
- keras/src/random/seed_generator.py +2 -2
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2026010204.dist-info → keras_nightly-3.14.0.dev2026010304.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2026010204.dist-info → keras_nightly-3.14.0.dev2026010304.dist-info}/RECORD +7 -7
- {keras_nightly-3.14.0.dev2026010204.dist-info → keras_nightly-3.14.0.dev2026010304.dist-info}/WHEEL +0 -0
- {keras_nightly-3.14.0.dev2026010204.dist-info → keras_nightly-3.14.0.dev2026010304.dist-info}/top_level.txt +0 -0
keras/src/backend/jax/nn.py
CHANGED
|
@@ -1471,25 +1471,42 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False):
|
|
|
1471
1471
|
# Only support at least Ampere
|
|
1472
1472
|
if not check_compute_capability("8.0"):
|
|
1473
1473
|
raise RuntimeError("Require at least Ampere arch to run")
|
|
1474
|
-
|
|
1474
|
+
|
|
1475
|
+
# Inspect inputs of `check_layout`
|
|
1475
1476
|
check_layout_params = list(
|
|
1476
1477
|
inspect.signature(check_layout).parameters.keys()
|
|
1477
1478
|
)
|
|
1478
1479
|
for known_param in ("query", "key", "value", "bias", "layout"):
|
|
1479
1480
|
check_layout_params.remove(known_param)
|
|
1480
1481
|
# Defaults to `None` when not specified.
|
|
1481
|
-
|
|
1482
|
+
check_layout_kwargs = {key: None for key in check_layout_params}
|
|
1482
1483
|
check_layout(
|
|
1483
|
-
query, key, value, bias, layout=_normalize_layout("BTNH"), **kwargs
|
|
1484
|
-
)
|
|
1485
|
-
check_is_flash_attention(
|
|
1486
1484
|
query,
|
|
1487
1485
|
key,
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1486
|
+
value,
|
|
1487
|
+
bias,
|
|
1488
|
+
layout=_normalize_layout("BTNH"),
|
|
1489
|
+
**check_layout_kwargs,
|
|
1492
1490
|
)
|
|
1491
|
+
|
|
1492
|
+
# Inspect inputs of `check_is_flash_attention`
|
|
1493
|
+
check_is_flash_attention_params = inspect.signature(
|
|
1494
|
+
check_is_flash_attention
|
|
1495
|
+
).parameters
|
|
1496
|
+
check_is_flash_attention_kwargs = {
|
|
1497
|
+
"query": query,
|
|
1498
|
+
"key": key,
|
|
1499
|
+
"value": value,
|
|
1500
|
+
"layout": _normalize_layout("BTNH"),
|
|
1501
|
+
"cudnn_version": cudnn_version,
|
|
1502
|
+
"has_bias": bias is not None,
|
|
1503
|
+
"is_training": False,
|
|
1504
|
+
}
|
|
1505
|
+
# Remove unsupported arguments
|
|
1506
|
+
for param in list(check_is_flash_attention_kwargs.keys()):
|
|
1507
|
+
if param not in check_is_flash_attention_params:
|
|
1508
|
+
check_is_flash_attention_kwargs.pop(param)
|
|
1509
|
+
check_is_flash_attention(**check_is_flash_attention_kwargs)
|
|
1493
1510
|
return True
|
|
1494
1511
|
except:
|
|
1495
1512
|
if raise_error:
|
|
@@ -29,7 +29,7 @@ class SeedGenerator:
|
|
|
29
29
|
a local `StateGenerator` with either a deterministic or random initial
|
|
30
30
|
state.
|
|
31
31
|
|
|
32
|
-
Remark concerning the JAX
|
|
32
|
+
Remark concerning the JAX backend: Note that the use of a local
|
|
33
33
|
`StateGenerator` as seed argument is required for JIT compilation of
|
|
34
34
|
RNG with the JAX backend, because the use of global state is not
|
|
35
35
|
supported.
|
|
@@ -111,7 +111,7 @@ class SeedGenerator:
|
|
|
111
111
|
return new_seed_value
|
|
112
112
|
|
|
113
113
|
def get_config(self):
|
|
114
|
-
return {"seed": self._initial_seed}
|
|
114
|
+
return {"seed": self._initial_seed, "name": self.name}
|
|
115
115
|
|
|
116
116
|
@classmethod
|
|
117
117
|
def from_config(cls, config):
|
keras/src/version.py
CHANGED
{keras_nightly-3.14.0.dev2026010204.dist-info → keras_nightly-3.14.0.dev2026010304.dist-info}/RECORD
RENAMED
|
@@ -128,7 +128,7 @@ keras/regularizers/__init__.py,sha256=542Shphw7W8h4Dyf2rmqMKUECVZ8IVBvN9g1LWhz-b
|
|
|
128
128
|
keras/saving/__init__.py,sha256=KvL2GZxjvgFgEhvEnkvqjIR9JSNHKz-NWZacXajsjLI,1298
|
|
129
129
|
keras/src/__init__.py,sha256=Gi4S7EiCMkE03PbdGNpFdaUYySWDs_FcAJ8Taz9Y1BE,684
|
|
130
130
|
keras/src/api_export.py,sha256=gXOkBOnmscV013WAc75lc4Up01-Kkg9EylIAT_QWctg,1173
|
|
131
|
-
keras/src/version.py,sha256=
|
|
131
|
+
keras/src/version.py,sha256=g0zaAy91Gg1cY-Ey6cBCCgMseDoEXzoyCPMqyyGWO-g,204
|
|
132
132
|
keras/src/activations/__init__.py,sha256=0nL3IFDB9unlrMz8ninKOWo-uCHasTUpTo1tXZb2u44,4433
|
|
133
133
|
keras/src/activations/activations.py,sha256=mogPggtp4CGldI3VOPNmesRxp6EbiR1_i4KLGaVwzL8,17614
|
|
134
134
|
keras/src/applications/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -170,7 +170,7 @@ keras/src/backend/jax/image.py,sha256=RiYIalbIaUQdDOGpDZUBk5KNsX94Xqg7iyXGATN9V5
|
|
|
170
170
|
keras/src/backend/jax/layer.py,sha256=o6CicT06udwamTRQIjNSDLZLyYHFzBXNbxewXgWe0iw,308
|
|
171
171
|
keras/src/backend/jax/linalg.py,sha256=LDaLZYz49ChE2kJR3YpaM9xuwusvd3krV7nNAAazTWA,2642
|
|
172
172
|
keras/src/backend/jax/math.py,sha256=1IEDpdoF8e5ltu3D4wbDQuihzvJHhMXz8W9Z_E-eJqU,9391
|
|
173
|
-
keras/src/backend/jax/nn.py,sha256=
|
|
173
|
+
keras/src/backend/jax/nn.py,sha256=mrRawNvf9EWe8rdTwK_Auz6xdLkVG6hH0nIAP7hyUDE,60271
|
|
174
174
|
keras/src/backend/jax/numpy.py,sha256=SMa6dH1n7v04SsnEkevCWBqmzj7Ed8TmBASOSrEQIMM,38619
|
|
175
175
|
keras/src/backend/jax/optimizer.py,sha256=5DeXQHcYmUI6F9i1m1VHn3sBt4LEStOeBXnKdESevLM,4134
|
|
176
176
|
keras/src/backend/jax/random.py,sha256=Uk2huGIk_dlzMrx5eDVrrr2TeCEMitn2vr4yzA0NXjs,3594
|
|
@@ -536,7 +536,7 @@ keras/src/quantizers/quantizers.py,sha256=BDD3vi_15lmOY_ybI7oQDgINYlM9CF0QSQuP6k
|
|
|
536
536
|
keras/src/quantizers/utils.py,sha256=i6e5MobXrQeKA6zFenjzUNoDDWRGL9bcfgdbE_-0IeM,672
|
|
537
537
|
keras/src/random/__init__.py,sha256=BmXVYPzxbhADohoLtAEEzB3cesP7YBFDsp1qc6BWWlg,420
|
|
538
538
|
keras/src/random/random.py,sha256=bUADZIVDuCghwIWTk0qBxXTxUdiNGWIdsRi8QJ3ePg4,17581
|
|
539
|
-
keras/src/random/seed_generator.py,sha256
|
|
539
|
+
keras/src/random/seed_generator.py,sha256=-a0CQa7--Xt0g0nfdjLmUzlFElY9Y838VcCx05AcllY,5655
|
|
540
540
|
keras/src/regularizers/__init__.py,sha256=GzK9FTKL2Xxd5H55GfG9gxDqt4eZoVHFWICgb2VW8qM,1731
|
|
541
541
|
keras/src/regularizers/regularizers.py,sha256=urXNmMGuqHT7lOmS-yQPl3At3Ny-37Xlo389ErCg84A,11799
|
|
542
542
|
keras/src/saving/__init__.py,sha256=vnrtfvnzW7Gwtxe5COhaMoEnVYB5iDe2YlqJ-DvqFIk,614
|
|
@@ -614,7 +614,7 @@ keras/utils/bounding_boxes/__init__.py,sha256=jtvQll4u8ZY0Z96HwNhP1nxWEG9FM3gI-6
|
|
|
614
614
|
keras/utils/legacy/__init__.py,sha256=oSYZz6uS8UxSElRaaJYWJEoweJ4GAasZjnn7fNaOlog,342
|
|
615
615
|
keras/visualization/__init__.py,sha256=UKWmiy6sps4SWlmQi9WX8_Z53cPpLlphz2zIeHdwJpQ,722
|
|
616
616
|
keras/wrappers/__init__.py,sha256=QkS-O5K8qGS7C3sytF8MpmO6PasATpNVGF8qtb7Ojsw,407
|
|
617
|
-
keras_nightly-3.14.0.
|
|
618
|
-
keras_nightly-3.14.0.
|
|
619
|
-
keras_nightly-3.14.0.
|
|
620
|
-
keras_nightly-3.14.0.
|
|
617
|
+
keras_nightly-3.14.0.dev2026010304.dist-info/METADATA,sha256=TRFo4Hl7iWbmC4VRcNIC6z3OZn-UvvN1VOu_XgjYt7I,6339
|
|
618
|
+
keras_nightly-3.14.0.dev2026010304.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
619
|
+
keras_nightly-3.14.0.dev2026010304.dist-info/top_level.txt,sha256=ptcw_-QuGZ4ZDjMdwi_Z0clZm8QAqFdvzzFnDEOTs9o,6
|
|
620
|
+
keras_nightly-3.14.0.dev2026010304.dist-info/RECORD,,
|
{keras_nightly-3.14.0.dev2026010204.dist-info → keras_nightly-3.14.0.dev2026010304.dist-info}/WHEEL
RENAMED
|
File without changes
|
|
File without changes
|