truss 0.10.9rc601__py3-none-any.whl → 0.10.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/base/constants.py +0 -1
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +30 -22
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +8 -2
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +63 -0
- truss/cli/train/deploy_from_checkpoint_config_whisper.yml +17 -0
- truss/cli/train_commands.py +11 -3
- truss/contexts/image_builder/cache_warmer.py +1 -3
- truss/contexts/image_builder/serving_image_builder.py +24 -32
- truss/remote/baseten/api.py +11 -0
- truss/remote/baseten/core.py +209 -1
- truss/remote/baseten/utils/time.py +15 -0
- truss/templates/server/model_wrapper.py +0 -12
- truss/templates/server/requirements.txt +1 -1
- truss/templates/server/truss_server.py +0 -13
- truss/templates/server.Dockerfile.jinja +1 -1
- truss/tests/cli/train/test_deploy_checkpoints.py +436 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +1 -1
- truss/tests/remote/baseten/conftest.py +18 -0
- truss/tests/remote/baseten/test_api.py +49 -14
- truss/tests/remote/baseten/test_core.py +517 -1
- truss/tests/test_data/test_openai/model/model.py +0 -3
- truss/truss_handle/truss_handle.py +0 -1
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/METADATA +2 -2
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/RECORD +30 -28
- truss_train/definitions.py +6 -0
- truss_train/deployment.py +15 -2
- truss/tests/util/test_basetenpointer.py +0 -227
- truss/util/basetenpointer.py +0 -160
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/WHEEL +0 -0
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/entry_points.txt +0 -0
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,7 @@ truss/__init__.py,sha256=CoUcP6vx_pocyemRmpbCPlndkHhdMkABAlr0ZXVuPCk,1163
|
|
|
2
2
|
truss/api/__init__.py,sha256=spBAa_m1pItiid97iDLKPmumgAkSirPkv-E8RWMZyOk,5090
|
|
3
3
|
truss/api/definitions.py,sha256=QAaIBqL59Q-R7HtLcXcoeCIWBN2HqOzApdFX0PpCq2s,1604
|
|
4
4
|
truss/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
-
truss/base/constants.py,sha256=
|
|
5
|
+
truss/base/constants.py,sha256=qwNNkd9EOAuiTxYLVccJaiPCNRayBAFvyj_GisYOT3I,3488
|
|
6
6
|
truss/base/custom_types.py,sha256=FUSIT2lPOQb6gfg6IzT63YBV8r8L6NIZ0D74Fp3e_jQ,2835
|
|
7
7
|
truss/base/errors.py,sha256=zDVLEvseTChdPP0oNhBBQCtQUtZJUaof5zeWMIjqz6o,691
|
|
8
8
|
truss/base/trt_llm_config.py,sha256=CRz3AqGDAyv8YpcBWXUrnfjvNAauyo3yf8ZOGVsSt6g,32782
|
|
@@ -11,7 +11,7 @@ truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
|
|
|
11
11
|
truss/cli/chains_commands.py,sha256=y6pdIAGCcKOPG9bPuCXPfSA0onQm5x-tT_3blSBfPYg,16971
|
|
12
12
|
truss/cli/cli.py,sha256=PaMkuwXZflkU7sa1tEoT_Zmy-iBkEZs1m4IVqcieaeo,30367
|
|
13
13
|
truss/cli/remote_cli.py,sha256=G_xCKRXzgkCmkiZJhUFfsv5YSVgde1jLA5LPQitpZgI,1905
|
|
14
|
-
truss/cli/train_commands.py,sha256=
|
|
14
|
+
truss/cli/train_commands.py,sha256=P9bdnpq1SgEGXBaVf9joKdsaCDX2v29P4MhLMuz-jYw,12344
|
|
15
15
|
truss/cli/logs/base_watcher.py,sha256=KKyd7lIrdaEeDVt8EtjMioSPGVpLyOcF0ewyzE_GGdQ,2785
|
|
16
16
|
truss/cli/logs/model_log_watcher.py,sha256=NACcP-wkcaroYa2Cb9BZC7Yr0554WZa_FSM2LXOf4A8,1263
|
|
17
17
|
truss/cli/logs/training_log_watcher.py,sha256=r6HRqrLnz-PiKTUXiDYYxg4ZnP8vYcXlEX1YmgHhzlo,1173
|
|
@@ -19,22 +19,24 @@ truss/cli/logs/utils.py,sha256=z-U_FG4BUzdZLbE3BnXb4DZQ0zt3LSZ3PiQpLaDuc3o,1031
|
|
|
19
19
|
truss/cli/train/common.py,sha256=Es1yllSYxjM9x2uBzTGbYwyd8ML66cqqge0XO8_G_X0,992
|
|
20
20
|
truss/cli/train/core.py,sha256=MBOhPSVYOU7wVh09uWQrJDEVOhJQug_2Odv3u6tCVTA,13855
|
|
21
21
|
truss/cli/train/deploy_from_checkpoint_config.yml,sha256=mktaVrfhN8Kjx1UveC4xr-gTW-kjwbHvq6bx_LpO-Wg,371
|
|
22
|
+
truss/cli/train/deploy_from_checkpoint_config_whisper.yml,sha256=6GbOorYC8ml0UyOUvuBpFO_fuYtYE646JqsalR-D4oY,406
|
|
22
23
|
truss/cli/train/metrics_watcher.py,sha256=ftrLQ5m7V1lAqcAvdGbMv5r0aF4D0lypfKjokCBQvLw,12798
|
|
23
24
|
truss/cli/train/poller.py,sha256=TGRzELxsicga0bEXewSX1ujw6lfPmDnHd6nr8zvOFO8,3550
|
|
24
25
|
truss/cli/train/types.py,sha256=alGtr4Q71GeB65PpGMhsoKygw4k_ncR6MKIP1ioP8rI,951
|
|
25
26
|
truss/cli/train/deploy_checkpoints/__init__.py,sha256=wL-M2yu8PxO2tFvjwshXAfPnB-5TlvsBp2v_bdzimRU,99
|
|
26
|
-
truss/cli/train/deploy_checkpoints/deploy_checkpoints.py,sha256=
|
|
27
|
-
truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py,sha256=
|
|
28
|
-
truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py,sha256=
|
|
27
|
+
truss/cli/train/deploy_checkpoints/deploy_checkpoints.py,sha256=xfblHi3py7GDgY24NcuAaDKzcQeOm67rjtWOK6vAEe4,17352
|
|
28
|
+
truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py,sha256=6x5nS_HnWYtS9vi-Pg8akzrJk9L_agjvFhm5EFh1m6Y,1964
|
|
29
|
+
truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py,sha256=FYRG5KTMlxEMZS-RA_m2gp1wuqWbSpqt2RhdQfLibhA,3968
|
|
29
30
|
truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py,sha256=P91dIAzuhl2GlzmrWwCcYI7uCMT1Lm7C79JQHM_exN4,4442
|
|
31
|
+
truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py,sha256=NSo2kEn-CxawGvUhn-xE81noxoTJ0cCv8X9fkrMxAsM,2617
|
|
30
32
|
truss/cli/utils/common.py,sha256=aWnla4qMSEz57dRMTl7R-EaScsuEpnQUeziGUaIeqeU,6149
|
|
31
33
|
truss/cli/utils/output.py,sha256=GNjU85ZAMp5BI6Yij5wYXcaAvpm_kmHV0nHNmdkMxb0,646
|
|
32
34
|
truss/cli/utils/self_upgrade.py,sha256=eTJZA4Wc8uUp4Qh6viRQp6bZm--wnQp7KWe5KRRpPtg,5427
|
|
33
35
|
truss/contexts/docker_build_setup.py,sha256=cF4ExZgtYvrWxvyCAaUZUvV_DB_7__MqVomUDpalvKo,3925
|
|
34
36
|
truss/contexts/truss_context.py,sha256=uS6L-ACHxNk0BsJwESOHh1lA0OGGw0pb33aFKGsASj4,436
|
|
35
|
-
truss/contexts/image_builder/cache_warmer.py,sha256=
|
|
37
|
+
truss/contexts/image_builder/cache_warmer.py,sha256=TGMV1Mh87n2e_dSowH0sf0rZhZraDOR-LVapZL3a5r8,7377
|
|
36
38
|
truss/contexts/image_builder/image_builder.py,sha256=IuRgDeeoHVLzIkJvKtX3807eeqEyaroCs_KWDcIHZUg,1461
|
|
37
|
-
truss/contexts/image_builder/serving_image_builder.py,sha256=
|
|
39
|
+
truss/contexts/image_builder/serving_image_builder.py,sha256=FH5HPnrr9_OomN5WplsyUrGGETe9ld6h3q9JCpvB6FY,33322
|
|
38
40
|
truss/contexts/image_builder/util.py,sha256=y2-CjUKv0XV-0w2sr1fUCflysDJLsoU4oPp6tvvoFnk,1203
|
|
39
41
|
truss/contexts/local_loader/docker_build_emulator.py,sha256=rmf7I28zksSmHjwvJMx2rIa6xK4KeR5fBm5YFth_fQg,2464
|
|
40
42
|
truss/contexts/local_loader/dockerfile_parser.py,sha256=GoRJ0Af_3ILyLhjovK5lrCGn1rMxz6W3l681ro17ZzI,1344
|
|
@@ -50,9 +52,9 @@ truss/patch/truss_dir_patch_applier.py,sha256=ALnaVnu96g0kF2UmGuBFTua3lrXpwAy4sG
|
|
|
50
52
|
truss/remote/remote_factory.py,sha256=-0gLh_yIyNDgD48Q6sR8Yo5dOMQg84lrHRvn_XR0n4s,3585
|
|
51
53
|
truss/remote/truss_remote.py,sha256=TEe6h6by5-JLy7PMFsDN2QxIY5FmdIYN3bKvHHl02xM,8440
|
|
52
54
|
truss/remote/baseten/__init__.py,sha256=XNqJW1zyp143XQc6-7XVwsUA_Q_ZJv_ausn1_Ohtw9Y,176
|
|
53
|
-
truss/remote/baseten/api.py,sha256=
|
|
55
|
+
truss/remote/baseten/api.py,sha256=6Nie4hv4z5I62boeCQvP3tGA0Pwu96bMgz1vp5Tkxao,24447
|
|
54
56
|
truss/remote/baseten/auth.py,sha256=tI7s6cI2EZgzpMIzrdbILHyGwiHDnmoKf_JBhJXT55E,776
|
|
55
|
-
truss/remote/baseten/core.py,sha256=
|
|
57
|
+
truss/remote/baseten/core.py,sha256=uxtmBI9RAVHu1glIEJb5Q4ccJYLeZM1Cp5Svb9W68Yw,21965
|
|
56
58
|
truss/remote/baseten/custom_types.py,sha256=g7MwgYaeqIxF-e170G5iEVLWiw5jgAnqXztIUqVkdyc,3227
|
|
57
59
|
truss/remote/baseten/error.py,sha256=3TNTwwPqZnr4NRd9Sl6SfLUQR2fz9l6akDPpOntTpzA,578
|
|
58
60
|
truss/remote/baseten/remote.py,sha256=Se8AES5mk8jxa8S9fN2DSG7wnsaV7ftRjJ4Uwc_w_S0,22544
|
|
@@ -60,6 +62,7 @@ truss/remote/baseten/rest_client.py,sha256=_t3CWsWARt2u0C0fDsF4rtvkkHe-lH7KXoPxW
|
|
|
60
62
|
truss/remote/baseten/service.py,sha256=j_cCbSkpvCqIoptCDp65BCRKORPrp9NswTdI3BbiFqU,6081
|
|
61
63
|
truss/remote/baseten/utils/status.py,sha256=jputc9N9AHXxUuW4KOk6mcZKzQ_gOBOe5BSx9K0DxPY,1266
|
|
62
64
|
truss/remote/baseten/utils/tar.py,sha256=pMUv--YkwXDngUx1WUOK-KmAIKMcOg2E-CD5x4heh3s,2514
|
|
65
|
+
truss/remote/baseten/utils/time.py,sha256=Ry9GMjYnbIGYVIGwtmv4V8ljWjvdcaCf5NOQzlNeGxI,397
|
|
63
66
|
truss/remote/baseten/utils/transfer.py,sha256=d3VptuQb6M1nyS6kz0BAfeOYDLkMKUjatJXpY-mp-As,1548
|
|
64
67
|
truss/templates/README.md.jinja,sha256=N7CJdyldZuJamj5jLh47le0hFBdu9irVsTBqoxhPNPQ,2476
|
|
65
68
|
truss/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -68,7 +71,7 @@ truss/templates/cache.Dockerfile.jinja,sha256=LhsVP9F3BATKQGkgya_YT4v6ABTUkpy-Jb
|
|
|
68
71
|
truss/templates/cache_requirements.txt,sha256=xoPoJ-OVnf1z6oq_RVM3vCr3ionByyqMLj7wGs61nUs,87
|
|
69
72
|
truss/templates/copy_cache_files.Dockerfile.jinja,sha256=arHldnuclt7vUFHyRz6vus5NGMDkIofm-1RU37A0xZM,98
|
|
70
73
|
truss/templates/docker_server_requirements.txt,sha256=PyhOPKAmKW1N2vLvTfLMwsEtuGpoRrbWuNo7tT6v2Mc,18
|
|
71
|
-
truss/templates/server.Dockerfile.jinja,sha256=
|
|
74
|
+
truss/templates/server.Dockerfile.jinja,sha256=Ts4kty2ZXTJS69XkNHTNHtEyr8yf8VwNQgBBLY89chk,5996
|
|
72
75
|
truss/templates/control/requirements.txt,sha256=Kk0tYID7trPk5gwX38Wrt2-YGWZAXFJCJRcqJ8ZzCjc,251
|
|
73
76
|
truss/templates/control/control/application.py,sha256=jYeta6hWe1SkfLL3W4IDmdYjg3ZuKqI_UagWYs5RB_E,3793
|
|
74
77
|
truss/templates/control/control/endpoints.py,sha256=FM-sgao7I3gMoUTasM3Xq_g2LDoJQe75JxIoaQxzeNo,10031
|
|
@@ -92,9 +95,9 @@ truss/templates/docker_server/proxy.conf.jinja,sha256=Lg-PcZzKflG85exZKHNgW_I6r0
|
|
|
92
95
|
truss/templates/docker_server/supervisord.conf.jinja,sha256=CoaSLv0Lr8t1tS_q102IFufNX2lWrlbCHJLjMhYjOwM,1711
|
|
93
96
|
truss/templates/server/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
94
97
|
truss/templates/server/main.py,sha256=kWXrdD8z8IpamyWxc8qcvd5ck9gM1Kz2QH5qHJCnmOQ,222
|
|
95
|
-
truss/templates/server/model_wrapper.py,sha256=
|
|
96
|
-
truss/templates/server/requirements.txt,sha256=
|
|
97
|
-
truss/templates/server/truss_server.py,sha256=
|
|
98
|
+
truss/templates/server/model_wrapper.py,sha256=k75VVISwwlsx5EGb82UZsu8kCM_i6Yi3-Hd0-Kpm1yo,42055
|
|
99
|
+
truss/templates/server/requirements.txt,sha256=Xvf7mT4zjK1B6rIrNW80-An03yCNCXvWiB6OvWrhIxg,672
|
|
100
|
+
truss/templates/server/truss_server.py,sha256=ob_nceeGtFPZzKKdk_ZZGLoZrJOGE6hR52xM1sPR97A,19498
|
|
98
101
|
truss/templates/server/common/__init__.py,sha256=qHIqr68L5Tn4mV6S-PbORpcuJ4jmtBR8aCuRTIWDvNo,85
|
|
99
102
|
truss/templates/server/common/errors.py,sha256=qWeZlmNI8ZGbZbOIp_mtS6IKvUFIzhj3QH8zp-xTp9o,8554
|
|
100
103
|
truss/templates/server/common/patches.py,sha256=uEOzvDnXsHOkTSa8zygGYuR4GHhrFNVHNQc5peJcwvo,1393
|
|
@@ -136,10 +139,10 @@ truss/tests/test_truss_gatherer.py,sha256=bn288OEkC49YY0mhly4cAl410ktZPfElNdWwZy
|
|
|
136
139
|
truss/tests/test_truss_handle.py,sha256=-xz9VXkecXDTslmQZ-dmUmQLnvD0uumRqHS2uvGlMBA,30750
|
|
137
140
|
truss/tests/test_util.py,sha256=hs1bNMkXKEdoPRx4Nw-NAEdoibR92OubZuADGmbiYsQ,1344
|
|
138
141
|
truss/tests/cli/test_cli.py,sha256=yfbVS5u1hnAmmA8mJ539vj3lhH-JVGUvC4Q_Mbort44,787
|
|
139
|
-
truss/tests/cli/train/test_deploy_checkpoints.py,sha256=
|
|
142
|
+
truss/tests/cli/train/test_deploy_checkpoints.py,sha256=wQZ3DPLPAyXE3iaQiyHJTBO15v_gXN44eDk1StYkKmM,44764
|
|
140
143
|
truss/tests/cli/train/test_train_cli_core.py,sha256=T1Xa6-NRk2nTJGX6sXaA8x4qCwL3Ini72PBI2gW7rYM,7879
|
|
141
144
|
truss/tests/cli/train/resources/test_deploy_from_checkpoint_config.yml,sha256=GF7r9l0KaeXiUYCPSBpeMPd2QG6PeWWyI12NdbqLOgc,1930
|
|
142
|
-
truss/tests/contexts/image_builder/test_serving_image_builder.py,sha256=
|
|
145
|
+
truss/tests/contexts/image_builder/test_serving_image_builder.py,sha256=iJA7nxcLXhBmyjhLIKeN64ql0OI_R53l-qSt3SsENV8,22368
|
|
143
146
|
truss/tests/contexts/local_loader/test_load_local.py,sha256=D1qMH2IpYA2j5009v50QMgUnKdeOsX15ndkwXe10a4E,801
|
|
144
147
|
truss/tests/contexts/local_loader/test_truss_module_finder.py,sha256=oN1K2lg3ATHY5yOVUTfQIaSqusTF9I2wFaYaTSo5-O4,5342
|
|
145
148
|
truss/tests/local/test_local_config_handler.py,sha256=aLvcOyfppskA2MziVLy_kMcagjxMpO4mjar9zxUN6g0,2245
|
|
@@ -151,9 +154,10 @@ truss/tests/patch/test_truss_dir_patch_applier.py,sha256=P0SCqkVXLge7laSPlFWZM7A
|
|
|
151
154
|
truss/tests/patch/test_types.py,sha256=OUVDiLckbjjjEN49I4hm62emOTAr8lv_QooJrmXxs5o,306
|
|
152
155
|
truss/tests/remote/test_remote_factory.py,sha256=S-iZlF5Pf5SDoFUnMlZXy9iRMkosVgwLd22evzWlFr0,4842
|
|
153
156
|
truss/tests/remote/test_truss_remote.py,sha256=Rguyrnbx5RlbPJHFfCtsRtX1czAJ9Fo0aeC5EWRVkGw,2726
|
|
154
|
-
truss/tests/remote/baseten/
|
|
157
|
+
truss/tests/remote/baseten/conftest.py,sha256=vNk0nfDB7XdmqatOMhjdANCWFGYM4VwSHVKlaBO2PPk,442
|
|
158
|
+
truss/tests/remote/baseten/test_api.py,sha256=AKJeNsrUtTNa0QPClfEvXlBOSJ214PKp23ULehMRJOQ,15885
|
|
155
159
|
truss/tests/remote/baseten/test_auth.py,sha256=ttu4bDnmwGfo3oiNut4HVGnh-QnjAefwZJctiibQJKY,669
|
|
156
|
-
truss/tests/remote/baseten/test_core.py,sha256=
|
|
160
|
+
truss/tests/remote/baseten/test_core.py,sha256=6NzJTDmoSUv6Muy1LFEYIUg10-cqw-hbLyeTSWcdNjY,26117
|
|
157
161
|
truss/tests/remote/baseten/test_remote.py,sha256=y1qSPL1t7dBeYI3xMFn436fttG7wkYdAoENTz7qKObg,23634
|
|
158
162
|
truss/tests/remote/baseten/test_service.py,sha256=ufZbtQlBNIzFCxRt_iE-APLpWbVw_3ViUpSh6H9W5nU,1945
|
|
159
163
|
truss/tests/templates/control/control/test_endpoints.py,sha256=tGU3w8zOKC8LfWGdhp-TlV7E603KXg2xGwpqDdf8Pnw,3385
|
|
@@ -236,7 +240,7 @@ truss/tests/test_data/test_go_custom_server_truss/docker/main.go,sha256=WR3mJU1o
|
|
|
236
240
|
truss/tests/test_data/test_openai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
237
241
|
truss/tests/test_data/test_openai/config.yaml,sha256=ByY_Smgx0lw24Yj0hqgofEmL3nrGNj7gZE5iBKlvwxk,235
|
|
238
242
|
truss/tests/test_data/test_openai/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
239
|
-
truss/tests/test_data/test_openai/model/model.py,sha256=
|
|
243
|
+
truss/tests/test_data/test_openai/model/model.py,sha256=GEtIJnWlU1snBid2sS-bZHrjQpP8UzL8tanzyH_tdgE,319
|
|
240
244
|
truss/tests/test_data/test_pyantic_v1/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
241
245
|
truss/tests/test_data/test_pyantic_v1/config.yaml,sha256=fqWpH3E4UPEnjvAw6Q9_F5oZZLy69RAfycbgtmCFsXo,270
|
|
242
246
|
truss/tests/test_data/test_pyantic_v1/requirements.txt,sha256=OpG4JAdJME9VWjoNftdHYg-y94k2gbhqdM1_NwOgcT8,13
|
|
@@ -305,7 +309,6 @@ truss/tests/test_data/test_truss_with_error/packages/helpers_1.py,sha256=qIm-hQY
|
|
|
305
309
|
truss/tests/test_data/test_truss_with_error/packages/helpers_2.py,sha256=q_UpVfXq_K2tuHv6YwsIzVHC3sy5k5hKDw6lMCdS0oc,53
|
|
306
310
|
truss/tests/trt_llm/test_trt_llm_config.py,sha256=lNQ4EEkOsiT17KvnvW1snCeEBd7K_cl9_Y0dko3qpn8,8505
|
|
307
311
|
truss/tests/trt_llm/test_validation.py,sha256=dmax2EHxRfqxJvWzV8uubkTef50833KBBHw-WkHufL8,2120
|
|
308
|
-
truss/tests/util/test_basetenpointer.py,sha256=Bdms21_m8T4xmFNHRO5nS2tU2wU7094_1SkfBxjptmk,9824
|
|
309
312
|
truss/tests/util/test_config_checks.py,sha256=aoZF_Q-eRd3qz5wjUqa8Cr_7qF2SxodXbBIY_DBuFWg,522
|
|
310
313
|
truss/tests/util/test_env_vars.py,sha256=hthgB1mU0bJb1H4Jugc-0khArlLZ3x6tLE82cDaa-J0,390
|
|
311
314
|
truss/tests/util/test_path.py,sha256=YfW3-IM_7iRsdR1Cb26KB1BkDsG_53_BUGBzoxY2Nog,7408
|
|
@@ -316,7 +319,7 @@ truss/truss_handle/build.py,sha256=BKFV-S57tnWcfRffvQ7SPp78BrjmRy3GhgF6ThaIrDM,3
|
|
|
316
319
|
truss/truss_handle/decorators.py,sha256=PUR5w2rl_cvcsVtAUpcYLzNXuOml9R0-wtpXy-9hDPk,407
|
|
317
320
|
truss/truss_handle/readme_generator.py,sha256=B4XbGwUjzMNOr71DWNAL8kCu5_ZHq7YOM8yVGaOZMSE,716
|
|
318
321
|
truss/truss_handle/truss_gatherer.py,sha256=Xysl_UnCVhehPfZeHa8p7WFp94ENqh-VVpbuqnCui3A,2870
|
|
319
|
-
truss/truss_handle/truss_handle.py,sha256=
|
|
322
|
+
truss/truss_handle/truss_handle.py,sha256=WF2MQSly9DQ1SoAvqfi87Ulu4llTadpXoncsDjpL79E,40886
|
|
320
323
|
truss/truss_handle/patch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
321
324
|
truss/truss_handle/patch/calc_patch.py,sha256=Qyk1QmacK4jy9Ia8-93L8VtAWJhw15z22DdZUkBKlys,18334
|
|
322
325
|
truss/truss_handle/patch/constants.py,sha256=pCEi5Pwi8Rnqthrr3VEsWL9EP1P1VV1T8DEYuitHLmc,139
|
|
@@ -327,7 +330,6 @@ truss/truss_handle/patch/local_truss_patch_applier.py,sha256=fOHWKt3teYnbqeRsF63
|
|
|
327
330
|
truss/truss_handle/patch/signature.py,sha256=8eas8gy6Japd1hrgdmtHmKTTxQmWsbmgKRQQGL2PVuA,858
|
|
328
331
|
truss/truss_handle/patch/truss_dir_patch_applier.py,sha256=uhhHvKYHn_dpfz0xp4jwO9_qAej5sO3f8of_h-21PP4,3666
|
|
329
332
|
truss/util/.truss_ignore,sha256=jpQA9ou-r_JEIcEHsUqGLHhir_m3d4IPGNyzKXtS-2g,3131
|
|
330
|
-
truss/util/basetenpointer.py,sha256=PJ_meuTuXAopnWsHe1ZaH2RnltfmqdQ4QeXDQEPrblI,5596
|
|
331
333
|
truss/util/docker.py,sha256=6PD7kMBBrOjsdvgkuSv7JMgZbe3NoJIeGasljMm2SwA,3934
|
|
332
334
|
truss/util/download.py,sha256=1lfBwzyaNLEp7SAVrBd9BX5inZpkCVp8sBnS9RNoiJA,2521
|
|
333
335
|
truss/util/env_vars.py,sha256=7Bv686eER71Barrs6fNamk_TrTJGmu9yV2TxaVmupn0,1232
|
|
@@ -357,12 +359,12 @@ truss_chains/remote_chainlet/model_skeleton.py,sha256=8ZReLOO2MLcdg7bNZ61C-6j-e6
|
|
|
357
359
|
truss_chains/remote_chainlet/stub.py,sha256=Y2gDUzMY9WRaQNHIz-o4dfLUfFyYV9dUhIRQcfgrY8g,17209
|
|
358
360
|
truss_chains/remote_chainlet/utils.py,sha256=O_5P-VAUvg0cegEW1uKCOf5EBwD8rEGYVoGMivOmc7k,22374
|
|
359
361
|
truss_train/__init__.py,sha256=7hE6j6-u6UGzCGaNp3CsCN0kAVjBus1Ekups-Bk0fi4,837
|
|
360
|
-
truss_train/definitions.py,sha256=
|
|
361
|
-
truss_train/deployment.py,sha256=
|
|
362
|
+
truss_train/definitions.py,sha256=V985HhY4rdXL10DZxpFEpze9ScxzWErMht4WwaPknGU,6789
|
|
363
|
+
truss_train/deployment.py,sha256=fDYRfzFRtVKMRVG0bKXYPmx6HXwLE0ukSQ0f81hG8kk,3020
|
|
362
364
|
truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
|
|
363
365
|
truss_train/public_api.py,sha256=9N_NstiUlmBuLUwH_fNG_1x7OhGCytZLNvqKXBlStrM,1220
|
|
364
|
-
truss-0.10.
|
|
365
|
-
truss-0.10.
|
|
366
|
-
truss-0.10.
|
|
367
|
-
truss-0.10.
|
|
368
|
-
truss-0.10.
|
|
366
|
+
truss-0.10.10.dist-info/METADATA,sha256=5q5tQ4MtWhQspNwHEsMnSDOLj-fYFeB7zL4VinS2I28,6670
|
|
367
|
+
truss-0.10.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
368
|
+
truss-0.10.10.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
|
|
369
|
+
truss-0.10.10.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
|
|
370
|
+
truss-0.10.10.dist-info/RECORD,,
|
truss_train/definitions.py
CHANGED
|
@@ -18,6 +18,7 @@ class ModelWeightsFormat(str, enum.Enum):
|
|
|
18
18
|
|
|
19
19
|
LORA = "lora"
|
|
20
20
|
FULL = "full"
|
|
21
|
+
WHISPER = "whisper"
|
|
21
22
|
|
|
22
23
|
def to_truss_config(self) -> "ModelWeightsFormat":
|
|
23
24
|
return ModelWeightsFormat[self.name]
|
|
@@ -126,6 +127,7 @@ class TrainingJob(custom_types.SafeModelNoExtra):
|
|
|
126
127
|
image: Image
|
|
127
128
|
compute: Compute = Compute()
|
|
128
129
|
runtime: Runtime = Runtime()
|
|
130
|
+
name: Optional[str] = None
|
|
129
131
|
|
|
130
132
|
def model_dump(self, *args, **kwargs):
|
|
131
133
|
data = super().model_dump(*args, **kwargs)
|
|
@@ -170,6 +172,10 @@ class FullCheckpoint(Checkpoint):
|
|
|
170
172
|
model_weight_format: ModelWeightsFormat = ModelWeightsFormat.FULL
|
|
171
173
|
|
|
172
174
|
|
|
175
|
+
class WhisperCheckpoint(Checkpoint):
|
|
176
|
+
model_weight_format: ModelWeightsFormat = ModelWeightsFormat.WHISPER
|
|
177
|
+
|
|
178
|
+
|
|
173
179
|
class LoRACheckpoint(Checkpoint):
|
|
174
180
|
lora_details: LoRADetails = LoRADetails()
|
|
175
181
|
model_weight_format: ModelWeightsFormat = ModelWeightsFormat.LORA
|
truss_train/deployment.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import pathlib
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import List
|
|
3
|
+
from typing import List, Optional
|
|
4
4
|
|
|
5
5
|
from truss.base.custom_types import SafeModel
|
|
6
|
+
from truss.cli.utils.output import console
|
|
6
7
|
from truss.remote.baseten import custom_types as b10_types
|
|
7
8
|
from truss.remote.baseten.api import BasetenApi
|
|
8
9
|
from truss.remote.baseten.core import archive_dir
|
|
@@ -44,6 +45,7 @@ def prepare_push(api: BasetenApi, config: pathlib.Path, training_job: TrainingJo
|
|
|
44
45
|
image=training_job.image,
|
|
45
46
|
runtime=training_job.runtime,
|
|
46
47
|
compute=training_job.compute,
|
|
48
|
+
name=training_job.name,
|
|
47
49
|
runtime_artifacts=[
|
|
48
50
|
S3Artifact(s3_key=credentials["s3_key"], s3_bucket=credentials["s3_bucket"])
|
|
49
51
|
],
|
|
@@ -57,14 +59,25 @@ def create_training_job(
|
|
|
57
59
|
training_project=training_project
|
|
58
60
|
)
|
|
59
61
|
prepared_job = prepare_push(remote_provider.api, config, training_project.job)
|
|
62
|
+
|
|
60
63
|
job_resp = remote_provider.api.create_training_job(
|
|
61
64
|
project_id=project_resp["id"], job=prepared_job
|
|
62
65
|
)
|
|
63
66
|
return job_resp
|
|
64
67
|
|
|
65
68
|
|
|
66
|
-
def create_training_job_from_file(
|
|
69
|
+
def create_training_job_from_file(
|
|
70
|
+
remote_provider: BasetenRemote,
|
|
71
|
+
config: Path,
|
|
72
|
+
job_name_from_cli: Optional[str] = None,
|
|
73
|
+
) -> dict:
|
|
67
74
|
with loader.import_training_project(config) as training_project:
|
|
75
|
+
if job_name_from_cli:
|
|
76
|
+
if training_project.job.name:
|
|
77
|
+
console.print(
|
|
78
|
+
f"[bold yellow]⚠ Warning:[/bold yellow] name '{training_project.job.name}' provided in config file will be ignored. Using job name '{job_name_from_cli}' provided via --job-name flag."
|
|
79
|
+
)
|
|
80
|
+
training_project.job.name = job_name_from_cli
|
|
68
81
|
job_resp = create_training_job(
|
|
69
82
|
remote_provider=remote_provider,
|
|
70
83
|
training_project=training_project,
|
|
@@ -1,227 +0,0 @@
|
|
|
1
|
-
import time
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
from tempfile import TemporaryDirectory
|
|
4
|
-
|
|
5
|
-
import pytest
|
|
6
|
-
import requests
|
|
7
|
-
from huggingface_hub.errors import HfHubHTTPError
|
|
8
|
-
|
|
9
|
-
from truss.base.truss_config import ModelCache, ModelRepo
|
|
10
|
-
from truss.util.basetenpointer import model_cache_hf_to_b10ptr
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def test_dolly_12b():
|
|
14
|
-
ModelCached = ModelCache(
|
|
15
|
-
[
|
|
16
|
-
dict(
|
|
17
|
-
repo_id="databricks/dolly-v2-12b",
|
|
18
|
-
revision="19308160448536e378e3db21a73a751579ee7fdd",
|
|
19
|
-
use_volume=True,
|
|
20
|
-
volume_folder="databricks_dolly_v2_12b",
|
|
21
|
-
runtime_secret_name="hf_access_token",
|
|
22
|
-
)
|
|
23
|
-
]
|
|
24
|
-
)
|
|
25
|
-
for _ in range(2):
|
|
26
|
-
try:
|
|
27
|
-
bptr = model_cache_hf_to_b10ptr(ModelCached)
|
|
28
|
-
continue
|
|
29
|
-
# timeout by huggingface hub timeout error
|
|
30
|
-
except requests.exceptions.ReadTimeout as e:
|
|
31
|
-
# this is expected to timeout when the request takes too long
|
|
32
|
-
# due to the large size of the model
|
|
33
|
-
print("ReadTimeout Error: ", e)
|
|
34
|
-
pytest.skip(
|
|
35
|
-
"Skipping test due to ReadTimeout error from Hugging Face API, "
|
|
36
|
-
"this can happen for large models like Dolly-12b"
|
|
37
|
-
)
|
|
38
|
-
except HfHubHTTPError as e:
|
|
39
|
-
if e.response.status_code == 429:
|
|
40
|
-
pytest.skip("Hugging Face API rate limit exceeded")
|
|
41
|
-
raise
|
|
42
|
-
bptr_list = bptr.pointers
|
|
43
|
-
expected = [
|
|
44
|
-
{
|
|
45
|
-
"resolution": {
|
|
46
|
-
"url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/.gitattributes",
|
|
47
|
-
"expiration_timestamp": 2373918212,
|
|
48
|
-
},
|
|
49
|
-
"uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:.gitattributes",
|
|
50
|
-
"file_name": "/app/model_cache/databricks_dolly_v2_12b/.gitattributes",
|
|
51
|
-
"hashtype": "etag",
|
|
52
|
-
"hash": "c7d9f3332a950355d5a77d85000f05e6f45435ea",
|
|
53
|
-
"size": 1477,
|
|
54
|
-
},
|
|
55
|
-
{
|
|
56
|
-
"resolution": {
|
|
57
|
-
"url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/README.md",
|
|
58
|
-
"expiration_timestamp": 2373918212,
|
|
59
|
-
},
|
|
60
|
-
"uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:README.md",
|
|
61
|
-
"file_name": "/app/model_cache/databricks_dolly_v2_12b/README.md",
|
|
62
|
-
"hashtype": "etag",
|
|
63
|
-
"hash": "2912eb39545af0367335cff448d07214519c5eed",
|
|
64
|
-
"size": 10746,
|
|
65
|
-
},
|
|
66
|
-
{
|
|
67
|
-
"resolution": {
|
|
68
|
-
"url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/config.json",
|
|
69
|
-
"expiration_timestamp": 2373918212,
|
|
70
|
-
},
|
|
71
|
-
"uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:config.json",
|
|
72
|
-
"file_name": "/app/model_cache/databricks_dolly_v2_12b/config.json",
|
|
73
|
-
"hashtype": "etag",
|
|
74
|
-
"hash": "888c677eda015e2375fad52d75062d14b30ebad9",
|
|
75
|
-
"size": 818,
|
|
76
|
-
},
|
|
77
|
-
{
|
|
78
|
-
"resolution": {
|
|
79
|
-
"url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/instruct_pipeline.py",
|
|
80
|
-
"expiration_timestamp": 2373918212,
|
|
81
|
-
},
|
|
82
|
-
"uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:instruct_pipeline.py",
|
|
83
|
-
"file_name": "/app/model_cache/databricks_dolly_v2_12b/instruct_pipeline.py",
|
|
84
|
-
"hashtype": "etag",
|
|
85
|
-
"hash": "f8b291569e936cf104f44d003f95451bf5e1f965",
|
|
86
|
-
"size": 9159,
|
|
87
|
-
},
|
|
88
|
-
{
|
|
89
|
-
"resolution": {
|
|
90
|
-
"url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/pytorch_model.bin",
|
|
91
|
-
"expiration_timestamp": 2373918212,
|
|
92
|
-
},
|
|
93
|
-
"uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:pytorch_model.bin",
|
|
94
|
-
"file_name": "/app/model_cache/databricks_dolly_v2_12b/pytorch_model.bin",
|
|
95
|
-
"hashtype": "etag",
|
|
96
|
-
"hash": "19e10711310992c310c3775964c7635f4b28dd86587403e718c6d6d524a406a5",
|
|
97
|
-
"size": 23834965761,
|
|
98
|
-
},
|
|
99
|
-
{
|
|
100
|
-
"resolution": {
|
|
101
|
-
"url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/special_tokens_map.json",
|
|
102
|
-
"expiration_timestamp": 2373918212,
|
|
103
|
-
},
|
|
104
|
-
"uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:special_tokens_map.json",
|
|
105
|
-
"file_name": "/app/model_cache/databricks_dolly_v2_12b/special_tokens_map.json",
|
|
106
|
-
"hashtype": "etag",
|
|
107
|
-
"hash": "ecc1ee07dec13ee276fa9f1b29a1078da3280a4d",
|
|
108
|
-
"size": 228,
|
|
109
|
-
},
|
|
110
|
-
{
|
|
111
|
-
"resolution": {
|
|
112
|
-
"url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/tokenizer.json",
|
|
113
|
-
"expiration_timestamp": 2373918212,
|
|
114
|
-
},
|
|
115
|
-
"uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:tokenizer.json",
|
|
116
|
-
"file_name": "/app/model_cache/databricks_dolly_v2_12b/tokenizer.json",
|
|
117
|
-
"hashtype": "etag",
|
|
118
|
-
"hash": "22868c8caf99a303c1a44bfea98f20f4254fc0e5",
|
|
119
|
-
"size": 2114274,
|
|
120
|
-
},
|
|
121
|
-
{
|
|
122
|
-
"resolution": {
|
|
123
|
-
"url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/tokenizer_config.json",
|
|
124
|
-
"expiration_timestamp": 2373918212,
|
|
125
|
-
},
|
|
126
|
-
"uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:tokenizer_config.json",
|
|
127
|
-
"file_name": "/app/model_cache/databricks_dolly_v2_12b/tokenizer_config.json",
|
|
128
|
-
"hashtype": "etag",
|
|
129
|
-
"hash": "51e564ead5d28eebc74b25d86f0a694b7c7cc618",
|
|
130
|
-
"size": 449,
|
|
131
|
-
},
|
|
132
|
-
]
|
|
133
|
-
assert len(bptr_list) == len(expected), (
|
|
134
|
-
f"Expected {len(expected)} but got {len(bptr_list)}"
|
|
135
|
-
)
|
|
136
|
-
for expected, actual in zip(expected, bptr_list):
|
|
137
|
-
assert expected["uid"] == actual.uid, (
|
|
138
|
-
f"Expected uid {expected['uid']} but got {actual.uid}"
|
|
139
|
-
)
|
|
140
|
-
assert expected["file_name"] == actual.file_name, (
|
|
141
|
-
f"Expected file_name {expected['file_name']} but got {actual.file_name}"
|
|
142
|
-
)
|
|
143
|
-
assert expected["hash"] == actual.hash, (
|
|
144
|
-
f"Expected hash {expected['hash']} but got {actual.hash}"
|
|
145
|
-
)
|
|
146
|
-
assert expected["size"] == actual.size, (
|
|
147
|
-
f"Expected size {expected['size']} but got {actual.size}"
|
|
148
|
-
)
|
|
149
|
-
assert expected["resolution"]["url"] == actual.resolution.url, (
|
|
150
|
-
f"Expected resolution url {expected['resolution']['url']} but got {actual.resolution.url}"
|
|
151
|
-
)
|
|
152
|
-
# 100 years or more ahead
|
|
153
|
-
assert (
|
|
154
|
-
actual.resolution.expiration_timestamp
|
|
155
|
-
>= time.time() + 20 * 365 * 24 * 60 * 60
|
|
156
|
-
), (
|
|
157
|
-
f"Expected unix expiration timestamp to be at least 20 years ahead, but got {actual.resolution.expiration_timestamp}. "
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
# download first file and verify size
|
|
161
|
-
with TemporaryDirectory() as tmp:
|
|
162
|
-
# Get the first pointer (.gitattributes)
|
|
163
|
-
first_pointer = bptr_list[0]
|
|
164
|
-
tmp_path = Path(tmp) / "downloaded_file"
|
|
165
|
-
|
|
166
|
-
# Download the file
|
|
167
|
-
response = requests.get(first_pointer.resolution.url)
|
|
168
|
-
response.raise_for_status()
|
|
169
|
-
|
|
170
|
-
# Save the file
|
|
171
|
-
tmp_path.write_bytes(response.content)
|
|
172
|
-
|
|
173
|
-
# Verify file size matches metadata
|
|
174
|
-
actual_size = tmp_path.stat().st_size
|
|
175
|
-
assert actual_size == first_pointer.size, (
|
|
176
|
-
f"Downloaded file size {actual_size} does not match expected size {first_pointer.size}"
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
def test_with_main():
|
|
181
|
-
# main should be resolved to 41dec486b25746052d3335decc8f5961607418a0
|
|
182
|
-
cache = ModelCache(
|
|
183
|
-
[
|
|
184
|
-
ModelRepo(
|
|
185
|
-
repo_id="intfloat/llm-retriever-base",
|
|
186
|
-
revision="main",
|
|
187
|
-
ignore_patterns=["*.json", "*.txt", "*.md", "*.bin", "*.model"],
|
|
188
|
-
volume_folder="mistral_demo",
|
|
189
|
-
use_volume=True,
|
|
190
|
-
)
|
|
191
|
-
]
|
|
192
|
-
)
|
|
193
|
-
b10ptr = model_cache_hf_to_b10ptr(cache)
|
|
194
|
-
expected = {
|
|
195
|
-
"pointers": [
|
|
196
|
-
{
|
|
197
|
-
"resolution": {
|
|
198
|
-
"url": "https://huggingface.co/intfloat/llm-retriever-base/resolve/41dec486b25746052d3335decc8f5961607418a0/.gitattributes",
|
|
199
|
-
"expiration_timestamp": 4044816725,
|
|
200
|
-
},
|
|
201
|
-
"uid": "intfloat/llm-retriever-base:main:.gitattributes",
|
|
202
|
-
"file_name": "/app/model_cache/mistral_demo/.gitattributes",
|
|
203
|
-
"hashtype": "etag",
|
|
204
|
-
"hash": "a6344aac8c09253b3b630fb776ae94478aa0275b",
|
|
205
|
-
"size": 1519,
|
|
206
|
-
"runtime_secret_name": "hf_access_token",
|
|
207
|
-
},
|
|
208
|
-
{
|
|
209
|
-
"resolution": {
|
|
210
|
-
"url": "https://huggingface.co/intfloat/llm-retriever-base/resolve/41dec486b25746052d3335decc8f5961607418a0/model.safetensors",
|
|
211
|
-
"expiration_timestamp": 4044816725,
|
|
212
|
-
},
|
|
213
|
-
"uid": "intfloat/llm-retriever-base:main:model.safetensors",
|
|
214
|
-
"file_name": "/app/model_cache/mistral_demo/model.safetensors",
|
|
215
|
-
"hashtype": "etag",
|
|
216
|
-
"hash": "565dd4f1cc6318ccf07af8680c27fd935b3b56ca2684d1af58abcd4e8bf6ecfa",
|
|
217
|
-
"size": 437955512,
|
|
218
|
-
"runtime_secret_name": "hf_access_token",
|
|
219
|
-
},
|
|
220
|
-
]
|
|
221
|
-
}
|
|
222
|
-
assert b10ptr.model_dump() == expected
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
if __name__ == "__main__":
|
|
226
|
-
test_dolly_12b()
|
|
227
|
-
test_with_main()
|
truss/util/basetenpointer.py
DELETED
|
@@ -1,160 +0,0 @@
|
|
|
1
|
-
"""This file contains the utils to create a basetenpointer from a huggingface repo, which can be resolved at runtime."""
|
|
2
|
-
|
|
3
|
-
import time
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
from typing import TYPE_CHECKING, Optional
|
|
6
|
-
|
|
7
|
-
import requests
|
|
8
|
-
from huggingface_hub import hf_api, hf_hub_url
|
|
9
|
-
from huggingface_hub.utils import filter_repo_objects
|
|
10
|
-
from pydantic import BaseModel
|
|
11
|
-
|
|
12
|
-
if TYPE_CHECKING:
|
|
13
|
-
from truss.base.truss_config import ModelCache
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
# copied from: https://github.com/basetenlabs/baseten/blob/caeba66cd544a5152bb6a018d6ac2871814f327b/baseten_shared/baseten_shared/lms/types.py#L13
|
|
17
|
-
class Resolution(BaseModel):
|
|
18
|
-
url: str
|
|
19
|
-
expiration_timestamp: int
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class BasetenPointer(BaseModel):
|
|
23
|
-
resolution: Optional[Resolution] = None
|
|
24
|
-
uid: str
|
|
25
|
-
file_name: str
|
|
26
|
-
hashtype: str
|
|
27
|
-
hash: str
|
|
28
|
-
size: int
|
|
29
|
-
runtime_secret_name: str = "hf_access_token" # TODO: remove the default
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class BasetenPointerList(BaseModel):
|
|
33
|
-
pointers: list[BasetenPointer]
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def get_hf_metadata(api: "hf_api.HfApi", repo: str, revision: str, file: str):
|
|
37
|
-
url = hf_hub_url(repo_id=repo, revision=revision, filename=file)
|
|
38
|
-
meta = api.get_hf_file_metadata(url=url)
|
|
39
|
-
return {"etag": meta.etag, "location": meta.location, "size": meta.size, "url": url}
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def filter_repo_files(
|
|
43
|
-
files: list[str],
|
|
44
|
-
allow_patterns: Optional[list[str]],
|
|
45
|
-
ignore_patterns: Optional[list[str]],
|
|
46
|
-
) -> list[str]:
|
|
47
|
-
return list(
|
|
48
|
-
filter_repo_objects(
|
|
49
|
-
items=files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
|
50
|
-
)
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def metadata_hf_repo(
|
|
55
|
-
repo: str,
|
|
56
|
-
revision: str,
|
|
57
|
-
allow_patterns: Optional[list[str]] = None,
|
|
58
|
-
ignore_patterns: Optional[list[str]] = None,
|
|
59
|
-
) -> dict[str, dict]:
|
|
60
|
-
"""Lists all files, gathers metadata without downloading, just using the Hugging Face API.
|
|
61
|
-
Example:
|
|
62
|
-
[{'.gitattributes': HfFileMetadata(
|
|
63
|
-
commit_hash='07163b72af1488142a360786df853f237b1a3ca1',
|
|
64
|
-
etag='a6344aac8c09253b3b630fb776ae94478aa0275b',
|
|
65
|
-
location='https://huggingface.co/intfloat/e5-mistral-7b-instruct/resolve/main/.gitattributes',
|
|
66
|
-
url='https://huggingface.co/intfloat/e5-mistral-7b-instruct/resolve/main/.gitattributes',
|
|
67
|
-
size=1519)]
|
|
68
|
-
"""
|
|
69
|
-
api = hf_api.HfApi()
|
|
70
|
-
model_info = api.model_info(repo_id=repo, revision=revision)
|
|
71
|
-
real_revision = model_info.sha
|
|
72
|
-
real_revision = real_revision or revision
|
|
73
|
-
if revision != real_revision:
|
|
74
|
-
print(
|
|
75
|
-
f"Warning: revision {revision} is moving, using {real_revision} instead. "
|
|
76
|
-
f"Please update your code to use `revision={real_revision}` instead otherwise you will keep moving. "
|
|
77
|
-
)
|
|
78
|
-
files: list[str] = api.list_repo_files(repo_id=repo, revision=real_revision)
|
|
79
|
-
files = filter_repo_files(
|
|
80
|
-
files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
hf_files_meta = {
|
|
84
|
-
file: get_hf_metadata(api, repo, real_revision, file) for file in files
|
|
85
|
-
}
|
|
86
|
-
|
|
87
|
-
return hf_files_meta
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
def model_cache_hf_to_b10ptr(cache: "ModelCache") -> BasetenPointerList:
|
|
91
|
-
"""
|
|
92
|
-
Convert a ModelCache object to a BasetenPointer object.
|
|
93
|
-
"""
|
|
94
|
-
assert cache.is_v2, "ModelCache is not v2"
|
|
95
|
-
|
|
96
|
-
basetenpointers: list[BasetenPointer] = []
|
|
97
|
-
|
|
98
|
-
for model in cache.models:
|
|
99
|
-
assert model.revision is not None, "ModelCache is not v2, revision is None"
|
|
100
|
-
exception = None
|
|
101
|
-
for _ in range(3):
|
|
102
|
-
try:
|
|
103
|
-
metadata_hf_repo_list = metadata_hf_repo(
|
|
104
|
-
repo=model.repo_id,
|
|
105
|
-
revision=model.revision,
|
|
106
|
-
allow_patterns=model.allow_patterns,
|
|
107
|
-
ignore_patterns=model.ignore_patterns,
|
|
108
|
-
)
|
|
109
|
-
break
|
|
110
|
-
except requests.exceptions.ReadTimeout as e:
|
|
111
|
-
# this is expected, sometimes huggingface hub times out
|
|
112
|
-
print("ReadTimeout Error: ", e)
|
|
113
|
-
time.sleep(5)
|
|
114
|
-
exception = e
|
|
115
|
-
except Exception as e:
|
|
116
|
-
raise e
|
|
117
|
-
else:
|
|
118
|
-
# if we get here, we have exhausted the retries
|
|
119
|
-
assert exception is not None, "ReadTimeout Error: " + str(exception)
|
|
120
|
-
raise exception
|
|
121
|
-
# convert the metadata to b10 pointer format
|
|
122
|
-
b10_pointer_list = [
|
|
123
|
-
BasetenPointer(
|
|
124
|
-
uid=f"{model.repo_id}:{model.revision}:{filename}",
|
|
125
|
-
file_name=(Path(model.runtime_path) / filename).as_posix(),
|
|
126
|
-
hashtype="etag",
|
|
127
|
-
hash=content["etag"],
|
|
128
|
-
size=content["size"],
|
|
129
|
-
runtime_secret_name=model.runtime_secret_name,
|
|
130
|
-
resolution=Resolution(
|
|
131
|
-
url=content["url"],
|
|
132
|
-
expiration_timestamp=int(
|
|
133
|
-
4044816725 # 90 years in the future, hf does not expire. needs to be static, to have cache hits.
|
|
134
|
-
),
|
|
135
|
-
),
|
|
136
|
-
)
|
|
137
|
-
for filename, content in metadata_hf_repo_list.items()
|
|
138
|
-
]
|
|
139
|
-
basetenpointers.extend(b10_pointer_list)
|
|
140
|
-
|
|
141
|
-
return BasetenPointerList(pointers=basetenpointers)
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
if __name__ == "__main__":
|
|
145
|
-
# example usage
|
|
146
|
-
from truss.base.truss_config import ModelCache, ModelRepo
|
|
147
|
-
|
|
148
|
-
cache = ModelCache(
|
|
149
|
-
[
|
|
150
|
-
ModelRepo(
|
|
151
|
-
repo_id="intfloat/llm-retriever-base",
|
|
152
|
-
revision="main",
|
|
153
|
-
ignore_patterns=["*.json", "*.txt", "*.md", "*.bin", "*.model"],
|
|
154
|
-
volume_folder="mistral_demo",
|
|
155
|
-
use_volume=True,
|
|
156
|
-
)
|
|
157
|
-
]
|
|
158
|
-
)
|
|
159
|
-
b10ptr = model_cache_hf_to_b10ptr(cache)
|
|
160
|
-
print(b10ptr.model_dump_json())
|
|
File without changes
|
|
File without changes
|
|
File without changes
|