gimlet-api 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gml/compile.py CHANGED
@@ -14,49 +14,311 @@
14
14
  #
15
15
  # SPDX-License-Identifier: Apache-2.0
16
16
 
17
- from torch.fx.experimental.proxy_tensor import make_fx
18
- from torch_mlir import ExampleArgs, OutputType
19
- from torch_mlir import compile as torch_mlir_compile
20
- from torch_mlir.dynamo import _get_decomposition_table
17
+ import contextlib
18
+ import functools
19
+ from typing import Any, Dict, List, Optional, Sequence, Union
21
20
 
21
+ import safetensors_mlir
22
+ import torch
23
+ import torch_mlir
24
+ from gml.asset_manager import AssetManager
25
+ from mlir.ir import (
26
+ BF16Type,
27
+ ComplexType,
28
+ Context,
29
+ F16Type,
30
+ F32Type,
31
+ F64Type,
32
+ IntegerType,
33
+ Operation,
34
+ RankedTensorType,
35
+ Value,
36
+ )
37
+ from safetensors.torch import save_file
38
+ from torch._decomp import remove_decompositions
39
+ from torch.export._trace import _export
40
+ from torch_mlir.dialects import torch as torch_d
41
+ from torch_mlir.extras.fx_decomp_util import get_decomposition_table
42
+ from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks
43
+ from torch_mlir.fx import export_and_import
22
44
 
23
- def to_torch_mlir(model, example_inputs):
24
- example_args = ExampleArgs.get(example_inputs)
25
- args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)[
26
- "forward"
45
+
46
+ def _default_decomposition_denylist():
47
+ """These ops will not be decomposed by default."""
48
+ return [
49
+ torch.ops.aten.full.default,
50
+ torch.ops.aten.upsample_bilinear2d.vec,
27
51
  ]
52
+
53
+
54
+ @contextlib.contextmanager
55
+ def _patch_aot_export_module():
56
+ """This contextmanager prevents PyTorch dispatch from running when calling aot_export_module.
57
+
58
+ This patch is necessary because not all callers of `aot_export_module` expose the pre_dispatch flag.
59
+ For example, `ExportedProgram.run_decompositions` which is called by `torch_mlir.fx.export_and_import` doesn't
60
+ expose the pre_dispatch flag.
61
+
62
+ Without setting `pre_dispatch=True`, PyTorch dispatch will run before tracing which causes certain operations to be decomposed.
63
+ For example, `upsample_nearest2d` will be decomposed into aten.index.Tensor calls. This is undesirable for runtimes that provide
64
+ optimized implementations of the equivalent of `upsample_nearest2d`.
65
+ """
66
+ import torch._functorch.aot_autograd
67
+
68
+ orig = torch._functorch.aot_autograd.aot_export_module
69
+ torch._functorch.aot_autograd.aot_export_module = functools.partial(
70
+ orig, pre_dispatch=True
71
+ )
72
+ yield
73
+ torch._functorch.aot_autograd.aot_export_module = orig
74
+
75
+
76
+ _torch_dtype_to_builtin_element_type = {
77
+ torch.float16: lambda: F16Type.get(),
78
+ torch.bfloat16: lambda: BF16Type.get(),
79
+ torch.float32: lambda: F32Type.get(),
80
+ torch.float64: lambda: F64Type.get(),
81
+ torch.uint8: lambda: IntegerType.get_unsigned(8),
82
+ torch.int8: lambda: IntegerType.get_signless(8),
83
+ torch.int16: lambda: IntegerType.get_signless(16),
84
+ torch.int32: lambda: IntegerType.get_signless(32),
85
+ torch.int64: lambda: IntegerType.get_signless(64),
86
+ torch.bool: lambda: IntegerType.get_signless(1),
87
+ torch.qint8: lambda: IntegerType.get_signless(8),
88
+ torch.quint8: lambda: IntegerType.get_unsigned(8),
89
+ torch.complex32: lambda: ComplexType.get(F16Type.get()),
90
+ torch.complex64: lambda: ComplexType.get(F32Type.get()),
91
+ torch.complex128: lambda: ComplexType.get(F64Type.get()),
92
+ }
93
+
94
+
95
+ def _get_unique_(tensors, name):
96
+ index = 0
97
+ name = "{}_{}".format(name, index)
98
+ while name in tensors:
99
+ index += 1
100
+ name = "{}_{}".format(name, index)
101
+ return name
102
+
103
+
104
+ class TensorSet:
105
+ def __init__(self):
106
+ self._tensors: Dict[str, torch.Tensor] = dict()
107
+
108
+ def add(self, tensor: torch.Tensor) -> str:
109
+ shape_desc = "_".join([str(d) for d in tensor.shape])
110
+ base_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
111
+
112
+ index = 0
113
+ name = "{}_{}".format(base_name, index)
114
+ while name in self._tensors and not torch.equal(tensor, self._tensors[name]):
115
+ index += 1
116
+ name = "{}_{}".format(base_name, index)
117
+
118
+ self._tensors[name] = tensor
119
+ return name
120
+
121
+ def tensors(self) -> Dict[str, torch.Tensor]:
122
+ return self._tensors
123
+
124
+
125
+ class SafetensorImporterHooks(FxImporterHooks):
126
+ def __init__(self, asset_manager: AssetManager):
127
+ self._asset_mgr = asset_manager
128
+ # TODO(james): shard weights into multiple shards.
129
+ self.asset_name = "weights.shard0"
130
+ self._tensors = TensorSet()
131
+
132
+ def resolve_literal(
133
+ self, gni: "torch_mlir.extras.fx_importer.GraphNodeImporter", literal: Any
134
+ ) -> Optional[Value]:
135
+ if not isinstance(literal, torch.Tensor):
136
+ return None
137
+ tensor = literal
138
+ ctx = gni._c
139
+
140
+ tensor_name = self._tensors.add(tensor)
141
+
142
+ file_attr = safetensors_mlir.FileAttr.get(ctx, self.asset_name)
143
+
144
+ if tensor.dtype not in _torch_dtype_to_builtin_element_type:
145
+ raise ValueError("unsupported torch dtype: {}".format(tensor.dtype))
146
+ elem_type = _torch_dtype_to_builtin_element_type[tensor.dtype]()
147
+ tensor_type = RankedTensorType.get(tuple(tensor.size()), elem_type)
148
+
149
+ tensor_attr = safetensors_mlir.TensorAttr.get(
150
+ tensor_type, file_attr, tensor_name
151
+ )
152
+ builtin_tensor = safetensors_mlir.tensor_ref(tensor_type, tensor_attr)
153
+
154
+ vtensor_type = gni._cc.tensor_to_vtensor_type(tensor)
155
+ return Operation.create(
156
+ name="torch_c.from_builtin_tensor",
157
+ results=[vtensor_type],
158
+ operands=[builtin_tensor],
159
+ ).result
160
+
161
+ def save_tensors(self):
162
+ file_path = self._asset_mgr.add_asset(self.asset_name)
163
+ tensors = self._tensors.tensors()
164
+ for k in tensors:
165
+ tensors[k] = tensors[k].contiguous()
166
+ save_file(tensors, file_path)
167
+
168
+
169
+ def _to_module_list(val):
170
+ if isinstance(val, torch.nn.Module):
171
+ return val
172
+
173
+ converted = []
174
+ for item in val:
175
+ c = _to_module_container(item)
176
+ if c is None:
177
+ return None
178
+ converted.append(c)
179
+ if not converted:
180
+ return None
181
+ return torch.nn.ModuleList(converted)
182
+
183
+
184
+ def _to_module_dict(val):
185
+ if isinstance(val, torch.nn.Module):
186
+ return val
187
+
188
+ converted = dict()
189
+ for k, v in val.items():
190
+ c = _to_module_container(v)
191
+ if c is None:
192
+ return None
193
+ converted[k] = v
194
+ if not converted:
195
+ return None
196
+ return torch.nn.ModuleDict(converted)
197
+
198
+
199
+ def _to_module_container(val, root=False):
200
+ if isinstance(val, torch.nn.Module) and not root:
201
+ return val
202
+ if isinstance(val, dict):
203
+ return _to_module_dict(val)
204
+ if isinstance(val, list) or isinstance(val, tuple):
205
+ return _to_module_list(val)
206
+
207
+ return None
208
+
209
+
210
+ def _replace_containers_with_torch_containers(mod: torch.nn.Module):
211
+ """Replaces any lists, dict, or nested combinations of lists/dicts that are attributes of `mod` with torch.nn.ModuleList/torch.nn.ModuleDict
212
+
213
+ This fixes some `module is not installed as a submodule` errors.
214
+ ."""
215
+ _excludes = set(["_modules"])
216
+ replacements = dict()
217
+ for name, val in mod.__dict__.items():
218
+ if name in _excludes:
219
+ continue
220
+ c = _to_module_container(val, root=True)
221
+ if c is None:
222
+ continue
223
+ replacements[name] = c
224
+
225
+ for name, repl in replacements.items():
226
+ setattr(mod, name, repl)
227
+
228
+
229
+ def _ensure_submodules_accessed_through_getattr(mod: torch.nn.Module):
230
+ """This removes any registered modules from `mod.__dict__`.
231
+
232
+ This ensures that all accesses of submodules go through torch's __getattr__ infra,
233
+ preventing certain cases of `module is not installed as a submodule` errors.
234
+ """
235
+ if not hasattr(mod, "_modules"):
236
+ return
237
+ for name in mod._modules:
238
+ if name in mod.__dict__:
239
+ del mod.__dict__[name]
240
+
241
+
242
+ def _submodule_registration_workarounds(mod: torch.nn.Module):
243
+ """Apply submodule registration workarounds recursively to all submodules of `mod`."""
244
+ _ensure_submodules_accessed_through_getattr(mod)
245
+ _replace_containers_with_torch_containers(mod)
246
+ # We intentionally don't use `mod.modules()` (which returns all recursive submodules) here because we want only
247
+ # the direct dependencies of `mod`. So that we get a pre-order traversal, ensuring the workarounds are applied
248
+ # before we check for submodules.
249
+ for submod in mod._modules.values():
250
+ if submod is mod:
251
+ continue
252
+ _submodule_registration_workarounds(submod)
253
+
254
+
255
+ def to_torch_mlir(
256
+ model: torch.nn.Module,
257
+ example_inputs: Sequence[torch.Tensor],
258
+ dynamic_shapes: Optional[
259
+ Sequence[Dict[int, Union[str, "torch.export.dynamic_shapes._Dim"]]]
260
+ ] = None,
261
+ decomposition_denylist: Optional[List[torch._ops.OperatorBase]] = None,
262
+ weight_manager: Optional[AssetManager] = None,
263
+ ):
264
+ if dynamic_shapes is not None:
265
+ for shape in dynamic_shapes:
266
+ if not isinstance(shape, dict):
267
+ continue
268
+ for idx in shape:
269
+ if isinstance(shape[idx], torch.export.dynamic_shapes._Dim):
270
+ continue
271
+ shape[idx] = torch.export.Dim(shape[idx])
272
+
273
+ if decomposition_denylist is None:
274
+ decomposition_denylist = _default_decomposition_denylist()
275
+
276
+ model = model.eval().to("cpu")
277
+
278
+ _submodule_registration_workarounds(model)
279
+
28
280
  try:
29
281
  # Running the model a few times on the inputs, leads to more consistent compiled results.
30
282
  for _ in range(2):
31
- _ = model(*args)
32
- except Exception:
283
+ _ = model(*example_inputs)
284
+ except: # noqa
33
285
  # Ignore errors running the model. This can happen when the model has data dependent branches.
34
286
  pass
35
287
 
36
- try:
37
- compiled = torch_mlir_compile(
38
- model,
39
- example_inputs,
40
- use_tracing=False,
41
- ignore_traced_shapes=False,
42
- output_type=OutputType.RAW,
43
- use_make_fx=False,
44
- )
45
- return compiled
46
- except Exception:
47
- pass
48
-
49
- # If the module can't be exported directly, we try to create an FX graph and then export it.
50
- model = make_fx(
51
- model, pre_dispatch=True, decomposition_table=_get_decomposition_table()
52
- )(*args)
53
- compiled = torch_mlir_compile(
288
+ prog = _export(
54
289
  model,
55
- example_inputs,
56
- use_tracing=False,
57
- ignore_traced_shapes=False,
58
- output_type=OutputType.RAW,
59
- use_make_fx=False,
290
+ tuple(example_inputs),
291
+ pre_dispatch=False,
292
+ strict=False,
293
+ dynamic_shapes=dynamic_shapes,
60
294
  )
295
+ decomp_table = get_decomposition_table()
296
+ remove_decompositions(decomp_table, decomposition_denylist)
297
+ hooks = None
298
+ if weight_manager is not None:
299
+ hooks = SafetensorImporterHooks(weight_manager)
300
+
301
+ context = Context()
302
+ torch_d.register_dialect(context)
303
+ safetensors_mlir.register_dialect(context)
304
+ fx_importer = FxImporter(context=context, hooks=hooks)
305
+
306
+ with _patch_aot_export_module():
307
+ module = export_and_import(
308
+ prog,
309
+ *example_inputs,
310
+ decomposition_table=decomp_table,
311
+ fx_importer=fx_importer,
312
+ )
313
+
314
+ if hooks is not None:
315
+ hooks.save_tensors()
316
+
317
+ try:
318
+ module.operation.verify()
319
+ except Exception as exc:
320
+ raise Exception(
321
+ "failed to verify converted torch model MLIR module: {}".format(module)
322
+ ) from exc
61
323
 
62
- return compiled
324
+ return module
gml/device.py ADDED
@@ -0,0 +1,77 @@
1
+ # Copyright 2023- Gimlet Labs, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ from typing import List
18
+
19
+ import gml.proto.src.api.corepb.v1.cp_edge_pb2 as cpedgepb
20
+ import gml.proto.src.api.corepb.v1.device_info_pb2 as deviceinfopb
21
+
22
+
23
+ class DeviceCapabilities:
24
+ def __init__(self, runtimes: List[str], cameras: List[str]):
25
+ self.runtimes = runtimes
26
+ self.cameras = cameras
27
+
28
+ def to_proto(self) -> cpedgepb.DeviceCapabilities:
29
+ return cpedgepb.DeviceCapabilities(
30
+ model_runtimes=[
31
+ deviceinfopb.ModelRuntimeInfo(
32
+ type=_runtime_str_to_runtime_protos(runtime)
33
+ )
34
+ for runtime in self.runtimes
35
+ ],
36
+ cameras=[
37
+ deviceinfopb.CameraInfo(
38
+ driver=_camera_driver_str_to_camera_driver_protos(camera),
39
+ camera_id=str(idx),
40
+ )
41
+ for idx, camera in enumerate(self.cameras)
42
+ ],
43
+ camera_drivers=[
44
+ deviceinfopb.CameraDriverInfo(
45
+ driver=_camera_driver_str_to_camera_driver_protos(camera)
46
+ )
47
+ for camera in self.cameras
48
+ ],
49
+ )
50
+
51
+
52
+ def _runtime_str_to_runtime_protos(
53
+ runtime: str,
54
+ ) -> deviceinfopb.ModelRuntimeInfo.ModelRuntimeType:
55
+ match runtime.lower():
56
+ case "tensorrt":
57
+ return (
58
+ deviceinfopb.ModelRuntimeInfo.ModelRuntimeType.MODEL_RUNTIME_TYPE_TENSORRT
59
+ )
60
+ case "openvino":
61
+ return (
62
+ deviceinfopb.ModelRuntimeInfo.ModelRuntimeType.MODEL_RUNTIME_TYPE_OPENVINO
63
+ )
64
+ case _:
65
+ raise ValueError("invalid runtime: {}".format(runtime))
66
+
67
+
68
+ def _camera_driver_str_to_camera_driver_protos(
69
+ driver: str,
70
+ ) -> deviceinfopb.CameraDriver:
71
+ match driver.lower():
72
+ case "argus":
73
+ return deviceinfopb.CameraDriver.CAMERA_DRIVER_ARGUS
74
+ case "v4l2":
75
+ return deviceinfopb.CameraDriver.CAMERA_DRIVER_V4L2
76
+ case _:
77
+ raise ValueError("invalid driver: {}".format(driver))