onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/export/api.py +35 -5
- onnx_diagnostic/export/control_flow.py +511 -0
- onnx_diagnostic/export/control_flow_research.py +135 -0
- onnx_diagnostic/ext_test_case.py +33 -9
- onnx_diagnostic/helpers/cache_helper.py +217 -203
- onnx_diagnostic/helpers/helper.py +2 -0
- onnx_diagnostic/helpers/log_helper.py +26 -4
- onnx_diagnostic/helpers/mini_onnx_builder.py +54 -2
- onnx_diagnostic/helpers/onnx_helper.py +12 -15
- onnx_diagnostic/helpers/rt_helper.py +547 -0
- onnx_diagnostic/helpers/torch_helper.py +5 -0
- onnx_diagnostic/tasks/image_text_to_text.py +5 -1
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +561 -59
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.2.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.2.dist-info}/RECORD +24 -22
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.2.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.2.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.2.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
5
|
import numpy as np
|
|
5
6
|
import onnx
|
|
@@ -491,3 +492,549 @@ def onnx_generate_with_genai(
|
|
|
491
492
|
if return_session:
|
|
492
493
|
return input_ids, session
|
|
493
494
|
return input_ids
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
_mapping_types = {
|
|
498
|
+
"float": "F",
|
|
499
|
+
"double": "D",
|
|
500
|
+
"float16": "H",
|
|
501
|
+
"uint8": "U8",
|
|
502
|
+
"uint16": "U16",
|
|
503
|
+
"uint32": "U32",
|
|
504
|
+
"uint64": "U64",
|
|
505
|
+
"int8": "I8",
|
|
506
|
+
"int16": "I16",
|
|
507
|
+
"int32": "I32",
|
|
508
|
+
"int64": "I64",
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def _process_shape(shape_df):
|
|
513
|
+
if isinstance(shape_df, float) or len(shape_df) == 0:
|
|
514
|
+
return ""
|
|
515
|
+
values = []
|
|
516
|
+
for val in shape_df:
|
|
517
|
+
if len(val) != 1:
|
|
518
|
+
raise ValueError(f"Unable to process shape {val!r} from {values!r}.")
|
|
519
|
+
for _k, _v in val.items():
|
|
520
|
+
k, v = _k, _v
|
|
521
|
+
break
|
|
522
|
+
if v:
|
|
523
|
+
vs = "x".join(map(str, v))
|
|
524
|
+
values.append(f"{_mapping_types.get(k,k)}[{vs}]")
|
|
525
|
+
else:
|
|
526
|
+
values.append(f"{_mapping_types.get(k,k)}")
|
|
527
|
+
return "+".join(values)
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def post_process_df_profile(
|
|
531
|
+
df: "pandas.DataFrame", # noqa: F821
|
|
532
|
+
first_it_out: bool = False,
|
|
533
|
+
agg: bool = False,
|
|
534
|
+
agg_op_name: bool = True,
|
|
535
|
+
with_shape: bool = False,
|
|
536
|
+
) -> "pandas.DataFrame": # noqa: F821
|
|
537
|
+
"""
|
|
538
|
+
Post-processed a dataframe obtained after profiling onnxruntime.
|
|
539
|
+
It adds a column for a more explicit event name and adds
|
|
540
|
+
a column for the iteration number
|
|
541
|
+
|
|
542
|
+
:param agg: aggregate the result
|
|
543
|
+
:param first_it_out: leave the first iteration
|
|
544
|
+
out of the aggregation
|
|
545
|
+
:param agg_op_name: aggregate on operator name or operator index
|
|
546
|
+
:param with_shape: keep the shape to aggregate
|
|
547
|
+
:return: DataFrame
|
|
548
|
+
"""
|
|
549
|
+
events = {"kernel_time", "fence_after", "fence_before"}
|
|
550
|
+
|
|
551
|
+
def sep_event(s):
|
|
552
|
+
for e in events:
|
|
553
|
+
if s.endswith(e):
|
|
554
|
+
return e
|
|
555
|
+
return s
|
|
556
|
+
|
|
557
|
+
df = df.copy()
|
|
558
|
+
df["event_name"] = df["name"].apply(sep_event)
|
|
559
|
+
df["iteration"] = -1
|
|
560
|
+
current = -1
|
|
561
|
+
for i in range(df.shape[0]):
|
|
562
|
+
if df.loc[i, "name"] == "SequentialExecutor::Execute":
|
|
563
|
+
current += 1
|
|
564
|
+
df.loc[i, "iteration"] = current
|
|
565
|
+
|
|
566
|
+
if not agg:
|
|
567
|
+
if with_shape:
|
|
568
|
+
df["args_input_type_shape"] = df["args_input_type_shape"].apply(_process_shape)
|
|
569
|
+
df["args_output_type_shape"] = df["args_output_type_shape"].apply(_process_shape)
|
|
570
|
+
else:
|
|
571
|
+
df = df.drop(["args_input_type_shape", "args_output_type_shape"], axis=1)
|
|
572
|
+
if first_it_out:
|
|
573
|
+
df["it==0"] = (df["iteration"] <= 0).astype(int)
|
|
574
|
+
return df
|
|
575
|
+
|
|
576
|
+
agg_cols = ["cat", "args_node_index", "args_op_name", "args_provider", "event_name"]
|
|
577
|
+
if with_shape:
|
|
578
|
+
agg_cols.append("args_input_type_shape")
|
|
579
|
+
df["args_input_type_shape"] = df["args_input_type_shape"].apply(_process_shape)
|
|
580
|
+
df["args_output_type_shape"] = df["args_output_type_shape"].apply(_process_shape)
|
|
581
|
+
else:
|
|
582
|
+
df = df.drop(["args_input_type_shape", "args_output_type_shape"], axis=1)
|
|
583
|
+
|
|
584
|
+
if first_it_out:
|
|
585
|
+
df["it==0"] = (df["iteration"] <= 0).astype(int)
|
|
586
|
+
agg_cols.insert(0, "it==0")
|
|
587
|
+
if agg_op_name:
|
|
588
|
+
del agg_cols[agg_cols.index("args_node_index")]
|
|
589
|
+
for c in agg_cols:
|
|
590
|
+
df[c] = df[c].fillna("")
|
|
591
|
+
df["dur"] = df["dur"].fillna(0)
|
|
592
|
+
agg = df[[*agg_cols, "dur"]].groupby(agg_cols).sum()
|
|
593
|
+
return agg
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
def js_profile_to_dataframe(
|
|
597
|
+
filename: str,
|
|
598
|
+
as_df: bool = True,
|
|
599
|
+
first_it_out: bool = False,
|
|
600
|
+
agg: bool = False,
|
|
601
|
+
agg_op_name: bool = False,
|
|
602
|
+
with_shape: bool = False,
|
|
603
|
+
) -> Union[List, "pandas.DataFrame"]: # noqa: F821
|
|
604
|
+
"""
|
|
605
|
+
Profiles the execution of an onnx graph with onnxruntime.
|
|
606
|
+
|
|
607
|
+
:param filename: filename holding the profiling stored in json format
|
|
608
|
+
:param as_df: returns the
|
|
609
|
+
:param first_it_out: if aggregated, leaves the first iteration out
|
|
610
|
+
:param agg: aggregate by event
|
|
611
|
+
:param agg_op_name: aggregate on operator name or operator index
|
|
612
|
+
:param with_shape: keep the shape before aggregating
|
|
613
|
+
:return: DataFrame or dictionary
|
|
614
|
+
"""
|
|
615
|
+
with open(filename, "r") as f:
|
|
616
|
+
content = f.read()
|
|
617
|
+
js = json.loads(content)
|
|
618
|
+
|
|
619
|
+
suffixes = ["_kernel_time", "_fence_before", "_fence_after"]
|
|
620
|
+
rows = []
|
|
621
|
+
for row in js:
|
|
622
|
+
if "args" in row and isinstance(row["args"], dict):
|
|
623
|
+
for k, v in row["args"].items():
|
|
624
|
+
row[f"args_{k}"] = v
|
|
625
|
+
del row["args"]
|
|
626
|
+
name = row["name"]
|
|
627
|
+
for suf in suffixes:
|
|
628
|
+
if name.endswith(suf):
|
|
629
|
+
changed = name[: -len(suf)]
|
|
630
|
+
row["op_name"] = changed
|
|
631
|
+
break
|
|
632
|
+
rows.append(row)
|
|
633
|
+
if as_df:
|
|
634
|
+
import pandas
|
|
635
|
+
|
|
636
|
+
return post_process_df_profile(
|
|
637
|
+
pandas.DataFrame(rows),
|
|
638
|
+
first_it_out=first_it_out,
|
|
639
|
+
agg=agg,
|
|
640
|
+
agg_op_name=agg_op_name,
|
|
641
|
+
with_shape=with_shape,
|
|
642
|
+
)
|
|
643
|
+
return rows
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
def _preprocess_graph1(df):
|
|
647
|
+
df = df.copy()
|
|
648
|
+
df["args_provider"] = df["args_provider"].apply(
|
|
649
|
+
lambda s: s.replace("ExecutionProvider", "") if isinstance(s, str) else s
|
|
650
|
+
)
|
|
651
|
+
agg_cols = ["dur", "args_op_name", "args_provider"]
|
|
652
|
+
for c in ["it==0", "args_input_type_shape"]:
|
|
653
|
+
if c in df.columns:
|
|
654
|
+
agg_cols.append(c)
|
|
655
|
+
if "it==0" in df.columns:
|
|
656
|
+
vs = ["t>=1", "t=0"]
|
|
657
|
+
df["it==0"] = df["it==0"].apply(lambda v: vs[v])
|
|
658
|
+
gr_dur = df[agg_cols].groupby(agg_cols[1:]).sum().sort_values("dur")
|
|
659
|
+
gr_n = df[agg_cols].groupby(agg_cols[1:]).count()
|
|
660
|
+
gr_n = gr_n.loc[gr_dur.index, :]
|
|
661
|
+
gr_n.columns = ["count"]
|
|
662
|
+
gr = gr_dur.merge(gr_n, left_index=True, right_index=True, how="outer")
|
|
663
|
+
gr["ratio"] = gr["dur"] / gr["dur"].sum()
|
|
664
|
+
return gr_dur, gr_n, gr
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def _preprocess_graph2(df):
|
|
668
|
+
df = df.reset_index(drop=False).copy()
|
|
669
|
+
df["args_node_index"] = df["args_node_index"].apply(
|
|
670
|
+
lambda i: int(i) if i not in {None, ""} else -1
|
|
671
|
+
)
|
|
672
|
+
df["args_provider"] = df["args_provider"].apply(
|
|
673
|
+
lambda s: s.replace("ExecutionProvider", "") if isinstance(s, str) else s
|
|
674
|
+
)
|
|
675
|
+
df = df[(df["cat"] == "Node") & (df["event_name"] == "kernel_time")]
|
|
676
|
+
agg_cols = ["dur", "args_node_index", "args_op_name", "args_provider"]
|
|
677
|
+
for c in ["it==0", "args_input_type_shape"]:
|
|
678
|
+
if c in df.columns:
|
|
679
|
+
agg_cols.append(c)
|
|
680
|
+
if "it==0" in df.columns:
|
|
681
|
+
vs = ["t>=1", "t=0"]
|
|
682
|
+
df["it==0"] = df["it==0"].apply(lambda v: vs[v])
|
|
683
|
+
df = df[agg_cols].groupby(agg_cols[1:]).sum()
|
|
684
|
+
df = df.sort_index(ascending=False)
|
|
685
|
+
df["ratio"] = df["dur"] / df["dur"].sum()
|
|
686
|
+
return df
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def plot_ort_profile(
|
|
690
|
+
df: "pandas.DataFrame", # noqa: F821
|
|
691
|
+
ax0: Optional["matplotlib.axes.Axes"] = None, # noqa: F821
|
|
692
|
+
ax1: Optional["matplotlib.axes.Axes"] = None, # noqa: F821
|
|
693
|
+
title: Optional[str] = None,
|
|
694
|
+
) -> "matplotlib.axes.Axes": # noqa: F821
|
|
695
|
+
"""
|
|
696
|
+
Plots time spend in computation based on a dataframe
|
|
697
|
+
produced by function :func:`js_profile_to_dataframe`.
|
|
698
|
+
|
|
699
|
+
:param df: dataframe
|
|
700
|
+
:param ax0: first axis to draw time
|
|
701
|
+
:param ax1: second axis to draw occurrences
|
|
702
|
+
:param title: graph title
|
|
703
|
+
:return: the graph
|
|
704
|
+
|
|
705
|
+
.. plot::
|
|
706
|
+
:include-source:
|
|
707
|
+
|
|
708
|
+
import numpy as np
|
|
709
|
+
from onnx import TensorProto
|
|
710
|
+
import onnx.helper as oh
|
|
711
|
+
from onnx.checker import check_model
|
|
712
|
+
from onnx.numpy_helper import from_array
|
|
713
|
+
import matplotlib.pyplot as plt
|
|
714
|
+
from onnxruntime import InferenceSession, SessionOptions
|
|
715
|
+
from onnx_diagnostic.helpers.rt_helper import js_profile_to_dataframe, plot_ort_profile
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
def get_model():
|
|
719
|
+
model_def0 = oh.make_model(
|
|
720
|
+
oh.make_graph(
|
|
721
|
+
[
|
|
722
|
+
oh.make_node("Add", ["X", "init1"], ["X1"]),
|
|
723
|
+
oh.make_node("Abs", ["X"], ["X2"]),
|
|
724
|
+
oh.make_node("Add", ["X", "init3"], ["inter"]),
|
|
725
|
+
oh.make_node("Mul", ["X1", "inter"], ["Xm"]),
|
|
726
|
+
oh.make_node("Sub", ["X2", "Xm"], ["final"]),
|
|
727
|
+
],
|
|
728
|
+
"test",
|
|
729
|
+
[oh.make_tensor_value_info("X", TensorProto.FLOAT, [None])],
|
|
730
|
+
[oh.make_tensor_value_info("final", TensorProto.FLOAT, [None])],
|
|
731
|
+
[
|
|
732
|
+
from_array(np.array([1], dtype=np.float32), name="init1"),
|
|
733
|
+
from_array(np.array([3], dtype=np.float32), name="init3"),
|
|
734
|
+
],
|
|
735
|
+
),
|
|
736
|
+
opset_imports=[oh.make_opsetid("", 18)],
|
|
737
|
+
ir_version=9,
|
|
738
|
+
)
|
|
739
|
+
check_model(model_def0)
|
|
740
|
+
return model_def0
|
|
741
|
+
|
|
742
|
+
|
|
743
|
+
sess_options = SessionOptions()
|
|
744
|
+
sess_options.enable_profiling = True
|
|
745
|
+
sess = InferenceSession(
|
|
746
|
+
get_model().SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
|
|
747
|
+
)
|
|
748
|
+
for _ in range(11):
|
|
749
|
+
sess.run(None, dict(X=np.arange(10).astype(np.float32)))
|
|
750
|
+
prof = sess.end_profiling()
|
|
751
|
+
|
|
752
|
+
df = js_profile_to_dataframe(prof, first_it_out=True)
|
|
753
|
+
print(df.head())
|
|
754
|
+
|
|
755
|
+
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
|
|
756
|
+
plot_ort_profile(df, ax[0], ax[1], "test_title")
|
|
757
|
+
fig.tight_layout()
|
|
758
|
+
|
|
759
|
+
With ``agg=True``:
|
|
760
|
+
|
|
761
|
+
.. plot::
|
|
762
|
+
:include-source:
|
|
763
|
+
|
|
764
|
+
import numpy as np
|
|
765
|
+
from onnx import TensorProto
|
|
766
|
+
import onnx.helper as oh
|
|
767
|
+
from onnx.checker import check_model
|
|
768
|
+
from onnx.numpy_helper import from_array
|
|
769
|
+
import matplotlib.pyplot as plt
|
|
770
|
+
from onnxruntime import InferenceSession, SessionOptions
|
|
771
|
+
from onnx_diagnostic.helpers.rt_helper import js_profile_to_dataframe, plot_ort_profile
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
def get_model():
|
|
775
|
+
model_def0 = oh.make_model(
|
|
776
|
+
oh.make_graph(
|
|
777
|
+
[
|
|
778
|
+
oh.make_node("Add", ["X", "init1"], ["X1"]),
|
|
779
|
+
oh.make_node("Abs", ["X"], ["X2"]),
|
|
780
|
+
oh.make_node("Add", ["X", "init3"], ["inter"]),
|
|
781
|
+
oh.make_node("Mul", ["X1", "inter"], ["Xm"]),
|
|
782
|
+
oh.make_node("Sub", ["X2", "Xm"], ["final"]),
|
|
783
|
+
],
|
|
784
|
+
"test",
|
|
785
|
+
[oh.make_tensor_value_info("X", TensorProto.FLOAT, [None])],
|
|
786
|
+
[oh.make_tensor_value_info("final", TensorProto.FLOAT, [None])],
|
|
787
|
+
[
|
|
788
|
+
from_array(np.array([1], dtype=np.float32), name="init1"),
|
|
789
|
+
from_array(np.array([3], dtype=np.float32), name="init3"),
|
|
790
|
+
],
|
|
791
|
+
),
|
|
792
|
+
opset_imports=[oh.make_opsetid("", 18)],
|
|
793
|
+
ir_version=9,
|
|
794
|
+
)
|
|
795
|
+
check_model(model_def0)
|
|
796
|
+
return model_def0
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
sess_options = SessionOptions()
|
|
800
|
+
sess_options.enable_profiling = True
|
|
801
|
+
sess = InferenceSession(
|
|
802
|
+
get_model().SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
|
|
803
|
+
)
|
|
804
|
+
for _ in range(11):
|
|
805
|
+
sess.run(None, dict(X=np.arange(10).astype(np.float32)))
|
|
806
|
+
prof = sess.end_profiling()
|
|
807
|
+
|
|
808
|
+
df = js_profile_to_dataframe(prof, first_it_out=True, agg=True)
|
|
809
|
+
print(df.head())
|
|
810
|
+
|
|
811
|
+
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
|
|
812
|
+
plot_ort_profile(df, ax[0], ax[1], "test_title")
|
|
813
|
+
fig.tight_layout()
|
|
814
|
+
"""
|
|
815
|
+
fontsize = 10
|
|
816
|
+
if ax0 is None:
|
|
817
|
+
import matplotlib.pyplot as plt
|
|
818
|
+
|
|
819
|
+
ax0 = plt.gca()
|
|
820
|
+
|
|
821
|
+
if "args_provider" in df.columns:
|
|
822
|
+
# Aggregation by operator
|
|
823
|
+
gr_dur, gr_n, _ = _preprocess_graph1(df)
|
|
824
|
+
gr_dur.plot.barh(ax=ax0)
|
|
825
|
+
with warnings.catch_warnings():
|
|
826
|
+
warnings.simplefilter("ignore")
|
|
827
|
+
ax0.set_xticklabels(ax0.get_xticklabels(), fontsize=fontsize)
|
|
828
|
+
ax0.get_yaxis().set_label_text("")
|
|
829
|
+
ax0.set_yticklabels(
|
|
830
|
+
ax0.get_yticklabels(), rotation=45, ha="right", fontsize=fontsize
|
|
831
|
+
)
|
|
832
|
+
if title is not None:
|
|
833
|
+
ax0.set_title(title)
|
|
834
|
+
if ax1 is not None:
|
|
835
|
+
gr_n.plot.barh(ax=ax1)
|
|
836
|
+
ax1.set_title("n occurrences")
|
|
837
|
+
with warnings.catch_warnings():
|
|
838
|
+
warnings.simplefilter("ignore")
|
|
839
|
+
ax1.set_xticklabels(ax1.get_xticklabels(), fontsize=fontsize)
|
|
840
|
+
ax1.get_yaxis().set_label_text("")
|
|
841
|
+
ax1.set_yticklabels(
|
|
842
|
+
ax1.get_yticklabels(), rotation=45, ha="right", fontsize=fontsize
|
|
843
|
+
)
|
|
844
|
+
return ax0
|
|
845
|
+
|
|
846
|
+
df = _preprocess_graph2(df)
|
|
847
|
+
df[["dur"]].plot.barh(ax=ax0)
|
|
848
|
+
if title is not None:
|
|
849
|
+
ax0.set_title(title)
|
|
850
|
+
with warnings.catch_warnings():
|
|
851
|
+
warnings.simplefilter("ignore")
|
|
852
|
+
ax0.set_xticklabels(ax0.get_xticklabels(), fontsize=fontsize)
|
|
853
|
+
ax0.get_yaxis().set_label_text("")
|
|
854
|
+
ax0.set_yticklabels(ax0.get_yticklabels(), fontsize=fontsize)
|
|
855
|
+
return ax0
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def plot_ort_profile_timeline(
|
|
859
|
+
df: "pandas.DataFrame", # noqa: F821
|
|
860
|
+
ax: Optional["matplotlib.axes.Axes"] = None, # noqa: F821
|
|
861
|
+
iteration: int = -2,
|
|
862
|
+
title: Optional[str] = None,
|
|
863
|
+
quantile: float = 0.5,
|
|
864
|
+
fontsize: int = 12,
|
|
865
|
+
) -> "matplotlib.axes.Axes": # noqa: F821
|
|
866
|
+
"""
|
|
867
|
+
Creates a timeline based on a dataframe
|
|
868
|
+
produced by function :func:`js_profile_to_dataframe`.
|
|
869
|
+
|
|
870
|
+
:param df: dataframe
|
|
871
|
+
:param ax: first axis to draw time
|
|
872
|
+
:param iteration: iteration to plot, negative value to start from the end
|
|
873
|
+
:param title: graph title
|
|
874
|
+
:param quantile: draw the 10% less consuming operators in a different color
|
|
875
|
+
:param fontsize: font size
|
|
876
|
+
:return: the graph
|
|
877
|
+
|
|
878
|
+
.. plot::
|
|
879
|
+
:include-source:
|
|
880
|
+
|
|
881
|
+
import numpy as np
|
|
882
|
+
from onnx import TensorProto
|
|
883
|
+
import onnx.helper as oh
|
|
884
|
+
from onnx.checker import check_model
|
|
885
|
+
from onnx.numpy_helper import from_array
|
|
886
|
+
import matplotlib.pyplot as plt
|
|
887
|
+
from onnxruntime import InferenceSession, SessionOptions
|
|
888
|
+
from onnx_diagnostic.helpers.rt_helper import (
|
|
889
|
+
js_profile_to_dataframe,
|
|
890
|
+
plot_ort_profile_timeline,
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
def get_model():
|
|
895
|
+
model_def0 = oh.make_model(
|
|
896
|
+
oh.make_graph(
|
|
897
|
+
[
|
|
898
|
+
oh.make_node("Add", ["X", "init1"], ["X1"]),
|
|
899
|
+
oh.make_node("Abs", ["X"], ["X2"]),
|
|
900
|
+
oh.make_node("Add", ["X", "init3"], ["inter"]),
|
|
901
|
+
oh.make_node("Mul", ["X1", "inter"], ["Xm"]),
|
|
902
|
+
oh.make_node("Sub", ["X2", "Xm"], ["final"]),
|
|
903
|
+
],
|
|
904
|
+
"test",
|
|
905
|
+
[oh.make_tensor_value_info("X", TensorProto.FLOAT, [None])],
|
|
906
|
+
[oh.make_tensor_value_info("final", TensorProto.FLOAT, [None])],
|
|
907
|
+
[
|
|
908
|
+
from_array(np.array([1], dtype=np.float32), name="init1"),
|
|
909
|
+
from_array(np.array([3], dtype=np.float32), name="init3"),
|
|
910
|
+
],
|
|
911
|
+
),
|
|
912
|
+
opset_imports=[oh.make_opsetid("", 18)],
|
|
913
|
+
ir_version=9,
|
|
914
|
+
)
|
|
915
|
+
check_model(model_def0)
|
|
916
|
+
return model_def0
|
|
917
|
+
|
|
918
|
+
|
|
919
|
+
sess_options = SessionOptions()
|
|
920
|
+
sess_options.enable_profiling = True
|
|
921
|
+
sess = InferenceSession(
|
|
922
|
+
get_model().SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
|
|
923
|
+
)
|
|
924
|
+
for _ in range(11):
|
|
925
|
+
sess.run(None, dict(X=np.arange(10).astype(np.float32)))
|
|
926
|
+
prof = sess.end_profiling()
|
|
927
|
+
|
|
928
|
+
df = js_profile_to_dataframe(prof, first_it_out=True)
|
|
929
|
+
print(df.head())
|
|
930
|
+
|
|
931
|
+
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
|
|
932
|
+
plot_ort_profile_timeline(df, ax, title="test_timeline", quantile=0.5)
|
|
933
|
+
fig.tight_layout()
|
|
934
|
+
"""
|
|
935
|
+
if ax is None:
|
|
936
|
+
import matplotlib.pyplot as plt
|
|
937
|
+
|
|
938
|
+
ax = plt.gca()
|
|
939
|
+
|
|
940
|
+
df = df.copy()
|
|
941
|
+
df["iteration"] = df["iteration"].astype(int)
|
|
942
|
+
iterations = set(df["iteration"])
|
|
943
|
+
n_iter = iteration if iteration >= 0 else max(iterations) + 1 + iteration
|
|
944
|
+
dfi = df[df["iteration"] == n_iter]
|
|
945
|
+
assert dfi.shape[0] > 0, f"Iteration {iteration} cannot be found in {iterations}."
|
|
946
|
+
|
|
947
|
+
if "fence_before" in set(dfi["event_name"]):
|
|
948
|
+
started = {}
|
|
949
|
+
data = []
|
|
950
|
+
for irow in dfi.iterrows():
|
|
951
|
+
assert isinstance(irow, tuple), f"pandas has changed its api, type is {type(irow)}"
|
|
952
|
+
assert len(irow) == 2, f"pandas has changed its api, row is {irow}"
|
|
953
|
+
row = irow[1]
|
|
954
|
+
it = row["iteration"]
|
|
955
|
+
op_type = row["args_op_name"]
|
|
956
|
+
op_name = row["op_name"]
|
|
957
|
+
event_name = row["event_name"]
|
|
958
|
+
provider = row["args_provider"]
|
|
959
|
+
ts = float(row["ts"])
|
|
960
|
+
dur = float(row["dur"])
|
|
961
|
+
if event_name == "fence_before":
|
|
962
|
+
started[op_type, op_name, it] = dict(
|
|
963
|
+
op_name=op_name, op_type=op_type, begin=ts
|
|
964
|
+
)
|
|
965
|
+
elif event_name == "kernel_time":
|
|
966
|
+
obs = started[op_type, op_name, it]
|
|
967
|
+
obs["duration"] = dur
|
|
968
|
+
obs["begin_kernel"] = ts
|
|
969
|
+
obs["provider"] = provider
|
|
970
|
+
elif event_name == "fence_after":
|
|
971
|
+
obs = started[op_type, op_name, it]
|
|
972
|
+
obs["end"] = ts
|
|
973
|
+
data.append(obs)
|
|
974
|
+
del started[op_type, op_name, it]
|
|
975
|
+
else:
|
|
976
|
+
assert event_name in {
|
|
977
|
+
"SequentialExecutor::Execute",
|
|
978
|
+
"model_run",
|
|
979
|
+
}, f"Unexpected event_name={event_name!r}, row={row}"
|
|
980
|
+
else:
|
|
981
|
+
# New format
|
|
982
|
+
data = []
|
|
983
|
+
for irow in dfi.iterrows():
|
|
984
|
+
row = irow[1]
|
|
985
|
+
if row["event_name"] != "kernel_time":
|
|
986
|
+
continue
|
|
987
|
+
obs = dict(
|
|
988
|
+
duration=float(row["dur"]),
|
|
989
|
+
op_name=row["op_name"],
|
|
990
|
+
op_type=row["args_op_name"],
|
|
991
|
+
provider=row["args_provider"],
|
|
992
|
+
begin=float(row["ts"]),
|
|
993
|
+
end=float(row["ts"]) + float(row["dur"]),
|
|
994
|
+
begin_kernel=float(row["ts"]),
|
|
995
|
+
)
|
|
996
|
+
data.append(obs)
|
|
997
|
+
|
|
998
|
+
# durations
|
|
999
|
+
data_dur = list(sorted(d["duration"] for d in data))
|
|
1000
|
+
threshold = data_dur[int(quantile * len(data_dur))]
|
|
1001
|
+
origin = dfi["ts"].min()
|
|
1002
|
+
|
|
1003
|
+
colors = ["blue", "green", "red", "orange"]
|
|
1004
|
+
|
|
1005
|
+
import matplotlib.patches as mpatches
|
|
1006
|
+
|
|
1007
|
+
cs = [0, 0]
|
|
1008
|
+
for i, obs in enumerate(data):
|
|
1009
|
+
dur = obs["duration"]
|
|
1010
|
+
cat = int(dur >= threshold)
|
|
1011
|
+
|
|
1012
|
+
# color
|
|
1013
|
+
color = colors[cat * 2 + cs[cat] % 2]
|
|
1014
|
+
cs[cat] += 1
|
|
1015
|
+
|
|
1016
|
+
# rectangle
|
|
1017
|
+
t1 = obs["begin"] - origin
|
|
1018
|
+
t2 = obs["end"] - origin
|
|
1019
|
+
shape = mpatches.Rectangle((0, t1), 1, t2 - t1, ec="none", color=color)
|
|
1020
|
+
ax.add_artist(shape)
|
|
1021
|
+
tk1 = obs["begin_kernel"] - origin
|
|
1022
|
+
tk2 = (obs["begin_kernel"] + obs["duration"]) - origin
|
|
1023
|
+
ax.plot([0, 1], [tk1, tk1], "b--")
|
|
1024
|
+
ax.plot([0, 1], [tk2, tk2], "b--")
|
|
1025
|
+
if i == 0:
|
|
1026
|
+
ax.plot([0, 2], [tk1, tk1], "b")
|
|
1027
|
+
elif i == len(data) - 1:
|
|
1028
|
+
ax.plot([0, 2], [tk2, tk2], "b")
|
|
1029
|
+
|
|
1030
|
+
# text
|
|
1031
|
+
y = (tk1 + tk2) / 2
|
|
1032
|
+
text = obs["op_type"]
|
|
1033
|
+
prov = obs["provider"].replace("ExecutionProvider", "")
|
|
1034
|
+
name = obs["op_name"]
|
|
1035
|
+
if len(name) >= 10:
|
|
1036
|
+
name = name[:5] + "..." + name[5:]
|
|
1037
|
+
ax.text(1, y, f"{i}:{prov}:{text}-{name}", fontsize=fontsize, va="center")
|
|
1038
|
+
|
|
1039
|
+
ax.invert_yaxis()
|
|
1040
|
+
return ax
|
|
@@ -450,6 +450,11 @@ def fake_torchdynamo_exporting():
|
|
|
450
450
|
"""
|
|
451
451
|
memorize = torch.compiler._is_exporting_flag
|
|
452
452
|
torch.compiler._is_exporting_flag = True
|
|
453
|
+
assert torch.compiler.is_exporting(), (
|
|
454
|
+
f"Changes not detected "
|
|
455
|
+
f"torch.compiler._is_exporting_flag={torch.compiler._is_exporting_flag} "
|
|
456
|
+
f"and torch.compiler.is_exporting()={torch.compiler.is_exporting()}"
|
|
457
|
+
)
|
|
453
458
|
try:
|
|
454
459
|
yield
|
|
455
460
|
finally:
|
|
@@ -311,7 +311,11 @@ def get_inputs_default(
|
|
|
311
311
|
attention_mask=torch.cat(
|
|
312
312
|
[
|
|
313
313
|
torch.ones((batch_size, sequence_length), dtype=torch.int64),
|
|
314
|
-
|
|
314
|
+
(
|
|
315
|
+
torch.ones(input_ids.shape)
|
|
316
|
+
if pad_token_id is None
|
|
317
|
+
else input_ids.ne(pad_token_id)
|
|
318
|
+
).to(torch.int64),
|
|
315
319
|
],
|
|
316
320
|
axis=-1,
|
|
317
321
|
),
|
|
@@ -570,6 +570,34 @@ class ControlFlowScanDecomposition_151564(torch.nn.Module):
|
|
|
570
570
|
_dynamic = {"images": {0: DYN, 1: DYN}, "position": {0: DYN}}
|
|
571
571
|
|
|
572
572
|
|
|
573
|
+
class ControlFlowWhileDec(torch.nn.Module):
|
|
574
|
+
def forward(self, ci, a, b):
|
|
575
|
+
def cond_fn(i, x, y):
|
|
576
|
+
return i > 0
|
|
577
|
+
|
|
578
|
+
def body_fn(i, x, y):
|
|
579
|
+
return i - 1, x + y, y - x
|
|
580
|
+
|
|
581
|
+
return torch._higher_order_ops.while_loop(cond_fn, body_fn, [ci, a, b])
|
|
582
|
+
|
|
583
|
+
_inputs = [(torch.tensor(1), torch.randn(2, 3), torch.randn(2, 3))]
|
|
584
|
+
_dynamic = {}, {0: DYN, 1: DYN}, {0: DYN}
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
class ControlFlowWhileInc(torch.nn.Module):
|
|
588
|
+
def forward(self, ci, a, b):
|
|
589
|
+
def cond_fn(i, x, y):
|
|
590
|
+
return i < x.size(0)
|
|
591
|
+
|
|
592
|
+
def body_fn(i, x, y):
|
|
593
|
+
return i + 1, x + y, y - x
|
|
594
|
+
|
|
595
|
+
return torch._higher_order_ops.while_loop(cond_fn, body_fn, [ci, a, b])
|
|
596
|
+
|
|
597
|
+
_inputs = [(torch.tensor(1), torch.randn(2, 3), torch.randn(2, 3))]
|
|
598
|
+
_dynamic = {}, {0: DYN, 1: DYN}, {0: DYN}
|
|
599
|
+
|
|
600
|
+
|
|
573
601
|
class SignatureInt1(torch.nn.Module):
|
|
574
602
|
def __init__(self, n_dims: int = 3, n_targets: int = 1):
|
|
575
603
|
super().__init__()
|
|
@@ -32,7 +32,7 @@ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
|
|
|
32
32
|
v = getattr(mod, k)
|
|
33
33
|
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
|
|
34
34
|
to_patch.append(v)
|
|
35
|
-
|
|
35
|
+
elif v.__doc__:
|
|
36
36
|
# a function
|
|
37
37
|
doc = v.__doc__.lstrip()
|
|
38
38
|
if doc.startswith("manual patch"):
|
|
@@ -4,14 +4,18 @@ import packaging.version as pv
|
|
|
4
4
|
import optree
|
|
5
5
|
import torch
|
|
6
6
|
import transformers
|
|
7
|
-
from transformers.cache_utils import
|
|
8
|
-
DynamicCache,
|
|
9
|
-
EncoderDecoderCache,
|
|
10
|
-
HybridCache,
|
|
11
|
-
SlidingWindowCache,
|
|
12
|
-
StaticCache,
|
|
13
|
-
)
|
|
7
|
+
from transformers.cache_utils import DynamicCache, StaticCache
|
|
14
8
|
|
|
9
|
+
try:
|
|
10
|
+
from transformers.cache_utils import (
|
|
11
|
+
EncoderDecoderCache,
|
|
12
|
+
HybridCache,
|
|
13
|
+
SlidingWindowCache,
|
|
14
|
+
)
|
|
15
|
+
except ImportError:
|
|
16
|
+
EncoderDecoderCache = None
|
|
17
|
+
HybridCache = None
|
|
18
|
+
SlidingWindowCache = None
|
|
15
19
|
from ..helpers import string_type
|
|
16
20
|
from .serialization import _lower_name_with_
|
|
17
21
|
|