ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
  4. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
  5. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
  6. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
  7. ai_edge_torch/config.py +4 -1
  8. ai_edge_torch/fx_pass_base.py +101 -0
  9. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
  10. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
  11. ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
  13. ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
  14. ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
  15. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
  16. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
  17. ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
  18. ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
  19. ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
  20. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  21. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
  22. ai_edge_torch/generative/examples/t5/t5.py +43 -30
  23. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  24. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  25. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
  26. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
  27. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
  28. ai_edge_torch/generative/fx_passes/__init__.py +4 -4
  29. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
  30. ai_edge_torch/generative/layers/attention.py +84 -73
  31. ai_edge_torch/generative/layers/builder.py +38 -14
  32. ai_edge_torch/generative/layers/feed_forward.py +26 -8
  33. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  34. ai_edge_torch/generative/layers/model_config.py +61 -33
  35. ai_edge_torch/generative/layers/normalization.py +158 -0
  36. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  37. ai_edge_torch/generative/quantize/example.py +2 -2
  38. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  39. ai_edge_torch/generative/test/test_loader.py +1 -1
  40. ai_edge_torch/generative/test/test_model_conversion.py +77 -62
  41. ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
  42. ai_edge_torch/generative/test/test_quantize.py +5 -5
  43. ai_edge_torch/generative/test/utils.py +54 -0
  44. ai_edge_torch/generative/utilities/loader.py +28 -15
  45. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  46. ai_edge_torch/odml_torch/export.py +40 -0
  47. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  48. ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
  49. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
  50. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  51. ai_edge_torch/version.py +1 -1
  52. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
  53. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
  54. ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
  55. ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
  56. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  57. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  58. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  59. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  60. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  61. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  62. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  63. /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
  64. /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
  65. /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
  66. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
  67. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
  68. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -101,6 +101,8 @@ class ModelLoader:
101
101
  attn_value_proj: str = None
102
102
  attn_fused_qkv_proj: str = None
103
103
  attn_output_proj: str = None
104
+ attn_query_norm: str = None
105
+ attn_key_norm: str = None
104
106
 
105
107
  ff_up_proj: str = None
106
108
  ff_down_proj: str = None
@@ -221,7 +223,8 @@ class ModelLoader:
221
223
  converted_state: Dict[str, torch.Tensor],
222
224
  ):
223
225
  prefix = f"transformer_blocks.{idx}"
224
- if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
226
+ ff_config = config.block_config(idx).ff_config
227
+ if ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
225
228
  ff_up_proj_name = self._names.ff_up_proj.format(idx)
226
229
  ff_down_proj_name = self._names.ff_down_proj.format(idx)
227
230
  converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
@@ -230,7 +233,7 @@ class ModelLoader:
230
233
  converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
231
234
  f"{ff_down_proj_name}.weight"
232
235
  )
233
- if config.ff_config.use_bias:
236
+ if ff_config.use_bias:
234
237
  converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
235
238
  f"{ff_up_proj_name}.bias"
236
239
  )
@@ -250,7 +253,7 @@ class ModelLoader:
250
253
  converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
251
254
  f"{ff_gate_proj_name}.weight"
252
255
  )
253
- if config.ff_config.use_bias:
256
+ if ff_config.use_bias:
254
257
  converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
255
258
  f"{ff_up_proj_name}.bias"
256
259
  )
@@ -289,6 +292,7 @@ class ModelLoader:
289
292
  converted_state: Dict[str, torch.Tensor],
290
293
  ):
291
294
  prefix = f"transformer_blocks.{idx}"
295
+ attn_config = config.block_config(idx).attn_config
292
296
  if self._names.attn_fused_qkv_proj:
293
297
  fused_qkv_name = self._names.attn_fused_qkv_proj.format(idx)
294
298
  converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = state.pop(
@@ -300,13 +304,13 @@ class ModelLoader:
300
304
  v_name = self._names.attn_value_proj.format(idx)
301
305
  converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = (
302
306
  self._fuse_qkv(
303
- config,
307
+ attn_config,
304
308
  state.pop(f"{q_name}.weight"),
305
309
  state.pop(f"{k_name}.weight"),
306
310
  state.pop(f"{v_name}.weight"),
307
311
  )
308
312
  )
309
- if config.attn_config.qkv_use_bias:
313
+ if attn_config.qkv_use_bias:
310
314
  if self._names.attn_fused_qkv_proj:
311
315
  converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = state.pop(
312
316
  f"{fused_qkv_name}.bias"
@@ -314,18 +318,29 @@ class ModelLoader:
314
318
  else:
315
319
  converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = (
316
320
  self._fuse_qkv(
317
- config,
321
+ attn_config,
318
322
  state.pop(f"{q_name}.bias"),
319
323
  state.pop(f"{k_name}.bias"),
320
324
  state.pop(f"{v_name}.bias"),
321
325
  )
322
326
  )
323
327
 
328
+ if self._names.attn_query_norm is not None:
329
+ attn_query_norm_name = self._names.attn_query_norm.format(idx)
330
+ converted_state[f"{prefix}.atten_func.query_norm.weight"] = state.pop(
331
+ f"{attn_query_norm_name}.weight"
332
+ )
333
+ if self._names.attn_key_norm is not None:
334
+ attn_key_norm_name = self._names.attn_key_norm.format(idx)
335
+ converted_state[f"{prefix}.atten_func.key_norm.weight"] = state.pop(
336
+ f"{attn_key_norm_name}.weight"
337
+ )
338
+
324
339
  o_name = self._names.attn_output_proj.format(idx)
325
340
  converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
326
341
  state.pop(f"{o_name}.weight")
327
342
  )
328
- if config.attn_config.output_proj_use_bias:
343
+ if attn_config.output_proj_use_bias:
329
344
  converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
330
345
  state.pop(f"{o_name}.bias")
331
346
  )
@@ -360,18 +375,16 @@ class ModelLoader:
360
375
 
361
376
  def _fuse_qkv(
362
377
  self,
363
- config: model_config.ModelConfig,
378
+ attn_config: model_config.AttentionConfig,
364
379
  q: torch.Tensor,
365
380
  k: torch.Tensor,
366
381
  v: torch.Tensor,
367
382
  ) -> torch.Tensor:
368
- if config.attn_config.qkv_fused_interleaved:
369
- q_per_kv = (
370
- config.attn_config.num_heads // config.attn_config.num_query_groups
371
- )
372
- qs = torch.split(q, config.attn_config.head_dim * q_per_kv)
373
- ks = torch.split(k, config.attn_config.head_dim)
374
- vs = torch.split(v, config.attn_config.head_dim)
383
+ if attn_config.qkv_fused_interleaved:
384
+ q_per_kv = attn_config.num_heads // attn_config.num_query_groups
385
+ qs = torch.split(q, attn_config.head_dim * q_per_kv)
386
+ ks = torch.split(k, attn_config.head_dim)
387
+ vs = torch.split(v, attn_config.head_dim)
375
388
  cycled = [t for group in zip(qs, ks, vs) for t in group]
376
389
  return torch.cat(cycled)
377
390
  else:
@@ -279,7 +279,8 @@ class ModelLoader:
279
279
  prefix = additional_prefix + f"transformer_blocks.{idx}"
280
280
  if names.ff_up_proj is None or names.ff_down_proj is None:
281
281
  return
282
- if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
282
+ ff_config = config.block_config(idx).ff_config
283
+ if ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
283
284
  ff_up_proj_name = names.ff_up_proj.format(idx)
284
285
  ff_down_proj_name = names.ff_down_proj.format(idx)
285
286
  converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
@@ -288,7 +289,7 @@ class ModelLoader:
288
289
  converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
289
290
  f"{ff_down_proj_name}.weight"
290
291
  )
291
- if config.ff_config.use_bias:
292
+ if ff_config.use_bias:
292
293
  converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
293
294
  f"{ff_up_proj_name}.bias"
294
295
  )
@@ -309,7 +310,7 @@ class ModelLoader:
309
310
  converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
310
311
  f"{ff_gate_proj_name}.weight"
311
312
  )
312
- if config.ff_config.use_bias:
313
+ if ff_config.use_bias:
313
314
  converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
314
315
  f"{ff_up_proj_name}.bias"
315
316
  )
@@ -337,20 +338,21 @@ class ModelLoader:
337
338
  ):
338
339
  return
339
340
  prefix = additional_prefix + f"transformer_blocks.{idx}"
341
+ attn_config = config.block_config(idx).attn_config
340
342
  q_name = names.attn_query_proj.format(idx)
341
343
  k_name = names.attn_key_proj.format(idx)
342
344
  v_name = names.attn_value_proj.format(idx)
343
345
  # model.encoder.transformer_blocks[0].atten_func.q_projection.weight
344
346
  if fuse_attention:
345
347
  converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
346
- config,
348
+ attn_config,
347
349
  state.pop(f"{q_name}.weight"),
348
350
  state.pop(f"{k_name}.weight"),
349
351
  state.pop(f"{v_name}.weight"),
350
352
  )
351
- if config.attn_config.qkv_use_bias:
353
+ if attn_config.qkv_use_bias:
352
354
  converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv(
353
- config,
355
+ attn_config,
354
356
  state.pop(f"{q_name}.bias"),
355
357
  state.pop(f"{k_name}.bias"),
356
358
  state.pop(f"{v_name}.bias"),
@@ -365,7 +367,7 @@ class ModelLoader:
365
367
  converted_state[f"{prefix}.atten_func.v_projection.weight"] = state.pop(
366
368
  f"{v_name}.weight"
367
369
  )
368
- if config.attn_config.qkv_use_bias:
370
+ if attn_config.qkv_use_bias:
369
371
  converted_state[f"{prefix}.atten_func.q_projection.bias"] = state.pop(
370
372
  f"{q_name}.bias"
371
373
  )
@@ -380,7 +382,7 @@ class ModelLoader:
380
382
  converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
381
383
  state.pop(f"{o_name}.weight")
382
384
  )
383
- if config.attn_config.output_proj_use_bias:
385
+ if attn_config.output_proj_use_bias:
384
386
  converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
385
387
  state.pop(f"{o_name}.bias")
386
388
  )
@@ -402,6 +404,7 @@ class ModelLoader:
402
404
  ):
403
405
  return
404
406
  prefix = additional_prefix + f"transformer_blocks.{idx}"
407
+ attn_config = config.block_config(idx).attn_config
405
408
  q_name = names.cross_attn_query_proj.format(idx)
406
409
  k_name = names.cross_attn_key_proj.format(idx)
407
410
  v_name = names.cross_attn_value_proj.format(idx)
@@ -409,16 +412,16 @@ class ModelLoader:
409
412
  if fuse_attention:
410
413
  converted_state[f"{prefix}.cross_atten_func.attn.weight"] = (
411
414
  self._fuse_qkv(
412
- config,
415
+ attn_config,
413
416
  state.pop(f"{q_name}.weight"),
414
417
  state.pop(f"{k_name}.weight"),
415
418
  state.pop(f"{v_name}.weight"),
416
419
  )
417
420
  )
418
- if config.attn_config.qkv_use_bias:
421
+ if attn_config.qkv_use_bias:
419
422
  converted_state[f"{prefix}.cross_atten_func.attn.bias"] = (
420
423
  self._fuse_qkv(
421
- config,
424
+ attn_config,
422
425
  state.pop(f"{q_name}.bias"),
423
426
  state.pop(f"{k_name}.bias"),
424
427
  state.pop(f"{v_name}.bias"),
@@ -434,7 +437,7 @@ class ModelLoader:
434
437
  converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = (
435
438
  state.pop(f"{v_name}.weight")
436
439
  )
437
- if config.attn_config.qkv_use_bias:
440
+ if attn_config.qkv_use_bias:
438
441
  converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = (
439
442
  state.pop(f"{q_name}.bias")
440
443
  )
@@ -449,7 +452,7 @@ class ModelLoader:
449
452
  converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = (
450
453
  state.pop(f"{o_name}.weight")
451
454
  )
452
- if config.attn_config.output_proj_use_bias:
455
+ if attn_config.output_proj_use_bias:
453
456
  converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = (
454
457
  state.pop(f"{o_name}.bias")
455
458
  )
@@ -496,16 +499,14 @@ class ModelLoader:
496
499
 
497
500
  def _fuse_qkv(
498
501
  self,
499
- config: model_config.ModelConfig,
502
+ attn_config: model_config.AttentionConfig,
500
503
  q: torch.Tensor,
501
504
  k: torch.Tensor,
502
505
  v: torch.Tensor,
503
506
  ) -> torch.Tensor:
504
- q_per_kv = (
505
- config.attn_config.num_heads // config.attn_config.num_query_groups
506
- )
507
- qs = torch.split(q, config.attn_config.head_dim * q_per_kv)
508
- ks = torch.split(k, config.attn_config.head_dim)
509
- vs = torch.split(v, config.attn_config.head_dim)
507
+ q_per_kv = attn_config.num_heads // attn_config.num_query_groups
508
+ qs = torch.split(q, attn_config.head_dim * q_per_kv)
509
+ ks = torch.split(k, attn_config.head_dim)
510
+ vs = torch.split(v, attn_config.head_dim)
510
511
  cycled = [t for group in zip(qs, ks, vs) for t in group]
511
512
  return torch.cat(cycled)
@@ -223,6 +223,41 @@ class MlirLowered:
223
223
  return tf_integration.mlir_to_flatbuffer(self)
224
224
 
225
225
 
226
+ # TODO(b/331481564) Make this a ai_edge_torch FX pass.
227
+ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
228
+ """Convert internal constant aten ops' output from int64 to int32.
229
+
230
+ Int32 generally has better performance and compatibility than int64 in
231
+ runtime. This pass converts aten op where the output(s) are int64 constant
232
+ tensors to return int32 constant tensors.
233
+
234
+ Args:
235
+ exported_program: The exported program to apply the pass.
236
+ """
237
+
238
+ def in_i32(x: int):
239
+ return -2147483648 <= x <= 2147483647
240
+
241
+ def rewrite_arange(node: torch.fx.Node):
242
+ tensor_meta = node.meta.get("tensor_meta", None)
243
+ if not tensor_meta:
244
+ return
245
+
246
+ start, end = node.args[:2]
247
+ if tensor_meta.dtype != torch.int64:
248
+ return
249
+ if not (in_i32(start) and in_i32(end)):
250
+ return
251
+ op = node.target
252
+ node.target = lambda *args, **kwargs: op(*args, **kwargs).type(torch.int32)
253
+
254
+ graph_module = exported_program.graph_module
255
+ for node in graph_module.graph.nodes:
256
+
257
+ if node.target == torch.ops.aten.arange.start_step:
258
+ rewrite_arange(node)
259
+
260
+
226
261
  def exported_program_to_mlir(
227
262
  exported_program: torch.export.ExportedProgram,
228
263
  ) -> MlirLowered:
@@ -231,6 +266,11 @@ def exported_program_to_mlir(
231
266
  lowerings.decompositions()
232
267
  )
233
268
 
269
+ _convert_i64_to_i32(exported_program)
270
+ exported_program = exported_program.run_decompositions(
271
+ lowerings.decompositions()
272
+ )
273
+
234
274
  with export_utils.create_ir_context() as context, ir.Location.unknown():
235
275
 
236
276
  module = ir.Module.create()
@@ -16,6 +16,7 @@ from . import _basic
16
16
  from . import _batch_norm
17
17
  from . import _convolution
18
18
  from . import _jax_lowerings
19
+ from . import _layer_norm
19
20
  from . import context
20
21
  from . import registry
21
22
  from . import utils
@@ -202,3 +202,47 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
202
202
  x, y = utils.broadcast_args_if_needed(x, y)
203
203
 
204
204
  return stablehlo.divide(x, y)
205
+
206
+
207
+ # Schema:
208
+ # - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
209
+ # start=None, SymInt? end=None, SymInt step=1) -> Tensor
210
+ # Torch Reference:
211
+ # - https://pytorch.org/docs/stable/generated/torch.slice_scatter.html
212
+ # - https://github.com/pytorch/pytorch/blob/18f9331e5deb4c02ae5c206e133a9b4add49bd97/aten/src/ATen/native/TensorShape.cpp#L4002
213
+ @lower(torch.ops.aten.slice_scatter)
214
+ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
215
+ start = start or 0
216
+ end = end or self.type.shape[dim]
217
+ if start < 0:
218
+ start = self.type.shape[dim] + start
219
+ if end < 0:
220
+ end = self.type.shape[dim] + end
221
+
222
+ end = start + step * math.ceil((end - start) / step) - (step - 1)
223
+
224
+ padding_low = start
225
+ padding_high = self.type.shape[dim] - end
226
+
227
+ rank = len(self.type.shape)
228
+ src = stablehlo.pad(
229
+ src,
230
+ utils.splat(0, src.type.element_type, []),
231
+ edge_padding_low=[padding_low if i == dim else 0 for i in range(rank)],
232
+ edge_padding_high=[padding_high if i == dim else 0 for i in range(rank)],
233
+ interior_padding=[step - 1 if i == dim else 0 for i in range(rank)],
234
+ )
235
+ pred = np.ones(self.type.shape, dtype=np.bool_)
236
+ pred[*[
237
+ slice(start, end, step) if i == dim else slice(None, None, None)
238
+ for i in range(rank)
239
+ ]] = False
240
+ pred = stablehlo.constant(
241
+ ir.DenseElementsAttr.get(
242
+ np.packbits(pred, bitorder="little"),
243
+ type=ir.IntegerType.get_signless(1),
244
+ shape=pred.shape,
245
+ )
246
+ )
247
+ out = stablehlo.select(pred, self, src)
248
+ return out
@@ -167,7 +167,6 @@ lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
167
167
  lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
168
168
  lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
169
169
  lower_by_torch_xla2(torch.ops.aten.native_group_norm)
170
- lower_by_torch_xla2(torch.ops.aten.native_layer_norm)
171
170
  lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
172
171
  lower_by_torch_xla2(torch.ops.aten.ne)
173
172
  lower_by_torch_xla2(torch.ops.aten.neg)
@@ -204,7 +203,6 @@ lower_by_torch_xla2(torch.ops.aten.sin)
204
203
  lower_by_torch_xla2(torch.ops.aten.sinh)
205
204
  lower_by_torch_xla2(torch.ops.aten.slice)
206
205
  lower_by_torch_xla2(torch.ops.aten.slice_copy)
207
- lower_by_torch_xla2(torch.ops.aten.slice_scatter)
208
206
  lower_by_torch_xla2(torch.ops.aten.sort)
209
207
  lower_by_torch_xla2(torch.ops.aten.split)
210
208
  lower_by_torch_xla2(torch.ops.aten.split_copy)
@@ -0,0 +1,78 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Provides lowering for coreaten to stablehlo for LayerNorm."""
16
+
17
+ import math
18
+ from typing import Optional
19
+ from ai_edge_torch.odml_torch.lowerings import registry
20
+ from ai_edge_torch.odml_torch.lowerings import utils
21
+ from jax._src.lib.mlir import ir
22
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import torch
24
+
25
+
26
+ # native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight,
27
+ # Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
28
+ @registry.lower(torch.ops.aten.native_layer_norm)
29
+ def _aten_native_layer_norm(
30
+ lctx,
31
+ data: ir.Value,
32
+ normalized_shape: list[int],
33
+ weight: Optional[ir.Value],
34
+ bias: Optional[ir.Value],
35
+ eps: float,
36
+ ):
37
+ data_type: ir.RankedTensorType = data.type
38
+ unnormalized_count = math.prod(data_type.shape) // math.prod(normalized_shape)
39
+ dest_shape = [
40
+ 1,
41
+ unnormalized_count,
42
+ math.prod(normalized_shape),
43
+ ]
44
+ dest_type = ir.RankedTensorType.get(dest_shape, data_type.element_type)
45
+
46
+ reshaped_data = stablehlo.reshape(dest_type, data)
47
+
48
+ one = utils.splat(1, data_type.element_type, [unnormalized_count])
49
+ zero = utils.splat(0, data_type.element_type, [unnormalized_count])
50
+ output, mean, var = stablehlo.batch_norm_training(
51
+ reshaped_data, one, zero, eps, 1
52
+ )
53
+ eps_splat = utils.splat(eps, var.type.element_type, var.type.shape)
54
+ rstd = stablehlo.rsqrt(stablehlo.add(var, eps_splat))
55
+
56
+ stats_shape = data_type.shape[: -1 * len(normalized_shape)] + [1] * len(
57
+ normalized_shape
58
+ )
59
+ stats_type = ir.RankedTensorType.get(stats_shape, data_type.element_type)
60
+ mean = stablehlo.reshape(stats_type, mean)
61
+ rstd = stablehlo.reshape(stats_type, rstd)
62
+
63
+ output = stablehlo.reshape(data_type, output)
64
+
65
+ data_rank = len(data_type.shape)
66
+ normalized_rank = len(normalized_shape)
67
+ if weight is not None:
68
+ weight = stablehlo.broadcast_in_dim(
69
+ data_type, weight, list(range(data_rank - normalized_rank, data_rank))
70
+ )
71
+ output = stablehlo.multiply(weight, output)
72
+ if bias is not None:
73
+ bias = stablehlo.broadcast_in_dim(
74
+ data_type, bias, list(range(data_rank - normalized_rank, data_rank))
75
+ )
76
+ output = stablehlo.add(bias, output)
77
+
78
+ return output, mean, rstd
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240910"
16
+ __version__ = "0.3.0.dev20240914"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240910
3
+ Version: 0.3.0.dev20240914
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI