onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import os
3
+ import warnings
3
4
  from typing import Any, Dict, List, Optional, Tuple, Union
4
5
  import numpy as np
5
6
  import onnx
@@ -491,3 +492,549 @@ def onnx_generate_with_genai(
491
492
  if return_session:
492
493
  return input_ids, session
493
494
  return input_ids
495
+
496
+
497
+ _mapping_types = {
498
+ "float": "F",
499
+ "double": "D",
500
+ "float16": "H",
501
+ "uint8": "U8",
502
+ "uint16": "U16",
503
+ "uint32": "U32",
504
+ "uint64": "U64",
505
+ "int8": "I8",
506
+ "int16": "I16",
507
+ "int32": "I32",
508
+ "int64": "I64",
509
+ }
510
+
511
+
512
+ def _process_shape(shape_df):
513
+ if isinstance(shape_df, float) or len(shape_df) == 0:
514
+ return ""
515
+ values = []
516
+ for val in shape_df:
517
+ if len(val) != 1:
518
+ raise ValueError(f"Unable to process shape {val!r} from {values!r}.")
519
+ for _k, _v in val.items():
520
+ k, v = _k, _v
521
+ break
522
+ if v:
523
+ vs = "x".join(map(str, v))
524
+ values.append(f"{_mapping_types.get(k,k)}[{vs}]")
525
+ else:
526
+ values.append(f"{_mapping_types.get(k,k)}")
527
+ return "+".join(values)
528
+
529
+
530
+ def post_process_df_profile(
531
+ df: "pandas.DataFrame", # noqa: F821
532
+ first_it_out: bool = False,
533
+ agg: bool = False,
534
+ agg_op_name: bool = True,
535
+ with_shape: bool = False,
536
+ ) -> "pandas.DataFrame": # noqa: F821
537
+ """
538
+ Post-processed a dataframe obtained after profiling onnxruntime.
539
+ It adds a column for a more explicit event name and adds
540
+ a column for the iteration number
541
+
542
+ :param agg: aggregate the result
543
+ :param first_it_out: leave the first iteration
544
+ out of the aggregation
545
+ :param agg_op_name: aggregate on operator name or operator index
546
+ :param with_shape: keep the shape to aggregate
547
+ :return: DataFrame
548
+ """
549
+ events = {"kernel_time", "fence_after", "fence_before"}
550
+
551
+ def sep_event(s):
552
+ for e in events:
553
+ if s.endswith(e):
554
+ return e
555
+ return s
556
+
557
+ df = df.copy()
558
+ df["event_name"] = df["name"].apply(sep_event)
559
+ df["iteration"] = -1
560
+ current = -1
561
+ for i in range(df.shape[0]):
562
+ if df.loc[i, "name"] == "SequentialExecutor::Execute":
563
+ current += 1
564
+ df.loc[i, "iteration"] = current
565
+
566
+ if not agg:
567
+ if with_shape:
568
+ df["args_input_type_shape"] = df["args_input_type_shape"].apply(_process_shape)
569
+ df["args_output_type_shape"] = df["args_output_type_shape"].apply(_process_shape)
570
+ else:
571
+ df = df.drop(["args_input_type_shape", "args_output_type_shape"], axis=1)
572
+ if first_it_out:
573
+ df["it==0"] = (df["iteration"] <= 0).astype(int)
574
+ return df
575
+
576
+ agg_cols = ["cat", "args_node_index", "args_op_name", "args_provider", "event_name"]
577
+ if with_shape:
578
+ agg_cols.append("args_input_type_shape")
579
+ df["args_input_type_shape"] = df["args_input_type_shape"].apply(_process_shape)
580
+ df["args_output_type_shape"] = df["args_output_type_shape"].apply(_process_shape)
581
+ else:
582
+ df = df.drop(["args_input_type_shape", "args_output_type_shape"], axis=1)
583
+
584
+ if first_it_out:
585
+ df["it==0"] = (df["iteration"] <= 0).astype(int)
586
+ agg_cols.insert(0, "it==0")
587
+ if agg_op_name:
588
+ del agg_cols[agg_cols.index("args_node_index")]
589
+ for c in agg_cols:
590
+ df[c] = df[c].fillna("")
591
+ df["dur"] = df["dur"].fillna(0)
592
+ agg = df[[*agg_cols, "dur"]].groupby(agg_cols).sum()
593
+ return agg
594
+
595
+
596
+ def js_profile_to_dataframe(
597
+ filename: str,
598
+ as_df: bool = True,
599
+ first_it_out: bool = False,
600
+ agg: bool = False,
601
+ agg_op_name: bool = False,
602
+ with_shape: bool = False,
603
+ ) -> Union[List, "pandas.DataFrame"]: # noqa: F821
604
+ """
605
+ Profiles the execution of an onnx graph with onnxruntime.
606
+
607
+ :param filename: filename holding the profiling stored in json format
608
+ :param as_df: returns the
609
+ :param first_it_out: if aggregated, leaves the first iteration out
610
+ :param agg: aggregate by event
611
+ :param agg_op_name: aggregate on operator name or operator index
612
+ :param with_shape: keep the shape before aggregating
613
+ :return: DataFrame or dictionary
614
+ """
615
+ with open(filename, "r") as f:
616
+ content = f.read()
617
+ js = json.loads(content)
618
+
619
+ suffixes = ["_kernel_time", "_fence_before", "_fence_after"]
620
+ rows = []
621
+ for row in js:
622
+ if "args" in row and isinstance(row["args"], dict):
623
+ for k, v in row["args"].items():
624
+ row[f"args_{k}"] = v
625
+ del row["args"]
626
+ name = row["name"]
627
+ for suf in suffixes:
628
+ if name.endswith(suf):
629
+ changed = name[: -len(suf)]
630
+ row["op_name"] = changed
631
+ break
632
+ rows.append(row)
633
+ if as_df:
634
+ import pandas
635
+
636
+ return post_process_df_profile(
637
+ pandas.DataFrame(rows),
638
+ first_it_out=first_it_out,
639
+ agg=agg,
640
+ agg_op_name=agg_op_name,
641
+ with_shape=with_shape,
642
+ )
643
+ return rows
644
+
645
+
646
+ def _preprocess_graph1(df):
647
+ df = df.copy()
648
+ df["args_provider"] = df["args_provider"].apply(
649
+ lambda s: s.replace("ExecutionProvider", "") if isinstance(s, str) else s
650
+ )
651
+ agg_cols = ["dur", "args_op_name", "args_provider"]
652
+ for c in ["it==0", "args_input_type_shape"]:
653
+ if c in df.columns:
654
+ agg_cols.append(c)
655
+ if "it==0" in df.columns:
656
+ vs = ["t>=1", "t=0"]
657
+ df["it==0"] = df["it==0"].apply(lambda v: vs[v])
658
+ gr_dur = df[agg_cols].groupby(agg_cols[1:]).sum().sort_values("dur")
659
+ gr_n = df[agg_cols].groupby(agg_cols[1:]).count()
660
+ gr_n = gr_n.loc[gr_dur.index, :]
661
+ gr_n.columns = ["count"]
662
+ gr = gr_dur.merge(gr_n, left_index=True, right_index=True, how="outer")
663
+ gr["ratio"] = gr["dur"] / gr["dur"].sum()
664
+ return gr_dur, gr_n, gr
665
+
666
+
667
+ def _preprocess_graph2(df):
668
+ df = df.reset_index(drop=False).copy()
669
+ df["args_node_index"] = df["args_node_index"].apply(
670
+ lambda i: int(i) if i not in {None, ""} else -1
671
+ )
672
+ df["args_provider"] = df["args_provider"].apply(
673
+ lambda s: s.replace("ExecutionProvider", "") if isinstance(s, str) else s
674
+ )
675
+ df = df[(df["cat"] == "Node") & (df["event_name"] == "kernel_time")]
676
+ agg_cols = ["dur", "args_node_index", "args_op_name", "args_provider"]
677
+ for c in ["it==0", "args_input_type_shape"]:
678
+ if c in df.columns:
679
+ agg_cols.append(c)
680
+ if "it==0" in df.columns:
681
+ vs = ["t>=1", "t=0"]
682
+ df["it==0"] = df["it==0"].apply(lambda v: vs[v])
683
+ df = df[agg_cols].groupby(agg_cols[1:]).sum()
684
+ df = df.sort_index(ascending=False)
685
+ df["ratio"] = df["dur"] / df["dur"].sum()
686
+ return df
687
+
688
+
689
+ def plot_ort_profile(
690
+ df: "pandas.DataFrame", # noqa: F821
691
+ ax0: Optional["matplotlib.axes.Axes"] = None, # noqa: F821
692
+ ax1: Optional["matplotlib.axes.Axes"] = None, # noqa: F821
693
+ title: Optional[str] = None,
694
+ ) -> "matplotlib.axes.Axes": # noqa: F821
695
+ """
696
+ Plots time spend in computation based on a dataframe
697
+ produced by function :func:`js_profile_to_dataframe`.
698
+
699
+ :param df: dataframe
700
+ :param ax0: first axis to draw time
701
+ :param ax1: second axis to draw occurrences
702
+ :param title: graph title
703
+ :return: the graph
704
+
705
+ .. plot::
706
+ :include-source:
707
+
708
+ import numpy as np
709
+ from onnx import TensorProto
710
+ import onnx.helper as oh
711
+ from onnx.checker import check_model
712
+ from onnx.numpy_helper import from_array
713
+ import matplotlib.pyplot as plt
714
+ from onnxruntime import InferenceSession, SessionOptions
715
+ from onnx_diagnostic.helpers.rt_helper import js_profile_to_dataframe, plot_ort_profile
716
+
717
+
718
+ def get_model():
719
+ model_def0 = oh.make_model(
720
+ oh.make_graph(
721
+ [
722
+ oh.make_node("Add", ["X", "init1"], ["X1"]),
723
+ oh.make_node("Abs", ["X"], ["X2"]),
724
+ oh.make_node("Add", ["X", "init3"], ["inter"]),
725
+ oh.make_node("Mul", ["X1", "inter"], ["Xm"]),
726
+ oh.make_node("Sub", ["X2", "Xm"], ["final"]),
727
+ ],
728
+ "test",
729
+ [oh.make_tensor_value_info("X", TensorProto.FLOAT, [None])],
730
+ [oh.make_tensor_value_info("final", TensorProto.FLOAT, [None])],
731
+ [
732
+ from_array(np.array([1], dtype=np.float32), name="init1"),
733
+ from_array(np.array([3], dtype=np.float32), name="init3"),
734
+ ],
735
+ ),
736
+ opset_imports=[oh.make_opsetid("", 18)],
737
+ ir_version=9,
738
+ )
739
+ check_model(model_def0)
740
+ return model_def0
741
+
742
+
743
+ sess_options = SessionOptions()
744
+ sess_options.enable_profiling = True
745
+ sess = InferenceSession(
746
+ get_model().SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
747
+ )
748
+ for _ in range(11):
749
+ sess.run(None, dict(X=np.arange(10).astype(np.float32)))
750
+ prof = sess.end_profiling()
751
+
752
+ df = js_profile_to_dataframe(prof, first_it_out=True)
753
+ print(df.head())
754
+
755
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
756
+ plot_ort_profile(df, ax[0], ax[1], "test_title")
757
+ fig.tight_layout()
758
+
759
+ With ``agg=True``:
760
+
761
+ .. plot::
762
+ :include-source:
763
+
764
+ import numpy as np
765
+ from onnx import TensorProto
766
+ import onnx.helper as oh
767
+ from onnx.checker import check_model
768
+ from onnx.numpy_helper import from_array
769
+ import matplotlib.pyplot as plt
770
+ from onnxruntime import InferenceSession, SessionOptions
771
+ from onnx_diagnostic.helpers.rt_helper import js_profile_to_dataframe, plot_ort_profile
772
+
773
+
774
+ def get_model():
775
+ model_def0 = oh.make_model(
776
+ oh.make_graph(
777
+ [
778
+ oh.make_node("Add", ["X", "init1"], ["X1"]),
779
+ oh.make_node("Abs", ["X"], ["X2"]),
780
+ oh.make_node("Add", ["X", "init3"], ["inter"]),
781
+ oh.make_node("Mul", ["X1", "inter"], ["Xm"]),
782
+ oh.make_node("Sub", ["X2", "Xm"], ["final"]),
783
+ ],
784
+ "test",
785
+ [oh.make_tensor_value_info("X", TensorProto.FLOAT, [None])],
786
+ [oh.make_tensor_value_info("final", TensorProto.FLOAT, [None])],
787
+ [
788
+ from_array(np.array([1], dtype=np.float32), name="init1"),
789
+ from_array(np.array([3], dtype=np.float32), name="init3"),
790
+ ],
791
+ ),
792
+ opset_imports=[oh.make_opsetid("", 18)],
793
+ ir_version=9,
794
+ )
795
+ check_model(model_def0)
796
+ return model_def0
797
+
798
+
799
+ sess_options = SessionOptions()
800
+ sess_options.enable_profiling = True
801
+ sess = InferenceSession(
802
+ get_model().SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
803
+ )
804
+ for _ in range(11):
805
+ sess.run(None, dict(X=np.arange(10).astype(np.float32)))
806
+ prof = sess.end_profiling()
807
+
808
+ df = js_profile_to_dataframe(prof, first_it_out=True, agg=True)
809
+ print(df.head())
810
+
811
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
812
+ plot_ort_profile(df, ax[0], ax[1], "test_title")
813
+ fig.tight_layout()
814
+ """
815
+ fontsize = 10
816
+ if ax0 is None:
817
+ import matplotlib.pyplot as plt
818
+
819
+ ax0 = plt.gca()
820
+
821
+ if "args_provider" in df.columns:
822
+ # Aggregation by operator
823
+ gr_dur, gr_n, _ = _preprocess_graph1(df)
824
+ gr_dur.plot.barh(ax=ax0)
825
+ with warnings.catch_warnings():
826
+ warnings.simplefilter("ignore")
827
+ ax0.set_xticklabels(ax0.get_xticklabels(), fontsize=fontsize)
828
+ ax0.get_yaxis().set_label_text("")
829
+ ax0.set_yticklabels(
830
+ ax0.get_yticklabels(), rotation=45, ha="right", fontsize=fontsize
831
+ )
832
+ if title is not None:
833
+ ax0.set_title(title)
834
+ if ax1 is not None:
835
+ gr_n.plot.barh(ax=ax1)
836
+ ax1.set_title("n occurrences")
837
+ with warnings.catch_warnings():
838
+ warnings.simplefilter("ignore")
839
+ ax1.set_xticklabels(ax1.get_xticklabels(), fontsize=fontsize)
840
+ ax1.get_yaxis().set_label_text("")
841
+ ax1.set_yticklabels(
842
+ ax1.get_yticklabels(), rotation=45, ha="right", fontsize=fontsize
843
+ )
844
+ return ax0
845
+
846
+ df = _preprocess_graph2(df)
847
+ df[["dur"]].plot.barh(ax=ax0)
848
+ if title is not None:
849
+ ax0.set_title(title)
850
+ with warnings.catch_warnings():
851
+ warnings.simplefilter("ignore")
852
+ ax0.set_xticklabels(ax0.get_xticklabels(), fontsize=fontsize)
853
+ ax0.get_yaxis().set_label_text("")
854
+ ax0.set_yticklabels(ax0.get_yticklabels(), fontsize=fontsize)
855
+ return ax0
856
+
857
+
858
+ def plot_ort_profile_timeline(
859
+ df: "pandas.DataFrame", # noqa: F821
860
+ ax: Optional["matplotlib.axes.Axes"] = None, # noqa: F821
861
+ iteration: int = -2,
862
+ title: Optional[str] = None,
863
+ quantile: float = 0.5,
864
+ fontsize: int = 12,
865
+ ) -> "matplotlib.axes.Axes": # noqa: F821
866
+ """
867
+ Creates a timeline based on a dataframe
868
+ produced by function :func:`js_profile_to_dataframe`.
869
+
870
+ :param df: dataframe
871
+ :param ax: first axis to draw time
872
+ :param iteration: iteration to plot, negative value to start from the end
873
+ :param title: graph title
874
+ :param quantile: draw the 10% less consuming operators in a different color
875
+ :param fontsize: font size
876
+ :return: the graph
877
+
878
+ .. plot::
879
+ :include-source:
880
+
881
+ import numpy as np
882
+ from onnx import TensorProto
883
+ import onnx.helper as oh
884
+ from onnx.checker import check_model
885
+ from onnx.numpy_helper import from_array
886
+ import matplotlib.pyplot as plt
887
+ from onnxruntime import InferenceSession, SessionOptions
888
+ from onnx_diagnostic.helpers.rt_helper import (
889
+ js_profile_to_dataframe,
890
+ plot_ort_profile_timeline,
891
+ )
892
+
893
+
894
+ def get_model():
895
+ model_def0 = oh.make_model(
896
+ oh.make_graph(
897
+ [
898
+ oh.make_node("Add", ["X", "init1"], ["X1"]),
899
+ oh.make_node("Abs", ["X"], ["X2"]),
900
+ oh.make_node("Add", ["X", "init3"], ["inter"]),
901
+ oh.make_node("Mul", ["X1", "inter"], ["Xm"]),
902
+ oh.make_node("Sub", ["X2", "Xm"], ["final"]),
903
+ ],
904
+ "test",
905
+ [oh.make_tensor_value_info("X", TensorProto.FLOAT, [None])],
906
+ [oh.make_tensor_value_info("final", TensorProto.FLOAT, [None])],
907
+ [
908
+ from_array(np.array([1], dtype=np.float32), name="init1"),
909
+ from_array(np.array([3], dtype=np.float32), name="init3"),
910
+ ],
911
+ ),
912
+ opset_imports=[oh.make_opsetid("", 18)],
913
+ ir_version=9,
914
+ )
915
+ check_model(model_def0)
916
+ return model_def0
917
+
918
+
919
+ sess_options = SessionOptions()
920
+ sess_options.enable_profiling = True
921
+ sess = InferenceSession(
922
+ get_model().SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
923
+ )
924
+ for _ in range(11):
925
+ sess.run(None, dict(X=np.arange(10).astype(np.float32)))
926
+ prof = sess.end_profiling()
927
+
928
+ df = js_profile_to_dataframe(prof, first_it_out=True)
929
+ print(df.head())
930
+
931
+ fig, ax = plt.subplots(1, 1, figsize=(10, 5))
932
+ plot_ort_profile_timeline(df, ax, title="test_timeline", quantile=0.5)
933
+ fig.tight_layout()
934
+ """
935
+ if ax is None:
936
+ import matplotlib.pyplot as plt
937
+
938
+ ax = plt.gca()
939
+
940
+ df = df.copy()
941
+ df["iteration"] = df["iteration"].astype(int)
942
+ iterations = set(df["iteration"])
943
+ n_iter = iteration if iteration >= 0 else max(iterations) + 1 + iteration
944
+ dfi = df[df["iteration"] == n_iter]
945
+ assert dfi.shape[0] > 0, f"Iteration {iteration} cannot be found in {iterations}."
946
+
947
+ if "fence_before" in set(dfi["event_name"]):
948
+ started = {}
949
+ data = []
950
+ for irow in dfi.iterrows():
951
+ assert isinstance(irow, tuple), f"pandas has changed its api, type is {type(irow)}"
952
+ assert len(irow) == 2, f"pandas has changed its api, row is {irow}"
953
+ row = irow[1]
954
+ it = row["iteration"]
955
+ op_type = row["args_op_name"]
956
+ op_name = row["op_name"]
957
+ event_name = row["event_name"]
958
+ provider = row["args_provider"]
959
+ ts = float(row["ts"])
960
+ dur = float(row["dur"])
961
+ if event_name == "fence_before":
962
+ started[op_type, op_name, it] = dict(
963
+ op_name=op_name, op_type=op_type, begin=ts
964
+ )
965
+ elif event_name == "kernel_time":
966
+ obs = started[op_type, op_name, it]
967
+ obs["duration"] = dur
968
+ obs["begin_kernel"] = ts
969
+ obs["provider"] = provider
970
+ elif event_name == "fence_after":
971
+ obs = started[op_type, op_name, it]
972
+ obs["end"] = ts
973
+ data.append(obs)
974
+ del started[op_type, op_name, it]
975
+ else:
976
+ assert event_name in {
977
+ "SequentialExecutor::Execute",
978
+ "model_run",
979
+ }, f"Unexpected event_name={event_name!r}, row={row}"
980
+ else:
981
+ # New format
982
+ data = []
983
+ for irow in dfi.iterrows():
984
+ row = irow[1]
985
+ if row["event_name"] != "kernel_time":
986
+ continue
987
+ obs = dict(
988
+ duration=float(row["dur"]),
989
+ op_name=row["op_name"],
990
+ op_type=row["args_op_name"],
991
+ provider=row["args_provider"],
992
+ begin=float(row["ts"]),
993
+ end=float(row["ts"]) + float(row["dur"]),
994
+ begin_kernel=float(row["ts"]),
995
+ )
996
+ data.append(obs)
997
+
998
+ # durations
999
+ data_dur = list(sorted(d["duration"] for d in data))
1000
+ threshold = data_dur[int(quantile * len(data_dur))]
1001
+ origin = dfi["ts"].min()
1002
+
1003
+ colors = ["blue", "green", "red", "orange"]
1004
+
1005
+ import matplotlib.patches as mpatches
1006
+
1007
+ cs = [0, 0]
1008
+ for i, obs in enumerate(data):
1009
+ dur = obs["duration"]
1010
+ cat = int(dur >= threshold)
1011
+
1012
+ # color
1013
+ color = colors[cat * 2 + cs[cat] % 2]
1014
+ cs[cat] += 1
1015
+
1016
+ # rectangle
1017
+ t1 = obs["begin"] - origin
1018
+ t2 = obs["end"] - origin
1019
+ shape = mpatches.Rectangle((0, t1), 1, t2 - t1, ec="none", color=color)
1020
+ ax.add_artist(shape)
1021
+ tk1 = obs["begin_kernel"] - origin
1022
+ tk2 = (obs["begin_kernel"] + obs["duration"]) - origin
1023
+ ax.plot([0, 1], [tk1, tk1], "b--")
1024
+ ax.plot([0, 1], [tk2, tk2], "b--")
1025
+ if i == 0:
1026
+ ax.plot([0, 2], [tk1, tk1], "b")
1027
+ elif i == len(data) - 1:
1028
+ ax.plot([0, 2], [tk2, tk2], "b")
1029
+
1030
+ # text
1031
+ y = (tk1 + tk2) / 2
1032
+ text = obs["op_type"]
1033
+ prov = obs["provider"].replace("ExecutionProvider", "")
1034
+ name = obs["op_name"]
1035
+ if len(name) >= 10:
1036
+ name = name[:5] + "..." + name[5:]
1037
+ ax.text(1, y, f"{i}:{prov}:{text}-{name}", fontsize=fontsize, va="center")
1038
+
1039
+ ax.invert_yaxis()
1040
+ return ax
@@ -450,6 +450,11 @@ def fake_torchdynamo_exporting():
450
450
  """
451
451
  memorize = torch.compiler._is_exporting_flag
452
452
  torch.compiler._is_exporting_flag = True
453
+ assert torch.compiler.is_exporting(), (
454
+ f"Changes not detected "
455
+ f"torch.compiler._is_exporting_flag={torch.compiler._is_exporting_flag} "
456
+ f"and torch.compiler.is_exporting()={torch.compiler.is_exporting()}"
457
+ )
453
458
  try:
454
459
  yield
455
460
  finally:
@@ -311,7 +311,11 @@ def get_inputs_default(
311
311
  attention_mask=torch.cat(
312
312
  [
313
313
  torch.ones((batch_size, sequence_length), dtype=torch.int64),
314
- input_ids.ne(pad_token_id).to(torch.int64),
314
+ (
315
+ torch.ones(input_ids.shape)
316
+ if pad_token_id is None
317
+ else input_ids.ne(pad_token_id)
318
+ ).to(torch.int64),
315
319
  ],
316
320
  axis=-1,
317
321
  ),
@@ -570,6 +570,34 @@ class ControlFlowScanDecomposition_151564(torch.nn.Module):
570
570
  _dynamic = {"images": {0: DYN, 1: DYN}, "position": {0: DYN}}
571
571
 
572
572
 
573
+ class ControlFlowWhileDec(torch.nn.Module):
574
+ def forward(self, ci, a, b):
575
+ def cond_fn(i, x, y):
576
+ return i > 0
577
+
578
+ def body_fn(i, x, y):
579
+ return i - 1, x + y, y - x
580
+
581
+ return torch._higher_order_ops.while_loop(cond_fn, body_fn, [ci, a, b])
582
+
583
+ _inputs = [(torch.tensor(1), torch.randn(2, 3), torch.randn(2, 3))]
584
+ _dynamic = {}, {0: DYN, 1: DYN}, {0: DYN}
585
+
586
+
587
+ class ControlFlowWhileInc(torch.nn.Module):
588
+ def forward(self, ci, a, b):
589
+ def cond_fn(i, x, y):
590
+ return i < x.size(0)
591
+
592
+ def body_fn(i, x, y):
593
+ return i + 1, x + y, y - x
594
+
595
+ return torch._higher_order_ops.while_loop(cond_fn, body_fn, [ci, a, b])
596
+
597
+ _inputs = [(torch.tensor(1), torch.randn(2, 3), torch.randn(2, 3))]
598
+ _dynamic = {}, {0: DYN, 1: DYN}, {0: DYN}
599
+
600
+
573
601
  class SignatureInt1(torch.nn.Module):
574
602
  def __init__(self, n_dims: int = 3, n_targets: int = 1):
575
603
  super().__init__()
@@ -32,7 +32,7 @@ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
32
32
  v = getattr(mod, k)
33
33
  if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
34
34
  to_patch.append(v)
35
- else:
35
+ elif v.__doc__:
36
36
  # a function
37
37
  doc = v.__doc__.lstrip()
38
38
  if doc.startswith("manual patch"):
@@ -4,14 +4,18 @@ import packaging.version as pv
4
4
  import optree
5
5
  import torch
6
6
  import transformers
7
- from transformers.cache_utils import (
8
- DynamicCache,
9
- EncoderDecoderCache,
10
- HybridCache,
11
- SlidingWindowCache,
12
- StaticCache,
13
- )
7
+ from transformers.cache_utils import DynamicCache, StaticCache
14
8
 
9
+ try:
10
+ from transformers.cache_utils import (
11
+ EncoderDecoderCache,
12
+ HybridCache,
13
+ SlidingWindowCache,
14
+ )
15
+ except ImportError:
16
+ EncoderDecoderCache = None
17
+ HybridCache = None
18
+ SlidingWindowCache = None
15
19
  from ..helpers import string_type
16
20
  from .serialization import _lower_name_with_
17
21