gimlet-api 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gimlet_api-0.0.5.dist-info → gimlet_api-0.0.6.dist-info}/METADATA +5 -3
- {gimlet_api-0.0.5.dist-info → gimlet_api-0.0.6.dist-info}/RECORD +20 -16
- gml/__init__.py +1 -0
- gml/client.py +105 -19
- gml/compile.py +123 -5
- gml/device.py +74 -0
- gml/hf.py +521 -0
- gml/model.py +47 -16
- gml/model_utils.py +2 -0
- gml/pipelines.py +69 -3
- gml/preprocessing.py +5 -2
- gml/proto/src/api/corepb/v1/cp_edge_pb2.py +51 -33
- gml/proto/src/api/corepb/v1/mediastream_pb2.py +33 -25
- gml/proto/src/api/corepb/v1/model_exec_pb2.py +105 -101
- gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py +40 -0
- gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2_grpc.py +66 -0
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py +7 -3
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py +33 -0
- gml/tensor.py +115 -34
- {gimlet_api-0.0.5.dist-info → gimlet_api-0.0.6.dist-info}/WHEEL +0 -0
@@ -9,8 +9,10 @@ Classifier: Typing :: Typed
|
|
9
9
|
Requires-Python: >=3
|
10
10
|
Requires-Dist: protobuf
|
11
11
|
Requires-Dist: grpcio
|
12
|
-
Requires-Dist: torch
|
13
|
-
Requires-Dist: torch_mlir_gml
|
14
|
-
|
12
|
+
Requires-Dist: torch>=2.3.0
|
13
|
+
Requires-Dist: torch_mlir_gml==0.0.2
|
14
|
+
Requires-Dist: numpy<2.0.0
|
15
|
+
Requires-Dist: transformers>=4.43.3
|
16
|
+
Version: 0.0.6
|
15
17
|
|
16
18
|
UNKNOWN
|
@@ -1,11 +1,13 @@
|
|
1
|
-
gml/__init__.py,sha256=
|
1
|
+
gml/__init__.py,sha256=H3WQZ_RaN7VNeb__qeHEbKLEwkaG7gpL5FQ8s1IotUA,773
|
2
2
|
gml/_utils.py,sha256=mSCWHhCdzcUvHqmJIB2FS215K1LMgJCWcZ6e6FWK3hQ,1184
|
3
|
-
gml/client.py,sha256=
|
4
|
-
gml/compile.py,sha256=
|
5
|
-
gml/
|
6
|
-
gml/
|
7
|
-
gml/
|
8
|
-
gml/
|
3
|
+
gml/client.py,sha256=5QDKljltBeBTCd2hH38--fTSP0bVVcAvSnWsA9YEFQc,13819
|
4
|
+
gml/compile.py,sha256=hR8u3LaMiIW8d12FHrvtmtgzUNQq48DYxe8bW-wJ_VY,6054
|
5
|
+
gml/device.py,sha256=VUZc6m8QalJ7G9KBKjCY4cIcv2VBd6zAT3ysnh_m1Z0,2585
|
6
|
+
gml/hf.py,sha256=GRvEEl9zSIv0iWN91Z6ykFYZ2VdNAVABjZrrzYWUFw4,17792
|
7
|
+
gml/model.py,sha256=nXUV6-L4TIkQHCWUpWyG7QJ6YKTZb7eauW9F4pzVTII,6566
|
8
|
+
gml/model_utils.py,sha256=vZvE5cHZIDkUkeZ4Pk4hhV-zOYMiREluv4b8kdqQ3Ig,1375
|
9
|
+
gml/pipelines.py,sha256=Bha8J3b5uW8COIejiH12NNF0Tc0XDBt2B3Dez5Jxt4s,5314
|
10
|
+
gml/preprocessing.py,sha256=STQDSA1_jXPTenJotNtsNMXOc9h1x_wJyQ100LXS6-g,3209
|
9
11
|
gml/proto/gogoproto/gogo_pb2.py,sha256=WVMIAR8K--mCUkTPM7mEeeXGpQlRRtt_kco10iP3CZs,15728
|
10
12
|
gml/proto/mediapipe/framework/calculator_contract_test_pb2.py,sha256=hNjyZCBz3RYa6rN4xR3FOCZKA24gq_LsJ3EMegl5wK4,2031
|
11
13
|
gml/proto/mediapipe/framework/calculator_options_pb2.py,sha256=Nq1BQRtLdsIgfkw7ymD3eg2p2_RSlZhiHS7YbDhNHR0,1563
|
@@ -23,21 +25,23 @@ gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=wQjeDti-C8JiNwRn-z5
|
|
23
25
|
gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=t2Far6oVcUFQIimzgAkZ8vQd0asMIlvECp4osC0ujgg,9735
|
24
26
|
gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=cbNmE12Nm3PjW4NXU7-Z-9m_0Zs3Ab8R1xLkDnvclCg,1730
|
25
27
|
gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=FPNx5fRXj-bnN5mkDUXVz17M33vuHV_hmxH0ggkAUVs,5536
|
26
|
-
gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=
|
27
|
-
gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=
|
28
|
-
gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=
|
28
|
+
gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=u41Sohshi6gBfeZO5VnQzfRStFADFzT1Um5mDY9chcg,15309
|
29
|
+
gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=1qA-ElTgWeGv3oevYlIjK1TIRSgWbR1TTWxA6Q3SOXk,7224
|
30
|
+
gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=1DM58lSFgfoHk0ui3ZTjDfifgp4dhE7nHvhMwmInpsA,27103
|
29
31
|
gml/proto/src/common/typespb/jwt_pb2.py,sha256=lxy-bqbyg96i9n_xr2JbkuWX-ldnoJavXPMnApzVSio,5580
|
30
32
|
gml/proto/src/common/typespb/status_pb2.py,sha256=IbBJnbsAlvsuTtyT285ZuW6k5VaPfl5kRSOnBxD_H8M,2109
|
31
33
|
gml/proto/src/common/typespb/uuid_pb2.py,sha256=5Fm3jYpCPX7sMrP6RhRYsF0SnuZNIBEQJk9f0jwZ2Rw,1188
|
34
|
+
gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256=tkJFPWpndKZy19TFuLKlBfWW1fUQPj0lJLiQ9HfugZU,3213
|
35
|
+
gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2_grpc.py,sha256=l-gTK9nYpTlVb7QGAckSQXlHhkRdKe2-nrxXc8NQavY,2912
|
32
36
|
gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py,sha256=KgoUT8ccF-yJPe1r4otQjAPQoKBaQzdBlHoIUSkk0yE,11445
|
33
37
|
gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py,sha256=p3OpT8-hfNHu4-29qr-ZahRwO-LoCYM9Q4jomAHTXGA,24572
|
34
38
|
gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py,sha256=r8mbJNTq45_c0amPnTr8OFZasCk7XWu2YS_eu7GfWJg,7050
|
35
39
|
gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py,sha256=XlE4R2PJaOmzQocx7y6SKJvuqt8tYBGzBuhajvzG0cc,12919
|
36
|
-
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=
|
37
|
-
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256
|
40
|
+
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=wvLQvoh2UA5qCcMALT6PS47LYmmVdBz9U47WFLs5Ayg,6330
|
41
|
+
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=-snjW7n6JveUzJVPFcm25XlL19kowPSKgd61l_jPnHA,9541
|
38
42
|
gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py,sha256=RVedXkNYu2iF5OHiXoYyRw9AGRCUWG7qNyY-5QY71Go,3762
|
39
43
|
gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py,sha256=KSdb6V04qUHDsb1R2o3wixwTyZgrhwnPYobjnRgWX4I,4735
|
40
|
-
gml/tensor.py,sha256=
|
41
|
-
gimlet_api-0.0.
|
42
|
-
gimlet_api-0.0.
|
43
|
-
gimlet_api-0.0.
|
44
|
+
gml/tensor.py,sha256=753IsMFYZD7p_f0cQPt4nTIBo5p5S5ELqwCuoHORdMk,14823
|
45
|
+
gimlet_api-0.0.6.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
|
46
|
+
gimlet_api-0.0.6.dist-info/METADATA,sha256=mF73_t-Tn5NPVxLnJBOTCQkKvb0weobtQLmOqSkc4B0,506
|
47
|
+
gimlet_api-0.0.6.dist-info/RECORD,,
|
gml/__init__.py
CHANGED
gml/client.py
CHANGED
@@ -21,6 +21,8 @@ from typing import BinaryIO, List, Optional, TextIO, Union
|
|
21
21
|
|
22
22
|
import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
|
23
23
|
import gml.proto.src.common.typespb.uuid_pb2 as uuidpb
|
24
|
+
import gml.proto.src.controlplane.compiler.cpb.v1.cpb_pb2 as cpb
|
25
|
+
import gml.proto.src.controlplane.compiler.cpb.v1.cpb_pb2_grpc as cpb_grpc
|
24
26
|
import gml.proto.src.controlplane.directory.directorypb.v1.directory_pb2 as directorypb
|
25
27
|
import gml.proto.src.controlplane.directory.directorypb.v1.directory_pb2_grpc as directorypb_grpc
|
26
28
|
import gml.proto.src.controlplane.filetransfer.ftpb.v1.ftpb_pb2 as ftpb
|
@@ -31,6 +33,7 @@ import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2 as mpb
|
|
31
33
|
import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2_grpc as mpb_grpc
|
32
34
|
import grpc
|
33
35
|
from gml._utils import chunk_file, sha256sum
|
36
|
+
from gml.device import DeviceCapabilities
|
34
37
|
from gml.model import Model
|
35
38
|
from gml.pipelines import Pipeline
|
36
39
|
|
@@ -65,6 +68,10 @@ class FileAlreadyExists(Exception):
|
|
65
68
|
pass
|
66
69
|
|
67
70
|
|
71
|
+
class ModelAlreadyExists(Exception):
|
72
|
+
pass
|
73
|
+
|
74
|
+
|
68
75
|
class OrgNotSet(Exception):
|
69
76
|
pass
|
70
77
|
|
@@ -106,6 +113,7 @@ class Client:
|
|
106
113
|
|
107
114
|
self._org_id_cache: Optional[uuidpb.UUID] = None
|
108
115
|
self._fts_stub_cache: Optional[ftpb_grpc.FileTransferServiceStub] = None
|
116
|
+
self._cs_stub_cache: Optional[cpb_grpc.CompilerServiceStub] = None
|
109
117
|
self._lps_stub_cache: Optional[lppb_grpc.LogicalPipelineServiceStub] = None
|
110
118
|
self._ods_stub_cache: Optional[directorypb_grpc.OrgDirectoryServiceStub] = None
|
111
119
|
self._ms_stub_cache: Optional[mpb_grpc.ModelServiceStub] = None
|
@@ -137,6 +145,13 @@ class Client:
|
|
137
145
|
)
|
138
146
|
return self._ods_stub_cache
|
139
147
|
|
148
|
+
def _cs_stub(self):
|
149
|
+
if self._cs_stub_cache is None:
|
150
|
+
self._cs_stub_cache = cpb_grpc.CompilerServiceStub(
|
151
|
+
self._channel_factory.get_grpc_channel()
|
152
|
+
)
|
153
|
+
return self._cs_stub_cache
|
154
|
+
|
140
155
|
def _ms_stub(self):
|
141
156
|
if self._ms_stub_cache is None:
|
142
157
|
self._ms_stub_cache = mpb_grpc.ModelServiceStub(
|
@@ -185,7 +200,7 @@ class Client:
|
|
185
200
|
file_id: uuidpb.UUID,
|
186
201
|
sha256: str,
|
187
202
|
file: TextIO | BinaryIO,
|
188
|
-
chunk_size=
|
203
|
+
chunk_size=1024 * 1024,
|
189
204
|
):
|
190
205
|
def chunked_requests():
|
191
206
|
file.seek(0)
|
@@ -206,7 +221,7 @@ class Client:
|
|
206
221
|
name: str,
|
207
222
|
file: TextIO | BinaryIO,
|
208
223
|
sha256: Optional[str] = None,
|
209
|
-
chunk_size=
|
224
|
+
chunk_size=1024 * 1024,
|
210
225
|
) -> ftpb.FileInfo:
|
211
226
|
file_info = self._create_file(name)
|
212
227
|
|
@@ -237,7 +252,21 @@ class Client:
|
|
237
252
|
raise Exception("file status is deleted or unknown, cannot re-upload")
|
238
253
|
return file_info
|
239
254
|
|
240
|
-
def
|
255
|
+
def _get_model_if_exists(self, name: str) -> Optional[cpb.Model]:
|
256
|
+
req = mpb.GetModelRequest(
|
257
|
+
name=name,
|
258
|
+
org_id=self._get_org_id(),
|
259
|
+
)
|
260
|
+
stub = self._ms_stub()
|
261
|
+
try:
|
262
|
+
resp = stub.GetModel(req, metadata=self._get_request_metadata())
|
263
|
+
return cpb.Model(id=resp.id, info=resp.model_info)
|
264
|
+
except grpc.RpcError as e:
|
265
|
+
if e.code() != grpc.StatusCode.NOT_FOUND:
|
266
|
+
raise e
|
267
|
+
return None
|
268
|
+
|
269
|
+
def _create_model(self, model_info: modelexecpb.ModelInfo) -> cpb.Model:
|
241
270
|
req = mpb.CreateModelRequest(
|
242
271
|
org_id=self._get_org_id(),
|
243
272
|
name=model_info.name,
|
@@ -247,26 +276,36 @@ class Client:
|
|
247
276
|
resp = stub.CreateModel(
|
248
277
|
req, metadata=self._get_request_metadata(idempotent=True)
|
249
278
|
)
|
250
|
-
return resp.id
|
279
|
+
return cpb.Model(id=resp.id, info=model_info)
|
280
|
+
|
281
|
+
def create_model(self, model: Model) -> cpb.Model:
|
282
|
+
existing_model = self._get_model_if_exists(model.name)
|
283
|
+
if existing_model is not None:
|
284
|
+
print(
|
285
|
+
'warning: model "{}" already exists and will not be uploaded.'.format(
|
286
|
+
model.name
|
287
|
+
)
|
288
|
+
)
|
289
|
+
return existing_model
|
251
290
|
|
252
|
-
def create_model(self, model: Model):
|
253
291
|
model_info = model.to_proto()
|
254
|
-
|
255
|
-
|
256
|
-
file
|
292
|
+
with model.collect_assets() as model_assets:
|
293
|
+
for asset_name, file in model_assets.items():
|
294
|
+
if isinstance(file, Path) or isinstance(file, str):
|
295
|
+
file = open(file, "rb")
|
257
296
|
|
258
|
-
|
297
|
+
sha256 = sha256sum(file)
|
259
298
|
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
299
|
+
upload_name = model.name
|
300
|
+
if asset_name:
|
301
|
+
upload_name += ":" + asset_name
|
302
|
+
print(f"Uploading {upload_name}...")
|
264
303
|
|
265
|
-
|
304
|
+
file_info = self._upload_file_if_not_exists(sha256, file, sha256)
|
266
305
|
|
267
|
-
|
306
|
+
model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
|
268
307
|
|
269
|
-
|
308
|
+
file.close()
|
270
309
|
|
271
310
|
return self._create_model(model_info)
|
272
311
|
|
@@ -285,9 +324,7 @@ class Client:
|
|
285
324
|
if isinstance(pipeline, Pipeline):
|
286
325
|
if self._org_name is None:
|
287
326
|
raise ValueError("must set `org` to upload a pipeline")
|
288
|
-
yaml = pipeline.to_yaml(
|
289
|
-
[model.name for model in models], self._org_name
|
290
|
-
)
|
327
|
+
yaml = pipeline.to_yaml(models, self._org_name)
|
291
328
|
else:
|
292
329
|
yaml = pipeline
|
293
330
|
else:
|
@@ -306,3 +343,52 @@ class Client:
|
|
306
343
|
req, metadata=self._get_request_metadata(idempotent=True)
|
307
344
|
)
|
308
345
|
return resp.id
|
346
|
+
|
347
|
+
def check_compile(
|
348
|
+
self,
|
349
|
+
*,
|
350
|
+
models: List[Model],
|
351
|
+
pipeline_file: Optional[Path] = None,
|
352
|
+
pipeline: Optional[Union[str, Pipeline]] = None,
|
353
|
+
runtimes=List[str],
|
354
|
+
cameras=List[str],
|
355
|
+
):
|
356
|
+
model_with_assets = []
|
357
|
+
for model in models:
|
358
|
+
created_model = self.create_model(model)
|
359
|
+
model_with_assets.append(created_model)
|
360
|
+
|
361
|
+
if pipeline_file is not None:
|
362
|
+
with open(pipeline_file, "r") as f:
|
363
|
+
yaml = f.read()
|
364
|
+
elif isinstance(pipeline, Pipeline):
|
365
|
+
if self._org_name is None:
|
366
|
+
raise ValueError("must set `org` to compile a Pipeline object")
|
367
|
+
yaml = pipeline.to_yaml(models, self._org_name)
|
368
|
+
elif pipeline is None:
|
369
|
+
raise ValueError("must specify one of 'pipeline_file' or 'pipeline'")
|
370
|
+
else:
|
371
|
+
yaml = pipeline
|
372
|
+
|
373
|
+
stub = self._lps_stub()
|
374
|
+
req = lppb.ParseLogicalPipelineYAMLRequest(
|
375
|
+
yaml=yaml,
|
376
|
+
)
|
377
|
+
resp: lppb.ParseLogicalPipelineYAMLResponse = stub.ParseLogicalPipelineYAML(
|
378
|
+
req, metadata=self._get_request_metadata()
|
379
|
+
)
|
380
|
+
pipeline = resp.logical_pipeline
|
381
|
+
|
382
|
+
capabilities = DeviceCapabilities(runtimes=runtimes, cameras=cameras)
|
383
|
+
|
384
|
+
stub = self._cs_stub()
|
385
|
+
req = cpb.CompileRequest(
|
386
|
+
logical_pipeline=pipeline,
|
387
|
+
device_capabilities=capabilities.to_proto(),
|
388
|
+
models=model_with_assets,
|
389
|
+
)
|
390
|
+
try:
|
391
|
+
stub.Compile(req, metadata=self._get_request_metadata())
|
392
|
+
print("Compilation successful")
|
393
|
+
except grpc.RpcError as e:
|
394
|
+
print("Compilation failed, code {0}: {1}".format(e.code(), e.details()))
|
gml/compile.py
CHANGED
@@ -14,13 +14,111 @@
|
|
14
14
|
#
|
15
15
|
# SPDX-License-Identifier: Apache-2.0
|
16
16
|
|
17
|
-
|
18
|
-
|
19
|
-
from
|
20
|
-
from torch_mlir.dynamo import _get_decomposition_table
|
17
|
+
import contextlib
|
18
|
+
import functools
|
19
|
+
from typing import Dict, List, Optional, Sequence, Union
|
21
20
|
|
21
|
+
import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
|
22
|
+
import torch
|
23
|
+
import torch_mlir # noqa
|
24
|
+
|
25
|
+
try:
|
26
|
+
import torch_mlir.fx # noqa
|
27
|
+
from torch.export import export # noqa
|
28
|
+
|
29
|
+
has_fx_importer_torch_export = True
|
30
|
+
except ImportError:
|
31
|
+
has_fx_importer_torch_export = False
|
32
|
+
|
33
|
+
|
34
|
+
def _default_decomposition_denylist():
|
35
|
+
"""These ops will not be decomposed by default."""
|
36
|
+
return [
|
37
|
+
torch.ops.aten.full.default,
|
38
|
+
torch.ops.aten.upsample_bilinear2d.vec,
|
39
|
+
]
|
40
|
+
|
41
|
+
|
42
|
+
@contextlib.contextmanager
|
43
|
+
def _patch_aot_export_module():
|
44
|
+
"""This contextmanager prevents PyTorch dispatch from running when calling aot_export_module.
|
45
|
+
|
46
|
+
This patch is necessary because not all callers of `aot_export_module` expose the pre_dispatch flag.
|
47
|
+
For example, `ExportedProgram.run_decompositions` which is called by `torch_mlir.fx.export_and_import` doesn't
|
48
|
+
expose the pre_dispatch flag.
|
49
|
+
|
50
|
+
Without setting `pre_dispatch=True`, PyTorch dispatch will run before tracing which causes certain operations to be decomposed.
|
51
|
+
For example, `upsample_nearest2d` will be decomposed into aten.index.Tensor calls. This is undesirable for runtimes that provide
|
52
|
+
optimized implementations of the equivalent of `upsample_nearest2d`.
|
53
|
+
"""
|
54
|
+
import torch._functorch.aot_autograd
|
55
|
+
|
56
|
+
orig = torch._functorch.aot_autograd.aot_export_module
|
57
|
+
torch._functorch.aot_autograd.aot_export_module = functools.partial(
|
58
|
+
orig, pre_dispatch=True
|
59
|
+
)
|
60
|
+
yield
|
61
|
+
torch._functorch.aot_autograd.aot_export_module = orig
|
62
|
+
|
63
|
+
|
64
|
+
def to_torch_mlir_w_torch_export(
|
65
|
+
model: torch.nn.Module,
|
66
|
+
example_inputs: Sequence[torch.Tensor],
|
67
|
+
dynamic_shapes: Optional[
|
68
|
+
Sequence[Dict[int, Union[str, "torch.export.dynamic_shapes._Dim"]]]
|
69
|
+
] = None,
|
70
|
+
decomposition_denylist: Optional[List[torch._ops.OperatorBase]] = None,
|
71
|
+
):
|
72
|
+
from torch._decomp import remove_decompositions
|
73
|
+
from torch.export._trace import _export
|
74
|
+
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
75
|
+
from torch_mlir.fx import export_and_import
|
76
|
+
|
77
|
+
if dynamic_shapes is not None:
|
78
|
+
for shape in dynamic_shapes:
|
79
|
+
if not isinstance(shape, dict):
|
80
|
+
continue
|
81
|
+
for idx in shape:
|
82
|
+
if isinstance(shape[idx], torch.export.dynamic_shapes._Dim):
|
83
|
+
continue
|
84
|
+
shape[idx] = torch.export.Dim(shape[idx])
|
85
|
+
|
86
|
+
if decomposition_denylist is None:
|
87
|
+
decomposition_denylist = _default_decomposition_denylist()
|
88
|
+
|
89
|
+
model = model.eval().to("cpu")
|
90
|
+
|
91
|
+
try:
|
92
|
+
# Running the model a few times on the inputs, leads to more consistent compiled results.
|
93
|
+
for _ in range(2):
|
94
|
+
_ = model(*example_inputs)
|
95
|
+
except: # noqa
|
96
|
+
# Ignore errors running the model. This can happen when the model has data dependent branches.
|
97
|
+
pass
|
98
|
+
|
99
|
+
prog = _export(
|
100
|
+
model,
|
101
|
+
tuple(example_inputs),
|
102
|
+
pre_dispatch=False,
|
103
|
+
strict=False,
|
104
|
+
dynamic_shapes=dynamic_shapes,
|
105
|
+
)
|
106
|
+
decomp_table = get_decomposition_table()
|
107
|
+
remove_decompositions(decomp_table, decomposition_denylist)
|
108
|
+
with _patch_aot_export_module():
|
109
|
+
return export_and_import(
|
110
|
+
prog,
|
111
|
+
*example_inputs,
|
112
|
+
decomposition_table=decomp_table,
|
113
|
+
)
|
114
|
+
|
115
|
+
|
116
|
+
def to_torch_mlir_fallback(model, example_inputs):
|
117
|
+
from torch.fx.experimental.proxy_tensor import make_fx
|
118
|
+
from torch_mlir import ExampleArgs, OutputType
|
119
|
+
from torch_mlir import compile as torch_mlir_compile
|
120
|
+
from torch_mlir.dynamo import _get_decomposition_table
|
22
121
|
|
23
|
-
def to_torch_mlir(model, example_inputs):
|
24
122
|
example_args = ExampleArgs.get(example_inputs)
|
25
123
|
args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)[
|
26
124
|
"forward"
|
@@ -60,3 +158,23 @@ def to_torch_mlir(model, example_inputs):
|
|
60
158
|
)
|
61
159
|
|
62
160
|
return compiled
|
161
|
+
|
162
|
+
|
163
|
+
def to_torch_mlir(
|
164
|
+
model,
|
165
|
+
example_inputs,
|
166
|
+
dynamic_shapes: Optional[
|
167
|
+
Sequence[Dict[int, Union[str, "torch.export.dynamic_shapes._Dim"]]]
|
168
|
+
] = None,
|
169
|
+
):
|
170
|
+
if has_fx_importer_torch_export:
|
171
|
+
return to_torch_mlir_w_torch_export(model, example_inputs, dynamic_shapes)
|
172
|
+
else:
|
173
|
+
return to_torch_mlir_fallback(model, example_inputs)
|
174
|
+
|
175
|
+
|
176
|
+
def torch_mlir_output_kind():
|
177
|
+
if has_fx_importer_torch_export:
|
178
|
+
return modelexecpb.ModelInfo.MODEL_KIND_TORCH
|
179
|
+
else:
|
180
|
+
return modelexecpb.ModelInfo.MODEL_KIND_TORCHSCRIPT
|
gml/device.py
ADDED
@@ -0,0 +1,74 @@
|
|
1
|
+
# Copyright 2023- Gimlet Labs, Inc.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
# SPDX-License-Identifier: Apache-2.0
|
16
|
+
|
17
|
+
from typing import List
|
18
|
+
|
19
|
+
import gml.proto.src.api.corepb.v1.cp_edge_pb2 as cpedgepb
|
20
|
+
|
21
|
+
|
22
|
+
class DeviceCapabilities:
|
23
|
+
def __init__(self, runtimes: List[str], cameras: List[str]):
|
24
|
+
self.runtimes = runtimes
|
25
|
+
self.cameras = cameras
|
26
|
+
|
27
|
+
def to_proto(self) -> cpedgepb.DeviceCapabilities:
|
28
|
+
return cpedgepb.DeviceCapabilities(
|
29
|
+
model_runtimes=[
|
30
|
+
cpedgepb.DeviceCapabilities.ModelRuntimeInfo(
|
31
|
+
type=_runtime_str_to_runtime_protos(runtime)
|
32
|
+
)
|
33
|
+
for runtime in self.runtimes
|
34
|
+
],
|
35
|
+
cameras=[
|
36
|
+
cpedgepb.DeviceCapabilities.CameraInfo(
|
37
|
+
driver=_camera_driver_str_to_camera_driver_protos(camera),
|
38
|
+
camera_id=str(idx),
|
39
|
+
)
|
40
|
+
for idx, camera in enumerate(self.cameras)
|
41
|
+
],
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
def _runtime_str_to_runtime_protos(
|
46
|
+
runtime: str,
|
47
|
+
) -> cpedgepb.DeviceCapabilities.ModelRuntimeInfo.ModelRuntimeType:
|
48
|
+
match runtime.lower():
|
49
|
+
case "tensorrt":
|
50
|
+
return (
|
51
|
+
cpedgepb.DeviceCapabilities.ModelRuntimeInfo.ModelRuntimeType.MODEL_RUNTIME_TYPE_TENSORRT
|
52
|
+
)
|
53
|
+
case "openvino":
|
54
|
+
return (
|
55
|
+
cpedgepb.DeviceCapabilities.ModelRuntimeInfo.ModelRuntimeType.MODEL_RUNTIME_TYPE_OPENVINO
|
56
|
+
)
|
57
|
+
case _:
|
58
|
+
raise ValueError("invalid runtime: {}".format(runtime))
|
59
|
+
|
60
|
+
|
61
|
+
def _camera_driver_str_to_camera_driver_protos(
|
62
|
+
driver: str,
|
63
|
+
) -> cpedgepb.DeviceCapabilities.CameraInfo.CameraDriver:
|
64
|
+
match driver.lower():
|
65
|
+
case "argus":
|
66
|
+
return (
|
67
|
+
cpedgepb.DeviceCapabilities.CameraInfo.CameraDriver.CAMERA_DRIVER_ARGUS
|
68
|
+
)
|
69
|
+
case "v4l2":
|
70
|
+
return (
|
71
|
+
cpedgepb.DeviceCapabilities.CameraInfo.CameraDriver.CAMERA_DRIVER_V4L2
|
72
|
+
)
|
73
|
+
case _:
|
74
|
+
raise ValueError("invalid driver: {}".format(driver))
|