toil 6.1.0a1__py3-none-any.whl → 8.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. toil/__init__.py +122 -315
  2. toil/batchSystems/__init__.py +1 -0
  3. toil/batchSystems/abstractBatchSystem.py +173 -89
  4. toil/batchSystems/abstractGridEngineBatchSystem.py +272 -148
  5. toil/batchSystems/awsBatch.py +244 -135
  6. toil/batchSystems/cleanup_support.py +26 -16
  7. toil/batchSystems/contained_executor.py +31 -28
  8. toil/batchSystems/gridengine.py +86 -50
  9. toil/batchSystems/htcondor.py +166 -89
  10. toil/batchSystems/kubernetes.py +632 -382
  11. toil/batchSystems/local_support.py +20 -15
  12. toil/batchSystems/lsf.py +134 -81
  13. toil/batchSystems/lsfHelper.py +13 -11
  14. toil/batchSystems/mesos/__init__.py +41 -29
  15. toil/batchSystems/mesos/batchSystem.py +290 -151
  16. toil/batchSystems/mesos/executor.py +79 -50
  17. toil/batchSystems/mesos/test/__init__.py +31 -23
  18. toil/batchSystems/options.py +46 -28
  19. toil/batchSystems/registry.py +53 -19
  20. toil/batchSystems/singleMachine.py +296 -125
  21. toil/batchSystems/slurm.py +603 -138
  22. toil/batchSystems/torque.py +47 -33
  23. toil/bus.py +186 -76
  24. toil/common.py +664 -368
  25. toil/cwl/__init__.py +1 -1
  26. toil/cwl/cwltoil.py +1136 -483
  27. toil/cwl/utils.py +17 -22
  28. toil/deferred.py +63 -42
  29. toil/exceptions.py +5 -3
  30. toil/fileStores/__init__.py +5 -5
  31. toil/fileStores/abstractFileStore.py +140 -60
  32. toil/fileStores/cachingFileStore.py +717 -269
  33. toil/fileStores/nonCachingFileStore.py +116 -87
  34. toil/job.py +1225 -368
  35. toil/jobStores/abstractJobStore.py +416 -266
  36. toil/jobStores/aws/jobStore.py +863 -477
  37. toil/jobStores/aws/utils.py +201 -120
  38. toil/jobStores/conftest.py +3 -2
  39. toil/jobStores/fileJobStore.py +292 -154
  40. toil/jobStores/googleJobStore.py +140 -74
  41. toil/jobStores/utils.py +36 -15
  42. toil/leader.py +668 -272
  43. toil/lib/accelerators.py +115 -18
  44. toil/lib/aws/__init__.py +74 -31
  45. toil/lib/aws/ami.py +122 -87
  46. toil/lib/aws/iam.py +284 -108
  47. toil/lib/aws/s3.py +31 -0
  48. toil/lib/aws/session.py +214 -39
  49. toil/lib/aws/utils.py +287 -231
  50. toil/lib/bioio.py +13 -5
  51. toil/lib/compatibility.py +11 -6
  52. toil/lib/conversions.py +104 -47
  53. toil/lib/docker.py +131 -103
  54. toil/lib/ec2.py +361 -199
  55. toil/lib/ec2nodes.py +174 -106
  56. toil/lib/encryption/_dummy.py +5 -3
  57. toil/lib/encryption/_nacl.py +10 -6
  58. toil/lib/encryption/conftest.py +1 -0
  59. toil/lib/exceptions.py +26 -7
  60. toil/lib/expando.py +5 -3
  61. toil/lib/ftp_utils.py +217 -0
  62. toil/lib/generatedEC2Lists.py +127 -19
  63. toil/lib/humanize.py +6 -2
  64. toil/lib/integration.py +341 -0
  65. toil/lib/io.py +141 -15
  66. toil/lib/iterables.py +4 -2
  67. toil/lib/memoize.py +12 -8
  68. toil/lib/misc.py +66 -21
  69. toil/lib/objects.py +2 -2
  70. toil/lib/resources.py +68 -15
  71. toil/lib/retry.py +126 -81
  72. toil/lib/threading.py +299 -82
  73. toil/lib/throttle.py +16 -15
  74. toil/options/common.py +843 -409
  75. toil/options/cwl.py +175 -90
  76. toil/options/runner.py +50 -0
  77. toil/options/wdl.py +73 -17
  78. toil/provisioners/__init__.py +117 -46
  79. toil/provisioners/abstractProvisioner.py +332 -157
  80. toil/provisioners/aws/__init__.py +70 -33
  81. toil/provisioners/aws/awsProvisioner.py +1145 -715
  82. toil/provisioners/clusterScaler.py +541 -279
  83. toil/provisioners/gceProvisioner.py +282 -179
  84. toil/provisioners/node.py +155 -79
  85. toil/realtimeLogger.py +34 -22
  86. toil/resource.py +137 -75
  87. toil/server/app.py +128 -62
  88. toil/server/celery_app.py +3 -1
  89. toil/server/cli/wes_cwl_runner.py +82 -53
  90. toil/server/utils.py +54 -28
  91. toil/server/wes/abstract_backend.py +64 -26
  92. toil/server/wes/amazon_wes_utils.py +21 -15
  93. toil/server/wes/tasks.py +121 -63
  94. toil/server/wes/toil_backend.py +142 -107
  95. toil/server/wsgi_app.py +4 -3
  96. toil/serviceManager.py +58 -22
  97. toil/statsAndLogging.py +224 -70
  98. toil/test/__init__.py +282 -183
  99. toil/test/batchSystems/batchSystemTest.py +460 -210
  100. toil/test/batchSystems/batch_system_plugin_test.py +90 -0
  101. toil/test/batchSystems/test_gridengine.py +173 -0
  102. toil/test/batchSystems/test_lsf_helper.py +67 -58
  103. toil/test/batchSystems/test_slurm.py +110 -49
  104. toil/test/cactus/__init__.py +0 -0
  105. toil/test/cactus/test_cactus_integration.py +56 -0
  106. toil/test/cwl/cwlTest.py +496 -287
  107. toil/test/cwl/measure_default_memory.cwl +12 -0
  108. toil/test/cwl/not_run_required_input.cwl +29 -0
  109. toil/test/cwl/scatter_duplicate_outputs.cwl +40 -0
  110. toil/test/cwl/seqtk_seq.cwl +1 -1
  111. toil/test/docs/scriptsTest.py +69 -46
  112. toil/test/jobStores/jobStoreTest.py +427 -264
  113. toil/test/lib/aws/test_iam.py +118 -50
  114. toil/test/lib/aws/test_s3.py +16 -9
  115. toil/test/lib/aws/test_utils.py +5 -6
  116. toil/test/lib/dockerTest.py +118 -141
  117. toil/test/lib/test_conversions.py +113 -115
  118. toil/test/lib/test_ec2.py +58 -50
  119. toil/test/lib/test_integration.py +104 -0
  120. toil/test/lib/test_misc.py +12 -5
  121. toil/test/mesos/MesosDataStructuresTest.py +23 -10
  122. toil/test/mesos/helloWorld.py +7 -6
  123. toil/test/mesos/stress.py +25 -20
  124. toil/test/options/__init__.py +13 -0
  125. toil/test/options/options.py +42 -0
  126. toil/test/provisioners/aws/awsProvisionerTest.py +320 -150
  127. toil/test/provisioners/clusterScalerTest.py +440 -250
  128. toil/test/provisioners/clusterTest.py +166 -44
  129. toil/test/provisioners/gceProvisionerTest.py +174 -100
  130. toil/test/provisioners/provisionerTest.py +25 -13
  131. toil/test/provisioners/restartScript.py +5 -4
  132. toil/test/server/serverTest.py +188 -141
  133. toil/test/sort/restart_sort.py +137 -68
  134. toil/test/sort/sort.py +134 -66
  135. toil/test/sort/sortTest.py +91 -49
  136. toil/test/src/autoDeploymentTest.py +141 -101
  137. toil/test/src/busTest.py +20 -18
  138. toil/test/src/checkpointTest.py +8 -2
  139. toil/test/src/deferredFunctionTest.py +49 -35
  140. toil/test/src/dockerCheckTest.py +32 -24
  141. toil/test/src/environmentTest.py +135 -0
  142. toil/test/src/fileStoreTest.py +539 -272
  143. toil/test/src/helloWorldTest.py +7 -4
  144. toil/test/src/importExportFileTest.py +61 -31
  145. toil/test/src/jobDescriptionTest.py +46 -21
  146. toil/test/src/jobEncapsulationTest.py +2 -0
  147. toil/test/src/jobFileStoreTest.py +74 -50
  148. toil/test/src/jobServiceTest.py +187 -73
  149. toil/test/src/jobTest.py +121 -71
  150. toil/test/src/miscTests.py +19 -18
  151. toil/test/src/promisedRequirementTest.py +82 -36
  152. toil/test/src/promisesTest.py +7 -6
  153. toil/test/src/realtimeLoggerTest.py +10 -6
  154. toil/test/src/regularLogTest.py +71 -37
  155. toil/test/src/resourceTest.py +80 -49
  156. toil/test/src/restartDAGTest.py +36 -22
  157. toil/test/src/resumabilityTest.py +9 -2
  158. toil/test/src/retainTempDirTest.py +45 -14
  159. toil/test/src/systemTest.py +12 -8
  160. toil/test/src/threadingTest.py +44 -25
  161. toil/test/src/toilContextManagerTest.py +10 -7
  162. toil/test/src/userDefinedJobArgTypeTest.py +8 -5
  163. toil/test/src/workerTest.py +73 -23
  164. toil/test/utils/toilDebugTest.py +103 -33
  165. toil/test/utils/toilKillTest.py +4 -5
  166. toil/test/utils/utilsTest.py +245 -106
  167. toil/test/wdl/wdltoil_test.py +818 -149
  168. toil/test/wdl/wdltoil_test_kubernetes.py +91 -0
  169. toil/toilState.py +120 -35
  170. toil/utils/toilConfig.py +13 -4
  171. toil/utils/toilDebugFile.py +44 -27
  172. toil/utils/toilDebugJob.py +214 -27
  173. toil/utils/toilDestroyCluster.py +11 -6
  174. toil/utils/toilKill.py +8 -3
  175. toil/utils/toilLaunchCluster.py +256 -140
  176. toil/utils/toilMain.py +37 -16
  177. toil/utils/toilRsyncCluster.py +32 -14
  178. toil/utils/toilSshCluster.py +49 -22
  179. toil/utils/toilStats.py +356 -273
  180. toil/utils/toilStatus.py +292 -139
  181. toil/utils/toilUpdateEC2Instances.py +3 -1
  182. toil/version.py +12 -12
  183. toil/wdl/utils.py +5 -5
  184. toil/wdl/wdltoil.py +3913 -1033
  185. toil/worker.py +367 -184
  186. {toil-6.1.0a1.dist-info → toil-8.0.0.dist-info}/LICENSE +25 -0
  187. toil-8.0.0.dist-info/METADATA +173 -0
  188. toil-8.0.0.dist-info/RECORD +253 -0
  189. {toil-6.1.0a1.dist-info → toil-8.0.0.dist-info}/WHEEL +1 -1
  190. toil-6.1.0a1.dist-info/METADATA +0 -125
  191. toil-6.1.0a1.dist-info/RECORD +0 -237
  192. {toil-6.1.0a1.dist-info → toil-8.0.0.dist-info}/entry_points.txt +0 -0
  193. {toil-6.1.0a1.dist-info → toil-8.0.0.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,8 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ from __future__ import annotations
15
+
14
16
  import json
15
17
  import logging
16
18
  import os
@@ -20,75 +22,93 @@ import string
20
22
  import textwrap
21
23
  import time
22
24
  import uuid
25
+ from collections.abc import Collection, Iterable
23
26
  from functools import wraps
24
27
  from shlex import quote
25
- from typing import (Any,
26
- Callable,
27
- Collection,
28
- Dict,
29
- Iterable,
30
- List,
31
- Optional,
32
- Set)
33
- from urllib.parse import unquote
28
+ from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
34
29
 
35
30
  # We need these to exist as attributes we can get off of the boto object
36
- import boto.ec2
37
- import boto.iam
38
- import boto.vpc
39
- from boto.ec2.blockdevicemapping import \
40
- BlockDeviceMapping as Boto2BlockDeviceMapping
41
- from boto.ec2.blockdevicemapping import BlockDeviceType as Boto2BlockDeviceType
42
- from boto.ec2.instance import Instance as Boto2Instance
43
- from boto.exception import BotoServerError, EC2ResponseError
44
- from boto.utils import get_instance_metadata
45
31
  from botocore.exceptions import ClientError
46
32
 
47
- from toil.lib.aws import zone_to_region
33
+ from toil.lib.aws import AWSRegionName, AWSServerErrors, zone_to_region
48
34
  from toil.lib.aws.ami import get_flatcar_ami
49
- from toil.lib.aws.iam import (CLUSTER_LAUNCHING_PERMISSIONS,
50
- get_policy_permissions,
51
- policy_permissions_allow)
35
+ from toil.lib.aws.iam import (
36
+ CLUSTER_LAUNCHING_PERMISSIONS,
37
+ create_iam_role,
38
+ get_policy_permissions,
39
+ policy_permissions_allow,
40
+ )
52
41
  from toil.lib.aws.session import AWSConnectionManager
53
- from toil.lib.aws.utils import create_s3_bucket
42
+ from toil.lib.aws.session import client as get_client
43
+ from toil.lib.aws.utils import boto3_pager, create_s3_bucket
54
44
  from toil.lib.conversions import human2bytes
55
- from toil.lib.ec2 import (a_short_time,
56
- create_auto_scaling_group,
57
- create_instances,
58
- create_launch_template,
59
- create_ondemand_instances,
60
- create_spot_instances,
61
- wait_instances_running,
62
- wait_transition,
63
- wait_until_instance_profile_arn_exists)
45
+ from toil.lib.ec2 import (
46
+ a_short_time,
47
+ create_auto_scaling_group,
48
+ create_instances,
49
+ create_launch_template,
50
+ create_ondemand_instances,
51
+ create_spot_instances,
52
+ increase_instance_hop_limit,
53
+ wait_instances_running,
54
+ wait_transition,
55
+ wait_until_instance_profile_arn_exists,
56
+ )
64
57
  from toil.lib.ec2nodes import InstanceType
65
58
  from toil.lib.generatedEC2Lists import E2Instances
66
59
  from toil.lib.memoize import memoize
67
60
  from toil.lib.misc import truncExpBackoff
68
- from toil.lib.retry import (ErrorCondition,
69
- get_error_body,
70
- get_error_code,
71
- get_error_message,
72
- get_error_status,
73
- old_retry,
74
- retry)
75
- from toil.provisioners import (ClusterCombinationNotSupportedException,
76
- NoSuchClusterException)
77
- from toil.provisioners.abstractProvisioner import (AbstractProvisioner,
78
- ManagedNodesNotSupportedException,
79
- Shape)
61
+ from toil.lib.retry import (
62
+ get_error_body,
63
+ get_error_code,
64
+ get_error_message,
65
+ get_error_status,
66
+ old_retry,
67
+ retry,
68
+ )
69
+ from toil.provisioners import (
70
+ ClusterCombinationNotSupportedException,
71
+ NoSuchClusterException,
72
+ NoSuchZoneException,
73
+ )
74
+ from toil.provisioners.abstractProvisioner import (
75
+ AbstractProvisioner,
76
+ ManagedNodesNotSupportedException,
77
+ Shape,
78
+ )
80
79
  from toil.provisioners.aws import get_best_aws_zone
81
80
  from toil.provisioners.node import Node
82
81
 
82
+ if TYPE_CHECKING:
83
+ from mypy_boto3_autoscaling.client import AutoScalingClient
84
+ from mypy_boto3_ec2.client import EC2Client
85
+ from mypy_boto3_ec2.service_resource import Instance
86
+ from mypy_boto3_ec2.type_defs import (
87
+ BlockDeviceMappingTypeDef,
88
+ CreateSecurityGroupResultTypeDef,
89
+ DescribeInstancesResultTypeDef,
90
+ EbsBlockDeviceTypeDef,
91
+ FilterTypeDef,
92
+ InstanceTypeDef,
93
+ IpPermissionTypeDef,
94
+ ReservationTypeDef,
95
+ SecurityGroupTypeDef,
96
+ SpotInstanceRequestTypeDef,
97
+ TagDescriptionTypeDef,
98
+ TagTypeDef,
99
+ )
100
+ from mypy_boto3_iam.client import IAMClient
101
+ from mypy_boto3_iam.type_defs import InstanceProfileTypeDef, RoleTypeDef
102
+
83
103
  logger = logging.getLogger(__name__)
84
104
  logging.getLogger("boto").setLevel(logging.CRITICAL)
85
105
  # Role name (used as the suffix) for EC2 instance profiles that are automatically created by Toil.
86
- _INSTANCE_PROFILE_ROLE_NAME = 'toil'
106
+ _INSTANCE_PROFILE_ROLE_NAME = "toil"
87
107
  # The tag key that specifies the Toil node type ("leader" or "worker") so that
88
108
  # leader vs. worker nodes can be robustly identified.
89
- _TAG_KEY_TOIL_NODE_TYPE = 'ToilNodeType'
109
+ _TAG_KEY_TOIL_NODE_TYPE = "ToilNodeType"
90
110
  # The tag that specifies the cluster name on all nodes
91
- _TAG_KEY_TOIL_CLUSTER_NAME = 'clusterName'
111
+ _TAG_KEY_TOIL_CLUSTER_NAME = "clusterName"
92
112
  # How much storage on the root volume is expected to go to overhead and be
93
113
  # unavailable to jobs when the node comes up?
94
114
  # TODO: measure
@@ -96,29 +116,23 @@ _STORAGE_ROOT_OVERHEAD_GIGS = 4
96
116
  # The maximum length of a S3 bucket
97
117
  _S3_BUCKET_MAX_NAME_LEN = 63
98
118
  # The suffix of the S3 bucket associated with the cluster
99
- _S3_BUCKET_INTERNAL_SUFFIX = '--internal'
100
-
101
- # prevent removal of these imports
102
- str(boto.ec2)
103
- str(boto.iam)
104
- str(boto.vpc)
105
-
119
+ _S3_BUCKET_INTERNAL_SUFFIX = "--internal"
106
120
 
107
121
 
108
- def awsRetryPredicate(e):
122
+ def awsRetryPredicate(e: Exception) -> bool:
109
123
  if isinstance(e, socket.gaierror):
110
124
  # Could be a DNS outage:
111
125
  # socket.gaierror: [Errno -2] Name or service not known
112
126
  return True
113
127
  # boto/AWS gives multiple messages for the same error...
114
- if get_error_status(e) == 503 and 'Request limit exceeded' in get_error_body(e):
128
+ if get_error_status(e) == 503 and "Request limit exceeded" in get_error_body(e):
115
129
  return True
116
- elif get_error_status(e) == 400 and 'Rate exceeded' in get_error_body(e):
130
+ elif get_error_status(e) == 400 and "Rate exceeded" in get_error_body(e):
117
131
  return True
118
- elif get_error_status(e) == 400 and 'NotFound' in get_error_body(e):
132
+ elif get_error_status(e) == 400 and "NotFound" in get_error_body(e):
119
133
  # EC2 can take a while to propagate instance IDs to all servers.
120
134
  return True
121
- elif get_error_status(e) == 400 and get_error_code(e) == 'Throttling':
135
+ elif get_error_status(e) == 400 and get_error_code(e) == "Throttling":
122
136
  return True
123
137
  return False
124
138
 
@@ -132,47 +146,89 @@ def expectedShutdownErrors(e: Exception) -> bool:
132
146
  impossible or unnecessary (such as errors resulting from a thing not
133
147
  existing to be deleted).
134
148
  """
135
- return get_error_status(e) == 400 and 'dependent object' in get_error_body(e)
149
+ return get_error_status(e) == 400 and "dependent object" in get_error_body(e)
136
150
 
137
151
 
138
- def awsRetry(f):
152
+ F = TypeVar("F") # so mypy understands passed through types
153
+
154
+
155
+ def awsRetry(f: Callable[..., F]) -> Callable[..., F]:
139
156
  """
140
- This decorator retries the wrapped function if aws throws unexpected errors
141
- errors.
157
+ This decorator retries the wrapped function if aws throws unexpected errors.
158
+
142
159
  It should wrap any function that makes use of boto
143
160
  """
161
+
144
162
  @wraps(f)
145
- def wrapper(*args, **kwargs):
146
- for attempt in old_retry(delays=truncExpBackoff(),
147
- timeout=300,
148
- predicate=awsRetryPredicate):
163
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
164
+ for attempt in old_retry(
165
+ delays=truncExpBackoff(), timeout=300, predicate=awsRetryPredicate
166
+ ):
149
167
  with attempt:
150
168
  return f(*args, **kwargs)
169
+
151
170
  return wrapper
152
171
 
153
172
 
154
- def awsFilterImpairedNodes(nodes, ec2):
173
+ def awsFilterImpairedNodes(
174
+ nodes: list[InstanceTypeDef], boto3_ec2: EC2Client
175
+ ) -> list[InstanceTypeDef]:
155
176
  # if TOIL_AWS_NODE_DEBUG is set don't terminate nodes with
156
177
  # failing status checks so they can be debugged
157
- nodeDebug = os.environ.get('TOIL_AWS_NODE_DEBUG') in ('True', 'TRUE', 'true', True)
178
+ nodeDebug = os.environ.get("TOIL_AWS_NODE_DEBUG") in ("True", "TRUE", "true", True)
158
179
  if not nodeDebug:
159
180
  return nodes
160
- nodeIDs = [node.id for node in nodes]
161
- statuses = ec2.get_all_instance_status(instance_ids=nodeIDs)
162
- statusMap = {status.id: status.instance_status for status in statuses}
163
- healthyNodes = [node for node in nodes if statusMap.get(node.id, None) != 'impaired']
164
- impairedNodes = [node.id for node in nodes if statusMap.get(node.id, None) == 'impaired']
165
- logger.warning('TOIL_AWS_NODE_DEBUG is set and nodes %s have failed EC2 status checks so '
166
- 'will not be terminated.', ' '.join(impairedNodes))
181
+ nodeIDs = [node["InstanceId"] for node in nodes]
182
+ statuses = boto3_ec2.describe_instance_status(InstanceIds=nodeIDs)
183
+ statusMap = {
184
+ status["InstanceId"]: status["InstanceStatus"]["Status"]
185
+ for status in statuses["InstanceStatuses"]
186
+ }
187
+ healthyNodes = [
188
+ node for node in nodes if statusMap.get(node["InstanceId"], None) != "impaired"
189
+ ]
190
+ impairedNodes = [
191
+ node["InstanceId"]
192
+ for node in nodes
193
+ if statusMap.get(node["InstanceId"], None) == "impaired"
194
+ ]
195
+ logger.warning(
196
+ "TOIL_AWS_NODE_DEBUG is set and nodes %s have failed EC2 status checks so "
197
+ "will not be terminated.",
198
+ " ".join(impairedNodes),
199
+ )
167
200
  return healthyNodes
168
201
 
169
202
 
170
203
  class InvalidClusterStateException(Exception):
171
204
  pass
172
205
 
206
+
207
+ def collapse_tags(instance_tags: list[TagTypeDef]) -> dict[str, str]:
208
+ """
209
+ Collapse tags from boto3 format to node format
210
+ :param instance_tags: tags as a list
211
+ :return: Dict of tags
212
+ """
213
+ collapsed_tags: dict[str, str] = dict()
214
+ for tag in instance_tags:
215
+ if tag.get("Key") is not None:
216
+ collapsed_tags[tag["Key"]] = tag["Value"]
217
+ return collapsed_tags
218
+
219
+
173
220
  class AWSProvisioner(AbstractProvisioner):
174
- def __init__(self, clusterName, clusterType, zone, nodeStorage, nodeStorageOverrides, sseKey):
175
- self.cloud = 'aws'
221
+ def __init__(
222
+ self,
223
+ clusterName: str | None,
224
+ clusterType: str | None,
225
+ zone: str | None,
226
+ nodeStorage: int,
227
+ nodeStorageOverrides: list[str] | None,
228
+ sseKey: str | None,
229
+ enable_fuse: bool,
230
+ ):
231
+ self.cloud = "aws"
176
232
  self._sseKey = sseKey
177
233
  # self._zone will be filled in by base class constructor
178
234
  # We will use it as the leader zone.
@@ -180,33 +236,54 @@ class AWSProvisioner(AbstractProvisioner):
180
236
 
181
237
  if zone is None:
182
238
  # Can't proceed without a real zone
183
- raise RuntimeError('No AWS availability zone specified. Configure in Boto '
184
- 'configuration file, TOIL_AWS_ZONE environment variable, or '
185
- 'on the command line.')
239
+ raise RuntimeError(
240
+ "No AWS availability zone specified. Configure in Boto "
241
+ "configuration file, TOIL_AWS_ZONE environment variable, or "
242
+ "on the command line."
243
+ )
186
244
 
187
245
  # Determine our region to work in, before readClusterSettings() which
188
246
  # might need it. TODO: support multiple regions in one cluster
189
- self._region = zone_to_region(zone)
247
+ self._region: AWSRegionName = zone_to_region(zone)
190
248
 
191
249
  # Set up our connections to AWS
192
250
  self.aws = AWSConnectionManager()
193
251
 
194
252
  # Set our architecture to the current machine architecture
195
253
  # Assume the same architecture unless specified differently in launchCluster()
196
- self._architecture = 'amd64' if platform.machine() in ['x86_64', 'amd64'] else 'arm64'
254
+ self._architecture = (
255
+ "amd64" if platform.machine() in ["x86_64", "amd64"] else "arm64"
256
+ )
197
257
 
198
258
  # Call base class constructor, which will call createClusterSettings()
199
259
  # or readClusterSettings()
200
- super().__init__(clusterName, clusterType, zone, nodeStorage, nodeStorageOverrides)
260
+ super().__init__(
261
+ clusterName,
262
+ clusterType,
263
+ zone,
264
+ nodeStorage,
265
+ nodeStorageOverrides,
266
+ enable_fuse,
267
+ )
268
+
269
+ if self._zone is None:
270
+ logger.warning(
271
+ "Leader zone was never initialized before creating AWS provisioner. Defaulting to cluster zone."
272
+ )
273
+
274
+ self._leader_subnet: str = self._get_default_subnet(self._zone or zone)
275
+ self._tags: dict[str, Any] = {}
201
276
 
202
277
  # After self.clusterName is set, generate a valid name for the S3 bucket associated with this cluster
203
278
  suffix = _S3_BUCKET_INTERNAL_SUFFIX
204
- self.s3_bucket_name = self.clusterName[:_S3_BUCKET_MAX_NAME_LEN - len(suffix)] + suffix
279
+ self.s3_bucket_name = (
280
+ self.clusterName[: _S3_BUCKET_MAX_NAME_LEN - len(suffix)] + suffix
281
+ )
205
282
 
206
- def supportedClusterTypes(self):
207
- return {'mesos', 'kubernetes'}
283
+ def supportedClusterTypes(self) -> set[str]:
284
+ return {"mesos", "kubernetes"}
208
285
 
209
- def createClusterSettings(self):
286
+ def createClusterSettings(self) -> None:
210
287
  """
211
288
  Create a new set of cluster settings for a cluster to be deployed into
212
289
  AWS.
@@ -216,47 +293,64 @@ class AWSProvisioner(AbstractProvisioner):
216
293
  # constructor.
217
294
  assert self._zone is not None
218
295
 
219
- def readClusterSettings(self):
296
+ def readClusterSettings(self) -> None:
220
297
  """
221
298
  Reads the cluster settings from the instance metadata, which assumes
222
299
  the instance is the leader.
223
300
  """
224
- instanceMetaData = get_instance_metadata()
225
- ec2 = self.aws.boto2(self._region, 'ec2')
226
- instance = ec2.get_all_instances(instance_ids=[instanceMetaData["instance-id"]])[0].instances[0]
301
+ from ec2_metadata import ec2_metadata
302
+
303
+ boto3_ec2 = self.aws.client(self._region, "ec2")
304
+ instance: InstanceTypeDef = boto3_ec2.describe_instances(
305
+ InstanceIds=[ec2_metadata.instance_id]
306
+ )["Reservations"][0]["Instances"][0]
227
307
  # The cluster name is the same as the name of the leader.
228
- self.clusterName = str(instance.tags["Name"])
308
+ self.clusterName: str = "default-toil-cluster-name"
309
+ for tag in instance["Tags"]:
310
+ if tag.get("Key") == "Name":
311
+ self.clusterName = tag["Value"]
229
312
  # Determine what subnet we, the leader, are in
230
- self._leader_subnet = instance.subnet_id
313
+ self._leader_subnet = instance["SubnetId"]
231
314
  # Determine where to deploy workers.
232
315
  self._worker_subnets_by_zone = self._get_good_subnets_like(self._leader_subnet)
233
316
 
234
- self._leaderPrivateIP = instanceMetaData['local-ipv4'] # this is PRIVATE IP
235
- self._keyName = list(instanceMetaData['public-keys'].keys())[0]
236
- self._tags = {k: v for k, v in self.getLeader().tags.items() if k != _TAG_KEY_TOIL_NODE_TYPE}
317
+ self._leaderPrivateIP = ec2_metadata.private_ipv4 # this is PRIVATE IP
318
+ self._tags = {
319
+ k: v
320
+ for k, v in (self.getLeader().tags or {}).items()
321
+ if k != _TAG_KEY_TOIL_NODE_TYPE
322
+ }
237
323
  # Grab the ARN name of the instance profile (a str) to apply to workers
238
- self._leaderProfileArn = instanceMetaData['iam']['info']['InstanceProfileArn']
324
+ leader_info = None
325
+ for attempt in old_retry(timeout=300, predicate=lambda e: True):
326
+ with attempt:
327
+ leader_info = ec2_metadata.iam_info
328
+ if leader_info is None:
329
+ raise RuntimeError("Could not get EC2 metadata IAM info")
330
+ if leader_info is None:
331
+ # This is more for mypy as it is unable to see that the retry will guarantee this is not None
332
+ # and that this is not reachable
333
+ raise RuntimeError(f"Leader IAM metadata is unreachable.")
334
+ self._leaderProfileArn = leader_info["InstanceProfileArn"]
335
+
239
336
  # The existing metadata API returns a single string if there is one security group, but
240
337
  # a list when there are multiple: change the format to always be a list.
241
- rawSecurityGroups = instanceMetaData['security-groups']
242
- self._leaderSecurityGroupNames = {rawSecurityGroups} if not isinstance(rawSecurityGroups, list) else set(rawSecurityGroups)
338
+ rawSecurityGroups = ec2_metadata.security_groups
339
+ self._leaderSecurityGroupNames: set[str] = set(rawSecurityGroups)
243
340
  # Since we have access to the names, we don't also need to use any IDs
244
- self._leaderSecurityGroupIDs = set()
341
+ self._leaderSecurityGroupIDs: set[str] = set()
245
342
 
246
343
  # Let the base provisioner work out how to deploy duly authorized
247
344
  # workers for this leader.
248
345
  self._setLeaderWorkerAuthentication()
249
346
 
250
- @retry(errors=[ErrorCondition(
251
- error=ClientError,
252
- error_codes=[404, 500, 502, 503, 504]
253
- )])
347
+ @retry(errors=[AWSServerErrors])
254
348
  def _write_file_to_cloud(self, key: str, contents: bytes) -> str:
255
349
  bucket_name = self.s3_bucket_name
256
350
 
257
351
  # Connect to S3
258
- s3 = self.aws.resource(self._region, 's3')
259
- s3_client = self.aws.client(self._region, 's3')
352
+ s3 = self.aws.resource(self._region, "s3")
353
+ s3_client = self.aws.client(self._region, "s3")
260
354
 
261
355
  # create bucket if needed, then write file to S3
262
356
  try:
@@ -265,14 +359,18 @@ class AWSProvisioner(AbstractProvisioner):
265
359
  bucket = s3.Bucket(bucket_name)
266
360
  except ClientError as err:
267
361
  if get_error_status(err) == 404:
268
- bucket = create_s3_bucket(s3, bucket_name=bucket_name, region=self._region)
362
+ bucket = create_s3_bucket(
363
+ s3, bucket_name=bucket_name, region=self._region
364
+ )
269
365
  bucket.wait_until_exists()
270
366
  bucket.Versioning().enable()
271
367
 
272
- owner_tag = os.environ.get('TOIL_OWNER_TAG')
368
+ owner_tag = os.environ.get("TOIL_OWNER_TAG")
273
369
  if owner_tag:
274
370
  bucket_tagging = s3.BucketTagging(bucket_name)
275
- bucket_tagging.put(Tagging={'TagSet': [{'Key': 'Owner', 'Value': owner_tag}]})
371
+ bucket_tagging.put(
372
+ Tagging={"TagSet": [{"Key": "Owner", "Value": owner_tag}]}
373
+ )
276
374
  else:
277
375
  raise
278
376
 
@@ -282,34 +380,38 @@ class AWSProvisioner(AbstractProvisioner):
282
380
  obj.put(Body=contents)
283
381
 
284
382
  obj.wait_until_exists()
285
- return f's3://{bucket_name}/{key}'
383
+ return f"s3://{bucket_name}/{key}"
286
384
 
287
385
  def _read_file_from_cloud(self, key: str) -> bytes:
288
386
  bucket_name = self.s3_bucket_name
289
- obj = self.aws.resource(self._region, 's3').Object(bucket_name, key)
387
+ obj = self.aws.resource(self._region, "s3").Object(bucket_name, key)
290
388
 
291
389
  try:
292
- return obj.get().get('Body').read()
390
+ return obj.get()["Body"].read()
293
391
  except ClientError as e:
294
392
  if get_error_status(e) == 404:
295
- logger.warning(f'Trying to read non-existent file "{key}" from {bucket_name}.')
393
+ logger.warning(
394
+ f'Trying to read non-existent file "{key}" from {bucket_name}.'
395
+ )
296
396
  raise
297
397
 
298
398
  def _get_user_data_limit(self) -> int:
299
399
  # See: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-add-user-data.html
300
- return human2bytes('16KB')
301
-
302
- def launchCluster(self,
303
- leaderNodeType: str,
304
- leaderStorage: int,
305
- owner: str,
306
- keyName: str,
307
- botoPath: str,
308
- userTags: Optional[dict],
309
- vpcSubnet: Optional[str],
310
- awsEc2ProfileArn: Optional[str],
311
- awsEc2ExtraSecurityGroupIds: Optional[list],
312
- **kwargs):
400
+ return human2bytes("16KB")
401
+
402
+ def launchCluster(
403
+ self,
404
+ leaderNodeType: str,
405
+ leaderStorage: int,
406
+ owner: str,
407
+ keyName: str,
408
+ botoPath: str,
409
+ userTags: dict[str, str] | None,
410
+ vpcSubnet: str | None,
411
+ awsEc2ProfileArn: str | None,
412
+ awsEc2ExtraSecurityGroupIds: list[str] | None,
413
+ **kwargs: dict[str, Any],
414
+ ) -> None:
313
415
  """
314
416
  Starts a single leader node and populates this class with the leader's metadata.
315
417
 
@@ -325,36 +427,46 @@ class AWSProvisioner(AbstractProvisioner):
325
427
  :return: None
326
428
  """
327
429
 
328
- if 'network' in kwargs:
329
- logger.warning('AWS provisioner does not support a network parameter. Ignoring %s!', kwargs["network"])
430
+ if "network" in kwargs:
431
+ logger.warning(
432
+ "AWS provisioner does not support a network parameter. Ignoring %s!",
433
+ kwargs["network"],
434
+ )
330
435
 
331
436
  # First, pre-flight-check our permissions before making anything.
332
- if not policy_permissions_allow(get_policy_permissions(region=self._region), CLUSTER_LAUNCHING_PERMISSIONS):
437
+ if not policy_permissions_allow(
438
+ get_policy_permissions(region=self._region), CLUSTER_LAUNCHING_PERMISSIONS
439
+ ):
333
440
  # Function prints a more specific warning to the log, but give some context.
334
- logger.warning('Toil may not be able to properly launch (or destroy!) your cluster.')
441
+ logger.warning(
442
+ "Toil may not be able to properly launch (or destroy!) your cluster."
443
+ )
335
444
 
336
445
  leader_type = E2Instances[leaderNodeType]
337
446
 
338
- if self.clusterType == 'kubernetes':
447
+ if self.clusterType == "kubernetes":
339
448
  if leader_type.cores < 2:
340
449
  # Kubernetes won't run here.
341
- raise RuntimeError('Kubernetes requires 2 or more cores, and %s is too small' %
342
- leaderNodeType)
450
+ raise RuntimeError(
451
+ "Kubernetes requires 2 or more cores, and %s is too small"
452
+ % leaderNodeType
453
+ )
343
454
  self._keyName = keyName
344
455
  self._architecture = leader_type.architecture
345
456
 
346
- if self.clusterType == 'mesos' and self._architecture != 'amd64':
457
+ if self.clusterType == "mesos" and self._architecture != "amd64":
347
458
  # Mesos images aren't currently available for this architecture, so we can't start a Mesos cluster.
348
459
  # Complain about this before we create anything.
349
- raise ClusterCombinationNotSupportedException(type(self), self.clusterType, self._architecture,
350
- reason="Mesos is only available for amd64.")
460
+ raise ClusterCombinationNotSupportedException(
461
+ type(self),
462
+ self.clusterType,
463
+ self._architecture,
464
+ reason="Mesos is only available for amd64.",
465
+ )
351
466
 
352
467
  if vpcSubnet:
353
468
  # This is where we put the leader
354
469
  self._leader_subnet = vpcSubnet
355
- else:
356
- # Find the default subnet for the zone
357
- self._leader_subnet = self._get_default_subnet(self._zone)
358
470
 
359
471
  profileArn = awsEc2ProfileArn or self._createProfileArn()
360
472
 
@@ -363,45 +475,49 @@ class AWSProvisioner(AbstractProvisioner):
363
475
  bdms = self._getBoto3BlockDeviceMappings(leader_type, rootVolSize=leaderStorage)
364
476
 
365
477
  # Make up the tags
366
- self._tags = {'Name': self.clusterName,
367
- 'Owner': owner,
368
- _TAG_KEY_TOIL_CLUSTER_NAME: self.clusterName}
478
+ self._tags = {
479
+ "Name": self.clusterName,
480
+ "Owner": owner,
481
+ _TAG_KEY_TOIL_CLUSTER_NAME: self.clusterName,
482
+ }
369
483
 
370
484
  if userTags is not None:
371
485
  self._tags.update(userTags)
372
486
 
373
- #All user specified tags have been set
374
- userData = self._getIgnitionUserData('leader', architecture=self._architecture)
487
+ # All user specified tags have been set
488
+ userData = self._getIgnitionUserData("leader", architecture=self._architecture)
375
489
 
376
- if self.clusterType == 'kubernetes':
490
+ if self.clusterType == "kubernetes":
377
491
  # All nodes need a tag putting them in the cluster.
378
492
  # This tag needs to be on there before the a leader can finish its startup.
379
- self._tags['kubernetes.io/cluster/' + self.clusterName] = ''
493
+ self._tags["kubernetes.io/cluster/" + self.clusterName] = ""
380
494
 
381
495
  # Make tags for the leader specifically
382
496
  leader_tags = dict(self._tags)
383
- leader_tags[_TAG_KEY_TOIL_NODE_TYPE] = 'leader'
384
- logger.debug('Launching leader with tags: %s', leader_tags)
385
-
386
- instances = create_instances(self.aws.resource(self._region, 'ec2'),
387
- image_id=self._discoverAMI(),
388
- num_instances=1,
389
- key_name=self._keyName,
390
- security_group_ids=createdSGs + (awsEc2ExtraSecurityGroupIds or []),
391
- instance_type=leader_type.name,
392
- user_data=userData,
393
- block_device_map=bdms,
394
- instance_profile_arn=profileArn,
395
- placement_az=self._zone,
396
- subnet_id=self._leader_subnet,
397
- tags=leader_tags)
497
+ leader_tags[_TAG_KEY_TOIL_NODE_TYPE] = "leader"
498
+ logger.debug("Launching leader with tags: %s", leader_tags)
499
+
500
+ instances: list[Instance] = create_instances(
501
+ self.aws.resource(self._region, "ec2"),
502
+ image_id=self._discoverAMI(),
503
+ num_instances=1,
504
+ key_name=self._keyName,
505
+ security_group_ids=createdSGs + (awsEc2ExtraSecurityGroupIds or []),
506
+ instance_type=leader_type.name,
507
+ user_data=userData,
508
+ block_device_map=bdms,
509
+ instance_profile_arn=profileArn,
510
+ placement_az=self._zone,
511
+ subnet_id=self._leader_subnet,
512
+ tags=leader_tags,
513
+ )
398
514
 
399
515
  # wait for the leader to exist at all
400
516
  leader = instances[0]
401
517
  leader.wait_until_exists()
402
518
 
403
519
  # Don't go on until the leader is started
404
- logger.info('Waiting for leader instance %s to be running', leader)
520
+ logger.info("Waiting for leader instance %s to be running", leader)
405
521
  leader.wait_until_running()
406
522
 
407
523
  # Now reload it to make sure all the IPs are set.
@@ -411,22 +527,31 @@ class AWSProvisioner(AbstractProvisioner):
411
527
  # Sometimes AWS just fails to assign a public IP when we really need one.
412
528
  # But sometimes people intend to use private IPs only in Toil-managed clusters.
413
529
  # TODO: detect if we have a route to the private IP and fail fast if not.
414
- logger.warning("AWS did not assign a public IP to the cluster leader. If you aren't "
415
- "connected to the private subnet, cluster setup will fail!")
530
+ logger.warning(
531
+ "AWS did not assign a public IP to the cluster leader. If you aren't "
532
+ "connected to the private subnet, cluster setup will fail!"
533
+ )
416
534
 
417
535
  # Remember enough about the leader to let us launch workers in its
418
536
  # cluster.
419
537
  self._leaderPrivateIP = leader.private_ip_address
420
538
  self._worker_subnets_by_zone = self._get_good_subnets_like(self._leader_subnet)
421
539
  self._leaderSecurityGroupNames = set()
422
- self._leaderSecurityGroupIDs = set(createdSGs + (awsEc2ExtraSecurityGroupIds or []))
540
+ self._leaderSecurityGroupIDs = set(
541
+ createdSGs + (awsEc2ExtraSecurityGroupIds or [])
542
+ )
423
543
  self._leaderProfileArn = profileArn
424
544
 
425
- leaderNode = Node(publicIP=leader.public_ip_address, privateIP=leader.private_ip_address,
426
- name=leader.id, launchTime=leader.launch_time,
427
- nodeType=leader_type.name, preemptible=False,
428
- tags=leader.tags)
429
- leaderNode.waitForNode('toil_leader')
545
+ leaderNode = Node(
546
+ publicIP=leader.public_ip_address,
547
+ privateIP=leader.private_ip_address,
548
+ name=leader.id,
549
+ launchTime=leader.launch_time,
550
+ nodeType=leader_type.name,
551
+ preemptible=False,
552
+ tags=collapse_tags(leader.tags),
553
+ )
554
+ leaderNode.waitForNode("toil_leader")
430
555
 
431
556
  # Download credentials
432
557
  self._setLeaderWorkerAuthentication(leaderNode)
@@ -439,7 +564,7 @@ class AWSProvisioner(AbstractProvisioner):
439
564
  instance_base_tags = json.dumps(self._tags)
440
565
  return config + " -e TOIL_AWS_TAGS=" + quote(instance_base_tags)
441
566
 
442
- def _get_worker_subnets(self) -> List[str]:
567
+ def _get_worker_subnets(self) -> list[str]:
443
568
  """
444
569
  Get all worker subnets we should balance across, as a flat list.
445
570
  """
@@ -455,7 +580,7 @@ class AWSProvisioner(AbstractProvisioner):
455
580
  return collected
456
581
 
457
582
  @awsRetry
458
- def _get_good_subnets_like(self, base_subnet_id: str) -> Dict[str, List[str]]:
583
+ def _get_good_subnets_like(self, base_subnet_id: str) -> dict[str, list[str]]:
459
584
  """
460
585
  Given a subnet ID, get all the similar subnets (including it),
461
586
  organized by availability zone.
@@ -466,9 +591,9 @@ class AWSProvisioner(AbstractProvisioner):
466
591
  """
467
592
 
468
593
  # Grab the ec2 resource we need to make queries
469
- ec2 = self.aws.resource(self._region, 'ec2')
594
+ ec2 = self.aws.resource(self._region, "ec2")
470
595
  # And the client
471
- ec2_client = self.aws.client(self._region, 'ec2')
596
+ ec2_client = self.aws.client(self._region, "ec2")
472
597
 
473
598
  # What subnet are we basing this on?
474
599
  base_subnet = ec2.Subnet(base_subnet_id)
@@ -483,31 +608,33 @@ class AWSProvisioner(AbstractProvisioner):
483
608
  acls = set(self._get_subnet_acls(base_subnet_id))
484
609
 
485
610
  # Compose a filter that selects the subnets we might want
486
- filters = [{
487
- 'Name': 'vpc-id',
488
- 'Values': [vpc_id]
489
- }, {
490
- 'Name': 'default-for-az',
491
- 'Values': ['true' if is_default else 'false']
492
- }, {
493
- 'Name': 'state',
494
- 'Values': ['available']
495
- }]
611
+ filters: list[FilterTypeDef] = [
612
+ {"Name": "vpc-id", "Values": [vpc_id]},
613
+ {"Name": "default-for-az", "Values": ["true" if is_default else "false"]},
614
+ {"Name": "state", "Values": ["available"]},
615
+ ]
496
616
 
497
617
  # Fill in this collection
498
- by_az = {}
618
+ by_az: dict[str, list[str]] = {}
499
619
 
500
620
  # Go get all the subnets. There's no way to page manually here so it
501
621
  # must page automatically.
502
- for subnet in self.aws.resource(self._region, 'ec2').subnets.filter(Filters=filters):
622
+ for subnet in self.aws.resource(self._region, "ec2").subnets.filter(
623
+ Filters=filters
624
+ ):
503
625
  # For each subnet in the VPC
504
626
 
505
627
  # See if it has the right ACLs
506
628
  subnet_acls = set(self._get_subnet_acls(subnet.subnet_id))
507
629
  if subnet_acls != acls:
508
630
  # Reject this subnet because it has different ACLs
509
- logger.debug('Subnet %s is a lot like subnet %s but has ACLs of %s instead of %s; skipping',
510
- subnet.subnet_id, base_subnet_id, subnet_acls, acls)
631
+ logger.debug(
632
+ "Subnet %s is a lot like subnet %s but has ACLs of %s instead of %s; skipping",
633
+ subnet.subnet_id,
634
+ base_subnet_id,
635
+ subnet_acls,
636
+ acls,
637
+ )
511
638
  continue
512
639
 
513
640
  if subnet.availability_zone not in by_az:
@@ -519,24 +646,24 @@ class AWSProvisioner(AbstractProvisioner):
519
646
  return by_az
520
647
 
521
648
  @awsRetry
522
- def _get_subnet_acls(self, subnet: str) -> List[str]:
649
+ def _get_subnet_acls(self, subnet: str) -> list[str]:
523
650
  """
524
651
  Get all Network ACL IDs associated with a given subnet ID.
525
652
  """
526
653
 
527
654
  # Grab the connection we need to use for this operation.
528
- ec2 = self.aws.client(self._region, 'ec2')
655
+ ec2 = self.aws.client(self._region, "ec2")
529
656
 
530
657
  # Compose a filter that selects the default subnet in the AZ
531
- filters = [{
532
- 'Name': 'association.subnet-id',
533
- 'Values': [subnet]
534
- }]
658
+ filters = [{"Name": "association.subnet-id", "Values": [subnet]}]
535
659
 
536
660
  # TODO: Can't we use the resource's network_acls.filter(Filters=)?
537
- return [item['NetworkAclId'] for item in self._pager(ec2.describe_network_acls,
538
- 'NetworkAcls',
539
- Filters=filters)]
661
+ return [
662
+ item["NetworkAclId"]
663
+ for item in boto3_pager(
664
+ ec2.describe_network_acls, "NetworkAcls", Filters=filters
665
+ )
666
+ ]
540
667
 
541
668
  @awsRetry
542
669
  def _get_default_subnet(self, zone: str) -> str:
@@ -546,30 +673,32 @@ class AWSProvisioner(AbstractProvisioner):
546
673
  """
547
674
 
548
675
  # Compose a filter that selects the default subnet in the AZ
549
- filters = [{
550
- 'Name': 'default-for-az',
551
- 'Values': ['true']
552
- }, {
553
- 'Name': 'availability-zone',
554
- 'Values': [zone]
555
- }]
556
-
557
- for subnet in self.aws.resource(zone_to_region(zone), 'ec2').subnets.filter(Filters=filters):
676
+ filters: list[FilterTypeDef] = [
677
+ {"Name": "default-for-az", "Values": ["true"]},
678
+ {"Name": "availability-zone", "Values": [zone]},
679
+ ]
680
+
681
+ for subnet in self.aws.resource(zone_to_region(zone), "ec2").subnets.filter(
682
+ Filters=filters
683
+ ):
558
684
  # There should only be one result, so when we see it, return it
559
685
  return subnet.subnet_id
560
686
  # If we don't find a subnet, something is wrong. Maybe this zone was
561
687
  # added after your account?
562
- raise RuntimeError(f"No default subnet found in availability zone {zone}. "
563
- f"Note that Amazon does not add default subnets for new "
564
- f"zones to old accounts. Specify a VPC subnet ID to use, "
565
- f"or create a default subnet in the zone.")
688
+ raise RuntimeError(
689
+ f"No default subnet found in availability zone {zone}. "
690
+ f"Note that Amazon does not add default subnets for new "
691
+ f"zones to old accounts. Specify a VPC subnet ID to use, "
692
+ f"or create a default subnet in the zone."
693
+ )
566
694
 
567
- def getKubernetesAutoscalerSetupCommands(self, values: Dict[str, str]) -> str:
695
+ def getKubernetesAutoscalerSetupCommands(self, values: dict[str, str]) -> str:
568
696
  """
569
697
  Get the Bash commands necessary to configure the Kubernetes Cluster Autoscaler for AWS.
570
698
  """
571
699
 
572
- return textwrap.dedent('''\
700
+ return textwrap.dedent(
701
+ """\
573
702
  curl -sSL https://raw.githubusercontent.com/kubernetes/autoscaler/cluster-autoscaler-{AUTOSCALER_VERSION}/cluster-autoscaler/cloudprovider/aws/examples/cluster-autoscaler-run-on-master.yaml | \\
574
703
  sed "s|--nodes={{{{ node_asg_min }}}}:{{{{ node_asg_max }}}}:{{{{ name }}}}|--node-group-auto-discovery=asg:tag=k8s.io/cluster-autoscaler/enabled,k8s.io/cluster-autoscaler/{CLUSTER_NAME}|" | \\
575
704
  sed 's|kubernetes.io/role: master|node-role.kubernetes.io/master: ""|' | \\
@@ -577,39 +706,44 @@ class AWSProvisioner(AbstractProvisioner):
577
706
  sed '/value: "true"/d' | \\
578
707
  sed 's|path: "/etc/ssl/certs/ca-bundle.crt"|path: "/usr/share/ca-certificates/ca-certificates.crt"|' | \\
579
708
  kubectl apply -f -
580
- ''').format(**values)
709
+ """
710
+ ).format(**values)
581
711
 
582
- def getKubernetesCloudProvider(self) -> Optional[str]:
712
+ def getKubernetesCloudProvider(self) -> str | None:
583
713
  """
584
714
  Use the "aws" Kubernetes cloud provider when setting up Kubernetes.
585
715
  """
586
716
 
587
- return 'aws'
717
+ return "aws"
588
718
 
589
- def getNodeShape(self, instance_type: str, preemptible=False) -> Shape:
719
+ def getNodeShape(self, instance_type: str, preemptible: bool = False) -> Shape:
590
720
  """
591
721
  Get the Shape for the given instance type (e.g. 't2.medium').
592
722
  """
593
723
  type_info = E2Instances[instance_type]
594
724
 
595
- disk = type_info.disks * type_info.disk_capacity * 2 ** 30
725
+ disk = type_info.disks * type_info.disk_capacity * 2**30
596
726
  if disk == 0:
597
727
  # This is an EBS-backed instance. We will use the root
598
728
  # volume, so add the amount of EBS storage requested for
599
729
  # the root volume
600
- disk = self._nodeStorageOverrides.get(instance_type, self._nodeStorage) * 2 ** 30
730
+ disk = (
731
+ self._nodeStorageOverrides.get(instance_type, self._nodeStorage) * 2**30
732
+ )
601
733
 
602
734
  # Underestimate memory by 100M to prevent autoscaler from disagreeing with
603
735
  # mesos about whether a job can run on a particular node type
604
- memory = (type_info.memory - 0.1) * 2 ** 30
605
- return Shape(wallTime=60 * 60,
606
- memory=memory,
607
- cores=type_info.cores,
608
- disk=disk,
609
- preemptible=preemptible)
736
+ memory = (type_info.memory - 0.1) * 2**30
737
+ return Shape(
738
+ wallTime=60 * 60,
739
+ memory=int(memory),
740
+ cores=type_info.cores,
741
+ disk=int(disk),
742
+ preemptible=preemptible,
743
+ )
610
744
 
611
745
  @staticmethod
612
- def retryPredicate(e):
746
+ def retryPredicate(e: Exception) -> bool:
613
747
  return awsRetryPredicate(e)
614
748
 
615
749
  def destroyCluster(self) -> None:
@@ -619,16 +753,16 @@ class AWSProvisioner(AbstractProvisioner):
619
753
  # The leader may create more instances while we're terminating the workers.
620
754
  vpcId = None
621
755
  try:
622
- leader = self._getLeaderInstance()
623
- vpcId = leader.vpc_id
624
- logger.info('Terminating the leader first ...')
756
+ leader = self._getLeaderInstanceBoto3()
757
+ vpcId = leader.get("VpcId")
758
+ logger.info("Terminating the leader first ...")
625
759
  self._terminateInstances([leader])
626
760
  except (NoSuchClusterException, InvalidClusterStateException):
627
761
  # It's ok if the leader is not found. We'll terminate any remaining
628
762
  # instances below anyway.
629
763
  pass
630
764
 
631
- logger.debug('Deleting autoscaling groups ...')
765
+ logger.debug("Deleting autoscaling groups ...")
632
766
  removed = False
633
767
 
634
768
  for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors):
@@ -636,35 +770,47 @@ class AWSProvisioner(AbstractProvisioner):
636
770
  for asgName in self._getAutoScalingGroupNames():
637
771
  try:
638
772
  # We delete the group and all the instances via ForceDelete.
639
- self.aws.client(self._region, 'autoscaling').delete_auto_scaling_group(AutoScalingGroupName=asgName, ForceDelete=True)
773
+ self.aws.client(
774
+ self._region, "autoscaling"
775
+ ).delete_auto_scaling_group(
776
+ AutoScalingGroupName=asgName, ForceDelete=True
777
+ )
640
778
  removed = True
641
779
  except ClientError as e:
642
- if get_error_code(e) == 'ValidationError' and 'AutoScalingGroup name not found' in get_error_message(e):
780
+ if get_error_code(
781
+ e
782
+ ) == "ValidationError" and "AutoScalingGroup name not found" in get_error_message(
783
+ e
784
+ ):
643
785
  # This ASG does not need to be removed (or a
644
786
  # previous delete returned an error but also
645
787
  # succeeded).
646
788
  pass
647
789
 
648
790
  if removed:
649
- logger.debug('... Successfully deleted autoscaling groups')
791
+ logger.debug("... Successfully deleted autoscaling groups")
650
792
 
651
793
  # Do the workers after the ASGs because some may belong to ASGs
652
- logger.info('Terminating any remaining workers ...')
794
+ logger.info("Terminating any remaining workers ...")
653
795
  removed = False
654
- instances = self._get_nodes_in_cluster(include_stopped_nodes=True)
796
+ instances = self._get_nodes_in_cluster_boto3(include_stopped_nodes=True)
655
797
  spotIDs = self._getSpotRequestIDs()
798
+ boto3_ec2: EC2Client = self.aws.client(region=self._region, service_name="ec2")
656
799
  if spotIDs:
657
- self.aws.boto2(self._region, 'ec2').cancel_spot_instance_requests(request_ids=spotIDs)
800
+ boto3_ec2.cancel_spot_instance_requests(SpotInstanceRequestIds=spotIDs)
801
+ # self.aws.boto2(self._region, 'ec2').cancel_spot_instance_requests(request_ids=spotIDs)
658
802
  removed = True
659
- instancesToTerminate = awsFilterImpairedNodes(instances, self.aws.boto2(self._region, 'ec2'))
803
+ instancesToTerminate = awsFilterImpairedNodes(
804
+ instances, self.aws.client(self._region, "ec2")
805
+ )
660
806
  if instancesToTerminate:
661
- vpcId = vpcId or instancesToTerminate[0].vpc_id
807
+ vpcId = vpcId or instancesToTerminate[0].get("VpcId")
662
808
  self._terminateInstances(instancesToTerminate)
663
809
  removed = True
664
810
  if removed:
665
- logger.debug('... Successfully terminated workers')
811
+ logger.debug("... Successfully terminated workers")
666
812
 
667
- logger.info('Deleting launch templates ...')
813
+ logger.info("Deleting launch templates ...")
668
814
  removed = False
669
815
  for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors):
670
816
  with attempt:
@@ -672,8 +818,8 @@ class AWSProvisioner(AbstractProvisioner):
672
818
  # for some LuanchTemplate.
673
819
  mistake = False
674
820
  for ltID in self._get_launch_template_ids():
675
- response = self.aws.client(self._region, 'ec2').delete_launch_template(LaunchTemplateId=ltID)
676
- if 'LaunchTemplate' not in response:
821
+ response = boto3_ec2.delete_launch_template(LaunchTemplateId=ltID)
822
+ if "LaunchTemplate" not in response:
677
823
  mistake = True
678
824
  else:
679
825
  removed = True
@@ -681,48 +827,59 @@ class AWSProvisioner(AbstractProvisioner):
681
827
  # We missed something
682
828
  removed = False
683
829
  if removed:
684
- logger.debug('... Successfully deleted launch templates')
830
+ logger.debug("... Successfully deleted launch templates")
685
831
 
686
832
  if len(instances) == len(instancesToTerminate):
687
833
  # All nodes are gone now.
688
834
 
689
- logger.info('Deleting IAM roles ...')
835
+ logger.info("Deleting IAM roles ...")
690
836
  self._deleteRoles(self._getRoleNames())
691
837
  self._deleteInstanceProfiles(self._getInstanceProfileNames())
692
838
 
693
- logger.info('Deleting security group ...')
839
+ logger.info("Deleting security group ...")
694
840
  removed = False
695
841
  for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors):
696
842
  with attempt:
697
- for sg in self.aws.boto2(self._region, 'ec2').get_all_security_groups():
843
+ security_groups: list[SecurityGroupTypeDef] = (
844
+ boto3_ec2.describe_security_groups()["SecurityGroups"]
845
+ )
846
+ for security_group in security_groups:
698
847
  # TODO: If we terminate the leader and the workers but
699
848
  # miss the security group, we won't find it now because
700
849
  # we won't have vpcId set.
701
- if sg.name == self.clusterName and vpcId and sg.vpc_id == vpcId:
850
+ if (
851
+ security_group.get("GroupName") == self.clusterName
852
+ and vpcId
853
+ and security_group.get("VpcId") == vpcId
854
+ ):
702
855
  try:
703
- self.aws.boto2(self._region, 'ec2').delete_security_group(group_id=sg.id)
856
+ boto3_ec2.delete_security_group(
857
+ GroupId=security_group["GroupId"]
858
+ )
704
859
  removed = True
705
- except BotoServerError as e:
706
- if e.error_code == 'InvalidGroup.NotFound':
860
+ except ClientError as e:
861
+ if get_error_code(e) == "InvalidGroup.NotFound":
707
862
  pass
708
863
  else:
709
864
  raise
710
865
  if removed:
711
- logger.debug('... Successfully deleted security group')
866
+ logger.debug("... Successfully deleted security group")
712
867
  else:
713
868
  assert len(instances) > len(instancesToTerminate)
714
869
  # the security group can't be deleted until all nodes are terminated
715
- logger.warning('The TOIL_AWS_NODE_DEBUG environment variable is set and some nodes '
716
- 'have failed health checks. As a result, the security group & IAM '
717
- 'roles will not be deleted.')
870
+ logger.warning(
871
+ "The TOIL_AWS_NODE_DEBUG environment variable is set and some nodes "
872
+ "have failed health checks. As a result, the security group & IAM "
873
+ "roles will not be deleted."
874
+ )
718
875
 
719
876
  # delete S3 buckets that might have been created by `self._write_file_to_cloud()`
720
- logger.info('Deleting S3 buckets ...')
877
+ logger.info("Deleting S3 buckets ...")
721
878
  removed = False
722
879
  for attempt in old_retry(timeout=300, predicate=awsRetryPredicate):
723
880
  with attempt:
724
881
  # Grab the S3 resource to use
725
- s3 = self.aws.resource(self._region, 's3')
882
+ s3 = self.aws.resource(self._region, "s3")
726
883
  try:
727
884
  bucket = s3.Bucket(self.s3_bucket_name)
728
885
 
@@ -738,13 +895,15 @@ class AWSProvisioner(AbstractProvisioner):
738
895
  else:
739
896
  raise # retry this
740
897
  if removed:
741
- print('... Successfully deleted S3 buckets')
898
+ print("... Successfully deleted S3 buckets")
742
899
 
743
- def terminateNodes(self, nodes: List[Node]) -> None:
900
+ def terminateNodes(self, nodes: list[Node]) -> None:
744
901
  if nodes:
745
902
  self._terminateIDs([x.name for x in nodes])
746
903
 
747
- def _recover_node_type_bid(self, node_type: Set[str], spot_bid: Optional[float]) -> Optional[float]:
904
+ def _recover_node_type_bid(
905
+ self, node_type: set[str], spot_bid: float | None
906
+ ) -> float | None:
748
907
  """
749
908
  The old Toil-managed autoscaler will tell us to make some nodes of
750
909
  particular instance types, and to just work out a bid, but it doesn't
@@ -771,16 +930,23 @@ class AWSProvisioner(AbstractProvisioner):
771
930
  break
772
931
  if spot_bid is None:
773
932
  # We didn't bid on any class including this type either
774
- raise RuntimeError("No spot bid given for a preemptible node request.")
933
+ raise RuntimeError(
934
+ "No spot bid given for a preemptible node request."
935
+ )
775
936
  else:
776
937
  raise RuntimeError("No spot bid given for a preemptible node request.")
777
938
 
778
939
  return spot_bid
779
940
 
780
- def addNodes(self, nodeTypes: Set[str], numNodes, preemptible, spotBid=None) -> int:
941
+ def addNodes(
942
+ self,
943
+ nodeTypes: set[str],
944
+ numNodes: int,
945
+ preemptible: bool,
946
+ spotBid: float | None = None,
947
+ ) -> int:
781
948
  # Grab the AWS connection we need
782
- ec2 = self.aws.boto2(self._region, 'ec2')
783
-
949
+ boto3_ec2 = get_client(service_name="ec2", region_name=self._region)
784
950
  assert self._leaderPrivateIP
785
951
 
786
952
  if preemptible:
@@ -792,8 +958,7 @@ class AWSProvisioner(AbstractProvisioner):
792
958
  node_type = next(iter(nodeTypes))
793
959
  type_info = E2Instances[node_type]
794
960
  root_vol_size = self._nodeStorageOverrides.get(node_type, self._nodeStorage)
795
- bdm = self._getBoto2BlockDeviceMapping(type_info,
796
- rootVolSize=root_vol_size)
961
+ bdm = self._getBoto3BlockDeviceMapping(type_info, rootVolSize=root_vol_size)
797
962
 
798
963
  # Pick a zone and subnet_id to launch into
799
964
  if preemptible:
@@ -803,7 +968,7 @@ class AWSProvisioner(AbstractProvisioner):
803
968
  # We're allowed to pick from any of these zones.
804
969
  zone_options = list(self._worker_subnets_by_zone.keys())
805
970
 
806
- zone = get_best_aws_zone(spotBid, type_info.name, ec2, zone_options)
971
+ zone = get_best_aws_zone(spotBid, type_info.name, boto3_ec2, zone_options)
807
972
  else:
808
973
  # We don't need to ever do any balancing across zones for on-demand
809
974
  # instances. Just pick a zone.
@@ -814,6 +979,11 @@ class AWSProvisioner(AbstractProvisioner):
814
979
  # The workers aren't allowed in the leader's zone.
815
980
  # Pick an arbitrary zone we can use.
816
981
  zone = next(iter(self._worker_subnets_by_zone.keys()))
982
+ if zone is None:
983
+ logger.exception(
984
+ "Could not find a valid zone. Make sure TOIL_AWS_ZONE is set or spot bids are not too low."
985
+ )
986
+ raise NoSuchZoneException()
817
987
  if self._leader_subnet in self._worker_subnets_by_zone.get(zone, []):
818
988
  # The leader's subnet is an option for this zone, so use it.
819
989
  subnet_id = self._leader_subnet
@@ -822,21 +992,38 @@ class AWSProvisioner(AbstractProvisioner):
822
992
  subnet_id = next(iter(self._worker_subnets_by_zone[zone]))
823
993
 
824
994
  keyPath = self._sseKey if self._sseKey else None
825
- userData = self._getIgnitionUserData('worker', keyPath, preemptible, self._architecture)
995
+ userData: str = self._getIgnitionUserData(
996
+ "worker", keyPath, preemptible, self._architecture
997
+ )
998
+ userDataBytes: bytes = b""
826
999
  if isinstance(userData, str):
827
1000
  # Spot-market provisioning requires bytes for user data.
828
- userData = userData.encode('utf-8')
829
-
830
- kwargs = {'key_name': self._keyName,
831
- 'security_group_ids': self._getSecurityGroupIDs(),
832
- 'instance_type': type_info.name,
833
- 'user_data': userData,
834
- 'block_device_map': bdm,
835
- 'instance_profile_arn': self._leaderProfileArn,
836
- 'placement': zone,
837
- 'subnet_id': subnet_id}
838
-
839
- instancesLaunched = []
1001
+ userDataBytes = userData.encode("utf-8")
1002
+
1003
+ spot_kwargs = {
1004
+ "KeyName": self._keyName,
1005
+ "LaunchSpecification": {
1006
+ "SecurityGroupIds": self._getSecurityGroupIDs(),
1007
+ "InstanceType": type_info.name,
1008
+ "UserData": userDataBytes,
1009
+ "BlockDeviceMappings": bdm,
1010
+ "IamInstanceProfile": {"Arn": self._leaderProfileArn},
1011
+ "Placement": {"AvailabilityZone": zone},
1012
+ "SubnetId": subnet_id,
1013
+ },
1014
+ }
1015
+ on_demand_kwargs = {
1016
+ "KeyName": self._keyName,
1017
+ "SecurityGroupIds": self._getSecurityGroupIDs(),
1018
+ "InstanceType": type_info.name,
1019
+ "UserData": userDataBytes,
1020
+ "BlockDeviceMappings": bdm,
1021
+ "IamInstanceProfile": {"Arn": self._leaderProfileArn},
1022
+ "Placement": {"AvailabilityZone": zone},
1023
+ "SubnetId": subnet_id,
1024
+ }
1025
+
1026
+ instancesLaunched: list[InstanceTypeDef] = []
840
1027
 
841
1028
  for attempt in old_retry(predicate=awsRetryPredicate):
842
1029
  with attempt:
@@ -844,45 +1031,85 @@ class AWSProvisioner(AbstractProvisioner):
844
1031
  # the biggest obstacle is AWS request throttling, so we retry on these errors at
845
1032
  # every request in this method
846
1033
  if not preemptible:
847
- logger.debug('Launching %s non-preemptible nodes', numNodes)
848
- instancesLaunched = create_ondemand_instances(ec2,
849
- image_id=self._discoverAMI(),
850
- spec=kwargs, num_instances=numNodes)
1034
+ logger.debug("Launching %s non-preemptible nodes", numNodes)
1035
+ instancesLaunched = create_ondemand_instances(
1036
+ boto3_ec2=boto3_ec2,
1037
+ image_id=self._discoverAMI(),
1038
+ spec=on_demand_kwargs,
1039
+ num_instances=numNodes,
1040
+ )
851
1041
  else:
852
- logger.debug('Launching %s preemptible nodes', numNodes)
1042
+ logger.debug("Launching %s preemptible nodes", numNodes)
853
1043
  # force generator to evaluate
854
- instancesLaunched = list(create_spot_instances(ec2=ec2,
855
- price=spotBid,
856
- image_id=self._discoverAMI(),
857
- tags={_TAG_KEY_TOIL_CLUSTER_NAME: self.clusterName},
858
- spec=kwargs,
859
- num_instances=numNodes,
860
- tentative=True)
861
- )
1044
+ generatedInstancesLaunched: list[DescribeInstancesResultTypeDef] = (
1045
+ list(
1046
+ create_spot_instances(
1047
+ boto3_ec2=boto3_ec2,
1048
+ price=spotBid,
1049
+ image_id=self._discoverAMI(),
1050
+ tags={_TAG_KEY_TOIL_CLUSTER_NAME: self.clusterName},
1051
+ spec=spot_kwargs,
1052
+ num_instances=numNodes,
1053
+ tentative=True,
1054
+ )
1055
+ )
1056
+ )
862
1057
  # flatten the list
863
- instancesLaunched = [item for sublist in instancesLaunched for item in sublist]
1058
+ flatten_reservations: list[ReservationTypeDef] = [
1059
+ reservation
1060
+ for subdict in generatedInstancesLaunched
1061
+ for reservation in subdict["Reservations"]
1062
+ for key, value in subdict.items()
1063
+ ]
1064
+ # get a flattened list of all requested instances, as before instancesLaunched is a dict of reservations which is a dict of instance requests
1065
+ instancesLaunched = [
1066
+ instance
1067
+ for instances in flatten_reservations
1068
+ for instance in instances["Instances"]
1069
+ ]
864
1070
 
865
1071
  for attempt in old_retry(predicate=awsRetryPredicate):
866
1072
  with attempt:
867
- wait_instances_running(ec2, instancesLaunched)
1073
+ list(
1074
+ wait_instances_running(boto3_ec2, instancesLaunched)
1075
+ ) # ensure all instances are running
868
1076
 
869
- self._tags[_TAG_KEY_TOIL_NODE_TYPE] = 'worker'
870
- AWSProvisioner._addTags(instancesLaunched, self._tags)
1077
+ increase_instance_hop_limit(boto3_ec2, instancesLaunched)
1078
+
1079
+ self._tags[_TAG_KEY_TOIL_NODE_TYPE] = "worker"
1080
+ AWSProvisioner._addTags(boto3_ec2, instancesLaunched, self._tags)
871
1081
  if self._sseKey:
872
1082
  for i in instancesLaunched:
873
1083
  self._waitForIP(i)
874
- node = Node(publicIP=i.ip_address, privateIP=i.private_ip_address, name=i.id,
875
- launchTime=i.launch_time, nodeType=i.instance_type, preemptible=preemptible,
876
- tags=i.tags)
877
- node.waitForNode('toil_worker')
878
- node.coreRsync([self._sseKey, ':' + self._sseKey], applianceName='toil_worker')
879
- logger.debug('Launched %s new instance(s)', numNodes)
1084
+ node = Node(
1085
+ publicIP=i["PublicIpAddress"],
1086
+ privateIP=i["PrivateIpAddress"],
1087
+ name=i["InstanceId"],
1088
+ launchTime=i["LaunchTime"],
1089
+ nodeType=i["InstanceType"],
1090
+ preemptible=preemptible,
1091
+ tags=collapse_tags(i["Tags"]),
1092
+ )
1093
+ node.waitForNode("toil_worker")
1094
+ node.coreRsync(
1095
+ [self._sseKey, ":" + self._sseKey], applianceName="toil_worker"
1096
+ )
1097
+ logger.debug("Launched %s new instance(s)", numNodes)
880
1098
  return len(instancesLaunched)
881
1099
 
882
- def addManagedNodes(self, nodeTypes: Set[str], minNodes, maxNodes, preemptible, spotBid=None) -> None:
883
-
884
- if self.clusterType != 'kubernetes':
885
- raise ManagedNodesNotSupportedException("Managed nodes only supported for Kubernetes clusters")
1100
+ def addManagedNodes(
1101
+ self,
1102
+ nodeTypes: set[str],
1103
+ minNodes: int,
1104
+ maxNodes: int,
1105
+ preemptible: bool,
1106
+ spotBid: float | None = None,
1107
+ ) -> None:
1108
+
1109
+ if self.clusterType != "kubernetes":
1110
+ raise ManagedNodesNotSupportedException(
1111
+ "Managed nodes only supported for Kubernetes clusters"
1112
+ )
886
1113
 
887
1114
  assert self._leaderPrivateIP
888
1115
 
@@ -894,25 +1121,51 @@ class AWSProvisioner(AbstractProvisioner):
894
1121
 
895
1122
  # Make one template per node type, so we can apply storage overrides correctly
896
1123
  # TODO: deduplicate these if the same instance type appears in multiple sets?
897
- launch_template_ids = {n: self._get_worker_launch_template(n, preemptible=preemptible) for n in nodeTypes}
1124
+ launch_template_ids = {
1125
+ n: self._get_worker_launch_template(n, preemptible=preemptible)
1126
+ for n in nodeTypes
1127
+ }
898
1128
  # Make the ASG across all of them
899
- self._createWorkerAutoScalingGroup(launch_template_ids, nodeTypes, minNodes, maxNodes,
900
- spot_bid=spotBid)
1129
+ self._createWorkerAutoScalingGroup(
1130
+ launch_template_ids, nodeTypes, minNodes, maxNodes, spot_bid=spotBid
1131
+ )
901
1132
 
902
- def getProvisionedWorkers(self, instance_type: Optional[str] = None, preemptible: Optional[bool] = None) -> List[Node]:
1133
+ def getProvisionedWorkers(
1134
+ self, instance_type: str | None = None, preemptible: bool | None = None
1135
+ ) -> list[Node]:
903
1136
  assert self._leaderPrivateIP
904
- entireCluster = self._get_nodes_in_cluster(instance_type=instance_type)
905
- logger.debug('All nodes in cluster: %s', entireCluster)
906
- workerInstances = [i for i in entireCluster if i.private_ip_address != self._leaderPrivateIP]
907
- logger.debug('All workers found in cluster: %s', workerInstances)
1137
+ entireCluster = self._get_nodes_in_cluster_boto3(instance_type=instance_type)
1138
+ logger.debug("All nodes in cluster: %s", entireCluster)
1139
+ workerInstances: list[InstanceTypeDef] = [
1140
+ i for i in entireCluster if i["PrivateIpAddress"] != self._leaderPrivateIP
1141
+ ]
1142
+ logger.debug("All workers found in cluster: %s", workerInstances)
908
1143
  if preemptible is not None:
909
- workerInstances = [i for i in workerInstances if preemptible == (i.spot_instance_request_id is not None)]
910
- logger.debug('%spreemptible workers found in cluster: %s', 'non-' if not preemptible else '', workerInstances)
911
- workerInstances = awsFilterImpairedNodes(workerInstances, self.aws.boto2(self._region, 'ec2'))
912
- return [Node(publicIP=i.ip_address, privateIP=i.private_ip_address,
913
- name=i.id, launchTime=i.launch_time, nodeType=i.instance_type,
914
- preemptible=i.spot_instance_request_id is not None, tags=i.tags)
915
- for i in workerInstances]
1144
+ workerInstances = [
1145
+ i
1146
+ for i in workerInstances
1147
+ if preemptible == (i["SpotInstanceRequestId"] is not None)
1148
+ ]
1149
+ logger.debug(
1150
+ "%spreemptible workers found in cluster: %s",
1151
+ "non-" if not preemptible else "",
1152
+ workerInstances,
1153
+ )
1154
+ workerInstances = awsFilterImpairedNodes(
1155
+ workerInstances, self.aws.client(self._region, "ec2")
1156
+ )
1157
+ return [
1158
+ Node(
1159
+ publicIP=i["PublicIpAddress"],
1160
+ privateIP=i["PrivateIpAddress"],
1161
+ name=i["InstanceId"],
1162
+ launchTime=i["LaunchTime"],
1163
+ nodeType=i["InstanceType"],
1164
+ preemptible=i["SpotInstanceRequestId"] is not None,
1165
+ tags=collapse_tags(i["Tags"]),
1166
+ )
1167
+ for i in workerInstances
1168
+ ]
916
1169
 
917
1170
  @memoize
918
1171
  def _discoverAMI(self) -> str:
@@ -920,17 +1173,19 @@ class AWSProvisioner(AbstractProvisioner):
920
1173
  :return: The AMI ID (a string like 'ami-0a9a5d2b65cce04eb') for Flatcar.
921
1174
  :rtype: str
922
1175
  """
923
- return get_flatcar_ami(self.aws.client(self._region, 'ec2'), self._architecture)
1176
+ return get_flatcar_ami(self.aws.client(self._region, "ec2"), self._architecture)
924
1177
 
925
1178
  def _toNameSpace(self) -> str:
926
1179
  assert isinstance(self.clusterName, (str, bytes))
927
- if any(char.isupper() for char in self.clusterName) or '_' in self.clusterName:
928
- raise RuntimeError("The cluster name must be lowercase and cannot contain the '_' "
929
- "character.")
1180
+ if any(char.isupper() for char in self.clusterName) or "_" in self.clusterName:
1181
+ raise RuntimeError(
1182
+ "The cluster name must be lowercase and cannot contain the '_' "
1183
+ "character."
1184
+ )
930
1185
  namespace = self.clusterName
931
- if not namespace.startswith('/'):
932
- namespace = '/' + namespace + '/'
933
- return namespace.replace('-', '/')
1186
+ if not namespace.startswith("/"):
1187
+ namespace = "/" + namespace + "/"
1188
+ return namespace.replace("-", "/")
934
1189
 
935
1190
  def _namespace_name(self, name: str) -> str:
936
1191
  """
@@ -941,7 +1196,7 @@ class AWSProvisioner(AbstractProvisioner):
941
1196
  # This logic is a bit weird, but it's what Boto2Context used to use.
942
1197
  # Drop the leading / from the absolute-path-style "namespace" name and
943
1198
  # then encode underscores and slashes.
944
- return (self._toNameSpace() + name)[1:].replace('_', '__').replace('/', '_')
1199
+ return (self._toNameSpace() + name)[1:].replace("_", "__").replace("/", "_")
945
1200
 
946
1201
  def _is_our_namespaced_name(self, namespaced_name: str) -> bool:
947
1202
  """
@@ -949,93 +1204,153 @@ class AWSProvisioner(AbstractProvisioner):
949
1204
  and was generated by _namespace_name().
950
1205
  """
951
1206
 
952
- denamespaced = '/' + '_'.join(s.replace('_', '/') for s in namespaced_name.split('__'))
1207
+ denamespaced = "/" + "_".join(
1208
+ s.replace("_", "/") for s in namespaced_name.split("__")
1209
+ )
953
1210
  return denamespaced.startswith(self._toNameSpace())
954
1211
 
1212
+ def _getLeaderInstanceBoto3(self) -> InstanceTypeDef:
1213
+ """
1214
+ Get the Boto 3 instance for the cluster's leader.
1215
+ """
1216
+ # Tags are stored differently in Boto 3
1217
+ instances: list[InstanceTypeDef] = self._get_nodes_in_cluster_boto3(
1218
+ include_stopped_nodes=True
1219
+ )
1220
+ instances.sort(key=lambda x: x["LaunchTime"])
1221
+ try:
1222
+ leader = instances[0] # assume leader was launched first
1223
+ except IndexError:
1224
+ raise NoSuchClusterException(self.clusterName)
1225
+ if leader.get("Tags") is not None:
1226
+ tag_value = next(
1227
+ item["Value"]
1228
+ for item in leader["Tags"]
1229
+ if item["Key"] == _TAG_KEY_TOIL_NODE_TYPE
1230
+ )
1231
+ else:
1232
+ tag_value = None
1233
+ if (tag_value or "leader") != "leader":
1234
+ raise InvalidClusterStateException(
1235
+ "Invalid cluster state! The first launched instance appears not to be the leader "
1236
+ 'as it is missing the "leader" tag. The safest recovery is to destroy the cluster '
1237
+ "and restart the job. Incorrect Leader ID: %s" % leader["InstanceId"]
1238
+ )
1239
+ return leader
955
1240
 
956
- def _getLeaderInstance(self) -> Boto2Instance:
1241
+ def _getLeaderInstance(self) -> InstanceTypeDef:
957
1242
  """
958
1243
  Get the Boto 2 instance for the cluster's leader.
959
1244
  """
960
- instances = self._get_nodes_in_cluster(include_stopped_nodes=True)
961
- instances.sort(key=lambda x: x.launch_time)
1245
+ instances = self._get_nodes_in_cluster_boto3(include_stopped_nodes=True)
1246
+ instances.sort(key=lambda x: x["LaunchTime"])
962
1247
  try:
963
- leader = instances[0] # assume leader was launched first
1248
+ leader: InstanceTypeDef = instances[0] # assume leader was launched first
964
1249
  except IndexError:
965
1250
  raise NoSuchClusterException(self.clusterName)
966
- if (leader.tags.get(_TAG_KEY_TOIL_NODE_TYPE) or 'leader') != 'leader':
1251
+ tagged_node_type: str = "leader"
1252
+ for tag in leader["Tags"]:
1253
+ # If a tag specifying node type exists,
1254
+ if tag.get("Key") is not None and tag["Key"] == _TAG_KEY_TOIL_NODE_TYPE:
1255
+ tagged_node_type = tag["Value"]
1256
+ if tagged_node_type != "leader":
967
1257
  raise InvalidClusterStateException(
968
- 'Invalid cluster state! The first launched instance appears not to be the leader '
1258
+ "Invalid cluster state! The first launched instance appears not to be the leader "
969
1259
  'as it is missing the "leader" tag. The safest recovery is to destroy the cluster '
970
- 'and restart the job. Incorrect Leader ID: %s' % leader.id
1260
+ "and restart the job. Incorrect Leader ID: %s" % leader["InstanceId"]
971
1261
  )
972
1262
  return leader
973
1263
 
974
- def getLeader(self, wait=False) -> Node:
1264
+ def getLeader(self, wait: bool = False) -> Node:
975
1265
  """
976
1266
  Get the leader for the cluster as a Toil Node object.
977
1267
  """
978
- leader = self._getLeaderInstance()
979
-
980
- leaderNode = Node(publicIP=leader.ip_address, privateIP=leader.private_ip_address,
981
- name=leader.id, launchTime=leader.launch_time, nodeType=None,
982
- preemptible=False, tags=leader.tags)
1268
+ leader: InstanceTypeDef = self._getLeaderInstanceBoto3()
1269
+
1270
+ leaderNode = Node(
1271
+ publicIP=leader["PublicIpAddress"],
1272
+ privateIP=leader["PrivateIpAddress"],
1273
+ name=leader["InstanceId"],
1274
+ launchTime=leader["LaunchTime"],
1275
+ nodeType=None,
1276
+ preemptible=False,
1277
+ tags=collapse_tags(leader["Tags"]),
1278
+ )
983
1279
  if wait:
984
1280
  logger.debug("Waiting for toil_leader to enter 'running' state...")
985
- wait_instances_running(self.aws.boto2(self._region, 'ec2'), [leader])
986
- logger.debug('... toil_leader is running')
1281
+ wait_instances_running(self.aws.client(self._region, "ec2"), [leader])
1282
+ logger.debug("... toil_leader is running")
987
1283
  self._waitForIP(leader)
988
- leaderNode.waitForNode('toil_leader')
1284
+ leaderNode.waitForNode("toil_leader")
989
1285
 
990
1286
  return leaderNode
991
1287
 
992
1288
  @classmethod
993
1289
  @awsRetry
994
- def _addTag(cls, instance: Boto2Instance, key: str, value: str):
995
- instance.add_tag(key, value)
1290
+ def _addTag(
1291
+ cls, boto3_ec2: EC2Client, instance: InstanceTypeDef, key: str, value: str
1292
+ ) -> None:
1293
+ if instance.get("Tags") is None:
1294
+ instance["Tags"] = []
1295
+ new_tag: TagTypeDef = {"Key": key, "Value": value}
1296
+ boto3_ec2.create_tags(Resources=[instance["InstanceId"]], Tags=[new_tag])
996
1297
 
997
1298
  @classmethod
998
- def _addTags(cls, instances: List[Boto2Instance], tags: Dict[str, str]):
1299
+ def _addTags(
1300
+ cls,
1301
+ boto3_ec2: EC2Client,
1302
+ instances: list[InstanceTypeDef],
1303
+ tags: dict[str, str],
1304
+ ) -> None:
999
1305
  for instance in instances:
1000
1306
  for key, value in tags.items():
1001
- cls._addTag(instance, key, value)
1307
+ cls._addTag(boto3_ec2, instance, key, value)
1002
1308
 
1003
1309
  @classmethod
1004
- def _waitForIP(cls, instance: Boto2Instance):
1310
+ def _waitForIP(cls, instance: InstanceTypeDef) -> None:
1005
1311
  """
1006
1312
  Wait until the instances has a public IP address assigned to it.
1007
1313
 
1008
1314
  :type instance: boto.ec2.instance.Instance
1009
1315
  """
1010
- logger.debug('Waiting for ip...')
1316
+ logger.debug("Waiting for ip...")
1011
1317
  while True:
1012
1318
  time.sleep(a_short_time)
1013
- instance.update()
1014
- if instance.ip_address or instance.public_dns_name or instance.private_ip_address:
1015
- logger.debug('...got ip')
1319
+ if (
1320
+ instance.get("PublicIpAddress")
1321
+ or instance.get("PublicDnsName")
1322
+ or instance.get("PrivateIpAddress")
1323
+ ):
1324
+ logger.debug("...got ip")
1016
1325
  break
1017
1326
 
1018
- def _terminateInstances(self, instances: List[Boto2Instance]):
1019
- instanceIDs = [x.id for x in instances]
1327
+ def _terminateInstances(self, instances: list[InstanceTypeDef]) -> None:
1328
+ instanceIDs = [x["InstanceId"] for x in instances]
1020
1329
  self._terminateIDs(instanceIDs)
1021
- logger.info('... Waiting for instance(s) to shut down...')
1330
+ logger.info("... Waiting for instance(s) to shut down...")
1022
1331
  for instance in instances:
1023
- wait_transition(instance, {'pending', 'running', 'shutting-down', 'stopping', 'stopped'}, 'terminated')
1024
- logger.info('Instance(s) terminated.')
1332
+ wait_transition(
1333
+ self.aws.client(region=self._region, service_name="ec2"),
1334
+ instance,
1335
+ {"pending", "running", "shutting-down", "stopping", "stopped"},
1336
+ "terminated",
1337
+ )
1338
+ logger.info("Instance(s) terminated.")
1025
1339
 
1026
1340
  @awsRetry
1027
- def _terminateIDs(self, instanceIDs: List[str]):
1028
- logger.info('Terminating instance(s): %s', instanceIDs)
1029
- self.aws.boto2(self._region, 'ec2').terminate_instances(instance_ids=instanceIDs)
1030
- logger.info('Instance(s) terminated.')
1341
+ def _terminateIDs(self, instanceIDs: list[str]) -> None:
1342
+ logger.info("Terminating instance(s): %s", instanceIDs)
1343
+ boto3_ec2 = self.aws.client(region=self._region, service_name="ec2")
1344
+ boto3_ec2.terminate_instances(InstanceIds=instanceIDs)
1345
+ logger.info("Instance(s) terminated.")
1031
1346
 
1032
1347
  @awsRetry
1033
- def _deleteRoles(self, names: List[str]):
1348
+ def _deleteRoles(self, names: list[str]) -> None:
1034
1349
  """
1035
1350
  Delete all the given named IAM roles.
1036
1351
  Detatches but does not delete associated instance profiles.
1037
1352
  """
1038
-
1353
+ boto3_iam = self.aws.client(region=self._region, service_name="iam")
1039
1354
  for role_name in names:
1040
1355
  for profile_name in self._getRoleInstanceProfileNames(role_name):
1041
1356
  # We can't delete either the role or the profile while they
@@ -1043,177 +1358,244 @@ class AWSProvisioner(AbstractProvisioner):
1043
1358
 
1044
1359
  for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors):
1045
1360
  with attempt:
1046
- self.aws.client(self._region, 'iam').remove_role_from_instance_profile(InstanceProfileName=profile_name,
1047
- RoleName=role_name)
1361
+ boto3_iam.remove_role_from_instance_profile(
1362
+ InstanceProfileName=profile_name, RoleName=role_name
1363
+ )
1048
1364
  # We also need to drop all inline policies
1049
1365
  for policy_name in self._getRoleInlinePolicyNames(role_name):
1050
1366
  for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors):
1051
1367
  with attempt:
1052
- self.aws.client(self._region, 'iam').delete_role_policy(PolicyName=policy_name,
1053
- RoleName=role_name)
1368
+ boto3_iam.delete_role_policy(
1369
+ PolicyName=policy_name, RoleName=role_name
1370
+ )
1054
1371
 
1055
1372
  for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors):
1056
1373
  with attempt:
1057
- self.aws.client(self._region, 'iam').delete_role(RoleName=role_name)
1058
- logger.debug('... Successfully deleted IAM role %s', role_name)
1059
-
1374
+ boto3_iam.delete_role(RoleName=role_name)
1375
+ logger.debug("... Successfully deleted IAM role %s", role_name)
1060
1376
 
1061
1377
  @awsRetry
1062
- def _deleteInstanceProfiles(self, names: List[str]):
1378
+ def _deleteInstanceProfiles(self, names: list[str]) -> None:
1063
1379
  """
1064
1380
  Delete all the given named IAM instance profiles.
1065
1381
  All roles must already be detached.
1066
1382
  """
1067
-
1383
+ boto3_iam = self.aws.client(region=self._region, service_name="iam")
1068
1384
  for profile_name in names:
1069
1385
  for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors):
1070
1386
  with attempt:
1071
- self.aws.client(self._region, 'iam').delete_instance_profile(InstanceProfileName=profile_name)
1072
- logger.debug('... Succesfully deleted instance profile %s', profile_name)
1387
+ boto3_iam.delete_instance_profile(InstanceProfileName=profile_name)
1388
+ logger.debug(
1389
+ "... Succesfully deleted instance profile %s", profile_name
1390
+ )
1073
1391
 
1074
1392
  @classmethod
1075
- def _getBoto2BlockDeviceMapping(cls, type_info: InstanceType, rootVolSize: int = 50) -> Boto2BlockDeviceMapping:
1393
+ def _getBoto3BlockDeviceMapping(
1394
+ cls, type_info: InstanceType, rootVolSize: int = 50
1395
+ ) -> list[BlockDeviceMappingTypeDef]:
1076
1396
  # determine number of ephemeral drives via cgcloud-lib (actually this is moved into toil's lib
1077
- bdtKeys = [''] + [f'/dev/xvd{c}' for c in string.ascii_lowercase[1:]]
1078
- bdm = Boto2BlockDeviceMapping()
1397
+ bdtKeys = [""] + [f"/dev/xvd{c}" for c in string.ascii_lowercase[1:]]
1398
+ bdm_list: list[BlockDeviceMappingTypeDef] = []
1079
1399
  # Change root volume size to allow for bigger Docker instances
1080
- root_vol = Boto2BlockDeviceType(delete_on_termination=True)
1081
- root_vol.size = rootVolSize
1082
- bdm["/dev/xvda"] = root_vol
1400
+ root_vol: EbsBlockDeviceTypeDef = {
1401
+ "DeleteOnTermination": True,
1402
+ "VolumeSize": rootVolSize,
1403
+ }
1404
+ bdm: BlockDeviceMappingTypeDef = {"DeviceName": "/dev/xvda", "Ebs": root_vol}
1405
+ bdm_list.append(bdm)
1083
1406
  # The first disk is already attached for us so start with 2nd.
1084
1407
  # Disk count is weirdly a float in our instance database, so make it an int here.
1085
1408
  for disk in range(1, int(type_info.disks) + 1):
1086
- bdm[bdtKeys[disk]] = Boto2BlockDeviceType(
1087
- ephemeral_name=f'ephemeral{disk - 1}') # ephemeral counts start at 0
1409
+ bdm = {}
1410
+ bdm["DeviceName"] = bdtKeys[disk]
1411
+ bdm["VirtualName"] = f"ephemeral{disk - 1}" # ephemeral counts start at 0
1412
+ bdm["Ebs"] = root_vol # default
1413
+ # bdm["Ebs"] = root_vol.update({"VirtualName": f"ephemeral{disk - 1}"})
1414
+ bdm_list.append(bdm)
1088
1415
 
1089
- logger.debug('Device mapping: %s', bdm)
1090
- return bdm
1416
+ logger.debug("Device mapping: %s", bdm_list)
1417
+ return bdm_list
1091
1418
 
1092
1419
  @classmethod
1093
- def _getBoto3BlockDeviceMappings(cls, type_info: InstanceType, rootVolSize: int = 50) -> List[dict]:
1420
+ def _getBoto3BlockDeviceMappings(
1421
+ cls, type_info: InstanceType, rootVolSize: int = 50
1422
+ ) -> list[BlockDeviceMappingTypeDef]:
1094
1423
  """
1095
1424
  Get block device mappings for the root volume for a worker.
1096
1425
  """
1097
1426
 
1098
1427
  # Start with the root
1099
- bdms = [{
1100
- 'DeviceName': '/dev/xvda',
1101
- 'Ebs': {
1102
- 'DeleteOnTermination': True,
1103
- 'VolumeSize': rootVolSize,
1104
- 'VolumeType': 'gp2'
1428
+ bdms: list[BlockDeviceMappingTypeDef] = [
1429
+ {
1430
+ "DeviceName": "/dev/xvda",
1431
+ "Ebs": {
1432
+ "DeleteOnTermination": True,
1433
+ "VolumeSize": rootVolSize,
1434
+ "VolumeType": "gp2",
1435
+ },
1105
1436
  }
1106
- }]
1437
+ ]
1107
1438
 
1108
1439
  # Get all the virtual drives we might have
1109
- bdtKeys = [f'/dev/xvd{c}' for c in string.ascii_lowercase]
1440
+ bdtKeys = [f"/dev/xvd{c}" for c in string.ascii_lowercase]
1110
1441
 
1111
1442
  # The first disk is already attached for us so start with 2nd.
1112
1443
  # Disk count is weirdly a float in our instance database, so make it an int here.
1113
1444
  for disk in range(1, int(type_info.disks) + 1):
1114
1445
  # Make a block device mapping to attach the ephemeral disk to a
1115
1446
  # virtual block device in the VM
1116
- bdms.append({
1117
- 'DeviceName': bdtKeys[disk],
1118
- 'VirtualName': f'ephemeral{disk - 1}' # ephemeral counts start at 0
1119
- })
1120
- logger.debug('Device mapping: %s', bdms)
1447
+ bdms.append(
1448
+ {
1449
+ "DeviceName": bdtKeys[disk],
1450
+ "VirtualName": f"ephemeral{disk - 1}", # ephemeral counts start at 0
1451
+ }
1452
+ )
1453
+ logger.debug("Device mapping: %s", bdms)
1121
1454
  return bdms
1122
1455
 
1123
1456
  @awsRetry
1124
- def _get_nodes_in_cluster(self, instance_type: Optional[str] = None, include_stopped_nodes=False) -> List[Boto2Instance]:
1125
- """
1126
- Get Boto2 instance objects for all nodes in the cluster.
1127
- """
1128
-
1129
- all_instances = self.aws.boto2(self._region, 'ec2').get_only_instances(filters={'instance.group-name': self.clusterName})
1130
-
1131
- def instanceFilter(i):
1457
+ def _get_nodes_in_cluster_boto3(
1458
+ self, instance_type: str | None = None, include_stopped_nodes: bool = False
1459
+ ) -> list[InstanceTypeDef]:
1460
+ """
1461
+ Get Boto3 instance objects for all nodes in the cluster.
1462
+ """
1463
+ boto3_ec2: EC2Client = self.aws.client(region=self._region, service_name="ec2")
1464
+ instance_filter: FilterTypeDef = {
1465
+ "Name": "instance.group-name",
1466
+ "Values": [self.clusterName],
1467
+ }
1468
+ describe_response: DescribeInstancesResultTypeDef = (
1469
+ boto3_ec2.describe_instances(Filters=[instance_filter])
1470
+ )
1471
+ all_instances: list[InstanceTypeDef] = []
1472
+ for reservation in describe_response["Reservations"]:
1473
+ instances = reservation["Instances"]
1474
+ all_instances.extend(instances)
1475
+
1476
+ # all_instances = self.aws.boto2(self._region, 'ec2').get_only_instances(filters={'instance.group-name': self.clusterName})
1477
+
1478
+ def instanceFilter(i: InstanceTypeDef) -> bool:
1132
1479
  # filter by type only if nodeType is true
1133
- rightType = not instance_type or i.instance_type == instance_type
1134
- rightState = i.state == 'running' or i.state == 'pending'
1480
+ rightType = not instance_type or i["InstanceType"] == instance_type
1481
+ rightState = (
1482
+ i["State"]["Name"] == "running" or i["State"]["Name"] == "pending"
1483
+ )
1135
1484
  if include_stopped_nodes:
1136
- rightState = rightState or i.state == 'stopping' or i.state == 'stopped'
1485
+ rightState = (
1486
+ rightState
1487
+ or i["State"]["Name"] == "stopping"
1488
+ or i["State"]["Name"] == "stopped"
1489
+ )
1137
1490
  return rightType and rightState
1138
1491
 
1139
1492
  return [i for i in all_instances if instanceFilter(i)]
1140
1493
 
1141
- def _filter_nodes_in_cluster(self, instance_type: Optional[str] = None, preemptible: bool = False) -> List[Boto2Instance]:
1142
- """
1143
- Get Boto2 instance objects for the nodes in the cluster filtered by preemptability.
1144
- """
1145
-
1146
- instances = self._get_nodes_in_cluster(instance_type, include_stopped_nodes=False)
1147
-
1148
- if preemptible:
1149
- return [i for i in instances if i.spot_instance_request_id is not None]
1150
-
1151
- return [i for i in instances if i.spot_instance_request_id is None]
1152
-
1153
- def _getSpotRequestIDs(self) -> List[str]:
1494
+ def _getSpotRequestIDs(self) -> list[str]:
1154
1495
  """
1155
1496
  Get the IDs of all spot requests associated with the cluster.
1156
1497
  """
1157
1498
 
1158
1499
  # Grab the connection we need to use for this operation.
1159
- ec2 = self.aws.boto2(self._region, 'ec2')
1160
-
1161
- requests = ec2.get_all_spot_instance_requests()
1162
- tags = ec2.get_all_tags({'tag:': {_TAG_KEY_TOIL_CLUSTER_NAME: self.clusterName}})
1163
- idsToCancel = [tag.id for tag in tags]
1164
- return [request for request in requests if request.id in idsToCancel]
1165
-
1166
- def _createSecurityGroups(self) -> List[str]:
1500
+ ec2: EC2Client = self.aws.client(self._region, "ec2")
1501
+
1502
+ requests: list[SpotInstanceRequestTypeDef] = (
1503
+ ec2.describe_spot_instance_requests()["SpotInstanceRequests"]
1504
+ )
1505
+ tag_filter: FilterTypeDef = {
1506
+ "Name": "tag:" + _TAG_KEY_TOIL_CLUSTER_NAME,
1507
+ "Values": [self.clusterName],
1508
+ }
1509
+ tags: list[TagDescriptionTypeDef] = ec2.describe_tags(Filters=[tag_filter])[
1510
+ "Tags"
1511
+ ]
1512
+ idsToCancel = [tag["ResourceId"] for tag in tags]
1513
+ return [
1514
+ request["SpotInstanceRequestId"]
1515
+ for request in requests
1516
+ if request["InstanceId"] in idsToCancel
1517
+ ]
1518
+
1519
+ def _createSecurityGroups(self) -> list[str]:
1167
1520
  """
1168
1521
  Create security groups for the cluster. Returns a list of their IDs.
1169
1522
  """
1170
1523
 
1524
+ def group_not_found(e: ClientError) -> bool:
1525
+ retry = get_error_status(
1526
+ e
1527
+ ) == 400 and "does not exist in default VPC" in get_error_body(e)
1528
+ return retry
1529
+
1171
1530
  # Grab the connection we need to use for this operation.
1172
1531
  # The VPC connection can do anything the EC2 one can do, but also look at subnets.
1173
- vpc = self.aws.boto2(self._region, 'vpc')
1532
+ boto3_ec2: EC2Client = self.aws.client(region=self._region, service_name="ec2")
1174
1533
 
1175
- def groupNotFound(e):
1176
- retry = (e.status == 400 and 'does not exist in default VPC' in e.body)
1177
- return retry
1178
- # Security groups need to belong to the same VPC as the leader. If we
1179
- # put the leader in a particular non-default subnet, it may be in a
1180
- # particular non-default VPC, which we need to know about.
1181
- vpcId = None
1534
+ vpc_id = None
1182
1535
  if self._leader_subnet:
1183
- subnets = vpc.get_all_subnets(subnet_ids=[self._leader_subnet])
1536
+ subnets = boto3_ec2.describe_subnets(SubnetIds=[self._leader_subnet])[
1537
+ "Subnets"
1538
+ ]
1184
1539
  if len(subnets) > 0:
1185
- vpcId = subnets[0].vpc_id
1186
- # security group create/get. ssh + all ports open within the group
1540
+ vpc_id = subnets[0]["VpcId"]
1187
1541
  try:
1188
- web = vpc.create_security_group(self.clusterName,
1189
- 'Toil appliance security group', vpc_id=vpcId)
1190
- except EC2ResponseError as e:
1191
- if e.status == 400 and 'already exists' in e.body:
1192
- pass # group exists- nothing to do
1542
+ # Security groups need to belong to the same VPC as the leader. If we
1543
+ # put the leader in a particular non-default subnet, it may be in a
1544
+ # particular non-default VPC, which we need to know about.
1545
+ other = {
1546
+ "GroupName": self.clusterName,
1547
+ "Description": "Toil appliance security group",
1548
+ }
1549
+ if vpc_id is not None:
1550
+ other["VpcId"] = vpc_id
1551
+ # mypy stubs don't explicitly state kwargs even though documentation allows it, and mypy gets confused
1552
+ web_response: CreateSecurityGroupResultTypeDef = boto3_ec2.create_security_group(**other) # type: ignore[arg-type]
1553
+ except ClientError as e:
1554
+ if get_error_status(e) == 400 and "already exists" in get_error_body(e):
1555
+ pass
1193
1556
  else:
1194
1557
  raise
1195
1558
  else:
1196
- for attempt in old_retry(predicate=groupNotFound, timeout=300):
1197
- with attempt:
1198
- # open port 22 for ssh-ing
1199
- web.authorize(ip_protocol='tcp', from_port=22, to_port=22, cidr_ip='0.0.0.0/0')
1200
- # TODO: boto2 doesn't support IPv6 here but we need to.
1201
- for attempt in old_retry(predicate=groupNotFound, timeout=300):
1559
+ for attempt in old_retry(predicate=group_not_found, timeout=300):
1202
1560
  with attempt:
1203
- # the following authorizes all TCP access within the web security group
1204
- web.authorize(ip_protocol='tcp', from_port=0, to_port=65535, src_group=web)
1205
- for attempt in old_retry(predicate=groupNotFound, timeout=300):
1206
- with attempt:
1207
- # We also want to open up UDP, both for user code and for the RealtimeLogger
1208
- web.authorize(ip_protocol='udp', from_port=0, to_port=65535, src_group=web)
1561
+ ip_permissions: list[IpPermissionTypeDef] = [
1562
+ {
1563
+ "IpProtocol": "tcp",
1564
+ "FromPort": 22,
1565
+ "ToPort": 22,
1566
+ "IpRanges": [{"CidrIp": "0.0.0.0/0"}],
1567
+ "Ipv6Ranges": [{"CidrIpv6": "::/0"}],
1568
+ }
1569
+ ]
1570
+ for protocol in ("tcp", "udp"):
1571
+ ip_permissions.append(
1572
+ {
1573
+ "IpProtocol": protocol,
1574
+ "FromPort": 0,
1575
+ "ToPort": 65535,
1576
+ "UserIdGroupPairs": [
1577
+ {
1578
+ "GroupId": web_response["GroupId"],
1579
+ "GroupName": self.clusterName,
1580
+ }
1581
+ ],
1582
+ }
1583
+ )
1584
+ boto3_ec2.authorize_security_group_ingress(
1585
+ IpPermissions=ip_permissions,
1586
+ GroupName=self.clusterName,
1587
+ GroupId=web_response["GroupId"],
1588
+ )
1209
1589
  out = []
1210
- for sg in vpc.get_all_security_groups():
1211
- if sg.name == self.clusterName and (vpcId is None or sg.vpc_id == vpcId):
1212
- out.append(sg)
1213
- return [sg.id for sg in out]
1590
+ for sg in boto3_ec2.describe_security_groups()["SecurityGroups"]:
1591
+ if sg["GroupName"] == self.clusterName and (
1592
+ vpc_id is None or sg["VpcId"] == vpc_id
1593
+ ):
1594
+ out.append(sg["GroupId"])
1595
+ return out
1214
1596
 
1215
1597
  @awsRetry
1216
- def _getSecurityGroupIDs(self) -> List[str]:
1598
+ def _getSecurityGroupIDs(self) -> list[str]:
1217
1599
  """
1218
1600
  Get all the security group IDs to apply to leaders and workers.
1219
1601
  """
@@ -1222,13 +1604,20 @@ class AWSProvisioner(AbstractProvisioner):
1222
1604
 
1223
1605
  # Depending on if we enumerated them on the leader or locally, we might
1224
1606
  # know the required security groups by name, ID, or both.
1225
- sgs = [sg for sg in self.aws.boto2(self._region, 'ec2').get_all_security_groups()
1226
- if (sg.name in self._leaderSecurityGroupNames or
1227
- sg.id in self._leaderSecurityGroupIDs)]
1228
- return [sg.id for sg in sgs]
1607
+ boto3_ec2 = self.aws.client(region=self._region, service_name="ec2")
1608
+ return [
1609
+ sg["GroupId"]
1610
+ for sg in boto3_ec2.describe_security_groups()["SecurityGroups"]
1611
+ if (
1612
+ sg["GroupName"] in self._leaderSecurityGroupNames
1613
+ or sg["GroupId"] in self._leaderSecurityGroupIDs
1614
+ )
1615
+ ]
1229
1616
 
1230
1617
  @awsRetry
1231
- def _get_launch_template_ids(self, filters: Optional[List[Dict[str, List[str]]]] = None) -> List[str]:
1618
+ def _get_launch_template_ids(
1619
+ self, filters: list[FilterTypeDef] | None = None
1620
+ ) -> list[str]:
1232
1621
  """
1233
1622
  Find all launch templates associated with the cluster.
1234
1623
 
@@ -1236,10 +1625,12 @@ class AWSProvisioner(AbstractProvisioner):
1236
1625
  """
1237
1626
 
1238
1627
  # Grab the connection we need to use for this operation.
1239
- ec2 = self.aws.client(self._region, 'ec2')
1628
+ ec2: EC2Client = self.aws.client(self._region, "ec2")
1240
1629
 
1241
1630
  # How do we match the right templates?
1242
- combined_filters = [{'Name': 'tag:' + _TAG_KEY_TOIL_CLUSTER_NAME, 'Values': [self.clusterName]}]
1631
+ combined_filters: list[FilterTypeDef] = [
1632
+ {"Name": "tag:" + _TAG_KEY_TOIL_CLUSTER_NAME, "Values": [self.clusterName]}
1633
+ ]
1243
1634
 
1244
1635
  if filters:
1245
1636
  # Add any user-specified filters
@@ -1247,16 +1638,21 @@ class AWSProvisioner(AbstractProvisioner):
1247
1638
 
1248
1639
  allTemplateIDs = []
1249
1640
  # Get the first page with no NextToken
1250
- response = ec2.describe_launch_templates(Filters=combined_filters,
1251
- MaxResults=200)
1641
+ response = ec2.describe_launch_templates(
1642
+ Filters=combined_filters, MaxResults=200
1643
+ )
1252
1644
  while True:
1253
1645
  # Process the current page
1254
- allTemplateIDs += [item['LaunchTemplateId'] for item in response.get('LaunchTemplates', [])]
1255
- if 'NextToken' in response:
1646
+ allTemplateIDs += [
1647
+ item["LaunchTemplateId"] for item in response.get("LaunchTemplates", [])
1648
+ ]
1649
+ if "NextToken" in response:
1256
1650
  # There are more pages. Get the next one, supplying the token.
1257
- response = ec2.describe_launch_templates(Filters=filters,
1258
- NextToken=response['NextToken'],
1259
- MaxResults=200)
1651
+ response = ec2.describe_launch_templates(
1652
+ Filters=filters or [],
1653
+ NextToken=response["NextToken"],
1654
+ MaxResults=200,
1655
+ )
1260
1656
  else:
1261
1657
  # No more pages
1262
1658
  break
@@ -1264,7 +1660,9 @@ class AWSProvisioner(AbstractProvisioner):
1264
1660
  return allTemplateIDs
1265
1661
 
1266
1662
  @awsRetry
1267
- def _get_worker_launch_template(self, instance_type: str, preemptible: bool = False, backoff: float = 1.0) -> str:
1663
+ def _get_worker_launch_template(
1664
+ self, instance_type: str, preemptible: bool = False, backoff: float = 1.0
1665
+ ) -> str:
1268
1666
  """
1269
1667
  Get a launch template for instances with the given parameters. Only one
1270
1668
  such launch template will be created, no matter how many times the
@@ -1283,36 +1681,55 @@ class AWSProvisioner(AbstractProvisioner):
1283
1681
  :return: The ID of the template.
1284
1682
  """
1285
1683
 
1286
- lt_name = self._name_worker_launch_template(instance_type, preemptible=preemptible)
1684
+ lt_name = self._name_worker_launch_template(
1685
+ instance_type, preemptible=preemptible
1686
+ )
1287
1687
 
1288
1688
  # How do we match the right templates?
1289
- filters = [{'Name': 'launch-template-name', 'Values': [lt_name]}]
1689
+ filters: list[FilterTypeDef] = [
1690
+ {"Name": "launch-template-name", "Values": [lt_name]}
1691
+ ]
1290
1692
 
1291
1693
  # Get the templates
1292
- templates = self._get_launch_template_ids(filters=filters)
1694
+ templates: list[str] = self._get_launch_template_ids(filters=filters)
1293
1695
 
1294
1696
  if len(templates) > 1:
1295
1697
  # There shouldn't ever be multiple templates with our reserved name
1296
- raise RuntimeError(f"Multiple launch templates already exist named {lt_name}; "
1297
- "something else is operating in our cluster namespace.")
1698
+ raise RuntimeError(
1699
+ f"Multiple launch templates already exist named {lt_name}; "
1700
+ "something else is operating in our cluster namespace."
1701
+ )
1298
1702
  elif len(templates) == 0:
1299
1703
  # Template doesn't exist so we can create it.
1300
1704
  try:
1301
- return self._create_worker_launch_template(instance_type, preemptible=preemptible)
1705
+ return self._create_worker_launch_template(
1706
+ instance_type, preemptible=preemptible
1707
+ )
1302
1708
  except ClientError as e:
1303
- if get_error_code(e) == 'InvalidLaunchTemplateName.AlreadyExistsException':
1709
+ if (
1710
+ get_error_code(e)
1711
+ == "InvalidLaunchTemplateName.AlreadyExistsException"
1712
+ ):
1304
1713
  # Someone got to it before us (or we couldn't read our own
1305
1714
  # writes). Recurse to try again, because now it exists.
1306
- logger.info('Waiting %f seconds for template %s to be available', backoff, lt_name)
1715
+ logger.info(
1716
+ "Waiting %f seconds for template %s to be available",
1717
+ backoff,
1718
+ lt_name,
1719
+ )
1307
1720
  time.sleep(backoff)
1308
- return self._get_worker_launch_template(instance_type, preemptible=preemptible, backoff=backoff*2)
1721
+ return self._get_worker_launch_template(
1722
+ instance_type, preemptible=preemptible, backoff=backoff * 2
1723
+ )
1309
1724
  else:
1310
1725
  raise
1311
1726
  else:
1312
1727
  # There must be exactly one template
1313
1728
  return templates[0]
1314
1729
 
1315
- def _name_worker_launch_template(self, instance_type: str, preemptible: bool = False) -> str:
1730
+ def _name_worker_launch_template(
1731
+ self, instance_type: str, preemptible: bool = False
1732
+ ) -> str:
1316
1733
  """
1317
1734
  Get the name we should use for the launch template with the given parameters.
1318
1735
 
@@ -1323,13 +1740,15 @@ class AWSProvisioner(AbstractProvisioner):
1323
1740
  """
1324
1741
 
1325
1742
  # The name has the cluster name in it
1326
- lt_name = f'{self.clusterName}-lt-{instance_type}'
1743
+ lt_name = f"{self.clusterName}-lt-{instance_type}"
1327
1744
  if preemptible:
1328
- lt_name += '-spot'
1745
+ lt_name += "-spot"
1329
1746
 
1330
1747
  return lt_name
1331
1748
 
1332
- def _create_worker_launch_template(self, instance_type: str, preemptible: bool = False) -> str:
1749
+ def _create_worker_launch_template(
1750
+ self, instance_type: str, preemptible: bool = False
1751
+ ) -> str:
1333
1752
  """
1334
1753
  Create the launch template for launching worker instances for the cluster.
1335
1754
 
@@ -1345,31 +1764,37 @@ class AWSProvisioner(AbstractProvisioner):
1345
1764
 
1346
1765
  assert self._leaderPrivateIP
1347
1766
  type_info = E2Instances[instance_type]
1348
- rootVolSize=self._nodeStorageOverrides.get(instance_type, self._nodeStorage)
1767
+ rootVolSize = self._nodeStorageOverrides.get(instance_type, self._nodeStorage)
1349
1768
  bdms = self._getBoto3BlockDeviceMappings(type_info, rootVolSize=rootVolSize)
1350
1769
 
1351
1770
  keyPath = self._sseKey if self._sseKey else None
1352
- userData = self._getIgnitionUserData('worker', keyPath, preemptible, self._architecture)
1771
+ userData = self._getIgnitionUserData(
1772
+ "worker", keyPath, preemptible, self._architecture
1773
+ )
1353
1774
 
1354
- lt_name = self._name_worker_launch_template(instance_type, preemptible=preemptible)
1775
+ lt_name = self._name_worker_launch_template(
1776
+ instance_type, preemptible=preemptible
1777
+ )
1355
1778
 
1356
1779
  # But really we find it by tag
1357
1780
  tags = dict(self._tags)
1358
- tags[_TAG_KEY_TOIL_NODE_TYPE] = 'worker'
1359
-
1360
- return create_launch_template(self.aws.client(self._region, 'ec2'),
1361
- template_name=lt_name,
1362
- image_id=self._discoverAMI(),
1363
- key_name=self._keyName,
1364
- security_group_ids=self._getSecurityGroupIDs(),
1365
- instance_type=instance_type,
1366
- user_data=userData,
1367
- block_device_map=bdms,
1368
- instance_profile_arn=self._leaderProfileArn,
1369
- tags=tags)
1781
+ tags[_TAG_KEY_TOIL_NODE_TYPE] = "worker"
1782
+
1783
+ return create_launch_template(
1784
+ self.aws.client(self._region, "ec2"),
1785
+ template_name=lt_name,
1786
+ image_id=self._discoverAMI(),
1787
+ key_name=self._keyName,
1788
+ security_group_ids=self._getSecurityGroupIDs(),
1789
+ instance_type=instance_type,
1790
+ user_data=userData,
1791
+ block_device_map=bdms,
1792
+ instance_profile_arn=self._leaderProfileArn,
1793
+ tags=tags,
1794
+ )
1370
1795
 
1371
1796
  @awsRetry
1372
- def _getAutoScalingGroupNames(self) -> List[str]:
1797
+ def _getAutoScalingGroupNames(self) -> list[str]:
1373
1798
  """
1374
1799
  Find all auto-scaling groups associated with the cluster.
1375
1800
 
@@ -1377,29 +1802,33 @@ class AWSProvisioner(AbstractProvisioner):
1377
1802
  """
1378
1803
 
1379
1804
  # Grab the connection we need to use for this operation.
1380
- autoscaling = self.aws.client(self._region, 'autoscaling')
1805
+ autoscaling: AutoScalingClient = self.aws.client(self._region, "autoscaling")
1381
1806
 
1382
1807
  # AWS won't filter ASGs server-side for us in describe_auto_scaling_groups.
1383
1808
  # So we search instances of applied tags for the ASGs they are on.
1384
1809
  # The ASGs tagged with our cluster are our ASGs.
1385
1810
  # The filtering is on different fields of the tag object itself.
1386
- filters = [{'Name': 'key',
1387
- 'Values': [_TAG_KEY_TOIL_CLUSTER_NAME]},
1388
- {'Name': 'value',
1389
- 'Values': [self.clusterName]}]
1811
+ filters: list[FilterTypeDef] = [
1812
+ {"Name": "key", "Values": [_TAG_KEY_TOIL_CLUSTER_NAME]},
1813
+ {"Name": "value", "Values": [self.clusterName]},
1814
+ ]
1390
1815
 
1391
1816
  matchedASGs = []
1392
1817
  # Get the first page with no NextToken
1393
1818
  response = autoscaling.describe_tags(Filters=filters)
1394
1819
  while True:
1395
1820
  # Process the current page
1396
- matchedASGs += [item['ResourceId'] for item in response.get('Tags', [])
1397
- if item['Key'] == _TAG_KEY_TOIL_CLUSTER_NAME and
1398
- item['Value'] == self.clusterName]
1399
- if 'NextToken' in response:
1821
+ matchedASGs += [
1822
+ item["ResourceId"]
1823
+ for item in response.get("Tags", [])
1824
+ if item["Key"] == _TAG_KEY_TOIL_CLUSTER_NAME
1825
+ and item["Value"] == self.clusterName
1826
+ ]
1827
+ if "NextToken" in response:
1400
1828
  # There are more pages. Get the next one, supplying the token.
1401
- response = autoscaling.describe_tags(Filters=filters,
1402
- NextToken=response['NextToken'])
1829
+ response = autoscaling.describe_tags(
1830
+ Filters=filters, NextToken=response["NextToken"]
1831
+ )
1403
1832
  else:
1404
1833
  # No more pages
1405
1834
  break
@@ -1407,16 +1836,18 @@ class AWSProvisioner(AbstractProvisioner):
1407
1836
  for name in matchedASGs:
1408
1837
  # Double check to make sure we definitely aren't finding non-Toil
1409
1838
  # things
1410
- assert name.startswith('toil-')
1839
+ assert name.startswith("toil-")
1411
1840
 
1412
1841
  return matchedASGs
1413
1842
 
1414
- def _createWorkerAutoScalingGroup(self,
1415
- launch_template_ids: Dict[str, str],
1416
- instance_types: Collection[str],
1417
- min_size: int,
1418
- max_size: int,
1419
- spot_bid: Optional[float] = None) -> str:
1843
+ def _createWorkerAutoScalingGroup(
1844
+ self,
1845
+ launch_template_ids: dict[str, str],
1846
+ instance_types: Collection[str],
1847
+ min_size: int,
1848
+ max_size: int,
1849
+ spot_bid: float | None = None,
1850
+ ) -> str:
1420
1851
  """
1421
1852
  Create an autoscaling group.
1422
1853
 
@@ -1446,8 +1877,12 @@ class AWSProvisioner(AbstractProvisioner):
1446
1877
  for instance_type in instance_types:
1447
1878
  spec = E2Instances[instance_type]
1448
1879
  spec_gigs = spec.disks * spec.disk_capacity
1449
- rootVolSize = self._nodeStorageOverrides.get(instance_type, self._nodeStorage)
1450
- storage_gigs.append(max(rootVolSize - _STORAGE_ROOT_OVERHEAD_GIGS, spec_gigs))
1880
+ rootVolSize = self._nodeStorageOverrides.get(
1881
+ instance_type, self._nodeStorage
1882
+ )
1883
+ storage_gigs.append(
1884
+ max(rootVolSize - _STORAGE_ROOT_OVERHEAD_GIGS, spec_gigs)
1885
+ )
1451
1886
  # Get the min storage we expect to see, but not less than 0.
1452
1887
  min_gigs = max(min(storage_gigs), 0)
1453
1888
 
@@ -1456,32 +1891,36 @@ class AWSProvisioner(AbstractProvisioner):
1456
1891
  tags = dict(self._tags)
1457
1892
 
1458
1893
  # We tag the ASG with the Toil type, although nothing cares.
1459
- tags[_TAG_KEY_TOIL_NODE_TYPE] = 'worker'
1894
+ tags[_TAG_KEY_TOIL_NODE_TYPE] = "worker"
1460
1895
 
1461
- if self.clusterType == 'kubernetes':
1896
+ if self.clusterType == "kubernetes":
1462
1897
  # We also need to tag it with Kubernetes autoscaler info (empty tags)
1463
- tags['k8s.io/cluster-autoscaler/' + self.clusterName] = ''
1464
- assert(self.clusterName != 'enabled')
1465
- tags['k8s.io/cluster-autoscaler/enabled'] = ''
1466
- tags['k8s.io/cluster-autoscaler/node-template/resources/ephemeral-storage'] = f'{min_gigs}G'
1898
+ tags["k8s.io/cluster-autoscaler/" + self.clusterName] = ""
1899
+ assert self.clusterName != "enabled"
1900
+ tags["k8s.io/cluster-autoscaler/enabled"] = ""
1901
+ tags[
1902
+ "k8s.io/cluster-autoscaler/node-template/resources/ephemeral-storage"
1903
+ ] = f"{min_gigs}G"
1467
1904
 
1468
1905
  # Now we need to make up a unique name
1469
1906
  # TODO: can we make this more semantic without risking collisions? Maybe count up in memory?
1470
- asg_name = 'toil-' + str(uuid.uuid4())
1471
-
1472
- create_auto_scaling_group(self.aws.client(self._region, 'autoscaling'),
1473
- asg_name=asg_name,
1474
- launch_template_ids=launch_template_ids,
1475
- vpc_subnets=self._get_worker_subnets(),
1476
- min_size=min_size,
1477
- max_size=max_size,
1478
- instance_types=instance_types,
1479
- spot_bid=spot_bid,
1480
- tags=tags)
1907
+ asg_name = "toil-" + str(uuid.uuid4())
1908
+
1909
+ create_auto_scaling_group(
1910
+ self.aws.client(self._region, "autoscaling"),
1911
+ asg_name=asg_name,
1912
+ launch_template_ids=launch_template_ids,
1913
+ vpc_subnets=self._get_worker_subnets(),
1914
+ min_size=min_size,
1915
+ max_size=max_size,
1916
+ instance_types=instance_types,
1917
+ spot_bid=spot_bid,
1918
+ tags=tags,
1919
+ )
1481
1920
 
1482
1921
  return asg_name
1483
1922
 
1484
- def _boto2_pager(self, requestor_callable: Callable, result_attribute_name: str) -> Iterable[Dict[str, Any]]:
1923
+ def _boto2_pager(self, requestor_callable: Callable[[...], Any], result_attribute_name: str) -> Iterable[dict[str, Any]]: # type: ignore[misc]
1485
1924
  """
1486
1925
  Yield all the results from calling the given Boto 2 method and paging
1487
1926
  through all the results using the "marker" field. Results are to be
@@ -1489,68 +1928,50 @@ class AWSProvisioner(AbstractProvisioner):
1489
1928
  """
1490
1929
  marker = None
1491
1930
  while True:
1492
- result = requestor_callable(marker=marker)
1931
+ result = requestor_callable(marker=marker) # type: ignore[call-arg]
1493
1932
  yield from getattr(result, result_attribute_name)
1494
- if result.is_truncated == 'true':
1933
+ if result.is_truncated == "true":
1495
1934
  marker = result.marker
1496
1935
  else:
1497
1936
  break
1498
1937
 
1499
- def _pager(self, requestor_callable: Callable, result_attribute_name: str, **kwargs) -> Iterable[Dict[str, Any]]:
1500
- """
1501
- Yield all the results from calling the given Boto 3 method with the
1502
- given keyword arguments, paging through the results using the Marker or
1503
- NextToken, and fetching out and looping over the list in the response
1504
- with the given attribute name.
1505
- """
1506
-
1507
- # Recover the Boto3 client, and the name of the operation
1508
- client = requestor_callable.__self__
1509
- op_name = requestor_callable.__name__
1510
-
1511
- # grab a Boto 3 built-in paginator. See
1512
- # <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/paginators.html>
1513
- paginator = client.get_paginator(op_name)
1514
-
1515
- for page in paginator.paginate(**kwargs):
1516
- # Invoke it and go through the pages, yielding from them
1517
- yield from page.get(result_attribute_name, [])
1518
-
1519
1938
  @awsRetry
1520
- def _getRoleNames(self) -> List[str]:
1939
+ def _getRoleNames(self) -> list[str]:
1521
1940
  """
1522
1941
  Get all the IAM roles belonging to the cluster, as names.
1523
1942
  """
1524
1943
 
1525
1944
  results = []
1526
- for result in self._boto2_pager(self.aws.boto2(self._region, 'iam').list_roles, 'roles'):
1945
+ boto3_iam = self.aws.client(self._region, "iam")
1946
+ for result in boto3_pager(boto3_iam.list_roles, "Roles"):
1527
1947
  # For each Boto2 role object
1528
1948
  # Grab out the name
1529
- name = result['role_name']
1949
+ result2 = cast("RoleTypeDef", result)
1950
+ name = result2["RoleName"]
1530
1951
  if self._is_our_namespaced_name(name):
1531
1952
  # If it looks like ours, it is ours.
1532
1953
  results.append(name)
1533
1954
  return results
1534
1955
 
1535
1956
  @awsRetry
1536
- def _getInstanceProfileNames(self) -> List[str]:
1957
+ def _getInstanceProfileNames(self) -> list[str]:
1537
1958
  """
1538
1959
  Get all the instance profiles belonging to the cluster, as names.
1539
1960
  """
1540
1961
 
1541
1962
  results = []
1542
- for result in self._boto2_pager(self.aws.boto2(self._region, 'iam').list_instance_profiles,
1543
- 'instance_profiles'):
1544
- # For each Boto2 role object
1963
+ boto3_iam = self.aws.client(self._region, "iam")
1964
+ for result in boto3_pager(boto3_iam.list_instance_profiles, "InstanceProfiles"):
1545
1965
  # Grab out the name
1546
- name = result['instance_profile_name']
1966
+ result2 = cast("InstanceProfileTypeDef", result)
1967
+ name = result2["InstanceProfileName"]
1547
1968
  if self._is_our_namespaced_name(name):
1548
1969
  # If it looks like ours, it is ours.
1549
1970
  results.append(name)
1550
1971
  return results
1551
1972
 
1552
1973
  @awsRetry
1553
- def _getRoleInstanceProfileNames(self, role_name: str) -> List[str]:
1974
+ def _getRoleInstanceProfileNames(self, role_name: str) -> list[str]:
1554
1975
  """
1555
1976
  Get all the instance profiles with the IAM role with the given name.
1556
1977
 
@@ -1558,14 +1979,19 @@ class AWSProvisioner(AbstractProvisioner):
1558
1979
  """
1559
1980
 
1560
1981
  # Grab the connection we need to use for this operation.
1561
- iam = self.aws.client(self._region, 'iam')
1562
-
1563
- return [item['InstanceProfileName'] for item in self._pager(iam.list_instance_profiles_for_role,
1564
- 'InstanceProfiles',
1565
- RoleName=role_name)]
1982
+ boto3_iam: IAMClient = self.aws.client(self._region, "iam")
1983
+
1984
+ return [
1985
+ item["InstanceProfileName"]
1986
+ for item in boto3_pager(
1987
+ boto3_iam.list_instance_profiles_for_role,
1988
+ "InstanceProfiles",
1989
+ RoleName=role_name,
1990
+ )
1991
+ ]
1566
1992
 
1567
1993
  @awsRetry
1568
- def _getRolePolicyArns(self, role_name: str) -> List[str]:
1994
+ def _getRolePolicyArns(self, role_name: str) -> list[str]:
1569
1995
  """
1570
1996
  Get all the policies attached to the IAM role with the given name.
1571
1997
 
@@ -1575,36 +2001,44 @@ class AWSProvisioner(AbstractProvisioner):
1575
2001
  """
1576
2002
 
1577
2003
  # Grab the connection we need to use for this operation.
1578
- iam = self.aws.client(self._region, 'iam')
2004
+ boto3_iam: IAMClient = self.aws.client(self._region, "iam")
1579
2005
 
1580
2006
  # TODO: we don't currently use attached policies.
1581
2007
 
1582
- return [item['PolicyArn'] for item in self._pager(iam.list_attached_role_policies,
1583
- 'AttachedPolicies',
1584
- RoleName=role_name)]
2008
+ return [
2009
+ item["PolicyArn"]
2010
+ for item in boto3_pager(
2011
+ boto3_iam.list_attached_role_policies,
2012
+ "AttachedPolicies",
2013
+ RoleName=role_name,
2014
+ )
2015
+ ]
1585
2016
 
1586
2017
  @awsRetry
1587
- def _getRoleInlinePolicyNames(self, role_name: str) -> List[str]:
2018
+ def _getRoleInlinePolicyNames(self, role_name: str) -> list[str]:
1588
2019
  """
1589
2020
  Get all the policies inline in the given IAM role.
1590
2021
  Returns policy names.
1591
2022
  """
1592
2023
 
1593
2024
  # Grab the connection we need to use for this operation.
1594
- iam = self.aws.client(self._region, 'iam')
2025
+ boto3_iam: IAMClient = self.aws.client(self._region, "iam")
1595
2026
 
1596
- return list(self._pager(iam.list_role_policies,
1597
- 'PolicyNames',
1598
- RoleName=role_name))
2027
+ return list(
2028
+ boto3_pager(boto3_iam.list_role_policies, "PolicyNames", RoleName=role_name)
2029
+ )
1599
2030
 
1600
- def full_policy(self, resource: str) -> dict:
2031
+ def full_policy(self, resource: str) -> dict[str, Any]:
1601
2032
  """
1602
2033
  Produce a dict describing the JSON form of a full-access-granting AWS
1603
2034
  IAM policy for the service with the given name (e.g. 's3').
1604
2035
  """
1605
- return dict(Version="2012-10-17", Statement=[dict(Effect="Allow", Resource="*", Action=f"{resource}:*")])
2036
+ return dict(
2037
+ Version="2012-10-17",
2038
+ Statement=[dict(Effect="Allow", Resource="*", Action=f"{resource}:*")],
2039
+ )
1606
2040
 
1607
- def kubernetes_policy(self) -> dict:
2041
+ def kubernetes_policy(self) -> dict[str, Any]:
1608
2042
  """
1609
2043
  Get the Kubernetes policy grants not provided by the full grants on EC2
1610
2044
  and IAM. See
@@ -1618,101 +2052,86 @@ class AWSProvisioner(AbstractProvisioner):
1618
2052
  Some of these are really only needed on the leader.
1619
2053
  """
1620
2054
 
1621
- return dict(Version="2012-10-17", Statement=[dict(Effect="Allow", Resource="*", Action=[
1622
- "ecr:GetAuthorizationToken",
1623
- "ecr:BatchCheckLayerAvailability",
1624
- "ecr:GetDownloadUrlForLayer",
1625
- "ecr:GetRepositoryPolicy",
1626
- "ecr:DescribeRepositories",
1627
- "ecr:ListImages",
1628
- "ecr:BatchGetImage",
1629
- "autoscaling:DescribeAutoScalingGroups",
1630
- "autoscaling:DescribeAutoScalingInstances",
1631
- "autoscaling:DescribeLaunchConfigurations",
1632
- "autoscaling:DescribeTags",
1633
- "autoscaling:SetDesiredCapacity",
1634
- "autoscaling:TerminateInstanceInAutoScalingGroup",
1635
- "elasticloadbalancing:AddTags",
1636
- "elasticloadbalancing:ApplySecurityGroupsToLoadBalancer",
1637
- "elasticloadbalancing:AttachLoadBalancerToSubnets",
1638
- "elasticloadbalancing:ConfigureHealthCheck",
1639
- "elasticloadbalancing:CreateListener",
1640
- "elasticloadbalancing:CreateLoadBalancer",
1641
- "elasticloadbalancing:CreateLoadBalancerListeners",
1642
- "elasticloadbalancing:CreateLoadBalancerPolicy",
1643
- "elasticloadbalancing:CreateTargetGroup",
1644
- "elasticloadbalancing:DeleteListener",
1645
- "elasticloadbalancing:DeleteLoadBalancer",
1646
- "elasticloadbalancing:DeleteLoadBalancerListeners",
1647
- "elasticloadbalancing:DeleteTargetGroup",
1648
- "elasticloadbalancing:DeregisterInstancesFromLoadBalancer",
1649
- "elasticloadbalancing:DeregisterTargets",
1650
- "elasticloadbalancing:DescribeListeners",
1651
- "elasticloadbalancing:DescribeLoadBalancerAttributes",
1652
- "elasticloadbalancing:DescribeLoadBalancerPolicies",
1653
- "elasticloadbalancing:DescribeLoadBalancers",
1654
- "elasticloadbalancing:DescribeTargetGroups",
1655
- "elasticloadbalancing:DescribeTargetHealth",
1656
- "elasticloadbalancing:DetachLoadBalancerFromSubnets",
1657
- "elasticloadbalancing:ModifyListener",
1658
- "elasticloadbalancing:ModifyLoadBalancerAttributes",
1659
- "elasticloadbalancing:ModifyTargetGroup",
1660
- "elasticloadbalancing:RegisterInstancesWithLoadBalancer",
1661
- "elasticloadbalancing:RegisterTargets",
1662
- "elasticloadbalancing:SetLoadBalancerPoliciesForBackendServer",
1663
- "elasticloadbalancing:SetLoadBalancerPoliciesOfListener",
1664
- "kms:DescribeKey"
1665
- ])])
1666
-
1667
- def _setup_iam_ec2_role(self, local_role_name: str, policies: Dict[str, Any]) -> str:
2055
+ return dict(
2056
+ Version="2012-10-17",
2057
+ Statement=[
2058
+ dict(
2059
+ Effect="Allow",
2060
+ Resource="*",
2061
+ Action=[
2062
+ "ecr:GetAuthorizationToken",
2063
+ "ecr:BatchCheckLayerAvailability",
2064
+ "ecr:GetDownloadUrlForLayer",
2065
+ "ecr:GetRepositoryPolicy",
2066
+ "ecr:DescribeRepositories",
2067
+ "ecr:ListImages",
2068
+ "ecr:BatchGetImage",
2069
+ "autoscaling:DescribeAutoScalingGroups",
2070
+ "autoscaling:DescribeAutoScalingInstances",
2071
+ "autoscaling:DescribeLaunchConfigurations",
2072
+ "autoscaling:DescribeTags",
2073
+ "autoscaling:SetDesiredCapacity",
2074
+ "autoscaling:TerminateInstanceInAutoScalingGroup",
2075
+ "elasticloadbalancing:AddTags",
2076
+ "elasticloadbalancing:ApplySecurityGroupsToLoadBalancer",
2077
+ "elasticloadbalancing:AttachLoadBalancerToSubnets",
2078
+ "elasticloadbalancing:ConfigureHealthCheck",
2079
+ "elasticloadbalancing:CreateListener",
2080
+ "elasticloadbalancing:CreateLoadBalancer",
2081
+ "elasticloadbalancing:CreateLoadBalancerListeners",
2082
+ "elasticloadbalancing:CreateLoadBalancerPolicy",
2083
+ "elasticloadbalancing:CreateTargetGroup",
2084
+ "elasticloadbalancing:DeleteListener",
2085
+ "elasticloadbalancing:DeleteLoadBalancer",
2086
+ "elasticloadbalancing:DeleteLoadBalancerListeners",
2087
+ "elasticloadbalancing:DeleteTargetGroup",
2088
+ "elasticloadbalancing:DeregisterInstancesFromLoadBalancer",
2089
+ "elasticloadbalancing:DeregisterTargets",
2090
+ "elasticloadbalancing:DescribeListeners",
2091
+ "elasticloadbalancing:DescribeLoadBalancerAttributes",
2092
+ "elasticloadbalancing:DescribeLoadBalancerPolicies",
2093
+ "elasticloadbalancing:DescribeLoadBalancers",
2094
+ "elasticloadbalancing:DescribeTargetGroups",
2095
+ "elasticloadbalancing:DescribeTargetHealth",
2096
+ "elasticloadbalancing:DetachLoadBalancerFromSubnets",
2097
+ "elasticloadbalancing:ModifyListener",
2098
+ "elasticloadbalancing:ModifyLoadBalancerAttributes",
2099
+ "elasticloadbalancing:ModifyTargetGroup",
2100
+ "elasticloadbalancing:RegisterInstancesWithLoadBalancer",
2101
+ "elasticloadbalancing:RegisterTargets",
2102
+ "elasticloadbalancing:SetLoadBalancerPoliciesForBackendServer",
2103
+ "elasticloadbalancing:SetLoadBalancerPoliciesOfListener",
2104
+ "kms:DescribeKey",
2105
+ ],
2106
+ )
2107
+ ],
2108
+ )
2109
+
2110
+ def _setup_iam_ec2_role(
2111
+ self, local_role_name: str, policies: dict[str, Any]
2112
+ ) -> str:
1668
2113
  """
1669
2114
  Create an IAM role with the given policies, using the given name in
1670
2115
  addition to the cluster name, and return its full name.
1671
2116
  """
1672
-
1673
- # Grab the connection we need to use for this operation.
1674
- iam = self.aws.boto2(self._region, 'iam')
1675
-
1676
- # Make sure we can tell our roles apart from roles for other clusters
1677
- aws_role_name = self._namespace_name(local_role_name)
1678
- try:
1679
- # Make the role
1680
- logger.debug('Creating IAM role %s...', aws_role_name)
1681
- iam.create_role(aws_role_name, assume_role_policy_document=json.dumps({
2117
+ ec2_role_policy_document = json.dumps(
2118
+ {
1682
2119
  "Version": "2012-10-17",
1683
- "Statement": [{
1684
- "Effect": "Allow",
1685
- "Principal": {"Service": ["ec2.amazonaws.com"]},
1686
- "Action": ["sts:AssumeRole"]}
1687
- ]}))
1688
- logger.debug('Created new IAM role')
1689
- except BotoServerError as e:
1690
- if e.status == 409 and e.error_code == 'EntityAlreadyExists':
1691
- logger.debug('IAM role already exists. Reusing.')
1692
- else:
1693
- raise
1694
-
1695
- # Delete superfluous policies
1696
- policy_names = set(iam.list_role_policies(aws_role_name).policy_names)
1697
- for policy_name in policy_names.difference(set(list(policies.keys()))):
1698
- iam.delete_role_policy(aws_role_name, policy_name)
1699
-
1700
- # Create expected policies
1701
- for policy_name, policy in policies.items():
1702
- current_policy = None
1703
- try:
1704
- current_policy = json.loads(unquote(
1705
- iam.get_role_policy(aws_role_name, policy_name).policy_document))
1706
- except BotoServerError as e:
1707
- if e.status == 404 and e.error_code == 'NoSuchEntity':
1708
- pass
1709
- else:
1710
- raise
1711
- if current_policy != policy:
1712
- iam.put_role_policy(aws_role_name, policy_name, json.dumps(policy))
1713
-
1714
- # Now the role has the right policies so it is ready.
1715
- return aws_role_name
2120
+ "Statement": [
2121
+ {
2122
+ "Effect": "Allow",
2123
+ "Principal": {"Service": ["ec2.amazonaws.com"]},
2124
+ "Action": ["sts:AssumeRole"],
2125
+ }
2126
+ ],
2127
+ }
2128
+ )
2129
+ return create_iam_role(
2130
+ role_name=self._namespace_name(local_role_name),
2131
+ assume_role_policy_document=ec2_role_policy_document,
2132
+ policies=policies,
2133
+ region=self._region,
2134
+ )
1716
2135
 
1717
2136
  @awsRetry
1718
2137
  def _createProfileArn(self) -> str:
@@ -1724,56 +2143,67 @@ class AWSProvisioner(AbstractProvisioner):
1724
2143
  """
1725
2144
 
1726
2145
  # Grab the connection we need to use for this operation.
1727
- iam = self.aws.boto2(self._region, 'iam')
1728
-
1729
- policy = dict(iam_full=self.full_policy('iam'), ec2_full=self.full_policy('ec2'),
1730
- s3_full=self.full_policy('s3'), sbd_full=self.full_policy('sdb'))
1731
- if self.clusterType == 'kubernetes':
2146
+ boto3_iam: IAMClient = self.aws.client(self._region, "iam")
2147
+
2148
+ policy = dict(
2149
+ iam_full=self.full_policy("iam"),
2150
+ ec2_full=self.full_policy("ec2"),
2151
+ s3_full=self.full_policy("s3"),
2152
+ sbd_full=self.full_policy("sdb"),
2153
+ )
2154
+ if self.clusterType == "kubernetes":
1732
2155
  # We also need autoscaling groups and some other stuff for AWS-Kubernetes integrations.
1733
2156
  # TODO: We use one merged policy for leader and worker, but we could be more specific.
1734
- policy['kubernetes_merged'] = self.kubernetes_policy()
2157
+ policy["kubernetes_merged"] = self.kubernetes_policy()
1735
2158
  iamRoleName = self._setup_iam_ec2_role(_INSTANCE_PROFILE_ROLE_NAME, policy)
1736
2159
 
1737
2160
  try:
1738
- profile = iam.get_instance_profile(iamRoleName)
1739
- logger.debug("Have preexisting instance profile: %s", profile.get_instance_profile_response.get_instance_profile_result.instance_profile)
1740
- except BotoServerError as e:
1741
- if e.status == 404:
1742
- profile = iam.create_instance_profile(iamRoleName)
1743
- profile = profile.create_instance_profile_response.create_instance_profile_result
1744
- logger.debug("Created new instance profile: %s", profile.instance_profile)
1745
- else:
1746
- raise
2161
+ profile_result = boto3_iam.get_instance_profile(
2162
+ InstanceProfileName=iamRoleName
2163
+ )
2164
+ profile: InstanceProfileTypeDef = profile_result["InstanceProfile"]
2165
+ logger.debug("Have preexisting instance profile: %s", profile)
2166
+ except boto3_iam.exceptions.NoSuchEntityException:
2167
+ profile_result = boto3_iam.create_instance_profile(
2168
+ InstanceProfileName=iamRoleName
2169
+ )
2170
+ profile = profile_result["InstanceProfile"]
2171
+ logger.debug("Created new instance profile: %s", profile)
1747
2172
  else:
1748
- profile = profile.get_instance_profile_response.get_instance_profile_result
1749
- profile = profile.instance_profile
2173
+ profile = profile_result["InstanceProfile"]
1750
2174
 
1751
- profile_arn = profile.arn
2175
+ profile_arn: str = profile["Arn"]
1752
2176
 
1753
2177
  # Now we have the profile ARN, but we want to make sure it really is
1754
2178
  # visible by name in a different session.
1755
2179
  wait_until_instance_profile_arn_exists(profile_arn)
1756
2180
 
1757
- if len(profile.roles) > 1:
2181
+ if len(profile["Roles"]) > 1:
1758
2182
  # This is too many roles. We probably grabbed something we should
1759
2183
  # not have by mistake, and this is some important profile for
1760
2184
  # something else.
1761
- raise RuntimeError(f'Did not expect instance profile {profile_arn} to contain '
1762
- f'more than one role; is it really a Toil-managed profile?')
1763
- elif len(profile.roles) == 1:
1764
- # this should be profile.roles[0].role_name
1765
- if profile.roles.member.role_name == iamRoleName:
2185
+ raise RuntimeError(
2186
+ f"Did not expect instance profile {profile_arn} to contain "
2187
+ f"more than one role; is it really a Toil-managed profile?"
2188
+ )
2189
+ elif len(profile["Roles"]) == 1:
2190
+ if profile["Roles"][0]["RoleName"] == iamRoleName:
1766
2191
  return profile_arn
1767
2192
  else:
1768
2193
  # Drop this wrong role and use the fallback code for 0 roles
1769
- iam.remove_role_from_instance_profile(iamRoleName,
1770
- profile.roles.member.role_name)
2194
+ boto3_iam.remove_role_from_instance_profile(
2195
+ InstanceProfileName=iamRoleName,
2196
+ RoleName=profile["Roles"][0]["RoleName"],
2197
+ )
1771
2198
 
1772
2199
  # If we get here, we had 0 roles on the profile, or we had 1 but we removed it.
1773
- for attempt in old_retry(predicate=lambda err: err.status == 404):
2200
+ for attempt in old_retry(predicate=lambda err: get_error_status(err) == 404):
1774
2201
  with attempt:
1775
2202
  # Put the IAM role on the profile
1776
- iam.add_role_to_instance_profile(profile.instance_profile_name, iamRoleName)
2203
+ boto3_iam.add_role_to_instance_profile(
2204
+ InstanceProfileName=profile["InstanceProfileName"],
2205
+ RoleName=iamRoleName,
2206
+ )
1777
2207
  logger.debug("Associated role %s with profile", iamRoleName)
1778
2208
 
1779
2209
  return profile_arn