gimlet-api 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,8 +9,11 @@ Classifier: Typing :: Typed
9
9
  Requires-Python: >=3
10
10
  Requires-Dist: protobuf
11
11
  Requires-Dist: grpcio
12
- Requires-Dist: torch
13
- Requires-Dist: torch_mlir_gml
14
- Version: 0.0.5
12
+ Requires-Dist: torch>=2.3.0
13
+ Requires-Dist: torch-mlir-gml
14
+ Requires-Dist: numpy<2.0.0
15
+ Requires-Dist: transformers>=4.43.3
16
+ Requires-Dist: safetensors-mlir
17
+ Version: 0.0.7
15
18
 
16
19
  UNKNOWN
@@ -1,11 +1,14 @@
1
- gml/__init__.py,sha256=e81afOgzDBpBNVb3pv_9Y0waGPuX95Y73ofWsq-UVak,718
1
+ gml/__init__.py,sha256=H3WQZ_RaN7VNeb__qeHEbKLEwkaG7gpL5FQ8s1IotUA,773
2
2
  gml/_utils.py,sha256=mSCWHhCdzcUvHqmJIB2FS215K1LMgJCWcZ6e6FWK3hQ,1184
3
- gml/client.py,sha256=4wJ5JeyhMc7UTJazD0GvcHH90Oqiu-x7UY-Q1Zl2Wxo,10640
4
- gml/compile.py,sha256=pS_mjATIvAH-zxSbIgZeizmpi08571zpOkn5bWRyzM0,2087
5
- gml/model.py,sha256=Nwf44iTZjMOwFL05JKHaEvTd1QAkJdtgaAuFs6RVUp8,5256
6
- gml/model_utils.py,sha256=67w32zLI7rjlNsKW6nvCUp-HUjYNMCKMDhjymgJ6EOU,1275
7
- gml/pipelines.py,sha256=RObsuECms-dGFAMskxlknxfV5-9d5sEOll7McUP5reo,3737
8
- gml/preprocessing.py,sha256=w-D_GzRndY-Hec8O5UdMSVDXMWtohbwPRxizUjhHsFA,3004
3
+ gml/asset_manager.py,sha256=VnbqUZHPOgPrAh6ri9C0EuNhS8tAHIrbUyJPAJuD9po,2053
4
+ gml/client.py,sha256=5QDKljltBeBTCd2hH38--fTSP0bVVcAvSnWsA9YEFQc,13819
5
+ gml/compile.py,sha256=K4WdC01WkyLlbcoSJzdF2LaVmOgxIkIdc3YNjRFRw9s,10849
6
+ gml/device.py,sha256=sMILurG02aDjL8wrdBW3ftC44WoAPUeZ4Y0yQ0DtaBk,2665
7
+ gml/hf.py,sha256=hi0Af0Q3FM7VvfLB1PkrNai1j7siH6Ouwc_sXf8QE8c,17900
8
+ gml/model.py,sha256=cHFjIEplWCDeSCSl_IPHYzNyv-KOenx-OAsEd-5TpTs,7260
9
+ gml/model_utils.py,sha256=vZvE5cHZIDkUkeZ4Pk4hhV-zOYMiREluv4b8kdqQ3Ig,1375
10
+ gml/pipelines.py,sha256=6tujvMpAACwmEmUfZGFuRn0N8zKvjSVIsIX7FKcPvEU,4301
11
+ gml/preprocessing.py,sha256=STQDSA1_jXPTenJotNtsNMXOc9h1x_wJyQ100LXS6-g,3209
9
12
  gml/proto/gogoproto/gogo_pb2.py,sha256=WVMIAR8K--mCUkTPM7mEeeXGpQlRRtt_kco10iP3CZs,15728
10
13
  gml/proto/mediapipe/framework/calculator_contract_test_pb2.py,sha256=hNjyZCBz3RYa6rN4xR3FOCZKA24gq_LsJ3EMegl5wK4,2031
11
14
  gml/proto/mediapipe/framework/calculator_options_pb2.py,sha256=Nq1BQRtLdsIgfkw7ymD3eg2p2_RSlZhiHS7YbDhNHR0,1563
@@ -22,22 +25,26 @@ gml/proto/mediapipe/framework/thread_pool_executor_pb2.py,sha256=9TJ66fqSo1BiJmE
22
25
  gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=wQjeDti-C8JiNwRn-z5M5p-Fqxm-SmnbPaoitJcSK-4,2860
23
26
  gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=t2Far6oVcUFQIimzgAkZ8vQd0asMIlvECp4osC0ujgg,9735
24
27
  gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=cbNmE12Nm3PjW4NXU7-Z-9m_0Zs3Ab8R1xLkDnvclCg,1730
25
- gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=FPNx5fRXj-bnN5mkDUXVz17M33vuHV_hmxH0ggkAUVs,5536
26
- gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=QPd5gMZvt97ukeHbrgqwroJH4oKTNtjEiAlgfk-FH-k,13496
27
- gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=yNTIivI9LChFifs90eO5dKwVmAR3WGPYQWXHOuXGUD4,6528
28
- gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=U8ujw_7ijKrULiCqxJ4WY9VWiUtIemfIy2e6h1NPrto,25106
28
+ gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=ChS3EbDhgtywVrI1fU-CqqynhRpRU-YrdK7gFCaN46w,6008
29
+ gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=O0YhmiQrwy-bFuqYt5Vno2M8m99-qlwyhkfeLBB14WQ,14903
30
+ gml/proto/src/api/corepb/v1/device_info_pb2.py,sha256=j4gvzhM1MsDLMHCAwah1X5xxpfG_5nODs7K83mTY0zI,3425
31
+ gml/proto/src/api/corepb/v1/gem_config_pb2.py,sha256=yyEqUqq3-YiX-ByAhbTbZfdh09KuNzEtIYhgk_noJVM,3367
32
+ gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=_WV7zav0uaPHzP-yvjRtUtrwexWiz4eqVclIZmhqdcY,7193
33
+ gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=1DM58lSFgfoHk0ui3ZTjDfifgp4dhE7nHvhMwmInpsA,27103
29
34
  gml/proto/src/common/typespb/jwt_pb2.py,sha256=lxy-bqbyg96i9n_xr2JbkuWX-ldnoJavXPMnApzVSio,5580
30
35
  gml/proto/src/common/typespb/status_pb2.py,sha256=IbBJnbsAlvsuTtyT285ZuW6k5VaPfl5kRSOnBxD_H8M,2109
31
36
  gml/proto/src/common/typespb/uuid_pb2.py,sha256=5Fm3jYpCPX7sMrP6RhRYsF0SnuZNIBEQJk9f0jwZ2Rw,1188
37
+ gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256=tkJFPWpndKZy19TFuLKlBfWW1fUQPj0lJLiQ9HfugZU,3213
38
+ gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2_grpc.py,sha256=l-gTK9nYpTlVb7QGAckSQXlHhkRdKe2-nrxXc8NQavY,2912
32
39
  gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py,sha256=KgoUT8ccF-yJPe1r4otQjAPQoKBaQzdBlHoIUSkk0yE,11445
33
40
  gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py,sha256=p3OpT8-hfNHu4-29qr-ZahRwO-LoCYM9Q4jomAHTXGA,24572
34
41
  gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py,sha256=r8mbJNTq45_c0amPnTr8OFZasCk7XWu2YS_eu7GfWJg,7050
35
42
  gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py,sha256=XlE4R2PJaOmzQocx7y6SKJvuqt8tYBGzBuhajvzG0cc,12919
36
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=AvPGThkFumcElIYY9fjqWdTX9J44RCEJbaYi1F7gn4c,5661
37
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=q1PugN3Jm_4v5hVWADJLCIkIEC2_beKEqEH4vb_SpH8,7396
43
+ gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=wvLQvoh2UA5qCcMALT6PS47LYmmVdBz9U47WFLs5Ayg,6330
44
+ gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=-snjW7n6JveUzJVPFcm25XlL19kowPSKgd61l_jPnHA,9541
38
45
  gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py,sha256=RVedXkNYu2iF5OHiXoYyRw9AGRCUWG7qNyY-5QY71Go,3762
39
46
  gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py,sha256=KSdb6V04qUHDsb1R2o3wixwTyZgrhwnPYobjnRgWX4I,4735
40
- gml/tensor.py,sha256=8Ne95YFAkid_jjbtuXFtn_Eu0Hn9u5IaBRuwR7BpzxI,11827
41
- gimlet_api-0.0.5.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
42
- gimlet_api-0.0.5.dist-info/METADATA,sha256=t2CgErN9FNsBNkrCtmXNZcFlTnWzB3cA1vtL4jA1Y-c,429
43
- gimlet_api-0.0.5.dist-info/RECORD,,
47
+ gml/tensor.py,sha256=753IsMFYZD7p_f0cQPt4nTIBo5p5S5ELqwCuoHORdMk,14823
48
+ gimlet_api-0.0.7.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
49
+ gimlet_api-0.0.7.dist-info/METADATA,sha256=IhN5ODabyw5hfMJDD-LCeGoiZSmqn1kWCKDNay8or5s,531
50
+ gimlet_api-0.0.7.dist-info/RECORD,,
gml/__init__.py CHANGED
@@ -15,4 +15,5 @@
15
15
  # SPDX-License-Identifier: Apache-2.0
16
16
 
17
17
  from gml.client import Client # noqa
18
+ from gml.hf import import_huggingface_pipeline # noqa
18
19
  from gml.model import ModelFromFiles, TorchModel # noqa
gml/asset_manager.py ADDED
@@ -0,0 +1,75 @@
1
+ # Copyright 2023- Gimlet Labs, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import abc
18
+ import tempfile
19
+ from pathlib import Path
20
+ from typing import Dict
21
+
22
+
23
+ class AssetManager:
24
+ @abc.abstractmethod
25
+ def add_asset(self, name: str) -> Path:
26
+ pass
27
+
28
+ @abc.abstractmethod
29
+ def assets(self) -> Dict[str, Path]:
30
+ pass
31
+
32
+ def __enter__(self):
33
+ return self
34
+
35
+ def __exit__(self, exc, value, tb):
36
+ pass
37
+
38
+
39
+ class DirectoryAssetManager(AssetManager):
40
+ def __init__(self, path: str | Path):
41
+ self.path = Path(path)
42
+ self._asset_paths: Dict[str, Path] = dict()
43
+
44
+ def add_asset(self, name: str) -> Path:
45
+ path = self.path / name
46
+ self._asset_paths[name] = path
47
+ return path
48
+
49
+ def assets(self) -> Dict[str, Path]:
50
+ return self._asset_paths
51
+
52
+
53
+ class TempFileAssetManager(AssetManager):
54
+ def __init__(self):
55
+ self._assets = dict()
56
+ self._asset_paths = dict()
57
+
58
+ def add_asset(self, name: str) -> Path:
59
+ tmp = tempfile.NamedTemporaryFile(mode="w")
60
+ self._assets[name] = tmp
61
+ file = tmp.__enter__()
62
+ self._asset_paths[name] = Path(file.name)
63
+ return self._asset_paths[name]
64
+
65
+ def assets(self) -> Dict[str, Path]:
66
+ return self._asset_paths
67
+
68
+ def __enter__(self):
69
+ return self
70
+
71
+ def __exit__(self, exc, value, tb):
72
+ for tmp in self._assets.values():
73
+ tmp.__exit__(exc, value, tb)
74
+ self._assets.clear()
75
+ self._asset_paths.clear()
gml/client.py CHANGED
@@ -21,6 +21,8 @@ from typing import BinaryIO, List, Optional, TextIO, Union
21
21
 
22
22
  import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
23
23
  import gml.proto.src.common.typespb.uuid_pb2 as uuidpb
24
+ import gml.proto.src.controlplane.compiler.cpb.v1.cpb_pb2 as cpb
25
+ import gml.proto.src.controlplane.compiler.cpb.v1.cpb_pb2_grpc as cpb_grpc
24
26
  import gml.proto.src.controlplane.directory.directorypb.v1.directory_pb2 as directorypb
25
27
  import gml.proto.src.controlplane.directory.directorypb.v1.directory_pb2_grpc as directorypb_grpc
26
28
  import gml.proto.src.controlplane.filetransfer.ftpb.v1.ftpb_pb2 as ftpb
@@ -31,6 +33,7 @@ import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2 as mpb
31
33
  import gml.proto.src.controlplane.model.mpb.v1.mpb_pb2_grpc as mpb_grpc
32
34
  import grpc
33
35
  from gml._utils import chunk_file, sha256sum
36
+ from gml.device import DeviceCapabilities
34
37
  from gml.model import Model
35
38
  from gml.pipelines import Pipeline
36
39
 
@@ -65,6 +68,10 @@ class FileAlreadyExists(Exception):
65
68
  pass
66
69
 
67
70
 
71
+ class ModelAlreadyExists(Exception):
72
+ pass
73
+
74
+
68
75
  class OrgNotSet(Exception):
69
76
  pass
70
77
 
@@ -106,6 +113,7 @@ class Client:
106
113
 
107
114
  self._org_id_cache: Optional[uuidpb.UUID] = None
108
115
  self._fts_stub_cache: Optional[ftpb_grpc.FileTransferServiceStub] = None
116
+ self._cs_stub_cache: Optional[cpb_grpc.CompilerServiceStub] = None
109
117
  self._lps_stub_cache: Optional[lppb_grpc.LogicalPipelineServiceStub] = None
110
118
  self._ods_stub_cache: Optional[directorypb_grpc.OrgDirectoryServiceStub] = None
111
119
  self._ms_stub_cache: Optional[mpb_grpc.ModelServiceStub] = None
@@ -137,6 +145,13 @@ class Client:
137
145
  )
138
146
  return self._ods_stub_cache
139
147
 
148
+ def _cs_stub(self):
149
+ if self._cs_stub_cache is None:
150
+ self._cs_stub_cache = cpb_grpc.CompilerServiceStub(
151
+ self._channel_factory.get_grpc_channel()
152
+ )
153
+ return self._cs_stub_cache
154
+
140
155
  def _ms_stub(self):
141
156
  if self._ms_stub_cache is None:
142
157
  self._ms_stub_cache = mpb_grpc.ModelServiceStub(
@@ -185,7 +200,7 @@ class Client:
185
200
  file_id: uuidpb.UUID,
186
201
  sha256: str,
187
202
  file: TextIO | BinaryIO,
188
- chunk_size=64 * 1024,
203
+ chunk_size=1024 * 1024,
189
204
  ):
190
205
  def chunked_requests():
191
206
  file.seek(0)
@@ -206,7 +221,7 @@ class Client:
206
221
  name: str,
207
222
  file: TextIO | BinaryIO,
208
223
  sha256: Optional[str] = None,
209
- chunk_size=64 * 1024,
224
+ chunk_size=1024 * 1024,
210
225
  ) -> ftpb.FileInfo:
211
226
  file_info = self._create_file(name)
212
227
 
@@ -237,7 +252,21 @@ class Client:
237
252
  raise Exception("file status is deleted or unknown, cannot re-upload")
238
253
  return file_info
239
254
 
240
- def _create_model(self, model_info: modelexecpb.ModelInfo):
255
+ def _get_model_if_exists(self, name: str) -> Optional[cpb.Model]:
256
+ req = mpb.GetModelRequest(
257
+ name=name,
258
+ org_id=self._get_org_id(),
259
+ )
260
+ stub = self._ms_stub()
261
+ try:
262
+ resp = stub.GetModel(req, metadata=self._get_request_metadata())
263
+ return cpb.Model(id=resp.id, info=resp.model_info)
264
+ except grpc.RpcError as e:
265
+ if e.code() != grpc.StatusCode.NOT_FOUND:
266
+ raise e
267
+ return None
268
+
269
+ def _create_model(self, model_info: modelexecpb.ModelInfo) -> cpb.Model:
241
270
  req = mpb.CreateModelRequest(
242
271
  org_id=self._get_org_id(),
243
272
  name=model_info.name,
@@ -247,26 +276,36 @@ class Client:
247
276
  resp = stub.CreateModel(
248
277
  req, metadata=self._get_request_metadata(idempotent=True)
249
278
  )
250
- return resp.id
279
+ return cpb.Model(id=resp.id, info=model_info)
280
+
281
+ def create_model(self, model: Model) -> cpb.Model:
282
+ existing_model = self._get_model_if_exists(model.name)
283
+ if existing_model is not None:
284
+ print(
285
+ 'warning: model "{}" already exists and will not be uploaded.'.format(
286
+ model.name
287
+ )
288
+ )
289
+ return existing_model
251
290
 
252
- def create_model(self, model: Model):
253
291
  model_info = model.to_proto()
254
- for asset_name, file in model.collect_assets().items():
255
- if isinstance(file, Path) or isinstance(file, str):
256
- file = open(file, "rb")
292
+ with model.collect_assets() as model_assets:
293
+ for asset_name, file in model_assets.items():
294
+ if isinstance(file, Path) or isinstance(file, str):
295
+ file = open(file, "rb")
257
296
 
258
- sha256 = sha256sum(file)
297
+ sha256 = sha256sum(file)
259
298
 
260
- upload_name = model.name
261
- if asset_name:
262
- upload_name += ":" + asset_name
263
- print(f"Uploading {upload_name}...")
299
+ upload_name = model.name
300
+ if asset_name:
301
+ upload_name += ":" + asset_name
302
+ print(f"Uploading {upload_name}...")
264
303
 
265
- file_info = self._upload_file_if_not_exists(sha256, file, sha256)
304
+ file_info = self._upload_file_if_not_exists(sha256, file, sha256)
266
305
 
267
- model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
306
+ model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
268
307
 
269
- file.close()
308
+ file.close()
270
309
 
271
310
  return self._create_model(model_info)
272
311
 
@@ -285,9 +324,7 @@ class Client:
285
324
  if isinstance(pipeline, Pipeline):
286
325
  if self._org_name is None:
287
326
  raise ValueError("must set `org` to upload a pipeline")
288
- yaml = pipeline.to_yaml(
289
- [model.name for model in models], self._org_name
290
- )
327
+ yaml = pipeline.to_yaml(models, self._org_name)
291
328
  else:
292
329
  yaml = pipeline
293
330
  else:
@@ -306,3 +343,52 @@ class Client:
306
343
  req, metadata=self._get_request_metadata(idempotent=True)
307
344
  )
308
345
  return resp.id
346
+
347
+ def check_compile(
348
+ self,
349
+ *,
350
+ models: List[Model],
351
+ pipeline_file: Optional[Path] = None,
352
+ pipeline: Optional[Union[str, Pipeline]] = None,
353
+ runtimes=List[str],
354
+ cameras=List[str],
355
+ ):
356
+ model_with_assets = []
357
+ for model in models:
358
+ created_model = self.create_model(model)
359
+ model_with_assets.append(created_model)
360
+
361
+ if pipeline_file is not None:
362
+ with open(pipeline_file, "r") as f:
363
+ yaml = f.read()
364
+ elif isinstance(pipeline, Pipeline):
365
+ if self._org_name is None:
366
+ raise ValueError("must set `org` to compile a Pipeline object")
367
+ yaml = pipeline.to_yaml(models, self._org_name)
368
+ elif pipeline is None:
369
+ raise ValueError("must specify one of 'pipeline_file' or 'pipeline'")
370
+ else:
371
+ yaml = pipeline
372
+
373
+ stub = self._lps_stub()
374
+ req = lppb.ParseLogicalPipelineYAMLRequest(
375
+ yaml=yaml,
376
+ )
377
+ resp: lppb.ParseLogicalPipelineYAMLResponse = stub.ParseLogicalPipelineYAML(
378
+ req, metadata=self._get_request_metadata()
379
+ )
380
+ pipeline = resp.logical_pipeline
381
+
382
+ capabilities = DeviceCapabilities(runtimes=runtimes, cameras=cameras)
383
+
384
+ stub = self._cs_stub()
385
+ req = cpb.CompileRequest(
386
+ logical_pipeline=pipeline,
387
+ device_capabilities=capabilities.to_proto(),
388
+ models=model_with_assets,
389
+ )
390
+ try:
391
+ stub.Compile(req, metadata=self._get_request_metadata())
392
+ print("Compilation successful")
393
+ except grpc.RpcError as e:
394
+ print("Compilation failed, code {0}: {1}".format(e.code(), e.details()))