ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
- ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
- ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
- ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
- ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +43 -30
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +84 -73
- ai_edge_torch/generative/layers/builder.py +38 -14
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +61 -33
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +77 -62
- ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +28 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -101,6 +101,8 @@ class ModelLoader:
|
|
101
101
|
attn_value_proj: str = None
|
102
102
|
attn_fused_qkv_proj: str = None
|
103
103
|
attn_output_proj: str = None
|
104
|
+
attn_query_norm: str = None
|
105
|
+
attn_key_norm: str = None
|
104
106
|
|
105
107
|
ff_up_proj: str = None
|
106
108
|
ff_down_proj: str = None
|
@@ -221,7 +223,8 @@ class ModelLoader:
|
|
221
223
|
converted_state: Dict[str, torch.Tensor],
|
222
224
|
):
|
223
225
|
prefix = f"transformer_blocks.{idx}"
|
224
|
-
|
226
|
+
ff_config = config.block_config(idx).ff_config
|
227
|
+
if ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
|
225
228
|
ff_up_proj_name = self._names.ff_up_proj.format(idx)
|
226
229
|
ff_down_proj_name = self._names.ff_down_proj.format(idx)
|
227
230
|
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
@@ -230,7 +233,7 @@ class ModelLoader:
|
|
230
233
|
converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
|
231
234
|
f"{ff_down_proj_name}.weight"
|
232
235
|
)
|
233
|
-
if
|
236
|
+
if ff_config.use_bias:
|
234
237
|
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
235
238
|
f"{ff_up_proj_name}.bias"
|
236
239
|
)
|
@@ -250,7 +253,7 @@ class ModelLoader:
|
|
250
253
|
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
251
254
|
f"{ff_gate_proj_name}.weight"
|
252
255
|
)
|
253
|
-
if
|
256
|
+
if ff_config.use_bias:
|
254
257
|
converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
|
255
258
|
f"{ff_up_proj_name}.bias"
|
256
259
|
)
|
@@ -289,6 +292,7 @@ class ModelLoader:
|
|
289
292
|
converted_state: Dict[str, torch.Tensor],
|
290
293
|
):
|
291
294
|
prefix = f"transformer_blocks.{idx}"
|
295
|
+
attn_config = config.block_config(idx).attn_config
|
292
296
|
if self._names.attn_fused_qkv_proj:
|
293
297
|
fused_qkv_name = self._names.attn_fused_qkv_proj.format(idx)
|
294
298
|
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = state.pop(
|
@@ -300,13 +304,13 @@ class ModelLoader:
|
|
300
304
|
v_name = self._names.attn_value_proj.format(idx)
|
301
305
|
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = (
|
302
306
|
self._fuse_qkv(
|
303
|
-
|
307
|
+
attn_config,
|
304
308
|
state.pop(f"{q_name}.weight"),
|
305
309
|
state.pop(f"{k_name}.weight"),
|
306
310
|
state.pop(f"{v_name}.weight"),
|
307
311
|
)
|
308
312
|
)
|
309
|
-
if
|
313
|
+
if attn_config.qkv_use_bias:
|
310
314
|
if self._names.attn_fused_qkv_proj:
|
311
315
|
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = state.pop(
|
312
316
|
f"{fused_qkv_name}.bias"
|
@@ -314,18 +318,29 @@ class ModelLoader:
|
|
314
318
|
else:
|
315
319
|
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = (
|
316
320
|
self._fuse_qkv(
|
317
|
-
|
321
|
+
attn_config,
|
318
322
|
state.pop(f"{q_name}.bias"),
|
319
323
|
state.pop(f"{k_name}.bias"),
|
320
324
|
state.pop(f"{v_name}.bias"),
|
321
325
|
)
|
322
326
|
)
|
323
327
|
|
328
|
+
if self._names.attn_query_norm is not None:
|
329
|
+
attn_query_norm_name = self._names.attn_query_norm.format(idx)
|
330
|
+
converted_state[f"{prefix}.atten_func.query_norm.weight"] = state.pop(
|
331
|
+
f"{attn_query_norm_name}.weight"
|
332
|
+
)
|
333
|
+
if self._names.attn_key_norm is not None:
|
334
|
+
attn_key_norm_name = self._names.attn_key_norm.format(idx)
|
335
|
+
converted_state[f"{prefix}.atten_func.key_norm.weight"] = state.pop(
|
336
|
+
f"{attn_key_norm_name}.weight"
|
337
|
+
)
|
338
|
+
|
324
339
|
o_name = self._names.attn_output_proj.format(idx)
|
325
340
|
converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
|
326
341
|
state.pop(f"{o_name}.weight")
|
327
342
|
)
|
328
|
-
if
|
343
|
+
if attn_config.output_proj_use_bias:
|
329
344
|
converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
|
330
345
|
state.pop(f"{o_name}.bias")
|
331
346
|
)
|
@@ -360,18 +375,16 @@ class ModelLoader:
|
|
360
375
|
|
361
376
|
def _fuse_qkv(
|
362
377
|
self,
|
363
|
-
|
378
|
+
attn_config: model_config.AttentionConfig,
|
364
379
|
q: torch.Tensor,
|
365
380
|
k: torch.Tensor,
|
366
381
|
v: torch.Tensor,
|
367
382
|
) -> torch.Tensor:
|
368
|
-
if
|
369
|
-
q_per_kv =
|
370
|
-
|
371
|
-
)
|
372
|
-
|
373
|
-
ks = torch.split(k, config.attn_config.head_dim)
|
374
|
-
vs = torch.split(v, config.attn_config.head_dim)
|
383
|
+
if attn_config.qkv_fused_interleaved:
|
384
|
+
q_per_kv = attn_config.num_heads // attn_config.num_query_groups
|
385
|
+
qs = torch.split(q, attn_config.head_dim * q_per_kv)
|
386
|
+
ks = torch.split(k, attn_config.head_dim)
|
387
|
+
vs = torch.split(v, attn_config.head_dim)
|
375
388
|
cycled = [t for group in zip(qs, ks, vs) for t in group]
|
376
389
|
return torch.cat(cycled)
|
377
390
|
else:
|
@@ -279,7 +279,8 @@ class ModelLoader:
|
|
279
279
|
prefix = additional_prefix + f"transformer_blocks.{idx}"
|
280
280
|
if names.ff_up_proj is None or names.ff_down_proj is None:
|
281
281
|
return
|
282
|
-
|
282
|
+
ff_config = config.block_config(idx).ff_config
|
283
|
+
if ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
|
283
284
|
ff_up_proj_name = names.ff_up_proj.format(idx)
|
284
285
|
ff_down_proj_name = names.ff_down_proj.format(idx)
|
285
286
|
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
@@ -288,7 +289,7 @@ class ModelLoader:
|
|
288
289
|
converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
|
289
290
|
f"{ff_down_proj_name}.weight"
|
290
291
|
)
|
291
|
-
if
|
292
|
+
if ff_config.use_bias:
|
292
293
|
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
293
294
|
f"{ff_up_proj_name}.bias"
|
294
295
|
)
|
@@ -309,7 +310,7 @@ class ModelLoader:
|
|
309
310
|
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
310
311
|
f"{ff_gate_proj_name}.weight"
|
311
312
|
)
|
312
|
-
if
|
313
|
+
if ff_config.use_bias:
|
313
314
|
converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
|
314
315
|
f"{ff_up_proj_name}.bias"
|
315
316
|
)
|
@@ -337,20 +338,21 @@ class ModelLoader:
|
|
337
338
|
):
|
338
339
|
return
|
339
340
|
prefix = additional_prefix + f"transformer_blocks.{idx}"
|
341
|
+
attn_config = config.block_config(idx).attn_config
|
340
342
|
q_name = names.attn_query_proj.format(idx)
|
341
343
|
k_name = names.attn_key_proj.format(idx)
|
342
344
|
v_name = names.attn_value_proj.format(idx)
|
343
345
|
# model.encoder.transformer_blocks[0].atten_func.q_projection.weight
|
344
346
|
if fuse_attention:
|
345
347
|
converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
|
346
|
-
|
348
|
+
attn_config,
|
347
349
|
state.pop(f"{q_name}.weight"),
|
348
350
|
state.pop(f"{k_name}.weight"),
|
349
351
|
state.pop(f"{v_name}.weight"),
|
350
352
|
)
|
351
|
-
if
|
353
|
+
if attn_config.qkv_use_bias:
|
352
354
|
converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv(
|
353
|
-
|
355
|
+
attn_config,
|
354
356
|
state.pop(f"{q_name}.bias"),
|
355
357
|
state.pop(f"{k_name}.bias"),
|
356
358
|
state.pop(f"{v_name}.bias"),
|
@@ -365,7 +367,7 @@ class ModelLoader:
|
|
365
367
|
converted_state[f"{prefix}.atten_func.v_projection.weight"] = state.pop(
|
366
368
|
f"{v_name}.weight"
|
367
369
|
)
|
368
|
-
if
|
370
|
+
if attn_config.qkv_use_bias:
|
369
371
|
converted_state[f"{prefix}.atten_func.q_projection.bias"] = state.pop(
|
370
372
|
f"{q_name}.bias"
|
371
373
|
)
|
@@ -380,7 +382,7 @@ class ModelLoader:
|
|
380
382
|
converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
|
381
383
|
state.pop(f"{o_name}.weight")
|
382
384
|
)
|
383
|
-
if
|
385
|
+
if attn_config.output_proj_use_bias:
|
384
386
|
converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
|
385
387
|
state.pop(f"{o_name}.bias")
|
386
388
|
)
|
@@ -402,6 +404,7 @@ class ModelLoader:
|
|
402
404
|
):
|
403
405
|
return
|
404
406
|
prefix = additional_prefix + f"transformer_blocks.{idx}"
|
407
|
+
attn_config = config.block_config(idx).attn_config
|
405
408
|
q_name = names.cross_attn_query_proj.format(idx)
|
406
409
|
k_name = names.cross_attn_key_proj.format(idx)
|
407
410
|
v_name = names.cross_attn_value_proj.format(idx)
|
@@ -409,16 +412,16 @@ class ModelLoader:
|
|
409
412
|
if fuse_attention:
|
410
413
|
converted_state[f"{prefix}.cross_atten_func.attn.weight"] = (
|
411
414
|
self._fuse_qkv(
|
412
|
-
|
415
|
+
attn_config,
|
413
416
|
state.pop(f"{q_name}.weight"),
|
414
417
|
state.pop(f"{k_name}.weight"),
|
415
418
|
state.pop(f"{v_name}.weight"),
|
416
419
|
)
|
417
420
|
)
|
418
|
-
if
|
421
|
+
if attn_config.qkv_use_bias:
|
419
422
|
converted_state[f"{prefix}.cross_atten_func.attn.bias"] = (
|
420
423
|
self._fuse_qkv(
|
421
|
-
|
424
|
+
attn_config,
|
422
425
|
state.pop(f"{q_name}.bias"),
|
423
426
|
state.pop(f"{k_name}.bias"),
|
424
427
|
state.pop(f"{v_name}.bias"),
|
@@ -434,7 +437,7 @@ class ModelLoader:
|
|
434
437
|
converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = (
|
435
438
|
state.pop(f"{v_name}.weight")
|
436
439
|
)
|
437
|
-
if
|
440
|
+
if attn_config.qkv_use_bias:
|
438
441
|
converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = (
|
439
442
|
state.pop(f"{q_name}.bias")
|
440
443
|
)
|
@@ -449,7 +452,7 @@ class ModelLoader:
|
|
449
452
|
converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = (
|
450
453
|
state.pop(f"{o_name}.weight")
|
451
454
|
)
|
452
|
-
if
|
455
|
+
if attn_config.output_proj_use_bias:
|
453
456
|
converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = (
|
454
457
|
state.pop(f"{o_name}.bias")
|
455
458
|
)
|
@@ -496,16 +499,14 @@ class ModelLoader:
|
|
496
499
|
|
497
500
|
def _fuse_qkv(
|
498
501
|
self,
|
499
|
-
|
502
|
+
attn_config: model_config.AttentionConfig,
|
500
503
|
q: torch.Tensor,
|
501
504
|
k: torch.Tensor,
|
502
505
|
v: torch.Tensor,
|
503
506
|
) -> torch.Tensor:
|
504
|
-
q_per_kv =
|
505
|
-
|
506
|
-
)
|
507
|
-
|
508
|
-
ks = torch.split(k, config.attn_config.head_dim)
|
509
|
-
vs = torch.split(v, config.attn_config.head_dim)
|
507
|
+
q_per_kv = attn_config.num_heads // attn_config.num_query_groups
|
508
|
+
qs = torch.split(q, attn_config.head_dim * q_per_kv)
|
509
|
+
ks = torch.split(k, attn_config.head_dim)
|
510
|
+
vs = torch.split(v, attn_config.head_dim)
|
510
511
|
cycled = [t for group in zip(qs, ks, vs) for t in group]
|
511
512
|
return torch.cat(cycled)
|
@@ -223,6 +223,41 @@ class MlirLowered:
|
|
223
223
|
return tf_integration.mlir_to_flatbuffer(self)
|
224
224
|
|
225
225
|
|
226
|
+
# TODO(b/331481564) Make this a ai_edge_torch FX pass.
|
227
|
+
def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
228
|
+
"""Convert internal constant aten ops' output from int64 to int32.
|
229
|
+
|
230
|
+
Int32 generally has better performance and compatibility than int64 in
|
231
|
+
runtime. This pass converts aten op where the output(s) are int64 constant
|
232
|
+
tensors to return int32 constant tensors.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
exported_program: The exported program to apply the pass.
|
236
|
+
"""
|
237
|
+
|
238
|
+
def in_i32(x: int):
|
239
|
+
return -2147483648 <= x <= 2147483647
|
240
|
+
|
241
|
+
def rewrite_arange(node: torch.fx.Node):
|
242
|
+
tensor_meta = node.meta.get("tensor_meta", None)
|
243
|
+
if not tensor_meta:
|
244
|
+
return
|
245
|
+
|
246
|
+
start, end = node.args[:2]
|
247
|
+
if tensor_meta.dtype != torch.int64:
|
248
|
+
return
|
249
|
+
if not (in_i32(start) and in_i32(end)):
|
250
|
+
return
|
251
|
+
op = node.target
|
252
|
+
node.target = lambda *args, **kwargs: op(*args, **kwargs).type(torch.int32)
|
253
|
+
|
254
|
+
graph_module = exported_program.graph_module
|
255
|
+
for node in graph_module.graph.nodes:
|
256
|
+
|
257
|
+
if node.target == torch.ops.aten.arange.start_step:
|
258
|
+
rewrite_arange(node)
|
259
|
+
|
260
|
+
|
226
261
|
def exported_program_to_mlir(
|
227
262
|
exported_program: torch.export.ExportedProgram,
|
228
263
|
) -> MlirLowered:
|
@@ -231,6 +266,11 @@ def exported_program_to_mlir(
|
|
231
266
|
lowerings.decompositions()
|
232
267
|
)
|
233
268
|
|
269
|
+
_convert_i64_to_i32(exported_program)
|
270
|
+
exported_program = exported_program.run_decompositions(
|
271
|
+
lowerings.decompositions()
|
272
|
+
)
|
273
|
+
|
234
274
|
with export_utils.create_ir_context() as context, ir.Location.unknown():
|
235
275
|
|
236
276
|
module = ir.Module.create()
|
@@ -202,3 +202,47 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
|
|
202
202
|
x, y = utils.broadcast_args_if_needed(x, y)
|
203
203
|
|
204
204
|
return stablehlo.divide(x, y)
|
205
|
+
|
206
|
+
|
207
|
+
# Schema:
|
208
|
+
# - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
|
209
|
+
# start=None, SymInt? end=None, SymInt step=1) -> Tensor
|
210
|
+
# Torch Reference:
|
211
|
+
# - https://pytorch.org/docs/stable/generated/torch.slice_scatter.html
|
212
|
+
# - https://github.com/pytorch/pytorch/blob/18f9331e5deb4c02ae5c206e133a9b4add49bd97/aten/src/ATen/native/TensorShape.cpp#L4002
|
213
|
+
@lower(torch.ops.aten.slice_scatter)
|
214
|
+
def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
215
|
+
start = start or 0
|
216
|
+
end = end or self.type.shape[dim]
|
217
|
+
if start < 0:
|
218
|
+
start = self.type.shape[dim] + start
|
219
|
+
if end < 0:
|
220
|
+
end = self.type.shape[dim] + end
|
221
|
+
|
222
|
+
end = start + step * math.ceil((end - start) / step) - (step - 1)
|
223
|
+
|
224
|
+
padding_low = start
|
225
|
+
padding_high = self.type.shape[dim] - end
|
226
|
+
|
227
|
+
rank = len(self.type.shape)
|
228
|
+
src = stablehlo.pad(
|
229
|
+
src,
|
230
|
+
utils.splat(0, src.type.element_type, []),
|
231
|
+
edge_padding_low=[padding_low if i == dim else 0 for i in range(rank)],
|
232
|
+
edge_padding_high=[padding_high if i == dim else 0 for i in range(rank)],
|
233
|
+
interior_padding=[step - 1 if i == dim else 0 for i in range(rank)],
|
234
|
+
)
|
235
|
+
pred = np.ones(self.type.shape, dtype=np.bool_)
|
236
|
+
pred[*[
|
237
|
+
slice(start, end, step) if i == dim else slice(None, None, None)
|
238
|
+
for i in range(rank)
|
239
|
+
]] = False
|
240
|
+
pred = stablehlo.constant(
|
241
|
+
ir.DenseElementsAttr.get(
|
242
|
+
np.packbits(pred, bitorder="little"),
|
243
|
+
type=ir.IntegerType.get_signless(1),
|
244
|
+
shape=pred.shape,
|
245
|
+
)
|
246
|
+
)
|
247
|
+
out = stablehlo.select(pred, self, src)
|
248
|
+
return out
|
@@ -167,7 +167,6 @@ lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
|
|
167
167
|
lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
|
168
168
|
lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
|
169
169
|
lower_by_torch_xla2(torch.ops.aten.native_group_norm)
|
170
|
-
lower_by_torch_xla2(torch.ops.aten.native_layer_norm)
|
171
170
|
lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
|
172
171
|
lower_by_torch_xla2(torch.ops.aten.ne)
|
173
172
|
lower_by_torch_xla2(torch.ops.aten.neg)
|
@@ -204,7 +203,6 @@ lower_by_torch_xla2(torch.ops.aten.sin)
|
|
204
203
|
lower_by_torch_xla2(torch.ops.aten.sinh)
|
205
204
|
lower_by_torch_xla2(torch.ops.aten.slice)
|
206
205
|
lower_by_torch_xla2(torch.ops.aten.slice_copy)
|
207
|
-
lower_by_torch_xla2(torch.ops.aten.slice_scatter)
|
208
206
|
lower_by_torch_xla2(torch.ops.aten.sort)
|
209
207
|
lower_by_torch_xla2(torch.ops.aten.split)
|
210
208
|
lower_by_torch_xla2(torch.ops.aten.split_copy)
|
@@ -0,0 +1,78 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Provides lowering for coreaten to stablehlo for LayerNorm."""
|
16
|
+
|
17
|
+
import math
|
18
|
+
from typing import Optional
|
19
|
+
from ai_edge_torch.odml_torch.lowerings import registry
|
20
|
+
from ai_edge_torch.odml_torch.lowerings import utils
|
21
|
+
from jax._src.lib.mlir import ir
|
22
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
23
|
+
import torch
|
24
|
+
|
25
|
+
|
26
|
+
# native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight,
|
27
|
+
# Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
|
28
|
+
@registry.lower(torch.ops.aten.native_layer_norm)
|
29
|
+
def _aten_native_layer_norm(
|
30
|
+
lctx,
|
31
|
+
data: ir.Value,
|
32
|
+
normalized_shape: list[int],
|
33
|
+
weight: Optional[ir.Value],
|
34
|
+
bias: Optional[ir.Value],
|
35
|
+
eps: float,
|
36
|
+
):
|
37
|
+
data_type: ir.RankedTensorType = data.type
|
38
|
+
unnormalized_count = math.prod(data_type.shape) // math.prod(normalized_shape)
|
39
|
+
dest_shape = [
|
40
|
+
1,
|
41
|
+
unnormalized_count,
|
42
|
+
math.prod(normalized_shape),
|
43
|
+
]
|
44
|
+
dest_type = ir.RankedTensorType.get(dest_shape, data_type.element_type)
|
45
|
+
|
46
|
+
reshaped_data = stablehlo.reshape(dest_type, data)
|
47
|
+
|
48
|
+
one = utils.splat(1, data_type.element_type, [unnormalized_count])
|
49
|
+
zero = utils.splat(0, data_type.element_type, [unnormalized_count])
|
50
|
+
output, mean, var = stablehlo.batch_norm_training(
|
51
|
+
reshaped_data, one, zero, eps, 1
|
52
|
+
)
|
53
|
+
eps_splat = utils.splat(eps, var.type.element_type, var.type.shape)
|
54
|
+
rstd = stablehlo.rsqrt(stablehlo.add(var, eps_splat))
|
55
|
+
|
56
|
+
stats_shape = data_type.shape[: -1 * len(normalized_shape)] + [1] * len(
|
57
|
+
normalized_shape
|
58
|
+
)
|
59
|
+
stats_type = ir.RankedTensorType.get(stats_shape, data_type.element_type)
|
60
|
+
mean = stablehlo.reshape(stats_type, mean)
|
61
|
+
rstd = stablehlo.reshape(stats_type, rstd)
|
62
|
+
|
63
|
+
output = stablehlo.reshape(data_type, output)
|
64
|
+
|
65
|
+
data_rank = len(data_type.shape)
|
66
|
+
normalized_rank = len(normalized_shape)
|
67
|
+
if weight is not None:
|
68
|
+
weight = stablehlo.broadcast_in_dim(
|
69
|
+
data_type, weight, list(range(data_rank - normalized_rank, data_rank))
|
70
|
+
)
|
71
|
+
output = stablehlo.multiply(weight, output)
|
72
|
+
if bias is not None:
|
73
|
+
bias = stablehlo.broadcast_in_dim(
|
74
|
+
data_type, bias, list(range(data_rank - normalized_rank, data_rank))
|
75
|
+
)
|
76
|
+
output = stablehlo.add(bias, output)
|
77
|
+
|
78
|
+
return output, mean, rstd
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240914
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|