truss 0.10.9rc601__py3-none-any.whl → 0.10.10rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (32) hide show
  1. truss/base/constants.py +0 -1
  2. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +30 -22
  3. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +8 -2
  4. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +2 -2
  5. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +63 -0
  6. truss/cli/train/deploy_from_checkpoint_config_whisper.yml +17 -0
  7. truss/cli/train_commands.py +11 -3
  8. truss/contexts/image_builder/cache_warmer.py +1 -3
  9. truss/contexts/image_builder/serving_image_builder.py +24 -32
  10. truss/remote/baseten/api.py +11 -0
  11. truss/remote/baseten/core.py +209 -1
  12. truss/remote/baseten/utils/time.py +15 -0
  13. truss/templates/server/model_wrapper.py +0 -12
  14. truss/templates/server/requirements.txt +1 -1
  15. truss/templates/server/truss_server.py +0 -13
  16. truss/templates/server.Dockerfile.jinja +1 -1
  17. truss/tests/cli/train/test_deploy_checkpoints.py +436 -0
  18. truss/tests/contexts/image_builder/test_serving_image_builder.py +1 -1
  19. truss/tests/remote/baseten/conftest.py +18 -0
  20. truss/tests/remote/baseten/test_api.py +49 -14
  21. truss/tests/remote/baseten/test_core.py +517 -1
  22. truss/tests/test_data/test_openai/model/model.py +0 -3
  23. truss/truss_handle/truss_handle.py +0 -1
  24. {truss-0.10.9rc601.dist-info → truss-0.10.10rc1.dist-info}/METADATA +2 -2
  25. {truss-0.10.9rc601.dist-info → truss-0.10.10rc1.dist-info}/RECORD +30 -28
  26. truss_train/definitions.py +6 -0
  27. truss_train/deployment.py +15 -2
  28. truss/tests/util/test_basetenpointer.py +0 -227
  29. truss/util/basetenpointer.py +0 -160
  30. {truss-0.10.9rc601.dist-info → truss-0.10.10rc1.dist-info}/WHEEL +0 -0
  31. {truss-0.10.9rc601.dist-info → truss-0.10.10rc1.dist-info}/entry_points.txt +0 -0
  32. {truss-0.10.9rc601.dist-info → truss-0.10.10rc1.dist-info}/licenses/LICENSE +0 -0
truss/base/constants.py CHANGED
@@ -29,7 +29,6 @@ BEI_REQUIRED_MAX_NUM_TOKENS = 16384
29
29
  TRTLLM_MIN_MEMORY_REQUEST_GI = 10
30
30
  HF_MODELS_API_URL = "https://huggingface.co/api/models"
31
31
  HF_ACCESS_TOKEN_KEY = "hf_access_token"
32
- HF_ACCESS_TOKEN_FILE_NAME = "hf_access_token"
33
32
  TRUSSLESS_MAX_PAYLOAD_SIZE = "64M"
34
33
  # Alias for TEMPLATES_DIR
35
34
  SERVING_DIR: pathlib.Path = TEMPLATES_DIR
@@ -33,6 +33,10 @@ from .deploy_lora_checkpoints import (
33
33
  hydrate_lora_checkpoint,
34
34
  render_vllm_lora_truss_config,
35
35
  )
36
+ from .deploy_whisper_checkpoints import (
37
+ hydrate_whisper_checkpoint,
38
+ render_vllm_whisper_truss_config,
39
+ )
36
40
 
37
41
  HF_TOKEN_ENVVAR_NAME = "HF_TOKEN"
38
42
  # If we change this, make sure to update the logic in backend codebase
@@ -178,6 +182,8 @@ def hydrate_checkpoint(
178
182
  return hydrate_lora_checkpoint(job_id, checkpoint_id, checkpoint)
179
183
  elif checkpoint_type.lower() == ModelWeightsFormat.FULL.value:
180
184
  return hydrate_full_checkpoint(job_id, checkpoint_id, checkpoint)
185
+ elif checkpoint_type.lower() == ModelWeightsFormat.WHISPER.value:
186
+ return hydrate_whisper_checkpoint(job_id, checkpoint_id, checkpoint)
181
187
  else:
182
188
  raise ValueError(
183
189
  f"Unsupported checkpoint type: {checkpoint_type}. Contact Baseten for support with other checkpoint types."
@@ -196,6 +202,8 @@ def _render_truss_config_for_checkpoint_deployment(
196
202
  return render_vllm_lora_truss_config(checkpoint_deploy)
197
203
  elif checkpoint_deploy.model_weight_format == ModelWeightsFormat.FULL:
198
204
  return render_vllm_full_truss_config(checkpoint_deploy)
205
+ elif checkpoint_deploy.model_weight_format == ModelWeightsFormat.WHISPER:
206
+ return render_vllm_whisper_truss_config(checkpoint_deploy)
199
207
  else:
200
208
  raise ValueError(
201
209
  f"Unsupported model weight format: {checkpoint_deploy.model_weight_format}. Please upgrade to the latest Truss version to access the latest supported formats. Contact Baseten if you would like us to support additional formats."
@@ -288,18 +296,6 @@ def _get_checkpoint_ids_to_deploy(
288
296
  return checkpoint_ids
289
297
 
290
298
 
291
- def _select_single_checkpoint(checkpoint_id_options: List[str]) -> List[str]:
292
- """Select a single checkpoint using interactive prompt."""
293
- checkpoint_id = inquirer.select(
294
- message="Select the checkpoint to deploy:", choices=checkpoint_id_options
295
- ).execute()
296
-
297
- if not checkpoint_id:
298
- raise click.UsageError("A checkpoint must be selected.")
299
-
300
- return [checkpoint_id]
301
-
302
-
303
299
  def _select_multiple_checkpoints(checkpoint_id_options: List[str]) -> List[str]:
304
300
  """Select multiple checkpoints using interactive checkbox."""
305
301
  checkpoint_ids = inquirer.checkbox(
@@ -351,6 +347,8 @@ def _get_base_model_id(user_input: Optional[str], checkpoint: dict) -> Optional[
351
347
  )
352
348
  elif checkpoint.get("checkpoint_type") == ModelWeightsFormat.FULL.value.lower():
353
349
  return None
350
+ elif checkpoint.get("checkpoint_type") == ModelWeightsFormat.WHISPER.value.lower():
351
+ return None
354
352
  else:
355
353
  base_model_id = inquirer.text(message="Enter the base model id.").execute()
356
354
  if not base_model_id:
@@ -416,18 +414,28 @@ def _validate_selected_checkpoints(
416
414
  "Unable to infer model weight format. Reach out to Baseten for support."
417
415
  )
418
416
 
419
- has_full_checkpoint = any(
420
- response_checkpoints[checkpoint_id].get("checkpoint_type")
421
- == ModelWeightsFormat.FULL.value
422
- for checkpoint_id in checkpoint_ids
423
- )
424
-
425
- if has_full_checkpoint and len(checkpoint_ids) > 1:
426
- # vLLM does not support multiple checkpoints when any checkpoint is full model weights.
427
- raise ValueError(
428
- "Full checkpoints are not supported for multiple checkpoints. Please select a single checkpoint."
417
+ validation_rules = {
418
+ ModelWeightsFormat.FULL.value: {
419
+ "error_message": "Full checkpoints are not supported for multiple checkpoints. Please select a single checkpoint.",
420
+ "reason": "vLLM does not support multiple checkpoints when any checkpoint is full model weights.",
421
+ },
422
+ ModelWeightsFormat.WHISPER.value: {
423
+ "error_message": "Whisper checkpoints are not supported for multiple checkpoints. Please select a single checkpoint.",
424
+ "reason": "vLLM does not support multiple checkpoints when any checkpoint is whisper model weights.",
425
+ },
426
+ }
427
+
428
+ # Check each checkpoint type that has restrictions
429
+ for checkpoint_type, rule in validation_rules.items():
430
+ has_restricted_checkpoint = any(
431
+ response_checkpoints[checkpoint_id].get("checkpoint_type")
432
+ == checkpoint_type
433
+ for checkpoint_id in checkpoint_ids
429
434
  )
430
435
 
436
+ if has_restricted_checkpoint and len(checkpoint_ids) > 1:
437
+ raise ValueError(rule["error_message"])
438
+
431
439
 
432
440
  def get_hf_secret_name(user_input: Union[str, SecretReference, None]) -> str:
433
441
  """Get HuggingFace secret name from user input or prompt for it."""
@@ -3,7 +3,7 @@ from pathlib import Path
3
3
 
4
4
  from truss.base import truss_config
5
5
  from truss.cli.train.types import DeployCheckpointsConfigComplete
6
- from truss_train.definitions import SecretReference
6
+ from truss_train.definitions import ModelWeightsFormat, SecretReference
7
7
 
8
8
  START_COMMAND_ENVVAR_NAME = "BT_DOCKER_SERVER_START_CMD"
9
9
 
@@ -12,8 +12,14 @@ def setup_base_truss_config(
12
12
  checkpoint_deploy: DeployCheckpointsConfigComplete,
13
13
  ) -> truss_config.TrussConfig:
14
14
  """Set up the base truss config with common properties."""
15
+ truss_deploy_config = None
16
+ truss_base_file = (
17
+ "deploy_from_checkpoint_config_whisper.yml"
18
+ if checkpoint_deploy.model_weight_format == ModelWeightsFormat.WHISPER
19
+ else "deploy_from_checkpoint_config.yml"
20
+ )
15
21
  truss_deploy_config = truss_config.TrussConfig.from_yaml(
16
- Path(os.path.dirname(__file__), "..", "deploy_from_checkpoint_config.yml")
22
+ Path(os.path.dirname(__file__), "..", truss_base_file)
17
23
  )
18
24
  if not truss_deploy_config.docker_server:
19
25
  raise ValueError(
@@ -40,7 +40,7 @@ def render_vllm_full_truss_config(
40
40
  truss_deploy_config, checkpoint_deploy
41
41
  )
42
42
 
43
- checkpoint_str = _build_full_checkpoint_string(truss_deploy_config)
43
+ checkpoint_str = build_full_checkpoint_string(truss_deploy_config)
44
44
 
45
45
  accelerator = checkpoint_deploy.compute.accelerator
46
46
 
@@ -71,7 +71,7 @@ def hydrate_full_checkpoint(
71
71
  return FullCheckpoint(training_job_id=job_id, paths=paths)
72
72
 
73
73
 
74
- def _build_full_checkpoint_string(truss_deploy_config) -> str:
74
+ def build_full_checkpoint_string(truss_deploy_config) -> str:
75
75
  """Build checkpoint string from artifact references for full checkpoints.
76
76
 
77
77
  Args:
@@ -0,0 +1,63 @@
1
+ from jinja2 import Template
2
+
3
+ from truss.base import truss_config
4
+ from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
5
+ START_COMMAND_ENVVAR_NAME,
6
+ )
7
+ from truss.cli.train.deploy_checkpoints.deploy_full_checkpoints import (
8
+ build_full_checkpoint_string,
9
+ )
10
+ from truss.cli.train.types import DeployCheckpointsConfigComplete
11
+ from truss_train.definitions import WhisperCheckpoint
12
+
13
+ from .deploy_checkpoints_helpers import (
14
+ setup_base_truss_config,
15
+ setup_environment_variables_and_secrets,
16
+ )
17
+
18
+ VLLM_WHISPER_START_COMMAND = Template(
19
+ "sh -c '{% if envvars %}{{ envvars }} {% endif %}"
20
+ 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
21
+ "vllm serve {{ model_path }} --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }}'"
22
+ )
23
+
24
+
25
+ def render_vllm_whisper_truss_config(
26
+ checkpoint_deploy: DeployCheckpointsConfigComplete,
27
+ ) -> truss_config.TrussConfig:
28
+ """Render truss config specifically for whisper checkpoints using vLLM."""
29
+ truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
30
+
31
+ start_command_envvars = setup_environment_variables_and_secrets(
32
+ truss_deploy_config, checkpoint_deploy
33
+ )
34
+
35
+ checkpoint_str = build_full_checkpoint_string(truss_deploy_config)
36
+
37
+ accelerator = checkpoint_deploy.compute.accelerator
38
+
39
+ start_command_args = {
40
+ "model_path": checkpoint_str,
41
+ "envvars": start_command_envvars,
42
+ "specify_tensor_parallelism": accelerator.count if accelerator else 1,
43
+ }
44
+ # Note: we set the start command as an environment variable in supervisord config.
45
+ # This is so that we don't have to change the supervisord config when the start command changes.
46
+ # Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
47
+ start_command = VLLM_WHISPER_START_COMMAND.render(**start_command_args)
48
+ truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
49
+ # Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
50
+ truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
51
+ f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
52
+ )
53
+
54
+ return truss_deploy_config
55
+
56
+
57
+ def hydrate_whisper_checkpoint(
58
+ job_id: str, checkpoint_id: str, checkpoint: dict
59
+ ) -> WhisperCheckpoint:
60
+ """Create a Checkpoint object for whisper model weights."""
61
+ # NOTE: Slash at the end is important since it means the checkpoint is a directory
62
+ paths = [f"rank-0/{checkpoint_id}/"]
63
+ return WhisperCheckpoint(training_job_id=job_id, paths=paths)
@@ -0,0 +1,17 @@
1
+ base_image:
2
+ image: vllm/vllm-openai:latest
3
+
4
+ docker_server:
5
+ start_command: sh -c "" # replaced when deploying
6
+ readiness_endpoint: /health
7
+ liveness_endpoint: /health
8
+ predict_endpoint: /v1/audio/transcriptions
9
+ server_port: 8000
10
+ runtime:
11
+ predict_concurrency : 256
12
+ environment_variables:
13
+ VLLM_LOGGING_LEVEL: WARNING
14
+ VLLM_USE_V1: 0
15
+ HF_HUB_ENABLE_HF_TRANSFER: 1
16
+ requirements:
17
+ - vllm[audio]
@@ -13,6 +13,7 @@ from truss.cli.train import common as train_common
13
13
  from truss.cli.train import core
14
14
  from truss.cli.utils import common
15
15
  from truss.cli.utils.output import console, error_console
16
+ from truss.remote.baseten.core import get_training_job_logs_with_pagination
16
17
  from truss.remote.baseten.remote import BasetenRemote
17
18
  from truss.remote.remote_factory import RemoteFactory
18
19
 
@@ -72,8 +73,11 @@ def _prepare_click_context(f: click.Command, params: dict) -> click.Context:
72
73
  @click.argument("config", type=Path, required=True)
73
74
  @click.option("--remote", type=str, required=False, help="Remote to use")
74
75
  @click.option("--tail", is_flag=True, help="Tail for status + logs after push.")
76
+ @click.option("--job-name", type=str, required=False, help="Name of the training job.")
75
77
  @common.common_options()
76
- def push_training_job(config: Path, remote: Optional[str], tail: bool):
78
+ def push_training_job(
79
+ config: Path, remote: Optional[str], tail: bool, job_name: Optional[str]
80
+ ):
77
81
  """Run a training job"""
78
82
  from truss_train import deployment
79
83
 
@@ -84,7 +88,9 @@ def push_training_job(config: Path, remote: Optional[str], tail: bool):
84
88
  remote_provider: BasetenRemote = cast(
85
89
  BasetenRemote, RemoteFactory.create(remote=remote)
86
90
  )
87
- job_resp = deployment.create_training_job_from_file(remote_provider, config)
91
+ job_resp = deployment.create_training_job_from_file(
92
+ remote_provider, config, job_name
93
+ )
88
94
 
89
95
  # Note: This post create logic needs to happen outside the context
90
96
  # of the above context manager, as only one console session can be active
@@ -138,7 +144,9 @@ def get_job_logs(
138
144
  )
139
145
 
140
146
  if not tail:
141
- logs = remote_provider.api.get_training_job_logs(project_id, job_id)
147
+ logs = get_training_job_logs_with_pagination(
148
+ remote_provider.api, project_id, job_id
149
+ )
142
150
  for log in cli_log_utils.parse_logs(logs):
143
151
  cli_log_utils.output_log(log)
144
152
  else:
@@ -15,8 +15,6 @@ from botocore.exceptions import ClientError, NoCredentialsError
15
15
  from google.cloud import storage
16
16
  from huggingface_hub import hf_hub_download
17
17
 
18
- from truss.base import constants
19
-
20
18
  B10CP_PATH_TRUSS_ENV_VAR_NAME = "B10CP_PATH_TRUSS"
21
19
 
22
20
  GCS_CREDENTIALS = "/app/data/service_account.json"
@@ -110,7 +108,7 @@ class RepositoryFile(ABC):
110
108
 
111
109
  class HuggingFaceFile(RepositoryFile):
112
110
  def download_to_cache(self):
113
- secret_path = Path(f"/etc/secrets/{constants.HF_ACCESS_TOKEN_FILE_NAME}")
111
+ secret_path = Path("/etc/secrets/hf-access-token")
114
112
  secret = secret_path.read_text().strip() if secret_path.exists() else None
115
113
  try:
116
114
  hf_hub_download(
@@ -73,7 +73,6 @@ from truss.contexts.image_builder.util import (
73
73
  )
74
74
  from truss.contexts.truss_context import TrussContext
75
75
  from truss.truss_handle.patch.hash import directory_content_hash
76
- from truss.util.basetenpointer import model_cache_hf_to_b10ptr
77
76
  from truss.util.jinja import read_template_from_fs
78
77
  from truss.util.path import (
79
78
  build_truss_target_directory,
@@ -93,6 +92,8 @@ USER_TRUSS_IGNORE_FILE = ".truss_ignore"
93
92
  GCS_CREDENTIALS = "service_account.json"
94
93
  S3_CREDENTIALS = "s3_credentials.json"
95
94
 
95
+ HF_ACCESS_TOKEN_FILE_NAME = "hf-access-token"
96
+
96
97
  CLOUD_BUCKET_CACHE = MODEL_CACHE_PATH
97
98
 
98
99
  HF_SOURCE_DIR = Path("./root/.cache/huggingface/hub/")
@@ -324,36 +325,27 @@ def get_files_to_model_cache_v1(config: TrussConfig, truss_dir: Path, build_dir:
324
325
  def build_model_cache_v2_and_copy_bptr_manifest(config: TrussConfig, build_dir: Path):
325
326
  assert config.model_cache.is_v2
326
327
  assert all(model.volume_folder is not None for model in config.model_cache.models)
327
- try:
328
- from truss_transfer import PyModelRepo, create_basetenpointer_from_models
329
-
330
- py_models = [
331
- PyModelRepo(
332
- repo_id=model.repo_id,
333
- revision=model.revision,
334
- runtime_secret_name=model.runtime_secret_name,
335
- allow_patterns=model.allow_patterns,
336
- ignore_patterns=model.ignore_patterns,
337
- volume_folder=model.volume_folder,
338
- kind=model.kind.value,
339
- )
340
- for model in config.model_cache.models
341
- ]
342
- # create BasetenPointer from models
343
- basetenpointer_json = create_basetenpointer_from_models(models=py_models)
344
- bptr_py = json.loads(basetenpointer_json)["pointers"]
345
- logging.info(f"created ({len(bptr_py)}) Basetenpointer")
346
- logging.info(f"pointers json: {basetenpointer_json}")
347
- with open(build_dir / "bptr-manifest", "w") as f:
348
- f.write(basetenpointer_json)
349
- except Exception as e:
350
- logging.warning(f"debug: failed to create BasetenPointer: {e}")
351
- # TODO: remove below section + remove logging lines above.
352
- # builds BasetenManifest for caching
353
- basetenpointers = model_cache_hf_to_b10ptr(config.model_cache)
354
- # write json of bastenpointers into build dir
355
- with open(build_dir / "bptr-manifest", "w") as f:
356
- f.write(basetenpointers.model_dump_json())
328
+ from truss_transfer import PyModelRepo, create_basetenpointer_from_models
329
+
330
+ py_models = [
331
+ PyModelRepo(
332
+ repo_id=model.repo_id,
333
+ revision=model.revision,
334
+ runtime_secret_name=model.runtime_secret_name,
335
+ allow_patterns=model.allow_patterns,
336
+ ignore_patterns=model.ignore_patterns,
337
+ volume_folder=model.volume_folder,
338
+ kind=model.kind.value,
339
+ )
340
+ for model in config.model_cache.models
341
+ ]
342
+ # create BasetenPointer from models
343
+ basetenpointer_json = create_basetenpointer_from_models(models=py_models)
344
+ bptr_py = json.loads(basetenpointer_json)["pointers"]
345
+ logging.info(f"created ({len(bptr_py)}) Basetenpointer")
346
+ logging.info(f"pointers json: {basetenpointer_json}")
347
+ with open(build_dir / "bptr-manifest", "w") as f:
348
+ f.write(basetenpointer_json)
357
349
 
358
350
 
359
351
  def generate_docker_server_nginx_config(build_dir, config):
@@ -819,7 +811,7 @@ class ServingImageBuilder(ImageBuilder):
819
811
  model_cache_v1=config.model_cache.is_v1,
820
812
  model_cache_v2=config.model_cache.is_v2,
821
813
  hf_access_token=hf_access_token,
822
- hf_access_token_file_name=constants.HF_ACCESS_TOKEN_FILE_NAME,
814
+ hf_access_token_file_name=HF_ACCESS_TOKEN_FILE_NAME,
823
815
  external_data_files=external_data_files,
824
816
  build_commands=build_commands,
825
817
  use_local_src=config.use_local_src,
@@ -669,6 +669,17 @@ class BasetenApi:
669
669
  # NB(nikhil): reverse order so latest logs are at the end
670
670
  return resp_json["logs"][::-1]
671
671
 
672
+ def _fetch_log_batch(
673
+ self, project_id: str, job_id: str, query_params: Dict[str, Any]
674
+ ) -> List[Any]:
675
+ """
676
+ Fetch a single batch of logs from the API.
677
+ """
678
+ resp_json = self._rest_api_client.post(
679
+ f"v1/training_projects/{project_id}/jobs/{job_id}/logs", body=query_params
680
+ )
681
+ return resp_json["logs"]
682
+
672
683
  def get_training_job_checkpoint_presigned_url(
673
684
  self, project_id: str, job_id: str, page_size: int = 100
674
685
  ) -> List[Dict[str, str]]:
@@ -3,7 +3,9 @@ import json
3
3
  import logging
4
4
  import pathlib
5
5
  import textwrap
6
- from typing import IO, TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Type
6
+ from typing import IO, TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Type
7
+
8
+ import requests
7
9
 
8
10
  from truss.base.errors import ValidationError
9
11
 
@@ -15,6 +17,7 @@ from truss.remote.baseten import custom_types as b10_types
15
17
  from truss.remote.baseten.api import BasetenApi
16
18
  from truss.remote.baseten.error import ApiError
17
19
  from truss.remote.baseten.utils.tar import create_tar_with_progress_bar
20
+ from truss.remote.baseten.utils.time import iso_to_millis
18
21
  from truss.remote.baseten.utils.transfer import multipart_upload_boto3
19
22
  from truss.util.path import load_trussignore_patterns_from_truss_dir
20
23
 
@@ -27,6 +30,16 @@ NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING = (
27
30
  "Model hasn't been deployed yet. No environments exist."
28
31
  )
29
32
 
33
+ # Maximum number of iterations to prevent infinite loops when paginating logs
34
+ MAX_ITERATIONS = 10_000
35
+ MIN_BATCH_SIZE = 100
36
+
37
+ # LIMIT for the number of logs to fetch per request defined by the server
38
+ MAX_BATCH_SIZE = 1000
39
+
40
+ NANOSECONDS_PER_MILLISECOND = 1_000_000
41
+ MILLISECONDS_PER_HOUR = 60 * 60 * 1000
42
+
30
43
 
31
44
  class ModelIdentifier:
32
45
  value: str
@@ -465,3 +478,198 @@ def validate_truss_config_against_backend(api: BasetenApi, config: str):
465
478
  raise ValidationError(
466
479
  f"Validation failed with the following errors:\n{error_messages}"
467
480
  )
481
+
482
+
483
+ def _build_log_query_params(
484
+ start_time: Optional[int], end_time: Optional[int], batch_size: int
485
+ ) -> Dict[str, Any]:
486
+ """
487
+ Build query parameters for log fetching request.
488
+
489
+ Args:
490
+ start_time: Start time in milliseconds since epoch
491
+ end_time: End time in milliseconds since epoch
492
+ batch_size: Number of logs to fetch per request
493
+
494
+ Returns:
495
+ Dictionary of query parameters with None values removed
496
+ """
497
+ query_body = {
498
+ "start_epoch_millis": start_time,
499
+ "end_epoch_millis": end_time,
500
+ "limit": batch_size,
501
+ "direction": "asc",
502
+ }
503
+
504
+ return {k: v for k, v in query_body.items() if v is not None}
505
+
506
+
507
+ def _handle_server_error_backoff(
508
+ error: requests.HTTPError, job_id: str, iteration: int, batch_size: int
509
+ ) -> int:
510
+ """
511
+ Slash the batch size in half and return the new batch size
512
+ """
513
+ old_batch_size = batch_size
514
+ new_batch_size = max(batch_size // 2, MIN_BATCH_SIZE)
515
+
516
+ logging.warning(
517
+ f"Server error (HTTP {error.response.status_code}) for job {job_id} at iteration {iteration}. "
518
+ f"Reducing batch size from {old_batch_size} to {new_batch_size}. Retrying..."
519
+ )
520
+
521
+ return new_batch_size
522
+
523
+
524
+ def _process_batch_logs(
525
+ batch_logs: List[Any], job_id: str, iteration: int, batch_size: int
526
+ ) -> Tuple[bool, Optional[int], Optional[int]]:
527
+ """
528
+ Process a batch of logs and determine if pagination should continue.
529
+
530
+ Args:
531
+ batch_logs: List of logs from the current batch
532
+ job_id: The job ID for logging
533
+ iteration: Current iteration number for logging
534
+ batch_size: Expected batch size
535
+
536
+ Returns:
537
+ Tuple of (should_continue, next_start_time, next_end_time)
538
+ """
539
+
540
+ # If no logs returned, we're done
541
+ if not batch_logs:
542
+ logging.info(f"No logs returned for job {job_id} at iteration {iteration}")
543
+ return False, None, None
544
+
545
+ # If we got fewer logs than the batch size, we've reached the end
546
+ if len(batch_logs) == 0:
547
+ logging.info(f"Reached end of logs for job {job_id} at iteration {iteration}")
548
+ return False, None, None
549
+
550
+ # Timestamp returned in nanoseconds for the last log in this batch converted
551
+ # to milliseconds to use as start for next iteration
552
+ last_log_timestamp = int(batch_logs[-1]["timestamp"]) // NANOSECONDS_PER_MILLISECOND
553
+
554
+ # Update start time for next iteration (add 1ms to avoid overlap)
555
+ next_start_time_ms = last_log_timestamp + 1
556
+
557
+ # Set end time to 2 hours from next start time, maximum time delta allowed by the API
558
+ next_end_time_ms = next_start_time_ms + 2 * MILLISECONDS_PER_HOUR
559
+
560
+ return True, next_start_time_ms, next_end_time_ms
561
+
562
+
563
+ class BatchedTrainingLogsFetcher:
564
+ """
565
+ Iterator for fetching training job logs in batches using time-based pagination.
566
+
567
+ This iterator handles the complexity of paginating through training job logs,
568
+ including error handling, batch size adjustment, and time window management.
569
+ """
570
+
571
+ def __init__(
572
+ self,
573
+ api: BasetenApi,
574
+ project_id: str,
575
+ job_id: str,
576
+ batch_size: int = MAX_BATCH_SIZE,
577
+ ):
578
+ self.api = api
579
+ self.project_id = project_id
580
+ self.job_id = job_id
581
+ self.batch_size = batch_size
582
+ self.iteration = 0
583
+ self.current_start_time = None
584
+ self.current_end_time = None
585
+ self._initialize_time_window()
586
+
587
+ def _initialize_time_window(self):
588
+ training_job = self.api.get_training_job(self.project_id, self.job_id)
589
+ self.current_start_time = iso_to_millis(
590
+ training_job["training_job"]["created_at"]
591
+ )
592
+ self.current_end_time = self.current_start_time + 2 * MILLISECONDS_PER_HOUR
593
+
594
+ def __iter__(self):
595
+ return self
596
+
597
+ def __next__(self) -> List[Any]:
598
+ if self.iteration >= MAX_ITERATIONS:
599
+ logging.warning(
600
+ f"Reached maximum iteration limit ({MAX_ITERATIONS}) while paginating "
601
+ f"training job logs for project_id={self.project_id}, job_id={self.job_id}."
602
+ )
603
+ raise StopIteration
604
+
605
+ query_params = _build_log_query_params(
606
+ self.current_start_time, self.current_end_time, self.batch_size
607
+ )
608
+
609
+ try:
610
+ batch_logs = self.api._fetch_log_batch(
611
+ self.project_id, self.job_id, query_params
612
+ )
613
+
614
+ should_continue, next_start_time, next_end_time = _process_batch_logs(
615
+ batch_logs, self.job_id, self.iteration, self.batch_size
616
+ )
617
+
618
+ if not should_continue:
619
+ logging.info(
620
+ f"Completed pagination for job {self.job_id}. Total iterations: {self.iteration + 1}"
621
+ )
622
+ raise StopIteration
623
+
624
+ self.current_start_time = next_start_time # type: ignore[assignment]
625
+ self.current_end_time = next_end_time # type: ignore[assignment]
626
+ self.iteration += 1
627
+
628
+ return batch_logs
629
+
630
+ except requests.HTTPError as e:
631
+ if 500 <= e.response.status_code < 600:
632
+ if self.batch_size == MIN_BATCH_SIZE:
633
+ logging.error(
634
+ "Failed to fetch all training job logs due to persistent server errors. "
635
+ "Please try again later or contact support if the issue persists."
636
+ )
637
+ raise StopIteration
638
+ self.batch_size = _handle_server_error_backoff(
639
+ e, self.job_id, self.iteration, self.batch_size
640
+ )
641
+ # Retry the same iteration with reduced batch size
642
+ return self.__next__()
643
+ else:
644
+ logging.error(
645
+ f"HTTP error fetching logs for job {self.job_id} at iteration {self.iteration}: {e}"
646
+ )
647
+ raise StopIteration
648
+ except Exception as e:
649
+ logging.error(
650
+ f"Error fetching logs for job {self.job_id} at iteration {self.iteration}: {e}"
651
+ )
652
+ raise StopIteration
653
+
654
+
655
+ def get_training_job_logs_with_pagination(
656
+ api: BasetenApi, project_id: str, job_id: str, batch_size: int = MAX_BATCH_SIZE
657
+ ) -> List[Any]:
658
+ """
659
+ This method implements forward time-based pagination by starting from the earliest
660
+ available log and working forward in time. It uses the timestamp of the newest log in
661
+ each batch as the start time for the next request.
662
+
663
+ Returns:
664
+ List of all logs in chronological order (oldest first)
665
+ """
666
+ all_logs = []
667
+
668
+ logs_iterator = BatchedTrainingLogsFetcher(api, project_id, job_id, batch_size)
669
+
670
+ for batch_logs in logs_iterator:
671
+ all_logs.extend(batch_logs)
672
+
673
+ logging.info(f"Completed pagination for job {job_id}. Total logs: {len(all_logs)}")
674
+
675
+ return all_logs
@@ -0,0 +1,15 @@
1
+ from dateutil import parser
2
+
3
+
4
+ def iso_to_millis(ts: str) -> int:
5
+ """
6
+ Convert ISO 8601 timestamp string to milliseconds since epoch.
7
+
8
+ Args:
9
+ ts: ISO 8601 timestamp string (handles Zulu/UTC (Z) automatically)
10
+
11
+ Returns:
12
+ Milliseconds since epoch as integer
13
+ """
14
+ dt = parser.isoparse(ts) # handles Zulu/UTC (Z) automatically
15
+ return int(dt.timestamp() * 1000)