torchax 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/tensor.py ADDED
@@ -0,0 +1,557 @@
1
+ import logging
2
+ import sys
3
+ import contextlib
4
+ from typing import Optional, Any
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import numpy
8
+ import torch
9
+ import torch.distributed._functional_collectives
10
+ import torch.func
11
+ import torch.utils._mode_utils as mode_utils
12
+ import torch.utils._python_dispatch as torch_dispatch
13
+ import torch.utils._pytree as torch_pytree
14
+
15
+ from torchax import config
16
+ from torchax.ops import mappings, ops_registry
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class OperatorNotFound(Exception):
22
+ pass
23
+
24
+
25
+ def wrap(jaxarray):
26
+ return torch_pytree.tree_map_only(jnp.ndarray, Tensor, jaxarray)
27
+
28
+
29
+ def unwrap(torchtensors):
30
+ return torch_pytree.tree_map_only(Tensor, lambda x: x._elem, torchtensors)
31
+
32
+
33
+ def t2j(t):
34
+ if isinstance(t, Tensor):
35
+ return t._elem
36
+ return mappings.t2j(t)
37
+
38
+
39
+ def j2t(x):
40
+ return mappings.j2t(x)
41
+
42
+
43
+ def t2j_dtype(dtype):
44
+ return mappings.t2j_dtype(dtype)
45
+
46
+
47
+ def j2t_dtype(dtype):
48
+ return mappings.j2t_dtype(dtype)
49
+
50
+
51
+ @contextlib.contextmanager
52
+ def log_nested(env, message):
53
+ if env.config.debug_print_each_op:
54
+ print((' ' * log_nested.level) + message, file=sys.stderr)
55
+ log_nested.level += 1
56
+ yield
57
+ log_nested.level -= 1
58
+
59
+ log_nested.level = 0
60
+
61
+
62
+ class Tensor(torch.Tensor):
63
+
64
+ @staticmethod
65
+ def __new__(cls, elem, env):
66
+ dtype = j2t_dtype(elem.dtype)
67
+ shape = list(elem.shape)
68
+ for i, s in enumerate(shape):
69
+ if not isinstance(s, int):
70
+ shape[i] = 1
71
+ if dtype is None:
72
+ dtype = torch.float32
73
+ return torch.Tensor._make_wrapper_subclass(
74
+ cls,
75
+ shape,
76
+ dtype=dtype,
77
+ device='meta',
78
+ requires_grad=False,
79
+ )
80
+
81
+ def __init__(self, elem: jax.Array, env: 'Environment'):
82
+ super().__init__()
83
+ self._elem = elem
84
+ self._env = env
85
+
86
+ def __str__(self):
87
+ return "Tensor({} {})".format(str(type(self._elem)), str(self._elem))
88
+
89
+ __repr__ = __str__
90
+
91
+ def __jax_array__(self):
92
+ return self._elem
93
+
94
+ @property
95
+ def shape(self):
96
+ return self._elem.shape
97
+
98
+ @property
99
+ def ndim(self):
100
+ return len(self._elem.shape)
101
+
102
+ def flatten(self, start_dim=0, end_dim=-1):
103
+ if end_dim == -1:
104
+ end_dim = self.ndim
105
+ new_shape = (
106
+ self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1:])
107
+ new_elem = jnp.reshape(self._elem, new_shape)
108
+ return Tensor(new_elem, self._env)
109
+ # return torch.reshape(self, new_shape)
110
+
111
+ def __setitem__(self, key, val):
112
+ key, val = self._env.t2j_iso((key, val))
113
+ self._elem = self._elem.at[key].set(val)
114
+
115
+ def type_as(self, other):
116
+ self._elem = self._elem.astype(other._elem.dtype)
117
+ return self
118
+
119
+ __torch_function__ = torch._C._disabled_torch_function_impl
120
+
121
+ @classmethod
122
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
123
+ env = None
124
+ for arg in torch_pytree.arg_tree_leaves(*args, **kwargs):
125
+ if isinstance(arg, Tensor):
126
+ env = arg._env
127
+ break
128
+
129
+ with env:
130
+ return func(*args, **(kwargs or {}))
131
+
132
+ def detach(self):
133
+ return Tensor(jax.lax.stop_gradient(self.jax()), self._env)
134
+
135
+ def numpy(self) -> numpy.ndarray:
136
+ import numpy as np
137
+
138
+ return np.array(self._elem)
139
+
140
+ def jax(self) -> jax.Array:
141
+ return self._elem
142
+
143
+ def torch(self) -> torch.Tensor:
144
+ return j2t(self.jax())
145
+
146
+ @property
147
+ def dtype(self):
148
+ return j2t_dtype(self._elem.dtype)
149
+
150
+ def dim(self):
151
+ return self.ndim
152
+
153
+ @property
154
+ def device(self):
155
+ return torch.device('jax:0')
156
+
157
+ @property
158
+ def jax_device(self):
159
+ return self._elem.device
160
+
161
+ @property
162
+ def data(self):
163
+ logger.warn("In-place to .data modifications still results a copy on TPU")
164
+ return self
165
+
166
+ @data.setter
167
+ def data(self, other):
168
+ if isinstance(other, Tensor):
169
+ self._elem = other._elem
170
+
171
+ def apply_jax(self, jax_function, *args, **kwargs):
172
+ # Call a jax function on _elem
173
+ res = jax_function(self._elem, *args, **kwargs)
174
+ return self._env.j2t_iso(res)
175
+
176
+ def apply_jax_(self, jax_function, *args, **kwargs):
177
+ self._elem = jax_function(self._elem, *args, **kwargs)
178
+ return self
179
+
180
+ def tolist(self):
181
+ return self._elem.tolist()
182
+
183
+ def shard_(self, sharding):
184
+ self.apply_jax_(jax.lax.with_sharding_constraint, sharding)
185
+
186
+
187
+ def debug_accuracy(func, args, kwargs, current_output):
188
+ args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only(
189
+ torch.Tensor, lambda x: j2t(x._elem), (args, kwargs, current_output))
190
+
191
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
192
+ if 'device' in kwargs_torch:
193
+ kwargs_torch['device'] = 'cpu' # do the torch native for comparison
194
+ expected_out = func(*args_torch, **kwargs_torch)
195
+
196
+ flattened_current_out, _ = torch_pytree.tree_flatten(out_torch)
197
+ flattened_expected_out, _ = torch_pytree.tree_flatten(expected_out)
198
+
199
+ for ex, real in zip(flattened_expected_out, flattened_current_out):
200
+ if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype:
201
+ ex = ex.to(real.dtype)
202
+ try:
203
+ if (isinstance(ex, torch.Tensor) and
204
+ not torch.allclose(ex, real, atol=1e-3, equal_nan=True)):
205
+ import pdb
206
+
207
+ pdb.set_trace()
208
+ except:
209
+ import pdb
210
+
211
+ pdb.set_trace()
212
+
213
+ return True
214
+
215
+ def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
216
+ def _display(a):
217
+ if isinstance(a, torch.Tensor):
218
+ return f'Tensor of {type(a)}: {a.dtype}{a.shape}'
219
+ elif isinstance(a, jax.Array):
220
+ return f'Jax Array of {type(a)}: {a.dtype}{a.shape}'
221
+ else:
222
+ return str(a)
223
+
224
+ kwargs = kwargs or {}
225
+ title = 'DISPATCH' if is_dispatch else 'FUNCTION'
226
+ args_msg = 'args: ' + ','.join(_display(a) for a in args) if log_args else ''
227
+ kwargs_msg = 'kwargs: ' + ','.join(f'{key}: {_display(a)}' for key, a in kwargs.items()) if log_args else ''
228
+ return f'{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}'
229
+
230
+
231
+ class XLAFunctionMode(torch.overrides.TorchFunctionMode):
232
+ """Context manager that dispatches torch function calls to JAX."""
233
+
234
+ def __init__(self, env):
235
+ self.env = env
236
+
237
+ def __torch_function__(self,
238
+ func,
239
+ types,
240
+ args=(),
241
+ kwargs=None) -> torch.Tensor:
242
+ message = f'FUNCTION: {_name_of_func(func)}'
243
+ if self.env.config.debug_print_each_op_operands:
244
+ message = message + 'f'
245
+ message = _make_debug_msg(False, self.env.config.debug_print_each_op_operands,
246
+ func, args, kwargs)
247
+ with log_nested(self.env, message):
248
+ try:
249
+ return self.env.dispatch(func, types, args, kwargs)
250
+ except OperatorNotFound:
251
+ pass
252
+ if _name_of_func(func) in ('rot90'): # skip rot90 with k%4==0 due to no change
253
+ if len(args) >= 2 and type(args[1]) == int:
254
+ if ((args[1])%4 == 0):
255
+ return args[0]
256
+ return func(*args, **(kwargs or {}))
257
+
258
+
259
+ class XLADispatchMode(torch_dispatch.TorchDispatchMode):
260
+
261
+ def __init__(self, env):
262
+ self.env = env
263
+
264
+ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
265
+ message = _make_debug_msg(True, self.env.config.debug_print_each_op_operands,
266
+ func, args, kwargs)
267
+ with log_nested(self.env, message):
268
+ if isinstance(func, torch._ops.OpOverloadPacket):
269
+ with self:
270
+ return func(*args, **kwargs)
271
+ if func.namespace not in ('aten', '_c10d_functional', 'torchvision'):
272
+ return func(*args, **kwargs)
273
+ return self.env.dispatch(func, types, args, kwargs)
274
+
275
+ def _name_of_func(func):
276
+ if hasattr(func, 'name'):
277
+ return func.name()
278
+ return func.__name__
279
+
280
+
281
+ # Constructors that don't take other tensor as input
282
+ TENSOR_CONSTRUCTORS = {
283
+ torch.ones,
284
+ torch.zeros,
285
+ torch.empty,
286
+ torch.empty_strided,
287
+ torch.tensor,
288
+ torch.arange,
289
+ torch.eye,
290
+ torch.randn,
291
+ torch.rand,
292
+ torch.randint,
293
+ torch.full,
294
+ torch.as_tensor,
295
+ }
296
+
297
+
298
+ class Environment(contextlib.ContextDecorator):
299
+ """This class holds a set of configurations and "globals" needed
300
+
301
+ for executing torch program using jax.
302
+ Things included so far:
303
+
304
+ op registry
305
+ PRNGKey
306
+ Configs
307
+
308
+ Also helper functions to manipulate those.
309
+ """
310
+
311
+ _prng_key: jax.random.PRNGKey
312
+
313
+
314
+ def __init__(self, configuration=None):
315
+ self._function_mode = XLAFunctionMode(self)
316
+ self._dispatch_mode = XLADispatchMode(self)
317
+
318
+ # name is torch callable
319
+ self._ops = {}
320
+ self.load_ops()
321
+
322
+ self._mesh = None
323
+ self.config = configuration or config.Configuration()
324
+
325
+ self._manually_entered = False
326
+ self.enabled = False
327
+ self._jax_devices = set(['jax', 'jax_cpu', 'xla'])
328
+
329
+ def get_as_jax_device(self, device: Any):
330
+ if device is None:
331
+ device = torch.get_default_device()
332
+
333
+ if isinstance(device, torch.device):
334
+ device = str(device)
335
+
336
+ if (not self.config.use_torch_native_for_cpu_tensor and
337
+ device.startswith('cpu')):
338
+ return jax.devices('cpu')[0]
339
+
340
+ if self.config.treat_cuda_as_jax_device and device.startswith('cuda'):
341
+ return jax.local_devices()[0]
342
+
343
+ if device.startswith('jax'):
344
+ return jax.local_devices()[0]
345
+
346
+ return None # fallback to torch
347
+
348
+
349
+
350
+ def load_ops(self):
351
+ from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms
352
+ self._ops.update(ops_registry.all_aten_ops)
353
+ self._ops.update(ops_registry.all_torch_functions)
354
+
355
+ decomps = torch._decomp.core_aten_decompositions()
356
+ from torchax.decompositions import EXTRA_DECOMP
357
+ decomps.update(EXTRA_DECOMP)
358
+ for k, v in decomps.items():
359
+ if k not in self._ops:
360
+ self._ops[k] = ops_registry.Operator(
361
+ k,
362
+ v,
363
+ is_jax_function=False,
364
+ is_user_defined=False,
365
+ needs_env=False
366
+ )
367
+
368
+ def _to_copy(self, the_tensor, new_dtype, new_device):
369
+ if isinstance(the_tensor, Tensor):
370
+ arr = the_tensor.jax()
371
+ if new_dtype is not None and new_dtype != arr.dtype:
372
+ arr = arr.astype(mappings.t2j_dtype(new_dtype))
373
+ if new_device is not None:
374
+ # convert xla tensor to other device
375
+ # only supported is CPU
376
+ if str(new_device).startswith('cpu'):
377
+ # converting to a non-jax device: let torch native handle it
378
+ torch_tensor = j2t(arr) if isinstance(the_tensor, Tensor) else arr
379
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
380
+ return torch_tensor.to(new_device)
381
+ else:
382
+ if new_dtype is not None and new_dtype != the_tensor.dtype:
383
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
384
+ the_tensor = the_tensor.to(new_dtype)
385
+ jax_device = self.get_as_jax_device(new_device)
386
+ if jax_device:
387
+ arr = t2j(the_tensor)
388
+ arr = jax.device_put(arr, jax_device)
389
+ else:
390
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
391
+ return the_tensor.to(new_device)
392
+
393
+ return Tensor(arr, self)
394
+
395
+
396
+ def get_and_rotate_prng_key(self, generator: Optional[torch.Generator]=None):
397
+ # Always use the default `randint` to get the next seed
398
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
399
+ next_key = torch.randint(
400
+ 0, 2**32, (), dtype=torch.uint32, generator=generator).numpy()
401
+
402
+ return jax.random.key(next_key)
403
+
404
+ def _handle_tensor_constructor(self, func, args, kwargs):
405
+ device = kwargs.get('device')
406
+ jax_device = self.get_as_jax_device(device)
407
+ # TODO(qihqi) figure out better ways for device propagation
408
+ if not self._manually_entered and jax_device is None:
409
+ # let torch handle it
410
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
411
+ return func(*args, **kwargs)
412
+ with jax.default_device(jax_device):
413
+ op = self._ops.get(func)
414
+ if op is None and isinstance(func, torch._ops.OpOverload):
415
+ op = self._ops.get(func.overloadpacket)
416
+ res = op.func(*args, **kwargs)
417
+ if isinstance(res, jax.Array):
418
+ res = Tensor(res, self)
419
+ return res
420
+
421
+ def _torch_Tensor_to(self, args, kwargs):
422
+ the_tensor = args[0]
423
+ args = args[1:]
424
+ if len(args) >= 1 and isinstance(args[0], torch.Tensor):
425
+ dtype = args[0].dtype
426
+ device = args[0].device
427
+ return self._to_copy(the_tensor, dtype, device)
428
+ device = kwargs.get('device')
429
+ dtype = kwargs.get('dtype')
430
+ # args like pin_memory etc that we will ignore
431
+ args = list(filter(lambda x: not isinstance(x, bool), args))
432
+ if len(args) >= 2:
433
+ device, dtype, *_ = args
434
+ elif len(args) == 1 and isinstance(args[0], torch.dtype):
435
+ dtype = args[0]
436
+ elif len(args) == 1:
437
+ device = args[0]
438
+ return self._to_copy(the_tensor, dtype, device)
439
+
440
+
441
+ def dispatch(self, func, types, args, kwargs):
442
+
443
+ kwargs = kwargs or {}
444
+ if func in TENSOR_CONSTRUCTORS:
445
+ return self._handle_tensor_constructor(func, args, kwargs)
446
+ if func in (torch.Tensor.to, torch.ops.aten.lift_fresh.default ,torch.ops.aten._to_copy, torch.ops.aten._to_copy.default):
447
+ return self._torch_Tensor_to(args, kwargs)
448
+
449
+ # If the func doesn't act on Tensor, and is not a tensor constructor,
450
+ # We should skip and let torch handle it.
451
+
452
+ tensor_args = [t for t in torch_pytree.tree_flatten(args)[0] if isinstance(t, torch.Tensor)]
453
+ if tensor_args and all(not isinstance(t, Tensor) for t in tensor_args):
454
+ return func(*args, **kwargs)
455
+
456
+ with jax.named_scope(_name_of_func(func)):
457
+ op = self._ops.get(func)
458
+
459
+ if op is None and isinstance(func, torch._ops.OpOverloadPacket):
460
+ op = self._ops.get(func.default)
461
+
462
+ if op is None and isinstance(func, torch._ops.OpOverload):
463
+ op = self._ops.get(func.overloadpacket)
464
+
465
+ if op is None:
466
+ raise OperatorNotFound(
467
+ f'Operator with name {_name_of_func(func)} has no lowering')
468
+
469
+ old_args, old_kwargs = args, kwargs
470
+ args, kwargs = torch_pytree.tree_map_only(
471
+ torch.distributed._functional_collectives.AsyncCollectiveTensor,
472
+ torch.distributed._functional_collectives.wait_tensor,
473
+ (args, kwargs))
474
+ try:
475
+ if op.is_jax_function:
476
+ args, kwargs = self.t2j_iso((args, kwargs))
477
+ except AssertionError:
478
+ if self.config.debug_mixed_tensor:
479
+ import pdb; pdb.set_trace()
480
+ else:
481
+ raise
482
+
483
+
484
+ if op.needs_env:
485
+ kwargs['env'] = self
486
+
487
+ with self:
488
+ res = op.func(*args, **kwargs)
489
+
490
+ if op.is_jax_function:
491
+ res = self.j2t_iso(res)
492
+
493
+ if self.config.debug_accuracy_for_each_op:
494
+ debug_accuracy(func, old_args, old_kwargs, res)
495
+ return res
496
+
497
+ def enable_torch_modes(self):
498
+ self._dispatch_mode.__enter__()
499
+ self._function_mode.__enter__()
500
+ self.enabled = True
501
+
502
+ def disable_torch_modes(self, *exc):
503
+ if not exc:
504
+ exc = (None, None, None)
505
+ self._function_mode.__exit__(*exc)
506
+ self._dispatch_mode.__exit__(*exc)
507
+ self.enabled = False
508
+
509
+ def __enter__(self):
510
+ self.enable_torch_modes()
511
+ self._manually_entered = True
512
+ return self
513
+
514
+ def __exit__(self, *exc):
515
+ self._manually_entered = False
516
+ self.disable_torch_modes(*exc)
517
+
518
+ def _move_one_value(self, val):
519
+ if isinstance(val, torch.nn.Module):
520
+ with self:
521
+ return val.to('jax')
522
+ if isinstance(val, Tensor):
523
+ return val
524
+ if isinstance(val, torch.Tensor):
525
+ return Tensor(t2j(val), self)
526
+ return val
527
+
528
+ def to_xla(self, torchvalues):
529
+ # tensors are torch.Tensors (not XLATensor)
530
+ res = torch_pytree.tree_map(
531
+ self._move_one_value,
532
+ torchvalues)
533
+ return res
534
+
535
+ def t2j_iso(self, torchtensors):
536
+ def to_jax(x):
537
+ if isinstance(x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
538
+ x = x.wait()
539
+ assert isinstance(x, Tensor), f'Expect a Tensor but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor'
540
+ return x.jax()
541
+ return torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors)
542
+
543
+ def j2t_iso(self, jaxarray):
544
+ return torch_pytree.tree_map_only(
545
+ jnp.ndarray, lambda x: Tensor(x, self), jaxarray)
546
+
547
+ def j2t_copy(self, args):
548
+ pass
549
+
550
+ def override_op_definition(self, op_to_override, op_impl):
551
+ self._ops[op_to_override] = ops_registry.Operator(
552
+ op_to_override,
553
+ op_impl,
554
+ is_jax_function=False,
555
+ is_user_defined=True,
556
+ needs_env=False
557
+ )
@@ -0,0 +1,119 @@
1
+ # pylint: disable
2
+ import os
3
+ from typing import Any, Tuple
4
+
5
+ from jax.experimental import jax2tf
6
+ import tensorflow as tf
7
+ import torch
8
+ from torchax import export
9
+
10
+
11
+ def exported_program_to_tf_function(ep, enable_xla=True):
12
+ weights, jax_program = export.exported_program_to_jax(ep)
13
+ wrapped = lambda *args: jax_program(weights, (args,))
14
+ avals = export.extract_avals(ep)
15
+ input_signature = [
16
+ tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=f"args_{i}")
17
+ for i, t in enumerate(avals)
18
+ ]
19
+ tf_f = tf.function(
20
+ jax2tf.convert(
21
+ wrapped,
22
+ with_gradient=False,
23
+ enable_xla=enable_xla,
24
+ ),
25
+ autograph=False,
26
+ input_signature=input_signature,
27
+ )
28
+ return tf_f
29
+
30
+
31
+ def exported_program_to_tf_module(ep: torch.export.ExportedProgram,
32
+ enable_xla=True) -> tf.Module:
33
+ tfm = tf.Module()
34
+ tfm.f = exported_program_to_tf_function(ep, enable_xla)
35
+ return tfm
36
+
37
+
38
+ def save_exported_program_as_tf_saved_model(
39
+ ep: torch.export.ExportedProgram,
40
+ saved_model_dir: os.PathLike,
41
+ serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
42
+ function_alias: str = "",
43
+ enable_xla=True,
44
+ ):
45
+ """This function will export and save a pytorch ExportedProgram to tf.saved_model format.
46
+
47
+ The resulting tf.saved_model can be used inference using tf.serving model
48
+ server
49
+ or further convert to tflite flatbuffer for on-device serving.
50
+
51
+ Args:
52
+ torch_model: torch.nn.Module - model to export and save
53
+ args: Tuple[Any] - a set of args to trace the model with, i.e.
54
+ torch_model(*args) must run
55
+ saved_model_dir: os.PathLike - location to an empty directory to store the
56
+ saved_model
57
+ serving_key: str - serving key tag, this is used by tf.serving to know
58
+ which function to run.
59
+ function_alias: str - passed through saved_model.save, used to tag a
60
+ function for inference converter or other tools.
61
+ """
62
+ tfm = exported_program_to_tf_module(ep, enable_xla=enable_xla)
63
+ signatures = {
64
+ serving_key: tfm.f.get_concrete_function(*tfm.f.input_signature)
65
+ }
66
+ save_options = tf.saved_model.SaveOptions(function_aliases={
67
+ function_alias: tfm.f,
68
+ })
69
+ tf.saved_model.save(
70
+ tfm,
71
+ saved_model_dir,
72
+ signatures=signatures,
73
+ options=save_options,
74
+ )
75
+
76
+
77
+ def save_torch_module_as_tf_saved_model(
78
+ torch_model: torch.nn.Module,
79
+ args: Tuple[Any],
80
+ saved_model_dir: os.PathLike,
81
+ serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
82
+ function_alias: str = "",
83
+ enable_xla=True,
84
+ ):
85
+ """This function will export and save a pytorch nn.Module to tf.saved_model format.
86
+
87
+ The resulting tf.saved_model can be used inference using tf.serving model
88
+ server
89
+ or further convert to tflite flatbuffer for on-device serving.
90
+
91
+ Args:
92
+ torch_model: torch.nn.Module - model to export and save
93
+ args: Tuple[Any] - a set of args to trace the model with, i.e.
94
+ torch_model(*args) must run
95
+ saved_model_dir: os.PathLike - location to an empty directory to store the
96
+ saved_model
97
+ serving_key: str - serving key tag, this is used by tf.serving to know
98
+ which function to run.
99
+ function_alias: str - passed through saved_model.save, used to tag a
100
+ function for inference converter or other tools.
101
+ """
102
+ ep = torch.export.export(torch_model, args)
103
+ save_exported_program_as_tf_saved_model(ep, saved_model_dir, serving_key,
104
+ function_alias, enable_xla)
105
+
106
+
107
+ def exported_program_to_tflite_flatbuffer(ep: torch.export.ExportedProgram):
108
+ tfm = exported_program_to_tf_module(ep)
109
+ tf_concrete_func = tfm.f.get_concrete_function(*tfm.f.input_signature)
110
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
111
+ [tf_concrete_func], tfm)
112
+ tflite_model = converter.convert()
113
+ return tflite_model
114
+
115
+
116
+ def torch_module_to_tflite_flatbuffer(torch_model: torch.nn.Module,
117
+ args: Tuple[Any]):
118
+ ep = torch.export.export(torch_model, args)
119
+ return exported_program_to_tflite_flatbuffer(ep)