gimlet-api 0.0.9__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,11 +9,13 @@ Classifier: Typing :: Typed
9
9
  Requires-Python: >=3
10
10
  Requires-Dist: protobuf
11
11
  Requires-Dist: grpcio
12
- Requires-Dist: torch>=2.3.0
12
+ Requires-Dist: torch>=2.6.0
13
13
  Requires-Dist: torch-mlir-gml
14
14
  Requires-Dist: numpy<2.0.0
15
+ Requires-Dist: rich
15
16
  Requires-Dist: transformers>=4.43.3
17
+ Requires-Dist: tokenizers>=0.21.0
16
18
  Requires-Dist: safetensors-mlir
17
- Version: 0.0.9
19
+ Version: 0.0.10
18
20
 
19
21
  UNKNOWN
@@ -1,14 +1,14 @@
1
1
  gml/__init__.py,sha256=H3WQZ_RaN7VNeb__qeHEbKLEwkaG7gpL5FQ8s1IotUA,773
2
2
  gml/_utils.py,sha256=mSCWHhCdzcUvHqmJIB2FS215K1LMgJCWcZ6e6FWK3hQ,1184
3
3
  gml/asset_manager.py,sha256=VnbqUZHPOgPrAh6ri9C0EuNhS8tAHIrbUyJPAJuD9po,2053
4
- gml/client.py,sha256=jPA71PTbv-4DX2FsfFcj1VPN-eMGdF4tKCj2NnOW7ZE,13862
5
- gml/compile.py,sha256=Ih43r_zU07p91w9aiA0lrPJfmACpAWg0x_HFddMSy7Q,8346
4
+ gml/client.py,sha256=AcnG5mniHOfq-He-uCph2-xQ39cZwmXZePaUEed87b8,14378
5
+ gml/compile.py,sha256=3L5fpD8DK45RLiywj1b5NuDlbsxpzRxI87k1GahlMpc,9851
6
6
  gml/device.py,sha256=Iw71NnuLcgjY32ZMXHlnlPkosTuHEmL9E98utmNChlM,2650
7
- gml/hf.py,sha256=pp215wNmaPyCVy4DqFJbe_vEe1BRJ1GAJEURZnLuU0g,28220
8
- gml/model.py,sha256=xESdD7tlqn93ym67Lyyk7TZdM3wUqyn7qWdP2AbgdkI,7261
7
+ gml/hf.py,sha256=Kv2yffy8omTRQDPnoIZocG2EOyfhr7UvLFIvTmRxw0g,36170
8
+ gml/model.py,sha256=8fIYlLRduTsUZfYJr_YVPNxbEVIzr7_yaaTe4T-TZ2Y,8429
9
9
  gml/model_utils.py,sha256=vZvE5cHZIDkUkeZ4Pk4hhV-zOYMiREluv4b8kdqQ3Ig,1375
10
10
  gml/pipelines.py,sha256=LKj_lh5I5HzyUUIPG4CImiqBnQPrJsj0CHPKhLiOOGo,8374
11
- gml/preprocessing.py,sha256=MaKkEW4ZP9fjpkJQfpc0X3rCUuSuSmJnGMClHamKmZU,3210
11
+ gml/preprocessing.py,sha256=YPcxwBOdx0h0ADzoloYbFw9qUGFbi167E8HA4Zwn7Pk,3928
12
12
  gml/proto/gogoproto/gogo_pb2.py,sha256=WVMIAR8K--mCUkTPM7mEeeXGpQlRRtt_kco10iP3CZs,15728
13
13
  gml/proto/mediapipe/framework/calculator_contract_test_pb2.py,sha256=hNjyZCBz3RYa6rN4xR3FOCZKA24gq_LsJ3EMegl5wK4,2031
14
14
  gml/proto/mediapipe/framework/calculator_options_pb2.py,sha256=Nq1BQRtLdsIgfkw7ymD3eg2p2_RSlZhiHS7YbDhNHR0,1563
@@ -23,29 +23,31 @@ gml/proto/mediapipe/framework/stream_handler_pb2.py,sha256=kNo-2Fdua_CeyJInI3q5r
23
23
  gml/proto/mediapipe/framework/test_calculators_pb2.py,sha256=tXF25VpGtHGArffRqFmjD6FO7xmuCPd5j9UYON2SVSM,2230
24
24
  gml/proto/mediapipe/framework/thread_pool_executor_pb2.py,sha256=9TJ66fqSo1BiJmEAQesK0fnVe55zcJpOqVip6HotgyE,2345
25
25
  gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=wQjeDti-C8JiNwRn-z5M5p-Fqxm-SmnbPaoitJcSK-4,2860
26
- gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=t2Far6oVcUFQIimzgAkZ8vQd0asMIlvECp4osC0ujgg,9735
26
+ gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=k8oW5tmFlJK2574Ky6kDc0JmNNQCLroRwCCGyxDd7JA,9968
27
27
  gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=cbNmE12Nm3PjW4NXU7-Z-9m_0Zs3Ab8R1xLkDnvclCg,1730
28
- gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=BosvQ6GYaUGsNTkRZH7osP2dZGWP6U9WyxItIQ_QS-8,9769
29
- gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=oIpxq13C1ynK3alzDNZTOL5URxz5qzbDLD9NOM5xxjE,14511
30
- gml/proto/src/api/corepb/v1/device_info_pb2.py,sha256=5orIOJAkvtH9pWBSXveDASFi4Rn59YWdOSnVLdj891A,5356
31
- gml/proto/src/api/corepb/v1/gem_config_pb2.py,sha256=2ljfF16Xeqgj9TM3gHN54BqRHqS3SQNhOCenEY9K9qU,4718
32
- gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=Un9OwDUmWdqv92QP66K-WVOAzxP_4hMoz33JI4W1G5Y,7868
33
- gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=o0drstrDssejqCFo8Cmm9F0zDw_bmzeOUHiYFrruOqE,29877
28
+ gml/proto/src/api/corepb/v1/compiled_pipeline_pb2.py,sha256=g3MxBqshtwaM9_Nrbvwo995_XWq-maXGP6mDeiEzZKo,7529
29
+ gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=DylHEVXr36Deh5p-WK8aRwQF-uGW5mJ2mo8pJ3qg7KA,13213
30
+ gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=H0WgAgv6-qaf7wnnKALmSBpD_czmUNHNYpsnE3Tmcrs,14988
31
+ gml/proto/src/api/corepb/v1/deployed_pipeline_pb2.py,sha256=cZjoJuZ3fpCiw2Ox7bcHCXYqRTebb08n-aodwjE-xKI,3053
32
+ gml/proto/src/api/corepb/v1/device_info_pb2.py,sha256=pTZGPjfglje-Wu_-R4qiwPtewXNJIGq5Kedme9SHiaU,6713
33
+ gml/proto/src/api/corepb/v1/gem_config_pb2.py,sha256=vC0g3k9hDv-LhiV6LwaYCly6x00Xx_YA0i2AZSwCo_I,5396
34
+ gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=mgi5-prV7Lz0XJ2wo04jGLSvbnDGtdmduSv_6d6I9oA,8368
35
+ gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=_TXJvHSxkX1Il6xEVEiFIfei_ZV4KhdL3cSKaMgIYIw,33548
34
36
  gml/proto/src/common/typespb/jwt_pb2.py,sha256=lxy-bqbyg96i9n_xr2JbkuWX-ldnoJavXPMnApzVSio,5580
35
37
  gml/proto/src/common/typespb/status_pb2.py,sha256=IbBJnbsAlvsuTtyT285ZuW6k5VaPfl5kRSOnBxD_H8M,2109
36
38
  gml/proto/src/common/typespb/uuid_pb2.py,sha256=5Fm3jYpCPX7sMrP6RhRYsF0SnuZNIBEQJk9f0jwZ2Rw,1188
37
- gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256=R8jcxOlR1iz4Y7MnxIKoJ2RaNayqWPiBSt0W496QT-c,3262
39
+ gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256=4mp1QWV7FOzF_nC3RDKZ9vTA-ezMhukcjBEt1lcjGmM,4933
38
40
  gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2_grpc.py,sha256=l-gTK9nYpTlVb7QGAckSQXlHhkRdKe2-nrxXc8NQavY,2912
39
41
  gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py,sha256=KgoUT8ccF-yJPe1r4otQjAPQoKBaQzdBlHoIUSkk0yE,11445
40
42
  gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py,sha256=p3OpT8-hfNHu4-29qr-ZahRwO-LoCYM9Q4jomAHTXGA,24572
41
43
  gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py,sha256=r8mbJNTq45_c0amPnTr8OFZasCk7XWu2YS_eu7GfWJg,7050
42
44
  gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py,sha256=XlE4R2PJaOmzQocx7y6SKJvuqt8tYBGzBuhajvzG0cc,12919
43
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=wvLQvoh2UA5qCcMALT6PS47LYmmVdBz9U47WFLs5Ayg,6330
45
+ gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=2s2p6dURKJLboaR965m2-rGTo_63Bi1cXsA90Hz9u-M,6632
44
46
  gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=-snjW7n6JveUzJVPFcm25XlL19kowPSKgd61l_jPnHA,9541
45
47
  gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py,sha256=RVedXkNYu2iF5OHiXoYyRw9AGRCUWG7qNyY-5QY71Go,3762
46
48
  gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py,sha256=KSdb6V04qUHDsb1R2o3wixwTyZgrhwnPYobjnRgWX4I,4735
47
49
  gml/register_submodules.py,sha256=U8IwjVygX2vxNi_aK6ljHOD4mmrOhbyVczvy4wwulqU,5027
48
50
  gml/tensor.py,sha256=aPLm3I3qkYNDcJmntaUycqqN5rsZmcj8ql0EkupJudY,14977
49
- gimlet_api-0.0.9.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
50
- gimlet_api-0.0.9.dist-info/METADATA,sha256=P5wOKzPZyJroiZTPRpWsDdjiS5XQL21GK-heo5Set_E,531
51
- gimlet_api-0.0.9.dist-info/RECORD,,
51
+ gimlet_api-0.0.10.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
52
+ gimlet_api-0.0.10.dist-info/METADATA,sha256=i3n2dnjznNFL6XFsj1bL0T544E0FmMVQySLgiBkUW04,586
53
+ gimlet_api-0.0.10.dist-info/RECORD,,
gml/client.py CHANGED
@@ -18,8 +18,12 @@ import os
18
18
  import uuid
19
19
  from pathlib import Path
20
20
  from typing import BinaryIO, List, Optional, TextIO, Union
21
+ from urllib.parse import quote
21
22
 
22
23
  import grpc
24
+ from rich.progress import (
25
+ Console,
26
+ )
23
27
 
24
28
  import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
25
29
  import gml.proto.src.common.typespb.uuid_pb2 as uuidpb
@@ -39,6 +43,7 @@ from gml.model import Model
39
43
  from gml.pipelines import Pipeline
40
44
 
41
45
  DEFAULT_CONTROLPLANE_ADDR = "app.gimletlabs.ai"
46
+ console = Console()
42
47
 
43
48
 
44
49
  class _ChannelFactory:
@@ -282,31 +287,28 @@ class Client:
282
287
  def create_model(self, model: Model) -> modelexecpb.Model:
283
288
  existing_model = self._get_model_if_exists(model.name)
284
289
  if existing_model is not None:
285
- print(
286
- 'warning: model "{}" already exists and will not be uploaded.'.format(
287
- model.name
288
- )
290
+ console.print(
291
+ f'[yellow]warning:[/yellow] model "{model.name}" already exists and will not be uploaded.'
289
292
  )
290
293
  return existing_model
291
-
292
294
  model_info = model.to_proto()
293
- with model.collect_assets() as model_assets:
294
- for asset_name, file in model_assets.items():
295
- if isinstance(file, Path) or isinstance(file, str):
296
- file = open(file, "rb")
297
-
298
- sha256 = sha256sum(file)
295
+ with console.status(f'Creating model "{model.name}"...'):
296
+ with model.collect_assets() as model_assets:
297
+ for asset_name, file in model_assets.items():
298
+ if isinstance(file, Path) or isinstance(file, str):
299
+ file = open(file, "rb")
299
300
 
300
- upload_name = model.name
301
- if asset_name:
302
- upload_name += ":" + asset_name
303
- print(f"Uploading {upload_name}...")
301
+ sha256 = sha256sum(file)
304
302
 
305
- file_info = self._upload_file_if_not_exists(sha256, file, sha256)
303
+ upload_name = model.name
304
+ if asset_name:
305
+ upload_name += ":" + asset_name
306
+ file_info = self._upload_file_if_not_exists(sha256, file, sha256)
307
+ console.print(f"Uploaded {upload_name}.")
306
308
 
307
- model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
309
+ model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
308
310
 
309
- file.close()
311
+ file.close()
310
312
 
311
313
  return self._create_model(model_info)
312
314
 
@@ -331,6 +333,8 @@ class Client:
331
333
  else:
332
334
  raise ValueError("must specify one of 'pipeline_file' or 'pipeline'")
333
335
 
336
+ console.print(f'Uploading pipeline "{name}" to {self._org_name}...')
337
+
334
338
  for model in models:
335
339
  self.create_model(model)
336
340
 
@@ -343,6 +347,11 @@ class Client:
343
347
  resp: lppb.CreateLogicalPipelineResponse = stub.CreateLogicalPipeline(
344
348
  req, metadata=self._get_request_metadata(idempotent=True)
345
349
  )
350
+
351
+ url = f"https://{os.getenv('GML_CONTROLPLANE_ADDR')}/orgs/{quote(self._org_name)}/pipelines/{quote(name)}"
352
+ console.print(
353
+ f"[green]Pipeline upload complete![/green]\nView your pipeline at: [cyan]{url}[/cyan]"
354
+ )
346
355
  return resp.id
347
356
 
348
357
  def check_compile(
gml/compile.py CHANGED
@@ -16,10 +16,11 @@
16
16
 
17
17
  import contextlib
18
18
  import functools
19
- from typing import Any, Dict, List, Optional, Sequence, Union
19
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Union
20
20
 
21
21
  import safetensors_mlir
22
22
  import torch
23
+ import torch.utils._pytree
23
24
  import torch_mlir
24
25
  from mlir.ir import (
25
26
  BF16Type,
@@ -28,6 +29,7 @@ from mlir.ir import (
28
29
  F16Type,
29
30
  F32Type,
30
31
  F64Type,
32
+ Float8E4M3FNType,
31
33
  IntegerType,
32
34
  Operation,
33
35
  RankedTensorType,
@@ -40,6 +42,7 @@ from torch_mlir.dialects import torch as torch_d
40
42
  from torch_mlir.extras.fx_decomp_util import get_decomposition_table
41
43
  from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks, InputInfo
42
44
  from torch_mlir.fx import export_and_import
45
+ from transformers import DynamicCache
43
46
 
44
47
  from gml.asset_manager import AssetManager
45
48
  from gml.register_submodules import submodule_registration_workarounds
@@ -53,6 +56,45 @@ def _default_decomposition_denylist():
53
56
  ]
54
57
 
55
58
 
59
+ _registered_dynamic_cache_pytree_node = False
60
+
61
+
62
+ def register_dynamic_cache_pytree_node():
63
+ """
64
+ Registers flattening/unflattening for transformers.DynamicCache
65
+ Pytree is a representation of tensor collections used inside torch.export.
66
+ """
67
+
68
+ global _registered_dynamic_cache_pytree_node
69
+ if _registered_dynamic_cache_pytree_node:
70
+ return
71
+ _registered_dynamic_cache_pytree_node = True
72
+
73
+ def flatten_cache_with_keys(dynamic_cache: DynamicCache):
74
+ return [
75
+ (
76
+ torch.utils._pytree.MappingKey(i),
77
+ list(value),
78
+ )
79
+ for i, value in enumerate(dynamic_cache.to_legacy_cache())
80
+ ], None
81
+
82
+ def flatten_cache(dynamic_cache: DynamicCache):
83
+ flattened, ctx = flatten_cache_with_keys(dynamic_cache)
84
+ return [v for _, v in flattened], ctx
85
+
86
+ def unflatten_cache(flattened: Iterable[Any], context: Any):
87
+ return DynamicCache.from_legacy_cache(flattened)
88
+
89
+ torch.utils._pytree.register_pytree_node(
90
+ DynamicCache,
91
+ flatten_cache,
92
+ unflatten_cache,
93
+ serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
94
+ flatten_with_keys_fn=flatten_cache_with_keys,
95
+ )
96
+
97
+
56
98
  @contextlib.contextmanager
57
99
  def _patch_aot_export_module():
58
100
  """This contextmanager prevents PyTorch dispatch from running when calling aot_export_module.
@@ -91,6 +133,8 @@ _torch_dtype_to_builtin_element_type = {
91
133
  torch.complex32: lambda: ComplexType.get(F16Type.get()),
92
134
  torch.complex64: lambda: ComplexType.get(F32Type.get()),
93
135
  torch.complex128: lambda: ComplexType.get(F64Type.get()),
136
+ # Quantized types.
137
+ torch.float8_e4m3fn: lambda: Float8E4M3FNType.get(),
94
138
  }
95
139
 
96
140
 
@@ -179,6 +223,7 @@ def to_torch_mlir(
179
223
  ] = None,
180
224
  decomposition_denylist: Optional[List[torch._ops.OperatorBase]] = None,
181
225
  weight_manager: Optional[AssetManager] = None,
226
+ export_predispatch: bool = False,
182
227
  ):
183
228
  if dynamic_shapes is not None:
184
229
  for shape in dynamic_shapes:
@@ -205,10 +250,11 @@ def to_torch_mlir(
205
250
  # Ignore errors running the model. This can happen when the model has data dependent branches.
206
251
  pass
207
252
 
253
+ register_dynamic_cache_pytree_node()
208
254
  prog = _export(
209
255
  model,
210
256
  tuple(example_inputs),
211
- pre_dispatch=False,
257
+ pre_dispatch=export_predispatch,
212
258
  strict=False,
213
259
  dynamic_shapes=dynamic_shapes,
214
260
  )
gml/hf.py CHANGED
@@ -24,8 +24,10 @@ from typing import Any, BinaryIO, Dict, List, Optional, TextIO, Tuple
24
24
 
25
25
  import torch
26
26
  import transformers
27
+ from rich.progress import Console
27
28
  from transformers import (
28
29
  BaseImageProcessor,
30
+ DynamicCache,
29
31
  Pipeline,
30
32
  PreTrainedModel,
31
33
  PreTrainedTokenizer,
@@ -49,6 +51,7 @@ from gml.tensor import (
49
51
  DetectionNumCandidatesDimension,
50
52
  DetectionOutputDimension,
51
53
  DimensionSemantics,
54
+ EmbeddingDimension,
52
55
  ImageChannelDimension,
53
56
  ImageHeightDimension,
54
57
  ImageWidthDimension,
@@ -60,6 +63,11 @@ from gml.tensor import (
60
63
 
61
64
  FALLBACK_RESIZE_SIZE = 512
62
65
 
66
+ # Set dynamic dimension max size to less than the int64 max, leaving leeway for the size to be ~4x by the model.
67
+ MAX_DYNAMIC_VAL = 2**61
68
+
69
+ console = Console()
70
+
63
71
 
64
72
  class HuggingFaceTokenizer(Model):
65
73
  def __init__(self, tokenizer: PreTrainedTokenizer, name: Optional[str] = None):
@@ -105,7 +113,6 @@ def flatten(items):
105
113
 
106
114
 
107
115
  class WrapWithFunctionalCache(torch.nn.Module):
108
-
109
116
  def __init__(self, model: transformers.PreTrainedModel):
110
117
  super().__init__()
111
118
  self.model = model
@@ -128,6 +135,8 @@ class HuggingFaceTextGenerationPipeline:
128
135
  name: Optional[str] = None,
129
136
  tokenizer_name: Optional[str] = None,
130
137
  dynamic_seqlen: bool = False,
138
+ dynamic_batch: bool = False,
139
+ export_predispatch: bool = False,
131
140
  ):
132
141
  self.pipeline = pipeline
133
142
  self.tokenizer_model = HuggingFaceTokenizer(pipeline.tokenizer, tokenizer_name)
@@ -139,13 +148,20 @@ class HuggingFaceTextGenerationPipeline:
139
148
  self.model = self.model.to(torch.float16)
140
149
  self.model = WrapWithFunctionalCache(pipeline.model)
141
150
 
151
+ self.dynamic_batch = dynamic_batch
152
+ self.batch_size = 1
153
+ if self.dynamic_batch:
154
+ # dynamic tracing fails for dimensions of size 1.
155
+ self.batch_size = 2
156
+
142
157
  self.language_model = TorchModel(
143
158
  name,
144
159
  torch_module=self.model,
160
+ export_predispatch=export_predispatch,
145
161
  **self._guess_model_spec(dynamic_seqlen),
146
162
  )
147
163
 
148
- def _initialize_key_value_cache(self):
164
+ def _initialize_key_value_cache(self) -> DynamicCache:
149
165
  cache = []
150
166
  config = self.pipeline.model.config
151
167
  head_dim = (
@@ -158,7 +174,12 @@ class HuggingFaceTextGenerationPipeline:
158
174
  if config.num_key_value_heads is None
159
175
  else config.num_key_value_heads
160
176
  )
161
- cache_shape = (1, num_key_value_heads, self._cache_length_for_tracing, head_dim)
177
+ cache_shape = (
178
+ self.batch_size,
179
+ num_key_value_heads,
180
+ self._cache_length_for_tracing,
181
+ head_dim,
182
+ )
162
183
  for _ in range(config.num_hidden_layers):
163
184
  cache.append(
164
185
  [
@@ -166,7 +187,67 @@ class HuggingFaceTextGenerationPipeline:
166
187
  torch.zeros(cache_shape).to(torch.float16),
167
188
  ]
168
189
  )
169
- return cache
190
+ return DynamicCache.from_legacy_cache(cache)
191
+
192
+ def _parse_transformer_config(
193
+ self, model: transformers.PreTrainedModel
194
+ ) -> modelexecpb.TransformerConfig:
195
+ # Only non-default rope config set the rope_scaling parameter
196
+ attention_head_size = getattr(
197
+ model.config,
198
+ "attention_head_size",
199
+ model.config.hidden_size // model.config.num_attention_heads,
200
+ )
201
+ partial_rotary_factor = getattr(model.config, "partial_rotary_factor", 1.0)
202
+ rotary_embedding_dim = getattr(
203
+ model.config,
204
+ "rotary_dim",
205
+ int(attention_head_size * partial_rotary_factor),
206
+ )
207
+ if (
208
+ hasattr(model.config, "rope_scaling")
209
+ and model.config.rope_scaling is not None
210
+ ):
211
+ rope_scaling = model.config.rope_scaling
212
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
213
+ if not rope_type == "llama3":
214
+ raise NotImplementedError(
215
+ "rope scaling type {} is not supported".format(rope_type)
216
+ )
217
+ # LLAMA 3 example config: https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json
218
+ llama3_config = modelexecpb.Llama3RopeConfig()
219
+ llama3_config.theta = model.config.rope_theta
220
+ llama3_config.rotary_embedding_dim = rotary_embedding_dim
221
+ llama3_config.max_position_embeddings = model.config.max_position_embeddings
222
+
223
+ llama3_config.factor = rope_scaling["factor"]
224
+ llama3_config.high_freq_factor = rope_scaling["high_freq_factor"]
225
+ llama3_config.low_freq_factor = rope_scaling["low_freq_factor"]
226
+ llama3_config.original_max_position_embeddings = rope_scaling[
227
+ "original_max_position_embeddings"
228
+ ]
229
+ return modelexecpb.TransformerConfig(
230
+ position_embedding_config=modelexecpb.PositionEmbeddingConfig(
231
+ kind=modelexecpb.PositionEmbeddingKind.POSITION_EMBEDDING_KIND_ROPE_LLAMA3,
232
+ llama3_rope_config=llama3_config,
233
+ ),
234
+ )
235
+ # Default rope configs:
236
+ # 1. Llama-2: https://huggingface.co/NousResearch/Llama-2-7b-hf/blob/main/config.json
237
+ # 2. Qwen2.5: https://huggingface.co/Qwen/Qwen2.5-14B-Instruct-1M/blob/main/config.json
238
+ # 3. Mixtral: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json
239
+ default_rope_config = modelexecpb.DefaultRopeConfig()
240
+ default_rope_config.theta = model.config.rope_theta
241
+ default_rope_config.max_position_embeddings = (
242
+ model.config.max_position_embeddings
243
+ )
244
+ default_rope_config.rotary_embedding_dim = rotary_embedding_dim
245
+ return modelexecpb.TransformerConfig(
246
+ position_embedding_config=modelexecpb.PositionEmbeddingConfig(
247
+ kind=modelexecpb.PositionEmbeddingKind.POSITION_EMBEDDING_KIND_ROPE_DEFAULT,
248
+ default_rope_config=default_rope_config,
249
+ ),
250
+ )
170
251
 
171
252
  def _guess_model_spec(self, dynamic_seqlen: bool) -> Dict:
172
253
  input_dict = self.pipeline.preprocess("this is a prompt! Test test test?")
@@ -179,7 +260,7 @@ class HuggingFaceTextGenerationPipeline:
179
260
  input_tensor_semantics = []
180
261
 
181
262
  # This currently assumes that all HF language models have inputs that are [B, NUM_TOKENS].
182
- inputs.append(input_dict["input_ids"])
263
+ inputs.append(torch.tile(input_dict["input_ids"], [self.batch_size, 1]))
183
264
  input_tensor_semantics.append(
184
265
  TensorSemantics(
185
266
  dimensions=[
@@ -192,7 +273,7 @@ class HuggingFaceTextGenerationPipeline:
192
273
  # Assume that the model supports a KeyValue cache.
193
274
  cache_values = self._initialize_key_value_cache()
194
275
  inputs.append(cache_values)
195
- for _ in cache_values:
276
+ for _ in range(len(cache_values)):
196
277
  input_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
197
278
  input_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
198
279
 
@@ -209,7 +290,7 @@ class HuggingFaceTextGenerationPipeline:
209
290
  if (
210
291
  not found_logits
211
292
  and len(tensor.shape) == 3
212
- and tensor.shape[0] == 1
293
+ and tensor.shape[0] == self.batch_size
213
294
  and tensor.shape[1] == seqlen
214
295
  ):
215
296
  # This should be the logits tensor.
@@ -226,14 +307,38 @@ class HuggingFaceTextGenerationPipeline:
226
307
  else:
227
308
  output_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
228
309
 
310
+ if not found_logits:
311
+ raise ValueError(
312
+ "could not determine output logits tensor for text generation model"
313
+ )
314
+
315
+ num_experts_per_tok = (
316
+ 1
317
+ if not hasattr(self.pipeline.model.config, "num_experts_per_tok")
318
+ else self.pipeline.model.config.num_experts_per_tok
319
+ )
320
+
229
321
  dynamic_shapes = None
230
- seqlen = torch.export.Dim("seqlen", min=2, max=9223372036854775096)
322
+ # Set range to half of seqlen to account for # of tokens per expert.
323
+ # pytorch export creates a constraint on the number of possible tokens
324
+ # sent to each expert. That value is num_experts * seqlen. If we don't divide
325
+ # by number of experts, the tracing creates an integer value that exceeds the valid int64
326
+ # range and will throw a hard to decipher error message.
327
+ seqlen = torch.export.Dim(
328
+ "seqlen", min=2, max=MAX_DYNAMIC_VAL // num_experts_per_tok
329
+ )
231
330
 
232
- cache_length = torch.export.Dim("cache_length", min=2, max=9223372036854775096)
331
+ cache_length = torch.export.Dim("cache_length", min=2, max=MAX_DYNAMIC_VAL)
233
332
  dynamic_shapes = [
234
333
  {1: seqlen},
235
- [[{2: cache_length}, {2: cache_length}] for _ in cache_values],
334
+ [[{2: cache_length}, {2: cache_length}] for _ in range(len(cache_values))],
236
335
  ]
336
+ if self.dynamic_batch:
337
+ batch = torch.export.Dim("batch")
338
+ dynamic_shapes[0][0] = batch
339
+ for i in range(len(cache_values)):
340
+ dynamic_shapes[1][i][0][0] = batch
341
+ dynamic_shapes[1][i][1][0] = batch
237
342
 
238
343
  return {
239
344
  "example_inputs": inputs,
@@ -241,6 +346,7 @@ class HuggingFaceTextGenerationPipeline:
241
346
  "input_tensor_semantics": input_tensor_semantics,
242
347
  "output_tensor_semantics": output_tensor_semantics,
243
348
  "generation_config": HuggingFaceGenerationConfig(self.pipeline.model),
349
+ "transformer_config": self._parse_transformer_config(self.pipeline.model),
244
350
  }
245
351
 
246
352
  def models(self) -> List[Model]:
@@ -695,8 +801,8 @@ class HuggingFaceZeroShotObjectDetectionPipeline:
695
801
 
696
802
  spec["dynamic_shapes"].extend(
697
803
  [
698
- {0: "num_labels"},
699
- {0: "num_labels"},
804
+ {0: torch.export.Dim("num_labels", max=MAX_DYNAMIC_VAL)},
805
+ {0: torch.export.Dim("num_labels", max=MAX_DYNAMIC_VAL)},
700
806
  ]
701
807
  )
702
808
 
@@ -762,33 +868,121 @@ class HuggingFaceDepthEstimationPipeline:
762
868
  return [self.model]
763
869
 
764
870
 
765
- def import_huggingface_pipeline(pipeline: Pipeline, **kwargs) -> List[Model]:
766
- if pipeline.framework != "pt":
767
- raise ValueError(
768
- "unimplemented: hugging face pipeline framework: {}".format(
769
- pipeline.framework
871
+ class HuggingFaceFeatureExtractionPipeline:
872
+ def __init__(self, pipeline: Pipeline, name: Optional[str] = None):
873
+ self.pipeline = pipeline
874
+ if name is None:
875
+ name = pipeline.model.name_or_path
876
+
877
+ self.tokenizer_model = HuggingFaceTokenizer(self.pipeline.tokenizer)
878
+
879
+ self.model = TorchModel(
880
+ name=name,
881
+ torch_module=self.pipeline.model,
882
+ **self._guess_model_spec(),
883
+ )
884
+
885
+ def _guess_model_spec(self) -> Dict:
886
+ spec = {
887
+ "example_inputs": [],
888
+ "input_tensor_semantics": [],
889
+ "output_tensor_semantics": [],
890
+ "dynamic_shapes": [],
891
+ }
892
+
893
+ input_dict = self.pipeline.preprocess("this is a prompt! Test test test?")
894
+ if "input_ids" not in input_dict:
895
+ raise ValueError(
896
+ 'HuggingFaceFeatureExtractionPipeline expects preprocessed inputs to have an "input_ids" tensor'
770
897
  )
898
+
899
+ spec["example_inputs"].append(input_dict["input_ids"])
900
+ spec["input_tensor_semantics"].extend(
901
+ [
902
+ TensorSemantics(
903
+ dimensions=[
904
+ BatchDimension(),
905
+ TokensDimension(),
906
+ ]
907
+ ),
908
+ ]
909
+ )
910
+
911
+ spec["output_tensor_semantics"].extend(
912
+ [
913
+ TensorSemantics(
914
+ dimensions=[
915
+ BatchDimension(),
916
+ TokensDimension(),
917
+ EmbeddingDimension(),
918
+ ],
919
+ ),
920
+ TensorSemantics(
921
+ dimensions=[
922
+ BatchDimension(),
923
+ EmbeddingDimension(),
924
+ ],
925
+ ),
926
+ ]
771
927
  )
772
928
 
773
- if pipeline.task == "text-generation":
774
- return HuggingFaceTextGenerationPipeline(pipeline, **kwargs).models()
775
- elif pipeline.task == "image-segmentation":
776
- return HuggingFaceImageSegmentationPipeline(pipeline, **kwargs).models()
777
- elif pipeline.task == "object-detection":
778
- return HuggingFaceObjectDetectionPipeline(pipeline, **kwargs).models()
779
- elif pipeline.task == "zero-shot-object-detection":
780
- return HuggingFaceZeroShotObjectDetectionPipeline(pipeline, **kwargs).models()
781
- elif pipeline.task == "depth-estimation":
782
- return HuggingFaceDepthEstimationPipeline(pipeline, **kwargs).models()
783
- raise ValueError(
784
- "unimplemented: hugging face pipeline task: {} (supported tasks: [{}])".format(
785
- pipeline.task,
929
+ max_seqlen = (
930
+ getattr(self.pipeline.model.config, "max_position_embeddings", 500) - 1
931
+ )
932
+ spec["dynamic_shapes"].extend(
786
933
  [
787
- "text-generation",
788
- "image-segmentation",
789
- "object-detection",
790
- "zero-shot-object-detection",
791
- "depth-estimation",
792
- ],
793
- )
794
- )
934
+ {
935
+ 1: torch.export.Dim(
936
+ "seqlen",
937
+ max=max_seqlen,
938
+ )
939
+ },
940
+ ]
941
+ )
942
+ return spec
943
+
944
+ def models(self) -> List[Model]:
945
+ return [self.model, self.tokenizer_model]
946
+
947
+
948
+ def import_huggingface_pipeline(pipeline: Pipeline, **kwargs) -> List[Model]:
949
+ with console.status(
950
+ f'Importing HuggingFace pipeline: "{pipeline.model.name_or_path}"'
951
+ ):
952
+ if pipeline.framework != "pt":
953
+ raise ValueError(
954
+ "unimplemented: hugging face pipeline framework: {}".format(
955
+ pipeline.framework
956
+ )
957
+ )
958
+
959
+ if pipeline.task == "text-generation":
960
+ result = HuggingFaceTextGenerationPipeline(pipeline, **kwargs).models()
961
+ elif pipeline.task == "image-segmentation":
962
+ result = HuggingFaceImageSegmentationPipeline(pipeline, **kwargs).models()
963
+ elif pipeline.task == "object-detection":
964
+ result = HuggingFaceObjectDetectionPipeline(pipeline, **kwargs).models()
965
+ elif pipeline.task == "zero-shot-object-detection":
966
+ result = HuggingFaceZeroShotObjectDetectionPipeline(
967
+ pipeline, **kwargs
968
+ ).models()
969
+ elif pipeline.task == "depth-estimation":
970
+ result = HuggingFaceDepthEstimationPipeline(pipeline, **kwargs).models()
971
+ elif pipeline.task == "feature-extraction":
972
+ result = HuggingFaceFeatureExtractionPipeline(pipeline, **kwargs).models()
973
+ else:
974
+ raise ValueError(
975
+ "unimplemented: hugging face pipeline task: {} (supported tasks: [{}])".format(
976
+ pipeline.task,
977
+ [
978
+ "text-generation",
979
+ "image-segmentation",
980
+ "object-detection",
981
+ "zero-shot-object-detection",
982
+ "depth-estimation",
983
+ "feature-extraction",
984
+ ],
985
+ )
986
+ )
987
+ console.print(f'Imported HuggingFace pipeline: "{pipeline.model.name_or_path}".')
988
+ return result