truss 0.11.18rc500__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/truss_config.py +10 -3
  3. truss/cli/chains_commands.py +39 -1
  4. truss/cli/cli.py +35 -5
  5. truss/cli/remote_cli.py +29 -0
  6. truss/cli/resolvers/chain_team_resolver.py +82 -0
  7. truss/cli/resolvers/model_team_resolver.py +90 -0
  8. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  9. truss/cli/train/cache.py +332 -0
  10. truss/cli/train/core.py +19 -143
  11. truss/cli/train_commands.py +69 -11
  12. truss/cli/utils/common.py +40 -3
  13. truss/remote/baseten/api.py +58 -5
  14. truss/remote/baseten/core.py +22 -4
  15. truss/remote/baseten/remote.py +24 -2
  16. truss/templates/control/control/helpers/inference_server_process_controller.py +3 -1
  17. truss/templates/server/requirements.txt +1 -1
  18. truss/templates/server.Dockerfile.jinja +10 -10
  19. truss/templates/shared/util.py +6 -5
  20. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  21. truss/tests/cli/test_chains_cli.py +44 -0
  22. truss/tests/cli/test_cli.py +134 -1
  23. truss/tests/cli/test_cli_utils_common.py +11 -0
  24. truss/tests/cli/test_model_team_resolver.py +279 -0
  25. truss/tests/cli/train/test_cache_view.py +240 -3
  26. truss/tests/cli/train/test_train_cli_core.py +2 -2
  27. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  28. truss/tests/conftest.py +187 -0
  29. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  30. truss/tests/remote/baseten/test_api.py +122 -3
  31. truss/tests/remote/baseten/test_chain_upload.py +10 -1
  32. truss/tests/remote/baseten/test_core.py +86 -0
  33. truss/tests/remote/baseten/test_remote.py +216 -288
  34. truss/tests/test_config.py +21 -12
  35. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  36. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  37. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  38. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  39. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  40. truss/tests/test_model_inference.py +13 -0
  41. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/METADATA +1 -1
  42. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/RECORD +50 -38
  43. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  44. truss_chains/deployment/deployment_client.py +9 -4
  45. truss_chains/private_types.py +15 -0
  46. truss_train/definitions.py +3 -1
  47. truss_train/deployment.py +43 -21
  48. truss_train/public_api.py +4 -2
  49. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  50. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,14 @@
1
+ environment_variables: {}
2
+ external_package_dirs: []
3
+ model_metadata: {}
4
+ model_name: Test Build Commands
5
+ python_version: py39
6
+ resources:
7
+ accelerator: null
8
+ cpu: '1'
9
+ memory: 2Gi
10
+ use_gpu: false
11
+ secrets: {}
12
+ system_packages: []
13
+ build_commands:
14
+ - sed -i 's/TEST_FIRST_VALUE/TEST_SECOND_VALUE/g' /packages/constants/constants.py
@@ -0,0 +1,12 @@
1
+ from constants.constants import TEST_KEY
2
+
3
+
4
+ class Model:
5
+ def __init__(self, **kwargs):
6
+ pass
7
+
8
+ def load(self):
9
+ pass
10
+
11
+ def predict(self, input) -> str:
12
+ return TEST_KEY
@@ -0,0 +1 @@
1
+ TEST_KEY = "TEST_FIRST_VALUE"
@@ -8,6 +8,7 @@ requirements:
8
8
  - torch
9
9
  model_cache:
10
10
  - repo_id: julien-c/EsperBERTo-small
11
+ use_volume: false
11
12
  ignore_patterns:
12
13
  - "*.bin"
13
14
  - "*.msgpack"
@@ -2099,3 +2099,16 @@ async def test_websocket_ping_timeout_behavior(caplog):
2099
2099
 
2100
2100
  # We wait 3 seconds, so there should be ~3 PING/PONGS
2101
2101
  assert 2 <= caplog.text.count("PING") <= 4
2102
+
2103
+
2104
+ @pytest.mark.integration
2105
+ def test_build_commands_on_model_files(test_data_path):
2106
+ with ensure_kill_all():
2107
+ truss_dir = test_data_path / "test_build_commands_truss"
2108
+ tr = TrussHandle(truss_dir)
2109
+ container, urls = tr.docker_run_for_test()
2110
+ time.sleep(3) # Sleeping to allow the load to finish
2111
+
2112
+ response = requests.post(urls.predict_url, json={})
2113
+ assert response.status_code == 200
2114
+ assert response.json() == "TEST_SECOND_VALUE"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: truss
3
- Version: 0.11.18rc500
3
+ Version: 0.11.24rc2
4
4
  Summary: A seamless bridge from model development to model delivery
5
5
  Project-URL: Repository, https://github.com/basetenlabs/truss
6
6
  Project-URL: Homepage, https://truss.baseten.co
@@ -1,23 +1,27 @@
1
1
  truss/__init__.py,sha256=CoUcP6vx_pocyemRmpbCPlndkHhdMkABAlr0ZXVuPCk,1163
2
- truss/api/__init__.py,sha256=5GTE2rlupet-beaawUr0FPyDPEJ9UyBTUpJmCE3RGfc,5453
2
+ truss/api/__init__.py,sha256=6ZDLaMn6MI5kqmRZaHwgZf2jOQyxQS3H45-Zy9qeSvg,5634
3
3
  truss/api/definitions.py,sha256=QAaIBqL59Q-R7HtLcXcoeCIWBN2HqOzApdFX0PpCq2s,1604
4
4
  truss/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  truss/base/constants.py,sha256=yFvDVyWMqeAsWOfKUKJVUNbH26N2JZKDqgxmRSe4EWk,3664
6
6
  truss/base/custom_types.py,sha256=FUSIT2lPOQb6gfg6IzT63YBV8r8L6NIZ0D74Fp3e_jQ,2835
7
7
  truss/base/errors.py,sha256=zDVLEvseTChdPP0oNhBBQCtQUtZJUaof5zeWMIjqz6o,691
8
8
  truss/base/trt_llm_config.py,sha256=rEtBVFg2QnNMxnaz11s5Z69dJB1w7Bpt48Wf6jSsVZI,33087
9
- truss/base/truss_config.py,sha256=s39Xc1e20s8IV07YLl_aVnp-uRS18ZQ2TV-3FILx4nY,28416
9
+ truss/base/truss_config.py,sha256=woY42bYYzA3-hWjzHqvQxpJG9hsXPvHZ7LNNAjiLkAE,28675
10
10
  truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
11
- truss/cli/chains_commands.py,sha256=QijtACpuAt2O1RV_qhTNPw0jcFg-u0dX9PP-ct0t-rs,17716
12
- truss/cli/cli.py,sha256=VGOw1ell7h9bna64UmopavCpVPdjDerSaGPDoizIsRI,30313
13
- truss/cli/remote_cli.py,sha256=G_xCKRXzgkCmkiZJhUFfsv5YSVgde1jLA5LPQitpZgI,1905
14
- truss/cli/train_commands.py,sha256=CrVqWsdkmSxgi3i2sSEyiE4QdfD0Z96F2Ib-PMZJjm8,20444
11
+ truss/cli/chains_commands.py,sha256=7meWaNVYIPmNkHRhRmW0of_vGa-McFag1XIgnUnSYVQ,18982
12
+ truss/cli/cli.py,sha256=iEO6JOY3yu4OtwzdMhmrxr72AboenLujWMm5WG6vKk4,31198
13
+ truss/cli/remote_cli.py,sha256=UI0YptsLD7WBEdi37fUoJIgCHGjrSxkiNFYS61Wlqdk,2730
14
+ truss/cli/train_commands.py,sha256=rL-wqqEl7CbvRKi3YkhmoFQO_59STVehlsOIC2dvmxY,22102
15
15
  truss/cli/logs/base_watcher.py,sha256=vuqteoaMVGX34cgKcETf4X_gOkvnSnDaWz1_pbeFhqs,3343
16
16
  truss/cli/logs/model_log_watcher.py,sha256=38vQCcNItfDrTKucvdJ10ZYLOcbGa5ZAKUqUnV4nH34,1971
17
17
  truss/cli/logs/training_log_watcher.py,sha256=r6HRqrLnz-PiKTUXiDYYxg4ZnP8vYcXlEX1YmgHhzlo,1173
18
18
  truss/cli/logs/utils.py,sha256=z-U_FG4BUzdZLbE3BnXb4DZQ0zt3LSZ3PiQpLaDuc3o,1031
19
+ truss/cli/resolvers/chain_team_resolver.py,sha256=9luCeXu4YpGn9K9snTqqB0GwzSvdUGOjA7CkXG1s-ug,3872
20
+ truss/cli/resolvers/model_team_resolver.py,sha256=D4QML3HK9T07XTCPiZ8JoYcqtYuXTPCVpwRPIni6kwA,4283
21
+ truss/cli/resolvers/training_project_team_resolver.py,sha256=UKwmuWL7OfMTbEtmGcTCmamGG_aNLLi2PzBZziBBJbo,3881
22
+ truss/cli/train/cache.py,sha256=S2f2x0ObdjUAyjLhzp3CPQCE5x2FbtSc5Gm_iTjlbZc,11519
19
23
  truss/cli/train/common.py,sha256=xTR41U5FeSndXfNBBHF9wF5XwZH1sOIVFlv-XHjsKIU,1547
20
- truss/cli/train/core.py,sha256=fWuHvjIT4tkax19B7_1_SWvkX1ot2xQ6WwcDGBhTnus,26520
24
+ truss/cli/train/core.py,sha256=ahyjfjkCQSH9Eqb_6pcrRM0Td52gsUcnnTkrAsTDpLE,22336
21
25
  truss/cli/train/deploy_from_checkpoint_config.yml,sha256=mktaVrfhN8Kjx1UveC4xr-gTW-kjwbHvq6bx_LpO-Wg,371
22
26
  truss/cli/train/deploy_from_checkpoint_config_whisper.yml,sha256=6GbOorYC8ml0UyOUvuBpFO_fuYtYE646JqsalR-D4oY,406
23
27
  truss/cli/train/metrics_watcher.py,sha256=smz-zrEsBj_-wJHI0pAZ-EAPrvfCWzq1eQjGiFNM-Mk,12755
@@ -29,7 +33,7 @@ truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py,sha256=r_IKMlqe
29
33
  truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py,sha256=f8_UB7CF6Y3MOhaf8Zim0heNiauOOAmA-WqsyP3X9mk,386
30
34
  truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py,sha256=YKVMy3xogmbubJNrN_1LCR6xdHj9lBOAlKgMxWHdlQM,1115
31
35
  truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py,sha256=eEo4ahRTsvRKOxCDj7DAHIUyUZWicn39VmC1PYS0pCY,314
32
- truss/cli/utils/common.py,sha256=ink9ZE0MsOv6PCFK_Ra5k1aHm281TXTnMpnLjf2PtUM,6585
36
+ truss/cli/utils/common.py,sha256=keB8t-IVq-Wgel9laOZ_Ag4p0KxqLEBY2MKRwWncwUw,7869
33
37
  truss/cli/utils/output.py,sha256=GNjU85ZAMp5BI6Yij5wYXcaAvpm_kmHV0nHNmdkMxb0,646
34
38
  truss/cli/utils/self_upgrade.py,sha256=eTJZA4Wc8uUp4Qh6viRQp6bZm--wnQp7KWe5KRRpPtg,5427
35
39
  truss/contexts/docker_build_setup.py,sha256=cF4ExZgtYvrWxvyCAaUZUvV_DB_7__MqVomUDpalvKo,3925
@@ -52,12 +56,12 @@ truss/patch/truss_dir_patch_applier.py,sha256=ALnaVnu96g0kF2UmGuBFTua3lrXpwAy4sG
52
56
  truss/remote/remote_factory.py,sha256=-0gLh_yIyNDgD48Q6sR8Yo5dOMQg84lrHRvn_XR0n4s,3585
53
57
  truss/remote/truss_remote.py,sha256=TEe6h6by5-JLy7PMFsDN2QxIY5FmdIYN3bKvHHl02xM,8440
54
58
  truss/remote/baseten/__init__.py,sha256=XNqJW1zyp143XQc6-7XVwsUA_Q_ZJv_ausn1_Ohtw9Y,176
55
- truss/remote/baseten/api.py,sha256=54Cl_2zHpRU4g2VXzK-BYlxPJeHHImceFrbxD9AASXo,30335
59
+ truss/remote/baseten/api.py,sha256=StxIk5k-88DYEd0YU7yKxuMXpcODSdgbqD7VjwRK--Q,32430
56
60
  truss/remote/baseten/auth.py,sha256=tI7s6cI2EZgzpMIzrdbILHyGwiHDnmoKf_JBhJXT55E,776
57
- truss/remote/baseten/core.py,sha256=FC5-87Vs2f0NR8eddtSRvr3Z5W2rF7mpiq9jCPrbzr4,23399
61
+ truss/remote/baseten/core.py,sha256=U2JwM-HFTf_KmVRnHeuGJU5g4bvi3auNfmTjkHlw48I,24210
58
62
  truss/remote/baseten/custom_types.py,sha256=g7yWkE8p6uIAG5JqgfELFGHzjFLvO7vLPzbe-yl1nYs,4735
59
63
  truss/remote/baseten/error.py,sha256=3TNTwwPqZnr4NRd9Sl6SfLUQR2fz9l6akDPpOntTpzA,578
60
- truss/remote/baseten/remote.py,sha256=aKG1BODtrnmuRV-M8T3F3pw8oHawGwI09caKANJ19BM,23420
64
+ truss/remote/baseten/remote.py,sha256=7xCm_zmfiWfnZLbJi4Txxq7xezN8B235a1tO0rVeTMQ,24372
61
65
  truss/remote/baseten/rest_client.py,sha256=_t3CWsWARt2u0C0fDsF4rtvkkHe-lH7KXoPxWXAkKd4,1185
62
66
  truss/remote/baseten/service.py,sha256=HMaKiYbr2Mzv4BfXF9QkJ8H3Wwrq3LOMpFt9js4t0rs,5834
63
67
  truss/remote/baseten/utils/status.py,sha256=jputc9N9AHXxUuW4KOk6mcZKzQ_gOBOe5BSx9K0DxPY,1266
@@ -72,7 +76,7 @@ truss/templates/cache_requirements.txt,sha256=xoPoJ-OVnf1z6oq_RVM3vCr3ionByyqMLj
72
76
  truss/templates/copy_cache_files.Dockerfile.jinja,sha256=Os5zFdYLZ_AfCRGq4RcpVTObOTwL7zvmwYcvOzd_Zqo,126
73
77
  truss/templates/docker_server_requirements.txt,sha256=PyhOPKAmKW1N2vLvTfLMwsEtuGpoRrbWuNo7tT6v2Mc,18
74
78
  truss/templates/no_build.Dockerfile.jinja,sha256=8x2PJUxr_gHai0St8ue2aWyih36t8kBytXMGr_5LG4w,35
75
- truss/templates/server.Dockerfile.jinja,sha256=Mu5_ZxuAknwaEOsF0l-XssA9pDg3pD3eLl6JBzNJ4rg,7091
79
+ truss/templates/server.Dockerfile.jinja,sha256=wApuWMfyXrDpPlIodm9bK-KNrsi7fVdNKm-iw7mi7bw,7075
76
80
  truss/templates/control/requirements.txt,sha256=tJGr83WoE0CZm2FrloZ9VScK84q-_FTuVXjDYrexhW0,250
77
81
  truss/templates/control/control/application.py,sha256=5Kam6M-XtfKGaXQz8cc3d0bwDkB80o2MskABWROx1gk,5321
78
82
  truss/templates/control/control/endpoints.py,sha256=KzqsLVNJE6r6TCPW8D5FMCtsfHadTwR15A3z_viGxmM,11782
@@ -81,7 +85,7 @@ truss/templates/control/control/helpers/context_managers.py,sha256=W6dyFgLBhPa5m
81
85
  truss/templates/control/control/helpers/custom_types.py,sha256=n_lTudtLTpy4oPV3aDdJ4X2rh3KCV5btYO9UnTeUouQ,5471
82
86
  truss/templates/control/control/helpers/errors.py,sha256=LddFuQywuCCdYTEnFT5EalxdWos4uR89rbhMakCy2bA,970
83
87
  truss/templates/control/control/helpers/inference_server_controller.py,sha256=anFm7FwkGaUnYRQo2dxXohA9__c-XAVAqfA1EL2bIIY,6324
84
- truss/templates/control/control/helpers/inference_server_process_controller.py,sha256=8jhAWsUjG2JQ3elJA6ldCPJ8deo0vOLoD3rI8eKwkvk,4426
88
+ truss/templates/control/control/helpers/inference_server_process_controller.py,sha256=RkkLA2_bKUl1iqwrwlH7gZPZUuUVFfJDy74EVWkRVOI,4579
85
89
  truss/templates/control/control/helpers/inference_server_starter.py,sha256=Fz2Puijro6Cc5cvTsAqOveNJbBQR_ARYJXl4lvETJ8Y,2633
86
90
  truss/templates/control/control/helpers/truss_patch/__init__.py,sha256=CXZdUV_ylqLTJrKuFpvSnUT6PUFrZrMF2y6jiHbdaKU,998
87
91
  truss/templates/control/control/helpers/truss_patch/model_code_patch_applier.py,sha256=LTIIADLz0wRn7V49j64dU1U7Hbta9YLde3pb5YZWvzQ,2001
@@ -97,7 +101,7 @@ truss/templates/docker_server/supervisord.conf.jinja,sha256=AliMMd6bNn-oCYIB8Gum
97
101
  truss/templates/server/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
98
102
  truss/templates/server/main.py,sha256=kWXrdD8z8IpamyWxc8qcvd5ck9gM1Kz2QH5qHJCnmOQ,222
99
103
  truss/templates/server/model_wrapper.py,sha256=k75VVISwwlsx5EGb82UZsu8kCM_i6Yi3-Hd0-Kpm1yo,42055
100
- truss/templates/server/requirements.txt,sha256=ZRnawwwAMkf88-S5GhXdjkLzB36IRg11AKUvIOV8kxg,672
104
+ truss/templates/server/requirements.txt,sha256=1EvdUuD9Fyy_xgo0a5WuAxMZaFRTZtozt79PB5KkGcI,672
101
105
  truss/templates/server/truss_server.py,sha256=YKcG7Sr0T_8XjIC3GK9vBwoNb8oxVgwic3-3Ikzpmgw,19781
102
106
  truss/templates/server/common/__init__.py,sha256=qHIqr68L5Tn4mV6S-PbORpcuJ4jmtBR8aCuRTIWDvNo,85
103
107
  truss/templates/server/common/errors.py,sha256=My0P6-Y7imVTICIhazHT0vlSu3XJDH7As06OyVzu4Do,8589
@@ -112,7 +116,7 @@ truss/templates/shared/lazy_data_resolver.py,sha256=HxrZz6X30j2LbsExYSqhuOGoYEff
112
116
  truss/templates/shared/log_config.py,sha256=l9udyu4VKHZePlfK9LQEd5TOUUodPuehypsXRSUL4Ac,5411
113
117
  truss/templates/shared/secrets_resolver.py,sha256=3prDe3Q06NTmUEe7KCW-W4TD1CzGck9lpDG789209z4,2110
114
118
  truss/templates/shared/serialization.py,sha256=_WC_2PPkRi-MdTwxwjG8LKQptnHi4sANfpOlKWevqWc,3736
115
- truss/templates/shared/util.py,sha256=dPgFF4iL_YkeC6Kf8tZUHJH60rbpskHwVPh0ONLGaQM,2222
119
+ truss/templates/shared/util.py,sha256=NoIel4ES73fHzjY7tknxAFhZQGh11tqL45Ye6-RTh4k,2281
116
120
  truss/templates/train/config.py,sha256=aQJ3lsyVRlq6edjjZq4_Anz1bZVwkjLdclmZPJTdo1k,1626
117
121
  truss/templates/train/run.sh,sha256=2rimigJOn6yg4DguRfOJWkzm77X-meNSYXnidLafqNg,346
118
122
  truss/templates/trtllm-audio/model/model.py,sha256=o38QqW57b1lf8O_td1lW_AojZZ8R_qAZCgzOWtoIse8,1619
@@ -127,28 +131,32 @@ truss/templates/trtllm-audio/packages/whisper_trt/utils.py,sha256=pi-c486yPe85Te
127
131
  truss/templates/trtllm-briton/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
128
132
  truss/templates/trtllm-briton/src/extension.py,sha256=6qwYPIYQEmXd2xz50-v80Nilc_xLAMgdYkHu2JWboH4,2655
129
133
  truss/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
130
- truss/tests/conftest.py,sha256=LCP-knM4UsQYvGd2Vbv2hHnbPCuOELCqPui31DrtxyM,24221
134
+ truss/tests/conftest.py,sha256=0izxvPZPdlRZBIvxH0z0ISNcdq3Et3OQ1BLDwevrNpM,29928
131
135
  truss/tests/helpers.py,sha256=DKnUGzmt-k4IN_wSnlAXbVYCiEt58xFzFmmuCDSQ0dg,555
132
136
  truss/tests/test_build.py,sha256=Wq4sM9tmZFVTCN3YljvOcn04Kkj-L-41Tlw0DKQ8Z7c,709
133
- truss/tests/test_config.py,sha256=AVpVCL_XHYXKSGHzwecrh7BAJlB_Wr5AUlnnwMqWM98,30559
137
+ truss/tests/test_config.py,sha256=bxIZVELyX31y69lPRdv2On8NiLD1M8IyDtQCv78kN9c,30942
134
138
  truss/tests/test_context_builder_image.py,sha256=fVZNJSzZNiWa7Dr1X_VhhMJtyJ5HzsLaPglOr6NV2CA,1105
135
139
  truss/tests/test_control_truss_patching.py,sha256=qQOUfyF1MorZ_obOvPJK9utI4HUAzgT6YBS-fo90TEw,14497
136
140
  truss/tests/test_custom_server.py,sha256=GP2qMgnqxJMPRtfEciqbhBcG0_JUK7gNL7nrXPGrSLg,1305
137
141
  truss/tests/test_docker.py,sha256=3RI6jEC9CVQsKj83s_gOBl3EkdOaov-KEX4IihfMJW4,523
138
- truss/tests/test_model_inference.py,sha256=Q8mgNDNbwAUi7AQTgmyK-QrYuksuARDczYndTh56fKk,76205
142
+ truss/tests/test_model_inference.py,sha256=JbLmBb3N0oRr58LdVmxaNsMgK8aeSfjDyUfsabZPD9Q,76688
139
143
  truss/tests/test_model_schema.py,sha256=Bw28CZ4D0JQOkYdBQJZvgryeW0TRn7Axketp5kvZ_t4,14219
140
144
  truss/tests/test_testing_utilities_for_other_tests.py,sha256=YqIKflnd_BUMYaDBSkX76RWiWGWM_UlC2IoT4NngMqE,3048
141
145
  truss/tests/test_truss_gatherer.py,sha256=bn288OEkC49YY0mhly4cAl410ktZPfElNdWwZy82WfA,1261
142
146
  truss/tests/test_truss_handle.py,sha256=-xz9VXkecXDTslmQZ-dmUmQLnvD0uumRqHS2uvGlMBA,30750
143
147
  truss/tests/test_util.py,sha256=hs1bNMkXKEdoPRx4Nw-NAEdoibR92OubZuADGmbiYsQ,1344
144
- truss/tests/cli/test_chains_cli.py,sha256=l9GTQrhRm9SRZn43WkMY4tdRslLmdsVyiydRPa1_Ja4,3162
145
- truss/tests/cli/test_cli.py,sha256=yfbVS5u1hnAmmA8mJ539vj3lhH-JVGUvC4Q_Mbort44,787
146
- truss/tests/cli/train/test_cache_view.py,sha256=aVRCh3atRpFbJqyYgq7N-vAW0DiKMftQ7ajUqO2ClOg,22606
148
+ truss/tests/cli/test_chains_cli.py,sha256=fmMYwA5XdA7lngyqnx4-3kKWxRB3vl094VjI44gVmzA,4582
149
+ truss/tests/cli/test_cli.py,sha256=kNQjwvUdsGhkDY6LbLT7X1AysfgVAvzRwF36swP-97E,5214
150
+ truss/tests/cli/test_cli_utils_common.py,sha256=X9eU5MngmQPs_R3hqf-t2TlGVyZUdxSwYfYv7nUC9R4,455
151
+ truss/tests/cli/test_model_team_resolver.py,sha256=Dsk0u5G66LaHhCpt-tWpxvKALXTINMzDP-mylN351Ak,10370
152
+ truss/tests/cli/chains/test_chains_team_parameter.py,sha256=psnb7batCG0jTUUhi_1dD4ji-XcqH0EKtEqoTvIsm3w,18816
153
+ truss/tests/cli/train/test_cache_view.py,sha256=912zvkbI6m3ffGVqUCgRcoB1cuaUPxvo3rSbmwQrO4E,30213
147
154
  truss/tests/cli/train/test_deploy_checkpoints.py,sha256=Ndkd9YxEgDLf3zLAZYH0myFK_wkKTz0oGZ57yWQt_l8,10100
148
- truss/tests/cli/train/test_train_cli_core.py,sha256=vzYfxKdwoa3NaFMrVZbSg5qOoLXivMvZXN1ClQirGTQ,16148
155
+ truss/tests/cli/train/test_train_cli_core.py,sha256=Fabmz8bDPEElgUnCRxORyFVl9v8LzORCPrJgrDnFP10,16177
149
156
  truss/tests/cli/train/test_train_init.py,sha256=SRAZvvD5-PWYlpHHek2MftYTA4I3ZHi7gniHl2fYV98,17464
157
+ truss/tests/cli/train/test_train_team_parameter.py,sha256=kIth6iO7zd8MHqzhq2x3jFFkhxpYaY-V0bHUkaSvpHI,18748
150
158
  truss/tests/cli/train/resources/test_deploy_from_checkpoint_config.yml,sha256=GF7r9l0KaeXiUYCPSBpeMPd2QG6PeWWyI12NdbqLOgc,1930
151
- truss/tests/contexts/image_builder/test_serving_image_builder.py,sha256=16niCXZnuxFHXYQw2vPFZ8svSZafkH5DT0Gx3Z9Xdd8,22377
159
+ truss/tests/contexts/image_builder/test_serving_image_builder.py,sha256=okn0Dx7Pf976YIcRsMLRXqiIks81yAa9v1HgiDE9uv8,22512
152
160
  truss/tests/contexts/local_loader/test_load_local.py,sha256=D1qMH2IpYA2j5009v50QMgUnKdeOsX15ndkwXe10a4E,801
153
161
  truss/tests/contexts/local_loader/test_truss_module_finder.py,sha256=oN1K2lg3ATHY5yOVUTfQIaSqusTF9I2wFaYaTSo5-O4,5342
154
162
  truss/tests/local/test_local_config_handler.py,sha256=aLvcOyfppskA2MziVLy_kMcagjxMpO4mjar9zxUN6g0,2245
@@ -161,11 +169,11 @@ truss/tests/patch/test_types.py,sha256=OUVDiLckbjjjEN49I4hm62emOTAr8lv_QooJrmXxs
161
169
  truss/tests/remote/test_remote_factory.py,sha256=S-iZlF5Pf5SDoFUnMlZXy9iRMkosVgwLd22evzWlFr0,4842
162
170
  truss/tests/remote/test_truss_remote.py,sha256=Rguyrnbx5RlbPJHFfCtsRtX1czAJ9Fo0aeC5EWRVkGw,2726
163
171
  truss/tests/remote/baseten/conftest.py,sha256=vNk0nfDB7XdmqatOMhjdANCWFGYM4VwSHVKlaBO2PPk,442
164
- truss/tests/remote/baseten/test_api.py,sha256=AKJeNsrUtTNa0QPClfEvXlBOSJ214PKp23ULehMRJOQ,15885
172
+ truss/tests/remote/baseten/test_api.py,sha256=bRTq4r1KKvVtAXTwo-HOUvDu0mdhuy65_zqfGlT3UhM,20327
165
173
  truss/tests/remote/baseten/test_auth.py,sha256=ttu4bDnmwGfo3oiNut4HVGnh-QnjAefwZJctiibQJKY,669
166
- truss/tests/remote/baseten/test_chain_upload.py,sha256=XaaF1ocovkBYsLMJ8EpXB9FUGfQZAwu4iyOWqoVn7tc,10886
167
- truss/tests/remote/baseten/test_core.py,sha256=6NzJTDmoSUv6Muy1LFEYIUg10-cqw-hbLyeTSWcdNjY,26117
168
- truss/tests/remote/baseten/test_remote.py,sha256=y1qSPL1t7dBeYI3xMFn436fttG7wkYdAoENTz7qKObg,23634
174
+ truss/tests/remote/baseten/test_chain_upload.py,sha256=D_Uhgi8YfvYpF2vAcFFutCT-gDPWM4YobHdRs-L87n8,11409
175
+ truss/tests/remote/baseten/test_core.py,sha256=oi9dSmlkJWO88LwDTULquNlQbnEyxHnPYehm8ph01kg,29253
176
+ truss/tests/remote/baseten/test_remote.py,sha256=NHUDEU3UrSKYVqGda0Q2h9DhIfjVE8tscVde0dvSi5M,21476
169
177
  truss/tests/remote/baseten/test_service.py,sha256=ehbGkzzSPdLN7JHxc0O9YDPfzzKqU8OBzJGjRdw08zE,3786
170
178
  truss/tests/templates/control/control/conftest.py,sha256=euDFh0AhcHP-vAmTzi1Qj3lymnplDTgvtbt4Ez_lfpw,654
171
179
  truss/tests/templates/control/control/test_endpoints.py,sha256=HIlRYOicsdHD8r_V5gHpZWybDC26uwXJfbvCohdE3HI,3751
@@ -226,6 +234,10 @@ truss/tests/test_data/test_build_commands_failure/__init__.py,sha256=47DEQpj8HBS
226
234
  truss/tests/test_data/test_build_commands_failure/config.yaml,sha256=xosGlR8QZNo-eE1kOj6_4wVrkBD1Y928NGcFeK8Lo2g,259
227
235
  truss/tests/test_data/test_build_commands_failure/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
228
236
  truss/tests/test_data/test_build_commands_failure/model/model.py,sha256=ELYdtI0UT0T45c1yfnSsc4LvQHQn66e50UMj9RYEm1g,502
237
+ truss/tests/test_data/test_build_commands_truss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
238
+ truss/tests/test_data/test_build_commands_truss/config.yaml,sha256=NaAlcxR9QRew2n3C5uMhreJdirjuP0OrhhTXuEWyajQ,330
239
+ truss/tests/test_data/test_build_commands_truss/model/model.py,sha256=BMgwja9pg_GQCPp08kh8DiXKhMF7NAQK2U7aiAYY8X8,199
240
+ truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py,sha256=WelkZEadCUlo8BkCYwQt2XPhTuCLbey7v10sZMEqClo,30
229
241
  truss/tests/test_data/test_concurrency_truss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
230
242
  truss/tests/test_data/test_concurrency_truss/config.yaml,sha256=8a1tsXlHn8IVbu_X3anP93YfkkaFxJ6wwWVI2t-q3UA,111
231
243
  truss/tests/test_data/test_concurrency_truss/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -301,7 +313,7 @@ truss/tests/test_data/test_truss/packages/__init__.py,sha256=47DEQpj8HBSa-_TImW-
301
313
  truss/tests/test_data/test_truss/packages/test_package/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
302
314
  truss/tests/test_data/test_truss/packages/test_package/test.py,sha256=Crrh4K5yghbuRJk8Wjp1X4scOH2Uf8TE9yyrDkqEIUs,6
303
315
  truss/tests/test_data/test_truss_server_model_cache_v1/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
304
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml,sha256=k8wO0gs60WsLhkF6Y8ObUkhdJoPZENLqCRyg9g8ScUM,379
316
+ truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml,sha256=Cjdm2tQSgtA6K7xSn3gPx0UkorMP8IgKq0ixJm2ekWg,401
305
317
  truss/tests/test_data/test_truss_server_model_cache_v1/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
306
318
  truss/tests/test_data/test_truss_server_model_cache_v1/model/model.py,sha256=PzIUuHohTxl7rqFpKNF6Gx2t9cfYhc_T6_a8FEPygyo,448
307
319
  truss/tests/test_data/test_truss_server_model_cache_v2/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -352,7 +364,7 @@ truss/util/requirements.py,sha256=6T4nVV_NbSl3mAEo-CAk3JFmyJ_RJD768QaR55RdUJQ,69
352
364
  truss/util/user_config.py,sha256=CvBf5oouNyfdcFXOg3HFhELVW-THiuwyOYdW3aTxdHw,9130
353
365
  truss_chains/__init__.py,sha256=QDw1YwdqMaQpz5Oltu2Eq2vzEX9fDrMoqnhtbeh60i4,1278
354
366
  truss_chains/framework.py,sha256=CS7tSegPe2Q8UUT6CDkrtSrB3utr_1QN1jTEPjrj5Ug,67519
355
- truss_chains/private_types.py,sha256=vdcl8FuVsL9JGIu_9K7fd2EW9Ytzoq8nfEx5pmuMKTA,9063
367
+ truss_chains/private_types.py,sha256=fSdlsJE0VST2rU-OyATz9hOQyK1IwUB4O4sIhMOz6CI,9727
356
368
  truss_chains/public_api.py,sha256=civY8juJU92jSGBI7zM1qMnA7hlUdCq7L8o4IOo5meA,9722
357
369
  truss_chains/public_types.py,sha256=RPr8jgKO_F_26F7H3CpwbidL-6euoKPdFHVpEIpYqrQ,29415
358
370
  truss_chains/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -361,7 +373,7 @@ truss_chains/streaming.py,sha256=DGl2LEAN67YwP7Nn9MK488KmYc4KopWmcHuE6WjyO1Q,125
361
373
  truss_chains/utils.py,sha256=LvpCG2lnN6dqPqyX3PwLH9tyjUzqQN3N4WeEFROMHak,6291
362
374
  truss_chains/deployment/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
363
375
  truss_chains/deployment/code_gen.py,sha256=397FiSNZuW59J3Ma7N9GKGfvG_87BNFAXCIV8BW41t0,32669
364
- truss_chains/deployment/deployment_client.py,sha256=4cHuvaynVCclJ6M9pw8ukhO1E2NRKohIRxftvOfNvOE,34499
376
+ truss_chains/deployment/deployment_client.py,sha256=mg9q2lw_kVCgRnv36pfLq8RLlcQwO4X_TIoyyf7eqRI,34733
365
377
  truss_chains/reference_code/reference_chainlet.py,sha256=5feSeqGtrHDbldkfZCfX2R5YbbW0Uhc35mhaP2pXrHw,1340
366
378
  truss_chains/reference_code/reference_model.py,sha256=emH3hb23E_nbP98I37PGp1Xk1hz3g3lQ00tiLo55cSM,322
367
379
  truss_chains/remote_chainlet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -369,13 +381,13 @@ truss_chains/remote_chainlet/model_skeleton.py,sha256=8ZReLOO2MLcdg7bNZ61C-6j-e6
369
381
  truss_chains/remote_chainlet/stub.py,sha256=Y2gDUzMY9WRaQNHIz-o4dfLUfFyYV9dUhIRQcfgrY8g,17209
370
382
  truss_chains/remote_chainlet/utils.py,sha256=Zn3GZRvK8f65WUa-qa-8uPFZ2pD7ukRFxbLOvT-BL0Q,24063
371
383
  truss_train/__init__.py,sha256=A3MzRPMInZfmzLvPpZI7gdKgshAVCw6bwhU-6JYU2zs,939
372
- truss_train/definitions.py,sha256=jcaVICE03iI8lBqEPe01uO3vFiMu_8pqB-j_dX-zwhI,8209
373
- truss_train/deployment.py,sha256=lWWANSuzBWu2M4oK4qD7n-oVR1JKdmw2Pn5BJQHg-Ck,3074
384
+ truss_train/definitions.py,sha256=z5s7VaK9nhj9rpZ6Yfa_FzA-XiEzpE3bxNQr6VnyK9s,8287
385
+ truss_train/deployment.py,sha256=SfDMFBOLZvH0iuTKuOpVIOn7wnATfKnAVubP9DKFeTw,3608
374
386
  truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
375
- truss_train/public_api.py,sha256=9N_NstiUlmBuLUwH_fNG_1x7OhGCytZLNvqKXBlStrM,1220
387
+ truss_train/public_api.py,sha256=22li-_qXj74e9NsvxzvCbwHjGW7M6Np5-apLuh8tMKo,1322
376
388
  truss_train/restore_from_checkpoint.py,sha256=8hdPm-WSgkt74HDPjvCjZMBpvA9MwtoYsxVjOoa7BaM,1176
377
- truss-0.11.18rc500.dist-info/METADATA,sha256=o9mIIX2iDFH8mEVD39WocUjQofwSiLA6t8IM0vVBHFI,6683
378
- truss-0.11.18rc500.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
379
- truss-0.11.18rc500.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
380
- truss-0.11.18rc500.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
381
- truss-0.11.18rc500.dist-info/RECORD,,
389
+ truss-0.11.24rc2.dist-info/METADATA,sha256=L0Pb4nyRwyXqtaxmu7ZEaABdjUpkbEz24wldAy_wstA,6681
390
+ truss-0.11.24rc2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
391
+ truss-0.11.24rc2.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
392
+ truss-0.11.24rc2.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
393
+ truss-0.11.24rc2.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -502,10 +502,13 @@ def _create_baseten_chain(
502
502
  f"Pushing Chain '{baseten_options.chain_name}' to Baseten "
503
503
  f"(publish={baseten_options.publish}, environment={baseten_options.environment})."
504
504
  )
505
- remote_provider = cast(
506
- b10_remote.BasetenRemote,
507
- remote_factory.RemoteFactory.create(remote=baseten_options.remote),
508
- )
505
+ if baseten_options.remote_provider is not None:
506
+ remote_provider = baseten_options.remote_provider
507
+ else:
508
+ remote_provider = cast(
509
+ b10_remote.BasetenRemote,
510
+ remote_factory.RemoteFactory.create(remote=baseten_options.remote),
511
+ )
509
512
 
510
513
  if user_config.settings.include_git_info or baseten_options.include_git_info:
511
514
  truss_user_env = b10_types.TrussUserEnv.collect_with_git_info(
@@ -531,6 +534,8 @@ def _create_baseten_chain(
531
534
  environment=baseten_options.environment,
532
535
  progress_bar=progress_bar,
533
536
  disable_chain_download=baseten_options.disable_chain_download,
537
+ deployment_name=baseten_options.deployment_name,
538
+ team_id=baseten_options.team_id,
534
539
  )
535
540
  return BasetenChainService(
536
541
  baseten_options.chain_name,
@@ -21,6 +21,8 @@ import pydantic
21
21
 
22
22
  from truss.base import custom_types
23
23
  from truss.base.constants import PRODUCTION_ENVIRONMENT_NAME
24
+ from truss.remote.baseten.remote import BasetenRemote
25
+ from truss.remote.remote_factory import RemoteFactory
24
26
  from truss_chains import public_types, utils
25
27
 
26
28
  TRUSS_CONFIG_CHAINS_KEY = "chains_metadata"
@@ -266,6 +268,9 @@ class PushOptionsBaseten(PushOptions):
266
268
  include_git_info: bool
267
269
  working_dir: pathlib.Path
268
270
  disable_chain_download: bool = False
271
+ deployment_name: Optional[str] = None
272
+ team_id: Optional[str] = None
273
+ remote_provider: Optional[BasetenRemote] = None
269
274
 
270
275
  @classmethod
271
276
  def create(
@@ -279,11 +284,18 @@ class PushOptionsBaseten(PushOptions):
279
284
  working_dir: pathlib.Path,
280
285
  environment: Optional[str] = None,
281
286
  disable_chain_download: bool = False,
287
+ deployment_name: Optional[str] = None,
288
+ team_id: Optional[str] = None,
289
+ remote_provider: Optional[BasetenRemote] = None,
282
290
  ) -> "PushOptionsBaseten":
283
291
  if promote and not environment:
284
292
  environment = PRODUCTION_ENVIRONMENT_NAME
285
293
  if environment:
286
294
  publish = True
295
+
296
+ if remote_provider is None and remote and not only_generate_trusses:
297
+ remote_provider = cast(BasetenRemote, RemoteFactory.create(remote=remote))
298
+
287
299
  return PushOptionsBaseten(
288
300
  remote=remote,
289
301
  chain_name=chain_name,
@@ -293,6 +305,9 @@ class PushOptionsBaseten(PushOptions):
293
305
  include_git_info=include_git_info,
294
306
  working_dir=working_dir,
295
307
  disable_chain_download=disable_chain_download,
308
+ deployment_name=deployment_name,
309
+ team_id=team_id,
310
+ remote_provider=remote_provider,
296
311
  )
297
312
 
298
313
 
@@ -99,6 +99,7 @@ class LoadCheckpointConfig(custom_types.SafeModelNoExtra):
99
99
  class CheckpointingConfig(custom_types.SafeModelNoExtra):
100
100
  enabled: bool = False
101
101
  checkpoint_path: Optional[str] = None
102
+ volume_size_gib: Optional[int] = None
102
103
 
103
104
 
104
105
  class CacheConfig(custom_types.SafeModelNoExtra):
@@ -125,7 +126,7 @@ class Runtime(custom_types.SafeModelNoExtra):
125
126
  raise ValueError(
126
127
  "Cannot set both 'enable_cache' and 'cache_config'. "
127
128
  "'enable_cache' is deprecated. Prefer migrating to 'cache_config' with "
128
- "`enabled=True` and `enable_legacy_hf_cache=True`."
129
+ "`enabled=True` and `enable_legacy_hf_mount=True`."
129
130
  )
130
131
 
131
132
  # Migrate enable_cache to cache_config if enable_cache is True
@@ -181,6 +182,7 @@ class TrainingProject(custom_types.SafeModelNoExtra):
181
182
  # TrainingProject is the wrapper around project config and job config. However, we exclude job
182
183
  # in serialization so just TrainingProject metadata is included in API requests.
183
184
  job: TrainingJob = pydantic.Field(exclude=True)
185
+ team_name: Optional[str] = None
184
186
 
185
187
 
186
188
  class Checkpoint(custom_types.ConfigModel, ABC):
truss_train/deployment.py CHANGED
@@ -9,7 +9,6 @@ from truss.remote.baseten.api import BasetenApi
9
9
  from truss.remote.baseten.core import archive_dir
10
10
  from truss.remote.baseten.remote import BasetenRemote
11
11
  from truss.remote.baseten.utils import transfer
12
- from truss_train import loader
13
12
  from truss_train.definitions import TrainingJob, TrainingProject
14
13
 
15
14
 
@@ -22,6 +21,7 @@ class S3Artifact(SafeModel):
22
21
  # to the end user via the TrainingJob SDK.
23
22
  class PreparedTrainingJob(TrainingJob):
24
23
  runtime_artifacts: List[S3Artifact] = []
24
+ truss_user_env: Optional[b10_types.TrussUserEnv] = None
25
25
 
26
26
  def model_dump(self, *args, **kwargs):
27
27
  data = super().model_dump(*args, **kwargs)
@@ -31,7 +31,12 @@ class PreparedTrainingJob(TrainingJob):
31
31
  return data
32
32
 
33
33
 
34
- def prepare_push(api: BasetenApi, config: pathlib.Path, training_job: TrainingJob):
34
+ def prepare_push(
35
+ api: BasetenApi,
36
+ config: pathlib.Path,
37
+ training_job: TrainingJob,
38
+ truss_user_env: Optional[b10_types.TrussUserEnv] = None,
39
+ ):
35
40
  # Assume config is at the root of the directory.
36
41
  archive = archive_dir(config.absolute().parent)
37
42
  credentials = api.get_blob_credentials(b10_types.BlobType.TRAIN)
@@ -49,16 +54,27 @@ def prepare_push(api: BasetenApi, config: pathlib.Path, training_job: TrainingJo
49
54
  runtime_artifacts=[
50
55
  S3Artifact(s3_key=credentials["s3_key"], s3_bucket=credentials["s3_bucket"])
51
56
  ],
57
+ truss_user_env=truss_user_env,
52
58
  )
53
59
 
54
60
 
55
- def create_training_job(
56
- remote_provider: BasetenRemote, training_project: TrainingProject, config: Path
61
+ def _upsert_project_and_create_job(
62
+ remote_provider: BasetenRemote,
63
+ training_project: TrainingProject,
64
+ config: Path,
65
+ team_id: Optional[str] = None,
57
66
  ) -> dict:
58
- project_resp = remote_provider.api.upsert_training_project(
59
- training_project=training_project
67
+ project_resp = remote_provider.upsert_training_project(
68
+ training_project=training_project, team_id=team_id
69
+ )
70
+
71
+ # Collect TrussUserEnv with git info from the config directory
72
+ working_dir = config.absolute().parent
73
+ truss_user_env = b10_types.TrussUserEnv.collect_with_git_info(working_dir)
74
+
75
+ prepared_job = prepare_push(
76
+ remote_provider.api, config, training_project.job, truss_user_env=truss_user_env
60
77
  )
61
- prepared_job = prepare_push(remote_provider.api, config, training_project.job)
62
78
 
63
79
  job_resp = remote_provider.api.create_training_job(
64
80
  project_id=project_resp["id"], job=prepared_job
@@ -66,22 +82,28 @@ def create_training_job(
66
82
  return job_resp
67
83
 
68
84
 
69
- def create_training_job_from_file(
85
+ def create_training_job(
70
86
  remote_provider: BasetenRemote,
71
87
  config: Path,
88
+ training_project: TrainingProject,
72
89
  job_name_from_cli: Optional[str] = None,
90
+ team_name: Optional[str] = None,
91
+ team_id: Optional[str] = None,
73
92
  ) -> dict:
74
- with loader.import_training_project(config) as training_project:
75
- if job_name_from_cli:
76
- if training_project.job.name:
77
- console.print(
78
- f"[bold yellow]⚠ Warning:[/bold yellow] name '{training_project.job.name}' provided in config file will be ignored. Using job name '{job_name_from_cli}' provided via --job-name flag."
79
- )
80
- training_project.job.name = job_name_from_cli
81
- job_resp = create_training_job(
82
- remote_provider=remote_provider,
83
- training_project=training_project,
84
- config=config,
85
- )
86
- job_resp["job_object"] = training_project.job
93
+ if job_name_from_cli:
94
+ if training_project.job.name:
95
+ console.print(
96
+ f"[bold yellow]⚠ Warning:[/bold yellow] name '{training_project.job.name}' provided in config file will be ignored. Using job name '{job_name_from_cli}' provided via --job-name flag."
97
+ )
98
+ training_project.job.name = job_name_from_cli
99
+ if team_name:
100
+ training_project.team_name = team_name
101
+
102
+ job_resp = _upsert_project_and_create_job(
103
+ remote_provider=remote_provider,
104
+ training_project=training_project,
105
+ config=config,
106
+ team_id=team_id,
107
+ )
108
+ job_resp["job_object"] = training_project.job
87
109
  return job_resp
truss_train/public_api.py CHANGED
@@ -3,7 +3,8 @@ from typing import cast
3
3
 
4
4
  from truss.remote.baseten.remote import BasetenRemote
5
5
  from truss.remote.remote_factory import RemoteFactory
6
- from truss_train.deployment import create_training_job_from_file
6
+ from truss_train import loader
7
+ from truss_train.deployment import create_training_job
7
8
 
8
9
 
9
10
  def push(config: Path, remote: str = "baseten"):
@@ -30,4 +31,5 @@ def push(config: Path, remote: str = "baseten"):
30
31
  remote_provider: BasetenRemote = cast(
31
32
  BasetenRemote, RemoteFactory.create(remote=remote)
32
33
  )
33
- return create_training_job_from_file(remote_provider, config)
34
+ with loader.import_training_project(config) as training_project:
35
+ return create_training_job(remote_provider, config, training_project)