truss 0.11.2rc503__py3-none-any.whl → 0.11.2rc505__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (33) hide show
  1. truss/base/constants.py +3 -0
  2. truss/cli/chains_commands.py +20 -7
  3. truss/cli/train/core.py +156 -0
  4. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +1 -1
  5. truss/cli/train_commands.py +72 -0
  6. truss/templates/base.Dockerfile.jinja +1 -3
  7. truss/templates/control/control/endpoints.py +82 -33
  8. truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +3 -20
  9. truss/templates/control/requirements.txt +1 -1
  10. truss/templates/server/common/errors.py +1 -0
  11. truss/templates/server/truss_server.py +5 -3
  12. truss/templates/server.Dockerfile.jinja +2 -4
  13. truss/templates/train/config.py +46 -0
  14. truss/templates/train/run.sh +11 -0
  15. truss/tests/cli/train/test_deploy_checkpoints.py +3 -3
  16. truss/tests/cli/train/test_train_init.py +499 -0
  17. truss/tests/patch/test_calc_patch.py +14 -26
  18. truss/tests/templates/control/control/test_endpoints.py +20 -14
  19. truss/tests/test_control_truss_patching.py +0 -17
  20. truss/truss_handle/patch/calc_patch.py +5 -20
  21. {truss-0.11.2rc503.dist-info → truss-0.11.2rc505.dist-info}/METADATA +1 -1
  22. {truss-0.11.2rc503.dist-info → truss-0.11.2rc505.dist-info}/RECORD +32 -29
  23. truss_chains/deployment/code_gen.py +5 -1
  24. truss_chains/deployment/deployment_client.py +45 -7
  25. truss_chains/public_types.py +6 -3
  26. truss_chains/remote_chainlet/utils.py +46 -7
  27. truss_train/__init__.py +4 -0
  28. truss_train/definitions.py +47 -2
  29. truss_train/restore_from_checkpoint.py +42 -0
  30. truss/templates/server/entrypoint.sh +0 -32
  31. {truss-0.11.2rc503.dist-info → truss-0.11.2rc505.dist-info}/WHEEL +0 -0
  32. {truss-0.11.2rc503.dist-info → truss-0.11.2rc505.dist-info}/entry_points.txt +0 -0
  33. {truss-0.11.2rc503.dist-info → truss-0.11.2rc505.dist-info}/licenses/LICENSE +0 -0
@@ -9,8 +9,6 @@ from truss.tests.test_testing_utilities_for_other_tests import ensure_kill_all
9
9
  from truss.tests.test_truss_handle import (
10
10
  verify_python_requirement_installed_on_container,
11
11
  verify_python_requirement_not_installed_on_container,
12
- verify_system_package_installed_on_container,
13
- verify_system_requirement_not_installed_on_container,
14
12
  )
15
13
  from truss.truss_handle.truss_gatherer import calc_shadow_truss_dirname
16
14
  from truss.truss_handle.truss_handle import TrussHandle
@@ -147,13 +145,6 @@ def test_control_truss_python_sys_req_patch(
147
145
  th.remove_python_requirement(req)
148
146
  return th.docker_predict([1], tag=tag, binary=binary, local_port=None)
149
147
 
150
- def predict_with_system_requirement_added(pkg):
151
- th.add_system_package(pkg)
152
- return th.docker_predict([1], tag=tag, binary=binary, local_port=None)
153
-
154
- def predict_with_system_requirement_removed(pkg):
155
- th.remove_system_package(pkg)
156
- return th.docker_predict([1], tag=tag, binary=binary, local_port=None)
157
148
 
158
149
  with ensure_kill_all():
159
150
  th.docker_predict([1], tag=tag, binary=binary, local_port=None)
@@ -170,14 +161,6 @@ def test_control_truss_python_sys_req_patch(
170
161
  assert current_num_docker_images(th) == orig_num_truss_images
171
162
  verify_python_requirement_not_installed_on_container(container, python_req)
172
163
 
173
- system_pkg = "jq"
174
- predict_with_system_requirement_added(system_pkg)
175
- assert current_num_docker_images(th) == orig_num_truss_images
176
- verify_system_package_installed_on_container(container, system_pkg)
177
-
178
- predict_with_system_requirement_removed(system_pkg)
179
- assert current_num_docker_images(th) == orig_num_truss_images
180
- verify_system_requirement_not_installed_on_container(container, system_pkg)
181
164
 
182
165
 
183
166
  @pytest.mark.integration
@@ -425,20 +425,12 @@ def _calc_system_packages_patches(
425
425
  ) -> List[Patch]:
426
426
  """Calculate patch based on changes to system packates.
427
427
 
428
- Empty list means no relevant differences found.
428
+ System package patches are no longer supported, so this always returns an empty list.
429
+ Changes to system packages will require a full rebuild instead of patching.
429
430
  """
430
- patches = []
431
- prev_pkgs = system_packages_set(prev_config.system_packages)
432
- new_pkgs = system_packages_set(new_config.system_packages)
433
- removed_pkgs = prev_pkgs.difference(new_pkgs)
434
- for removed_pkg in removed_pkgs:
435
- patches.append(_mk_system_package_patch(Action.REMOVE, removed_pkg))
436
-
437
- added_pkgs = new_pkgs.difference(prev_pkgs)
438
- for added_pkg in added_pkgs:
439
- patches.append(_mk_system_package_patch(Action.ADD, added_pkg))
440
-
441
- return patches
431
+ # System package patches are no longer supported - return empty list
432
+ # This will cause any system package changes to be handled by full rebuild
433
+ return []
442
434
 
443
435
 
444
436
  def _mk_config_patch(action: Action, config: dict) -> Patch:
@@ -471,13 +463,6 @@ def _mk_python_requirement_patch(action: Action, requirement: str) -> Patch:
471
463
  )
472
464
 
473
465
 
474
- def _mk_system_package_patch(action: Action, package: str) -> Patch:
475
- return Patch(
476
- type=PatchType.SYSTEM_PACKAGE,
477
- body=SystemPackagePatch(action=action, package=package),
478
- )
479
-
480
-
481
466
  def _relative_to(path: str, relative_to_path: str):
482
467
  return str(Path(path).relative_to(relative_to_path))
483
468
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: truss
3
- Version: 0.11.2rc503
3
+ Version: 0.11.2rc505
4
4
  Summary: A seamless bridge from model development to model delivery
5
5
  Project-URL: Repository, https://github.com/basetenlabs/truss
6
6
  Project-URL: Homepage, https://truss.baseten.co
@@ -2,29 +2,29 @@ truss/__init__.py,sha256=CoUcP6vx_pocyemRmpbCPlndkHhdMkABAlr0ZXVuPCk,1163
2
2
  truss/api/__init__.py,sha256=spBAa_m1pItiid97iDLKPmumgAkSirPkv-E8RWMZyOk,5090
3
3
  truss/api/definitions.py,sha256=QAaIBqL59Q-R7HtLcXcoeCIWBN2HqOzApdFX0PpCq2s,1604
4
4
  truss/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- truss/base/constants.py,sha256=qwNNkd9EOAuiTxYLVccJaiPCNRayBAFvyj_GisYOT3I,3488
5
+ truss/base/constants.py,sha256=sExArdnuGg83z83XMgaQ4b8SS3V_j_bJEpOATDGJzpE,3600
6
6
  truss/base/custom_types.py,sha256=FUSIT2lPOQb6gfg6IzT63YBV8r8L6NIZ0D74Fp3e_jQ,2835
7
7
  truss/base/errors.py,sha256=zDVLEvseTChdPP0oNhBBQCtQUtZJUaof5zeWMIjqz6o,691
8
8
  truss/base/trt_llm_config.py,sha256=CRz3AqGDAyv8YpcBWXUrnfjvNAauyo3yf8ZOGVsSt6g,32782
9
9
  truss/base/truss_config.py,sha256=7CtiJIwMHtDU8Wzn8UTJUVVunD0pWFl4QUVycK2aIpY,28055
10
10
  truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
11
- truss/cli/chains_commands.py,sha256=bqOXQ-0RPS66vSP_OPQdJ5dvctGiVrsGoSUMbURGdSI,16970
11
+ truss/cli/chains_commands.py,sha256=Kpa5mCg6URAJQE2ZmZfVQFhjBHEitKT28tKiW0H6XAI,17406
12
12
  truss/cli/cli.py,sha256=PaMkuwXZflkU7sa1tEoT_Zmy-iBkEZs1m4IVqcieaeo,30367
13
13
  truss/cli/remote_cli.py,sha256=G_xCKRXzgkCmkiZJhUFfsv5YSVgde1jLA5LPQitpZgI,1905
14
- truss/cli/train_commands.py,sha256=GDye7yXGL_nQvXAlY5MWsdj5x0zYOvcQw0Ubn14TiRU,14365
14
+ truss/cli/train_commands.py,sha256=TZhtvofviWQF34pYppRCaQ6qayTsvPnx6afTrYbFpOM,17319
15
15
  truss/cli/logs/base_watcher.py,sha256=KKyd7lIrdaEeDVt8EtjMioSPGVpLyOcF0ewyzE_GGdQ,2785
16
16
  truss/cli/logs/model_log_watcher.py,sha256=NACcP-wkcaroYa2Cb9BZC7Yr0554WZa_FSM2LXOf4A8,1263
17
17
  truss/cli/logs/training_log_watcher.py,sha256=r6HRqrLnz-PiKTUXiDYYxg4ZnP8vYcXlEX1YmgHhzlo,1173
18
18
  truss/cli/logs/utils.py,sha256=z-U_FG4BUzdZLbE3BnXb4DZQ0zt3LSZ3PiQpLaDuc3o,1031
19
19
  truss/cli/train/common.py,sha256=xTR41U5FeSndXfNBBHF9wF5XwZH1sOIVFlv-XHjsKIU,1547
20
- truss/cli/train/core.py,sha256=dAmetxKqSc4bQPnVS_8WLfNsw1L7vLT2tU02BVwRPgc,20206
20
+ truss/cli/train/core.py,sha256=4vPnREmaJh8R_rlwR0_H5NRaXhdyY2g07w11uab-9qw,25908
21
21
  truss/cli/train/deploy_from_checkpoint_config.yml,sha256=mktaVrfhN8Kjx1UveC4xr-gTW-kjwbHvq6bx_LpO-Wg,371
22
22
  truss/cli/train/deploy_from_checkpoint_config_whisper.yml,sha256=6GbOorYC8ml0UyOUvuBpFO_fuYtYE646JqsalR-D4oY,406
23
23
  truss/cli/train/metrics_watcher.py,sha256=smz-zrEsBj_-wJHI0pAZ-EAPrvfCWzq1eQjGiFNM-Mk,12755
24
24
  truss/cli/train/poller.py,sha256=TGRzELxsicga0bEXewSX1ujw6lfPmDnHd6nr8zvOFO8,3550
25
25
  truss/cli/train/types.py,sha256=alGtr4Q71GeB65PpGMhsoKygw4k_ncR6MKIP1ioP8rI,951
26
26
  truss/cli/train/deploy_checkpoints/__init__.py,sha256=wL-M2yu8PxO2tFvjwshXAfPnB-5TlvsBp2v_bdzimRU,99
27
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py,sha256=xfblHi3py7GDgY24NcuAaDKzcQeOm67rjtWOK6vAEe4,17352
27
+ truss/cli/train/deploy_checkpoints/deploy_checkpoints.py,sha256=nJHUjR4f4_13mFPNetWKq7ecsqr-cradbv3RBBDj2pk,17364
28
28
  truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py,sha256=6x5nS_HnWYtS9vi-Pg8akzrJk9L_agjvFhm5EFh1m6Y,1964
29
29
  truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py,sha256=FYRG5KTMlxEMZS-RA_m2gp1wuqWbSpqt2RhdQfLibhA,3968
30
30
  truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py,sha256=P91dIAzuhl2GlzmrWwCcYI7uCMT1Lm7C79JQHM_exN4,4442
@@ -66,15 +66,15 @@ truss/remote/baseten/utils/time.py,sha256=Ry9GMjYnbIGYVIGwtmv4V8ljWjvdcaCf5NOQzl
66
66
  truss/remote/baseten/utils/transfer.py,sha256=d3VptuQb6M1nyS6kz0BAfeOYDLkMKUjatJXpY-mp-As,1548
67
67
  truss/templates/README.md.jinja,sha256=N7CJdyldZuJamj5jLh47le0hFBdu9irVsTBqoxhPNPQ,2476
68
68
  truss/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
69
- truss/templates/base.Dockerfile.jinja,sha256=3deYATR6gbgNmmYhg4yXopbOUnICSbXCfjeGU8mdUaQ,5734
69
+ truss/templates/base.Dockerfile.jinja,sha256=irked6fWbiZ4tMkhR3zi3njpaaI9bANVqq7PTjp_Tmc,5610
70
70
  truss/templates/cache.Dockerfile.jinja,sha256=1qZqDo1phrcqi-Vwol-VafYJkADsBbQWU6huQ-_1x00,1146
71
71
  truss/templates/cache_requirements.txt,sha256=xoPoJ-OVnf1z6oq_RVM3vCr3ionByyqMLj7wGs61nUs,87
72
72
  truss/templates/copy_cache_files.Dockerfile.jinja,sha256=Os5zFdYLZ_AfCRGq4RcpVTObOTwL7zvmwYcvOzd_Zqo,126
73
73
  truss/templates/docker_server_requirements.txt,sha256=PyhOPKAmKW1N2vLvTfLMwsEtuGpoRrbWuNo7tT6v2Mc,18
74
- truss/templates/server.Dockerfile.jinja,sha256=fNrCi1sGefjGlu2JOzLu8E7PiwZvpoq8JgW_BBhvya0,7219
75
- truss/templates/control/requirements.txt,sha256=Kk0tYID7trPk5gwX38Wrt2-YGWZAXFJCJRcqJ8ZzCjc,251
74
+ truss/templates/server.Dockerfile.jinja,sha256=CUYnF_hgxPGq2re7__0UPWlwzOHMoFkxp6NVKi3U16s,7071
75
+ truss/templates/control/requirements.txt,sha256=tJGr83WoE0CZm2FrloZ9VScK84q-_FTuVXjDYrexhW0,250
76
76
  truss/templates/control/control/application.py,sha256=jYeta6hWe1SkfLL3W4IDmdYjg3ZuKqI_UagWYs5RB_E,3793
77
- truss/templates/control/control/endpoints.py,sha256=FM-sgao7I3gMoUTasM3Xq_g2LDoJQe75JxIoaQxzeNo,10031
77
+ truss/templates/control/control/endpoints.py,sha256=VQ1lvZjFvR091yRkiFdvXw1Q7PiNGXT9rJwY7_sX6yg,11828
78
78
  truss/templates/control/control/server.py,sha256=R4Y219i1dcz0kkksN8obLoX-YXWGo9iW1igindyG50c,3128
79
79
  truss/templates/control/control/helpers/context_managers.py,sha256=W6dyFgLBhPa5meqrOb3w_phMtKfaJI-GhwUfpiycDc8,413
80
80
  truss/templates/control/control/helpers/custom_types.py,sha256=n_lTudtLTpy4oPV3aDdJ4X2rh3KCV5btYO9UnTeUouQ,5471
@@ -84,7 +84,7 @@ truss/templates/control/control/helpers/inference_server_process_controller.py,s
84
84
  truss/templates/control/control/helpers/inference_server_starter.py,sha256=Fz2Puijro6Cc5cvTsAqOveNJbBQR_ARYJXl4lvETJ8Y,2633
85
85
  truss/templates/control/control/helpers/truss_patch/__init__.py,sha256=CXZdUV_ylqLTJrKuFpvSnUT6PUFrZrMF2y6jiHbdaKU,998
86
86
  truss/templates/control/control/helpers/truss_patch/model_code_patch_applier.py,sha256=LTIIADLz0wRn7V49j64dU1U7Hbta9YLde3pb5YZWvzQ,2001
87
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py,sha256=62TDVaDmgAH0-X116xSDnNTOFEgUQH4sNJr0aALFl_0,7149
87
+ truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py,sha256=uiZvKhLa_rpr_N67NpApgA-XVrBFLXraZlikSOFDDOw,6445
88
88
  truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py,sha256=CL3KEAj4B3ApMQShd7TI5umXVbazLZY5StrNlwHwWtc,1995
89
89
  truss/templates/control/control/helpers/truss_patch/system_packages.py,sha256=IYh1CVU_kooAvtSGXKQDDWnNdOhlv7ENWagsL1wvhgw,208
90
90
  truss/templates/custom/examples.yaml,sha256=2UcCtEdavImWmiCtj31ckBlAKVOwNMC5AwMIIznKDag,48
@@ -94,13 +94,12 @@ truss/templates/custom_python_dx/my_model.py,sha256=NG75mQ6wxzB1BYUemDFZvRLBET-U
94
94
  truss/templates/docker_server/proxy.conf.jinja,sha256=Lg-PcZzKflG85exZKHNgW_I6r0mATV8AtOIBaE40-RM,1669
95
95
  truss/templates/docker_server/supervisord.conf.jinja,sha256=dd37fwZE--cutrvOUCqEyJQQQhlp61H2IUs2huKWsSk,1808
96
96
  truss/templates/server/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
97
- truss/templates/server/entrypoint.sh,sha256=aRqtkHt_gHRgEa2S1sGnDoaFaZtA43NbvKYUy5CtzcM,881
98
97
  truss/templates/server/main.py,sha256=kWXrdD8z8IpamyWxc8qcvd5ck9gM1Kz2QH5qHJCnmOQ,222
99
98
  truss/templates/server/model_wrapper.py,sha256=k75VVISwwlsx5EGb82UZsu8kCM_i6Yi3-Hd0-Kpm1yo,42055
100
99
  truss/templates/server/requirements.txt,sha256=XblmpfxAmRo3X1V_9oMj8yjdpZ5Wk-C2oa3z6nq4OGw,672
101
- truss/templates/server/truss_server.py,sha256=ob_nceeGtFPZzKKdk_ZZGLoZrJOGE6hR52xM1sPR97A,19498
100
+ truss/templates/server/truss_server.py,sha256=noXfGJMsKIhgF4oI_8LC1UHkcx8Vg8nGSITZJ_bkRFQ,19598
102
101
  truss/templates/server/common/__init__.py,sha256=qHIqr68L5Tn4mV6S-PbORpcuJ4jmtBR8aCuRTIWDvNo,85
103
- truss/templates/server/common/errors.py,sha256=qWeZlmNI8ZGbZbOIp_mtS6IKvUFIzhj3QH8zp-xTp9o,8554
102
+ truss/templates/server/common/errors.py,sha256=My0P6-Y7imVTICIhazHT0vlSu3XJDH7As06OyVzu4Do,8589
104
103
  truss/templates/server/common/patches.py,sha256=uEOzvDnXsHOkTSa8zygGYuR4GHhrFNVHNQc5peJcwvo,1393
105
104
  truss/templates/server/common/retry.py,sha256=dtz6yvwLoY0i55FnxECz57zEOKjAhGMYvvM-k9jiR9c,624
106
105
  truss/templates/server/common/schema.py,sha256=WLFtVyEKmk4whg5_gk6Gt1vOD6wM5fWKLb4zNuD0bkw,6042
@@ -113,6 +112,8 @@ truss/templates/shared/log_config.py,sha256=l9udyu4VKHZePlfK9LQEd5TOUUodPuehypsX
113
112
  truss/templates/shared/secrets_resolver.py,sha256=3prDe3Q06NTmUEe7KCW-W4TD1CzGck9lpDG789209z4,2110
114
113
  truss/templates/shared/serialization.py,sha256=_WC_2PPkRi-MdTwxwjG8LKQptnHi4sANfpOlKWevqWc,3736
115
114
  truss/templates/shared/util.py,sha256=dPgFF4iL_YkeC6Kf8tZUHJH60rbpskHwVPh0ONLGaQM,2222
115
+ truss/templates/train/config.py,sha256=aQJ3lsyVRlq6edjjZq4_Anz1bZVwkjLdclmZPJTdo1k,1626
116
+ truss/templates/train/run.sh,sha256=2rimigJOn6yg4DguRfOJWkzm77X-meNSYXnidLafqNg,346
116
117
  truss/templates/trtllm-audio/model/model.py,sha256=o38QqW57b1lf8O_td1lW_AojZZ8R_qAZCgzOWtoIse8,1619
117
118
  truss/templates/trtllm-audio/packages/sigint_patch.py,sha256=t6pYpVwgQsLCgcxQq7-V3scr9ZOiIxtYSpy9LCfdNTk,414
118
119
  truss/templates/trtllm-audio/packages/whisper_trt/__init__.py,sha256=5ZQfVlwtkWrnjYiuBIVSviYDhV-kksygDkHEWBS_ijM,7065
@@ -130,7 +131,7 @@ truss/tests/helpers.py,sha256=DKnUGzmt-k4IN_wSnlAXbVYCiEt58xFzFmmuCDSQ0dg,555
130
131
  truss/tests/test_build.py,sha256=Wq4sM9tmZFVTCN3YljvOcn04Kkj-L-41Tlw0DKQ8Z7c,709
131
132
  truss/tests/test_config.py,sha256=AVpVCL_XHYXKSGHzwecrh7BAJlB_Wr5AUlnnwMqWM98,30559
132
133
  truss/tests/test_context_builder_image.py,sha256=fVZNJSzZNiWa7Dr1X_VhhMJtyJ5HzsLaPglOr6NV2CA,1105
133
- truss/tests/test_control_truss_patching.py,sha256=lbMuAjLbkeDRLxUxXHWr41BZyhZKHQYoMnbJSj3dqrc,15390
134
+ truss/tests/test_control_truss_patching.py,sha256=geBSW8g-Em9FH2T5hsmBkc_Hr5DWPJ8ye2GmstKKiQ0,14499
134
135
  truss/tests/test_custom_server.py,sha256=GP2qMgnqxJMPRtfEciqbhBcG0_JUK7gNL7nrXPGrSLg,1305
135
136
  truss/tests/test_docker.py,sha256=3RI6jEC9CVQsKj83s_gOBl3EkdOaov-KEX4IihfMJW4,523
136
137
  truss/tests/test_model_inference.py,sha256=9QfPMa1kjxvKCWg5XKocjwcpfDkKB7pWd8bn4hIkshk,76213
@@ -141,14 +142,15 @@ truss/tests/test_truss_handle.py,sha256=-xz9VXkecXDTslmQZ-dmUmQLnvD0uumRqHS2uvGl
141
142
  truss/tests/test_util.py,sha256=hs1bNMkXKEdoPRx4Nw-NAEdoibR92OubZuADGmbiYsQ,1344
142
143
  truss/tests/cli/test_cli.py,sha256=yfbVS5u1hnAmmA8mJ539vj3lhH-JVGUvC4Q_Mbort44,787
143
144
  truss/tests/cli/train/test_cache_view.py,sha256=aVRCh3atRpFbJqyYgq7N-vAW0DiKMftQ7ajUqO2ClOg,22606
144
- truss/tests/cli/train/test_deploy_checkpoints.py,sha256=wQZ3DPLPAyXE3iaQiyHJTBO15v_gXN44eDk1StYkKmM,44764
145
+ truss/tests/cli/train/test_deploy_checkpoints.py,sha256=lDk88uAUPYatJ30JKVVtJDdXv_zWNk1nxXFyUH6IVGw,44800
145
146
  truss/tests/cli/train/test_train_cli_core.py,sha256=vzYfxKdwoa3NaFMrVZbSg5qOoLXivMvZXN1ClQirGTQ,16148
147
+ truss/tests/cli/train/test_train_init.py,sha256=pv8BfyLlVG0QtdowTziITjKa_OE1KigatmAGx8XSZrM,17238
146
148
  truss/tests/cli/train/resources/test_deploy_from_checkpoint_config.yml,sha256=GF7r9l0KaeXiUYCPSBpeMPd2QG6PeWWyI12NdbqLOgc,1930
147
149
  truss/tests/contexts/image_builder/test_serving_image_builder.py,sha256=16niCXZnuxFHXYQw2vPFZ8svSZafkH5DT0Gx3Z9Xdd8,22377
148
150
  truss/tests/contexts/local_loader/test_load_local.py,sha256=D1qMH2IpYA2j5009v50QMgUnKdeOsX15ndkwXe10a4E,801
149
151
  truss/tests/contexts/local_loader/test_truss_module_finder.py,sha256=oN1K2lg3ATHY5yOVUTfQIaSqusTF9I2wFaYaTSo5-O4,5342
150
152
  truss/tests/local/test_local_config_handler.py,sha256=aLvcOyfppskA2MziVLy_kMcagjxMpO4mjar9zxUN6g0,2245
151
- truss/tests/patch/test_calc_patch.py,sha256=GBPvIwIQ12fgYwqiHn7BUzAO5693-AX5R4upwGqfdB8,31960
153
+ truss/tests/patch/test_calc_patch.py,sha256=avV5-OpJK5rL811d0ERx-Mv9HERkpP4vxecxqh8hM7A,31782
152
154
  truss/tests/patch/test_dir_signature.py,sha256=HnG9Cyqt86YagYkY-jurSf36yYP2oM7PQvfb_d5T2mY,1033
153
155
  truss/tests/patch/test_hash.py,sha256=VsGAllNP653rmyrvPYBRY1gEc0gTpLl38tAhjXFUGGM,5997
154
156
  truss/tests/patch/test_signature.py,sha256=vdAy5dbIqTEWLZVpO6szTGdNTRZgE8PtABGuhPP0Y6s,728
@@ -162,7 +164,7 @@ truss/tests/remote/baseten/test_auth.py,sha256=ttu4bDnmwGfo3oiNut4HVGnh-QnjAefwZ
162
164
  truss/tests/remote/baseten/test_core.py,sha256=6NzJTDmoSUv6Muy1LFEYIUg10-cqw-hbLyeTSWcdNjY,26117
163
165
  truss/tests/remote/baseten/test_remote.py,sha256=y1qSPL1t7dBeYI3xMFn436fttG7wkYdAoENTz7qKObg,23634
164
166
  truss/tests/remote/baseten/test_service.py,sha256=ufZbtQlBNIzFCxRt_iE-APLpWbVw_3ViUpSh6H9W5nU,1945
165
- truss/tests/templates/control/control/test_endpoints.py,sha256=tGU3w8zOKC8LfWGdhp-TlV7E603KXg2xGwpqDdf8Pnw,3385
167
+ truss/tests/templates/control/control/test_endpoints.py,sha256=fxTiiCR0ltaHCL_-v-22Ie1qVgnch1lqcj3w0U3R-fk,3644
166
168
  truss/tests/templates/control/control/test_server.py,sha256=r1O3VEK9eoIL2-cg8nYLXYct_H3jf5rGp1wLT1KBdeA,9488
167
169
  truss/tests/templates/control/control/test_server_integration.py,sha256=EdDY3nLzjrRCJ5LI5yZsNCEImSRkxTL7Rn9mGnK67zA,11837
168
170
  truss/tests/templates/control/control/helpers/test_context_managers.py,sha256=3LoonRaKu_UvhaWs1eNmEQCZq-iJ3aIjI0Mn4amC8Bw,283
@@ -323,7 +325,7 @@ truss/truss_handle/readme_generator.py,sha256=B4XbGwUjzMNOr71DWNAL8kCu5_ZHq7YOM8
323
325
  truss/truss_handle/truss_gatherer.py,sha256=Xysl_UnCVhehPfZeHa8p7WFp94ENqh-VVpbuqnCui3A,2870
324
326
  truss/truss_handle/truss_handle.py,sha256=WF2MQSly9DQ1SoAvqfi87Ulu4llTadpXoncsDjpL79E,40886
325
327
  truss/truss_handle/patch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
326
- truss/truss_handle/patch/calc_patch.py,sha256=Qyk1QmacK4jy9Ia8-93L8VtAWJhw15z22DdZUkBKlys,18334
328
+ truss/truss_handle/patch/calc_patch.py,sha256=zaM30WExGxKsZBGiBjevDs583jwk5QSyO-uxH0PogX4,17936
327
329
  truss/truss_handle/patch/constants.py,sha256=pCEi5Pwi8Rnqthrr3VEsWL9EP1P1VV1T8DEYuitHLmc,139
328
330
  truss/truss_handle/patch/custom_types.py,sha256=QklzhgLD_PpvNvNYQCvujAd16eYEaDGfLA1scxk6zsA,3481
329
331
  truss/truss_handle/patch/dir_signature.py,sha256=UCdZCzXkI-l-ae0I0pdmB2bavB9qzhhOKYXyLnDFQZY,921
@@ -346,27 +348,28 @@ truss_chains/__init__.py,sha256=QDw1YwdqMaQpz5Oltu2Eq2vzEX9fDrMoqnhtbeh60i4,1278
346
348
  truss_chains/framework.py,sha256=CS7tSegPe2Q8UUT6CDkrtSrB3utr_1QN1jTEPjrj5Ug,67519
347
349
  truss_chains/private_types.py,sha256=6CaQEPawFLXjEbJ-01lqfexJtUIekF_q61LNENWegFo,8917
348
350
  truss_chains/public_api.py,sha256=0AXV6UdZIFAMycUNG_klgo4aLFmBZeKGfrulZEWzR0M,9532
349
- truss_chains/public_types.py,sha256=q8Oet6MpECW1FhWW25SCExpZhmk4cFmEsqrO30oZIMw,29112
351
+ truss_chains/public_types.py,sha256=Am1vc5pWuEDs65UQp8Be4iOU05kMXjPG3QrJfNXDmHs,29225
350
352
  truss_chains/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
351
353
  truss_chains/pydantic_numpy.py,sha256=MG8Ji_Inwo_JSfM2n7TPj8B-nbrBlDYsY3SOeBwD8fE,4289
352
354
  truss_chains/streaming.py,sha256=DGl2LEAN67YwP7Nn9MK488KmYc4KopWmcHuE6WjyO1Q,12521
353
355
  truss_chains/utils.py,sha256=LvpCG2lnN6dqPqyX3PwLH9tyjUzqQN3N4WeEFROMHak,6291
354
356
  truss_chains/deployment/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
355
- truss_chains/deployment/code_gen.py,sha256=AmAUZ3h1hP3uYkl3J6o096K5RFLuBOP7kOFSnFC_C4U,32568
356
- truss_chains/deployment/deployment_client.py,sha256=haFiVmQek42ewlN_YflBaRDQT4ZYbmT20tvvJOkcUX0,32899
357
+ truss_chains/deployment/code_gen.py,sha256=IBOYdhsWUyW_sBVhlEQAhvwxKcsGflDjgmR-1HyJJLg,32666
358
+ truss_chains/deployment/deployment_client.py,sha256=2paNyBjrpFTxROP0YrmJMUlH6o8mrkF-iPms7VJhLdA,34017
357
359
  truss_chains/reference_code/reference_chainlet.py,sha256=5feSeqGtrHDbldkfZCfX2R5YbbW0Uhc35mhaP2pXrHw,1340
358
360
  truss_chains/reference_code/reference_model.py,sha256=emH3hb23E_nbP98I37PGp1Xk1hz3g3lQ00tiLo55cSM,322
359
361
  truss_chains/remote_chainlet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
360
362
  truss_chains/remote_chainlet/model_skeleton.py,sha256=8ZReLOO2MLcdg7bNZ61C-6j-e68i2Z-fFlyV3sz0qH8,2376
361
363
  truss_chains/remote_chainlet/stub.py,sha256=Y2gDUzMY9WRaQNHIz-o4dfLUfFyYV9dUhIRQcfgrY8g,17209
362
- truss_chains/remote_chainlet/utils.py,sha256=O_5P-VAUvg0cegEW1uKCOf5EBwD8rEGYVoGMivOmc7k,22374
363
- truss_train/__init__.py,sha256=7hE6j6-u6UGzCGaNp3CsCN0kAVjBus1Ekups-Bk0fi4,837
364
- truss_train/definitions.py,sha256=V985HhY4rdXL10DZxpFEpze9ScxzWErMht4WwaPknGU,6789
364
+ truss_chains/remote_chainlet/utils.py,sha256=Zn3GZRvK8f65WUa-qa-8uPFZ2pD7ukRFxbLOvT-BL0Q,24063
365
+ truss_train/__init__.py,sha256=A3MzRPMInZfmzLvPpZI7gdKgshAVCw6bwhU-6JYU2zs,939
366
+ truss_train/definitions.py,sha256=3wVxkxMtHlcc-hb2umtj74FjA9TjenfiPTX7EQSh6zw,8245
365
367
  truss_train/deployment.py,sha256=lWWANSuzBWu2M4oK4qD7n-oVR1JKdmw2Pn5BJQHg-Ck,3074
366
368
  truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
367
369
  truss_train/public_api.py,sha256=9N_NstiUlmBuLUwH_fNG_1x7OhGCytZLNvqKXBlStrM,1220
368
- truss-0.11.2rc503.dist-info/METADATA,sha256=ibPdGGEmpNg-1bG8GcmjRw0UR1KB6Dak-4uHszZ9_qo,6674
369
- truss-0.11.2rc503.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
370
- truss-0.11.2rc503.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
371
- truss-0.11.2rc503.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
372
- truss-0.11.2rc503.dist-info/RECORD,,
370
+ truss_train/restore_from_checkpoint.py,sha256=KmJuTUVpvtvlkEClcmllxAF2TKgbp-FuzfblfGh06XA,1239
371
+ truss-0.11.2rc505.dist-info/METADATA,sha256=xktjK8t_rAM2j4EYCBYXAPMKlVFZuowom6o6uE06q_Y,6674
372
+ truss-0.11.2rc505.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
373
+ truss-0.11.2rc505.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
374
+ truss-0.11.2rc505.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
375
+ truss-0.11.2rc505.dist-info/RECORD,,
@@ -544,7 +544,11 @@ async def websocket(self, websocket: fastapi.WebSocket) -> None:
544
544
  )"""
545
545
  return _Source(
546
546
  src=src,
547
- imports={"import fastapi", "from truss_chains.remote_chainlet import utils"},
547
+ imports={
548
+ "import fastapi",
549
+ "from starlette.websockets import WebSocketState",
550
+ "from truss_chains.remote_chainlet import utils",
551
+ },
548
552
  )
549
553
 
550
554
 
@@ -138,7 +138,12 @@ class ChainService(abc.ABC):
138
138
 
139
139
  def _generate_chainlet_artifacts(
140
140
  options: private_types.PushOptions, entrypoint: Type[private_types.ABCChainlet]
141
- ) -> tuple[b10_types.ChainletArtifact, list[b10_types.ChainletArtifact], bool]:
141
+ ) -> tuple[
142
+ b10_types.ChainletArtifact,
143
+ list[b10_types.ChainletArtifact],
144
+ bool,
145
+ Optional[private_types.ChainletAPIDescriptor],
146
+ ]:
142
147
  chain_root = _get_chain_root(entrypoint)
143
148
  entrypoint_artifact: Optional[b10_types.ChainletArtifact] = None
144
149
  dependency_artifacts: list[b10_types.ChainletArtifact] = []
@@ -192,7 +197,19 @@ def _generate_chainlet_artifacts(
192
197
 
193
198
  assert entrypoint_artifact is not None
194
199
 
195
- return entrypoint_artifact, dependency_artifacts, has_engine_builder_chainlets
200
+ # Find the entrypoint descriptor
201
+ entrypoint_descriptor = None
202
+ for chainlet_descriptor in _get_ordered_dependencies([entrypoint]):
203
+ if chainlet_descriptor.chainlet_cls == entrypoint:
204
+ entrypoint_descriptor = chainlet_descriptor
205
+ break
206
+
207
+ return (
208
+ entrypoint_artifact,
209
+ dependency_artifacts,
210
+ has_engine_builder_chainlets,
211
+ entrypoint_descriptor,
212
+ )
196
213
 
197
214
 
198
215
  @framework.raise_validation_errors_before
@@ -201,9 +218,12 @@ def push(
201
218
  options: private_types.PushOptions,
202
219
  progress_bar: Optional[Type["progress.Progress"]] = None,
203
220
  ) -> Optional[ChainService]:
204
- entrypoint_artifact, dependency_artifacts, has_engine_builder_chainlets = (
205
- _generate_chainlet_artifacts(options, entrypoint)
206
- )
221
+ (
222
+ entrypoint_artifact,
223
+ dependency_artifacts,
224
+ has_engine_builder_chainlets,
225
+ entrypoint_descriptor,
226
+ ) = _generate_chainlet_artifacts(options, entrypoint)
207
227
  if options.only_generate_trusses:
208
228
  return None
209
229
  if isinstance(options, private_types.PushOptionsBaseten):
@@ -213,7 +233,11 @@ def push(
213
233
  "not supportd, push with `--publish`."
214
234
  )
215
235
  return _create_baseten_chain(
216
- options, entrypoint_artifact, dependency_artifacts, progress_bar
236
+ options,
237
+ entrypoint_artifact,
238
+ dependency_artifacts,
239
+ progress_bar,
240
+ entrypoint_descriptor,
217
241
  )
218
242
  elif isinstance(options, private_types.PushOptionsLocalDocker):
219
243
  if has_engine_builder_chainlets:
@@ -369,16 +393,19 @@ def _create_docker_chain(
369
393
  class BasetenChainService(ChainService):
370
394
  _chain_deployment_handle: b10_core.ChainDeploymentHandleAtomic
371
395
  _remote: b10_remote.BasetenRemote
396
+ _entrypoint_descriptor: Optional[private_types.ChainletAPIDescriptor]
372
397
 
373
398
  def __init__(
374
399
  self,
375
400
  name: str,
376
401
  chain_deployment_handle: b10_core.ChainDeploymentHandleAtomic,
377
402
  remote: b10_remote.BasetenRemote,
403
+ entrypoint_descriptor: Optional[private_types.ChainletAPIDescriptor] = None,
378
404
  ) -> None:
379
405
  super().__init__(name)
380
406
  self._chain_deployment_handle = chain_deployment_handle
381
407
  self._remote = remote
408
+ self._entrypoint_descriptor = entrypoint_descriptor
382
409
 
383
410
  @property
384
411
  def run_remote_url(self) -> str:
@@ -393,6 +420,13 @@ class BasetenChainService(ChainService):
393
420
  is_draft=handle.is_draft,
394
421
  )
395
422
 
423
+ @property
424
+ def is_websocket(self) -> bool:
425
+ """Check if the entrypoint uses websockets."""
426
+ if self._entrypoint_descriptor is None:
427
+ return False
428
+ return self._entrypoint_descriptor.endpoint.is_websocket
429
+
396
430
  def run_remote(self, json_data: Dict) -> Any:
397
431
  """Invokes the entrypoint with JSON data.
398
432
 
@@ -462,6 +496,7 @@ def _create_baseten_chain(
462
496
  entrypoint_artifact: b10_types.ChainletArtifact,
463
497
  dependency_artifacts: list[b10_types.ChainletArtifact],
464
498
  progress_bar: Optional[Type["progress.Progress"]],
499
+ entrypoint_descriptor: Optional[private_types.ChainletAPIDescriptor] = None,
465
500
  ):
466
501
  logging.info(
467
502
  f"Pushing Chain '{baseten_options.chain_name}' to Baseten "
@@ -491,7 +526,10 @@ def _create_baseten_chain(
491
526
  progress_bar=progress_bar,
492
527
  )
493
528
  return BasetenChainService(
494
- baseten_options.chain_name, chain_deployment_handle, remote_provider
529
+ baseten_options.chain_name,
530
+ chain_deployment_handle,
531
+ remote_provider,
532
+ entrypoint_descriptor,
495
533
  )
496
534
 
497
535
 
@@ -473,6 +473,7 @@ class WebSocketProtocol(Protocol):
473
473
 
474
474
  async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: ...
475
475
 
476
+ async def receive(self) -> Union[str, bytes]: ...
476
477
  async def receive_text(self) -> str: ...
477
478
  async def receive_bytes(self) -> bytes: ...
478
479
  async def receive_json(self) -> Any: ...
@@ -481,9 +482,11 @@ class WebSocketProtocol(Protocol):
481
482
  async def send_bytes(self, data: bytes) -> None: ...
482
483
  async def send_json(self, data: Any) -> None: ...
483
484
 
484
- def iter_text(self) -> AsyncIterator[str]: ...
485
- def iter_bytes(self) -> AsyncIterator[bytes]: ...
486
- def iter_json(self) -> AsyncIterator[Any]: ...
485
+ async def iter_text(self) -> AsyncIterator[str]: ...
486
+ async def iter_bytes(self) -> AsyncIterator[bytes]: ...
487
+ async def iter_json(self) -> AsyncIterator[Any]: ...
488
+
489
+ def is_connected(self) -> bool: ...
487
490
 
488
491
 
489
492
  class EngineBuilderLLMInput(pydantic.BaseModel):
@@ -11,6 +11,7 @@ import textwrap
11
11
  import threading
12
12
  import time
13
13
  import traceback
14
+ import typing
14
15
  from collections.abc import AsyncIterator
15
16
  from typing import (
16
17
  TYPE_CHECKING,
@@ -38,7 +39,7 @@ if TYPE_CHECKING:
38
39
  try:
39
40
  import prometheus_client
40
41
  except ImportError:
41
- logging.warning("Optional `prometheus_client` is not installed. ")
42
+ logging.warning("Optional `prometheus_client` is not installed.")
42
43
 
43
44
  class _NoOpMetric:
44
45
  def labels(self, *args: object, **kwargs: object) -> "_NoOpMetric":
@@ -64,6 +65,26 @@ except ImportError:
64
65
  return _NoOpMetric()
65
66
 
66
67
 
68
+ try:
69
+ from fastapi import WebSocketDisconnect
70
+ except ImportError:
71
+ # NB(nikhil): Stub implementation of WebSocketDisconnect, in case local environment doesn't have
72
+ # fastapi.
73
+ class WebSocketDisconnect(Exception): # type: ignore[no-redef]
74
+ def __init__(self, code: int, reason: Optional[str] = None):
75
+ super().__init__()
76
+ self.code = code
77
+ self.reason = reason
78
+
79
+
80
+ try:
81
+ from starlette.websockets import WebSocketState
82
+ except ImportError:
83
+ # NB(nikhil): Stub implementation of WebSocketState, in case local environment doesn't have starlette.
84
+ class WebSocketState: # type: ignore[no-redef]
85
+ CONNECTED = "connected"
86
+
87
+
67
88
  T = TypeVar("T")
68
89
 
69
90
  _LockT = TypeVar("_LockT", bound=Union[threading.Lock, asyncio.Lock])
@@ -586,6 +607,18 @@ class WebsocketWrapperFastAPI:
586
607
  async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
587
608
  await self._websocket.close(code=code, reason=reason)
588
609
 
610
+ async def receive(self) -> Union[str, bytes]:
611
+ message = await self._websocket.receive()
612
+
613
+ if message.get("type") == "websocket.disconnect":
614
+ # NB(nikhil): Mimics FastAPI `_raise_on_disconnect`, since otherwise the user has no
615
+ # way of detecting that the client disconnected.
616
+ raise WebSocketDisconnect(message["code"], message.get("reason"))
617
+ elif message.get("text"):
618
+ return typing.cast(str, message["text"])
619
+ else:
620
+ return typing.cast(bytes, message["bytes"])
621
+
589
622
  async def receive_text(self) -> str:
590
623
  return await self._websocket.receive_text()
591
624
 
@@ -605,13 +638,19 @@ class WebsocketWrapperFastAPI:
605
638
  await self._websocket.send_json(data)
606
639
 
607
640
  async def iter_text(self) -> AsyncIterator[str]:
608
- while True:
609
- yield await self.receive_text()
641
+ return self._websocket.iter_text()
610
642
 
611
643
  async def iter_bytes(self) -> AsyncIterator[bytes]:
612
- while True:
613
- yield await self.receive_bytes()
644
+ return self._websocket.iter_bytes()
614
645
 
615
646
  async def iter_json(self) -> AsyncIterator[Any]:
616
- while True:
617
- yield await self.receive_json()
647
+ return self._websocket.iter_json()
648
+
649
+ async def is_connected(self) -> bool:
650
+ # NB(nikhil): This isn't a foolproof mechanism for detecting whether a websocket
651
+ # connection is actually alive, ping/pong messages are best suited for that. However,
652
+ # as a heuristic to determine if a message is safe to send, this can do a pretty good job.
653
+ return (
654
+ self._websocket.application_state == WebSocketState.CONNECTED
655
+ and self._websocket.client_state == WebSocketState.CONNECTED
656
+ )
truss_train/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from truss_train.definitions import (
2
2
  AWSIAMDockerAuth,
3
+ BasetenCheckpoint,
3
4
  CacheConfig,
4
5
  CheckpointingConfig,
5
6
  CheckpointList,
@@ -10,6 +11,7 @@ from truss_train.definitions import (
10
11
  FullCheckpoint,
11
12
  GCPServiceAccountJSONDockerAuth,
12
13
  Image,
14
+ LoadCheckpointConfig,
13
15
  LoRACheckpoint,
14
16
  LoRADetails,
15
17
  ModelWeightsFormat,
@@ -36,6 +38,8 @@ __all__ = [
36
38
  "CacheConfig",
37
39
  "AWSIAMDockerAuth",
38
40
  "GCPServiceAccountJSONDockerAuth",
41
+ "LoadCheckpointConfig",
42
+ "BasetenCheckpoint",
39
43
  "DockerAuth",
40
44
  "Image",
41
45
  ]
@@ -1,11 +1,11 @@
1
1
  import enum
2
2
  from abc import ABC
3
- from typing import Dict, List, Optional, Union
3
+ from typing import Dict, List, Literal, Optional, Union
4
4
 
5
5
  import pydantic
6
6
  from pydantic import field_validator, model_validator
7
7
 
8
- from truss.base import custom_types, truss_config
8
+ from truss.base import constants, custom_types, truss_config
9
9
 
10
10
  DEFAULT_LORA_RANK = 16
11
11
 
@@ -56,6 +56,50 @@ class Compute(custom_types.SafeModelNoExtra):
56
56
  )
57
57
 
58
58
 
59
+ class _CheckpointBase(custom_types.SafeModelNoExtra):
60
+ typ: str
61
+
62
+
63
+ class _BasetenLatestCheckpoint(_CheckpointBase):
64
+ job_id: Optional[str] = None
65
+ project_name: Optional[str] = None
66
+ typ: Literal["baseten_latest_checkpoint"] = "baseten_latest_checkpoint"
67
+
68
+
69
+ class _BasetenNamedCheckpoint(_CheckpointBase):
70
+ checkpoint_name: str
71
+ job_id: Optional[str]
72
+ project_name: Optional[str]
73
+ typ: Literal["baseten_named_checkpoint"] = "baseten_named_checkpoint"
74
+
75
+
76
+ class BasetenCheckpoint:
77
+ @staticmethod
78
+ def from_latest_checkpoint(
79
+ project_name: Optional[str] = None, job_id: Optional[str] = None
80
+ ) -> _BasetenLatestCheckpoint:
81
+ return _BasetenLatestCheckpoint(project_name=project_name, job_id=job_id)
82
+
83
+ @classmethod
84
+ def from_named_checkpoint(
85
+ cls,
86
+ checkpoint_name: str,
87
+ project_name: Optional[str] = None,
88
+ job_id: Optional[str] = None,
89
+ ) -> _BasetenNamedCheckpoint:
90
+ return _BasetenNamedCheckpoint(
91
+ checkpoint_name=checkpoint_name, project_name=project_name, job_id=job_id
92
+ )
93
+
94
+
95
+ class LoadCheckpointConfig(custom_types.SafeModelNoExtra):
96
+ enabled: bool = False
97
+ checkpoints: List[Union[_BasetenLatestCheckpoint, _BasetenNamedCheckpoint]] = [
98
+ _BasetenLatestCheckpoint()
99
+ ]
100
+ download_folder: str = constants.DEFAULT_TRAINING_CHECKPOINT_FOLDER
101
+
102
+
59
103
  class CheckpointingConfig(custom_types.SafeModelNoExtra):
60
104
  enabled: bool = False
61
105
  checkpoint_path: Optional[str] = None
@@ -72,6 +116,7 @@ class Runtime(custom_types.SafeModelNoExtra):
72
116
  environment_variables: Dict[str, Union[str, SecretReference]] = {}
73
117
  enable_cache: Optional[bool] = None
74
118
  checkpointing_config: CheckpointingConfig = CheckpointingConfig()
119
+ load_checkpoint_config: Optional[LoadCheckpointConfig] = None
75
120
  cache_config: Optional[CacheConfig] = None
76
121
 
77
122
  @model_validator(mode="before")
@@ -0,0 +1,42 @@
1
+ from truss_train import (
2
+ BasetenCheckpoint,
3
+ CheckpointingConfig,
4
+ Image,
5
+ LoadCheckpointConfig,
6
+ Runtime,
7
+ TrainingJob,
8
+ TrainingProject,
9
+ )
10
+
11
+ load_checkpoint_config = LoadCheckpointConfig(enabled=True)
12
+
13
+ load_from_most_recent_checkpoint = BasetenCheckpoint.from_latest_checkpoint()
14
+
15
+ load_most_recent_checkpoint = BasetenCheckpoint.from_latest_checkpoint(
16
+ job_id="lqz4pw4", # Optional
17
+ project_name="first-project", # Optional
18
+ )
19
+
20
+ load_from_named_checkpoint = BasetenCheckpoint.from_named_checkpoint(
21
+ checkpoint_name="checkpoint-24",
22
+ project_name="first-project", # Optional
23
+ job_id="lqz4pw4", # Optional
24
+ )
25
+
26
+ load_checkpoint_config = LoadCheckpointConfig(
27
+ enabled=True,
28
+ download_folder="/tmp/custom_location", # default is None -> default path set by server-side
29
+ checkpoints=[load_from_most_recent_checkpoint, load_from_named_checkpoint],
30
+ )
31
+
32
+ checkpointing_config = CheckpointingConfig(enabled=True)
33
+
34
+ job = TrainingJob(
35
+ image=Image(base_image="ghcr.io/baseten-ai/truss-train-base:latest"),
36
+ runtime=Runtime(
37
+ checkpointing_config=checkpointing_config,
38
+ load_checkpoint_config=load_checkpoint_config,
39
+ ),
40
+ )
41
+
42
+ project = TrainingProject(name="new-project", job=job)