truss 0.11.18rc500__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/truss_config.py +10 -3
  3. truss/cli/chains_commands.py +39 -1
  4. truss/cli/cli.py +35 -5
  5. truss/cli/remote_cli.py +29 -0
  6. truss/cli/resolvers/chain_team_resolver.py +82 -0
  7. truss/cli/resolvers/model_team_resolver.py +90 -0
  8. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  9. truss/cli/train/cache.py +332 -0
  10. truss/cli/train/core.py +19 -143
  11. truss/cli/train_commands.py +69 -11
  12. truss/cli/utils/common.py +40 -3
  13. truss/remote/baseten/api.py +58 -5
  14. truss/remote/baseten/core.py +22 -4
  15. truss/remote/baseten/remote.py +24 -2
  16. truss/templates/control/control/helpers/inference_server_process_controller.py +3 -1
  17. truss/templates/server/requirements.txt +1 -1
  18. truss/templates/server.Dockerfile.jinja +10 -10
  19. truss/templates/shared/util.py +6 -5
  20. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  21. truss/tests/cli/test_chains_cli.py +44 -0
  22. truss/tests/cli/test_cli.py +134 -1
  23. truss/tests/cli/test_cli_utils_common.py +11 -0
  24. truss/tests/cli/test_model_team_resolver.py +279 -0
  25. truss/tests/cli/train/test_cache_view.py +240 -3
  26. truss/tests/cli/train/test_train_cli_core.py +2 -2
  27. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  28. truss/tests/conftest.py +187 -0
  29. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  30. truss/tests/remote/baseten/test_api.py +122 -3
  31. truss/tests/remote/baseten/test_chain_upload.py +10 -1
  32. truss/tests/remote/baseten/test_core.py +86 -0
  33. truss/tests/remote/baseten/test_remote.py +216 -288
  34. truss/tests/test_config.py +21 -12
  35. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  36. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  37. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  38. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  39. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  40. truss/tests/test_model_inference.py +13 -0
  41. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/METADATA +1 -1
  42. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/RECORD +50 -38
  43. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  44. truss_chains/deployment/deployment_client.py +9 -4
  45. truss_chains/private_types.py +15 -0
  46. truss_train/definitions.py +3 -1
  47. truss_train/deployment.py +43 -21
  48. truss_train/public_api.py +4 -2
  49. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  50. {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
@@ -200,6 +200,8 @@ class BasetenApi:
200
200
  deployment_name: Optional[str] = None,
201
201
  origin: Optional[b10_types.ModelOrigin] = None,
202
202
  environment: Optional[str] = None,
203
+ deploy_timeout_minutes: Optional[int] = None,
204
+ team_id: Optional[str] = None,
203
205
  ):
204
206
  query_string = f"""
205
207
  mutation ($trussUserEnv: String) {{
@@ -213,6 +215,8 @@ class BasetenApi:
213
215
  {f'version_name: "{deployment_name}"' if deployment_name else ""}
214
216
  {f"model_origin: {origin.value}" if origin else ""}
215
217
  {f'environment_name: "{environment}"' if environment else ""}
218
+ {f"deploy_timeout_minutes: {deploy_timeout_minutes}" if deploy_timeout_minutes is not None else ""}
219
+ {f'team_id: "{team_id}"' if team_id else ""}
216
220
  ) {{
217
221
  model_version {{
218
222
  id
@@ -244,6 +248,7 @@ class BasetenApi:
244
248
  deployment_name: Optional[str] = None,
245
249
  environment: Optional[str] = None,
246
250
  preserve_env_instance_type: bool = True,
251
+ deploy_timeout_minutes: Optional[int] = None,
247
252
  ):
248
253
  query_string = f"""
249
254
  mutation ($trussUserEnv: String) {{
@@ -257,6 +262,7 @@ class BasetenApi:
257
262
  preserve_env_instance_type: {"true" if preserve_env_instance_type else "false"}
258
263
  {f'name: "{deployment_name}"' if deployment_name else ""}
259
264
  {f'environment_name: "{environment}"' if environment else ""}
265
+ {f"deploy_timeout_minutes: {deploy_timeout_minutes}" if deploy_timeout_minutes is not None else ""}
260
266
  ) {{
261
267
  model_version {{
262
268
  id
@@ -286,6 +292,8 @@ class BasetenApi:
286
292
  truss_user_env: b10_types.TrussUserEnv,
287
293
  allow_truss_download=True,
288
294
  origin: Optional[b10_types.ModelOrigin] = None,
295
+ deploy_timeout_minutes: Optional[int] = None,
296
+ team_id: Optional[str] = None,
289
297
  ):
290
298
  query_string = f"""
291
299
  mutation ($trussUserEnv: String) {{
@@ -295,6 +303,8 @@ class BasetenApi:
295
303
  truss_user_env: $trussUserEnv
296
304
  allow_truss_download: {"true" if allow_truss_download else "false"}
297
305
  {f"model_origin: {origin.value}" if origin else ""}
306
+ {f"deploy_timeout_minutes: {deploy_timeout_minutes}" if deploy_timeout_minutes is not None else ""}
307
+ {f'team_id: "{team_id}"' if team_id else ""}
298
308
  ) {{
299
309
  model_version {{
300
310
  id
@@ -327,6 +337,8 @@ class BasetenApi:
327
337
  is_draft: bool = False,
328
338
  original_source_artifact_s3_key: Optional[str] = None,
329
339
  allow_truss_download: Optional[bool] = True,
340
+ deployment_name: Optional[str] = None,
341
+ team_id: Optional[str] = None,
330
342
  ):
331
343
  if allow_truss_download is None:
332
344
  allow_truss_download = True
@@ -350,10 +362,14 @@ class BasetenApi:
350
362
  params.append(
351
363
  f'original_source_artifact_s3_key: "{original_source_artifact_s3_key}"'
352
364
  )
365
+ if team_id:
366
+ params.append(f'team_id: "{team_id}"')
353
367
 
354
368
  params.append(f"is_draft: {str(is_draft).lower()}")
355
369
  if allow_truss_download is False:
356
370
  params.append("allow_truss_download: false")
371
+ if deployment_name:
372
+ params.append(f'deployment_name: "{deployment_name}"')
357
373
 
358
374
  params_str = PARAMS_INDENT.join(params)
359
375
 
@@ -382,18 +398,24 @@ class BasetenApi:
382
398
 
383
399
  return resp["data"]["deploy_chain_atomic"]
384
400
 
385
- def get_chains(self):
401
+ def get_chains(self, team_id: Optional[str] = None):
386
402
  query_string = """
387
403
  {
388
404
  chains {
389
405
  id
390
406
  name
407
+ team {
408
+ name
409
+ }
391
410
  }
392
411
  }
393
412
  """
394
413
 
395
414
  resp = self._post_graphql_query(query_string)
396
- return resp["data"]["chains"]
415
+ chains = resp["data"]["chains"]
416
+
417
+ # TODO(COR-492): Filter by team_id in the backend
418
+ return chains
397
419
 
398
420
  def get_chain_deployments(self, chain_id: str):
399
421
  query_string = f"""
@@ -456,6 +478,10 @@ class BasetenApi:
456
478
  models {
457
479
  id,
458
480
  name
481
+ team {
482
+ id
483
+ name
484
+ }
459
485
  versions{
460
486
  id,
461
487
  semver,
@@ -495,6 +521,10 @@ class BasetenApi:
495
521
  id
496
522
  name
497
523
  hostname
524
+ team {{
525
+ id
526
+ name
527
+ }}
498
528
  versions {{
499
529
  id
500
530
  semver
@@ -647,10 +677,14 @@ class BasetenApi:
647
677
  "v1/api_keys", body={"type": api_key_type.value, "name": name}
648
678
  )
649
679
 
650
- def upsert_training_project(self, training_project):
680
+ def upsert_training_project(self, training_project, team_id: Optional[str] = None):
681
+ if team_id:
682
+ endpoint = f"v1/teams/{team_id}/training_projects"
683
+ else:
684
+ endpoint = "v1/training_projects"
651
685
  resp_json = self._rest_api_client.post(
652
- "v1/training_projects",
653
- body={"training_project": training_project.model_dump()},
686
+ endpoint,
687
+ body={"training_project": training_project.model_dump(exclude_none=True)},
654
688
  )
655
689
  return resp_json["training_project"]
656
690
 
@@ -903,3 +937,22 @@ class BasetenApi:
903
937
  return [
904
938
  InstanceTypeV1(**instance_type) for instance_type in instance_types_data
905
939
  ]
940
+
941
+ def get_teams(self) -> Dict[str, Dict[str, str]]:
942
+ """
943
+ Get all available teams via GraphQL API.
944
+ Returns a dictionary mapping team name to team data (with 'id' and 'name' keys).
945
+ """
946
+ query_string = """
947
+ query Teams {
948
+ teams {
949
+ id
950
+ name
951
+ }
952
+ }
953
+ """
954
+
955
+ resp = self._post_graphql_query(query_string)
956
+ teams_data = resp["data"]["teams"]
957
+ # Convert list to dict mapping team_name -> team
958
+ return {team["name"]: team for team in teams_data}
@@ -92,19 +92,21 @@ class ModelVersionHandle(NamedTuple):
92
92
  instance_type_name: Optional[str] = None
93
93
 
94
94
 
95
- def get_chain_id_by_name(api: BasetenApi, chain_name: str) -> Optional[str]:
95
+ def get_chain_id_by_name(
96
+ api: BasetenApi, chain_name: str, team_id: Optional[str] = None
97
+ ) -> Optional[str]:
96
98
  """
97
99
  Check if a chain with the given name exists in the Baseten remote.
98
100
 
99
101
  Args:
100
102
  api: BasetenApi instance
101
103
  chain_name: Name of the chain to check for existence
104
+ team_id: Optional team_id to filter chains by team
102
105
 
103
106
  Returns:
104
107
  chain_id if present, otherwise None
105
108
  """
106
- chains = api.get_chains()
107
-
109
+ chains = api.get_chains(team_id=team_id)
108
110
  chain_name_id_mapping = {chain["name"]: chain["id"] for chain in chains}
109
111
  return chain_name_id_mapping.get(chain_name)
110
112
 
@@ -132,6 +134,8 @@ def create_chain_atomic(
132
134
  environment: Optional[str],
133
135
  original_source_artifact_s3_key: Optional[str] = None,
134
136
  allow_truss_download: bool = True,
137
+ deployment_name: Optional[str] = None,
138
+ team_id: Optional[str] = None,
135
139
  ) -> ChainDeploymentHandleAtomic:
136
140
  if environment and is_draft:
137
141
  logging.info(
@@ -140,7 +144,7 @@ def create_chain_atomic(
140
144
  )
141
145
  is_draft = False
142
146
 
143
- chain_id = get_chain_id_by_name(api, chain_name)
147
+ chain_id = get_chain_id_by_name(api, chain_name, team_id=team_id)
144
148
 
145
149
  # TODO(Tyron): Refactor for better readability:
146
150
  # 1. Prepare all arguments for `deploy_chain_atomic`.
@@ -156,6 +160,8 @@ def create_chain_atomic(
156
160
  truss_user_env=truss_user_env,
157
161
  original_source_artifact_s3_key=original_source_artifact_s3_key,
158
162
  allow_truss_download=allow_truss_download,
163
+ deployment_name=deployment_name,
164
+ team_id=team_id,
159
165
  )
160
166
  elif chain_id:
161
167
  # This is the only case where promote has relevance, since
@@ -171,6 +177,8 @@ def create_chain_atomic(
171
177
  truss_user_env=truss_user_env,
172
178
  original_source_artifact_s3_key=original_source_artifact_s3_key,
173
179
  allow_truss_download=allow_truss_download,
180
+ deployment_name=deployment_name,
181
+ team_id=team_id,
174
182
  )
175
183
  except ApiError as e:
176
184
  if (
@@ -193,6 +201,8 @@ def create_chain_atomic(
193
201
  truss_user_env=truss_user_env,
194
202
  original_source_artifact_s3_key=original_source_artifact_s3_key,
195
203
  allow_truss_download=allow_truss_download,
204
+ deployment_name=deployment_name,
205
+ team_id=team_id,
196
206
  )
197
207
 
198
208
  return ChainDeploymentHandleAtomic(
@@ -397,6 +407,8 @@ def create_truss_service(
397
407
  origin: Optional[b10_types.ModelOrigin] = None,
398
408
  environment: Optional[str] = None,
399
409
  preserve_env_instance_type: bool = True,
410
+ deploy_timeout_minutes: Optional[int] = None,
411
+ team_id: Optional[str] = None,
400
412
  ) -> ModelVersionHandle:
401
413
  """
402
414
  Create a model in the Baseten remote.
@@ -412,6 +424,7 @@ def create_truss_service(
412
424
  to zero.
413
425
  deployment_name: Name to apply to the created deployment. Not applied to
414
426
  development model.
427
+ team_id: ID of the team to create the model in.
415
428
 
416
429
  Returns:
417
430
  A Model Version handle.
@@ -424,6 +437,8 @@ def create_truss_service(
424
437
  truss_user_env,
425
438
  allow_truss_download=allow_truss_download,
426
439
  origin=origin,
440
+ deploy_timeout_minutes=deploy_timeout_minutes,
441
+ team_id=team_id,
427
442
  )
428
443
 
429
444
  return ModelVersionHandle(
@@ -448,6 +463,8 @@ def create_truss_service(
448
463
  deployment_name=deployment_name,
449
464
  origin=origin,
450
465
  environment=environment,
466
+ deploy_timeout_minutes=deploy_timeout_minutes,
467
+ team_id=team_id,
451
468
  )
452
469
 
453
470
  return ModelVersionHandle(
@@ -472,6 +489,7 @@ def create_truss_service(
472
489
  deployment_name=deployment_name,
473
490
  environment=environment,
474
491
  preserve_env_instance_type=preserve_env_instance_type,
492
+ deploy_timeout_minutes=deploy_timeout_minutes,
475
493
  )
476
494
  except ApiError as e:
477
495
  if (
@@ -69,6 +69,7 @@ class FinalPushData(custom_types.OracleData):
69
69
  origin: Optional[custom_types.ModelOrigin] = None
70
70
  environment: Optional[str] = None
71
71
  allow_truss_download: bool
72
+ team_id: Optional[str] = None
72
73
 
73
74
 
74
75
  class BasetenRemote(TrussRemote):
@@ -127,6 +128,8 @@ class BasetenRemote(TrussRemote):
127
128
  origin: Optional[custom_types.ModelOrigin] = None,
128
129
  environment: Optional[str] = None,
129
130
  progress_bar: Optional[Type["progress.Progress"]] = None,
131
+ deploy_timeout_minutes: Optional[int] = None,
132
+ team_id: Optional[str] = None,
130
133
  ) -> FinalPushData:
131
134
  if model_name.isspace():
132
135
  raise ValueError("Model name cannot be empty")
@@ -164,6 +167,13 @@ class BasetenRemote(TrussRemote):
164
167
  "Deployment name must only contain alphanumeric, -, _ and . characters"
165
168
  )
166
169
 
170
+ if deploy_timeout_minutes is not None and (
171
+ deploy_timeout_minutes < 10 or deploy_timeout_minutes > 1440
172
+ ):
173
+ raise ValueError(
174
+ "deploy-timeout-minutes must be between 10 minutes and 1440 minutes (24 hours)"
175
+ )
176
+
167
177
  model_id = exists_model(self._api, model_name)
168
178
 
169
179
  if model_id is not None and disable_truss_download:
@@ -188,6 +198,7 @@ class BasetenRemote(TrussRemote):
188
198
  origin=origin,
189
199
  environment=environment,
190
200
  allow_truss_download=not disable_truss_download,
201
+ team_id=team_id,
191
202
  )
192
203
 
193
204
  def push( # type: ignore
@@ -205,6 +216,8 @@ class BasetenRemote(TrussRemote):
205
216
  progress_bar: Optional[Type["progress.Progress"]] = None,
206
217
  include_git_info: bool = False,
207
218
  preserve_env_instance_type: bool = True,
219
+ deploy_timeout_minutes: Optional[int] = None,
220
+ team_id: Optional[str] = None,
208
221
  ) -> BasetenService:
209
222
  push_data = self._prepare_push(
210
223
  truss_handle=truss_handle,
@@ -217,6 +230,8 @@ class BasetenRemote(TrussRemote):
217
230
  origin=origin,
218
231
  environment=environment,
219
232
  progress_bar=progress_bar,
233
+ deploy_timeout_minutes=deploy_timeout_minutes,
234
+ team_id=team_id,
220
235
  )
221
236
 
222
237
  if include_git_info:
@@ -242,6 +257,8 @@ class BasetenRemote(TrussRemote):
242
257
  environment=push_data.environment,
243
258
  truss_user_env=truss_user_env,
244
259
  preserve_env_instance_type=preserve_env_instance_type,
260
+ deploy_timeout_minutes=deploy_timeout_minutes,
261
+ team_id=push_data.team_id,
245
262
  )
246
263
 
247
264
  if model_version_handle.instance_type_name:
@@ -269,6 +286,8 @@ class BasetenRemote(TrussRemote):
269
286
  environment: Optional[str] = None,
270
287
  progress_bar: Optional[Type["progress.Progress"]] = None,
271
288
  disable_chain_download: bool = False,
289
+ deployment_name: Optional[str] = None,
290
+ team_id: Optional[str] = None,
272
291
  ) -> ChainDeploymentHandleAtomic:
273
292
  # If we are promoting a model to an environment after deploy, it must be published.
274
293
  # Draft models cannot be promoted.
@@ -289,6 +308,7 @@ class BasetenRemote(TrussRemote):
289
308
  origin=custom_types.ModelOrigin.CHAINS,
290
309
  progress_bar=progress_bar,
291
310
  disable_truss_download=disable_chain_download,
311
+ deployment_name=deployment_name,
292
312
  )
293
313
  oracle_data = custom_types.OracleData(
294
314
  model_name=push_data.model_name,
@@ -326,6 +346,8 @@ class BasetenRemote(TrussRemote):
326
346
  environment=environment,
327
347
  original_source_artifact_s3_key=raw_chain_s3_key,
328
348
  allow_truss_download=not disable_chain_download,
349
+ deployment_name=deployment_name,
350
+ team_id=team_id,
329
351
  )
330
352
  logging.info("Successfully pushed to baseten. Chain is building and deploying.")
331
353
  return chain_deployment_handle
@@ -589,5 +611,5 @@ class BasetenRemote(TrussRemote):
589
611
  ) -> PatchResult:
590
612
  return self._patch(watch_path, truss_ignore_patterns, console=None)
591
613
 
592
- def upsert_training_project(self, training_project):
593
- return self._api.upsert_training_project(training_project)
614
+ def upsert_training_project(self, training_project, team_id=None):
615
+ return self._api.upsert_training_project(training_project, team_id=team_id)
@@ -49,7 +49,9 @@ class InferenceServerProcessController:
49
49
 
50
50
  def _terminate_children_and_process(self):
51
51
  """Kill child processes first, then parent. Prevents port binding conflicts."""
52
- kill_child_processes(self._inference_server_process.pid)
52
+ # Use a shorter timeout than the truss patch read timeout (=120s):
53
+ # see remote/baseten/api.py:_post_graphql_query()
54
+ kill_child_processes(self._inference_server_process.pid, timeout_seconds=30)
53
55
  self._inference_server_process.terminate()
54
56
 
55
57
  def stop(self):
@@ -18,7 +18,7 @@ psutil>=5.9.4
18
18
  python-json-logger>=2.0.2
19
19
  pyyaml>=6.0.0
20
20
  requests>=2.31.0
21
- truss-transfer==0.0.37
21
+ truss-transfer==0.0.38
22
22
  uvicorn>=0.24.0
23
23
  uvloop>=0.19.0
24
24
  websockets>=10.0
@@ -56,12 +56,6 @@ RUN mkdir -p {{ dst.parent }}; curl -L "{{ url }}" -o {{ dst }}
56
56
  {% endfor %} {#- endfor external_data_files #}
57
57
  {%- endif %} {#- endif external_data_files #}
58
58
 
59
- {%- if build_commands %}
60
- {% for command in build_commands %}
61
- RUN {% for secret,path in config.build.secret_to_path_mapping.items() %} --mount=type=secret,id={{ secret }},target={{ path }}{%- endfor %} {{ command }}
62
- {% endfor %} {#- endfor build_commands #}
63
- {%- endif %} {#- endif build_commands #}
64
-
65
59
  {# Copy data before code for better caching #}
66
60
  {%- if data_dir_exists %}
67
61
  COPY --chown={{ default_owner }} ./{{ config.data_dir }} ${APP_HOME}/data
@@ -109,7 +103,13 @@ USER {{ app_username }}
109
103
  {%- endif %} {#- endif non_root_user #}
110
104
  {%- endmacro -%}
111
105
 
112
- {%- if config.docker_server %}
106
+ {%- if build_commands %}
107
+ {% for command in build_commands %}
108
+ RUN {% for secret,path in config.build.secret_to_path_mapping.items() %} --mount=type=secret,id={{ secret }},target={{ path }}{%- endfor %} {{ command }}
109
+ {% endfor %} {#- endfor build_commands #}
110
+ {%- endif %} {#- endif build_commands #}
111
+
112
+ {%- if config.docker_server %}
113
113
  RUN apt-get update -y && apt-get install -y --no-install-recommends \
114
114
  curl nginx && rm -rf /var/lib/apt/lists/*
115
115
  COPY --chown={{ default_owner }} ./docker_server_requirements.txt ${APP_HOME}/docker_server_requirements.txt
@@ -131,7 +131,7 @@ RUN rm -f /etc/nginx/sites-enabled/default
131
131
  {{ chown_and_switch_to_regular_user_if_enabled(["/var/lib/nginx", "/var/log/nginx", "/run"]) }}
132
132
  ENTRYPOINT ["/docker_server/.venv/bin/supervisord", "-c", "{{ supervisor_config_path }}"]
133
133
 
134
- {%- elif requires_live_reload %} {#- elif requires_live_reload #}
134
+ {%- elif requires_live_reload %} {#- elif requires_live_reload #}
135
135
  ENV HASH_TRUSS="{{ truss_hash }}"
136
136
  ENV CONTROL_SERVER_PORT="8080"
137
137
  ENV INFERENCE_SERVER_PORT="8090"
@@ -139,11 +139,11 @@ ENV SERVER_START_CMD="/control/.env/bin/python /control/control/server.py"
139
139
  {{ chown_and_switch_to_regular_user_if_enabled() }}
140
140
  ENTRYPOINT ["/control/.env/bin/python", "/control/control/server.py"]
141
141
 
142
- {%- else %} {#- else (default inference server) #}
142
+ {%- else %} {#- else (default inference server) #}
143
143
  ENV INFERENCE_SERVER_PORT="8080"
144
144
  ENV SERVER_START_CMD="{{ python_executable }} /app/main.py"
145
145
  {{ chown_and_switch_to_regular_user_if_enabled() }}
146
146
  ENTRYPOINT ["{{ python_executable }}", "/app/main.py"]
147
- {%- endif %} {#- endif config.docker_server / live_reload #}
147
+ {%- endif %} {#- endif config.docker_server / live_reload #}
148
148
 
149
149
  {% endblock %} {#- endblock run #}
@@ -1,7 +1,7 @@
1
1
  import multiprocessing
2
2
  import os
3
3
  import sys
4
- from typing import List
4
+ from typing import List, Optional
5
5
 
6
6
  import psutil
7
7
 
@@ -62,7 +62,10 @@ def all_processes_dead(procs: List[multiprocessing.Process]) -> bool:
62
62
  return True
63
63
 
64
64
 
65
- def kill_child_processes(parent_pid: int):
65
+ def kill_child_processes(
66
+ parent_pid: int,
67
+ timeout_seconds: Optional[float] = CHILD_PROCESS_WAIT_TIMEOUT_SECONDS,
68
+ ):
66
69
  try:
67
70
  parent = psutil.Process(parent_pid)
68
71
  except psutil.NoSuchProcess:
@@ -70,8 +73,6 @@ def kill_child_processes(parent_pid: int):
70
73
  children = parent.children(recursive=True)
71
74
  for process in children:
72
75
  process.terminate()
73
- gone, alive = psutil.wait_procs(
74
- children, timeout=CHILD_PROCESS_WAIT_TIMEOUT_SECONDS
75
- )
76
+ gone, alive = psutil.wait_procs(children, timeout=timeout_seconds)
76
77
  for process in alive:
77
78
  process.kill()