gimlet-api 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,10 +13,10 @@ Requires-Dist: torch>=2.6.0
13
13
  Requires-Dist: torch-mlir-gml
14
14
  Requires-Dist: numpy<2.0.0
15
15
  Requires-Dist: rich
16
- Requires-Dist: transformers>=4.43.3
16
+ Requires-Dist: transformers>=4.53.0
17
17
  Requires-Dist: tokenizers>=0.21.0
18
18
  Requires-Dist: safetensors-mlir
19
19
  Requires-Dist: packaging
20
- Version: 0.0.13
20
+ Version: 0.0.15
21
21
 
22
22
  UNKNOWN
@@ -1,13 +1,13 @@
1
1
  gml/__init__.py,sha256=H3WQZ_RaN7VNeb__qeHEbKLEwkaG7gpL5FQ8s1IotUA,773
2
2
  gml/_utils.py,sha256=mSCWHhCdzcUvHqmJIB2FS215K1LMgJCWcZ6e6FWK3hQ,1184
3
3
  gml/asset_manager.py,sha256=VnbqUZHPOgPrAh6ri9C0EuNhS8tAHIrbUyJPAJuD9po,2053
4
- gml/client.py,sha256=ztYvImrY_o8tR-_RJMXdDPVkqzjfVsF8jjM3CUhnFOY,14441
5
- gml/compile.py,sha256=WnsVgZTaiW7Uh-D_ObkX1ee9p4_8PDAF2KIQCiHAbFA,10744
6
- gml/device.py,sha256=Iw71NnuLcgjY32ZMXHlnlPkosTuHEmL9E98utmNChlM,2650
7
- gml/hf.py,sha256=Kv2yffy8omTRQDPnoIZocG2EOyfhr7UvLFIvTmRxw0g,36170
4
+ gml/client.py,sha256=3rqVTMSv7QUlSrMqSQDzXuXDvfm3WTF3G8enK8on9zU,15861
5
+ gml/compile.py,sha256=1zM0ihwbrptZ4FphyvtvfKJyVFuVVfyBnDnnEeJ12fg,10397
6
+ gml/device.py,sha256=9Z7dsBfpvTHShd1OWSi1Pvn85EFYDmn1dszWV8YHIJI,2648
7
+ gml/hf.py,sha256=4tU2c3Th_mc__78Odetg2g3b16eZM4oSevUtMr27H_k,37811
8
8
  gml/model.py,sha256=8fIYlLRduTsUZfYJr_YVPNxbEVIzr7_yaaTe4T-TZ2Y,8429
9
- gml/model_utils.py,sha256=vZvE5cHZIDkUkeZ4Pk4hhV-zOYMiREluv4b8kdqQ3Ig,1375
10
- gml/pipelines.py,sha256=hjsh7yNICDsjyKB8gQh9rtpwmfSk2q6otGogc932eLA,8454
9
+ gml/model_utils.py,sha256=Kw08MIPmwIOocoQXfjlqjn78mVerCQ2uzleT0H_zcck,1821
10
+ gml/pipelines.py,sha256=Nif9FNqXqjYM6iL3RZzkX-KBsZamSu7xh8HwjNqsB5A,10220
11
11
  gml/preprocessing.py,sha256=YPcxwBOdx0h0ADzoloYbFw9qUGFbi167E8HA4Zwn7Pk,3928
12
12
  gml/proto/gogoproto/gogo_pb2.py,sha256=WVMIAR8K--mCUkTPM7mEeeXGpQlRRtt_kco10iP3CZs,15728
13
13
  gml/proto/mediapipe/framework/calculator_contract_test_pb2.py,sha256=hNjyZCBz3RYa6rN4xR3FOCZKA24gq_LsJ3EMegl5wK4,2031
@@ -22,33 +22,36 @@ gml/proto/mediapipe/framework/status_handler_pb2.py,sha256=dgiW2ohm-ho07z1k4TM_X
22
22
  gml/proto/mediapipe/framework/stream_handler_pb2.py,sha256=kNo-2Fdua_CeyJInI3q5r9IoAUanjhk9jh01Z1KXu6Q,2043
23
23
  gml/proto/mediapipe/framework/test_calculators_pb2.py,sha256=tXF25VpGtHGArffRqFmjD6FO7xmuCPd5j9UYON2SVSM,2230
24
24
  gml/proto/mediapipe/framework/thread_pool_executor_pb2.py,sha256=9TJ66fqSo1BiJmEAQesK0fnVe55zcJpOqVip6HotgyE,2345
25
- gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=wQjeDti-C8JiNwRn-z5M5p-Fqxm-SmnbPaoitJcSK-4,2860
25
+ gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=2l9c_xGfUvShLFkzofChHmbgpa7I0-u-FJ_J2Wv3lvs,3168
26
26
  gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=k8oW5tmFlJK2574Ky6kDc0JmNNQCLroRwCCGyxDd7JA,9968
27
- gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=cbNmE12Nm3PjW4NXU7-Z-9m_0Zs3Ab8R1xLkDnvclCg,1730
28
- gml/proto/src/api/corepb/v1/compiled_pipeline_pb2.py,sha256=g3MxBqshtwaM9_Nrbvwo995_XWq-maXGP6mDeiEzZKo,7529
29
- gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=r0tM1XFlorKCCv5hCYEq8-LXo7moJ4C8PXFGQl-GqiU,14694
30
- gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=KuLwOZprktJ6aFnUM_OkHucQXSKQ8yg-3lRHOZidWeI,21014
27
+ gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=08f2F5overFnGlyNyZHb5rUyv7-G9pC15c4xDCccIPY,1831
28
+ gml/proto/src/api/corepb/v1/actor_net_pb2.py,sha256=XV3UZbyxvHcXRG-kFdA5LImaTAnSVRw2scw8Cz5Mn6Q,11009
29
+ gml/proto/src/api/corepb/v1/compiled_pipeline_pb2.py,sha256=K2xhqSAxQ3w_3VFC43XpeqXm9_EDmVz6Fb8PrA_jxOY,7609
30
+ gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=n6LOgoQCAJZ2LJivb9DffmzLpYY7e9K1EDT8pQsIbno,14705
31
+ gml/proto/src/api/corepb/v1/cp_dp_pb2.py,sha256=cHJkgfehbgnohipFXMtoHSzR5mGfVZ9oKwOVFjEMRc4,10358
32
+ gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=Ruw7_GcoElKzrhEoOMYkUezJDvrjwA29cVMtBhfV8I8,23294
33
+ gml/proto/src/api/corepb/v1/dataplane_pb2.py,sha256=D3nA4c8624Irh1cIWM8rbvUBUmr29CJ0lQlPko2BgMU,1966
31
34
  gml/proto/src/api/corepb/v1/deployed_pipeline_pb2.py,sha256=XbppBI1fQ-FazD2in1o6Z9_BIPRBArCE5dVUF7iUn3Y,6649
32
- gml/proto/src/api/corepb/v1/device_info_pb2.py,sha256=hcnU9CSZjTa0liXMGPLOos1oSKvF3jQdUaAgXZSqFS0,6760
35
+ gml/proto/src/api/corepb/v1/device_info_pb2.py,sha256=lXFF04AkL_Y3Tcg9mXAvgo-w3lmAMiuHQZT5yLpyO4s,7029
33
36
  gml/proto/src/api/corepb/v1/gem_config_pb2.py,sha256=vC0g3k9hDv-LhiV6LwaYCly6x00Xx_YA0i2AZSwCo_I,5396
34
- gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=LB5YJNw_MMfFa4hgfWhpqp4yG2rTzxKZa4L3vzsB_lU,9838
35
- gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=Z4y7P6nyO_6dwhEkv7qhsYKEyAJVEB4nS41LdO1NpYA,34465
36
- gml/proto/src/common/typespb/jwt_pb2.py,sha256=lxy-bqbyg96i9n_xr2JbkuWX-ldnoJavXPMnApzVSio,5580
37
+ gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=fAB7s7w4soBtaWJXwni5OI--lWapnLM1LeqZzIBWnlo,10359
38
+ gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=eh6-tQPq5GEhsuRM6IgJRc2PvKhMlQADu_Lj7FZN5O8,37754
39
+ gml/proto/src/common/typespb/jwt_pb2.py,sha256=JxBZr8JU1mBoo1PKPClXr3SdfjZynRYRlQ-JHZRjqhE,6134
37
40
  gml/proto/src/common/typespb/status_pb2.py,sha256=IbBJnbsAlvsuTtyT285ZuW6k5VaPfl5kRSOnBxD_H8M,2109
38
41
  gml/proto/src/common/typespb/uuid_pb2.py,sha256=5Fm3jYpCPX7sMrP6RhRYsF0SnuZNIBEQJk9f0jwZ2Rw,1188
39
- gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256=4mp1QWV7FOzF_nC3RDKZ9vTA-ezMhukcjBEt1lcjGmM,4933
42
+ gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256=MIizns1dezQjocpJjeNJ4Z7BFWNweKrRpJ070L9IaCk,5203
40
43
  gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2_grpc.py,sha256=l-gTK9nYpTlVb7QGAckSQXlHhkRdKe2-nrxXc8NQavY,2912
41
44
  gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py,sha256=S3OzKYO34BRuYs3rSKbLfjAgm3LQb6wQFS-sfFdQSfk,11496
42
45
  gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py,sha256=p3OpT8-hfNHu4-29qr-ZahRwO-LoCYM9Q4jomAHTXGA,24572
43
46
  gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py,sha256=r8mbJNTq45_c0amPnTr8OFZasCk7XWu2YS_eu7GfWJg,7050
44
47
  gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py,sha256=XlE4R2PJaOmzQocx7y6SKJvuqt8tYBGzBuhajvzG0cc,12919
45
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=1-d46lvO8c80Rj7rWNaExSQsNeH9CioHilP9wW_o6I8,7985
46
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=cX7p2xLe01WzYTaB2TzqhePUUhaZkTE4iOAZzHaklmQ,11634
48
+ gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=CaRhKtLOcz9AhIw9Rxws0ALqmFAktqkQyRlgmoA6OF0,8976
49
+ gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=NA_Ud55lnbJfgOw729j97MqfKuylCopiimt6pryNJoU,13740
47
50
  gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py,sha256=IryUZ-TlpQKvU52-XFRKvuAfiH-0EkXrTzwvmmK7Fmk,4591
48
51
  gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py,sha256=ZABewZxEthfQx2pEfvBfLc4M8JoKazheQOH181CegjY,6586
49
52
  gml/register_submodules.py,sha256=U8IwjVygX2vxNi_aK6ljHOD4mmrOhbyVczvy4wwulqU,5027
50
- gml/tensor.py,sha256=aPLm3I3qkYNDcJmntaUycqqN5rsZmcj8ql0EkupJudY,14977
53
+ gml/tensor.py,sha256=ojRlfMEf5wsLyOHuAVl0wuZlcaqO0KF4EDYdtEju6hk,15229
51
54
  gml/version_utils.py,sha256=ouCemolnoDm71NiQRcfpa5k5bETTLaFCH6lrEyivGNY,1626
52
- gimlet_api-0.0.13.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
53
- gimlet_api-0.0.13.dist-info/METADATA,sha256=iU7mWRms-_xioOe-DyeOgOE4oTfJ5c1f5NDFSp2Ew90,611
54
- gimlet_api-0.0.13.dist-info/RECORD,,
55
+ gimlet_api-0.0.15.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
56
+ gimlet_api-0.0.15.dist-info/METADATA,sha256=E-gLFjOe2M6B_KiEyg41boQKRbEw7ej_FwV-pJaop1k,611
57
+ gimlet_api-0.0.15.dist-info/RECORD,,
gml/client.py CHANGED
@@ -16,6 +16,7 @@
16
16
 
17
17
  import os
18
18
  import uuid
19
+ import warnings
19
20
  from pathlib import Path
20
21
  from typing import BinaryIO, List, Optional, TextIO, Union
21
22
  from urllib.parse import quote
@@ -24,6 +25,8 @@ import grpc
24
25
  from rich.progress import (
25
26
  Console,
26
27
  )
28
+ from tqdm import TqdmExperimentalWarning
29
+ from tqdm.rich import tqdm
27
30
 
28
31
  import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
29
32
  import gml.proto.src.common.typespb.uuid_pb2 as uuidpb
@@ -42,6 +45,9 @@ from gml.device import DeviceCapabilities
42
45
  from gml.model import Model
43
46
  from gml.pipelines import Pipeline
44
47
 
48
+ # Filter out tqdm experimental warnings for the rich progress bar.
49
+ warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
50
+
45
51
  DEFAULT_CONTROLPLANE_ADDR = "app.gimletlabs.ai"
46
52
  console = Console()
47
53
 
@@ -206,15 +212,27 @@ class Client:
206
212
  file_id: uuidpb.UUID,
207
213
  sha256: str,
208
214
  file: TextIO | BinaryIO,
215
+ display_name: str,
209
216
  chunk_size=1024 * 1024,
210
217
  ):
211
218
  def chunked_requests():
219
+ file.seek(0, os.SEEK_END)
220
+ total_size = file.tell()
221
+
212
222
  file.seek(0)
213
- for chunk in chunk_file(file, chunk_size):
214
- req = ftpb.UploadFileRequest(
215
- file_id=file_id, sha256sum=sha256, chunk=chunk
216
- )
217
- yield req
223
+
224
+ with tqdm(
225
+ total=total_size,
226
+ desc=f"Uploading {display_name}",
227
+ unit="B",
228
+ unit_scale=True,
229
+ ) as pbar:
230
+ for chunk in chunk_file(file, chunk_size):
231
+ pbar.update(len(chunk))
232
+ req = ftpb.UploadFileRequest(
233
+ file_id=file_id, sha256sum=sha256, chunk=chunk
234
+ )
235
+ yield req
218
236
 
219
237
  stub = self._fts_stub()
220
238
  resp: ftpb.UploadFileResponse = stub.UploadFile(
@@ -226,6 +244,7 @@ class Client:
226
244
  self,
227
245
  name: str,
228
246
  file: TextIO | BinaryIO,
247
+ display_name: str,
229
248
  sha256: Optional[str] = None,
230
249
  chunk_size=1024 * 1024,
231
250
  ) -> ftpb.FileInfo:
@@ -233,18 +252,30 @@ class Client:
233
252
 
234
253
  if sha256 is None:
235
254
  sha256 = sha256sum(file)
236
- self._upload_created_file(file_info.file_id, sha256, file, chunk_size)
255
+ self._upload_created_file(
256
+ file_id=file_info.file_id,
257
+ sha256=sha256,
258
+ file=file,
259
+ display_name=display_name,
260
+ chunk_size=chunk_size,
261
+ )
237
262
  return self._file_info_by_name(name)
238
263
 
239
264
  def _upload_file_if_not_exists(
240
265
  self,
241
- name: str,
266
+ name: str, # name is what is stored in the file service, typically a sha256 of the file.
242
267
  file: TextIO | BinaryIO,
268
+ display_name: str,
243
269
  sha256: Optional[str] = None,
244
270
  ) -> ftpb.FileInfo:
245
271
  file_info: Optional[ftpb.FileInfo] = None
246
272
  try:
247
- file_info = self.upload_file(name, file, sha256)
273
+ file_info = self.upload_file(
274
+ name=name,
275
+ file=file,
276
+ display_name=display_name,
277
+ sha256=sha256,
278
+ )
248
279
  except FileAlreadyExists:
249
280
  file_info = self._file_info_by_name(name)
250
281
 
@@ -252,7 +283,12 @@ class Client:
252
283
  case ftpb.FILE_STATUS_READY:
253
284
  pass
254
285
  case ftpb.FILE_STATUS_CREATED:
255
- self._upload_created_file(file_info.file_id, sha256, file)
286
+ self._upload_created_file(
287
+ file_id=file_info.file_id,
288
+ sha256=sha256,
289
+ file=file,
290
+ display_name=display_name,
291
+ )
256
292
  file_info = self._file_info_by_name(name)
257
293
  case _:
258
294
  raise Exception("file status is deleted or unknown, cannot re-upload")
@@ -292,23 +328,32 @@ class Client:
292
328
  )
293
329
  return existing_model
294
330
  model_info = model.to_proto()
295
- with console.status(f'Creating model "{model.name}"...'):
296
- with model.collect_assets() as model_assets:
297
- for asset_name, file in model_assets.items():
298
- if isinstance(file, Path) or isinstance(file, str):
299
- file = open(file, "rb")
300
-
301
- sha256 = sha256sum(file)
302
-
303
- upload_name = model.name
304
- if asset_name:
305
- upload_name += ":" + asset_name
306
- file_info = self._upload_file_if_not_exists(sha256, file, sha256)
307
- console.print(f"Uploaded {upload_name}.")
331
+ # Don't use context lib because we want to show progress for asset collection, then will have a separate
332
+ # progress bar for file upload.
333
+ status = console.status(f'Tracing model "{model.name}"...')
334
+ status.start()
335
+ with model.collect_assets() as model_assets:
336
+ status.stop()
337
+ for asset_name, file in model_assets.items():
338
+ if isinstance(file, Path) or isinstance(file, str):
339
+ file = open(file, "rb")
340
+
341
+ sha256 = sha256sum(file)
342
+
343
+ display_name = model.name
344
+ if asset_name:
345
+ display_name += ":" + asset_name
346
+ file_info = self._upload_file_if_not_exists(
347
+ name=sha256,
348
+ file=file,
349
+ display_name=display_name,
350
+ sha256=sha256,
351
+ )
352
+ console.print(f"Uploaded {display_name}.")
308
353
 
309
- model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
354
+ model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
310
355
 
311
- file.close()
356
+ file.close()
312
357
 
313
358
  return self._create_model(model_info)
314
359
 
@@ -350,7 +395,7 @@ class Client:
350
395
  req, metadata=self._get_request_metadata(idempotent=True)
351
396
  )
352
397
 
353
- url = f"https://{self._controlplane_addr}/orgs/{quote(self._org_name)}/pipeline/{quote(name)}"
398
+ url = f"https://{self._controlplane_addr}/orgs/{quote(self._org_name)}/workloads/{quote(name, safe='')}"
354
399
  console.print(
355
400
  f"[green]Pipeline upload complete![/green]\nView your pipeline at: [cyan]{url}[/cyan]"
356
401
  )
gml/compile.py CHANGED
@@ -253,18 +253,7 @@ def to_torch_mlir(
253
253
  if decomposition_denylist is None:
254
254
  decomposition_denylist = _default_decomposition_denylist()
255
255
 
256
- model = model.eval().to("cpu")
257
-
258
256
  submodule_registration_workarounds(model)
259
-
260
- try:
261
- # Running the model a few times on the inputs, leads to more consistent compiled results.
262
- for _ in range(2):
263
- _ = model(*example_inputs)
264
- except: # noqa
265
- # Ignore errors running the model. This can happen when the model has data dependent branches.
266
- pass
267
-
268
257
  register_dynamic_cache_pytree_node()
269
258
  prog = _export(
270
259
  model,
gml/device.py CHANGED
@@ -57,8 +57,8 @@ def _runtime_str_to_runtime_protos(
57
57
  return deviceinfopb.ModelRuntimeType.MODEL_RUNTIME_TYPE_TENSORRT
58
58
  case "openvino":
59
59
  return deviceinfopb.ModelRuntimeType.MODEL_RUNTIME_TYPE_OPENVINO
60
- case "hailort":
61
- return deviceinfopb.ModelRuntimeType.MODEL_RUNTIME_TYPE_HAILORT
60
+ case "habana":
61
+ return deviceinfopb.ModelRuntimeType.MODEL_RUNTIME_TYPE_HABANA
62
62
  case _:
63
63
  raise ValueError("invalid runtime: {}".format(runtime))
64
64
 
gml/hf.py CHANGED
@@ -27,6 +27,7 @@ import transformers
27
27
  from rich.progress import Console
28
28
  from transformers import (
29
29
  BaseImageProcessor,
30
+ Cache,
30
31
  DynamicCache,
31
32
  Pipeline,
32
33
  PreTrainedModel,
@@ -52,9 +53,11 @@ from gml.tensor import (
52
53
  DetectionOutputDimension,
53
54
  DimensionSemantics,
54
55
  EmbeddingDimension,
56
+ IgnoreDimension,
55
57
  ImageChannelDimension,
56
58
  ImageHeightDimension,
57
59
  ImageWidthDimension,
60
+ PositionIDsDimension,
58
61
  SegmentationMaskChannel,
59
62
  TensorSemantics,
60
63
  TokensDimension,
@@ -117,10 +120,18 @@ class WrapWithFunctionalCache(torch.nn.Module):
117
120
  super().__init__()
118
121
  self.model = model
119
122
 
120
- def forward(self, input_ids, cache):
123
+ def forward(
124
+ self,
125
+ input_ids: torch.LongTensor,
126
+ past_key_values: Cache,
127
+ attention_mask: Optional[torch.Tensor] = None,
128
+ position_ids: Optional[torch.LongTensor] = None,
129
+ ):
121
130
  outputs = self.model(
122
131
  input_ids=input_ids,
123
- past_key_values=cache,
132
+ attention_mask=attention_mask,
133
+ position_ids=position_ids,
134
+ past_key_values=past_key_values,
124
135
  return_dict=True,
125
136
  use_cache=True,
126
137
  )
@@ -134,7 +145,7 @@ class HuggingFaceTextGenerationPipeline:
134
145
  pipeline: Pipeline,
135
146
  name: Optional[str] = None,
136
147
  tokenizer_name: Optional[str] = None,
137
- dynamic_seqlen: bool = False,
148
+ trace_w_attn_mask_and_pos_ids: bool = False,
138
149
  dynamic_batch: bool = False,
139
150
  export_predispatch: bool = False,
140
151
  ):
@@ -158,7 +169,7 @@ class HuggingFaceTextGenerationPipeline:
158
169
  name,
159
170
  torch_module=self.model,
160
171
  export_predispatch=export_predispatch,
161
- **self._guess_model_spec(dynamic_seqlen),
172
+ **self._guess_model_spec(trace_w_attn_mask_and_pos_ids),
162
173
  )
163
174
 
164
175
  def _initialize_key_value_cache(self) -> DynamicCache:
@@ -249,7 +260,13 @@ class HuggingFaceTextGenerationPipeline:
249
260
  ),
250
261
  )
251
262
 
252
- def _guess_model_spec(self, dynamic_seqlen: bool) -> Dict:
263
+ def _guess_model_spec(self, trace_w_attn_mask_and_pos_ids: bool) -> Dict:
264
+ num_experts_per_tok = (
265
+ 1
266
+ if not hasattr(self.pipeline.model.config, "num_experts_per_tok")
267
+ else self.pipeline.model.config.num_experts_per_tok
268
+ )
269
+
253
270
  input_dict = self.pipeline.preprocess("this is a prompt! Test test test?")
254
271
  if "input_ids" not in input_dict:
255
272
  raise ValueError(
@@ -258,9 +275,22 @@ class HuggingFaceTextGenerationPipeline:
258
275
 
259
276
  inputs = []
260
277
  input_tensor_semantics = []
278
+ dynamic_shapes = []
279
+
280
+ # Set range to half of seq_length to account for # of tokens per expert.
281
+ # pytorch export creates a constraint on the number of possible tokens
282
+ # sent to each expert. That value is num_experts * seq_length. If we don't divide
283
+ # by number of experts, the tracing creates an integer value that exceeds the valid int64
284
+ # range and will throw a hard to decipher error message.
285
+ seq_length = torch.export.Dim(
286
+ "seq_length", min=2, max=MAX_DYNAMIC_VAL // num_experts_per_tok
287
+ )
288
+ batch_shape = {0: torch.export.Dim("batch_size")} if self.dynamic_batch else {}
261
289
 
262
290
  # This currently assumes that all HF language models have inputs that are [B, NUM_TOKENS].
263
- inputs.append(torch.tile(input_dict["input_ids"], [self.batch_size, 1]))
291
+ inputs.append(
292
+ torch.tile(input_dict["input_ids"].to(torch.int32), [self.batch_size, 1])
293
+ )
264
294
  input_tensor_semantics.append(
265
295
  TensorSemantics(
266
296
  dimensions=[
@@ -269,76 +299,84 @@ class HuggingFaceTextGenerationPipeline:
269
299
  ],
270
300
  )
271
301
  )
302
+ dynamic_shapes.append({1: seq_length} | batch_shape)
303
+
304
+ cache_length = torch.export.Dim("cache_length", min=2, max=MAX_DYNAMIC_VAL)
272
305
 
273
306
  # Assume that the model supports a KeyValue cache.
274
307
  cache_values = self._initialize_key_value_cache()
308
+ cache_shapes = []
275
309
  inputs.append(cache_values)
276
310
  for _ in range(len(cache_values)):
277
311
  input_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
278
312
  input_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
279
-
280
- outputs = self.model(*inputs)
281
-
282
- # Determine output semantics.
283
- output_tensor_semantics = []
284
- seqlen = inputs[0].shape[1]
285
- found_logits = False
286
- for tensor in flatten(outputs):
287
- if not isinstance(tensor, torch.Tensor):
288
- continue
289
-
290
- if (
291
- not found_logits
292
- and len(tensor.shape) == 3
293
- and tensor.shape[0] == self.batch_size
294
- and tensor.shape[1] == seqlen
295
- ):
296
- # This should be the logits tensor.
297
- output_tensor_semantics.append(
298
- TensorSemantics(
299
- dimensions=[
300
- BatchDimension(),
301
- TokensDimension(),
302
- VocabLogitsDimension(),
303
- ],
313
+ cache_shapes.append(
314
+ [{2: cache_length} | batch_shape, {2: cache_length} | batch_shape]
315
+ )
316
+ dynamic_shapes.append(cache_shapes)
317
+
318
+ if trace_w_attn_mask_and_pos_ids:
319
+ input_len = input_dict["input_ids"].shape[1]
320
+ # Assume that the model supports a 4D attention mask.
321
+ # This is typically an optional input and not specifying it means we treat it as a causal mask,
322
+ # however in scenarios where we have padded inputs or KV caches, this may be explicitly set.
323
+ inputs.append(
324
+ torch.triu(
325
+ torch.ones(
326
+ (input_len, input_len + self._cache_length_for_tracing),
327
+ dtype=torch.float16,
304
328
  )
329
+ * (-float("inf")),
330
+ diagonal=1,
331
+ ).expand(self.batch_size, 1, -1, -1)
332
+ )
333
+ input_tensor_semantics.append(
334
+ TensorSemantics(
335
+ dimensions=[
336
+ BatchDimension(),
337
+ IgnoreDimension(),
338
+ AttentionMaskDimension(),
339
+ AttentionMaskDimension(),
340
+ ],
305
341
  )
306
- found_logits = True
307
- else:
308
- output_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
309
-
310
- if not found_logits:
311
- raise ValueError(
312
- "could not determine output logits tensor for text generation model"
342
+ )
343
+ seq_and_cache_length = torch.export.Dim(
344
+ "seq_and_cache_length",
345
+ min=4,
346
+ max=MAX_DYNAMIC_VAL + MAX_DYNAMIC_VAL // num_experts_per_tok,
347
+ )
348
+ dynamic_shapes.append(
349
+ {2: seq_length, 3: seq_and_cache_length} | batch_shape
313
350
  )
314
351
 
315
- num_experts_per_tok = (
316
- 1
317
- if not hasattr(self.pipeline.model.config, "num_experts_per_tok")
318
- else self.pipeline.model.config.num_experts_per_tok
319
- )
320
-
321
- dynamic_shapes = None
322
- # Set range to half of seqlen to account for # of tokens per expert.
323
- # pytorch export creates a constraint on the number of possible tokens
324
- # sent to each expert. That value is num_experts * seqlen. If we don't divide
325
- # by number of experts, the tracing creates an integer value that exceeds the valid int64
326
- # range and will throw a hard to decipher error message.
327
- seqlen = torch.export.Dim(
328
- "seqlen", min=2, max=MAX_DYNAMIC_VAL // num_experts_per_tok
329
- )
352
+ # Assume that the model supports position ids.
353
+ inputs.append(
354
+ torch.arange(
355
+ self._cache_length_for_tracing,
356
+ self._cache_length_for_tracing + input_len,
357
+ dtype=torch.int32,
358
+ ).expand(self.batch_size, -1)
359
+ )
360
+ input_tensor_semantics.append(
361
+ TensorSemantics(
362
+ dimensions=[BatchDimension(), PositionIDsDimension()],
363
+ )
364
+ )
365
+ dynamic_shapes.append({1: seq_length} | batch_shape)
330
366
 
331
- cache_length = torch.export.Dim("cache_length", min=2, max=MAX_DYNAMIC_VAL)
332
- dynamic_shapes = [
333
- {1: seqlen},
334
- [[{2: cache_length}, {2: cache_length}] for _ in range(len(cache_values))],
367
+ # Since we wrap the model with WrapWithFunctionalCache, the outputs are well defined.
368
+ output_tensor_semantics = [
369
+ TensorSemantics(
370
+ dimensions=[
371
+ BatchDimension(),
372
+ TokensDimension(),
373
+ VocabLogitsDimension(),
374
+ ],
375
+ ),
376
+ ] + [
377
+ AttentionKeyValueCacheTensorSemantics()
378
+ for _ in range(len(cache_values) * 2)
335
379
  ]
336
- if self.dynamic_batch:
337
- batch = torch.export.Dim("batch")
338
- dynamic_shapes[0][0] = batch
339
- for i in range(len(cache_values)):
340
- dynamic_shapes[1][i][0][0] = batch
341
- dynamic_shapes[1][i][1][0] = batch
342
380
 
343
381
  return {
344
382
  "example_inputs": inputs,
gml/model_utils.py CHANGED
@@ -15,11 +15,13 @@
15
15
  # SPDX-License-Identifier: Apache-2.0
16
16
 
17
17
 
18
- def prepare_ultralytics_yolo(model):
18
+ def prepare_ultralytics_yolo(model, example_inputs, num_iters=2):
19
19
  """Prepares an ultralytics YOLO model for export.
20
20
 
21
21
  Ultralytics YOLO models requires setting `export=True` on some of the torch modules for exporting to work properly.
22
22
  This function handles setting that value on the necessary modules.
23
+
24
+ This also runs forward passes on the model to stabilize the exported weights.
23
25
  """
24
26
  if not hasattr(model, "model"):
25
27
  raise ValueError(
@@ -33,3 +35,10 @@ def prepare_ultralytics_yolo(model):
33
35
  m.export = True
34
36
  # YOLOv8 requires setting `format` when `export = True`
35
37
  m.format = "custom"
38
+
39
+ # Run a couple of forward passes as a warmup since the exported weights seem to change
40
+ # after a forward run.
41
+ # See https://github.com/ultralytics/yolov5/blob/2540fd4c1c2d9186126a71b3eb681d3a0a11861e/models/yolo.py#L118
42
+ model.model.eval().to("cpu")
43
+ for _ in range(num_iters):
44
+ model.model(*example_inputs)