gimlet-api 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gimlet_api-0.0.6.dist-info → gimlet_api-0.0.8.dist-info}/METADATA +3 -2
- {gimlet_api-0.0.6.dist-info → gimlet_api-0.0.8.dist-info}/RECORD +20 -16
- gml/asset_manager.py +75 -0
- gml/client.py +7 -6
- gml/compile.py +148 -84
- gml/device.py +15 -16
- gml/hf.py +299 -34
- gml/model.py +28 -12
- gml/pipelines.py +120 -40
- gml/preprocessing.py +2 -1
- gml/proto/src/api/corepb/v1/controlplane_pb2.py +37 -18
- gml/proto/src/api/corepb/v1/cp_edge_pb2.py +67 -77
- gml/proto/src/api/corepb/v1/device_info_pb2.py +51 -0
- gml/proto/src/api/corepb/v1/gem_config_pb2.py +45 -0
- gml/proto/src/api/corepb/v1/mediastream_pb2.py +23 -19
- gml/proto/src/api/corepb/v1/model_exec_pb2.py +127 -112
- gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py +7 -11
- gml/register_submodules.py +134 -0
- gml/tensor.py +2 -1
- {gimlet_api-0.0.6.dist-info → gimlet_api-0.0.8.dist-info}/WHEEL +0 -0
@@ -10,9 +10,10 @@ Requires-Python: >=3
|
|
10
10
|
Requires-Dist: protobuf
|
11
11
|
Requires-Dist: grpcio
|
12
12
|
Requires-Dist: torch>=2.3.0
|
13
|
-
Requires-Dist:
|
13
|
+
Requires-Dist: torch-mlir-gml
|
14
14
|
Requires-Dist: numpy<2.0.0
|
15
15
|
Requires-Dist: transformers>=4.43.3
|
16
|
-
|
16
|
+
Requires-Dist: safetensors-mlir
|
17
|
+
Version: 0.0.8
|
17
18
|
|
18
19
|
UNKNOWN
|
@@ -1,13 +1,14 @@
|
|
1
1
|
gml/__init__.py,sha256=H3WQZ_RaN7VNeb__qeHEbKLEwkaG7gpL5FQ8s1IotUA,773
|
2
2
|
gml/_utils.py,sha256=mSCWHhCdzcUvHqmJIB2FS215K1LMgJCWcZ6e6FWK3hQ,1184
|
3
|
-
gml/
|
4
|
-
gml/
|
5
|
-
gml/
|
6
|
-
gml/
|
7
|
-
gml/
|
3
|
+
gml/asset_manager.py,sha256=VnbqUZHPOgPrAh6ri9C0EuNhS8tAHIrbUyJPAJuD9po,2053
|
4
|
+
gml/client.py,sha256=YFanzPfP619xqBgdfyN_3-Am-jI9eYGvZT8CxhPLTBg,13860
|
5
|
+
gml/compile.py,sha256=Ih43r_zU07p91w9aiA0lrPJfmACpAWg0x_HFddMSy7Q,8346
|
6
|
+
gml/device.py,sha256=5ocZU_jGUwMfC6PUyAU328Me61BSLwJp8euJCL3mdzo,2550
|
7
|
+
gml/hf.py,sha256=e9tw6UGJ1lEZcCplLKo_LgxwTIDWD84DXtQOWZrTR9A,27698
|
8
|
+
gml/model.py,sha256=xESdD7tlqn93ym67Lyyk7TZdM3wUqyn7qWdP2AbgdkI,7261
|
8
9
|
gml/model_utils.py,sha256=vZvE5cHZIDkUkeZ4Pk4hhV-zOYMiREluv4b8kdqQ3Ig,1375
|
9
|
-
gml/pipelines.py,sha256=
|
10
|
-
gml/preprocessing.py,sha256=
|
10
|
+
gml/pipelines.py,sha256=d5Vm4eW2RfFZ1SpRebOGEic5sfbLIpMyH4NOy8wdVyI,7319
|
11
|
+
gml/preprocessing.py,sha256=MaKkEW4ZP9fjpkJQfpc0X3rCUuSuSmJnGMClHamKmZU,3210
|
11
12
|
gml/proto/gogoproto/gogo_pb2.py,sha256=WVMIAR8K--mCUkTPM7mEeeXGpQlRRtt_kco10iP3CZs,15728
|
12
13
|
gml/proto/mediapipe/framework/calculator_contract_test_pb2.py,sha256=hNjyZCBz3RYa6rN4xR3FOCZKA24gq_LsJ3EMegl5wK4,2031
|
13
14
|
gml/proto/mediapipe/framework/calculator_options_pb2.py,sha256=Nq1BQRtLdsIgfkw7ymD3eg2p2_RSlZhiHS7YbDhNHR0,1563
|
@@ -24,14 +25,16 @@ gml/proto/mediapipe/framework/thread_pool_executor_pb2.py,sha256=9TJ66fqSo1BiJmE
|
|
24
25
|
gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=wQjeDti-C8JiNwRn-z5M5p-Fqxm-SmnbPaoitJcSK-4,2860
|
25
26
|
gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=t2Far6oVcUFQIimzgAkZ8vQd0asMIlvECp4osC0ujgg,9735
|
26
27
|
gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=cbNmE12Nm3PjW4NXU7-Z-9m_0Zs3Ab8R1xLkDnvclCg,1730
|
27
|
-
gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=
|
28
|
-
gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=
|
29
|
-
gml/proto/src/api/corepb/v1/
|
30
|
-
gml/proto/src/api/corepb/v1/
|
28
|
+
gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=dLZM7rnkWjxHOvayCkK4klFe09GMWLfwPt5MLCkZFzQ,8963
|
29
|
+
gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=oIpxq13C1ynK3alzDNZTOL5URxz5qzbDLD9NOM5xxjE,14511
|
30
|
+
gml/proto/src/api/corepb/v1/device_info_pb2.py,sha256=-z-FOpAOm3NkcNyRFsENpdW_pqYO1JpmIPtlbWNpH_g,4666
|
31
|
+
gml/proto/src/api/corepb/v1/gem_config_pb2.py,sha256=yyEqUqq3-YiX-ByAhbTbZfdh09KuNzEtIYhgk_noJVM,3367
|
32
|
+
gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=Un9OwDUmWdqv92QP66K-WVOAzxP_4hMoz33JI4W1G5Y,7868
|
33
|
+
gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=ydwuRVWXNV0ceZ3WVvBIh74rwLWCJHYm-FgzJWhUNE4,28976
|
31
34
|
gml/proto/src/common/typespb/jwt_pb2.py,sha256=lxy-bqbyg96i9n_xr2JbkuWX-ldnoJavXPMnApzVSio,5580
|
32
35
|
gml/proto/src/common/typespb/status_pb2.py,sha256=IbBJnbsAlvsuTtyT285ZuW6k5VaPfl5kRSOnBxD_H8M,2109
|
33
36
|
gml/proto/src/common/typespb/uuid_pb2.py,sha256=5Fm3jYpCPX7sMrP6RhRYsF0SnuZNIBEQJk9f0jwZ2Rw,1188
|
34
|
-
gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256
|
37
|
+
gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256=-AfLzprMY7wKKGrNgNYFzSv7OlV3YdYolH-KtrK130s,2839
|
35
38
|
gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2_grpc.py,sha256=l-gTK9nYpTlVb7QGAckSQXlHhkRdKe2-nrxXc8NQavY,2912
|
36
39
|
gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py,sha256=KgoUT8ccF-yJPe1r4otQjAPQoKBaQzdBlHoIUSkk0yE,11445
|
37
40
|
gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py,sha256=p3OpT8-hfNHu4-29qr-ZahRwO-LoCYM9Q4jomAHTXGA,24572
|
@@ -41,7 +44,8 @@ gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=wvLQvoh2UA
|
|
41
44
|
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=-snjW7n6JveUzJVPFcm25XlL19kowPSKgd61l_jPnHA,9541
|
42
45
|
gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py,sha256=RVedXkNYu2iF5OHiXoYyRw9AGRCUWG7qNyY-5QY71Go,3762
|
43
46
|
gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py,sha256=KSdb6V04qUHDsb1R2o3wixwTyZgrhwnPYobjnRgWX4I,4735
|
44
|
-
gml/
|
45
|
-
|
46
|
-
gimlet_api-0.0.
|
47
|
-
gimlet_api-0.0.
|
47
|
+
gml/register_submodules.py,sha256=U8IwjVygX2vxNi_aK6ljHOD4mmrOhbyVczvy4wwulqU,5027
|
48
|
+
gml/tensor.py,sha256=Bv2wshr44ugfdRjWj7JUS8b6_GLBQLZKjGYTojhxm9w,14824
|
49
|
+
gimlet_api-0.0.8.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
|
50
|
+
gimlet_api-0.0.8.dist-info/METADATA,sha256=P4s8-0QNDA4rY8k0blk87Rh3qHfZr-R5A5yKYkwbl2E,531
|
51
|
+
gimlet_api-0.0.8.dist-info/RECORD,,
|
gml/asset_manager.py
ADDED
@@ -0,0 +1,75 @@
|
|
1
|
+
# Copyright 2023- Gimlet Labs, Inc.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
# SPDX-License-Identifier: Apache-2.0
|
16
|
+
|
17
|
+
import abc
|
18
|
+
import tempfile
|
19
|
+
from pathlib import Path
|
20
|
+
from typing import Dict
|
21
|
+
|
22
|
+
|
23
|
+
class AssetManager:
|
24
|
+
@abc.abstractmethod
|
25
|
+
def add_asset(self, name: str) -> Path:
|
26
|
+
pass
|
27
|
+
|
28
|
+
@abc.abstractmethod
|
29
|
+
def assets(self) -> Dict[str, Path]:
|
30
|
+
pass
|
31
|
+
|
32
|
+
def __enter__(self):
|
33
|
+
return self
|
34
|
+
|
35
|
+
def __exit__(self, exc, value, tb):
|
36
|
+
pass
|
37
|
+
|
38
|
+
|
39
|
+
class DirectoryAssetManager(AssetManager):
|
40
|
+
def __init__(self, path: str | Path):
|
41
|
+
self.path = Path(path)
|
42
|
+
self._asset_paths: Dict[str, Path] = dict()
|
43
|
+
|
44
|
+
def add_asset(self, name: str) -> Path:
|
45
|
+
path = self.path / name
|
46
|
+
self._asset_paths[name] = path
|
47
|
+
return path
|
48
|
+
|
49
|
+
def assets(self) -> Dict[str, Path]:
|
50
|
+
return self._asset_paths
|
51
|
+
|
52
|
+
|
53
|
+
class TempFileAssetManager(AssetManager):
|
54
|
+
def __init__(self):
|
55
|
+
self._assets = dict()
|
56
|
+
self._asset_paths = dict()
|
57
|
+
|
58
|
+
def add_asset(self, name: str) -> Path:
|
59
|
+
tmp = tempfile.NamedTemporaryFile(mode="w")
|
60
|
+
self._assets[name] = tmp
|
61
|
+
file = tmp.__enter__()
|
62
|
+
self._asset_paths[name] = Path(file.name)
|
63
|
+
return self._asset_paths[name]
|
64
|
+
|
65
|
+
def assets(self) -> Dict[str, Path]:
|
66
|
+
return self._asset_paths
|
67
|
+
|
68
|
+
def __enter__(self):
|
69
|
+
return self
|
70
|
+
|
71
|
+
def __exit__(self, exc, value, tb):
|
72
|
+
for tmp in self._assets.values():
|
73
|
+
tmp.__exit__(exc, value, tb)
|
74
|
+
self._assets.clear()
|
75
|
+
self._asset_paths.clear()
|
gml/client.py
CHANGED
@@ -19,6 +19,8 @@ import uuid
|
|
19
19
|
from pathlib import Path
|
20
20
|
from typing import BinaryIO, List, Optional, TextIO, Union
|
21
21
|
|
22
|
+
import grpc
|
23
|
+
|
22
24
|
import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
|
23
25
|
import gml.proto.src.common.typespb.uuid_pb2 as uuidpb
|
24
26
|
import gml.proto.src.controlplane.compiler.cpb.v1.cpb_pb2 as cpb
|
@@ -31,7 +33,6 @@ import gml.proto.src.controlplane.logicalpipeline.lppb.v1.lppb_pb2 as lppb
|
|
31
33
|
import gml.proto.src.controlplane.logicalpipeline.lppb.v1.lppb_pb2_grpc as lppb_grpc
|
32
34
|
import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2 as mpb
|
33
35
|
import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2_grpc as mpb_grpc
|
34
|
-
import grpc
|
35
36
|
from gml._utils import chunk_file, sha256sum
|
36
37
|
from gml.device import DeviceCapabilities
|
37
38
|
from gml.model import Model
|
@@ -252,7 +253,7 @@ class Client:
|
|
252
253
|
raise Exception("file status is deleted or unknown, cannot re-upload")
|
253
254
|
return file_info
|
254
255
|
|
255
|
-
def _get_model_if_exists(self, name: str) -> Optional[
|
256
|
+
def _get_model_if_exists(self, name: str) -> Optional[modelexecpb.Model]:
|
256
257
|
req = mpb.GetModelRequest(
|
257
258
|
name=name,
|
258
259
|
org_id=self._get_org_id(),
|
@@ -260,13 +261,13 @@ class Client:
|
|
260
261
|
stub = self._ms_stub()
|
261
262
|
try:
|
262
263
|
resp = stub.GetModel(req, metadata=self._get_request_metadata())
|
263
|
-
return
|
264
|
+
return modelexecpb.Model(id=resp.id, info=resp.model_info)
|
264
265
|
except grpc.RpcError as e:
|
265
266
|
if e.code() != grpc.StatusCode.NOT_FOUND:
|
266
267
|
raise e
|
267
268
|
return None
|
268
269
|
|
269
|
-
def _create_model(self, model_info: modelexecpb.ModelInfo) ->
|
270
|
+
def _create_model(self, model_info: modelexecpb.ModelInfo) -> modelexecpb.Model:
|
270
271
|
req = mpb.CreateModelRequest(
|
271
272
|
org_id=self._get_org_id(),
|
272
273
|
name=model_info.name,
|
@@ -276,9 +277,9 @@ class Client:
|
|
276
277
|
resp = stub.CreateModel(
|
277
278
|
req, metadata=self._get_request_metadata(idempotent=True)
|
278
279
|
)
|
279
|
-
return
|
280
|
+
return modelexecpb.Model(id=resp.id, info=model_info)
|
280
281
|
|
281
|
-
def create_model(self, model: Model) ->
|
282
|
+
def create_model(self, model: Model) -> modelexecpb.Model:
|
282
283
|
existing_model = self._get_model_if_exists(model.name)
|
283
284
|
if existing_model is not None:
|
284
285
|
print(
|
gml/compile.py
CHANGED
@@ -16,19 +16,33 @@
|
|
16
16
|
|
17
17
|
import contextlib
|
18
18
|
import functools
|
19
|
-
from typing import Dict, List, Optional, Sequence, Union
|
19
|
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
20
20
|
|
21
|
-
import
|
21
|
+
import safetensors_mlir
|
22
22
|
import torch
|
23
|
-
import torch_mlir
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
23
|
+
import torch_mlir
|
24
|
+
from mlir.ir import (
|
25
|
+
BF16Type,
|
26
|
+
ComplexType,
|
27
|
+
Context,
|
28
|
+
F16Type,
|
29
|
+
F32Type,
|
30
|
+
F64Type,
|
31
|
+
IntegerType,
|
32
|
+
Operation,
|
33
|
+
RankedTensorType,
|
34
|
+
Value,
|
35
|
+
)
|
36
|
+
from safetensors.torch import save_file
|
37
|
+
from torch._decomp import remove_decompositions
|
38
|
+
from torch.export._trace import _export
|
39
|
+
from torch_mlir.dialects import torch as torch_d
|
40
|
+
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
41
|
+
from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks, InputInfo
|
42
|
+
from torch_mlir.fx import export_and_import
|
43
|
+
|
44
|
+
from gml.asset_manager import AssetManager
|
45
|
+
from gml.register_submodules import submodule_registration_workarounds
|
32
46
|
|
33
47
|
|
34
48
|
def _default_decomposition_denylist():
|
@@ -61,33 +75,128 @@ def _patch_aot_export_module():
|
|
61
75
|
torch._functorch.aot_autograd.aot_export_module = orig
|
62
76
|
|
63
77
|
|
64
|
-
|
78
|
+
_torch_dtype_to_builtin_element_type = {
|
79
|
+
torch.float16: lambda: F16Type.get(),
|
80
|
+
torch.bfloat16: lambda: BF16Type.get(),
|
81
|
+
torch.float32: lambda: F32Type.get(),
|
82
|
+
torch.float64: lambda: F64Type.get(),
|
83
|
+
torch.uint8: lambda: IntegerType.get_unsigned(8),
|
84
|
+
torch.int8: lambda: IntegerType.get_signless(8),
|
85
|
+
torch.int16: lambda: IntegerType.get_signless(16),
|
86
|
+
torch.int32: lambda: IntegerType.get_signless(32),
|
87
|
+
torch.int64: lambda: IntegerType.get_signless(64),
|
88
|
+
torch.bool: lambda: IntegerType.get_signless(1),
|
89
|
+
torch.qint8: lambda: IntegerType.get_signless(8),
|
90
|
+
torch.quint8: lambda: IntegerType.get_unsigned(8),
|
91
|
+
torch.complex32: lambda: ComplexType.get(F16Type.get()),
|
92
|
+
torch.complex64: lambda: ComplexType.get(F32Type.get()),
|
93
|
+
torch.complex128: lambda: ComplexType.get(F64Type.get()),
|
94
|
+
}
|
95
|
+
|
96
|
+
|
97
|
+
def _get_unique_(tensors, name):
|
98
|
+
index = 0
|
99
|
+
name = "{}_{}".format(name, index)
|
100
|
+
while name in tensors:
|
101
|
+
index += 1
|
102
|
+
name = "{}_{}".format(name, index)
|
103
|
+
return name
|
104
|
+
|
105
|
+
|
106
|
+
class TensorSet:
|
107
|
+
def __init__(self):
|
108
|
+
self._tensors: Dict[str, torch.Tensor] = dict()
|
109
|
+
|
110
|
+
def add(self, tensor: torch.Tensor) -> str:
|
111
|
+
shape_desc = "_".join([str(d) for d in tensor.shape])
|
112
|
+
base_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
|
113
|
+
|
114
|
+
index = 0
|
115
|
+
name = "{}_{}".format(base_name, index)
|
116
|
+
while name in self._tensors and not torch.equal(tensor, self._tensors[name]):
|
117
|
+
index += 1
|
118
|
+
name = "{}_{}".format(base_name, index)
|
119
|
+
|
120
|
+
self._tensors[name] = tensor
|
121
|
+
return name
|
122
|
+
|
123
|
+
def tensors(self) -> Dict[str, torch.Tensor]:
|
124
|
+
return self._tensors
|
125
|
+
|
126
|
+
|
127
|
+
class SafetensorImporterHooks(FxImporterHooks):
|
128
|
+
def __init__(self, asset_manager: AssetManager):
|
129
|
+
self._asset_mgr = asset_manager
|
130
|
+
# TODO(james): shard weights into multiple shards.
|
131
|
+
self.asset_name = "weights.shard0"
|
132
|
+
self._tensors = TensorSet()
|
133
|
+
|
134
|
+
def resolve_literal(
|
135
|
+
self,
|
136
|
+
gni: "torch_mlir.extras.fx_importer.GraphNodeImporter",
|
137
|
+
literal: Any,
|
138
|
+
info: Optional[InputInfo],
|
139
|
+
) -> Optional[Value]:
|
140
|
+
if not isinstance(literal, torch.Tensor):
|
141
|
+
return None
|
142
|
+
tensor = literal
|
143
|
+
ctx = gni._c
|
144
|
+
|
145
|
+
tensor_name = self._tensors.add(tensor)
|
146
|
+
|
147
|
+
file_attr = safetensors_mlir.FileAttr.get(ctx, self.asset_name)
|
148
|
+
|
149
|
+
if tensor.dtype not in _torch_dtype_to_builtin_element_type:
|
150
|
+
raise ValueError("unsupported torch dtype: {}".format(tensor.dtype))
|
151
|
+
elem_type = _torch_dtype_to_builtin_element_type[tensor.dtype]()
|
152
|
+
tensor_type = RankedTensorType.get(tuple(tensor.size()), elem_type)
|
153
|
+
|
154
|
+
tensor_attr = safetensors_mlir.TensorAttr.get(
|
155
|
+
tensor_type, file_attr, tensor_name
|
156
|
+
)
|
157
|
+
builtin_tensor = safetensors_mlir.tensor_ref(tensor_type, tensor_attr)
|
158
|
+
|
159
|
+
vtensor_type = gni._cc.tensor_to_vtensor_type(tensor)
|
160
|
+
return Operation.create(
|
161
|
+
name="torch_c.from_builtin_tensor",
|
162
|
+
results=[vtensor_type],
|
163
|
+
operands=[builtin_tensor],
|
164
|
+
).result
|
165
|
+
|
166
|
+
def save_tensors(self):
|
167
|
+
file_path = self._asset_mgr.add_asset(self.asset_name)
|
168
|
+
tensors = self._tensors.tensors()
|
169
|
+
for k in tensors:
|
170
|
+
tensors[k] = tensors[k].contiguous()
|
171
|
+
save_file(tensors, file_path)
|
172
|
+
|
173
|
+
|
174
|
+
def to_torch_mlir(
|
65
175
|
model: torch.nn.Module,
|
66
176
|
example_inputs: Sequence[torch.Tensor],
|
67
177
|
dynamic_shapes: Optional[
|
68
178
|
Sequence[Dict[int, Union[str, "torch.export.dynamic_shapes._Dim"]]]
|
69
179
|
] = None,
|
70
180
|
decomposition_denylist: Optional[List[torch._ops.OperatorBase]] = None,
|
181
|
+
weight_manager: Optional[AssetManager] = None,
|
71
182
|
):
|
72
|
-
from torch._decomp import remove_decompositions
|
73
|
-
from torch.export._trace import _export
|
74
|
-
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
75
|
-
from torch_mlir.fx import export_and_import
|
76
|
-
|
77
183
|
if dynamic_shapes is not None:
|
78
184
|
for shape in dynamic_shapes:
|
79
185
|
if not isinstance(shape, dict):
|
80
186
|
continue
|
81
187
|
for idx in shape:
|
82
|
-
|
188
|
+
# Assign the value so that pyright understands the type.
|
189
|
+
value = shape[idx]
|
190
|
+
if isinstance(value, torch.export.dynamic_shapes._Dim):
|
83
191
|
continue
|
84
|
-
shape[idx] = torch.export.Dim(
|
85
|
-
|
192
|
+
shape[idx] = torch.export.Dim(value)
|
86
193
|
if decomposition_denylist is None:
|
87
194
|
decomposition_denylist = _default_decomposition_denylist()
|
88
195
|
|
89
196
|
model = model.eval().to("cpu")
|
90
197
|
|
198
|
+
submodule_registration_workarounds(model)
|
199
|
+
|
91
200
|
try:
|
92
201
|
# Running the model a few times on the inputs, leads to more consistent compiled results.
|
93
202
|
for _ in range(2):
|
@@ -105,76 +214,31 @@ def to_torch_mlir_w_torch_export(
|
|
105
214
|
)
|
106
215
|
decomp_table = get_decomposition_table()
|
107
216
|
remove_decompositions(decomp_table, decomposition_denylist)
|
217
|
+
hooks = None
|
218
|
+
if weight_manager is not None:
|
219
|
+
hooks = SafetensorImporterHooks(weight_manager)
|
220
|
+
|
221
|
+
context = Context()
|
222
|
+
torch_d.register_dialect(context)
|
223
|
+
safetensors_mlir.register_dialect(context)
|
224
|
+
fx_importer = FxImporter(context=context, hooks=hooks)
|
225
|
+
|
108
226
|
with _patch_aot_export_module():
|
109
|
-
|
227
|
+
module = export_and_import(
|
110
228
|
prog,
|
111
229
|
*example_inputs,
|
112
230
|
decomposition_table=decomp_table,
|
231
|
+
fx_importer=fx_importer,
|
113
232
|
)
|
114
233
|
|
234
|
+
if hooks is not None:
|
235
|
+
hooks.save_tensors()
|
115
236
|
|
116
|
-
def to_torch_mlir_fallback(model, example_inputs):
|
117
|
-
from torch.fx.experimental.proxy_tensor import make_fx
|
118
|
-
from torch_mlir import ExampleArgs, OutputType
|
119
|
-
from torch_mlir import compile as torch_mlir_compile
|
120
|
-
from torch_mlir.dynamo import _get_decomposition_table
|
121
|
-
|
122
|
-
example_args = ExampleArgs.get(example_inputs)
|
123
|
-
args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)[
|
124
|
-
"forward"
|
125
|
-
]
|
126
237
|
try:
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
pass
|
133
|
-
|
134
|
-
try:
|
135
|
-
compiled = torch_mlir_compile(
|
136
|
-
model,
|
137
|
-
example_inputs,
|
138
|
-
use_tracing=False,
|
139
|
-
ignore_traced_shapes=False,
|
140
|
-
output_type=OutputType.RAW,
|
141
|
-
use_make_fx=False,
|
142
|
-
)
|
143
|
-
return compiled
|
144
|
-
except Exception:
|
145
|
-
pass
|
146
|
-
|
147
|
-
# If the module can't be exported directly, we try to create an FX graph and then export it.
|
148
|
-
model = make_fx(
|
149
|
-
model, pre_dispatch=True, decomposition_table=_get_decomposition_table()
|
150
|
-
)(*args)
|
151
|
-
compiled = torch_mlir_compile(
|
152
|
-
model,
|
153
|
-
example_inputs,
|
154
|
-
use_tracing=False,
|
155
|
-
ignore_traced_shapes=False,
|
156
|
-
output_type=OutputType.RAW,
|
157
|
-
use_make_fx=False,
|
158
|
-
)
|
159
|
-
|
160
|
-
return compiled
|
161
|
-
|
162
|
-
|
163
|
-
def to_torch_mlir(
|
164
|
-
model,
|
165
|
-
example_inputs,
|
166
|
-
dynamic_shapes: Optional[
|
167
|
-
Sequence[Dict[int, Union[str, "torch.export.dynamic_shapes._Dim"]]]
|
168
|
-
] = None,
|
169
|
-
):
|
170
|
-
if has_fx_importer_torch_export:
|
171
|
-
return to_torch_mlir_w_torch_export(model, example_inputs, dynamic_shapes)
|
172
|
-
else:
|
173
|
-
return to_torch_mlir_fallback(model, example_inputs)
|
174
|
-
|
238
|
+
module.operation.verify()
|
239
|
+
except Exception as exc:
|
240
|
+
raise Exception(
|
241
|
+
"failed to verify converted torch model MLIR module: {}".format(module)
|
242
|
+
) from exc
|
175
243
|
|
176
|
-
|
177
|
-
if has_fx_importer_torch_export:
|
178
|
-
return modelexecpb.ModelInfo.MODEL_KIND_TORCH
|
179
|
-
else:
|
180
|
-
return modelexecpb.ModelInfo.MODEL_KIND_TORCHSCRIPT
|
244
|
+
return module
|
gml/device.py
CHANGED
@@ -17,6 +17,7 @@
|
|
17
17
|
from typing import List
|
18
18
|
|
19
19
|
import gml.proto.src.api.corepb.v1.cp_edge_pb2 as cpedgepb
|
20
|
+
import gml.proto.src.api.corepb.v1.device_info_pb2 as deviceinfopb
|
20
21
|
|
21
22
|
|
22
23
|
class DeviceCapabilities:
|
@@ -27,48 +28,46 @@ class DeviceCapabilities:
|
|
27
28
|
def to_proto(self) -> cpedgepb.DeviceCapabilities:
|
28
29
|
return cpedgepb.DeviceCapabilities(
|
29
30
|
model_runtimes=[
|
30
|
-
|
31
|
+
deviceinfopb.ModelRuntimeInfo(
|
31
32
|
type=_runtime_str_to_runtime_protos(runtime)
|
32
33
|
)
|
33
34
|
for runtime in self.runtimes
|
34
35
|
],
|
35
36
|
cameras=[
|
36
|
-
|
37
|
+
deviceinfopb.CameraInfo(
|
37
38
|
driver=_camera_driver_str_to_camera_driver_protos(camera),
|
38
39
|
camera_id=str(idx),
|
39
40
|
)
|
40
41
|
for idx, camera in enumerate(self.cameras)
|
41
42
|
],
|
43
|
+
camera_drivers=[
|
44
|
+
deviceinfopb.CameraDriverInfo(
|
45
|
+
driver=_camera_driver_str_to_camera_driver_protos(camera)
|
46
|
+
)
|
47
|
+
for camera in self.cameras
|
48
|
+
],
|
42
49
|
)
|
43
50
|
|
44
51
|
|
45
52
|
def _runtime_str_to_runtime_protos(
|
46
53
|
runtime: str,
|
47
|
-
) ->
|
54
|
+
) -> deviceinfopb.ModelRuntimeType:
|
48
55
|
match runtime.lower():
|
49
56
|
case "tensorrt":
|
50
|
-
return
|
51
|
-
cpedgepb.DeviceCapabilities.ModelRuntimeInfo.ModelRuntimeType.MODEL_RUNTIME_TYPE_TENSORRT
|
52
|
-
)
|
57
|
+
return deviceinfopb.ModelRuntimeType.MODEL_RUNTIME_TYPE_TENSORRT
|
53
58
|
case "openvino":
|
54
|
-
return
|
55
|
-
cpedgepb.DeviceCapabilities.ModelRuntimeInfo.ModelRuntimeType.MODEL_RUNTIME_TYPE_OPENVINO
|
56
|
-
)
|
59
|
+
return deviceinfopb.ModelRuntimeType.MODEL_RUNTIME_TYPE_OPENVINO
|
57
60
|
case _:
|
58
61
|
raise ValueError("invalid runtime: {}".format(runtime))
|
59
62
|
|
60
63
|
|
61
64
|
def _camera_driver_str_to_camera_driver_protos(
|
62
65
|
driver: str,
|
63
|
-
) ->
|
66
|
+
) -> deviceinfopb.CameraDriver:
|
64
67
|
match driver.lower():
|
65
68
|
case "argus":
|
66
|
-
return
|
67
|
-
cpedgepb.DeviceCapabilities.CameraInfo.CameraDriver.CAMERA_DRIVER_ARGUS
|
68
|
-
)
|
69
|
+
return deviceinfopb.CameraDriver.CAMERA_DRIVER_ARGUS
|
69
70
|
case "v4l2":
|
70
|
-
return
|
71
|
-
cpedgepb.DeviceCapabilities.CameraInfo.CameraDriver.CAMERA_DRIVER_V4L2
|
72
|
-
)
|
71
|
+
return deviceinfopb.CameraDriver.CAMERA_DRIVER_V4L2
|
73
72
|
case _:
|
74
73
|
raise ValueError("invalid driver: {}".format(driver))
|