torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/CONTRIBUTING.md CHANGED
@@ -8,7 +8,7 @@ If you plan to contribute new features, utility functions or extensions to the c
8
8
  # Developer setup
9
9
 
10
10
  ## Mac setup:
11
- @qihqi
11
+ @qihqi
12
12
 
13
13
  I am able to develop directly on mac (m1) laptop for most of parts. Using steps
14
14
  in README.md works. The condensed version for easy copy & paste:
@@ -24,7 +24,7 @@ pytest test
24
24
 
25
25
  ### VSCode
26
26
 
27
- I use vscode on my Mac. I loosely followed instruction in
27
+ I use vscode on my Mac. I loosely followed instruction in
28
28
  https://code.visualstudio.com/docs/python/python-tutorial
29
29
  to setup a proper python environment.
30
30
 
torchax/__init__.py CHANGED
@@ -6,29 +6,31 @@ import os
6
6
  import torch
7
7
  from torch.utils import _pytree as pytree
8
8
  from torchax import tensor
9
- from torchax import distributed # noqa: F401
9
+ from contextlib import contextmanager
10
10
 
11
- __version__ = "0.0.4"
11
+ __version__ = "0.0.6"
12
12
  VERSION = __version__
13
13
 
14
14
  __all__ = [
15
- 'default_env',
16
- 'extract_jax',
17
- 'enable_globally',
15
+ 'default_env',
16
+ 'extract_jax',
17
+ 'enable_globally',
18
18
  ]
19
19
 
20
20
  from jax._src import xla_bridge
21
+
21
22
  os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
22
23
 
23
24
  # torchax:oss-begin
24
25
  if getattr(jax.config, 'jax_pjrt_client_create_options', None):
25
26
  jax.config.update(
26
- 'jax_pjrt_client_create_options',
27
- f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}'
28
- )
27
+ 'jax_pjrt_client_create_options',
28
+ f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}')
29
29
  # torchax:oss-end
30
30
 
31
31
  env = None
32
+
33
+
32
34
  def default_env():
33
35
  global env
34
36
 
@@ -37,53 +39,53 @@ def default_env():
37
39
  return env
38
40
 
39
41
 
40
-
41
42
  def extract_jax(mod: torch.nn.Module, env=None):
42
43
  """Returns a pytree of jax.ndarray and a jax callable."""
43
44
  if env is None:
44
45
  env = default_env()
45
- states = mod.state_dict()
46
+ states = dict(mod.named_buffers())
47
+ states.update(mod.named_parameters())
46
48
 
47
- states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)
49
+ states = env.t2j_copy(states)
48
50
 
49
51
  #@jax.jit
50
- def jax_func(states, inputs):
51
- (states, inputs) = env.j2t_iso((states, inputs))
52
+ def jax_func(states, args, kwargs=None):
53
+ (states, args, kwargs) = env.j2t_iso((states, args, kwargs))
52
54
  with env:
53
- res = torch.func.functional_call(mod, states, inputs, tie_weights=False)
55
+ res = torch.func.functional_call(
56
+ mod, states, args, kwargs, tie_weights=False)
54
57
  return env.t2j_iso(res)
55
58
 
56
59
  return states, jax_func
57
60
 
61
+
58
62
  def enable_globally():
59
63
  env = default_env().enable_torch_modes()
60
64
  return env
61
65
 
66
+
62
67
  def disable_globally():
63
- global env
68
+ global env
64
69
  default_env().disable_torch_modes()
65
70
 
71
+
66
72
  @contextlib.contextmanager
67
73
  def disable_temporarily():
68
74
  prev = default_env().enabled
69
75
  if prev:
70
76
  disable_globally()
71
- yield()
77
+ yield ()
72
78
  if prev:
73
79
  enable_globally()
74
80
 
75
81
 
76
82
  torch.utils.rename_privateuse1_backend('jax')
77
83
  unsupported_dtype = [torch.quint8]
78
- torch.utils.generate_methods_for_privateuse1_backend(
79
- for_tensor=True, for_module=True, for_storage=True,
80
- unsupported_dtype=unsupported_dtype)
81
84
 
82
85
  import jax
83
86
  import torchax.device_module
84
- torch._register_device_module('jax', torchax.device_module)
85
-
86
87
 
88
+ torch._register_device_module('jax', torchax.device_module)
87
89
 
88
90
 
89
91
  def enable_accuracy_mode():
@@ -98,13 +100,13 @@ def enable_performance_mode():
98
100
  default_env().config.internal_respect_torch_return_dtypes = False
99
101
 
100
102
 
101
-
102
103
  @dataclasses.dataclass
103
104
  class CompileOptions:
104
105
  # only valid if compiling nn.Module
105
- methods_to_compile: List[str] = dataclasses.field(default_factory=lambda: ['forward'])
106
+ methods_to_compile: List[str] = dataclasses.field(
107
+ default_factory=lambda: ['forward'])
106
108
  jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
107
- mode: str = 'jax' # or dynamo or export
109
+ mode: str = 'jax' # or dynamo or export
108
110
 
109
111
 
110
112
  def compile(fn, options: Optional[CompileOptions] = None):
torchax/amp.py ADDED
@@ -0,0 +1,332 @@
1
+ import contextlib
2
+ import enum
3
+ import torch
4
+ from torch.utils import _pytree as pytree
5
+
6
+
7
+ # enum class CastPolicy : uint8_t {
8
+ # lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
9
+ # // running the op. Currently, lower_precision_fp is
10
+ # // fp16 for AutocastCUDA, and is defined by user
11
+ # // (default bf16) for AutocastCPU or other device.
12
+ # fp32, // Cast all inputs to at::kFloat before running the op.
13
+ # fp32_set_opt_dtype, // Treats functions (like softmax) that
14
+ # // 1. we'd like to run in fp32 and
15
+ # // 2. have a std::optional<ScalarType> arg that controls
16
+ # // the output type.
17
+ # // fp32_set_opt_dtype wrappers' policy is: if the output
18
+ # // type is already set, don't touch it, otherwise, set
19
+ # // it to at::kFloat.
20
+ # fp32_append_dtype, // Treats functions (like norm) that
21
+ # // 1. we'd like to run in fp32 and
22
+ # // 2. have some overloads that accept an output type and
23
+ # // other overloads that don't.
24
+ # // fp32_append_dtype wrappers wrap the overloads that don't
25
+ # // have an output dtype.
26
+ # // The wrapper policy is: append at::kFloat to the args,
27
+ # // and redispatch to the type-aware overload.
28
+ # promote, // Run in the widest dtype among several args.
29
+ # };
30
+ class CastPolicy(enum.Enum):
31
+ LOWER_PRECISION_FP = 0
32
+ FP32 = 1
33
+ FP32_SET_OPT_DTYPE = 2
34
+ FP32_APPEND_DTYPE = 3
35
+ PROMOTE = 4
36
+
37
+
38
+ def execute_policy(policy, args, kwargs, target_lower_fp):
39
+
40
+ def is_float(a):
41
+ return isinstance(a, torch.Tensor) and a.is_floating_point()
42
+ match policy:
43
+ case CastPolicy.LOWER_PRECISION_FP:
44
+ return pytree.tree_map_only(is_float, lambda a: a.to(target_lower_fp),
45
+ (args, kwargs))
46
+ case CastPolicy.FP32:
47
+ return pytree.tree_map_only(is_float, lambda a: a.to(torch.float32),
48
+ (args, kwargs))
49
+ case CastPolicy.PROMOTE:
50
+ dtypes = set(a.dtype for a in args)
51
+ widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1]
52
+ return pytree.tree_map_only(is_float, lambda a: a.to(widest),
53
+ (args, kwargs))
54
+ case _:
55
+ raise AssertionError(f'Policy {policy} not implemented yet.')
56
+
57
+
58
+ @contextlib.contextmanager
59
+ def autocast(device, dtype=torch.bfloat16, env=None):
60
+ del device
61
+ if env is None:
62
+ import torchax
63
+ env = torchax.default_env()
64
+ with env.override_property(autocast_dtype=dtype):
65
+ yield
66
+
67
+
68
+ # https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
69
+ autocast_policy = {
70
+ torch.ops.aten.conv1d.default:
71
+ CastPolicy.LOWER_PRECISION_FP,
72
+ torch.ops.aten.conv1d.padding:
73
+ CastPolicy.LOWER_PRECISION_FP,
74
+ torch.ops.aten.conv2d.default:
75
+ CastPolicy.LOWER_PRECISION_FP,
76
+ torch.ops.aten.conv2d.padding:
77
+ CastPolicy.LOWER_PRECISION_FP,
78
+ torch.ops.aten.conv3d.default:
79
+ CastPolicy.LOWER_PRECISION_FP,
80
+ torch.ops.aten.conv3d.padding:
81
+ CastPolicy.LOWER_PRECISION_FP,
82
+ torch.ops.aten.bmm.default:
83
+ CastPolicy.LOWER_PRECISION_FP,
84
+ torch.ops.aten.mm.default:
85
+ CastPolicy.LOWER_PRECISION_FP,
86
+ torch.ops.aten.linalg_vecdot.default:
87
+ CastPolicy.LOWER_PRECISION_FP,
88
+ torch.ops.aten.baddbmm.default:
89
+ CastPolicy.LOWER_PRECISION_FP,
90
+ torch.ops.aten.addmm.default:
91
+ CastPolicy.LOWER_PRECISION_FP,
92
+ torch.ops.aten._addmm_activation.default:
93
+ CastPolicy.LOWER_PRECISION_FP,
94
+ torch.ops.aten.addbmm.default:
95
+ CastPolicy.LOWER_PRECISION_FP,
96
+ torch.ops.aten.linear.default:
97
+ CastPolicy.LOWER_PRECISION_FP,
98
+ torch.ops.aten._convolution.deprecated:
99
+ CastPolicy.LOWER_PRECISION_FP,
100
+ torch.ops.aten.matmul.default:
101
+ CastPolicy.LOWER_PRECISION_FP,
102
+ torch.ops.aten.conv_tbc.default:
103
+ CastPolicy.LOWER_PRECISION_FP,
104
+ torch.ops.aten.mkldnn_rnn_layer.default:
105
+ CastPolicy.LOWER_PRECISION_FP,
106
+ torch.ops.aten.conv_transpose1d.default:
107
+ CastPolicy.LOWER_PRECISION_FP,
108
+ torch.ops.aten.conv_transpose2d.input:
109
+ CastPolicy.LOWER_PRECISION_FP,
110
+ torch.ops.aten.conv_transpose3d.input:
111
+ CastPolicy.LOWER_PRECISION_FP,
112
+ torch.ops.aten.prelu.default:
113
+ CastPolicy.LOWER_PRECISION_FP,
114
+ torch.ops.aten.scaled_dot_product_attention.default:
115
+ CastPolicy.LOWER_PRECISION_FP,
116
+ torch.ops.aten._native_multi_head_attention.default:
117
+ CastPolicy.LOWER_PRECISION_FP,
118
+
119
+ # fp32 cast policy
120
+ torch.ops.aten.avg_pool3d.default:
121
+ CastPolicy.FP32,
122
+ torch.ops.aten.binary_cross_entropy.default:
123
+ CastPolicy.FP32,
124
+ torch.ops.aten.grid_sampler.default:
125
+ CastPolicy.FP32,
126
+ torch.ops.aten.polar.default:
127
+ CastPolicy.FP32,
128
+ torch.ops.aten.prod.default:
129
+ CastPolicy.FP32,
130
+ torch.ops.aten.prod.dim_int:
131
+ CastPolicy.FP32,
132
+ torch.ops.aten.prod.dim_Dimname:
133
+ CastPolicy.FP32,
134
+ torch.ops.aten.quantile.default:
135
+ CastPolicy.FP32,
136
+ torch.ops.aten.quantile.scalar:
137
+ CastPolicy.FP32,
138
+ torch.ops.aten.nanquantile.default:
139
+ CastPolicy.FP32,
140
+ torch.ops.aten.nanquantile.scalar:
141
+ CastPolicy.FP32,
142
+ torch.ops.aten.stft.default:
143
+ CastPolicy.FP32,
144
+ torch.ops.aten.stft.center:
145
+ CastPolicy.FP32,
146
+ torch.ops.aten.cdist.default:
147
+ CastPolicy.FP32,
148
+ torch.ops.aten.grid_sampler_2d.default:
149
+ CastPolicy.FP32,
150
+ torch.ops.aten._grid_sampler_2d_cpu_fallback.default:
151
+ CastPolicy.FP32,
152
+ torch.ops.aten.grid_sampler_3d.default:
153
+ CastPolicy.FP32,
154
+ torch.ops.aten.trace.default:
155
+ CastPolicy.FP32,
156
+ torch.ops.aten.view_as_complex.default:
157
+ CastPolicy.FP32,
158
+ torch.ops.aten.cholesky.default:
159
+ CastPolicy.FP32,
160
+ torch.ops.aten.cholesky_inverse.default:
161
+ CastPolicy.FP32,
162
+ torch.ops.aten.cholesky_solve.default:
163
+ CastPolicy.FP32,
164
+ torch.ops.aten.inverse.default:
165
+ CastPolicy.FP32,
166
+ torch.ops.aten.lu_solve.default:
167
+ CastPolicy.FP32,
168
+ torch.ops.aten.orgqr.default:
169
+ CastPolicy.FP32,
170
+ torch.ops.aten.ormqr.default:
171
+ CastPolicy.FP32,
172
+ torch.ops.aten.pinverse.default:
173
+ CastPolicy.FP32,
174
+ torch.ops.aten.max_pool3d.default:
175
+ CastPolicy.FP32,
176
+ torch.ops.aten.max_unpool2d.default:
177
+ CastPolicy.FP32,
178
+ torch.ops.aten.max_unpool3d.default:
179
+ CastPolicy.FP32,
180
+ torch.ops.aten.adaptive_avg_pool3d.default:
181
+ CastPolicy.FP32,
182
+ torch.ops.aten.reflection_pad1d.default:
183
+ CastPolicy.FP32,
184
+ torch.ops.aten.reflection_pad2d.default:
185
+ CastPolicy.FP32,
186
+ torch.ops.aten.replication_pad1d.default:
187
+ CastPolicy.FP32,
188
+ torch.ops.aten.replication_pad2d.default:
189
+ CastPolicy.FP32,
190
+ torch.ops.aten.replication_pad3d.default:
191
+ CastPolicy.FP32,
192
+ torch.ops.aten.mse_loss.default:
193
+ CastPolicy.FP32,
194
+ torch.ops.aten.cosine_embedding_loss.default:
195
+ CastPolicy.FP32,
196
+ torch.ops.aten.nll_loss.default:
197
+ CastPolicy.FP32,
198
+ torch.ops.aten.nll_loss2d.default:
199
+ CastPolicy.FP32,
200
+ torch.ops.aten.hinge_embedding_loss.default:
201
+ CastPolicy.FP32,
202
+ torch.ops.aten.poisson_nll_loss.default:
203
+ CastPolicy.FP32,
204
+ torch.ops.aten.smooth_l1_loss.default:
205
+ CastPolicy.FP32,
206
+ torch.ops.aten.cross_entropy_loss.default:
207
+ CastPolicy.FP32,
208
+ torch.ops.aten.l1_loss.default:
209
+ CastPolicy.FP32,
210
+ torch.ops.aten.huber_loss.default:
211
+ CastPolicy.FP32,
212
+ torch.ops.aten.margin_ranking_loss.default:
213
+ CastPolicy.FP32,
214
+ torch.ops.aten.soft_margin_loss.default:
215
+ CastPolicy.FP32,
216
+ torch.ops.aten.triplet_margin_loss.default:
217
+ CastPolicy.FP32,
218
+ torch.ops.aten.multi_margin_loss.default:
219
+ CastPolicy.FP32,
220
+ torch.ops.aten.ctc_loss.IntList:
221
+ CastPolicy.FP32,
222
+ torch.ops.aten.ctc_loss.Tensor:
223
+ CastPolicy.FP32,
224
+ torch.ops.aten.kl_div.default:
225
+ CastPolicy.FP32,
226
+ torch.ops.aten.multilabel_margin_loss.default:
227
+ CastPolicy.FP32,
228
+ torch.ops.aten.binary_cross_entropy_with_logits.default:
229
+ CastPolicy.FP32,
230
+ torch.ops.aten.fft_fft.default:
231
+ CastPolicy.FP32,
232
+ torch.ops.aten.fft_ifft.default:
233
+ CastPolicy.FP32,
234
+ torch.ops.aten.fft_fft2.default:
235
+ CastPolicy.FP32,
236
+ torch.ops.aten.fft_ifft2.default:
237
+ CastPolicy.FP32,
238
+ torch.ops.aten.fft_fftn.default:
239
+ CastPolicy.FP32,
240
+ torch.ops.aten.fft_ifftn.default:
241
+ CastPolicy.FP32,
242
+ torch.ops.aten.fft_rfft.default:
243
+ CastPolicy.FP32,
244
+ torch.ops.aten.fft_irfft.default:
245
+ CastPolicy.FP32,
246
+ torch.ops.aten.fft_rfft2.default:
247
+ CastPolicy.FP32,
248
+ torch.ops.aten.fft_irfft2.default:
249
+ CastPolicy.FP32,
250
+ torch.ops.aten.fft_rfftn.default:
251
+ CastPolicy.FP32,
252
+ torch.ops.aten.fft_irfftn.default:
253
+ CastPolicy.FP32,
254
+ torch.ops.aten.fft_hfft.default:
255
+ CastPolicy.FP32,
256
+ torch.ops.aten.fft_ihfft.default:
257
+ CastPolicy.FP32,
258
+ torch.ops.aten.linalg_cond.default:
259
+ CastPolicy.FP32,
260
+ torch.ops.aten.linalg_cond.p_str:
261
+ CastPolicy.FP32,
262
+ torch.ops.aten.linalg_matrix_rank.default:
263
+ CastPolicy.FP32,
264
+ torch.ops.aten.linalg_matrix_rank.tol_tensor:
265
+ CastPolicy.FP32,
266
+ torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor:
267
+ CastPolicy.FP32,
268
+ torch.ops.aten.linalg_matrix_rank.atol_rtol_float:
269
+ CastPolicy.FP32,
270
+ torch.ops.aten.linalg_solve.default:
271
+ CastPolicy.FP32,
272
+ torch.ops.aten.linalg_cholesky.default:
273
+ CastPolicy.FP32,
274
+ torch.ops.aten.linalg_svdvals.default:
275
+ CastPolicy.FP32,
276
+ torch.ops.aten.linalg_eigvals.default:
277
+ CastPolicy.FP32,
278
+ torch.ops.aten.linalg_eigvalsh.default:
279
+ CastPolicy.FP32,
280
+ torch.ops.aten.linalg_inv.default:
281
+ CastPolicy.FP32,
282
+ torch.ops.aten.linalg_householder_product.default:
283
+ CastPolicy.FP32,
284
+ torch.ops.aten.linalg_tensorinv.default:
285
+ CastPolicy.FP32,
286
+ torch.ops.aten.linalg_tensorsolve.default:
287
+ CastPolicy.FP32,
288
+ torch.ops.aten.fake_quantize_per_tensor_affine.default:
289
+ CastPolicy.FP32,
290
+ torch.ops.aten.geqrf.default:
291
+ CastPolicy.FP32,
292
+ torch.ops.aten._lu_with_info.default:
293
+ CastPolicy.FP32,
294
+ torch.ops.aten.qr.default:
295
+ CastPolicy.FP32,
296
+ torch.ops.aten.svd.default:
297
+ CastPolicy.FP32,
298
+ torch.ops.aten.triangular_solve.default:
299
+ CastPolicy.FP32,
300
+ torch.ops.aten.fractional_max_pool2d.default:
301
+ CastPolicy.FP32,
302
+ torch.ops.aten.fractional_max_pool3d.default:
303
+ CastPolicy.FP32,
304
+ torch.ops.aten.adaptive_max_pool3d.default:
305
+ CastPolicy.FP32,
306
+ torch.ops.aten.multilabel_margin_loss_forward.default:
307
+ CastPolicy.FP32,
308
+ torch.ops.aten.linalg_qr.default:
309
+ CastPolicy.FP32,
310
+ torch.ops.aten.linalg_cholesky_ex.default:
311
+ CastPolicy.FP32,
312
+ torch.ops.aten.linalg_svd.default:
313
+ CastPolicy.FP32,
314
+ torch.ops.aten.linalg_eig.default:
315
+ CastPolicy.FP32,
316
+ torch.ops.aten.linalg_eigh.default:
317
+ CastPolicy.FP32,
318
+ torch.ops.aten.linalg_lstsq.default:
319
+ CastPolicy.FP32,
320
+ torch.ops.aten.linalg_inv_ex.default:
321
+ CastPolicy.FP32,
322
+
323
+ # promote
324
+ torch.ops.aten.stack.default:
325
+ CastPolicy.PROMOTE,
326
+ torch.ops.aten.cat.default:
327
+ CastPolicy.PROMOTE,
328
+ torch.ops.aten.index_copy.default:
329
+ CastPolicy.PROMOTE,
330
+ torch.ops.aten.index_copy.dimname:
331
+ CastPolicy.PROMOTE,
332
+ }
torchax/config.py CHANGED
@@ -3,17 +3,28 @@ import dataclasses
3
3
 
4
4
  @dataclasses.dataclass
5
5
  class Configuration:
6
- debug_print_each_op: bool = False
7
- debug_accuracy_for_each_op: bool = False
8
- debug_mixed_tensor: bool = False
9
- debug_print_each_op_operands: bool = False
10
- use_int32_for_index: bool = False
11
-
12
- # Flash attention
13
- use_tpu_flash_attention: bool = False
14
- shmap_flash_attention: bool = False
15
-
16
- # device
17
- treat_cuda_as_jax_device: bool = True
18
- use_torch_native_for_cpu_tensor: bool = True
19
- internal_respect_torch_return_dtypes: bool = False
6
+ debug_print_each_op: bool = False
7
+ debug_accuracy_for_each_op: bool = False
8
+ debug_mixed_tensor: bool = False
9
+ debug_print_each_op_operands: bool = False
10
+
11
+ use_int32_for_index: bool = False
12
+
13
+ # normally, math between CPU torch.Tensor with torchax.Tensor is not
14
+ # allowed. However, if that torch.Tensor happens to be scalar, then we
15
+ # can use scalar * tensor math to handle it
16
+ allow_mixed_math_with_scalar_tensor: bool = True
17
+
18
+ # If true, we will convert Views into torchax.Tensors eagerly
19
+ force_materialize_views: bool = False
20
+
21
+ # Use DLPack for converting jax.Arrays <-> and torch.Tensor
22
+ use_dlpack_for_data_conversion: bool = False
23
+
24
+ # Flash attention
25
+ use_tpu_flash_attention: bool = False
26
+ shmap_flash_attention: bool = False
27
+
28
+ # device
29
+ treat_cuda_as_jax_device: bool = True
30
+ internal_respect_torch_return_dtypes: bool = False
@@ -0,0 +1,30 @@
1
+ import dataclasses
2
+
3
+
4
+ @dataclasses.dataclass
5
+ class Configuration:
6
+ debug_print_each_op: bool = False
7
+ debug_accuracy_for_each_op: bool = False
8
+ debug_mixed_tensor: bool = False
9
+ debug_print_each_op_operands: bool = False
10
+
11
+ use_int32_for_index: bool = False
12
+
13
+ # normally, math between CPU torch.Tensor with torchax.Tensor is not
14
+ # allowed. However, if that torch.Tensor happens to be scalar, then we
15
+ # can use scalar * tensor math to handle it
16
+ allow_mixed_math_with_scalar_tensor: bool = True
17
+
18
+ # If true, we will convert Views into torchax.Tensors eagerly
19
+ force_materialize_views: bool = False
20
+
21
+ # Use DLPack for converting jax.Arrays <-> and torch.Tensor
22
+ use_dlpack_for_data_conversion: bool = False
23
+
24
+ # Flash attention
25
+ use_tpu_flash_attention: bool = False
26
+ shmap_flash_attention: bool = False
27
+
28
+ # device
29
+ treat_cuda_as_jax_device: bool = True
30
+ internal_respect_torch_return_dtypes: bool = False