gimlet-api 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gimlet_api-0.0.5.dist-info → gimlet_api-0.0.7.dist-info}/METADATA +6 -3
- {gimlet_api-0.0.5.dist-info → gimlet_api-0.0.7.dist-info}/RECORD +24 -17
- gml/__init__.py +1 -0
- gml/asset_manager.py +75 -0
- gml/client.py +105 -19
- gml/compile.py +296 -34
- gml/device.py +77 -0
- gml/hf.py +524 -0
- gml/model.py +64 -18
- gml/model_utils.py +2 -0
- gml/pipelines.py +67 -40
- gml/preprocessing.py +5 -2
- gml/proto/src/api/corepb/v1/controlplane_pb2.py +7 -3
- gml/proto/src/api/corepb/v1/cp_edge_pb2.py +77 -63
- gml/proto/src/api/corepb/v1/device_info_pb2.py +43 -0
- gml/proto/src/api/corepb/v1/gem_config_pb2.py +45 -0
- gml/proto/src/api/corepb/v1/mediastream_pb2.py +33 -25
- gml/proto/src/api/corepb/v1/model_exec_pb2.py +105 -101
- gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py +40 -0
- gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2_grpc.py +66 -0
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py +7 -3
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py +33 -0
- gml/tensor.py +115 -34
- {gimlet_api-0.0.5.dist-info → gimlet_api-0.0.7.dist-info}/WHEEL +0 -0
gml/compile.py
CHANGED
@@ -14,49 +14,311 @@
|
|
14
14
|
#
|
15
15
|
# SPDX-License-Identifier: Apache-2.0
|
16
16
|
|
17
|
-
|
18
|
-
|
19
|
-
from
|
20
|
-
from torch_mlir.dynamo import _get_decomposition_table
|
17
|
+
import contextlib
|
18
|
+
import functools
|
19
|
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
21
20
|
|
21
|
+
import safetensors_mlir
|
22
|
+
import torch
|
23
|
+
import torch_mlir
|
24
|
+
from gml.asset_manager import AssetManager
|
25
|
+
from mlir.ir import (
|
26
|
+
BF16Type,
|
27
|
+
ComplexType,
|
28
|
+
Context,
|
29
|
+
F16Type,
|
30
|
+
F32Type,
|
31
|
+
F64Type,
|
32
|
+
IntegerType,
|
33
|
+
Operation,
|
34
|
+
RankedTensorType,
|
35
|
+
Value,
|
36
|
+
)
|
37
|
+
from safetensors.torch import save_file
|
38
|
+
from torch._decomp import remove_decompositions
|
39
|
+
from torch.export._trace import _export
|
40
|
+
from torch_mlir.dialects import torch as torch_d
|
41
|
+
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
42
|
+
from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks
|
43
|
+
from torch_mlir.fx import export_and_import
|
22
44
|
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
45
|
+
|
46
|
+
def _default_decomposition_denylist():
|
47
|
+
"""These ops will not be decomposed by default."""
|
48
|
+
return [
|
49
|
+
torch.ops.aten.full.default,
|
50
|
+
torch.ops.aten.upsample_bilinear2d.vec,
|
27
51
|
]
|
52
|
+
|
53
|
+
|
54
|
+
@contextlib.contextmanager
|
55
|
+
def _patch_aot_export_module():
|
56
|
+
"""This contextmanager prevents PyTorch dispatch from running when calling aot_export_module.
|
57
|
+
|
58
|
+
This patch is necessary because not all callers of `aot_export_module` expose the pre_dispatch flag.
|
59
|
+
For example, `ExportedProgram.run_decompositions` which is called by `torch_mlir.fx.export_and_import` doesn't
|
60
|
+
expose the pre_dispatch flag.
|
61
|
+
|
62
|
+
Without setting `pre_dispatch=True`, PyTorch dispatch will run before tracing which causes certain operations to be decomposed.
|
63
|
+
For example, `upsample_nearest2d` will be decomposed into aten.index.Tensor calls. This is undesirable for runtimes that provide
|
64
|
+
optimized implementations of the equivalent of `upsample_nearest2d`.
|
65
|
+
"""
|
66
|
+
import torch._functorch.aot_autograd
|
67
|
+
|
68
|
+
orig = torch._functorch.aot_autograd.aot_export_module
|
69
|
+
torch._functorch.aot_autograd.aot_export_module = functools.partial(
|
70
|
+
orig, pre_dispatch=True
|
71
|
+
)
|
72
|
+
yield
|
73
|
+
torch._functorch.aot_autograd.aot_export_module = orig
|
74
|
+
|
75
|
+
|
76
|
+
_torch_dtype_to_builtin_element_type = {
|
77
|
+
torch.float16: lambda: F16Type.get(),
|
78
|
+
torch.bfloat16: lambda: BF16Type.get(),
|
79
|
+
torch.float32: lambda: F32Type.get(),
|
80
|
+
torch.float64: lambda: F64Type.get(),
|
81
|
+
torch.uint8: lambda: IntegerType.get_unsigned(8),
|
82
|
+
torch.int8: lambda: IntegerType.get_signless(8),
|
83
|
+
torch.int16: lambda: IntegerType.get_signless(16),
|
84
|
+
torch.int32: lambda: IntegerType.get_signless(32),
|
85
|
+
torch.int64: lambda: IntegerType.get_signless(64),
|
86
|
+
torch.bool: lambda: IntegerType.get_signless(1),
|
87
|
+
torch.qint8: lambda: IntegerType.get_signless(8),
|
88
|
+
torch.quint8: lambda: IntegerType.get_unsigned(8),
|
89
|
+
torch.complex32: lambda: ComplexType.get(F16Type.get()),
|
90
|
+
torch.complex64: lambda: ComplexType.get(F32Type.get()),
|
91
|
+
torch.complex128: lambda: ComplexType.get(F64Type.get()),
|
92
|
+
}
|
93
|
+
|
94
|
+
|
95
|
+
def _get_unique_(tensors, name):
|
96
|
+
index = 0
|
97
|
+
name = "{}_{}".format(name, index)
|
98
|
+
while name in tensors:
|
99
|
+
index += 1
|
100
|
+
name = "{}_{}".format(name, index)
|
101
|
+
return name
|
102
|
+
|
103
|
+
|
104
|
+
class TensorSet:
|
105
|
+
def __init__(self):
|
106
|
+
self._tensors: Dict[str, torch.Tensor] = dict()
|
107
|
+
|
108
|
+
def add(self, tensor: torch.Tensor) -> str:
|
109
|
+
shape_desc = "_".join([str(d) for d in tensor.shape])
|
110
|
+
base_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
|
111
|
+
|
112
|
+
index = 0
|
113
|
+
name = "{}_{}".format(base_name, index)
|
114
|
+
while name in self._tensors and not torch.equal(tensor, self._tensors[name]):
|
115
|
+
index += 1
|
116
|
+
name = "{}_{}".format(base_name, index)
|
117
|
+
|
118
|
+
self._tensors[name] = tensor
|
119
|
+
return name
|
120
|
+
|
121
|
+
def tensors(self) -> Dict[str, torch.Tensor]:
|
122
|
+
return self._tensors
|
123
|
+
|
124
|
+
|
125
|
+
class SafetensorImporterHooks(FxImporterHooks):
|
126
|
+
def __init__(self, asset_manager: AssetManager):
|
127
|
+
self._asset_mgr = asset_manager
|
128
|
+
# TODO(james): shard weights into multiple shards.
|
129
|
+
self.asset_name = "weights.shard0"
|
130
|
+
self._tensors = TensorSet()
|
131
|
+
|
132
|
+
def resolve_literal(
|
133
|
+
self, gni: "torch_mlir.extras.fx_importer.GraphNodeImporter", literal: Any
|
134
|
+
) -> Optional[Value]:
|
135
|
+
if not isinstance(literal, torch.Tensor):
|
136
|
+
return None
|
137
|
+
tensor = literal
|
138
|
+
ctx = gni._c
|
139
|
+
|
140
|
+
tensor_name = self._tensors.add(tensor)
|
141
|
+
|
142
|
+
file_attr = safetensors_mlir.FileAttr.get(ctx, self.asset_name)
|
143
|
+
|
144
|
+
if tensor.dtype not in _torch_dtype_to_builtin_element_type:
|
145
|
+
raise ValueError("unsupported torch dtype: {}".format(tensor.dtype))
|
146
|
+
elem_type = _torch_dtype_to_builtin_element_type[tensor.dtype]()
|
147
|
+
tensor_type = RankedTensorType.get(tuple(tensor.size()), elem_type)
|
148
|
+
|
149
|
+
tensor_attr = safetensors_mlir.TensorAttr.get(
|
150
|
+
tensor_type, file_attr, tensor_name
|
151
|
+
)
|
152
|
+
builtin_tensor = safetensors_mlir.tensor_ref(tensor_type, tensor_attr)
|
153
|
+
|
154
|
+
vtensor_type = gni._cc.tensor_to_vtensor_type(tensor)
|
155
|
+
return Operation.create(
|
156
|
+
name="torch_c.from_builtin_tensor",
|
157
|
+
results=[vtensor_type],
|
158
|
+
operands=[builtin_tensor],
|
159
|
+
).result
|
160
|
+
|
161
|
+
def save_tensors(self):
|
162
|
+
file_path = self._asset_mgr.add_asset(self.asset_name)
|
163
|
+
tensors = self._tensors.tensors()
|
164
|
+
for k in tensors:
|
165
|
+
tensors[k] = tensors[k].contiguous()
|
166
|
+
save_file(tensors, file_path)
|
167
|
+
|
168
|
+
|
169
|
+
def _to_module_list(val):
|
170
|
+
if isinstance(val, torch.nn.Module):
|
171
|
+
return val
|
172
|
+
|
173
|
+
converted = []
|
174
|
+
for item in val:
|
175
|
+
c = _to_module_container(item)
|
176
|
+
if c is None:
|
177
|
+
return None
|
178
|
+
converted.append(c)
|
179
|
+
if not converted:
|
180
|
+
return None
|
181
|
+
return torch.nn.ModuleList(converted)
|
182
|
+
|
183
|
+
|
184
|
+
def _to_module_dict(val):
|
185
|
+
if isinstance(val, torch.nn.Module):
|
186
|
+
return val
|
187
|
+
|
188
|
+
converted = dict()
|
189
|
+
for k, v in val.items():
|
190
|
+
c = _to_module_container(v)
|
191
|
+
if c is None:
|
192
|
+
return None
|
193
|
+
converted[k] = v
|
194
|
+
if not converted:
|
195
|
+
return None
|
196
|
+
return torch.nn.ModuleDict(converted)
|
197
|
+
|
198
|
+
|
199
|
+
def _to_module_container(val, root=False):
|
200
|
+
if isinstance(val, torch.nn.Module) and not root:
|
201
|
+
return val
|
202
|
+
if isinstance(val, dict):
|
203
|
+
return _to_module_dict(val)
|
204
|
+
if isinstance(val, list) or isinstance(val, tuple):
|
205
|
+
return _to_module_list(val)
|
206
|
+
|
207
|
+
return None
|
208
|
+
|
209
|
+
|
210
|
+
def _replace_containers_with_torch_containers(mod: torch.nn.Module):
|
211
|
+
"""Replaces any lists, dict, or nested combinations of lists/dicts that are attributes of `mod` with torch.nn.ModuleList/torch.nn.ModuleDict
|
212
|
+
|
213
|
+
This fixes some `module is not installed as a submodule` errors.
|
214
|
+
."""
|
215
|
+
_excludes = set(["_modules"])
|
216
|
+
replacements = dict()
|
217
|
+
for name, val in mod.__dict__.items():
|
218
|
+
if name in _excludes:
|
219
|
+
continue
|
220
|
+
c = _to_module_container(val, root=True)
|
221
|
+
if c is None:
|
222
|
+
continue
|
223
|
+
replacements[name] = c
|
224
|
+
|
225
|
+
for name, repl in replacements.items():
|
226
|
+
setattr(mod, name, repl)
|
227
|
+
|
228
|
+
|
229
|
+
def _ensure_submodules_accessed_through_getattr(mod: torch.nn.Module):
|
230
|
+
"""This removes any registered modules from `mod.__dict__`.
|
231
|
+
|
232
|
+
This ensures that all accesses of submodules go through torch's __getattr__ infra,
|
233
|
+
preventing certain cases of `module is not installed as a submodule` errors.
|
234
|
+
"""
|
235
|
+
if not hasattr(mod, "_modules"):
|
236
|
+
return
|
237
|
+
for name in mod._modules:
|
238
|
+
if name in mod.__dict__:
|
239
|
+
del mod.__dict__[name]
|
240
|
+
|
241
|
+
|
242
|
+
def _submodule_registration_workarounds(mod: torch.nn.Module):
|
243
|
+
"""Apply submodule registration workarounds recursively to all submodules of `mod`."""
|
244
|
+
_ensure_submodules_accessed_through_getattr(mod)
|
245
|
+
_replace_containers_with_torch_containers(mod)
|
246
|
+
# We intentionally don't use `mod.modules()` (which returns all recursive submodules) here because we want only
|
247
|
+
# the direct dependencies of `mod`. So that we get a pre-order traversal, ensuring the workarounds are applied
|
248
|
+
# before we check for submodules.
|
249
|
+
for submod in mod._modules.values():
|
250
|
+
if submod is mod:
|
251
|
+
continue
|
252
|
+
_submodule_registration_workarounds(submod)
|
253
|
+
|
254
|
+
|
255
|
+
def to_torch_mlir(
|
256
|
+
model: torch.nn.Module,
|
257
|
+
example_inputs: Sequence[torch.Tensor],
|
258
|
+
dynamic_shapes: Optional[
|
259
|
+
Sequence[Dict[int, Union[str, "torch.export.dynamic_shapes._Dim"]]]
|
260
|
+
] = None,
|
261
|
+
decomposition_denylist: Optional[List[torch._ops.OperatorBase]] = None,
|
262
|
+
weight_manager: Optional[AssetManager] = None,
|
263
|
+
):
|
264
|
+
if dynamic_shapes is not None:
|
265
|
+
for shape in dynamic_shapes:
|
266
|
+
if not isinstance(shape, dict):
|
267
|
+
continue
|
268
|
+
for idx in shape:
|
269
|
+
if isinstance(shape[idx], torch.export.dynamic_shapes._Dim):
|
270
|
+
continue
|
271
|
+
shape[idx] = torch.export.Dim(shape[idx])
|
272
|
+
|
273
|
+
if decomposition_denylist is None:
|
274
|
+
decomposition_denylist = _default_decomposition_denylist()
|
275
|
+
|
276
|
+
model = model.eval().to("cpu")
|
277
|
+
|
278
|
+
_submodule_registration_workarounds(model)
|
279
|
+
|
28
280
|
try:
|
29
281
|
# Running the model a few times on the inputs, leads to more consistent compiled results.
|
30
282
|
for _ in range(2):
|
31
|
-
_ = model(*
|
32
|
-
except
|
283
|
+
_ = model(*example_inputs)
|
284
|
+
except: # noqa
|
33
285
|
# Ignore errors running the model. This can happen when the model has data dependent branches.
|
34
286
|
pass
|
35
287
|
|
36
|
-
|
37
|
-
compiled = torch_mlir_compile(
|
38
|
-
model,
|
39
|
-
example_inputs,
|
40
|
-
use_tracing=False,
|
41
|
-
ignore_traced_shapes=False,
|
42
|
-
output_type=OutputType.RAW,
|
43
|
-
use_make_fx=False,
|
44
|
-
)
|
45
|
-
return compiled
|
46
|
-
except Exception:
|
47
|
-
pass
|
48
|
-
|
49
|
-
# If the module can't be exported directly, we try to create an FX graph and then export it.
|
50
|
-
model = make_fx(
|
51
|
-
model, pre_dispatch=True, decomposition_table=_get_decomposition_table()
|
52
|
-
)(*args)
|
53
|
-
compiled = torch_mlir_compile(
|
288
|
+
prog = _export(
|
54
289
|
model,
|
55
|
-
example_inputs,
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
use_make_fx=False,
|
290
|
+
tuple(example_inputs),
|
291
|
+
pre_dispatch=False,
|
292
|
+
strict=False,
|
293
|
+
dynamic_shapes=dynamic_shapes,
|
60
294
|
)
|
295
|
+
decomp_table = get_decomposition_table()
|
296
|
+
remove_decompositions(decomp_table, decomposition_denylist)
|
297
|
+
hooks = None
|
298
|
+
if weight_manager is not None:
|
299
|
+
hooks = SafetensorImporterHooks(weight_manager)
|
300
|
+
|
301
|
+
context = Context()
|
302
|
+
torch_d.register_dialect(context)
|
303
|
+
safetensors_mlir.register_dialect(context)
|
304
|
+
fx_importer = FxImporter(context=context, hooks=hooks)
|
305
|
+
|
306
|
+
with _patch_aot_export_module():
|
307
|
+
module = export_and_import(
|
308
|
+
prog,
|
309
|
+
*example_inputs,
|
310
|
+
decomposition_table=decomp_table,
|
311
|
+
fx_importer=fx_importer,
|
312
|
+
)
|
313
|
+
|
314
|
+
if hooks is not None:
|
315
|
+
hooks.save_tensors()
|
316
|
+
|
317
|
+
try:
|
318
|
+
module.operation.verify()
|
319
|
+
except Exception as exc:
|
320
|
+
raise Exception(
|
321
|
+
"failed to verify converted torch model MLIR module: {}".format(module)
|
322
|
+
) from exc
|
61
323
|
|
62
|
-
return
|
324
|
+
return module
|
gml/device.py
ADDED
@@ -0,0 +1,77 @@
|
|
1
|
+
# Copyright 2023- Gimlet Labs, Inc.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
# SPDX-License-Identifier: Apache-2.0
|
16
|
+
|
17
|
+
from typing import List
|
18
|
+
|
19
|
+
import gml.proto.src.api.corepb.v1.cp_edge_pb2 as cpedgepb
|
20
|
+
import gml.proto.src.api.corepb.v1.device_info_pb2 as deviceinfopb
|
21
|
+
|
22
|
+
|
23
|
+
class DeviceCapabilities:
|
24
|
+
def __init__(self, runtimes: List[str], cameras: List[str]):
|
25
|
+
self.runtimes = runtimes
|
26
|
+
self.cameras = cameras
|
27
|
+
|
28
|
+
def to_proto(self) -> cpedgepb.DeviceCapabilities:
|
29
|
+
return cpedgepb.DeviceCapabilities(
|
30
|
+
model_runtimes=[
|
31
|
+
deviceinfopb.ModelRuntimeInfo(
|
32
|
+
type=_runtime_str_to_runtime_protos(runtime)
|
33
|
+
)
|
34
|
+
for runtime in self.runtimes
|
35
|
+
],
|
36
|
+
cameras=[
|
37
|
+
deviceinfopb.CameraInfo(
|
38
|
+
driver=_camera_driver_str_to_camera_driver_protos(camera),
|
39
|
+
camera_id=str(idx),
|
40
|
+
)
|
41
|
+
for idx, camera in enumerate(self.cameras)
|
42
|
+
],
|
43
|
+
camera_drivers=[
|
44
|
+
deviceinfopb.CameraDriverInfo(
|
45
|
+
driver=_camera_driver_str_to_camera_driver_protos(camera)
|
46
|
+
)
|
47
|
+
for camera in self.cameras
|
48
|
+
],
|
49
|
+
)
|
50
|
+
|
51
|
+
|
52
|
+
def _runtime_str_to_runtime_protos(
|
53
|
+
runtime: str,
|
54
|
+
) -> deviceinfopb.ModelRuntimeInfo.ModelRuntimeType:
|
55
|
+
match runtime.lower():
|
56
|
+
case "tensorrt":
|
57
|
+
return (
|
58
|
+
deviceinfopb.ModelRuntimeInfo.ModelRuntimeType.MODEL_RUNTIME_TYPE_TENSORRT
|
59
|
+
)
|
60
|
+
case "openvino":
|
61
|
+
return (
|
62
|
+
deviceinfopb.ModelRuntimeInfo.ModelRuntimeType.MODEL_RUNTIME_TYPE_OPENVINO
|
63
|
+
)
|
64
|
+
case _:
|
65
|
+
raise ValueError("invalid runtime: {}".format(runtime))
|
66
|
+
|
67
|
+
|
68
|
+
def _camera_driver_str_to_camera_driver_protos(
|
69
|
+
driver: str,
|
70
|
+
) -> deviceinfopb.CameraDriver:
|
71
|
+
match driver.lower():
|
72
|
+
case "argus":
|
73
|
+
return deviceinfopb.CameraDriver.CAMERA_DRIVER_ARGUS
|
74
|
+
case "v4l2":
|
75
|
+
return deviceinfopb.CameraDriver.CAMERA_DRIVER_V4L2
|
76
|
+
case _:
|
77
|
+
raise ValueError("invalid driver: {}".format(driver))
|