truss 0.11.10rc4__py3-none-any.whl → 0.11.10rc500__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/base/trt_llm_config.py +1 -0
- truss/cli/train/core.py +19 -6
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +40 -12
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -27
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -2
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -3
- truss/cli/train/types.py +19 -0
- truss/cli/train_commands.py +45 -3
- truss/remote/baseten/api.py +1 -1
- truss/remote/baseten/service.py +0 -7
- truss/templates/control/control/server.py +1 -1
- truss/templates/control/requirements.txt +1 -2
- truss/templates/server/requirements.txt +1 -2
- truss/templates/server/truss_server.py +1 -1
- truss/templates/server.Dockerfile.jinja +1 -1
- truss/tests/cli/train/test_deploy_checkpoints.py +2 -3
- truss/tests/remote/baseten/test_service.py +56 -0
- truss/util/__init__.py +0 -0
- {truss-0.11.10rc4.dist-info → truss-0.11.10rc500.dist-info}/METADATA +2 -2
- {truss-0.11.10rc4.dist-info → truss-0.11.10rc500.dist-info}/RECORD +24 -23
- truss_train/definitions.py +3 -2
- {truss-0.11.10rc4.dist-info → truss-0.11.10rc500.dist-info}/WHEEL +0 -0
- {truss-0.11.10rc4.dist-info → truss-0.11.10rc500.dist-info}/entry_points.txt +0 -0
- {truss-0.11.10rc4.dist-info → truss-0.11.10rc500.dist-info}/licenses/LICENSE +0 -0
truss/base/trt_llm_config.py
CHANGED
truss/cli/train/core.py
CHANGED
|
@@ -16,7 +16,11 @@ from rich.text import Text
|
|
|
16
16
|
|
|
17
17
|
from truss.cli.train import common, deploy_checkpoints
|
|
18
18
|
from truss.cli.train.metrics_watcher import MetricsWatcher
|
|
19
|
-
from truss.cli.train.types import
|
|
19
|
+
from truss.cli.train.types import (
|
|
20
|
+
DeployCheckpointArgs,
|
|
21
|
+
DeployCheckpointsConfigComplete,
|
|
22
|
+
DeploySuccessResult,
|
|
23
|
+
)
|
|
20
24
|
from truss.cli.utils import common as cli_common
|
|
21
25
|
from truss.cli.utils.output import console
|
|
22
26
|
from truss.remote.baseten.custom_types import (
|
|
@@ -244,17 +248,25 @@ def view_training_job_metrics(
|
|
|
244
248
|
|
|
245
249
|
def create_model_version_from_inference_template(
|
|
246
250
|
remote_provider: BasetenRemote, args: DeployCheckpointArgs
|
|
247
|
-
) ->
|
|
251
|
+
) -> DeploySuccessResult:
|
|
248
252
|
if not args.deploy_config_path:
|
|
249
253
|
return deploy_checkpoints.create_model_version_from_inference_template(
|
|
250
|
-
remote_provider,
|
|
254
|
+
remote_provider,
|
|
255
|
+
DeployCheckpointsConfig(),
|
|
256
|
+
args.project_id,
|
|
257
|
+
args.job_id,
|
|
258
|
+
args.dry_run,
|
|
251
259
|
)
|
|
252
260
|
#### User provided a checkpoint deploy config file
|
|
253
261
|
with loader.import_deploy_checkpoints_config(
|
|
254
262
|
Path(args.deploy_config_path)
|
|
255
263
|
) as checkpoint_deploy:
|
|
256
264
|
return deploy_checkpoints.create_model_version_from_inference_template(
|
|
257
|
-
remote_provider,
|
|
265
|
+
remote_provider,
|
|
266
|
+
checkpoint_deploy,
|
|
267
|
+
args.project_id,
|
|
268
|
+
args.job_id,
|
|
269
|
+
args.dry_run,
|
|
258
270
|
)
|
|
259
271
|
|
|
260
272
|
|
|
@@ -262,7 +274,7 @@ def _get_checkpoint_names(
|
|
|
262
274
|
checkpoint_deploy_config: DeployCheckpointsConfigComplete,
|
|
263
275
|
) -> list[str]:
|
|
264
276
|
return [
|
|
265
|
-
checkpoint.
|
|
277
|
+
checkpoint.checkpoint_name
|
|
266
278
|
for checkpoint in checkpoint_deploy_config.checkpoint_details.checkpoints
|
|
267
279
|
]
|
|
268
280
|
|
|
@@ -299,9 +311,10 @@ def display_training_job(
|
|
|
299
311
|
table.add_column("Value")
|
|
300
312
|
|
|
301
313
|
# Basic job details
|
|
314
|
+
table.add_row("Job Name", job["name"])
|
|
315
|
+
table.add_row("Job ID", job["id"])
|
|
302
316
|
table.add_row("Project ID", job["training_project"]["id"])
|
|
303
317
|
table.add_row("Project Name", job["training_project"]["name"])
|
|
304
|
-
table.add_row("Job ID", job["id"])
|
|
305
318
|
table.add_row("Status", job["current_status"])
|
|
306
319
|
table.add_row("Instance Type", job["instance_type"]["name"])
|
|
307
320
|
table.add_row("Created", cli_common.format_localized_time(job["created_at"]))
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import re
|
|
2
3
|
from collections import OrderedDict
|
|
3
4
|
from typing import List, Optional, Union
|
|
@@ -7,7 +8,11 @@ from InquirerPy import inquirer
|
|
|
7
8
|
|
|
8
9
|
from truss.base import truss_config
|
|
9
10
|
from truss.cli.train import common
|
|
10
|
-
from truss.cli.train.types import
|
|
11
|
+
from truss.cli.train.types import (
|
|
12
|
+
DeployCheckpointsConfigComplete,
|
|
13
|
+
DeploySuccessModelVersion,
|
|
14
|
+
DeploySuccessResult,
|
|
15
|
+
)
|
|
11
16
|
from truss.cli.utils.output import console
|
|
12
17
|
from truss.remote.baseten.remote import BasetenRemote
|
|
13
18
|
from truss_train.definitions import (
|
|
@@ -35,13 +40,16 @@ def create_model_version_from_inference_template(
|
|
|
35
40
|
checkpoint_deploy_config: DeployCheckpointsConfig,
|
|
36
41
|
project_id: Optional[str],
|
|
37
42
|
job_id: Optional[str],
|
|
38
|
-
|
|
43
|
+
dry_run: bool,
|
|
44
|
+
) -> DeploySuccessResult:
|
|
39
45
|
checkpoint_deploy_config = _hydrate_deploy_config(
|
|
40
46
|
checkpoint_deploy_config, remote_provider, project_id, job_id
|
|
41
47
|
)
|
|
42
48
|
|
|
43
49
|
request_data = _build_inference_template_request(
|
|
44
|
-
checkpoint_deploy_config,
|
|
50
|
+
checkpoint_deploy_config=checkpoint_deploy_config,
|
|
51
|
+
remote_provider=remote_provider,
|
|
52
|
+
dry_run=dry_run,
|
|
45
53
|
)
|
|
46
54
|
|
|
47
55
|
# Call the GraphQL mutation to create model version from inference template
|
|
@@ -49,7 +57,9 @@ def create_model_version_from_inference_template(
|
|
|
49
57
|
result = remote_provider.api.create_model_version_from_inference_template(
|
|
50
58
|
request_data
|
|
51
59
|
)
|
|
60
|
+
truss_config_result = _get_truss_config_from_result(result)
|
|
52
61
|
|
|
62
|
+
model_version = None
|
|
53
63
|
if result and result.get("model_version"):
|
|
54
64
|
console.print(
|
|
55
65
|
f"Successfully created model version: {result['model_version']['name']}",
|
|
@@ -58,7 +68,10 @@ def create_model_version_from_inference_template(
|
|
|
58
68
|
console.print(
|
|
59
69
|
f"Model version ID: {result['model_version']['id']}", style="yellow"
|
|
60
70
|
)
|
|
61
|
-
|
|
71
|
+
model_version = DeploySuccessModelVersion.model_validate(
|
|
72
|
+
result["model_version"]
|
|
73
|
+
)
|
|
74
|
+
elif not dry_run:
|
|
62
75
|
console.print(
|
|
63
76
|
"Warning: Unexpected response format from server", style="yellow"
|
|
64
77
|
)
|
|
@@ -68,12 +81,31 @@ def create_model_version_from_inference_template(
|
|
|
68
81
|
console.print(f"Error creating model version: {e}", style="red")
|
|
69
82
|
raise
|
|
70
83
|
|
|
71
|
-
return
|
|
84
|
+
return DeploySuccessResult(
|
|
85
|
+
deploy_config=checkpoint_deploy_config,
|
|
86
|
+
truss_config=truss_config_result,
|
|
87
|
+
model_version=model_version,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _get_truss_config_from_result(result: dict) -> Optional[truss_config.TrussConfig]:
|
|
92
|
+
if result and result.get("truss_config"):
|
|
93
|
+
truss_config_dict = json.loads(result["truss_config"])
|
|
94
|
+
return truss_config.TrussConfig.from_dict(truss_config_dict)
|
|
95
|
+
# Although this should never happen, we defensively allow ourselves to return None
|
|
96
|
+
# because we need a failure to handle the truss config doesn't necessarily mean we failed to deploy
|
|
97
|
+
# the model version.
|
|
98
|
+
console.print(
|
|
99
|
+
"No truss config returned. Reach out to Baseten for support if this persists.",
|
|
100
|
+
style="red",
|
|
101
|
+
)
|
|
102
|
+
return None
|
|
72
103
|
|
|
73
104
|
|
|
74
105
|
def _build_inference_template_request(
|
|
75
106
|
checkpoint_deploy_config: DeployCheckpointsConfigComplete,
|
|
76
107
|
remote_provider: BasetenRemote,
|
|
108
|
+
dry_run: bool,
|
|
77
109
|
) -> dict:
|
|
78
110
|
"""
|
|
79
111
|
Build the GraphQL request data structure for createModelVersionFromInferenceTemplate mutation.
|
|
@@ -83,18 +115,12 @@ def _build_inference_template_request(
|
|
|
83
115
|
weights_sources = []
|
|
84
116
|
for checkpoint in checkpoint_deploy_config.checkpoint_details.checkpoints:
|
|
85
117
|
# Extract checkpoint name from the first path
|
|
86
|
-
checkpoint_name = (
|
|
87
|
-
checkpoint.paths[0].strip("/").split("/")[-1]
|
|
88
|
-
if checkpoint.paths
|
|
89
|
-
else "checkpoint"
|
|
90
|
-
)
|
|
91
|
-
|
|
92
118
|
weights_source = {
|
|
93
119
|
"weight_source_type": "B10_CHECKPOINTING",
|
|
94
120
|
"b10_training_checkpoint_weights_source": {
|
|
95
121
|
"checkpoint": {
|
|
96
122
|
"training_job_id": checkpoint.training_job_id,
|
|
97
|
-
"checkpoint_name": checkpoint_name,
|
|
123
|
+
"checkpoint_name": checkpoint.checkpoint_name,
|
|
98
124
|
}
|
|
99
125
|
},
|
|
100
126
|
}
|
|
@@ -126,6 +152,7 @@ def _build_inference_template_request(
|
|
|
126
152
|
"weights_sources": weights_sources,
|
|
127
153
|
"inference_stack": inference_stack,
|
|
128
154
|
"instance_type_id": instance_type_id,
|
|
155
|
+
"dry_run": dry_run,
|
|
129
156
|
}
|
|
130
157
|
|
|
131
158
|
return request_data
|
|
@@ -276,6 +303,7 @@ def _ensure_checkpoint_details(
|
|
|
276
303
|
job_id: Optional[str],
|
|
277
304
|
) -> CheckpointList:
|
|
278
305
|
if checkpoint_details and checkpoint_details.checkpoints:
|
|
306
|
+
# TODO: check here
|
|
279
307
|
return _process_user_provided_checkpoints(checkpoint_details, remote_provider)
|
|
280
308
|
else:
|
|
281
309
|
return _prompt_user_for_checkpoint_details(
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
|
-
|
|
3
1
|
from truss_train.definitions import FullCheckpoint
|
|
4
2
|
|
|
5
3
|
|
|
@@ -8,28 +6,4 @@ def hydrate_full_checkpoint(
|
|
|
8
6
|
) -> FullCheckpoint:
|
|
9
7
|
"""Create a Checkpoint object for full model weights."""
|
|
10
8
|
# NOTE: Slash at the end is important since it means the checkpoint is a directory
|
|
11
|
-
|
|
12
|
-
return FullCheckpoint(training_job_id=job_id, paths=paths)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def build_full_checkpoint_string(truss_deploy_config) -> str:
|
|
16
|
-
"""Build checkpoint string from artifact references for full checkpoints.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
truss_deploy_config: The truss deploy configuration containing training checkpoints.
|
|
20
|
-
|
|
21
|
-
Returns:
|
|
22
|
-
A space-separated string of checkpoint paths.
|
|
23
|
-
"""
|
|
24
|
-
checkpoint_parts = []
|
|
25
|
-
for (
|
|
26
|
-
truss_checkpoint
|
|
27
|
-
) in truss_deploy_config.training_checkpoints.artifact_references: # type: ignore
|
|
28
|
-
ckpt_path = Path(
|
|
29
|
-
truss_deploy_config.training_checkpoints.download_folder, # type: ignore
|
|
30
|
-
truss_checkpoint.training_job_id,
|
|
31
|
-
truss_checkpoint.paths[0],
|
|
32
|
-
)
|
|
33
|
-
checkpoint_parts.append(str(ckpt_path))
|
|
34
|
-
|
|
35
|
-
return " ".join(checkpoint_parts)
|
|
9
|
+
return FullCheckpoint(training_job_id=job_id, checkpoint_name=checkpoint_id)
|
|
@@ -11,11 +11,10 @@ def hydrate_lora_checkpoint(
|
|
|
11
11
|
) -> LoRACheckpoint:
|
|
12
12
|
"""Create a LoRA-specific Checkpoint object."""
|
|
13
13
|
# NOTE: Slash at the end is important since it means the checkpoint is a directory
|
|
14
|
-
paths = [f"rank-0/{checkpoint_id}/"]
|
|
15
14
|
return LoRACheckpoint(
|
|
16
15
|
training_job_id=job_id,
|
|
17
|
-
paths=paths,
|
|
18
16
|
lora_details=LoRADetails(rank=_get_lora_rank(checkpoint)),
|
|
17
|
+
checkpoint_name=checkpoint_id,
|
|
19
18
|
)
|
|
20
19
|
|
|
21
20
|
|
|
@@ -5,6 +5,4 @@ def hydrate_whisper_checkpoint(
|
|
|
5
5
|
job_id: str, checkpoint_id: str, checkpoint: dict
|
|
6
6
|
) -> WhisperCheckpoint:
|
|
7
7
|
"""Create a Checkpoint object for whisper model weights."""
|
|
8
|
-
|
|
9
|
-
paths = [f"rank-0/{checkpoint_id}/"]
|
|
10
|
-
return WhisperCheckpoint(training_job_id=job_id, paths=paths)
|
|
8
|
+
return WhisperCheckpoint(training_job_id=job_id, checkpoint_name=checkpoint_id)
|
truss/cli/train/types.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from typing import Optional
|
|
3
3
|
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from truss.base import truss_config
|
|
4
7
|
from truss_train.definitions import (
|
|
5
8
|
CheckpointList,
|
|
6
9
|
Compute,
|
|
@@ -11,6 +14,7 @@ from truss_train.definitions import (
|
|
|
11
14
|
|
|
12
15
|
@dataclass
|
|
13
16
|
class DeployCheckpointArgs:
|
|
17
|
+
dry_run: bool
|
|
14
18
|
project_id: Optional[str]
|
|
15
19
|
job_id: Optional[str]
|
|
16
20
|
deploy_config_path: Optional[str]
|
|
@@ -26,3 +30,18 @@ class DeployCheckpointsConfigComplete(DeployCheckpointsConfig):
|
|
|
26
30
|
model_name: str
|
|
27
31
|
runtime: DeployCheckpointsRuntime
|
|
28
32
|
compute: Compute
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class DeploySuccessModelVersion(BaseModel):
|
|
36
|
+
# allow extra fields to be forwards compatible with server
|
|
37
|
+
class Config:
|
|
38
|
+
extra = "allow"
|
|
39
|
+
|
|
40
|
+
name: str
|
|
41
|
+
id: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DeploySuccessResult(BaseModel):
|
|
45
|
+
deploy_config: DeployCheckpointsConfigComplete
|
|
46
|
+
truss_config: Optional[truss_config.TrussConfig]
|
|
47
|
+
model_version: Optional[DeploySuccessModelVersion]
|
truss/cli/train_commands.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import sys
|
|
3
|
+
from datetime import datetime
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import Optional, cast
|
|
5
6
|
|
|
@@ -22,6 +23,7 @@ from truss.cli.train.core import (
|
|
|
22
23
|
SORT_ORDER_ASC,
|
|
23
24
|
SORT_ORDER_DESC,
|
|
24
25
|
)
|
|
26
|
+
from truss.cli.train.types import DeploySuccessResult
|
|
25
27
|
from truss.cli.utils import common
|
|
26
28
|
from truss.cli.utils.output import console, error_console
|
|
27
29
|
from truss.remote.baseten.core import get_training_job_logs_with_pagination
|
|
@@ -306,6 +308,12 @@ def get_job_metrics(
|
|
|
306
308
|
@click.option(
|
|
307
309
|
"--dry-run", is_flag=True, help="Generate a truss config without deploying"
|
|
308
310
|
)
|
|
311
|
+
@click.option(
|
|
312
|
+
"--truss-config-output-dir",
|
|
313
|
+
type=str,
|
|
314
|
+
required=False,
|
|
315
|
+
help="Path to output the truss config to. If not provided, will output to truss_configs/<model_version_name>_<model_version_id> or truss_configs/dry_run_<timestamp> if dry run.",
|
|
316
|
+
)
|
|
309
317
|
@click.option("--remote", type=str, required=False, help="Remote to use")
|
|
310
318
|
@common.common_options()
|
|
311
319
|
def deploy_checkpoints(
|
|
@@ -315,6 +323,7 @@ def deploy_checkpoints(
|
|
|
315
323
|
config: Optional[str],
|
|
316
324
|
remote: Optional[str],
|
|
317
325
|
dry_run: bool,
|
|
326
|
+
truss_config_output_dir: Optional[str],
|
|
318
327
|
):
|
|
319
328
|
"""
|
|
320
329
|
Deploy a LoRA checkpoint via vLLM.
|
|
@@ -332,13 +341,46 @@ def deploy_checkpoints(
|
|
|
332
341
|
result = train_cli.create_model_version_from_inference_template(
|
|
333
342
|
remote_provider,
|
|
334
343
|
train_cli.DeployCheckpointArgs(
|
|
335
|
-
project_id=project_id,
|
|
344
|
+
project_id=project_id,
|
|
345
|
+
job_id=job_id,
|
|
346
|
+
deploy_config_path=config,
|
|
347
|
+
dry_run=dry_run,
|
|
336
348
|
),
|
|
337
349
|
)
|
|
338
350
|
|
|
339
351
|
if dry_run:
|
|
340
|
-
console.print("--dry-run flag provided
|
|
341
|
-
|
|
352
|
+
console.print("did not deploy because --dry-run flag provided", style="yellow")
|
|
353
|
+
|
|
354
|
+
_write_truss_config(result, truss_config_output_dir, dry_run)
|
|
355
|
+
|
|
356
|
+
if not dry_run:
|
|
357
|
+
train_cli.print_deploy_checkpoints_success_message(result.deploy_config)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _write_truss_config(
|
|
361
|
+
result: DeploySuccessResult, truss_config_output_dir: Optional[str], dry_run: bool
|
|
362
|
+
) -> None:
|
|
363
|
+
if not result.truss_config:
|
|
364
|
+
return
|
|
365
|
+
# format: 20251006_123456
|
|
366
|
+
datestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
367
|
+
folder_name = (
|
|
368
|
+
f"{result.model_version.name}_{result.model_version.id}"
|
|
369
|
+
if result.model_version
|
|
370
|
+
else f"dry_run_{datestamp}"
|
|
371
|
+
)
|
|
372
|
+
output_dir_str = truss_config_output_dir or f"truss_configs/{folder_name}"
|
|
373
|
+
output_dir = Path(output_dir_str)
|
|
374
|
+
output_path = output_dir / "config.yaml"
|
|
375
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
376
|
+
console.print(f"Writing truss config to {output_path}", style="yellow")
|
|
377
|
+
console.print(f"👀 Run `cat {output_path}` to view the truss config", style="green")
|
|
378
|
+
if dry_run:
|
|
379
|
+
console.print(
|
|
380
|
+
f"🚀 Run `cd {output_dir} && truss push --publish` to deploy the truss",
|
|
381
|
+
style="green",
|
|
382
|
+
)
|
|
383
|
+
result.truss_config.write_to_yaml_file(output_path)
|
|
342
384
|
|
|
343
385
|
|
|
344
386
|
@train.command(name="download")
|
truss/remote/baseten/api.py
CHANGED
truss/remote/baseten/service.py
CHANGED
|
@@ -137,13 +137,6 @@ class BasetenService(TrussService):
|
|
|
137
137
|
|
|
138
138
|
return decode_content()
|
|
139
139
|
|
|
140
|
-
parsed_response = response.json()
|
|
141
|
-
|
|
142
|
-
if "error" in parsed_response:
|
|
143
|
-
# In the case that the model is in a non-ready state, the response
|
|
144
|
-
# will be a json with an `error` key.
|
|
145
|
-
return parsed_response
|
|
146
|
-
|
|
147
140
|
return response.json()
|
|
148
141
|
|
|
149
142
|
def authenticate(self) -> dict:
|
|
@@ -72,9 +72,9 @@ class ControlServer:
|
|
|
72
72
|
# httptools installed, which does not work with our requests & version
|
|
73
73
|
# of uvicorn.
|
|
74
74
|
http="h11",
|
|
75
|
+
loop="uvloop",
|
|
75
76
|
**extra_kwargs,
|
|
76
77
|
)
|
|
77
|
-
cfg.setup_event_loop()
|
|
78
78
|
|
|
79
79
|
server = uvicorn.Server(cfg)
|
|
80
80
|
asyncio.run(server.serve())
|
|
@@ -7,7 +7,6 @@ python-json-logger>=2.0.2
|
|
|
7
7
|
tenacity>=8.1.0
|
|
8
8
|
# To avoid divergence, this should follow the latest release.
|
|
9
9
|
truss==0.11.1
|
|
10
|
-
|
|
11
|
-
uvicorn>=0.24.0,<0.36.0
|
|
10
|
+
uvicorn>=0.24.0
|
|
12
11
|
uvloop>=0.19.0
|
|
13
12
|
websockets>=10.0
|
|
@@ -19,7 +19,6 @@ python-json-logger>=2.0.2
|
|
|
19
19
|
pyyaml>=6.0.0
|
|
20
20
|
requests>=2.31.0
|
|
21
21
|
truss-transfer==0.0.32
|
|
22
|
-
|
|
23
|
-
uvicorn>=0.24.0,<0.36.0
|
|
22
|
+
uvicorn>=0.24.0
|
|
24
23
|
uvloop>=0.19.0
|
|
25
24
|
websockets>=10.0
|
|
@@ -497,9 +497,9 @@ class TrussServer:
|
|
|
497
497
|
timeout_graceful_shutdown=TIMEOUT_GRACEFUL_SHUTDOWN,
|
|
498
498
|
log_config=log_config.make_log_config(log_level),
|
|
499
499
|
ws_max_size=WS_MAX_MSG_SZ_BYTES,
|
|
500
|
+
loop="uvloop",
|
|
500
501
|
**extra_kwargs,
|
|
501
502
|
)
|
|
502
|
-
cfg.setup_event_loop() # Call this so uvloop gets used
|
|
503
503
|
server = uvicorn.Server(config=cfg)
|
|
504
504
|
self._server = server
|
|
505
505
|
asyncio.run(server.serve())
|
|
@@ -69,7 +69,7 @@ COPY --chown={{ default_owner }} ./{{ config.data_dir }} ${APP_HOME}/data
|
|
|
69
69
|
|
|
70
70
|
{%- if model_cache_v2 %}
|
|
71
71
|
{# v0.0.9, keep synced with server_requirements.txt #}
|
|
72
|
-
RUN curl -sSL --fail --retry 5 --retry-delay 2 -o /usr/local/bin/truss-transfer-cli https://github.com/basetenlabs/truss/releases/download/v0.
|
|
72
|
+
RUN curl -sSL --fail --retry 5 --retry-delay 2 -o /usr/local/bin/truss-transfer-cli https://github.com/basetenlabs/truss/releases/download/v0.11.10rc5/truss-transfer-cli-v0.11.10rc5-linux-x86_64-unknown-linux-musl
|
|
73
73
|
RUN chmod +x /usr/local/bin/truss-transfer-cli
|
|
74
74
|
RUN mkdir /static-bptr
|
|
75
75
|
RUN echo "hash {{model_cache_hash}}"
|
|
@@ -115,8 +115,7 @@ def test_hydrate_full_checkpoint():
|
|
|
115
115
|
assert isinstance(result, definitions.FullCheckpoint)
|
|
116
116
|
assert result.training_job_id == job_id
|
|
117
117
|
assert result.model_weight_format == ModelWeightsFormat.FULL
|
|
118
|
-
assert
|
|
119
|
-
assert result.paths[0] == f"rank-0/{checkpoint_id}/"
|
|
118
|
+
assert result.checkpoint_name == checkpoint_id
|
|
120
119
|
|
|
121
120
|
|
|
122
121
|
def test_hydrate_checkpoint_dispatcher_full():
|
|
@@ -272,6 +271,6 @@ def test_hydrate_whisper_checkpoint():
|
|
|
272
271
|
result = hydrate_whisper_checkpoint(job_id, checkpoint_id, checkpoint)
|
|
273
272
|
|
|
274
273
|
assert result.training_job_id == job_id
|
|
275
|
-
assert result.
|
|
274
|
+
assert result.checkpoint_name == checkpoint_id
|
|
276
275
|
assert result.model_weight_format == definitions.ModelWeightsFormat.WHISPER
|
|
277
276
|
assert isinstance(result, definitions.WhisperCheckpoint)
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
+
from unittest.mock import MagicMock
|
|
2
|
+
|
|
1
3
|
from truss.remote.baseten import service
|
|
4
|
+
from truss.remote.baseten.core import ModelVersionHandle
|
|
2
5
|
|
|
3
6
|
|
|
4
7
|
def test_model_invoke_url_prod():
|
|
@@ -65,3 +68,56 @@ def test_chain_logs_url():
|
|
|
65
68
|
"https://app.baseten.co", "abc", "666", "543"
|
|
66
69
|
)
|
|
67
70
|
assert url == "https://app.baseten.co/chains/abc/logs/666/543"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def test_predict_response_to_json():
|
|
74
|
+
"""Test that predict method returns JSON response for normal dict result."""
|
|
75
|
+
# Create a mock BasetenService
|
|
76
|
+
mock_handle = MagicMock(spec=ModelVersionHandle)
|
|
77
|
+
mock_handle.model_id = "test-model"
|
|
78
|
+
mock_handle.version_id = "test-version"
|
|
79
|
+
mock_handle.hostname = "https://model-test.api.baseten.co"
|
|
80
|
+
|
|
81
|
+
mock_api = MagicMock()
|
|
82
|
+
mock_api.app_url = "https://app.baseten.co"
|
|
83
|
+
|
|
84
|
+
service_instance = service.BasetenService(
|
|
85
|
+
model_version_handle=mock_handle,
|
|
86
|
+
is_draft=False,
|
|
87
|
+
api_key="test-key",
|
|
88
|
+
service_url="https://test.com",
|
|
89
|
+
api=mock_api,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Mock the _send_request method to return a successful JSON response
|
|
93
|
+
mock_response = MagicMock()
|
|
94
|
+
mock_response.json.return_value = {"result": "success"}
|
|
95
|
+
service_instance._send_request = MagicMock(return_value=mock_response)
|
|
96
|
+
|
|
97
|
+
# Test predict method
|
|
98
|
+
result = service_instance.predict({"input": "test"})
|
|
99
|
+
|
|
100
|
+
# Verify that the JSON response is returned directly
|
|
101
|
+
assert result == {"result": "success"}
|
|
102
|
+
|
|
103
|
+
# Test non-dict response types below
|
|
104
|
+
|
|
105
|
+
# With integer response
|
|
106
|
+
mock_response.json.return_value = 42
|
|
107
|
+
result = service_instance.predict({"input": "test"})
|
|
108
|
+
assert result == 42
|
|
109
|
+
|
|
110
|
+
# With string response
|
|
111
|
+
mock_response.json.return_value = "success"
|
|
112
|
+
result = service_instance.predict({"input": "test"})
|
|
113
|
+
assert result == "success"
|
|
114
|
+
|
|
115
|
+
# With list response
|
|
116
|
+
mock_response.json.return_value = [1, 2, 3, 4]
|
|
117
|
+
result = service_instance.predict({"input": "test"})
|
|
118
|
+
assert result == [1, 2, 3, 4]
|
|
119
|
+
|
|
120
|
+
# With boolean response
|
|
121
|
+
mock_response.json.return_value = True
|
|
122
|
+
result = service_instance.predict({"input": "test"})
|
|
123
|
+
assert result is True
|
truss/util/__init__.py
ADDED
|
File without changes
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: truss
|
|
3
|
-
Version: 0.11.
|
|
3
|
+
Version: 0.11.10rc500
|
|
4
4
|
Summary: A seamless bridge from model development to model delivery
|
|
5
5
|
Project-URL: Repository, https://github.com/basetenlabs/truss
|
|
6
6
|
Project-URL: Homepage, https://truss.baseten.co
|
|
@@ -37,7 +37,7 @@ Requires-Dist: rich<14,>=13.4.2
|
|
|
37
37
|
Requires-Dist: ruff>=0.4.8
|
|
38
38
|
Requires-Dist: tenacity>=8.0.1
|
|
39
39
|
Requires-Dist: tomlkit>=0.13.2
|
|
40
|
-
Requires-Dist: truss-transfer<0.0.
|
|
40
|
+
Requires-Dist: truss-transfer<0.0.36,>=0.0.32
|
|
41
41
|
Requires-Dist: watchfiles<0.20,>=0.19.0
|
|
42
42
|
Description-Content-Type: text/markdown
|
|
43
43
|
|
|
@@ -5,30 +5,30 @@ truss/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
5
5
|
truss/base/constants.py,sha256=sExArdnuGg83z83XMgaQ4b8SS3V_j_bJEpOATDGJzpE,3600
|
|
6
6
|
truss/base/custom_types.py,sha256=FUSIT2lPOQb6gfg6IzT63YBV8r8L6NIZ0D74Fp3e_jQ,2835
|
|
7
7
|
truss/base/errors.py,sha256=zDVLEvseTChdPP0oNhBBQCtQUtZJUaof5zeWMIjqz6o,691
|
|
8
|
-
truss/base/trt_llm_config.py,sha256
|
|
8
|
+
truss/base/trt_llm_config.py,sha256=81ZZxRQF3o29HLCX6nlXtPwALejcdns6c4mbrExwASk,32958
|
|
9
9
|
truss/base/truss_config.py,sha256=7CtiJIwMHtDU8Wzn8UTJUVVunD0pWFl4QUVycK2aIpY,28055
|
|
10
10
|
truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
|
|
11
11
|
truss/cli/chains_commands.py,sha256=Kpa5mCg6URAJQE2ZmZfVQFhjBHEitKT28tKiW0H6XAI,17406
|
|
12
12
|
truss/cli/cli.py,sha256=PaMkuwXZflkU7sa1tEoT_Zmy-iBkEZs1m4IVqcieaeo,30367
|
|
13
13
|
truss/cli/remote_cli.py,sha256=G_xCKRXzgkCmkiZJhUFfsv5YSVgde1jLA5LPQitpZgI,1905
|
|
14
|
-
truss/cli/train_commands.py,sha256=
|
|
14
|
+
truss/cli/train_commands.py,sha256=CrVqWsdkmSxgi3i2sSEyiE4QdfD0Z96F2Ib-PMZJjm8,20444
|
|
15
15
|
truss/cli/logs/base_watcher.py,sha256=vuqteoaMVGX34cgKcETf4X_gOkvnSnDaWz1_pbeFhqs,3343
|
|
16
16
|
truss/cli/logs/model_log_watcher.py,sha256=38vQCcNItfDrTKucvdJ10ZYLOcbGa5ZAKUqUnV4nH34,1971
|
|
17
17
|
truss/cli/logs/training_log_watcher.py,sha256=r6HRqrLnz-PiKTUXiDYYxg4ZnP8vYcXlEX1YmgHhzlo,1173
|
|
18
18
|
truss/cli/logs/utils.py,sha256=z-U_FG4BUzdZLbE3BnXb4DZQ0zt3LSZ3PiQpLaDuc3o,1031
|
|
19
19
|
truss/cli/train/common.py,sha256=xTR41U5FeSndXfNBBHF9wF5XwZH1sOIVFlv-XHjsKIU,1547
|
|
20
|
-
truss/cli/train/core.py,sha256=
|
|
20
|
+
truss/cli/train/core.py,sha256=fWuHvjIT4tkax19B7_1_SWvkX1ot2xQ6WwcDGBhTnus,26520
|
|
21
21
|
truss/cli/train/deploy_from_checkpoint_config.yml,sha256=mktaVrfhN8Kjx1UveC4xr-gTW-kjwbHvq6bx_LpO-Wg,371
|
|
22
22
|
truss/cli/train/deploy_from_checkpoint_config_whisper.yml,sha256=6GbOorYC8ml0UyOUvuBpFO_fuYtYE646JqsalR-D4oY,406
|
|
23
23
|
truss/cli/train/metrics_watcher.py,sha256=smz-zrEsBj_-wJHI0pAZ-EAPrvfCWzq1eQjGiFNM-Mk,12755
|
|
24
24
|
truss/cli/train/poller.py,sha256=TGRzELxsicga0bEXewSX1ujw6lfPmDnHd6nr8zvOFO8,3550
|
|
25
|
-
truss/cli/train/types.py,sha256=
|
|
25
|
+
truss/cli/train/types.py,sha256=0tfgInTm1Dmq7p8Gh_0RwxGViN6Tpp0YjsiYeMUNiPE,1163
|
|
26
26
|
truss/cli/train/deploy_checkpoints/__init__.py,sha256=HuiDD6-qfwJ7fbRVX4s9Fxn6rmO6nwTLb0fVxwcMKls,137
|
|
27
|
-
truss/cli/train/deploy_checkpoints/deploy_checkpoints.py,sha256=
|
|
27
|
+
truss/cli/train/deploy_checkpoints/deploy_checkpoints.py,sha256=AEl41GHI6s35ajVaPUGnbRJAdAwUaQdVnLPbUbGAifE,22478
|
|
28
28
|
truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py,sha256=r_IKMlqejMwIU6gsfxDIRzfmltfDf6Sz9-vnKrP_10s,83
|
|
29
|
-
truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py,sha256=
|
|
30
|
-
truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py,sha256=
|
|
31
|
-
truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py,sha256=
|
|
29
|
+
truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py,sha256=f8_UB7CF6Y3MOhaf8Zim0heNiauOOAmA-WqsyP3X9mk,386
|
|
30
|
+
truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py,sha256=YKVMy3xogmbubJNrN_1LCR6xdHj9lBOAlKgMxWHdlQM,1115
|
|
31
|
+
truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py,sha256=eEo4ahRTsvRKOxCDj7DAHIUyUZWicn39VmC1PYS0pCY,314
|
|
32
32
|
truss/cli/utils/common.py,sha256=ink9ZE0MsOv6PCFK_Ra5k1aHm281TXTnMpnLjf2PtUM,6585
|
|
33
33
|
truss/cli/utils/output.py,sha256=GNjU85ZAMp5BI6Yij5wYXcaAvpm_kmHV0nHNmdkMxb0,646
|
|
34
34
|
truss/cli/utils/self_upgrade.py,sha256=eTJZA4Wc8uUp4Qh6viRQp6bZm--wnQp7KWe5KRRpPtg,5427
|
|
@@ -52,14 +52,14 @@ truss/patch/truss_dir_patch_applier.py,sha256=ALnaVnu96g0kF2UmGuBFTua3lrXpwAy4sG
|
|
|
52
52
|
truss/remote/remote_factory.py,sha256=-0gLh_yIyNDgD48Q6sR8Yo5dOMQg84lrHRvn_XR0n4s,3585
|
|
53
53
|
truss/remote/truss_remote.py,sha256=TEe6h6by5-JLy7PMFsDN2QxIY5FmdIYN3bKvHHl02xM,8440
|
|
54
54
|
truss/remote/baseten/__init__.py,sha256=XNqJW1zyp143XQc6-7XVwsUA_Q_ZJv_ausn1_Ohtw9Y,176
|
|
55
|
-
truss/remote/baseten/api.py,sha256=
|
|
55
|
+
truss/remote/baseten/api.py,sha256=5B5IXNy0v8hRHNH2ar3rldDa47kwt5s1PtKZQ9_pfmE,28263
|
|
56
56
|
truss/remote/baseten/auth.py,sha256=tI7s6cI2EZgzpMIzrdbILHyGwiHDnmoKf_JBhJXT55E,776
|
|
57
57
|
truss/remote/baseten/core.py,sha256=uxtmBI9RAVHu1glIEJb5Q4ccJYLeZM1Cp5Svb9W68Yw,21965
|
|
58
58
|
truss/remote/baseten/custom_types.py,sha256=bYrfTzGgYr6FDoya0omyadCLSTcTc-83U2scQORyUj0,4715
|
|
59
59
|
truss/remote/baseten/error.py,sha256=3TNTwwPqZnr4NRd9Sl6SfLUQR2fz9l6akDPpOntTpzA,578
|
|
60
60
|
truss/remote/baseten/remote.py,sha256=Se8AES5mk8jxa8S9fN2DSG7wnsaV7ftRjJ4Uwc_w_S0,22544
|
|
61
61
|
truss/remote/baseten/rest_client.py,sha256=_t3CWsWARt2u0C0fDsF4rtvkkHe-lH7KXoPxWXAkKd4,1185
|
|
62
|
-
truss/remote/baseten/service.py,sha256=
|
|
62
|
+
truss/remote/baseten/service.py,sha256=HMaKiYbr2Mzv4BfXF9QkJ8H3Wwrq3LOMpFt9js4t0rs,5834
|
|
63
63
|
truss/remote/baseten/utils/status.py,sha256=jputc9N9AHXxUuW4KOk6mcZKzQ_gOBOe5BSx9K0DxPY,1266
|
|
64
64
|
truss/remote/baseten/utils/tar.py,sha256=pMUv--YkwXDngUx1WUOK-KmAIKMcOg2E-CD5x4heh3s,2514
|
|
65
65
|
truss/remote/baseten/utils/time.py,sha256=Ry9GMjYnbIGYVIGwtmv4V8ljWjvdcaCf5NOQzlNeGxI,397
|
|
@@ -71,11 +71,11 @@ truss/templates/cache.Dockerfile.jinja,sha256=1qZqDo1phrcqi-Vwol-VafYJkADsBbQWU6
|
|
|
71
71
|
truss/templates/cache_requirements.txt,sha256=xoPoJ-OVnf1z6oq_RVM3vCr3ionByyqMLj7wGs61nUs,87
|
|
72
72
|
truss/templates/copy_cache_files.Dockerfile.jinja,sha256=Os5zFdYLZ_AfCRGq4RcpVTObOTwL7zvmwYcvOzd_Zqo,126
|
|
73
73
|
truss/templates/docker_server_requirements.txt,sha256=PyhOPKAmKW1N2vLvTfLMwsEtuGpoRrbWuNo7tT6v2Mc,18
|
|
74
|
-
truss/templates/server.Dockerfile.jinja,sha256=
|
|
75
|
-
truss/templates/control/requirements.txt,sha256=
|
|
74
|
+
truss/templates/server.Dockerfile.jinja,sha256=KzFZwJlZ3zExYheHtfXFIiL72pzbCQzJC5Z7uZyrX2A,7071
|
|
75
|
+
truss/templates/control/requirements.txt,sha256=tJGr83WoE0CZm2FrloZ9VScK84q-_FTuVXjDYrexhW0,250
|
|
76
76
|
truss/templates/control/control/application.py,sha256=5Kam6M-XtfKGaXQz8cc3d0bwDkB80o2MskABWROx1gk,5321
|
|
77
77
|
truss/templates/control/control/endpoints.py,sha256=KzqsLVNJE6r6TCPW8D5FMCtsfHadTwR15A3z_viGxmM,11782
|
|
78
|
-
truss/templates/control/control/server.py,sha256=
|
|
78
|
+
truss/templates/control/control/server.py,sha256=bdyXoqhW9e4jvFhxt3eoBKUPzH2Y5UAnMhY5GR0Kd5M,3124
|
|
79
79
|
truss/templates/control/control/helpers/context_managers.py,sha256=W6dyFgLBhPa5meqrOb3w_phMtKfaJI-GhwUfpiycDc8,413
|
|
80
80
|
truss/templates/control/control/helpers/custom_types.py,sha256=n_lTudtLTpy4oPV3aDdJ4X2rh3KCV5btYO9UnTeUouQ,5471
|
|
81
81
|
truss/templates/control/control/helpers/errors.py,sha256=LddFuQywuCCdYTEnFT5EalxdWos4uR89rbhMakCy2bA,970
|
|
@@ -96,8 +96,8 @@ truss/templates/docker_server/supervisord.conf.jinja,sha256=dd37fwZE--cutrvOUCqE
|
|
|
96
96
|
truss/templates/server/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
97
97
|
truss/templates/server/main.py,sha256=kWXrdD8z8IpamyWxc8qcvd5ck9gM1Kz2QH5qHJCnmOQ,222
|
|
98
98
|
truss/templates/server/model_wrapper.py,sha256=k75VVISwwlsx5EGb82UZsu8kCM_i6Yi3-Hd0-Kpm1yo,42055
|
|
99
|
-
truss/templates/server/requirements.txt,sha256=
|
|
100
|
-
truss/templates/server/truss_server.py,sha256=
|
|
99
|
+
truss/templates/server/requirements.txt,sha256=2jknFfEZyN0kKKhQo5hNvZFfpgImTiFDLOTw76G3Fjk,672
|
|
100
|
+
truss/templates/server/truss_server.py,sha256=YKcG7Sr0T_8XjIC3GK9vBwoNb8oxVgwic3-3Ikzpmgw,19781
|
|
101
101
|
truss/templates/server/common/__init__.py,sha256=qHIqr68L5Tn4mV6S-PbORpcuJ4jmtBR8aCuRTIWDvNo,85
|
|
102
102
|
truss/templates/server/common/errors.py,sha256=My0P6-Y7imVTICIhazHT0vlSu3XJDH7As06OyVzu4Do,8589
|
|
103
103
|
truss/templates/server/common/patches.py,sha256=uEOzvDnXsHOkTSa8zygGYuR4GHhrFNVHNQc5peJcwvo,1393
|
|
@@ -142,7 +142,7 @@ truss/tests/test_truss_handle.py,sha256=-xz9VXkecXDTslmQZ-dmUmQLnvD0uumRqHS2uvGl
|
|
|
142
142
|
truss/tests/test_util.py,sha256=hs1bNMkXKEdoPRx4Nw-NAEdoibR92OubZuADGmbiYsQ,1344
|
|
143
143
|
truss/tests/cli/test_cli.py,sha256=yfbVS5u1hnAmmA8mJ539vj3lhH-JVGUvC4Q_Mbort44,787
|
|
144
144
|
truss/tests/cli/train/test_cache_view.py,sha256=aVRCh3atRpFbJqyYgq7N-vAW0DiKMftQ7ajUqO2ClOg,22606
|
|
145
|
-
truss/tests/cli/train/test_deploy_checkpoints.py,sha256=
|
|
145
|
+
truss/tests/cli/train/test_deploy_checkpoints.py,sha256=Ndkd9YxEgDLf3zLAZYH0myFK_wkKTz0oGZ57yWQt_l8,10100
|
|
146
146
|
truss/tests/cli/train/test_train_cli_core.py,sha256=vzYfxKdwoa3NaFMrVZbSg5qOoLXivMvZXN1ClQirGTQ,16148
|
|
147
147
|
truss/tests/cli/train/test_train_init.py,sha256=SRAZvvD5-PWYlpHHek2MftYTA4I3ZHi7gniHl2fYV98,17464
|
|
148
148
|
truss/tests/cli/train/resources/test_deploy_from_checkpoint_config.yml,sha256=GF7r9l0KaeXiUYCPSBpeMPd2QG6PeWWyI12NdbqLOgc,1930
|
|
@@ -163,7 +163,7 @@ truss/tests/remote/baseten/test_api.py,sha256=AKJeNsrUtTNa0QPClfEvXlBOSJ214PKp23
|
|
|
163
163
|
truss/tests/remote/baseten/test_auth.py,sha256=ttu4bDnmwGfo3oiNut4HVGnh-QnjAefwZJctiibQJKY,669
|
|
164
164
|
truss/tests/remote/baseten/test_core.py,sha256=6NzJTDmoSUv6Muy1LFEYIUg10-cqw-hbLyeTSWcdNjY,26117
|
|
165
165
|
truss/tests/remote/baseten/test_remote.py,sha256=y1qSPL1t7dBeYI3xMFn436fttG7wkYdAoENTz7qKObg,23634
|
|
166
|
-
truss/tests/remote/baseten/test_service.py,sha256=
|
|
166
|
+
truss/tests/remote/baseten/test_service.py,sha256=ehbGkzzSPdLN7JHxc0O9YDPfzzKqU8OBzJGjRdw08zE,3786
|
|
167
167
|
truss/tests/templates/control/control/conftest.py,sha256=euDFh0AhcHP-vAmTzi1Qj3lymnplDTgvtbt4Ez_lfpw,654
|
|
168
168
|
truss/tests/templates/control/control/test_endpoints.py,sha256=HIlRYOicsdHD8r_V5gHpZWybDC26uwXJfbvCohdE3HI,3751
|
|
169
169
|
truss/tests/templates/control/control/test_server.py,sha256=0D0OMwZ-9jZRxxHoiQYij0RBMenuA9o29LlwNzd04Vk,9149
|
|
@@ -335,6 +335,7 @@ truss/truss_handle/patch/local_truss_patch_applier.py,sha256=fOHWKt3teYnbqeRsF63
|
|
|
335
335
|
truss/truss_handle/patch/signature.py,sha256=8eas8gy6Japd1hrgdmtHmKTTxQmWsbmgKRQQGL2PVuA,858
|
|
336
336
|
truss/truss_handle/patch/truss_dir_patch_applier.py,sha256=uhhHvKYHn_dpfz0xp4jwO9_qAej5sO3f8of_h-21PP4,3666
|
|
337
337
|
truss/util/.truss_ignore,sha256=jpQA9ou-r_JEIcEHsUqGLHhir_m3d4IPGNyzKXtS-2g,3131
|
|
338
|
+
truss/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
338
339
|
truss/util/docker.py,sha256=6PD7kMBBrOjsdvgkuSv7JMgZbe3NoJIeGasljMm2SwA,3934
|
|
339
340
|
truss/util/download.py,sha256=1lfBwzyaNLEp7SAVrBd9BX5inZpkCVp8sBnS9RNoiJA,2521
|
|
340
341
|
truss/util/env_vars.py,sha256=7Bv686eER71Barrs6fNamk_TrTJGmu9yV2TxaVmupn0,1232
|
|
@@ -364,13 +365,13 @@ truss_chains/remote_chainlet/model_skeleton.py,sha256=8ZReLOO2MLcdg7bNZ61C-6j-e6
|
|
|
364
365
|
truss_chains/remote_chainlet/stub.py,sha256=Y2gDUzMY9WRaQNHIz-o4dfLUfFyYV9dUhIRQcfgrY8g,17209
|
|
365
366
|
truss_chains/remote_chainlet/utils.py,sha256=Zn3GZRvK8f65WUa-qa-8uPFZ2pD7ukRFxbLOvT-BL0Q,24063
|
|
366
367
|
truss_train/__init__.py,sha256=A3MzRPMInZfmzLvPpZI7gdKgshAVCw6bwhU-6JYU2zs,939
|
|
367
|
-
truss_train/definitions.py,sha256=
|
|
368
|
+
truss_train/definitions.py,sha256=jcaVICE03iI8lBqEPe01uO3vFiMu_8pqB-j_dX-zwhI,8209
|
|
368
369
|
truss_train/deployment.py,sha256=lWWANSuzBWu2M4oK4qD7n-oVR1JKdmw2Pn5BJQHg-Ck,3074
|
|
369
370
|
truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
|
|
370
371
|
truss_train/public_api.py,sha256=9N_NstiUlmBuLUwH_fNG_1x7OhGCytZLNvqKXBlStrM,1220
|
|
371
372
|
truss_train/restore_from_checkpoint.py,sha256=8hdPm-WSgkt74HDPjvCjZMBpvA9MwtoYsxVjOoa7BaM,1176
|
|
372
|
-
truss-0.11.
|
|
373
|
-
truss-0.11.
|
|
374
|
-
truss-0.11.
|
|
375
|
-
truss-0.11.
|
|
376
|
-
truss-0.11.
|
|
373
|
+
truss-0.11.10rc500.dist-info/METADATA,sha256=r5CBFbt58s_-RKflr6vpSPiGBxKyP_SsEllzlh3d5oQ,6683
|
|
374
|
+
truss-0.11.10rc500.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
375
|
+
truss-0.11.10rc500.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
|
|
376
|
+
truss-0.11.10rc500.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
|
|
377
|
+
truss-0.11.10rc500.dist-info/RECORD,,
|
truss_train/definitions.py
CHANGED
|
@@ -185,12 +185,13 @@ class TrainingProject(custom_types.SafeModelNoExtra):
|
|
|
185
185
|
|
|
186
186
|
class Checkpoint(custom_types.ConfigModel, ABC):
|
|
187
187
|
training_job_id: str
|
|
188
|
-
|
|
188
|
+
checkpoint_name: str
|
|
189
189
|
model_weight_format: ModelWeightsFormat
|
|
190
190
|
|
|
191
191
|
def to_truss_config(self) -> truss_config.TrainingArtifactReference:
|
|
192
192
|
return truss_config.TrainingArtifactReference(
|
|
193
|
-
training_job_id=self.training_job_id,
|
|
193
|
+
training_job_id=self.training_job_id,
|
|
194
|
+
paths=[f"rank-0/{self.checkpoint_name}/"],
|
|
194
195
|
)
|
|
195
196
|
|
|
196
197
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|