torchax 0.0.6__py3-none-any.whl → 0.0.10.dev20251116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +10 -5
- torchax/__init__.py +92 -65
- torchax/amp.py +14 -0
- torchax/checkpoint.py +79 -0
- torchax/config.py +14 -0
- torchax/decompositions.py +14 -0
- torchax/device_module.py +14 -0
- torchax/export.py +14 -0
- torchax/flax.py +14 -0
- torchax/interop.py +44 -31
- torchax/mesh_util.py +14 -0
- torchax/ops/__init__.py +14 -0
- torchax/ops/jaten.py +3985 -3686
- torchax/ops/jax_reimplement.py +14 -0
- torchax/ops/jc10d.py +14 -0
- torchax/ops/jimage.py +14 -0
- torchax/ops/jlibrary.py +14 -0
- torchax/ops/jtorch.py +364 -309
- torchax/ops/jtorchvision_nms.py +14 -0
- torchax/ops/mappings.py +26 -4
- torchax/ops/op_base.py +14 -0
- torchax/ops/ops_registry.py +14 -0
- torchax/tensor.py +38 -13
- torchax/train.py +112 -97
- torchax/types.py +14 -0
- torchax/util.py +14 -0
- torchax/view.py +14 -0
- torchax-0.0.10.dev20251116.dist-info/METADATA +507 -0
- torchax-0.0.10.dev20251116.dist-info/RECORD +31 -0
- torchax-0.0.10.dev20251116.dist-info/licenses/LICENSE +201 -0
- torchax/configuration.py +0 -30
- torchax/environment.py +0 -1
- torchax/tf_integration.py +0 -119
- torchax-0.0.6.dist-info/METADATA +0 -307
- torchax-0.0.6.dist-info/RECORD +0 -33
- torchax-0.0.6.dist-info/licenses/LICENSE +0 -28
- {torchax-0.0.6.dist-info → torchax-0.0.10.dev20251116.dist-info}/WHEEL +0 -0
torchax/ops/jtorchvision_nms.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
"""
|
|
2
16
|
Forked at: https://raw.githubusercontent.com/mlperf/training_results_v0.7/refs/heads/master/Google/benchmarks/ssd/implementations/ssd-research-JAX-tpu-v3-4096/nms.py
|
|
3
17
|
"""
|
torchax/ops/mappings.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from jax import dlpack as jaxdl
|
|
2
16
|
import jax.numpy as jnp
|
|
3
17
|
import numpy
|
|
@@ -6,6 +20,14 @@ import torch.func
|
|
|
6
20
|
import torch.utils.dlpack as torchdl
|
|
7
21
|
import torch.utils._mode_utils as mode_utils
|
|
8
22
|
|
|
23
|
+
NUMPY_UNSUPPORTED_DTYPES = {
|
|
24
|
+
torch.bfloat16: jnp.bfloat16,
|
|
25
|
+
torch.float8_e4m3fn: jnp.float8_e4m3fn,
|
|
26
|
+
torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz,
|
|
27
|
+
torch.float8_e5m2: jnp.float8_e5m2,
|
|
28
|
+
torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz,
|
|
29
|
+
}
|
|
30
|
+
|
|
9
31
|
|
|
10
32
|
def t2j(t, use_dlpack=True):
|
|
11
33
|
is_bool = False
|
|
@@ -28,14 +50,14 @@ def t2j(t, use_dlpack=True):
|
|
|
28
50
|
if res is None:
|
|
29
51
|
# https://github.com/google/jax/issues/7657
|
|
30
52
|
# https://github.com/google/jax/issues/17784
|
|
31
|
-
if t.dtype
|
|
53
|
+
if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
|
|
32
54
|
nparray = (t.cpu().detach().to(torch.float32).numpy()
|
|
33
|
-
) #
|
|
55
|
+
) # handle dtypes not supported by numpy
|
|
34
56
|
else:
|
|
35
57
|
nparray = t.cpu().detach().numpy()
|
|
36
58
|
res = jnp.asarray(nparray)
|
|
37
|
-
if t.dtype
|
|
38
|
-
res = res.astype(
|
|
59
|
+
if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
|
|
60
|
+
res = res.astype(NUMPY_UNSUPPORTED_DTYPES[t.dtype])
|
|
39
61
|
|
|
40
62
|
if is_bool:
|
|
41
63
|
res = res.astype(jnp.bool_)
|
torchax/ops/op_base.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
16
|
import jax
|
|
3
17
|
import jax.numpy as jnp
|
torchax/ops/ops_registry.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import dataclasses
|
|
2
16
|
import logging
|
|
3
17
|
from torchax.types import JaxCallable, TorchCallable
|
torchax/tensor.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import threading
|
|
2
16
|
import logging
|
|
3
17
|
import sys
|
|
@@ -357,16 +371,26 @@ class Environment(contextlib.ContextDecorator):
|
|
|
357
371
|
|
|
358
372
|
_prng_key = jax.random.key(torch.initial_seed() % (1 << 63))
|
|
359
373
|
self._property = threading.local()
|
|
360
|
-
self.
|
|
361
|
-
|
|
362
|
-
mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype)
|
|
363
|
-
]
|
|
374
|
+
self._initial_content = RuntimeProperty(
|
|
375
|
+
mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype)
|
|
364
376
|
|
|
365
377
|
@property
|
|
366
378
|
def param(self):
|
|
379
|
+
if not hasattr(self._property, 'content'):
|
|
380
|
+
self._property.content = [
|
|
381
|
+
self._initial_content
|
|
382
|
+
]
|
|
367
383
|
return self._property.content[-1]
|
|
368
384
|
|
|
369
385
|
def manual_seed(self, key):
|
|
386
|
+
if isinstance(key, torch.Tensor):
|
|
387
|
+
assert key.ndim == 0, 'manual seed can only take scalars'
|
|
388
|
+
assert not key.dtype.is_floating_point, 'manual seed can only be integers'
|
|
389
|
+
|
|
390
|
+
if isinstance(key, Tensor):
|
|
391
|
+
key = key._elem
|
|
392
|
+
else:
|
|
393
|
+
key = key.item()
|
|
370
394
|
jax_key = jax.random.PRNGKey(key)
|
|
371
395
|
new_prop = self.param.override(prng=jax_key)
|
|
372
396
|
self._property.content.append(new_prop)
|
|
@@ -469,12 +493,12 @@ class Environment(contextlib.ContextDecorator):
|
|
|
469
493
|
arr = self.t2j_copy(the_tensor)
|
|
470
494
|
res = Tensor(arr, self, the_tensor.requires_grad)
|
|
471
495
|
|
|
472
|
-
if new_dtype is not None and new_dtype !=
|
|
473
|
-
if isinstance(
|
|
496
|
+
if new_dtype is not None and new_dtype != res.dtype:
|
|
497
|
+
if isinstance(res, Tensor):
|
|
474
498
|
res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype))
|
|
475
499
|
else:
|
|
476
500
|
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
477
|
-
return
|
|
501
|
+
return res.to(device=new_device, dtype=new_dtype)
|
|
478
502
|
return res
|
|
479
503
|
|
|
480
504
|
def get_and_rotate_prng_key(self,
|
|
@@ -634,14 +658,14 @@ class Environment(contextlib.ContextDecorator):
|
|
|
634
658
|
|
|
635
659
|
def t2j_iso(self, torchtensors):
|
|
636
660
|
"""Convert torchax Tensor to jax array.
|
|
637
|
-
|
|
661
|
+
|
|
638
662
|
This function will not copy, will just unwrap the inner jax array out.
|
|
639
663
|
Note: iso is short for "isomorphic"
|
|
640
664
|
"""
|
|
641
665
|
|
|
642
666
|
def to_jax(x):
|
|
643
667
|
if self.config.allow_mixed_math_with_scalar_tensor and not isinstance(
|
|
644
|
-
x, Tensor):
|
|
668
|
+
x, Tensor) and not isinstance(x, View):
|
|
645
669
|
if x.squeeze().ndim == 0:
|
|
646
670
|
return x.item()
|
|
647
671
|
if isinstance(
|
|
@@ -667,7 +691,7 @@ class Environment(contextlib.ContextDecorator):
|
|
|
667
691
|
|
|
668
692
|
def j2t_iso(self, jaxarray):
|
|
669
693
|
"""Convert jax array to torchax Tensor.
|
|
670
|
-
|
|
694
|
+
|
|
671
695
|
This function will not copy, will just wrap the jax array with a torchax Tensor
|
|
672
696
|
Note: iso is short for "isomorphic"
|
|
673
697
|
"""
|
|
@@ -676,7 +700,7 @@ class Environment(contextlib.ContextDecorator):
|
|
|
676
700
|
|
|
677
701
|
def j2t_copy(self, args):
|
|
678
702
|
"""Convert torch.Tensor in cpu to a jax array
|
|
679
|
-
|
|
703
|
+
|
|
680
704
|
This might involves copying the data (depending if dlpack is enabled)
|
|
681
705
|
"""
|
|
682
706
|
return torch_pytree.tree_map_only(
|
|
@@ -686,7 +710,7 @@ class Environment(contextlib.ContextDecorator):
|
|
|
686
710
|
|
|
687
711
|
def t2j_copy(self, args):
|
|
688
712
|
"""Convert jax array to torch.Tensor in cpu.
|
|
689
|
-
|
|
713
|
+
|
|
690
714
|
This might involves copying the data (depending if dlpack is enabled)
|
|
691
715
|
"""
|
|
692
716
|
return torch_pytree.tree_map_only(
|
|
@@ -694,13 +718,14 @@ class Environment(contextlib.ContextDecorator):
|
|
|
694
718
|
lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion),
|
|
695
719
|
args)
|
|
696
720
|
|
|
697
|
-
def override_op_definition(self, op_to_override, op_impl):
|
|
721
|
+
def override_op_definition(self, op_to_override, op_impl, is_view_op=False):
|
|
698
722
|
self._ops[op_to_override] = ops_registry.Operator(
|
|
699
723
|
op_to_override,
|
|
700
724
|
op_impl,
|
|
701
725
|
is_jax_function=False,
|
|
702
726
|
is_user_defined=True,
|
|
703
727
|
needs_env=False,
|
|
728
|
+
is_view_op=is_view_op,
|
|
704
729
|
)
|
|
705
730
|
|
|
706
731
|
@contextlib.contextmanager
|
torchax/train.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import collections
|
|
2
16
|
import functools
|
|
3
17
|
import torch
|
|
@@ -12,106 +26,107 @@ mark_sharding = torch_view(jax.lax.with_sharding_constraint)
|
|
|
12
26
|
|
|
13
27
|
|
|
14
28
|
def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None):
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
29
|
+
"""Make a function that do one train step given model and loss.
|
|
30
|
+
|
|
31
|
+
model_fn: a function representing the model's forward:
|
|
32
|
+
i.e. has signature Callable[weights, buffers, args] -> result. Where,
|
|
33
|
+
weights is a pytree of trainable parameters
|
|
34
|
+
buffers is a pytree of non-trainable parameters / constants
|
|
35
|
+
args is the input data loaded from the data set
|
|
36
|
+
result is the return value of the model
|
|
37
|
+
loss_fn: a function to compute loss.
|
|
38
|
+
i.e. it has signature of Callable[result, label] -> loss
|
|
39
|
+
where, result is what model_fn returned
|
|
40
|
+
loss is loaded from the dataloader.
|
|
41
|
+
optax_optimizer: the optimizer from optax library. for example, optax.adam
|
|
42
|
+
remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
|
|
43
|
+
to do gradient checkpointing. If None, then it means checkpoint everything.
|
|
44
|
+
"""
|
|
45
|
+
env = torchax.default_env()
|
|
46
|
+
|
|
47
|
+
def loss(weights, buffers, args, label): # inputs are XLATensor
|
|
48
|
+
with env, jax.named_scope("compute_loss"):
|
|
49
|
+
res = model_fn(weights, buffers, args)
|
|
50
|
+
l = loss_fn(res, label)
|
|
51
|
+
return l
|
|
52
|
+
|
|
53
|
+
# loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy})
|
|
54
|
+
grad_fn = interop.jax_value_and_grad(loss)
|
|
55
|
+
|
|
56
|
+
def step(weights, buffers, opt_state, args, label): # inputs are array
|
|
57
|
+
with jax.named_scope("compute_gradient"):
|
|
58
|
+
loss, gradient = grad_fn(weights, buffers, args, label)
|
|
59
|
+
|
|
60
|
+
with jax.named_scope("optimizer_updates"):
|
|
61
|
+
updates, opt_state = interop.call_jax(
|
|
62
|
+
optax_optimizer.update, gradient, opt_state, weights
|
|
63
|
+
)
|
|
64
|
+
weights = interop.call_jax(optax.apply_updates, weights, updates)
|
|
65
|
+
return loss, weights, opt_state
|
|
66
|
+
|
|
67
|
+
# TODO: apply jax.jit so the user don't have to.
|
|
68
|
+
return step
|
|
54
69
|
|
|
55
70
|
|
|
56
71
|
class Container:
|
|
57
|
-
|
|
72
|
+
pass
|
|
58
73
|
|
|
59
74
|
|
|
60
75
|
class ScannedModule(torch.nn.Module):
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
76
|
+
def __init__(self, module_list, checkpoint_policy=None):
|
|
77
|
+
super().__init__()
|
|
78
|
+
|
|
79
|
+
self.c = None
|
|
80
|
+
assert module_list
|
|
81
|
+
self.c = Container()
|
|
82
|
+
self.c.one_mod = module_list[0]
|
|
83
|
+
self.checkpoint_policy = checkpoint_policy
|
|
84
|
+
|
|
85
|
+
weights = self._stack_layer_weights(module_list)
|
|
86
|
+
self.layer_weights_keys = list(self.c.one_mod.state_dict().keys())
|
|
87
|
+
self.params = torch.nn.ParameterDict(
|
|
88
|
+
{self._param_name_new(k): v for k, v in weights.items()}
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def _stack_layer_weights(self, module_list):
|
|
92
|
+
# Create weights such that, for every [n, m] weights
|
|
93
|
+
# becomes [k, n, m] where k is number of layer
|
|
94
|
+
# i.e. stacking layer weights together
|
|
95
|
+
temp = collections.defaultdict(list)
|
|
96
|
+
for m in module_list:
|
|
97
|
+
for k, v in m.state_dict().items():
|
|
98
|
+
temp[k].append(v)
|
|
99
|
+
res = {k: torch.stack(v) for k, v in temp.items()}
|
|
100
|
+
return res
|
|
101
|
+
|
|
102
|
+
def _param_name_new(self, old):
|
|
103
|
+
return "___".join(old.split("."))
|
|
104
|
+
|
|
105
|
+
def _param_name_old(self, new):
|
|
106
|
+
return ".".join(new.split("___"))
|
|
107
|
+
|
|
108
|
+
def forward(self, *args, **kwargs):
|
|
109
|
+
assert not kwargs
|
|
110
|
+
weights = {
|
|
111
|
+
k: self.params[self._param_name_new(k)]
|
|
112
|
+
for k in self.layer_weights_keys
|
|
113
|
+
}
|
|
114
|
+
scan = interop.torch_view(jax.lax.scan)
|
|
115
|
+
|
|
116
|
+
def eval_one_layer(args, weight):
|
|
117
|
+
# unpack args
|
|
118
|
+
h, *rest = args
|
|
119
|
+
newh = torch.func.functional_call(self.c.one_mod, weight, args)
|
|
120
|
+
# next layer's input; and residual to be added to list
|
|
121
|
+
return (newh, *rest), None
|
|
122
|
+
|
|
123
|
+
_eval_one_layer = interop.gradient_checkpoint(
|
|
124
|
+
eval_one_layer,
|
|
125
|
+
kwargs={"policy": self.checkpoint_policy},
|
|
126
|
+
)
|
|
127
|
+
h, _ = scan(
|
|
128
|
+
_eval_one_layer,
|
|
129
|
+
args,
|
|
130
|
+
weights,
|
|
131
|
+
)
|
|
132
|
+
return h[0]
|
torchax/types.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Callable, Any, Union, ParamSpec, TypeAlias
|
|
2
16
|
import torch
|
|
3
17
|
import jax
|
torchax/util.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Any, Callable
|
|
2
16
|
|
|
3
17
|
|
torchax/view.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import torch
|
|
2
16
|
import torch.utils._pytree as torch_pytree
|
|
3
17
|
import jax
|