torchax 0.0.10.dev20251116__py3-none-any.whl → 0.0.11.dev202612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

@@ -16,12 +16,10 @@
16
16
  Forked at: https://raw.githubusercontent.com/mlperf/training_results_v0.7/refs/heads/master/Google/benchmarks/ssd/implementations/ssd-research-JAX-tpu-v3-4096/nms.py
17
17
  """
18
18
 
19
- import functools
20
- from typing import List, Union, Optional, Tuple
21
-
19
+ import jax.numpy as jnp
22
20
  import torch
23
21
  from jax import lax
24
- import jax.numpy as jnp
22
+
25
23
  from . import ops_registry
26
24
 
27
25
  _NMS_TILE_SIZE = 256
@@ -38,9 +36,11 @@ def _bbox_overlap(boxes, gt_boxes):
38
36
  iou: Intersection over union matrix of all input bounding boxes
39
37
  """
40
38
  bb_y_min, bb_x_min, bb_y_max, bb_x_max = jnp.split(
41
- ary=boxes, indices_or_sections=4, axis=2)
39
+ ary=boxes, indices_or_sections=4, axis=2
40
+ )
42
41
  gt_y_min, gt_x_min, gt_y_max, gt_x_max = jnp.split(
43
- ary=gt_boxes, indices_or_sections=4, axis=2)
42
+ ary=gt_boxes, indices_or_sections=4, axis=2
43
+ )
44
44
 
45
45
  # Calculates the intersection area.
46
46
  i_xmin = jnp.maximum(bb_x_min, jnp.transpose(gt_x_min, [0, 2, 1]))
@@ -64,11 +64,16 @@ def _bbox_overlap(boxes, gt_boxes):
64
64
  def _self_suppression(in_args):
65
65
  iou, _, iou_sum = in_args
66
66
  batch_size = iou.shape[0]
67
- can_suppress_others = jnp.reshape(
68
- jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(iou.dtype)
69
- iou_suppressed = jnp.reshape(
70
- (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(
71
- iou.dtype), [batch_size, -1, 1]) * iou
67
+ can_suppress_others = jnp.reshape(jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(
68
+ iou.dtype
69
+ )
70
+ iou_suppressed = (
71
+ jnp.reshape(
72
+ (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(iou.dtype),
73
+ [batch_size, -1, 1],
74
+ )
75
+ * iou
76
+ )
72
77
  iou_sum_new = jnp.sum(iou_suppressed, [1, 2])
73
78
  return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new
74
79
 
@@ -76,11 +81,14 @@ def _self_suppression(in_args):
76
81
  def _cross_suppression(in_args):
77
82
  boxes, box_slice, iou_threshold, inner_idx = in_args
78
83
  batch_size = boxes.shape[0]
79
- new_slice = lax.dynamic_slice(boxes, [0, inner_idx * _NMS_TILE_SIZE, 0],
80
- [batch_size, _NMS_TILE_SIZE, 4])
84
+ new_slice = lax.dynamic_slice(
85
+ boxes, [0, inner_idx * _NMS_TILE_SIZE, 0], [batch_size, _NMS_TILE_SIZE, 4]
86
+ )
81
87
  iou = _bbox_overlap(new_slice, box_slice)
82
- ret_slice = jnp.expand_dims((jnp.all(iou < iou_threshold, [1])).astype(
83
- box_slice.dtype), 2) * box_slice
88
+ ret_slice = (
89
+ jnp.expand_dims((jnp.all(iou < iou_threshold, [1])).astype(box_slice.dtype), 2)
90
+ * box_slice
91
+ )
84
92
  return boxes, ret_slice, iou_threshold, inner_idx + 1
85
93
 
86
94
 
@@ -101,38 +109,44 @@ def _suppression_loop_body(in_args):
101
109
  batch_size = boxes.shape[0]
102
110
 
103
111
  # Iterates over tiles that can possibly suppress the current tile.
104
- box_slice = lax.dynamic_slice(boxes, [0, idx * _NMS_TILE_SIZE, 0],
105
- [batch_size, _NMS_TILE_SIZE, 4])
112
+ box_slice = lax.dynamic_slice(
113
+ boxes, [0, idx * _NMS_TILE_SIZE, 0], [batch_size, _NMS_TILE_SIZE, 4]
114
+ )
106
115
 
107
116
  def _loop_cond(in_args):
108
117
  _, _, _, inner_idx = in_args
109
118
  return inner_idx < idx
110
119
 
111
- _, box_slice, _, _ = lax.while_loop(_loop_cond, _cross_suppression,
112
- (boxes, box_slice, iou_threshold, 0))
120
+ _, box_slice, _, _ = lax.while_loop(
121
+ _loop_cond, _cross_suppression, (boxes, box_slice, iou_threshold, 0)
122
+ )
113
123
 
114
124
  # Iterates over the current tile to compute self-suppression.
115
125
  iou = _bbox_overlap(box_slice, box_slice)
116
126
  mask = jnp.expand_dims(
117
- jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1])
118
- > jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 0)
127
+ jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1])
128
+ > jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]),
129
+ 0,
130
+ )
119
131
  iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype)
120
132
 
121
133
  def _loop_cond2(in_args):
122
134
  _, loop_condition, _ = in_args
123
135
  return loop_condition
124
136
 
125
- suppressed_iou, _, _ = lax.while_loop(_loop_cond2, _self_suppression,
126
- (iou, True, jnp.sum(iou, [1, 2])))
137
+ suppressed_iou, _, _ = lax.while_loop(
138
+ _loop_cond2, _self_suppression, (iou, True, jnp.sum(iou, [1, 2]))
139
+ )
127
140
  suppressed_box = jnp.sum(suppressed_iou, 1) > 0
128
141
  box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2)
129
142
 
130
143
  # Uses box_slice to update the input boxes.
131
- mask = jnp.reshape((jnp.equal(jnp.arange(num_tiles),
132
- idx)).astype(boxes.dtype), [1, -1, 1, 1])
133
- boxes = jnp.tile(jnp.expand_dims(
134
- box_slice, 1), [1, num_tiles, 1, 1]) * mask + jnp.reshape(
135
- boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask)
144
+ mask = jnp.reshape(
145
+ (jnp.equal(jnp.arange(num_tiles), idx)).astype(boxes.dtype), [1, -1, 1, 1]
146
+ )
147
+ boxes = jnp.tile(
148
+ jnp.expand_dims(box_slice, 1), [1, num_tiles, 1, 1]
149
+ ) * mask + jnp.reshape(boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask)
136
150
  boxes = jnp.reshape(boxes, [batch_size, -1, 4])
137
151
 
138
152
  # Updates output_size.
@@ -193,8 +207,7 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
193
207
  """
194
208
  batch_size = boxes.shape[0]
195
209
  num_boxes = boxes.shape[1]
196
- pad = int(jnp.ceil(
197
- float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes
210
+ pad = int(jnp.ceil(float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes
198
211
  boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]])
199
212
  scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]])
200
213
  num_boxes += pad
@@ -202,29 +215,37 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
202
215
  def _loop_cond(in_args):
203
216
  unused_boxes, unused_threshold, output_size, idx = in_args
204
217
  return jnp.logical_and(
205
- jnp.min(output_size) < max_output_size, idx
206
- < num_boxes // _NMS_TILE_SIZE)
218
+ jnp.min(output_size) < max_output_size, idx < num_boxes // _NMS_TILE_SIZE
219
+ )
207
220
 
208
221
  selected_boxes, _, output_size, _ = lax.while_loop(
209
- _loop_cond, _suppression_loop_body,
210
- (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0))
222
+ _loop_cond,
223
+ _suppression_loop_body,
224
+ (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0),
225
+ )
211
226
  idx = num_boxes - lax.top_k(
212
- jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) *
213
- jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0),
214
- max_output_size)[0].astype(jnp.int32)
227
+ jnp.any(selected_boxes > 0, [2]).astype(jnp.int32)
228
+ * jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0),
229
+ max_output_size,
230
+ )[0].astype(jnp.int32)
215
231
  idx = jnp.minimum(idx, num_boxes - 1)
216
232
  idx = jnp.reshape(
217
- idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1])
233
+ idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1]
234
+ )
218
235
 
219
236
  return idx
220
- boxes = jnp.reshape((jnp.reshape(boxes, [-1, 4]))[idx],
221
- [batch_size, max_output_size, 4])
222
- boxes = boxes * (jnp.reshape(jnp.arange(max_output_size), [1, -1, 1])
223
- < jnp.reshape(output_size, [-1, 1, 1])).astype(boxes.dtype)
224
- scores = jnp.reshape(
225
- jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size])
226
- scores = scores * (jnp.reshape(jnp.arange(max_output_size), [1, -1])
227
- < jnp.reshape(output_size, [-1, 1])).astype(scores.dtype)
237
+ boxes = jnp.reshape(
238
+ (jnp.reshape(boxes, [-1, 4]))[idx], [batch_size, max_output_size, 4]
239
+ )
240
+ boxes = boxes * (
241
+ jnp.reshape(jnp.arange(max_output_size), [1, -1, 1])
242
+ < jnp.reshape(output_size, [-1, 1, 1])
243
+ ).astype(boxes.dtype)
244
+ scores = jnp.reshape(jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size])
245
+ scores = scores * (
246
+ jnp.reshape(jnp.arange(max_output_size), [1, -1])
247
+ < jnp.reshape(output_size, [-1, 1])
248
+ ).astype(scores.dtype)
228
249
  return scores, boxes
229
250
 
230
251
 
@@ -235,14 +256,13 @@ def nms(boxes, scores, iou_threshold):
235
256
  max_output_size = boxes.shape[0]
236
257
  boxes = boxes.reshape((1, *boxes.shape))
237
258
  scores = scores.reshape((1, *scores.shape))
238
- res = non_max_suppression_padded(scores, boxes, max_output_size,
239
- iou_threshold)
259
+ res = non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold)
240
260
  return res
241
261
 
242
262
 
243
263
  try:
244
264
  import torch
245
- import torchvision
265
+
246
266
  ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms)
247
267
  except Exception:
248
268
  pass
torchax/ops/mappings.py CHANGED
@@ -12,20 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from jax import dlpack as jaxdl
16
15
  import jax.numpy as jnp
17
16
  import numpy
18
17
  import torch
19
18
  import torch.func
20
- import torch.utils.dlpack as torchdl
21
19
  import torch.utils._mode_utils as mode_utils
20
+ import torch.utils.dlpack as torchdl
21
+ from jax import dlpack as jaxdl
22
22
 
23
23
  NUMPY_UNSUPPORTED_DTYPES = {
24
- torch.bfloat16: jnp.bfloat16,
25
- torch.float8_e4m3fn: jnp.float8_e4m3fn,
26
- torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz,
27
- torch.float8_e5m2: jnp.float8_e5m2,
28
- torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz,
24
+ torch.bfloat16: jnp.bfloat16,
25
+ torch.float8_e4m3fn: jnp.float8_e4m3fn,
26
+ torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz,
27
+ torch.float8_e5m2: jnp.float8_e5m2,
28
+ torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz,
29
29
  }
30
30
 
31
31
 
@@ -51,8 +51,9 @@ def t2j(t, use_dlpack=True):
51
51
  # https://github.com/google/jax/issues/7657
52
52
  # https://github.com/google/jax/issues/17784
53
53
  if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
54
- nparray = (t.cpu().detach().to(torch.float32).numpy()
55
- ) # handle dtypes not supported by numpy
54
+ nparray = (
55
+ t.cpu().detach().to(torch.float32).numpy()
56
+ ) # handle dtypes not supported by numpy
56
57
  else:
57
58
  nparray = t.cpu().detach().numpy()
58
59
  res = jnp.asarray(nparray)
@@ -91,71 +92,49 @@ def j2t(x, use_dlpack=True):
91
92
 
92
93
 
93
94
  TORCH_DTYPE_TO_JAX = {
94
- # NO_MAPPING : jnp.float0.dtype (signless scalar int),
95
- torch.bool:
96
- jnp.bool_.dtype,
97
- # NO_MAPPING : jnp.int4.dtype,
98
- torch.int8:
99
- jnp.int8.dtype,
100
- torch.int16:
101
- jnp.int16.dtype,
102
- torch.int32:
103
- jnp.int32.dtype,
104
- torch.int64:
105
- jnp.int64.dtype,
106
- torch.long:
107
- jnp.int64.dtype,
108
- # NO_MAPPING : jnp.uint4
109
- torch.uint8:
110
- jnp.uint8.dtype,
111
- torch.uint16:
112
- jnp.uint16.dtype,
113
- torch.uint32:
114
- jnp.uint32.dtype,
115
- torch.uint64:
116
- jnp.uint64.dtype,
117
- # NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype,
118
- torch.float8_e4m3fn:
119
- jnp.float8_e4m3fn.dtype,
120
- # NO_MAPPING : jnp.float8_e4m3fnuz.dtype,
121
- torch.float8_e5m2:
122
- jnp.float8_e5m2.dtype,
123
- # NO_MAPPING : jnp.float8_e5m2fnuz.dtype,
124
- torch.bfloat16:
125
- jnp.bfloat16.dtype,
126
- torch.half:
127
- jnp.float16.dtype,
128
- torch.float16:
129
- jnp.float16.dtype,
130
- torch.float32:
131
- jnp.float32.dtype,
132
- torch.float64:
133
- jnp.float64.dtype,
134
- torch.double:
135
- jnp.double.dtype,
136
- torch.complex64:
137
- jnp.complex64.dtype,
138
- torch.complex128:
139
- jnp.complex128.dtype,
140
- None:
141
- None,
95
+ # NO_MAPPING : jnp.float0.dtype (signless scalar int),
96
+ torch.bool: jnp.bool_.dtype,
97
+ # NO_MAPPING : jnp.int4.dtype,
98
+ torch.int8: jnp.int8.dtype,
99
+ torch.int16: jnp.int16.dtype,
100
+ torch.int32: jnp.int32.dtype,
101
+ torch.int64: jnp.int64.dtype,
102
+ torch.long: jnp.int64.dtype,
103
+ # NO_MAPPING : jnp.uint4
104
+ torch.uint8: jnp.uint8.dtype,
105
+ torch.uint16: jnp.uint16.dtype,
106
+ torch.uint32: jnp.uint32.dtype,
107
+ torch.uint64: jnp.uint64.dtype,
108
+ torch.float4_e2m1fn_x2: jnp.float4_e2m1fn.dtype,
109
+ # NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype,
110
+ torch.float8_e4m3fn: jnp.float8_e4m3fn.dtype,
111
+ # NO_MAPPING : jnp.float8_e4m3fnuz.dtype,
112
+ torch.float8_e5m2: jnp.float8_e5m2.dtype,
113
+ # NO_MAPPING : jnp.float8_e5m2fnuz.dtype,
114
+ torch.bfloat16: jnp.bfloat16.dtype,
115
+ torch.half: jnp.float16.dtype,
116
+ torch.float16: jnp.float16.dtype,
117
+ torch.float32: jnp.float32.dtype,
118
+ torch.float64: jnp.float64.dtype,
119
+ torch.double: jnp.double.dtype,
120
+ torch.complex64: jnp.complex64.dtype,
121
+ torch.complex128: jnp.complex128.dtype,
122
+ None: None,
142
123
  }
143
124
 
144
125
  JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()}
145
126
  # Add imprecise mappings for some JAX dtypes which don't have torch analogues
146
- JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8
147
- JAX_DTYPE_TO_TORCH[jnp.dtype('uint4')] = torch.uint8
127
+ JAX_DTYPE_TO_TORCH[jnp.dtype("int4")] = torch.int8
128
+ JAX_DTYPE_TO_TORCH[jnp.dtype("uint4")] = torch.uint8
148
129
 
149
130
 
150
131
  def t2j_dtype(dtype):
151
132
  if dtype not in TORCH_DTYPE_TO_JAX:
152
- raise RuntimeError(
153
- f'Attempting to convert unknown type: {dtype} to jax type,')
133
+ raise RuntimeError(f"Attempting to convert unknown type: {dtype} to jax type,")
154
134
  return TORCH_DTYPE_TO_JAX[dtype]
155
135
 
156
136
 
157
137
  def j2t_dtype(dtype):
158
138
  if dtype not in JAX_DTYPE_TO_TORCH:
159
- raise RuntimeError(
160
- f'Attempting to convert unknown type: {dtype} to torch type,')
139
+ raise RuntimeError(f"Attempting to convert unknown type: {dtype} to torch type,")
161
140
  return JAX_DTYPE_TO_TORCH[dtype]
torchax/ops/op_base.py CHANGED
@@ -13,25 +13,23 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import functools
16
+ from collections.abc import Callable
17
+ from typing import Concatenate, ParamSpec
18
+
16
19
  import jax
17
20
  import jax.numpy as jnp
18
21
  import numpy as np
19
22
  import torch
23
+
24
+ from torchax import types
20
25
  from torchax.ops import mappings
21
26
  from torchax.view import View
22
- from torchax import types
23
- import sys
24
-
25
- from typing import Callable, Optional, ParamSpec, Concatenate
26
27
 
27
28
 
28
29
  class InplaceOp:
29
-
30
- def __init__(self,
31
- functional_op,
32
- replace=False,
33
- position_to_mutate=0,
34
- is_jax_func=False):
30
+ def __init__(
31
+ self, functional_op, replace=False, position_to_mutate=0, is_jax_func=False
32
+ ):
35
33
  self.functional = functional_op
36
34
  self.replace = replace
37
35
  self.position_to_mutate = position_to_mutate
@@ -65,15 +63,14 @@ class InplaceOp:
65
63
 
66
64
 
67
65
  class OutVariant:
68
-
69
66
  def __call__(self, *args, **kwargs):
70
- to_mutate = kwargs['out']
71
- del kwargs['out']
67
+ to_mutate = kwargs["out"]
68
+ del kwargs["out"]
72
69
  to_mutate._elem = self.functional(*args, **kwargs)._elem
73
70
  return to_mutate
74
71
 
75
72
 
76
- P = ParamSpec('P')
73
+ P = ParamSpec("P")
77
74
 
78
75
 
79
76
  def convert_dtype(use_default_dtype: bool = True):
@@ -87,11 +84,8 @@ def convert_dtype(use_default_dtype: bool = True):
87
84
  """
88
85
 
89
86
  def decorator(func: types.TorchCallable):
90
-
91
87
  @functools.wraps(func)
92
- def wrapper(*args: P.args,
93
- dtype: Optional[torch.dtype] = None,
94
- **kwargs: P.kwargs):
88
+ def wrapper(*args: P.args, dtype: torch.dtype | None = None, **kwargs: P.kwargs):
95
89
  if not dtype and use_default_dtype:
96
90
  dtype = torch.get_default_dtype()
97
91
  if isinstance(dtype, torch.dtype):
@@ -106,8 +100,7 @@ def convert_dtype(use_default_dtype: bool = True):
106
100
  return decorator
107
101
 
108
102
 
109
- def maybe_convert_constant_dtype(val: Optional[types.JaxValue],
110
- dtype: Optional[jnp.dtype]):
103
+ def maybe_convert_constant_dtype(val: types.JaxValue | None, dtype: jnp.dtype | None):
111
104
  """Optionally converts scalar constant's dtype using `numpy`
112
105
 
113
106
  Use in cases where you require a constant and can't handle a traced array.
@@ -134,12 +127,11 @@ def promote_int_input(f: Callable[Concatenate[jax.Array, P], types.JaxValue]):
134
127
  return wrapper
135
128
 
136
129
 
137
- def foreach_loop(seq: jax.Array,
138
- fn: Callable[[jax.Array, jax.Array], jax.Array],
139
- init_val=0.0):
130
+ def foreach_loop(
131
+ seq: jax.Array, fn: Callable[[jax.Array, jax.Array], jax.Array], init_val=0.0
132
+ ):
140
133
  """Run `fn` for each element of 1D array `seq`.
141
134
 
142
135
  Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`."""
143
136
  assert len(seq.shape) == 1
144
- return jax.lax.fori_loop(0, len(seq), lambda i, carry: fn(carry, seq[i]),
145
- init_val)
137
+ return jax.lax.fori_loop(0, len(seq), lambda i, carry: fn(carry, seq[i]), init_val)
@@ -14,56 +14,61 @@
14
14
 
15
15
  import dataclasses
16
16
  import logging
17
- from torchax.types import JaxCallable, TorchCallable
18
17
 
19
- from typing import Union, Dict
18
+ from torchax.types import JaxCallable, TorchCallable
20
19
 
21
20
 
22
21
  @dataclasses.dataclass
23
22
  class Operator:
24
23
  torch_op: TorchCallable
25
- func: Union[TorchCallable, JaxCallable]
24
+ func: TorchCallable | JaxCallable
26
25
  is_jax_function: bool
27
26
  is_user_defined: bool
28
27
  needs_env: bool
29
28
  is_view_op: bool
30
29
 
31
30
 
32
- all_aten_ops: Dict[TorchCallable, Operator] = {}
33
- all_torch_functions: Dict[TorchCallable, Operator] = {}
31
+ all_aten_ops: dict[TorchCallable, Operator] = {}
32
+ all_torch_functions: dict[TorchCallable, Operator] = {}
34
33
 
35
34
 
36
- def register_torch_dispatch_op(aten_op,
37
- impl_callable,
38
- is_jax_function=True,
39
- is_user_defined=False,
40
- needs_env=False,
41
- is_view_op=False):
35
+ def register_torch_dispatch_op(
36
+ aten_op,
37
+ impl_callable,
38
+ is_jax_function=True,
39
+ is_user_defined=False,
40
+ needs_env=False,
41
+ is_view_op=False,
42
+ ):
42
43
  op = Operator(
43
- aten_op,
44
- impl_callable,
45
- is_jax_function=is_jax_function,
46
- is_user_defined=is_user_defined,
47
- needs_env=needs_env,
48
- is_view_op=is_view_op)
44
+ aten_op,
45
+ impl_callable,
46
+ is_jax_function=is_jax_function,
47
+ is_user_defined=is_user_defined,
48
+ needs_env=needs_env,
49
+ is_view_op=is_view_op,
50
+ )
49
51
  if aten_op in all_aten_ops:
50
- logging.warning(f'Duplicate op registration for {aten_op}')
52
+ logging.warning(f"Duplicate op registration for {aten_op}")
51
53
  all_aten_ops[aten_op] = op
52
54
  return impl_callable
53
55
 
54
56
 
55
- def register_torch_function_op(torch_func,
56
- impl_callable,
57
- is_jax_function=True,
58
- is_user_defined=False,
59
- needs_env=False,
60
- is_view_op=False):
57
+ def register_torch_function_op(
58
+ torch_func,
59
+ impl_callable,
60
+ is_jax_function=True,
61
+ is_user_defined=False,
62
+ needs_env=False,
63
+ is_view_op=False,
64
+ ):
61
65
  op = Operator(
62
- torch_func,
63
- impl_callable,
64
- is_jax_function=is_jax_function,
65
- is_user_defined=is_user_defined,
66
- needs_env=needs_env,
67
- is_view_op=is_view_op)
66
+ torch_func,
67
+ impl_callable,
68
+ is_jax_function=is_jax_function,
69
+ is_user_defined=is_user_defined,
70
+ needs_env=needs_env,
71
+ is_view_op=is_view_op,
72
+ )
68
73
  all_torch_functions[torch_func] = op
69
74
  return impl_callable