onnx-diagnostic 0.8.0__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +78 -22
  3. onnx_diagnostic/export/api.py +35 -5
  4. onnx_diagnostic/export/control_flow.py +511 -0
  5. onnx_diagnostic/export/control_flow_research.py +135 -0
  6. onnx_diagnostic/ext_test_case.py +33 -9
  7. onnx_diagnostic/helpers/cache_helper.py +217 -203
  8. onnx_diagnostic/helpers/helper.py +6 -2
  9. onnx_diagnostic/helpers/log_helper.py +39 -5
  10. onnx_diagnostic/helpers/memory_peak.py +2 -0
  11. onnx_diagnostic/helpers/mini_onnx_builder.py +55 -3
  12. onnx_diagnostic/helpers/onnx_helper.py +13 -16
  13. onnx_diagnostic/helpers/rt_helper.py +579 -15
  14. onnx_diagnostic/helpers/torch_helper.py +5 -0
  15. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  16. onnx_diagnostic/tasks/text2text_generation.py +1 -0
  17. onnx_diagnostic/tasks/text_generation.py +84 -54
  18. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  19. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  20. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  21. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -1
  22. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +563 -61
  23. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  24. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  25. onnx_diagnostic/torch_models/validate.py +620 -213
  26. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/METADATA +1 -1
  27. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/RECORD +30 -28
  28. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/WHEEL +0 -0
  29. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/licenses/LICENSE.txt +0 -0
  30. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import os
3
+ import warnings
3
4
  from typing import Any, Dict, List, Optional, Tuple, Union
4
5
  import numpy as np
5
6
  import onnx
@@ -10,13 +11,9 @@ from .ort_session import InferenceSessionForTorch
10
11
 
11
12
 
12
13
  def name_type_to_onnx_dtype(name: str) -> int:
13
- if name == "tensor(int64)":
14
- return onnx.TensorProto.INT64
15
- if name == "tensor(float)":
16
- return onnx.TensorProto.FLOAT
17
- if name == "tensor(float16)":
18
- return onnx.TensorProto.FLOAT16
19
- raise AssertionError(f"Unexpected value {name!r}")
14
+ assert name.startswith("tensor(") and name.endswith(")"), f"Invalid value name={name!r}"
15
+ look = name[7:-1]
16
+ return getattr(onnx.TensorProto, look.upper())
20
17
 
21
18
 
22
19
  def make_feeds(
@@ -153,7 +150,7 @@ def make_empty_cache(
153
150
  def generate_and_validate(
154
151
  model,
155
152
  input_ids: torch.Tensor,
156
- eos_token_id: int,
153
+ eos_token_id: int = 2,
157
154
  max_new_tokens: int = 100,
158
155
  session: Optional[Union[InferenceSessionForTorch, onnx.ModelProto, str]] = None,
159
156
  atol: float = 0.1,
@@ -262,10 +259,10 @@ def generate_and_validate(
262
259
  def onnx_generate(
263
260
  model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
264
261
  input_ids: torch.Tensor,
265
- eos_token_id: int,
262
+ eos_token_id: int = 2,
266
263
  max_new_tokens=100,
267
264
  return_session: bool = False,
268
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch]]:
265
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch, Dict[str, Any]]]:
269
266
  """
270
267
  Implements a simple method ``generate`` for an ONNX model.
271
268
  The function does not expect any ``position_ids`` as input.
@@ -277,7 +274,7 @@ def onnx_generate(
277
274
  :param return_session: returns the instance of class
278
275
  :class:`InferenceSessionForTorch
279
276
  <onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
280
- created if necessary
277
+ created if necessary, the function returns the feeds for the next iteration
281
278
  :return: input tokens concatenated with new tokens
282
279
 
283
280
  .. runpython::
@@ -353,12 +350,19 @@ def onnx_generate(
353
350
  input_shapes = session.input_shapes
354
351
  input_names = session.input_names
355
352
  input_types = session.input_types
353
+ has_position_ids = "position_ids" in session.input_names
356
354
 
357
355
  assert (
358
356
  len(input_names) > 2
359
357
  and input_names[:2] == ["input_ids", "attention_mask"]
360
- and input_names[2].startswith("past_key_values")
361
- ), f"Only text generation is supported but input_names == {input_names}"
358
+ and input_names[3 if has_position_ids else 2].startswith("past_key_values")
359
+ ), (
360
+ f"Only text generation is supported but input_names == {input_names}, "
361
+ f"has_position_ids={has_position_ids}"
362
+ )
363
+ assert (
364
+ not has_position_ids or input_names[2] == "position_ids"
365
+ ), f"position_ids must the third input but input_names={input_names}"
362
366
 
363
367
  # First call: prefill
364
368
  feeds = dict(
@@ -370,6 +374,10 @@ def onnx_generate(
370
374
  input_ids.shape[0], input_names[2:], input_shapes[2:], input_types[2:]
371
375
  ),
372
376
  )
377
+ if has_position_ids:
378
+ feeds["position_ids"] = torch.unsqueeze(
379
+ torch.arange(input_ids.shape[1], dtype=torch.int64, device=input_ids.device), 0
380
+ )
373
381
 
374
382
  outputs = session.run(None, feeds)
375
383
 
@@ -393,11 +401,21 @@ def onnx_generate(
393
401
  input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
394
402
  ),
395
403
  )
396
- feeds.update(dict(zip(input_names[2:], outputs[1:])))
404
+ if has_position_ids:
405
+ feeds["position_ids"] = torch.unsqueeze(
406
+ torch.arange(
407
+ input_ids.shape[1],
408
+ input_ids.shape[1] + 1,
409
+ dtype=torch.int64,
410
+ device=input_ids.device,
411
+ ),
412
+ 0,
413
+ )
414
+ feeds.update(dict(zip(input_names[3 if has_position_ids else 2 :], outputs[1:])))
397
415
  outputs = session.run(None, feeds)
398
416
 
399
417
  if return_session:
400
- return input_ids, session
418
+ return input_ids, session, feeds
401
419
  return input_ids
402
420
 
403
421
 
@@ -474,3 +492,549 @@ def onnx_generate_with_genai(
474
492
  if return_session:
475
493
  return input_ids, session
476
494
  return input_ids
495
+
496
+
497
+ _mapping_types = {
498
+ "float": "F",
499
+ "double": "D",
500
+ "float16": "H",
501
+ "uint8": "U8",
502
+ "uint16": "U16",
503
+ "uint32": "U32",
504
+ "uint64": "U64",
505
+ "int8": "I8",
506
+ "int16": "I16",
507
+ "int32": "I32",
508
+ "int64": "I64",
509
+ }
510
+
511
+
512
+ def _process_shape(shape_df):
513
+ if isinstance(shape_df, float) or len(shape_df) == 0:
514
+ return ""
515
+ values = []
516
+ for val in shape_df:
517
+ if len(val) != 1:
518
+ raise ValueError(f"Unable to process shape {val!r} from {values!r}.")
519
+ for _k, _v in val.items():
520
+ k, v = _k, _v
521
+ break
522
+ if v:
523
+ vs = "x".join(map(str, v))
524
+ values.append(f"{_mapping_types.get(k,k)}[{vs}]")
525
+ else:
526
+ values.append(f"{_mapping_types.get(k,k)}")
527
+ return "+".join(values)
528
+
529
+
530
+ def post_process_df_profile(
531
+ df: "pandas.DataFrame", # noqa: F821
532
+ first_it_out: bool = False,
533
+ agg: bool = False,
534
+ agg_op_name: bool = True,
535
+ with_shape: bool = False,
536
+ ) -> "pandas.DataFrame": # noqa: F821
537
+ """
538
+ Post-processed a dataframe obtained after profiling onnxruntime.
539
+ It adds a column for a more explicit event name and adds
540
+ a column for the iteration number
541
+
542
+ :param agg: aggregate the result
543
+ :param first_it_out: leave the first iteration
544
+ out of the aggregation
545
+ :param agg_op_name: aggregate on operator name or operator index
546
+ :param with_shape: keep the shape to aggregate
547
+ :return: DataFrame
548
+ """
549
+ events = {"kernel_time", "fence_after", "fence_before"}
550
+
551
+ def sep_event(s):
552
+ for e in events:
553
+ if s.endswith(e):
554
+ return e
555
+ return s
556
+
557
+ df = df.copy()
558
+ df["event_name"] = df["name"].apply(sep_event)
559
+ df["iteration"] = -1
560
+ current = -1
561
+ for i in range(df.shape[0]):
562
+ if df.loc[i, "name"] == "SequentialExecutor::Execute":
563
+ current += 1
564
+ df.loc[i, "iteration"] = current
565
+
566
+ if not agg:
567
+ if with_shape:
568
+ df["args_input_type_shape"] = df["args_input_type_shape"].apply(_process_shape)
569
+ df["args_output_type_shape"] = df["args_output_type_shape"].apply(_process_shape)
570
+ else:
571
+ df = df.drop(["args_input_type_shape", "args_output_type_shape"], axis=1)
572
+ if first_it_out:
573
+ df["it==0"] = (df["iteration"] <= 0).astype(int)
574
+ return df
575
+
576
+ agg_cols = ["cat", "args_node_index", "args_op_name", "args_provider", "event_name"]
577
+ if with_shape:
578
+ agg_cols.append("args_input_type_shape")
579
+ df["args_input_type_shape"] = df["args_input_type_shape"].apply(_process_shape)
580
+ df["args_output_type_shape"] = df["args_output_type_shape"].apply(_process_shape)
581
+ else:
582
+ df = df.drop(["args_input_type_shape", "args_output_type_shape"], axis=1)
583
+
584
+ if first_it_out:
585
+ df["it==0"] = (df["iteration"] <= 0).astype(int)
586
+ agg_cols.insert(0, "it==0")
587
+ if agg_op_name:
588
+ del agg_cols[agg_cols.index("args_node_index")]
589
+ for c in agg_cols:
590
+ df[c] = df[c].fillna("")
591
+ df["dur"] = df["dur"].fillna(0)
592
+ agg = df[[*agg_cols, "dur"]].groupby(agg_cols).sum()
593
+ return agg
594
+
595
+
596
+ def js_profile_to_dataframe(
597
+ filename: str,
598
+ as_df: bool = True,
599
+ first_it_out: bool = False,
600
+ agg: bool = False,
601
+ agg_op_name: bool = False,
602
+ with_shape: bool = False,
603
+ ) -> Union[List, "pandas.DataFrame"]: # noqa: F821
604
+ """
605
+ Profiles the execution of an onnx graph with onnxruntime.
606
+
607
+ :param filename: filename holding the profiling stored in json format
608
+ :param as_df: returns the
609
+ :param first_it_out: if aggregated, leaves the first iteration out
610
+ :param agg: aggregate by event
611
+ :param agg_op_name: aggregate on operator name or operator index
612
+ :param with_shape: keep the shape before aggregating
613
+ :return: DataFrame or dictionary
614
+ """
615
+ with open(filename, "r") as f:
616
+ content = f.read()
617
+ js = json.loads(content)
618
+
619
+ suffixes = ["_kernel_time", "_fence_before", "_fence_after"]
620
+ rows = []
621
+ for row in js:
622
+ if "args" in row and isinstance(row["args"], dict):
623
+ for k, v in row["args"].items():
624
+ row[f"args_{k}"] = v
625
+ del row["args"]
626
+ name = row["name"]
627
+ for suf in suffixes:
628
+ if name.endswith(suf):
629
+ changed = name[: -len(suf)]
630
+ row["op_name"] = changed
631
+ break
632
+ rows.append(row)
633
+ if as_df:
634
+ import pandas
635
+
636
+ return post_process_df_profile(
637
+ pandas.DataFrame(rows),
638
+ first_it_out=first_it_out,
639
+ agg=agg,
640
+ agg_op_name=agg_op_name,
641
+ with_shape=with_shape,
642
+ )
643
+ return rows
644
+
645
+
646
+ def _preprocess_graph1(df):
647
+ df = df.copy()
648
+ df["args_provider"] = df["args_provider"].apply(
649
+ lambda s: s.replace("ExecutionProvider", "") if isinstance(s, str) else s
650
+ )
651
+ agg_cols = ["dur", "args_op_name", "args_provider"]
652
+ for c in ["it==0", "args_input_type_shape"]:
653
+ if c in df.columns:
654
+ agg_cols.append(c)
655
+ if "it==0" in df.columns:
656
+ vs = ["t>=1", "t=0"]
657
+ df["it==0"] = df["it==0"].apply(lambda v: vs[v])
658
+ gr_dur = df[agg_cols].groupby(agg_cols[1:]).sum().sort_values("dur")
659
+ gr_n = df[agg_cols].groupby(agg_cols[1:]).count()
660
+ gr_n = gr_n.loc[gr_dur.index, :]
661
+ gr_n.columns = ["count"]
662
+ gr = gr_dur.merge(gr_n, left_index=True, right_index=True, how="outer")
663
+ gr["ratio"] = gr["dur"] / gr["dur"].sum()
664
+ return gr_dur, gr_n, gr
665
+
666
+
667
+ def _preprocess_graph2(df):
668
+ df = df.reset_index(drop=False).copy()
669
+ df["args_node_index"] = df["args_node_index"].apply(
670
+ lambda i: int(i) if i not in {None, ""} else -1
671
+ )
672
+ df["args_provider"] = df["args_provider"].apply(
673
+ lambda s: s.replace("ExecutionProvider", "") if isinstance(s, str) else s
674
+ )
675
+ df = df[(df["cat"] == "Node") & (df["event_name"] == "kernel_time")]
676
+ agg_cols = ["dur", "args_node_index", "args_op_name", "args_provider"]
677
+ for c in ["it==0", "args_input_type_shape"]:
678
+ if c in df.columns:
679
+ agg_cols.append(c)
680
+ if "it==0" in df.columns:
681
+ vs = ["t>=1", "t=0"]
682
+ df["it==0"] = df["it==0"].apply(lambda v: vs[v])
683
+ df = df[agg_cols].groupby(agg_cols[1:]).sum()
684
+ df = df.sort_index(ascending=False)
685
+ df["ratio"] = df["dur"] / df["dur"].sum()
686
+ return df
687
+
688
+
689
+ def plot_ort_profile(
690
+ df: "pandas.DataFrame", # noqa: F821
691
+ ax0: Optional["matplotlib.axes.Axes"] = None, # noqa: F821
692
+ ax1: Optional["matplotlib.axes.Axes"] = None, # noqa: F821
693
+ title: Optional[str] = None,
694
+ ) -> "matplotlib.axes.Axes": # noqa: F821
695
+ """
696
+ Plots time spend in computation based on a dataframe
697
+ produced by function :func:`js_profile_to_dataframe`.
698
+
699
+ :param df: dataframe
700
+ :param ax0: first axis to draw time
701
+ :param ax1: second axis to draw occurrences
702
+ :param title: graph title
703
+ :return: the graph
704
+
705
+ .. plot::
706
+ :include-source:
707
+
708
+ import numpy as np
709
+ from onnx import TensorProto
710
+ import onnx.helper as oh
711
+ from onnx.checker import check_model
712
+ from onnx.numpy_helper import from_array
713
+ import matplotlib.pyplot as plt
714
+ from onnxruntime import InferenceSession, SessionOptions
715
+ from onnx_diagnostic.helpers.rt_helper import js_profile_to_dataframe, plot_ort_profile
716
+
717
+
718
+ def get_model():
719
+ model_def0 = oh.make_model(
720
+ oh.make_graph(
721
+ [
722
+ oh.make_node("Add", ["X", "init1"], ["X1"]),
723
+ oh.make_node("Abs", ["X"], ["X2"]),
724
+ oh.make_node("Add", ["X", "init3"], ["inter"]),
725
+ oh.make_node("Mul", ["X1", "inter"], ["Xm"]),
726
+ oh.make_node("Sub", ["X2", "Xm"], ["final"]),
727
+ ],
728
+ "test",
729
+ [oh.make_tensor_value_info("X", TensorProto.FLOAT, [None])],
730
+ [oh.make_tensor_value_info("final", TensorProto.FLOAT, [None])],
731
+ [
732
+ from_array(np.array([1], dtype=np.float32), name="init1"),
733
+ from_array(np.array([3], dtype=np.float32), name="init3"),
734
+ ],
735
+ ),
736
+ opset_imports=[oh.make_opsetid("", 18)],
737
+ ir_version=9,
738
+ )
739
+ check_model(model_def0)
740
+ return model_def0
741
+
742
+
743
+ sess_options = SessionOptions()
744
+ sess_options.enable_profiling = True
745
+ sess = InferenceSession(
746
+ get_model().SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
747
+ )
748
+ for _ in range(11):
749
+ sess.run(None, dict(X=np.arange(10).astype(np.float32)))
750
+ prof = sess.end_profiling()
751
+
752
+ df = js_profile_to_dataframe(prof, first_it_out=True)
753
+ print(df.head())
754
+
755
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
756
+ plot_ort_profile(df, ax[0], ax[1], "test_title")
757
+ fig.tight_layout()
758
+
759
+ With ``agg=True``:
760
+
761
+ .. plot::
762
+ :include-source:
763
+
764
+ import numpy as np
765
+ from onnx import TensorProto
766
+ import onnx.helper as oh
767
+ from onnx.checker import check_model
768
+ from onnx.numpy_helper import from_array
769
+ import matplotlib.pyplot as plt
770
+ from onnxruntime import InferenceSession, SessionOptions
771
+ from onnx_diagnostic.helpers.rt_helper import js_profile_to_dataframe, plot_ort_profile
772
+
773
+
774
+ def get_model():
775
+ model_def0 = oh.make_model(
776
+ oh.make_graph(
777
+ [
778
+ oh.make_node("Add", ["X", "init1"], ["X1"]),
779
+ oh.make_node("Abs", ["X"], ["X2"]),
780
+ oh.make_node("Add", ["X", "init3"], ["inter"]),
781
+ oh.make_node("Mul", ["X1", "inter"], ["Xm"]),
782
+ oh.make_node("Sub", ["X2", "Xm"], ["final"]),
783
+ ],
784
+ "test",
785
+ [oh.make_tensor_value_info("X", TensorProto.FLOAT, [None])],
786
+ [oh.make_tensor_value_info("final", TensorProto.FLOAT, [None])],
787
+ [
788
+ from_array(np.array([1], dtype=np.float32), name="init1"),
789
+ from_array(np.array([3], dtype=np.float32), name="init3"),
790
+ ],
791
+ ),
792
+ opset_imports=[oh.make_opsetid("", 18)],
793
+ ir_version=9,
794
+ )
795
+ check_model(model_def0)
796
+ return model_def0
797
+
798
+
799
+ sess_options = SessionOptions()
800
+ sess_options.enable_profiling = True
801
+ sess = InferenceSession(
802
+ get_model().SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
803
+ )
804
+ for _ in range(11):
805
+ sess.run(None, dict(X=np.arange(10).astype(np.float32)))
806
+ prof = sess.end_profiling()
807
+
808
+ df = js_profile_to_dataframe(prof, first_it_out=True, agg=True)
809
+ print(df.head())
810
+
811
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
812
+ plot_ort_profile(df, ax[0], ax[1], "test_title")
813
+ fig.tight_layout()
814
+ """
815
+ fontsize = 10
816
+ if ax0 is None:
817
+ import matplotlib.pyplot as plt
818
+
819
+ ax0 = plt.gca()
820
+
821
+ if "args_provider" in df.columns:
822
+ # Aggregation by operator
823
+ gr_dur, gr_n, _ = _preprocess_graph1(df)
824
+ gr_dur.plot.barh(ax=ax0)
825
+ with warnings.catch_warnings():
826
+ warnings.simplefilter("ignore")
827
+ ax0.set_xticklabels(ax0.get_xticklabels(), fontsize=fontsize)
828
+ ax0.get_yaxis().set_label_text("")
829
+ ax0.set_yticklabels(
830
+ ax0.get_yticklabels(), rotation=45, ha="right", fontsize=fontsize
831
+ )
832
+ if title is not None:
833
+ ax0.set_title(title)
834
+ if ax1 is not None:
835
+ gr_n.plot.barh(ax=ax1)
836
+ ax1.set_title("n occurrences")
837
+ with warnings.catch_warnings():
838
+ warnings.simplefilter("ignore")
839
+ ax1.set_xticklabels(ax1.get_xticklabels(), fontsize=fontsize)
840
+ ax1.get_yaxis().set_label_text("")
841
+ ax1.set_yticklabels(
842
+ ax1.get_yticklabels(), rotation=45, ha="right", fontsize=fontsize
843
+ )
844
+ return ax0
845
+
846
+ df = _preprocess_graph2(df)
847
+ df[["dur"]].plot.barh(ax=ax0)
848
+ if title is not None:
849
+ ax0.set_title(title)
850
+ with warnings.catch_warnings():
851
+ warnings.simplefilter("ignore")
852
+ ax0.set_xticklabels(ax0.get_xticklabels(), fontsize=fontsize)
853
+ ax0.get_yaxis().set_label_text("")
854
+ ax0.set_yticklabels(ax0.get_yticklabels(), fontsize=fontsize)
855
+ return ax0
856
+
857
+
858
+ def plot_ort_profile_timeline(
859
+ df: "pandas.DataFrame", # noqa: F821
860
+ ax: Optional["matplotlib.axes.Axes"] = None, # noqa: F821
861
+ iteration: int = -2,
862
+ title: Optional[str] = None,
863
+ quantile: float = 0.5,
864
+ fontsize: int = 12,
865
+ ) -> "matplotlib.axes.Axes": # noqa: F821
866
+ """
867
+ Creates a timeline based on a dataframe
868
+ produced by function :func:`js_profile_to_dataframe`.
869
+
870
+ :param df: dataframe
871
+ :param ax: first axis to draw time
872
+ :param iteration: iteration to plot, negative value to start from the end
873
+ :param title: graph title
874
+ :param quantile: draw the 10% less consuming operators in a different color
875
+ :param fontsize: font size
876
+ :return: the graph
877
+
878
+ .. plot::
879
+ :include-source:
880
+
881
+ import numpy as np
882
+ from onnx import TensorProto
883
+ import onnx.helper as oh
884
+ from onnx.checker import check_model
885
+ from onnx.numpy_helper import from_array
886
+ import matplotlib.pyplot as plt
887
+ from onnxruntime import InferenceSession, SessionOptions
888
+ from onnx_diagnostic.helpers.rt_helper import (
889
+ js_profile_to_dataframe,
890
+ plot_ort_profile_timeline,
891
+ )
892
+
893
+
894
+ def get_model():
895
+ model_def0 = oh.make_model(
896
+ oh.make_graph(
897
+ [
898
+ oh.make_node("Add", ["X", "init1"], ["X1"]),
899
+ oh.make_node("Abs", ["X"], ["X2"]),
900
+ oh.make_node("Add", ["X", "init3"], ["inter"]),
901
+ oh.make_node("Mul", ["X1", "inter"], ["Xm"]),
902
+ oh.make_node("Sub", ["X2", "Xm"], ["final"]),
903
+ ],
904
+ "test",
905
+ [oh.make_tensor_value_info("X", TensorProto.FLOAT, [None])],
906
+ [oh.make_tensor_value_info("final", TensorProto.FLOAT, [None])],
907
+ [
908
+ from_array(np.array([1], dtype=np.float32), name="init1"),
909
+ from_array(np.array([3], dtype=np.float32), name="init3"),
910
+ ],
911
+ ),
912
+ opset_imports=[oh.make_opsetid("", 18)],
913
+ ir_version=9,
914
+ )
915
+ check_model(model_def0)
916
+ return model_def0
917
+
918
+
919
+ sess_options = SessionOptions()
920
+ sess_options.enable_profiling = True
921
+ sess = InferenceSession(
922
+ get_model().SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
923
+ )
924
+ for _ in range(11):
925
+ sess.run(None, dict(X=np.arange(10).astype(np.float32)))
926
+ prof = sess.end_profiling()
927
+
928
+ df = js_profile_to_dataframe(prof, first_it_out=True)
929
+ print(df.head())
930
+
931
+ fig, ax = plt.subplots(1, 1, figsize=(10, 5))
932
+ plot_ort_profile_timeline(df, ax, title="test_timeline", quantile=0.5)
933
+ fig.tight_layout()
934
+ """
935
+ if ax is None:
936
+ import matplotlib.pyplot as plt
937
+
938
+ ax = plt.gca()
939
+
940
+ df = df.copy()
941
+ df["iteration"] = df["iteration"].astype(int)
942
+ iterations = set(df["iteration"])
943
+ n_iter = iteration if iteration >= 0 else max(iterations) + 1 + iteration
944
+ dfi = df[df["iteration"] == n_iter]
945
+ assert dfi.shape[0] > 0, f"Iteration {iteration} cannot be found in {iterations}."
946
+
947
+ if "fence_before" in set(dfi["event_name"]):
948
+ started = {}
949
+ data = []
950
+ for irow in dfi.iterrows():
951
+ assert isinstance(irow, tuple), f"pandas has changed its api, type is {type(irow)}"
952
+ assert len(irow) == 2, f"pandas has changed its api, row is {irow}"
953
+ row = irow[1]
954
+ it = row["iteration"]
955
+ op_type = row["args_op_name"]
956
+ op_name = row["op_name"]
957
+ event_name = row["event_name"]
958
+ provider = row["args_provider"]
959
+ ts = float(row["ts"])
960
+ dur = float(row["dur"])
961
+ if event_name == "fence_before":
962
+ started[op_type, op_name, it] = dict(
963
+ op_name=op_name, op_type=op_type, begin=ts
964
+ )
965
+ elif event_name == "kernel_time":
966
+ obs = started[op_type, op_name, it]
967
+ obs["duration"] = dur
968
+ obs["begin_kernel"] = ts
969
+ obs["provider"] = provider
970
+ elif event_name == "fence_after":
971
+ obs = started[op_type, op_name, it]
972
+ obs["end"] = ts
973
+ data.append(obs)
974
+ del started[op_type, op_name, it]
975
+ else:
976
+ assert event_name in {
977
+ "SequentialExecutor::Execute",
978
+ "model_run",
979
+ }, f"Unexpected event_name={event_name!r}, row={row}"
980
+ else:
981
+ # New format
982
+ data = []
983
+ for irow in dfi.iterrows():
984
+ row = irow[1]
985
+ if row["event_name"] != "kernel_time":
986
+ continue
987
+ obs = dict(
988
+ duration=float(row["dur"]),
989
+ op_name=row["op_name"],
990
+ op_type=row["args_op_name"],
991
+ provider=row["args_provider"],
992
+ begin=float(row["ts"]),
993
+ end=float(row["ts"]) + float(row["dur"]),
994
+ begin_kernel=float(row["ts"]),
995
+ )
996
+ data.append(obs)
997
+
998
+ # durations
999
+ data_dur = list(sorted(d["duration"] for d in data))
1000
+ threshold = data_dur[int(quantile * len(data_dur))]
1001
+ origin = dfi["ts"].min()
1002
+
1003
+ colors = ["blue", "green", "red", "orange"]
1004
+
1005
+ import matplotlib.patches as mpatches
1006
+
1007
+ cs = [0, 0]
1008
+ for i, obs in enumerate(data):
1009
+ dur = obs["duration"]
1010
+ cat = int(dur >= threshold)
1011
+
1012
+ # color
1013
+ color = colors[cat * 2 + cs[cat] % 2]
1014
+ cs[cat] += 1
1015
+
1016
+ # rectangle
1017
+ t1 = obs["begin"] - origin
1018
+ t2 = obs["end"] - origin
1019
+ shape = mpatches.Rectangle((0, t1), 1, t2 - t1, ec="none", color=color)
1020
+ ax.add_artist(shape)
1021
+ tk1 = obs["begin_kernel"] - origin
1022
+ tk2 = (obs["begin_kernel"] + obs["duration"]) - origin
1023
+ ax.plot([0, 1], [tk1, tk1], "b--")
1024
+ ax.plot([0, 1], [tk2, tk2], "b--")
1025
+ if i == 0:
1026
+ ax.plot([0, 2], [tk1, tk1], "b")
1027
+ elif i == len(data) - 1:
1028
+ ax.plot([0, 2], [tk2, tk2], "b")
1029
+
1030
+ # text
1031
+ y = (tk1 + tk2) / 2
1032
+ text = obs["op_type"]
1033
+ prov = obs["provider"].replace("ExecutionProvider", "")
1034
+ name = obs["op_name"]
1035
+ if len(name) >= 10:
1036
+ name = name[:5] + "..." + name[5:]
1037
+ ax.text(1, y, f"{i}:{prov}:{text}-{name}", fontsize=fontsize, va="center")
1038
+
1039
+ ax.invert_yaxis()
1040
+ return ax
@@ -450,6 +450,11 @@ def fake_torchdynamo_exporting():
450
450
  """
451
451
  memorize = torch.compiler._is_exporting_flag
452
452
  torch.compiler._is_exporting_flag = True
453
+ assert torch.compiler.is_exporting(), (
454
+ f"Changes not detected "
455
+ f"torch.compiler._is_exporting_flag={torch.compiler._is_exporting_flag} "
456
+ f"and torch.compiler.is_exporting()={torch.compiler.is_exporting()}"
457
+ )
453
458
  try:
454
459
  yield
455
460
  finally:
@@ -311,7 +311,11 @@ def get_inputs_default(
311
311
  attention_mask=torch.cat(
312
312
  [
313
313
  torch.ones((batch_size, sequence_length), dtype=torch.int64),
314
- input_ids.ne(pad_token_id).to(torch.int64),
314
+ (
315
+ torch.ones(input_ids.shape)
316
+ if pad_token_id is None
317
+ else input_ids.ne(pad_token_id)
318
+ ).to(torch.int64),
315
319
  ],
316
320
  axis=-1,
317
321
  ),
@@ -151,6 +151,7 @@ def get_inputs(
151
151
  assert (
152
152
  add_second_input > 0
153
153
  ), f"Not implemented for add_second_input={add_second_input}."
154
+ res["inputs_prompt"] = dict(input_ids=torch.randint(1000, 30000, (1, 11)))
154
155
  res["inputs2"] = get_inputs(
155
156
  model=model,
156
157
  config=config,