gimlet-api 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gimlet_api-0.0.4.dist-info/METADATA +16 -0
- gimlet_api-0.0.4.dist-info/RECORD +43 -0
- {gimlet_api-0.0.1.dist-info → gimlet_api-0.0.4.dist-info}/WHEEL +1 -2
- gml/__init__.py +18 -0
- gml/_utils.py +42 -0
- gml/client.py +308 -0
- gml/compile.py +44 -8
- gml/model.py +150 -0
- gml/model_utils.py +33 -0
- gml/pipelines.py +149 -0
- gml/preprocessing.py +78 -0
- gml/proto/gogoproto/gogo_pb2.py +101 -0
- gml/proto/mediapipe/framework/calculator_contract_test_pb2.py +32 -0
- gml/proto/mediapipe/framework/calculator_options_pb2.py +28 -0
- gml/proto/mediapipe/framework/calculator_pb2.py +56 -0
- gml/proto/mediapipe/framework/calculator_profile_pb2.py +47 -0
- gml/proto/mediapipe/framework/mediapipe_options_pb2.py +26 -0
- gml/proto/mediapipe/framework/packet_factory_pb2.py +30 -0
- gml/proto/mediapipe/framework/packet_generator_pb2.py +32 -0
- gml/proto/mediapipe/framework/packet_test_pb2.py +32 -0
- gml/proto/mediapipe/framework/status_handler_pb2.py +27 -0
- gml/proto/mediapipe/framework/stream_handler_pb2.py +29 -0
- gml/proto/mediapipe/framework/test_calculators_pb2.py +32 -0
- gml/proto/mediapipe/framework/thread_pool_executor_pb2.py +30 -0
- gml/proto/opentelemetry/proto/common/v1/common_pb2.py +34 -0
- gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py +62 -0
- gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py +27 -0
- gml/proto/src/api/corepb/v1/controlplane_pb2.py +60 -0
- gml/proto/src/api/corepb/v1/cp_edge_pb2.py +117 -0
- gml/proto/src/api/corepb/v1/mediastream_pb2.py +66 -0
- gml/proto/src/api/corepb/v1/model_exec_pb2.py +167 -0
- gml/proto/src/common/typespb/jwt_pb2.py +61 -0
- gml/proto/src/common/typespb/status_pb2.py +29 -0
- gml/proto/src/common/typespb/uuid_pb2.py +26 -0
- gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py +115 -0
- gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py +452 -0
- gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py +70 -0
- gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py +231 -0
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py +57 -0
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py +132 -0
- gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py +47 -0
- gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py +99 -0
- gml/tensor.py +333 -0
- gimlet_api-0.0.1.dist-info/METADATA +0 -6
- gimlet_api-0.0.1.dist-info/RECORD +0 -6
- gimlet_api-0.0.1.dist-info/top_level.txt +0 -1
@@ -0,0 +1,16 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: gimlet-api
|
3
|
+
Author: Gimlet Labs, Inc.
|
4
|
+
Author-email: support@gimletlabs.ai
|
5
|
+
Classifier: Programming Language :: Python :: 3
|
6
|
+
Classifier: Operating System :: OS Independent
|
7
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
8
|
+
Classifier: Typing :: Typed
|
9
|
+
Requires-Python: >=3
|
10
|
+
Requires-Dist: protobuf
|
11
|
+
Requires-Dist: grpcio
|
12
|
+
Requires-Dist: torch
|
13
|
+
Requires-Dist: torch_mlir_gml
|
14
|
+
Version: 0.0.4
|
15
|
+
|
16
|
+
UNKNOWN
|
@@ -0,0 +1,43 @@
|
|
1
|
+
gml/__init__.py,sha256=e81afOgzDBpBNVb3pv_9Y0waGPuX95Y73ofWsq-UVak,718
|
2
|
+
gml/_utils.py,sha256=mSCWHhCdzcUvHqmJIB2FS215K1LMgJCWcZ6e6FWK3hQ,1184
|
3
|
+
gml/client.py,sha256=4wJ5JeyhMc7UTJazD0GvcHH90Oqiu-x7UY-Q1Zl2Wxo,10640
|
4
|
+
gml/compile.py,sha256=pS_mjATIvAH-zxSbIgZeizmpi08571zpOkn5bWRyzM0,2087
|
5
|
+
gml/model.py,sha256=91I0kCyix1e2ZR7UAdxtyxOIS6pTEfUQ3G62MBkYx5k,5221
|
6
|
+
gml/model_utils.py,sha256=67w32zLI7rjlNsKW6nvCUp-HUjYNMCKMDhjymgJ6EOU,1275
|
7
|
+
gml/pipelines.py,sha256=RObsuECms-dGFAMskxlknxfV5-9d5sEOll7McUP5reo,3737
|
8
|
+
gml/preprocessing.py,sha256=w-D_GzRndY-Hec8O5UdMSVDXMWtohbwPRxizUjhHsFA,3004
|
9
|
+
gml/proto/gogoproto/gogo_pb2.py,sha256=WVMIAR8K--mCUkTPM7mEeeXGpQlRRtt_kco10iP3CZs,15728
|
10
|
+
gml/proto/mediapipe/framework/calculator_contract_test_pb2.py,sha256=hNjyZCBz3RYa6rN4xR3FOCZKA24gq_LsJ3EMegl5wK4,2031
|
11
|
+
gml/proto/mediapipe/framework/calculator_options_pb2.py,sha256=Nq1BQRtLdsIgfkw7ymD3eg2p2_RSlZhiHS7YbDhNHR0,1563
|
12
|
+
gml/proto/mediapipe/framework/calculator_pb2.py,sha256=sP2SNeN0mp3YXtWaGFKMDdC3R_uclrLxnYgxWwTy4p8,9092
|
13
|
+
gml/proto/mediapipe/framework/calculator_profile_pb2.py,sha256=uy2yO6DffTL7f4HNuxvpztSZefVE3-SVxvnyxR49Ya4,6335
|
14
|
+
gml/proto/mediapipe/framework/mediapipe_options_pb2.py,sha256=84T1x8HgcwmtF0Eq8ZKeNMWwG6_COAr2246hZE611n0,1318
|
15
|
+
gml/proto/mediapipe/framework/packet_factory_pb2.py,sha256=AfSB4dAuJbDUNgay93Yi56Z9_DLjJZ-XeD6gNZeuxCc,1916
|
16
|
+
gml/proto/mediapipe/framework/packet_generator_pb2.py,sha256=ORDDp5UBLwMPhYiXoAbW8TSLZ6QTaMUeSrmh5jwHPpA,2167
|
17
|
+
gml/proto/mediapipe/framework/packet_test_pb2.py,sha256=K8Uo_1VzSg-u5HFgNjFLQptVsakzHNFeWJz3eldto5g,1797
|
18
|
+
gml/proto/mediapipe/framework/status_handler_pb2.py,sha256=dgiW2ohm-ho07z1k4TM_Xt0MLvC1ZtDgkhhXDByshMk,1707
|
19
|
+
gml/proto/mediapipe/framework/stream_handler_pb2.py,sha256=kNo-2Fdua_CeyJInI3q5r9IoAUanjhk9jh01Z1KXu6Q,2043
|
20
|
+
gml/proto/mediapipe/framework/test_calculators_pb2.py,sha256=tXF25VpGtHGArffRqFmjD6FO7xmuCPd5j9UYON2SVSM,2230
|
21
|
+
gml/proto/mediapipe/framework/thread_pool_executor_pb2.py,sha256=9TJ66fqSo1BiJmEAQesK0fnVe55zcJpOqVip6HotgyE,2345
|
22
|
+
gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=wQjeDti-C8JiNwRn-z5M5p-Fqxm-SmnbPaoitJcSK-4,2860
|
23
|
+
gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=t2Far6oVcUFQIimzgAkZ8vQd0asMIlvECp4osC0ujgg,9735
|
24
|
+
gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=cbNmE12Nm3PjW4NXU7-Z-9m_0Zs3Ab8R1xLkDnvclCg,1730
|
25
|
+
gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=FPNx5fRXj-bnN5mkDUXVz17M33vuHV_hmxH0ggkAUVs,5536
|
26
|
+
gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=QPd5gMZvt97ukeHbrgqwroJH4oKTNtjEiAlgfk-FH-k,13496
|
27
|
+
gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=yNTIivI9LChFifs90eO5dKwVmAR3WGPYQWXHOuXGUD4,6528
|
28
|
+
gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=U8ujw_7ijKrULiCqxJ4WY9VWiUtIemfIy2e6h1NPrto,25106
|
29
|
+
gml/proto/src/common/typespb/jwt_pb2.py,sha256=lxy-bqbyg96i9n_xr2JbkuWX-ldnoJavXPMnApzVSio,5580
|
30
|
+
gml/proto/src/common/typespb/status_pb2.py,sha256=IbBJnbsAlvsuTtyT285ZuW6k5VaPfl5kRSOnBxD_H8M,2109
|
31
|
+
gml/proto/src/common/typespb/uuid_pb2.py,sha256=5Fm3jYpCPX7sMrP6RhRYsF0SnuZNIBEQJk9f0jwZ2Rw,1188
|
32
|
+
gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py,sha256=KgoUT8ccF-yJPe1r4otQjAPQoKBaQzdBlHoIUSkk0yE,11445
|
33
|
+
gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py,sha256=p3OpT8-hfNHu4-29qr-ZahRwO-LoCYM9Q4jomAHTXGA,24572
|
34
|
+
gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py,sha256=r8mbJNTq45_c0amPnTr8OFZasCk7XWu2YS_eu7GfWJg,7050
|
35
|
+
gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py,sha256=XlE4R2PJaOmzQocx7y6SKJvuqt8tYBGzBuhajvzG0cc,12919
|
36
|
+
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=AvPGThkFumcElIYY9fjqWdTX9J44RCEJbaYi1F7gn4c,5661
|
37
|
+
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=q1PugN3Jm_4v5hVWADJLCIkIEC2_beKEqEH4vb_SpH8,7396
|
38
|
+
gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py,sha256=RVedXkNYu2iF5OHiXoYyRw9AGRCUWG7qNyY-5QY71Go,3762
|
39
|
+
gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py,sha256=KSdb6V04qUHDsb1R2o3wixwTyZgrhwnPYobjnRgWX4I,4735
|
40
|
+
gml/tensor.py,sha256=8Ne95YFAkid_jjbtuXFtn_Eu0Hn9u5IaBRuwR7BpzxI,11827
|
41
|
+
gimlet_api-0.0.4.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
|
42
|
+
gimlet_api-0.0.4.dist-info/METADATA,sha256=Bm_ca7tpLR7foub0QOavwvJ9_2_4FQlFPr6n8zwoJBU,429
|
43
|
+
gimlet_api-0.0.4.dist-info/RECORD,,
|
gml/__init__.py
CHANGED
@@ -0,0 +1,18 @@
|
|
1
|
+
# Copyright 2023- Gimlet Labs, Inc.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
# SPDX-License-Identifier: Apache-2.0
|
16
|
+
|
17
|
+
from gml.client import Client # noqa
|
18
|
+
from gml.model import ModelFromFiles, TorchModel # noqa
|
gml/_utils.py
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
# Copyright 2023- Gimlet Labs, Inc.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
# SPDX-License-Identifier: Apache-2.0
|
16
|
+
|
17
|
+
import hashlib
|
18
|
+
import os
|
19
|
+
from typing import BinaryIO, TextIO
|
20
|
+
|
21
|
+
|
22
|
+
def chunk_file(f: TextIO | BinaryIO, chunk_size=64 * 1024):
|
23
|
+
while True:
|
24
|
+
chunk = f.read(chunk_size)
|
25
|
+
if not chunk:
|
26
|
+
break
|
27
|
+
yield chunk
|
28
|
+
|
29
|
+
|
30
|
+
def sha256sum(f: BinaryIO, buffer_size=64 * 1024):
|
31
|
+
sha256sum = hashlib.sha256()
|
32
|
+
for chunk in chunk_file(f, buffer_size):
|
33
|
+
sha256sum.update(chunk)
|
34
|
+
|
35
|
+
return sha256sum.hexdigest()
|
36
|
+
|
37
|
+
|
38
|
+
def get_file_size(f: TextIO | BinaryIO):
|
39
|
+
f.seek(0, os.SEEK_END)
|
40
|
+
file_size = f.tell()
|
41
|
+
f.seek(0)
|
42
|
+
return file_size
|
gml/client.py
ADDED
@@ -0,0 +1,308 @@
|
|
1
|
+
# Copyright 2023- Gimlet Labs, Inc.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
# SPDX-License-Identifier: Apache-2.0
|
16
|
+
|
17
|
+
import os
|
18
|
+
import uuid
|
19
|
+
from pathlib import Path
|
20
|
+
from typing import BinaryIO, List, Optional, TextIO, Union
|
21
|
+
|
22
|
+
import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
|
23
|
+
import gml.proto.src.common.typespb.uuid_pb2 as uuidpb
|
24
|
+
import gml.proto.src.controlplane.directory.directorypb.v1.directory_pb2 as directorypb
|
25
|
+
import gml.proto.src.controlplane.directory.directorypb.v1.directory_pb2_grpc as directorypb_grpc
|
26
|
+
import gml.proto.src.controlplane.filetransfer.ftpb.v1.ftpb_pb2 as ftpb
|
27
|
+
import gml.proto.src.controlplane.filetransfer.ftpb.v1.ftpb_pb2_grpc as ftpb_grpc
|
28
|
+
import gml.proto.src.controlplane.logicalpipeline.lppb.v1.lppb_pb2 as lppb
|
29
|
+
import gml.proto.src.controlplane.logicalpipeline.lppb.v1.lppb_pb2_grpc as lppb_grpc
|
30
|
+
import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2 as mpb
|
31
|
+
import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2_grpc as mpb_grpc
|
32
|
+
import grpc
|
33
|
+
from gml._utils import chunk_file, sha256sum
|
34
|
+
from gml.model import Model
|
35
|
+
from gml.pipelines import Pipeline
|
36
|
+
|
37
|
+
DEFAULT_CONTROLPLANE_ADDR = "app.gimletlabs.ai"
|
38
|
+
|
39
|
+
|
40
|
+
class _ChannelFactory:
|
41
|
+
"""
|
42
|
+
_ChannelFactory creates grpc channels to a controlplane.
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(self, controlplane_addr: str, insecure_no_ssl=False):
|
46
|
+
self.controlplane_addr = controlplane_addr
|
47
|
+
self.insecure_no_ssl = insecure_no_ssl
|
48
|
+
|
49
|
+
self._channel_cache: grpc.Channel = None
|
50
|
+
|
51
|
+
def get_grpc_channel(self) -> grpc.Channel:
|
52
|
+
if self._channel_cache is not None:
|
53
|
+
return self._channel_cache
|
54
|
+
return self._create_grpc_channel()
|
55
|
+
|
56
|
+
def _create_grpc_channel(self) -> grpc.Channel:
|
57
|
+
if self.insecure_no_ssl:
|
58
|
+
return grpc.insecure_channel(self.controlplane_addr)
|
59
|
+
|
60
|
+
creds = grpc.ssl_channel_credentials()
|
61
|
+
return grpc.secure_channel(self.controlplane_addr, creds)
|
62
|
+
|
63
|
+
|
64
|
+
class FileAlreadyExists(Exception):
|
65
|
+
pass
|
66
|
+
|
67
|
+
|
68
|
+
class OrgNotSet(Exception):
|
69
|
+
pass
|
70
|
+
|
71
|
+
|
72
|
+
class APIKeyNotSet(Exception):
|
73
|
+
pass
|
74
|
+
|
75
|
+
|
76
|
+
class Client:
|
77
|
+
"""
|
78
|
+
Client provides authorized access to a controlplane.
|
79
|
+
"""
|
80
|
+
|
81
|
+
def __init__(
|
82
|
+
self,
|
83
|
+
api_key: Optional[str] = None,
|
84
|
+
controlplane_addr: Optional[str] = None,
|
85
|
+
org: Optional[str] = None,
|
86
|
+
insecure_no_ssl: bool = False,
|
87
|
+
):
|
88
|
+
self._org_name = org
|
89
|
+
self._api_key = api_key
|
90
|
+
if self._api_key is None:
|
91
|
+
self._api_key = os.getenv("GML_API_KEY")
|
92
|
+
if self._api_key is None:
|
93
|
+
raise APIKeyNotSet(
|
94
|
+
"must provide api_key explicitly or through environment variable GML_API_KEY"
|
95
|
+
)
|
96
|
+
|
97
|
+
self._controlplane_addr = controlplane_addr
|
98
|
+
if self._controlplane_addr is None:
|
99
|
+
self._controlplane_addr = os.getenv("GML_CONTROLPLANE_ADDR")
|
100
|
+
if self._controlplane_addr is None:
|
101
|
+
self._controlplane_addr = DEFAULT_CONTROLPLANE_ADDR
|
102
|
+
|
103
|
+
self._channel_factory = _ChannelFactory(
|
104
|
+
self._controlplane_addr, insecure_no_ssl=insecure_no_ssl
|
105
|
+
)
|
106
|
+
|
107
|
+
self._org_id_cache: Optional[uuidpb.UUID] = None
|
108
|
+
self._fts_stub_cache: Optional[ftpb_grpc.FileTransferServiceStub] = None
|
109
|
+
self._lps_stub_cache: Optional[lppb_grpc.LogicalPipelineServiceStub] = None
|
110
|
+
self._ods_stub_cache: Optional[directorypb_grpc.OrgDirectoryServiceStub] = None
|
111
|
+
self._ms_stub_cache: Optional[mpb_grpc.ModelServiceStub] = None
|
112
|
+
|
113
|
+
def _get_request_metadata(self, idempotent=False):
|
114
|
+
md = [("x-api-key", self._api_key)]
|
115
|
+
if idempotent:
|
116
|
+
md.append(("x-idempotency-key", uuid.uuid4().hex))
|
117
|
+
return md
|
118
|
+
|
119
|
+
def _fts_stub(self):
|
120
|
+
if self._fts_stub_cache is None:
|
121
|
+
self._fts_stub_cache = ftpb_grpc.FileTransferServiceStub(
|
122
|
+
self._channel_factory.get_grpc_channel()
|
123
|
+
)
|
124
|
+
return self._fts_stub_cache
|
125
|
+
|
126
|
+
def _lps_stub(self):
|
127
|
+
if self._lps_stub_cache is None:
|
128
|
+
self._lps_stub_cache = lppb_grpc.LogicalPipelineServiceStub(
|
129
|
+
self._channel_factory.get_grpc_channel()
|
130
|
+
)
|
131
|
+
return self._lps_stub_cache
|
132
|
+
|
133
|
+
def _ods_stub(self):
|
134
|
+
if self._ods_stub_cache is None:
|
135
|
+
self._ods_stub_cache = directorypb_grpc.OrgDirectoryServiceStub(
|
136
|
+
self._channel_factory.get_grpc_channel()
|
137
|
+
)
|
138
|
+
return self._ods_stub_cache
|
139
|
+
|
140
|
+
def _ms_stub(self):
|
141
|
+
if self._ms_stub_cache is None:
|
142
|
+
self._ms_stub_cache = mpb_grpc.ModelServiceStub(
|
143
|
+
self._channel_factory.get_grpc_channel()
|
144
|
+
)
|
145
|
+
return self._ms_stub_cache
|
146
|
+
|
147
|
+
def _get_org_id(self):
|
148
|
+
if self._org_name is None:
|
149
|
+
raise OrgNotSet("organization not set for method that is org specific")
|
150
|
+
stub = self._ods_stub()
|
151
|
+
req = directorypb.GetOrgRequest(org_name=self._org_name)
|
152
|
+
resp: directorypb.GetOrgResponse = stub.GetOrg(
|
153
|
+
req, metadata=self._get_request_metadata()
|
154
|
+
)
|
155
|
+
return resp.org_info.id
|
156
|
+
|
157
|
+
def _org_id(self):
|
158
|
+
if self._org_id_cache is None:
|
159
|
+
self._org_id_cache = self._get_org_id()
|
160
|
+
return self._org_id_cache
|
161
|
+
|
162
|
+
def _create_file(self, name: str) -> ftpb.FileInfo:
|
163
|
+
stub = self._fts_stub()
|
164
|
+
try:
|
165
|
+
req = ftpb.CreateFileInfoRequest(name=name)
|
166
|
+
resp: ftpb.CreateFileInfoResponse = stub.CreateFileInfo(
|
167
|
+
req, metadata=self._get_request_metadata()
|
168
|
+
)
|
169
|
+
except grpc.RpcError as e:
|
170
|
+
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
|
171
|
+
raise FileAlreadyExists(
|
172
|
+
f"A file already exists with name: {name}"
|
173
|
+
) from e
|
174
|
+
raise e
|
175
|
+
return resp.info
|
176
|
+
|
177
|
+
def _file_info_by_name(self, name: str) -> ftpb.GetFileInfoByNameResponse:
|
178
|
+
stub = self._fts_stub()
|
179
|
+
|
180
|
+
req = ftpb.GetFileInfoByNameRequest(name=name)
|
181
|
+
return stub.GetFileInfoByName(req, metadata=self._get_request_metadata()).info
|
182
|
+
|
183
|
+
def _upload_created_file(
|
184
|
+
self,
|
185
|
+
file_id: uuidpb.UUID,
|
186
|
+
sha256: str,
|
187
|
+
file: TextIO | BinaryIO,
|
188
|
+
chunk_size=64 * 1024,
|
189
|
+
):
|
190
|
+
def chunked_requests():
|
191
|
+
file.seek(0)
|
192
|
+
for chunk in chunk_file(file, chunk_size):
|
193
|
+
req = ftpb.UploadFileRequest(
|
194
|
+
file_id=file_id, sha256sum=sha256, chunk=chunk
|
195
|
+
)
|
196
|
+
yield req
|
197
|
+
|
198
|
+
stub = self._fts_stub()
|
199
|
+
resp: ftpb.UploadFileResponse = stub.UploadFile(
|
200
|
+
chunked_requests(), metadata=self._get_request_metadata()
|
201
|
+
)
|
202
|
+
return resp
|
203
|
+
|
204
|
+
def upload_file(
|
205
|
+
self,
|
206
|
+
name: str,
|
207
|
+
file: TextIO | BinaryIO,
|
208
|
+
sha256: Optional[str] = None,
|
209
|
+
chunk_size=64 * 1024,
|
210
|
+
) -> ftpb.FileInfo:
|
211
|
+
file_info = self._create_file(name)
|
212
|
+
|
213
|
+
if sha256 is None:
|
214
|
+
sha256 = sha256sum(file)
|
215
|
+
self._upload_created_file(file_info.file_id, sha256, file, chunk_size)
|
216
|
+
return self._file_info_by_name(name)
|
217
|
+
|
218
|
+
def _upload_file_if_not_exists(
|
219
|
+
self,
|
220
|
+
name: str,
|
221
|
+
file: TextIO | BinaryIO,
|
222
|
+
sha256: Optional[str] = None,
|
223
|
+
) -> ftpb.FileInfo:
|
224
|
+
file_info: Optional[ftpb.FileInfo] = None
|
225
|
+
try:
|
226
|
+
file_info = self.upload_file(name, file, sha256)
|
227
|
+
except FileAlreadyExists:
|
228
|
+
file_info = self._file_info_by_name(name)
|
229
|
+
|
230
|
+
match file_info.status:
|
231
|
+
case ftpb.FILE_STATUS_READY:
|
232
|
+
pass
|
233
|
+
case ftpb.FILE_STATUS_CREATED:
|
234
|
+
self._upload_created_file(file_info.file_id, sha256, file)
|
235
|
+
file_info = self._file_info_by_name(name)
|
236
|
+
case _:
|
237
|
+
raise Exception("file status is deleted or unknown, cannot re-upload")
|
238
|
+
return file_info
|
239
|
+
|
240
|
+
def _create_model(self, model_info: modelexecpb.ModelInfo):
|
241
|
+
req = mpb.CreateModelRequest(
|
242
|
+
org_id=self._get_org_id(),
|
243
|
+
name=model_info.name,
|
244
|
+
model_info=model_info,
|
245
|
+
)
|
246
|
+
stub = self._ms_stub()
|
247
|
+
resp = stub.CreateModel(
|
248
|
+
req, metadata=self._get_request_metadata(idempotent=True)
|
249
|
+
)
|
250
|
+
return resp.id
|
251
|
+
|
252
|
+
def create_model(self, model: Model):
|
253
|
+
model_info = model.to_proto()
|
254
|
+
for asset_name, file in model.collect_assets().items():
|
255
|
+
if isinstance(file, Path) or isinstance(file, str):
|
256
|
+
file = open(file, "rb")
|
257
|
+
|
258
|
+
sha256 = sha256sum(file)
|
259
|
+
|
260
|
+
upload_name = model.name
|
261
|
+
if asset_name:
|
262
|
+
upload_name += ":" + asset_name
|
263
|
+
print(f"Uploading {upload_name}...")
|
264
|
+
|
265
|
+
file_info = self._upload_file_if_not_exists(sha256, file, sha256)
|
266
|
+
|
267
|
+
model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
|
268
|
+
|
269
|
+
file.close()
|
270
|
+
|
271
|
+
return self._create_model(model_info)
|
272
|
+
|
273
|
+
def upload_pipeline(
|
274
|
+
self,
|
275
|
+
*,
|
276
|
+
name: str,
|
277
|
+
models: List[Model],
|
278
|
+
pipeline_file: Optional[Path] = None,
|
279
|
+
pipeline: Optional[Union[str, Pipeline]] = None,
|
280
|
+
) -> uuidpb.UUID:
|
281
|
+
if pipeline_file is not None:
|
282
|
+
with open(pipeline_file, "r") as f:
|
283
|
+
yaml = f.read()
|
284
|
+
elif pipeline is not None:
|
285
|
+
if isinstance(pipeline, Pipeline):
|
286
|
+
if self._org_name is None:
|
287
|
+
raise ValueError("must set `org` to upload a pipeline")
|
288
|
+
yaml = pipeline.to_yaml(
|
289
|
+
[model.name for model in models], self._org_name
|
290
|
+
)
|
291
|
+
else:
|
292
|
+
yaml = pipeline
|
293
|
+
else:
|
294
|
+
raise ValueError("must specify one of 'pipeline_file' or 'pipeline'")
|
295
|
+
|
296
|
+
for model in models:
|
297
|
+
self.create_model(model)
|
298
|
+
|
299
|
+
stub = self._lps_stub()
|
300
|
+
req = lppb.CreateLogicalPipelineRequest(
|
301
|
+
org_id=self._org_id(),
|
302
|
+
name=name,
|
303
|
+
yaml=yaml,
|
304
|
+
)
|
305
|
+
resp: lppb.CreateLogicalPipelineResponse = stub.CreateLogicalPipeline(
|
306
|
+
req, metadata=self._get_request_metadata(idempotent=True)
|
307
|
+
)
|
308
|
+
return resp.id
|
gml/compile.py
CHANGED
@@ -1,19 +1,55 @@
|
|
1
|
-
|
1
|
+
# Copyright 2023- Gimlet Labs, Inc.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
16
|
|
3
|
-
|
17
|
+
from torch.fx.experimental.proxy_tensor import make_fx
|
18
|
+
from torch_mlir import ExampleArgs, OutputType
|
4
19
|
from torch_mlir import compile as torch_mlir_compile
|
5
|
-
from torch_mlir import OutputType
|
6
20
|
from torch_mlir.dynamo import _get_decomposition_table
|
7
21
|
|
8
22
|
|
9
23
|
def to_torch_mlir(model, example_inputs):
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
24
|
+
example_args = ExampleArgs.get(example_inputs)
|
25
|
+
args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)[
|
26
|
+
"forward"
|
27
|
+
]
|
28
|
+
try:
|
29
|
+
# Running the model a few times on the inputs, leads to more consistent compiled results.
|
30
|
+
for _ in range(2):
|
31
|
+
_ = model(*args)
|
32
|
+
except Exception:
|
33
|
+
# Ignore errors running the model. This can happen when the model has data dependent branches.
|
34
|
+
pass
|
14
35
|
|
15
|
-
|
36
|
+
try:
|
37
|
+
compiled = torch_mlir_compile(
|
38
|
+
model,
|
39
|
+
example_inputs,
|
40
|
+
use_tracing=False,
|
41
|
+
ignore_traced_shapes=False,
|
42
|
+
output_type=OutputType.RAW,
|
43
|
+
use_make_fx=False,
|
44
|
+
)
|
45
|
+
return compiled
|
46
|
+
except Exception:
|
47
|
+
pass
|
16
48
|
|
49
|
+
# If the module can't be exported directly, we try to create an FX graph and then export it.
|
50
|
+
model = make_fx(
|
51
|
+
model, pre_dispatch=True, decomposition_table=_get_decomposition_table()
|
52
|
+
)(*args)
|
17
53
|
compiled = torch_mlir_compile(
|
18
54
|
model,
|
19
55
|
example_inputs,
|
gml/model.py
ADDED
@@ -0,0 +1,150 @@
|
|
1
|
+
# Copyright 2023- Gimlet Labs, Inc.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
# SPDX-License-Identifier: Apache-2.0
|
16
|
+
|
17
|
+
import abc
|
18
|
+
import io
|
19
|
+
from pathlib import Path
|
20
|
+
from typing import BinaryIO, Dict, List, Literal, Optional, TextIO, Tuple
|
21
|
+
|
22
|
+
import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
|
23
|
+
import torch
|
24
|
+
import torch_mlir
|
25
|
+
from gml.compile import to_torch_mlir
|
26
|
+
from gml.preprocessing import ImagePreprocessingStep
|
27
|
+
from gml.tensor import TensorSemantics
|
28
|
+
|
29
|
+
|
30
|
+
class Model(abc.ABC):
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
name: str,
|
34
|
+
kind: modelexecpb.ModelInfo.ModelKind,
|
35
|
+
storage_format: modelexecpb.ModelInfo.ModelStorageFormat,
|
36
|
+
input_tensor_semantics: List[TensorSemantics],
|
37
|
+
output_tensor_semantics: List[TensorSemantics],
|
38
|
+
class_labels: Optional[List[str]] = None,
|
39
|
+
class_labels_file: Optional[Path] = None,
|
40
|
+
image_preprocessing_steps: Optional[List[ImagePreprocessingStep]] = None,
|
41
|
+
):
|
42
|
+
self.name = name
|
43
|
+
self.kind = kind
|
44
|
+
self.storage_format = storage_format
|
45
|
+
self.class_labels = class_labels
|
46
|
+
if class_labels_file:
|
47
|
+
self.class_labels = []
|
48
|
+
with open(class_labels_file, "r") as f:
|
49
|
+
for line in f.readlines():
|
50
|
+
self.class_labels.append(line.strip())
|
51
|
+
self.input_tensor_semantics = input_tensor_semantics
|
52
|
+
self.output_tensor_semantics = output_tensor_semantics
|
53
|
+
self.image_preprocessing_steps = image_preprocessing_steps
|
54
|
+
|
55
|
+
def to_proto(self) -> modelexecpb.ModelInfo:
|
56
|
+
image_preprocessing_steps = None
|
57
|
+
if self.image_preprocessing_steps:
|
58
|
+
image_preprocessing_steps = [
|
59
|
+
step.to_proto() for step in self.image_preprocessing_steps
|
60
|
+
]
|
61
|
+
return modelexecpb.ModelInfo(
|
62
|
+
name=self.name,
|
63
|
+
kind=self.kind,
|
64
|
+
format=self.storage_format,
|
65
|
+
class_labels=self.class_labels,
|
66
|
+
image_preprocessing_steps=image_preprocessing_steps,
|
67
|
+
input_tensor_semantics=[
|
68
|
+
semantics.to_proto() for semantics in self.input_tensor_semantics
|
69
|
+
],
|
70
|
+
output_tensor_semantics=[
|
71
|
+
semantics.to_proto() for semantics in self.output_tensor_semantics
|
72
|
+
],
|
73
|
+
)
|
74
|
+
|
75
|
+
@abc.abstractmethod
|
76
|
+
def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
|
77
|
+
pass
|
78
|
+
|
79
|
+
|
80
|
+
class TorchModel(Model):
|
81
|
+
def __init__(
|
82
|
+
self,
|
83
|
+
name: str,
|
84
|
+
torch_module: torch.nn.Module,
|
85
|
+
input_shapes: List[List[int]],
|
86
|
+
input_dtypes: List[torch.dtype],
|
87
|
+
**kwargs,
|
88
|
+
):
|
89
|
+
super().__init__(
|
90
|
+
name,
|
91
|
+
modelexecpb.ModelInfo.MODEL_KIND_TORCHSCRIPT,
|
92
|
+
modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_MLIR_TEXT,
|
93
|
+
**kwargs,
|
94
|
+
)
|
95
|
+
self.torch_module = torch_module
|
96
|
+
self.input_shapes = input_shapes
|
97
|
+
self.input_dtypes = input_dtypes
|
98
|
+
|
99
|
+
def _convert_to_torch_mlir(self):
|
100
|
+
return to_torch_mlir(
|
101
|
+
self.torch_module.to("cpu"),
|
102
|
+
[
|
103
|
+
torch_mlir.TensorPlaceholder(shape, dtype)
|
104
|
+
for shape, dtype in zip(self.input_shapes, self.input_dtypes)
|
105
|
+
],
|
106
|
+
)
|
107
|
+
|
108
|
+
def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
|
109
|
+
compiled = self._convert_to_torch_mlir()
|
110
|
+
file = io.BytesIO(str(compiled).encode("utf-8"))
|
111
|
+
return {"": file}
|
112
|
+
|
113
|
+
|
114
|
+
def _kind_str_to_kind_format_protos(
|
115
|
+
kind: str,
|
116
|
+
) -> Tuple[modelexecpb.ModelInfo.ModelKind, modelexecpb.ModelInfo.ModelStorageFormat]:
|
117
|
+
match kind.lower():
|
118
|
+
case "openvino":
|
119
|
+
return (
|
120
|
+
modelexecpb.ModelInfo.MODEL_KIND_OPENVINO,
|
121
|
+
modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_OPENVINO,
|
122
|
+
)
|
123
|
+
case "onnx":
|
124
|
+
return (
|
125
|
+
modelexecpb.ModelInfo.MODEL_KIND_ONNX,
|
126
|
+
modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_PROTOBUF,
|
127
|
+
)
|
128
|
+
case "tfl":
|
129
|
+
return (
|
130
|
+
modelexecpb.ModelInfo.MODEL_KIND_TFLITE,
|
131
|
+
modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_FLATBUFFER,
|
132
|
+
)
|
133
|
+
case _:
|
134
|
+
raise ValueError("invalid model kind: {}".format(kind))
|
135
|
+
|
136
|
+
|
137
|
+
class ModelFromFiles(Model):
|
138
|
+
def __init__(
|
139
|
+
self,
|
140
|
+
name: str,
|
141
|
+
kind: Literal["openvino", "onnx", "tfl"],
|
142
|
+
files: Dict[str, TextIO | BinaryIO | Path],
|
143
|
+
**kwargs,
|
144
|
+
):
|
145
|
+
kind, storage_format = _kind_str_to_kind_format_protos(kind)
|
146
|
+
super().__init__(name=name, kind=kind, storage_format=storage_format, **kwargs)
|
147
|
+
self.files = files
|
148
|
+
|
149
|
+
def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
|
150
|
+
return self.files
|
gml/model_utils.py
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
# Copyright 2023- Gimlet Labs, Inc.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
# SPDX-License-Identifier: Apache-2.0
|
16
|
+
|
17
|
+
|
18
|
+
def prepare_ultralytics_yolo(model):
|
19
|
+
"""Prepares an ultralytics YOLO model for export.
|
20
|
+
|
21
|
+
Ultralytics YOLO models requires setting `export=True` on some of the torch modules for exporting to work properly.
|
22
|
+
This function handles setting that value on the necessary modules.
|
23
|
+
"""
|
24
|
+
if not hasattr(model, "model"):
|
25
|
+
raise ValueError(
|
26
|
+
"input to `prepare_ultralytics_yolo` is not a supported ultralytics yolo model"
|
27
|
+
)
|
28
|
+
if hasattr(model.model, "fuse") and callable(model.model.fuse):
|
29
|
+
model.model.fuse()
|
30
|
+
|
31
|
+
for _, m in model.named_modules():
|
32
|
+
if hasattr(m, "export"):
|
33
|
+
m.export = True
|