truss 0.10.9rc538__py3-none-any.whl → 0.10.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

@@ -73,8 +73,11 @@ def _prepare_click_context(f: click.Command, params: dict) -> click.Context:
73
73
  @click.argument("config", type=Path, required=True)
74
74
  @click.option("--remote", type=str, required=False, help="Remote to use")
75
75
  @click.option("--tail", is_flag=True, help="Tail for status + logs after push.")
76
+ @click.option("--job-name", type=str, required=False, help="Name of the training job.")
76
77
  @common.common_options()
77
- def push_training_job(config: Path, remote: Optional[str], tail: bool):
78
+ def push_training_job(
79
+ config: Path, remote: Optional[str], tail: bool, job_name: Optional[str]
80
+ ):
78
81
  """Run a training job"""
79
82
  from truss_train import deployment
80
83
 
@@ -85,7 +88,9 @@ def push_training_job(config: Path, remote: Optional[str], tail: bool):
85
88
  remote_provider: BasetenRemote = cast(
86
89
  BasetenRemote, RemoteFactory.create(remote=remote)
87
90
  )
88
- job_resp = deployment.create_training_job_from_file(remote_provider, config)
91
+ job_resp = deployment.create_training_job_from_file(
92
+ remote_provider, config, job_name
93
+ )
89
94
 
90
95
  # Note: This post create logic needs to happen outside the context
91
96
  # of the above context manager, as only one console session can be active
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import os
4
3
  import json
5
4
  import logging
6
5
  import re
@@ -74,7 +73,6 @@ from truss.contexts.image_builder.util import (
74
73
  )
75
74
  from truss.contexts.truss_context import TrussContext
76
75
  from truss.truss_handle.patch.hash import directory_content_hash
77
- from truss.util.basetenpointer import model_cache_hf_to_b10ptr
78
76
  from truss.util.jinja import read_template_from_fs
79
77
  from truss.util.path import (
80
78
  build_truss_target_directory,
@@ -327,36 +325,27 @@ def get_files_to_model_cache_v1(config: TrussConfig, truss_dir: Path, build_dir:
327
325
  def build_model_cache_v2_and_copy_bptr_manifest(config: TrussConfig, build_dir: Path):
328
326
  assert config.model_cache.is_v2
329
327
  assert all(model.volume_folder is not None for model in config.model_cache.models)
330
- try:
331
- from truss_transfer import PyModelRepo, create_basetenpointer_from_models
332
-
333
- py_models = [
334
- PyModelRepo(
335
- repo_id=model.repo_id,
336
- revision=model.revision,
337
- runtime_secret_name=model.runtime_secret_name,
338
- allow_patterns=model.allow_patterns,
339
- ignore_patterns=model.ignore_patterns,
340
- volume_folder=model.volume_folder,
341
- kind=model.kind.value,
342
- )
343
- for model in config.model_cache.models
344
- ]
345
- # create BasetenPointer from models
346
- basetenpointer_json = create_basetenpointer_from_models(models=py_models)
347
- bptr_py = json.loads(basetenpointer_json)["pointers"]
348
- logging.info(f"created ({len(bptr_py)}) Basetenpointer")
349
- logging.info(f"pointers json: {basetenpointer_json}")
350
- with open(build_dir / "bptr-manifest", "w") as f:
351
- f.write(basetenpointer_json)
352
- except Exception as e:
353
- logging.warning(f"debug: failed to create BasetenPointer: {e}")
354
- # TODO: remove below section + remove logging lines above.
355
- # builds BasetenManifest for caching
356
- basetenpointers = model_cache_hf_to_b10ptr(config.model_cache)
357
- # write json of bastenpointers into build dir
358
- with open(build_dir / "bptr-manifest", "w") as f:
359
- f.write(basetenpointers.model_dump_json())
328
+ from truss_transfer import PyModelRepo, create_basetenpointer_from_models
329
+
330
+ py_models = [
331
+ PyModelRepo(
332
+ repo_id=model.repo_id,
333
+ revision=model.revision,
334
+ runtime_secret_name=model.runtime_secret_name,
335
+ allow_patterns=model.allow_patterns,
336
+ ignore_patterns=model.ignore_patterns,
337
+ volume_folder=model.volume_folder,
338
+ kind=model.kind.value,
339
+ )
340
+ for model in config.model_cache.models
341
+ ]
342
+ # create BasetenPointer from models
343
+ basetenpointer_json = create_basetenpointer_from_models(models=py_models)
344
+ bptr_py = json.loads(basetenpointer_json)["pointers"]
345
+ logging.info(f"created ({len(bptr_py)}) Basetenpointer")
346
+ logging.info(f"pointers json: {basetenpointer_json}")
347
+ with open(build_dir / "bptr-manifest", "w") as f:
348
+ f.write(basetenpointer_json)
360
349
 
361
350
 
362
351
  def generate_docker_server_nginx_config(build_dir, config):
@@ -794,8 +783,6 @@ class ServingImageBuilder(ImageBuilder):
794
783
  config
795
784
  )
796
785
 
797
- non_root_user = os.getenv("BT_USE_NON_ROOT_USER", False)
798
- enable_model_container_admin_commands = os.getenv("BT_ENABLE_MODEL_CONTAINER_ADMIN_CMDS")
799
786
  dockerfile_contents = dockerfile_template.render(
800
787
  should_install_server_requirements=should_install_server_requirements,
801
788
  base_image_name_and_tag=base_image_name_and_tag,
@@ -829,8 +816,6 @@ class ServingImageBuilder(ImageBuilder):
829
816
  build_commands=build_commands,
830
817
  use_local_src=config.use_local_src,
831
818
  passthrough_environment_variables=passthrough_environment_variables,
832
- non_root_user=non_root_user,
833
- enable_model_container_admin_commands=enable_model_container_admin_commands,
834
819
  **FILENAME_CONSTANTS_MAP,
835
820
  )
836
821
  # Consolidate repeated empty lines to single empty lines.
@@ -8,32 +8,6 @@ FROM {{ base_image_name_and_tag }} AS truss_server
8
8
  {%- set python_executable = config.base_image.python_executable_path or 'python3' %}
9
9
  ENV PYTHON_EXECUTABLE="{{ python_executable }}"
10
10
 
11
- {%- set app_username = "app" %} {# needed later for USER directive#}
12
- {% block user_setup %}
13
- {%- set app_user_uid = 60000 %}
14
- {%- set control_server_dir = "/control" %}
15
- {%- set default_owner = "root:root" %}
16
- # The non-root user's home directory.
17
- # uv will use $HOME to install packages.
18
- ENV HOME=/home/{{ app_username }}
19
- # Directory containing inference server code.
20
- ENV APP_HOME=/{{ app_username }}
21
- RUN mkdir -p ${APP_HOME} {{ control_server_dir }}
22
- # Create a non-root user to run model containers.
23
- RUN useradd -u {{ app_user_uid }} -ms /bin/bash {{ app_username }}
24
- # Add the user's local bin to the path, used by uv.
25
- ENV PATH=${PATH}:${HOME}/.local/bin
26
- {% endblock %} {#- endblock user_setup #}
27
-
28
- {%- if enable_model_container_admin_commands %}
29
- RUN apt update && apt install -y sudo
30
- {%- set allowed_admin_commands = ["/usr/bin/apt install *", "/usr/bin/apt update"] %}
31
- RUN echo "Defaults:{{ app_username }} passwd_tries=0\n{{ app_username }} ALL=(root) NOPASSWD: {{ allowed_admin_commands | join(", ") }}" > /etc/sudoers.d/app-packages
32
- RUN chmod 0440 /etc/sudoers.d/app-packages
33
- {#- optional but good practice: check if the sudoers file is valid #}
34
- RUN visudo -c
35
- {%- endif %}
36
-
37
11
  {%- set UV_VERSION = "0.7.19" %}
38
12
  {#
39
13
  NB(nikhil): We use a semi-complex uv installation command across the board:
@@ -65,6 +39,7 @@ RUN if ! command -v uv >/dev/null 2>&1; then \
65
39
  command -v curl >/dev/null 2>&1 || (apt update && apt install -y curl) && \
66
40
  curl -LsSf --retry 5 --retry-delay 5 https://astral.sh/uv/{{ UV_VERSION }}/install.sh | sh; \
67
41
  fi
42
+ ENV PATH="/root/.local/bin:$PATH"
68
43
  {% endblock %}
69
44
 
70
45
  {% block base_image_patch %}
@@ -82,7 +57,7 @@ RUN {{ sys_pip_install_command }} install mkl
82
57
 
83
58
  {% block install_system_requirements %}
84
59
  {%- if should_install_system_requirements %}
85
- COPY --chown={{ default_owner }} ./{{ system_packages_filename }} {{ system_packages_filename }}
60
+ COPY ./{{ system_packages_filename }} {{ system_packages_filename }}
86
61
  RUN apt-get update && apt-get install --yes --no-install-recommends $(cat {{ system_packages_filename }}) \
87
62
  && apt-get autoremove -y \
88
63
  && apt-get clean -y \
@@ -93,11 +68,11 @@ RUN apt-get update && apt-get install --yes --no-install-recommends $(cat {{ sys
93
68
 
94
69
  {% block install_requirements %}
95
70
  {%- if should_install_user_requirements_file %}
96
- COPY --chown={{ default_owner }} ./{{ user_supplied_requirements_filename }} {{ user_supplied_requirements_filename }}
71
+ COPY ./{{ user_supplied_requirements_filename }} {{ user_supplied_requirements_filename }}
97
72
  RUN {{ sys_pip_install_command }} -r {{ user_supplied_requirements_filename }} --no-cache-dir
98
73
  {%- endif %}
99
74
  {%- if should_install_requirements %}
100
- COPY --chown={{ default_owner }} ./{{ config_requirements_filename }} {{ config_requirements_filename }}
75
+ COPY ./{{ config_requirements_filename }} {{ config_requirements_filename }}
101
76
  RUN {{ sys_pip_install_command }} -r {{ config_requirements_filename }} --no-cache-dir
102
77
  {%- endif %}
103
78
  {% endblock %}
@@ -105,6 +80,7 @@ RUN {{ sys_pip_install_command }} -r {{ config_requirements_filename }} --no-cac
105
80
 
106
81
 
107
82
  {%- if not config.docker_server %}
83
+ ENV APP_HOME="/app"
108
84
  WORKDIR $APP_HOME
109
85
  {%- endif %}
110
86
 
@@ -114,7 +90,7 @@ WORKDIR $APP_HOME
114
90
 
115
91
  {% block bundled_packages_copy %}
116
92
  {%- if bundled_packages_dir_exists %}
117
- COPY --chown={{ default_owner }} ./{{ config.bundled_packages_dir }} /packages
93
+ COPY ./{{ config.bundled_packages_dir }} /packages
118
94
  {%- endif %}
119
95
  {% endblock %}
120
96
 
@@ -1,7 +1,7 @@
1
1
  FROM python:3.11-slim AS cache_warmer
2
2
 
3
- RUN mkdir -p ${APP_HOME}/model_cache
4
- WORKDIR ${APP_HOME}
3
+ RUN mkdir -p /app/model_cache
4
+ WORKDIR /app
5
5
 
6
6
  {% if hf_access_token %}
7
7
  ENV HUGGING_FACE_HUB_TOKEN="{{hf_access_token}}"
@@ -9,12 +9,12 @@ ENV HUGGING_FACE_HUB_TOKEN="{{hf_access_token}}"
9
9
 
10
10
  RUN apt-get -y update; apt-get -y install curl; curl -s https://baseten-public.s3.us-west-2.amazonaws.com/bin/b10cp-5fe8dc7da-linux-amd64 -o /app/b10cp; chmod +x /app/b10cp
11
11
  ENV B10CP_PATH_TRUSS="/app/b10cp"
12
- COPY --chown={{ default_owner }} ./cache_requirements.txt ${APP_HOME}/cache_requirements.txt
12
+ COPY ./cache_requirements.txt /app/cache_requirements.txt
13
13
  RUN pip install -r /app/cache_requirements.txt --no-cache-dir && rm -rf /root/.cache/pip
14
- COPY --chown={{ default_owner }} ./cache_warmer.py /cache_warmer.py
14
+ COPY ./cache_warmer.py /cache_warmer.py
15
15
 
16
16
  {% for credential in credentials_to_cache %}
17
- COPY ./{{credential}} ${APP_HOME}/{{credential}}
17
+ COPY ./{{credential}} /app/{{credential}}
18
18
  {% endfor %}
19
19
 
20
20
  {% for repo, hf_dir in models.items() %}
@@ -1,3 +1,3 @@
1
1
  {% for file in cached_files %}
2
- COPY --chown={{ default_owner }} --from=cache_warmer {{file.source}} {{file.dst}}
2
+ COPY --from=cache_warmer {{file.source}} {{file.dst}}
3
3
  {% endfor %}
@@ -1,5 +1,4 @@
1
1
  [supervisord]
2
- pidfile=/tmp/supervisord.pid ; Set PID file location to /tmp to be writable by the non-root user
3
2
  nodaemon=true ; Run supervisord in the foreground (useful for containers)
4
3
  logfile=/dev/null ; Disable logging to file (send logs to /dev/null)
5
4
  logfile_maxbytes=0 ; No size limit on logfile (since logging is disabled)
@@ -18,7 +18,7 @@ psutil>=5.9.4
18
18
  python-json-logger>=2.0.2
19
19
  pyyaml>=6.0.0
20
20
  requests>=2.31.0
21
- truss-transfer==0.0.27
21
+ truss-transfer==0.0.29
22
22
  uvicorn>=0.24.0
23
23
  uvloop>=0.19.0
24
24
  websockets>=10.0
@@ -20,7 +20,7 @@ RUN apt update && \
20
20
  && apt-get clean -y \
21
21
  && rm -rf /var/lib/apt/lists/*
22
22
 
23
- COPY --chown={{ default_owner }} ./{{ base_server_requirements_filename }} {{ base_server_requirements_filename }}
23
+ COPY ./{{ base_server_requirements_filename }} {{ base_server_requirements_filename }}
24
24
  RUN {{ sys_pip_install_command }} -r {{ base_server_requirements_filename }} --no-cache-dir
25
25
  {%- endif %} {#- endif not config.docker_server #}
26
26
 
@@ -38,7 +38,7 @@ RUN ln -sf {{ config.base_image.python_executable_path }} /usr/local/bin/python
38
38
 
39
39
  {% block install_requirements %}
40
40
  {%- if should_install_server_requirements %}
41
- COPY --chown={{ default_owner }} ./{{ server_requirements_filename }} {{ server_requirements_filename }}
41
+ COPY ./{{ server_requirements_filename }} {{ server_requirements_filename }}
42
42
  RUN {{ sys_pip_install_command }} -r {{ server_requirements_filename }} --no-cache-dir
43
43
  {%- endif %} {#- endif should_install_server_requirements #}
44
44
  {{ super() }}
@@ -65,47 +65,47 @@ RUN {% for secret,path in config.build.secret_to_path_mapping.items() %} --mount
65
65
 
66
66
  {# Copy data before code for better caching #}
67
67
  {%- if data_dir_exists %}
68
- COPY --chown={{ default_owner }} ./{{ config.data_dir }} ${APP_HOME}/data
68
+ COPY ./{{ config.data_dir }} /app/data
69
69
  {%- endif %} {#- endif data_dir_exists #}
70
70
 
71
71
  {%- if model_cache_v2 %}
72
72
  # v0.0.9, keep synced with server_requirements.txt
73
- RUN curl -sSL --fail --retry 5 --retry-delay 2 -o /usr/local/bin/truss-transfer-cli https://github.com/basetenlabs/truss/releases/download/v0.10.9rc0/truss-transfer-cli-v0.10.9rc0-linux-x86_64-unknown-linux-musl
73
+ RUN curl -sSL --fail --retry 5 --retry-delay 2 -o /usr/local/bin/truss-transfer-cli https://github.com/basetenlabs/truss/releases/download/v0.10.10rc1/truss-transfer-cli-v0.10.10rc1-linux-x86_64-unknown-linux-musl
74
74
  RUN chmod +x /usr/local/bin/truss-transfer-cli
75
75
  RUN mkdir /static-bptr
76
76
  RUN echo "hash {{model_cache_hash}}"
77
- COPY --chown={{ default_owner }} ./bptr-manifest /static-bptr/static-bptr-manifest.json
77
+ COPY ./bptr-manifest /static-bptr/static-bptr-manifest.json
78
78
  {%- endif %} {#- endif model_cache_v2 #}
79
79
 
80
80
  {%- if not config.docker_server %}
81
- COPY --chown={{ default_owner }} ./server ${APP_HOME}
81
+ COPY ./server /app
82
82
  {%- endif %} {#- endif not config.docker_server #}
83
83
 
84
84
  {%- if use_local_src %}
85
85
  {# This path takes precedence over site-packages. #}
86
- COPY --chown={{ default_owner }} ./truss_chains ${APP_HOME}/truss_chains
87
- COPY --chown={{ default_owner }} ./truss ${APP_HOME}/truss
86
+ COPY ./truss_chains /app/truss_chains
87
+ COPY ./truss /app/truss
88
88
  {%- endif %} {#- endif use_local_src #}
89
89
 
90
- COPY --chown={{ default_owner }} ./config.yaml ${APP_HOME}/config.yaml
90
+ COPY ./config.yaml /app/config.yaml
91
91
  {%- if requires_live_reload %}
92
92
  RUN uv python install {{ control_python_version }}
93
93
  RUN uv venv /control/.env --python {{ control_python_version }}
94
94
 
95
- COPY --chown={{ default_owner }} ./control /control
95
+ COPY ./control /control
96
96
  RUN uv pip install -r /control/requirements.txt --python /control/.env/bin/python --no-cache-dir
97
97
  {%- endif %} {#- endif requires_live_reload #}
98
98
 
99
99
  {%- if model_dir_exists %}
100
- COPY --chown={{ default_owner }} ./{{ config.model_module_dir }} ${APP_HOME}/model
100
+ COPY ./{{ config.model_module_dir }} /app/model
101
101
  {%- endif %} {#- endif model_dir_exists #}
102
102
  {% endblock %} {#- endblock app_copy #}
103
103
 
104
104
  {% block run %}
105
105
  {%- if config.docker_server %}
106
- RUN apt-get update -y && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
106
+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
107
107
  curl nginx && rm -rf /var/lib/apt/lists/*
108
- COPY --chown={{ default_owner }} ./docker_server_requirements.txt ${APP_HOME}/docker_server_requirements.txt
108
+ COPY ./docker_server_requirements.txt /app/docker_server_requirements.txt
109
109
 
110
110
  {# NB(nikhil): Use the same python version for custom server proxy as the control server, for consistency. #}
111
111
  RUN uv python install {{ control_python_version }}
@@ -115,38 +115,21 @@ RUN uv pip install --python /docker_server/.venv/bin/python -r /app/docker_serve
115
115
  {% set supervisor_config_path = "/etc/supervisor/supervisord.conf" %}
116
116
  {% set supervisor_log_dir = "/var/log/supervisor" %}
117
117
  {% set supervisor_server_url = "http://localhost:8080" %}
118
- COPY --chown={{ default_owner }} ./proxy.conf {{ proxy_config_path }}
118
+ COPY ./proxy.conf {{ proxy_config_path }}
119
119
  RUN mkdir -p {{ supervisor_log_dir }}
120
- COPY --chown={{ default_owner }} ./supervisord.conf {{ supervisor_config_path }}
120
+ COPY supervisord.conf {{ supervisor_config_path }}
121
121
  ENV SUPERVISOR_SERVER_URL="{{ supervisor_server_url }}"
122
122
  ENV SERVER_START_CMD="/docker_server/.venv/bin/supervisord -c {{ supervisor_config_path }}"
123
- {#- default configuration uses port 80, which requires root privileges, so we remove it #}
124
- RUN rm -f /etc/nginx/sites-enabled/default
125
- {%- if non_root_user %}
126
- {#- nginx writes to /var/lib/nginx, /var/log/nginx, and /run directories #}
127
- {% set nginx_dirs = ["/var/lib/nginx", "/var/log/nginx", "/run"] %}
128
- RUN chown -R {{ app_username }}:{{ app_username }} {{ nginx_dirs | join(" ") }}
129
- RUN chown -R {{ app_username }}:{{ app_username }} ${HOME} ${APP_HOME}
130
- USER {{ app_username }}
131
- {%- endif %} {#- endif non_root_user #}
132
123
  ENTRYPOINT ["/docker_server/.venv/bin/supervisord", "-c", "{{ supervisor_config_path }}"]
133
124
  {%- elif requires_live_reload %} {#- elif requires_live_reload #}
134
125
  ENV HASH_TRUSS="{{ truss_hash }}"
135
126
  ENV CONTROL_SERVER_PORT="8080"
136
127
  ENV INFERENCE_SERVER_PORT="8090"
137
128
  ENV SERVER_START_CMD="/control/.env/bin/python /control/control/server.py"
138
- {%- if non_root_user %}
139
- RUN chown -R {{ app_username }}:{{ app_username }} ${HOME} ${APP_HOME}
140
- USER {{ app_username }}
141
- {%- endif %} {#- endif non_root_user #}
142
129
  ENTRYPOINT ["/control/.env/bin/python", "/control/control/server.py"]
143
130
  {%- else %} {#- else (default inference server) #}
144
131
  ENV INFERENCE_SERVER_PORT="8080"
145
132
  ENV SERVER_START_CMD="{{ python_executable }} /app/main.py"
146
- {%- if non_root_user %}
147
- RUN chown -R {{ app_username }}:{{ app_username }} ${HOME} ${APP_HOME}
148
- USER {{ app_username }}
149
- {%- endif %} {#- endif non_root_user #}
150
133
  ENTRYPOINT ["{{ python_executable }}", "/app/main.py"]
151
134
  {%- endif %} {#- endif config.docker_server / live_reload #}
152
135
  {% endblock %} {#- endblock run #}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: truss
3
- Version: 0.10.9rc538
3
+ Version: 0.10.10
4
4
  Summary: A seamless bridge from model development to model delivery
5
5
  Project-URL: Repository, https://github.com/basetenlabs/truss
6
6
  Project-URL: Homepage, https://truss.baseten.co
@@ -37,7 +37,7 @@ Requires-Dist: rich<14,>=13.4.2
37
37
  Requires-Dist: ruff>=0.4.8
38
38
  Requires-Dist: tenacity>=8.0.1
39
39
  Requires-Dist: tomlkit>=0.13.2
40
- Requires-Dist: truss-transfer==0.0.27
40
+ Requires-Dist: truss-transfer==0.0.29
41
41
  Requires-Dist: watchfiles<0.20,>=0.19.0
42
42
  Description-Content-Type: text/markdown
43
43
 
@@ -11,7 +11,7 @@ truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
11
11
  truss/cli/chains_commands.py,sha256=y6pdIAGCcKOPG9bPuCXPfSA0onQm5x-tT_3blSBfPYg,16971
12
12
  truss/cli/cli.py,sha256=PaMkuwXZflkU7sa1tEoT_Zmy-iBkEZs1m4IVqcieaeo,30367
13
13
  truss/cli/remote_cli.py,sha256=G_xCKRXzgkCmkiZJhUFfsv5YSVgde1jLA5LPQitpZgI,1905
14
- truss/cli/train_commands.py,sha256=u8FhhA4r9j6smIr0bfvPbnBOJyd6sCqyEemsWXfcwWs,12193
14
+ truss/cli/train_commands.py,sha256=P9bdnpq1SgEGXBaVf9joKdsaCDX2v29P4MhLMuz-jYw,12344
15
15
  truss/cli/logs/base_watcher.py,sha256=KKyd7lIrdaEeDVt8EtjMioSPGVpLyOcF0ewyzE_GGdQ,2785
16
16
  truss/cli/logs/model_log_watcher.py,sha256=NACcP-wkcaroYa2Cb9BZC7Yr0554WZa_FSM2LXOf4A8,1263
17
17
  truss/cli/logs/training_log_watcher.py,sha256=r6HRqrLnz-PiKTUXiDYYxg4ZnP8vYcXlEX1YmgHhzlo,1173
@@ -36,7 +36,7 @@ truss/contexts/docker_build_setup.py,sha256=cF4ExZgtYvrWxvyCAaUZUvV_DB_7__MqVomU
36
36
  truss/contexts/truss_context.py,sha256=uS6L-ACHxNk0BsJwESOHh1lA0OGGw0pb33aFKGsASj4,436
37
37
  truss/contexts/image_builder/cache_warmer.py,sha256=TGMV1Mh87n2e_dSowH0sf0rZhZraDOR-LVapZL3a5r8,7377
38
38
  truss/contexts/image_builder/image_builder.py,sha256=IuRgDeeoHVLzIkJvKtX3807eeqEyaroCs_KWDcIHZUg,1461
39
- truss/contexts/image_builder/serving_image_builder.py,sha256=yrFBAGmYspt_4rtdrku-zeKQfEkqfHfTzIJEkDNDkng,34226
39
+ truss/contexts/image_builder/serving_image_builder.py,sha256=FH5HPnrr9_OomN5WplsyUrGGETe9ld6h3q9JCpvB6FY,33322
40
40
  truss/contexts/image_builder/util.py,sha256=y2-CjUKv0XV-0w2sr1fUCflysDJLsoU4oPp6tvvoFnk,1203
41
41
  truss/contexts/local_loader/docker_build_emulator.py,sha256=rmf7I28zksSmHjwvJMx2rIa6xK4KeR5fBm5YFth_fQg,2464
42
42
  truss/contexts/local_loader/dockerfile_parser.py,sha256=GoRJ0Af_3ILyLhjovK5lrCGn1rMxz6W3l681ro17ZzI,1344
@@ -66,12 +66,12 @@ truss/remote/baseten/utils/time.py,sha256=Ry9GMjYnbIGYVIGwtmv4V8ljWjvdcaCf5NOQzl
66
66
  truss/remote/baseten/utils/transfer.py,sha256=d3VptuQb6M1nyS6kz0BAfeOYDLkMKUjatJXpY-mp-As,1548
67
67
  truss/templates/README.md.jinja,sha256=N7CJdyldZuJamj5jLh47le0hFBdu9irVsTBqoxhPNPQ,2476
68
68
  truss/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
69
- truss/templates/base.Dockerfile.jinja,sha256=7EjUj08Wy0OmG0euDnwi-MFEYYiizcTIeNeDRVyMyhk,5255
70
- truss/templates/cache.Dockerfile.jinja,sha256=evMbbMXjceDjUmGGuB0shLS_PxZsc9ZGkbzR-6aIW-Q,1114
69
+ truss/templates/base.Dockerfile.jinja,sha256=vFAJH1lC9jg90-076H2DCmkXUAlpseitIN6c4UwagxA,4020
70
+ truss/templates/cache.Dockerfile.jinja,sha256=LhsVP9F3BATKQGkgya_YT4v6ABTUkpy-Jb3N36zsw10,1030
71
71
  truss/templates/cache_requirements.txt,sha256=xoPoJ-OVnf1z6oq_RVM3vCr3ionByyqMLj7wGs61nUs,87
72
- truss/templates/copy_cache_files.Dockerfile.jinja,sha256=Os5zFdYLZ_AfCRGq4RcpVTObOTwL7zvmwYcvOzd_Zqo,126
72
+ truss/templates/copy_cache_files.Dockerfile.jinja,sha256=arHldnuclt7vUFHyRz6vus5NGMDkIofm-1RU37A0xZM,98
73
73
  truss/templates/docker_server_requirements.txt,sha256=PyhOPKAmKW1N2vLvTfLMwsEtuGpoRrbWuNo7tT6v2Mc,18
74
- truss/templates/server.Dockerfile.jinja,sha256=qGhUrii-WrM3dnE-5QGE4NYSvT93mG_lYXOPFeB7xpc,7247
74
+ truss/templates/server.Dockerfile.jinja,sha256=Ts4kty2ZXTJS69XkNHTNHtEyr8yf8VwNQgBBLY89chk,5996
75
75
  truss/templates/control/requirements.txt,sha256=Kk0tYID7trPk5gwX38Wrt2-YGWZAXFJCJRcqJ8ZzCjc,251
76
76
  truss/templates/control/control/application.py,sha256=jYeta6hWe1SkfLL3W4IDmdYjg3ZuKqI_UagWYs5RB_E,3793
77
77
  truss/templates/control/control/endpoints.py,sha256=FM-sgao7I3gMoUTasM3Xq_g2LDoJQe75JxIoaQxzeNo,10031
@@ -92,11 +92,11 @@ truss/templates/custom/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM
92
92
  truss/templates/custom/model/model.py,sha256=J04rLxK09Pwt2F4GoKOLKL-H-CqZUdYIM-PL2CE9PoE,1079
93
93
  truss/templates/custom_python_dx/my_model.py,sha256=NG75mQ6wxzB1BYUemDFZvRLBET-UrzuUK4FuHjqI29U,910
94
94
  truss/templates/docker_server/proxy.conf.jinja,sha256=Lg-PcZzKflG85exZKHNgW_I6r0mATV8AtOIBaE40-RM,1669
95
- truss/templates/docker_server/supervisord.conf.jinja,sha256=dd37fwZE--cutrvOUCqEyJQQQhlp61H2IUs2huKWsSk,1808
95
+ truss/templates/docker_server/supervisord.conf.jinja,sha256=CoaSLv0Lr8t1tS_q102IFufNX2lWrlbCHJLjMhYjOwM,1711
96
96
  truss/templates/server/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
97
97
  truss/templates/server/main.py,sha256=kWXrdD8z8IpamyWxc8qcvd5ck9gM1Kz2QH5qHJCnmOQ,222
98
98
  truss/templates/server/model_wrapper.py,sha256=k75VVISwwlsx5EGb82UZsu8kCM_i6Yi3-Hd0-Kpm1yo,42055
99
- truss/templates/server/requirements.txt,sha256=iRR2BEpBQnt-YOiTEKOnaab7tlR4C23V1cuURuIt7ZY,672
99
+ truss/templates/server/requirements.txt,sha256=Xvf7mT4zjK1B6rIrNW80-An03yCNCXvWiB6OvWrhIxg,672
100
100
  truss/templates/server/truss_server.py,sha256=ob_nceeGtFPZzKKdk_ZZGLoZrJOGE6hR52xM1sPR97A,19498
101
101
  truss/templates/server/common/__init__.py,sha256=qHIqr68L5Tn4mV6S-PbORpcuJ4jmtBR8aCuRTIWDvNo,85
102
102
  truss/templates/server/common/errors.py,sha256=qWeZlmNI8ZGbZbOIp_mtS6IKvUFIzhj3QH8zp-xTp9o,8554
@@ -309,7 +309,6 @@ truss/tests/test_data/test_truss_with_error/packages/helpers_1.py,sha256=qIm-hQY
309
309
  truss/tests/test_data/test_truss_with_error/packages/helpers_2.py,sha256=q_UpVfXq_K2tuHv6YwsIzVHC3sy5k5hKDw6lMCdS0oc,53
310
310
  truss/tests/trt_llm/test_trt_llm_config.py,sha256=lNQ4EEkOsiT17KvnvW1snCeEBd7K_cl9_Y0dko3qpn8,8505
311
311
  truss/tests/trt_llm/test_validation.py,sha256=dmax2EHxRfqxJvWzV8uubkTef50833KBBHw-WkHufL8,2120
312
- truss/tests/util/test_basetenpointer.py,sha256=Bdms21_m8T4xmFNHRO5nS2tU2wU7094_1SkfBxjptmk,9824
313
312
  truss/tests/util/test_config_checks.py,sha256=aoZF_Q-eRd3qz5wjUqa8Cr_7qF2SxodXbBIY_DBuFWg,522
314
313
  truss/tests/util/test_env_vars.py,sha256=hthgB1mU0bJb1H4Jugc-0khArlLZ3x6tLE82cDaa-J0,390
315
314
  truss/tests/util/test_path.py,sha256=YfW3-IM_7iRsdR1Cb26KB1BkDsG_53_BUGBzoxY2Nog,7408
@@ -331,7 +330,6 @@ truss/truss_handle/patch/local_truss_patch_applier.py,sha256=fOHWKt3teYnbqeRsF63
331
330
  truss/truss_handle/patch/signature.py,sha256=8eas8gy6Japd1hrgdmtHmKTTxQmWsbmgKRQQGL2PVuA,858
332
331
  truss/truss_handle/patch/truss_dir_patch_applier.py,sha256=uhhHvKYHn_dpfz0xp4jwO9_qAej5sO3f8of_h-21PP4,3666
333
332
  truss/util/.truss_ignore,sha256=jpQA9ou-r_JEIcEHsUqGLHhir_m3d4IPGNyzKXtS-2g,3131
334
- truss/util/basetenpointer.py,sha256=PJ_meuTuXAopnWsHe1ZaH2RnltfmqdQ4QeXDQEPrblI,5596
335
333
  truss/util/docker.py,sha256=6PD7kMBBrOjsdvgkuSv7JMgZbe3NoJIeGasljMm2SwA,3934
336
334
  truss/util/download.py,sha256=1lfBwzyaNLEp7SAVrBd9BX5inZpkCVp8sBnS9RNoiJA,2521
337
335
  truss/util/env_vars.py,sha256=7Bv686eER71Barrs6fNamk_TrTJGmu9yV2TxaVmupn0,1232
@@ -361,12 +359,12 @@ truss_chains/remote_chainlet/model_skeleton.py,sha256=8ZReLOO2MLcdg7bNZ61C-6j-e6
361
359
  truss_chains/remote_chainlet/stub.py,sha256=Y2gDUzMY9WRaQNHIz-o4dfLUfFyYV9dUhIRQcfgrY8g,17209
362
360
  truss_chains/remote_chainlet/utils.py,sha256=O_5P-VAUvg0cegEW1uKCOf5EBwD8rEGYVoGMivOmc7k,22374
363
361
  truss_train/__init__.py,sha256=7hE6j6-u6UGzCGaNp3CsCN0kAVjBus1Ekups-Bk0fi4,837
364
- truss_train/definitions.py,sha256=RZs4bCWkq7gBJALDLgmd4QxjlxWk6GMs2a62kiAalvw,6758
365
- truss_train/deployment.py,sha256=zmeJ66kg1Wc7l7bwA_cXqv85uMF77hYl7NPHuhc1NPs,2493
362
+ truss_train/definitions.py,sha256=V985HhY4rdXL10DZxpFEpze9ScxzWErMht4WwaPknGU,6789
363
+ truss_train/deployment.py,sha256=fDYRfzFRtVKMRVG0bKXYPmx6HXwLE0ukSQ0f81hG8kk,3020
366
364
  truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
367
365
  truss_train/public_api.py,sha256=9N_NstiUlmBuLUwH_fNG_1x7OhGCytZLNvqKXBlStrM,1220
368
- truss-0.10.9rc538.dist-info/METADATA,sha256=FYmavkZGrV6sXYBVv98hgPBT3XRi4SE1Oq_0MssgDoE,6674
369
- truss-0.10.9rc538.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
370
- truss-0.10.9rc538.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
371
- truss-0.10.9rc538.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
372
- truss-0.10.9rc538.dist-info/RECORD,,
366
+ truss-0.10.10.dist-info/METADATA,sha256=5q5tQ4MtWhQspNwHEsMnSDOLj-fYFeB7zL4VinS2I28,6670
367
+ truss-0.10.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
368
+ truss-0.10.10.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
369
+ truss-0.10.10.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
370
+ truss-0.10.10.dist-info/RECORD,,
@@ -127,6 +127,7 @@ class TrainingJob(custom_types.SafeModelNoExtra):
127
127
  image: Image
128
128
  compute: Compute = Compute()
129
129
  runtime: Runtime = Runtime()
130
+ name: Optional[str] = None
130
131
 
131
132
  def model_dump(self, *args, **kwargs):
132
133
  data = super().model_dump(*args, **kwargs)
truss_train/deployment.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import pathlib
2
2
  from pathlib import Path
3
- from typing import List
3
+ from typing import List, Optional
4
4
 
5
5
  from truss.base.custom_types import SafeModel
6
+ from truss.cli.utils.output import console
6
7
  from truss.remote.baseten import custom_types as b10_types
7
8
  from truss.remote.baseten.api import BasetenApi
8
9
  from truss.remote.baseten.core import archive_dir
@@ -44,6 +45,7 @@ def prepare_push(api: BasetenApi, config: pathlib.Path, training_job: TrainingJo
44
45
  image=training_job.image,
45
46
  runtime=training_job.runtime,
46
47
  compute=training_job.compute,
48
+ name=training_job.name,
47
49
  runtime_artifacts=[
48
50
  S3Artifact(s3_key=credentials["s3_key"], s3_bucket=credentials["s3_bucket"])
49
51
  ],
@@ -57,14 +59,25 @@ def create_training_job(
57
59
  training_project=training_project
58
60
  )
59
61
  prepared_job = prepare_push(remote_provider.api, config, training_project.job)
62
+
60
63
  job_resp = remote_provider.api.create_training_job(
61
64
  project_id=project_resp["id"], job=prepared_job
62
65
  )
63
66
  return job_resp
64
67
 
65
68
 
66
- def create_training_job_from_file(remote_provider: BasetenRemote, config: Path) -> dict:
69
+ def create_training_job_from_file(
70
+ remote_provider: BasetenRemote,
71
+ config: Path,
72
+ job_name_from_cli: Optional[str] = None,
73
+ ) -> dict:
67
74
  with loader.import_training_project(config) as training_project:
75
+ if job_name_from_cli:
76
+ if training_project.job.name:
77
+ console.print(
78
+ f"[bold yellow]⚠ Warning:[/bold yellow] name '{training_project.job.name}' provided in config file will be ignored. Using job name '{job_name_from_cli}' provided via --job-name flag."
79
+ )
80
+ training_project.job.name = job_name_from_cli
68
81
  job_resp = create_training_job(
69
82
  remote_provider=remote_provider,
70
83
  training_project=training_project,
@@ -1,227 +0,0 @@
1
- import time
2
- from pathlib import Path
3
- from tempfile import TemporaryDirectory
4
-
5
- import pytest
6
- import requests
7
- from huggingface_hub.errors import HfHubHTTPError
8
-
9
- from truss.base.truss_config import ModelCache, ModelRepo
10
- from truss.util.basetenpointer import model_cache_hf_to_b10ptr
11
-
12
-
13
- def test_dolly_12b():
14
- ModelCached = ModelCache(
15
- [
16
- dict(
17
- repo_id="databricks/dolly-v2-12b",
18
- revision="19308160448536e378e3db21a73a751579ee7fdd",
19
- use_volume=True,
20
- volume_folder="databricks_dolly_v2_12b",
21
- runtime_secret_name="hf_access_token",
22
- )
23
- ]
24
- )
25
- for _ in range(2):
26
- try:
27
- bptr = model_cache_hf_to_b10ptr(ModelCached)
28
- continue
29
- # timeout by huggingface hub timeout error
30
- except requests.exceptions.ReadTimeout as e:
31
- # this is expected to timeout when the request takes too long
32
- # due to the large size of the model
33
- print("ReadTimeout Error: ", e)
34
- pytest.skip(
35
- "Skipping test due to ReadTimeout error from Hugging Face API, "
36
- "this can happen for large models like Dolly-12b"
37
- )
38
- except HfHubHTTPError as e:
39
- if e.response.status_code == 429:
40
- pytest.skip("Hugging Face API rate limit exceeded")
41
- raise
42
- bptr_list = bptr.pointers
43
- expected = [
44
- {
45
- "resolution": {
46
- "url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/.gitattributes",
47
- "expiration_timestamp": 2373918212,
48
- },
49
- "uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:.gitattributes",
50
- "file_name": "/app/model_cache/databricks_dolly_v2_12b/.gitattributes",
51
- "hashtype": "etag",
52
- "hash": "c7d9f3332a950355d5a77d85000f05e6f45435ea",
53
- "size": 1477,
54
- },
55
- {
56
- "resolution": {
57
- "url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/README.md",
58
- "expiration_timestamp": 2373918212,
59
- },
60
- "uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:README.md",
61
- "file_name": "/app/model_cache/databricks_dolly_v2_12b/README.md",
62
- "hashtype": "etag",
63
- "hash": "2912eb39545af0367335cff448d07214519c5eed",
64
- "size": 10746,
65
- },
66
- {
67
- "resolution": {
68
- "url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/config.json",
69
- "expiration_timestamp": 2373918212,
70
- },
71
- "uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:config.json",
72
- "file_name": "/app/model_cache/databricks_dolly_v2_12b/config.json",
73
- "hashtype": "etag",
74
- "hash": "888c677eda015e2375fad52d75062d14b30ebad9",
75
- "size": 818,
76
- },
77
- {
78
- "resolution": {
79
- "url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/instruct_pipeline.py",
80
- "expiration_timestamp": 2373918212,
81
- },
82
- "uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:instruct_pipeline.py",
83
- "file_name": "/app/model_cache/databricks_dolly_v2_12b/instruct_pipeline.py",
84
- "hashtype": "etag",
85
- "hash": "f8b291569e936cf104f44d003f95451bf5e1f965",
86
- "size": 9159,
87
- },
88
- {
89
- "resolution": {
90
- "url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/pytorch_model.bin",
91
- "expiration_timestamp": 2373918212,
92
- },
93
- "uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:pytorch_model.bin",
94
- "file_name": "/app/model_cache/databricks_dolly_v2_12b/pytorch_model.bin",
95
- "hashtype": "etag",
96
- "hash": "19e10711310992c310c3775964c7635f4b28dd86587403e718c6d6d524a406a5",
97
- "size": 23834965761,
98
- },
99
- {
100
- "resolution": {
101
- "url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/special_tokens_map.json",
102
- "expiration_timestamp": 2373918212,
103
- },
104
- "uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:special_tokens_map.json",
105
- "file_name": "/app/model_cache/databricks_dolly_v2_12b/special_tokens_map.json",
106
- "hashtype": "etag",
107
- "hash": "ecc1ee07dec13ee276fa9f1b29a1078da3280a4d",
108
- "size": 228,
109
- },
110
- {
111
- "resolution": {
112
- "url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/tokenizer.json",
113
- "expiration_timestamp": 2373918212,
114
- },
115
- "uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:tokenizer.json",
116
- "file_name": "/app/model_cache/databricks_dolly_v2_12b/tokenizer.json",
117
- "hashtype": "etag",
118
- "hash": "22868c8caf99a303c1a44bfea98f20f4254fc0e5",
119
- "size": 2114274,
120
- },
121
- {
122
- "resolution": {
123
- "url": "https://huggingface.co/databricks/dolly-v2-12b/resolve/19308160448536e378e3db21a73a751579ee7fdd/tokenizer_config.json",
124
- "expiration_timestamp": 2373918212,
125
- },
126
- "uid": "databricks/dolly-v2-12b:19308160448536e378e3db21a73a751579ee7fdd:tokenizer_config.json",
127
- "file_name": "/app/model_cache/databricks_dolly_v2_12b/tokenizer_config.json",
128
- "hashtype": "etag",
129
- "hash": "51e564ead5d28eebc74b25d86f0a694b7c7cc618",
130
- "size": 449,
131
- },
132
- ]
133
- assert len(bptr_list) == len(expected), (
134
- f"Expected {len(expected)} but got {len(bptr_list)}"
135
- )
136
- for expected, actual in zip(expected, bptr_list):
137
- assert expected["uid"] == actual.uid, (
138
- f"Expected uid {expected['uid']} but got {actual.uid}"
139
- )
140
- assert expected["file_name"] == actual.file_name, (
141
- f"Expected file_name {expected['file_name']} but got {actual.file_name}"
142
- )
143
- assert expected["hash"] == actual.hash, (
144
- f"Expected hash {expected['hash']} but got {actual.hash}"
145
- )
146
- assert expected["size"] == actual.size, (
147
- f"Expected size {expected['size']} but got {actual.size}"
148
- )
149
- assert expected["resolution"]["url"] == actual.resolution.url, (
150
- f"Expected resolution url {expected['resolution']['url']} but got {actual.resolution.url}"
151
- )
152
- # 100 years or more ahead
153
- assert (
154
- actual.resolution.expiration_timestamp
155
- >= time.time() + 20 * 365 * 24 * 60 * 60
156
- ), (
157
- f"Expected unix expiration timestamp to be at least 20 years ahead, but got {actual.resolution.expiration_timestamp}. "
158
- )
159
-
160
- # download first file and verify size
161
- with TemporaryDirectory() as tmp:
162
- # Get the first pointer (.gitattributes)
163
- first_pointer = bptr_list[0]
164
- tmp_path = Path(tmp) / "downloaded_file"
165
-
166
- # Download the file
167
- response = requests.get(first_pointer.resolution.url)
168
- response.raise_for_status()
169
-
170
- # Save the file
171
- tmp_path.write_bytes(response.content)
172
-
173
- # Verify file size matches metadata
174
- actual_size = tmp_path.stat().st_size
175
- assert actual_size == first_pointer.size, (
176
- f"Downloaded file size {actual_size} does not match expected size {first_pointer.size}"
177
- )
178
-
179
-
180
- def test_with_main():
181
- # main should be resolved to 41dec486b25746052d3335decc8f5961607418a0
182
- cache = ModelCache(
183
- [
184
- ModelRepo(
185
- repo_id="intfloat/llm-retriever-base",
186
- revision="main",
187
- ignore_patterns=["*.json", "*.txt", "*.md", "*.bin", "*.model"],
188
- volume_folder="mistral_demo",
189
- use_volume=True,
190
- )
191
- ]
192
- )
193
- b10ptr = model_cache_hf_to_b10ptr(cache)
194
- expected = {
195
- "pointers": [
196
- {
197
- "resolution": {
198
- "url": "https://huggingface.co/intfloat/llm-retriever-base/resolve/41dec486b25746052d3335decc8f5961607418a0/.gitattributes",
199
- "expiration_timestamp": 4044816725,
200
- },
201
- "uid": "intfloat/llm-retriever-base:main:.gitattributes",
202
- "file_name": "/app/model_cache/mistral_demo/.gitattributes",
203
- "hashtype": "etag",
204
- "hash": "a6344aac8c09253b3b630fb776ae94478aa0275b",
205
- "size": 1519,
206
- "runtime_secret_name": "hf_access_token",
207
- },
208
- {
209
- "resolution": {
210
- "url": "https://huggingface.co/intfloat/llm-retriever-base/resolve/41dec486b25746052d3335decc8f5961607418a0/model.safetensors",
211
- "expiration_timestamp": 4044816725,
212
- },
213
- "uid": "intfloat/llm-retriever-base:main:model.safetensors",
214
- "file_name": "/app/model_cache/mistral_demo/model.safetensors",
215
- "hashtype": "etag",
216
- "hash": "565dd4f1cc6318ccf07af8680c27fd935b3b56ca2684d1af58abcd4e8bf6ecfa",
217
- "size": 437955512,
218
- "runtime_secret_name": "hf_access_token",
219
- },
220
- ]
221
- }
222
- assert b10ptr.model_dump() == expected
223
-
224
-
225
- if __name__ == "__main__":
226
- test_dolly_12b()
227
- test_with_main()
@@ -1,160 +0,0 @@
1
- """This file contains the utils to create a basetenpointer from a huggingface repo, which can be resolved at runtime."""
2
-
3
- import time
4
- from pathlib import Path
5
- from typing import TYPE_CHECKING, Optional
6
-
7
- import requests
8
- from huggingface_hub import hf_api, hf_hub_url
9
- from huggingface_hub.utils import filter_repo_objects
10
- from pydantic import BaseModel
11
-
12
- if TYPE_CHECKING:
13
- from truss.base.truss_config import ModelCache
14
-
15
-
16
- # copied from: https://github.com/basetenlabs/baseten/blob/caeba66cd544a5152bb6a018d6ac2871814f327b/baseten_shared/baseten_shared/lms/types.py#L13
17
- class Resolution(BaseModel):
18
- url: str
19
- expiration_timestamp: int
20
-
21
-
22
- class BasetenPointer(BaseModel):
23
- resolution: Optional[Resolution] = None
24
- uid: str
25
- file_name: str
26
- hashtype: str
27
- hash: str
28
- size: int
29
- runtime_secret_name: str = "hf_access_token" # TODO: remove the default
30
-
31
-
32
- class BasetenPointerList(BaseModel):
33
- pointers: list[BasetenPointer]
34
-
35
-
36
- def get_hf_metadata(api: "hf_api.HfApi", repo: str, revision: str, file: str):
37
- url = hf_hub_url(repo_id=repo, revision=revision, filename=file)
38
- meta = api.get_hf_file_metadata(url=url)
39
- return {"etag": meta.etag, "location": meta.location, "size": meta.size, "url": url}
40
-
41
-
42
- def filter_repo_files(
43
- files: list[str],
44
- allow_patterns: Optional[list[str]],
45
- ignore_patterns: Optional[list[str]],
46
- ) -> list[str]:
47
- return list(
48
- filter_repo_objects(
49
- items=files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
50
- )
51
- )
52
-
53
-
54
- def metadata_hf_repo(
55
- repo: str,
56
- revision: str,
57
- allow_patterns: Optional[list[str]] = None,
58
- ignore_patterns: Optional[list[str]] = None,
59
- ) -> dict[str, dict]:
60
- """Lists all files, gathers metadata without downloading, just using the Hugging Face API.
61
- Example:
62
- [{'.gitattributes': HfFileMetadata(
63
- commit_hash='07163b72af1488142a360786df853f237b1a3ca1',
64
- etag='a6344aac8c09253b3b630fb776ae94478aa0275b',
65
- location='https://huggingface.co/intfloat/e5-mistral-7b-instruct/resolve/main/.gitattributes',
66
- url='https://huggingface.co/intfloat/e5-mistral-7b-instruct/resolve/main/.gitattributes',
67
- size=1519)]
68
- """
69
- api = hf_api.HfApi()
70
- model_info = api.model_info(repo_id=repo, revision=revision)
71
- real_revision = model_info.sha
72
- real_revision = real_revision or revision
73
- if revision != real_revision:
74
- print(
75
- f"Warning: revision {revision} is moving, using {real_revision} instead. "
76
- f"Please update your code to use `revision={real_revision}` instead otherwise you will keep moving. "
77
- )
78
- files: list[str] = api.list_repo_files(repo_id=repo, revision=real_revision)
79
- files = filter_repo_files(
80
- files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
81
- )
82
-
83
- hf_files_meta = {
84
- file: get_hf_metadata(api, repo, real_revision, file) for file in files
85
- }
86
-
87
- return hf_files_meta
88
-
89
-
90
- def model_cache_hf_to_b10ptr(cache: "ModelCache") -> BasetenPointerList:
91
- """
92
- Convert a ModelCache object to a BasetenPointer object.
93
- """
94
- assert cache.is_v2, "ModelCache is not v2"
95
-
96
- basetenpointers: list[BasetenPointer] = []
97
-
98
- for model in cache.models:
99
- assert model.revision is not None, "ModelCache is not v2, revision is None"
100
- exception = None
101
- for _ in range(3):
102
- try:
103
- metadata_hf_repo_list = metadata_hf_repo(
104
- repo=model.repo_id,
105
- revision=model.revision,
106
- allow_patterns=model.allow_patterns,
107
- ignore_patterns=model.ignore_patterns,
108
- )
109
- break
110
- except requests.exceptions.ReadTimeout as e:
111
- # this is expected, sometimes huggingface hub times out
112
- print("ReadTimeout Error: ", e)
113
- time.sleep(5)
114
- exception = e
115
- except Exception as e:
116
- raise e
117
- else:
118
- # if we get here, we have exhausted the retries
119
- assert exception is not None, "ReadTimeout Error: " + str(exception)
120
- raise exception
121
- # convert the metadata to b10 pointer format
122
- b10_pointer_list = [
123
- BasetenPointer(
124
- uid=f"{model.repo_id}:{model.revision}:{filename}",
125
- file_name=(Path(model.runtime_path) / filename).as_posix(),
126
- hashtype="etag",
127
- hash=content["etag"],
128
- size=content["size"],
129
- runtime_secret_name=model.runtime_secret_name,
130
- resolution=Resolution(
131
- url=content["url"],
132
- expiration_timestamp=int(
133
- 4044816725 # 90 years in the future, hf does not expire. needs to be static, to have cache hits.
134
- ),
135
- ),
136
- )
137
- for filename, content in metadata_hf_repo_list.items()
138
- ]
139
- basetenpointers.extend(b10_pointer_list)
140
-
141
- return BasetenPointerList(pointers=basetenpointers)
142
-
143
-
144
- if __name__ == "__main__":
145
- # example usage
146
- from truss.base.truss_config import ModelCache, ModelRepo
147
-
148
- cache = ModelCache(
149
- [
150
- ModelRepo(
151
- repo_id="intfloat/llm-retriever-base",
152
- revision="main",
153
- ignore_patterns=["*.json", "*.txt", "*.md", "*.bin", "*.model"],
154
- volume_folder="mistral_demo",
155
- use_volume=True,
156
- )
157
- ]
158
- )
159
- b10ptr = model_cache_hf_to_b10ptr(cache)
160
- print(b10ptr.model_dump_json())