sagemaker-core 1.0.8__py3-none-any.whl → 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sagemaker-core might be problematic. Click here for more details.

@@ -1380,6 +1380,11 @@ SHAPE_DAG = {
1380
1380
  "shape": "AppLifecycleManagement",
1381
1381
  "type": "structure",
1382
1382
  },
1383
+ {
1384
+ "name": "BuiltInLifecycleConfigArn",
1385
+ "shape": "StudioLifecycleConfigArn",
1386
+ "type": "string",
1387
+ },
1383
1388
  ],
1384
1389
  "type": "structure",
1385
1390
  },
@@ -1884,6 +1889,7 @@ SHAPE_DAG = {
1884
1889
  "shape": "AppSecurityGroupManagement",
1885
1890
  "type": "string",
1886
1891
  },
1892
+ {"name": "TagPropagation", "shape": "TagPropagation", "type": "string"},
1887
1893
  {"name": "DefaultSpaceSettings", "shape": "DefaultSpaceSettings", "type": "structure"},
1888
1894
  ],
1889
1895
  "type": "structure",
@@ -3792,6 +3798,11 @@ SHAPE_DAG = {
3792
3798
  {"name": "CreationTime", "shape": "Timestamp", "type": "timestamp"},
3793
3799
  {"name": "FailureReason", "shape": "FailureReason", "type": "string"},
3794
3800
  {"name": "ResourceSpec", "shape": "ResourceSpec", "type": "structure"},
3801
+ {
3802
+ "name": "BuiltInLifecycleConfigArn",
3803
+ "shape": "StudioLifecycleConfigArn",
3804
+ "type": "string",
3805
+ },
3795
3806
  ],
3796
3807
  "type": "structure",
3797
3808
  },
@@ -4125,6 +4136,7 @@ SHAPE_DAG = {
4125
4136
  "shape": "AppSecurityGroupManagement",
4126
4137
  "type": "string",
4127
4138
  },
4139
+ {"name": "TagPropagation", "shape": "TagPropagation", "type": "string"},
4128
4140
  {"name": "DefaultSpaceSettings", "shape": "DefaultSpaceSettings", "type": "structure"},
4129
4141
  ],
4130
4142
  "type": "structure",
@@ -7697,6 +7709,11 @@ SHAPE_DAG = {
7697
7709
  "type": "structure",
7698
7710
  },
7699
7711
  {"name": "EmrSettings", "shape": "EmrSettings", "type": "structure"},
7712
+ {
7713
+ "name": "BuiltInLifecycleConfigArn",
7714
+ "shape": "StudioLifecycleConfigArn",
7715
+ "type": "string",
7716
+ },
7700
7717
  ],
7701
7718
  "type": "structure",
7702
7719
  },
@@ -14082,6 +14099,7 @@ SHAPE_DAG = {
14082
14099
  {"name": "DefaultSpaceSettings", "shape": "DefaultSpaceSettings", "type": "structure"},
14083
14100
  {"name": "SubnetIds", "shape": "Subnets", "type": "list"},
14084
14101
  {"name": "AppNetworkAccessType", "shape": "AppNetworkAccessType", "type": "string"},
14102
+ {"name": "TagPropagation", "shape": "TagPropagation", "type": "string"},
14085
14103
  ],
14086
14104
  "type": "structure",
14087
14105
  },
@@ -0,0 +1,167 @@
1
+ import boto3
2
+ import botocore
3
+
4
+ from boto3.session import Session
5
+ import botocore.client
6
+ from botocore.config import Config
7
+ from typing import Generator, Tuple, List
8
+ from sagemaker_core.main.utils import SingletonMeta
9
+
10
+
11
+ class CloudWatchLogsClient(metaclass=SingletonMeta):
12
+ """
13
+ A singleton class for creating a CloudWatchLogs client.
14
+ """
15
+
16
+ client: botocore.client = None
17
+
18
+ def __init__(self):
19
+ if not self.client:
20
+ session = Session()
21
+ self.client = session.client(
22
+ "logs",
23
+ session.region_name,
24
+ config=Config(retries={"max_attempts": 10, "mode": "standard"}),
25
+ )
26
+
27
+
28
+ class LogStreamHandler:
29
+ log_group_name: str = None
30
+ log_stream_name: str = None
31
+ stream_id: int = None
32
+ next_token: str = None
33
+ cw_client = None
34
+
35
+ def __init__(self, log_group_name: str, log_stream_name: str, stream_id: int):
36
+ self.log_group_name = log_group_name
37
+ self.log_stream_name = log_stream_name
38
+ self.cw_client = CloudWatchLogsClient().client
39
+ self.stream_id = stream_id
40
+
41
+ def get_latest_log_events(self) -> Generator[Tuple[str, dict], None, None]:
42
+ """
43
+ This method gets all the latest log events for this stream that exist at this moment in time.
44
+
45
+ cw_client.get_log_events() always returns a nextForwardToken even if the current batch of events is empty.
46
+ You can keep calling cw_client.get_log_events() with the same token until a new batch of log events exist.
47
+
48
+ API Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/logs/client/get_log_events.html
49
+
50
+ Returns:
51
+ Generator[tuple[str, dict], None, None]: Generator that yields a tuple that consists for two values
52
+ str: stream_name,
53
+ dict: event dict in format
54
+ {
55
+ "ingestionTime": number,
56
+ "message": "string",
57
+ "timestamp": number
58
+ }
59
+ """
60
+ while True:
61
+ if not self.next_token:
62
+ token_args = {}
63
+ else:
64
+ token_args = {"nextToken": self.next_token}
65
+
66
+ response = self.cw_client.get_log_events(
67
+ logGroupName=self.log_group_name,
68
+ logStreamName=self.log_stream_name,
69
+ startFromHead=True,
70
+ **token_args,
71
+ )
72
+
73
+ self.next_token = response["nextForwardToken"]
74
+ if not response["events"]:
75
+ break
76
+
77
+ for event in response["events"]:
78
+ yield self.log_stream_name, event
79
+
80
+
81
+ class MultiLogStreamHandler:
82
+ log_group_name: str = None
83
+ log_stream_name_prefix: str = None
84
+ expected_stream_count: int = None
85
+ streams: List[LogStreamHandler] = []
86
+ cw_client = None
87
+
88
+ def __init__(
89
+ self, log_group_name: str, log_stream_name_prefix: str, expected_stream_count: int
90
+ ):
91
+ self.log_group_name = log_group_name
92
+ self.log_stream_name_prefix = log_stream_name_prefix
93
+ self.expected_stream_count = expected_stream_count
94
+ self.cw_client = CloudWatchLogsClient().client
95
+
96
+ def get_latest_log_events(self) -> Generator[Tuple[str, dict], None, None]:
97
+ """
98
+ This method gets all the latest log events from each stream that exist at this moment.
99
+
100
+ Returns:
101
+ Generator[tuple[str, dict], None, None]: Generator that yields a tuple that consists for two values
102
+ str: stream_name,
103
+ dict: event dict in format -
104
+ {
105
+ "ingestionTime": number,
106
+ "message": "string",
107
+ "timestamp": number
108
+ }
109
+ """
110
+ if not self.ready():
111
+ return []
112
+
113
+ for stream in self.streams:
114
+ yield from stream.get_latest_log_events()
115
+
116
+ def ready(self) -> bool:
117
+ """
118
+ Checks whether or not MultiLogStreamHandler is ready to serve new log events at this moment.
119
+
120
+ If self.streams is already set, return True.
121
+ Otherwise, check if the current number of log streams in the log group match the exptected stream count.
122
+
123
+ Returns:
124
+ bool: Whether or not MultiLogStreamHandler is ready to serve new log events.
125
+ """
126
+
127
+ if len(self.streams) >= self.expected_stream_count:
128
+ return True
129
+
130
+ try:
131
+ response = self.cw_client.describe_log_streams(
132
+ logGroupName=self.log_group_name,
133
+ logStreamNamePrefix=self.log_stream_name_prefix + "/",
134
+ orderBy="LogStreamName",
135
+ )
136
+ stream_names = [stream["logStreamName"] for stream in response["logStreams"]]
137
+
138
+ next_token = response.get("nextToken")
139
+ while next_token:
140
+ response = self.cw_client.describe_log_streams(
141
+ logGroupName=self.log_group_name,
142
+ logStreamNamePrefix=self.log_stream_name_prefix + "/",
143
+ orderBy="LogStreamName",
144
+ nextToken=next_token,
145
+ )
146
+ stream_names.extend([stream["logStreamName"] for stream in response["logStreams"]])
147
+ next_token = response.get("nextToken", None)
148
+
149
+ if len(stream_names) >= self.expected_stream_count:
150
+ self.streams = [
151
+ LogStreamHandler(self.log_group_name, log_stream_name, index)
152
+ for index, log_stream_name in enumerate(stream_names)
153
+ ]
154
+
155
+ return True
156
+ else:
157
+ # Log streams are created whenever a container starts writing to stdout/err,
158
+ # so if the stream count is less than the expected number, return False
159
+ return False
160
+
161
+ except botocore.exceptions.ClientError as e:
162
+ # On the very first training job run on an account, there's no log group until
163
+ # the container starts logging, so ignore any errors thrown about that
164
+ if e.response["Error"]["Code"] == "ResourceNotFoundException":
165
+ return False
166
+ else:
167
+ raise
@@ -41,6 +41,7 @@ from sagemaker_core.main.intelligent_defaults_helper import (
41
41
  load_default_configs_for_resource_name,
42
42
  get_config_value,
43
43
  )
44
+ from sagemaker_core.main.logs import MultiLogStreamHandler
44
45
  from sagemaker_core.main.shapes import *
45
46
  from sagemaker_core.main.exceptions import *
46
47
 
@@ -812,7 +813,8 @@ class Algorithm(Base):
812
813
  Group(progress, status),
813
814
  title="Wait Log Panel",
814
815
  border_style=Style(color=Color.BLUE.value),
815
- )
816
+ ),
817
+ transient=True,
816
818
  ):
817
819
  while True:
818
820
  self.refresh()
@@ -977,6 +979,7 @@ class App(Base):
977
979
  creation_time: The creation time of the application. After an application has been shut down for 24 hours, SageMaker deletes all metadata for the application. To be considered an update and retain application metadata, applications must be restarted within 24 hours after the previous application has been shut down. After this time window, creation of an application is considered a new application rather than an update of the previous application.
978
980
  failure_reason: The failure reason.
979
981
  resource_spec: The instance type and the Amazon Resource Name (ARN) of the SageMaker image created on the instance.
982
+ built_in_lifecycle_config_arn: The lifecycle configuration that runs before the default lifecycle configuration
980
983
 
981
984
  """
982
985
 
@@ -992,6 +995,7 @@ class App(Base):
992
995
  creation_time: Optional[datetime.datetime] = Unassigned()
993
996
  failure_reason: Optional[str] = Unassigned()
994
997
  resource_spec: Optional[ResourceSpec] = Unassigned()
998
+ built_in_lifecycle_config_arn: Optional[str] = Unassigned()
995
999
 
996
1000
  def get_name(self) -> str:
997
1001
  attributes = vars(self)
@@ -1270,7 +1274,8 @@ class App(Base):
1270
1274
  Group(progress, status),
1271
1275
  title="Wait Log Panel",
1272
1276
  border_style=Style(color=Color.BLUE.value),
1273
- )
1277
+ ),
1278
+ transient=True,
1274
1279
  ):
1275
1280
  while True:
1276
1281
  self.refresh()
@@ -2652,7 +2657,11 @@ class AutoMLJob(Base):
2652
2657
  logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}")
2653
2658
 
2654
2659
  @Base.add_validate_call
2655
- def wait(self, poll: int = 5, timeout: Optional[int] = None) -> None:
2660
+ def wait(
2661
+ self,
2662
+ poll: int = 5,
2663
+ timeout: Optional[int] = None,
2664
+ ) -> None:
2656
2665
  """
2657
2666
  Wait for a AutoMLJob resource.
2658
2667
 
@@ -2682,7 +2691,8 @@ class AutoMLJob(Base):
2682
2691
  Group(progress, status),
2683
2692
  title="Wait Log Panel",
2684
2693
  border_style=Style(color=Color.BLUE.value),
2685
- )
2694
+ ),
2695
+ transient=True,
2686
2696
  ):
2687
2697
  while True:
2688
2698
  self.refresh()
@@ -3130,7 +3140,11 @@ class AutoMLJobV2(Base):
3130
3140
  return self
3131
3141
 
3132
3142
  @Base.add_validate_call
3133
- def wait(self, poll: int = 5, timeout: Optional[int] = None) -> None:
3143
+ def wait(
3144
+ self,
3145
+ poll: int = 5,
3146
+ timeout: Optional[int] = None,
3147
+ ) -> None:
3134
3148
  """
3135
3149
  Wait for a AutoMLJobV2 resource.
3136
3150
 
@@ -3160,7 +3174,8 @@ class AutoMLJobV2(Base):
3160
3174
  Group(progress, status),
3161
3175
  title="Wait Log Panel",
3162
3176
  border_style=Style(color=Color.BLUE.value),
3163
- )
3177
+ ),
3178
+ transient=True,
3164
3179
  ):
3165
3180
  while True:
3166
3181
  self.refresh()
@@ -3534,7 +3549,8 @@ class Cluster(Base):
3534
3549
  Group(progress, status),
3535
3550
  title="Wait Log Panel",
3536
3551
  border_style=Style(color=Color.BLUE.value),
3537
- )
3552
+ ),
3553
+ transient=True,
3538
3554
  ):
3539
3555
  while True:
3540
3556
  self.refresh()
@@ -4493,7 +4509,11 @@ class CompilationJob(Base):
4493
4509
  logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}")
4494
4510
 
4495
4511
  @Base.add_validate_call
4496
- def wait(self, poll: int = 5, timeout: Optional[int] = None) -> None:
4512
+ def wait(
4513
+ self,
4514
+ poll: int = 5,
4515
+ timeout: Optional[int] = None,
4516
+ ) -> None:
4497
4517
  """
4498
4518
  Wait for a CompilationJob resource.
4499
4519
 
@@ -4523,7 +4543,8 @@ class CompilationJob(Base):
4523
4543
  Group(progress, status),
4524
4544
  title="Wait Log Panel",
4525
4545
  border_style=Style(color=Color.BLUE.value),
4526
- )
4546
+ ),
4547
+ transient=True,
4527
4548
  ):
4528
4549
  while True:
4529
4550
  self.refresh()
@@ -6147,6 +6168,7 @@ class Domain(Base):
6147
6168
  vpc_id: The ID of the Amazon Virtual Private Cloud (VPC) that the domain uses for communication.
6148
6169
  kms_key_id: The Amazon Web Services KMS customer managed key used to encrypt the EFS volume attached to the domain.
6149
6170
  app_security_group_management: The entity that creates and manages the required security groups for inter-app communication in VPCOnly mode. Required when CreateDomain.AppNetworkAccessType is VPCOnly and DomainSettings.RStudioServerProDomainSettings.DomainExecutionRoleArn is provided.
6171
+ tag_propagation: Indicates whether custom tag propagation is supported for the domain.
6150
6172
  default_space_settings: The default settings used to create a space.
6151
6173
 
6152
6174
  """
@@ -6172,6 +6194,7 @@ class Domain(Base):
6172
6194
  vpc_id: Optional[str] = Unassigned()
6173
6195
  kms_key_id: Optional[str] = Unassigned()
6174
6196
  app_security_group_management: Optional[str] = Unassigned()
6197
+ tag_propagation: Optional[str] = Unassigned()
6175
6198
  default_space_settings: Optional[DefaultSpaceSettings] = Unassigned()
6176
6199
 
6177
6200
  def get_name(self) -> str:
@@ -6270,6 +6293,7 @@ class Domain(Base):
6270
6293
  home_efs_file_system_kms_key_id: Optional[str] = Unassigned(),
6271
6294
  kms_key_id: Optional[str] = Unassigned(),
6272
6295
  app_security_group_management: Optional[str] = Unassigned(),
6296
+ tag_propagation: Optional[str] = Unassigned(),
6273
6297
  default_space_settings: Optional[DefaultSpaceSettings] = Unassigned(),
6274
6298
  session: Optional[Session] = None,
6275
6299
  region: Optional[str] = None,
@@ -6289,6 +6313,7 @@ class Domain(Base):
6289
6313
  home_efs_file_system_kms_key_id: Use KmsKeyId.
6290
6314
  kms_key_id: SageMaker uses Amazon Web Services KMS to encrypt EFS and EBS volumes attached to the domain with an Amazon Web Services managed key by default. For more control, specify a customer managed key.
6291
6315
  app_security_group_management: The entity that creates and manages the required security groups for inter-app communication in VPCOnly mode. Required when CreateDomain.AppNetworkAccessType is VPCOnly and DomainSettings.RStudioServerProDomainSettings.DomainExecutionRoleArn is provided. If setting up the domain for use with RStudio, this value must be set to Service.
6316
+ tag_propagation: Indicates whether custom tag propagation is supported for the domain. Defaults to DISABLED.
6292
6317
  default_space_settings: The default settings used to create a space.
6293
6318
  session: Boto3 session.
6294
6319
  region: Region name.
@@ -6330,6 +6355,7 @@ class Domain(Base):
6330
6355
  "HomeEfsFileSystemKmsKeyId": home_efs_file_system_kms_key_id,
6331
6356
  "KmsKeyId": kms_key_id,
6332
6357
  "AppSecurityGroupManagement": app_security_group_management,
6358
+ "TagPropagation": tag_propagation,
6333
6359
  "DefaultSpaceSettings": default_space_settings,
6334
6360
  }
6335
6361
 
@@ -6446,6 +6472,7 @@ class Domain(Base):
6446
6472
  default_space_settings: Optional[DefaultSpaceSettings] = Unassigned(),
6447
6473
  subnet_ids: Optional[List[str]] = Unassigned(),
6448
6474
  app_network_access_type: Optional[str] = Unassigned(),
6475
+ tag_propagation: Optional[str] = Unassigned(),
6449
6476
  ) -> Optional["Domain"]:
6450
6477
  """
6451
6478
  Update a Domain resource
@@ -6482,6 +6509,7 @@ class Domain(Base):
6482
6509
  "DefaultSpaceSettings": default_space_settings,
6483
6510
  "SubnetIds": subnet_ids,
6484
6511
  "AppNetworkAccessType": app_network_access_type,
6512
+ "TagPropagation": tag_propagation,
6485
6513
  }
6486
6514
  logger.debug(f"Input request: {operation_input_args}")
6487
6515
  # serialize the input request
@@ -6574,7 +6602,8 @@ class Domain(Base):
6574
6602
  Group(progress, status),
6575
6603
  title="Wait Log Panel",
6576
6604
  border_style=Style(color=Color.BLUE.value),
6577
- )
6605
+ ),
6606
+ transient=True,
6578
6607
  ):
6579
6608
  while True:
6580
6609
  self.refresh()
@@ -7523,7 +7552,11 @@ class EdgePackagingJob(Base):
7523
7552
  logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}")
7524
7553
 
7525
7554
  @Base.add_validate_call
7526
- def wait(self, poll: int = 5, timeout: Optional[int] = None) -> None:
7555
+ def wait(
7556
+ self,
7557
+ poll: int = 5,
7558
+ timeout: Optional[int] = None,
7559
+ ) -> None:
7527
7560
  """
7528
7561
  Wait for a EdgePackagingJob resource.
7529
7562
 
@@ -7553,7 +7586,8 @@ class EdgePackagingJob(Base):
7553
7586
  Group(progress, status),
7554
7587
  title="Wait Log Panel",
7555
7588
  border_style=Style(color=Color.BLUE.value),
7556
- )
7589
+ ),
7590
+ transient=True,
7557
7591
  ):
7558
7592
  while True:
7559
7593
  self.refresh()
@@ -8024,7 +8058,8 @@ class Endpoint(Base):
8024
8058
  Group(progress, status),
8025
8059
  title="Wait Log Panel",
8026
8060
  border_style=Style(color=Color.BLUE.value),
8027
- )
8061
+ ),
8062
+ transient=True,
8028
8063
  ):
8029
8064
  while True:
8030
8065
  self.refresh()
@@ -9540,7 +9575,8 @@ class FeatureGroup(Base):
9540
9575
  Group(progress, status),
9541
9576
  title="Wait Log Panel",
9542
9577
  border_style=Style(color=Color.BLUE.value),
9543
- )
9578
+ ),
9579
+ transient=True,
9544
9580
  ):
9545
9581
  while True:
9546
9582
  self.refresh()
@@ -10408,7 +10444,8 @@ class FlowDefinition(Base):
10408
10444
  Group(progress, status),
10409
10445
  title="Wait Log Panel",
10410
10446
  border_style=Style(color=Color.BLUE.value),
10411
- )
10447
+ ),
10448
+ transient=True,
10412
10449
  ):
10413
10450
  while True:
10414
10451
  self.refresh()
@@ -10903,7 +10940,8 @@ class Hub(Base):
10903
10940
  Group(progress, status),
10904
10941
  title="Wait Log Panel",
10905
10942
  border_style=Style(color=Color.BLUE.value),
10906
- )
10943
+ ),
10944
+ transient=True,
10907
10945
  ):
10908
10946
  while True:
10909
10947
  self.refresh()
@@ -11292,7 +11330,8 @@ class HubContent(Base):
11292
11330
  Group(progress, status),
11293
11331
  title="Wait Log Panel",
11294
11332
  border_style=Style(color=Color.BLUE.value),
11295
- )
11333
+ ),
11334
+ transient=True,
11296
11335
  ):
11297
11336
  while True:
11298
11337
  self.refresh()
@@ -11867,7 +11906,8 @@ class HumanTaskUi(Base):
11867
11906
  Group(progress, status),
11868
11907
  title="Wait Log Panel",
11869
11908
  border_style=Style(color=Color.BLUE.value),
11870
- )
11909
+ ),
11910
+ transient=True,
11871
11911
  ):
11872
11912
  while True:
11873
11913
  self.refresh()
@@ -12334,7 +12374,11 @@ class HyperParameterTuningJob(Base):
12334
12374
  logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}")
12335
12375
 
12336
12376
  @Base.add_validate_call
12337
- def wait(self, poll: int = 5, timeout: Optional[int] = None) -> None:
12377
+ def wait(
12378
+ self,
12379
+ poll: int = 5,
12380
+ timeout: Optional[int] = None,
12381
+ ) -> None:
12338
12382
  """
12339
12383
  Wait for a HyperParameterTuningJob resource.
12340
12384
 
@@ -12364,7 +12408,8 @@ class HyperParameterTuningJob(Base):
12364
12408
  Group(progress, status),
12365
12409
  title="Wait Log Panel",
12366
12410
  border_style=Style(color=Color.BLUE.value),
12367
- )
12411
+ ),
12412
+ transient=True,
12368
12413
  ):
12369
12414
  while True:
12370
12415
  self.refresh()
@@ -12936,7 +12981,8 @@ class Image(Base):
12936
12981
  Group(progress, status),
12937
12982
  title="Wait Log Panel",
12938
12983
  border_style=Style(color=Color.BLUE.value),
12939
- )
12984
+ ),
12985
+ transient=True,
12940
12986
  ):
12941
12987
  while True:
12942
12988
  self.refresh()
@@ -13542,7 +13588,8 @@ class ImageVersion(Base):
13542
13588
  Group(progress, status),
13543
13589
  title="Wait Log Panel",
13544
13590
  border_style=Style(color=Color.BLUE.value),
13545
- )
13591
+ ),
13592
+ transient=True,
13546
13593
  ):
13547
13594
  while True:
13548
13595
  self.refresh()
@@ -13959,7 +14006,8 @@ class InferenceComponent(Base):
13959
14006
  Group(progress, status),
13960
14007
  title="Wait Log Panel",
13961
14008
  border_style=Style(color=Color.BLUE.value),
13962
- )
14009
+ ),
14010
+ transient=True,
13963
14011
  ):
13964
14012
  while True:
13965
14013
  self.refresh()
@@ -14594,7 +14642,8 @@ class InferenceExperiment(Base):
14594
14642
  Group(progress, status),
14595
14643
  title="Wait Log Panel",
14596
14644
  border_style=Style(color=Color.BLUE.value),
14597
- )
14645
+ ),
14646
+ transient=True,
14598
14647
  ):
14599
14648
  while True:
14600
14649
  self.refresh()
@@ -14967,7 +15016,11 @@ class InferenceRecommendationsJob(Base):
14967
15016
  logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}")
14968
15017
 
14969
15018
  @Base.add_validate_call
14970
- def wait(self, poll: int = 5, timeout: Optional[int] = None) -> None:
15019
+ def wait(
15020
+ self,
15021
+ poll: int = 5,
15022
+ timeout: Optional[int] = None,
15023
+ ) -> None:
14971
15024
  """
14972
15025
  Wait for a InferenceRecommendationsJob resource.
14973
15026
 
@@ -14997,7 +15050,8 @@ class InferenceRecommendationsJob(Base):
14997
15050
  Group(progress, status),
14998
15051
  title="Wait Log Panel",
14999
15052
  border_style=Style(color=Color.BLUE.value),
15000
- )
15053
+ ),
15054
+ transient=True,
15001
15055
  ):
15002
15056
  while True:
15003
15057
  self.refresh()
@@ -15529,7 +15583,11 @@ class LabelingJob(Base):
15529
15583
  logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}")
15530
15584
 
15531
15585
  @Base.add_validate_call
15532
- def wait(self, poll: int = 5, timeout: Optional[int] = None) -> None:
15586
+ def wait(
15587
+ self,
15588
+ poll: int = 5,
15589
+ timeout: Optional[int] = None,
15590
+ ) -> None:
15533
15591
  """
15534
15592
  Wait for a LabelingJob resource.
15535
15593
 
@@ -15559,7 +15617,8 @@ class LabelingJob(Base):
15559
15617
  Group(progress, status),
15560
15618
  title="Wait Log Panel",
15561
15619
  border_style=Style(color=Color.BLUE.value),
15562
- )
15620
+ ),
15621
+ transient=True,
15563
15622
  ):
15564
15623
  while True:
15565
15624
  self.refresh()
@@ -16307,7 +16366,8 @@ class MlflowTrackingServer(Base):
16307
16366
  Group(progress, status),
16308
16367
  title="Wait Log Panel",
16309
16368
  border_style=Style(color=Color.BLUE.value),
16310
- )
16369
+ ),
16370
+ transient=True,
16311
16371
  ):
16312
16372
  while True:
16313
16373
  self.refresh()
@@ -17565,7 +17625,8 @@ class ModelCard(Base):
17565
17625
  Group(progress, status),
17566
17626
  title="Wait Log Panel",
17567
17627
  border_style=Style(color=Color.BLUE.value),
17568
- )
17628
+ ),
17629
+ transient=True,
17569
17630
  ):
17570
17631
  while True:
17571
17632
  self.refresh()
@@ -17939,7 +18000,11 @@ class ModelCardExportJob(Base):
17939
18000
  return self
17940
18001
 
17941
18002
  @Base.add_validate_call
17942
- def wait(self, poll: int = 5, timeout: Optional[int] = None) -> None:
18003
+ def wait(
18004
+ self,
18005
+ poll: int = 5,
18006
+ timeout: Optional[int] = None,
18007
+ ) -> None:
17943
18008
  """
17944
18009
  Wait for a ModelCardExportJob resource.
17945
18010
 
@@ -17969,7 +18034,8 @@ class ModelCardExportJob(Base):
17969
18034
  Group(progress, status),
17970
18035
  title="Wait Log Panel",
17971
18036
  border_style=Style(color=Color.BLUE.value),
17972
- )
18037
+ ),
18038
+ transient=True,
17973
18039
  ):
17974
18040
  while True:
17975
18041
  self.refresh()
@@ -18926,7 +18992,8 @@ class ModelPackage(Base):
18926
18992
  Group(progress, status),
18927
18993
  title="Wait Log Panel",
18928
18994
  border_style=Style(color=Color.BLUE.value),
18929
- )
18995
+ ),
18996
+ transient=True,
18930
18997
  ):
18931
18998
  while True:
18932
18999
  self.refresh()
@@ -19393,7 +19460,8 @@ class ModelPackageGroup(Base):
19393
19460
  Group(progress, status),
19394
19461
  title="Wait Log Panel",
19395
19462
  border_style=Style(color=Color.BLUE.value),
19396
- )
19463
+ ),
19464
+ transient=True,
19397
19465
  ):
19398
19466
  while True:
19399
19467
  self.refresh()
@@ -20782,7 +20850,8 @@ class MonitoringSchedule(Base):
20782
20850
  Group(progress, status),
20783
20851
  title="Wait Log Panel",
20784
20852
  border_style=Style(color=Color.BLUE.value),
20785
- )
20853
+ ),
20854
+ transient=True,
20786
20855
  ):
20787
20856
  while True:
20788
20857
  self.refresh()
@@ -21351,7 +21420,8 @@ class NotebookInstance(Base):
21351
21420
  Group(progress, status),
21352
21421
  title="Wait Log Panel",
21353
21422
  border_style=Style(color=Color.BLUE.value),
21354
- )
21423
+ ),
21424
+ transient=True,
21355
21425
  ):
21356
21426
  while True:
21357
21427
  self.refresh()
@@ -22190,7 +22260,11 @@ class OptimizationJob(Base):
22190
22260
  logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}")
22191
22261
 
22192
22262
  @Base.add_validate_call
22193
- def wait(self, poll: int = 5, timeout: Optional[int] = None) -> None:
22263
+ def wait(
22264
+ self,
22265
+ poll: int = 5,
22266
+ timeout: Optional[int] = None,
22267
+ ) -> None:
22194
22268
  """
22195
22269
  Wait for a OptimizationJob resource.
22196
22270
 
@@ -22220,7 +22294,8 @@ class OptimizationJob(Base):
22220
22294
  Group(progress, status),
22221
22295
  title="Wait Log Panel",
22222
22296
  border_style=Style(color=Color.BLUE.value),
22223
- )
22297
+ ),
22298
+ transient=True,
22224
22299
  ):
22225
22300
  while True:
22226
22301
  self.refresh()
@@ -22691,7 +22766,8 @@ class Pipeline(Base):
22691
22766
  Group(progress, status),
22692
22767
  title="Wait Log Panel",
22693
22768
  border_style=Style(color=Color.BLUE.value),
22694
- )
22769
+ ),
22770
+ transient=True,
22695
22771
  ):
22696
22772
  while True:
22697
22773
  self.refresh()
@@ -23089,7 +23165,8 @@ class PipelineExecution(Base):
23089
23165
  Group(progress, status),
23090
23166
  title="Wait Log Panel",
23091
23167
  border_style=Style(color=Color.BLUE.value),
23092
- )
23168
+ ),
23169
+ transient=True,
23093
23170
  ):
23094
23171
  while True:
23095
23172
  self.refresh()
@@ -24066,13 +24143,19 @@ class ProcessingJob(Base):
24066
24143
  logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}")
24067
24144
 
24068
24145
  @Base.add_validate_call
24069
- def wait(self, poll: int = 5, timeout: Optional[int] = None) -> None:
24146
+ def wait(
24147
+ self,
24148
+ poll: int = 5,
24149
+ timeout: Optional[int] = None,
24150
+ logs: Optional[bool] = False,
24151
+ ) -> None:
24070
24152
  """
24071
24153
  Wait for a ProcessingJob resource.
24072
24154
 
24073
24155
  Parameters:
24074
24156
  poll: The number of seconds to wait between each poll.
24075
24157
  timeout: The maximum number of seconds to wait before timing out.
24158
+ logs: Whether to print logs while waiting.
24076
24159
 
24077
24160
  Raises:
24078
24161
  TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
@@ -24091,18 +24174,32 @@ class ProcessingJob(Base):
24091
24174
  progress.add_task("Waiting for ProcessingJob...")
24092
24175
  status = Status("Current status:")
24093
24176
 
24177
+ instance_count = self.processing_resources.cluster_config.instance_count
24178
+ if logs:
24179
+ multi_stream_logger = MultiLogStreamHandler(
24180
+ log_group_name=f"/aws/sagemaker/ProcessingJobs",
24181
+ log_stream_name_prefix=self.get_name(),
24182
+ expected_stream_count=instance_count,
24183
+ )
24184
+
24094
24185
  with Live(
24095
24186
  Panel(
24096
24187
  Group(progress, status),
24097
24188
  title="Wait Log Panel",
24098
24189
  border_style=Style(color=Color.BLUE.value),
24099
- )
24190
+ ),
24191
+ transient=True,
24100
24192
  ):
24101
24193
  while True:
24102
24194
  self.refresh()
24103
24195
  current_status = self.processing_job_status
24104
24196
  status.update(f"Current status: [bold]{current_status}")
24105
24197
 
24198
+ if logs and multi_stream_logger.ready():
24199
+ stream_log_events = multi_stream_logger.get_latest_log_events()
24200
+ for stream_id, event in stream_log_events:
24201
+ logger.info(f"{stream_id}:\n{event['message']}")
24202
+
24106
24203
  if current_status in terminal_states:
24107
24204
  logger.info(f"Final Resource Status: [bold]{current_status}")
24108
24205
 
@@ -24530,7 +24627,8 @@ class Project(Base):
24530
24627
  Group(progress, status),
24531
24628
  title="Wait Log Panel",
24532
24629
  border_style=Style(color=Color.BLUE.value),
24533
- )
24630
+ ),
24631
+ transient=True,
24534
24632
  ):
24535
24633
  while True:
24536
24634
  self.refresh()
@@ -25175,7 +25273,8 @@ class Space(Base):
25175
25273
  Group(progress, status),
25176
25274
  title="Wait Log Panel",
25177
25275
  border_style=Style(color=Color.BLUE.value),
25178
- )
25276
+ ),
25277
+ transient=True,
25179
25278
  ):
25180
25279
  while True:
25181
25280
  self.refresh()
@@ -26426,13 +26525,19 @@ class TrainingJob(Base):
26426
26525
  logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}")
26427
26526
 
26428
26527
  @Base.add_validate_call
26429
- def wait(self, poll: int = 5, timeout: Optional[int] = None) -> None:
26528
+ def wait(
26529
+ self,
26530
+ poll: int = 5,
26531
+ timeout: Optional[int] = None,
26532
+ logs: Optional[bool] = False,
26533
+ ) -> None:
26430
26534
  """
26431
26535
  Wait for a TrainingJob resource.
26432
26536
 
26433
26537
  Parameters:
26434
26538
  poll: The number of seconds to wait between each poll.
26435
26539
  timeout: The maximum number of seconds to wait before timing out.
26540
+ logs: Whether to print logs while waiting.
26436
26541
 
26437
26542
  Raises:
26438
26543
  TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
@@ -26451,18 +26556,41 @@ class TrainingJob(Base):
26451
26556
  progress.add_task("Waiting for TrainingJob...")
26452
26557
  status = Status("Current status:")
26453
26558
 
26559
+ instance_count = (
26560
+ sum(
26561
+ instance_group.instance_count
26562
+ for instance_group in self.resource_config.instance_groups
26563
+ )
26564
+ if self.resource_config.instance_groups
26565
+ and not isinstance(self.resource_config.instance_groups, Unassigned)
26566
+ else self.resource_config.instance_count
26567
+ )
26568
+
26569
+ if logs:
26570
+ multi_stream_logger = MultiLogStreamHandler(
26571
+ log_group_name=f"/aws/sagemaker/TrainingJobs",
26572
+ log_stream_name_prefix=self.get_name(),
26573
+ expected_stream_count=instance_count,
26574
+ )
26575
+
26454
26576
  with Live(
26455
26577
  Panel(
26456
26578
  Group(progress, status),
26457
26579
  title="Wait Log Panel",
26458
26580
  border_style=Style(color=Color.BLUE.value),
26459
- )
26581
+ ),
26582
+ transient=True,
26460
26583
  ):
26461
26584
  while True:
26462
26585
  self.refresh()
26463
26586
  current_status = self.training_job_status
26464
26587
  status.update(f"Current status: [bold]{current_status}")
26465
26588
 
26589
+ if logs and multi_stream_logger.ready():
26590
+ stream_log_events = multi_stream_logger.get_latest_log_events()
26591
+ for stream_id, event in stream_log_events:
26592
+ logger.info(f"{stream_id}:\n{event['message']}")
26593
+
26466
26594
  if current_status in terminal_states:
26467
26595
  logger.info(f"Final Resource Status: [bold]{current_status}")
26468
26596
 
@@ -26877,13 +27005,19 @@ class TransformJob(Base):
26877
27005
  logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}")
26878
27006
 
26879
27007
  @Base.add_validate_call
26880
- def wait(self, poll: int = 5, timeout: Optional[int] = None) -> None:
27008
+ def wait(
27009
+ self,
27010
+ poll: int = 5,
27011
+ timeout: Optional[int] = None,
27012
+ logs: Optional[bool] = False,
27013
+ ) -> None:
26881
27014
  """
26882
27015
  Wait for a TransformJob resource.
26883
27016
 
26884
27017
  Parameters:
26885
27018
  poll: The number of seconds to wait between each poll.
26886
27019
  timeout: The maximum number of seconds to wait before timing out.
27020
+ logs: Whether to print logs while waiting.
26887
27021
 
26888
27022
  Raises:
26889
27023
  TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
@@ -26902,18 +27036,32 @@ class TransformJob(Base):
26902
27036
  progress.add_task("Waiting for TransformJob...")
26903
27037
  status = Status("Current status:")
26904
27038
 
27039
+ instance_count = self.transform_resources.instance_count
27040
+ if logs:
27041
+ multi_stream_logger = MultiLogStreamHandler(
27042
+ log_group_name=f"/aws/sagemaker/TransformJobs",
27043
+ log_stream_name_prefix=self.get_name(),
27044
+ expected_stream_count=instance_count,
27045
+ )
27046
+
26905
27047
  with Live(
26906
27048
  Panel(
26907
27049
  Group(progress, status),
26908
27050
  title="Wait Log Panel",
26909
27051
  border_style=Style(color=Color.BLUE.value),
26910
- )
27052
+ ),
27053
+ transient=True,
26911
27054
  ):
26912
27055
  while True:
26913
27056
  self.refresh()
26914
27057
  current_status = self.transform_job_status
26915
27058
  status.update(f"Current status: [bold]{current_status}")
26916
27059
 
27060
+ if logs and multi_stream_logger.ready():
27061
+ stream_log_events = multi_stream_logger.get_latest_log_events()
27062
+ for stream_id, event in stream_log_events:
27063
+ logger.info(f"{stream_id}:\n{event['message']}")
27064
+
26917
27065
  if current_status in terminal_states:
26918
27066
  logger.info(f"Final Resource Status: [bold]{current_status}")
26919
27067
 
@@ -27729,7 +27877,8 @@ class TrialComponent(Base):
27729
27877
  Group(progress, status),
27730
27878
  title="Wait Log Panel",
27731
27879
  border_style=Style(color=Color.BLUE.value),
27732
- )
27880
+ ),
27881
+ transient=True,
27733
27882
  ):
27734
27883
  while True:
27735
27884
  self.refresh()
@@ -28390,7 +28539,8 @@ class UserProfile(Base):
28390
28539
  Group(progress, status),
28391
28540
  title="Wait Log Panel",
28392
28541
  border_style=Style(color=Color.BLUE.value),
28393
- )
28542
+ ),
28543
+ transient=True,
28394
28544
  ):
28395
28545
  while True:
28396
28546
  self.refresh()
@@ -28875,7 +29025,8 @@ class Workforce(Base):
28875
29025
  Group(progress, status),
28876
29026
  title="Wait Log Panel",
28877
29027
  border_style=Style(color=Color.BLUE.value),
28878
- )
29028
+ ),
29029
+ transient=True,
28879
29030
  ):
28880
29031
  while True:
28881
29032
  self.refresh()
@@ -3306,12 +3306,14 @@ class CodeEditorAppSettings(Base):
3306
3306
  custom_images: A list of custom SageMaker images that are configured to run as a Code Editor app.
3307
3307
  lifecycle_config_arns: The Amazon Resource Name (ARN) of the Code Editor application lifecycle configuration.
3308
3308
  app_lifecycle_management: Settings that are used to configure and manage the lifecycle of CodeEditor applications.
3309
+ built_in_lifecycle_config_arn: The lifecycle configuration that runs before the default lifecycle configuration. It can override changes made in the default lifecycle configuration.
3309
3310
  """
3310
3311
 
3311
3312
  default_resource_spec: Optional[ResourceSpec] = Unassigned()
3312
3313
  custom_images: Optional[List[CustomImage]] = Unassigned()
3313
3314
  lifecycle_config_arns: Optional[List[str]] = Unassigned()
3314
3315
  app_lifecycle_management: Optional[AppLifecycleManagement] = Unassigned()
3316
+ built_in_lifecycle_config_arn: Optional[str] = Unassigned()
3315
3317
 
3316
3318
 
3317
3319
  class CodeRepository(Base):
@@ -4235,6 +4237,7 @@ class JupyterLabAppSettings(Base):
4235
4237
  code_repositories: A list of Git repositories that SageMaker automatically displays to users for cloning in the JupyterLab application.
4236
4238
  app_lifecycle_management: Indicates whether idle shutdown is activated for JupyterLab applications.
4237
4239
  emr_settings: The configuration parameters that specify the IAM roles assumed by the execution role of SageMaker (assumable roles) and the cluster instances or job execution environments (execution roles or runtime roles) to manage and access resources required for running Amazon EMR clusters or Amazon EMR Serverless applications.
4240
+ built_in_lifecycle_config_arn: The lifecycle configuration that runs before the default lifecycle configuration. It can override changes made in the default lifecycle configuration.
4238
4241
  """
4239
4242
 
4240
4243
  default_resource_spec: Optional[ResourceSpec] = Unassigned()
@@ -4243,6 +4246,7 @@ class JupyterLabAppSettings(Base):
4243
4246
  code_repositories: Optional[List[CodeRepository]] = Unassigned()
4244
4247
  app_lifecycle_management: Optional[AppLifecycleManagement] = Unassigned()
4245
4248
  emr_settings: Optional[EmrSettings] = Unassigned()
4249
+ built_in_lifecycle_config_arn: Optional[str] = Unassigned()
4246
4250
 
4247
4251
 
4248
4252
  class DefaultEbsStorageSettings(Base):
@@ -160,6 +160,12 @@ def enable_textual_rich_console_and_traceback():
160
160
  textual_rich_console_and_traceback_enabled = True
161
161
 
162
162
 
163
+ def get_rich_handler():
164
+ handler = RichHandler(markup=True)
165
+ handler.setFormatter(logging.Formatter("%(message)s"))
166
+ return handler
167
+
168
+
163
169
  def get_textual_rich_logger(name: str, log_level: str = "INFO") -> logging.Logger:
164
170
  """
165
171
  Get a logger with textual rich handler.
@@ -175,7 +181,7 @@ def get_textual_rich_logger(name: str, log_level: str = "INFO") -> logging.Logge
175
181
 
176
182
  """
177
183
  enable_textual_rich_console_and_traceback()
178
- handler = RichHandler(markup=True)
184
+ handler = get_rich_handler()
179
185
  logging.basicConfig(level=getattr(logging, log_level), handlers=[handler])
180
186
  logger = logging.getLogger(name)
181
187
 
@@ -217,8 +223,8 @@ def configure_logging(log_level=None):
217
223
  # reset any currently associated handlers with log level
218
224
  for handler in _logger.handlers:
219
225
  _logger.removeHandler(handler)
220
- console_handler = RichHandler(markup=True)
221
- _logger.addHandler(console_handler)
226
+ rich_handler = get_rich_handler()
227
+ _logger.addHandler(rich_handler)
222
228
 
223
229
 
224
230
  def is_snake_case(s: str):
@@ -20,6 +20,8 @@ OBJECT_METHODS = set(
20
20
 
21
21
  TERMINAL_STATES = set(["Completed", "Stopped", "Deleted", "Failed", "Succeeded", "Cancelled"])
22
22
 
23
+ RESOURCE_WITH_LOGS = set(["TrainingJob", "ProcessingJob", "TransformJob"])
24
+
23
25
  CONFIGURABLE_ATTRIBUTE_SUBSTRINGS = [
24
26
  "kms",
25
27
  "s3",
@@ -29,6 +29,7 @@ from sagemaker_core.tools.constants import (
29
29
  CONFIG_SCHEMA_FILE_NAME,
30
30
  PYTHON_TYPES_TO_BASIC_JSON_TYPES,
31
31
  CONFIGURABLE_ATTRIBUTE_SUBSTRINGS,
32
+ RESOURCE_WITH_LOGS,
32
33
  )
33
34
  from sagemaker_core.tools.method import Method, MethodType
34
35
  from sagemaker_core.main.utils import (
@@ -71,6 +72,8 @@ from sagemaker_core.tools.templates import (
71
72
  GET_ALL_METHOD_WITH_ARGS_TEMPLATE,
72
73
  UPDATE_METHOD_TEMPLATE_WITHOUT_DECORATOR,
73
74
  RESOURCE_METHOD_EXCEPTION_DOCSTRING,
75
+ INIT_WAIT_LOGS_TEMPLATE,
76
+ PRINT_WAIT_LOGS,
74
77
  )
75
78
  from sagemaker_core.tools.data_extractor import (
76
79
  load_combined_shapes_data,
@@ -188,6 +191,7 @@ class ResourcesCodeGen:
188
191
  "from sagemaker_core.main.utils import SageMakerClient, ResourceIterator, Unassigned, get_textual_rich_logger, "
189
192
  "snake_to_pascal, pascal_to_snake, is_not_primitive, is_not_str_dict, is_primitive_list, serialize",
190
193
  "from sagemaker_core.main.intelligent_defaults_helper import load_default_configs_for_resource_name, get_config_value",
194
+ "from sagemaker_core.main.logs import MultiLogStreamHandler",
191
195
  "from sagemaker_core.main.shapes import *",
192
196
  "from sagemaker_core.main.exceptions import *",
193
197
  ]
@@ -1541,6 +1545,28 @@ class ResourcesCodeGen:
1541
1545
 
1542
1546
  return "'(Unknown)'"
1543
1547
 
1548
+ def _get_instance_count_ref(self, resource_name: str) -> str:
1549
+ """Get the instance count reference for a resource object.
1550
+ Args:
1551
+ resource_name (str): The resource name.
1552
+ Returns:
1553
+ str: The instance count reference for resource object
1554
+ """
1555
+
1556
+ if resource_name == "TrainingJob":
1557
+ return """(
1558
+ sum(instance_group.instance_count for instance_group in self.resource_config.instance_groups)
1559
+ if self.resource_config.instance_groups and not isinstance(self.resource_config.instance_groups, Unassigned)
1560
+ else self.resource_config.instance_count
1561
+ )
1562
+ """
1563
+ elif resource_name == "TransformJob":
1564
+ return "self.transform_resources.instance_count"
1565
+ elif resource_name == "ProcessingJob":
1566
+ return "self.processing_resources.cluster_config.instance_count"
1567
+
1568
+ raise ValueError(f"Instance count reference not found for resource {resource_name}")
1569
+
1544
1570
  def generate_wait_method(self, resource_name: str) -> str:
1545
1571
  """Auto-Generate WAIT Method for a waitable resource.
1546
1572
 
@@ -1573,11 +1599,32 @@ class ResourcesCodeGen:
1573
1599
  )
1574
1600
  formatted_failed_block = add_indent(formatted_failed_block, 16)
1575
1601
 
1602
+ logs_arg = ""
1603
+ logs_arg_doc = ""
1604
+ init_wait_logs = ""
1605
+ print_wait_logs = ""
1606
+ if resource_name in RESOURCE_WITH_LOGS:
1607
+ logs_arg = "logs: Optional[bool] = False,"
1608
+ logs_arg_doc = "logs: Whether to print logs while waiting.\n"
1609
+
1610
+ instance_count = self._get_instance_count_ref(resource_name)
1611
+ init_wait_logs = add_indent(
1612
+ INIT_WAIT_LOGS_TEMPLATE.format(
1613
+ get_instance_count=instance_count,
1614
+ job_type=resource_name,
1615
+ )
1616
+ )
1617
+ print_wait_logs = add_indent(PRINT_WAIT_LOGS, 12)
1618
+
1576
1619
  formatted_method = WAIT_METHOD_TEMPLATE.format(
1577
1620
  terminal_resource_states=terminal_resource_states,
1578
1621
  status_key_path=status_key_path,
1579
1622
  failed_error_block=formatted_failed_block,
1580
1623
  resource_name=resource_name,
1624
+ logs_arg=logs_arg,
1625
+ logs_arg_doc=logs_arg_doc,
1626
+ init_wait_logs=init_wait_logs,
1627
+ print_wait_logs=print_wait_logs,
1581
1628
  )
1582
1629
  return formatted_method
1583
1630
 
@@ -262,12 +262,31 @@ if "failed" in current_status.lower():
262
262
  raise FailedStatusError(resource_type="{resource_name}", status=current_status, reason={reason})
263
263
  """
264
264
 
265
+ INIT_WAIT_LOGS_TEMPLATE = """
266
+ instance_count = {get_instance_count}
267
+ if logs:
268
+ multi_stream_logger = MultiLogStreamHandler(
269
+ log_group_name=f"/aws/sagemaker/{job_type}s",
270
+ log_stream_name_prefix=self.get_name(),
271
+ expected_stream_count=instance_count
272
+ )
273
+ """
274
+
275
+ PRINT_WAIT_LOGS = """
276
+ if logs and multi_stream_logger.ready():
277
+ stream_log_events = multi_stream_logger.get_latest_log_events()
278
+ for stream_id, event in stream_log_events:
279
+ logger.info(f"{stream_id}:\\n{event['message']}")
280
+ """
281
+
282
+
265
283
  WAIT_METHOD_TEMPLATE = '''
266
284
  @Base.add_validate_call
267
285
  def wait(
268
286
  self,
269
287
  poll: int = 5,
270
- timeout: Optional[int] = None
288
+ timeout: Optional[int] = None,
289
+ {logs_arg}
271
290
  ) -> None:
272
291
  """
273
292
  Wait for a {resource_name} resource.
@@ -275,7 +294,7 @@ def wait(
275
294
  Parameters:
276
295
  poll: The number of seconds to wait between each poll.
277
296
  timeout: The maximum number of seconds to wait before timing out.
278
-
297
+ {logs_arg_doc}
279
298
  Raises:
280
299
  TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
281
300
  FailedStatusError: If the resource reaches a failed state.
@@ -291,13 +310,22 @@ def wait(
291
310
  )
292
311
  progress.add_task("Waiting for {resource_name}...")
293
312
  status = Status("Current status:")
294
-
295
- with Live(Panel(Group(progress, status), title="Wait Log Panel", border_style=Style(color=Color.BLUE.value))):
313
+ {init_wait_logs}
314
+
315
+ with Live(
316
+ Panel(
317
+ Group(progress, status),
318
+ title="Wait Log Panel",
319
+ border_style=Style(color=Color.BLUE.value
320
+ )
321
+ ),
322
+ transient=True
323
+ ):
296
324
  while True:
297
325
  self.refresh()
298
326
  current_status = self{status_key_path}
299
327
  status.update(f"Current status: [bold]{{current_status}}")
300
-
328
+ {print_wait_logs}
301
329
  if current_status in terminal_states:
302
330
  logger.info(f"Final Resource Status: [bold]{{current_status}}")
303
331
  {failed_error_block}
@@ -338,7 +366,15 @@ def wait_for_status(
338
366
  progress.add_task(f"Waiting for {resource_name} to reach [bold]{{target_status}} status...")
339
367
  status = Status("Current status:")
340
368
 
341
- with Live(Panel(Group(progress, status), title="Wait Log Panel", border_style=Style(color=Color.BLUE.value))):
369
+ with Live(
370
+ Panel(
371
+ Group(progress, status),
372
+ title="Wait Log Panel",
373
+ border_style=Style(color=Color.BLUE.value
374
+ )
375
+ ),
376
+ transient=True
377
+ ):
342
378
  while True:
343
379
  self.refresh()
344
380
  current_status = self{status_key_path}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sagemaker-core
3
- Version: 1.0.8
3
+ Version: 1.0.10
4
4
  Summary: An python package for sagemaker core functionalities
5
5
  Author-email: AWS <sagemaker-interests@amazon.com>
6
6
  Project-URL: Repository, https://github.com/aws/sagemaker-core.git
@@ -6,29 +6,30 @@ sagemaker_core/main/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSu
6
6
  sagemaker_core/main/config_schema.py,sha256=TeGoTobT4yotEVyfguLF0IdKYlOymsDZ45ySxXiCDuw,56998
7
7
  sagemaker_core/main/exceptions.py,sha256=87DUlrmHxaWoiYNlpNY9ixxFMPRk_dIGPsA2e_xdVwQ,5602
8
8
  sagemaker_core/main/intelligent_defaults_helper.py,sha256=5SDM6UavZtp-k5LhqRL7GRIDgzFB5UsC_p7YuiSPK9A,8334
9
- sagemaker_core/main/resources.py,sha256=VjjmTLzPcjjg3c8-1g9Fv88Gpwimd9KuNq02378B6cE,1316658
10
- sagemaker_core/main/shapes.py,sha256=Groaj8psouJwQpRljmZIDeRM6fyU6p2J3wgjzcvQ_5k,696205
9
+ sagemaker_core/main/logs.py,sha256=yfEH7uP91nbE1lefymOlBr81ziBzsDSIOF2Qyd54FJE,6241
10
+ sagemaker_core/main/resources.py,sha256=uJs7vsqZcu6-xjLcMCX-Ys8wDJQ7LAHtfaf3riMygv8,1321298
11
+ sagemaker_core/main/shapes.py,sha256=5loClZSSYkEmdshd8UaT0d4VzMtahZSMlLpIO600g7o,696705
11
12
  sagemaker_core/main/user_agent.py,sha256=4sZybDXkzRoZnOnVDQ8p8zFTfiRJdsH7amDWInVQ4xU,2708
12
- sagemaker_core/main/utils.py,sha256=lXkJyiCow5uf32l0EmkimB0RKVk2BS7OM2fYoLsOfD4,18346
13
+ sagemaker_core/main/utils.py,sha256=LCFDM6oxf6_e1i-_Dgtkm3ehl7YfoEpJ2kTTFTL6iOU,18471
13
14
  sagemaker_core/main/code_injection/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
15
  sagemaker_core/main/code_injection/base.py,sha256=11_Jif0nOzfbLGlXaacKf-wcizzfS64U0OSZGoVffFU,1733
15
16
  sagemaker_core/main/code_injection/codec.py,sha256=2DjmeD2uND307UqDefvVEpE0rZ8yfFU3Bi3TvQCQveI,7658
16
17
  sagemaker_core/main/code_injection/constants.py,sha256=2ICExGge8vAWx7lSTW0JGh-bH1korkvpOpDu5M63eI4,980
17
- sagemaker_core/main/code_injection/shape_dag.py,sha256=0OoQzH_r_TAYR2-KUKGQnGMD4U9Cb-NRMxc_SORIOdw,657895
18
+ sagemaker_core/main/code_injection/shape_dag.py,sha256=FyU_a0Jh3AVxTNVN2kjt0oHqqyHNjAgIy3G7kzqksm4,658657
18
19
  sagemaker_core/resources/__init__.py,sha256=EAYTFMN-nPjnPjjBbhIUeaL67FLKNPd7qbcbl9VIrws,31
19
20
  sagemaker_core/shapes/__init__.py,sha256=RnbIu9eTxKt-DNsOFJabrWIgrrtS9_SdAozP9JBl_ic,28
20
21
  sagemaker_core/tools/__init__.py,sha256=xX79JImxCVzrWMnjgntLCve2G5I-R4pRar5s20kT9Rs,56
21
22
  sagemaker_core/tools/codegen.py,sha256=mKWVi2pWnPxyIoWUEPYjEc9Gw7D9bCOrHqa00yzIZ1o,2005
22
- sagemaker_core/tools/constants.py,sha256=8oM0nHuzXWnOJTIdXsu50f9RsEz4rgmmZKWhLkrTP-s,3309
23
+ sagemaker_core/tools/constants.py,sha256=a2WjUDK7gzxgilZs99vp30qh4kQ-y6JKhrwwqVAA12o,3385
23
24
  sagemaker_core/tools/data_extractor.py,sha256=pNfmTA0NUA96IgfLrla7a36Qjc1NljbwgZYaOhouKqQ,2113
24
25
  sagemaker_core/tools/method.py,sha256=4Hmu4UWpiBgUTZljYdW1KIKDduDxf_nfhCyuWgLVMWI,717
25
- sagemaker_core/tools/resources_codegen.py,sha256=ShsWjABBkSpPlp3gByzpe36M6hfYfBofODlUynqVXTY,82405
26
+ sagemaker_core/tools/resources_codegen.py,sha256=ASirF9UMkGAOYZrrZxmFqK3gYQ4YbVYdAiUcWt6qdII,84360
26
27
  sagemaker_core/tools/resources_extractor.py,sha256=hN61ehZbPnhFW-2FIVDi7NsEz4rLvGr-WoglHQGfrug,14523
27
28
  sagemaker_core/tools/shapes_codegen.py,sha256=_ve959bwH8usZ6dPlpXxi2on9t0hLpcmhRWnaWHCWMQ,11745
28
29
  sagemaker_core/tools/shapes_extractor.py,sha256=4KjgDmhlPM4G1f1NeYbORKlXs1s7Q_sm_NK31S_ROQ0,11950
29
- sagemaker_core/tools/templates.py,sha256=nze_A01EpegYUwoR_gRv2qBNKNFruBY8L3RiIX5lz3M,22458
30
- sagemaker_core-1.0.8.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
31
- sagemaker_core-1.0.8.dist-info/METADATA,sha256=MvlRWwl72_MVfsHjlrirz1RooJ1vJBuyCrZe_HDJr9s,4877
32
- sagemaker_core-1.0.8.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
33
- sagemaker_core-1.0.8.dist-info/top_level.txt,sha256=R3GAZZ1zC5JxqdE_0x2Lu_WYi2Xfke7VsiP3L5zngfA,15
34
- sagemaker_core-1.0.8.dist-info/RECORD,,
30
+ sagemaker_core/tools/templates.py,sha256=yX2RQKeClgYwKS5Qu_mDpnWJIBCuj0yELrdm95aiTpk,23262
31
+ sagemaker_core-1.0.10.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
32
+ sagemaker_core-1.0.10.dist-info/METADATA,sha256=LJ5rjbEJgXzd3afuvD5kIzDQ2Yxx1d2wCobLb5Jy7OI,4878
33
+ sagemaker_core-1.0.10.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
34
+ sagemaker_core-1.0.10.dist-info/top_level.txt,sha256=R3GAZZ1zC5JxqdE_0x2Lu_WYi2Xfke7VsiP3L5zngfA,15
35
+ sagemaker_core-1.0.10.dist-info/RECORD,,