onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,343 @@
1
+ import os
2
+ import textwrap
3
+ import torch
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+ from ..helpers import flatten_object
6
+ from ..helpers.torch_helper import to_any
7
+ from .hghub.model_inputs import _preprocess_model_id
8
+ from .hghub import get_untrained_model_with_inputs
9
+ from .validate import filter_inputs, make_patch_kwargs
10
+
11
+
12
+ CODE_SAMPLES = {
13
+ "imports": "from typing import Any\nimport torch",
14
+ "get_model_with_inputs": textwrap.dedent(
15
+ """
16
+ def get_model_with_inputs(
17
+ model_id:str,
18
+ subfolder: str | None = None,
19
+ dtype: str | torch.dtype | None = None,
20
+ device: str | torch.device | None = None,
21
+ same_as_pretrained: bool = False,
22
+ use_pretrained: bool = False,
23
+ input_options: dict[str, Any] | None = None,
24
+ model_options: dict[str, Any] | None = None,
25
+ ) -> dict[str, Any]:
26
+ if use_pretrained:
27
+ import transformers
28
+ assert same_as_pretrained, (
29
+ "same_as_pretrained must be True if use_pretrained is True"
30
+ )
31
+ # tokenizer = AutoTokenizer.from_pretrained(model_path)
32
+ model = transformers.AutoModel.from_pretrained(
33
+ model_id,
34
+ trust_remote_code=True,
35
+ subfolder=subfolder,
36
+ dtype=dtype,
37
+ device=device,
38
+ )
39
+ data = {"model": model}
40
+ assert not input_options, "Not implemented yet with input_options{input_options}"
41
+ assert not model_options, "Not implemented yet with input_options{model_options}"
42
+ else:
43
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
44
+ data = get_untrained_model_with_inputs(
45
+ model_id,
46
+ use_pretrained=use_pretrained,
47
+ same_as_pretrained=same_as_pretrained,
48
+ inputs_kwargs=input_options,
49
+ model_kwargs=model_options,
50
+ subfolder=subfolder,
51
+ add_second_input=False,
52
+ )
53
+ if dtype:
54
+ data["model"] = data["model"].to(
55
+ getattr(torch, dtype) if isinstance(dtype, str) else dtype
56
+ )
57
+ if device:
58
+ data["model"] = data["model"].to(device)
59
+ return data["model"]
60
+ """
61
+ ),
62
+ }
63
+
64
+
65
+ def make_code_for_inputs(inputs: Dict[str, torch.Tensor]) -> str:
66
+ """
67
+ Creates a code to generate random inputs.
68
+
69
+ :param inputs: dictionary
70
+ :return: code
71
+ """
72
+ codes = []
73
+ for k, v in inputs.items():
74
+ if isinstance(v, (int, bool, float)):
75
+ code = f"{k}={v}"
76
+ elif isinstance(v, torch.Tensor):
77
+ shape = tuple(map(int, v.shape))
78
+ if v.dtype in (torch.int32, torch.int64):
79
+ code = f"{k}=torch.randint({v.max()}, size={shape}, dtype={v.dtype})"
80
+ elif v.dtype in (torch.float32, torch.float16, torch.bfloat16):
81
+ code = f"{k}=torch.rand({shape}, dtype={v.dtype})"
82
+ else:
83
+ raise ValueError(f"Unexpected dtype = {v.dtype} for k={k!r}")
84
+ elif v.__class__.__name__ == "DynamicCache":
85
+ obj = flatten_object(v)
86
+ cc = [f"torch.rand({tuple(map(int,_.shape))}, dtype={_.dtype})" for _ in obj]
87
+ va = [f"({a},{b})" for a, b in zip(cc[: len(cc) // 2], cc[len(cc) // 2 :])]
88
+ va2 = ", ".join(va)
89
+ code = f"{k}=make_dynamic_cache([{va2}])"
90
+ else:
91
+ raise ValueError(f"Unexpected type {type(v)} for k={k!r}")
92
+ codes.append(code)
93
+ st = ", ".join(codes)
94
+ return f"dict({st})"
95
+
96
+
97
+ def make_export_code(
98
+ exporter: str,
99
+ optimization: Optional[str] = None,
100
+ patch_kwargs: Optional[Dict[str, Any]] = None,
101
+ stop_if_static: int = 0,
102
+ dump_folder: Optional[str] = None,
103
+ opset: Optional[int] = None,
104
+ dynamic_shapes: Optional[Dict[str, Any]] = None,
105
+ output_names: Optional[List[str]] = None,
106
+ verbose: int = 0,
107
+ ) -> Tuple[str, str]:
108
+ args = [f"dynamic_shapes={dynamic_shapes}"]
109
+ if output_names:
110
+ args.append(f"output_names={output_names}")
111
+ code = []
112
+ imports = []
113
+ if dump_folder:
114
+ code.append(f"os.makedirs({dump_folder!r})")
115
+ imports.append("import os")
116
+ filename = os.path.join(dump_folder, "model.onnx")
117
+ if exporter == "custom":
118
+ if opset:
119
+ args.append(f"target_opset={opset}")
120
+ if optimization:
121
+ args.append(f"options=OptimizationOptions(patterns={optimization!r})")
122
+ args.append(f"large_model=True, filename={filename!r}")
123
+ sargs = ", ".join(args)
124
+ imports.extend(
125
+ [
126
+ "from experimental_experiment.torch_interpreter import to_onnx",
127
+ "from experimental_experiment.xbuilder import OptimizationOptions",
128
+ ]
129
+ )
130
+ code.extend([f"onx = to_onnx(model, inputs, {sargs})"])
131
+ elif exporter == "onnx-dynamo":
132
+ if opset:
133
+ args.append(f"opset_version={opset}")
134
+ sargs = ", ".join(args)
135
+ code.extend([f"epo = torch.onnx.export(model, args=(), kwargs=inputs, {sargs})"])
136
+ if optimization:
137
+ imports.append("import onnxscript")
138
+ code.extend(["onnxscript.optimizer.optimize_ir(epo.model)"])
139
+ if "os_ort" in optimization:
140
+ imports.append("import onnxscript.rewriter.ort_fusions as ort_fusions")
141
+ code.extend(["ort_fusions.optimize_for_ort(epo.model)"])
142
+ if dump_folder:
143
+ code.extend([f"epo.save({filename!r})"])
144
+ else:
145
+ raise ValueError(f"Unexpected exporter {exporter!r}")
146
+ if not patch_kwargs:
147
+ return "\n".join(imports), "\n".join(code)
148
+
149
+ imports.append("from onnx_diagnostic.torch_export_patches import torch_export_patches")
150
+ if stop_if_static:
151
+ patch_kwargs["stop_if_static"] = stop_if_static
152
+ sargs = ", ".join(f"{k}={v}" for k, v in patch_kwargs.items())
153
+ code = [f"with torch_export_patches({sargs}):", *[" " + _ for _ in code]]
154
+ return "\n".join(imports), "\n".join(code)
155
+
156
+
157
+ def code_sample(
158
+ model_id: str,
159
+ task: Optional[str] = None,
160
+ do_run: bool = False,
161
+ exporter: Optional[str] = None,
162
+ do_same: bool = False,
163
+ verbose: int = 0,
164
+ dtype: Optional[Union[str, torch.dtype]] = None,
165
+ device: Optional[Union[str, torch.device]] = None,
166
+ same_as_pretrained: bool = False,
167
+ use_pretrained: bool = False,
168
+ optimization: Optional[str] = None,
169
+ quiet: bool = False,
170
+ patch: Union[bool, str, Dict[str, bool]] = False,
171
+ rewrite: bool = False,
172
+ stop_if_static: int = 1,
173
+ dump_folder: Optional[str] = None,
174
+ drop_inputs: Optional[List[str]] = None,
175
+ input_options: Optional[Dict[str, Any]] = None,
176
+ model_options: Optional[Dict[str, Any]] = None,
177
+ subfolder: Optional[str] = None,
178
+ opset: Optional[int] = None,
179
+ runtime: str = "onnxruntime",
180
+ output_names: Optional[List[str]] = None,
181
+ ) -> str:
182
+ """
183
+ This generates a code to export a model with the proper settings.
184
+
185
+ :param model_id: model id to validate
186
+ :param task: task used to generate the necessary inputs,
187
+ can be left empty to use the default task for this model
188
+ if it can be determined
189
+ :param do_run: checks the model works with the defined inputs
190
+ :param exporter: exporter the model using this exporter,
191
+ available list: ``export-strict``, ``export-nostrict``, ...
192
+ see below
193
+ :param do_same: checks the discrepancies of the exported model
194
+ :param verbose: verbosity level
195
+ :param dtype: uses this dtype to check the model
196
+ :param device: do the verification on this device
197
+ :param same_as_pretrained: use a model equivalent to the trained,
198
+ this is not always possible
199
+ :param use_pretrained: use the trained model, not the untrained one
200
+ :param optimization: optimization to apply to the exported model,
201
+ depend on the the exporter
202
+ :param quiet: if quiet, catches exception if any issue
203
+ :param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
204
+ if True before exporting
205
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
206
+ a string can be used to specify only one of them
207
+ :param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
208
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
209
+ :param stop_if_static: stops if a dynamic dimension becomes static,
210
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
211
+ :param dump_folder: dumps everything in a subfolder of this one
212
+ :param drop_inputs: drops this list of inputs (given their names)
213
+ :param input_options: additional options to define the dummy inputs
214
+ used to export
215
+ :param model_options: additional options when creating the model such as
216
+ ``num_hidden_layers`` or ``attn_implementation``
217
+ :param subfolder: version or subfolders to uses when retrieving a model id
218
+ :param opset: onnx opset to use for the conversion
219
+ :param runtime: onnx runtime to use to check about discrepancies,
220
+ possible values ``onnxruntime``, ``torch``, ``orteval``,
221
+ ``orteval10``, ``ref`` only if `do_run` is true
222
+ :param output_names: output names the onnx exporter should use
223
+ :return: a code
224
+
225
+ .. runpython::
226
+ :showcode:
227
+
228
+ from onnx_diagnostic.torch_models.code_sample import code_sample
229
+
230
+ print(
231
+ code_sample(
232
+ "arnir0/Tiny-LLM",
233
+ exporter="onnx-dynamo",
234
+ optimization="ir",
235
+ patch=True,
236
+ )
237
+ )
238
+ """
239
+ model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
240
+ model_id,
241
+ subfolder,
242
+ same_as_pretrained=same_as_pretrained,
243
+ use_pretrained=use_pretrained,
244
+ )
245
+ patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite)
246
+
247
+ iop = input_options or {}
248
+ mop = model_options or {}
249
+ data = get_untrained_model_with_inputs(
250
+ model_id,
251
+ verbose=verbose,
252
+ task=task,
253
+ use_pretrained=use_pretrained,
254
+ same_as_pretrained=same_as_pretrained,
255
+ inputs_kwargs=iop,
256
+ model_kwargs=mop,
257
+ subfolder=subfolder,
258
+ add_second_input=False,
259
+ )
260
+ if drop_inputs:
261
+ update = {}
262
+ for k in data:
263
+ if k.startswith("inputs"):
264
+ update[k], ds = filter_inputs(
265
+ data[k],
266
+ drop_names=drop_inputs,
267
+ model=data["model"],
268
+ dynamic_shapes=data["dynamic_shapes"],
269
+ )
270
+ update["dynamic_shapes"] = ds
271
+ data.update(update)
272
+
273
+ update = {}
274
+ for k in data:
275
+ if k.startswith("inputs"):
276
+ v = data[k]
277
+ if dtype:
278
+ update[k] = v = to_any(
279
+ v, getattr(torch, dtype) if isinstance(dtype, str) else dtype
280
+ )
281
+ if device:
282
+ update[k] = v = to_any(v, device)
283
+ if update:
284
+ data.update(update)
285
+
286
+ args = [f"{model_id!r}"]
287
+ if subfolder:
288
+ args.append(f"subfolder={subfolder!r}")
289
+ if dtype:
290
+ args.append(f"dtype={dtype!r}")
291
+ if device:
292
+ args.append(f"device={device!r}")
293
+ if same_as_pretrained:
294
+ args.append(f"same_as_pretrained={same_as_pretrained!r}")
295
+ if use_pretrained:
296
+ args.append(f"use_pretrained={use_pretrained!r}")
297
+ if input_options:
298
+ args.append(f"input_options={input_options!r}")
299
+ if model_options:
300
+ args.append(f"model_options={model_options!r}")
301
+ model_args = ", ".join(args)
302
+ imports, exporter_code = (
303
+ make_export_code(
304
+ exporter=exporter,
305
+ patch_kwargs=patch_kwargs,
306
+ verbose=verbose,
307
+ optimization=optimization,
308
+ stop_if_static=stop_if_static,
309
+ dump_folder=dump_folder,
310
+ opset=opset,
311
+ dynamic_shapes=data["dynamic_shapes"],
312
+ )
313
+ if exporter is not None
314
+ else ([], [])
315
+ )
316
+ input_code = make_code_for_inputs(data["inputs"])
317
+ cache_import = (
318
+ "from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache"
319
+ if "dynamic_cache" in input_code
320
+ else ""
321
+ )
322
+
323
+ pieces = [
324
+ CODE_SAMPLES["imports"],
325
+ imports,
326
+ cache_import,
327
+ CODE_SAMPLES["get_model_with_inputs"],
328
+ textwrap.dedent(
329
+ f"""
330
+ model = get_model_with_inputs({model_args})
331
+ """
332
+ ),
333
+ f"inputs = {input_code}",
334
+ exporter_code,
335
+ ]
336
+ code = "\n".join(pieces) # type: ignore[arg-type]
337
+ try:
338
+ import black
339
+ except ImportError:
340
+ # No black formatting.
341
+ return code
342
+
343
+ return black.format_str(code, mode=black.Mode())
@@ -0,0 +1 @@
1
+ from .model_inputs import get_untrained_model_with_inputs