onnxslim 0.1.71__py3-none-any.whl → 0.1.72__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. onnxslim/core/__init__.py +10 -3
  2. onnxslim/core/optimization/__init__.py +2 -2
  3. onnxslim/core/optimization/dead_node_elimination.py +1 -1
  4. onnxslim/core/pattern/__init__.py +15 -2
  5. onnxslim/core/pattern/elimination/slice.py +1 -1
  6. onnxslim/core/pattern/elimination/unsqueeze.py +2 -2
  7. onnxslim/core/pattern/fusion/concat_reshape.py +1 -1
  8. onnxslim/core/pattern/fusion/convadd.py +1 -1
  9. onnxslim/core/pattern/fusion/convbn.py +1 -1
  10. onnxslim/core/pattern/fusion/gemm.py +1 -1
  11. onnxslim/core/pattern/fusion/padconv.py +1 -2
  12. onnxslim/core/pattern/registry.py +3 -1
  13. onnxslim/misc/tabulate.py +9 -7
  14. onnxslim/third_party/_sympy/printers.py +3 -2
  15. onnxslim/third_party/_sympy/symbol.py +4 -3
  16. onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +2 -2
  17. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +1 -1
  18. onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +5 -5
  19. onnxslim/third_party/onnx_graphsurgeon/ir/function.py +9 -9
  20. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +10 -10
  21. onnxslim/third_party/onnx_graphsurgeon/ir/node.py +7 -7
  22. onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +5 -5
  23. onnxslim/third_party/symbolic_shape_infer.py +111 -111
  24. onnxslim/utils.py +6 -7
  25. onnxslim/version.py +1 -1
  26. {onnxslim-0.1.71.dist-info → onnxslim-0.1.72.dist-info}/METADATA +1 -1
  27. {onnxslim-0.1.71.dist-info → onnxslim-0.1.72.dist-info}/RECORD +32 -32
  28. {onnxslim-0.1.71.dist-info → onnxslim-0.1.72.dist-info}/WHEEL +0 -0
  29. {onnxslim-0.1.71.dist-info → onnxslim-0.1.72.dist-info}/entry_points.txt +0 -0
  30. {onnxslim-0.1.71.dist-info → onnxslim-0.1.72.dist-info}/licenses/LICENSE +0 -0
  31. {onnxslim-0.1.71.dist-info → onnxslim-0.1.72.dist-info}/top_level.txt +0 -0
  32. {onnxslim-0.1.71.dist-info → onnxslim-0.1.72.dist-info}/zip-safe +0 -0
onnxslim/core/__init__.py CHANGED
@@ -1,6 +1,9 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import os
3
5
  import tempfile
6
+ from typing import Optional
4
7
 
5
8
  import numpy as np
6
9
  import onnx
@@ -18,6 +21,7 @@ logger = logging.getLogger("onnxslim")
18
21
 
19
22
  DEBUG = bool(os.getenv("ONNXSLIM_DEBUG"))
20
23
  AUTO_MERGE = True if os.getenv("ONNXSLIM_AUTO_MERGE") is None else bool(int(os.getenv("ONNXSLIM_AUTO_MERGE")))
24
+ FORCE_ONNXRUNTIME_SHAPE_INFERENCE = bool(os.getenv("ONNXSLIM_FORCE_ONNXRUNTIME_SHAPE_INFERENCE"))
21
25
 
22
26
 
23
27
  def input_shape_modification(model: onnx.ModelProto, input_shapes: str) -> onnx.ModelProto:
@@ -122,6 +126,9 @@ def input_modification(model: onnx.ModelProto, inputs: str) -> onnx.ModelProto:
122
126
  def shape_infer(model: onnx.ModelProto):
123
127
  """Infer tensor shapes in an ONNX model using symbolic and static shape inference techniques."""
124
128
  logger.debug("Start shape inference.")
129
+ if FORCE_ONNXRUNTIME_SHAPE_INFERENCE:
130
+ logger.debug("force onnxruntime shape infer.")
131
+ return SymbolicShapeInference.infer_shapes(model, auto_merge=AUTO_MERGE)
125
132
  try:
126
133
  logger.debug("try onnxruntime shape infer.")
127
134
  model = SymbolicShapeInference.infer_shapes(model, auto_merge=AUTO_MERGE)
@@ -142,7 +149,7 @@ def shape_infer(model: onnx.ModelProto):
142
149
  return model
143
150
 
144
151
 
145
- def optimize(model: onnx.ModelProto, skip_fusion_patterns: str = None, size_threshold: int = None):
152
+ def optimize(model: onnx.ModelProto, skip_fusion_patterns: str | None = None, size_threshold: int | None = None):
146
153
  """Optimize the given ONNX model with options to skip specific fusion patterns and return the optimized model."""
147
154
  logger.debug("Start converting model to gs.")
148
155
  graph = gs.import_onnx(model).toposort()
@@ -171,11 +178,11 @@ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto:
171
178
 
172
179
  for node in graph.nodes:
173
180
  if node.op == "Cast":
174
- inp_dtype = [input.dtype for input in node.inputs][0]
181
+ inp_dtype = next(input.dtype for input in node.inputs)
175
182
  if inp_dtype in [np.float16, np.float32]:
176
183
  node.erase()
177
184
  else:
178
- outp_dtype = [output.dtype for output in node.outputs][0]
185
+ outp_dtype = next(output.dtype for output in node.outputs)
179
186
  if outp_dtype == np.float16:
180
187
  node.attrs["to"] = dtype_to_onnx(np.float32)
181
188
  node.outputs[0].dtype = np.float32
@@ -51,7 +51,7 @@ class OptimizationSettings:
51
51
  return any([getattr(cls, key) for key in cls.keys()])
52
52
 
53
53
 
54
- def optimize_model(model: onnx.ModelProto | gs.Graph, skip_fusion_patterns: str = None) -> onnx.ModelProto:
54
+ def optimize_model(model: onnx.ModelProto | gs.Graph, skip_fusion_patterns: str | None = None) -> onnx.ModelProto:
55
55
  """Optimize and transform the given ONNX model using various fusion patterns and graph rewriting techniques."""
56
56
  graph = model if isinstance(model, gs.Graph) else gs.import_onnx(model)
57
57
  if OptimizationSettings.graph_fusion:
@@ -85,7 +85,7 @@ def replace_custom_layer(
85
85
  inputs,
86
86
  outputs: list[str],
87
87
  name: str,
88
- attrs: dict = None,
88
+ attrs: dict | None = None,
89
89
  domain: str = "ai.onnx.contrib",
90
90
  ):
91
91
  """Replace a custom layer in the computational graph with specified parameters and domain."""
@@ -29,7 +29,7 @@ def dead_node_elimination(graph, is_subgraph=False):
29
29
  node.erase()
30
30
  logger.debug(f"removing {node.op} op: {node.name}")
31
31
  elif node.op == "Cast":
32
- inp_dtype = [dtype_to_onnx(input.dtype) for input in node.inputs][0]
32
+ inp_dtype = next(dtype_to_onnx(input.dtype) for input in node.inputs)
33
33
  if inp_dtype == node.attrs["to"]:
34
34
  node.erase()
35
35
  logger.debug(f"removing {node.op} op: {node.name}")
@@ -114,10 +114,17 @@ class PatternMatcher:
114
114
  """Retrieve the match point node from the pattern dictionary based on output node input names."""
115
115
  return self.pattern_dict[self.pattern_dict[self.output_names[0]].input_names[0]]
116
116
 
117
+ def reset_input(self):
118
+ for k, v in self.pattern_dict.items():
119
+ if v.op == "input":
120
+ if hasattr(self, v.name):
121
+ delattr(self, v.name)
122
+
117
123
  def match(self, node):
118
124
  """Match a given node to a pattern by comparing input names with the match point node from the pattern
119
125
  dictionary.
120
126
  """
127
+ self.reset_input()
121
128
  match_point = self.get_match_point()
122
129
 
123
130
  def match_(node, pattern_node):
@@ -125,8 +132,14 @@ class PatternMatcher:
125
132
  dictionary.
126
133
  """
127
134
  if pattern_node.op == "input":
128
- return True
129
-
135
+ if hasattr(self, pattern_node.name):
136
+ if getattr(self, pattern_node.name) == node:
137
+ return True
138
+ else:
139
+ return False
140
+ else:
141
+ setattr(self, pattern_node.name, node)
142
+ return True
130
143
  # node is an input variable
131
144
  if not hasattr(node, "op"):
132
145
  return False
@@ -58,7 +58,7 @@ class SlicePatternMatcher(PatternMatcher):
58
58
  inputs = []
59
59
  inputs.extend(
60
60
  (
61
- list(first_slice_node.inputs)[0],
61
+ next(iter(first_slice_node.inputs)),
62
62
  gs.Constant(
63
63
  second_slice_node_inputs[1].name + "_starts",
64
64
  values=np.array(new_starts, dtype=np.int64),
@@ -50,14 +50,14 @@ class UnsqueezePatternMatcher(PatternMatcher):
50
50
  axis + sum(1 for axis_ in axes_node_unsqueeze_1 if axis_ <= axis) for axis in axes_node_unsqueeze_0
51
51
  ]
52
52
 
53
- inputs = [list(node_unsqueeze_0.inputs)[0]]
53
+ inputs = [next(iter(node_unsqueeze_0.inputs))]
54
54
  outputs = list(node_unsqueeze_1.outputs)
55
55
 
56
56
  index = node_unsqueeze_1.inputs.index(node_unsqueeze_0.outputs[0])
57
57
  node_unsqueeze_1.inputs.pop(index)
58
58
  for i, item in enumerate(node_unsqueeze_0.inputs):
59
59
  node_unsqueeze_1.inputs.insert(index + i, item)
60
- inputs = [list(node_unsqueeze_1.inputs)[0]]
60
+ inputs = [next(iter(node_unsqueeze_1.inputs))]
61
61
  outputs = list(node_unsqueeze_1.outputs)
62
62
  node_unsqueeze_1.inputs.clear()
63
63
  node_unsqueeze_1.outputs.clear()
@@ -36,7 +36,7 @@ class ConcatReshapeMatcher(PatternMatcher):
36
36
  def rewrite(self, opset=11):
37
37
  match_case = {}
38
38
  concat_node = self.concat_0
39
- index = [idx for idx, i in enumerate(concat_node.inputs) if isinstance(i, gs.Variable)][0]
39
+ index = next(idx for idx, i in enumerate(concat_node.inputs) if isinstance(i, gs.Variable))
40
40
  constant = gs.Constant(
41
41
  concat_node.inputs[index].name + "_fixed",
42
42
  values=np.array([-1], dtype=np.int64),
@@ -41,7 +41,7 @@ class ConvAddMatcher(PatternMatcher):
41
41
  conv_bias = conv_node.inputs[2].values + node.inputs[1].values.squeeze()
42
42
 
43
43
  inputs = []
44
- inputs.append(list(conv_node.inputs)[0])
44
+ inputs.append(next(iter(conv_node.inputs)))
45
45
  inputs.append(conv_weight)
46
46
  weight_name = list(conv_node.inputs)[1].name
47
47
  if weight_name.endswith("weight"):
@@ -53,7 +53,7 @@ class ConvBatchNormMatcher(PatternMatcher):
53
53
  conv_b = (conv_transpose_bias - bn_running_mean) * bn_var_rsqrt * bn_scale + bn_bias
54
54
 
55
55
  inputs = []
56
- inputs.append(list(conv_transpose_node.inputs)[0])
56
+ inputs.append(next(iter(conv_transpose_node.inputs)))
57
57
  weight_name = list(conv_transpose_node.inputs)[1].name
58
58
  if weight_name.endswith("weight"):
59
59
  bias_name = f"{weight_name[:-6]}bias"
@@ -105,7 +105,7 @@ class MatMulAddPatternMatcher(PatternMatcher):
105
105
  }
106
106
  )
107
107
 
108
- values = list(input_variable.shape[:-1]) + [matmul_bias_variable.values.shape[-1]]
108
+ values = [*list(input_variable.shape[:-1]), matmul_bias_variable.values.shape[-1]]
109
109
  post_reshape_const = gs.Constant(
110
110
  f"{matmul_node.name}_post_reshape_in",
111
111
  values=np.array(values, dtype=np.int64),
@@ -37,8 +37,7 @@ class PadConvMatcher(PatternMatcher):
37
37
 
38
38
  pad_inputs = len(pad_node.inputs)
39
39
  if pad_inputs < 3 or (
40
- pad_inputs >= 3
41
- and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0)
40
+ (pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0))
42
41
  or (pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Variable) and pad_node.inputs[2].name == ""))
43
42
  ):
44
43
  if (
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections import OrderedDict
2
4
 
3
5
  DEFAULT_FUSION_PATTERNS = OrderedDict()
@@ -12,7 +14,7 @@ def register_fusion_pattern(fusion_pattern):
12
14
  DEFAULT_FUSION_PATTERNS[layer_type] = fusion_pattern
13
15
 
14
16
 
15
- def get_fusion_patterns(skip_fusion_patterns: str = None):
17
+ def get_fusion_patterns(skip_fusion_patterns: str | None = None):
16
18
  """Returns a copy of the default fusion patterns, optionally excluding specific patterns."""
17
19
  default_fusion_patterns = DEFAULT_FUSION_PATTERNS.copy()
18
20
  if skip_fusion_patterns:
onnxslim/misc/tabulate.py CHANGED
@@ -24,7 +24,7 @@ def _is_file(f):
24
24
  return isinstance(f, io.IOBase)
25
25
 
26
26
 
27
- __all__ = ["tabulate", "tabulate_formats", "simple_separated_format"]
27
+ __all__ = ["simple_separated_format", "tabulate", "tabulate_formats"]
28
28
  try:
29
29
  from .version import version as __version__ # noqa: F401
30
30
  except ImportError:
@@ -1202,7 +1202,7 @@ def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True):
1202
1202
  if val is None:
1203
1203
  return missingval
1204
1204
 
1205
- if valtype is str or valtype is not int and valtype is not bytes and valtype is not float:
1205
+ if valtype is str or (valtype is not int and valtype is not bytes and valtype is not float):
1206
1206
  return f"{val}"
1207
1207
  elif valtype is int:
1208
1208
  return format(val, intfmt)
@@ -1276,7 +1276,7 @@ def _prepend_row_index(rows, index):
1276
1276
  index_iter = iter(index)
1277
1277
  for row in sans_rows:
1278
1278
  index_v = next(index_iter)
1279
- new_rows.append([index_v] + list(row))
1279
+ new_rows.append([index_v, *list(row)])
1280
1280
  rows = new_rows
1281
1281
  _reinsert_separating_lines(rows, separating_lines)
1282
1282
  return rows
@@ -1325,7 +1325,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
1325
1325
  """
1326
1326
  try:
1327
1327
  bool(headers)
1328
- is_headers2bool_broken = False # noqa
1328
+ is_headers2bool_broken = False
1329
1329
  except ValueError: # numpy.ndarray, pandas.core.index.Index, ...
1330
1330
  is_headers2bool_broken = True # noqa
1331
1331
  headers = list(headers)
@@ -1423,7 +1423,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
1423
1423
  if headers == "firstrow":
1424
1424
  if len(rows) > 0:
1425
1425
  if index is not None:
1426
- headers = [index[0]] + list(rows[0])
1426
+ headers = [index[0], *list(rows[0])]
1427
1427
  index = index[1:]
1428
1428
  else:
1429
1429
  headers = rows[0]
@@ -2534,8 +2534,10 @@ class _CustomTextWrap(textwrap.TextWrapper):
2534
2534
  if (
2535
2535
  self.max_lines is None
2536
2536
  or len(lines) + 1 < self.max_lines
2537
- or (not chunks or self.drop_whitespace and len(chunks) == 1 and not chunks[0].strip())
2538
- and cur_len <= width
2537
+ or (
2538
+ (not chunks or (self.drop_whitespace and len(chunks) == 1 and not chunks[0].strip()))
2539
+ and cur_len <= width
2540
+ )
2539
2541
  ):
2540
2542
  # Convert current line back to a string and store it in
2541
2543
  # list of all lines (return value).
@@ -1,5 +1,6 @@
1
+ from __future__ import annotations
2
+
1
3
  import sys
2
- from typing import Optional
3
4
 
4
5
  import sympy
5
6
  from sympy.printing.precedence import PRECEDENCE, precedence
@@ -22,7 +23,7 @@ class ExprPrinter(StrPrinter):
22
23
  def _print_Not(self, expr: sympy.Expr) -> str:
23
24
  return f"not ({self._print(expr.args[0])})"
24
25
 
25
- def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str:
26
+ def _print_Add(self, expr: sympy.Expr, order: str | None = None) -> str:
26
27
  return self.stringify(expr.args, " + ", precedence(expr))
27
28
 
28
29
  def _print_Relational(self, expr: sympy.Expr) -> str:
@@ -12,9 +12,10 @@ You can occasionally test if prefixes have been hardcoded by renaming prefixes
12
12
  in this file and seeing what breaks.
13
13
  """
14
14
 
15
+ from __future__ import annotations
16
+
15
17
  from collections.abc import Iterable
16
18
  from enum import Enum, auto
17
- from typing import Union
18
19
 
19
20
  import sympy
20
21
 
@@ -88,7 +89,7 @@ def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
88
89
 
89
90
  # This type is a little wider than it should be, because free_symbols says
90
91
  # that it contains Basic, rather than Symbol
91
- def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Iterable[SymT]]) -> bool:
92
+ def symbol_is_type(sym: sympy.Basic, prefix: SymT | Iterable[SymT]) -> bool:
92
93
  assert isinstance(sym, sympy.Symbol)
93
94
  name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK
94
95
  if isinstance(prefix, SymT):
@@ -97,5 +98,5 @@ def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Iterable[SymT]]) -> boo
97
98
  return name_str.startswith(tuple(prefix_str[p] for p in prefix))
98
99
 
99
100
 
100
- def free_symbol_is_type(e: sympy.Expr, prefix: Union[SymT, Iterable[SymT]]) -> bool:
101
+ def free_symbol_is_type(e: sympy.Expr, prefix: SymT | Iterable[SymT]) -> bool:
101
102
  return any(symbol_is_type(v, prefix) for v in e.free_symbols)
@@ -278,8 +278,8 @@ class OnnxExporter(BaseExporter):
278
278
  def export_graph(
279
279
  graph_proto: onnx.GraphProto,
280
280
  graph: Graph,
281
- tensor_map: OrderedDict[str, Tensor] = None,
282
- subgraph_tensor_map: OrderedDict[str, Tensor] = None,
281
+ tensor_map: OrderedDict[str, Tensor] | None = None,
282
+ subgraph_tensor_map: OrderedDict[str, Tensor] | None = None,
283
283
  do_type_check=True,
284
284
  ) -> None:
285
285
  """
@@ -82,7 +82,7 @@ class PatternMapping(dict):
82
82
  def __str__(self) -> str:
83
83
  """Returns a string representation of the pattern mapping, including inputs, outputs, and constants."""
84
84
  if self.onnx_node is None:
85
- return "{" + str.join(", ", [f"{key}: {str(value)}" for key, value in self.items()]) + "}"
85
+ return "{" + str.join(", ", [f"{key}: {value!s}" for key, value in self.items()]) + "}"
86
86
  return self.onnx_node.name
87
87
 
88
88
 
@@ -357,7 +357,7 @@ class OnnxImporter(BaseImporter):
357
357
  @staticmethod
358
358
  def import_function(
359
359
  onnx_function: onnx.FunctionProto,
360
- model_opset: int = None,
360
+ model_opset: int | None = None,
361
361
  model_import_domains: onnx.OperatorSetIdProto = None,
362
362
  ) -> Function:
363
363
  """Imports an ONNX function to a Function object using the model opset and import domains."""
@@ -405,13 +405,13 @@ class OnnxImporter(BaseImporter):
405
405
  @staticmethod
406
406
  def import_graph(
407
407
  onnx_graph: onnx.GraphProto,
408
- tensor_map: OrderedDict[str, Tensor] = None,
408
+ tensor_map: OrderedDict[str, Tensor] | None = None,
409
409
  opset=None,
410
410
  import_domains: onnx.OperatorSetIdProto = None,
411
411
  ir_version=None,
412
- producer_name: str = None,
413
- producer_version: str = None,
414
- functions: list[Function] = None,
412
+ producer_name: str | None = None,
413
+ producer_version: str | None = None,
414
+ functions: list[Function] | None = None,
415
415
  metadata_props=None,
416
416
  ) -> Graph:
417
417
  """
@@ -43,15 +43,15 @@ class Function(Graph):
43
43
  def __init__(
44
44
  self,
45
45
  name: str,
46
- domain: str = None,
47
- nodes: Sequence[Node] = None,
48
- inputs: Sequence[Tensor] = None,
49
- outputs: Sequence[Tensor] = None,
50
- doc_string: str = None,
51
- opset: int = None,
52
- import_domains: Sequence[onnx.OperatorSetIdProto] = None,
53
- functions: Sequence[Function] = None,
54
- attrs: dict = None,
46
+ domain: str | None = None,
47
+ nodes: Sequence[Node] | None = None,
48
+ inputs: Sequence[Tensor] | None = None,
49
+ outputs: Sequence[Tensor] | None = None,
50
+ doc_string: str | None = None,
51
+ opset: int | None = None,
52
+ import_domains: Sequence[onnx.OperatorSetIdProto] | None = None,
53
+ functions: Sequence[Function] | None = None,
54
+ attrs: dict | None = None,
55
55
  ):
56
56
  """
57
57
  Args:
@@ -100,17 +100,17 @@ class Graph:
100
100
 
101
101
  def __init__(
102
102
  self,
103
- nodes: Sequence[Node] = None,
104
- inputs: Sequence[Tensor] = None,
105
- outputs: Sequence[Tensor] = None,
103
+ nodes: Sequence[Node] | None = None,
104
+ inputs: Sequence[Tensor] | None = None,
105
+ outputs: Sequence[Tensor] | None = None,
106
106
  name=None,
107
107
  doc_string=None,
108
108
  opset=None,
109
109
  import_domains=None,
110
110
  ir_version=None,
111
- producer_name: str = None,
112
- producer_version: str = None,
113
- functions: Sequence[Function] = None,
111
+ producer_name: str | None = None,
112
+ producer_version: str | None = None,
113
+ functions: Sequence[Function] | None = None,
114
114
  metadata_props=None,
115
115
  ):
116
116
  """
@@ -347,7 +347,7 @@ class Graph:
347
347
  func_ids.add(func.unique_id)
348
348
  return self.functions
349
349
 
350
- for graph in self.functions + [self]:
350
+ for graph in [*self.functions, self]:
351
351
  for subgraph in graph.subgraphs(recursive=True):
352
352
  new_list = absorb_function_list(subgraph.functions)
353
353
  subgraph._functions = new_list
@@ -774,7 +774,7 @@ class Graph:
774
774
  if len(node.attrs) != 1:
775
775
  G_LOGGER.warning("Constant node must contain exactly one attribute")
776
776
  continue
777
- attr_name, attr_val = list(node.attrs.items())[0]
777
+ attr_name, attr_val = next(iter(node.attrs.items()))
778
778
  allowed_attrs = {
779
779
  "value",
780
780
  "value_float",
@@ -975,7 +975,7 @@ class Graph:
975
975
 
976
976
  def get_scalar_value(tensor):
977
977
  """Gets the scalar value of a constant tensor with a single item."""
978
- return list(tensor.values)[0] if tensor.shape else tensor.values
978
+ return next(iter(tensor.values)) if tensor.shape else tensor.values
979
979
 
980
980
  def fold_shape(tensor):
981
981
  """Returns the input tensor shape if available, otherwise returns None."""
@@ -1409,7 +1409,7 @@ class Graph:
1409
1409
  self.nodes.append(node)
1410
1410
  return node.outputs
1411
1411
 
1412
- def copy(self, tensor_map: OrderedDict[str, Tensor] = None):
1412
+ def copy(self, tensor_map: OrderedDict[str, Tensor] | None = None):
1413
1413
  """
1414
1414
  Copy the graph.
1415
1415
 
@@ -42,11 +42,11 @@ class Node:
42
42
  def __init__(
43
43
  self,
44
44
  op: str,
45
- name: str = None,
46
- attrs: dict[str, object] = None,
47
- inputs: list[Tensor] = None,
48
- outputs: list[Tensor] = None,
49
- domain: str = None,
45
+ name: str | None = None,
46
+ attrs: dict[str, object] | None = None,
47
+ inputs: list[Tensor] | None = None,
48
+ outputs: list[Tensor] | None = None,
49
+ domain: str | None = None,
50
50
  ):
51
51
  """
52
52
  A node represents an operation in a graph, and consumes zero or more Tensors, and produces zero or more Tensors.
@@ -152,8 +152,8 @@ class Node:
152
152
 
153
153
  def copy(
154
154
  self,
155
- inputs: list[Tensor] = None,
156
- outputs: list[Tensor] = None,
155
+ inputs: list[Tensor] | None = None,
156
+ outputs: list[Tensor] | None = None,
157
157
  tensor_map=None,
158
158
  ):
159
159
  """
@@ -65,7 +65,7 @@ class Tensor:
65
65
  def to_constant(
66
66
  self,
67
67
  values: np.ndarray,
68
- data_location: int = None,
68
+ data_location: int | None = None,
69
69
  export_dtype: np.dtype | onnx.TensorProto.DataType = None,
70
70
  ):
71
71
  """
@@ -91,7 +91,7 @@ class Tensor:
91
91
 
92
92
  return self
93
93
 
94
- def to_variable(self, dtype: np.dtype | onnx.TensorProto.DataType = None, shape: Sequence[int | str] = None):
94
+ def to_variable(self, dtype: np.dtype | onnx.TensorProto.DataType = None, shape: Sequence[int | str] | None = None):
95
95
  """
96
96
  Modifies this tensor in-place to convert it to a Variable. This means that all consumers/producers of the tensor
97
97
  will see the update.
@@ -199,7 +199,7 @@ class Variable(Tensor):
199
199
  self,
200
200
  name: str,
201
201
  dtype: np.dtype | onnx.TensorProto.DataType = None,
202
- shape: Sequence[int | str] = None,
202
+ shape: Sequence[int | str] | None = None,
203
203
  type: str = "tensor_type",
204
204
  ):
205
205
  """
@@ -390,7 +390,7 @@ class Constant(Tensor):
390
390
  self,
391
391
  name: str,
392
392
  values: np.ndarray | LazyValues,
393
- data_location: int = None,
393
+ data_location: int | None = None,
394
394
  export_dtype: np.dtype | onnx.TensorProto.DataType = None,
395
395
  ):
396
396
  """
@@ -426,7 +426,7 @@ class Constant(Tensor):
426
426
  self.data_location = data_location
427
427
  self._export_dtype = export_dtype
428
428
 
429
- def to_variable(self, dtype: np.dtype = None, shape: Sequence[int | str] = None):
429
+ def to_variable(self, dtype: np.dtype = None, shape: Sequence[int | str] | None = None):
430
430
  """Convert instance values to an appropriate variable with specified dtype and shape."""
431
431
  if shape is None:
432
432
  shape = []