torchax 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/train.py CHANGED
@@ -7,14 +7,11 @@ from torchax import interop
7
7
  from torchax.interop import torch_view, jax_view
8
8
  import optax
9
9
 
10
-
11
10
  remat = torch_view(jax.remat)
12
11
  mark_sharding = torch_view(jax.lax.with_sharding_constraint)
13
12
 
14
13
 
15
- def make_train_step(model_fn,
16
- loss_fn, optax_optimizer,
17
- remat_policy=None):
14
+ def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None):
18
15
  """Make a function that do one train step given model and loss.
19
16
 
20
17
  model_fn: a function representing the model's forward:
@@ -32,7 +29,8 @@ def make_train_step(model_fn,
32
29
  to do gradient checkpointing. If None, then it means checkpoint everything.
33
30
  """
34
31
  env = torchax.default_env()
35
- def loss(weights, buffers, args, label): # inputs are XLATensor
32
+
33
+ def loss(weights, buffers, args, label): # inputs are XLATensor
36
34
  with env, jax.named_scope('compute_loss'):
37
35
  res = model_fn(weights, buffers, args)
38
36
  l = loss_fn(res, label)
@@ -41,26 +39,24 @@ def make_train_step(model_fn,
41
39
  loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy})
42
40
  grad_fn = interop.jax_value_and_grad(loss)
43
41
 
44
- def step(weights, buffers, opt_state, args, label): #inputs are array
42
+ def step(weights, buffers, opt_state, args, label): #inputs are array
45
43
  with jax.named_scope('compute_gradient'):
46
- loss, gradient = grad_fn(weights, buffers, args, label)
44
+ loss, gradient = grad_fn(weights, buffers, args, label)
47
45
 
48
46
  with jax.named_scope("optimizer_updates"):
49
- updates, opt_state = interop.call_jax(
50
- optax_optimizer.update,
51
- gradient, opt_state, weights)
52
- weights = interop.call_jax(optax.apply_updates, weights, updates)
47
+ updates, opt_state = interop.call_jax(optax_optimizer.update, gradient,
48
+ opt_state, weights)
49
+ weights = interop.call_jax(optax.apply_updates, weights, updates)
53
50
  return loss, weights, opt_state
54
51
 
55
52
  # TODO: apply jax.jit so the user don't have to.
56
53
  return step
57
54
 
58
-
59
-
60
-
55
+
61
56
  class Container:
62
57
  pass
63
58
 
59
+
64
60
  class ScannedModule(torch.nn.Module):
65
61
 
66
62
  def __init__(self, module_list, checkpoint_policy=None):
@@ -75,9 +71,9 @@ class ScannedModule(torch.nn.Module):
75
71
  weights = self._stack_layer_weights(module_list)
76
72
  self.layer_weights_keys = list(self.c.one_mod.state_dict().keys())
77
73
  self.params = torch.nn.ParameterDict({
78
- self._param_name_new(k): v for k, v in weights.items()
74
+ self._param_name_new(k): v for k, v in weights.items()
79
75
  })
80
-
76
+
81
77
  def _stack_layer_weights(self, module_list):
82
78
  # Create weights such that, for every [n, m] weights
83
79
  # becomes [k, n, m] where k is number of layer
@@ -85,36 +81,37 @@ class ScannedModule(torch.nn.Module):
85
81
  temp = collections.defaultdict(list)
86
82
  for m in module_list:
87
83
  for k, v in m.state_dict().items():
88
- temp[k].append(v)
84
+ temp[k].append(v)
89
85
  res = {k: torch.stack(v) for k, v in temp.items()}
90
86
  return res
91
87
 
92
-
93
88
  def _param_name_new(self, old):
94
- return '___'.join(old.split('.'))
89
+ return '___'.join(old.split('.'))
95
90
 
96
91
  def _param_name_old(self, new):
97
- return '.'.join(new.split('___'))
92
+ return '.'.join(new.split('___'))
98
93
 
99
94
  def forward(self, *args, **kwargs):
100
- assert not kwargs
101
- weights = {k: self.params[self._param_name_new(k)] for k in self.layer_weights_keys}
102
- scan = interop.torch_view(jax.lax.scan)
103
-
104
- def eval_one_layer(args, weight):
105
- # unpack args
106
- h, *rest = args
107
- newh = torch.func.functional_call(self.c.one_mod, weight, args)
108
- # next layer's input; and residual to be added to list
109
- return (newh, *rest), None
110
-
111
- _eval_one_layer = interop.gradient_checkpoint(
112
- eval_one_layer,
113
- kwargs={'policy': self.checkpoint_policy},
114
- )
115
- h, _ = scan(
116
- _eval_one_layer,
117
- args,
118
- weights,
119
- )
120
- return h[0]
95
+ assert not kwargs
96
+ weights = {
97
+ k: self.params[self._param_name_new(k)] for k in self.layer_weights_keys
98
+ }
99
+ scan = interop.torch_view(jax.lax.scan)
100
+
101
+ def eval_one_layer(args, weight):
102
+ # unpack args
103
+ h, *rest = args
104
+ newh = torch.func.functional_call(self.c.one_mod, weight, args)
105
+ # next layer's input; and residual to be added to list
106
+ return (newh, *rest), None
107
+
108
+ _eval_one_layer = interop.gradient_checkpoint(
109
+ eval_one_layer,
110
+ kwargs={'policy': self.checkpoint_policy},
111
+ )
112
+ h, _ = scan(
113
+ _eval_one_layer,
114
+ args,
115
+ weights,
116
+ )
117
+ return h[0]
torchax/util.py ADDED
@@ -0,0 +1,88 @@
1
+ from typing import Any, Callable
2
+
3
+
4
+ def partition(original: list[Any],
5
+ func: Callable[[Any], bool]) -> tuple[list[Any], list[Any]]:
6
+ """Partitions elements into two parallel lists based on a predicate function.
7
+
8
+ Iterates through the 'original' list, applying 'func' to each element 'a'.
9
+ - If `func(a)` returns True, 'a' is appended to the first list ('truthy')
10
+ and `None` is appended to the second list ('falsy').
11
+ - If `func(a)` returns False, `None` is appended to the first list ('truthy')
12
+ and 'a' is appended to the second list ('falsy').
13
+
14
+ The result is two lists of the same length as the 'original' list, acting
15
+ as parallel representations of the partitioned elements, using `None` as
16
+ placeholders.
17
+
18
+ This is useful when we want to mark a group of elements as static (via passing
19
+ static_argnums) or donated (via donate_argnums) when combining with jax.jit
20
+ and friends.
21
+
22
+ Args:
23
+ original: The list of elements to partition.
24
+ func: A callable (function or lambda) that accepts an element from
25
+ 'original' and returns a boolean value (True or False).
26
+
27
+ Returns:
28
+ A tuple containing two lists (`truthy`, `falsy`), both of the same
29
+ length as `original`:
30
+ - The first list contains elements `x` where `func(x)` was True, and
31
+ `None` otherwise.
32
+ - The second list contains elements `x` where `func(x)` was False, and
33
+ `None` otherwise.
34
+
35
+ Example:
36
+ >>> def is_even(n): return n % 2 == 0
37
+ >>> nums = [1, 2, 3, 4, 5, 6]
38
+ >>> truthy_list, falsy_list = partition(nums, is_even)
39
+ >>> truthy_list
40
+ [None, 2, None, 4, None, 6]
41
+ >>> falsy_list
42
+ [1, None, 3, None, 5, None]
43
+ """
44
+ truthy = []
45
+ falsy = []
46
+ for a in original:
47
+ t, f = (a, None) if func(a) else (None, a)
48
+ truthy.append(t)
49
+ falsy.append(f)
50
+ return truthy, falsy
51
+
52
+
53
+ def merge(list1: list[Any], list2: list[Any]) -> list[Any]:
54
+ """Merges two lists element-wise, prioritizing non-None elements from list1.
55
+
56
+ Creates a new list where each element is taken from the corresponding position
57
+ in 'list1', unless that element is None, in which case the element from the
58
+ corresponding position in 'list2' is used. Assumes both lists have the
59
+ same length.
60
+
61
+ Invariant: merge(*partion(input_list, predicate)) == input_list for any predicate
62
+
63
+ Args:
64
+ list1: The primary list. Its elements are preferred unless they are None.
65
+ list2: The secondary list. Its elements are used as fallbacks when the
66
+ corresponding element in list1 is None.
67
+
68
+ Returns:
69
+ A new list representing the merged result.
70
+
71
+ Raises:
72
+ AssertionError: If 'list1' and 'list2' do not have the same length.
73
+
74
+ Example:
75
+ >>> l1 = [1, None, 3, None]
76
+ >>> l2 = [None, 2, None, 4]
77
+ >>> merge(l1, l2)
78
+ [1, 2, 3, 4]
79
+ >>> l3 = [None, 'b', None]
80
+ >>> l4 = ['a', None, 'c']
81
+ >>> merge(l3, l4)
82
+ ['a', 'b', 'c']
83
+ """
84
+ assert len(list1) == len(list2)
85
+ res = []
86
+ for a, b in zip(list1, list2):
87
+ res.append(b if a is None else a)
88
+ return res
torchax/view.py ADDED
@@ -0,0 +1,377 @@
1
+ import torch
2
+ import torch.utils._pytree as torch_pytree
3
+ import jax
4
+ from enum import Enum
5
+ from typing import Union, List, Tuple, Optional, Any, cast
6
+ from abc import ABC, abstractmethod
7
+
8
+ # Reference to original PyTorch native functions
9
+ # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
10
+
11
+
12
+ class ViewInfoType(Enum):
13
+ INVALID = 0
14
+ NARROW = 1
15
+ NO_OP = 2
16
+ PERMUTE = 3
17
+ RESHAPE = 4
18
+ RESIZE = 5
19
+ SELECT = 6
20
+ AS_STRIDED = 7
21
+ DIAGONAL = 8
22
+
23
+
24
+ class ViewInfo(ABC):
25
+ """
26
+ Abstract base class for all view operations.
27
+ Defines the interface for applying and updating view transformations.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ view_info_type: ViewInfoType = ViewInfoType.INVALID,
33
+ ):
34
+ """
35
+ Initialize a ViewInfo object.
36
+
37
+ Args:
38
+ view_info_type: The type of view operation
39
+ """
40
+ self.view_info_type = view_info_type
41
+
42
+ @abstractmethod
43
+ def update_tensor(self, new_value: jax.Array,
44
+ jax_array: jax.Array) -> jax.Array:
45
+ """
46
+ Apply this view transformation to a JAX array and update its value.
47
+
48
+ Args:
49
+ new_value: The new values to set in the view
50
+ jax_array: The parent array to update
51
+
52
+ Returns:
53
+ Updated array
54
+ """
55
+ pass
56
+
57
+ @abstractmethod
58
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
59
+ """
60
+ Apply this view transformation to a JAX array.
61
+
62
+ Args:
63
+ jax_array: The array to transform
64
+
65
+ Returns:
66
+ Transformed array
67
+ """
68
+ pass
69
+
70
+ @abstractmethod
71
+ def calculate_output_shape(self, source: jax.Array) -> List[int]:
72
+ """
73
+ Calculate the resulting shape after applying this view.
74
+
75
+ Args:
76
+ source: Original jax array before transformation
77
+
78
+ Returns:
79
+ Resulting shape after transformation
80
+ """
81
+ pass
82
+
83
+
84
+ class NarrowInfo(ViewInfo):
85
+ """
86
+ Represents a slicing operation on a tensor.
87
+ Handles operations like tensor[1:3, :, 2:5:2].
88
+ """
89
+
90
+ def __init__(self, slices: Union[slice, Tuple[slice]]) -> None:
91
+ """
92
+ Args:
93
+ slices: The slice(s) to apply to the tensor.
94
+ E.g. jax_array.at[slices] will return the transformed tensor.
95
+ """
96
+ super().__init__(ViewInfoType.NARROW)
97
+ self.slices = slices
98
+
99
+ def __eq__(self, other: object) -> bool:
100
+ if not isinstance(other, NarrowInfo):
101
+ return False
102
+ return self.slices == other.slices
103
+
104
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
105
+ try:
106
+ return jax_array[self.slices]
107
+ except IndexError as e:
108
+ raise IndexError("Invalid slice operation") from e
109
+
110
+ def update_tensor(self, new_value: jax.Array,
111
+ jax_array: jax.Array) -> jax.Array:
112
+ return jax_array.at[self.slices].set(new_value)
113
+
114
+ def calculate_output_shape(self, source: jax.Array) -> List[int]:
115
+ return source[self.slices].shape
116
+
117
+
118
+ class SelectInfo(ViewInfo):
119
+ """
120
+ Represents a selection operation on a tensor.
121
+ Typically used for indexing operations that select specific elements.
122
+ """
123
+
124
+ def __init__(self,
125
+ dim: int = 0,
126
+ start: int = 0,
127
+ end: int = 0,
128
+ stride: int = 0) -> None:
129
+ super().__init__(ViewInfoType.SELECT)
130
+ self.dim: int = dim
131
+ self.start: int = start
132
+ self.end: int = end
133
+ self.stride: int = stride
134
+
135
+ def __eq__(self, other: object) -> bool:
136
+ if not isinstance(other, SelectInfo):
137
+ return False
138
+ return (self.dim == other.dim and self.start == other.start and
139
+ self.end == other.end and self.stride == other.stride)
140
+
141
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
142
+ raise NotImplementedError("SelectInfo.apply not implemented")
143
+
144
+ def update_tensor(self, new_value: jax.Array,
145
+ jax_array: jax.Array) -> jax.Array:
146
+ raise NotImplementedError("SelectInfo.update not implemented")
147
+
148
+ def calculate_output_shape(self, source: jax.Array) -> List[int]:
149
+ raise NotImplementedError(
150
+ "SelectInfo.calculate_output_shape not implemented")
151
+
152
+
153
+ class AsStridedInfo(ViewInfo):
154
+ """
155
+ Information for as_strided operations.
156
+ """
157
+
158
+ def __init__(self, stride: List[int], offset: int = 0) -> None:
159
+ super().__init__(ViewInfoType.AS_STRIDED)
160
+ self.stride: List[int] = stride
161
+ self.offset: int = offset
162
+
163
+ def __eq__(self, other: object) -> bool:
164
+ if not isinstance(other, AsStridedInfo):
165
+ return False
166
+ return self.offset == other.offset and self.stride == other.stride
167
+
168
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
169
+ raise NotImplementedError("AsStridedInfo.apply not implemented")
170
+
171
+ def update_tensor(self, new_value: jax.Array,
172
+ jax_array: jax.Array) -> jax.Array:
173
+ raise NotImplementedError("AsStridedInfo.update not implemented")
174
+
175
+ def calculate_output_shape(self, source: jax.Array) -> List[int]:
176
+ raise NotImplementedError(
177
+ "AsStridedInfo.calculate_output_shape not implemented")
178
+
179
+
180
+ class DiagonalInfo(ViewInfo):
181
+ """
182
+ Information for diagonal operations.
183
+ Extracts diagonal elements from a tensor.
184
+ """
185
+
186
+ def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None:
187
+ """
188
+ Args:
189
+ offset: Offset from the main diagonal
190
+ dim1: First dimension for diagonal extraction
191
+ dim2: Second dimension for diagonal extraction
192
+ """
193
+ super().__init__(ViewInfoType.DIAGONAL)
194
+ self.offset: int = offset
195
+ self.dim1: int = dim1
196
+ self.dim2: int = dim2
197
+
198
+ def __eq__(self, other: object) -> bool:
199
+ if not isinstance(other, DiagonalInfo):
200
+ return False
201
+ return (self.offset == other.offset and self.dim1 == other.dim1 and
202
+ self.dim2 == other.dim2)
203
+
204
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
205
+ raise NotImplementedError("DiagonalInfo.apply not implemented")
206
+
207
+ def update_tensor(self, new_value: jax.Array,
208
+ jax_array: jax.Array) -> jax.Array:
209
+ raise NotImplementedError("DiagonalInfo.update not implemented")
210
+
211
+ def calculate_output_shape(self, source: jax.Array) -> List[int]:
212
+ raise NotImplementedError(
213
+ "DiagonalInfo.calculate_output_shape not implemented")
214
+
215
+
216
+ class View(torch.Tensor):
217
+ """
218
+ A View is a reference to another Tensor or another View,
219
+ with a transformation applied to it.
220
+ """
221
+
222
+ @staticmethod
223
+ def __new__(cls, parent: Union["torchax.Tensor", "View"], view_info: ViewInfo,
224
+ env: Any) -> "View":
225
+ """
226
+ Args:
227
+ parent: Parent tensor or view
228
+ view_info: Information about the view transformation
229
+ env: Environment for tensor operations
230
+ """
231
+ shape = view_info.calculate_output_shape(parent.jax())
232
+ return torch.Tensor._make_wrapper_subclass(
233
+ cls,
234
+ shape,
235
+ device="meta",
236
+ dtype=parent.dtype,
237
+ requires_grad=False,
238
+ )
239
+
240
+ def __init__(self, parent: Union["torchax.Tensor", "View"],
241
+ view_info: ViewInfo, env: Any) -> None:
242
+ super().__init__()
243
+ self.parent = parent
244
+ self.view_info = view_info
245
+ self._env = env
246
+
247
+ def get_transformation_chain(self) -> List[ViewInfo]:
248
+ """
249
+ Get all view transformations from the source tensor to this view.
250
+ """
251
+ if isinstance(self.parent, View):
252
+ transformations = self.parent.get_transformation_chain()
253
+ transformations.append(self.view_info)
254
+ return transformations
255
+ else:
256
+ return [self.view_info]
257
+
258
+ __torch_function__ = torch._C._disabled_torch_function_impl
259
+
260
+ def source_jax(self) -> jax.Array:
261
+ """
262
+ Returns the source tensor.
263
+ """
264
+ if isinstance(self.parent, View):
265
+ return self.parent.source_jax()
266
+ else:
267
+ return self.parent.jax()
268
+
269
+ def replace_source_jax(self, new_value: jax.Array) -> None:
270
+ """
271
+ Update the source tensor with new values.
272
+ """
273
+ if isinstance(self.parent, View):
274
+ self.parent.replace_source_jax(new_value)
275
+ else:
276
+ assert new_value.shape == self.parent._elem.shape
277
+ self.parent._elem = new_value
278
+
279
+ def torch(self) -> "torchax.Tensor":
280
+ """
281
+ Returns a Torchax tensor representing this view after all transformations
282
+ """
283
+ from torchax.tensor import Tensor
284
+
285
+ return Tensor(self.jax(), self._env)
286
+
287
+ def update(
288
+ self,
289
+ new_values: Union[jax.Array, "View", "torchax.Tensor"],
290
+ view_infos: Optional[List[ViewInfo]] = None,
291
+ ) -> None:
292
+ """
293
+ Update this view with new values, propagating changes back to source.
294
+ If view_infos is None, it will use the transformation chain
295
+ from the source tensor.
296
+ """
297
+ if view_infos is None:
298
+ view_infos = self.get_transformation_chain()
299
+
300
+ # Get the source JAX array
301
+ source_array = self.source_jax()
302
+
303
+ # Get the new value
304
+ from torchax.tensor import Tensor
305
+
306
+ if isinstance(new_values, View) or isinstance(new_values, Tensor):
307
+ new_values = new_values.jax()
308
+
309
+ # Apply all view transformations to the source array
310
+ # And store intermediate values
311
+ intermediate_values = [source_array]
312
+ for view_info in view_infos[:-1]:
313
+ intermediate_values.append(
314
+ view_info.transform_tensor(intermediate_values[-1]))
315
+
316
+ # TODO: Investigate efficiency of this algorithm
317
+ # Update the source array with the new value by
318
+ # applying inverse transformations in reverse order
319
+ for view_info, parent_array in zip(
320
+ reversed(view_infos), reversed(intermediate_values)):
321
+ # Apply the inverse transformation to propagate changes back
322
+ new_values = view_info.update_tensor(new_values, parent_array)
323
+
324
+ # Update the source tensor with the new values
325
+ self.replace_source_jax(new_values)
326
+
327
+ @classmethod
328
+ def __torch_dispatch__(
329
+ cls,
330
+ func: Any,
331
+ types: Tuple[Any, ...],
332
+ args: Tuple[Any, ...] = (),
333
+ kwargs: Optional[dict] = None,
334
+ ) -> Any:
335
+ raise AssertionError(
336
+ 'torchax Tensors can only do math within the torchax environment.'
337
+ 'Please wrap your code with `with torchax.default_env()` or '
338
+ 'call torchax.enable_globally() before.')
339
+
340
+ def create_sub_view(self, view_info: ViewInfo) -> "View":
341
+ """
342
+ Create a new view that is a child of this view.
343
+ """
344
+ return View(self, view_info, self._env)
345
+
346
+ def __str__(self) -> str:
347
+ return f"View({self.torch()})"
348
+
349
+ def jax(self) -> jax.Array:
350
+ """
351
+ Returns a copy of the source tensor after transformations.
352
+ """
353
+ result = self.source_jax()
354
+ for view_info in self.get_transformation_chain():
355
+ result = view_info.transform_tensor(result)
356
+ return result
357
+
358
+ def __setitem__(self, indexes, val):
359
+ view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)]
360
+ self.update(view_infos=view_infos, new_values=val)
361
+
362
+ def dim(self):
363
+ return self.ndim
364
+
365
+ @property
366
+ def device(self):
367
+ return torch.device("jax:0")
368
+
369
+ @property
370
+ def jax_device(self):
371
+ return self.jax().device
372
+
373
+ @property
374
+ def ndim(self):
375
+ return len(self.shape)
376
+
377
+ __repr__ = __str__