sagemaker-core 2.1.1__py3-none-any.whl → 2.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. sagemaker/__init__.py +2 -0
  2. sagemaker/core/common_utils.py +119 -1
  3. sagemaker/core/experiments/experiment.py +3 -0
  4. sagemaker/core/fw_utils.py +56 -12
  5. sagemaker/core/git_utils.py +66 -0
  6. sagemaker/core/helper/session_helper.py +22 -10
  7. sagemaker/core/image_retriever/image_retriever_utils.py +1 -3
  8. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +111 -1
  9. sagemaker/core/image_uri_config/huggingface-llm.json +110 -1
  10. sagemaker/core/image_uri_config/huggingface-neuronx.json +182 -6
  11. sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
  12. sagemaker/core/image_uri_config/huggingface.json +151 -2
  13. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +40 -0
  14. sagemaker/core/image_uri_config/sklearn.json +48 -0
  15. sagemaker/core/image_uri_config/xgboost.json +84 -0
  16. sagemaker/core/image_uris.py +9 -3
  17. sagemaker/core/iterators.py +11 -0
  18. sagemaker/core/jumpstart/models.py +2 -0
  19. sagemaker/core/jumpstart/region_config.json +8 -0
  20. sagemaker/core/local/data.py +10 -0
  21. sagemaker/core/local/utils.py +6 -5
  22. sagemaker/core/model_monitor/clarify_model_monitoring.py +2 -0
  23. sagemaker/core/model_registry.py +1 -1
  24. sagemaker/core/modules/configs.py +14 -1
  25. sagemaker/core/modules/train/container_drivers/common/utils.py +2 -10
  26. sagemaker/core/modules/train/sm_recipes/utils.py +1 -1
  27. sagemaker/core/processing.py +2 -0
  28. sagemaker/core/remote_function/client.py +31 -6
  29. sagemaker/core/remote_function/core/pipeline_variables.py +0 -6
  30. sagemaker/core/remote_function/core/serialization.py +16 -28
  31. sagemaker/core/remote_function/core/stored_function.py +8 -11
  32. sagemaker/core/remote_function/errors.py +1 -3
  33. sagemaker/core/remote_function/invoke_function.py +1 -6
  34. sagemaker/core/remote_function/job.py +2 -21
  35. sagemaker/core/telemetry/constants.py +6 -8
  36. sagemaker/core/telemetry/telemetry_logging.py +6 -5
  37. sagemaker/core/training/configs.py +16 -4
  38. sagemaker/core/workflow/utilities.py +10 -3
  39. {sagemaker_core-2.1.1.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +1 -1
  40. {sagemaker_core-2.1.1.dist-info → sagemaker_core-2.3.1.dist-info}/RECORD +43 -47
  41. sagemaker/core/huggingface/__init__.py +0 -29
  42. sagemaker/core/huggingface/llm_utils.py +0 -150
  43. sagemaker/core/huggingface/processing.py +0 -139
  44. sagemaker/core/huggingface/training_compiler/__init__.py +0 -0
  45. sagemaker/core/huggingface/training_compiler/config.py +0 -167
  46. sagemaker/core/image_uri_config/__init__.py +0 -13
  47. {sagemaker_core-2.1.1.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
  48. {sagemaker_core-2.1.1.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
  49. {sagemaker_core-2.1.1.dist-info → sagemaker_core-2.3.1.dist-info}/top_level.txt +0 -0
sagemaker/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ """Namespace package for SageMaker."""
2
+ __path__ = __import__('pkgutil').extend_path(__path__, __name__)
@@ -59,6 +59,7 @@ ALTERNATE_DOMAINS = {
59
59
  "us-isob-east-1": "sc2s.sgov.gov",
60
60
  "us-isof-south-1": "csp.hci.ic.gov",
61
61
  "us-isof-east-1": "csp.hci.ic.gov",
62
+ "eu-isoe-west-1": "cloud.adc-e.uk",
62
63
  }
63
64
 
64
65
  ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
@@ -74,6 +75,20 @@ DEFAULT_SLEEP_TIME_SECONDS = 10
74
75
  WAITING_DOT_NUMBER = 10
75
76
  MAX_ITEMS = 100
76
77
  PAGE_SIZE = 10
78
+ _MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators
79
+
80
+ _SENSITIVE_SYSTEM_PATHS = [
81
+ abspath(os.path.expanduser("~/.aws")),
82
+ abspath(os.path.expanduser("~/.ssh")),
83
+ abspath(os.path.expanduser("~/.kube")),
84
+ abspath(os.path.expanduser("~/.docker")),
85
+ abspath(os.path.expanduser("~/.config")),
86
+ abspath(os.path.expanduser("~/.credentials")),
87
+ "/etc",
88
+ "/root",
89
+ "/var/lib",
90
+ "/opt/ml/metadata",
91
+ ]
77
92
 
78
93
  logger = logging.getLogger(__name__)
79
94
 
@@ -409,6 +424,9 @@ def download_folder(bucket_name, prefix, target, sagemaker_session):
409
424
 
410
425
  prefix = prefix.lstrip("/")
411
426
 
427
+ if ".." in prefix:
428
+ raise ValueError("Traversal components are not allowed in S3 path!")
429
+
412
430
  # Try to download the prefix as an object first, in case it is a file and not a 'directory'.
413
431
  # Do this first, in case the object has broader permissions than the bucket.
414
432
  if not prefix.endswith("/"):
@@ -607,11 +625,73 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
607
625
  shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
608
626
 
609
627
 
628
+ def _validate_source_directory(source_directory):
629
+ """Validate that source_directory is safe to use.
630
+
631
+ Ensures the source directory path does not access restricted system locations.
632
+
633
+ Args:
634
+ source_directory (str): The source directory path to validate.
635
+
636
+ Raises:
637
+ ValueError: If the path is not allowed.
638
+ """
639
+ if not source_directory or source_directory.lower().startswith("s3://"):
640
+ # S3 paths and None are safe
641
+ return
642
+
643
+ # Resolve symlinks to get the actual path
644
+ abs_source = abspath(realpath(source_directory))
645
+
646
+ # Check if the source path is under any sensitive directory
647
+ for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
648
+ if abs_source != "/" and abs_source.startswith(sensitive_path):
649
+ raise ValueError(
650
+ f"source_directory cannot access sensitive system paths. "
651
+ f"Got: {source_directory} (resolved to {abs_source})"
652
+ )
653
+
654
+
655
+ def _validate_dependency_path(dependency):
656
+ """Validate that a dependency path is safe to use.
657
+
658
+ Ensures the dependency path does not access restricted system locations.
659
+
660
+ Args:
661
+ dependency (str): The dependency path to validate.
662
+
663
+ Raises:
664
+ ValueError: If the path is not allowed.
665
+ """
666
+ if not dependency:
667
+ return
668
+
669
+ # Resolve symlinks to get the actual path
670
+ abs_dependency = abspath(realpath(dependency))
671
+
672
+ # Check if the dependency path is under any sensitive directory
673
+ for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
674
+ if abs_dependency != "/" and abs_dependency.startswith(sensitive_path):
675
+ raise ValueError(
676
+ f"dependency path cannot access sensitive system paths. "
677
+ f"Got: {dependency} (resolved to {abs_dependency})"
678
+ )
679
+
680
+
610
681
  def _create_or_update_code_dir(
611
682
  model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp
612
683
  ):
613
684
  """Placeholder docstring"""
614
685
  code_dir = os.path.join(model_dir, "code")
686
+ resolved_code_dir = _get_resolved_path(code_dir)
687
+
688
+ # Validate that code_dir does not resolve to a sensitive system path
689
+ for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
690
+ if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path):
691
+ raise ValueError(
692
+ f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}"
693
+ )
694
+
615
695
  if source_directory and source_directory.lower().startswith("s3://"):
616
696
  local_code_path = os.path.join(tmp, "local_code.tar.gz")
617
697
  download_file_from_url(source_directory, local_code_path, sagemaker_session)
@@ -620,6 +700,8 @@ def _create_or_update_code_dir(
620
700
  custom_extractall_tarfile(t, code_dir)
621
701
 
622
702
  elif source_directory:
703
+ # Validate source_directory for security
704
+ _validate_source_directory(source_directory)
623
705
  if os.path.exists(code_dir):
624
706
  shutil.rmtree(code_dir)
625
707
  shutil.copytree(source_directory, code_dir)
@@ -635,6 +717,8 @@ def _create_or_update_code_dir(
635
717
  raise
636
718
 
637
719
  for dependency in dependencies:
720
+ # Validate dependency path for security
721
+ _validate_dependency_path(dependency)
638
722
  lib_dir = os.path.join(code_dir, "lib")
639
723
  if os.path.isdir(dependency):
640
724
  shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
@@ -1555,7 +1639,7 @@ def get_instance_type_family(instance_type: str) -> str:
1555
1639
  """
1556
1640
  instance_type_family = ""
1557
1641
  if isinstance(instance_type, str):
1558
- match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1642
+ match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
1559
1643
  if match is not None:
1560
1644
  instance_type_family = match[1]
1561
1645
  return instance_type_family
@@ -1646,6 +1730,38 @@ def _get_safe_members(members):
1646
1730
  yield file_info
1647
1731
 
1648
1732
 
1733
+ def _validate_extracted_paths(extract_path):
1734
+ """Validate that extracted paths remain within the expected directory.
1735
+
1736
+ Performs post-extraction validation to ensure all extracted files and directories
1737
+ are within the intended extraction path.
1738
+
1739
+ Args:
1740
+ extract_path (str): The path where files were extracted.
1741
+
1742
+ Raises:
1743
+ ValueError: If any extracted file is outside the expected extraction path.
1744
+ """
1745
+ base = _get_resolved_path(extract_path)
1746
+
1747
+ for root, dirs, files in os.walk(extract_path):
1748
+ # Check directories
1749
+ for dir_name in dirs:
1750
+ dir_path = os.path.join(root, dir_name)
1751
+ resolved = _get_resolved_path(dir_path)
1752
+ if not resolved.startswith(base):
1753
+ logger.error("Extracted directory escaped extraction path: %s", dir_path)
1754
+ raise ValueError(f"Extracted path outside expected directory: {dir_path}")
1755
+
1756
+ # Check files
1757
+ for file_name in files:
1758
+ file_path = os.path.join(root, file_name)
1759
+ resolved = _get_resolved_path(file_path)
1760
+ if not resolved.startswith(base):
1761
+ logger.error("Extracted file escaped extraction path: %s", file_path)
1762
+ raise ValueError(f"Extracted path outside expected directory: {file_path}")
1763
+
1764
+
1649
1765
  def custom_extractall_tarfile(tar, extract_path):
1650
1766
  """Extract a tarfile, optionally using data_filter if available.
1651
1767
 
@@ -1666,6 +1782,8 @@ def custom_extractall_tarfile(tar, extract_path):
1666
1782
  tar.extractall(path=extract_path, filter="data")
1667
1783
  else:
1668
1784
  tar.extractall(path=extract_path, members=_get_safe_members(tar))
1785
+ # Re-validate extracted paths to catch symlink race conditions
1786
+ _validate_extracted_paths(extract_path)
1669
1787
 
1670
1788
 
1671
1789
  def can_model_package_source_uri_autopopulate(source_uri: str):
@@ -21,6 +21,8 @@ from sagemaker.core.apiutils import _base_types
21
21
  from sagemaker.core.experiments.trial import _Trial
22
22
  from sagemaker.core.experiments.trial_component import _TrialComponent
23
23
  from sagemaker.core.common_utils import format_tags
24
+ from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
25
+ from sagemaker.core.telemetry.constants import Feature
24
26
 
25
27
 
26
28
  class Experiment(_base_types.Record):
@@ -93,6 +95,7 @@ class Experiment(_base_types.Record):
93
95
  )
94
96
 
95
97
  @classmethod
98
+ @_telemetry_emitter(feature=Feature.MLOPS, func_name="experiment.create")
96
99
  def create(
97
100
  cls,
98
101
  experiment_name,
@@ -25,13 +25,13 @@ from typing import Dict, List, Optional, Union
25
25
 
26
26
  from packaging import version
27
27
 
28
- import sagemaker.core.common_utils as sagemaker_utils
29
- from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs
28
+ import sagemaker.core.common_utils as utils
29
+ from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs, renamed_warning
30
30
  from sagemaker.core.instance_group import InstanceGroup
31
- from sagemaker.core.s3 import s3_path_join
31
+ from sagemaker.core.s3.utils import s3_path_join
32
32
  from sagemaker.core.session_settings import SessionSettings
33
33
  from sagemaker.core.workflow import is_pipeline_variable
34
- from sagemaker.core.helper.pipeline_variable import PipelineVariable
34
+ from sagemaker.core.workflow.entities import PipelineVariable
35
35
 
36
36
  logger = logging.getLogger(__name__)
37
37
 
@@ -155,6 +155,9 @@ TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [
155
155
  "2.3.1",
156
156
  "2.4.1",
157
157
  "2.5.1",
158
+ "2.6.0",
159
+ "2.7.1",
160
+ "2.8.0",
158
161
  ]
159
162
 
160
163
  TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
@@ -455,7 +458,7 @@ def tar_and_upload_dir(
455
458
 
456
459
  try:
457
460
  source_files = _list_files_to_compress(script, directory) + dependencies
458
- tar_file = sagemaker_utils.create_tar_file(
461
+ tar_file = utils.create_tar_file(
459
462
  source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME)
460
463
  )
461
464
 
@@ -516,7 +519,7 @@ def framework_name_from_image(image_uri):
516
519
  - str: The image tag
517
520
  - str: If the TensorFlow image is script mode
518
521
  """
519
- sagemaker_pattern = re.compile(sagemaker_utils.ECR_URI_PATTERN)
522
+ sagemaker_pattern = re.compile(utils.ECR_URI_PATTERN)
520
523
  sagemaker_match = sagemaker_pattern.match(image_uri)
521
524
  if sagemaker_match is None:
522
525
  return None, None, None, None
@@ -595,7 +598,7 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image):
595
598
  """
596
599
  name_from_image = f"/model_code/{int(time.time())}"
597
600
  if not is_pipeline_variable(image):
598
- name_from_image = sagemaker_utils.name_from_image(image)
601
+ name_from_image = utils.name_from_image(image)
599
602
  return s3_path_join(code_location_key_prefix, model_name or name_from_image)
600
603
 
601
604
 
@@ -961,7 +964,7 @@ def validate_distribution_for_instance_type(instance_type, distribution):
961
964
  """
962
965
  err_msg = ""
963
966
  if isinstance(instance_type, str):
964
- match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
967
+ match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
965
968
  if match and match[1].startswith("trn"):
966
969
  keys = list(distribution.keys())
967
970
  if len(keys) == 0:
@@ -1062,7 +1065,7 @@ def validate_torch_distributed_distribution(
1062
1065
  )
1063
1066
 
1064
1067
  # Check entry point type
1065
- if not entry_point.endswith(".py"):
1068
+ if entry_point is not None and not entry_point.endswith(".py"):
1066
1069
  err_msg += (
1067
1070
  "Unsupported entry point type for the distribution torch_distributed.\n"
1068
1071
  "Only python programs (*.py) are supported."
@@ -1082,7 +1085,7 @@ def _is_gpu_instance(instance_type):
1082
1085
  bool: Whether or not the instance_type supports GPU
1083
1086
  """
1084
1087
  if isinstance(instance_type, str):
1085
- match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1088
+ match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
1086
1089
  if match:
1087
1090
  if match[1].startswith("p") or match[1].startswith("g"):
1088
1091
  return True
@@ -1101,7 +1104,7 @@ def _is_trainium_instance(instance_type):
1101
1104
  bool: Whether or not the instance_type is a Trainium instance
1102
1105
  """
1103
1106
  if isinstance(instance_type, str):
1104
- match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1107
+ match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
1105
1108
  if match and match[1].startswith("trn"):
1106
1109
  return True
1107
1110
  return False
@@ -1148,7 +1151,7 @@ def _instance_type_supports_profiler(instance_type):
1148
1151
  bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
1149
1152
  """
1150
1153
  if isinstance(instance_type, str):
1151
- match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1154
+ match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
1152
1155
  if match and match[1].startswith("trn"):
1153
1156
  return True
1154
1157
  return False
@@ -1174,3 +1177,44 @@ def validate_version_or_image_args(framework_version, py_version, image_uri):
1174
1177
  "framework_version or py_version was None, yet image_uri was also None. "
1175
1178
  "Either specify both framework_version and py_version, or specify image_uri."
1176
1179
  )
1180
+
1181
+
1182
+ def create_image_uri(
1183
+ region,
1184
+ framework,
1185
+ instance_type,
1186
+ framework_version,
1187
+ py_version=None,
1188
+ account=None, # pylint: disable=W0613
1189
+ accelerator_type=None,
1190
+ optimized_families=None, # pylint: disable=W0613
1191
+ ):
1192
+ """Deprecated method. Please use sagemaker.image_uris.retrieve().
1193
+
1194
+ Args:
1195
+ region (str): AWS region where the image is uploaded.
1196
+ framework (str): framework used by the image.
1197
+ instance_type (str): SageMaker instance type. Used to determine device
1198
+ type (cpu/gpu/family-specific optimized).
1199
+ framework_version (str): The version of the framework.
1200
+ py_version (str): Optional. Python version Ex: `py38, py39, py310, py311`.
1201
+ If not specified, image uri will not include a python component.
1202
+ account (str): AWS account that contains the image. (default:
1203
+ '520713654638')
1204
+ accelerator_type (str): SageMaker Elastic Inference accelerator type.
1205
+ optimized_families (str): Deprecated. A no-op argument.
1206
+
1207
+ Returns:
1208
+ the image uri
1209
+ """
1210
+ from sagemaker.core import image_uris
1211
+
1212
+ renamed_warning("The method create_image_uri")
1213
+ return image_uris.retrieve(
1214
+ framework=framework,
1215
+ region=region,
1216
+ version=framework_version,
1217
+ py_version=py_version,
1218
+ instance_type=instance_type,
1219
+ accelerator_type=accelerator_type,
1220
+ )
@@ -20,7 +20,69 @@ import tempfile
20
20
  import warnings
21
21
  import six
22
22
  from six.moves import urllib
23
+ import re
24
+ from pathlib import Path
25
+ from urllib.parse import urlparse
26
+
27
+ def _sanitize_git_url(repo_url):
28
+ """Sanitize Git repository URL to prevent URL injection attacks.
29
+
30
+ Args:
31
+ repo_url (str): The Git repository URL to sanitize
23
32
 
33
+ Returns:
34
+ str: The sanitized URL
35
+
36
+ Raises:
37
+ ValueError: If the URL contains suspicious patterns that could indicate injection
38
+ """
39
+ at_count = repo_url.count("@")
40
+
41
+ if repo_url.startswith("git@"):
42
+ # git@ format requires exactly one @
43
+ if at_count != 1:
44
+ raise ValueError("Invalid SSH URL format: git@ URLs must have exactly one @ symbol")
45
+ elif repo_url.startswith("ssh://"):
46
+ # ssh:// format can have 0 or 1 @ symbols
47
+ if at_count > 1:
48
+ raise ValueError("Invalid SSH URL format: multiple @ symbols detected")
49
+ elif repo_url.startswith("https://") or repo_url.startswith("http://"):
50
+ # HTTPS format allows 0 or 1 @ symbols
51
+ if at_count > 1:
52
+ raise ValueError("Invalid HTTPS URL format: multiple @ symbols detected")
53
+
54
+ # Check for invalid characters in the URL before parsing
55
+ # These characters should not appear in legitimate URLs
56
+ invalid_chars = ["<", ">", "[", "]", "{", "}", "\\", "^", "`", "|"]
57
+ for char in invalid_chars:
58
+ if char in repo_url:
59
+ raise ValueError("Invalid characters in hostname")
60
+
61
+ try:
62
+ parsed = urlparse(repo_url)
63
+
64
+ # Check for suspicious characters in hostname that could indicate injection
65
+ if parsed.hostname:
66
+ # Check for URL-encoded characters that might be used for obfuscation
67
+ suspicious_patterns = ["%25", "%40", "%2F", "%3A"] # encoded %, @, /, :
68
+ for pattern in suspicious_patterns:
69
+ if pattern in parsed.hostname.lower():
70
+ raise ValueError(f"Suspicious URL encoding detected in hostname: {pattern}")
71
+
72
+ # Validate that the hostname looks legitimate
73
+ if not re.match(r"^[a-zA-Z0-9.-]+$", parsed.hostname):
74
+ raise ValueError("Invalid characters in hostname")
75
+
76
+ except Exception as e:
77
+ if isinstance(e, ValueError):
78
+ raise
79
+ raise ValueError(f"Failed to parse URL: {str(e)}")
80
+ else:
81
+ raise ValueError(
82
+ "Unsupported URL scheme: only https://, http://, git@, and ssh:// are allowed"
83
+ )
84
+
85
+ return repo_url
24
86
 
25
87
  def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
26
88
  """Git clone repo containing the training code and serving code.
@@ -87,6 +149,10 @@ def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
87
149
  if entry_point is None:
88
150
  raise ValueError("Please provide an entry point.")
89
151
  _validate_git_config(git_config)
152
+
153
+ # SECURITY: Sanitize the repository URL to prevent injection attacks
154
+ git_config["repo"] = _sanitize_git_url(git_config["repo"])
155
+
90
156
  dest_dir = tempfile.mkdtemp()
91
157
  _generate_and_run_clone_command(git_config, dest_dir)
92
158
 
@@ -330,16 +330,16 @@ class Session(object): # pylint: disable=too-many-public-methods
330
330
  user_profile_name = metadata.get("UserProfileName")
331
331
  execution_role_arn = metadata.get("ExecutionRoleArn")
332
332
  try:
333
+ # find execution role from the metadata file if present
334
+ if execution_role_arn is not None:
335
+ return execution_role_arn
336
+
333
337
  if domain_id is None:
334
338
  instance_desc = self.sagemaker_client.describe_notebook_instance(
335
339
  NotebookInstanceName=instance_name
336
340
  )
337
341
  return instance_desc["RoleArn"]
338
342
 
339
- # find execution role from the metadata file if present
340
- if execution_role_arn is not None:
341
- return execution_role_arn
342
-
343
343
  user_profile_desc = self.sagemaker_client.describe_user_profile(
344
344
  DomainId=domain_id, UserProfileName=user_profile_name
345
345
  )
@@ -666,9 +666,16 @@ class Session(object): # pylint: disable=too-many-public-methods
666
666
 
667
667
  """
668
668
  try:
669
- s3.meta.client.head_bucket(
670
- Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
671
- )
669
+ if self.default_bucket_prefix:
670
+ s3.meta.client.list_objects_v2(
671
+ Bucket=bucket_name,
672
+ Prefix=self.default_bucket_prefix,
673
+ ExpectedBucketOwner=expected_bucket_owner_id,
674
+ )
675
+ else:
676
+ s3.meta.client.head_bucket(
677
+ Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
678
+ )
672
679
  except ClientError as e:
673
680
  error_code = e.response["Error"]["Code"]
674
681
  message = e.response["Error"]["Message"]
@@ -699,7 +706,12 @@ class Session(object): # pylint: disable=too-many-public-methods
699
706
  bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
700
707
  """
701
708
  try:
702
- s3.meta.client.head_bucket(Bucket=bucket_name)
709
+ if self.default_bucket_prefix:
710
+ s3.meta.client.list_objects_v2(
711
+ Bucket=bucket_name, Prefix=self.default_bucket_prefix
712
+ )
713
+ else:
714
+ s3.meta.client.head_bucket(Bucket=bucket_name)
703
715
  except ClientError as e:
704
716
  error_code = e.response["Error"]["Code"]
705
717
  message = e.response["Error"]["Message"]
@@ -1865,7 +1877,7 @@ class Session(object): # pylint: disable=too-many-public-methods
1865
1877
  if "/" in role:
1866
1878
  return role
1867
1879
  return self.boto_session.resource("iam").Role(role).arn
1868
-
1880
+
1869
1881
 
1870
1882
  def _expand_container_def(c_def):
1871
1883
  """Placeholder docstring"""
@@ -2962,4 +2974,4 @@ def container_def(
2962
2974
  c_def["Mode"] = container_mode
2963
2975
  if image_config:
2964
2976
  c_def["ImageConfig"] = image_config
2965
- return c_def
2977
+ return c_def
@@ -185,9 +185,7 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
185
185
 
186
186
  def config_for_framework(framework):
187
187
  """Loads the JSON config for the given framework."""
188
- response = requests.get(s3_url)
189
- return response.json()
190
- fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))
188
+ fname = os.path.join(os.path.dirname(__file__), "..", "image_uri_config", "{}.json".format(framework))
191
189
  with open(fname) as f:
192
190
  return json.load(f)
193
191
 
@@ -4,7 +4,9 @@
4
4
  "inf2"
5
5
  ],
6
6
  "version_aliases": {
7
- "0.0": "0.0.28"
7
+ "0.0": "0.0.28",
8
+ "0.2": "0.2.0",
9
+ "0.3": "0.3.0"
8
10
  },
9
11
  "versions": {
10
12
  "0.0.16": {
@@ -654,6 +656,114 @@
654
656
  "container_version": {
655
657
  "inf2": "ubuntu22.04"
656
658
  }
659
+ },
660
+ "0.2.0": {
661
+ "py_versions": [
662
+ "py310"
663
+ ],
664
+ "registries": {
665
+ "af-south-1": "626614931356",
666
+ "ap-east-1": "871362719292",
667
+ "ap-east-2": "975050140332",
668
+ "ap-northeast-1": "763104351884",
669
+ "ap-northeast-2": "763104351884",
670
+ "ap-northeast-3": "364406365360",
671
+ "ap-south-1": "763104351884",
672
+ "ap-south-2": "772153158452",
673
+ "ap-southeast-1": "763104351884",
674
+ "ap-southeast-2": "763104351884",
675
+ "ap-southeast-3": "907027046896",
676
+ "ap-southeast-4": "457447274322",
677
+ "ap-southeast-5": "550225433462",
678
+ "ap-southeast-6": "633930458069",
679
+ "ap-southeast-7": "590183813437",
680
+ "ca-central-1": "763104351884",
681
+ "ca-west-1": "204538143572",
682
+ "cn-north-1": "727897471807",
683
+ "cn-northwest-1": "727897471807",
684
+ "eu-central-1": "763104351884",
685
+ "eu-central-2": "380420809688",
686
+ "eu-north-1": "763104351884",
687
+ "eu-south-1": "692866216735",
688
+ "eu-south-2": "503227376785",
689
+ "eu-west-1": "763104351884",
690
+ "eu-west-2": "763104351884",
691
+ "eu-west-3": "763104351884",
692
+ "il-central-1": "780543022126",
693
+ "me-central-1": "914824155844",
694
+ "me-south-1": "217643126080",
695
+ "mx-central-1": "637423239942",
696
+ "sa-east-1": "763104351884",
697
+ "us-east-1": "763104351884",
698
+ "us-east-2": "763104351884",
699
+ "us-gov-east-1": "446045086412",
700
+ "us-gov-west-1": "442386744353",
701
+ "us-iso-east-1": "886529160074",
702
+ "us-isob-east-1": "094389454867",
703
+ "us-isof-east-1": "303241398832",
704
+ "us-isof-south-1": "454834333376",
705
+ "us-west-1": "763104351884",
706
+ "us-west-2": "763104351884"
707
+ },
708
+ "tag_prefix": "2.5.1-optimum3.3.4",
709
+ "repository": "huggingface-pytorch-tgi-inference",
710
+ "container_version": {
711
+ "inf2": "ubuntu22.04"
712
+ }
713
+ },
714
+ "0.3.0": {
715
+ "py_versions": [
716
+ "py310"
717
+ ],
718
+ "registries": {
719
+ "af-south-1": "626614931356",
720
+ "ap-east-1": "871362719292",
721
+ "ap-east-2": "975050140332",
722
+ "ap-northeast-1": "763104351884",
723
+ "ap-northeast-2": "763104351884",
724
+ "ap-northeast-3": "364406365360",
725
+ "ap-south-1": "763104351884",
726
+ "ap-south-2": "772153158452",
727
+ "ap-southeast-1": "763104351884",
728
+ "ap-southeast-2": "763104351884",
729
+ "ap-southeast-3": "907027046896",
730
+ "ap-southeast-4": "457447274322",
731
+ "ap-southeast-5": "550225433462",
732
+ "ap-southeast-6": "633930458069",
733
+ "ap-southeast-7": "590183813437",
734
+ "ca-central-1": "763104351884",
735
+ "ca-west-1": "204538143572",
736
+ "cn-north-1": "727897471807",
737
+ "cn-northwest-1": "727897471807",
738
+ "eu-central-1": "763104351884",
739
+ "eu-central-2": "380420809688",
740
+ "eu-north-1": "763104351884",
741
+ "eu-south-1": "692866216735",
742
+ "eu-south-2": "503227376785",
743
+ "eu-west-1": "763104351884",
744
+ "eu-west-2": "763104351884",
745
+ "eu-west-3": "763104351884",
746
+ "il-central-1": "780543022126",
747
+ "me-central-1": "914824155844",
748
+ "me-south-1": "217643126080",
749
+ "mx-central-1": "637423239942",
750
+ "sa-east-1": "763104351884",
751
+ "us-east-1": "763104351884",
752
+ "us-east-2": "763104351884",
753
+ "us-gov-east-1": "446045086412",
754
+ "us-gov-west-1": "442386744353",
755
+ "us-iso-east-1": "886529160074",
756
+ "us-isob-east-1": "094389454867",
757
+ "us-isof-east-1": "303241398832",
758
+ "us-isof-south-1": "454834333376",
759
+ "us-west-1": "763104351884",
760
+ "us-west-2": "763104351884"
761
+ },
762
+ "tag_prefix": "2.7.0-optimum3.3.6",
763
+ "repository": "huggingface-pytorch-tgi-inference",
764
+ "container_version": {
765
+ "inf2": "ubuntu22.04"
766
+ }
657
767
  }
658
768
  }
659
769
  }