gimlet-api 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,8 +9,10 @@ Classifier: Typing :: Typed
9
9
  Requires-Python: >=3
10
10
  Requires-Dist: protobuf
11
11
  Requires-Dist: grpcio
12
- Requires-Dist: torch
13
- Requires-Dist: torch_mlir_gml
14
- Version: 0.0.5
12
+ Requires-Dist: torch>=2.3.0
13
+ Requires-Dist: torch_mlir_gml==0.0.2
14
+ Requires-Dist: numpy<2.0.0
15
+ Requires-Dist: transformers>=4.43.3
16
+ Version: 0.0.6
15
17
 
16
18
  UNKNOWN
@@ -1,11 +1,13 @@
1
- gml/__init__.py,sha256=e81afOgzDBpBNVb3pv_9Y0waGPuX95Y73ofWsq-UVak,718
1
+ gml/__init__.py,sha256=H3WQZ_RaN7VNeb__qeHEbKLEwkaG7gpL5FQ8s1IotUA,773
2
2
  gml/_utils.py,sha256=mSCWHhCdzcUvHqmJIB2FS215K1LMgJCWcZ6e6FWK3hQ,1184
3
- gml/client.py,sha256=4wJ5JeyhMc7UTJazD0GvcHH90Oqiu-x7UY-Q1Zl2Wxo,10640
4
- gml/compile.py,sha256=pS_mjATIvAH-zxSbIgZeizmpi08571zpOkn5bWRyzM0,2087
5
- gml/model.py,sha256=Nwf44iTZjMOwFL05JKHaEvTd1QAkJdtgaAuFs6RVUp8,5256
6
- gml/model_utils.py,sha256=67w32zLI7rjlNsKW6nvCUp-HUjYNMCKMDhjymgJ6EOU,1275
7
- gml/pipelines.py,sha256=RObsuECms-dGFAMskxlknxfV5-9d5sEOll7McUP5reo,3737
8
- gml/preprocessing.py,sha256=w-D_GzRndY-Hec8O5UdMSVDXMWtohbwPRxizUjhHsFA,3004
3
+ gml/client.py,sha256=5QDKljltBeBTCd2hH38--fTSP0bVVcAvSnWsA9YEFQc,13819
4
+ gml/compile.py,sha256=hR8u3LaMiIW8d12FHrvtmtgzUNQq48DYxe8bW-wJ_VY,6054
5
+ gml/device.py,sha256=VUZc6m8QalJ7G9KBKjCY4cIcv2VBd6zAT3ysnh_m1Z0,2585
6
+ gml/hf.py,sha256=GRvEEl9zSIv0iWN91Z6ykFYZ2VdNAVABjZrrzYWUFw4,17792
7
+ gml/model.py,sha256=nXUV6-L4TIkQHCWUpWyG7QJ6YKTZb7eauW9F4pzVTII,6566
8
+ gml/model_utils.py,sha256=vZvE5cHZIDkUkeZ4Pk4hhV-zOYMiREluv4b8kdqQ3Ig,1375
9
+ gml/pipelines.py,sha256=Bha8J3b5uW8COIejiH12NNF0Tc0XDBt2B3Dez5Jxt4s,5314
10
+ gml/preprocessing.py,sha256=STQDSA1_jXPTenJotNtsNMXOc9h1x_wJyQ100LXS6-g,3209
9
11
  gml/proto/gogoproto/gogo_pb2.py,sha256=WVMIAR8K--mCUkTPM7mEeeXGpQlRRtt_kco10iP3CZs,15728
10
12
  gml/proto/mediapipe/framework/calculator_contract_test_pb2.py,sha256=hNjyZCBz3RYa6rN4xR3FOCZKA24gq_LsJ3EMegl5wK4,2031
11
13
  gml/proto/mediapipe/framework/calculator_options_pb2.py,sha256=Nq1BQRtLdsIgfkw7ymD3eg2p2_RSlZhiHS7YbDhNHR0,1563
@@ -23,21 +25,23 @@ gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=wQjeDti-C8JiNwRn-z5
23
25
  gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=t2Far6oVcUFQIimzgAkZ8vQd0asMIlvECp4osC0ujgg,9735
24
26
  gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=cbNmE12Nm3PjW4NXU7-Z-9m_0Zs3Ab8R1xLkDnvclCg,1730
25
27
  gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=FPNx5fRXj-bnN5mkDUXVz17M33vuHV_hmxH0ggkAUVs,5536
26
- gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=QPd5gMZvt97ukeHbrgqwroJH4oKTNtjEiAlgfk-FH-k,13496
27
- gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=yNTIivI9LChFifs90eO5dKwVmAR3WGPYQWXHOuXGUD4,6528
28
- gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=U8ujw_7ijKrULiCqxJ4WY9VWiUtIemfIy2e6h1NPrto,25106
28
+ gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=u41Sohshi6gBfeZO5VnQzfRStFADFzT1Um5mDY9chcg,15309
29
+ gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=1qA-ElTgWeGv3oevYlIjK1TIRSgWbR1TTWxA6Q3SOXk,7224
30
+ gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=1DM58lSFgfoHk0ui3ZTjDfifgp4dhE7nHvhMwmInpsA,27103
29
31
  gml/proto/src/common/typespb/jwt_pb2.py,sha256=lxy-bqbyg96i9n_xr2JbkuWX-ldnoJavXPMnApzVSio,5580
30
32
  gml/proto/src/common/typespb/status_pb2.py,sha256=IbBJnbsAlvsuTtyT285ZuW6k5VaPfl5kRSOnBxD_H8M,2109
31
33
  gml/proto/src/common/typespb/uuid_pb2.py,sha256=5Fm3jYpCPX7sMrP6RhRYsF0SnuZNIBEQJk9f0jwZ2Rw,1188
34
+ gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256=tkJFPWpndKZy19TFuLKlBfWW1fUQPj0lJLiQ9HfugZU,3213
35
+ gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2_grpc.py,sha256=l-gTK9nYpTlVb7QGAckSQXlHhkRdKe2-nrxXc8NQavY,2912
32
36
  gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py,sha256=KgoUT8ccF-yJPe1r4otQjAPQoKBaQzdBlHoIUSkk0yE,11445
33
37
  gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py,sha256=p3OpT8-hfNHu4-29qr-ZahRwO-LoCYM9Q4jomAHTXGA,24572
34
38
  gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py,sha256=r8mbJNTq45_c0amPnTr8OFZasCk7XWu2YS_eu7GfWJg,7050
35
39
  gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py,sha256=XlE4R2PJaOmzQocx7y6SKJvuqt8tYBGzBuhajvzG0cc,12919
36
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=AvPGThkFumcElIYY9fjqWdTX9J44RCEJbaYi1F7gn4c,5661
37
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=q1PugN3Jm_4v5hVWADJLCIkIEC2_beKEqEH4vb_SpH8,7396
40
+ gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=wvLQvoh2UA5qCcMALT6PS47LYmmVdBz9U47WFLs5Ayg,6330
41
+ gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=-snjW7n6JveUzJVPFcm25XlL19kowPSKgd61l_jPnHA,9541
38
42
  gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py,sha256=RVedXkNYu2iF5OHiXoYyRw9AGRCUWG7qNyY-5QY71Go,3762
39
43
  gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py,sha256=KSdb6V04qUHDsb1R2o3wixwTyZgrhwnPYobjnRgWX4I,4735
40
- gml/tensor.py,sha256=8Ne95YFAkid_jjbtuXFtn_Eu0Hn9u5IaBRuwR7BpzxI,11827
41
- gimlet_api-0.0.5.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
42
- gimlet_api-0.0.5.dist-info/METADATA,sha256=t2CgErN9FNsBNkrCtmXNZcFlTnWzB3cA1vtL4jA1Y-c,429
43
- gimlet_api-0.0.5.dist-info/RECORD,,
44
+ gml/tensor.py,sha256=753IsMFYZD7p_f0cQPt4nTIBo5p5S5ELqwCuoHORdMk,14823
45
+ gimlet_api-0.0.6.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
46
+ gimlet_api-0.0.6.dist-info/METADATA,sha256=mF73_t-Tn5NPVxLnJBOTCQkKvb0weobtQLmOqSkc4B0,506
47
+ gimlet_api-0.0.6.dist-info/RECORD,,
gml/__init__.py CHANGED
@@ -15,4 +15,5 @@
15
15
  # SPDX-License-Identifier: Apache-2.0
16
16
 
17
17
  from gml.client import Client # noqa
18
+ from gml.hf import import_huggingface_pipeline # noqa
18
19
  from gml.model import ModelFromFiles, TorchModel # noqa
gml/client.py CHANGED
@@ -21,6 +21,8 @@ from typing import BinaryIO, List, Optional, TextIO, Union
21
21
 
22
22
  import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
23
23
  import gml.proto.src.common.typespb.uuid_pb2 as uuidpb
24
+ import gml.proto.src.controlplane.compiler.cpb.v1.cpb_pb2 as cpb
25
+ import gml.proto.src.controlplane.compiler.cpb.v1.cpb_pb2_grpc as cpb_grpc
24
26
  import gml.proto.src.controlplane.directory.directorypb.v1.directory_pb2 as directorypb
25
27
  import gml.proto.src.controlplane.directory.directorypb.v1.directory_pb2_grpc as directorypb_grpc
26
28
  import gml.proto.src.controlplane.filetransfer.ftpb.v1.ftpb_pb2 as ftpb
@@ -31,6 +33,7 @@ import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2 as mpb
31
33
  import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2_grpc as mpb_grpc
32
34
  import grpc
33
35
  from gml._utils import chunk_file, sha256sum
36
+ from gml.device import DeviceCapabilities
34
37
  from gml.model import Model
35
38
  from gml.pipelines import Pipeline
36
39
 
@@ -65,6 +68,10 @@ class FileAlreadyExists(Exception):
65
68
  pass
66
69
 
67
70
 
71
+ class ModelAlreadyExists(Exception):
72
+ pass
73
+
74
+
68
75
  class OrgNotSet(Exception):
69
76
  pass
70
77
 
@@ -106,6 +113,7 @@ class Client:
106
113
 
107
114
  self._org_id_cache: Optional[uuidpb.UUID] = None
108
115
  self._fts_stub_cache: Optional[ftpb_grpc.FileTransferServiceStub] = None
116
+ self._cs_stub_cache: Optional[cpb_grpc.CompilerServiceStub] = None
109
117
  self._lps_stub_cache: Optional[lppb_grpc.LogicalPipelineServiceStub] = None
110
118
  self._ods_stub_cache: Optional[directorypb_grpc.OrgDirectoryServiceStub] = None
111
119
  self._ms_stub_cache: Optional[mpb_grpc.ModelServiceStub] = None
@@ -137,6 +145,13 @@ class Client:
137
145
  )
138
146
  return self._ods_stub_cache
139
147
 
148
+ def _cs_stub(self):
149
+ if self._cs_stub_cache is None:
150
+ self._cs_stub_cache = cpb_grpc.CompilerServiceStub(
151
+ self._channel_factory.get_grpc_channel()
152
+ )
153
+ return self._cs_stub_cache
154
+
140
155
  def _ms_stub(self):
141
156
  if self._ms_stub_cache is None:
142
157
  self._ms_stub_cache = mpb_grpc.ModelServiceStub(
@@ -185,7 +200,7 @@ class Client:
185
200
  file_id: uuidpb.UUID,
186
201
  sha256: str,
187
202
  file: TextIO | BinaryIO,
188
- chunk_size=64 * 1024,
203
+ chunk_size=1024 * 1024,
189
204
  ):
190
205
  def chunked_requests():
191
206
  file.seek(0)
@@ -206,7 +221,7 @@ class Client:
206
221
  name: str,
207
222
  file: TextIO | BinaryIO,
208
223
  sha256: Optional[str] = None,
209
- chunk_size=64 * 1024,
224
+ chunk_size=1024 * 1024,
210
225
  ) -> ftpb.FileInfo:
211
226
  file_info = self._create_file(name)
212
227
 
@@ -237,7 +252,21 @@ class Client:
237
252
  raise Exception("file status is deleted or unknown, cannot re-upload")
238
253
  return file_info
239
254
 
240
- def _create_model(self, model_info: modelexecpb.ModelInfo):
255
+ def _get_model_if_exists(self, name: str) -> Optional[cpb.Model]:
256
+ req = mpb.GetModelRequest(
257
+ name=name,
258
+ org_id=self._get_org_id(),
259
+ )
260
+ stub = self._ms_stub()
261
+ try:
262
+ resp = stub.GetModel(req, metadata=self._get_request_metadata())
263
+ return cpb.Model(id=resp.id, info=resp.model_info)
264
+ except grpc.RpcError as e:
265
+ if e.code() != grpc.StatusCode.NOT_FOUND:
266
+ raise e
267
+ return None
268
+
269
+ def _create_model(self, model_info: modelexecpb.ModelInfo) -> cpb.Model:
241
270
  req = mpb.CreateModelRequest(
242
271
  org_id=self._get_org_id(),
243
272
  name=model_info.name,
@@ -247,26 +276,36 @@ class Client:
247
276
  resp = stub.CreateModel(
248
277
  req, metadata=self._get_request_metadata(idempotent=True)
249
278
  )
250
- return resp.id
279
+ return cpb.Model(id=resp.id, info=model_info)
280
+
281
+ def create_model(self, model: Model) -> cpb.Model:
282
+ existing_model = self._get_model_if_exists(model.name)
283
+ if existing_model is not None:
284
+ print(
285
+ 'warning: model "{}" already exists and will not be uploaded.'.format(
286
+ model.name
287
+ )
288
+ )
289
+ return existing_model
251
290
 
252
- def create_model(self, model: Model):
253
291
  model_info = model.to_proto()
254
- for asset_name, file in model.collect_assets().items():
255
- if isinstance(file, Path) or isinstance(file, str):
256
- file = open(file, "rb")
292
+ with model.collect_assets() as model_assets:
293
+ for asset_name, file in model_assets.items():
294
+ if isinstance(file, Path) or isinstance(file, str):
295
+ file = open(file, "rb")
257
296
 
258
- sha256 = sha256sum(file)
297
+ sha256 = sha256sum(file)
259
298
 
260
- upload_name = model.name
261
- if asset_name:
262
- upload_name += ":" + asset_name
263
- print(f"Uploading {upload_name}...")
299
+ upload_name = model.name
300
+ if asset_name:
301
+ upload_name += ":" + asset_name
302
+ print(f"Uploading {upload_name}...")
264
303
 
265
- file_info = self._upload_file_if_not_exists(sha256, file, sha256)
304
+ file_info = self._upload_file_if_not_exists(sha256, file, sha256)
266
305
 
267
- model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
306
+ model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
268
307
 
269
- file.close()
308
+ file.close()
270
309
 
271
310
  return self._create_model(model_info)
272
311
 
@@ -285,9 +324,7 @@ class Client:
285
324
  if isinstance(pipeline, Pipeline):
286
325
  if self._org_name is None:
287
326
  raise ValueError("must set `org` to upload a pipeline")
288
- yaml = pipeline.to_yaml(
289
- [model.name for model in models], self._org_name
290
- )
327
+ yaml = pipeline.to_yaml(models, self._org_name)
291
328
  else:
292
329
  yaml = pipeline
293
330
  else:
@@ -306,3 +343,52 @@ class Client:
306
343
  req, metadata=self._get_request_metadata(idempotent=True)
307
344
  )
308
345
  return resp.id
346
+
347
+ def check_compile(
348
+ self,
349
+ *,
350
+ models: List[Model],
351
+ pipeline_file: Optional[Path] = None,
352
+ pipeline: Optional[Union[str, Pipeline]] = None,
353
+ runtimes=List[str],
354
+ cameras=List[str],
355
+ ):
356
+ model_with_assets = []
357
+ for model in models:
358
+ created_model = self.create_model(model)
359
+ model_with_assets.append(created_model)
360
+
361
+ if pipeline_file is not None:
362
+ with open(pipeline_file, "r") as f:
363
+ yaml = f.read()
364
+ elif isinstance(pipeline, Pipeline):
365
+ if self._org_name is None:
366
+ raise ValueError("must set `org` to compile a Pipeline object")
367
+ yaml = pipeline.to_yaml(models, self._org_name)
368
+ elif pipeline is None:
369
+ raise ValueError("must specify one of 'pipeline_file' or 'pipeline'")
370
+ else:
371
+ yaml = pipeline
372
+
373
+ stub = self._lps_stub()
374
+ req = lppb.ParseLogicalPipelineYAMLRequest(
375
+ yaml=yaml,
376
+ )
377
+ resp: lppb.ParseLogicalPipelineYAMLResponse = stub.ParseLogicalPipelineYAML(
378
+ req, metadata=self._get_request_metadata()
379
+ )
380
+ pipeline = resp.logical_pipeline
381
+
382
+ capabilities = DeviceCapabilities(runtimes=runtimes, cameras=cameras)
383
+
384
+ stub = self._cs_stub()
385
+ req = cpb.CompileRequest(
386
+ logical_pipeline=pipeline,
387
+ device_capabilities=capabilities.to_proto(),
388
+ models=model_with_assets,
389
+ )
390
+ try:
391
+ stub.Compile(req, metadata=self._get_request_metadata())
392
+ print("Compilation successful")
393
+ except grpc.RpcError as e:
394
+ print("Compilation failed, code {0}: {1}".format(e.code(), e.details()))
gml/compile.py CHANGED
@@ -14,13 +14,111 @@
14
14
  #
15
15
  # SPDX-License-Identifier: Apache-2.0
16
16
 
17
- from torch.fx.experimental.proxy_tensor import make_fx
18
- from torch_mlir import ExampleArgs, OutputType
19
- from torch_mlir import compile as torch_mlir_compile
20
- from torch_mlir.dynamo import _get_decomposition_table
17
+ import contextlib
18
+ import functools
19
+ from typing import Dict, List, Optional, Sequence, Union
21
20
 
21
+ import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
22
+ import torch
23
+ import torch_mlir # noqa
24
+
25
+ try:
26
+ import torch_mlir.fx # noqa
27
+ from torch.export import export # noqa
28
+
29
+ has_fx_importer_torch_export = True
30
+ except ImportError:
31
+ has_fx_importer_torch_export = False
32
+
33
+
34
+ def _default_decomposition_denylist():
35
+ """These ops will not be decomposed by default."""
36
+ return [
37
+ torch.ops.aten.full.default,
38
+ torch.ops.aten.upsample_bilinear2d.vec,
39
+ ]
40
+
41
+
42
+ @contextlib.contextmanager
43
+ def _patch_aot_export_module():
44
+ """This contextmanager prevents PyTorch dispatch from running when calling aot_export_module.
45
+
46
+ This patch is necessary because not all callers of `aot_export_module` expose the pre_dispatch flag.
47
+ For example, `ExportedProgram.run_decompositions` which is called by `torch_mlir.fx.export_and_import` doesn't
48
+ expose the pre_dispatch flag.
49
+
50
+ Without setting `pre_dispatch=True`, PyTorch dispatch will run before tracing which causes certain operations to be decomposed.
51
+ For example, `upsample_nearest2d` will be decomposed into aten.index.Tensor calls. This is undesirable for runtimes that provide
52
+ optimized implementations of the equivalent of `upsample_nearest2d`.
53
+ """
54
+ import torch._functorch.aot_autograd
55
+
56
+ orig = torch._functorch.aot_autograd.aot_export_module
57
+ torch._functorch.aot_autograd.aot_export_module = functools.partial(
58
+ orig, pre_dispatch=True
59
+ )
60
+ yield
61
+ torch._functorch.aot_autograd.aot_export_module = orig
62
+
63
+
64
+ def to_torch_mlir_w_torch_export(
65
+ model: torch.nn.Module,
66
+ example_inputs: Sequence[torch.Tensor],
67
+ dynamic_shapes: Optional[
68
+ Sequence[Dict[int, Union[str, "torch.export.dynamic_shapes._Dim"]]]
69
+ ] = None,
70
+ decomposition_denylist: Optional[List[torch._ops.OperatorBase]] = None,
71
+ ):
72
+ from torch._decomp import remove_decompositions
73
+ from torch.export._trace import _export
74
+ from torch_mlir.extras.fx_decomp_util import get_decomposition_table
75
+ from torch_mlir.fx import export_and_import
76
+
77
+ if dynamic_shapes is not None:
78
+ for shape in dynamic_shapes:
79
+ if not isinstance(shape, dict):
80
+ continue
81
+ for idx in shape:
82
+ if isinstance(shape[idx], torch.export.dynamic_shapes._Dim):
83
+ continue
84
+ shape[idx] = torch.export.Dim(shape[idx])
85
+
86
+ if decomposition_denylist is None:
87
+ decomposition_denylist = _default_decomposition_denylist()
88
+
89
+ model = model.eval().to("cpu")
90
+
91
+ try:
92
+ # Running the model a few times on the inputs, leads to more consistent compiled results.
93
+ for _ in range(2):
94
+ _ = model(*example_inputs)
95
+ except: # noqa
96
+ # Ignore errors running the model. This can happen when the model has data dependent branches.
97
+ pass
98
+
99
+ prog = _export(
100
+ model,
101
+ tuple(example_inputs),
102
+ pre_dispatch=False,
103
+ strict=False,
104
+ dynamic_shapes=dynamic_shapes,
105
+ )
106
+ decomp_table = get_decomposition_table()
107
+ remove_decompositions(decomp_table, decomposition_denylist)
108
+ with _patch_aot_export_module():
109
+ return export_and_import(
110
+ prog,
111
+ *example_inputs,
112
+ decomposition_table=decomp_table,
113
+ )
114
+
115
+
116
+ def to_torch_mlir_fallback(model, example_inputs):
117
+ from torch.fx.experimental.proxy_tensor import make_fx
118
+ from torch_mlir import ExampleArgs, OutputType
119
+ from torch_mlir import compile as torch_mlir_compile
120
+ from torch_mlir.dynamo import _get_decomposition_table
22
121
 
23
- def to_torch_mlir(model, example_inputs):
24
122
  example_args = ExampleArgs.get(example_inputs)
25
123
  args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)[
26
124
  "forward"
@@ -60,3 +158,23 @@ def to_torch_mlir(model, example_inputs):
60
158
  )
61
159
 
62
160
  return compiled
161
+
162
+
163
+ def to_torch_mlir(
164
+ model,
165
+ example_inputs,
166
+ dynamic_shapes: Optional[
167
+ Sequence[Dict[int, Union[str, "torch.export.dynamic_shapes._Dim"]]]
168
+ ] = None,
169
+ ):
170
+ if has_fx_importer_torch_export:
171
+ return to_torch_mlir_w_torch_export(model, example_inputs, dynamic_shapes)
172
+ else:
173
+ return to_torch_mlir_fallback(model, example_inputs)
174
+
175
+
176
+ def torch_mlir_output_kind():
177
+ if has_fx_importer_torch_export:
178
+ return modelexecpb.ModelInfo.MODEL_KIND_TORCH
179
+ else:
180
+ return modelexecpb.ModelInfo.MODEL_KIND_TORCHSCRIPT
gml/device.py ADDED
@@ -0,0 +1,74 @@
1
+ # Copyright 2023- Gimlet Labs, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ from typing import List
18
+
19
+ import gml.proto.src.api.corepb.v1.cp_edge_pb2 as cpedgepb
20
+
21
+
22
+ class DeviceCapabilities:
23
+ def __init__(self, runtimes: List[str], cameras: List[str]):
24
+ self.runtimes = runtimes
25
+ self.cameras = cameras
26
+
27
+ def to_proto(self) -> cpedgepb.DeviceCapabilities:
28
+ return cpedgepb.DeviceCapabilities(
29
+ model_runtimes=[
30
+ cpedgepb.DeviceCapabilities.ModelRuntimeInfo(
31
+ type=_runtime_str_to_runtime_protos(runtime)
32
+ )
33
+ for runtime in self.runtimes
34
+ ],
35
+ cameras=[
36
+ cpedgepb.DeviceCapabilities.CameraInfo(
37
+ driver=_camera_driver_str_to_camera_driver_protos(camera),
38
+ camera_id=str(idx),
39
+ )
40
+ for idx, camera in enumerate(self.cameras)
41
+ ],
42
+ )
43
+
44
+
45
+ def _runtime_str_to_runtime_protos(
46
+ runtime: str,
47
+ ) -> cpedgepb.DeviceCapabilities.ModelRuntimeInfo.ModelRuntimeType:
48
+ match runtime.lower():
49
+ case "tensorrt":
50
+ return (
51
+ cpedgepb.DeviceCapabilities.ModelRuntimeInfo.ModelRuntimeType.MODEL_RUNTIME_TYPE_TENSORRT
52
+ )
53
+ case "openvino":
54
+ return (
55
+ cpedgepb.DeviceCapabilities.ModelRuntimeInfo.ModelRuntimeType.MODEL_RUNTIME_TYPE_OPENVINO
56
+ )
57
+ case _:
58
+ raise ValueError("invalid runtime: {}".format(runtime))
59
+
60
+
61
+ def _camera_driver_str_to_camera_driver_protos(
62
+ driver: str,
63
+ ) -> cpedgepb.DeviceCapabilities.CameraInfo.CameraDriver:
64
+ match driver.lower():
65
+ case "argus":
66
+ return (
67
+ cpedgepb.DeviceCapabilities.CameraInfo.CameraDriver.CAMERA_DRIVER_ARGUS
68
+ )
69
+ case "v4l2":
70
+ return (
71
+ cpedgepb.DeviceCapabilities.CameraInfo.CameraDriver.CAMERA_DRIVER_V4L2
72
+ )
73
+ case _:
74
+ raise ValueError("invalid driver: {}".format(driver))