gimlet-api 0.0.9__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gimlet_api-0.0.9.dist-info → gimlet_api-0.0.10.dist-info}/METADATA +4 -2
- {gimlet_api-0.0.9.dist-info → gimlet_api-0.0.10.dist-info}/RECORD +19 -17
- gml/client.py +27 -18
- gml/compile.py +48 -2
- gml/hf.py +232 -38
- gml/model.py +37 -0
- gml/preprocessing.py +17 -3
- gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py +39 -38
- gml/proto/src/api/corepb/v1/compiled_pipeline_pb2.py +64 -0
- gml/proto/src/api/corepb/v1/controlplane_pb2.py +35 -9
- gml/proto/src/api/corepb/v1/cp_edge_pb2.py +37 -35
- gml/proto/src/api/corepb/v1/deployed_pipeline_pb2.py +37 -0
- gml/proto/src/api/corepb/v1/device_info_pb2.py +19 -11
- gml/proto/src/api/corepb/v1/gem_config_pb2.py +17 -13
- gml/proto/src/api/corepb/v1/mediastream_pb2.py +42 -39
- gml/proto/src/api/corepb/v1/model_exec_pb2.py +129 -109
- gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py +20 -8
- gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py +25 -23
- {gimlet_api-0.0.9.dist-info → gimlet_api-0.0.10.dist-info}/WHEEL +0 -0
@@ -9,11 +9,13 @@ Classifier: Typing :: Typed
|
|
9
9
|
Requires-Python: >=3
|
10
10
|
Requires-Dist: protobuf
|
11
11
|
Requires-Dist: grpcio
|
12
|
-
Requires-Dist: torch>=2.
|
12
|
+
Requires-Dist: torch>=2.6.0
|
13
13
|
Requires-Dist: torch-mlir-gml
|
14
14
|
Requires-Dist: numpy<2.0.0
|
15
|
+
Requires-Dist: rich
|
15
16
|
Requires-Dist: transformers>=4.43.3
|
17
|
+
Requires-Dist: tokenizers>=0.21.0
|
16
18
|
Requires-Dist: safetensors-mlir
|
17
|
-
Version: 0.0.
|
19
|
+
Version: 0.0.10
|
18
20
|
|
19
21
|
UNKNOWN
|
@@ -1,14 +1,14 @@
|
|
1
1
|
gml/__init__.py,sha256=H3WQZ_RaN7VNeb__qeHEbKLEwkaG7gpL5FQ8s1IotUA,773
|
2
2
|
gml/_utils.py,sha256=mSCWHhCdzcUvHqmJIB2FS215K1LMgJCWcZ6e6FWK3hQ,1184
|
3
3
|
gml/asset_manager.py,sha256=VnbqUZHPOgPrAh6ri9C0EuNhS8tAHIrbUyJPAJuD9po,2053
|
4
|
-
gml/client.py,sha256=
|
5
|
-
gml/compile.py,sha256=
|
4
|
+
gml/client.py,sha256=AcnG5mniHOfq-He-uCph2-xQ39cZwmXZePaUEed87b8,14378
|
5
|
+
gml/compile.py,sha256=3L5fpD8DK45RLiywj1b5NuDlbsxpzRxI87k1GahlMpc,9851
|
6
6
|
gml/device.py,sha256=Iw71NnuLcgjY32ZMXHlnlPkosTuHEmL9E98utmNChlM,2650
|
7
|
-
gml/hf.py,sha256=
|
8
|
-
gml/model.py,sha256=
|
7
|
+
gml/hf.py,sha256=Kv2yffy8omTRQDPnoIZocG2EOyfhr7UvLFIvTmRxw0g,36170
|
8
|
+
gml/model.py,sha256=8fIYlLRduTsUZfYJr_YVPNxbEVIzr7_yaaTe4T-TZ2Y,8429
|
9
9
|
gml/model_utils.py,sha256=vZvE5cHZIDkUkeZ4Pk4hhV-zOYMiREluv4b8kdqQ3Ig,1375
|
10
10
|
gml/pipelines.py,sha256=LKj_lh5I5HzyUUIPG4CImiqBnQPrJsj0CHPKhLiOOGo,8374
|
11
|
-
gml/preprocessing.py,sha256=
|
11
|
+
gml/preprocessing.py,sha256=YPcxwBOdx0h0ADzoloYbFw9qUGFbi167E8HA4Zwn7Pk,3928
|
12
12
|
gml/proto/gogoproto/gogo_pb2.py,sha256=WVMIAR8K--mCUkTPM7mEeeXGpQlRRtt_kco10iP3CZs,15728
|
13
13
|
gml/proto/mediapipe/framework/calculator_contract_test_pb2.py,sha256=hNjyZCBz3RYa6rN4xR3FOCZKA24gq_LsJ3EMegl5wK4,2031
|
14
14
|
gml/proto/mediapipe/framework/calculator_options_pb2.py,sha256=Nq1BQRtLdsIgfkw7ymD3eg2p2_RSlZhiHS7YbDhNHR0,1563
|
@@ -23,29 +23,31 @@ gml/proto/mediapipe/framework/stream_handler_pb2.py,sha256=kNo-2Fdua_CeyJInI3q5r
|
|
23
23
|
gml/proto/mediapipe/framework/test_calculators_pb2.py,sha256=tXF25VpGtHGArffRqFmjD6FO7xmuCPd5j9UYON2SVSM,2230
|
24
24
|
gml/proto/mediapipe/framework/thread_pool_executor_pb2.py,sha256=9TJ66fqSo1BiJmEAQesK0fnVe55zcJpOqVip6HotgyE,2345
|
25
25
|
gml/proto/opentelemetry/proto/common/v1/common_pb2.py,sha256=wQjeDti-C8JiNwRn-z5M5p-Fqxm-SmnbPaoitJcSK-4,2860
|
26
|
-
gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=
|
26
|
+
gml/proto/opentelemetry/proto/metrics/v1/metrics_pb2.py,sha256=k8oW5tmFlJK2574Ky6kDc0JmNNQCLroRwCCGyxDd7JA,9968
|
27
27
|
gml/proto/opentelemetry/proto/resource/v1/resource_pb2.py,sha256=cbNmE12Nm3PjW4NXU7-Z-9m_0Zs3Ab8R1xLkDnvclCg,1730
|
28
|
-
gml/proto/src/api/corepb/v1/
|
29
|
-
gml/proto/src/api/corepb/v1/
|
30
|
-
gml/proto/src/api/corepb/v1/
|
31
|
-
gml/proto/src/api/corepb/v1/
|
32
|
-
gml/proto/src/api/corepb/v1/
|
33
|
-
gml/proto/src/api/corepb/v1/
|
28
|
+
gml/proto/src/api/corepb/v1/compiled_pipeline_pb2.py,sha256=g3MxBqshtwaM9_Nrbvwo995_XWq-maXGP6mDeiEzZKo,7529
|
29
|
+
gml/proto/src/api/corepb/v1/controlplane_pb2.py,sha256=DylHEVXr36Deh5p-WK8aRwQF-uGW5mJ2mo8pJ3qg7KA,13213
|
30
|
+
gml/proto/src/api/corepb/v1/cp_edge_pb2.py,sha256=H0WgAgv6-qaf7wnnKALmSBpD_czmUNHNYpsnE3Tmcrs,14988
|
31
|
+
gml/proto/src/api/corepb/v1/deployed_pipeline_pb2.py,sha256=cZjoJuZ3fpCiw2Ox7bcHCXYqRTebb08n-aodwjE-xKI,3053
|
32
|
+
gml/proto/src/api/corepb/v1/device_info_pb2.py,sha256=pTZGPjfglje-Wu_-R4qiwPtewXNJIGq5Kedme9SHiaU,6713
|
33
|
+
gml/proto/src/api/corepb/v1/gem_config_pb2.py,sha256=vC0g3k9hDv-LhiV6LwaYCly6x00Xx_YA0i2AZSwCo_I,5396
|
34
|
+
gml/proto/src/api/corepb/v1/mediastream_pb2.py,sha256=mgi5-prV7Lz0XJ2wo04jGLSvbnDGtdmduSv_6d6I9oA,8368
|
35
|
+
gml/proto/src/api/corepb/v1/model_exec_pb2.py,sha256=_TXJvHSxkX1Il6xEVEiFIfei_ZV4KhdL3cSKaMgIYIw,33548
|
34
36
|
gml/proto/src/common/typespb/jwt_pb2.py,sha256=lxy-bqbyg96i9n_xr2JbkuWX-ldnoJavXPMnApzVSio,5580
|
35
37
|
gml/proto/src/common/typespb/status_pb2.py,sha256=IbBJnbsAlvsuTtyT285ZuW6k5VaPfl5kRSOnBxD_H8M,2109
|
36
38
|
gml/proto/src/common/typespb/uuid_pb2.py,sha256=5Fm3jYpCPX7sMrP6RhRYsF0SnuZNIBEQJk9f0jwZ2Rw,1188
|
37
|
-
gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256=
|
39
|
+
gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2.py,sha256=4mp1QWV7FOzF_nC3RDKZ9vTA-ezMhukcjBEt1lcjGmM,4933
|
38
40
|
gml/proto/src/controlplane/compiler/cpb/v1/cpb_pb2_grpc.py,sha256=l-gTK9nYpTlVb7QGAckSQXlHhkRdKe2-nrxXc8NQavY,2912
|
39
41
|
gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2.py,sha256=KgoUT8ccF-yJPe1r4otQjAPQoKBaQzdBlHoIUSkk0yE,11445
|
40
42
|
gml/proto/src/controlplane/directory/directorypb/v1/directory_pb2_grpc.py,sha256=p3OpT8-hfNHu4-29qr-ZahRwO-LoCYM9Q4jomAHTXGA,24572
|
41
43
|
gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2.py,sha256=r8mbJNTq45_c0amPnTr8OFZasCk7XWu2YS_eu7GfWJg,7050
|
42
44
|
gml/proto/src/controlplane/filetransfer/ftpb/v1/ftpb_pb2_grpc.py,sha256=XlE4R2PJaOmzQocx7y6SKJvuqt8tYBGzBuhajvzG0cc,12919
|
43
|
-
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=
|
45
|
+
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2.py,sha256=2s2p6dURKJLboaR965m2-rGTo_63Bi1cXsA90Hz9u-M,6632
|
44
46
|
gml/proto/src/controlplane/logicalpipeline/lppb/v1/lppb_pb2_grpc.py,sha256=-snjW7n6JveUzJVPFcm25XlL19kowPSKgd61l_jPnHA,9541
|
45
47
|
gml/proto/src/controlplane/model/mpb/v1/mpb_pb2.py,sha256=RVedXkNYu2iF5OHiXoYyRw9AGRCUWG7qNyY-5QY71Go,3762
|
46
48
|
gml/proto/src/controlplane/model/mpb/v1/mpb_pb2_grpc.py,sha256=KSdb6V04qUHDsb1R2o3wixwTyZgrhwnPYobjnRgWX4I,4735
|
47
49
|
gml/register_submodules.py,sha256=U8IwjVygX2vxNi_aK6ljHOD4mmrOhbyVczvy4wwulqU,5027
|
48
50
|
gml/tensor.py,sha256=aPLm3I3qkYNDcJmntaUycqqN5rsZmcj8ql0EkupJudY,14977
|
49
|
-
gimlet_api-0.0.
|
50
|
-
gimlet_api-0.0.
|
51
|
-
gimlet_api-0.0.
|
51
|
+
gimlet_api-0.0.10.dist-info/WHEEL,sha256=sobxWSyDDkdg_rinUth-jxhXHqoNqlmNMJY3aTZn2Us,91
|
52
|
+
gimlet_api-0.0.10.dist-info/METADATA,sha256=i3n2dnjznNFL6XFsj1bL0T544E0FmMVQySLgiBkUW04,586
|
53
|
+
gimlet_api-0.0.10.dist-info/RECORD,,
|
gml/client.py
CHANGED
@@ -18,8 +18,12 @@ import os
|
|
18
18
|
import uuid
|
19
19
|
from pathlib import Path
|
20
20
|
from typing import BinaryIO, List, Optional, TextIO, Union
|
21
|
+
from urllib.parse import quote
|
21
22
|
|
22
23
|
import grpc
|
24
|
+
from rich.progress import (
|
25
|
+
Console,
|
26
|
+
)
|
23
27
|
|
24
28
|
import gml.proto.src.api.corepb.v1.model_exec_pb2 as modelexecpb
|
25
29
|
import gml.proto.src.common.typespb.uuid_pb2 as uuidpb
|
@@ -39,6 +43,7 @@ from gml.model import Model
|
|
39
43
|
from gml.pipelines import Pipeline
|
40
44
|
|
41
45
|
DEFAULT_CONTROLPLANE_ADDR = "app.gimletlabs.ai"
|
46
|
+
console = Console()
|
42
47
|
|
43
48
|
|
44
49
|
class _ChannelFactory:
|
@@ -282,31 +287,28 @@ class Client:
|
|
282
287
|
def create_model(self, model: Model) -> modelexecpb.Model:
|
283
288
|
existing_model = self._get_model_if_exists(model.name)
|
284
289
|
if existing_model is not None:
|
285
|
-
print(
|
286
|
-
'warning: model "{}" already exists and will not be uploaded.'
|
287
|
-
model.name
|
288
|
-
)
|
290
|
+
console.print(
|
291
|
+
f'[yellow]warning:[/yellow] model "{model.name}" already exists and will not be uploaded.'
|
289
292
|
)
|
290
293
|
return existing_model
|
291
|
-
|
292
294
|
model_info = model.to_proto()
|
293
|
-
with
|
294
|
-
|
295
|
-
|
296
|
-
file
|
297
|
-
|
298
|
-
sha256 = sha256sum(file)
|
295
|
+
with console.status(f'Creating model "{model.name}"...'):
|
296
|
+
with model.collect_assets() as model_assets:
|
297
|
+
for asset_name, file in model_assets.items():
|
298
|
+
if isinstance(file, Path) or isinstance(file, str):
|
299
|
+
file = open(file, "rb")
|
299
300
|
|
300
|
-
|
301
|
-
if asset_name:
|
302
|
-
upload_name += ":" + asset_name
|
303
|
-
print(f"Uploading {upload_name}...")
|
301
|
+
sha256 = sha256sum(file)
|
304
302
|
|
305
|
-
|
303
|
+
upload_name = model.name
|
304
|
+
if asset_name:
|
305
|
+
upload_name += ":" + asset_name
|
306
|
+
file_info = self._upload_file_if_not_exists(sha256, file, sha256)
|
307
|
+
console.print(f"Uploaded {upload_name}.")
|
306
308
|
|
307
|
-
|
309
|
+
model_info.file_assets[asset_name].MergeFrom(file_info.file_id)
|
308
310
|
|
309
|
-
|
311
|
+
file.close()
|
310
312
|
|
311
313
|
return self._create_model(model_info)
|
312
314
|
|
@@ -331,6 +333,8 @@ class Client:
|
|
331
333
|
else:
|
332
334
|
raise ValueError("must specify one of 'pipeline_file' or 'pipeline'")
|
333
335
|
|
336
|
+
console.print(f'Uploading pipeline "{name}" to {self._org_name}...')
|
337
|
+
|
334
338
|
for model in models:
|
335
339
|
self.create_model(model)
|
336
340
|
|
@@ -343,6 +347,11 @@ class Client:
|
|
343
347
|
resp: lppb.CreateLogicalPipelineResponse = stub.CreateLogicalPipeline(
|
344
348
|
req, metadata=self._get_request_metadata(idempotent=True)
|
345
349
|
)
|
350
|
+
|
351
|
+
url = f"https://{os.getenv('GML_CONTROLPLANE_ADDR')}/orgs/{quote(self._org_name)}/pipelines/{quote(name)}"
|
352
|
+
console.print(
|
353
|
+
f"[green]Pipeline upload complete![/green]\nView your pipeline at: [cyan]{url}[/cyan]"
|
354
|
+
)
|
346
355
|
return resp.id
|
347
356
|
|
348
357
|
def check_compile(
|
gml/compile.py
CHANGED
@@ -16,10 +16,11 @@
|
|
16
16
|
|
17
17
|
import contextlib
|
18
18
|
import functools
|
19
|
-
from typing import Any, Dict, List, Optional, Sequence, Union
|
19
|
+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union
|
20
20
|
|
21
21
|
import safetensors_mlir
|
22
22
|
import torch
|
23
|
+
import torch.utils._pytree
|
23
24
|
import torch_mlir
|
24
25
|
from mlir.ir import (
|
25
26
|
BF16Type,
|
@@ -28,6 +29,7 @@ from mlir.ir import (
|
|
28
29
|
F16Type,
|
29
30
|
F32Type,
|
30
31
|
F64Type,
|
32
|
+
Float8E4M3FNType,
|
31
33
|
IntegerType,
|
32
34
|
Operation,
|
33
35
|
RankedTensorType,
|
@@ -40,6 +42,7 @@ from torch_mlir.dialects import torch as torch_d
|
|
40
42
|
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
41
43
|
from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks, InputInfo
|
42
44
|
from torch_mlir.fx import export_and_import
|
45
|
+
from transformers import DynamicCache
|
43
46
|
|
44
47
|
from gml.asset_manager import AssetManager
|
45
48
|
from gml.register_submodules import submodule_registration_workarounds
|
@@ -53,6 +56,45 @@ def _default_decomposition_denylist():
|
|
53
56
|
]
|
54
57
|
|
55
58
|
|
59
|
+
_registered_dynamic_cache_pytree_node = False
|
60
|
+
|
61
|
+
|
62
|
+
def register_dynamic_cache_pytree_node():
|
63
|
+
"""
|
64
|
+
Registers flattening/unflattening for transformers.DynamicCache
|
65
|
+
Pytree is a representation of tensor collections used inside torch.export.
|
66
|
+
"""
|
67
|
+
|
68
|
+
global _registered_dynamic_cache_pytree_node
|
69
|
+
if _registered_dynamic_cache_pytree_node:
|
70
|
+
return
|
71
|
+
_registered_dynamic_cache_pytree_node = True
|
72
|
+
|
73
|
+
def flatten_cache_with_keys(dynamic_cache: DynamicCache):
|
74
|
+
return [
|
75
|
+
(
|
76
|
+
torch.utils._pytree.MappingKey(i),
|
77
|
+
list(value),
|
78
|
+
)
|
79
|
+
for i, value in enumerate(dynamic_cache.to_legacy_cache())
|
80
|
+
], None
|
81
|
+
|
82
|
+
def flatten_cache(dynamic_cache: DynamicCache):
|
83
|
+
flattened, ctx = flatten_cache_with_keys(dynamic_cache)
|
84
|
+
return [v for _, v in flattened], ctx
|
85
|
+
|
86
|
+
def unflatten_cache(flattened: Iterable[Any], context: Any):
|
87
|
+
return DynamicCache.from_legacy_cache(flattened)
|
88
|
+
|
89
|
+
torch.utils._pytree.register_pytree_node(
|
90
|
+
DynamicCache,
|
91
|
+
flatten_cache,
|
92
|
+
unflatten_cache,
|
93
|
+
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
|
94
|
+
flatten_with_keys_fn=flatten_cache_with_keys,
|
95
|
+
)
|
96
|
+
|
97
|
+
|
56
98
|
@contextlib.contextmanager
|
57
99
|
def _patch_aot_export_module():
|
58
100
|
"""This contextmanager prevents PyTorch dispatch from running when calling aot_export_module.
|
@@ -91,6 +133,8 @@ _torch_dtype_to_builtin_element_type = {
|
|
91
133
|
torch.complex32: lambda: ComplexType.get(F16Type.get()),
|
92
134
|
torch.complex64: lambda: ComplexType.get(F32Type.get()),
|
93
135
|
torch.complex128: lambda: ComplexType.get(F64Type.get()),
|
136
|
+
# Quantized types.
|
137
|
+
torch.float8_e4m3fn: lambda: Float8E4M3FNType.get(),
|
94
138
|
}
|
95
139
|
|
96
140
|
|
@@ -179,6 +223,7 @@ def to_torch_mlir(
|
|
179
223
|
] = None,
|
180
224
|
decomposition_denylist: Optional[List[torch._ops.OperatorBase]] = None,
|
181
225
|
weight_manager: Optional[AssetManager] = None,
|
226
|
+
export_predispatch: bool = False,
|
182
227
|
):
|
183
228
|
if dynamic_shapes is not None:
|
184
229
|
for shape in dynamic_shapes:
|
@@ -205,10 +250,11 @@ def to_torch_mlir(
|
|
205
250
|
# Ignore errors running the model. This can happen when the model has data dependent branches.
|
206
251
|
pass
|
207
252
|
|
253
|
+
register_dynamic_cache_pytree_node()
|
208
254
|
prog = _export(
|
209
255
|
model,
|
210
256
|
tuple(example_inputs),
|
211
|
-
pre_dispatch=
|
257
|
+
pre_dispatch=export_predispatch,
|
212
258
|
strict=False,
|
213
259
|
dynamic_shapes=dynamic_shapes,
|
214
260
|
)
|
gml/hf.py
CHANGED
@@ -24,8 +24,10 @@ from typing import Any, BinaryIO, Dict, List, Optional, TextIO, Tuple
|
|
24
24
|
|
25
25
|
import torch
|
26
26
|
import transformers
|
27
|
+
from rich.progress import Console
|
27
28
|
from transformers import (
|
28
29
|
BaseImageProcessor,
|
30
|
+
DynamicCache,
|
29
31
|
Pipeline,
|
30
32
|
PreTrainedModel,
|
31
33
|
PreTrainedTokenizer,
|
@@ -49,6 +51,7 @@ from gml.tensor import (
|
|
49
51
|
DetectionNumCandidatesDimension,
|
50
52
|
DetectionOutputDimension,
|
51
53
|
DimensionSemantics,
|
54
|
+
EmbeddingDimension,
|
52
55
|
ImageChannelDimension,
|
53
56
|
ImageHeightDimension,
|
54
57
|
ImageWidthDimension,
|
@@ -60,6 +63,11 @@ from gml.tensor import (
|
|
60
63
|
|
61
64
|
FALLBACK_RESIZE_SIZE = 512
|
62
65
|
|
66
|
+
# Set dynamic dimension max size to less than the int64 max, leaving leeway for the size to be ~4x by the model.
|
67
|
+
MAX_DYNAMIC_VAL = 2**61
|
68
|
+
|
69
|
+
console = Console()
|
70
|
+
|
63
71
|
|
64
72
|
class HuggingFaceTokenizer(Model):
|
65
73
|
def __init__(self, tokenizer: PreTrainedTokenizer, name: Optional[str] = None):
|
@@ -105,7 +113,6 @@ def flatten(items):
|
|
105
113
|
|
106
114
|
|
107
115
|
class WrapWithFunctionalCache(torch.nn.Module):
|
108
|
-
|
109
116
|
def __init__(self, model: transformers.PreTrainedModel):
|
110
117
|
super().__init__()
|
111
118
|
self.model = model
|
@@ -128,6 +135,8 @@ class HuggingFaceTextGenerationPipeline:
|
|
128
135
|
name: Optional[str] = None,
|
129
136
|
tokenizer_name: Optional[str] = None,
|
130
137
|
dynamic_seqlen: bool = False,
|
138
|
+
dynamic_batch: bool = False,
|
139
|
+
export_predispatch: bool = False,
|
131
140
|
):
|
132
141
|
self.pipeline = pipeline
|
133
142
|
self.tokenizer_model = HuggingFaceTokenizer(pipeline.tokenizer, tokenizer_name)
|
@@ -139,13 +148,20 @@ class HuggingFaceTextGenerationPipeline:
|
|
139
148
|
self.model = self.model.to(torch.float16)
|
140
149
|
self.model = WrapWithFunctionalCache(pipeline.model)
|
141
150
|
|
151
|
+
self.dynamic_batch = dynamic_batch
|
152
|
+
self.batch_size = 1
|
153
|
+
if self.dynamic_batch:
|
154
|
+
# dynamic tracing fails for dimensions of size 1.
|
155
|
+
self.batch_size = 2
|
156
|
+
|
142
157
|
self.language_model = TorchModel(
|
143
158
|
name,
|
144
159
|
torch_module=self.model,
|
160
|
+
export_predispatch=export_predispatch,
|
145
161
|
**self._guess_model_spec(dynamic_seqlen),
|
146
162
|
)
|
147
163
|
|
148
|
-
def _initialize_key_value_cache(self):
|
164
|
+
def _initialize_key_value_cache(self) -> DynamicCache:
|
149
165
|
cache = []
|
150
166
|
config = self.pipeline.model.config
|
151
167
|
head_dim = (
|
@@ -158,7 +174,12 @@ class HuggingFaceTextGenerationPipeline:
|
|
158
174
|
if config.num_key_value_heads is None
|
159
175
|
else config.num_key_value_heads
|
160
176
|
)
|
161
|
-
cache_shape = (
|
177
|
+
cache_shape = (
|
178
|
+
self.batch_size,
|
179
|
+
num_key_value_heads,
|
180
|
+
self._cache_length_for_tracing,
|
181
|
+
head_dim,
|
182
|
+
)
|
162
183
|
for _ in range(config.num_hidden_layers):
|
163
184
|
cache.append(
|
164
185
|
[
|
@@ -166,7 +187,67 @@ class HuggingFaceTextGenerationPipeline:
|
|
166
187
|
torch.zeros(cache_shape).to(torch.float16),
|
167
188
|
]
|
168
189
|
)
|
169
|
-
return cache
|
190
|
+
return DynamicCache.from_legacy_cache(cache)
|
191
|
+
|
192
|
+
def _parse_transformer_config(
|
193
|
+
self, model: transformers.PreTrainedModel
|
194
|
+
) -> modelexecpb.TransformerConfig:
|
195
|
+
# Only non-default rope config set the rope_scaling parameter
|
196
|
+
attention_head_size = getattr(
|
197
|
+
model.config,
|
198
|
+
"attention_head_size",
|
199
|
+
model.config.hidden_size // model.config.num_attention_heads,
|
200
|
+
)
|
201
|
+
partial_rotary_factor = getattr(model.config, "partial_rotary_factor", 1.0)
|
202
|
+
rotary_embedding_dim = getattr(
|
203
|
+
model.config,
|
204
|
+
"rotary_dim",
|
205
|
+
int(attention_head_size * partial_rotary_factor),
|
206
|
+
)
|
207
|
+
if (
|
208
|
+
hasattr(model.config, "rope_scaling")
|
209
|
+
and model.config.rope_scaling is not None
|
210
|
+
):
|
211
|
+
rope_scaling = model.config.rope_scaling
|
212
|
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
|
213
|
+
if not rope_type == "llama3":
|
214
|
+
raise NotImplementedError(
|
215
|
+
"rope scaling type {} is not supported".format(rope_type)
|
216
|
+
)
|
217
|
+
# LLAMA 3 example config: https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json
|
218
|
+
llama3_config = modelexecpb.Llama3RopeConfig()
|
219
|
+
llama3_config.theta = model.config.rope_theta
|
220
|
+
llama3_config.rotary_embedding_dim = rotary_embedding_dim
|
221
|
+
llama3_config.max_position_embeddings = model.config.max_position_embeddings
|
222
|
+
|
223
|
+
llama3_config.factor = rope_scaling["factor"]
|
224
|
+
llama3_config.high_freq_factor = rope_scaling["high_freq_factor"]
|
225
|
+
llama3_config.low_freq_factor = rope_scaling["low_freq_factor"]
|
226
|
+
llama3_config.original_max_position_embeddings = rope_scaling[
|
227
|
+
"original_max_position_embeddings"
|
228
|
+
]
|
229
|
+
return modelexecpb.TransformerConfig(
|
230
|
+
position_embedding_config=modelexecpb.PositionEmbeddingConfig(
|
231
|
+
kind=modelexecpb.PositionEmbeddingKind.POSITION_EMBEDDING_KIND_ROPE_LLAMA3,
|
232
|
+
llama3_rope_config=llama3_config,
|
233
|
+
),
|
234
|
+
)
|
235
|
+
# Default rope configs:
|
236
|
+
# 1. Llama-2: https://huggingface.co/NousResearch/Llama-2-7b-hf/blob/main/config.json
|
237
|
+
# 2. Qwen2.5: https://huggingface.co/Qwen/Qwen2.5-14B-Instruct-1M/blob/main/config.json
|
238
|
+
# 3. Mixtral: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json
|
239
|
+
default_rope_config = modelexecpb.DefaultRopeConfig()
|
240
|
+
default_rope_config.theta = model.config.rope_theta
|
241
|
+
default_rope_config.max_position_embeddings = (
|
242
|
+
model.config.max_position_embeddings
|
243
|
+
)
|
244
|
+
default_rope_config.rotary_embedding_dim = rotary_embedding_dim
|
245
|
+
return modelexecpb.TransformerConfig(
|
246
|
+
position_embedding_config=modelexecpb.PositionEmbeddingConfig(
|
247
|
+
kind=modelexecpb.PositionEmbeddingKind.POSITION_EMBEDDING_KIND_ROPE_DEFAULT,
|
248
|
+
default_rope_config=default_rope_config,
|
249
|
+
),
|
250
|
+
)
|
170
251
|
|
171
252
|
def _guess_model_spec(self, dynamic_seqlen: bool) -> Dict:
|
172
253
|
input_dict = self.pipeline.preprocess("this is a prompt! Test test test?")
|
@@ -179,7 +260,7 @@ class HuggingFaceTextGenerationPipeline:
|
|
179
260
|
input_tensor_semantics = []
|
180
261
|
|
181
262
|
# This currently assumes that all HF language models have inputs that are [B, NUM_TOKENS].
|
182
|
-
inputs.append(input_dict["input_ids"])
|
263
|
+
inputs.append(torch.tile(input_dict["input_ids"], [self.batch_size, 1]))
|
183
264
|
input_tensor_semantics.append(
|
184
265
|
TensorSemantics(
|
185
266
|
dimensions=[
|
@@ -192,7 +273,7 @@ class HuggingFaceTextGenerationPipeline:
|
|
192
273
|
# Assume that the model supports a KeyValue cache.
|
193
274
|
cache_values = self._initialize_key_value_cache()
|
194
275
|
inputs.append(cache_values)
|
195
|
-
for _ in cache_values:
|
276
|
+
for _ in range(len(cache_values)):
|
196
277
|
input_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
|
197
278
|
input_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
|
198
279
|
|
@@ -209,7 +290,7 @@ class HuggingFaceTextGenerationPipeline:
|
|
209
290
|
if (
|
210
291
|
not found_logits
|
211
292
|
and len(tensor.shape) == 3
|
212
|
-
and tensor.shape[0] ==
|
293
|
+
and tensor.shape[0] == self.batch_size
|
213
294
|
and tensor.shape[1] == seqlen
|
214
295
|
):
|
215
296
|
# This should be the logits tensor.
|
@@ -226,14 +307,38 @@ class HuggingFaceTextGenerationPipeline:
|
|
226
307
|
else:
|
227
308
|
output_tensor_semantics.append(AttentionKeyValueCacheTensorSemantics())
|
228
309
|
|
310
|
+
if not found_logits:
|
311
|
+
raise ValueError(
|
312
|
+
"could not determine output logits tensor for text generation model"
|
313
|
+
)
|
314
|
+
|
315
|
+
num_experts_per_tok = (
|
316
|
+
1
|
317
|
+
if not hasattr(self.pipeline.model.config, "num_experts_per_tok")
|
318
|
+
else self.pipeline.model.config.num_experts_per_tok
|
319
|
+
)
|
320
|
+
|
229
321
|
dynamic_shapes = None
|
230
|
-
|
322
|
+
# Set range to half of seqlen to account for # of tokens per expert.
|
323
|
+
# pytorch export creates a constraint on the number of possible tokens
|
324
|
+
# sent to each expert. That value is num_experts * seqlen. If we don't divide
|
325
|
+
# by number of experts, the tracing creates an integer value that exceeds the valid int64
|
326
|
+
# range and will throw a hard to decipher error message.
|
327
|
+
seqlen = torch.export.Dim(
|
328
|
+
"seqlen", min=2, max=MAX_DYNAMIC_VAL // num_experts_per_tok
|
329
|
+
)
|
231
330
|
|
232
|
-
cache_length = torch.export.Dim("cache_length", min=2, max=
|
331
|
+
cache_length = torch.export.Dim("cache_length", min=2, max=MAX_DYNAMIC_VAL)
|
233
332
|
dynamic_shapes = [
|
234
333
|
{1: seqlen},
|
235
|
-
[[{2: cache_length}, {2: cache_length}] for _ in cache_values],
|
334
|
+
[[{2: cache_length}, {2: cache_length}] for _ in range(len(cache_values))],
|
236
335
|
]
|
336
|
+
if self.dynamic_batch:
|
337
|
+
batch = torch.export.Dim("batch")
|
338
|
+
dynamic_shapes[0][0] = batch
|
339
|
+
for i in range(len(cache_values)):
|
340
|
+
dynamic_shapes[1][i][0][0] = batch
|
341
|
+
dynamic_shapes[1][i][1][0] = batch
|
237
342
|
|
238
343
|
return {
|
239
344
|
"example_inputs": inputs,
|
@@ -241,6 +346,7 @@ class HuggingFaceTextGenerationPipeline:
|
|
241
346
|
"input_tensor_semantics": input_tensor_semantics,
|
242
347
|
"output_tensor_semantics": output_tensor_semantics,
|
243
348
|
"generation_config": HuggingFaceGenerationConfig(self.pipeline.model),
|
349
|
+
"transformer_config": self._parse_transformer_config(self.pipeline.model),
|
244
350
|
}
|
245
351
|
|
246
352
|
def models(self) -> List[Model]:
|
@@ -695,8 +801,8 @@ class HuggingFaceZeroShotObjectDetectionPipeline:
|
|
695
801
|
|
696
802
|
spec["dynamic_shapes"].extend(
|
697
803
|
[
|
698
|
-
{0: "num_labels"},
|
699
|
-
{0: "num_labels"},
|
804
|
+
{0: torch.export.Dim("num_labels", max=MAX_DYNAMIC_VAL)},
|
805
|
+
{0: torch.export.Dim("num_labels", max=MAX_DYNAMIC_VAL)},
|
700
806
|
]
|
701
807
|
)
|
702
808
|
|
@@ -762,33 +868,121 @@ class HuggingFaceDepthEstimationPipeline:
|
|
762
868
|
return [self.model]
|
763
869
|
|
764
870
|
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
871
|
+
class HuggingFaceFeatureExtractionPipeline:
|
872
|
+
def __init__(self, pipeline: Pipeline, name: Optional[str] = None):
|
873
|
+
self.pipeline = pipeline
|
874
|
+
if name is None:
|
875
|
+
name = pipeline.model.name_or_path
|
876
|
+
|
877
|
+
self.tokenizer_model = HuggingFaceTokenizer(self.pipeline.tokenizer)
|
878
|
+
|
879
|
+
self.model = TorchModel(
|
880
|
+
name=name,
|
881
|
+
torch_module=self.pipeline.model,
|
882
|
+
**self._guess_model_spec(),
|
883
|
+
)
|
884
|
+
|
885
|
+
def _guess_model_spec(self) -> Dict:
|
886
|
+
spec = {
|
887
|
+
"example_inputs": [],
|
888
|
+
"input_tensor_semantics": [],
|
889
|
+
"output_tensor_semantics": [],
|
890
|
+
"dynamic_shapes": [],
|
891
|
+
}
|
892
|
+
|
893
|
+
input_dict = self.pipeline.preprocess("this is a prompt! Test test test?")
|
894
|
+
if "input_ids" not in input_dict:
|
895
|
+
raise ValueError(
|
896
|
+
'HuggingFaceFeatureExtractionPipeline expects preprocessed inputs to have an "input_ids" tensor'
|
770
897
|
)
|
898
|
+
|
899
|
+
spec["example_inputs"].append(input_dict["input_ids"])
|
900
|
+
spec["input_tensor_semantics"].extend(
|
901
|
+
[
|
902
|
+
TensorSemantics(
|
903
|
+
dimensions=[
|
904
|
+
BatchDimension(),
|
905
|
+
TokensDimension(),
|
906
|
+
]
|
907
|
+
),
|
908
|
+
]
|
909
|
+
)
|
910
|
+
|
911
|
+
spec["output_tensor_semantics"].extend(
|
912
|
+
[
|
913
|
+
TensorSemantics(
|
914
|
+
dimensions=[
|
915
|
+
BatchDimension(),
|
916
|
+
TokensDimension(),
|
917
|
+
EmbeddingDimension(),
|
918
|
+
],
|
919
|
+
),
|
920
|
+
TensorSemantics(
|
921
|
+
dimensions=[
|
922
|
+
BatchDimension(),
|
923
|
+
EmbeddingDimension(),
|
924
|
+
],
|
925
|
+
),
|
926
|
+
]
|
771
927
|
)
|
772
928
|
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
elif pipeline.task == "object-detection":
|
778
|
-
return HuggingFaceObjectDetectionPipeline(pipeline, **kwargs).models()
|
779
|
-
elif pipeline.task == "zero-shot-object-detection":
|
780
|
-
return HuggingFaceZeroShotObjectDetectionPipeline(pipeline, **kwargs).models()
|
781
|
-
elif pipeline.task == "depth-estimation":
|
782
|
-
return HuggingFaceDepthEstimationPipeline(pipeline, **kwargs).models()
|
783
|
-
raise ValueError(
|
784
|
-
"unimplemented: hugging face pipeline task: {} (supported tasks: [{}])".format(
|
785
|
-
pipeline.task,
|
929
|
+
max_seqlen = (
|
930
|
+
getattr(self.pipeline.model.config, "max_position_embeddings", 500) - 1
|
931
|
+
)
|
932
|
+
spec["dynamic_shapes"].extend(
|
786
933
|
[
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
934
|
+
{
|
935
|
+
1: torch.export.Dim(
|
936
|
+
"seqlen",
|
937
|
+
max=max_seqlen,
|
938
|
+
)
|
939
|
+
},
|
940
|
+
]
|
941
|
+
)
|
942
|
+
return spec
|
943
|
+
|
944
|
+
def models(self) -> List[Model]:
|
945
|
+
return [self.model, self.tokenizer_model]
|
946
|
+
|
947
|
+
|
948
|
+
def import_huggingface_pipeline(pipeline: Pipeline, **kwargs) -> List[Model]:
|
949
|
+
with console.status(
|
950
|
+
f'Importing HuggingFace pipeline: "{pipeline.model.name_or_path}"'
|
951
|
+
):
|
952
|
+
if pipeline.framework != "pt":
|
953
|
+
raise ValueError(
|
954
|
+
"unimplemented: hugging face pipeline framework: {}".format(
|
955
|
+
pipeline.framework
|
956
|
+
)
|
957
|
+
)
|
958
|
+
|
959
|
+
if pipeline.task == "text-generation":
|
960
|
+
result = HuggingFaceTextGenerationPipeline(pipeline, **kwargs).models()
|
961
|
+
elif pipeline.task == "image-segmentation":
|
962
|
+
result = HuggingFaceImageSegmentationPipeline(pipeline, **kwargs).models()
|
963
|
+
elif pipeline.task == "object-detection":
|
964
|
+
result = HuggingFaceObjectDetectionPipeline(pipeline, **kwargs).models()
|
965
|
+
elif pipeline.task == "zero-shot-object-detection":
|
966
|
+
result = HuggingFaceZeroShotObjectDetectionPipeline(
|
967
|
+
pipeline, **kwargs
|
968
|
+
).models()
|
969
|
+
elif pipeline.task == "depth-estimation":
|
970
|
+
result = HuggingFaceDepthEstimationPipeline(pipeline, **kwargs).models()
|
971
|
+
elif pipeline.task == "feature-extraction":
|
972
|
+
result = HuggingFaceFeatureExtractionPipeline(pipeline, **kwargs).models()
|
973
|
+
else:
|
974
|
+
raise ValueError(
|
975
|
+
"unimplemented: hugging face pipeline task: {} (supported tasks: [{}])".format(
|
976
|
+
pipeline.task,
|
977
|
+
[
|
978
|
+
"text-generation",
|
979
|
+
"image-segmentation",
|
980
|
+
"object-detection",
|
981
|
+
"zero-shot-object-detection",
|
982
|
+
"depth-estimation",
|
983
|
+
"feature-extraction",
|
984
|
+
],
|
985
|
+
)
|
986
|
+
)
|
987
|
+
console.print(f'Imported HuggingFace pipeline: "{pipeline.model.name_or_path}".')
|
988
|
+
return result
|