gimlet-api 0.0.1__1-py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. gimlet_api-0.0.1.dist-info/METADATA +16 -0
  2. gimlet_api-0.0.1.dist-info/RECORD +41 -0
  3. gimlet_api-0.0.1.dist-info/WHEEL +4 -0
  4. gml/__init__.py +17 -0
  5. gml/_utils.py +41 -0
  6. gml/client.py +297 -0
  7. gml/compile.py +61 -0
  8. gml/model.py +149 -0
  9. gml/preprocessing.py +77 -0
  10. gml/proto/gogoproto/gogo_pb2.py +101 -0
  11. gml/proto/mediapipe/framework/calculator_contract_test_pb2.py +32 -0
  12. gml/proto/mediapipe/framework/calculator_options_pb2.py +28 -0
  13. gml/proto/mediapipe/framework/calculator_pb2.py +56 -0
  14. gml/proto/mediapipe/framework/calculator_profile_pb2.py +47 -0
  15. gml/proto/mediapipe/framework/mediapipe_options_pb2.py +26 -0
  16. gml/proto/mediapipe/framework/packet_factory_pb2.py +30 -0
  17. gml/proto/mediapipe/framework/packet_generator_pb2.py +32 -0
  18. gml/proto/mediapipe/framework/packet_test_pb2.py +32 -0
  19. gml/proto/mediapipe/framework/status_handler_pb2.py +27 -0
  20. gml/proto/mediapipe/framework/stream_handler_pb2.py +29 -0
  21. gml/proto/mediapipe/framework/test_calculators_pb2.py +32 -0
  22. gml/proto/mediapipe/framework/thread_pool_executor_pb2.py +30 -0
  23. gml/proto/opentelemetry/proto/common/v1/common_pb2.py +34 -0
  24. gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py +62 -0
  25. gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py +27 -0
  26. gml/proto/src/api/corepb/v1/controlplane_pb2.py +56 -0
  27. gml/proto/src/api/corepb/v1/cp_edge_pb2.py +117 -0
  28. gml/proto/src/api/corepb/v1/mediastream_pb2.py +64 -0
  29. gml/proto/src/api/corepb/v1/model_exec_pb2.py +174 -0
  30. gml/proto/src/common/typespb/jwt_pb2.py +61 -0
  31. gml/proto/src/common/typespb/status_pb2.py +29 -0
  32. gml/proto/src/common/typespb/uuid_pb2.py +26 -0
  33. gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py +115 -0
  34. gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py +452 -0
  35. gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py +70 -0
  36. gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py +231 -0
  37. gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py +59 -0
  38. gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py +132 -0
  39. gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py +47 -0
  40. gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py +99 -0
  41. gml/tensor.py +193 -0
@@ -0,0 +1,16 @@
1
+ Metadata-Version: 2.1
2
+ Name: gimlet-api
3
+ Author: Gimlet Labs, Inc.
4
+ Author-email: support@gimletlabs.ai
5
+ Classifier: Programming Language :: Python :: 3
6
+ Classifier: Operating System :: OS Independent
7
+ Classifier: License :: OSI Approved :: Apache Software License
8
+ Classifier: Typing :: Typed
9
+ Requires-Python: >=3
10
+ Requires-Dist: protobuf
11
+ Requires-Dist: grpcio
12
+ Requires-Dist: torch
13
+ Requires-Dist: torch_mlir_gml
14
+ Version: 0.0.1
15
+
16
+ UNKNOWN
@@ -0,0 +1,41 @@
1
+ gml/__init__.py,sha256=0yRaNlE_nPT6q2OW7NQzMHNTKt1FYWNYjhCEyhLaRFM,727
2
+ gml/_utils.py,sha256=X22j6gPuVyFFdGjPflwkCNo34OohdulxtovBnHY8QrM,1193
3
+ gml/client.py,sha256=x-k0YnNpRu1-0kVcPfiUsQ-lLKAOfAwpaFd28ynntp8,10229
4
+ gml/compile.py,sha256=CrM7WJn7olXw7nuTOm1IIqOo1a_WbzTmBuUCHZ_hlRQ,2096
5
+ gml/model.py,sha256=RENQPeE1Df1oV4CyK3KzAztHrm-vrT5gTVipxJx6bTs,5220
6
+ gml/preprocessing.py,sha256=jcdzY1Xn5a9QU7qiWrOFJ-MZNM0eVTMS3inGR5qSA1E,3013
7
+ gml/proto/gogoproto/gogo_pb2.py,sha256=WVMIAR8K--mCUkTPM7mEeeXGpQlRRtt_kco10iP3CZs,15728
8
+ gml/proto/mediapipe/framework/calculator_contract_test_pb2.py,sha256=hNjyZCBz3RYa6rN4xR3FOCZKA24gq_LsJ3EMegl5wK4,2031
9
+ gml/proto/mediapipe/framework/calculator_options_pb2.py,sha256=Nq1BQRtLdsIgfkw7ymD3eg2p2_RSlZhiHS7YbDhNHR0,1563
10
+ gml/proto/mediapipe/framework/calculator_pb2.py,sha256=sP2SNeN0mp3YXtWaGFKMDdC3R_uclrLxnYgxWwTy4p8,9092
11
+ gml/proto/mediapipe/framework/calculator_profile_pb2.py,sha256=uy2yO6DffTL7f4HNuxvpztSZefVE3-SVxvnyxR49Ya4,6335
12
+ gml/proto/mediapipe/framework/mediapipe_options_pb2.py,sha256=84T1x8HgcwmtF0Eq8ZKeNMWwG6_COAr2246hZE611n0,1318
13
+ gml/proto/mediapipe/framework/packet_factory_pb2.py,sha256=AfSB4dAuJbDUNgay93Yi56Z9_DLjJZ-XeD6gNZeuxCc,1916
14
+ gml/proto/mediapipe/framework/packet_generator_pb2.py,sha256=ORDDp5UBLwMPhYiXoAbW8TSLZ6QTaMUeSrmh5jwHPpA,2167
15
+ gml/proto/mediapipe/framework/packet_test_pb2.py,sha256=K8Uo_1VzSg-u5HFgNjFLQptVsakzHNFeWJz3eldto5g,1797
16
+ gml/proto/mediapipe/framework/status_handler_pb2.py,sha256=dgiW2ohm-ho07z1k4TM_Xt0MLvC1ZtDgkhhXDByshMk,1707
17
+ gml/proto/mediapipe/framework/stream_handler_pb2.py,sha256=kNo-2Fdua_CeyJInI3q5r9IoAUanjhk9jh01Z1KXu6Q,2043
18
+ gml/proto/mediapipe/framework/test_calculators_pb2.py,sha256=tXF25VpGtHGArffRqFmjD6FO7xmuCPd5j9UYON2SVSM,2230
19
+ gml/proto/mediapipe/framework/thread_pool_executor_pb2.py,sha256=9TJ66fqSo1BiJmEAQesK0fnVe55zcJpOqVip6HotgyE,2345
20
+ gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=wQjeDti-C8JiNwRn-z5M5p-Fqxm-SmnbPaoitJcSK-4,2860
21
+ gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=t2Far6oVcUFQIimzgAkZ8vQd0asMIlvECp4osC0ujgg,9735
22
+ gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=cbNmE12Nm3PjW4NXU7-Z-9m_0Zs3Ab8R1xLkDnvclCg,1730
23
+ gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=tMndruR8CilHUaZ8QBPOkT2rh-qXmR20h_TlbF3m_MU,5121
24
+ gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=QPd5gMZvt97ukeHbrgqwroJH4oKTNtjEiAlgfk-FH-k,13496
25
+ gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=0vavfDHBauwg-yDdh1GQDPw_kleUAY_lZhIjOhtStec,6347
26
+ gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=v50FEbj2I9aw6whinGx2yA11nEAUXpktP7hphXxtpyc,25459
27
+ gml/proto/src/common/typespb/jwt_pb2.py,sha256=lxy-bqbyg96i9n_xr2JbkuWX-ldnoJavXPMnApzVSio,5580
28
+ gml/proto/src/common/typespb/status_pb2.py,sha256=IbBJnbsAlvsuTtyT285ZuW6k5VaPfl5kRSOnBxD_H8M,2109
29
+ gml/proto/src/common/typespb/uuid_pb2.py,sha256=5Fm3jYpCPX7sMrP6RhRYsF0SnuZNIBEQJk9f0jwZ2Rw,1188
30
+ gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py,sha256=KgoUT8ccF-yJPe1r4otQjAPQoKBaQzdBlHoIUSkk0yE,11445
31
+ gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py,sha256=p3OpT8-hfNHu4-29qr-ZahRwO-LoCYM9Q4jomAHTXGA,24572
32
+ gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py,sha256=r8mbJNTq45_c0amPnTr8OFZasCk7XWu2YS_eu7GfWJg,7050
33
+ gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py,sha256=XlE4R2PJaOmzQocx7y6SKJvuqt8tYBGzBuhajvzG0cc,12919
34
+ gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=2yD8ZKS5KMqvDd2LFgfOWUU9wC-y1lcSOkumAOOiRCY,5851
35
+ gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=q1PugN3Jm_4v5hVWADJLCIkIEC2_beKEqEH4vb_SpH8,7396
36
+ gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py,sha256=RVedXkNYu2iF5OHiXoYyRw9AGRCUWG7qNyY-5QY71Go,3762
37
+ gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py,sha256=KSdb6V04qUHDsb1R2o3wixwTyZgrhwnPYobjnRgWX4I,4735
38
+ gml/tensor.py,sha256=veEDZGWRCJGa16gAabuCZwSS3jLXDXBk4xTH-v5C-Dw,7170
39
+ gimlet_api-0.0.1.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
40
+ gimlet_api-0.0.1.dist-info/METADATA,sha256=G8Y-XJT84t2wtTppxUTAseI3QO18Ta8HCjuqVYlkwGg,429
41
+ gimlet_api-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: bazel-wheelmaker 1.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
gml/__init__.py ADDED
@@ -0,0 +1,17 @@
1
+ # Copyright © 2023- Gimlet Labs, Inc.
2
+ # All Rights Reserved.
3
+ #
4
+ # NOTICE: All information contained herein is, and remains
5
+ # the property of Gimlet Labs, Inc. and its suppliers,
6
+ # if any. The intellectual and technical concepts contained
7
+ # herein are proprietary to Gimlet Labs, Inc. and its suppliers and
8
+ # may be covered by U.S. and Foreign Patents, patents in process,
9
+ # and are protected by trade secret or copyright law. Dissemination
10
+ # of this information or reproduction of this material is strictly
11
+ # forbidden unless prior written permission is obtained from
12
+ # Gimlet Labs, Inc.
13
+ #
14
+ # SPDX-License-Identifier: Proprietary
15
+
16
+ from gml.client import Client # noqa
17
+ from gml.model import ModelFromFiles, TorchModel # noqa
gml/_utils.py ADDED
@@ -0,0 +1,41 @@
1
+ # Copyright © 2023- Gimlet Labs, Inc.
2
+ # All Rights Reserved.
3
+ #
4
+ # NOTICE: All information contained herein is, and remains
5
+ # the property of Gimlet Labs, Inc. and its suppliers,
6
+ # if any. The intellectual and technical concepts contained
7
+ # herein are proprietary to Gimlet Labs, Inc. and its suppliers and
8
+ # may be covered by U.S. and Foreign Patents, patents in process,
9
+ # and are protected by trade secret or copyright law. Dissemination
10
+ # of this information or reproduction of this material is strictly
11
+ # forbidden unless prior written permission is obtained from
12
+ # Gimlet Labs, Inc.
13
+ #
14
+ # SPDX-License-Identifier: Proprietary
15
+
16
+ import hashlib
17
+ import os
18
+ from typing import BinaryIO, TextIO
19
+
20
+
21
+ def chunk_file(f: TextIO | BinaryIO, chunk_size=64 * 1024):
22
+ while True:
23
+ chunk = f.read(chunk_size)
24
+ if not chunk:
25
+ break
26
+ yield chunk
27
+
28
+
29
+ def sha256sum(f: BinaryIO, buffer_size=64 * 1024):
30
+ sha256sum = hashlib.sha256()
31
+ for chunk in chunk_file(f, buffer_size):
32
+ sha256sum.update(chunk)
33
+
34
+ return sha256sum.hexdigest()
35
+
36
+
37
+ def get_file_size(f: TextIO | BinaryIO):
38
+ f.seek(0, os.SEEK_END)
39
+ file_size = f.tell()
40
+ f.seek(0)
41
+ return file_size
gml/client.py ADDED
@@ -0,0 +1,297 @@
1
+ # Copyright © 2023- Gimlet Labs, Inc.
2
+ # All Rights Reserved.
3
+ #
4
+ # NOTICE: All information contained herein is, and remains
5
+ # the property of Gimlet Labs, Inc. and its suppliers,
6
+ # if any. The intellectual and technical concepts contained
7
+ # herein are proprietary to Gimlet Labs, Inc. and its suppliers and
8
+ # may be covered by U.S. and Foreign Patents, patents in process,
9
+ # and are protected by trade secret or copyright law. Dissemination
10
+ # of this information or reproduction of this material is strictly
11
+ # forbidden unless prior written permission is obtained from
12
+ # Gimlet Labs, Inc.
13
+ #
14
+ # SPDX-License-Identifier: Proprietary
15
+
16
+ import os
17
+ import uuid
18
+ from pathlib import Path
19
+ from typing import BinaryIO, List, Optional, TextIO
20
+
21
+ import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
22
+ import gml.proto.src.common.typespb.uuid_pb2 as uuidpb
23
+ import gml.proto.src.controlplane.directory.directorypb.v1.directory_pb2 as directorypb
24
+ import gml.proto.src.controlplane.directory.directorypb.v1.directory_pb2_grpc as directorypb_grpc
25
+ import gml.proto.src.controlplane.filetransfer.ftpb.v1.ftpb_pb2 as ftpb
26
+ import gml.proto.src.controlplane.filetransfer.ftpb.v1.ftpb_pb2_grpc as ftpb_grpc
27
+ import gml.proto.src.controlplane.logicalpipeline.lppb.v1.lppb_pb2 as lppb
28
+ import gml.proto.src.controlplane.logicalpipeline.lppb.v1.lppb_pb2_grpc as lppb_grpc
29
+ import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2 as mpb
30
+ import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2_grpc as mpb_grpc
31
+ import grpc
32
+ from gml._utils import chunk_file, sha256sum
33
+ from gml.model import Model
34
+
35
+ DEFAULT_CONTROLPLANE_ADDR = "app.gimletlabs.ai"
36
+
37
+
38
+ class _ChannelFactory:
39
+ """
40
+ _ChannelFactory creates grpc channels to a controlplane.
41
+ """
42
+
43
+ def __init__(self, controlplane_addr: str, insecure_no_ssl=False):
44
+ self.controlplane_addr = controlplane_addr
45
+ self.insecure_no_ssl = insecure_no_ssl
46
+
47
+ self._channel_cache: grpc.Channel = None
48
+
49
+ def get_grpc_channel(self) -> grpc.Channel:
50
+ if self._channel_cache is not None:
51
+ return self._channel_cache
52
+ return self._create_grpc_channel()
53
+
54
+ def _create_grpc_channel(self) -> grpc.Channel:
55
+ if self.insecure_no_ssl:
56
+ return grpc.insecure_channel(self.controlplane_addr)
57
+
58
+ creds = grpc.ssl_channel_credentials()
59
+ return grpc.secure_channel(self.controlplane_addr, creds)
60
+
61
+
62
+ class FileAlreadyExists(Exception):
63
+ pass
64
+
65
+
66
+ class OrgNotSet(Exception):
67
+ pass
68
+
69
+
70
+ class APIKeyNotSet(Exception):
71
+ pass
72
+
73
+
74
+ class Client:
75
+ """
76
+ Client provides authorized access to a controlplane.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ api_key: Optional[str] = None,
82
+ controlplane_addr: Optional[str] = None,
83
+ org: Optional[str] = None,
84
+ insecure_no_ssl: bool = False,
85
+ ):
86
+ self._org_name = org
87
+ self._api_key = api_key
88
+ if self._api_key is None:
89
+ self._api_key = os.getenv("GML_API_KEY")
90
+ if self._api_key is None:
91
+ raise APIKeyNotSet(
92
+ "must provide api_key explicitly or through environment variable GML_API_KEY"
93
+ )
94
+
95
+ self._controlplane_addr = controlplane_addr
96
+ if self._controlplane_addr is None:
97
+ self._controlplane_addr = os.getenv("GML_CONTROLPLANE_ADDR")
98
+ if self._controlplane_addr is None:
99
+ self._controlplane_addr = DEFAULT_CONTROLPLANE_ADDR
100
+
101
+ self._channel_factory = _ChannelFactory(
102
+ self._controlplane_addr, insecure_no_ssl=insecure_no_ssl
103
+ )
104
+
105
+ self._org_id_cache: Optional[uuidpb.UUID] = None
106
+ self._fts_stub_cache: Optional[ftpb_grpc.FileTransferServiceStub] = None
107
+ self._lps_stub_cache: Optional[lppb_grpc.LogicalPipelineServiceStub] = None
108
+ self._ods_stub_cache: Optional[directorypb_grpc.OrgDirectoryServiceStub] = None
109
+ self._ms_stub_cache: Optional[mpb_grpc.ModelServiceStub] = None
110
+
111
+ def _get_request_metadata(self, idempotent=False):
112
+ md = [("x-api-key", self._api_key)]
113
+ if idempotent:
114
+ md.append(("x-idempotency-key", uuid.uuid4().hex))
115
+ return md
116
+
117
+ def _fts_stub(self):
118
+ if self._fts_stub_cache is None:
119
+ self._fts_stub_cache = ftpb_grpc.FileTransferServiceStub(
120
+ self._channel_factory.get_grpc_channel()
121
+ )
122
+ return self._fts_stub_cache
123
+
124
+ def _lps_stub(self):
125
+ if self._lps_stub_cache is None:
126
+ self._lps_stub_cache = lppb_grpc.LogicalPipelineServiceStub(
127
+ self._channel_factory.get_grpc_channel()
128
+ )
129
+ return self._lps_stub_cache
130
+
131
+ def _ods_stub(self):
132
+ if self._ods_stub_cache is None:
133
+ self._ods_stub_cache = directorypb_grpc.OrgDirectoryServiceStub(
134
+ self._channel_factory.get_grpc_channel()
135
+ )
136
+ return self._ods_stub_cache
137
+
138
+ def _ms_stub(self):
139
+ if self._ms_stub_cache is None:
140
+ self._ms_stub_cache = mpb_grpc.ModelServiceStub(
141
+ self._channel_factory.get_grpc_channel()
142
+ )
143
+ return self._ms_stub_cache
144
+
145
+ def _get_org_id(self):
146
+ if self._org_name is None:
147
+ raise OrgNotSet("organization not set for method that is org specific")
148
+ stub = self._ods_stub()
149
+ req = directorypb.GetOrgRequest(org_name=self._org_name)
150
+ resp: directorypb.GetOrgResponse = stub.GetOrg(
151
+ req, metadata=self._get_request_metadata()
152
+ )
153
+ return resp.org_info.id
154
+
155
+ def _org_id(self):
156
+ if self._org_id_cache is None:
157
+ self._org_id_cache = self._get_org_id()
158
+ return self._org_id_cache
159
+
160
+ def _create_file(self, name: str) -> ftpb.FileInfo:
161
+ stub = self._fts_stub()
162
+ try:
163
+ req = ftpb.CreateFileInfoRequest(name=name)
164
+ resp: ftpb.CreateFileInfoResponse = stub.CreateFileInfo(
165
+ req, metadata=self._get_request_metadata()
166
+ )
167
+ except grpc.RpcError as e:
168
+ if e.code() == grpc.StatusCode.ALREADY_EXISTS:
169
+ raise FileAlreadyExists(f"A file already exists with name: {name}")
170
+ raise e
171
+ return resp.info
172
+
173
+ def _file_info_by_name(self, name: str) -> ftpb.GetFileInfoByNameResponse:
174
+ stub = self._fts_stub()
175
+
176
+ req = ftpb.GetFileInfoByNameRequest(name=name)
177
+ return stub.GetFileInfoByName(req, metadata=self._get_request_metadata()).info
178
+
179
+ def _upload_created_file(
180
+ self,
181
+ file_id: uuidpb.UUID,
182
+ sha256: str,
183
+ file: TextIO | BinaryIO,
184
+ chunk_size=64 * 1024,
185
+ ):
186
+ def chunked_requests():
187
+ file.seek(0)
188
+ for chunk in chunk_file(file, chunk_size):
189
+ req = ftpb.UploadFileRequest(
190
+ file_id=file_id, sha256sum=sha256, chunk=chunk
191
+ )
192
+ yield req
193
+
194
+ stub = self._fts_stub()
195
+ resp: ftpb.UploadFileResponse = stub.UploadFile(
196
+ chunked_requests(), metadata=self._get_request_metadata()
197
+ )
198
+ return resp
199
+
200
+ def upload_file(
201
+ self,
202
+ name: str,
203
+ file: TextIO | BinaryIO,
204
+ sha256: Optional[str] = None,
205
+ chunk_size=64 * 1024,
206
+ ) -> ftpb.FileInfo:
207
+ file_info = self._create_file(name)
208
+
209
+ if sha256 is None:
210
+ sha256 = sha256sum(file)
211
+ self._upload_created_file(file_info.file_id, sha256, file, chunk_size)
212
+ return self._file_info_by_name(name)
213
+
214
+ def _upload_file_if_not_exists(
215
+ self,
216
+ name: str,
217
+ file: TextIO | BinaryIO,
218
+ sha256: Optional[str] = None,
219
+ ) -> ftpb.FileInfo:
220
+ file_info: Optional[ftpb.FileInfo] = None
221
+ try:
222
+ file_info = self.upload_file(name, file, sha256)
223
+ except FileAlreadyExists:
224
+ file_info = self._file_info_by_name(name)
225
+
226
+ match file_info.status:
227
+ case ftpb.FILE_STATUS_READY:
228
+ pass
229
+ case ftpb.FILE_STATUS_CREATED:
230
+ self._upload_created_file(file_info.file_id, sha256, file)
231
+ file_info = self._file_info_by_name(name)
232
+ case _:
233
+ raise Exception("file status is deleted or unknown, cannot re-upload")
234
+ return file_info
235
+
236
+ def _create_model(self, model_info: modelexecpb.ModelInfo):
237
+ req = mpb.CreateModelRequest(
238
+ org_id=self._get_org_id(),
239
+ name=model_info.name,
240
+ model_info=model_info,
241
+ )
242
+ stub = self._ms_stub()
243
+ resp = stub.CreateModel(
244
+ req, metadata=self._get_request_metadata(idempotent=True)
245
+ )
246
+ return resp.id
247
+
248
+ def create_model(self, model: Model):
249
+ model_info = model.to_proto()
250
+ for asset_name, file in model.collect_assets().items():
251
+ if isinstance(file, Path) or isinstance(file, str):
252
+ file = open(file, "rb")
253
+
254
+ sha256 = sha256sum(file)
255
+
256
+ upload_name = model.name
257
+ if asset_name:
258
+ upload_name += ":" + asset_name
259
+ print(f"Uploading {upload_name}...")
260
+
261
+ file_info = self._upload_file_if_not_exists(sha256, file, sha256)
262
+
263
+ model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
264
+
265
+ file.close()
266
+
267
+ return self._create_model(model_info)
268
+
269
+ def upload_pipeline(
270
+ self,
271
+ *,
272
+ name: str,
273
+ models: List[Model],
274
+ pipeline_file: Optional[Path] = None,
275
+ pipeline: Optional[str] = None,
276
+ ) -> uuidpb.UUID:
277
+ if pipeline_file is not None:
278
+ with open(pipeline_file, "r") as f:
279
+ yaml = f.read()
280
+ elif pipeline is not None:
281
+ yaml = pipeline
282
+ else:
283
+ raise ValueError("must specify one of 'pipeline_file' or 'pipeline'")
284
+
285
+ for model in models:
286
+ self.create_model(model)
287
+
288
+ stub = self._lps_stub()
289
+ req = lppb.CreateLogicalPipelineRequest(
290
+ org_id=self._org_id(),
291
+ name=name,
292
+ yaml=yaml,
293
+ )
294
+ resp: lppb.CreateLogicalPipelineResponse = stub.CreateLogicalPipeline(
295
+ req, metadata=self._get_request_metadata(idempotent=True)
296
+ )
297
+ return resp.id
gml/compile.py ADDED
@@ -0,0 +1,61 @@
1
+ # Copyright © 2023- Gimlet Labs, Inc.
2
+ # All Rights Reserved.
3
+ #
4
+ # NOTICE: All information contained herein is, and remains
5
+ # the property of Gimlet Labs, Inc. and its suppliers,
6
+ # if any. The intellectual and technical concepts contained
7
+ # herein are proprietary to Gimlet Labs, Inc. and its suppliers and
8
+ # may be covered by U.S. and Foreign Patents, patents in process,
9
+ # and are protected by trade secret or copyright law. Dissemination
10
+ # of this information or reproduction of this material is strictly
11
+ # forbidden unless prior written permission is obtained from
12
+ # Gimlet Labs, Inc.
13
+ #
14
+ # SPDX-License-Identifier: Proprietary
15
+
16
+ from torch.fx.experimental.proxy_tensor import make_fx
17
+ from torch_mlir import ExampleArgs, OutputType
18
+ from torch_mlir import compile as torch_mlir_compile
19
+ from torch_mlir.dynamo import _get_decomposition_table
20
+
21
+
22
+ def to_torch_mlir(model, example_inputs):
23
+ example_args = ExampleArgs.get(example_inputs)
24
+ args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)[
25
+ "forward"
26
+ ]
27
+ try:
28
+ # Running the model a few times on the inputs, leads to more consistent compiled results.
29
+ for _ in range(2):
30
+ _ = model(*args)
31
+ except Exception:
32
+ # Ignore errors running the model. This can happen when the model has data dependent branches.
33
+ pass
34
+
35
+ try:
36
+ compiled = torch_mlir_compile(
37
+ model,
38
+ example_inputs,
39
+ use_tracing=False,
40
+ ignore_traced_shapes=False,
41
+ output_type=OutputType.RAW,
42
+ use_make_fx=False,
43
+ )
44
+ return compiled
45
+ except Exception:
46
+ pass
47
+
48
+ # If the module can't be exported directly, we try to create an FX graph and then export it.
49
+ model = make_fx(
50
+ model, pre_dispatch=True, decomposition_table=_get_decomposition_table()
51
+ )(*args)
52
+ compiled = torch_mlir_compile(
53
+ model,
54
+ example_inputs,
55
+ use_tracing=False,
56
+ ignore_traced_shapes=False,
57
+ output_type=OutputType.RAW,
58
+ use_make_fx=False,
59
+ )
60
+
61
+ return compiled
gml/model.py ADDED
@@ -0,0 +1,149 @@
1
+ # Copyright © 2023- Gimlet Labs, Inc.
2
+ # All Rights Reserved.
3
+ #
4
+ # NOTICE: All information contained herein is, and remains
5
+ # the property of Gimlet Labs, Inc. and its suppliers,
6
+ # if any. The intellectual and technical concepts contained
7
+ # herein are proprietary to Gimlet Labs, Inc. and its suppliers and
8
+ # may be covered by U.S. and Foreign Patents, patents in process,
9
+ # and are protected by trade secret or copyright law. Dissemination
10
+ # of this information or reproduction of this material is strictly
11
+ # forbidden unless prior written permission is obtained from
12
+ # Gimlet Labs, Inc.
13
+ #
14
+ # SPDX-License-Identifier: Proprietary
15
+
16
+ import abc
17
+ import io
18
+ from pathlib import Path
19
+ from typing import BinaryIO, Dict, List, Literal, Optional, TextIO, Tuple
20
+
21
+ import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
22
+ import torch
23
+ import torch_mlir
24
+ from gml.compile import to_torch_mlir
25
+ from gml.preprocessing import ImagePreprocessingStep
26
+ from gml.tensor import TensorSemantics
27
+
28
+
29
+ class Model(abc.ABC):
30
+ def __init__(
31
+ self,
32
+ name: str,
33
+ kind: modelexecpb.ModelInfo.ModelKind,
34
+ storage_format: modelexecpb.ModelInfo.ModelStorageFormat,
35
+ input_tensor_semantics: List[TensorSemantics],
36
+ output_tensor_semantics: List[TensorSemantics],
37
+ class_labels: Optional[List[str]] = None,
38
+ class_labels_file: Optional[Path] = None,
39
+ image_preprocessing_steps: Optional[List[ImagePreprocessingStep]] = None,
40
+ ):
41
+ self.name = name
42
+ self.kind = kind
43
+ self.storage_format = storage_format
44
+ self.class_labels = class_labels
45
+ if class_labels_file:
46
+ self.class_labels = []
47
+ with open(class_labels_file, "r") as f:
48
+ for line in f.readlines():
49
+ self.class_labels.append(line.strip())
50
+ self.input_tensor_semantics = input_tensor_semantics
51
+ self.output_tensor_semantics = output_tensor_semantics
52
+ self.image_preprocessing_steps = image_preprocessing_steps
53
+
54
+ def to_proto(self) -> modelexecpb.ModelInfo:
55
+ image_preprocessing_steps = None
56
+ if self.image_preprocessing_steps:
57
+ image_preprocessing_steps = [
58
+ step.to_proto() for step in self.image_preprocessing_steps
59
+ ]
60
+ return modelexecpb.ModelInfo(
61
+ name=self.name,
62
+ kind=self.kind,
63
+ format=self.storage_format,
64
+ class_labels=self.class_labels,
65
+ image_preprocessing_steps=image_preprocessing_steps,
66
+ input_tensor_semantics=[
67
+ semantics.to_proto() for semantics in self.input_tensor_semantics
68
+ ],
69
+ output_tensor_semantics=[
70
+ semantics.to_proto() for semantics in self.output_tensor_semantics
71
+ ],
72
+ )
73
+
74
+ @abc.abstractmethod
75
+ def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
76
+ pass
77
+
78
+
79
+ class TorchModel(Model):
80
+ def __init__(
81
+ self,
82
+ name: str,
83
+ torch_module: torch.nn.Module,
84
+ input_shapes: List[List[int]],
85
+ input_dtypes: List[torch.dtype],
86
+ **kwargs,
87
+ ):
88
+ super().__init__(
89
+ name,
90
+ modelexecpb.ModelInfo.MODEL_KIND_TORCHSCRIPT,
91
+ modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_MLIR_TEXT,
92
+ **kwargs,
93
+ )
94
+ self.torch_module = torch_module
95
+ self.input_shapes = input_shapes
96
+ self.input_dtypes = input_dtypes
97
+
98
+ def _convert_to_torch_mlir(self):
99
+ return to_torch_mlir(
100
+ self.torch_module,
101
+ [
102
+ torch_mlir.TensorPlaceholder(shape, dtype)
103
+ for shape, dtype in zip(self.input_shapes, self.input_dtypes)
104
+ ],
105
+ )
106
+
107
+ def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
108
+ compiled = self._convert_to_torch_mlir()
109
+ file = io.BytesIO(str(compiled).encode("utf-8"))
110
+ return {"": file}
111
+
112
+
113
+ def _kind_str_to_kind_format_protos(
114
+ kind: str,
115
+ ) -> Tuple[modelexecpb.ModelInfo.ModelKind, modelexecpb.ModelInfo.ModelStorageFormat]:
116
+ match kind.lower():
117
+ case "openvino":
118
+ return (
119
+ modelexecpb.ModelInfo.MODEL_KIND_OPENVINO,
120
+ modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_OPENVINO,
121
+ )
122
+ case "onnx":
123
+ return (
124
+ modelexecpb.ModelInfo.MODEL_KIND_ONNX,
125
+ modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_PROTOBUF,
126
+ )
127
+ case "tfl":
128
+ return (
129
+ modelexecpb.ModelInfo.MODEL_KIND_TFLITE,
130
+ modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_FLATBUFFER,
131
+ )
132
+ case _:
133
+ raise ValueError("invalid model kind: {}".format(kind))
134
+
135
+
136
+ class ModelFromFiles(Model):
137
+ def __init__(
138
+ self,
139
+ name: str,
140
+ kind: Literal["openvino", "onnx", "tfl"],
141
+ files: Dict[str, TextIO | BinaryIO | Path],
142
+ **kwargs,
143
+ ):
144
+ kind, storage_format = _kind_str_to_kind_format_protos(kind)
145
+ super().__init__(name=name, kind=kind, storage_format=storage_format, **kwargs)
146
+ self.files = files
147
+
148
+ def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
149
+ return self.files