gimlet-api 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,9 +10,10 @@ Requires-Python: >=3
10
10
  Requires-Dist: protobuf
11
11
  Requires-Dist: grpcio
12
12
  Requires-Dist: torch>=2.3.0
13
- Requires-Dist: torch_mlir_gml==0.0.2
13
+ Requires-Dist: torch-mlir-gml
14
14
  Requires-Dist: numpy<2.0.0
15
15
  Requires-Dist: transformers>=4.43.3
16
- Version: 0.0.6
16
+ Requires-Dist: safetensors-mlir
17
+ Version: 0.0.8
17
18
 
18
19
  UNKNOWN
@@ -1,13 +1,14 @@
1
1
  gml/__init__.py,sha256=H3WQZ_RaN7VNeb__qeHEbKLEwkaG7gpL5FQ8s1IotUA,773
2
2
  gml/_utils.py,sha256=mSCWHhCdzcUvHqmJIB2FS215K1LMgJCWcZ6e6FWK3hQ,1184
3
- gml/client.py,sha256=5QDKljltBeBTCd2hH38--fTSP0bVVcAvSnWsA9YEFQc,13819
4
- gml/compile.py,sha256=hR8u3LaMiIW8d12FHrvtmtgzUNQq48DYxe8bW-wJ_VY,6054
5
- gml/device.py,sha256=VUZc6m8QalJ7G9KBKjCY4cIcv2VBd6zAT3ysnh_m1Z0,2585
6
- gml/hf.py,sha256=GRvEEl9zSIv0iWN91Z6ykFYZ2VdNAVABjZrrzYWUFw4,17792
7
- gml/model.py,sha256=nXUV6-L4TIkQHCWUpWyG7QJ6YKTZb7eauW9F4pzVTII,6566
3
+ gml/asset_manager.py,sha256=VnbqUZHPOgPrAh6ri9C0EuNhS8tAHIrbUyJPAJuD9po,2053
4
+ gml/client.py,sha256=YFanzPfP619xqBgdfyN_3-Am-jI9eYGvZT8CxhPLTBg,13860
5
+ gml/compile.py,sha256=Ih43r_zU07p91w9aiA0lrPJfmACpAWg0x_HFddMSy7Q,8346
6
+ gml/device.py,sha256=5ocZU_jGUwMfC6PUyAU328Me61BSLwJp8euJCL3mdzo,2550
7
+ gml/hf.py,sha256=e9tw6UGJ1lEZcCplLKo_LgxwTIDWD84DXtQOWZrTR9A,27698
8
+ gml/model.py,sha256=xESdD7tlqn93ym67Lyyk7TZdM3wUqyn7qWdP2AbgdkI,7261
8
9
  gml/model_utils.py,sha256=vZvE5cHZIDkUkeZ4Pk4hhV-zOYMiREluv4b8kdqQ3Ig,1375
9
- gml/pipelines.py,sha256=Bha8J3b5uW8COIejiH12NNF0Tc0XDBt2B3Dez5Jxt4s,5314
10
- gml/preprocessing.py,sha256=STQDSA1_jXPTenJotNtsNMXOc9h1x_wJyQ100LXS6-g,3209
10
+ gml/pipelines.py,sha256=d5Vm4eW2RfFZ1SpRebOGEic5sfbLIpMyH4NOy8wdVyI,7319
11
+ gml/preprocessing.py,sha256=MaKkEW4ZP9fjpkJQfpc0X3rCUuSuSmJnGMClHamKmZU,3210
11
12
  gml/proto/gogoproto/gogo_pb2.py,sha256=WVMIAR8K--mCUkTPM7mEeeXGpQlRRtt_kco10iP3CZs,15728
12
13
  gml/proto/mediapipe/framework/calculator_contract_test_pb2.py,sha256=hNjyZCBz3RYa6rN4xR3FOCZKA24gq_LsJ3EMegl5wK4,2031
13
14
  gml/proto/mediapipe/framework/calculator_options_pb2.py,sha256=Nq1BQRtLdsIgfkw7ymD3eg2p2_RSlZhiHS7YbDhNHR0,1563
@@ -24,14 +25,16 @@ gml/proto/mediapipe/framework/thread_pool_executor_pb2.py,sha256=9TJ66fqSo1BiJmE
24
25
  gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=wQjeDti-C8JiNwRn-z5M5p-Fqxm-SmnbPaoitJcSK-4,2860
25
26
  gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=t2Far6oVcUFQIimzgAkZ8vQd0asMIlvECp4osC0ujgg,9735
26
27
  gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=cbNmE12Nm3PjW4NXU7-Z-9m_0Zs3Ab8R1xLkDnvclCg,1730
27
- gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=FPNx5fRXj-bnN5mkDUXVz17M33vuHV_hmxH0ggkAUVs,5536
28
- gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=u41Sohshi6gBfeZO5VnQzfRStFADFzT1Um5mDY9chcg,15309
29
- gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=1qA-ElTgWeGv3oevYlIjK1TIRSgWbR1TTWxA6Q3SOXk,7224
30
- gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=1DM58lSFgfoHk0ui3ZTjDfifgp4dhE7nHvhMwmInpsA,27103
28
+ gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=dLZM7rnkWjxHOvayCkK4klFe09GMWLfwPt5MLCkZFzQ,8963
29
+ gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=oIpxq13C1ynK3alzDNZTOL5URxz5qzbDLD9NOM5xxjE,14511
30
+ gml/proto/src/api/corepb/v1/device_info_pb2.py,sha256=-z-FOpAOm3NkcNyRFsENpdW_pqYO1JpmIPtlbWNpH_g,4666
31
+ gml/proto/src/api/corepb/v1/gem_config_pb2.py,sha256=yyEqUqq3-YiX-ByAhbTbZfdh09KuNzEtIYhgk_noJVM,3367
32
+ gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=Un9OwDUmWdqv92QP66K-WVOAzxP_4hMoz33JI4W1G5Y,7868
33
+ gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=ydwuRVWXNV0ceZ3WVvBIh74rwLWCJHYm-FgzJWhUNE4,28976
31
34
  gml/proto/src/common/typespb/jwt_pb2.py,sha256=lxy-bqbyg96i9n_xr2JbkuWX-ldnoJavXPMnApzVSio,5580
32
35
  gml/proto/src/common/typespb/status_pb2.py,sha256=IbBJnbsAlvsuTtyT285ZuW6k5VaPfl5kRSOnBxD_H8M,2109
33
36
  gml/proto/src/common/typespb/uuid_pb2.py,sha256=5Fm3jYpCPX7sMrP6RhRYsF0SnuZNIBEQJk9f0jwZ2Rw,1188
34
- gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256=tkJFPWpndKZy19TFuLKlBfWW1fUQPj0lJLiQ9HfugZU,3213
37
+ gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256=-AfLzprMY7wKKGrNgNYFzSv7OlV3YdYolH-KtrK130s,2839
35
38
  gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2_grpc.py,sha256=l-gTK9nYpTlVb7QGAckSQXlHhkRdKe2-nrxXc8NQavY,2912
36
39
  gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py,sha256=KgoUT8ccF-yJPe1r4otQjAPQoKBaQzdBlHoIUSkk0yE,11445
37
40
  gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py,sha256=p3OpT8-hfNHu4-29qr-ZahRwO-LoCYM9Q4jomAHTXGA,24572
@@ -41,7 +44,8 @@ gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=wvLQvoh2UA
41
44
  gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=-snjW7n6JveUzJVPFcm25XlL19kowPSKgd61l_jPnHA,9541
42
45
  gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py,sha256=RVedXkNYu2iF5OHiXoYyRw9AGRCUWG7qNyY-5QY71Go,3762
43
46
  gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py,sha256=KSdb6V04qUHDsb1R2o3wixwTyZgrhwnPYobjnRgWX4I,4735
44
- gml/tensor.py,sha256=753IsMFYZD7p_f0cQPt4nTIBo5p5S5ELqwCuoHORdMk,14823
45
- gimlet_api-0.0.6.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
46
- gimlet_api-0.0.6.dist-info/METADATA,sha256=mF73_t-Tn5NPVxLnJBOTCQkKvb0weobtQLmOqSkc4B0,506
47
- gimlet_api-0.0.6.dist-info/RECORD,,
47
+ gml/register_submodules.py,sha256=U8IwjVygX2vxNi_aK6ljHOD4mmrOhbyVczvy4wwulqU,5027
48
+ gml/tensor.py,sha256=Bv2wshr44ugfdRjWj7JUS8b6_GLBQLZKjGYTojhxm9w,14824
49
+ gimlet_api-0.0.8.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
50
+ gimlet_api-0.0.8.dist-info/METADATA,sha256=P4s8-0QNDA4rY8k0blk87Rh3qHfZr-R5A5yKYkwbl2E,531
51
+ gimlet_api-0.0.8.dist-info/RECORD,,
gml/asset_manager.py ADDED
@@ -0,0 +1,75 @@
1
+ # Copyright 2023- Gimlet Labs, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import abc
18
+ import tempfile
19
+ from pathlib import Path
20
+ from typing import Dict
21
+
22
+
23
+ class AssetManager:
24
+ @abc.abstractmethod
25
+ def add_asset(self, name: str) -> Path:
26
+ pass
27
+
28
+ @abc.abstractmethod
29
+ def assets(self) -> Dict[str, Path]:
30
+ pass
31
+
32
+ def __enter__(self):
33
+ return self
34
+
35
+ def __exit__(self, exc, value, tb):
36
+ pass
37
+
38
+
39
+ class DirectoryAssetManager(AssetManager):
40
+ def __init__(self, path: str | Path):
41
+ self.path = Path(path)
42
+ self._asset_paths: Dict[str, Path] = dict()
43
+
44
+ def add_asset(self, name: str) -> Path:
45
+ path = self.path / name
46
+ self._asset_paths[name] = path
47
+ return path
48
+
49
+ def assets(self) -> Dict[str, Path]:
50
+ return self._asset_paths
51
+
52
+
53
+ class TempFileAssetManager(AssetManager):
54
+ def __init__(self):
55
+ self._assets = dict()
56
+ self._asset_paths = dict()
57
+
58
+ def add_asset(self, name: str) -> Path:
59
+ tmp = tempfile.NamedTemporaryFile(mode="w")
60
+ self._assets[name] = tmp
61
+ file = tmp.__enter__()
62
+ self._asset_paths[name] = Path(file.name)
63
+ return self._asset_paths[name]
64
+
65
+ def assets(self) -> Dict[str, Path]:
66
+ return self._asset_paths
67
+
68
+ def __enter__(self):
69
+ return self
70
+
71
+ def __exit__(self, exc, value, tb):
72
+ for tmp in self._assets.values():
73
+ tmp.__exit__(exc, value, tb)
74
+ self._assets.clear()
75
+ self._asset_paths.clear()
gml/client.py CHANGED
@@ -19,6 +19,8 @@ import uuid
19
19
  from pathlib import Path
20
20
  from typing import BinaryIO, List, Optional, TextIO, Union
21
21
 
22
+ import grpc
23
+
22
24
  import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
23
25
  import gml.proto.src.common.typespb.uuid_pb2 as uuidpb
24
26
  import gml.proto.src.controlplane.compiler.cpb.v1.cpb_pb2 as cpb
@@ -31,7 +33,6 @@ import gml.proto.src.controlplane.logicalpipeline.lppb.v1.lppb_pb2 as lppb
31
33
  import gml.proto.src.controlplane.logicalpipeline.lppb.v1.lppb_pb2_grpc as lppb_grpc
32
34
  import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2 as mpb
33
35
  import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2_grpc as mpb_grpc
34
- import grpc
35
36
  from gml._utils import chunk_file, sha256sum
36
37
  from gml.device import DeviceCapabilities
37
38
  from gml.model import Model
@@ -252,7 +253,7 @@ class Client:
252
253
  raise Exception("file status is deleted or unknown, cannot re-upload")
253
254
  return file_info
254
255
 
255
- def _get_model_if_exists(self, name: str) -> Optional[cpb.Model]:
256
+ def _get_model_if_exists(self, name: str) -> Optional[modelexecpb.Model]:
256
257
  req = mpb.GetModelRequest(
257
258
  name=name,
258
259
  org_id=self._get_org_id(),
@@ -260,13 +261,13 @@ class Client:
260
261
  stub = self._ms_stub()
261
262
  try:
262
263
  resp = stub.GetModel(req, metadata=self._get_request_metadata())
263
- return cpb.Model(id=resp.id, info=resp.model_info)
264
+ return modelexecpb.Model(id=resp.id, info=resp.model_info)
264
265
  except grpc.RpcError as e:
265
266
  if e.code() != grpc.StatusCode.NOT_FOUND:
266
267
  raise e
267
268
  return None
268
269
 
269
- def _create_model(self, model_info: modelexecpb.ModelInfo) -> cpb.Model:
270
+ def _create_model(self, model_info: modelexecpb.ModelInfo) -> modelexecpb.Model:
270
271
  req = mpb.CreateModelRequest(
271
272
  org_id=self._get_org_id(),
272
273
  name=model_info.name,
@@ -276,9 +277,9 @@ class Client:
276
277
  resp = stub.CreateModel(
277
278
  req, metadata=self._get_request_metadata(idempotent=True)
278
279
  )
279
- return cpb.Model(id=resp.id, info=model_info)
280
+ return modelexecpb.Model(id=resp.id, info=model_info)
280
281
 
281
- def create_model(self, model: Model) -> cpb.Model:
282
+ def create_model(self, model: Model) -> modelexecpb.Model:
282
283
  existing_model = self._get_model_if_exists(model.name)
283
284
  if existing_model is not None:
284
285
  print(
gml/compile.py CHANGED
@@ -16,19 +16,33 @@
16
16
 
17
17
  import contextlib
18
18
  import functools
19
- from typing import Dict, List, Optional, Sequence, Union
19
+ from typing import Any, Dict, List, Optional, Sequence, Union
20
20
 
21
- import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
21
+ import safetensors_mlir
22
22
  import torch
23
- import torch_mlir # noqa
24
-
25
- try:
26
- import torch_mlir.fx # noqa
27
- from torch.export import export # noqa
28
-
29
- has_fx_importer_torch_export = True
30
- except ImportError:
31
- has_fx_importer_torch_export = False
23
+ import torch_mlir
24
+ from mlir.ir import (
25
+ BF16Type,
26
+ ComplexType,
27
+ Context,
28
+ F16Type,
29
+ F32Type,
30
+ F64Type,
31
+ IntegerType,
32
+ Operation,
33
+ RankedTensorType,
34
+ Value,
35
+ )
36
+ from safetensors.torch import save_file
37
+ from torch._decomp import remove_decompositions
38
+ from torch.export._trace import _export
39
+ from torch_mlir.dialects import torch as torch_d
40
+ from torch_mlir.extras.fx_decomp_util import get_decomposition_table
41
+ from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks, InputInfo
42
+ from torch_mlir.fx import export_and_import
43
+
44
+ from gml.asset_manager import AssetManager
45
+ from gml.register_submodules import submodule_registration_workarounds
32
46
 
33
47
 
34
48
  def _default_decomposition_denylist():
@@ -61,33 +75,128 @@ def _patch_aot_export_module():
61
75
  torch._functorch.aot_autograd.aot_export_module = orig
62
76
 
63
77
 
64
- def to_torch_mlir_w_torch_export(
78
+ _torch_dtype_to_builtin_element_type = {
79
+ torch.float16: lambda: F16Type.get(),
80
+ torch.bfloat16: lambda: BF16Type.get(),
81
+ torch.float32: lambda: F32Type.get(),
82
+ torch.float64: lambda: F64Type.get(),
83
+ torch.uint8: lambda: IntegerType.get_unsigned(8),
84
+ torch.int8: lambda: IntegerType.get_signless(8),
85
+ torch.int16: lambda: IntegerType.get_signless(16),
86
+ torch.int32: lambda: IntegerType.get_signless(32),
87
+ torch.int64: lambda: IntegerType.get_signless(64),
88
+ torch.bool: lambda: IntegerType.get_signless(1),
89
+ torch.qint8: lambda: IntegerType.get_signless(8),
90
+ torch.quint8: lambda: IntegerType.get_unsigned(8),
91
+ torch.complex32: lambda: ComplexType.get(F16Type.get()),
92
+ torch.complex64: lambda: ComplexType.get(F32Type.get()),
93
+ torch.complex128: lambda: ComplexType.get(F64Type.get()),
94
+ }
95
+
96
+
97
+ def _get_unique_(tensors, name):
98
+ index = 0
99
+ name = "{}_{}".format(name, index)
100
+ while name in tensors:
101
+ index += 1
102
+ name = "{}_{}".format(name, index)
103
+ return name
104
+
105
+
106
+ class TensorSet:
107
+ def __init__(self):
108
+ self._tensors: Dict[str, torch.Tensor] = dict()
109
+
110
+ def add(self, tensor: torch.Tensor) -> str:
111
+ shape_desc = "_".join([str(d) for d in tensor.shape])
112
+ base_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
113
+
114
+ index = 0
115
+ name = "{}_{}".format(base_name, index)
116
+ while name in self._tensors and not torch.equal(tensor, self._tensors[name]):
117
+ index += 1
118
+ name = "{}_{}".format(base_name, index)
119
+
120
+ self._tensors[name] = tensor
121
+ return name
122
+
123
+ def tensors(self) -> Dict[str, torch.Tensor]:
124
+ return self._tensors
125
+
126
+
127
+ class SafetensorImporterHooks(FxImporterHooks):
128
+ def __init__(self, asset_manager: AssetManager):
129
+ self._asset_mgr = asset_manager
130
+ # TODO(james): shard weights into multiple shards.
131
+ self.asset_name = "weights.shard0"
132
+ self._tensors = TensorSet()
133
+
134
+ def resolve_literal(
135
+ self,
136
+ gni: "torch_mlir.extras.fx_importer.GraphNodeImporter",
137
+ literal: Any,
138
+ info: Optional[InputInfo],
139
+ ) -> Optional[Value]:
140
+ if not isinstance(literal, torch.Tensor):
141
+ return None
142
+ tensor = literal
143
+ ctx = gni._c
144
+
145
+ tensor_name = self._tensors.add(tensor)
146
+
147
+ file_attr = safetensors_mlir.FileAttr.get(ctx, self.asset_name)
148
+
149
+ if tensor.dtype not in _torch_dtype_to_builtin_element_type:
150
+ raise ValueError("unsupported torch dtype: {}".format(tensor.dtype))
151
+ elem_type = _torch_dtype_to_builtin_element_type[tensor.dtype]()
152
+ tensor_type = RankedTensorType.get(tuple(tensor.size()), elem_type)
153
+
154
+ tensor_attr = safetensors_mlir.TensorAttr.get(
155
+ tensor_type, file_attr, tensor_name
156
+ )
157
+ builtin_tensor = safetensors_mlir.tensor_ref(tensor_type, tensor_attr)
158
+
159
+ vtensor_type = gni._cc.tensor_to_vtensor_type(tensor)
160
+ return Operation.create(
161
+ name="torch_c.from_builtin_tensor",
162
+ results=[vtensor_type],
163
+ operands=[builtin_tensor],
164
+ ).result
165
+
166
+ def save_tensors(self):
167
+ file_path = self._asset_mgr.add_asset(self.asset_name)
168
+ tensors = self._tensors.tensors()
169
+ for k in tensors:
170
+ tensors[k] = tensors[k].contiguous()
171
+ save_file(tensors, file_path)
172
+
173
+
174
+ def to_torch_mlir(
65
175
  model: torch.nn.Module,
66
176
  example_inputs: Sequence[torch.Tensor],
67
177
  dynamic_shapes: Optional[
68
178
  Sequence[Dict[int, Union[str, "torch.export.dynamic_shapes._Dim"]]]
69
179
  ] = None,
70
180
  decomposition_denylist: Optional[List[torch._ops.OperatorBase]] = None,
181
+ weight_manager: Optional[AssetManager] = None,
71
182
  ):
72
- from torch._decomp import remove_decompositions
73
- from torch.export._trace import _export
74
- from torch_mlir.extras.fx_decomp_util import get_decomposition_table
75
- from torch_mlir.fx import export_and_import
76
-
77
183
  if dynamic_shapes is not None:
78
184
  for shape in dynamic_shapes:
79
185
  if not isinstance(shape, dict):
80
186
  continue
81
187
  for idx in shape:
82
- if isinstance(shape[idx], torch.export.dynamic_shapes._Dim):
188
+ # Assign the value so that pyright understands the type.
189
+ value = shape[idx]
190
+ if isinstance(value, torch.export.dynamic_shapes._Dim):
83
191
  continue
84
- shape[idx] = torch.export.Dim(shape[idx])
85
-
192
+ shape[idx] = torch.export.Dim(value)
86
193
  if decomposition_denylist is None:
87
194
  decomposition_denylist = _default_decomposition_denylist()
88
195
 
89
196
  model = model.eval().to("cpu")
90
197
 
198
+ submodule_registration_workarounds(model)
199
+
91
200
  try:
92
201
  # Running the model a few times on the inputs, leads to more consistent compiled results.
93
202
  for _ in range(2):
@@ -105,76 +214,31 @@ def to_torch_mlir_w_torch_export(
105
214
  )
106
215
  decomp_table = get_decomposition_table()
107
216
  remove_decompositions(decomp_table, decomposition_denylist)
217
+ hooks = None
218
+ if weight_manager is not None:
219
+ hooks = SafetensorImporterHooks(weight_manager)
220
+
221
+ context = Context()
222
+ torch_d.register_dialect(context)
223
+ safetensors_mlir.register_dialect(context)
224
+ fx_importer = FxImporter(context=context, hooks=hooks)
225
+
108
226
  with _patch_aot_export_module():
109
- return export_and_import(
227
+ module = export_and_import(
110
228
  prog,
111
229
  *example_inputs,
112
230
  decomposition_table=decomp_table,
231
+ fx_importer=fx_importer,
113
232
  )
114
233
 
234
+ if hooks is not None:
235
+ hooks.save_tensors()
115
236
 
116
- def to_torch_mlir_fallback(model, example_inputs):
117
- from torch.fx.experimental.proxy_tensor import make_fx
118
- from torch_mlir import ExampleArgs, OutputType
119
- from torch_mlir import compile as torch_mlir_compile
120
- from torch_mlir.dynamo import _get_decomposition_table
121
-
122
- example_args = ExampleArgs.get(example_inputs)
123
- args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)[
124
- "forward"
125
- ]
126
237
  try:
127
- # Running the model a few times on the inputs, leads to more consistent compiled results.
128
- for _ in range(2):
129
- _ = model(*args)
130
- except Exception:
131
- # Ignore errors running the model. This can happen when the model has data dependent branches.
132
- pass
133
-
134
- try:
135
- compiled = torch_mlir_compile(
136
- model,
137
- example_inputs,
138
- use_tracing=False,
139
- ignore_traced_shapes=False,
140
- output_type=OutputType.RAW,
141
- use_make_fx=False,
142
- )
143
- return compiled
144
- except Exception:
145
- pass
146
-
147
- # If the module can't be exported directly, we try to create an FX graph and then export it.
148
- model = make_fx(
149
- model, pre_dispatch=True, decomposition_table=_get_decomposition_table()
150
- )(*args)
151
- compiled = torch_mlir_compile(
152
- model,
153
- example_inputs,
154
- use_tracing=False,
155
- ignore_traced_shapes=False,
156
- output_type=OutputType.RAW,
157
- use_make_fx=False,
158
- )
159
-
160
- return compiled
161
-
162
-
163
- def to_torch_mlir(
164
- model,
165
- example_inputs,
166
- dynamic_shapes: Optional[
167
- Sequence[Dict[int, Union[str, "torch.export.dynamic_shapes._Dim"]]]
168
- ] = None,
169
- ):
170
- if has_fx_importer_torch_export:
171
- return to_torch_mlir_w_torch_export(model, example_inputs, dynamic_shapes)
172
- else:
173
- return to_torch_mlir_fallback(model, example_inputs)
174
-
238
+ module.operation.verify()
239
+ except Exception as exc:
240
+ raise Exception(
241
+ "failed to verify converted torch model MLIR module: {}".format(module)
242
+ ) from exc
175
243
 
176
- def torch_mlir_output_kind():
177
- if has_fx_importer_torch_export:
178
- return modelexecpb.ModelInfo.MODEL_KIND_TORCH
179
- else:
180
- return modelexecpb.ModelInfo.MODEL_KIND_TORCHSCRIPT
244
+ return module
gml/device.py CHANGED
@@ -17,6 +17,7 @@
17
17
  from typing import List
18
18
 
19
19
  import gml.proto.src.api.corepb.v1.cp_edge_pb2 as cpedgepb
20
+ import gml.proto.src.api.corepb.v1.device_info_pb2 as deviceinfopb
20
21
 
21
22
 
22
23
  class DeviceCapabilities:
@@ -27,48 +28,46 @@ class DeviceCapabilities:
27
28
  def to_proto(self) -> cpedgepb.DeviceCapabilities:
28
29
  return cpedgepb.DeviceCapabilities(
29
30
  model_runtimes=[
30
- cpedgepb.DeviceCapabilities.ModelRuntimeInfo(
31
+ deviceinfopb.ModelRuntimeInfo(
31
32
  type=_runtime_str_to_runtime_protos(runtime)
32
33
  )
33
34
  for runtime in self.runtimes
34
35
  ],
35
36
  cameras=[
36
- cpedgepb.DeviceCapabilities.CameraInfo(
37
+ deviceinfopb.CameraInfo(
37
38
  driver=_camera_driver_str_to_camera_driver_protos(camera),
38
39
  camera_id=str(idx),
39
40
  )
40
41
  for idx, camera in enumerate(self.cameras)
41
42
  ],
43
+ camera_drivers=[
44
+ deviceinfopb.CameraDriverInfo(
45
+ driver=_camera_driver_str_to_camera_driver_protos(camera)
46
+ )
47
+ for camera in self.cameras
48
+ ],
42
49
  )
43
50
 
44
51
 
45
52
  def _runtime_str_to_runtime_protos(
46
53
  runtime: str,
47
- ) -> cpedgepb.DeviceCapabilities.ModelRuntimeInfo.ModelRuntimeType:
54
+ ) -> deviceinfopb.ModelRuntimeType:
48
55
  match runtime.lower():
49
56
  case "tensorrt":
50
- return (
51
- cpedgepb.DeviceCapabilities.ModelRuntimeInfo.ModelRuntimeType.MODEL_RUNTIME_TYPE_TENSORRT
52
- )
57
+ return deviceinfopb.ModelRuntimeType.MODEL_RUNTIME_TYPE_TENSORRT
53
58
  case "openvino":
54
- return (
55
- cpedgepb.DeviceCapabilities.ModelRuntimeInfo.ModelRuntimeType.MODEL_RUNTIME_TYPE_OPENVINO
56
- )
59
+ return deviceinfopb.ModelRuntimeType.MODEL_RUNTIME_TYPE_OPENVINO
57
60
  case _:
58
61
  raise ValueError("invalid runtime: {}".format(runtime))
59
62
 
60
63
 
61
64
  def _camera_driver_str_to_camera_driver_protos(
62
65
  driver: str,
63
- ) -> cpedgepb.DeviceCapabilities.CameraInfo.CameraDriver:
66
+ ) -> deviceinfopb.CameraDriver:
64
67
  match driver.lower():
65
68
  case "argus":
66
- return (
67
- cpedgepb.DeviceCapabilities.CameraInfo.CameraDriver.CAMERA_DRIVER_ARGUS
68
- )
69
+ return deviceinfopb.CameraDriver.CAMERA_DRIVER_ARGUS
69
70
  case "v4l2":
70
- return (
71
- cpedgepb.DeviceCapabilities.CameraInfo.CameraDriver.CAMERA_DRIVER_V4L2
72
- )
71
+ return deviceinfopb.CameraDriver.CAMERA_DRIVER_V4L2
73
72
  case _:
74
73
  raise ValueError("invalid driver: {}".format(driver))