gimlet-api 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gimlet_api-0.0.13.dist-info → gimlet_api-0.0.15.dist-info}/METADATA +2 -2
- {gimlet_api-0.0.13.dist-info → gimlet_api-0.0.15.dist-info}/RECORD +25 -22
- gml/client.py +70 -25
- gml/compile.py +0 -11
- gml/device.py +2 -2
- gml/hf.py +101 -63
- gml/model_utils.py +10 -1
- gml/pipelines.py +97 -23
- gml/proto/opentelemetry/proto/common/v1/common_pb2.py +3 -1
- gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py +2 -2
- gml/proto/src/api/corepb/v1/actor_net_pb2.py +77 -0
- gml/proto/src/api/corepb/v1/compiled_pipeline_pb2.py +28 -28
- gml/proto/src/api/corepb/v1/controlplane_pb2.py +35 -33
- gml/proto/src/api/corepb/v1/cp_dp_pb2.py +92 -0
- gml/proto/src/api/corepb/v1/cp_edge_pb2.py +90 -77
- gml/proto/src/api/corepb/v1/dataplane_pb2.py +33 -0
- gml/proto/src/api/corepb/v1/device_info_pb2.py +19 -17
- gml/proto/src/api/corepb/v1/mediastream_pb2.py +11 -9
- gml/proto/src/api/corepb/v1/model_exec_pb2.py +164 -141
- gml/proto/src/common/typespb/jwt_pb2.py +14 -10
- gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py +10 -8
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py +17 -11
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py +33 -0
- gml/tensor.py +7 -0
- {gimlet_api-0.0.13.dist-info → gimlet_api-0.0.15.dist-info}/WHEEL +0 -0
@@ -13,10 +13,10 @@ Requires-Dist: torch>=2.6.0
|
|
13
13
|
Requires-Dist: torch-mlir-gml
|
14
14
|
Requires-Dist: numpy<2.0.0
|
15
15
|
Requires-Dist: rich
|
16
|
-
Requires-Dist: transformers>=4.
|
16
|
+
Requires-Dist: transformers>=4.53.0
|
17
17
|
Requires-Dist: tokenizers>=0.21.0
|
18
18
|
Requires-Dist: safetensors-mlir
|
19
19
|
Requires-Dist: packaging
|
20
|
-
Version: 0.0.
|
20
|
+
Version: 0.0.15
|
21
21
|
|
22
22
|
UNKNOWN
|
@@ -1,13 +1,13 @@
|
|
1
1
|
gml/__init__.py,sha256=H3WQZ_RaN7VNeb__qeHEbKLEwkaG7gpL5FQ8s1IotUA,773
|
2
2
|
gml/_utils.py,sha256=mSCWHhCdzcUvHqmJIB2FS215K1LMgJCWcZ6e6FWK3hQ,1184
|
3
3
|
gml/asset_manager.py,sha256=VnbqUZHPOgPrAh6ri9C0EuNhS8tAHIrbUyJPAJuD9po,2053
|
4
|
-
gml/client.py,sha256=
|
5
|
-
gml/compile.py,sha256=
|
6
|
-
gml/device.py,sha256=
|
7
|
-
gml/hf.py,sha256=
|
4
|
+
gml/client.py,sha256=3rqVTMSv7QUlSrMqSQDzXuXDvfm3WTF3G8enK8on9zU,15861
|
5
|
+
gml/compile.py,sha256=1zM0ihwbrptZ4FphyvtvfKJyVFuVVfyBnDnnEeJ12fg,10397
|
6
|
+
gml/device.py,sha256=9Z7dsBfpvTHShd1OWSi1Pvn85EFYDmn1dszWV8YHIJI,2648
|
7
|
+
gml/hf.py,sha256=4tU2c3Th_mc__78Odetg2g3b16eZM4oSevUtMr27H_k,37811
|
8
8
|
gml/model.py,sha256=8fIYlLRduTsUZfYJr_YVPNxbEVIzr7_yaaTe4T-TZ2Y,8429
|
9
|
-
gml/model_utils.py,sha256=
|
10
|
-
gml/pipelines.py,sha256=
|
9
|
+
gml/model_utils.py,sha256=Kw08MIPmwIOocoQXfjlqjn78mVerCQ2uzleT0H_zcck,1821
|
10
|
+
gml/pipelines.py,sha256=Nif9FNqXqjYM6iL3RZzkX-KBsZamSu7xh8HwjNqsB5A,10220
|
11
11
|
gml/preprocessing.py,sha256=YPcxwBOdx0h0ADzoloYbFw9qUGFbi167E8HA4Zwn7Pk,3928
|
12
12
|
gml/proto/gogoproto/gogo_pb2.py,sha256=WVMIAR8K--mCUkTPM7mEeeXGpQlRRtt_kco10iP3CZs,15728
|
13
13
|
gml/proto/mediapipe/framework/calculator_contract_test_pb2.py,sha256=hNjyZCBz3RYa6rN4xR3FOCZKA24gq_LsJ3EMegl5wK4,2031
|
@@ -22,33 +22,36 @@ gml/proto/mediapipe/framework/status_handler_pb2.py,sha256=dgiW2ohm-ho07z1k4TM_X
|
|
22
22
|
gml/proto/mediapipe/framework/stream_handler_pb2.py,sha256=kNo-2Fdua_CeyJInI3q5r9IoAUanjhk9jh01Z1KXu6Q,2043
|
23
23
|
gml/proto/mediapipe/framework/test_calculators_pb2.py,sha256=tXF25VpGtHGArffRqFmjD6FO7xmuCPd5j9UYON2SVSM,2230
|
24
24
|
gml/proto/mediapipe/framework/thread_pool_executor_pb2.py,sha256=9TJ66fqSo1BiJmEAQesK0fnVe55zcJpOqVip6HotgyE,2345
|
25
|
-
gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=
|
25
|
+
gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=2l9c_xGfUvShLFkzofChHmbgpa7I0-u-FJ_J2Wv3lvs,3168
|
26
26
|
gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=k8oW5tmFlJK2574Ky6kDc0JmNNQCLroRwCCGyxDd7JA,9968
|
27
|
-
gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=
|
28
|
-
gml/proto/src/api/corepb/v1/
|
29
|
-
gml/proto/src/api/corepb/v1/
|
30
|
-
gml/proto/src/api/corepb/v1/
|
27
|
+
gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=08f2F5overFnGlyNyZHb5rUyv7-G9pC15c4xDCccIPY,1831
|
28
|
+
gml/proto/src/api/corepb/v1/actor_net_pb2.py,sha256=XV3UZbyxvHcXRG-kFdA5LImaTAnSVRw2scw8Cz5Mn6Q,11009
|
29
|
+
gml/proto/src/api/corepb/v1/compiled_pipeline_pb2.py,sha256=K2xhqSAxQ3w_3VFC43XpeqXm9_EDmVz6Fb8PrA_jxOY,7609
|
30
|
+
gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=n6LOgoQCAJZ2LJivb9DffmzLpYY7e9K1EDT8pQsIbno,14705
|
31
|
+
gml/proto/src/api/corepb/v1/cp_dp_pb2.py,sha256=cHJkgfehbgnohipFXMtoHSzR5mGfVZ9oKwOVFjEMRc4,10358
|
32
|
+
gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=Ruw7_GcoElKzrhEoOMYkUezJDvrjwA29cVMtBhfV8I8,23294
|
33
|
+
gml/proto/src/api/corepb/v1/dataplane_pb2.py,sha256=D3nA4c8624Irh1cIWM8rbvUBUmr29CJ0lQlPko2BgMU,1966
|
31
34
|
gml/proto/src/api/corepb/v1/deployed_pipeline_pb2.py,sha256=XbppBI1fQ-FazD2in1o6Z9_BIPRBArCE5dVUF7iUn3Y,6649
|
32
|
-
gml/proto/src/api/corepb/v1/device_info_pb2.py,sha256=
|
35
|
+
gml/proto/src/api/corepb/v1/device_info_pb2.py,sha256=lXFF04AkL_Y3Tcg9mXAvgo-w3lmAMiuHQZT5yLpyO4s,7029
|
33
36
|
gml/proto/src/api/corepb/v1/gem_config_pb2.py,sha256=vC0g3k9hDv-LhiV6LwaYCly6x00Xx_YA0i2AZSwCo_I,5396
|
34
|
-
gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=
|
35
|
-
gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=
|
36
|
-
gml/proto/src/common/typespb/jwt_pb2.py,sha256=
|
37
|
+
gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=fAB7s7w4soBtaWJXwni5OI--lWapnLM1LeqZzIBWnlo,10359
|
38
|
+
gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=eh6-tQPq5GEhsuRM6IgJRc2PvKhMlQADu_Lj7FZN5O8,37754
|
39
|
+
gml/proto/src/common/typespb/jwt_pb2.py,sha256=JxBZr8JU1mBoo1PKPClXr3SdfjZynRYRlQ-JHZRjqhE,6134
|
37
40
|
gml/proto/src/common/typespb/status_pb2.py,sha256=IbBJnbsAlvsuTtyT285ZuW6k5VaPfl5kRSOnBxD_H8M,2109
|
38
41
|
gml/proto/src/common/typespb/uuid_pb2.py,sha256=5Fm3jYpCPX7sMrP6RhRYsF0SnuZNIBEQJk9f0jwZ2Rw,1188
|
39
|
-
gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256=
|
42
|
+
gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256=MIizns1dezQjocpJjeNJ4Z7BFWNweKrRpJ070L9IaCk,5203
|
40
43
|
gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2_grpc.py,sha256=l-gTK9nYpTlVb7QGAckSQXlHhkRdKe2-nrxXc8NQavY,2912
|
41
44
|
gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py,sha256=S3OzKYO34BRuYs3rSKbLfjAgm3LQb6wQFS-sfFdQSfk,11496
|
42
45
|
gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py,sha256=p3OpT8-hfNHu4-29qr-ZahRwO-LoCYM9Q4jomAHTXGA,24572
|
43
46
|
gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py,sha256=r8mbJNTq45_c0amPnTr8OFZasCk7XWu2YS_eu7GfWJg,7050
|
44
47
|
gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py,sha256=XlE4R2PJaOmzQocx7y6SKJvuqt8tYBGzBuhajvzG0cc,12919
|
45
|
-
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=
|
46
|
-
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=
|
48
|
+
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=CaRhKtLOcz9AhIw9Rxws0ALqmFAktqkQyRlgmoA6OF0,8976
|
49
|
+
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=NA_Ud55lnbJfgOw729j97MqfKuylCopiimt6pryNJoU,13740
|
47
50
|
gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py,sha256=IryUZ-TlpQKvU52-XFRKvuAfiH-0EkXrTzwvmmK7Fmk,4591
|
48
51
|
gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py,sha256=ZABewZxEthfQx2pEfvBfLc4M8JoKazheQOH181CegjY,6586
|
49
52
|
gml/register_submodules.py,sha256=U8IwjVygX2vxNi_aK6ljHOD4mmrOhbyVczvy4wwulqU,5027
|
50
|
-
gml/tensor.py,sha256=
|
53
|
+
gml/tensor.py,sha256=ojRlfMEf5wsLyOHuAVl0wuZlcaqO0KF4EDYdtEju6hk,15229
|
51
54
|
gml/version_utils.py,sha256=ouCemolnoDm71NiQRcfpa5k5bETTLaFCH6lrEyivGNY,1626
|
52
|
-
gimlet_api-0.0.
|
53
|
-
gimlet_api-0.0.
|
54
|
-
gimlet_api-0.0.
|
55
|
+
gimlet_api-0.0.15.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
|
56
|
+
gimlet_api-0.0.15.dist-info/METADATA,sha256=E-gLFjOe2M6B_KiEyg41boQKRbEw7ej_FwV-pJaop1k,611
|
57
|
+
gimlet_api-0.0.15.dist-info/RECORD,,
|
gml/client.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
|
17
17
|
import os
|
18
18
|
import uuid
|
19
|
+
import warnings
|
19
20
|
from pathlib import Path
|
20
21
|
from typing import BinaryIO, List, Optional, TextIO, Union
|
21
22
|
from urllib.parse import quote
|
@@ -24,6 +25,8 @@ import grpc
|
|
24
25
|
from rich.progress import (
|
25
26
|
Console,
|
26
27
|
)
|
28
|
+
from tqdm import TqdmExperimentalWarning
|
29
|
+
from tqdm.rich import tqdm
|
27
30
|
|
28
31
|
import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
|
29
32
|
import gml.proto.src.common.typespb.uuid_pb2 as uuidpb
|
@@ -42,6 +45,9 @@ from gml.device import DeviceCapabilities
|
|
42
45
|
from gml.model import Model
|
43
46
|
from gml.pipelines import Pipeline
|
44
47
|
|
48
|
+
# Filter out tqdm experimental warnings for the rich progress bar.
|
49
|
+
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
|
50
|
+
|
45
51
|
DEFAULT_CONTROLPLANE_ADDR = "app.gimletlabs.ai"
|
46
52
|
console = Console()
|
47
53
|
|
@@ -206,15 +212,27 @@ class Client:
|
|
206
212
|
file_id: uuidpb.UUID,
|
207
213
|
sha256: str,
|
208
214
|
file: TextIO | BinaryIO,
|
215
|
+
display_name: str,
|
209
216
|
chunk_size=1024 * 1024,
|
210
217
|
):
|
211
218
|
def chunked_requests():
|
219
|
+
file.seek(0, os.SEEK_END)
|
220
|
+
total_size = file.tell()
|
221
|
+
|
212
222
|
file.seek(0)
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
223
|
+
|
224
|
+
with tqdm(
|
225
|
+
total=total_size,
|
226
|
+
desc=f"Uploading {display_name}",
|
227
|
+
unit="B",
|
228
|
+
unit_scale=True,
|
229
|
+
) as pbar:
|
230
|
+
for chunk in chunk_file(file, chunk_size):
|
231
|
+
pbar.update(len(chunk))
|
232
|
+
req = ftpb.UploadFileRequest(
|
233
|
+
file_id=file_id, sha256sum=sha256, chunk=chunk
|
234
|
+
)
|
235
|
+
yield req
|
218
236
|
|
219
237
|
stub = self._fts_stub()
|
220
238
|
resp: ftpb.UploadFileResponse = stub.UploadFile(
|
@@ -226,6 +244,7 @@ class Client:
|
|
226
244
|
self,
|
227
245
|
name: str,
|
228
246
|
file: TextIO | BinaryIO,
|
247
|
+
display_name: str,
|
229
248
|
sha256: Optional[str] = None,
|
230
249
|
chunk_size=1024 * 1024,
|
231
250
|
) -> ftpb.FileInfo:
|
@@ -233,18 +252,30 @@ class Client:
|
|
233
252
|
|
234
253
|
if sha256 is None:
|
235
254
|
sha256 = sha256sum(file)
|
236
|
-
self._upload_created_file(
|
255
|
+
self._upload_created_file(
|
256
|
+
file_id=file_info.file_id,
|
257
|
+
sha256=sha256,
|
258
|
+
file=file,
|
259
|
+
display_name=display_name,
|
260
|
+
chunk_size=chunk_size,
|
261
|
+
)
|
237
262
|
return self._file_info_by_name(name)
|
238
263
|
|
239
264
|
def _upload_file_if_not_exists(
|
240
265
|
self,
|
241
|
-
name: str,
|
266
|
+
name: str, # name is what is stored in the file service, typically a sha256 of the file.
|
242
267
|
file: TextIO | BinaryIO,
|
268
|
+
display_name: str,
|
243
269
|
sha256: Optional[str] = None,
|
244
270
|
) -> ftpb.FileInfo:
|
245
271
|
file_info: Optional[ftpb.FileInfo] = None
|
246
272
|
try:
|
247
|
-
file_info = self.upload_file(
|
273
|
+
file_info = self.upload_file(
|
274
|
+
name=name,
|
275
|
+
file=file,
|
276
|
+
display_name=display_name,
|
277
|
+
sha256=sha256,
|
278
|
+
)
|
248
279
|
except FileAlreadyExists:
|
249
280
|
file_info = self._file_info_by_name(name)
|
250
281
|
|
@@ -252,7 +283,12 @@ class Client:
|
|
252
283
|
case ftpb.FILE_STATUS_READY:
|
253
284
|
pass
|
254
285
|
case ftpb.FILE_STATUS_CREATED:
|
255
|
-
self._upload_created_file(
|
286
|
+
self._upload_created_file(
|
287
|
+
file_id=file_info.file_id,
|
288
|
+
sha256=sha256,
|
289
|
+
file=file,
|
290
|
+
display_name=display_name,
|
291
|
+
)
|
256
292
|
file_info = self._file_info_by_name(name)
|
257
293
|
case _:
|
258
294
|
raise Exception("file status is deleted or unknown, cannot re-upload")
|
@@ -292,23 +328,32 @@ class Client:
|
|
292
328
|
)
|
293
329
|
return existing_model
|
294
330
|
model_info = model.to_proto()
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
331
|
+
# Don't use context lib because we want to show progress for asset collection, then will have a separate
|
332
|
+
# progress bar for file upload.
|
333
|
+
status = console.status(f'Tracing model "{model.name}"...')
|
334
|
+
status.start()
|
335
|
+
with model.collect_assets() as model_assets:
|
336
|
+
status.stop()
|
337
|
+
for asset_name, file in model_assets.items():
|
338
|
+
if isinstance(file, Path) or isinstance(file, str):
|
339
|
+
file = open(file, "rb")
|
340
|
+
|
341
|
+
sha256 = sha256sum(file)
|
342
|
+
|
343
|
+
display_name = model.name
|
344
|
+
if asset_name:
|
345
|
+
display_name += ":" + asset_name
|
346
|
+
file_info = self._upload_file_if_not_exists(
|
347
|
+
name=sha256,
|
348
|
+
file=file,
|
349
|
+
display_name=display_name,
|
350
|
+
sha256=sha256,
|
351
|
+
)
|
352
|
+
console.print(f"Uploaded {display_name}.")
|
308
353
|
|
309
|
-
|
354
|
+
model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
|
310
355
|
|
311
|
-
|
356
|
+
file.close()
|
312
357
|
|
313
358
|
return self._create_model(model_info)
|
314
359
|
|
@@ -350,7 +395,7 @@ class Client:
|
|
350
395
|
req, metadata=self._get_request_metadata(idempotent=True)
|
351
396
|
)
|
352
397
|
|
353
|
-
url = f"https://{self._controlplane_addr}/orgs/{quote(self._org_name)}/
|
398
|
+
url = f"https://{self._controlplane_addr}/orgs/{quote(self._org_name)}/workloads/{quote(name, safe='')}"
|
354
399
|
console.print(
|
355
400
|
f"[green]Pipeline upload complete![/green]\nView your pipeline at: [cyan]{url}[/cyan]"
|
356
401
|
)
|
gml/compile.py
CHANGED
@@ -253,18 +253,7 @@ def to_torch_mlir(
|
|
253
253
|
if decomposition_denylist is None:
|
254
254
|
decomposition_denylist = _default_decomposition_denylist()
|
255
255
|
|
256
|
-
model = model.eval().to("cpu")
|
257
|
-
|
258
256
|
submodule_registration_workarounds(model)
|
259
|
-
|
260
|
-
try:
|
261
|
-
# Running the model a few times on the inputs, leads to more consistent compiled results.
|
262
|
-
for _ in range(2):
|
263
|
-
_ = model(*example_inputs)
|
264
|
-
except: # noqa
|
265
|
-
# Ignore errors running the model. This can happen when the model has data dependent branches.
|
266
|
-
pass
|
267
|
-
|
268
257
|
register_dynamic_cache_pytree_node()
|
269
258
|
prog = _export(
|
270
259
|
model,
|
gml/device.py
CHANGED
@@ -57,8 +57,8 @@ def _runtime_str_to_runtime_protos(
|
|
57
57
|
return deviceinfopb.ModelRuntimeType.MODEL_RUNTIME_TYPE_TENSORRT
|
58
58
|
case "openvino":
|
59
59
|
return deviceinfopb.ModelRuntimeType.MODEL_RUNTIME_TYPE_OPENVINO
|
60
|
-
case "
|
61
|
-
return deviceinfopb.ModelRuntimeType.
|
60
|
+
case "habana":
|
61
|
+
return deviceinfopb.ModelRuntimeType.MODEL_RUNTIME_TYPE_HABANA
|
62
62
|
case _:
|
63
63
|
raise ValueError("invalid runtime: {}".format(runtime))
|
64
64
|
|
gml/hf.py
CHANGED
@@ -27,6 +27,7 @@ import transformers
|
|
27
27
|
from rich.progress import Console
|
28
28
|
from transformers import (
|
29
29
|
BaseImageProcessor,
|
30
|
+
Cache,
|
30
31
|
DynamicCache,
|
31
32
|
Pipeline,
|
32
33
|
PreTrainedModel,
|
@@ -52,9 +53,11 @@ from gml.tensor import (
|
|
52
53
|
DetectionOutputDimension,
|
53
54
|
DimensionSemantics,
|
54
55
|
EmbeddingDimension,
|
56
|
+
IgnoreDimension,
|
55
57
|
ImageChannelDimension,
|
56
58
|
ImageHeightDimension,
|
57
59
|
ImageWidthDimension,
|
60
|
+
PositionIDsDimension,
|
58
61
|
SegmentationMaskChannel,
|
59
62
|
TensorSemantics,
|
60
63
|
TokensDimension,
|
@@ -117,10 +120,18 @@ class WrapWithFunctionalCache(torch.nn.Module):
|
|
117
120
|
super().__init__()
|
118
121
|
self.model = model
|
119
122
|
|
120
|
-
def forward(
|
123
|
+
def forward(
|
124
|
+
self,
|
125
|
+
input_ids: torch.LongTensor,
|
126
|
+
past_key_values: Cache,
|
127
|
+
attention_mask: Optional[torch.Tensor] = None,
|
128
|
+
position_ids: Optional[torch.LongTensor] = None,
|
129
|
+
):
|
121
130
|
outputs = self.model(
|
122
131
|
input_ids=input_ids,
|
123
|
-
|
132
|
+
attention_mask=attention_mask,
|
133
|
+
position_ids=position_ids,
|
134
|
+
past_key_values=past_key_values,
|
124
135
|
return_dict=True,
|
125
136
|
use_cache=True,
|
126
137
|
)
|
@@ -134,7 +145,7 @@ class HuggingFaceTextGenerationPipeline:
|
|
134
145
|
pipeline: Pipeline,
|
135
146
|
name: Optional[str] = None,
|
136
147
|
tokenizer_name: Optional[str] = None,
|
137
|
-
|
148
|
+
trace_w_attn_mask_and_pos_ids: bool = False,
|
138
149
|
dynamic_batch: bool = False,
|
139
150
|
export_predispatch: bool = False,
|
140
151
|
):
|
@@ -158,7 +169,7 @@ class HuggingFaceTextGenerationPipeline:
|
|
158
169
|
name,
|
159
170
|
torch_module=self.model,
|
160
171
|
export_predispatch=export_predispatch,
|
161
|
-
**self._guess_model_spec(
|
172
|
+
**self._guess_model_spec(trace_w_attn_mask_and_pos_ids),
|
162
173
|
)
|
163
174
|
|
164
175
|
def _initialize_key_value_cache(self) -> DynamicCache:
|
@@ -249,7 +260,13 @@ class HuggingFaceTextGenerationPipeline:
|
|
249
260
|
),
|
250
261
|
)
|
251
262
|
|
252
|
-
def _guess_model_spec(self,
|
263
|
+
def _guess_model_spec(self, trace_w_attn_mask_and_pos_ids: bool) -> Dict:
|
264
|
+
num_experts_per_tok = (
|
265
|
+
1
|
266
|
+
if not hasattr(self.pipeline.model.config, "num_experts_per_tok")
|
267
|
+
else self.pipeline.model.config.num_experts_per_tok
|
268
|
+
)
|
269
|
+
|
253
270
|
input_dict = self.pipeline.preprocess("this is a prompt! Test test test?")
|
254
271
|
if "input_ids" not in input_dict:
|
255
272
|
raise ValueError(
|
@@ -258,9 +275,22 @@ class HuggingFaceTextGenerationPipeline:
|
|
258
275
|
|
259
276
|
inputs = []
|
260
277
|
input_tensor_semantics = []
|
278
|
+
dynamic_shapes = []
|
279
|
+
|
280
|
+
# Set range to half of seq_length to account for # of tokens per expert.
|
281
|
+
# pytorch export creates a constraint on the number of possible tokens
|
282
|
+
# sent to each expert. That value is num_experts * seq_length. If we don't divide
|
283
|
+
# by number of experts, the tracing creates an integer value that exceeds the valid int64
|
284
|
+
# range and will throw a hard to decipher error message.
|
285
|
+
seq_length = torch.export.Dim(
|
286
|
+
"seq_length", min=2, max=MAX_DYNAMIC_VAL // num_experts_per_tok
|
287
|
+
)
|
288
|
+
batch_shape = {0: torch.export.Dim("batch_size")} if self.dynamic_batch else {}
|
261
289
|
|
262
290
|
# This currently assumes that all HF language models have inputs that are [B, NUM_TOKENS].
|
263
|
-
inputs.append(
|
291
|
+
inputs.append(
|
292
|
+
torch.tile(input_dict["input_ids"].to(torch.int32), [self.batch_size, 1])
|
293
|
+
)
|
264
294
|
input_tensor_semantics.append(
|
265
295
|
TensorSemantics(
|
266
296
|
dimensions=[
|
@@ -269,76 +299,84 @@ class HuggingFaceTextGenerationPipeline:
|
|
269
299
|
],
|
270
300
|
)
|
271
301
|
)
|
302
|
+
dynamic_shapes.append({1: seq_length} | batch_shape)
|
303
|
+
|
304
|
+
cache_length = torch.export.Dim("cache_length", min=2, max=MAX_DYNAMIC_VAL)
|
272
305
|
|
273
306
|
# Assume that the model supports a KeyValue cache.
|
274
307
|
cache_values = self._initialize_key_value_cache()
|
308
|
+
cache_shapes = []
|
275
309
|
inputs.append(cache_values)
|
276
310
|
for _ in range(len(cache_values)):
|
277
311
|
input_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
|
278
312
|
input_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
and tensor.shape[1] == seqlen
|
295
|
-
):
|
296
|
-
# This should be the logits tensor.
|
297
|
-
output_tensor_semantics.append(
|
298
|
-
TensorSemantics(
|
299
|
-
dimensions=[
|
300
|
-
BatchDimension(),
|
301
|
-
TokensDimension(),
|
302
|
-
VocabLogitsDimension(),
|
303
|
-
],
|
313
|
+
cache_shapes.append(
|
314
|
+
[{2: cache_length} | batch_shape, {2: cache_length} | batch_shape]
|
315
|
+
)
|
316
|
+
dynamic_shapes.append(cache_shapes)
|
317
|
+
|
318
|
+
if trace_w_attn_mask_and_pos_ids:
|
319
|
+
input_len = input_dict["input_ids"].shape[1]
|
320
|
+
# Assume that the model supports a 4D attention mask.
|
321
|
+
# This is typically an optional input and not specifying it means we treat it as a causal mask,
|
322
|
+
# however in scenarios where we have padded inputs or KV caches, this may be explicitly set.
|
323
|
+
inputs.append(
|
324
|
+
torch.triu(
|
325
|
+
torch.ones(
|
326
|
+
(input_len, input_len + self._cache_length_for_tracing),
|
327
|
+
dtype=torch.float16,
|
304
328
|
)
|
329
|
+
* (-float("inf")),
|
330
|
+
diagonal=1,
|
331
|
+
).expand(self.batch_size, 1, -1, -1)
|
332
|
+
)
|
333
|
+
input_tensor_semantics.append(
|
334
|
+
TensorSemantics(
|
335
|
+
dimensions=[
|
336
|
+
BatchDimension(),
|
337
|
+
IgnoreDimension(),
|
338
|
+
AttentionMaskDimension(),
|
339
|
+
AttentionMaskDimension(),
|
340
|
+
],
|
305
341
|
)
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
342
|
+
)
|
343
|
+
seq_and_cache_length = torch.export.Dim(
|
344
|
+
"seq_and_cache_length",
|
345
|
+
min=4,
|
346
|
+
max=MAX_DYNAMIC_VAL + MAX_DYNAMIC_VAL // num_experts_per_tok,
|
347
|
+
)
|
348
|
+
dynamic_shapes.append(
|
349
|
+
{2: seq_length, 3: seq_and_cache_length} | batch_shape
|
313
350
|
)
|
314
351
|
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
)
|
352
|
+
# Assume that the model supports position ids.
|
353
|
+
inputs.append(
|
354
|
+
torch.arange(
|
355
|
+
self._cache_length_for_tracing,
|
356
|
+
self._cache_length_for_tracing + input_len,
|
357
|
+
dtype=torch.int32,
|
358
|
+
).expand(self.batch_size, -1)
|
359
|
+
)
|
360
|
+
input_tensor_semantics.append(
|
361
|
+
TensorSemantics(
|
362
|
+
dimensions=[BatchDimension(), PositionIDsDimension()],
|
363
|
+
)
|
364
|
+
)
|
365
|
+
dynamic_shapes.append({1: seq_length} | batch_shape)
|
330
366
|
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
367
|
+
# Since we wrap the model with WrapWithFunctionalCache, the outputs are well defined.
|
368
|
+
output_tensor_semantics = [
|
369
|
+
TensorSemantics(
|
370
|
+
dimensions=[
|
371
|
+
BatchDimension(),
|
372
|
+
TokensDimension(),
|
373
|
+
VocabLogitsDimension(),
|
374
|
+
],
|
375
|
+
),
|
376
|
+
] + [
|
377
|
+
AttentionKeyValueCacheTensorSemantics()
|
378
|
+
for _ in range(len(cache_values) * 2)
|
335
379
|
]
|
336
|
-
if self.dynamic_batch:
|
337
|
-
batch = torch.export.Dim("batch")
|
338
|
-
dynamic_shapes[0][0] = batch
|
339
|
-
for i in range(len(cache_values)):
|
340
|
-
dynamic_shapes[1][i][0][0] = batch
|
341
|
-
dynamic_shapes[1][i][1][0] = batch
|
342
380
|
|
343
381
|
return {
|
344
382
|
"example_inputs": inputs,
|
gml/model_utils.py
CHANGED
@@ -15,11 +15,13 @@
|
|
15
15
|
# SPDX-License-Identifier: Apache-2.0
|
16
16
|
|
17
17
|
|
18
|
-
def prepare_ultralytics_yolo(model):
|
18
|
+
def prepare_ultralytics_yolo(model, example_inputs, num_iters=2):
|
19
19
|
"""Prepares an ultralytics YOLO model for export.
|
20
20
|
|
21
21
|
Ultralytics YOLO models requires setting `export=True` on some of the torch modules for exporting to work properly.
|
22
22
|
This function handles setting that value on the necessary modules.
|
23
|
+
|
24
|
+
This also runs forward passes on the model to stabilize the exported weights.
|
23
25
|
"""
|
24
26
|
if not hasattr(model, "model"):
|
25
27
|
raise ValueError(
|
@@ -33,3 +35,10 @@ def prepare_ultralytics_yolo(model):
|
|
33
35
|
m.export = True
|
34
36
|
# YOLOv8 requires setting `format` when `export = True`
|
35
37
|
m.format = "custom"
|
38
|
+
|
39
|
+
# Run a couple of forward passes as a warmup since the exported weights seem to change
|
40
|
+
# after a forward run.
|
41
|
+
# See https://github.com/ultralytics/yolov5/blob/2540fd4c1c2d9186126a71b3eb681d3a0a11861e/models/yolo.py#L118
|
42
|
+
model.model.eval().to("cpu")
|
43
|
+
for _ in range(num_iters):
|
44
|
+
model.model(*example_inputs)
|