truss 0.11.18rc500__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/truss_config.py +10 -3
- truss/cli/chains_commands.py +39 -1
- truss/cli/cli.py +35 -5
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +19 -143
- truss/cli/train_commands.py +69 -11
- truss/cli/utils/common.py +40 -3
- truss/remote/baseten/api.py +58 -5
- truss/remote/baseten/core.py +22 -4
- truss/remote/baseten/remote.py +24 -2
- truss/templates/control/control/helpers/inference_server_process_controller.py +3 -1
- truss/templates/server/requirements.txt +1 -1
- truss/templates/server.Dockerfile.jinja +10 -10
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +44 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +10 -1
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/METADATA +1 -1
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/RECORD +50 -38
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +9 -4
- truss_chains/private_types.py +15 -0
- truss_train/definitions.py +3 -1
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
environment_variables: {}
|
|
2
|
+
external_package_dirs: []
|
|
3
|
+
model_metadata: {}
|
|
4
|
+
model_name: Test Build Commands
|
|
5
|
+
python_version: py39
|
|
6
|
+
resources:
|
|
7
|
+
accelerator: null
|
|
8
|
+
cpu: '1'
|
|
9
|
+
memory: 2Gi
|
|
10
|
+
use_gpu: false
|
|
11
|
+
secrets: {}
|
|
12
|
+
system_packages: []
|
|
13
|
+
build_commands:
|
|
14
|
+
- sed -i 's/TEST_FIRST_VALUE/TEST_SECOND_VALUE/g' /packages/constants/constants.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
TEST_KEY = "TEST_FIRST_VALUE"
|
|
@@ -2099,3 +2099,16 @@ async def test_websocket_ping_timeout_behavior(caplog):
|
|
|
2099
2099
|
|
|
2100
2100
|
# We wait 3 seconds, so there should be ~3 PING/PONGS
|
|
2101
2101
|
assert 2 <= caplog.text.count("PING") <= 4
|
|
2102
|
+
|
|
2103
|
+
|
|
2104
|
+
@pytest.mark.integration
|
|
2105
|
+
def test_build_commands_on_model_files(test_data_path):
|
|
2106
|
+
with ensure_kill_all():
|
|
2107
|
+
truss_dir = test_data_path / "test_build_commands_truss"
|
|
2108
|
+
tr = TrussHandle(truss_dir)
|
|
2109
|
+
container, urls = tr.docker_run_for_test()
|
|
2110
|
+
time.sleep(3) # Sleeping to allow the load to finish
|
|
2111
|
+
|
|
2112
|
+
response = requests.post(urls.predict_url, json={})
|
|
2113
|
+
assert response.status_code == 200
|
|
2114
|
+
assert response.json() == "TEST_SECOND_VALUE"
|
|
@@ -1,23 +1,27 @@
|
|
|
1
1
|
truss/__init__.py,sha256=CoUcP6vx_pocyemRmpbCPlndkHhdMkABAlr0ZXVuPCk,1163
|
|
2
|
-
truss/api/__init__.py,sha256=
|
|
2
|
+
truss/api/__init__.py,sha256=6ZDLaMn6MI5kqmRZaHwgZf2jOQyxQS3H45-Zy9qeSvg,5634
|
|
3
3
|
truss/api/definitions.py,sha256=QAaIBqL59Q-R7HtLcXcoeCIWBN2HqOzApdFX0PpCq2s,1604
|
|
4
4
|
truss/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
5
|
truss/base/constants.py,sha256=yFvDVyWMqeAsWOfKUKJVUNbH26N2JZKDqgxmRSe4EWk,3664
|
|
6
6
|
truss/base/custom_types.py,sha256=FUSIT2lPOQb6gfg6IzT63YBV8r8L6NIZ0D74Fp3e_jQ,2835
|
|
7
7
|
truss/base/errors.py,sha256=zDVLEvseTChdPP0oNhBBQCtQUtZJUaof5zeWMIjqz6o,691
|
|
8
8
|
truss/base/trt_llm_config.py,sha256=rEtBVFg2QnNMxnaz11s5Z69dJB1w7Bpt48Wf6jSsVZI,33087
|
|
9
|
-
truss/base/truss_config.py,sha256=
|
|
9
|
+
truss/base/truss_config.py,sha256=woY42bYYzA3-hWjzHqvQxpJG9hsXPvHZ7LNNAjiLkAE,28675
|
|
10
10
|
truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
|
|
11
|
-
truss/cli/chains_commands.py,sha256=
|
|
12
|
-
truss/cli/cli.py,sha256=
|
|
13
|
-
truss/cli/remote_cli.py,sha256=
|
|
14
|
-
truss/cli/train_commands.py,sha256=
|
|
11
|
+
truss/cli/chains_commands.py,sha256=7meWaNVYIPmNkHRhRmW0of_vGa-McFag1XIgnUnSYVQ,18982
|
|
12
|
+
truss/cli/cli.py,sha256=iEO6JOY3yu4OtwzdMhmrxr72AboenLujWMm5WG6vKk4,31198
|
|
13
|
+
truss/cli/remote_cli.py,sha256=UI0YptsLD7WBEdi37fUoJIgCHGjrSxkiNFYS61Wlqdk,2730
|
|
14
|
+
truss/cli/train_commands.py,sha256=rL-wqqEl7CbvRKi3YkhmoFQO_59STVehlsOIC2dvmxY,22102
|
|
15
15
|
truss/cli/logs/base_watcher.py,sha256=vuqteoaMVGX34cgKcETf4X_gOkvnSnDaWz1_pbeFhqs,3343
|
|
16
16
|
truss/cli/logs/model_log_watcher.py,sha256=38vQCcNItfDrTKucvdJ10ZYLOcbGa5ZAKUqUnV4nH34,1971
|
|
17
17
|
truss/cli/logs/training_log_watcher.py,sha256=r6HRqrLnz-PiKTUXiDYYxg4ZnP8vYcXlEX1YmgHhzlo,1173
|
|
18
18
|
truss/cli/logs/utils.py,sha256=z-U_FG4BUzdZLbE3BnXb4DZQ0zt3LSZ3PiQpLaDuc3o,1031
|
|
19
|
+
truss/cli/resolvers/chain_team_resolver.py,sha256=9luCeXu4YpGn9K9snTqqB0GwzSvdUGOjA7CkXG1s-ug,3872
|
|
20
|
+
truss/cli/resolvers/model_team_resolver.py,sha256=D4QML3HK9T07XTCPiZ8JoYcqtYuXTPCVpwRPIni6kwA,4283
|
|
21
|
+
truss/cli/resolvers/training_project_team_resolver.py,sha256=UKwmuWL7OfMTbEtmGcTCmamGG_aNLLi2PzBZziBBJbo,3881
|
|
22
|
+
truss/cli/train/cache.py,sha256=S2f2x0ObdjUAyjLhzp3CPQCE5x2FbtSc5Gm_iTjlbZc,11519
|
|
19
23
|
truss/cli/train/common.py,sha256=xTR41U5FeSndXfNBBHF9wF5XwZH1sOIVFlv-XHjsKIU,1547
|
|
20
|
-
truss/cli/train/core.py,sha256=
|
|
24
|
+
truss/cli/train/core.py,sha256=ahyjfjkCQSH9Eqb_6pcrRM0Td52gsUcnnTkrAsTDpLE,22336
|
|
21
25
|
truss/cli/train/deploy_from_checkpoint_config.yml,sha256=mktaVrfhN8Kjx1UveC4xr-gTW-kjwbHvq6bx_LpO-Wg,371
|
|
22
26
|
truss/cli/train/deploy_from_checkpoint_config_whisper.yml,sha256=6GbOorYC8ml0UyOUvuBpFO_fuYtYE646JqsalR-D4oY,406
|
|
23
27
|
truss/cli/train/metrics_watcher.py,sha256=smz-zrEsBj_-wJHI0pAZ-EAPrvfCWzq1eQjGiFNM-Mk,12755
|
|
@@ -29,7 +33,7 @@ truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py,sha256=r_IKMlqe
|
|
|
29
33
|
truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py,sha256=f8_UB7CF6Y3MOhaf8Zim0heNiauOOAmA-WqsyP3X9mk,386
|
|
30
34
|
truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py,sha256=YKVMy3xogmbubJNrN_1LCR6xdHj9lBOAlKgMxWHdlQM,1115
|
|
31
35
|
truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py,sha256=eEo4ahRTsvRKOxCDj7DAHIUyUZWicn39VmC1PYS0pCY,314
|
|
32
|
-
truss/cli/utils/common.py,sha256=
|
|
36
|
+
truss/cli/utils/common.py,sha256=keB8t-IVq-Wgel9laOZ_Ag4p0KxqLEBY2MKRwWncwUw,7869
|
|
33
37
|
truss/cli/utils/output.py,sha256=GNjU85ZAMp5BI6Yij5wYXcaAvpm_kmHV0nHNmdkMxb0,646
|
|
34
38
|
truss/cli/utils/self_upgrade.py,sha256=eTJZA4Wc8uUp4Qh6viRQp6bZm--wnQp7KWe5KRRpPtg,5427
|
|
35
39
|
truss/contexts/docker_build_setup.py,sha256=cF4ExZgtYvrWxvyCAaUZUvV_DB_7__MqVomUDpalvKo,3925
|
|
@@ -52,12 +56,12 @@ truss/patch/truss_dir_patch_applier.py,sha256=ALnaVnu96g0kF2UmGuBFTua3lrXpwAy4sG
|
|
|
52
56
|
truss/remote/remote_factory.py,sha256=-0gLh_yIyNDgD48Q6sR8Yo5dOMQg84lrHRvn_XR0n4s,3585
|
|
53
57
|
truss/remote/truss_remote.py,sha256=TEe6h6by5-JLy7PMFsDN2QxIY5FmdIYN3bKvHHl02xM,8440
|
|
54
58
|
truss/remote/baseten/__init__.py,sha256=XNqJW1zyp143XQc6-7XVwsUA_Q_ZJv_ausn1_Ohtw9Y,176
|
|
55
|
-
truss/remote/baseten/api.py,sha256=
|
|
59
|
+
truss/remote/baseten/api.py,sha256=StxIk5k-88DYEd0YU7yKxuMXpcODSdgbqD7VjwRK--Q,32430
|
|
56
60
|
truss/remote/baseten/auth.py,sha256=tI7s6cI2EZgzpMIzrdbILHyGwiHDnmoKf_JBhJXT55E,776
|
|
57
|
-
truss/remote/baseten/core.py,sha256=
|
|
61
|
+
truss/remote/baseten/core.py,sha256=U2JwM-HFTf_KmVRnHeuGJU5g4bvi3auNfmTjkHlw48I,24210
|
|
58
62
|
truss/remote/baseten/custom_types.py,sha256=g7yWkE8p6uIAG5JqgfELFGHzjFLvO7vLPzbe-yl1nYs,4735
|
|
59
63
|
truss/remote/baseten/error.py,sha256=3TNTwwPqZnr4NRd9Sl6SfLUQR2fz9l6akDPpOntTpzA,578
|
|
60
|
-
truss/remote/baseten/remote.py,sha256=
|
|
64
|
+
truss/remote/baseten/remote.py,sha256=7xCm_zmfiWfnZLbJi4Txxq7xezN8B235a1tO0rVeTMQ,24372
|
|
61
65
|
truss/remote/baseten/rest_client.py,sha256=_t3CWsWARt2u0C0fDsF4rtvkkHe-lH7KXoPxWXAkKd4,1185
|
|
62
66
|
truss/remote/baseten/service.py,sha256=HMaKiYbr2Mzv4BfXF9QkJ8H3Wwrq3LOMpFt9js4t0rs,5834
|
|
63
67
|
truss/remote/baseten/utils/status.py,sha256=jputc9N9AHXxUuW4KOk6mcZKzQ_gOBOe5BSx9K0DxPY,1266
|
|
@@ -72,7 +76,7 @@ truss/templates/cache_requirements.txt,sha256=xoPoJ-OVnf1z6oq_RVM3vCr3ionByyqMLj
|
|
|
72
76
|
truss/templates/copy_cache_files.Dockerfile.jinja,sha256=Os5zFdYLZ_AfCRGq4RcpVTObOTwL7zvmwYcvOzd_Zqo,126
|
|
73
77
|
truss/templates/docker_server_requirements.txt,sha256=PyhOPKAmKW1N2vLvTfLMwsEtuGpoRrbWuNo7tT6v2Mc,18
|
|
74
78
|
truss/templates/no_build.Dockerfile.jinja,sha256=8x2PJUxr_gHai0St8ue2aWyih36t8kBytXMGr_5LG4w,35
|
|
75
|
-
truss/templates/server.Dockerfile.jinja,sha256=
|
|
79
|
+
truss/templates/server.Dockerfile.jinja,sha256=wApuWMfyXrDpPlIodm9bK-KNrsi7fVdNKm-iw7mi7bw,7075
|
|
76
80
|
truss/templates/control/requirements.txt,sha256=tJGr83WoE0CZm2FrloZ9VScK84q-_FTuVXjDYrexhW0,250
|
|
77
81
|
truss/templates/control/control/application.py,sha256=5Kam6M-XtfKGaXQz8cc3d0bwDkB80o2MskABWROx1gk,5321
|
|
78
82
|
truss/templates/control/control/endpoints.py,sha256=KzqsLVNJE6r6TCPW8D5FMCtsfHadTwR15A3z_viGxmM,11782
|
|
@@ -81,7 +85,7 @@ truss/templates/control/control/helpers/context_managers.py,sha256=W6dyFgLBhPa5m
|
|
|
81
85
|
truss/templates/control/control/helpers/custom_types.py,sha256=n_lTudtLTpy4oPV3aDdJ4X2rh3KCV5btYO9UnTeUouQ,5471
|
|
82
86
|
truss/templates/control/control/helpers/errors.py,sha256=LddFuQywuCCdYTEnFT5EalxdWos4uR89rbhMakCy2bA,970
|
|
83
87
|
truss/templates/control/control/helpers/inference_server_controller.py,sha256=anFm7FwkGaUnYRQo2dxXohA9__c-XAVAqfA1EL2bIIY,6324
|
|
84
|
-
truss/templates/control/control/helpers/inference_server_process_controller.py,sha256=
|
|
88
|
+
truss/templates/control/control/helpers/inference_server_process_controller.py,sha256=RkkLA2_bKUl1iqwrwlH7gZPZUuUVFfJDy74EVWkRVOI,4579
|
|
85
89
|
truss/templates/control/control/helpers/inference_server_starter.py,sha256=Fz2Puijro6Cc5cvTsAqOveNJbBQR_ARYJXl4lvETJ8Y,2633
|
|
86
90
|
truss/templates/control/control/helpers/truss_patch/__init__.py,sha256=CXZdUV_ylqLTJrKuFpvSnUT6PUFrZrMF2y6jiHbdaKU,998
|
|
87
91
|
truss/templates/control/control/helpers/truss_patch/model_code_patch_applier.py,sha256=LTIIADLz0wRn7V49j64dU1U7Hbta9YLde3pb5YZWvzQ,2001
|
|
@@ -97,7 +101,7 @@ truss/templates/docker_server/supervisord.conf.jinja,sha256=AliMMd6bNn-oCYIB8Gum
|
|
|
97
101
|
truss/templates/server/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
98
102
|
truss/templates/server/main.py,sha256=kWXrdD8z8IpamyWxc8qcvd5ck9gM1Kz2QH5qHJCnmOQ,222
|
|
99
103
|
truss/templates/server/model_wrapper.py,sha256=k75VVISwwlsx5EGb82UZsu8kCM_i6Yi3-Hd0-Kpm1yo,42055
|
|
100
|
-
truss/templates/server/requirements.txt,sha256=
|
|
104
|
+
truss/templates/server/requirements.txt,sha256=1EvdUuD9Fyy_xgo0a5WuAxMZaFRTZtozt79PB5KkGcI,672
|
|
101
105
|
truss/templates/server/truss_server.py,sha256=YKcG7Sr0T_8XjIC3GK9vBwoNb8oxVgwic3-3Ikzpmgw,19781
|
|
102
106
|
truss/templates/server/common/__init__.py,sha256=qHIqr68L5Tn4mV6S-PbORpcuJ4jmtBR8aCuRTIWDvNo,85
|
|
103
107
|
truss/templates/server/common/errors.py,sha256=My0P6-Y7imVTICIhazHT0vlSu3XJDH7As06OyVzu4Do,8589
|
|
@@ -112,7 +116,7 @@ truss/templates/shared/lazy_data_resolver.py,sha256=HxrZz6X30j2LbsExYSqhuOGoYEff
|
|
|
112
116
|
truss/templates/shared/log_config.py,sha256=l9udyu4VKHZePlfK9LQEd5TOUUodPuehypsXRSUL4Ac,5411
|
|
113
117
|
truss/templates/shared/secrets_resolver.py,sha256=3prDe3Q06NTmUEe7KCW-W4TD1CzGck9lpDG789209z4,2110
|
|
114
118
|
truss/templates/shared/serialization.py,sha256=_WC_2PPkRi-MdTwxwjG8LKQptnHi4sANfpOlKWevqWc,3736
|
|
115
|
-
truss/templates/shared/util.py,sha256=
|
|
119
|
+
truss/templates/shared/util.py,sha256=NoIel4ES73fHzjY7tknxAFhZQGh11tqL45Ye6-RTh4k,2281
|
|
116
120
|
truss/templates/train/config.py,sha256=aQJ3lsyVRlq6edjjZq4_Anz1bZVwkjLdclmZPJTdo1k,1626
|
|
117
121
|
truss/templates/train/run.sh,sha256=2rimigJOn6yg4DguRfOJWkzm77X-meNSYXnidLafqNg,346
|
|
118
122
|
truss/templates/trtllm-audio/model/model.py,sha256=o38QqW57b1lf8O_td1lW_AojZZ8R_qAZCgzOWtoIse8,1619
|
|
@@ -127,28 +131,32 @@ truss/templates/trtllm-audio/packages/whisper_trt/utils.py,sha256=pi-c486yPe85Te
|
|
|
127
131
|
truss/templates/trtllm-briton/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
128
132
|
truss/templates/trtllm-briton/src/extension.py,sha256=6qwYPIYQEmXd2xz50-v80Nilc_xLAMgdYkHu2JWboH4,2655
|
|
129
133
|
truss/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
130
|
-
truss/tests/conftest.py,sha256=
|
|
134
|
+
truss/tests/conftest.py,sha256=0izxvPZPdlRZBIvxH0z0ISNcdq3Et3OQ1BLDwevrNpM,29928
|
|
131
135
|
truss/tests/helpers.py,sha256=DKnUGzmt-k4IN_wSnlAXbVYCiEt58xFzFmmuCDSQ0dg,555
|
|
132
136
|
truss/tests/test_build.py,sha256=Wq4sM9tmZFVTCN3YljvOcn04Kkj-L-41Tlw0DKQ8Z7c,709
|
|
133
|
-
truss/tests/test_config.py,sha256=
|
|
137
|
+
truss/tests/test_config.py,sha256=bxIZVELyX31y69lPRdv2On8NiLD1M8IyDtQCv78kN9c,30942
|
|
134
138
|
truss/tests/test_context_builder_image.py,sha256=fVZNJSzZNiWa7Dr1X_VhhMJtyJ5HzsLaPglOr6NV2CA,1105
|
|
135
139
|
truss/tests/test_control_truss_patching.py,sha256=qQOUfyF1MorZ_obOvPJK9utI4HUAzgT6YBS-fo90TEw,14497
|
|
136
140
|
truss/tests/test_custom_server.py,sha256=GP2qMgnqxJMPRtfEciqbhBcG0_JUK7gNL7nrXPGrSLg,1305
|
|
137
141
|
truss/tests/test_docker.py,sha256=3RI6jEC9CVQsKj83s_gOBl3EkdOaov-KEX4IihfMJW4,523
|
|
138
|
-
truss/tests/test_model_inference.py,sha256=
|
|
142
|
+
truss/tests/test_model_inference.py,sha256=JbLmBb3N0oRr58LdVmxaNsMgK8aeSfjDyUfsabZPD9Q,76688
|
|
139
143
|
truss/tests/test_model_schema.py,sha256=Bw28CZ4D0JQOkYdBQJZvgryeW0TRn7Axketp5kvZ_t4,14219
|
|
140
144
|
truss/tests/test_testing_utilities_for_other_tests.py,sha256=YqIKflnd_BUMYaDBSkX76RWiWGWM_UlC2IoT4NngMqE,3048
|
|
141
145
|
truss/tests/test_truss_gatherer.py,sha256=bn288OEkC49YY0mhly4cAl410ktZPfElNdWwZy82WfA,1261
|
|
142
146
|
truss/tests/test_truss_handle.py,sha256=-xz9VXkecXDTslmQZ-dmUmQLnvD0uumRqHS2uvGlMBA,30750
|
|
143
147
|
truss/tests/test_util.py,sha256=hs1bNMkXKEdoPRx4Nw-NAEdoibR92OubZuADGmbiYsQ,1344
|
|
144
|
-
truss/tests/cli/test_chains_cli.py,sha256=
|
|
145
|
-
truss/tests/cli/test_cli.py,sha256=
|
|
146
|
-
truss/tests/cli/
|
|
148
|
+
truss/tests/cli/test_chains_cli.py,sha256=fmMYwA5XdA7lngyqnx4-3kKWxRB3vl094VjI44gVmzA,4582
|
|
149
|
+
truss/tests/cli/test_cli.py,sha256=kNQjwvUdsGhkDY6LbLT7X1AysfgVAvzRwF36swP-97E,5214
|
|
150
|
+
truss/tests/cli/test_cli_utils_common.py,sha256=X9eU5MngmQPs_R3hqf-t2TlGVyZUdxSwYfYv7nUC9R4,455
|
|
151
|
+
truss/tests/cli/test_model_team_resolver.py,sha256=Dsk0u5G66LaHhCpt-tWpxvKALXTINMzDP-mylN351Ak,10370
|
|
152
|
+
truss/tests/cli/chains/test_chains_team_parameter.py,sha256=psnb7batCG0jTUUhi_1dD4ji-XcqH0EKtEqoTvIsm3w,18816
|
|
153
|
+
truss/tests/cli/train/test_cache_view.py,sha256=912zvkbI6m3ffGVqUCgRcoB1cuaUPxvo3rSbmwQrO4E,30213
|
|
147
154
|
truss/tests/cli/train/test_deploy_checkpoints.py,sha256=Ndkd9YxEgDLf3zLAZYH0myFK_wkKTz0oGZ57yWQt_l8,10100
|
|
148
|
-
truss/tests/cli/train/test_train_cli_core.py,sha256=
|
|
155
|
+
truss/tests/cli/train/test_train_cli_core.py,sha256=Fabmz8bDPEElgUnCRxORyFVl9v8LzORCPrJgrDnFP10,16177
|
|
149
156
|
truss/tests/cli/train/test_train_init.py,sha256=SRAZvvD5-PWYlpHHek2MftYTA4I3ZHi7gniHl2fYV98,17464
|
|
157
|
+
truss/tests/cli/train/test_train_team_parameter.py,sha256=kIth6iO7zd8MHqzhq2x3jFFkhxpYaY-V0bHUkaSvpHI,18748
|
|
150
158
|
truss/tests/cli/train/resources/test_deploy_from_checkpoint_config.yml,sha256=GF7r9l0KaeXiUYCPSBpeMPd2QG6PeWWyI12NdbqLOgc,1930
|
|
151
|
-
truss/tests/contexts/image_builder/test_serving_image_builder.py,sha256=
|
|
159
|
+
truss/tests/contexts/image_builder/test_serving_image_builder.py,sha256=okn0Dx7Pf976YIcRsMLRXqiIks81yAa9v1HgiDE9uv8,22512
|
|
152
160
|
truss/tests/contexts/local_loader/test_load_local.py,sha256=D1qMH2IpYA2j5009v50QMgUnKdeOsX15ndkwXe10a4E,801
|
|
153
161
|
truss/tests/contexts/local_loader/test_truss_module_finder.py,sha256=oN1K2lg3ATHY5yOVUTfQIaSqusTF9I2wFaYaTSo5-O4,5342
|
|
154
162
|
truss/tests/local/test_local_config_handler.py,sha256=aLvcOyfppskA2MziVLy_kMcagjxMpO4mjar9zxUN6g0,2245
|
|
@@ -161,11 +169,11 @@ truss/tests/patch/test_types.py,sha256=OUVDiLckbjjjEN49I4hm62emOTAr8lv_QooJrmXxs
|
|
|
161
169
|
truss/tests/remote/test_remote_factory.py,sha256=S-iZlF5Pf5SDoFUnMlZXy9iRMkosVgwLd22evzWlFr0,4842
|
|
162
170
|
truss/tests/remote/test_truss_remote.py,sha256=Rguyrnbx5RlbPJHFfCtsRtX1czAJ9Fo0aeC5EWRVkGw,2726
|
|
163
171
|
truss/tests/remote/baseten/conftest.py,sha256=vNk0nfDB7XdmqatOMhjdANCWFGYM4VwSHVKlaBO2PPk,442
|
|
164
|
-
truss/tests/remote/baseten/test_api.py,sha256=
|
|
172
|
+
truss/tests/remote/baseten/test_api.py,sha256=bRTq4r1KKvVtAXTwo-HOUvDu0mdhuy65_zqfGlT3UhM,20327
|
|
165
173
|
truss/tests/remote/baseten/test_auth.py,sha256=ttu4bDnmwGfo3oiNut4HVGnh-QnjAefwZJctiibQJKY,669
|
|
166
|
-
truss/tests/remote/baseten/test_chain_upload.py,sha256=
|
|
167
|
-
truss/tests/remote/baseten/test_core.py,sha256=
|
|
168
|
-
truss/tests/remote/baseten/test_remote.py,sha256=
|
|
174
|
+
truss/tests/remote/baseten/test_chain_upload.py,sha256=D_Uhgi8YfvYpF2vAcFFutCT-gDPWM4YobHdRs-L87n8,11409
|
|
175
|
+
truss/tests/remote/baseten/test_core.py,sha256=oi9dSmlkJWO88LwDTULquNlQbnEyxHnPYehm8ph01kg,29253
|
|
176
|
+
truss/tests/remote/baseten/test_remote.py,sha256=NHUDEU3UrSKYVqGda0Q2h9DhIfjVE8tscVde0dvSi5M,21476
|
|
169
177
|
truss/tests/remote/baseten/test_service.py,sha256=ehbGkzzSPdLN7JHxc0O9YDPfzzKqU8OBzJGjRdw08zE,3786
|
|
170
178
|
truss/tests/templates/control/control/conftest.py,sha256=euDFh0AhcHP-vAmTzi1Qj3lymnplDTgvtbt4Ez_lfpw,654
|
|
171
179
|
truss/tests/templates/control/control/test_endpoints.py,sha256=HIlRYOicsdHD8r_V5gHpZWybDC26uwXJfbvCohdE3HI,3751
|
|
@@ -226,6 +234,10 @@ truss/tests/test_data/test_build_commands_failure/__init__.py,sha256=47DEQpj8HBS
|
|
|
226
234
|
truss/tests/test_data/test_build_commands_failure/config.yaml,sha256=xosGlR8QZNo-eE1kOj6_4wVrkBD1Y928NGcFeK8Lo2g,259
|
|
227
235
|
truss/tests/test_data/test_build_commands_failure/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
228
236
|
truss/tests/test_data/test_build_commands_failure/model/model.py,sha256=ELYdtI0UT0T45c1yfnSsc4LvQHQn66e50UMj9RYEm1g,502
|
|
237
|
+
truss/tests/test_data/test_build_commands_truss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
238
|
+
truss/tests/test_data/test_build_commands_truss/config.yaml,sha256=NaAlcxR9QRew2n3C5uMhreJdirjuP0OrhhTXuEWyajQ,330
|
|
239
|
+
truss/tests/test_data/test_build_commands_truss/model/model.py,sha256=BMgwja9pg_GQCPp08kh8DiXKhMF7NAQK2U7aiAYY8X8,199
|
|
240
|
+
truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py,sha256=WelkZEadCUlo8BkCYwQt2XPhTuCLbey7v10sZMEqClo,30
|
|
229
241
|
truss/tests/test_data/test_concurrency_truss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
230
242
|
truss/tests/test_data/test_concurrency_truss/config.yaml,sha256=8a1tsXlHn8IVbu_X3anP93YfkkaFxJ6wwWVI2t-q3UA,111
|
|
231
243
|
truss/tests/test_data/test_concurrency_truss/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -301,7 +313,7 @@ truss/tests/test_data/test_truss/packages/__init__.py,sha256=47DEQpj8HBSa-_TImW-
|
|
|
301
313
|
truss/tests/test_data/test_truss/packages/test_package/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
302
314
|
truss/tests/test_data/test_truss/packages/test_package/test.py,sha256=Crrh4K5yghbuRJk8Wjp1X4scOH2Uf8TE9yyrDkqEIUs,6
|
|
303
315
|
truss/tests/test_data/test_truss_server_model_cache_v1/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
304
|
-
truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml,sha256=
|
|
316
|
+
truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml,sha256=Cjdm2tQSgtA6K7xSn3gPx0UkorMP8IgKq0ixJm2ekWg,401
|
|
305
317
|
truss/tests/test_data/test_truss_server_model_cache_v1/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
306
318
|
truss/tests/test_data/test_truss_server_model_cache_v1/model/model.py,sha256=PzIUuHohTxl7rqFpKNF6Gx2t9cfYhc_T6_a8FEPygyo,448
|
|
307
319
|
truss/tests/test_data/test_truss_server_model_cache_v2/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -352,7 +364,7 @@ truss/util/requirements.py,sha256=6T4nVV_NbSl3mAEo-CAk3JFmyJ_RJD768QaR55RdUJQ,69
|
|
|
352
364
|
truss/util/user_config.py,sha256=CvBf5oouNyfdcFXOg3HFhELVW-THiuwyOYdW3aTxdHw,9130
|
|
353
365
|
truss_chains/__init__.py,sha256=QDw1YwdqMaQpz5Oltu2Eq2vzEX9fDrMoqnhtbeh60i4,1278
|
|
354
366
|
truss_chains/framework.py,sha256=CS7tSegPe2Q8UUT6CDkrtSrB3utr_1QN1jTEPjrj5Ug,67519
|
|
355
|
-
truss_chains/private_types.py,sha256=
|
|
367
|
+
truss_chains/private_types.py,sha256=fSdlsJE0VST2rU-OyATz9hOQyK1IwUB4O4sIhMOz6CI,9727
|
|
356
368
|
truss_chains/public_api.py,sha256=civY8juJU92jSGBI7zM1qMnA7hlUdCq7L8o4IOo5meA,9722
|
|
357
369
|
truss_chains/public_types.py,sha256=RPr8jgKO_F_26F7H3CpwbidL-6euoKPdFHVpEIpYqrQ,29415
|
|
358
370
|
truss_chains/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -361,7 +373,7 @@ truss_chains/streaming.py,sha256=DGl2LEAN67YwP7Nn9MK488KmYc4KopWmcHuE6WjyO1Q,125
|
|
|
361
373
|
truss_chains/utils.py,sha256=LvpCG2lnN6dqPqyX3PwLH9tyjUzqQN3N4WeEFROMHak,6291
|
|
362
374
|
truss_chains/deployment/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
363
375
|
truss_chains/deployment/code_gen.py,sha256=397FiSNZuW59J3Ma7N9GKGfvG_87BNFAXCIV8BW41t0,32669
|
|
364
|
-
truss_chains/deployment/deployment_client.py,sha256=
|
|
376
|
+
truss_chains/deployment/deployment_client.py,sha256=mg9q2lw_kVCgRnv36pfLq8RLlcQwO4X_TIoyyf7eqRI,34733
|
|
365
377
|
truss_chains/reference_code/reference_chainlet.py,sha256=5feSeqGtrHDbldkfZCfX2R5YbbW0Uhc35mhaP2pXrHw,1340
|
|
366
378
|
truss_chains/reference_code/reference_model.py,sha256=emH3hb23E_nbP98I37PGp1Xk1hz3g3lQ00tiLo55cSM,322
|
|
367
379
|
truss_chains/remote_chainlet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -369,13 +381,13 @@ truss_chains/remote_chainlet/model_skeleton.py,sha256=8ZReLOO2MLcdg7bNZ61C-6j-e6
|
|
|
369
381
|
truss_chains/remote_chainlet/stub.py,sha256=Y2gDUzMY9WRaQNHIz-o4dfLUfFyYV9dUhIRQcfgrY8g,17209
|
|
370
382
|
truss_chains/remote_chainlet/utils.py,sha256=Zn3GZRvK8f65WUa-qa-8uPFZ2pD7ukRFxbLOvT-BL0Q,24063
|
|
371
383
|
truss_train/__init__.py,sha256=A3MzRPMInZfmzLvPpZI7gdKgshAVCw6bwhU-6JYU2zs,939
|
|
372
|
-
truss_train/definitions.py,sha256=
|
|
373
|
-
truss_train/deployment.py,sha256=
|
|
384
|
+
truss_train/definitions.py,sha256=z5s7VaK9nhj9rpZ6Yfa_FzA-XiEzpE3bxNQr6VnyK9s,8287
|
|
385
|
+
truss_train/deployment.py,sha256=SfDMFBOLZvH0iuTKuOpVIOn7wnATfKnAVubP9DKFeTw,3608
|
|
374
386
|
truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
|
|
375
|
-
truss_train/public_api.py,sha256=
|
|
387
|
+
truss_train/public_api.py,sha256=22li-_qXj74e9NsvxzvCbwHjGW7M6Np5-apLuh8tMKo,1322
|
|
376
388
|
truss_train/restore_from_checkpoint.py,sha256=8hdPm-WSgkt74HDPjvCjZMBpvA9MwtoYsxVjOoa7BaM,1176
|
|
377
|
-
truss-0.11.
|
|
378
|
-
truss-0.11.
|
|
379
|
-
truss-0.11.
|
|
380
|
-
truss-0.11.
|
|
381
|
-
truss-0.11.
|
|
389
|
+
truss-0.11.24rc2.dist-info/METADATA,sha256=L0Pb4nyRwyXqtaxmu7ZEaABdjUpkbEz24wldAy_wstA,6681
|
|
390
|
+
truss-0.11.24rc2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
391
|
+
truss-0.11.24rc2.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
|
|
392
|
+
truss-0.11.24rc2.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
|
|
393
|
+
truss-0.11.24rc2.dist-info/RECORD,,
|
|
@@ -502,10 +502,13 @@ def _create_baseten_chain(
|
|
|
502
502
|
f"Pushing Chain '{baseten_options.chain_name}' to Baseten "
|
|
503
503
|
f"(publish={baseten_options.publish}, environment={baseten_options.environment})."
|
|
504
504
|
)
|
|
505
|
-
remote_provider
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
505
|
+
if baseten_options.remote_provider is not None:
|
|
506
|
+
remote_provider = baseten_options.remote_provider
|
|
507
|
+
else:
|
|
508
|
+
remote_provider = cast(
|
|
509
|
+
b10_remote.BasetenRemote,
|
|
510
|
+
remote_factory.RemoteFactory.create(remote=baseten_options.remote),
|
|
511
|
+
)
|
|
509
512
|
|
|
510
513
|
if user_config.settings.include_git_info or baseten_options.include_git_info:
|
|
511
514
|
truss_user_env = b10_types.TrussUserEnv.collect_with_git_info(
|
|
@@ -531,6 +534,8 @@ def _create_baseten_chain(
|
|
|
531
534
|
environment=baseten_options.environment,
|
|
532
535
|
progress_bar=progress_bar,
|
|
533
536
|
disable_chain_download=baseten_options.disable_chain_download,
|
|
537
|
+
deployment_name=baseten_options.deployment_name,
|
|
538
|
+
team_id=baseten_options.team_id,
|
|
534
539
|
)
|
|
535
540
|
return BasetenChainService(
|
|
536
541
|
baseten_options.chain_name,
|
truss_chains/private_types.py
CHANGED
|
@@ -21,6 +21,8 @@ import pydantic
|
|
|
21
21
|
|
|
22
22
|
from truss.base import custom_types
|
|
23
23
|
from truss.base.constants import PRODUCTION_ENVIRONMENT_NAME
|
|
24
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
25
|
+
from truss.remote.remote_factory import RemoteFactory
|
|
24
26
|
from truss_chains import public_types, utils
|
|
25
27
|
|
|
26
28
|
TRUSS_CONFIG_CHAINS_KEY = "chains_metadata"
|
|
@@ -266,6 +268,9 @@ class PushOptionsBaseten(PushOptions):
|
|
|
266
268
|
include_git_info: bool
|
|
267
269
|
working_dir: pathlib.Path
|
|
268
270
|
disable_chain_download: bool = False
|
|
271
|
+
deployment_name: Optional[str] = None
|
|
272
|
+
team_id: Optional[str] = None
|
|
273
|
+
remote_provider: Optional[BasetenRemote] = None
|
|
269
274
|
|
|
270
275
|
@classmethod
|
|
271
276
|
def create(
|
|
@@ -279,11 +284,18 @@ class PushOptionsBaseten(PushOptions):
|
|
|
279
284
|
working_dir: pathlib.Path,
|
|
280
285
|
environment: Optional[str] = None,
|
|
281
286
|
disable_chain_download: bool = False,
|
|
287
|
+
deployment_name: Optional[str] = None,
|
|
288
|
+
team_id: Optional[str] = None,
|
|
289
|
+
remote_provider: Optional[BasetenRemote] = None,
|
|
282
290
|
) -> "PushOptionsBaseten":
|
|
283
291
|
if promote and not environment:
|
|
284
292
|
environment = PRODUCTION_ENVIRONMENT_NAME
|
|
285
293
|
if environment:
|
|
286
294
|
publish = True
|
|
295
|
+
|
|
296
|
+
if remote_provider is None and remote and not only_generate_trusses:
|
|
297
|
+
remote_provider = cast(BasetenRemote, RemoteFactory.create(remote=remote))
|
|
298
|
+
|
|
287
299
|
return PushOptionsBaseten(
|
|
288
300
|
remote=remote,
|
|
289
301
|
chain_name=chain_name,
|
|
@@ -293,6 +305,9 @@ class PushOptionsBaseten(PushOptions):
|
|
|
293
305
|
include_git_info=include_git_info,
|
|
294
306
|
working_dir=working_dir,
|
|
295
307
|
disable_chain_download=disable_chain_download,
|
|
308
|
+
deployment_name=deployment_name,
|
|
309
|
+
team_id=team_id,
|
|
310
|
+
remote_provider=remote_provider,
|
|
296
311
|
)
|
|
297
312
|
|
|
298
313
|
|
truss_train/definitions.py
CHANGED
|
@@ -99,6 +99,7 @@ class LoadCheckpointConfig(custom_types.SafeModelNoExtra):
|
|
|
99
99
|
class CheckpointingConfig(custom_types.SafeModelNoExtra):
|
|
100
100
|
enabled: bool = False
|
|
101
101
|
checkpoint_path: Optional[str] = None
|
|
102
|
+
volume_size_gib: Optional[int] = None
|
|
102
103
|
|
|
103
104
|
|
|
104
105
|
class CacheConfig(custom_types.SafeModelNoExtra):
|
|
@@ -125,7 +126,7 @@ class Runtime(custom_types.SafeModelNoExtra):
|
|
|
125
126
|
raise ValueError(
|
|
126
127
|
"Cannot set both 'enable_cache' and 'cache_config'. "
|
|
127
128
|
"'enable_cache' is deprecated. Prefer migrating to 'cache_config' with "
|
|
128
|
-
"`enabled=True` and `
|
|
129
|
+
"`enabled=True` and `enable_legacy_hf_mount=True`."
|
|
129
130
|
)
|
|
130
131
|
|
|
131
132
|
# Migrate enable_cache to cache_config if enable_cache is True
|
|
@@ -181,6 +182,7 @@ class TrainingProject(custom_types.SafeModelNoExtra):
|
|
|
181
182
|
# TrainingProject is the wrapper around project config and job config. However, we exclude job
|
|
182
183
|
# in serialization so just TrainingProject metadata is included in API requests.
|
|
183
184
|
job: TrainingJob = pydantic.Field(exclude=True)
|
|
185
|
+
team_name: Optional[str] = None
|
|
184
186
|
|
|
185
187
|
|
|
186
188
|
class Checkpoint(custom_types.ConfigModel, ABC):
|
truss_train/deployment.py
CHANGED
|
@@ -9,7 +9,6 @@ from truss.remote.baseten.api import BasetenApi
|
|
|
9
9
|
from truss.remote.baseten.core import archive_dir
|
|
10
10
|
from truss.remote.baseten.remote import BasetenRemote
|
|
11
11
|
from truss.remote.baseten.utils import transfer
|
|
12
|
-
from truss_train import loader
|
|
13
12
|
from truss_train.definitions import TrainingJob, TrainingProject
|
|
14
13
|
|
|
15
14
|
|
|
@@ -22,6 +21,7 @@ class S3Artifact(SafeModel):
|
|
|
22
21
|
# to the end user via the TrainingJob SDK.
|
|
23
22
|
class PreparedTrainingJob(TrainingJob):
|
|
24
23
|
runtime_artifacts: List[S3Artifact] = []
|
|
24
|
+
truss_user_env: Optional[b10_types.TrussUserEnv] = None
|
|
25
25
|
|
|
26
26
|
def model_dump(self, *args, **kwargs):
|
|
27
27
|
data = super().model_dump(*args, **kwargs)
|
|
@@ -31,7 +31,12 @@ class PreparedTrainingJob(TrainingJob):
|
|
|
31
31
|
return data
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
def prepare_push(
|
|
34
|
+
def prepare_push(
|
|
35
|
+
api: BasetenApi,
|
|
36
|
+
config: pathlib.Path,
|
|
37
|
+
training_job: TrainingJob,
|
|
38
|
+
truss_user_env: Optional[b10_types.TrussUserEnv] = None,
|
|
39
|
+
):
|
|
35
40
|
# Assume config is at the root of the directory.
|
|
36
41
|
archive = archive_dir(config.absolute().parent)
|
|
37
42
|
credentials = api.get_blob_credentials(b10_types.BlobType.TRAIN)
|
|
@@ -49,16 +54,27 @@ def prepare_push(api: BasetenApi, config: pathlib.Path, training_job: TrainingJo
|
|
|
49
54
|
runtime_artifacts=[
|
|
50
55
|
S3Artifact(s3_key=credentials["s3_key"], s3_bucket=credentials["s3_bucket"])
|
|
51
56
|
],
|
|
57
|
+
truss_user_env=truss_user_env,
|
|
52
58
|
)
|
|
53
59
|
|
|
54
60
|
|
|
55
|
-
def
|
|
56
|
-
remote_provider: BasetenRemote,
|
|
61
|
+
def _upsert_project_and_create_job(
|
|
62
|
+
remote_provider: BasetenRemote,
|
|
63
|
+
training_project: TrainingProject,
|
|
64
|
+
config: Path,
|
|
65
|
+
team_id: Optional[str] = None,
|
|
57
66
|
) -> dict:
|
|
58
|
-
project_resp = remote_provider.
|
|
59
|
-
training_project=training_project
|
|
67
|
+
project_resp = remote_provider.upsert_training_project(
|
|
68
|
+
training_project=training_project, team_id=team_id
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Collect TrussUserEnv with git info from the config directory
|
|
72
|
+
working_dir = config.absolute().parent
|
|
73
|
+
truss_user_env = b10_types.TrussUserEnv.collect_with_git_info(working_dir)
|
|
74
|
+
|
|
75
|
+
prepared_job = prepare_push(
|
|
76
|
+
remote_provider.api, config, training_project.job, truss_user_env=truss_user_env
|
|
60
77
|
)
|
|
61
|
-
prepared_job = prepare_push(remote_provider.api, config, training_project.job)
|
|
62
78
|
|
|
63
79
|
job_resp = remote_provider.api.create_training_job(
|
|
64
80
|
project_id=project_resp["id"], job=prepared_job
|
|
@@ -66,22 +82,28 @@ def create_training_job(
|
|
|
66
82
|
return job_resp
|
|
67
83
|
|
|
68
84
|
|
|
69
|
-
def
|
|
85
|
+
def create_training_job(
|
|
70
86
|
remote_provider: BasetenRemote,
|
|
71
87
|
config: Path,
|
|
88
|
+
training_project: TrainingProject,
|
|
72
89
|
job_name_from_cli: Optional[str] = None,
|
|
90
|
+
team_name: Optional[str] = None,
|
|
91
|
+
team_id: Optional[str] = None,
|
|
73
92
|
) -> dict:
|
|
74
|
-
|
|
75
|
-
if
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
93
|
+
if job_name_from_cli:
|
|
94
|
+
if training_project.job.name:
|
|
95
|
+
console.print(
|
|
96
|
+
f"[bold yellow]⚠ Warning:[/bold yellow] name '{training_project.job.name}' provided in config file will be ignored. Using job name '{job_name_from_cli}' provided via --job-name flag."
|
|
97
|
+
)
|
|
98
|
+
training_project.job.name = job_name_from_cli
|
|
99
|
+
if team_name:
|
|
100
|
+
training_project.team_name = team_name
|
|
101
|
+
|
|
102
|
+
job_resp = _upsert_project_and_create_job(
|
|
103
|
+
remote_provider=remote_provider,
|
|
104
|
+
training_project=training_project,
|
|
105
|
+
config=config,
|
|
106
|
+
team_id=team_id,
|
|
107
|
+
)
|
|
108
|
+
job_resp["job_object"] = training_project.job
|
|
87
109
|
return job_resp
|
truss_train/public_api.py
CHANGED
|
@@ -3,7 +3,8 @@ from typing import cast
|
|
|
3
3
|
|
|
4
4
|
from truss.remote.baseten.remote import BasetenRemote
|
|
5
5
|
from truss.remote.remote_factory import RemoteFactory
|
|
6
|
-
from truss_train
|
|
6
|
+
from truss_train import loader
|
|
7
|
+
from truss_train.deployment import create_training_job
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
def push(config: Path, remote: str = "baseten"):
|
|
@@ -30,4 +31,5 @@ def push(config: Path, remote: str = "baseten"):
|
|
|
30
31
|
remote_provider: BasetenRemote = cast(
|
|
31
32
|
BasetenRemote, RemoteFactory.create(remote=remote)
|
|
32
33
|
)
|
|
33
|
-
|
|
34
|
+
with loader.import_training_project(config) as training_project:
|
|
35
|
+
return create_training_job(remote_provider, config, training_project)
|
|
File without changes
|
|
File without changes
|