gimlet-api 0.0.1__1-py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gimlet_api-0.0.1.dist-info/METADATA +16 -0
- gimlet_api-0.0.1.dist-info/RECORD +41 -0
- gimlet_api-0.0.1.dist-info/WHEEL +4 -0
- gml/__init__.py +17 -0
- gml/_utils.py +41 -0
- gml/client.py +297 -0
- gml/compile.py +61 -0
- gml/model.py +149 -0
- gml/preprocessing.py +77 -0
- gml/proto/gogoproto/gogo_pb2.py +101 -0
- gml/proto/mediapipe/framework/calculator_contract_test_pb2.py +32 -0
- gml/proto/mediapipe/framework/calculator_options_pb2.py +28 -0
- gml/proto/mediapipe/framework/calculator_pb2.py +56 -0
- gml/proto/mediapipe/framework/calculator_profile_pb2.py +47 -0
- gml/proto/mediapipe/framework/mediapipe_options_pb2.py +26 -0
- gml/proto/mediapipe/framework/packet_factory_pb2.py +30 -0
- gml/proto/mediapipe/framework/packet_generator_pb2.py +32 -0
- gml/proto/mediapipe/framework/packet_test_pb2.py +32 -0
- gml/proto/mediapipe/framework/status_handler_pb2.py +27 -0
- gml/proto/mediapipe/framework/stream_handler_pb2.py +29 -0
- gml/proto/mediapipe/framework/test_calculators_pb2.py +32 -0
- gml/proto/mediapipe/framework/thread_pool_executor_pb2.py +30 -0
- gml/proto/opentelemetry/proto/common/v1/common_pb2.py +34 -0
- gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py +62 -0
- gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py +27 -0
- gml/proto/src/api/corepb/v1/controlplane_pb2.py +56 -0
- gml/proto/src/api/corepb/v1/cp_edge_pb2.py +117 -0
- gml/proto/src/api/corepb/v1/mediastream_pb2.py +64 -0
- gml/proto/src/api/corepb/v1/model_exec_pb2.py +174 -0
- gml/proto/src/common/typespb/jwt_pb2.py +61 -0
- gml/proto/src/common/typespb/status_pb2.py +29 -0
- gml/proto/src/common/typespb/uuid_pb2.py +26 -0
- gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py +115 -0
- gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py +452 -0
- gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py +70 -0
- gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py +231 -0
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py +59 -0
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py +132 -0
- gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py +47 -0
- gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py +99 -0
- gml/tensor.py +193 -0
@@ -0,0 +1,16 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: gimlet-api
|
3
|
+
Author: Gimlet Labs, Inc.
|
4
|
+
Author-email: support@gimletlabs.ai
|
5
|
+
Classifier: Programming Language :: Python :: 3
|
6
|
+
Classifier: Operating System :: OS Independent
|
7
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
8
|
+
Classifier: Typing :: Typed
|
9
|
+
Requires-Python: >=3
|
10
|
+
Requires-Dist: protobuf
|
11
|
+
Requires-Dist: grpcio
|
12
|
+
Requires-Dist: torch
|
13
|
+
Requires-Dist: torch_mlir_gml
|
14
|
+
Version: 0.0.1
|
15
|
+
|
16
|
+
UNKNOWN
|
@@ -0,0 +1,41 @@
|
|
1
|
+
gml/__init__.py,sha256=0yRaNlE_nPT6q2OW7NQzMHNTKt1FYWNYjhCEyhLaRFM,727
|
2
|
+
gml/_utils.py,sha256=X22j6gPuVyFFdGjPflwkCNo34OohdulxtovBnHY8QrM,1193
|
3
|
+
gml/client.py,sha256=x-k0YnNpRu1-0kVcPfiUsQ-lLKAOfAwpaFd28ynntp8,10229
|
4
|
+
gml/compile.py,sha256=CrM7WJn7olXw7nuTOm1IIqOo1a_WbzTmBuUCHZ_hlRQ,2096
|
5
|
+
gml/model.py,sha256=RENQPeE1Df1oV4CyK3KzAztHrm-vrT5gTVipxJx6bTs,5220
|
6
|
+
gml/preprocessing.py,sha256=jcdzY1Xn5a9QU7qiWrOFJ-MZNM0eVTMS3inGR5qSA1E,3013
|
7
|
+
gml/proto/gogoproto/gogo_pb2.py,sha256=WVMIAR8K--mCUkTPM7mEeeXGpQlRRtt_kco10iP3CZs,15728
|
8
|
+
gml/proto/mediapipe/framework/calculator_contract_test_pb2.py,sha256=hNjyZCBz3RYa6rN4xR3FOCZKA24gq_LsJ3EMegl5wK4,2031
|
9
|
+
gml/proto/mediapipe/framework/calculator_options_pb2.py,sha256=Nq1BQRtLdsIgfkw7ymD3eg2p2_RSlZhiHS7YbDhNHR0,1563
|
10
|
+
gml/proto/mediapipe/framework/calculator_pb2.py,sha256=sP2SNeN0mp3YXtWaGFKMDdC3R_uclrLxnYgxWwTy4p8,9092
|
11
|
+
gml/proto/mediapipe/framework/calculator_profile_pb2.py,sha256=uy2yO6DffTL7f4HNuxvpztSZefVE3-SVxvnyxR49Ya4,6335
|
12
|
+
gml/proto/mediapipe/framework/mediapipe_options_pb2.py,sha256=84T1x8HgcwmtF0Eq8ZKeNMWwG6_COAr2246hZE611n0,1318
|
13
|
+
gml/proto/mediapipe/framework/packet_factory_pb2.py,sha256=AfSB4dAuJbDUNgay93Yi56Z9_DLjJZ-XeD6gNZeuxCc,1916
|
14
|
+
gml/proto/mediapipe/framework/packet_generator_pb2.py,sha256=ORDDp5UBLwMPhYiXoAbW8TSLZ6QTaMUeSrmh5jwHPpA,2167
|
15
|
+
gml/proto/mediapipe/framework/packet_test_pb2.py,sha256=K8Uo_1VzSg-u5HFgNjFLQptVsakzHNFeWJz3eldto5g,1797
|
16
|
+
gml/proto/mediapipe/framework/status_handler_pb2.py,sha256=dgiW2ohm-ho07z1k4TM_Xt0MLvC1ZtDgkhhXDByshMk,1707
|
17
|
+
gml/proto/mediapipe/framework/stream_handler_pb2.py,sha256=kNo-2Fdua_CeyJInI3q5r9IoAUanjhk9jh01Z1KXu6Q,2043
|
18
|
+
gml/proto/mediapipe/framework/test_calculators_pb2.py,sha256=tXF25VpGtHGArffRqFmjD6FO7xmuCPd5j9UYON2SVSM,2230
|
19
|
+
gml/proto/mediapipe/framework/thread_pool_executor_pb2.py,sha256=9TJ66fqSo1BiJmEAQesK0fnVe55zcJpOqVip6HotgyE,2345
|
20
|
+
gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=wQjeDti-C8JiNwRn-z5M5p-Fqxm-SmnbPaoitJcSK-4,2860
|
21
|
+
gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=t2Far6oVcUFQIimzgAkZ8vQd0asMIlvECp4osC0ujgg,9735
|
22
|
+
gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=cbNmE12Nm3PjW4NXU7-Z-9m_0Zs3Ab8R1xLkDnvclCg,1730
|
23
|
+
gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=tMndruR8CilHUaZ8QBPOkT2rh-qXmR20h_TlbF3m_MU,5121
|
24
|
+
gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=QPd5gMZvt97ukeHbrgqwroJH4oKTNtjEiAlgfk-FH-k,13496
|
25
|
+
gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=0vavfDHBauwg-yDdh1GQDPw_kleUAY_lZhIjOhtStec,6347
|
26
|
+
gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=v50FEbj2I9aw6whinGx2yA11nEAUXpktP7hphXxtpyc,25459
|
27
|
+
gml/proto/src/common/typespb/jwt_pb2.py,sha256=lxy-bqbyg96i9n_xr2JbkuWX-ldnoJavXPMnApzVSio,5580
|
28
|
+
gml/proto/src/common/typespb/status_pb2.py,sha256=IbBJnbsAlvsuTtyT285ZuW6k5VaPfl5kRSOnBxD_H8M,2109
|
29
|
+
gml/proto/src/common/typespb/uuid_pb2.py,sha256=5Fm3jYpCPX7sMrP6RhRYsF0SnuZNIBEQJk9f0jwZ2Rw,1188
|
30
|
+
gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py,sha256=KgoUT8ccF-yJPe1r4otQjAPQoKBaQzdBlHoIUSkk0yE,11445
|
31
|
+
gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py,sha256=p3OpT8-hfNHu4-29qr-ZahRwO-LoCYM9Q4jomAHTXGA,24572
|
32
|
+
gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py,sha256=r8mbJNTq45_c0amPnTr8OFZasCk7XWu2YS_eu7GfWJg,7050
|
33
|
+
gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py,sha256=XlE4R2PJaOmzQocx7y6SKJvuqt8tYBGzBuhajvzG0cc,12919
|
34
|
+
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=2yD8ZKS5KMqvDd2LFgfOWUU9wC-y1lcSOkumAOOiRCY,5851
|
35
|
+
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=q1PugN3Jm_4v5hVWADJLCIkIEC2_beKEqEH4vb_SpH8,7396
|
36
|
+
gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py,sha256=RVedXkNYu2iF5OHiXoYyRw9AGRCUWG7qNyY-5QY71Go,3762
|
37
|
+
gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py,sha256=KSdb6V04qUHDsb1R2o3wixwTyZgrhwnPYobjnRgWX4I,4735
|
38
|
+
gml/tensor.py,sha256=veEDZGWRCJGa16gAabuCZwSS3jLXDXBk4xTH-v5C-Dw,7170
|
39
|
+
gimlet_api-0.0.1.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
|
40
|
+
gimlet_api-0.0.1.dist-info/METADATA,sha256=G8Y-XJT84t2wtTppxUTAseI3QO18Ta8HCjuqVYlkwGg,429
|
41
|
+
gimlet_api-0.0.1.dist-info/RECORD,,
|
gml/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
1
|
+
# Copyright © 2023- Gimlet Labs, Inc.
|
2
|
+
# All Rights Reserved.
|
3
|
+
#
|
4
|
+
# NOTICE: All information contained herein is, and remains
|
5
|
+
# the property of Gimlet Labs, Inc. and its suppliers,
|
6
|
+
# if any. The intellectual and technical concepts contained
|
7
|
+
# herein are proprietary to Gimlet Labs, Inc. and its suppliers and
|
8
|
+
# may be covered by U.S. and Foreign Patents, patents in process,
|
9
|
+
# and are protected by trade secret or copyright law. Dissemination
|
10
|
+
# of this information or reproduction of this material is strictly
|
11
|
+
# forbidden unless prior written permission is obtained from
|
12
|
+
# Gimlet Labs, Inc.
|
13
|
+
#
|
14
|
+
# SPDX-License-Identifier: Proprietary
|
15
|
+
|
16
|
+
from gml.client import Client # noqa
|
17
|
+
from gml.model import ModelFromFiles, TorchModel # noqa
|
gml/_utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
1
|
+
# Copyright © 2023- Gimlet Labs, Inc.
|
2
|
+
# All Rights Reserved.
|
3
|
+
#
|
4
|
+
# NOTICE: All information contained herein is, and remains
|
5
|
+
# the property of Gimlet Labs, Inc. and its suppliers,
|
6
|
+
# if any. The intellectual and technical concepts contained
|
7
|
+
# herein are proprietary to Gimlet Labs, Inc. and its suppliers and
|
8
|
+
# may be covered by U.S. and Foreign Patents, patents in process,
|
9
|
+
# and are protected by trade secret or copyright law. Dissemination
|
10
|
+
# of this information or reproduction of this material is strictly
|
11
|
+
# forbidden unless prior written permission is obtained from
|
12
|
+
# Gimlet Labs, Inc.
|
13
|
+
#
|
14
|
+
# SPDX-License-Identifier: Proprietary
|
15
|
+
|
16
|
+
import hashlib
|
17
|
+
import os
|
18
|
+
from typing import BinaryIO, TextIO
|
19
|
+
|
20
|
+
|
21
|
+
def chunk_file(f: TextIO | BinaryIO, chunk_size=64 * 1024):
|
22
|
+
while True:
|
23
|
+
chunk = f.read(chunk_size)
|
24
|
+
if not chunk:
|
25
|
+
break
|
26
|
+
yield chunk
|
27
|
+
|
28
|
+
|
29
|
+
def sha256sum(f: BinaryIO, buffer_size=64 * 1024):
|
30
|
+
sha256sum = hashlib.sha256()
|
31
|
+
for chunk in chunk_file(f, buffer_size):
|
32
|
+
sha256sum.update(chunk)
|
33
|
+
|
34
|
+
return sha256sum.hexdigest()
|
35
|
+
|
36
|
+
|
37
|
+
def get_file_size(f: TextIO | BinaryIO):
|
38
|
+
f.seek(0, os.SEEK_END)
|
39
|
+
file_size = f.tell()
|
40
|
+
f.seek(0)
|
41
|
+
return file_size
|
gml/client.py
ADDED
@@ -0,0 +1,297 @@
|
|
1
|
+
# Copyright © 2023- Gimlet Labs, Inc.
|
2
|
+
# All Rights Reserved.
|
3
|
+
#
|
4
|
+
# NOTICE: All information contained herein is, and remains
|
5
|
+
# the property of Gimlet Labs, Inc. and its suppliers,
|
6
|
+
# if any. The intellectual and technical concepts contained
|
7
|
+
# herein are proprietary to Gimlet Labs, Inc. and its suppliers and
|
8
|
+
# may be covered by U.S. and Foreign Patents, patents in process,
|
9
|
+
# and are protected by trade secret or copyright law. Dissemination
|
10
|
+
# of this information or reproduction of this material is strictly
|
11
|
+
# forbidden unless prior written permission is obtained from
|
12
|
+
# Gimlet Labs, Inc.
|
13
|
+
#
|
14
|
+
# SPDX-License-Identifier: Proprietary
|
15
|
+
|
16
|
+
import os
|
17
|
+
import uuid
|
18
|
+
from pathlib import Path
|
19
|
+
from typing import BinaryIO, List, Optional, TextIO
|
20
|
+
|
21
|
+
import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
|
22
|
+
import gml.proto.src.common.typespb.uuid_pb2 as uuidpb
|
23
|
+
import gml.proto.src.controlplane.directory.directorypb.v1.directory_pb2 as directorypb
|
24
|
+
import gml.proto.src.controlplane.directory.directorypb.v1.directory_pb2_grpc as directorypb_grpc
|
25
|
+
import gml.proto.src.controlplane.filetransfer.ftpb.v1.ftpb_pb2 as ftpb
|
26
|
+
import gml.proto.src.controlplane.filetransfer.ftpb.v1.ftpb_pb2_grpc as ftpb_grpc
|
27
|
+
import gml.proto.src.controlplane.logicalpipeline.lppb.v1.lppb_pb2 as lppb
|
28
|
+
import gml.proto.src.controlplane.logicalpipeline.lppb.v1.lppb_pb2_grpc as lppb_grpc
|
29
|
+
import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2 as mpb
|
30
|
+
import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2_grpc as mpb_grpc
|
31
|
+
import grpc
|
32
|
+
from gml._utils import chunk_file, sha256sum
|
33
|
+
from gml.model import Model
|
34
|
+
|
35
|
+
DEFAULT_CONTROLPLANE_ADDR = "app.gimletlabs.ai"
|
36
|
+
|
37
|
+
|
38
|
+
class _ChannelFactory:
|
39
|
+
"""
|
40
|
+
_ChannelFactory creates grpc channels to a controlplane.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def __init__(self, controlplane_addr: str, insecure_no_ssl=False):
|
44
|
+
self.controlplane_addr = controlplane_addr
|
45
|
+
self.insecure_no_ssl = insecure_no_ssl
|
46
|
+
|
47
|
+
self._channel_cache: grpc.Channel = None
|
48
|
+
|
49
|
+
def get_grpc_channel(self) -> grpc.Channel:
|
50
|
+
if self._channel_cache is not None:
|
51
|
+
return self._channel_cache
|
52
|
+
return self._create_grpc_channel()
|
53
|
+
|
54
|
+
def _create_grpc_channel(self) -> grpc.Channel:
|
55
|
+
if self.insecure_no_ssl:
|
56
|
+
return grpc.insecure_channel(self.controlplane_addr)
|
57
|
+
|
58
|
+
creds = grpc.ssl_channel_credentials()
|
59
|
+
return grpc.secure_channel(self.controlplane_addr, creds)
|
60
|
+
|
61
|
+
|
62
|
+
class FileAlreadyExists(Exception):
|
63
|
+
pass
|
64
|
+
|
65
|
+
|
66
|
+
class OrgNotSet(Exception):
|
67
|
+
pass
|
68
|
+
|
69
|
+
|
70
|
+
class APIKeyNotSet(Exception):
|
71
|
+
pass
|
72
|
+
|
73
|
+
|
74
|
+
class Client:
|
75
|
+
"""
|
76
|
+
Client provides authorized access to a controlplane.
|
77
|
+
"""
|
78
|
+
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
api_key: Optional[str] = None,
|
82
|
+
controlplane_addr: Optional[str] = None,
|
83
|
+
org: Optional[str] = None,
|
84
|
+
insecure_no_ssl: bool = False,
|
85
|
+
):
|
86
|
+
self._org_name = org
|
87
|
+
self._api_key = api_key
|
88
|
+
if self._api_key is None:
|
89
|
+
self._api_key = os.getenv("GML_API_KEY")
|
90
|
+
if self._api_key is None:
|
91
|
+
raise APIKeyNotSet(
|
92
|
+
"must provide api_key explicitly or through environment variable GML_API_KEY"
|
93
|
+
)
|
94
|
+
|
95
|
+
self._controlplane_addr = controlplane_addr
|
96
|
+
if self._controlplane_addr is None:
|
97
|
+
self._controlplane_addr = os.getenv("GML_CONTROLPLANE_ADDR")
|
98
|
+
if self._controlplane_addr is None:
|
99
|
+
self._controlplane_addr = DEFAULT_CONTROLPLANE_ADDR
|
100
|
+
|
101
|
+
self._channel_factory = _ChannelFactory(
|
102
|
+
self._controlplane_addr, insecure_no_ssl=insecure_no_ssl
|
103
|
+
)
|
104
|
+
|
105
|
+
self._org_id_cache: Optional[uuidpb.UUID] = None
|
106
|
+
self._fts_stub_cache: Optional[ftpb_grpc.FileTransferServiceStub] = None
|
107
|
+
self._lps_stub_cache: Optional[lppb_grpc.LogicalPipelineServiceStub] = None
|
108
|
+
self._ods_stub_cache: Optional[directorypb_grpc.OrgDirectoryServiceStub] = None
|
109
|
+
self._ms_stub_cache: Optional[mpb_grpc.ModelServiceStub] = None
|
110
|
+
|
111
|
+
def _get_request_metadata(self, idempotent=False):
|
112
|
+
md = [("x-api-key", self._api_key)]
|
113
|
+
if idempotent:
|
114
|
+
md.append(("x-idempotency-key", uuid.uuid4().hex))
|
115
|
+
return md
|
116
|
+
|
117
|
+
def _fts_stub(self):
|
118
|
+
if self._fts_stub_cache is None:
|
119
|
+
self._fts_stub_cache = ftpb_grpc.FileTransferServiceStub(
|
120
|
+
self._channel_factory.get_grpc_channel()
|
121
|
+
)
|
122
|
+
return self._fts_stub_cache
|
123
|
+
|
124
|
+
def _lps_stub(self):
|
125
|
+
if self._lps_stub_cache is None:
|
126
|
+
self._lps_stub_cache = lppb_grpc.LogicalPipelineServiceStub(
|
127
|
+
self._channel_factory.get_grpc_channel()
|
128
|
+
)
|
129
|
+
return self._lps_stub_cache
|
130
|
+
|
131
|
+
def _ods_stub(self):
|
132
|
+
if self._ods_stub_cache is None:
|
133
|
+
self._ods_stub_cache = directorypb_grpc.OrgDirectoryServiceStub(
|
134
|
+
self._channel_factory.get_grpc_channel()
|
135
|
+
)
|
136
|
+
return self._ods_stub_cache
|
137
|
+
|
138
|
+
def _ms_stub(self):
|
139
|
+
if self._ms_stub_cache is None:
|
140
|
+
self._ms_stub_cache = mpb_grpc.ModelServiceStub(
|
141
|
+
self._channel_factory.get_grpc_channel()
|
142
|
+
)
|
143
|
+
return self._ms_stub_cache
|
144
|
+
|
145
|
+
def _get_org_id(self):
|
146
|
+
if self._org_name is None:
|
147
|
+
raise OrgNotSet("organization not set for method that is org specific")
|
148
|
+
stub = self._ods_stub()
|
149
|
+
req = directorypb.GetOrgRequest(org_name=self._org_name)
|
150
|
+
resp: directorypb.GetOrgResponse = stub.GetOrg(
|
151
|
+
req, metadata=self._get_request_metadata()
|
152
|
+
)
|
153
|
+
return resp.org_info.id
|
154
|
+
|
155
|
+
def _org_id(self):
|
156
|
+
if self._org_id_cache is None:
|
157
|
+
self._org_id_cache = self._get_org_id()
|
158
|
+
return self._org_id_cache
|
159
|
+
|
160
|
+
def _create_file(self, name: str) -> ftpb.FileInfo:
|
161
|
+
stub = self._fts_stub()
|
162
|
+
try:
|
163
|
+
req = ftpb.CreateFileInfoRequest(name=name)
|
164
|
+
resp: ftpb.CreateFileInfoResponse = stub.CreateFileInfo(
|
165
|
+
req, metadata=self._get_request_metadata()
|
166
|
+
)
|
167
|
+
except grpc.RpcError as e:
|
168
|
+
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
|
169
|
+
raise FileAlreadyExists(f"A file already exists with name: {name}")
|
170
|
+
raise e
|
171
|
+
return resp.info
|
172
|
+
|
173
|
+
def _file_info_by_name(self, name: str) -> ftpb.GetFileInfoByNameResponse:
|
174
|
+
stub = self._fts_stub()
|
175
|
+
|
176
|
+
req = ftpb.GetFileInfoByNameRequest(name=name)
|
177
|
+
return stub.GetFileInfoByName(req, metadata=self._get_request_metadata()).info
|
178
|
+
|
179
|
+
def _upload_created_file(
|
180
|
+
self,
|
181
|
+
file_id: uuidpb.UUID,
|
182
|
+
sha256: str,
|
183
|
+
file: TextIO | BinaryIO,
|
184
|
+
chunk_size=64 * 1024,
|
185
|
+
):
|
186
|
+
def chunked_requests():
|
187
|
+
file.seek(0)
|
188
|
+
for chunk in chunk_file(file, chunk_size):
|
189
|
+
req = ftpb.UploadFileRequest(
|
190
|
+
file_id=file_id, sha256sum=sha256, chunk=chunk
|
191
|
+
)
|
192
|
+
yield req
|
193
|
+
|
194
|
+
stub = self._fts_stub()
|
195
|
+
resp: ftpb.UploadFileResponse = stub.UploadFile(
|
196
|
+
chunked_requests(), metadata=self._get_request_metadata()
|
197
|
+
)
|
198
|
+
return resp
|
199
|
+
|
200
|
+
def upload_file(
|
201
|
+
self,
|
202
|
+
name: str,
|
203
|
+
file: TextIO | BinaryIO,
|
204
|
+
sha256: Optional[str] = None,
|
205
|
+
chunk_size=64 * 1024,
|
206
|
+
) -> ftpb.FileInfo:
|
207
|
+
file_info = self._create_file(name)
|
208
|
+
|
209
|
+
if sha256 is None:
|
210
|
+
sha256 = sha256sum(file)
|
211
|
+
self._upload_created_file(file_info.file_id, sha256, file, chunk_size)
|
212
|
+
return self._file_info_by_name(name)
|
213
|
+
|
214
|
+
def _upload_file_if_not_exists(
|
215
|
+
self,
|
216
|
+
name: str,
|
217
|
+
file: TextIO | BinaryIO,
|
218
|
+
sha256: Optional[str] = None,
|
219
|
+
) -> ftpb.FileInfo:
|
220
|
+
file_info: Optional[ftpb.FileInfo] = None
|
221
|
+
try:
|
222
|
+
file_info = self.upload_file(name, file, sha256)
|
223
|
+
except FileAlreadyExists:
|
224
|
+
file_info = self._file_info_by_name(name)
|
225
|
+
|
226
|
+
match file_info.status:
|
227
|
+
case ftpb.FILE_STATUS_READY:
|
228
|
+
pass
|
229
|
+
case ftpb.FILE_STATUS_CREATED:
|
230
|
+
self._upload_created_file(file_info.file_id, sha256, file)
|
231
|
+
file_info = self._file_info_by_name(name)
|
232
|
+
case _:
|
233
|
+
raise Exception("file status is deleted or unknown, cannot re-upload")
|
234
|
+
return file_info
|
235
|
+
|
236
|
+
def _create_model(self, model_info: modelexecpb.ModelInfo):
|
237
|
+
req = mpb.CreateModelRequest(
|
238
|
+
org_id=self._get_org_id(),
|
239
|
+
name=model_info.name,
|
240
|
+
model_info=model_info,
|
241
|
+
)
|
242
|
+
stub = self._ms_stub()
|
243
|
+
resp = stub.CreateModel(
|
244
|
+
req, metadata=self._get_request_metadata(idempotent=True)
|
245
|
+
)
|
246
|
+
return resp.id
|
247
|
+
|
248
|
+
def create_model(self, model: Model):
|
249
|
+
model_info = model.to_proto()
|
250
|
+
for asset_name, file in model.collect_assets().items():
|
251
|
+
if isinstance(file, Path) or isinstance(file, str):
|
252
|
+
file = open(file, "rb")
|
253
|
+
|
254
|
+
sha256 = sha256sum(file)
|
255
|
+
|
256
|
+
upload_name = model.name
|
257
|
+
if asset_name:
|
258
|
+
upload_name += ":" + asset_name
|
259
|
+
print(f"Uploading {upload_name}...")
|
260
|
+
|
261
|
+
file_info = self._upload_file_if_not_exists(sha256, file, sha256)
|
262
|
+
|
263
|
+
model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
|
264
|
+
|
265
|
+
file.close()
|
266
|
+
|
267
|
+
return self._create_model(model_info)
|
268
|
+
|
269
|
+
def upload_pipeline(
|
270
|
+
self,
|
271
|
+
*,
|
272
|
+
name: str,
|
273
|
+
models: List[Model],
|
274
|
+
pipeline_file: Optional[Path] = None,
|
275
|
+
pipeline: Optional[str] = None,
|
276
|
+
) -> uuidpb.UUID:
|
277
|
+
if pipeline_file is not None:
|
278
|
+
with open(pipeline_file, "r") as f:
|
279
|
+
yaml = f.read()
|
280
|
+
elif pipeline is not None:
|
281
|
+
yaml = pipeline
|
282
|
+
else:
|
283
|
+
raise ValueError("must specify one of 'pipeline_file' or 'pipeline'")
|
284
|
+
|
285
|
+
for model in models:
|
286
|
+
self.create_model(model)
|
287
|
+
|
288
|
+
stub = self._lps_stub()
|
289
|
+
req = lppb.CreateLogicalPipelineRequest(
|
290
|
+
org_id=self._org_id(),
|
291
|
+
name=name,
|
292
|
+
yaml=yaml,
|
293
|
+
)
|
294
|
+
resp: lppb.CreateLogicalPipelineResponse = stub.CreateLogicalPipeline(
|
295
|
+
req, metadata=self._get_request_metadata(idempotent=True)
|
296
|
+
)
|
297
|
+
return resp.id
|
gml/compile.py
ADDED
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright © 2023- Gimlet Labs, Inc.
|
2
|
+
# All Rights Reserved.
|
3
|
+
#
|
4
|
+
# NOTICE: All information contained herein is, and remains
|
5
|
+
# the property of Gimlet Labs, Inc. and its suppliers,
|
6
|
+
# if any. The intellectual and technical concepts contained
|
7
|
+
# herein are proprietary to Gimlet Labs, Inc. and its suppliers and
|
8
|
+
# may be covered by U.S. and Foreign Patents, patents in process,
|
9
|
+
# and are protected by trade secret or copyright law. Dissemination
|
10
|
+
# of this information or reproduction of this material is strictly
|
11
|
+
# forbidden unless prior written permission is obtained from
|
12
|
+
# Gimlet Labs, Inc.
|
13
|
+
#
|
14
|
+
# SPDX-License-Identifier: Proprietary
|
15
|
+
|
16
|
+
from torch.fx.experimental.proxy_tensor import make_fx
|
17
|
+
from torch_mlir import ExampleArgs, OutputType
|
18
|
+
from torch_mlir import compile as torch_mlir_compile
|
19
|
+
from torch_mlir.dynamo import _get_decomposition_table
|
20
|
+
|
21
|
+
|
22
|
+
def to_torch_mlir(model, example_inputs):
|
23
|
+
example_args = ExampleArgs.get(example_inputs)
|
24
|
+
args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)[
|
25
|
+
"forward"
|
26
|
+
]
|
27
|
+
try:
|
28
|
+
# Running the model a few times on the inputs, leads to more consistent compiled results.
|
29
|
+
for _ in range(2):
|
30
|
+
_ = model(*args)
|
31
|
+
except Exception:
|
32
|
+
# Ignore errors running the model. This can happen when the model has data dependent branches.
|
33
|
+
pass
|
34
|
+
|
35
|
+
try:
|
36
|
+
compiled = torch_mlir_compile(
|
37
|
+
model,
|
38
|
+
example_inputs,
|
39
|
+
use_tracing=False,
|
40
|
+
ignore_traced_shapes=False,
|
41
|
+
output_type=OutputType.RAW,
|
42
|
+
use_make_fx=False,
|
43
|
+
)
|
44
|
+
return compiled
|
45
|
+
except Exception:
|
46
|
+
pass
|
47
|
+
|
48
|
+
# If the module can't be exported directly, we try to create an FX graph and then export it.
|
49
|
+
model = make_fx(
|
50
|
+
model, pre_dispatch=True, decomposition_table=_get_decomposition_table()
|
51
|
+
)(*args)
|
52
|
+
compiled = torch_mlir_compile(
|
53
|
+
model,
|
54
|
+
example_inputs,
|
55
|
+
use_tracing=False,
|
56
|
+
ignore_traced_shapes=False,
|
57
|
+
output_type=OutputType.RAW,
|
58
|
+
use_make_fx=False,
|
59
|
+
)
|
60
|
+
|
61
|
+
return compiled
|
gml/model.py
ADDED
@@ -0,0 +1,149 @@
|
|
1
|
+
# Copyright © 2023- Gimlet Labs, Inc.
|
2
|
+
# All Rights Reserved.
|
3
|
+
#
|
4
|
+
# NOTICE: All information contained herein is, and remains
|
5
|
+
# the property of Gimlet Labs, Inc. and its suppliers,
|
6
|
+
# if any. The intellectual and technical concepts contained
|
7
|
+
# herein are proprietary to Gimlet Labs, Inc. and its suppliers and
|
8
|
+
# may be covered by U.S. and Foreign Patents, patents in process,
|
9
|
+
# and are protected by trade secret or copyright law. Dissemination
|
10
|
+
# of this information or reproduction of this material is strictly
|
11
|
+
# forbidden unless prior written permission is obtained from
|
12
|
+
# Gimlet Labs, Inc.
|
13
|
+
#
|
14
|
+
# SPDX-License-Identifier: Proprietary
|
15
|
+
|
16
|
+
import abc
|
17
|
+
import io
|
18
|
+
from pathlib import Path
|
19
|
+
from typing import BinaryIO, Dict, List, Literal, Optional, TextIO, Tuple
|
20
|
+
|
21
|
+
import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
|
22
|
+
import torch
|
23
|
+
import torch_mlir
|
24
|
+
from gml.compile import to_torch_mlir
|
25
|
+
from gml.preprocessing import ImagePreprocessingStep
|
26
|
+
from gml.tensor import TensorSemantics
|
27
|
+
|
28
|
+
|
29
|
+
class Model(abc.ABC):
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
name: str,
|
33
|
+
kind: modelexecpb.ModelInfo.ModelKind,
|
34
|
+
storage_format: modelexecpb.ModelInfo.ModelStorageFormat,
|
35
|
+
input_tensor_semantics: List[TensorSemantics],
|
36
|
+
output_tensor_semantics: List[TensorSemantics],
|
37
|
+
class_labels: Optional[List[str]] = None,
|
38
|
+
class_labels_file: Optional[Path] = None,
|
39
|
+
image_preprocessing_steps: Optional[List[ImagePreprocessingStep]] = None,
|
40
|
+
):
|
41
|
+
self.name = name
|
42
|
+
self.kind = kind
|
43
|
+
self.storage_format = storage_format
|
44
|
+
self.class_labels = class_labels
|
45
|
+
if class_labels_file:
|
46
|
+
self.class_labels = []
|
47
|
+
with open(class_labels_file, "r") as f:
|
48
|
+
for line in f.readlines():
|
49
|
+
self.class_labels.append(line.strip())
|
50
|
+
self.input_tensor_semantics = input_tensor_semantics
|
51
|
+
self.output_tensor_semantics = output_tensor_semantics
|
52
|
+
self.image_preprocessing_steps = image_preprocessing_steps
|
53
|
+
|
54
|
+
def to_proto(self) -> modelexecpb.ModelInfo:
|
55
|
+
image_preprocessing_steps = None
|
56
|
+
if self.image_preprocessing_steps:
|
57
|
+
image_preprocessing_steps = [
|
58
|
+
step.to_proto() for step in self.image_preprocessing_steps
|
59
|
+
]
|
60
|
+
return modelexecpb.ModelInfo(
|
61
|
+
name=self.name,
|
62
|
+
kind=self.kind,
|
63
|
+
format=self.storage_format,
|
64
|
+
class_labels=self.class_labels,
|
65
|
+
image_preprocessing_steps=image_preprocessing_steps,
|
66
|
+
input_tensor_semantics=[
|
67
|
+
semantics.to_proto() for semantics in self.input_tensor_semantics
|
68
|
+
],
|
69
|
+
output_tensor_semantics=[
|
70
|
+
semantics.to_proto() for semantics in self.output_tensor_semantics
|
71
|
+
],
|
72
|
+
)
|
73
|
+
|
74
|
+
@abc.abstractmethod
|
75
|
+
def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
|
76
|
+
pass
|
77
|
+
|
78
|
+
|
79
|
+
class TorchModel(Model):
|
80
|
+
def __init__(
|
81
|
+
self,
|
82
|
+
name: str,
|
83
|
+
torch_module: torch.nn.Module,
|
84
|
+
input_shapes: List[List[int]],
|
85
|
+
input_dtypes: List[torch.dtype],
|
86
|
+
**kwargs,
|
87
|
+
):
|
88
|
+
super().__init__(
|
89
|
+
name,
|
90
|
+
modelexecpb.ModelInfo.MODEL_KIND_TORCHSCRIPT,
|
91
|
+
modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_MLIR_TEXT,
|
92
|
+
**kwargs,
|
93
|
+
)
|
94
|
+
self.torch_module = torch_module
|
95
|
+
self.input_shapes = input_shapes
|
96
|
+
self.input_dtypes = input_dtypes
|
97
|
+
|
98
|
+
def _convert_to_torch_mlir(self):
|
99
|
+
return to_torch_mlir(
|
100
|
+
self.torch_module,
|
101
|
+
[
|
102
|
+
torch_mlir.TensorPlaceholder(shape, dtype)
|
103
|
+
for shape, dtype in zip(self.input_shapes, self.input_dtypes)
|
104
|
+
],
|
105
|
+
)
|
106
|
+
|
107
|
+
def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
|
108
|
+
compiled = self._convert_to_torch_mlir()
|
109
|
+
file = io.BytesIO(str(compiled).encode("utf-8"))
|
110
|
+
return {"": file}
|
111
|
+
|
112
|
+
|
113
|
+
def _kind_str_to_kind_format_protos(
|
114
|
+
kind: str,
|
115
|
+
) -> Tuple[modelexecpb.ModelInfo.ModelKind, modelexecpb.ModelInfo.ModelStorageFormat]:
|
116
|
+
match kind.lower():
|
117
|
+
case "openvino":
|
118
|
+
return (
|
119
|
+
modelexecpb.ModelInfo.MODEL_KIND_OPENVINO,
|
120
|
+
modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_OPENVINO,
|
121
|
+
)
|
122
|
+
case "onnx":
|
123
|
+
return (
|
124
|
+
modelexecpb.ModelInfo.MODEL_KIND_ONNX,
|
125
|
+
modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_PROTOBUF,
|
126
|
+
)
|
127
|
+
case "tfl":
|
128
|
+
return (
|
129
|
+
modelexecpb.ModelInfo.MODEL_KIND_TFLITE,
|
130
|
+
modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_FLATBUFFER,
|
131
|
+
)
|
132
|
+
case _:
|
133
|
+
raise ValueError("invalid model kind: {}".format(kind))
|
134
|
+
|
135
|
+
|
136
|
+
class ModelFromFiles(Model):
|
137
|
+
def __init__(
|
138
|
+
self,
|
139
|
+
name: str,
|
140
|
+
kind: Literal["openvino", "onnx", "tfl"],
|
141
|
+
files: Dict[str, TextIO | BinaryIO | Path],
|
142
|
+
**kwargs,
|
143
|
+
):
|
144
|
+
kind, storage_format = _kind_str_to_kind_format_protos(kind)
|
145
|
+
super().__init__(name=name, kind=kind, storage_format=storage_format, **kwargs)
|
146
|
+
self.files = files
|
147
|
+
|
148
|
+
def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
|
149
|
+
return self.files
|