sagemaker-core 2.1.1__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/__init__.py +2 -0
- sagemaker/core/common_utils.py +119 -1
- sagemaker/core/experiments/experiment.py +3 -0
- sagemaker/core/fw_utils.py +56 -12
- sagemaker/core/git_utils.py +66 -0
- sagemaker/core/helper/session_helper.py +22 -10
- sagemaker/core/image_retriever/image_retriever_utils.py +1 -3
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +111 -1
- sagemaker/core/image_uri_config/huggingface-llm.json +110 -1
- sagemaker/core/image_uri_config/huggingface-neuronx.json +182 -6
- sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
- sagemaker/core/image_uri_config/huggingface.json +151 -2
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +40 -0
- sagemaker/core/image_uri_config/sklearn.json +48 -0
- sagemaker/core/image_uri_config/xgboost.json +84 -0
- sagemaker/core/image_uris.py +9 -3
- sagemaker/core/iterators.py +11 -0
- sagemaker/core/jumpstart/models.py +2 -0
- sagemaker/core/jumpstart/region_config.json +8 -0
- sagemaker/core/local/data.py +10 -0
- sagemaker/core/local/utils.py +6 -5
- sagemaker/core/model_monitor/clarify_model_monitoring.py +2 -0
- sagemaker/core/model_registry.py +1 -1
- sagemaker/core/modules/configs.py +14 -1
- sagemaker/core/modules/train/container_drivers/common/utils.py +2 -10
- sagemaker/core/modules/train/sm_recipes/utils.py +1 -1
- sagemaker/core/processing.py +2 -0
- sagemaker/core/remote_function/client.py +31 -6
- sagemaker/core/remote_function/core/pipeline_variables.py +0 -6
- sagemaker/core/remote_function/core/serialization.py +16 -28
- sagemaker/core/remote_function/core/stored_function.py +8 -11
- sagemaker/core/remote_function/errors.py +1 -3
- sagemaker/core/remote_function/invoke_function.py +1 -6
- sagemaker/core/remote_function/job.py +2 -21
- sagemaker/core/telemetry/constants.py +6 -8
- sagemaker/core/telemetry/telemetry_logging.py +6 -5
- sagemaker/core/training/configs.py +16 -4
- sagemaker/core/workflow/utilities.py +10 -3
- {sagemaker_core-2.1.1.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +1 -1
- {sagemaker_core-2.1.1.dist-info → sagemaker_core-2.3.1.dist-info}/RECORD +43 -47
- sagemaker/core/huggingface/__init__.py +0 -29
- sagemaker/core/huggingface/llm_utils.py +0 -150
- sagemaker/core/huggingface/processing.py +0 -139
- sagemaker/core/huggingface/training_compiler/__init__.py +0 -0
- sagemaker/core/huggingface/training_compiler/config.py +0 -167
- sagemaker/core/image_uri_config/__init__.py +0 -13
- {sagemaker_core-2.1.1.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-2.1.1.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
- {sagemaker_core-2.1.1.dist-info → sagemaker_core-2.3.1.dist-info}/top_level.txt +0 -0
sagemaker/__init__.py
ADDED
sagemaker/core/common_utils.py
CHANGED
|
@@ -59,6 +59,7 @@ ALTERNATE_DOMAINS = {
|
|
|
59
59
|
"us-isob-east-1": "sc2s.sgov.gov",
|
|
60
60
|
"us-isof-south-1": "csp.hci.ic.gov",
|
|
61
61
|
"us-isof-east-1": "csp.hci.ic.gov",
|
|
62
|
+
"eu-isoe-west-1": "cloud.adc-e.uk",
|
|
62
63
|
}
|
|
63
64
|
|
|
64
65
|
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
|
|
@@ -74,6 +75,20 @@ DEFAULT_SLEEP_TIME_SECONDS = 10
|
|
|
74
75
|
WAITING_DOT_NUMBER = 10
|
|
75
76
|
MAX_ITEMS = 100
|
|
76
77
|
PAGE_SIZE = 10
|
|
78
|
+
_MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators
|
|
79
|
+
|
|
80
|
+
_SENSITIVE_SYSTEM_PATHS = [
|
|
81
|
+
abspath(os.path.expanduser("~/.aws")),
|
|
82
|
+
abspath(os.path.expanduser("~/.ssh")),
|
|
83
|
+
abspath(os.path.expanduser("~/.kube")),
|
|
84
|
+
abspath(os.path.expanduser("~/.docker")),
|
|
85
|
+
abspath(os.path.expanduser("~/.config")),
|
|
86
|
+
abspath(os.path.expanduser("~/.credentials")),
|
|
87
|
+
"/etc",
|
|
88
|
+
"/root",
|
|
89
|
+
"/var/lib",
|
|
90
|
+
"/opt/ml/metadata",
|
|
91
|
+
]
|
|
77
92
|
|
|
78
93
|
logger = logging.getLogger(__name__)
|
|
79
94
|
|
|
@@ -409,6 +424,9 @@ def download_folder(bucket_name, prefix, target, sagemaker_session):
|
|
|
409
424
|
|
|
410
425
|
prefix = prefix.lstrip("/")
|
|
411
426
|
|
|
427
|
+
if ".." in prefix:
|
|
428
|
+
raise ValueError("Traversal components are not allowed in S3 path!")
|
|
429
|
+
|
|
412
430
|
# Try to download the prefix as an object first, in case it is a file and not a 'directory'.
|
|
413
431
|
# Do this first, in case the object has broader permissions than the bucket.
|
|
414
432
|
if not prefix.endswith("/"):
|
|
@@ -607,11 +625,73 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
|
|
|
607
625
|
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
|
|
608
626
|
|
|
609
627
|
|
|
628
|
+
def _validate_source_directory(source_directory):
|
|
629
|
+
"""Validate that source_directory is safe to use.
|
|
630
|
+
|
|
631
|
+
Ensures the source directory path does not access restricted system locations.
|
|
632
|
+
|
|
633
|
+
Args:
|
|
634
|
+
source_directory (str): The source directory path to validate.
|
|
635
|
+
|
|
636
|
+
Raises:
|
|
637
|
+
ValueError: If the path is not allowed.
|
|
638
|
+
"""
|
|
639
|
+
if not source_directory or source_directory.lower().startswith("s3://"):
|
|
640
|
+
# S3 paths and None are safe
|
|
641
|
+
return
|
|
642
|
+
|
|
643
|
+
# Resolve symlinks to get the actual path
|
|
644
|
+
abs_source = abspath(realpath(source_directory))
|
|
645
|
+
|
|
646
|
+
# Check if the source path is under any sensitive directory
|
|
647
|
+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
|
|
648
|
+
if abs_source != "/" and abs_source.startswith(sensitive_path):
|
|
649
|
+
raise ValueError(
|
|
650
|
+
f"source_directory cannot access sensitive system paths. "
|
|
651
|
+
f"Got: {source_directory} (resolved to {abs_source})"
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
def _validate_dependency_path(dependency):
|
|
656
|
+
"""Validate that a dependency path is safe to use.
|
|
657
|
+
|
|
658
|
+
Ensures the dependency path does not access restricted system locations.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
dependency (str): The dependency path to validate.
|
|
662
|
+
|
|
663
|
+
Raises:
|
|
664
|
+
ValueError: If the path is not allowed.
|
|
665
|
+
"""
|
|
666
|
+
if not dependency:
|
|
667
|
+
return
|
|
668
|
+
|
|
669
|
+
# Resolve symlinks to get the actual path
|
|
670
|
+
abs_dependency = abspath(realpath(dependency))
|
|
671
|
+
|
|
672
|
+
# Check if the dependency path is under any sensitive directory
|
|
673
|
+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
|
|
674
|
+
if abs_dependency != "/" and abs_dependency.startswith(sensitive_path):
|
|
675
|
+
raise ValueError(
|
|
676
|
+
f"dependency path cannot access sensitive system paths. "
|
|
677
|
+
f"Got: {dependency} (resolved to {abs_dependency})"
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
|
|
610
681
|
def _create_or_update_code_dir(
|
|
611
682
|
model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp
|
|
612
683
|
):
|
|
613
684
|
"""Placeholder docstring"""
|
|
614
685
|
code_dir = os.path.join(model_dir, "code")
|
|
686
|
+
resolved_code_dir = _get_resolved_path(code_dir)
|
|
687
|
+
|
|
688
|
+
# Validate that code_dir does not resolve to a sensitive system path
|
|
689
|
+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
|
|
690
|
+
if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path):
|
|
691
|
+
raise ValueError(
|
|
692
|
+
f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}"
|
|
693
|
+
)
|
|
694
|
+
|
|
615
695
|
if source_directory and source_directory.lower().startswith("s3://"):
|
|
616
696
|
local_code_path = os.path.join(tmp, "local_code.tar.gz")
|
|
617
697
|
download_file_from_url(source_directory, local_code_path, sagemaker_session)
|
|
@@ -620,6 +700,8 @@ def _create_or_update_code_dir(
|
|
|
620
700
|
custom_extractall_tarfile(t, code_dir)
|
|
621
701
|
|
|
622
702
|
elif source_directory:
|
|
703
|
+
# Validate source_directory for security
|
|
704
|
+
_validate_source_directory(source_directory)
|
|
623
705
|
if os.path.exists(code_dir):
|
|
624
706
|
shutil.rmtree(code_dir)
|
|
625
707
|
shutil.copytree(source_directory, code_dir)
|
|
@@ -635,6 +717,8 @@ def _create_or_update_code_dir(
|
|
|
635
717
|
raise
|
|
636
718
|
|
|
637
719
|
for dependency in dependencies:
|
|
720
|
+
# Validate dependency path for security
|
|
721
|
+
_validate_dependency_path(dependency)
|
|
638
722
|
lib_dir = os.path.join(code_dir, "lib")
|
|
639
723
|
if os.path.isdir(dependency):
|
|
640
724
|
shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
|
|
@@ -1555,7 +1639,7 @@ def get_instance_type_family(instance_type: str) -> str:
|
|
|
1555
1639
|
"""
|
|
1556
1640
|
instance_type_family = ""
|
|
1557
1641
|
if isinstance(instance_type, str):
|
|
1558
|
-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
|
|
1642
|
+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
|
|
1559
1643
|
if match is not None:
|
|
1560
1644
|
instance_type_family = match[1]
|
|
1561
1645
|
return instance_type_family
|
|
@@ -1646,6 +1730,38 @@ def _get_safe_members(members):
|
|
|
1646
1730
|
yield file_info
|
|
1647
1731
|
|
|
1648
1732
|
|
|
1733
|
+
def _validate_extracted_paths(extract_path):
|
|
1734
|
+
"""Validate that extracted paths remain within the expected directory.
|
|
1735
|
+
|
|
1736
|
+
Performs post-extraction validation to ensure all extracted files and directories
|
|
1737
|
+
are within the intended extraction path.
|
|
1738
|
+
|
|
1739
|
+
Args:
|
|
1740
|
+
extract_path (str): The path where files were extracted.
|
|
1741
|
+
|
|
1742
|
+
Raises:
|
|
1743
|
+
ValueError: If any extracted file is outside the expected extraction path.
|
|
1744
|
+
"""
|
|
1745
|
+
base = _get_resolved_path(extract_path)
|
|
1746
|
+
|
|
1747
|
+
for root, dirs, files in os.walk(extract_path):
|
|
1748
|
+
# Check directories
|
|
1749
|
+
for dir_name in dirs:
|
|
1750
|
+
dir_path = os.path.join(root, dir_name)
|
|
1751
|
+
resolved = _get_resolved_path(dir_path)
|
|
1752
|
+
if not resolved.startswith(base):
|
|
1753
|
+
logger.error("Extracted directory escaped extraction path: %s", dir_path)
|
|
1754
|
+
raise ValueError(f"Extracted path outside expected directory: {dir_path}")
|
|
1755
|
+
|
|
1756
|
+
# Check files
|
|
1757
|
+
for file_name in files:
|
|
1758
|
+
file_path = os.path.join(root, file_name)
|
|
1759
|
+
resolved = _get_resolved_path(file_path)
|
|
1760
|
+
if not resolved.startswith(base):
|
|
1761
|
+
logger.error("Extracted file escaped extraction path: %s", file_path)
|
|
1762
|
+
raise ValueError(f"Extracted path outside expected directory: {file_path}")
|
|
1763
|
+
|
|
1764
|
+
|
|
1649
1765
|
def custom_extractall_tarfile(tar, extract_path):
|
|
1650
1766
|
"""Extract a tarfile, optionally using data_filter if available.
|
|
1651
1767
|
|
|
@@ -1666,6 +1782,8 @@ def custom_extractall_tarfile(tar, extract_path):
|
|
|
1666
1782
|
tar.extractall(path=extract_path, filter="data")
|
|
1667
1783
|
else:
|
|
1668
1784
|
tar.extractall(path=extract_path, members=_get_safe_members(tar))
|
|
1785
|
+
# Re-validate extracted paths to catch symlink race conditions
|
|
1786
|
+
_validate_extracted_paths(extract_path)
|
|
1669
1787
|
|
|
1670
1788
|
|
|
1671
1789
|
def can_model_package_source_uri_autopopulate(source_uri: str):
|
|
@@ -21,6 +21,8 @@ from sagemaker.core.apiutils import _base_types
|
|
|
21
21
|
from sagemaker.core.experiments.trial import _Trial
|
|
22
22
|
from sagemaker.core.experiments.trial_component import _TrialComponent
|
|
23
23
|
from sagemaker.core.common_utils import format_tags
|
|
24
|
+
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
|
|
25
|
+
from sagemaker.core.telemetry.constants import Feature
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
class Experiment(_base_types.Record):
|
|
@@ -93,6 +95,7 @@ class Experiment(_base_types.Record):
|
|
|
93
95
|
)
|
|
94
96
|
|
|
95
97
|
@classmethod
|
|
98
|
+
@_telemetry_emitter(feature=Feature.MLOPS, func_name="experiment.create")
|
|
96
99
|
def create(
|
|
97
100
|
cls,
|
|
98
101
|
experiment_name,
|
sagemaker/core/fw_utils.py
CHANGED
|
@@ -25,13 +25,13 @@ from typing import Dict, List, Optional, Union
|
|
|
25
25
|
|
|
26
26
|
from packaging import version
|
|
27
27
|
|
|
28
|
-
import sagemaker.core.common_utils as
|
|
29
|
-
from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs
|
|
28
|
+
import sagemaker.core.common_utils as utils
|
|
29
|
+
from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs, renamed_warning
|
|
30
30
|
from sagemaker.core.instance_group import InstanceGroup
|
|
31
|
-
from sagemaker.core.s3 import s3_path_join
|
|
31
|
+
from sagemaker.core.s3.utils import s3_path_join
|
|
32
32
|
from sagemaker.core.session_settings import SessionSettings
|
|
33
33
|
from sagemaker.core.workflow import is_pipeline_variable
|
|
34
|
-
from sagemaker.core.
|
|
34
|
+
from sagemaker.core.workflow.entities import PipelineVariable
|
|
35
35
|
|
|
36
36
|
logger = logging.getLogger(__name__)
|
|
37
37
|
|
|
@@ -155,6 +155,9 @@ TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [
|
|
|
155
155
|
"2.3.1",
|
|
156
156
|
"2.4.1",
|
|
157
157
|
"2.5.1",
|
|
158
|
+
"2.6.0",
|
|
159
|
+
"2.7.1",
|
|
160
|
+
"2.8.0",
|
|
158
161
|
]
|
|
159
162
|
|
|
160
163
|
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
|
|
@@ -455,7 +458,7 @@ def tar_and_upload_dir(
|
|
|
455
458
|
|
|
456
459
|
try:
|
|
457
460
|
source_files = _list_files_to_compress(script, directory) + dependencies
|
|
458
|
-
tar_file =
|
|
461
|
+
tar_file = utils.create_tar_file(
|
|
459
462
|
source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME)
|
|
460
463
|
)
|
|
461
464
|
|
|
@@ -516,7 +519,7 @@ def framework_name_from_image(image_uri):
|
|
|
516
519
|
- str: The image tag
|
|
517
520
|
- str: If the TensorFlow image is script mode
|
|
518
521
|
"""
|
|
519
|
-
sagemaker_pattern = re.compile(
|
|
522
|
+
sagemaker_pattern = re.compile(utils.ECR_URI_PATTERN)
|
|
520
523
|
sagemaker_match = sagemaker_pattern.match(image_uri)
|
|
521
524
|
if sagemaker_match is None:
|
|
522
525
|
return None, None, None, None
|
|
@@ -595,7 +598,7 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image):
|
|
|
595
598
|
"""
|
|
596
599
|
name_from_image = f"/model_code/{int(time.time())}"
|
|
597
600
|
if not is_pipeline_variable(image):
|
|
598
|
-
name_from_image =
|
|
601
|
+
name_from_image = utils.name_from_image(image)
|
|
599
602
|
return s3_path_join(code_location_key_prefix, model_name or name_from_image)
|
|
600
603
|
|
|
601
604
|
|
|
@@ -961,7 +964,7 @@ def validate_distribution_for_instance_type(instance_type, distribution):
|
|
|
961
964
|
"""
|
|
962
965
|
err_msg = ""
|
|
963
966
|
if isinstance(instance_type, str):
|
|
964
|
-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
|
|
967
|
+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
|
|
965
968
|
if match and match[1].startswith("trn"):
|
|
966
969
|
keys = list(distribution.keys())
|
|
967
970
|
if len(keys) == 0:
|
|
@@ -1062,7 +1065,7 @@ def validate_torch_distributed_distribution(
|
|
|
1062
1065
|
)
|
|
1063
1066
|
|
|
1064
1067
|
# Check entry point type
|
|
1065
|
-
if not entry_point.endswith(".py"):
|
|
1068
|
+
if entry_point is not None and not entry_point.endswith(".py"):
|
|
1066
1069
|
err_msg += (
|
|
1067
1070
|
"Unsupported entry point type for the distribution torch_distributed.\n"
|
|
1068
1071
|
"Only python programs (*.py) are supported."
|
|
@@ -1082,7 +1085,7 @@ def _is_gpu_instance(instance_type):
|
|
|
1082
1085
|
bool: Whether or not the instance_type supports GPU
|
|
1083
1086
|
"""
|
|
1084
1087
|
if isinstance(instance_type, str):
|
|
1085
|
-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
|
|
1088
|
+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
|
|
1086
1089
|
if match:
|
|
1087
1090
|
if match[1].startswith("p") or match[1].startswith("g"):
|
|
1088
1091
|
return True
|
|
@@ -1101,7 +1104,7 @@ def _is_trainium_instance(instance_type):
|
|
|
1101
1104
|
bool: Whether or not the instance_type is a Trainium instance
|
|
1102
1105
|
"""
|
|
1103
1106
|
if isinstance(instance_type, str):
|
|
1104
|
-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
|
|
1107
|
+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
|
|
1105
1108
|
if match and match[1].startswith("trn"):
|
|
1106
1109
|
return True
|
|
1107
1110
|
return False
|
|
@@ -1148,7 +1151,7 @@ def _instance_type_supports_profiler(instance_type):
|
|
|
1148
1151
|
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
|
|
1149
1152
|
"""
|
|
1150
1153
|
if isinstance(instance_type, str):
|
|
1151
|
-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
|
|
1154
|
+
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
|
|
1152
1155
|
if match and match[1].startswith("trn"):
|
|
1153
1156
|
return True
|
|
1154
1157
|
return False
|
|
@@ -1174,3 +1177,44 @@ def validate_version_or_image_args(framework_version, py_version, image_uri):
|
|
|
1174
1177
|
"framework_version or py_version was None, yet image_uri was also None. "
|
|
1175
1178
|
"Either specify both framework_version and py_version, or specify image_uri."
|
|
1176
1179
|
)
|
|
1180
|
+
|
|
1181
|
+
|
|
1182
|
+
def create_image_uri(
|
|
1183
|
+
region,
|
|
1184
|
+
framework,
|
|
1185
|
+
instance_type,
|
|
1186
|
+
framework_version,
|
|
1187
|
+
py_version=None,
|
|
1188
|
+
account=None, # pylint: disable=W0613
|
|
1189
|
+
accelerator_type=None,
|
|
1190
|
+
optimized_families=None, # pylint: disable=W0613
|
|
1191
|
+
):
|
|
1192
|
+
"""Deprecated method. Please use sagemaker.image_uris.retrieve().
|
|
1193
|
+
|
|
1194
|
+
Args:
|
|
1195
|
+
region (str): AWS region where the image is uploaded.
|
|
1196
|
+
framework (str): framework used by the image.
|
|
1197
|
+
instance_type (str): SageMaker instance type. Used to determine device
|
|
1198
|
+
type (cpu/gpu/family-specific optimized).
|
|
1199
|
+
framework_version (str): The version of the framework.
|
|
1200
|
+
py_version (str): Optional. Python version Ex: `py38, py39, py310, py311`.
|
|
1201
|
+
If not specified, image uri will not include a python component.
|
|
1202
|
+
account (str): AWS account that contains the image. (default:
|
|
1203
|
+
'520713654638')
|
|
1204
|
+
accelerator_type (str): SageMaker Elastic Inference accelerator type.
|
|
1205
|
+
optimized_families (str): Deprecated. A no-op argument.
|
|
1206
|
+
|
|
1207
|
+
Returns:
|
|
1208
|
+
the image uri
|
|
1209
|
+
"""
|
|
1210
|
+
from sagemaker.core import image_uris
|
|
1211
|
+
|
|
1212
|
+
renamed_warning("The method create_image_uri")
|
|
1213
|
+
return image_uris.retrieve(
|
|
1214
|
+
framework=framework,
|
|
1215
|
+
region=region,
|
|
1216
|
+
version=framework_version,
|
|
1217
|
+
py_version=py_version,
|
|
1218
|
+
instance_type=instance_type,
|
|
1219
|
+
accelerator_type=accelerator_type,
|
|
1220
|
+
)
|
sagemaker/core/git_utils.py
CHANGED
|
@@ -20,7 +20,69 @@ import tempfile
|
|
|
20
20
|
import warnings
|
|
21
21
|
import six
|
|
22
22
|
from six.moves import urllib
|
|
23
|
+
import re
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from urllib.parse import urlparse
|
|
26
|
+
|
|
27
|
+
def _sanitize_git_url(repo_url):
|
|
28
|
+
"""Sanitize Git repository URL to prevent URL injection attacks.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
repo_url (str): The Git repository URL to sanitize
|
|
23
32
|
|
|
33
|
+
Returns:
|
|
34
|
+
str: The sanitized URL
|
|
35
|
+
|
|
36
|
+
Raises:
|
|
37
|
+
ValueError: If the URL contains suspicious patterns that could indicate injection
|
|
38
|
+
"""
|
|
39
|
+
at_count = repo_url.count("@")
|
|
40
|
+
|
|
41
|
+
if repo_url.startswith("git@"):
|
|
42
|
+
# git@ format requires exactly one @
|
|
43
|
+
if at_count != 1:
|
|
44
|
+
raise ValueError("Invalid SSH URL format: git@ URLs must have exactly one @ symbol")
|
|
45
|
+
elif repo_url.startswith("ssh://"):
|
|
46
|
+
# ssh:// format can have 0 or 1 @ symbols
|
|
47
|
+
if at_count > 1:
|
|
48
|
+
raise ValueError("Invalid SSH URL format: multiple @ symbols detected")
|
|
49
|
+
elif repo_url.startswith("https://") or repo_url.startswith("http://"):
|
|
50
|
+
# HTTPS format allows 0 or 1 @ symbols
|
|
51
|
+
if at_count > 1:
|
|
52
|
+
raise ValueError("Invalid HTTPS URL format: multiple @ symbols detected")
|
|
53
|
+
|
|
54
|
+
# Check for invalid characters in the URL before parsing
|
|
55
|
+
# These characters should not appear in legitimate URLs
|
|
56
|
+
invalid_chars = ["<", ">", "[", "]", "{", "}", "\\", "^", "`", "|"]
|
|
57
|
+
for char in invalid_chars:
|
|
58
|
+
if char in repo_url:
|
|
59
|
+
raise ValueError("Invalid characters in hostname")
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
parsed = urlparse(repo_url)
|
|
63
|
+
|
|
64
|
+
# Check for suspicious characters in hostname that could indicate injection
|
|
65
|
+
if parsed.hostname:
|
|
66
|
+
# Check for URL-encoded characters that might be used for obfuscation
|
|
67
|
+
suspicious_patterns = ["%25", "%40", "%2F", "%3A"] # encoded %, @, /, :
|
|
68
|
+
for pattern in suspicious_patterns:
|
|
69
|
+
if pattern in parsed.hostname.lower():
|
|
70
|
+
raise ValueError(f"Suspicious URL encoding detected in hostname: {pattern}")
|
|
71
|
+
|
|
72
|
+
# Validate that the hostname looks legitimate
|
|
73
|
+
if not re.match(r"^[a-zA-Z0-9.-]+$", parsed.hostname):
|
|
74
|
+
raise ValueError("Invalid characters in hostname")
|
|
75
|
+
|
|
76
|
+
except Exception as e:
|
|
77
|
+
if isinstance(e, ValueError):
|
|
78
|
+
raise
|
|
79
|
+
raise ValueError(f"Failed to parse URL: {str(e)}")
|
|
80
|
+
else:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
"Unsupported URL scheme: only https://, http://, git@, and ssh:// are allowed"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return repo_url
|
|
24
86
|
|
|
25
87
|
def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
|
|
26
88
|
"""Git clone repo containing the training code and serving code.
|
|
@@ -87,6 +149,10 @@ def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
|
|
|
87
149
|
if entry_point is None:
|
|
88
150
|
raise ValueError("Please provide an entry point.")
|
|
89
151
|
_validate_git_config(git_config)
|
|
152
|
+
|
|
153
|
+
# SECURITY: Sanitize the repository URL to prevent injection attacks
|
|
154
|
+
git_config["repo"] = _sanitize_git_url(git_config["repo"])
|
|
155
|
+
|
|
90
156
|
dest_dir = tempfile.mkdtemp()
|
|
91
157
|
_generate_and_run_clone_command(git_config, dest_dir)
|
|
92
158
|
|
|
@@ -330,16 +330,16 @@ class Session(object): # pylint: disable=too-many-public-methods
|
|
|
330
330
|
user_profile_name = metadata.get("UserProfileName")
|
|
331
331
|
execution_role_arn = metadata.get("ExecutionRoleArn")
|
|
332
332
|
try:
|
|
333
|
+
# find execution role from the metadata file if present
|
|
334
|
+
if execution_role_arn is not None:
|
|
335
|
+
return execution_role_arn
|
|
336
|
+
|
|
333
337
|
if domain_id is None:
|
|
334
338
|
instance_desc = self.sagemaker_client.describe_notebook_instance(
|
|
335
339
|
NotebookInstanceName=instance_name
|
|
336
340
|
)
|
|
337
341
|
return instance_desc["RoleArn"]
|
|
338
342
|
|
|
339
|
-
# find execution role from the metadata file if present
|
|
340
|
-
if execution_role_arn is not None:
|
|
341
|
-
return execution_role_arn
|
|
342
|
-
|
|
343
343
|
user_profile_desc = self.sagemaker_client.describe_user_profile(
|
|
344
344
|
DomainId=domain_id, UserProfileName=user_profile_name
|
|
345
345
|
)
|
|
@@ -666,9 +666,16 @@ class Session(object): # pylint: disable=too-many-public-methods
|
|
|
666
666
|
|
|
667
667
|
"""
|
|
668
668
|
try:
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
669
|
+
if self.default_bucket_prefix:
|
|
670
|
+
s3.meta.client.list_objects_v2(
|
|
671
|
+
Bucket=bucket_name,
|
|
672
|
+
Prefix=self.default_bucket_prefix,
|
|
673
|
+
ExpectedBucketOwner=expected_bucket_owner_id,
|
|
674
|
+
)
|
|
675
|
+
else:
|
|
676
|
+
s3.meta.client.head_bucket(
|
|
677
|
+
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
|
|
678
|
+
)
|
|
672
679
|
except ClientError as e:
|
|
673
680
|
error_code = e.response["Error"]["Code"]
|
|
674
681
|
message = e.response["Error"]["Message"]
|
|
@@ -699,7 +706,12 @@ class Session(object): # pylint: disable=too-many-public-methods
|
|
|
699
706
|
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
|
|
700
707
|
"""
|
|
701
708
|
try:
|
|
702
|
-
|
|
709
|
+
if self.default_bucket_prefix:
|
|
710
|
+
s3.meta.client.list_objects_v2(
|
|
711
|
+
Bucket=bucket_name, Prefix=self.default_bucket_prefix
|
|
712
|
+
)
|
|
713
|
+
else:
|
|
714
|
+
s3.meta.client.head_bucket(Bucket=bucket_name)
|
|
703
715
|
except ClientError as e:
|
|
704
716
|
error_code = e.response["Error"]["Code"]
|
|
705
717
|
message = e.response["Error"]["Message"]
|
|
@@ -1865,7 +1877,7 @@ class Session(object): # pylint: disable=too-many-public-methods
|
|
|
1865
1877
|
if "/" in role:
|
|
1866
1878
|
return role
|
|
1867
1879
|
return self.boto_session.resource("iam").Role(role).arn
|
|
1868
|
-
|
|
1880
|
+
|
|
1869
1881
|
|
|
1870
1882
|
def _expand_container_def(c_def):
|
|
1871
1883
|
"""Placeholder docstring"""
|
|
@@ -2962,4 +2974,4 @@ def container_def(
|
|
|
2962
2974
|
c_def["Mode"] = container_mode
|
|
2963
2975
|
if image_config:
|
|
2964
2976
|
c_def["ImageConfig"] = image_config
|
|
2965
|
-
return c_def
|
|
2977
|
+
return c_def
|
|
@@ -185,9 +185,7 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
|
|
|
185
185
|
|
|
186
186
|
def config_for_framework(framework):
|
|
187
187
|
"""Loads the JSON config for the given framework."""
|
|
188
|
-
|
|
189
|
-
return response.json()
|
|
190
|
-
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))
|
|
188
|
+
fname = os.path.join(os.path.dirname(__file__), "..", "image_uri_config", "{}.json".format(framework))
|
|
191
189
|
with open(fname) as f:
|
|
192
190
|
return json.load(f)
|
|
193
191
|
|
|
@@ -4,7 +4,9 @@
|
|
|
4
4
|
"inf2"
|
|
5
5
|
],
|
|
6
6
|
"version_aliases": {
|
|
7
|
-
"0.0": "0.0.28"
|
|
7
|
+
"0.0": "0.0.28",
|
|
8
|
+
"0.2": "0.2.0",
|
|
9
|
+
"0.3": "0.3.0"
|
|
8
10
|
},
|
|
9
11
|
"versions": {
|
|
10
12
|
"0.0.16": {
|
|
@@ -654,6 +656,114 @@
|
|
|
654
656
|
"container_version": {
|
|
655
657
|
"inf2": "ubuntu22.04"
|
|
656
658
|
}
|
|
659
|
+
},
|
|
660
|
+
"0.2.0": {
|
|
661
|
+
"py_versions": [
|
|
662
|
+
"py310"
|
|
663
|
+
],
|
|
664
|
+
"registries": {
|
|
665
|
+
"af-south-1": "626614931356",
|
|
666
|
+
"ap-east-1": "871362719292",
|
|
667
|
+
"ap-east-2": "975050140332",
|
|
668
|
+
"ap-northeast-1": "763104351884",
|
|
669
|
+
"ap-northeast-2": "763104351884",
|
|
670
|
+
"ap-northeast-3": "364406365360",
|
|
671
|
+
"ap-south-1": "763104351884",
|
|
672
|
+
"ap-south-2": "772153158452",
|
|
673
|
+
"ap-southeast-1": "763104351884",
|
|
674
|
+
"ap-southeast-2": "763104351884",
|
|
675
|
+
"ap-southeast-3": "907027046896",
|
|
676
|
+
"ap-southeast-4": "457447274322",
|
|
677
|
+
"ap-southeast-5": "550225433462",
|
|
678
|
+
"ap-southeast-6": "633930458069",
|
|
679
|
+
"ap-southeast-7": "590183813437",
|
|
680
|
+
"ca-central-1": "763104351884",
|
|
681
|
+
"ca-west-1": "204538143572",
|
|
682
|
+
"cn-north-1": "727897471807",
|
|
683
|
+
"cn-northwest-1": "727897471807",
|
|
684
|
+
"eu-central-1": "763104351884",
|
|
685
|
+
"eu-central-2": "380420809688",
|
|
686
|
+
"eu-north-1": "763104351884",
|
|
687
|
+
"eu-south-1": "692866216735",
|
|
688
|
+
"eu-south-2": "503227376785",
|
|
689
|
+
"eu-west-1": "763104351884",
|
|
690
|
+
"eu-west-2": "763104351884",
|
|
691
|
+
"eu-west-3": "763104351884",
|
|
692
|
+
"il-central-1": "780543022126",
|
|
693
|
+
"me-central-1": "914824155844",
|
|
694
|
+
"me-south-1": "217643126080",
|
|
695
|
+
"mx-central-1": "637423239942",
|
|
696
|
+
"sa-east-1": "763104351884",
|
|
697
|
+
"us-east-1": "763104351884",
|
|
698
|
+
"us-east-2": "763104351884",
|
|
699
|
+
"us-gov-east-1": "446045086412",
|
|
700
|
+
"us-gov-west-1": "442386744353",
|
|
701
|
+
"us-iso-east-1": "886529160074",
|
|
702
|
+
"us-isob-east-1": "094389454867",
|
|
703
|
+
"us-isof-east-1": "303241398832",
|
|
704
|
+
"us-isof-south-1": "454834333376",
|
|
705
|
+
"us-west-1": "763104351884",
|
|
706
|
+
"us-west-2": "763104351884"
|
|
707
|
+
},
|
|
708
|
+
"tag_prefix": "2.5.1-optimum3.3.4",
|
|
709
|
+
"repository": "huggingface-pytorch-tgi-inference",
|
|
710
|
+
"container_version": {
|
|
711
|
+
"inf2": "ubuntu22.04"
|
|
712
|
+
}
|
|
713
|
+
},
|
|
714
|
+
"0.3.0": {
|
|
715
|
+
"py_versions": [
|
|
716
|
+
"py310"
|
|
717
|
+
],
|
|
718
|
+
"registries": {
|
|
719
|
+
"af-south-1": "626614931356",
|
|
720
|
+
"ap-east-1": "871362719292",
|
|
721
|
+
"ap-east-2": "975050140332",
|
|
722
|
+
"ap-northeast-1": "763104351884",
|
|
723
|
+
"ap-northeast-2": "763104351884",
|
|
724
|
+
"ap-northeast-3": "364406365360",
|
|
725
|
+
"ap-south-1": "763104351884",
|
|
726
|
+
"ap-south-2": "772153158452",
|
|
727
|
+
"ap-southeast-1": "763104351884",
|
|
728
|
+
"ap-southeast-2": "763104351884",
|
|
729
|
+
"ap-southeast-3": "907027046896",
|
|
730
|
+
"ap-southeast-4": "457447274322",
|
|
731
|
+
"ap-southeast-5": "550225433462",
|
|
732
|
+
"ap-southeast-6": "633930458069",
|
|
733
|
+
"ap-southeast-7": "590183813437",
|
|
734
|
+
"ca-central-1": "763104351884",
|
|
735
|
+
"ca-west-1": "204538143572",
|
|
736
|
+
"cn-north-1": "727897471807",
|
|
737
|
+
"cn-northwest-1": "727897471807",
|
|
738
|
+
"eu-central-1": "763104351884",
|
|
739
|
+
"eu-central-2": "380420809688",
|
|
740
|
+
"eu-north-1": "763104351884",
|
|
741
|
+
"eu-south-1": "692866216735",
|
|
742
|
+
"eu-south-2": "503227376785",
|
|
743
|
+
"eu-west-1": "763104351884",
|
|
744
|
+
"eu-west-2": "763104351884",
|
|
745
|
+
"eu-west-3": "763104351884",
|
|
746
|
+
"il-central-1": "780543022126",
|
|
747
|
+
"me-central-1": "914824155844",
|
|
748
|
+
"me-south-1": "217643126080",
|
|
749
|
+
"mx-central-1": "637423239942",
|
|
750
|
+
"sa-east-1": "763104351884",
|
|
751
|
+
"us-east-1": "763104351884",
|
|
752
|
+
"us-east-2": "763104351884",
|
|
753
|
+
"us-gov-east-1": "446045086412",
|
|
754
|
+
"us-gov-west-1": "442386744353",
|
|
755
|
+
"us-iso-east-1": "886529160074",
|
|
756
|
+
"us-isob-east-1": "094389454867",
|
|
757
|
+
"us-isof-east-1": "303241398832",
|
|
758
|
+
"us-isof-south-1": "454834333376",
|
|
759
|
+
"us-west-1": "763104351884",
|
|
760
|
+
"us-west-2": "763104351884"
|
|
761
|
+
},
|
|
762
|
+
"tag_prefix": "2.7.0-optimum3.3.6",
|
|
763
|
+
"repository": "huggingface-pytorch-tgi-inference",
|
|
764
|
+
"container_version": {
|
|
765
|
+
"inf2": "ubuntu22.04"
|
|
766
|
+
}
|
|
657
767
|
}
|
|
658
768
|
}
|
|
659
769
|
}
|