truss 0.11.18rc500__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/truss_config.py +10 -3
- truss/cli/chains_commands.py +39 -1
- truss/cli/cli.py +35 -5
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +19 -143
- truss/cli/train_commands.py +69 -11
- truss/cli/utils/common.py +40 -3
- truss/remote/baseten/api.py +58 -5
- truss/remote/baseten/core.py +22 -4
- truss/remote/baseten/remote.py +24 -2
- truss/templates/control/control/helpers/inference_server_process_controller.py +3 -1
- truss/templates/server/requirements.txt +1 -1
- truss/templates/server.Dockerfile.jinja +10 -10
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +44 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +10 -1
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/METADATA +1 -1
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/RECORD +50 -38
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +9 -4
- truss_chains/private_types.py +15 -0
- truss_train/definitions.py +3 -1
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
truss/remote/baseten/api.py
CHANGED
|
@@ -200,6 +200,8 @@ class BasetenApi:
|
|
|
200
200
|
deployment_name: Optional[str] = None,
|
|
201
201
|
origin: Optional[b10_types.ModelOrigin] = None,
|
|
202
202
|
environment: Optional[str] = None,
|
|
203
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
204
|
+
team_id: Optional[str] = None,
|
|
203
205
|
):
|
|
204
206
|
query_string = f"""
|
|
205
207
|
mutation ($trussUserEnv: String) {{
|
|
@@ -213,6 +215,8 @@ class BasetenApi:
|
|
|
213
215
|
{f'version_name: "{deployment_name}"' if deployment_name else ""}
|
|
214
216
|
{f"model_origin: {origin.value}" if origin else ""}
|
|
215
217
|
{f'environment_name: "{environment}"' if environment else ""}
|
|
218
|
+
{f"deploy_timeout_minutes: {deploy_timeout_minutes}" if deploy_timeout_minutes is not None else ""}
|
|
219
|
+
{f'team_id: "{team_id}"' if team_id else ""}
|
|
216
220
|
) {{
|
|
217
221
|
model_version {{
|
|
218
222
|
id
|
|
@@ -244,6 +248,7 @@ class BasetenApi:
|
|
|
244
248
|
deployment_name: Optional[str] = None,
|
|
245
249
|
environment: Optional[str] = None,
|
|
246
250
|
preserve_env_instance_type: bool = True,
|
|
251
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
247
252
|
):
|
|
248
253
|
query_string = f"""
|
|
249
254
|
mutation ($trussUserEnv: String) {{
|
|
@@ -257,6 +262,7 @@ class BasetenApi:
|
|
|
257
262
|
preserve_env_instance_type: {"true" if preserve_env_instance_type else "false"}
|
|
258
263
|
{f'name: "{deployment_name}"' if deployment_name else ""}
|
|
259
264
|
{f'environment_name: "{environment}"' if environment else ""}
|
|
265
|
+
{f"deploy_timeout_minutes: {deploy_timeout_minutes}" if deploy_timeout_minutes is not None else ""}
|
|
260
266
|
) {{
|
|
261
267
|
model_version {{
|
|
262
268
|
id
|
|
@@ -286,6 +292,8 @@ class BasetenApi:
|
|
|
286
292
|
truss_user_env: b10_types.TrussUserEnv,
|
|
287
293
|
allow_truss_download=True,
|
|
288
294
|
origin: Optional[b10_types.ModelOrigin] = None,
|
|
295
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
296
|
+
team_id: Optional[str] = None,
|
|
289
297
|
):
|
|
290
298
|
query_string = f"""
|
|
291
299
|
mutation ($trussUserEnv: String) {{
|
|
@@ -295,6 +303,8 @@ class BasetenApi:
|
|
|
295
303
|
truss_user_env: $trussUserEnv
|
|
296
304
|
allow_truss_download: {"true" if allow_truss_download else "false"}
|
|
297
305
|
{f"model_origin: {origin.value}" if origin else ""}
|
|
306
|
+
{f"deploy_timeout_minutes: {deploy_timeout_minutes}" if deploy_timeout_minutes is not None else ""}
|
|
307
|
+
{f'team_id: "{team_id}"' if team_id else ""}
|
|
298
308
|
) {{
|
|
299
309
|
model_version {{
|
|
300
310
|
id
|
|
@@ -327,6 +337,8 @@ class BasetenApi:
|
|
|
327
337
|
is_draft: bool = False,
|
|
328
338
|
original_source_artifact_s3_key: Optional[str] = None,
|
|
329
339
|
allow_truss_download: Optional[bool] = True,
|
|
340
|
+
deployment_name: Optional[str] = None,
|
|
341
|
+
team_id: Optional[str] = None,
|
|
330
342
|
):
|
|
331
343
|
if allow_truss_download is None:
|
|
332
344
|
allow_truss_download = True
|
|
@@ -350,10 +362,14 @@ class BasetenApi:
|
|
|
350
362
|
params.append(
|
|
351
363
|
f'original_source_artifact_s3_key: "{original_source_artifact_s3_key}"'
|
|
352
364
|
)
|
|
365
|
+
if team_id:
|
|
366
|
+
params.append(f'team_id: "{team_id}"')
|
|
353
367
|
|
|
354
368
|
params.append(f"is_draft: {str(is_draft).lower()}")
|
|
355
369
|
if allow_truss_download is False:
|
|
356
370
|
params.append("allow_truss_download: false")
|
|
371
|
+
if deployment_name:
|
|
372
|
+
params.append(f'deployment_name: "{deployment_name}"')
|
|
357
373
|
|
|
358
374
|
params_str = PARAMS_INDENT.join(params)
|
|
359
375
|
|
|
@@ -382,18 +398,24 @@ class BasetenApi:
|
|
|
382
398
|
|
|
383
399
|
return resp["data"]["deploy_chain_atomic"]
|
|
384
400
|
|
|
385
|
-
def get_chains(self):
|
|
401
|
+
def get_chains(self, team_id: Optional[str] = None):
|
|
386
402
|
query_string = """
|
|
387
403
|
{
|
|
388
404
|
chains {
|
|
389
405
|
id
|
|
390
406
|
name
|
|
407
|
+
team {
|
|
408
|
+
name
|
|
409
|
+
}
|
|
391
410
|
}
|
|
392
411
|
}
|
|
393
412
|
"""
|
|
394
413
|
|
|
395
414
|
resp = self._post_graphql_query(query_string)
|
|
396
|
-
|
|
415
|
+
chains = resp["data"]["chains"]
|
|
416
|
+
|
|
417
|
+
# TODO(COR-492): Filter by team_id in the backend
|
|
418
|
+
return chains
|
|
397
419
|
|
|
398
420
|
def get_chain_deployments(self, chain_id: str):
|
|
399
421
|
query_string = f"""
|
|
@@ -456,6 +478,10 @@ class BasetenApi:
|
|
|
456
478
|
models {
|
|
457
479
|
id,
|
|
458
480
|
name
|
|
481
|
+
team {
|
|
482
|
+
id
|
|
483
|
+
name
|
|
484
|
+
}
|
|
459
485
|
versions{
|
|
460
486
|
id,
|
|
461
487
|
semver,
|
|
@@ -495,6 +521,10 @@ class BasetenApi:
|
|
|
495
521
|
id
|
|
496
522
|
name
|
|
497
523
|
hostname
|
|
524
|
+
team {{
|
|
525
|
+
id
|
|
526
|
+
name
|
|
527
|
+
}}
|
|
498
528
|
versions {{
|
|
499
529
|
id
|
|
500
530
|
semver
|
|
@@ -647,10 +677,14 @@ class BasetenApi:
|
|
|
647
677
|
"v1/api_keys", body={"type": api_key_type.value, "name": name}
|
|
648
678
|
)
|
|
649
679
|
|
|
650
|
-
def upsert_training_project(self, training_project):
|
|
680
|
+
def upsert_training_project(self, training_project, team_id: Optional[str] = None):
|
|
681
|
+
if team_id:
|
|
682
|
+
endpoint = f"v1/teams/{team_id}/training_projects"
|
|
683
|
+
else:
|
|
684
|
+
endpoint = "v1/training_projects"
|
|
651
685
|
resp_json = self._rest_api_client.post(
|
|
652
|
-
|
|
653
|
-
body={"training_project": training_project.model_dump()},
|
|
686
|
+
endpoint,
|
|
687
|
+
body={"training_project": training_project.model_dump(exclude_none=True)},
|
|
654
688
|
)
|
|
655
689
|
return resp_json["training_project"]
|
|
656
690
|
|
|
@@ -903,3 +937,22 @@ class BasetenApi:
|
|
|
903
937
|
return [
|
|
904
938
|
InstanceTypeV1(**instance_type) for instance_type in instance_types_data
|
|
905
939
|
]
|
|
940
|
+
|
|
941
|
+
def get_teams(self) -> Dict[str, Dict[str, str]]:
|
|
942
|
+
"""
|
|
943
|
+
Get all available teams via GraphQL API.
|
|
944
|
+
Returns a dictionary mapping team name to team data (with 'id' and 'name' keys).
|
|
945
|
+
"""
|
|
946
|
+
query_string = """
|
|
947
|
+
query Teams {
|
|
948
|
+
teams {
|
|
949
|
+
id
|
|
950
|
+
name
|
|
951
|
+
}
|
|
952
|
+
}
|
|
953
|
+
"""
|
|
954
|
+
|
|
955
|
+
resp = self._post_graphql_query(query_string)
|
|
956
|
+
teams_data = resp["data"]["teams"]
|
|
957
|
+
# Convert list to dict mapping team_name -> team
|
|
958
|
+
return {team["name"]: team for team in teams_data}
|
truss/remote/baseten/core.py
CHANGED
|
@@ -92,19 +92,21 @@ class ModelVersionHandle(NamedTuple):
|
|
|
92
92
|
instance_type_name: Optional[str] = None
|
|
93
93
|
|
|
94
94
|
|
|
95
|
-
def get_chain_id_by_name(
|
|
95
|
+
def get_chain_id_by_name(
|
|
96
|
+
api: BasetenApi, chain_name: str, team_id: Optional[str] = None
|
|
97
|
+
) -> Optional[str]:
|
|
96
98
|
"""
|
|
97
99
|
Check if a chain with the given name exists in the Baseten remote.
|
|
98
100
|
|
|
99
101
|
Args:
|
|
100
102
|
api: BasetenApi instance
|
|
101
103
|
chain_name: Name of the chain to check for existence
|
|
104
|
+
team_id: Optional team_id to filter chains by team
|
|
102
105
|
|
|
103
106
|
Returns:
|
|
104
107
|
chain_id if present, otherwise None
|
|
105
108
|
"""
|
|
106
|
-
chains = api.get_chains()
|
|
107
|
-
|
|
109
|
+
chains = api.get_chains(team_id=team_id)
|
|
108
110
|
chain_name_id_mapping = {chain["name"]: chain["id"] for chain in chains}
|
|
109
111
|
return chain_name_id_mapping.get(chain_name)
|
|
110
112
|
|
|
@@ -132,6 +134,8 @@ def create_chain_atomic(
|
|
|
132
134
|
environment: Optional[str],
|
|
133
135
|
original_source_artifact_s3_key: Optional[str] = None,
|
|
134
136
|
allow_truss_download: bool = True,
|
|
137
|
+
deployment_name: Optional[str] = None,
|
|
138
|
+
team_id: Optional[str] = None,
|
|
135
139
|
) -> ChainDeploymentHandleAtomic:
|
|
136
140
|
if environment and is_draft:
|
|
137
141
|
logging.info(
|
|
@@ -140,7 +144,7 @@ def create_chain_atomic(
|
|
|
140
144
|
)
|
|
141
145
|
is_draft = False
|
|
142
146
|
|
|
143
|
-
chain_id = get_chain_id_by_name(api, chain_name)
|
|
147
|
+
chain_id = get_chain_id_by_name(api, chain_name, team_id=team_id)
|
|
144
148
|
|
|
145
149
|
# TODO(Tyron): Refactor for better readability:
|
|
146
150
|
# 1. Prepare all arguments for `deploy_chain_atomic`.
|
|
@@ -156,6 +160,8 @@ def create_chain_atomic(
|
|
|
156
160
|
truss_user_env=truss_user_env,
|
|
157
161
|
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
158
162
|
allow_truss_download=allow_truss_download,
|
|
163
|
+
deployment_name=deployment_name,
|
|
164
|
+
team_id=team_id,
|
|
159
165
|
)
|
|
160
166
|
elif chain_id:
|
|
161
167
|
# This is the only case where promote has relevance, since
|
|
@@ -171,6 +177,8 @@ def create_chain_atomic(
|
|
|
171
177
|
truss_user_env=truss_user_env,
|
|
172
178
|
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
173
179
|
allow_truss_download=allow_truss_download,
|
|
180
|
+
deployment_name=deployment_name,
|
|
181
|
+
team_id=team_id,
|
|
174
182
|
)
|
|
175
183
|
except ApiError as e:
|
|
176
184
|
if (
|
|
@@ -193,6 +201,8 @@ def create_chain_atomic(
|
|
|
193
201
|
truss_user_env=truss_user_env,
|
|
194
202
|
original_source_artifact_s3_key=original_source_artifact_s3_key,
|
|
195
203
|
allow_truss_download=allow_truss_download,
|
|
204
|
+
deployment_name=deployment_name,
|
|
205
|
+
team_id=team_id,
|
|
196
206
|
)
|
|
197
207
|
|
|
198
208
|
return ChainDeploymentHandleAtomic(
|
|
@@ -397,6 +407,8 @@ def create_truss_service(
|
|
|
397
407
|
origin: Optional[b10_types.ModelOrigin] = None,
|
|
398
408
|
environment: Optional[str] = None,
|
|
399
409
|
preserve_env_instance_type: bool = True,
|
|
410
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
411
|
+
team_id: Optional[str] = None,
|
|
400
412
|
) -> ModelVersionHandle:
|
|
401
413
|
"""
|
|
402
414
|
Create a model in the Baseten remote.
|
|
@@ -412,6 +424,7 @@ def create_truss_service(
|
|
|
412
424
|
to zero.
|
|
413
425
|
deployment_name: Name to apply to the created deployment. Not applied to
|
|
414
426
|
development model.
|
|
427
|
+
team_id: ID of the team to create the model in.
|
|
415
428
|
|
|
416
429
|
Returns:
|
|
417
430
|
A Model Version handle.
|
|
@@ -424,6 +437,8 @@ def create_truss_service(
|
|
|
424
437
|
truss_user_env,
|
|
425
438
|
allow_truss_download=allow_truss_download,
|
|
426
439
|
origin=origin,
|
|
440
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
441
|
+
team_id=team_id,
|
|
427
442
|
)
|
|
428
443
|
|
|
429
444
|
return ModelVersionHandle(
|
|
@@ -448,6 +463,8 @@ def create_truss_service(
|
|
|
448
463
|
deployment_name=deployment_name,
|
|
449
464
|
origin=origin,
|
|
450
465
|
environment=environment,
|
|
466
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
467
|
+
team_id=team_id,
|
|
451
468
|
)
|
|
452
469
|
|
|
453
470
|
return ModelVersionHandle(
|
|
@@ -472,6 +489,7 @@ def create_truss_service(
|
|
|
472
489
|
deployment_name=deployment_name,
|
|
473
490
|
environment=environment,
|
|
474
491
|
preserve_env_instance_type=preserve_env_instance_type,
|
|
492
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
475
493
|
)
|
|
476
494
|
except ApiError as e:
|
|
477
495
|
if (
|
truss/remote/baseten/remote.py
CHANGED
|
@@ -69,6 +69,7 @@ class FinalPushData(custom_types.OracleData):
|
|
|
69
69
|
origin: Optional[custom_types.ModelOrigin] = None
|
|
70
70
|
environment: Optional[str] = None
|
|
71
71
|
allow_truss_download: bool
|
|
72
|
+
team_id: Optional[str] = None
|
|
72
73
|
|
|
73
74
|
|
|
74
75
|
class BasetenRemote(TrussRemote):
|
|
@@ -127,6 +128,8 @@ class BasetenRemote(TrussRemote):
|
|
|
127
128
|
origin: Optional[custom_types.ModelOrigin] = None,
|
|
128
129
|
environment: Optional[str] = None,
|
|
129
130
|
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
131
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
132
|
+
team_id: Optional[str] = None,
|
|
130
133
|
) -> FinalPushData:
|
|
131
134
|
if model_name.isspace():
|
|
132
135
|
raise ValueError("Model name cannot be empty")
|
|
@@ -164,6 +167,13 @@ class BasetenRemote(TrussRemote):
|
|
|
164
167
|
"Deployment name must only contain alphanumeric, -, _ and . characters"
|
|
165
168
|
)
|
|
166
169
|
|
|
170
|
+
if deploy_timeout_minutes is not None and (
|
|
171
|
+
deploy_timeout_minutes < 10 or deploy_timeout_minutes > 1440
|
|
172
|
+
):
|
|
173
|
+
raise ValueError(
|
|
174
|
+
"deploy-timeout-minutes must be between 10 minutes and 1440 minutes (24 hours)"
|
|
175
|
+
)
|
|
176
|
+
|
|
167
177
|
model_id = exists_model(self._api, model_name)
|
|
168
178
|
|
|
169
179
|
if model_id is not None and disable_truss_download:
|
|
@@ -188,6 +198,7 @@ class BasetenRemote(TrussRemote):
|
|
|
188
198
|
origin=origin,
|
|
189
199
|
environment=environment,
|
|
190
200
|
allow_truss_download=not disable_truss_download,
|
|
201
|
+
team_id=team_id,
|
|
191
202
|
)
|
|
192
203
|
|
|
193
204
|
def push( # type: ignore
|
|
@@ -205,6 +216,8 @@ class BasetenRemote(TrussRemote):
|
|
|
205
216
|
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
206
217
|
include_git_info: bool = False,
|
|
207
218
|
preserve_env_instance_type: bool = True,
|
|
219
|
+
deploy_timeout_minutes: Optional[int] = None,
|
|
220
|
+
team_id: Optional[str] = None,
|
|
208
221
|
) -> BasetenService:
|
|
209
222
|
push_data = self._prepare_push(
|
|
210
223
|
truss_handle=truss_handle,
|
|
@@ -217,6 +230,8 @@ class BasetenRemote(TrussRemote):
|
|
|
217
230
|
origin=origin,
|
|
218
231
|
environment=environment,
|
|
219
232
|
progress_bar=progress_bar,
|
|
233
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
234
|
+
team_id=team_id,
|
|
220
235
|
)
|
|
221
236
|
|
|
222
237
|
if include_git_info:
|
|
@@ -242,6 +257,8 @@ class BasetenRemote(TrussRemote):
|
|
|
242
257
|
environment=push_data.environment,
|
|
243
258
|
truss_user_env=truss_user_env,
|
|
244
259
|
preserve_env_instance_type=preserve_env_instance_type,
|
|
260
|
+
deploy_timeout_minutes=deploy_timeout_minutes,
|
|
261
|
+
team_id=push_data.team_id,
|
|
245
262
|
)
|
|
246
263
|
|
|
247
264
|
if model_version_handle.instance_type_name:
|
|
@@ -269,6 +286,8 @@ class BasetenRemote(TrussRemote):
|
|
|
269
286
|
environment: Optional[str] = None,
|
|
270
287
|
progress_bar: Optional[Type["progress.Progress"]] = None,
|
|
271
288
|
disable_chain_download: bool = False,
|
|
289
|
+
deployment_name: Optional[str] = None,
|
|
290
|
+
team_id: Optional[str] = None,
|
|
272
291
|
) -> ChainDeploymentHandleAtomic:
|
|
273
292
|
# If we are promoting a model to an environment after deploy, it must be published.
|
|
274
293
|
# Draft models cannot be promoted.
|
|
@@ -289,6 +308,7 @@ class BasetenRemote(TrussRemote):
|
|
|
289
308
|
origin=custom_types.ModelOrigin.CHAINS,
|
|
290
309
|
progress_bar=progress_bar,
|
|
291
310
|
disable_truss_download=disable_chain_download,
|
|
311
|
+
deployment_name=deployment_name,
|
|
292
312
|
)
|
|
293
313
|
oracle_data = custom_types.OracleData(
|
|
294
314
|
model_name=push_data.model_name,
|
|
@@ -326,6 +346,8 @@ class BasetenRemote(TrussRemote):
|
|
|
326
346
|
environment=environment,
|
|
327
347
|
original_source_artifact_s3_key=raw_chain_s3_key,
|
|
328
348
|
allow_truss_download=not disable_chain_download,
|
|
349
|
+
deployment_name=deployment_name,
|
|
350
|
+
team_id=team_id,
|
|
329
351
|
)
|
|
330
352
|
logging.info("Successfully pushed to baseten. Chain is building and deploying.")
|
|
331
353
|
return chain_deployment_handle
|
|
@@ -589,5 +611,5 @@ class BasetenRemote(TrussRemote):
|
|
|
589
611
|
) -> PatchResult:
|
|
590
612
|
return self._patch(watch_path, truss_ignore_patterns, console=None)
|
|
591
613
|
|
|
592
|
-
def upsert_training_project(self, training_project):
|
|
593
|
-
return self._api.upsert_training_project(training_project)
|
|
614
|
+
def upsert_training_project(self, training_project, team_id=None):
|
|
615
|
+
return self._api.upsert_training_project(training_project, team_id=team_id)
|
|
@@ -49,7 +49,9 @@ class InferenceServerProcessController:
|
|
|
49
49
|
|
|
50
50
|
def _terminate_children_and_process(self):
|
|
51
51
|
"""Kill child processes first, then parent. Prevents port binding conflicts."""
|
|
52
|
-
|
|
52
|
+
# Use a shorter timeout than the truss patch read timeout (=120s):
|
|
53
|
+
# see remote/baseten/api.py:_post_graphql_query()
|
|
54
|
+
kill_child_processes(self._inference_server_process.pid, timeout_seconds=30)
|
|
53
55
|
self._inference_server_process.terminate()
|
|
54
56
|
|
|
55
57
|
def stop(self):
|
|
@@ -56,12 +56,6 @@ RUN mkdir -p {{ dst.parent }}; curl -L "{{ url }}" -o {{ dst }}
|
|
|
56
56
|
{% endfor %} {#- endfor external_data_files #}
|
|
57
57
|
{%- endif %} {#- endif external_data_files #}
|
|
58
58
|
|
|
59
|
-
{%- if build_commands %}
|
|
60
|
-
{% for command in build_commands %}
|
|
61
|
-
RUN {% for secret,path in config.build.secret_to_path_mapping.items() %} --mount=type=secret,id={{ secret }},target={{ path }}{%- endfor %} {{ command }}
|
|
62
|
-
{% endfor %} {#- endfor build_commands #}
|
|
63
|
-
{%- endif %} {#- endif build_commands #}
|
|
64
|
-
|
|
65
59
|
{# Copy data before code for better caching #}
|
|
66
60
|
{%- if data_dir_exists %}
|
|
67
61
|
COPY --chown={{ default_owner }} ./{{ config.data_dir }} ${APP_HOME}/data
|
|
@@ -109,7 +103,13 @@ USER {{ app_username }}
|
|
|
109
103
|
{%- endif %} {#- endif non_root_user #}
|
|
110
104
|
{%- endmacro -%}
|
|
111
105
|
|
|
112
|
-
|
|
106
|
+
{%- if build_commands %}
|
|
107
|
+
{% for command in build_commands %}
|
|
108
|
+
RUN {% for secret,path in config.build.secret_to_path_mapping.items() %} --mount=type=secret,id={{ secret }},target={{ path }}{%- endfor %} {{ command }}
|
|
109
|
+
{% endfor %} {#- endfor build_commands #}
|
|
110
|
+
{%- endif %} {#- endif build_commands #}
|
|
111
|
+
|
|
112
|
+
{%- if config.docker_server %}
|
|
113
113
|
RUN apt-get update -y && apt-get install -y --no-install-recommends \
|
|
114
114
|
curl nginx && rm -rf /var/lib/apt/lists/*
|
|
115
115
|
COPY --chown={{ default_owner }} ./docker_server_requirements.txt ${APP_HOME}/docker_server_requirements.txt
|
|
@@ -131,7 +131,7 @@ RUN rm -f /etc/nginx/sites-enabled/default
|
|
|
131
131
|
{{ chown_and_switch_to_regular_user_if_enabled(["/var/lib/nginx", "/var/log/nginx", "/run"]) }}
|
|
132
132
|
ENTRYPOINT ["/docker_server/.venv/bin/supervisord", "-c", "{{ supervisor_config_path }}"]
|
|
133
133
|
|
|
134
|
-
|
|
134
|
+
{%- elif requires_live_reload %} {#- elif requires_live_reload #}
|
|
135
135
|
ENV HASH_TRUSS="{{ truss_hash }}"
|
|
136
136
|
ENV CONTROL_SERVER_PORT="8080"
|
|
137
137
|
ENV INFERENCE_SERVER_PORT="8090"
|
|
@@ -139,11 +139,11 @@ ENV SERVER_START_CMD="/control/.env/bin/python /control/control/server.py"
|
|
|
139
139
|
{{ chown_and_switch_to_regular_user_if_enabled() }}
|
|
140
140
|
ENTRYPOINT ["/control/.env/bin/python", "/control/control/server.py"]
|
|
141
141
|
|
|
142
|
-
|
|
142
|
+
{%- else %} {#- else (default inference server) #}
|
|
143
143
|
ENV INFERENCE_SERVER_PORT="8080"
|
|
144
144
|
ENV SERVER_START_CMD="{{ python_executable }} /app/main.py"
|
|
145
145
|
{{ chown_and_switch_to_regular_user_if_enabled() }}
|
|
146
146
|
ENTRYPOINT ["{{ python_executable }}", "/app/main.py"]
|
|
147
|
-
|
|
147
|
+
{%- endif %} {#- endif config.docker_server / live_reload #}
|
|
148
148
|
|
|
149
149
|
{% endblock %} {#- endblock run #}
|
truss/templates/shared/util.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import multiprocessing
|
|
2
2
|
import os
|
|
3
3
|
import sys
|
|
4
|
-
from typing import List
|
|
4
|
+
from typing import List, Optional
|
|
5
5
|
|
|
6
6
|
import psutil
|
|
7
7
|
|
|
@@ -62,7 +62,10 @@ def all_processes_dead(procs: List[multiprocessing.Process]) -> bool:
|
|
|
62
62
|
return True
|
|
63
63
|
|
|
64
64
|
|
|
65
|
-
def kill_child_processes(
|
|
65
|
+
def kill_child_processes(
|
|
66
|
+
parent_pid: int,
|
|
67
|
+
timeout_seconds: Optional[float] = CHILD_PROCESS_WAIT_TIMEOUT_SECONDS,
|
|
68
|
+
):
|
|
66
69
|
try:
|
|
67
70
|
parent = psutil.Process(parent_pid)
|
|
68
71
|
except psutil.NoSuchProcess:
|
|
@@ -70,8 +73,6 @@ def kill_child_processes(parent_pid: int):
|
|
|
70
73
|
children = parent.children(recursive=True)
|
|
71
74
|
for process in children:
|
|
72
75
|
process.terminate()
|
|
73
|
-
gone, alive = psutil.wait_procs(
|
|
74
|
-
children, timeout=CHILD_PROCESS_WAIT_TIMEOUT_SECONDS
|
|
75
|
-
)
|
|
76
|
+
gone, alive = psutil.wait_procs(children, timeout=timeout_seconds)
|
|
76
77
|
for process in alive:
|
|
77
78
|
process.kill()
|