gimlet-api 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. gimlet_api-0.0.4.dist-info/METADATA +16 -0
  2. gimlet_api-0.0.4.dist-info/RECORD +43 -0
  3. {gimlet_api-0.0.1.dist-info → gimlet_api-0.0.4.dist-info}/WHEEL +1 -2
  4. gml/__init__.py +18 -0
  5. gml/_utils.py +42 -0
  6. gml/client.py +308 -0
  7. gml/compile.py +44 -8
  8. gml/model.py +150 -0
  9. gml/model_utils.py +33 -0
  10. gml/pipelines.py +149 -0
  11. gml/preprocessing.py +78 -0
  12. gml/proto/gogoproto/gogo_pb2.py +101 -0
  13. gml/proto/mediapipe/framework/calculator_contract_test_pb2.py +32 -0
  14. gml/proto/mediapipe/framework/calculator_options_pb2.py +28 -0
  15. gml/proto/mediapipe/framework/calculator_pb2.py +56 -0
  16. gml/proto/mediapipe/framework/calculator_profile_pb2.py +47 -0
  17. gml/proto/mediapipe/framework/mediapipe_options_pb2.py +26 -0
  18. gml/proto/mediapipe/framework/packet_factory_pb2.py +30 -0
  19. gml/proto/mediapipe/framework/packet_generator_pb2.py +32 -0
  20. gml/proto/mediapipe/framework/packet_test_pb2.py +32 -0
  21. gml/proto/mediapipe/framework/status_handler_pb2.py +27 -0
  22. gml/proto/mediapipe/framework/stream_handler_pb2.py +29 -0
  23. gml/proto/mediapipe/framework/test_calculators_pb2.py +32 -0
  24. gml/proto/mediapipe/framework/thread_pool_executor_pb2.py +30 -0
  25. gml/proto/opentelemetry/proto/common/v1/common_pb2.py +34 -0
  26. gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py +62 -0
  27. gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py +27 -0
  28. gml/proto/src/api/corepb/v1/controlplane_pb2.py +60 -0
  29. gml/proto/src/api/corepb/v1/cp_edge_pb2.py +117 -0
  30. gml/proto/src/api/corepb/v1/mediastream_pb2.py +66 -0
  31. gml/proto/src/api/corepb/v1/model_exec_pb2.py +167 -0
  32. gml/proto/src/common/typespb/jwt_pb2.py +61 -0
  33. gml/proto/src/common/typespb/status_pb2.py +29 -0
  34. gml/proto/src/common/typespb/uuid_pb2.py +26 -0
  35. gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py +115 -0
  36. gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py +452 -0
  37. gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py +70 -0
  38. gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py +231 -0
  39. gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py +57 -0
  40. gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py +132 -0
  41. gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py +47 -0
  42. gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py +99 -0
  43. gml/tensor.py +333 -0
  44. gimlet_api-0.0.1.dist-info/METADATA +0 -6
  45. gimlet_api-0.0.1.dist-info/RECORD +0 -6
  46. gimlet_api-0.0.1.dist-info/top_level.txt +0 -1
@@ -0,0 +1,16 @@
1
+ Metadata-Version: 2.1
2
+ Name: gimlet-api
3
+ Author: Gimlet Labs, Inc.
4
+ Author-email: support@gimletlabs.ai
5
+ Classifier: Programming Language :: Python :: 3
6
+ Classifier: Operating System :: OS Independent
7
+ Classifier: License :: OSI Approved :: Apache Software License
8
+ Classifier: Typing :: Typed
9
+ Requires-Python: >=3
10
+ Requires-Dist: protobuf
11
+ Requires-Dist: grpcio
12
+ Requires-Dist: torch
13
+ Requires-Dist: torch_mlir_gml
14
+ Version: 0.0.4
15
+
16
+ UNKNOWN
@@ -0,0 +1,43 @@
1
+ gml/__init__.py,sha256=e81afOgzDBpBNVb3pv_9Y0waGPuX95Y73ofWsq-UVak,718
2
+ gml/_utils.py,sha256=mSCWHhCdzcUvHqmJIB2FS215K1LMgJCWcZ6e6FWK3hQ,1184
3
+ gml/client.py,sha256=4wJ5JeyhMc7UTJazD0GvcHH90Oqiu-x7UY-Q1Zl2Wxo,10640
4
+ gml/compile.py,sha256=pS_mjATIvAH-zxSbIgZeizmpi08571zpOkn5bWRyzM0,2087
5
+ gml/model.py,sha256=91I0kCyix1e2ZR7UAdxtyxOIS6pTEfUQ3G62MBkYx5k,5221
6
+ gml/model_utils.py,sha256=67w32zLI7rjlNsKW6nvCUp-HUjYNMCKMDhjymgJ6EOU,1275
7
+ gml/pipelines.py,sha256=RObsuECms-dGFAMskxlknxfV5-9d5sEOll7McUP5reo,3737
8
+ gml/preprocessing.py,sha256=w-D_GzRndY-Hec8O5UdMSVDXMWtohbwPRxizUjhHsFA,3004
9
+ gml/proto/gogoproto/gogo_pb2.py,sha256=WVMIAR8K--mCUkTPM7mEeeXGpQlRRtt_kco10iP3CZs,15728
10
+ gml/proto/mediapipe/framework/calculator_contract_test_pb2.py,sha256=hNjyZCBz3RYa6rN4xR3FOCZKA24gq_LsJ3EMegl5wK4,2031
11
+ gml/proto/mediapipe/framework/calculator_options_pb2.py,sha256=Nq1BQRtLdsIgfkw7ymD3eg2p2_RSlZhiHS7YbDhNHR0,1563
12
+ gml/proto/mediapipe/framework/calculator_pb2.py,sha256=sP2SNeN0mp3YXtWaGFKMDdC3R_uclrLxnYgxWwTy4p8,9092
13
+ gml/proto/mediapipe/framework/calculator_profile_pb2.py,sha256=uy2yO6DffTL7f4HNuxvpztSZefVE3-SVxvnyxR49Ya4,6335
14
+ gml/proto/mediapipe/framework/mediapipe_options_pb2.py,sha256=84T1x8HgcwmtF0Eq8ZKeNMWwG6_COAr2246hZE611n0,1318
15
+ gml/proto/mediapipe/framework/packet_factory_pb2.py,sha256=AfSB4dAuJbDUNgay93Yi56Z9_DLjJZ-XeD6gNZeuxCc,1916
16
+ gml/proto/mediapipe/framework/packet_generator_pb2.py,sha256=ORDDp5UBLwMPhYiXoAbW8TSLZ6QTaMUeSrmh5jwHPpA,2167
17
+ gml/proto/mediapipe/framework/packet_test_pb2.py,sha256=K8Uo_1VzSg-u5HFgNjFLQptVsakzHNFeWJz3eldto5g,1797
18
+ gml/proto/mediapipe/framework/status_handler_pb2.py,sha256=dgiW2ohm-ho07z1k4TM_Xt0MLvC1ZtDgkhhXDByshMk,1707
19
+ gml/proto/mediapipe/framework/stream_handler_pb2.py,sha256=kNo-2Fdua_CeyJInI3q5r9IoAUanjhk9jh01Z1KXu6Q,2043
20
+ gml/proto/mediapipe/framework/test_calculators_pb2.py,sha256=tXF25VpGtHGArffRqFmjD6FO7xmuCPd5j9UYON2SVSM,2230
21
+ gml/proto/mediapipe/framework/thread_pool_executor_pb2.py,sha256=9TJ66fqSo1BiJmEAQesK0fnVe55zcJpOqVip6HotgyE,2345
22
+ gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=wQjeDti-C8JiNwRn-z5M5p-Fqxm-SmnbPaoitJcSK-4,2860
23
+ gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=t2Far6oVcUFQIimzgAkZ8vQd0asMIlvECp4osC0ujgg,9735
24
+ gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=cbNmE12Nm3PjW4NXU7-Z-9m_0Zs3Ab8R1xLkDnvclCg,1730
25
+ gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=FPNx5fRXj-bnN5mkDUXVz17M33vuHV_hmxH0ggkAUVs,5536
26
+ gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=QPd5gMZvt97ukeHbrgqwroJH4oKTNtjEiAlgfk-FH-k,13496
27
+ gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=yNTIivI9LChFifs90eO5dKwVmAR3WGPYQWXHOuXGUD4,6528
28
+ gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=U8ujw_7ijKrULiCqxJ4WY9VWiUtIemfIy2e6h1NPrto,25106
29
+ gml/proto/src/common/typespb/jwt_pb2.py,sha256=lxy-bqbyg96i9n_xr2JbkuWX-ldnoJavXPMnApzVSio,5580
30
+ gml/proto/src/common/typespb/status_pb2.py,sha256=IbBJnbsAlvsuTtyT285ZuW6k5VaPfl5kRSOnBxD_H8M,2109
31
+ gml/proto/src/common/typespb/uuid_pb2.py,sha256=5Fm3jYpCPX7sMrP6RhRYsF0SnuZNIBEQJk9f0jwZ2Rw,1188
32
+ gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py,sha256=KgoUT8ccF-yJPe1r4otQjAPQoKBaQzdBlHoIUSkk0yE,11445
33
+ gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py,sha256=p3OpT8-hfNHu4-29qr-ZahRwO-LoCYM9Q4jomAHTXGA,24572
34
+ gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py,sha256=r8mbJNTq45_c0amPnTr8OFZasCk7XWu2YS_eu7GfWJg,7050
35
+ gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py,sha256=XlE4R2PJaOmzQocx7y6SKJvuqt8tYBGzBuhajvzG0cc,12919
36
+ gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=AvPGThkFumcElIYY9fjqWdTX9J44RCEJbaYi1F7gn4c,5661
37
+ gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=q1PugN3Jm_4v5hVWADJLCIkIEC2_beKEqEH4vb_SpH8,7396
38
+ gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py,sha256=RVedXkNYu2iF5OHiXoYyRw9AGRCUWG7qNyY-5QY71Go,3762
39
+ gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py,sha256=KSdb6V04qUHDsb1R2o3wixwTyZgrhwnPYobjnRgWX4I,4735
40
+ gml/tensor.py,sha256=8Ne95YFAkid_jjbtuXFtn_Eu0Hn9u5IaBRuwR7BpzxI,11827
41
+ gimlet_api-0.0.4.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
42
+ gimlet_api-0.0.4.dist-info/METADATA,sha256=Bm_ca7tpLR7foub0QOavwvJ9_2_4FQlFPr6n8zwoJBU,429
43
+ gimlet_api-0.0.4.dist-info/RECORD,,
@@ -1,5 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.43.0)
2
+ Generator: bazel-wheelmaker 1.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
-
gml/__init__.py CHANGED
@@ -0,0 +1,18 @@
1
+ # Copyright 2023- Gimlet Labs, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ from gml.client import Client # noqa
18
+ from gml.model import ModelFromFiles, TorchModel # noqa
gml/_utils.py ADDED
@@ -0,0 +1,42 @@
1
+ # Copyright 2023- Gimlet Labs, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import hashlib
18
+ import os
19
+ from typing import BinaryIO, TextIO
20
+
21
+
22
+ def chunk_file(f: TextIO | BinaryIO, chunk_size=64 * 1024):
23
+ while True:
24
+ chunk = f.read(chunk_size)
25
+ if not chunk:
26
+ break
27
+ yield chunk
28
+
29
+
30
+ def sha256sum(f: BinaryIO, buffer_size=64 * 1024):
31
+ sha256sum = hashlib.sha256()
32
+ for chunk in chunk_file(f, buffer_size):
33
+ sha256sum.update(chunk)
34
+
35
+ return sha256sum.hexdigest()
36
+
37
+
38
+ def get_file_size(f: TextIO | BinaryIO):
39
+ f.seek(0, os.SEEK_END)
40
+ file_size = f.tell()
41
+ f.seek(0)
42
+ return file_size
gml/client.py ADDED
@@ -0,0 +1,308 @@
1
+ # Copyright 2023- Gimlet Labs, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import os
18
+ import uuid
19
+ from pathlib import Path
20
+ from typing import BinaryIO, List, Optional, TextIO, Union
21
+
22
+ import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
23
+ import gml.proto.src.common.typespb.uuid_pb2 as uuidpb
24
+ import gml.proto.src.controlplane.directory.directorypb.v1.directory_pb2 as directorypb
25
+ import gml.proto.src.controlplane.directory.directorypb.v1.directory_pb2_grpc as directorypb_grpc
26
+ import gml.proto.src.controlplane.filetransfer.ftpb.v1.ftpb_pb2 as ftpb
27
+ import gml.proto.src.controlplane.filetransfer.ftpb.v1.ftpb_pb2_grpc as ftpb_grpc
28
+ import gml.proto.src.controlplane.logicalpipeline.lppb.v1.lppb_pb2 as lppb
29
+ import gml.proto.src.controlplane.logicalpipeline.lppb.v1.lppb_pb2_grpc as lppb_grpc
30
+ import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2 as mpb
31
+ import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2_grpc as mpb_grpc
32
+ import grpc
33
+ from gml._utils import chunk_file, sha256sum
34
+ from gml.model import Model
35
+ from gml.pipelines import Pipeline
36
+
37
+ DEFAULT_CONTROLPLANE_ADDR = "app.gimletlabs.ai"
38
+
39
+
40
+ class _ChannelFactory:
41
+ """
42
+ _ChannelFactory creates grpc channels to a controlplane.
43
+ """
44
+
45
+ def __init__(self, controlplane_addr: str, insecure_no_ssl=False):
46
+ self.controlplane_addr = controlplane_addr
47
+ self.insecure_no_ssl = insecure_no_ssl
48
+
49
+ self._channel_cache: grpc.Channel = None
50
+
51
+ def get_grpc_channel(self) -> grpc.Channel:
52
+ if self._channel_cache is not None:
53
+ return self._channel_cache
54
+ return self._create_grpc_channel()
55
+
56
+ def _create_grpc_channel(self) -> grpc.Channel:
57
+ if self.insecure_no_ssl:
58
+ return grpc.insecure_channel(self.controlplane_addr)
59
+
60
+ creds = grpc.ssl_channel_credentials()
61
+ return grpc.secure_channel(self.controlplane_addr, creds)
62
+
63
+
64
+ class FileAlreadyExists(Exception):
65
+ pass
66
+
67
+
68
+ class OrgNotSet(Exception):
69
+ pass
70
+
71
+
72
+ class APIKeyNotSet(Exception):
73
+ pass
74
+
75
+
76
+ class Client:
77
+ """
78
+ Client provides authorized access to a controlplane.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ api_key: Optional[str] = None,
84
+ controlplane_addr: Optional[str] = None,
85
+ org: Optional[str] = None,
86
+ insecure_no_ssl: bool = False,
87
+ ):
88
+ self._org_name = org
89
+ self._api_key = api_key
90
+ if self._api_key is None:
91
+ self._api_key = os.getenv("GML_API_KEY")
92
+ if self._api_key is None:
93
+ raise APIKeyNotSet(
94
+ "must provide api_key explicitly or through environment variable GML_API_KEY"
95
+ )
96
+
97
+ self._controlplane_addr = controlplane_addr
98
+ if self._controlplane_addr is None:
99
+ self._controlplane_addr = os.getenv("GML_CONTROLPLANE_ADDR")
100
+ if self._controlplane_addr is None:
101
+ self._controlplane_addr = DEFAULT_CONTROLPLANE_ADDR
102
+
103
+ self._channel_factory = _ChannelFactory(
104
+ self._controlplane_addr, insecure_no_ssl=insecure_no_ssl
105
+ )
106
+
107
+ self._org_id_cache: Optional[uuidpb.UUID] = None
108
+ self._fts_stub_cache: Optional[ftpb_grpc.FileTransferServiceStub] = None
109
+ self._lps_stub_cache: Optional[lppb_grpc.LogicalPipelineServiceStub] = None
110
+ self._ods_stub_cache: Optional[directorypb_grpc.OrgDirectoryServiceStub] = None
111
+ self._ms_stub_cache: Optional[mpb_grpc.ModelServiceStub] = None
112
+
113
+ def _get_request_metadata(self, idempotent=False):
114
+ md = [("x-api-key", self._api_key)]
115
+ if idempotent:
116
+ md.append(("x-idempotency-key", uuid.uuid4().hex))
117
+ return md
118
+
119
+ def _fts_stub(self):
120
+ if self._fts_stub_cache is None:
121
+ self._fts_stub_cache = ftpb_grpc.FileTransferServiceStub(
122
+ self._channel_factory.get_grpc_channel()
123
+ )
124
+ return self._fts_stub_cache
125
+
126
+ def _lps_stub(self):
127
+ if self._lps_stub_cache is None:
128
+ self._lps_stub_cache = lppb_grpc.LogicalPipelineServiceStub(
129
+ self._channel_factory.get_grpc_channel()
130
+ )
131
+ return self._lps_stub_cache
132
+
133
+ def _ods_stub(self):
134
+ if self._ods_stub_cache is None:
135
+ self._ods_stub_cache = directorypb_grpc.OrgDirectoryServiceStub(
136
+ self._channel_factory.get_grpc_channel()
137
+ )
138
+ return self._ods_stub_cache
139
+
140
+ def _ms_stub(self):
141
+ if self._ms_stub_cache is None:
142
+ self._ms_stub_cache = mpb_grpc.ModelServiceStub(
143
+ self._channel_factory.get_grpc_channel()
144
+ )
145
+ return self._ms_stub_cache
146
+
147
+ def _get_org_id(self):
148
+ if self._org_name is None:
149
+ raise OrgNotSet("organization not set for method that is org specific")
150
+ stub = self._ods_stub()
151
+ req = directorypb.GetOrgRequest(org_name=self._org_name)
152
+ resp: directorypb.GetOrgResponse = stub.GetOrg(
153
+ req, metadata=self._get_request_metadata()
154
+ )
155
+ return resp.org_info.id
156
+
157
+ def _org_id(self):
158
+ if self._org_id_cache is None:
159
+ self._org_id_cache = self._get_org_id()
160
+ return self._org_id_cache
161
+
162
+ def _create_file(self, name: str) -> ftpb.FileInfo:
163
+ stub = self._fts_stub()
164
+ try:
165
+ req = ftpb.CreateFileInfoRequest(name=name)
166
+ resp: ftpb.CreateFileInfoResponse = stub.CreateFileInfo(
167
+ req, metadata=self._get_request_metadata()
168
+ )
169
+ except grpc.RpcError as e:
170
+ if e.code() == grpc.StatusCode.ALREADY_EXISTS:
171
+ raise FileAlreadyExists(
172
+ f"A file already exists with name: {name}"
173
+ ) from e
174
+ raise e
175
+ return resp.info
176
+
177
+ def _file_info_by_name(self, name: str) -> ftpb.GetFileInfoByNameResponse:
178
+ stub = self._fts_stub()
179
+
180
+ req = ftpb.GetFileInfoByNameRequest(name=name)
181
+ return stub.GetFileInfoByName(req, metadata=self._get_request_metadata()).info
182
+
183
+ def _upload_created_file(
184
+ self,
185
+ file_id: uuidpb.UUID,
186
+ sha256: str,
187
+ file: TextIO | BinaryIO,
188
+ chunk_size=64 * 1024,
189
+ ):
190
+ def chunked_requests():
191
+ file.seek(0)
192
+ for chunk in chunk_file(file, chunk_size):
193
+ req = ftpb.UploadFileRequest(
194
+ file_id=file_id, sha256sum=sha256, chunk=chunk
195
+ )
196
+ yield req
197
+
198
+ stub = self._fts_stub()
199
+ resp: ftpb.UploadFileResponse = stub.UploadFile(
200
+ chunked_requests(), metadata=self._get_request_metadata()
201
+ )
202
+ return resp
203
+
204
+ def upload_file(
205
+ self,
206
+ name: str,
207
+ file: TextIO | BinaryIO,
208
+ sha256: Optional[str] = None,
209
+ chunk_size=64 * 1024,
210
+ ) -> ftpb.FileInfo:
211
+ file_info = self._create_file(name)
212
+
213
+ if sha256 is None:
214
+ sha256 = sha256sum(file)
215
+ self._upload_created_file(file_info.file_id, sha256, file, chunk_size)
216
+ return self._file_info_by_name(name)
217
+
218
+ def _upload_file_if_not_exists(
219
+ self,
220
+ name: str,
221
+ file: TextIO | BinaryIO,
222
+ sha256: Optional[str] = None,
223
+ ) -> ftpb.FileInfo:
224
+ file_info: Optional[ftpb.FileInfo] = None
225
+ try:
226
+ file_info = self.upload_file(name, file, sha256)
227
+ except FileAlreadyExists:
228
+ file_info = self._file_info_by_name(name)
229
+
230
+ match file_info.status:
231
+ case ftpb.FILE_STATUS_READY:
232
+ pass
233
+ case ftpb.FILE_STATUS_CREATED:
234
+ self._upload_created_file(file_info.file_id, sha256, file)
235
+ file_info = self._file_info_by_name(name)
236
+ case _:
237
+ raise Exception("file status is deleted or unknown, cannot re-upload")
238
+ return file_info
239
+
240
+ def _create_model(self, model_info: modelexecpb.ModelInfo):
241
+ req = mpb.CreateModelRequest(
242
+ org_id=self._get_org_id(),
243
+ name=model_info.name,
244
+ model_info=model_info,
245
+ )
246
+ stub = self._ms_stub()
247
+ resp = stub.CreateModel(
248
+ req, metadata=self._get_request_metadata(idempotent=True)
249
+ )
250
+ return resp.id
251
+
252
+ def create_model(self, model: Model):
253
+ model_info = model.to_proto()
254
+ for asset_name, file in model.collect_assets().items():
255
+ if isinstance(file, Path) or isinstance(file, str):
256
+ file = open(file, "rb")
257
+
258
+ sha256 = sha256sum(file)
259
+
260
+ upload_name = model.name
261
+ if asset_name:
262
+ upload_name += ":" + asset_name
263
+ print(f"Uploading {upload_name}...")
264
+
265
+ file_info = self._upload_file_if_not_exists(sha256, file, sha256)
266
+
267
+ model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
268
+
269
+ file.close()
270
+
271
+ return self._create_model(model_info)
272
+
273
+ def upload_pipeline(
274
+ self,
275
+ *,
276
+ name: str,
277
+ models: List[Model],
278
+ pipeline_file: Optional[Path] = None,
279
+ pipeline: Optional[Union[str, Pipeline]] = None,
280
+ ) -> uuidpb.UUID:
281
+ if pipeline_file is not None:
282
+ with open(pipeline_file, "r") as f:
283
+ yaml = f.read()
284
+ elif pipeline is not None:
285
+ if isinstance(pipeline, Pipeline):
286
+ if self._org_name is None:
287
+ raise ValueError("must set `org` to upload a pipeline")
288
+ yaml = pipeline.to_yaml(
289
+ [model.name for model in models], self._org_name
290
+ )
291
+ else:
292
+ yaml = pipeline
293
+ else:
294
+ raise ValueError("must specify one of 'pipeline_file' or 'pipeline'")
295
+
296
+ for model in models:
297
+ self.create_model(model)
298
+
299
+ stub = self._lps_stub()
300
+ req = lppb.CreateLogicalPipelineRequest(
301
+ org_id=self._org_id(),
302
+ name=name,
303
+ yaml=yaml,
304
+ )
305
+ resp: lppb.CreateLogicalPipelineResponse = stub.CreateLogicalPipeline(
306
+ req, metadata=self._get_request_metadata(idempotent=True)
307
+ )
308
+ return resp.id
gml/compile.py CHANGED
@@ -1,19 +1,55 @@
1
- from torch.fx.experimental.proxy_tensor import make_fx
1
+ # Copyright 2023- Gimlet Labs, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
2
16
 
3
- #from torch_mlir.torchscript import compile as torch_mlir_compile
17
+ from torch.fx.experimental.proxy_tensor import make_fx
18
+ from torch_mlir import ExampleArgs, OutputType
4
19
  from torch_mlir import compile as torch_mlir_compile
5
- from torch_mlir import OutputType
6
20
  from torch_mlir.dynamo import _get_decomposition_table
7
21
 
8
22
 
9
23
  def to_torch_mlir(model, example_inputs):
10
- # Running the model a few times on the inputs, leads to more consistent compiled results.
11
- # TODO(james): would be good to understand why this is the case.
12
- for _ in range(2):
13
- _ = model(*example_inputs)
24
+ example_args = ExampleArgs.get(example_inputs)
25
+ args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)[
26
+ "forward"
27
+ ]
28
+ try:
29
+ # Running the model a few times on the inputs, leads to more consistent compiled results.
30
+ for _ in range(2):
31
+ _ = model(*args)
32
+ except Exception:
33
+ # Ignore errors running the model. This can happen when the model has data dependent branches.
34
+ pass
14
35
 
15
- model = make_fx(model, pre_dispatch=True, decomposition_table=_get_decomposition_table())(*example_inputs)
36
+ try:
37
+ compiled = torch_mlir_compile(
38
+ model,
39
+ example_inputs,
40
+ use_tracing=False,
41
+ ignore_traced_shapes=False,
42
+ output_type=OutputType.RAW,
43
+ use_make_fx=False,
44
+ )
45
+ return compiled
46
+ except Exception:
47
+ pass
16
48
 
49
+ # If the module can't be exported directly, we try to create an FX graph and then export it.
50
+ model = make_fx(
51
+ model, pre_dispatch=True, decomposition_table=_get_decomposition_table()
52
+ )(*args)
17
53
  compiled = torch_mlir_compile(
18
54
  model,
19
55
  example_inputs,
gml/model.py ADDED
@@ -0,0 +1,150 @@
1
+ # Copyright 2023- Gimlet Labs, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import abc
18
+ import io
19
+ from pathlib import Path
20
+ from typing import BinaryIO, Dict, List, Literal, Optional, TextIO, Tuple
21
+
22
+ import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
23
+ import torch
24
+ import torch_mlir
25
+ from gml.compile import to_torch_mlir
26
+ from gml.preprocessing import ImagePreprocessingStep
27
+ from gml.tensor import TensorSemantics
28
+
29
+
30
+ class Model(abc.ABC):
31
+ def __init__(
32
+ self,
33
+ name: str,
34
+ kind: modelexecpb.ModelInfo.ModelKind,
35
+ storage_format: modelexecpb.ModelInfo.ModelStorageFormat,
36
+ input_tensor_semantics: List[TensorSemantics],
37
+ output_tensor_semantics: List[TensorSemantics],
38
+ class_labels: Optional[List[str]] = None,
39
+ class_labels_file: Optional[Path] = None,
40
+ image_preprocessing_steps: Optional[List[ImagePreprocessingStep]] = None,
41
+ ):
42
+ self.name = name
43
+ self.kind = kind
44
+ self.storage_format = storage_format
45
+ self.class_labels = class_labels
46
+ if class_labels_file:
47
+ self.class_labels = []
48
+ with open(class_labels_file, "r") as f:
49
+ for line in f.readlines():
50
+ self.class_labels.append(line.strip())
51
+ self.input_tensor_semantics = input_tensor_semantics
52
+ self.output_tensor_semantics = output_tensor_semantics
53
+ self.image_preprocessing_steps = image_preprocessing_steps
54
+
55
+ def to_proto(self) -> modelexecpb.ModelInfo:
56
+ image_preprocessing_steps = None
57
+ if self.image_preprocessing_steps:
58
+ image_preprocessing_steps = [
59
+ step.to_proto() for step in self.image_preprocessing_steps
60
+ ]
61
+ return modelexecpb.ModelInfo(
62
+ name=self.name,
63
+ kind=self.kind,
64
+ format=self.storage_format,
65
+ class_labels=self.class_labels,
66
+ image_preprocessing_steps=image_preprocessing_steps,
67
+ input_tensor_semantics=[
68
+ semantics.to_proto() for semantics in self.input_tensor_semantics
69
+ ],
70
+ output_tensor_semantics=[
71
+ semantics.to_proto() for semantics in self.output_tensor_semantics
72
+ ],
73
+ )
74
+
75
+ @abc.abstractmethod
76
+ def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
77
+ pass
78
+
79
+
80
+ class TorchModel(Model):
81
+ def __init__(
82
+ self,
83
+ name: str,
84
+ torch_module: torch.nn.Module,
85
+ input_shapes: List[List[int]],
86
+ input_dtypes: List[torch.dtype],
87
+ **kwargs,
88
+ ):
89
+ super().__init__(
90
+ name,
91
+ modelexecpb.ModelInfo.MODEL_KIND_TORCHSCRIPT,
92
+ modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_MLIR_TEXT,
93
+ **kwargs,
94
+ )
95
+ self.torch_module = torch_module
96
+ self.input_shapes = input_shapes
97
+ self.input_dtypes = input_dtypes
98
+
99
+ def _convert_to_torch_mlir(self):
100
+ return to_torch_mlir(
101
+ self.torch_module.to("cpu"),
102
+ [
103
+ torch_mlir.TensorPlaceholder(shape, dtype)
104
+ for shape, dtype in zip(self.input_shapes, self.input_dtypes)
105
+ ],
106
+ )
107
+
108
+ def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
109
+ compiled = self._convert_to_torch_mlir()
110
+ file = io.BytesIO(str(compiled).encode("utf-8"))
111
+ return {"": file}
112
+
113
+
114
+ def _kind_str_to_kind_format_protos(
115
+ kind: str,
116
+ ) -> Tuple[modelexecpb.ModelInfo.ModelKind, modelexecpb.ModelInfo.ModelStorageFormat]:
117
+ match kind.lower():
118
+ case "openvino":
119
+ return (
120
+ modelexecpb.ModelInfo.MODEL_KIND_OPENVINO,
121
+ modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_OPENVINO,
122
+ )
123
+ case "onnx":
124
+ return (
125
+ modelexecpb.ModelInfo.MODEL_KIND_ONNX,
126
+ modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_PROTOBUF,
127
+ )
128
+ case "tfl":
129
+ return (
130
+ modelexecpb.ModelInfo.MODEL_KIND_TFLITE,
131
+ modelexecpb.ModelInfo.MODEL_STORAGE_FORMAT_FLATBUFFER,
132
+ )
133
+ case _:
134
+ raise ValueError("invalid model kind: {}".format(kind))
135
+
136
+
137
+ class ModelFromFiles(Model):
138
+ def __init__(
139
+ self,
140
+ name: str,
141
+ kind: Literal["openvino", "onnx", "tfl"],
142
+ files: Dict[str, TextIO | BinaryIO | Path],
143
+ **kwargs,
144
+ ):
145
+ kind, storage_format = _kind_str_to_kind_format_protos(kind)
146
+ super().__init__(name=name, kind=kind, storage_format=storage_format, **kwargs)
147
+ self.files = files
148
+
149
+ def collect_assets(self) -> Dict[str, TextIO | BinaryIO | Path]:
150
+ return self.files
gml/model_utils.py ADDED
@@ -0,0 +1,33 @@
1
+ # Copyright 2023- Gimlet Labs, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+
18
+ def prepare_ultralytics_yolo(model):
19
+ """Prepares an ultralytics YOLO model for export.
20
+
21
+ Ultralytics YOLO models requires setting `export=True` on some of the torch modules for exporting to work properly.
22
+ This function handles setting that value on the necessary modules.
23
+ """
24
+ if not hasattr(model, "model"):
25
+ raise ValueError(
26
+ "input to `prepare_ultralytics_yolo` is not a supported ultralytics yolo model"
27
+ )
28
+ if hasattr(model.model, "fuse") and callable(model.model.fuse):
29
+ model.model.fuse()
30
+
31
+ for _, m in model.named_modules():
32
+ if hasattr(m, "export"):
33
+ m.export = True