toil 7.0.0__py3-none-any.whl → 8.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. toil/__init__.py +121 -83
  2. toil/batchSystems/__init__.py +1 -0
  3. toil/batchSystems/abstractBatchSystem.py +137 -77
  4. toil/batchSystems/abstractGridEngineBatchSystem.py +211 -101
  5. toil/batchSystems/awsBatch.py +237 -128
  6. toil/batchSystems/cleanup_support.py +22 -16
  7. toil/batchSystems/contained_executor.py +30 -26
  8. toil/batchSystems/gridengine.py +85 -49
  9. toil/batchSystems/htcondor.py +164 -87
  10. toil/batchSystems/kubernetes.py +622 -386
  11. toil/batchSystems/local_support.py +17 -12
  12. toil/batchSystems/lsf.py +132 -79
  13. toil/batchSystems/lsfHelper.py +13 -11
  14. toil/batchSystems/mesos/__init__.py +41 -29
  15. toil/batchSystems/mesos/batchSystem.py +288 -149
  16. toil/batchSystems/mesos/executor.py +77 -49
  17. toil/batchSystems/mesos/test/__init__.py +31 -23
  18. toil/batchSystems/options.py +38 -29
  19. toil/batchSystems/registry.py +53 -19
  20. toil/batchSystems/singleMachine.py +293 -123
  21. toil/batchSystems/slurm.py +489 -137
  22. toil/batchSystems/torque.py +46 -32
  23. toil/bus.py +141 -73
  24. toil/common.py +630 -359
  25. toil/cwl/__init__.py +1 -1
  26. toil/cwl/cwltoil.py +1114 -532
  27. toil/cwl/utils.py +17 -22
  28. toil/deferred.py +62 -41
  29. toil/exceptions.py +5 -3
  30. toil/fileStores/__init__.py +5 -5
  31. toil/fileStores/abstractFileStore.py +88 -57
  32. toil/fileStores/cachingFileStore.py +711 -247
  33. toil/fileStores/nonCachingFileStore.py +113 -75
  34. toil/job.py +988 -315
  35. toil/jobStores/abstractJobStore.py +387 -243
  36. toil/jobStores/aws/jobStore.py +727 -403
  37. toil/jobStores/aws/utils.py +161 -109
  38. toil/jobStores/conftest.py +1 -0
  39. toil/jobStores/fileJobStore.py +289 -151
  40. toil/jobStores/googleJobStore.py +137 -70
  41. toil/jobStores/utils.py +36 -15
  42. toil/leader.py +614 -269
  43. toil/lib/accelerators.py +115 -18
  44. toil/lib/aws/__init__.py +55 -28
  45. toil/lib/aws/ami.py +122 -87
  46. toil/lib/aws/iam.py +284 -108
  47. toil/lib/aws/s3.py +31 -0
  48. toil/lib/aws/session.py +193 -58
  49. toil/lib/aws/utils.py +238 -218
  50. toil/lib/bioio.py +13 -5
  51. toil/lib/compatibility.py +11 -6
  52. toil/lib/conversions.py +83 -49
  53. toil/lib/docker.py +131 -103
  54. toil/lib/ec2.py +322 -209
  55. toil/lib/ec2nodes.py +174 -106
  56. toil/lib/encryption/_dummy.py +5 -3
  57. toil/lib/encryption/_nacl.py +10 -6
  58. toil/lib/encryption/conftest.py +1 -0
  59. toil/lib/exceptions.py +26 -7
  60. toil/lib/expando.py +4 -2
  61. toil/lib/ftp_utils.py +217 -0
  62. toil/lib/generatedEC2Lists.py +127 -19
  63. toil/lib/humanize.py +6 -2
  64. toil/lib/integration.py +341 -0
  65. toil/lib/io.py +99 -11
  66. toil/lib/iterables.py +4 -2
  67. toil/lib/memoize.py +12 -8
  68. toil/lib/misc.py +65 -18
  69. toil/lib/objects.py +2 -2
  70. toil/lib/resources.py +19 -7
  71. toil/lib/retry.py +115 -77
  72. toil/lib/threading.py +282 -80
  73. toil/lib/throttle.py +15 -14
  74. toil/options/common.py +834 -401
  75. toil/options/cwl.py +175 -90
  76. toil/options/runner.py +50 -0
  77. toil/options/wdl.py +70 -19
  78. toil/provisioners/__init__.py +111 -46
  79. toil/provisioners/abstractProvisioner.py +322 -157
  80. toil/provisioners/aws/__init__.py +62 -30
  81. toil/provisioners/aws/awsProvisioner.py +980 -627
  82. toil/provisioners/clusterScaler.py +541 -279
  83. toil/provisioners/gceProvisioner.py +282 -179
  84. toil/provisioners/node.py +147 -79
  85. toil/realtimeLogger.py +34 -22
  86. toil/resource.py +137 -75
  87. toil/server/app.py +127 -61
  88. toil/server/celery_app.py +3 -1
  89. toil/server/cli/wes_cwl_runner.py +82 -53
  90. toil/server/utils.py +54 -28
  91. toil/server/wes/abstract_backend.py +64 -26
  92. toil/server/wes/amazon_wes_utils.py +21 -15
  93. toil/server/wes/tasks.py +121 -63
  94. toil/server/wes/toil_backend.py +142 -107
  95. toil/server/wsgi_app.py +4 -3
  96. toil/serviceManager.py +58 -22
  97. toil/statsAndLogging.py +148 -64
  98. toil/test/__init__.py +263 -179
  99. toil/test/batchSystems/batchSystemTest.py +438 -195
  100. toil/test/batchSystems/batch_system_plugin_test.py +18 -7
  101. toil/test/batchSystems/test_gridengine.py +173 -0
  102. toil/test/batchSystems/test_lsf_helper.py +67 -58
  103. toil/test/batchSystems/test_slurm.py +93 -47
  104. toil/test/cactus/test_cactus_integration.py +20 -22
  105. toil/test/cwl/cwlTest.py +271 -71
  106. toil/test/cwl/measure_default_memory.cwl +12 -0
  107. toil/test/cwl/not_run_required_input.cwl +29 -0
  108. toil/test/cwl/scatter_duplicate_outputs.cwl +40 -0
  109. toil/test/docs/scriptsTest.py +60 -34
  110. toil/test/jobStores/jobStoreTest.py +412 -235
  111. toil/test/lib/aws/test_iam.py +116 -48
  112. toil/test/lib/aws/test_s3.py +16 -9
  113. toil/test/lib/aws/test_utils.py +5 -6
  114. toil/test/lib/dockerTest.py +118 -141
  115. toil/test/lib/test_conversions.py +113 -115
  116. toil/test/lib/test_ec2.py +57 -49
  117. toil/test/lib/test_integration.py +104 -0
  118. toil/test/lib/test_misc.py +12 -5
  119. toil/test/mesos/MesosDataStructuresTest.py +23 -10
  120. toil/test/mesos/helloWorld.py +7 -6
  121. toil/test/mesos/stress.py +25 -20
  122. toil/test/options/options.py +7 -2
  123. toil/test/provisioners/aws/awsProvisionerTest.py +293 -140
  124. toil/test/provisioners/clusterScalerTest.py +440 -250
  125. toil/test/provisioners/clusterTest.py +81 -42
  126. toil/test/provisioners/gceProvisionerTest.py +174 -100
  127. toil/test/provisioners/provisionerTest.py +25 -13
  128. toil/test/provisioners/restartScript.py +5 -4
  129. toil/test/server/serverTest.py +188 -141
  130. toil/test/sort/restart_sort.py +137 -68
  131. toil/test/sort/sort.py +134 -66
  132. toil/test/sort/sortTest.py +91 -49
  133. toil/test/src/autoDeploymentTest.py +140 -100
  134. toil/test/src/busTest.py +20 -18
  135. toil/test/src/checkpointTest.py +8 -2
  136. toil/test/src/deferredFunctionTest.py +49 -35
  137. toil/test/src/dockerCheckTest.py +33 -26
  138. toil/test/src/environmentTest.py +20 -10
  139. toil/test/src/fileStoreTest.py +538 -271
  140. toil/test/src/helloWorldTest.py +7 -4
  141. toil/test/src/importExportFileTest.py +61 -31
  142. toil/test/src/jobDescriptionTest.py +32 -17
  143. toil/test/src/jobEncapsulationTest.py +2 -0
  144. toil/test/src/jobFileStoreTest.py +74 -50
  145. toil/test/src/jobServiceTest.py +187 -73
  146. toil/test/src/jobTest.py +120 -70
  147. toil/test/src/miscTests.py +19 -18
  148. toil/test/src/promisedRequirementTest.py +82 -36
  149. toil/test/src/promisesTest.py +7 -6
  150. toil/test/src/realtimeLoggerTest.py +6 -6
  151. toil/test/src/regularLogTest.py +71 -37
  152. toil/test/src/resourceTest.py +80 -49
  153. toil/test/src/restartDAGTest.py +36 -22
  154. toil/test/src/resumabilityTest.py +9 -2
  155. toil/test/src/retainTempDirTest.py +45 -14
  156. toil/test/src/systemTest.py +12 -8
  157. toil/test/src/threadingTest.py +44 -25
  158. toil/test/src/toilContextManagerTest.py +10 -7
  159. toil/test/src/userDefinedJobArgTypeTest.py +8 -5
  160. toil/test/src/workerTest.py +33 -16
  161. toil/test/utils/toilDebugTest.py +70 -58
  162. toil/test/utils/toilKillTest.py +4 -5
  163. toil/test/utils/utilsTest.py +239 -102
  164. toil/test/wdl/wdltoil_test.py +789 -148
  165. toil/test/wdl/wdltoil_test_kubernetes.py +37 -23
  166. toil/toilState.py +52 -26
  167. toil/utils/toilConfig.py +13 -4
  168. toil/utils/toilDebugFile.py +44 -27
  169. toil/utils/toilDebugJob.py +85 -25
  170. toil/utils/toilDestroyCluster.py +11 -6
  171. toil/utils/toilKill.py +8 -3
  172. toil/utils/toilLaunchCluster.py +251 -145
  173. toil/utils/toilMain.py +37 -16
  174. toil/utils/toilRsyncCluster.py +27 -14
  175. toil/utils/toilSshCluster.py +45 -22
  176. toil/utils/toilStats.py +75 -36
  177. toil/utils/toilStatus.py +226 -119
  178. toil/utils/toilUpdateEC2Instances.py +3 -1
  179. toil/version.py +11 -11
  180. toil/wdl/utils.py +5 -5
  181. toil/wdl/wdltoil.py +3513 -1052
  182. toil/worker.py +269 -128
  183. toil-8.0.0.dist-info/METADATA +173 -0
  184. toil-8.0.0.dist-info/RECORD +253 -0
  185. {toil-7.0.0.dist-info → toil-8.0.0.dist-info}/WHEEL +1 -1
  186. toil-7.0.0.dist-info/METADATA +0 -158
  187. toil-7.0.0.dist-info/RECORD +0 -244
  188. {toil-7.0.0.dist-info → toil-8.0.0.dist-info}/LICENSE +0 -0
  189. {toil-7.0.0.dist-info → toil-8.0.0.dist-info}/entry_points.txt +0 -0
  190. {toil-7.0.0.dist-info → toil-8.0.0.dist-info}/top_level.txt +0 -0
@@ -22,83 +22,93 @@ import string
22
22
  import textwrap
23
23
  import time
24
24
  import uuid
25
+ from collections.abc import Collection, Iterable
25
26
  from functools import wraps
26
27
  from shlex import quote
27
- from typing import (Any,
28
- Callable,
29
- Collection,
30
- Dict,
31
- Iterable,
32
- List,
33
- Optional,
34
- Set,
35
- Union,
36
- Literal,
37
- cast,
38
- TypeVar)
39
- from urllib.parse import unquote
28
+ from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
40
29
 
41
30
  # We need these to exist as attributes we can get off of the boto object
42
31
  from botocore.exceptions import ClientError
43
- from mypy_boto3_autoscaling.client import AutoScalingClient
44
- from mypy_boto3_ec2.service_resource import Instance
45
- from mypy_boto3_iam.type_defs import InstanceProfileTypeDef, RoleTypeDef, ListRolePoliciesResponseTypeDef
46
- from mypy_extensions import VarArg, KwArg
47
32
 
48
- from toil.lib.aws import zone_to_region, AWSRegionName, AWSServerErrors
33
+ from toil.lib.aws import AWSRegionName, AWSServerErrors, zone_to_region
49
34
  from toil.lib.aws.ami import get_flatcar_ami
50
- from toil.lib.aws.iam import (CLUSTER_LAUNCHING_PERMISSIONS,
51
- get_policy_permissions,
52
- policy_permissions_allow)
35
+ from toil.lib.aws.iam import (
36
+ CLUSTER_LAUNCHING_PERMISSIONS,
37
+ create_iam_role,
38
+ get_policy_permissions,
39
+ policy_permissions_allow,
40
+ )
53
41
  from toil.lib.aws.session import AWSConnectionManager
54
- from toil.lib.aws.utils import create_s3_bucket, flatten_tags, boto3_pager
42
+ from toil.lib.aws.session import client as get_client
43
+ from toil.lib.aws.utils import boto3_pager, create_s3_bucket
55
44
  from toil.lib.conversions import human2bytes
56
- from toil.lib.ec2 import (a_short_time,
57
- create_auto_scaling_group,
58
- create_instances,
59
- create_launch_template,
60
- create_ondemand_instances,
61
- increase_instance_hop_limit,
62
- create_spot_instances,
63
- wait_instances_running,
64
- wait_transition,
65
- wait_until_instance_profile_arn_exists)
45
+ from toil.lib.ec2 import (
46
+ a_short_time,
47
+ create_auto_scaling_group,
48
+ create_instances,
49
+ create_launch_template,
50
+ create_ondemand_instances,
51
+ create_spot_instances,
52
+ increase_instance_hop_limit,
53
+ wait_instances_running,
54
+ wait_transition,
55
+ wait_until_instance_profile_arn_exists,
56
+ )
66
57
  from toil.lib.ec2nodes import InstanceType
67
58
  from toil.lib.generatedEC2Lists import E2Instances
68
59
  from toil.lib.memoize import memoize
69
60
  from toil.lib.misc import truncExpBackoff
70
- from toil.lib.retry import (ErrorCondition,
71
- get_error_body,
72
- get_error_code,
73
- get_error_message,
74
- get_error_status,
75
- old_retry,
76
- retry)
77
- from toil.provisioners import (ClusterCombinationNotSupportedException,
78
- NoSuchClusterException,
79
- NoSuchZoneException)
80
- from toil.provisioners.abstractProvisioner import (AbstractProvisioner,
81
- ManagedNodesNotSupportedException,
82
- Shape)
61
+ from toil.lib.retry import (
62
+ get_error_body,
63
+ get_error_code,
64
+ get_error_message,
65
+ get_error_status,
66
+ old_retry,
67
+ retry,
68
+ )
69
+ from toil.provisioners import (
70
+ ClusterCombinationNotSupportedException,
71
+ NoSuchClusterException,
72
+ NoSuchZoneException,
73
+ )
74
+ from toil.provisioners.abstractProvisioner import (
75
+ AbstractProvisioner,
76
+ ManagedNodesNotSupportedException,
77
+ Shape,
78
+ )
83
79
  from toil.provisioners.aws import get_best_aws_zone
84
80
  from toil.provisioners.node import Node
85
- from toil.lib.aws.session import client as get_client
86
81
 
87
- from mypy_boto3_ec2.client import EC2Client
88
- from mypy_boto3_iam.client import IAMClient
89
- from mypy_boto3_ec2.type_defs import DescribeInstancesResultTypeDef, InstanceTypeDef, TagTypeDef, BlockDeviceMappingTypeDef, EbsBlockDeviceTypeDef, FilterTypeDef, SpotInstanceRequestTypeDef, TagDescriptionTypeDef, SecurityGroupTypeDef, \
90
- CreateSecurityGroupResultTypeDef, IpPermissionTypeDef, ReservationTypeDef
91
- from mypy_boto3_s3.literals import BucketLocationConstraintType
82
+ if TYPE_CHECKING:
83
+ from mypy_boto3_autoscaling.client import AutoScalingClient
84
+ from mypy_boto3_ec2.client import EC2Client
85
+ from mypy_boto3_ec2.service_resource import Instance
86
+ from mypy_boto3_ec2.type_defs import (
87
+ BlockDeviceMappingTypeDef,
88
+ CreateSecurityGroupResultTypeDef,
89
+ DescribeInstancesResultTypeDef,
90
+ EbsBlockDeviceTypeDef,
91
+ FilterTypeDef,
92
+ InstanceTypeDef,
93
+ IpPermissionTypeDef,
94
+ ReservationTypeDef,
95
+ SecurityGroupTypeDef,
96
+ SpotInstanceRequestTypeDef,
97
+ TagDescriptionTypeDef,
98
+ TagTypeDef,
99
+ )
100
+ from mypy_boto3_iam.client import IAMClient
101
+ from mypy_boto3_iam.type_defs import InstanceProfileTypeDef, RoleTypeDef
92
102
 
93
103
  logger = logging.getLogger(__name__)
94
104
  logging.getLogger("boto").setLevel(logging.CRITICAL)
95
105
  # Role name (used as the suffix) for EC2 instance profiles that are automatically created by Toil.
96
- _INSTANCE_PROFILE_ROLE_NAME = 'toil'
106
+ _INSTANCE_PROFILE_ROLE_NAME = "toil"
97
107
  # The tag key that specifies the Toil node type ("leader" or "worker") so that
98
108
  # leader vs. worker nodes can be robustly identified.
99
- _TAG_KEY_TOIL_NODE_TYPE = 'ToilNodeType'
109
+ _TAG_KEY_TOIL_NODE_TYPE = "ToilNodeType"
100
110
  # The tag that specifies the cluster name on all nodes
101
- _TAG_KEY_TOIL_CLUSTER_NAME = 'clusterName'
111
+ _TAG_KEY_TOIL_CLUSTER_NAME = "clusterName"
102
112
  # How much storage on the root volume is expected to go to overhead and be
103
113
  # unavailable to jobs when the node comes up?
104
114
  # TODO: measure
@@ -106,7 +116,7 @@ _STORAGE_ROOT_OVERHEAD_GIGS = 4
106
116
  # The maximum length of a S3 bucket
107
117
  _S3_BUCKET_MAX_NAME_LEN = 63
108
118
  # The suffix of the S3 bucket associated with the cluster
109
- _S3_BUCKET_INTERNAL_SUFFIX = '--internal'
119
+ _S3_BUCKET_INTERNAL_SUFFIX = "--internal"
110
120
 
111
121
 
112
122
  def awsRetryPredicate(e: Exception) -> bool:
@@ -115,14 +125,14 @@ def awsRetryPredicate(e: Exception) -> bool:
115
125
  # socket.gaierror: [Errno -2] Name or service not known
116
126
  return True
117
127
  # boto/AWS gives multiple messages for the same error...
118
- if get_error_status(e) == 503 and 'Request limit exceeded' in get_error_body(e):
128
+ if get_error_status(e) == 503 and "Request limit exceeded" in get_error_body(e):
119
129
  return True
120
- elif get_error_status(e) == 400 and 'Rate exceeded' in get_error_body(e):
130
+ elif get_error_status(e) == 400 and "Rate exceeded" in get_error_body(e):
121
131
  return True
122
- elif get_error_status(e) == 400 and 'NotFound' in get_error_body(e):
132
+ elif get_error_status(e) == 400 and "NotFound" in get_error_body(e):
123
133
  # EC2 can take a while to propagate instance IDs to all servers.
124
134
  return True
125
- elif get_error_status(e) == 400 and get_error_code(e) == 'Throttling':
135
+ elif get_error_status(e) == 400 and get_error_code(e) == "Throttling":
126
136
  return True
127
137
  return False
128
138
 
@@ -136,7 +146,7 @@ def expectedShutdownErrors(e: Exception) -> bool:
136
146
  impossible or unnecessary (such as errors resulting from a thing not
137
147
  existing to be deleted).
138
148
  """
139
- return get_error_status(e) == 400 and 'dependent object' in get_error_body(e)
149
+ return get_error_status(e) == 400 and "dependent object" in get_error_body(e)
140
150
 
141
151
 
142
152
  F = TypeVar("F") # so mypy understands passed through types
@@ -151,28 +161,42 @@ def awsRetry(f: Callable[..., F]) -> Callable[..., F]:
151
161
 
152
162
  @wraps(f)
153
163
  def wrapper(*args: Any, **kwargs: Any) -> Any:
154
- for attempt in old_retry(delays=truncExpBackoff(),
155
- timeout=300,
156
- predicate=awsRetryPredicate):
164
+ for attempt in old_retry(
165
+ delays=truncExpBackoff(), timeout=300, predicate=awsRetryPredicate
166
+ ):
157
167
  with attempt:
158
168
  return f(*args, **kwargs)
159
169
 
160
170
  return wrapper
161
171
 
162
172
 
163
- def awsFilterImpairedNodes(nodes: List[InstanceTypeDef], boto3_ec2: EC2Client) -> List[InstanceTypeDef]:
173
+ def awsFilterImpairedNodes(
174
+ nodes: list[InstanceTypeDef], boto3_ec2: EC2Client
175
+ ) -> list[InstanceTypeDef]:
164
176
  # if TOIL_AWS_NODE_DEBUG is set don't terminate nodes with
165
177
  # failing status checks so they can be debugged
166
- nodeDebug = os.environ.get('TOIL_AWS_NODE_DEBUG') in ('True', 'TRUE', 'true', True)
178
+ nodeDebug = os.environ.get("TOIL_AWS_NODE_DEBUG") in ("True", "TRUE", "true", True)
167
179
  if not nodeDebug:
168
180
  return nodes
169
181
  nodeIDs = [node["InstanceId"] for node in nodes]
170
182
  statuses = boto3_ec2.describe_instance_status(InstanceIds=nodeIDs)
171
- statusMap = {status["InstanceId"]: status["InstanceStatus"]["Status"] for status in statuses["InstanceStatuses"]}
172
- healthyNodes = [node for node in nodes if statusMap.get(node["InstanceId"], None) != 'impaired']
173
- impairedNodes = [node["InstanceId"] for node in nodes if statusMap.get(node["InstanceId"], None) == 'impaired']
174
- logger.warning('TOIL_AWS_NODE_DEBUG is set and nodes %s have failed EC2 status checks so '
175
- 'will not be terminated.', ' '.join(impairedNodes))
183
+ statusMap = {
184
+ status["InstanceId"]: status["InstanceStatus"]["Status"]
185
+ for status in statuses["InstanceStatuses"]
186
+ }
187
+ healthyNodes = [
188
+ node for node in nodes if statusMap.get(node["InstanceId"], None) != "impaired"
189
+ ]
190
+ impairedNodes = [
191
+ node["InstanceId"]
192
+ for node in nodes
193
+ if statusMap.get(node["InstanceId"], None) == "impaired"
194
+ ]
195
+ logger.warning(
196
+ "TOIL_AWS_NODE_DEBUG is set and nodes %s have failed EC2 status checks so "
197
+ "will not be terminated.",
198
+ " ".join(impairedNodes),
199
+ )
176
200
  return healthyNodes
177
201
 
178
202
 
@@ -180,13 +204,13 @@ class InvalidClusterStateException(Exception):
180
204
  pass
181
205
 
182
206
 
183
- def collapse_tags(instance_tags: List[TagTypeDef]) -> Dict[str, str]:
207
+ def collapse_tags(instance_tags: list[TagTypeDef]) -> dict[str, str]:
184
208
  """
185
209
  Collapse tags from boto3 format to node format
186
- :param instance_tags: tags as list of TagTypeDef
210
+ :param instance_tags: tags as a list
187
211
  :return: Dict of tags
188
212
  """
189
- collapsed_tags: Dict[str, str] = dict()
213
+ collapsed_tags: dict[str, str] = dict()
190
214
  for tag in instance_tags:
191
215
  if tag.get("Key") is not None:
192
216
  collapsed_tags[tag["Key"]] = tag["Value"]
@@ -194,10 +218,17 @@ def collapse_tags(instance_tags: List[TagTypeDef]) -> Dict[str, str]:
194
218
 
195
219
 
196
220
  class AWSProvisioner(AbstractProvisioner):
197
- def __init__(self, clusterName: Optional[str], clusterType: Optional[str], zone: Optional[str],
198
- nodeStorage: int, nodeStorageOverrides: Optional[List[str]], sseKey: Optional[str],
199
- enable_fuse: bool):
200
- self.cloud = 'aws'
221
+ def __init__(
222
+ self,
223
+ clusterName: str | None,
224
+ clusterType: str | None,
225
+ zone: str | None,
226
+ nodeStorage: int,
227
+ nodeStorageOverrides: list[str] | None,
228
+ sseKey: str | None,
229
+ enable_fuse: bool,
230
+ ):
231
+ self.cloud = "aws"
201
232
  self._sseKey = sseKey
202
233
  # self._zone will be filled in by base class constructor
203
234
  # We will use it as the leader zone.
@@ -205,9 +236,11 @@ class AWSProvisioner(AbstractProvisioner):
205
236
 
206
237
  if zone is None:
207
238
  # Can't proceed without a real zone
208
- raise RuntimeError('No AWS availability zone specified. Configure in Boto '
209
- 'configuration file, TOIL_AWS_ZONE environment variable, or '
210
- 'on the command line.')
239
+ raise RuntimeError(
240
+ "No AWS availability zone specified. Configure in Boto "
241
+ "configuration file, TOIL_AWS_ZONE environment variable, or "
242
+ "on the command line."
243
+ )
211
244
 
212
245
  # Determine our region to work in, before readClusterSettings() which
213
246
  # might need it. TODO: support multiple regions in one cluster
@@ -218,24 +251,37 @@ class AWSProvisioner(AbstractProvisioner):
218
251
 
219
252
  # Set our architecture to the current machine architecture
220
253
  # Assume the same architecture unless specified differently in launchCluster()
221
- self._architecture = 'amd64' if platform.machine() in ['x86_64', 'amd64'] else 'arm64'
254
+ self._architecture = (
255
+ "amd64" if platform.machine() in ["x86_64", "amd64"] else "arm64"
256
+ )
222
257
 
223
258
  # Call base class constructor, which will call createClusterSettings()
224
259
  # or readClusterSettings()
225
- super().__init__(clusterName, clusterType, zone, nodeStorage, nodeStorageOverrides, enable_fuse)
260
+ super().__init__(
261
+ clusterName,
262
+ clusterType,
263
+ zone,
264
+ nodeStorage,
265
+ nodeStorageOverrides,
266
+ enable_fuse,
267
+ )
226
268
 
227
269
  if self._zone is None:
228
- logger.warning("Leader zone was never initialized before creating AWS provisioner. Defaulting to cluster zone.")
270
+ logger.warning(
271
+ "Leader zone was never initialized before creating AWS provisioner. Defaulting to cluster zone."
272
+ )
229
273
 
230
274
  self._leader_subnet: str = self._get_default_subnet(self._zone or zone)
231
- self._tags: Dict[str, Any] = {}
275
+ self._tags: dict[str, Any] = {}
232
276
 
233
277
  # After self.clusterName is set, generate a valid name for the S3 bucket associated with this cluster
234
278
  suffix = _S3_BUCKET_INTERNAL_SUFFIX
235
- self.s3_bucket_name = self.clusterName[:_S3_BUCKET_MAX_NAME_LEN - len(suffix)] + suffix
279
+ self.s3_bucket_name = (
280
+ self.clusterName[: _S3_BUCKET_MAX_NAME_LEN - len(suffix)] + suffix
281
+ )
236
282
 
237
- def supportedClusterTypes(self) -> Set[str]:
238
- return {'mesos', 'kubernetes'}
283
+ def supportedClusterTypes(self) -> set[str]:
284
+ return {"mesos", "kubernetes"}
239
285
 
240
286
  def createClusterSettings(self) -> None:
241
287
  """
@@ -253,8 +299,11 @@ class AWSProvisioner(AbstractProvisioner):
253
299
  the instance is the leader.
254
300
  """
255
301
  from ec2_metadata import ec2_metadata
256
- boto3_ec2 = self.aws.client(self._region, 'ec2')
257
- instance: InstanceTypeDef = boto3_ec2.describe_instances(InstanceIds=[ec2_metadata.instance_id])["Reservations"][0]["Instances"][0]
302
+
303
+ boto3_ec2 = self.aws.client(self._region, "ec2")
304
+ instance: InstanceTypeDef = boto3_ec2.describe_instances(
305
+ InstanceIds=[ec2_metadata.instance_id]
306
+ )["Reservations"][0]["Instances"][0]
258
307
  # The cluster name is the same as the name of the leader.
259
308
  self.clusterName: str = "default-toil-cluster-name"
260
309
  for tag in instance["Tags"]:
@@ -266,7 +315,11 @@ class AWSProvisioner(AbstractProvisioner):
266
315
  self._worker_subnets_by_zone = self._get_good_subnets_like(self._leader_subnet)
267
316
 
268
317
  self._leaderPrivateIP = ec2_metadata.private_ipv4 # this is PRIVATE IP
269
- self._tags = {k: v for k, v in (self.getLeader().tags or {}).items() if k != _TAG_KEY_TOIL_NODE_TYPE}
318
+ self._tags = {
319
+ k: v
320
+ for k, v in (self.getLeader().tags or {}).items()
321
+ if k != _TAG_KEY_TOIL_NODE_TYPE
322
+ }
270
323
  # Grab the ARN name of the instance profile (a str) to apply to workers
271
324
  leader_info = None
272
325
  for attempt in old_retry(timeout=300, predicate=lambda e: True):
@@ -283,9 +336,9 @@ class AWSProvisioner(AbstractProvisioner):
283
336
  # The existing metadata API returns a single string if there is one security group, but
284
337
  # a list when there are multiple: change the format to always be a list.
285
338
  rawSecurityGroups = ec2_metadata.security_groups
286
- self._leaderSecurityGroupNames: Set[str] = set(rawSecurityGroups)
339
+ self._leaderSecurityGroupNames: set[str] = set(rawSecurityGroups)
287
340
  # Since we have access to the names, we don't also need to use any IDs
288
- self._leaderSecurityGroupIDs: Set[str] = set()
341
+ self._leaderSecurityGroupIDs: set[str] = set()
289
342
 
290
343
  # Let the base provisioner work out how to deploy duly authorized
291
344
  # workers for this leader.
@@ -296,8 +349,8 @@ class AWSProvisioner(AbstractProvisioner):
296
349
  bucket_name = self.s3_bucket_name
297
350
 
298
351
  # Connect to S3
299
- s3 = self.aws.resource(self._region, 's3')
300
- s3_client = self.aws.client(self._region, 's3')
352
+ s3 = self.aws.resource(self._region, "s3")
353
+ s3_client = self.aws.client(self._region, "s3")
301
354
 
302
355
  # create bucket if needed, then write file to S3
303
356
  try:
@@ -306,14 +359,18 @@ class AWSProvisioner(AbstractProvisioner):
306
359
  bucket = s3.Bucket(bucket_name)
307
360
  except ClientError as err:
308
361
  if get_error_status(err) == 404:
309
- bucket = create_s3_bucket(s3, bucket_name=bucket_name, region=self._region)
362
+ bucket = create_s3_bucket(
363
+ s3, bucket_name=bucket_name, region=self._region
364
+ )
310
365
  bucket.wait_until_exists()
311
366
  bucket.Versioning().enable()
312
367
 
313
- owner_tag = os.environ.get('TOIL_OWNER_TAG')
368
+ owner_tag = os.environ.get("TOIL_OWNER_TAG")
314
369
  if owner_tag:
315
370
  bucket_tagging = s3.BucketTagging(bucket_name)
316
- bucket_tagging.put(Tagging={'TagSet': [{'Key': 'Owner', 'Value': owner_tag}]})
371
+ bucket_tagging.put(
372
+ Tagging={"TagSet": [{"Key": "Owner", "Value": owner_tag}]}
373
+ )
317
374
  else:
318
375
  raise
319
376
 
@@ -323,34 +380,38 @@ class AWSProvisioner(AbstractProvisioner):
323
380
  obj.put(Body=contents)
324
381
 
325
382
  obj.wait_until_exists()
326
- return f's3://{bucket_name}/{key}'
383
+ return f"s3://{bucket_name}/{key}"
327
384
 
328
385
  def _read_file_from_cloud(self, key: str) -> bytes:
329
386
  bucket_name = self.s3_bucket_name
330
- obj = self.aws.resource(self._region, 's3').Object(bucket_name, key)
387
+ obj = self.aws.resource(self._region, "s3").Object(bucket_name, key)
331
388
 
332
389
  try:
333
- return obj.get()['Body'].read()
390
+ return obj.get()["Body"].read()
334
391
  except ClientError as e:
335
392
  if get_error_status(e) == 404:
336
- logger.warning(f'Trying to read non-existent file "{key}" from {bucket_name}.')
393
+ logger.warning(
394
+ f'Trying to read non-existent file "{key}" from {bucket_name}.'
395
+ )
337
396
  raise
338
397
 
339
398
  def _get_user_data_limit(self) -> int:
340
399
  # See: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-add-user-data.html
341
- return human2bytes('16KB')
342
-
343
- def launchCluster(self,
344
- leaderNodeType: str,
345
- leaderStorage: int,
346
- owner: str,
347
- keyName: str,
348
- botoPath: str,
349
- userTags: Optional[Dict[str, str]],
350
- vpcSubnet: Optional[str],
351
- awsEc2ProfileArn: Optional[str],
352
- awsEc2ExtraSecurityGroupIds: Optional[List[str]],
353
- **kwargs: Dict[str, Any]) -> None:
400
+ return human2bytes("16KB")
401
+
402
+ def launchCluster(
403
+ self,
404
+ leaderNodeType: str,
405
+ leaderStorage: int,
406
+ owner: str,
407
+ keyName: str,
408
+ botoPath: str,
409
+ userTags: dict[str, str] | None,
410
+ vpcSubnet: str | None,
411
+ awsEc2ProfileArn: str | None,
412
+ awsEc2ExtraSecurityGroupIds: list[str] | None,
413
+ **kwargs: dict[str, Any],
414
+ ) -> None:
354
415
  """
355
416
  Starts a single leader node and populates this class with the leader's metadata.
356
417
 
@@ -366,29 +427,42 @@ class AWSProvisioner(AbstractProvisioner):
366
427
  :return: None
367
428
  """
368
429
 
369
- if 'network' in kwargs:
370
- logger.warning('AWS provisioner does not support a network parameter. Ignoring %s!', kwargs["network"])
430
+ if "network" in kwargs:
431
+ logger.warning(
432
+ "AWS provisioner does not support a network parameter. Ignoring %s!",
433
+ kwargs["network"],
434
+ )
371
435
 
372
436
  # First, pre-flight-check our permissions before making anything.
373
- if not policy_permissions_allow(get_policy_permissions(region=self._region), CLUSTER_LAUNCHING_PERMISSIONS):
437
+ if not policy_permissions_allow(
438
+ get_policy_permissions(region=self._region), CLUSTER_LAUNCHING_PERMISSIONS
439
+ ):
374
440
  # Function prints a more specific warning to the log, but give some context.
375
- logger.warning('Toil may not be able to properly launch (or destroy!) your cluster.')
441
+ logger.warning(
442
+ "Toil may not be able to properly launch (or destroy!) your cluster."
443
+ )
376
444
 
377
445
  leader_type = E2Instances[leaderNodeType]
378
446
 
379
- if self.clusterType == 'kubernetes':
447
+ if self.clusterType == "kubernetes":
380
448
  if leader_type.cores < 2:
381
449
  # Kubernetes won't run here.
382
- raise RuntimeError('Kubernetes requires 2 or more cores, and %s is too small' %
383
- leaderNodeType)
450
+ raise RuntimeError(
451
+ "Kubernetes requires 2 or more cores, and %s is too small"
452
+ % leaderNodeType
453
+ )
384
454
  self._keyName = keyName
385
455
  self._architecture = leader_type.architecture
386
456
 
387
- if self.clusterType == 'mesos' and self._architecture != 'amd64':
457
+ if self.clusterType == "mesos" and self._architecture != "amd64":
388
458
  # Mesos images aren't currently available for this architecture, so we can't start a Mesos cluster.
389
459
  # Complain about this before we create anything.
390
- raise ClusterCombinationNotSupportedException(type(self), self.clusterType, self._architecture,
391
- reason="Mesos is only available for amd64.")
460
+ raise ClusterCombinationNotSupportedException(
461
+ type(self),
462
+ self.clusterType,
463
+ self._architecture,
464
+ reason="Mesos is only available for amd64.",
465
+ )
392
466
 
393
467
  if vpcSubnet:
394
468
  # This is where we put the leader
@@ -401,45 +475,49 @@ class AWSProvisioner(AbstractProvisioner):
401
475
  bdms = self._getBoto3BlockDeviceMappings(leader_type, rootVolSize=leaderStorage)
402
476
 
403
477
  # Make up the tags
404
- self._tags = {'Name': self.clusterName,
405
- 'Owner': owner,
406
- _TAG_KEY_TOIL_CLUSTER_NAME: self.clusterName}
478
+ self._tags = {
479
+ "Name": self.clusterName,
480
+ "Owner": owner,
481
+ _TAG_KEY_TOIL_CLUSTER_NAME: self.clusterName,
482
+ }
407
483
 
408
484
  if userTags is not None:
409
485
  self._tags.update(userTags)
410
486
 
411
487
  # All user specified tags have been set
412
- userData = self._getIgnitionUserData('leader', architecture=self._architecture)
488
+ userData = self._getIgnitionUserData("leader", architecture=self._architecture)
413
489
 
414
- if self.clusterType == 'kubernetes':
490
+ if self.clusterType == "kubernetes":
415
491
  # All nodes need a tag putting them in the cluster.
416
492
  # This tag needs to be on there before the a leader can finish its startup.
417
- self._tags['kubernetes.io/cluster/' + self.clusterName] = ''
493
+ self._tags["kubernetes.io/cluster/" + self.clusterName] = ""
418
494
 
419
495
  # Make tags for the leader specifically
420
496
  leader_tags = dict(self._tags)
421
- leader_tags[_TAG_KEY_TOIL_NODE_TYPE] = 'leader'
422
- logger.debug('Launching leader with tags: %s', leader_tags)
423
-
424
- instances: List[Instance] = create_instances(self.aws.resource(self._region, 'ec2'),
425
- image_id=self._discoverAMI(),
426
- num_instances=1,
427
- key_name=self._keyName,
428
- security_group_ids=createdSGs + (awsEc2ExtraSecurityGroupIds or []),
429
- instance_type=leader_type.name,
430
- user_data=userData,
431
- block_device_map=bdms,
432
- instance_profile_arn=profileArn,
433
- placement_az=self._zone,
434
- subnet_id=self._leader_subnet,
435
- tags=leader_tags)
497
+ leader_tags[_TAG_KEY_TOIL_NODE_TYPE] = "leader"
498
+ logger.debug("Launching leader with tags: %s", leader_tags)
499
+
500
+ instances: list[Instance] = create_instances(
501
+ self.aws.resource(self._region, "ec2"),
502
+ image_id=self._discoverAMI(),
503
+ num_instances=1,
504
+ key_name=self._keyName,
505
+ security_group_ids=createdSGs + (awsEc2ExtraSecurityGroupIds or []),
506
+ instance_type=leader_type.name,
507
+ user_data=userData,
508
+ block_device_map=bdms,
509
+ instance_profile_arn=profileArn,
510
+ placement_az=self._zone,
511
+ subnet_id=self._leader_subnet,
512
+ tags=leader_tags,
513
+ )
436
514
 
437
515
  # wait for the leader to exist at all
438
516
  leader = instances[0]
439
517
  leader.wait_until_exists()
440
518
 
441
519
  # Don't go on until the leader is started
442
- logger.info('Waiting for leader instance %s to be running', leader)
520
+ logger.info("Waiting for leader instance %s to be running", leader)
443
521
  leader.wait_until_running()
444
522
 
445
523
  # Now reload it to make sure all the IPs are set.
@@ -449,22 +527,31 @@ class AWSProvisioner(AbstractProvisioner):
449
527
  # Sometimes AWS just fails to assign a public IP when we really need one.
450
528
  # But sometimes people intend to use private IPs only in Toil-managed clusters.
451
529
  # TODO: detect if we have a route to the private IP and fail fast if not.
452
- logger.warning("AWS did not assign a public IP to the cluster leader. If you aren't "
453
- "connected to the private subnet, cluster setup will fail!")
530
+ logger.warning(
531
+ "AWS did not assign a public IP to the cluster leader. If you aren't "
532
+ "connected to the private subnet, cluster setup will fail!"
533
+ )
454
534
 
455
535
  # Remember enough about the leader to let us launch workers in its
456
536
  # cluster.
457
537
  self._leaderPrivateIP = leader.private_ip_address
458
538
  self._worker_subnets_by_zone = self._get_good_subnets_like(self._leader_subnet)
459
539
  self._leaderSecurityGroupNames = set()
460
- self._leaderSecurityGroupIDs = set(createdSGs + (awsEc2ExtraSecurityGroupIds or []))
540
+ self._leaderSecurityGroupIDs = set(
541
+ createdSGs + (awsEc2ExtraSecurityGroupIds or [])
542
+ )
461
543
  self._leaderProfileArn = profileArn
462
544
 
463
- leaderNode = Node(publicIP=leader.public_ip_address, privateIP=leader.private_ip_address,
464
- name=leader.id, launchTime=leader.launch_time,
465
- nodeType=leader_type.name, preemptible=False,
466
- tags=collapse_tags(leader.tags))
467
- leaderNode.waitForNode('toil_leader')
545
+ leaderNode = Node(
546
+ publicIP=leader.public_ip_address,
547
+ privateIP=leader.private_ip_address,
548
+ name=leader.id,
549
+ launchTime=leader.launch_time,
550
+ nodeType=leader_type.name,
551
+ preemptible=False,
552
+ tags=collapse_tags(leader.tags),
553
+ )
554
+ leaderNode.waitForNode("toil_leader")
468
555
 
469
556
  # Download credentials
470
557
  self._setLeaderWorkerAuthentication(leaderNode)
@@ -477,7 +564,7 @@ class AWSProvisioner(AbstractProvisioner):
477
564
  instance_base_tags = json.dumps(self._tags)
478
565
  return config + " -e TOIL_AWS_TAGS=" + quote(instance_base_tags)
479
566
 
480
- def _get_worker_subnets(self) -> List[str]:
567
+ def _get_worker_subnets(self) -> list[str]:
481
568
  """
482
569
  Get all worker subnets we should balance across, as a flat list.
483
570
  """
@@ -493,7 +580,7 @@ class AWSProvisioner(AbstractProvisioner):
493
580
  return collected
494
581
 
495
582
  @awsRetry
496
- def _get_good_subnets_like(self, base_subnet_id: str) -> Dict[str, List[str]]:
583
+ def _get_good_subnets_like(self, base_subnet_id: str) -> dict[str, list[str]]:
497
584
  """
498
585
  Given a subnet ID, get all the similar subnets (including it),
499
586
  organized by availability zone.
@@ -504,9 +591,9 @@ class AWSProvisioner(AbstractProvisioner):
504
591
  """
505
592
 
506
593
  # Grab the ec2 resource we need to make queries
507
- ec2 = self.aws.resource(self._region, 'ec2')
594
+ ec2 = self.aws.resource(self._region, "ec2")
508
595
  # And the client
509
- ec2_client = self.aws.client(self._region, 'ec2')
596
+ ec2_client = self.aws.client(self._region, "ec2")
510
597
 
511
598
  # What subnet are we basing this on?
512
599
  base_subnet = ec2.Subnet(base_subnet_id)
@@ -521,31 +608,33 @@ class AWSProvisioner(AbstractProvisioner):
521
608
  acls = set(self._get_subnet_acls(base_subnet_id))
522
609
 
523
610
  # Compose a filter that selects the subnets we might want
524
- filters: List[FilterTypeDef] = [{
525
- 'Name': 'vpc-id',
526
- 'Values': [vpc_id]
527
- }, {
528
- 'Name': 'default-for-az',
529
- 'Values': ['true' if is_default else 'false']
530
- }, {
531
- 'Name': 'state',
532
- 'Values': ['available']
533
- }]
611
+ filters: list[FilterTypeDef] = [
612
+ {"Name": "vpc-id", "Values": [vpc_id]},
613
+ {"Name": "default-for-az", "Values": ["true" if is_default else "false"]},
614
+ {"Name": "state", "Values": ["available"]},
615
+ ]
534
616
 
535
617
  # Fill in this collection
536
- by_az: Dict[str, List[str]] = {}
618
+ by_az: dict[str, list[str]] = {}
537
619
 
538
620
  # Go get all the subnets. There's no way to page manually here so it
539
621
  # must page automatically.
540
- for subnet in self.aws.resource(self._region, 'ec2').subnets.filter(Filters=filters):
622
+ for subnet in self.aws.resource(self._region, "ec2").subnets.filter(
623
+ Filters=filters
624
+ ):
541
625
  # For each subnet in the VPC
542
626
 
543
627
  # See if it has the right ACLs
544
628
  subnet_acls = set(self._get_subnet_acls(subnet.subnet_id))
545
629
  if subnet_acls != acls:
546
630
  # Reject this subnet because it has different ACLs
547
- logger.debug('Subnet %s is a lot like subnet %s but has ACLs of %s instead of %s; skipping',
548
- subnet.subnet_id, base_subnet_id, subnet_acls, acls)
631
+ logger.debug(
632
+ "Subnet %s is a lot like subnet %s but has ACLs of %s instead of %s; skipping",
633
+ subnet.subnet_id,
634
+ base_subnet_id,
635
+ subnet_acls,
636
+ acls,
637
+ )
549
638
  continue
550
639
 
551
640
  if subnet.availability_zone not in by_az:
@@ -557,24 +646,24 @@ class AWSProvisioner(AbstractProvisioner):
557
646
  return by_az
558
647
 
559
648
  @awsRetry
560
- def _get_subnet_acls(self, subnet: str) -> List[str]:
649
+ def _get_subnet_acls(self, subnet: str) -> list[str]:
561
650
  """
562
651
  Get all Network ACL IDs associated with a given subnet ID.
563
652
  """
564
653
 
565
654
  # Grab the connection we need to use for this operation.
566
- ec2 = self.aws.client(self._region, 'ec2')
655
+ ec2 = self.aws.client(self._region, "ec2")
567
656
 
568
657
  # Compose a filter that selects the default subnet in the AZ
569
- filters = [{
570
- 'Name': 'association.subnet-id',
571
- 'Values': [subnet]
572
- }]
658
+ filters = [{"Name": "association.subnet-id", "Values": [subnet]}]
573
659
 
574
660
  # TODO: Can't we use the resource's network_acls.filter(Filters=)?
575
- return [item['NetworkAclId'] for item in boto3_pager(ec2.describe_network_acls,
576
- 'NetworkAcls',
577
- Filters=filters)]
661
+ return [
662
+ item["NetworkAclId"]
663
+ for item in boto3_pager(
664
+ ec2.describe_network_acls, "NetworkAcls", Filters=filters
665
+ )
666
+ ]
578
667
 
579
668
  @awsRetry
580
669
  def _get_default_subnet(self, zone: str) -> str:
@@ -584,30 +673,32 @@ class AWSProvisioner(AbstractProvisioner):
584
673
  """
585
674
 
586
675
  # Compose a filter that selects the default subnet in the AZ
587
- filters: List[FilterTypeDef] = [{
588
- 'Name': 'default-for-az',
589
- 'Values': ['true']
590
- }, {
591
- 'Name': 'availability-zone',
592
- 'Values': [zone]
593
- }]
594
-
595
- for subnet in self.aws.resource(zone_to_region(zone), 'ec2').subnets.filter(Filters=filters):
676
+ filters: list[FilterTypeDef] = [
677
+ {"Name": "default-for-az", "Values": ["true"]},
678
+ {"Name": "availability-zone", "Values": [zone]},
679
+ ]
680
+
681
+ for subnet in self.aws.resource(zone_to_region(zone), "ec2").subnets.filter(
682
+ Filters=filters
683
+ ):
596
684
  # There should only be one result, so when we see it, return it
597
685
  return subnet.subnet_id
598
686
  # If we don't find a subnet, something is wrong. Maybe this zone was
599
687
  # added after your account?
600
- raise RuntimeError(f"No default subnet found in availability zone {zone}. "
601
- f"Note that Amazon does not add default subnets for new "
602
- f"zones to old accounts. Specify a VPC subnet ID to use, "
603
- f"or create a default subnet in the zone.")
688
+ raise RuntimeError(
689
+ f"No default subnet found in availability zone {zone}. "
690
+ f"Note that Amazon does not add default subnets for new "
691
+ f"zones to old accounts. Specify a VPC subnet ID to use, "
692
+ f"or create a default subnet in the zone."
693
+ )
604
694
 
605
- def getKubernetesAutoscalerSetupCommands(self, values: Dict[str, str]) -> str:
695
+ def getKubernetesAutoscalerSetupCommands(self, values: dict[str, str]) -> str:
606
696
  """
607
697
  Get the Bash commands necessary to configure the Kubernetes Cluster Autoscaler for AWS.
608
698
  """
609
699
 
610
- return textwrap.dedent('''\
700
+ return textwrap.dedent(
701
+ """\
611
702
  curl -sSL https://raw.githubusercontent.com/kubernetes/autoscaler/cluster-autoscaler-{AUTOSCALER_VERSION}/cluster-autoscaler/cloudprovider/aws/examples/cluster-autoscaler-run-on-master.yaml | \\
612
703
  sed "s|--nodes={{{{ node_asg_min }}}}:{{{{ node_asg_max }}}}:{{{{ name }}}}|--node-group-auto-discovery=asg:tag=k8s.io/cluster-autoscaler/enabled,k8s.io/cluster-autoscaler/{CLUSTER_NAME}|" | \\
613
704
  sed 's|kubernetes.io/role: master|node-role.kubernetes.io/master: ""|' | \\
@@ -615,36 +706,41 @@ class AWSProvisioner(AbstractProvisioner):
615
706
  sed '/value: "true"/d' | \\
616
707
  sed 's|path: "/etc/ssl/certs/ca-bundle.crt"|path: "/usr/share/ca-certificates/ca-certificates.crt"|' | \\
617
708
  kubectl apply -f -
618
- ''').format(**values)
709
+ """
710
+ ).format(**values)
619
711
 
620
- def getKubernetesCloudProvider(self) -> Optional[str]:
712
+ def getKubernetesCloudProvider(self) -> str | None:
621
713
  """
622
714
  Use the "aws" Kubernetes cloud provider when setting up Kubernetes.
623
715
  """
624
716
 
625
- return 'aws'
717
+ return "aws"
626
718
 
627
- def getNodeShape(self, instance_type: str, preemptible: bool=False) -> Shape:
719
+ def getNodeShape(self, instance_type: str, preemptible: bool = False) -> Shape:
628
720
  """
629
721
  Get the Shape for the given instance type (e.g. 't2.medium').
630
722
  """
631
723
  type_info = E2Instances[instance_type]
632
724
 
633
- disk = type_info.disks * type_info.disk_capacity * 2 ** 30
725
+ disk = type_info.disks * type_info.disk_capacity * 2**30
634
726
  if disk == 0:
635
727
  # This is an EBS-backed instance. We will use the root
636
728
  # volume, so add the amount of EBS storage requested for
637
729
  # the root volume
638
- disk = self._nodeStorageOverrides.get(instance_type, self._nodeStorage) * 2 ** 30
730
+ disk = (
731
+ self._nodeStorageOverrides.get(instance_type, self._nodeStorage) * 2**30
732
+ )
639
733
 
640
734
  # Underestimate memory by 100M to prevent autoscaler from disagreeing with
641
735
  # mesos about whether a job can run on a particular node type
642
- memory = (type_info.memory - 0.1) * 2 ** 30
643
- return Shape(wallTime=60 * 60,
644
- memory=int(memory),
645
- cores=type_info.cores,
646
- disk=int(disk),
647
- preemptible=preemptible)
736
+ memory = (type_info.memory - 0.1) * 2**30
737
+ return Shape(
738
+ wallTime=60 * 60,
739
+ memory=int(memory),
740
+ cores=type_info.cores,
741
+ disk=int(disk),
742
+ preemptible=preemptible,
743
+ )
648
744
 
649
745
  @staticmethod
650
746
  def retryPredicate(e: Exception) -> bool:
@@ -659,14 +755,14 @@ class AWSProvisioner(AbstractProvisioner):
659
755
  try:
660
756
  leader = self._getLeaderInstanceBoto3()
661
757
  vpcId = leader.get("VpcId")
662
- logger.info('Terminating the leader first ...')
758
+ logger.info("Terminating the leader first ...")
663
759
  self._terminateInstances([leader])
664
760
  except (NoSuchClusterException, InvalidClusterStateException):
665
761
  # It's ok if the leader is not found. We'll terminate any remaining
666
762
  # instances below anyway.
667
763
  pass
668
764
 
669
- logger.debug('Deleting autoscaling groups ...')
765
+ logger.debug("Deleting autoscaling groups ...")
670
766
  removed = False
671
767
 
672
768
  for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors):
@@ -674,20 +770,28 @@ class AWSProvisioner(AbstractProvisioner):
674
770
  for asgName in self._getAutoScalingGroupNames():
675
771
  try:
676
772
  # We delete the group and all the instances via ForceDelete.
677
- self.aws.client(self._region, 'autoscaling').delete_auto_scaling_group(AutoScalingGroupName=asgName, ForceDelete=True)
773
+ self.aws.client(
774
+ self._region, "autoscaling"
775
+ ).delete_auto_scaling_group(
776
+ AutoScalingGroupName=asgName, ForceDelete=True
777
+ )
678
778
  removed = True
679
779
  except ClientError as e:
680
- if get_error_code(e) == 'ValidationError' and 'AutoScalingGroup name not found' in get_error_message(e):
780
+ if get_error_code(
781
+ e
782
+ ) == "ValidationError" and "AutoScalingGroup name not found" in get_error_message(
783
+ e
784
+ ):
681
785
  # This ASG does not need to be removed (or a
682
786
  # previous delete returned an error but also
683
787
  # succeeded).
684
788
  pass
685
789
 
686
790
  if removed:
687
- logger.debug('... Successfully deleted autoscaling groups')
791
+ logger.debug("... Successfully deleted autoscaling groups")
688
792
 
689
793
  # Do the workers after the ASGs because some may belong to ASGs
690
- logger.info('Terminating any remaining workers ...')
794
+ logger.info("Terminating any remaining workers ...")
691
795
  removed = False
692
796
  instances = self._get_nodes_in_cluster_boto3(include_stopped_nodes=True)
693
797
  spotIDs = self._getSpotRequestIDs()
@@ -696,15 +800,17 @@ class AWSProvisioner(AbstractProvisioner):
696
800
  boto3_ec2.cancel_spot_instance_requests(SpotInstanceRequestIds=spotIDs)
697
801
  # self.aws.boto2(self._region, 'ec2').cancel_spot_instance_requests(request_ids=spotIDs)
698
802
  removed = True
699
- instancesToTerminate = awsFilterImpairedNodes(instances, self.aws.client(self._region, 'ec2'))
803
+ instancesToTerminate = awsFilterImpairedNodes(
804
+ instances, self.aws.client(self._region, "ec2")
805
+ )
700
806
  if instancesToTerminate:
701
807
  vpcId = vpcId or instancesToTerminate[0].get("VpcId")
702
808
  self._terminateInstances(instancesToTerminate)
703
809
  removed = True
704
810
  if removed:
705
- logger.debug('... Successfully terminated workers')
811
+ logger.debug("... Successfully terminated workers")
706
812
 
707
- logger.info('Deleting launch templates ...')
813
+ logger.info("Deleting launch templates ...")
708
814
  removed = False
709
815
  for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors):
710
816
  with attempt:
@@ -713,7 +819,7 @@ class AWSProvisioner(AbstractProvisioner):
713
819
  mistake = False
714
820
  for ltID in self._get_launch_template_ids():
715
821
  response = boto3_ec2.delete_launch_template(LaunchTemplateId=ltID)
716
- if 'LaunchTemplate' not in response:
822
+ if "LaunchTemplate" not in response:
717
823
  mistake = True
718
824
  else:
719
825
  removed = True
@@ -721,49 +827,59 @@ class AWSProvisioner(AbstractProvisioner):
721
827
  # We missed something
722
828
  removed = False
723
829
  if removed:
724
- logger.debug('... Successfully deleted launch templates')
830
+ logger.debug("... Successfully deleted launch templates")
725
831
 
726
832
  if len(instances) == len(instancesToTerminate):
727
833
  # All nodes are gone now.
728
834
 
729
- logger.info('Deleting IAM roles ...')
835
+ logger.info("Deleting IAM roles ...")
730
836
  self._deleteRoles(self._getRoleNames())
731
837
  self._deleteInstanceProfiles(self._getInstanceProfileNames())
732
838
 
733
- logger.info('Deleting security group ...')
839
+ logger.info("Deleting security group ...")
734
840
  removed = False
735
841
  for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors):
736
842
  with attempt:
737
- security_groups: List[SecurityGroupTypeDef] = boto3_ec2.describe_security_groups()["SecurityGroups"]
843
+ security_groups: list[SecurityGroupTypeDef] = (
844
+ boto3_ec2.describe_security_groups()["SecurityGroups"]
845
+ )
738
846
  for security_group in security_groups:
739
847
  # TODO: If we terminate the leader and the workers but
740
848
  # miss the security group, we won't find it now because
741
849
  # we won't have vpcId set.
742
- if security_group.get("GroupName") == self.clusterName and vpcId and security_group.get("VpcId") == vpcId:
850
+ if (
851
+ security_group.get("GroupName") == self.clusterName
852
+ and vpcId
853
+ and security_group.get("VpcId") == vpcId
854
+ ):
743
855
  try:
744
- boto3_ec2.delete_security_group(GroupId=security_group["GroupId"])
856
+ boto3_ec2.delete_security_group(
857
+ GroupId=security_group["GroupId"]
858
+ )
745
859
  removed = True
746
860
  except ClientError as e:
747
- if get_error_code(e) == 'InvalidGroup.NotFound':
861
+ if get_error_code(e) == "InvalidGroup.NotFound":
748
862
  pass
749
863
  else:
750
864
  raise
751
865
  if removed:
752
- logger.debug('... Successfully deleted security group')
866
+ logger.debug("... Successfully deleted security group")
753
867
  else:
754
868
  assert len(instances) > len(instancesToTerminate)
755
869
  # the security group can't be deleted until all nodes are terminated
756
- logger.warning('The TOIL_AWS_NODE_DEBUG environment variable is set and some nodes '
757
- 'have failed health checks. As a result, the security group & IAM '
758
- 'roles will not be deleted.')
870
+ logger.warning(
871
+ "The TOIL_AWS_NODE_DEBUG environment variable is set and some nodes "
872
+ "have failed health checks. As a result, the security group & IAM "
873
+ "roles will not be deleted."
874
+ )
759
875
 
760
876
  # delete S3 buckets that might have been created by `self._write_file_to_cloud()`
761
- logger.info('Deleting S3 buckets ...')
877
+ logger.info("Deleting S3 buckets ...")
762
878
  removed = False
763
879
  for attempt in old_retry(timeout=300, predicate=awsRetryPredicate):
764
880
  with attempt:
765
881
  # Grab the S3 resource to use
766
- s3 = self.aws.resource(self._region, 's3')
882
+ s3 = self.aws.resource(self._region, "s3")
767
883
  try:
768
884
  bucket = s3.Bucket(self.s3_bucket_name)
769
885
 
@@ -779,13 +895,15 @@ class AWSProvisioner(AbstractProvisioner):
779
895
  else:
780
896
  raise # retry this
781
897
  if removed:
782
- print('... Successfully deleted S3 buckets')
898
+ print("... Successfully deleted S3 buckets")
783
899
 
784
- def terminateNodes(self, nodes: List[Node]) -> None:
900
+ def terminateNodes(self, nodes: list[Node]) -> None:
785
901
  if nodes:
786
902
  self._terminateIDs([x.name for x in nodes])
787
903
 
788
- def _recover_node_type_bid(self, node_type: Set[str], spot_bid: Optional[float]) -> Optional[float]:
904
+ def _recover_node_type_bid(
905
+ self, node_type: set[str], spot_bid: float | None
906
+ ) -> float | None:
789
907
  """
790
908
  The old Toil-managed autoscaler will tell us to make some nodes of
791
909
  particular instance types, and to just work out a bid, but it doesn't
@@ -812,15 +930,23 @@ class AWSProvisioner(AbstractProvisioner):
812
930
  break
813
931
  if spot_bid is None:
814
932
  # We didn't bid on any class including this type either
815
- raise RuntimeError("No spot bid given for a preemptible node request.")
933
+ raise RuntimeError(
934
+ "No spot bid given for a preemptible node request."
935
+ )
816
936
  else:
817
937
  raise RuntimeError("No spot bid given for a preemptible node request.")
818
938
 
819
939
  return spot_bid
820
940
 
821
- def addNodes(self, nodeTypes: Set[str], numNodes: int, preemptible: bool, spotBid: Optional[float]=None) -> int:
941
+ def addNodes(
942
+ self,
943
+ nodeTypes: set[str],
944
+ numNodes: int,
945
+ preemptible: bool,
946
+ spotBid: float | None = None,
947
+ ) -> int:
822
948
  # Grab the AWS connection we need
823
- boto3_ec2 = get_client(service_name='ec2', region_name=self._region)
949
+ boto3_ec2 = get_client(service_name="ec2", region_name=self._region)
824
950
  assert self._leaderPrivateIP
825
951
 
826
952
  if preemptible:
@@ -832,8 +958,7 @@ class AWSProvisioner(AbstractProvisioner):
832
958
  node_type = next(iter(nodeTypes))
833
959
  type_info = E2Instances[node_type]
834
960
  root_vol_size = self._nodeStorageOverrides.get(node_type, self._nodeStorage)
835
- bdm = self._getBoto3BlockDeviceMapping(type_info,
836
- rootVolSize=root_vol_size)
961
+ bdm = self._getBoto3BlockDeviceMapping(type_info, rootVolSize=root_vol_size)
837
962
 
838
963
  # Pick a zone and subnet_id to launch into
839
964
  if preemptible:
@@ -855,7 +980,9 @@ class AWSProvisioner(AbstractProvisioner):
855
980
  # Pick an arbitrary zone we can use.
856
981
  zone = next(iter(self._worker_subnets_by_zone.keys()))
857
982
  if zone is None:
858
- logger.exception("Could not find a valid zone. Make sure TOIL_AWS_ZONE is set or spot bids are not too low.")
983
+ logger.exception(
984
+ "Could not find a valid zone. Make sure TOIL_AWS_ZONE is set or spot bids are not too low."
985
+ )
859
986
  raise NoSuchZoneException()
860
987
  if self._leader_subnet in self._worker_subnets_by_zone.get(zone, []):
861
988
  # The leader's subnet is an option for this zone, so use it.
@@ -865,40 +992,38 @@ class AWSProvisioner(AbstractProvisioner):
865
992
  subnet_id = next(iter(self._worker_subnets_by_zone[zone]))
866
993
 
867
994
  keyPath = self._sseKey if self._sseKey else None
868
- userData: str = self._getIgnitionUserData('worker', keyPath, preemptible, self._architecture)
995
+ userData: str = self._getIgnitionUserData(
996
+ "worker", keyPath, preemptible, self._architecture
997
+ )
869
998
  userDataBytes: bytes = b""
870
999
  if isinstance(userData, str):
871
1000
  # Spot-market provisioning requires bytes for user data.
872
- userDataBytes = userData.encode('utf-8')
873
-
874
- spot_kwargs = {'KeyName': self._keyName,
875
- 'LaunchSpecification': {
876
- 'SecurityGroupIds': self._getSecurityGroupIDs(),
877
- 'InstanceType': type_info.name,
878
- 'UserData': userDataBytes,
879
- 'BlockDeviceMappings': bdm,
880
- 'IamInstanceProfile': {
881
- 'Arn': self._leaderProfileArn
882
- },
883
- 'Placement': {
884
- 'AvailabilityZone': zone
885
- },
886
- 'SubnetId': subnet_id}
887
- }
888
- on_demand_kwargs = {'KeyName': self._keyName,
889
- 'SecurityGroupIds': self._getSecurityGroupIDs(),
890
- 'InstanceType': type_info.name,
891
- 'UserData': userDataBytes,
892
- 'BlockDeviceMappings': bdm,
893
- 'IamInstanceProfile': {
894
- 'Arn': self._leaderProfileArn
895
- },
896
- 'Placement': {
897
- 'AvailabilityZone': zone
898
- },
899
- 'SubnetId': subnet_id}
900
-
901
- instancesLaunched: List[InstanceTypeDef] = []
1001
+ userDataBytes = userData.encode("utf-8")
1002
+
1003
+ spot_kwargs = {
1004
+ "KeyName": self._keyName,
1005
+ "LaunchSpecification": {
1006
+ "SecurityGroupIds": self._getSecurityGroupIDs(),
1007
+ "InstanceType": type_info.name,
1008
+ "UserData": userDataBytes,
1009
+ "BlockDeviceMappings": bdm,
1010
+ "IamInstanceProfile": {"Arn": self._leaderProfileArn},
1011
+ "Placement": {"AvailabilityZone": zone},
1012
+ "SubnetId": subnet_id,
1013
+ },
1014
+ }
1015
+ on_demand_kwargs = {
1016
+ "KeyName": self._keyName,
1017
+ "SecurityGroupIds": self._getSecurityGroupIDs(),
1018
+ "InstanceType": type_info.name,
1019
+ "UserData": userDataBytes,
1020
+ "BlockDeviceMappings": bdm,
1021
+ "IamInstanceProfile": {"Arn": self._leaderProfileArn},
1022
+ "Placement": {"AvailabilityZone": zone},
1023
+ "SubnetId": subnet_id,
1024
+ }
1025
+
1026
+ instancesLaunched: list[InstanceTypeDef] = []
902
1027
 
903
1028
  for attempt in old_retry(predicate=awsRetryPredicate):
904
1029
  with attempt:
@@ -906,49 +1031,85 @@ class AWSProvisioner(AbstractProvisioner):
906
1031
  # the biggest obstacle is AWS request throttling, so we retry on these errors at
907
1032
  # every request in this method
908
1033
  if not preemptible:
909
- logger.debug('Launching %s non-preemptible nodes', numNodes)
910
- instancesLaunched = create_ondemand_instances(boto3_ec2=boto3_ec2,
911
- image_id=self._discoverAMI(),
912
- spec=on_demand_kwargs, num_instances=numNodes)
1034
+ logger.debug("Launching %s non-preemptible nodes", numNodes)
1035
+ instancesLaunched = create_ondemand_instances(
1036
+ boto3_ec2=boto3_ec2,
1037
+ image_id=self._discoverAMI(),
1038
+ spec=on_demand_kwargs,
1039
+ num_instances=numNodes,
1040
+ )
913
1041
  else:
914
- logger.debug('Launching %s preemptible nodes', numNodes)
1042
+ logger.debug("Launching %s preemptible nodes", numNodes)
915
1043
  # force generator to evaluate
916
- generatedInstancesLaunched: List[DescribeInstancesResultTypeDef] = list(create_spot_instances(boto3_ec2=boto3_ec2,
917
- price=spotBid,
918
- image_id=self._discoverAMI(),
919
- tags={_TAG_KEY_TOIL_CLUSTER_NAME: self.clusterName},
920
- spec=spot_kwargs,
921
- num_instances=numNodes,
922
- tentative=True)
923
- )
1044
+ generatedInstancesLaunched: list[DescribeInstancesResultTypeDef] = (
1045
+ list(
1046
+ create_spot_instances(
1047
+ boto3_ec2=boto3_ec2,
1048
+ price=spotBid,
1049
+ image_id=self._discoverAMI(),
1050
+ tags={_TAG_KEY_TOIL_CLUSTER_NAME: self.clusterName},
1051
+ spec=spot_kwargs,
1052
+ num_instances=numNodes,
1053
+ tentative=True,
1054
+ )
1055
+ )
1056
+ )
924
1057
  # flatten the list
925
- flatten_reservations: List[ReservationTypeDef] = [reservation for subdict in generatedInstancesLaunched for reservation in subdict["Reservations"] for key, value in subdict.items()]
1058
+ flatten_reservations: list[ReservationTypeDef] = [
1059
+ reservation
1060
+ for subdict in generatedInstancesLaunched
1061
+ for reservation in subdict["Reservations"]
1062
+ for key, value in subdict.items()
1063
+ ]
926
1064
  # get a flattened list of all requested instances, as before instancesLaunched is a dict of reservations which is a dict of instance requests
927
- instancesLaunched = [instance for instances in flatten_reservations for instance in instances['Instances']]
1065
+ instancesLaunched = [
1066
+ instance
1067
+ for instances in flatten_reservations
1068
+ for instance in instances["Instances"]
1069
+ ]
928
1070
 
929
1071
  for attempt in old_retry(predicate=awsRetryPredicate):
930
1072
  with attempt:
931
- list(wait_instances_running(boto3_ec2, instancesLaunched)) # ensure all instances are running
1073
+ list(
1074
+ wait_instances_running(boto3_ec2, instancesLaunched)
1075
+ ) # ensure all instances are running
932
1076
 
933
1077
  increase_instance_hop_limit(boto3_ec2, instancesLaunched)
934
1078
 
935
- self._tags[_TAG_KEY_TOIL_NODE_TYPE] = 'worker'
1079
+ self._tags[_TAG_KEY_TOIL_NODE_TYPE] = "worker"
936
1080
  AWSProvisioner._addTags(boto3_ec2, instancesLaunched, self._tags)
937
1081
  if self._sseKey:
938
1082
  for i in instancesLaunched:
939
1083
  self._waitForIP(i)
940
- node = Node(publicIP=i['PublicIpAddress'], privateIP=i['PrivateIpAddress'], name=i['InstanceId'],
941
- launchTime=i['LaunchTime'], nodeType=i['InstanceType'], preemptible=preemptible,
942
- tags=collapse_tags(i['Tags']))
943
- node.waitForNode('toil_worker')
944
- node.coreRsync([self._sseKey, ':' + self._sseKey], applianceName='toil_worker')
945
- logger.debug('Launched %s new instance(s)', numNodes)
1084
+ node = Node(
1085
+ publicIP=i["PublicIpAddress"],
1086
+ privateIP=i["PrivateIpAddress"],
1087
+ name=i["InstanceId"],
1088
+ launchTime=i["LaunchTime"],
1089
+ nodeType=i["InstanceType"],
1090
+ preemptible=preemptible,
1091
+ tags=collapse_tags(i["Tags"]),
1092
+ )
1093
+ node.waitForNode("toil_worker")
1094
+ node.coreRsync(
1095
+ [self._sseKey, ":" + self._sseKey], applianceName="toil_worker"
1096
+ )
1097
+ logger.debug("Launched %s new instance(s)", numNodes)
946
1098
  return len(instancesLaunched)
947
1099
 
948
- def addManagedNodes(self, nodeTypes: Set[str], minNodes: int, maxNodes: int, preemptible: bool, spotBid: Optional[float] = None) -> None:
949
-
950
- if self.clusterType != 'kubernetes':
951
- raise ManagedNodesNotSupportedException("Managed nodes only supported for Kubernetes clusters")
1100
+ def addManagedNodes(
1101
+ self,
1102
+ nodeTypes: set[str],
1103
+ minNodes: int,
1104
+ maxNodes: int,
1105
+ preemptible: bool,
1106
+ spotBid: float | None = None,
1107
+ ) -> None:
1108
+
1109
+ if self.clusterType != "kubernetes":
1110
+ raise ManagedNodesNotSupportedException(
1111
+ "Managed nodes only supported for Kubernetes clusters"
1112
+ )
952
1113
 
953
1114
  assert self._leaderPrivateIP
954
1115
 
@@ -960,25 +1121,51 @@ class AWSProvisioner(AbstractProvisioner):
960
1121
 
961
1122
  # Make one template per node type, so we can apply storage overrides correctly
962
1123
  # TODO: deduplicate these if the same instance type appears in multiple sets?
963
- launch_template_ids = {n: self._get_worker_launch_template(n, preemptible=preemptible) for n in nodeTypes}
1124
+ launch_template_ids = {
1125
+ n: self._get_worker_launch_template(n, preemptible=preemptible)
1126
+ for n in nodeTypes
1127
+ }
964
1128
  # Make the ASG across all of them
965
- self._createWorkerAutoScalingGroup(launch_template_ids, nodeTypes, minNodes, maxNodes,
966
- spot_bid=spotBid)
1129
+ self._createWorkerAutoScalingGroup(
1130
+ launch_template_ids, nodeTypes, minNodes, maxNodes, spot_bid=spotBid
1131
+ )
967
1132
 
968
- def getProvisionedWorkers(self, instance_type: Optional[str] = None, preemptible: Optional[bool] = None) -> List[Node]:
1133
+ def getProvisionedWorkers(
1134
+ self, instance_type: str | None = None, preemptible: bool | None = None
1135
+ ) -> list[Node]:
969
1136
  assert self._leaderPrivateIP
970
1137
  entireCluster = self._get_nodes_in_cluster_boto3(instance_type=instance_type)
971
- logger.debug('All nodes in cluster: %s', entireCluster)
972
- workerInstances: List[InstanceTypeDef] = [i for i in entireCluster if i["PrivateIpAddress"] != self._leaderPrivateIP]
973
- logger.debug('All workers found in cluster: %s', workerInstances)
1138
+ logger.debug("All nodes in cluster: %s", entireCluster)
1139
+ workerInstances: list[InstanceTypeDef] = [
1140
+ i for i in entireCluster if i["PrivateIpAddress"] != self._leaderPrivateIP
1141
+ ]
1142
+ logger.debug("All workers found in cluster: %s", workerInstances)
974
1143
  if preemptible is not None:
975
- workerInstances = [i for i in workerInstances if preemptible == (i["SpotInstanceRequestId"] is not None)]
976
- logger.debug('%spreemptible workers found in cluster: %s', 'non-' if not preemptible else '', workerInstances)
977
- workerInstances = awsFilterImpairedNodes(workerInstances, self.aws.client(self._region, 'ec2'))
978
- return [Node(publicIP=i["PublicIpAddress"], privateIP=i["PrivateIpAddress"],
979
- name=i["InstanceId"], launchTime=i["LaunchTime"], nodeType=i["InstanceType"],
980
- preemptible=i["SpotInstanceRequestId"] is not None, tags=collapse_tags(i["Tags"]))
981
- for i in workerInstances]
1144
+ workerInstances = [
1145
+ i
1146
+ for i in workerInstances
1147
+ if preemptible == (i["SpotInstanceRequestId"] is not None)
1148
+ ]
1149
+ logger.debug(
1150
+ "%spreemptible workers found in cluster: %s",
1151
+ "non-" if not preemptible else "",
1152
+ workerInstances,
1153
+ )
1154
+ workerInstances = awsFilterImpairedNodes(
1155
+ workerInstances, self.aws.client(self._region, "ec2")
1156
+ )
1157
+ return [
1158
+ Node(
1159
+ publicIP=i["PublicIpAddress"],
1160
+ privateIP=i["PrivateIpAddress"],
1161
+ name=i["InstanceId"],
1162
+ launchTime=i["LaunchTime"],
1163
+ nodeType=i["InstanceType"],
1164
+ preemptible=i["SpotInstanceRequestId"] is not None,
1165
+ tags=collapse_tags(i["Tags"]),
1166
+ )
1167
+ for i in workerInstances
1168
+ ]
982
1169
 
983
1170
  @memoize
984
1171
  def _discoverAMI(self) -> str:
@@ -986,17 +1173,19 @@ class AWSProvisioner(AbstractProvisioner):
986
1173
  :return: The AMI ID (a string like 'ami-0a9a5d2b65cce04eb') for Flatcar.
987
1174
  :rtype: str
988
1175
  """
989
- return get_flatcar_ami(self.aws.client(self._region, 'ec2'), self._architecture)
1176
+ return get_flatcar_ami(self.aws.client(self._region, "ec2"), self._architecture)
990
1177
 
991
1178
  def _toNameSpace(self) -> str:
992
1179
  assert isinstance(self.clusterName, (str, bytes))
993
- if any(char.isupper() for char in self.clusterName) or '_' in self.clusterName:
994
- raise RuntimeError("The cluster name must be lowercase and cannot contain the '_' "
995
- "character.")
1180
+ if any(char.isupper() for char in self.clusterName) or "_" in self.clusterName:
1181
+ raise RuntimeError(
1182
+ "The cluster name must be lowercase and cannot contain the '_' "
1183
+ "character."
1184
+ )
996
1185
  namespace = self.clusterName
997
- if not namespace.startswith('/'):
998
- namespace = '/' + namespace + '/'
999
- return namespace.replace('-', '/')
1186
+ if not namespace.startswith("/"):
1187
+ namespace = "/" + namespace + "/"
1188
+ return namespace.replace("-", "/")
1000
1189
 
1001
1190
  def _namespace_name(self, name: str) -> str:
1002
1191
  """
@@ -1007,7 +1196,7 @@ class AWSProvisioner(AbstractProvisioner):
1007
1196
  # This logic is a bit weird, but it's what Boto2Context used to use.
1008
1197
  # Drop the leading / from the absolute-path-style "namespace" name and
1009
1198
  # then encode underscores and slashes.
1010
- return (self._toNameSpace() + name)[1:].replace('_', '__').replace('/', '_')
1199
+ return (self._toNameSpace() + name)[1:].replace("_", "__").replace("/", "_")
1011
1200
 
1012
1201
  def _is_our_namespaced_name(self, namespaced_name: str) -> bool:
1013
1202
  """
@@ -1015,30 +1204,37 @@ class AWSProvisioner(AbstractProvisioner):
1015
1204
  and was generated by _namespace_name().
1016
1205
  """
1017
1206
 
1018
- denamespaced = '/' + '_'.join(s.replace('_', '/') for s in namespaced_name.split('__'))
1207
+ denamespaced = "/" + "_".join(
1208
+ s.replace("_", "/") for s in namespaced_name.split("__")
1209
+ )
1019
1210
  return denamespaced.startswith(self._toNameSpace())
1020
1211
 
1021
1212
  def _getLeaderInstanceBoto3(self) -> InstanceTypeDef:
1022
1213
  """
1023
1214
  Get the Boto 3 instance for the cluster's leader.
1024
- :return: InstanceTypeDef
1025
1215
  """
1026
1216
  # Tags are stored differently in Boto 3
1027
- instances: List[InstanceTypeDef] = self._get_nodes_in_cluster_boto3(include_stopped_nodes=True)
1217
+ instances: list[InstanceTypeDef] = self._get_nodes_in_cluster_boto3(
1218
+ include_stopped_nodes=True
1219
+ )
1028
1220
  instances.sort(key=lambda x: x["LaunchTime"])
1029
1221
  try:
1030
1222
  leader = instances[0] # assume leader was launched first
1031
1223
  except IndexError:
1032
1224
  raise NoSuchClusterException(self.clusterName)
1033
1225
  if leader.get("Tags") is not None:
1034
- tag_value = next(item["Value"] for item in leader["Tags"] if item["Key"] == _TAG_KEY_TOIL_NODE_TYPE)
1226
+ tag_value = next(
1227
+ item["Value"]
1228
+ for item in leader["Tags"]
1229
+ if item["Key"] == _TAG_KEY_TOIL_NODE_TYPE
1230
+ )
1035
1231
  else:
1036
1232
  tag_value = None
1037
- if (tag_value or 'leader') != 'leader':
1233
+ if (tag_value or "leader") != "leader":
1038
1234
  raise InvalidClusterStateException(
1039
- 'Invalid cluster state! The first launched instance appears not to be the leader '
1235
+ "Invalid cluster state! The first launched instance appears not to be the leader "
1040
1236
  'as it is missing the "leader" tag. The safest recovery is to destroy the cluster '
1041
- 'and restart the job. Incorrect Leader ID: %s' % leader["InstanceId"]
1237
+ "and restart the job. Incorrect Leader ID: %s" % leader["InstanceId"]
1042
1238
  )
1043
1239
  return leader
1044
1240
 
@@ -1052,16 +1248,16 @@ class AWSProvisioner(AbstractProvisioner):
1052
1248
  leader: InstanceTypeDef = instances[0] # assume leader was launched first
1053
1249
  except IndexError:
1054
1250
  raise NoSuchClusterException(self.clusterName)
1055
- tagged_node_type: str = 'leader'
1251
+ tagged_node_type: str = "leader"
1056
1252
  for tag in leader["Tags"]:
1057
1253
  # If a tag specifying node type exists,
1058
1254
  if tag.get("Key") is not None and tag["Key"] == _TAG_KEY_TOIL_NODE_TYPE:
1059
1255
  tagged_node_type = tag["Value"]
1060
- if tagged_node_type != 'leader':
1256
+ if tagged_node_type != "leader":
1061
1257
  raise InvalidClusterStateException(
1062
- 'Invalid cluster state! The first launched instance appears not to be the leader '
1258
+ "Invalid cluster state! The first launched instance appears not to be the leader "
1063
1259
  'as it is missing the "leader" tag. The safest recovery is to destroy the cluster '
1064
- 'and restart the job. Incorrect Leader ID: %s' % leader["InstanceId"]
1260
+ "and restart the job. Incorrect Leader ID: %s" % leader["InstanceId"]
1065
1261
  )
1066
1262
  return leader
1067
1263
 
@@ -1071,28 +1267,41 @@ class AWSProvisioner(AbstractProvisioner):
1071
1267
  """
1072
1268
  leader: InstanceTypeDef = self._getLeaderInstanceBoto3()
1073
1269
 
1074
- leaderNode = Node(publicIP=leader["PublicIpAddress"], privateIP=leader["PrivateIpAddress"],
1075
- name=leader["InstanceId"], launchTime=leader["LaunchTime"], nodeType=None,
1076
- preemptible=False, tags=collapse_tags(leader["Tags"]))
1270
+ leaderNode = Node(
1271
+ publicIP=leader["PublicIpAddress"],
1272
+ privateIP=leader["PrivateIpAddress"],
1273
+ name=leader["InstanceId"],
1274
+ launchTime=leader["LaunchTime"],
1275
+ nodeType=None,
1276
+ preemptible=False,
1277
+ tags=collapse_tags(leader["Tags"]),
1278
+ )
1077
1279
  if wait:
1078
1280
  logger.debug("Waiting for toil_leader to enter 'running' state...")
1079
- wait_instances_running(self.aws.client(self._region, 'ec2'), [leader])
1080
- logger.debug('... toil_leader is running')
1281
+ wait_instances_running(self.aws.client(self._region, "ec2"), [leader])
1282
+ logger.debug("... toil_leader is running")
1081
1283
  self._waitForIP(leader)
1082
- leaderNode.waitForNode('toil_leader')
1284
+ leaderNode.waitForNode("toil_leader")
1083
1285
 
1084
1286
  return leaderNode
1085
1287
 
1086
1288
  @classmethod
1087
1289
  @awsRetry
1088
- def _addTag(cls, boto3_ec2: EC2Client, instance: InstanceTypeDef, key: str, value: str) -> None:
1089
- if instance.get('Tags') is None:
1090
- instance['Tags'] = []
1290
+ def _addTag(
1291
+ cls, boto3_ec2: EC2Client, instance: InstanceTypeDef, key: str, value: str
1292
+ ) -> None:
1293
+ if instance.get("Tags") is None:
1294
+ instance["Tags"] = []
1091
1295
  new_tag: TagTypeDef = {"Key": key, "Value": value}
1092
1296
  boto3_ec2.create_tags(Resources=[instance["InstanceId"]], Tags=[new_tag])
1093
1297
 
1094
1298
  @classmethod
1095
- def _addTags(cls, boto3_ec2: EC2Client, instances: List[InstanceTypeDef], tags: Dict[str, str]) -> None:
1299
+ def _addTags(
1300
+ cls,
1301
+ boto3_ec2: EC2Client,
1302
+ instances: list[InstanceTypeDef],
1303
+ tags: dict[str, str],
1304
+ ) -> None:
1096
1305
  for instance in instances:
1097
1306
  for key, value in tags.items():
1098
1307
  cls._addTag(boto3_ec2, instance, key, value)
@@ -1104,30 +1313,39 @@ class AWSProvisioner(AbstractProvisioner):
1104
1313
 
1105
1314
  :type instance: boto.ec2.instance.Instance
1106
1315
  """
1107
- logger.debug('Waiting for ip...')
1316
+ logger.debug("Waiting for ip...")
1108
1317
  while True:
1109
1318
  time.sleep(a_short_time)
1110
- if instance.get("PublicIpAddress") or instance.get("PublicDnsName") or instance.get("PrivateIpAddress"):
1111
- logger.debug('...got ip')
1319
+ if (
1320
+ instance.get("PublicIpAddress")
1321
+ or instance.get("PublicDnsName")
1322
+ or instance.get("PrivateIpAddress")
1323
+ ):
1324
+ logger.debug("...got ip")
1112
1325
  break
1113
1326
 
1114
- def _terminateInstances(self, instances: List[InstanceTypeDef]) -> None:
1327
+ def _terminateInstances(self, instances: list[InstanceTypeDef]) -> None:
1115
1328
  instanceIDs = [x["InstanceId"] for x in instances]
1116
1329
  self._terminateIDs(instanceIDs)
1117
- logger.info('... Waiting for instance(s) to shut down...')
1330
+ logger.info("... Waiting for instance(s) to shut down...")
1118
1331
  for instance in instances:
1119
- wait_transition(self.aws.client(region=self._region, service_name="ec2"), instance, {'pending', 'running', 'shutting-down', 'stopping', 'stopped'}, 'terminated')
1120
- logger.info('Instance(s) terminated.')
1332
+ wait_transition(
1333
+ self.aws.client(region=self._region, service_name="ec2"),
1334
+ instance,
1335
+ {"pending", "running", "shutting-down", "stopping", "stopped"},
1336
+ "terminated",
1337
+ )
1338
+ logger.info("Instance(s) terminated.")
1121
1339
 
1122
1340
  @awsRetry
1123
- def _terminateIDs(self, instanceIDs: List[str]) -> None:
1124
- logger.info('Terminating instance(s): %s', instanceIDs)
1341
+ def _terminateIDs(self, instanceIDs: list[str]) -> None:
1342
+ logger.info("Terminating instance(s): %s", instanceIDs)
1125
1343
  boto3_ec2 = self.aws.client(region=self._region, service_name="ec2")
1126
1344
  boto3_ec2.terminate_instances(InstanceIds=instanceIDs)
1127
- logger.info('Instance(s) terminated.')
1345
+ logger.info("Instance(s) terminated.")
1128
1346
 
1129
1347
  @awsRetry
1130
- def _deleteRoles(self, names: List[str]) -> None:
1348
+ def _deleteRoles(self, names: list[str]) -> None:
1131
1349
  """
1132
1350
  Delete all the given named IAM roles.
1133
1351
  Detatches but does not delete associated instance profiles.
@@ -1140,22 +1358,24 @@ class AWSProvisioner(AbstractProvisioner):
1140
1358
 
1141
1359
  for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors):
1142
1360
  with attempt:
1143
- boto3_iam.remove_role_from_instance_profile(InstanceProfileName=profile_name,
1144
- RoleName=role_name)
1361
+ boto3_iam.remove_role_from_instance_profile(
1362
+ InstanceProfileName=profile_name, RoleName=role_name
1363
+ )
1145
1364
  # We also need to drop all inline policies
1146
1365
  for policy_name in self._getRoleInlinePolicyNames(role_name):
1147
1366
  for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors):
1148
1367
  with attempt:
1149
- boto3_iam.delete_role_policy(PolicyName=policy_name,
1150
- RoleName=role_name)
1368
+ boto3_iam.delete_role_policy(
1369
+ PolicyName=policy_name, RoleName=role_name
1370
+ )
1151
1371
 
1152
1372
  for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors):
1153
1373
  with attempt:
1154
1374
  boto3_iam.delete_role(RoleName=role_name)
1155
- logger.debug('... Successfully deleted IAM role %s', role_name)
1375
+ logger.debug("... Successfully deleted IAM role %s", role_name)
1156
1376
 
1157
1377
  @awsRetry
1158
- def _deleteInstanceProfiles(self, names: List[str]) -> None:
1378
+ def _deleteInstanceProfiles(self, names: list[str]) -> None:
1159
1379
  """
1160
1380
  Delete all the given named IAM instance profiles.
1161
1381
  All roles must already be detached.
@@ -1165,16 +1385,22 @@ class AWSProvisioner(AbstractProvisioner):
1165
1385
  for attempt in old_retry(timeout=300, predicate=expectedShutdownErrors):
1166
1386
  with attempt:
1167
1387
  boto3_iam.delete_instance_profile(InstanceProfileName=profile_name)
1168
- logger.debug('... Succesfully deleted instance profile %s', profile_name)
1388
+ logger.debug(
1389
+ "... Succesfully deleted instance profile %s", profile_name
1390
+ )
1169
1391
 
1170
1392
  @classmethod
1171
- def _getBoto3BlockDeviceMapping(cls, type_info: InstanceType, rootVolSize: int = 50) -> List[BlockDeviceMappingTypeDef]:
1393
+ def _getBoto3BlockDeviceMapping(
1394
+ cls, type_info: InstanceType, rootVolSize: int = 50
1395
+ ) -> list[BlockDeviceMappingTypeDef]:
1172
1396
  # determine number of ephemeral drives via cgcloud-lib (actually this is moved into toil's lib
1173
- bdtKeys = [''] + [f'/dev/xvd{c}' for c in string.ascii_lowercase[1:]]
1174
- bdm_list: List[BlockDeviceMappingTypeDef] = []
1397
+ bdtKeys = [""] + [f"/dev/xvd{c}" for c in string.ascii_lowercase[1:]]
1398
+ bdm_list: list[BlockDeviceMappingTypeDef] = []
1175
1399
  # Change root volume size to allow for bigger Docker instances
1176
- root_vol: EbsBlockDeviceTypeDef = {"DeleteOnTermination": True,
1177
- "VolumeSize": rootVolSize}
1400
+ root_vol: EbsBlockDeviceTypeDef = {
1401
+ "DeleteOnTermination": True,
1402
+ "VolumeSize": rootVolSize,
1403
+ }
1178
1404
  bdm: BlockDeviceMappingTypeDef = {"DeviceName": "/dev/xvda", "Ebs": root_vol}
1179
1405
  bdm_list.append(bdm)
1180
1406
  # The first disk is already attached for us so start with 2nd.
@@ -1187,85 +1413,118 @@ class AWSProvisioner(AbstractProvisioner):
1187
1413
  # bdm["Ebs"] = root_vol.update({"VirtualName": f"ephemeral{disk - 1}"})
1188
1414
  bdm_list.append(bdm)
1189
1415
 
1190
- logger.debug('Device mapping: %s', bdm_list)
1416
+ logger.debug("Device mapping: %s", bdm_list)
1191
1417
  return bdm_list
1192
1418
 
1193
1419
  @classmethod
1194
- def _getBoto3BlockDeviceMappings(cls, type_info: InstanceType, rootVolSize: int = 50) -> List[BlockDeviceMappingTypeDef]:
1420
+ def _getBoto3BlockDeviceMappings(
1421
+ cls, type_info: InstanceType, rootVolSize: int = 50
1422
+ ) -> list[BlockDeviceMappingTypeDef]:
1195
1423
  """
1196
1424
  Get block device mappings for the root volume for a worker.
1197
1425
  """
1198
1426
 
1199
1427
  # Start with the root
1200
- bdms: List[BlockDeviceMappingTypeDef] = [{
1201
- 'DeviceName': '/dev/xvda',
1202
- 'Ebs': {
1203
- 'DeleteOnTermination': True,
1204
- 'VolumeSize': rootVolSize,
1205
- 'VolumeType': 'gp2'
1428
+ bdms: list[BlockDeviceMappingTypeDef] = [
1429
+ {
1430
+ "DeviceName": "/dev/xvda",
1431
+ "Ebs": {
1432
+ "DeleteOnTermination": True,
1433
+ "VolumeSize": rootVolSize,
1434
+ "VolumeType": "gp2",
1435
+ },
1206
1436
  }
1207
- }]
1437
+ ]
1208
1438
 
1209
1439
  # Get all the virtual drives we might have
1210
- bdtKeys = [f'/dev/xvd{c}' for c in string.ascii_lowercase]
1440
+ bdtKeys = [f"/dev/xvd{c}" for c in string.ascii_lowercase]
1211
1441
 
1212
1442
  # The first disk is already attached for us so start with 2nd.
1213
1443
  # Disk count is weirdly a float in our instance database, so make it an int here.
1214
1444
  for disk in range(1, int(type_info.disks) + 1):
1215
1445
  # Make a block device mapping to attach the ephemeral disk to a
1216
1446
  # virtual block device in the VM
1217
- bdms.append({
1218
- 'DeviceName': bdtKeys[disk],
1219
- 'VirtualName': f'ephemeral{disk - 1}' # ephemeral counts start at 0
1220
- })
1221
- logger.debug('Device mapping: %s', bdms)
1447
+ bdms.append(
1448
+ {
1449
+ "DeviceName": bdtKeys[disk],
1450
+ "VirtualName": f"ephemeral{disk - 1}", # ephemeral counts start at 0
1451
+ }
1452
+ )
1453
+ logger.debug("Device mapping: %s", bdms)
1222
1454
  return bdms
1223
1455
 
1224
1456
  @awsRetry
1225
- def _get_nodes_in_cluster_boto3(self, instance_type: Optional[str] = None, include_stopped_nodes: bool = False) -> List[InstanceTypeDef]:
1457
+ def _get_nodes_in_cluster_boto3(
1458
+ self, instance_type: str | None = None, include_stopped_nodes: bool = False
1459
+ ) -> list[InstanceTypeDef]:
1226
1460
  """
1227
1461
  Get Boto3 instance objects for all nodes in the cluster.
1228
1462
  """
1229
- boto3_ec2: EC2Client = self.aws.client(region=self._region, service_name='ec2')
1230
- instance_filter: FilterTypeDef = {'Name': 'instance.group-name', 'Values': [self.clusterName]}
1231
- describe_response: DescribeInstancesResultTypeDef = boto3_ec2.describe_instances(Filters=[instance_filter])
1232
- all_instances: List[InstanceTypeDef] = []
1233
- for reservation in describe_response['Reservations']:
1234
- instances = reservation['Instances']
1463
+ boto3_ec2: EC2Client = self.aws.client(region=self._region, service_name="ec2")
1464
+ instance_filter: FilterTypeDef = {
1465
+ "Name": "instance.group-name",
1466
+ "Values": [self.clusterName],
1467
+ }
1468
+ describe_response: DescribeInstancesResultTypeDef = (
1469
+ boto3_ec2.describe_instances(Filters=[instance_filter])
1470
+ )
1471
+ all_instances: list[InstanceTypeDef] = []
1472
+ for reservation in describe_response["Reservations"]:
1473
+ instances = reservation["Instances"]
1235
1474
  all_instances.extend(instances)
1236
1475
 
1237
1476
  # all_instances = self.aws.boto2(self._region, 'ec2').get_only_instances(filters={'instance.group-name': self.clusterName})
1238
1477
 
1239
1478
  def instanceFilter(i: InstanceTypeDef) -> bool:
1240
1479
  # filter by type only if nodeType is true
1241
- rightType = not instance_type or i['InstanceType'] == instance_type
1242
- rightState = i['State']['Name'] == 'running' or i['State']['Name'] == 'pending'
1480
+ rightType = not instance_type or i["InstanceType"] == instance_type
1481
+ rightState = (
1482
+ i["State"]["Name"] == "running" or i["State"]["Name"] == "pending"
1483
+ )
1243
1484
  if include_stopped_nodes:
1244
- rightState = rightState or i['State']['Name'] == 'stopping' or i['State']['Name'] == 'stopped'
1485
+ rightState = (
1486
+ rightState
1487
+ or i["State"]["Name"] == "stopping"
1488
+ or i["State"]["Name"] == "stopped"
1489
+ )
1245
1490
  return rightType and rightState
1246
1491
 
1247
1492
  return [i for i in all_instances if instanceFilter(i)]
1248
1493
 
1249
- def _getSpotRequestIDs(self) -> List[str]:
1494
+ def _getSpotRequestIDs(self) -> list[str]:
1250
1495
  """
1251
1496
  Get the IDs of all spot requests associated with the cluster.
1252
1497
  """
1253
1498
 
1254
1499
  # Grab the connection we need to use for this operation.
1255
- ec2: EC2Client = self.aws.client(self._region, 'ec2')
1256
-
1257
- requests: List[SpotInstanceRequestTypeDef] = ec2.describe_spot_instance_requests()["SpotInstanceRequests"]
1258
- tag_filter: FilterTypeDef = {"Name": "tag:" + _TAG_KEY_TOIL_CLUSTER_NAME, "Values": [self.clusterName]}
1259
- tags: List[TagDescriptionTypeDef] = ec2.describe_tags(Filters=[tag_filter])["Tags"]
1500
+ ec2: EC2Client = self.aws.client(self._region, "ec2")
1501
+
1502
+ requests: list[SpotInstanceRequestTypeDef] = (
1503
+ ec2.describe_spot_instance_requests()["SpotInstanceRequests"]
1504
+ )
1505
+ tag_filter: FilterTypeDef = {
1506
+ "Name": "tag:" + _TAG_KEY_TOIL_CLUSTER_NAME,
1507
+ "Values": [self.clusterName],
1508
+ }
1509
+ tags: list[TagDescriptionTypeDef] = ec2.describe_tags(Filters=[tag_filter])[
1510
+ "Tags"
1511
+ ]
1260
1512
  idsToCancel = [tag["ResourceId"] for tag in tags]
1261
- return [request["SpotInstanceRequestId"] for request in requests if request["InstanceId"] in idsToCancel]
1513
+ return [
1514
+ request["SpotInstanceRequestId"]
1515
+ for request in requests
1516
+ if request["InstanceId"] in idsToCancel
1517
+ ]
1262
1518
 
1263
- def _createSecurityGroups(self) -> List[str]:
1519
+ def _createSecurityGroups(self) -> list[str]:
1264
1520
  """
1265
1521
  Create security groups for the cluster. Returns a list of their IDs.
1266
1522
  """
1523
+
1267
1524
  def group_not_found(e: ClientError) -> bool:
1268
- retry = (get_error_status(e) == 400 and 'does not exist in default VPC' in get_error_body(e))
1525
+ retry = get_error_status(
1526
+ e
1527
+ ) == 400 and "does not exist in default VPC" in get_error_body(e)
1269
1528
  return retry
1270
1529
 
1271
1530
  # Grab the connection we need to use for this operation.
@@ -1274,49 +1533,69 @@ class AWSProvisioner(AbstractProvisioner):
1274
1533
 
1275
1534
  vpc_id = None
1276
1535
  if self._leader_subnet:
1277
- subnets = boto3_ec2.describe_subnets(SubnetIds=[self._leader_subnet])["Subnets"]
1536
+ subnets = boto3_ec2.describe_subnets(SubnetIds=[self._leader_subnet])[
1537
+ "Subnets"
1538
+ ]
1278
1539
  if len(subnets) > 0:
1279
1540
  vpc_id = subnets[0]["VpcId"]
1280
1541
  try:
1281
1542
  # Security groups need to belong to the same VPC as the leader. If we
1282
1543
  # put the leader in a particular non-default subnet, it may be in a
1283
1544
  # particular non-default VPC, which we need to know about.
1284
- other = {"GroupName": self.clusterName, "Description": "Toil appliance security group"}
1545
+ other = {
1546
+ "GroupName": self.clusterName,
1547
+ "Description": "Toil appliance security group",
1548
+ }
1285
1549
  if vpc_id is not None:
1286
1550
  other["VpcId"] = vpc_id
1287
1551
  # mypy stubs don't explicitly state kwargs even though documentation allows it, and mypy gets confused
1288
1552
  web_response: CreateSecurityGroupResultTypeDef = boto3_ec2.create_security_group(**other) # type: ignore[arg-type]
1289
1553
  except ClientError as e:
1290
- if get_error_status(e) == 400 and 'already exists' in get_error_body(e):
1554
+ if get_error_status(e) == 400 and "already exists" in get_error_body(e):
1291
1555
  pass
1292
1556
  else:
1293
1557
  raise
1294
1558
  else:
1295
1559
  for attempt in old_retry(predicate=group_not_found, timeout=300):
1296
1560
  with attempt:
1297
- ip_permissions: List[IpPermissionTypeDef] = [{"IpProtocol": "tcp",
1298
- "FromPort": 22,
1299
- "ToPort": 22,
1300
- "IpRanges": [
1301
- {"CidrIp": "0.0.0.0/0"}
1302
- ],
1303
- "Ipv6Ranges": [{"CidrIpv6": "::/0"}]}]
1561
+ ip_permissions: list[IpPermissionTypeDef] = [
1562
+ {
1563
+ "IpProtocol": "tcp",
1564
+ "FromPort": 22,
1565
+ "ToPort": 22,
1566
+ "IpRanges": [{"CidrIp": "0.0.0.0/0"}],
1567
+ "Ipv6Ranges": [{"CidrIpv6": "::/0"}],
1568
+ }
1569
+ ]
1304
1570
  for protocol in ("tcp", "udp"):
1305
- ip_permissions.append({"IpProtocol": protocol,
1306
- "FromPort": 0,
1307
- "ToPort": 65535,
1308
- "UserIdGroupPairs":
1309
- [{"GroupId": web_response["GroupId"],
1310
- "GroupName": self.clusterName}]})
1311
- boto3_ec2.authorize_security_group_ingress(IpPermissions=ip_permissions, GroupName=self.clusterName, GroupId=web_response["GroupId"])
1571
+ ip_permissions.append(
1572
+ {
1573
+ "IpProtocol": protocol,
1574
+ "FromPort": 0,
1575
+ "ToPort": 65535,
1576
+ "UserIdGroupPairs": [
1577
+ {
1578
+ "GroupId": web_response["GroupId"],
1579
+ "GroupName": self.clusterName,
1580
+ }
1581
+ ],
1582
+ }
1583
+ )
1584
+ boto3_ec2.authorize_security_group_ingress(
1585
+ IpPermissions=ip_permissions,
1586
+ GroupName=self.clusterName,
1587
+ GroupId=web_response["GroupId"],
1588
+ )
1312
1589
  out = []
1313
1590
  for sg in boto3_ec2.describe_security_groups()["SecurityGroups"]:
1314
- if sg["GroupName"] == self.clusterName and (vpc_id is None or sg["VpcId"] == vpc_id):
1591
+ if sg["GroupName"] == self.clusterName and (
1592
+ vpc_id is None or sg["VpcId"] == vpc_id
1593
+ ):
1315
1594
  out.append(sg["GroupId"])
1316
1595
  return out
1317
1596
 
1318
1597
  @awsRetry
1319
- def _getSecurityGroupIDs(self) -> List[str]:
1598
+ def _getSecurityGroupIDs(self) -> list[str]:
1320
1599
  """
1321
1600
  Get all the security group IDs to apply to leaders and workers.
1322
1601
  """
@@ -1325,13 +1604,20 @@ class AWSProvisioner(AbstractProvisioner):
1325
1604
 
1326
1605
  # Depending on if we enumerated them on the leader or locally, we might
1327
1606
  # know the required security groups by name, ID, or both.
1328
- boto3_ec2 = self.aws.client(region=self._region, service_name='ec2')
1329
- return [sg["GroupId"] for sg in boto3_ec2.describe_security_groups()["SecurityGroups"]
1330
- if (sg["GroupName"] in self._leaderSecurityGroupNames or
1331
- sg["GroupId"] in self._leaderSecurityGroupIDs)]
1607
+ boto3_ec2 = self.aws.client(region=self._region, service_name="ec2")
1608
+ return [
1609
+ sg["GroupId"]
1610
+ for sg in boto3_ec2.describe_security_groups()["SecurityGroups"]
1611
+ if (
1612
+ sg["GroupName"] in self._leaderSecurityGroupNames
1613
+ or sg["GroupId"] in self._leaderSecurityGroupIDs
1614
+ )
1615
+ ]
1332
1616
 
1333
1617
  @awsRetry
1334
- def _get_launch_template_ids(self, filters: Optional[List[FilterTypeDef]] = None) -> List[str]:
1618
+ def _get_launch_template_ids(
1619
+ self, filters: list[FilterTypeDef] | None = None
1620
+ ) -> list[str]:
1335
1621
  """
1336
1622
  Find all launch templates associated with the cluster.
1337
1623
 
@@ -1339,10 +1625,12 @@ class AWSProvisioner(AbstractProvisioner):
1339
1625
  """
1340
1626
 
1341
1627
  # Grab the connection we need to use for this operation.
1342
- ec2: EC2Client = self.aws.client(self._region, 'ec2')
1628
+ ec2: EC2Client = self.aws.client(self._region, "ec2")
1343
1629
 
1344
1630
  # How do we match the right templates?
1345
- combined_filters: List[FilterTypeDef] = [{'Name': 'tag:' + _TAG_KEY_TOIL_CLUSTER_NAME, 'Values': [self.clusterName]}]
1631
+ combined_filters: list[FilterTypeDef] = [
1632
+ {"Name": "tag:" + _TAG_KEY_TOIL_CLUSTER_NAME, "Values": [self.clusterName]}
1633
+ ]
1346
1634
 
1347
1635
  if filters:
1348
1636
  # Add any user-specified filters
@@ -1350,16 +1638,21 @@ class AWSProvisioner(AbstractProvisioner):
1350
1638
 
1351
1639
  allTemplateIDs = []
1352
1640
  # Get the first page with no NextToken
1353
- response = ec2.describe_launch_templates(Filters=combined_filters,
1354
- MaxResults=200)
1641
+ response = ec2.describe_launch_templates(
1642
+ Filters=combined_filters, MaxResults=200
1643
+ )
1355
1644
  while True:
1356
1645
  # Process the current page
1357
- allTemplateIDs += [item['LaunchTemplateId'] for item in response.get('LaunchTemplates', [])]
1358
- if 'NextToken' in response:
1646
+ allTemplateIDs += [
1647
+ item["LaunchTemplateId"] for item in response.get("LaunchTemplates", [])
1648
+ ]
1649
+ if "NextToken" in response:
1359
1650
  # There are more pages. Get the next one, supplying the token.
1360
- response = ec2.describe_launch_templates(Filters=filters or [],
1361
- NextToken=response['NextToken'],
1362
- MaxResults=200)
1651
+ response = ec2.describe_launch_templates(
1652
+ Filters=filters or [],
1653
+ NextToken=response["NextToken"],
1654
+ MaxResults=200,
1655
+ )
1363
1656
  else:
1364
1657
  # No more pages
1365
1658
  break
@@ -1367,7 +1660,9 @@ class AWSProvisioner(AbstractProvisioner):
1367
1660
  return allTemplateIDs
1368
1661
 
1369
1662
  @awsRetry
1370
- def _get_worker_launch_template(self, instance_type: str, preemptible: bool = False, backoff: float = 1.0) -> str:
1663
+ def _get_worker_launch_template(
1664
+ self, instance_type: str, preemptible: bool = False, backoff: float = 1.0
1665
+ ) -> str:
1371
1666
  """
1372
1667
  Get a launch template for instances with the given parameters. Only one
1373
1668
  such launch template will be created, no matter how many times the
@@ -1386,36 +1681,55 @@ class AWSProvisioner(AbstractProvisioner):
1386
1681
  :return: The ID of the template.
1387
1682
  """
1388
1683
 
1389
- lt_name = self._name_worker_launch_template(instance_type, preemptible=preemptible)
1684
+ lt_name = self._name_worker_launch_template(
1685
+ instance_type, preemptible=preemptible
1686
+ )
1390
1687
 
1391
1688
  # How do we match the right templates?
1392
- filters: List[FilterTypeDef] = [{'Name': 'launch-template-name', 'Values': [lt_name]}]
1689
+ filters: list[FilterTypeDef] = [
1690
+ {"Name": "launch-template-name", "Values": [lt_name]}
1691
+ ]
1393
1692
 
1394
1693
  # Get the templates
1395
- templates: List[str] = self._get_launch_template_ids(filters=filters)
1694
+ templates: list[str] = self._get_launch_template_ids(filters=filters)
1396
1695
 
1397
1696
  if len(templates) > 1:
1398
1697
  # There shouldn't ever be multiple templates with our reserved name
1399
- raise RuntimeError(f"Multiple launch templates already exist named {lt_name}; "
1400
- "something else is operating in our cluster namespace.")
1698
+ raise RuntimeError(
1699
+ f"Multiple launch templates already exist named {lt_name}; "
1700
+ "something else is operating in our cluster namespace."
1701
+ )
1401
1702
  elif len(templates) == 0:
1402
1703
  # Template doesn't exist so we can create it.
1403
1704
  try:
1404
- return self._create_worker_launch_template(instance_type, preemptible=preemptible)
1705
+ return self._create_worker_launch_template(
1706
+ instance_type, preemptible=preemptible
1707
+ )
1405
1708
  except ClientError as e:
1406
- if get_error_code(e) == 'InvalidLaunchTemplateName.AlreadyExistsException':
1709
+ if (
1710
+ get_error_code(e)
1711
+ == "InvalidLaunchTemplateName.AlreadyExistsException"
1712
+ ):
1407
1713
  # Someone got to it before us (or we couldn't read our own
1408
1714
  # writes). Recurse to try again, because now it exists.
1409
- logger.info('Waiting %f seconds for template %s to be available', backoff, lt_name)
1715
+ logger.info(
1716
+ "Waiting %f seconds for template %s to be available",
1717
+ backoff,
1718
+ lt_name,
1719
+ )
1410
1720
  time.sleep(backoff)
1411
- return self._get_worker_launch_template(instance_type, preemptible=preemptible, backoff=backoff * 2)
1721
+ return self._get_worker_launch_template(
1722
+ instance_type, preemptible=preemptible, backoff=backoff * 2
1723
+ )
1412
1724
  else:
1413
1725
  raise
1414
1726
  else:
1415
1727
  # There must be exactly one template
1416
1728
  return templates[0]
1417
1729
 
1418
- def _name_worker_launch_template(self, instance_type: str, preemptible: bool = False) -> str:
1730
+ def _name_worker_launch_template(
1731
+ self, instance_type: str, preemptible: bool = False
1732
+ ) -> str:
1419
1733
  """
1420
1734
  Get the name we should use for the launch template with the given parameters.
1421
1735
 
@@ -1426,13 +1740,15 @@ class AWSProvisioner(AbstractProvisioner):
1426
1740
  """
1427
1741
 
1428
1742
  # The name has the cluster name in it
1429
- lt_name = f'{self.clusterName}-lt-{instance_type}'
1743
+ lt_name = f"{self.clusterName}-lt-{instance_type}"
1430
1744
  if preemptible:
1431
- lt_name += '-spot'
1745
+ lt_name += "-spot"
1432
1746
 
1433
1747
  return lt_name
1434
1748
 
1435
- def _create_worker_launch_template(self, instance_type: str, preemptible: bool = False) -> str:
1749
+ def _create_worker_launch_template(
1750
+ self, instance_type: str, preemptible: bool = False
1751
+ ) -> str:
1436
1752
  """
1437
1753
  Create the launch template for launching worker instances for the cluster.
1438
1754
 
@@ -1452,27 +1768,33 @@ class AWSProvisioner(AbstractProvisioner):
1452
1768
  bdms = self._getBoto3BlockDeviceMappings(type_info, rootVolSize=rootVolSize)
1453
1769
 
1454
1770
  keyPath = self._sseKey if self._sseKey else None
1455
- userData = self._getIgnitionUserData('worker', keyPath, preemptible, self._architecture)
1771
+ userData = self._getIgnitionUserData(
1772
+ "worker", keyPath, preemptible, self._architecture
1773
+ )
1456
1774
 
1457
- lt_name = self._name_worker_launch_template(instance_type, preemptible=preemptible)
1775
+ lt_name = self._name_worker_launch_template(
1776
+ instance_type, preemptible=preemptible
1777
+ )
1458
1778
 
1459
1779
  # But really we find it by tag
1460
1780
  tags = dict(self._tags)
1461
- tags[_TAG_KEY_TOIL_NODE_TYPE] = 'worker'
1462
-
1463
- return create_launch_template(self.aws.client(self._region, 'ec2'),
1464
- template_name=lt_name,
1465
- image_id=self._discoverAMI(),
1466
- key_name=self._keyName,
1467
- security_group_ids=self._getSecurityGroupIDs(),
1468
- instance_type=instance_type,
1469
- user_data=userData,
1470
- block_device_map=bdms,
1471
- instance_profile_arn=self._leaderProfileArn,
1472
- tags=tags)
1781
+ tags[_TAG_KEY_TOIL_NODE_TYPE] = "worker"
1782
+
1783
+ return create_launch_template(
1784
+ self.aws.client(self._region, "ec2"),
1785
+ template_name=lt_name,
1786
+ image_id=self._discoverAMI(),
1787
+ key_name=self._keyName,
1788
+ security_group_ids=self._getSecurityGroupIDs(),
1789
+ instance_type=instance_type,
1790
+ user_data=userData,
1791
+ block_device_map=bdms,
1792
+ instance_profile_arn=self._leaderProfileArn,
1793
+ tags=tags,
1794
+ )
1473
1795
 
1474
1796
  @awsRetry
1475
- def _getAutoScalingGroupNames(self) -> List[str]:
1797
+ def _getAutoScalingGroupNames(self) -> list[str]:
1476
1798
  """
1477
1799
  Find all auto-scaling groups associated with the cluster.
1478
1800
 
@@ -1480,29 +1802,33 @@ class AWSProvisioner(AbstractProvisioner):
1480
1802
  """
1481
1803
 
1482
1804
  # Grab the connection we need to use for this operation.
1483
- autoscaling: AutoScalingClient = self.aws.client(self._region, 'autoscaling')
1805
+ autoscaling: AutoScalingClient = self.aws.client(self._region, "autoscaling")
1484
1806
 
1485
1807
  # AWS won't filter ASGs server-side for us in describe_auto_scaling_groups.
1486
1808
  # So we search instances of applied tags for the ASGs they are on.
1487
1809
  # The ASGs tagged with our cluster are our ASGs.
1488
1810
  # The filtering is on different fields of the tag object itself.
1489
- filters: List[FilterTypeDef] = [{'Name': 'key',
1490
- 'Values': [_TAG_KEY_TOIL_CLUSTER_NAME]},
1491
- {'Name': 'value',
1492
- 'Values': [self.clusterName]}]
1811
+ filters: list[FilterTypeDef] = [
1812
+ {"Name": "key", "Values": [_TAG_KEY_TOIL_CLUSTER_NAME]},
1813
+ {"Name": "value", "Values": [self.clusterName]},
1814
+ ]
1493
1815
 
1494
1816
  matchedASGs = []
1495
1817
  # Get the first page with no NextToken
1496
1818
  response = autoscaling.describe_tags(Filters=filters)
1497
1819
  while True:
1498
1820
  # Process the current page
1499
- matchedASGs += [item['ResourceId'] for item in response.get('Tags', [])
1500
- if item['Key'] == _TAG_KEY_TOIL_CLUSTER_NAME and
1501
- item['Value'] == self.clusterName]
1502
- if 'NextToken' in response:
1821
+ matchedASGs += [
1822
+ item["ResourceId"]
1823
+ for item in response.get("Tags", [])
1824
+ if item["Key"] == _TAG_KEY_TOIL_CLUSTER_NAME
1825
+ and item["Value"] == self.clusterName
1826
+ ]
1827
+ if "NextToken" in response:
1503
1828
  # There are more pages. Get the next one, supplying the token.
1504
- response = autoscaling.describe_tags(Filters=filters,
1505
- NextToken=response['NextToken'])
1829
+ response = autoscaling.describe_tags(
1830
+ Filters=filters, NextToken=response["NextToken"]
1831
+ )
1506
1832
  else:
1507
1833
  # No more pages
1508
1834
  break
@@ -1510,16 +1836,18 @@ class AWSProvisioner(AbstractProvisioner):
1510
1836
  for name in matchedASGs:
1511
1837
  # Double check to make sure we definitely aren't finding non-Toil
1512
1838
  # things
1513
- assert name.startswith('toil-')
1839
+ assert name.startswith("toil-")
1514
1840
 
1515
1841
  return matchedASGs
1516
1842
 
1517
- def _createWorkerAutoScalingGroup(self,
1518
- launch_template_ids: Dict[str, str],
1519
- instance_types: Collection[str],
1520
- min_size: int,
1521
- max_size: int,
1522
- spot_bid: Optional[float] = None) -> str:
1843
+ def _createWorkerAutoScalingGroup(
1844
+ self,
1845
+ launch_template_ids: dict[str, str],
1846
+ instance_types: Collection[str],
1847
+ min_size: int,
1848
+ max_size: int,
1849
+ spot_bid: float | None = None,
1850
+ ) -> str:
1523
1851
  """
1524
1852
  Create an autoscaling group.
1525
1853
 
@@ -1549,8 +1877,12 @@ class AWSProvisioner(AbstractProvisioner):
1549
1877
  for instance_type in instance_types:
1550
1878
  spec = E2Instances[instance_type]
1551
1879
  spec_gigs = spec.disks * spec.disk_capacity
1552
- rootVolSize = self._nodeStorageOverrides.get(instance_type, self._nodeStorage)
1553
- storage_gigs.append(max(rootVolSize - _STORAGE_ROOT_OVERHEAD_GIGS, spec_gigs))
1880
+ rootVolSize = self._nodeStorageOverrides.get(
1881
+ instance_type, self._nodeStorage
1882
+ )
1883
+ storage_gigs.append(
1884
+ max(rootVolSize - _STORAGE_ROOT_OVERHEAD_GIGS, spec_gigs)
1885
+ )
1554
1886
  # Get the min storage we expect to see, but not less than 0.
1555
1887
  min_gigs = max(min(storage_gigs), 0)
1556
1888
 
@@ -1559,32 +1891,36 @@ class AWSProvisioner(AbstractProvisioner):
1559
1891
  tags = dict(self._tags)
1560
1892
 
1561
1893
  # We tag the ASG with the Toil type, although nothing cares.
1562
- tags[_TAG_KEY_TOIL_NODE_TYPE] = 'worker'
1894
+ tags[_TAG_KEY_TOIL_NODE_TYPE] = "worker"
1563
1895
 
1564
- if self.clusterType == 'kubernetes':
1896
+ if self.clusterType == "kubernetes":
1565
1897
  # We also need to tag it with Kubernetes autoscaler info (empty tags)
1566
- tags['k8s.io/cluster-autoscaler/' + self.clusterName] = ''
1567
- assert (self.clusterName != 'enabled')
1568
- tags['k8s.io/cluster-autoscaler/enabled'] = ''
1569
- tags['k8s.io/cluster-autoscaler/node-template/resources/ephemeral-storage'] = f'{min_gigs}G'
1898
+ tags["k8s.io/cluster-autoscaler/" + self.clusterName] = ""
1899
+ assert self.clusterName != "enabled"
1900
+ tags["k8s.io/cluster-autoscaler/enabled"] = ""
1901
+ tags[
1902
+ "k8s.io/cluster-autoscaler/node-template/resources/ephemeral-storage"
1903
+ ] = f"{min_gigs}G"
1570
1904
 
1571
1905
  # Now we need to make up a unique name
1572
1906
  # TODO: can we make this more semantic without risking collisions? Maybe count up in memory?
1573
- asg_name = 'toil-' + str(uuid.uuid4())
1574
-
1575
- create_auto_scaling_group(self.aws.client(self._region, 'autoscaling'),
1576
- asg_name=asg_name,
1577
- launch_template_ids=launch_template_ids,
1578
- vpc_subnets=self._get_worker_subnets(),
1579
- min_size=min_size,
1580
- max_size=max_size,
1581
- instance_types=instance_types,
1582
- spot_bid=spot_bid,
1583
- tags=tags)
1907
+ asg_name = "toil-" + str(uuid.uuid4())
1908
+
1909
+ create_auto_scaling_group(
1910
+ self.aws.client(self._region, "autoscaling"),
1911
+ asg_name=asg_name,
1912
+ launch_template_ids=launch_template_ids,
1913
+ vpc_subnets=self._get_worker_subnets(),
1914
+ min_size=min_size,
1915
+ max_size=max_size,
1916
+ instance_types=instance_types,
1917
+ spot_bid=spot_bid,
1918
+ tags=tags,
1919
+ )
1584
1920
 
1585
1921
  return asg_name
1586
1922
 
1587
- def _boto2_pager(self, requestor_callable: Callable[[...], Any], result_attribute_name: str) -> Iterable[Dict[str, Any]]: # type: ignore[misc]
1923
+ def _boto2_pager(self, requestor_callable: Callable[[...], Any], result_attribute_name: str) -> Iterable[dict[str, Any]]: # type: ignore[misc]
1588
1924
  """
1589
1925
  Yield all the results from calling the given Boto 2 method and paging
1590
1926
  through all the results using the "marker" field. Results are to be
@@ -1594,49 +1930,48 @@ class AWSProvisioner(AbstractProvisioner):
1594
1930
  while True:
1595
1931
  result = requestor_callable(marker=marker) # type: ignore[call-arg]
1596
1932
  yield from getattr(result, result_attribute_name)
1597
- if result.is_truncated == 'true':
1933
+ if result.is_truncated == "true":
1598
1934
  marker = result.marker
1599
1935
  else:
1600
1936
  break
1601
1937
 
1602
1938
  @awsRetry
1603
- def _getRoleNames(self) -> List[str]:
1939
+ def _getRoleNames(self) -> list[str]:
1604
1940
  """
1605
1941
  Get all the IAM roles belonging to the cluster, as names.
1606
1942
  """
1607
1943
 
1608
1944
  results = []
1609
- boto3_iam = self.aws.client(self._region, 'iam')
1610
- for result in boto3_pager(boto3_iam.list_roles, 'Roles'):
1945
+ boto3_iam = self.aws.client(self._region, "iam")
1946
+ for result in boto3_pager(boto3_iam.list_roles, "Roles"):
1611
1947
  # For each Boto2 role object
1612
1948
  # Grab out the name
1613
- cast(RoleTypeDef, result)
1614
- name = result['RoleName']
1949
+ result2 = cast("RoleTypeDef", result)
1950
+ name = result2["RoleName"]
1615
1951
  if self._is_our_namespaced_name(name):
1616
1952
  # If it looks like ours, it is ours.
1617
1953
  results.append(name)
1618
1954
  return results
1619
1955
 
1620
1956
  @awsRetry
1621
- def _getInstanceProfileNames(self) -> List[str]:
1957
+ def _getInstanceProfileNames(self) -> list[str]:
1622
1958
  """
1623
1959
  Get all the instance profiles belonging to the cluster, as names.
1624
1960
  """
1625
1961
 
1626
1962
  results = []
1627
- boto3_iam = self.aws.client(self._region, 'iam')
1628
- for result in boto3_pager(boto3_iam.list_instance_profiles,
1629
- 'InstanceProfiles'):
1963
+ boto3_iam = self.aws.client(self._region, "iam")
1964
+ for result in boto3_pager(boto3_iam.list_instance_profiles, "InstanceProfiles"):
1630
1965
  # Grab out the name
1631
- cast(InstanceProfileTypeDef, result)
1632
- name = result['InstanceProfileName']
1966
+ result2 = cast("InstanceProfileTypeDef", result)
1967
+ name = result2["InstanceProfileName"]
1633
1968
  if self._is_our_namespaced_name(name):
1634
1969
  # If it looks like ours, it is ours.
1635
1970
  results.append(name)
1636
1971
  return results
1637
1972
 
1638
1973
  @awsRetry
1639
- def _getRoleInstanceProfileNames(self, role_name: str) -> List[str]:
1974
+ def _getRoleInstanceProfileNames(self, role_name: str) -> list[str]:
1640
1975
  """
1641
1976
  Get all the instance profiles with the IAM role with the given name.
1642
1977
 
@@ -1644,14 +1979,19 @@ class AWSProvisioner(AbstractProvisioner):
1644
1979
  """
1645
1980
 
1646
1981
  # Grab the connection we need to use for this operation.
1647
- boto3_iam: IAMClient = self.aws.client(self._region, 'iam')
1648
-
1649
- return [item['InstanceProfileName'] for item in boto3_pager(boto3_iam.list_instance_profiles_for_role,
1650
- 'InstanceProfiles',
1651
- RoleName=role_name)]
1982
+ boto3_iam: IAMClient = self.aws.client(self._region, "iam")
1983
+
1984
+ return [
1985
+ item["InstanceProfileName"]
1986
+ for item in boto3_pager(
1987
+ boto3_iam.list_instance_profiles_for_role,
1988
+ "InstanceProfiles",
1989
+ RoleName=role_name,
1990
+ )
1991
+ ]
1652
1992
 
1653
1993
  @awsRetry
1654
- def _getRolePolicyArns(self, role_name: str) -> List[str]:
1994
+ def _getRolePolicyArns(self, role_name: str) -> list[str]:
1655
1995
  """
1656
1996
  Get all the policies attached to the IAM role with the given name.
1657
1997
 
@@ -1661,34 +2001,44 @@ class AWSProvisioner(AbstractProvisioner):
1661
2001
  """
1662
2002
 
1663
2003
  # Grab the connection we need to use for this operation.
1664
- boto3_iam: IAMClient = self.aws.client(self._region, 'iam')
2004
+ boto3_iam: IAMClient = self.aws.client(self._region, "iam")
1665
2005
 
1666
2006
  # TODO: we don't currently use attached policies.
1667
2007
 
1668
- return [item['PolicyArn'] for item in boto3_pager(boto3_iam.list_attached_role_policies,
1669
- 'AttachedPolicies',
1670
- RoleName=role_name)]
2008
+ return [
2009
+ item["PolicyArn"]
2010
+ for item in boto3_pager(
2011
+ boto3_iam.list_attached_role_policies,
2012
+ "AttachedPolicies",
2013
+ RoleName=role_name,
2014
+ )
2015
+ ]
1671
2016
 
1672
2017
  @awsRetry
1673
- def _getRoleInlinePolicyNames(self, role_name: str) -> List[str]:
2018
+ def _getRoleInlinePolicyNames(self, role_name: str) -> list[str]:
1674
2019
  """
1675
2020
  Get all the policies inline in the given IAM role.
1676
2021
  Returns policy names.
1677
2022
  """
1678
2023
 
1679
2024
  # Grab the connection we need to use for this operation.
1680
- boto3_iam: IAMClient = self.aws.client(self._region, 'iam')
2025
+ boto3_iam: IAMClient = self.aws.client(self._region, "iam")
1681
2026
 
1682
- return list(boto3_pager(boto3_iam.list_role_policies, 'PolicyNames', RoleName=role_name))
2027
+ return list(
2028
+ boto3_pager(boto3_iam.list_role_policies, "PolicyNames", RoleName=role_name)
2029
+ )
1683
2030
 
1684
- def full_policy(self, resource: str) -> Dict[str, Any]:
2031
+ def full_policy(self, resource: str) -> dict[str, Any]:
1685
2032
  """
1686
2033
  Produce a dict describing the JSON form of a full-access-granting AWS
1687
2034
  IAM policy for the service with the given name (e.g. 's3').
1688
2035
  """
1689
- return dict(Version="2012-10-17", Statement=[dict(Effect="Allow", Resource="*", Action=f"{resource}:*")])
2036
+ return dict(
2037
+ Version="2012-10-17",
2038
+ Statement=[dict(Effect="Allow", Resource="*", Action=f"{resource}:*")],
2039
+ )
1690
2040
 
1691
- def kubernetes_policy(self) -> Dict[str, Any]:
2041
+ def kubernetes_policy(self) -> dict[str, Any]:
1692
2042
  """
1693
2043
  Get the Kubernetes policy grants not provided by the full grants on EC2
1694
2044
  and IAM. See
@@ -1702,98 +2052,86 @@ class AWSProvisioner(AbstractProvisioner):
1702
2052
  Some of these are really only needed on the leader.
1703
2053
  """
1704
2054
 
1705
- return dict(Version="2012-10-17", Statement=[dict(Effect="Allow", Resource="*", Action=[
1706
- "ecr:GetAuthorizationToken",
1707
- "ecr:BatchCheckLayerAvailability",
1708
- "ecr:GetDownloadUrlForLayer",
1709
- "ecr:GetRepositoryPolicy",
1710
- "ecr:DescribeRepositories",
1711
- "ecr:ListImages",
1712
- "ecr:BatchGetImage",
1713
- "autoscaling:DescribeAutoScalingGroups",
1714
- "autoscaling:DescribeAutoScalingInstances",
1715
- "autoscaling:DescribeLaunchConfigurations",
1716
- "autoscaling:DescribeTags",
1717
- "autoscaling:SetDesiredCapacity",
1718
- "autoscaling:TerminateInstanceInAutoScalingGroup",
1719
- "elasticloadbalancing:AddTags",
1720
- "elasticloadbalancing:ApplySecurityGroupsToLoadBalancer",
1721
- "elasticloadbalancing:AttachLoadBalancerToSubnets",
1722
- "elasticloadbalancing:ConfigureHealthCheck",
1723
- "elasticloadbalancing:CreateListener",
1724
- "elasticloadbalancing:CreateLoadBalancer",
1725
- "elasticloadbalancing:CreateLoadBalancerListeners",
1726
- "elasticloadbalancing:CreateLoadBalancerPolicy",
1727
- "elasticloadbalancing:CreateTargetGroup",
1728
- "elasticloadbalancing:DeleteListener",
1729
- "elasticloadbalancing:DeleteLoadBalancer",
1730
- "elasticloadbalancing:DeleteLoadBalancerListeners",
1731
- "elasticloadbalancing:DeleteTargetGroup",
1732
- "elasticloadbalancing:DeregisterInstancesFromLoadBalancer",
1733
- "elasticloadbalancing:DeregisterTargets",
1734
- "elasticloadbalancing:DescribeListeners",
1735
- "elasticloadbalancing:DescribeLoadBalancerAttributes",
1736
- "elasticloadbalancing:DescribeLoadBalancerPolicies",
1737
- "elasticloadbalancing:DescribeLoadBalancers",
1738
- "elasticloadbalancing:DescribeTargetGroups",
1739
- "elasticloadbalancing:DescribeTargetHealth",
1740
- "elasticloadbalancing:DetachLoadBalancerFromSubnets",
1741
- "elasticloadbalancing:ModifyListener",
1742
- "elasticloadbalancing:ModifyLoadBalancerAttributes",
1743
- "elasticloadbalancing:ModifyTargetGroup",
1744
- "elasticloadbalancing:RegisterInstancesWithLoadBalancer",
1745
- "elasticloadbalancing:RegisterTargets",
1746
- "elasticloadbalancing:SetLoadBalancerPoliciesForBackendServer",
1747
- "elasticloadbalancing:SetLoadBalancerPoliciesOfListener",
1748
- "kms:DescribeKey"
1749
- ])])
1750
-
1751
- def _setup_iam_ec2_role(self, local_role_name: str, policies: Dict[str, Any]) -> str:
2055
+ return dict(
2056
+ Version="2012-10-17",
2057
+ Statement=[
2058
+ dict(
2059
+ Effect="Allow",
2060
+ Resource="*",
2061
+ Action=[
2062
+ "ecr:GetAuthorizationToken",
2063
+ "ecr:BatchCheckLayerAvailability",
2064
+ "ecr:GetDownloadUrlForLayer",
2065
+ "ecr:GetRepositoryPolicy",
2066
+ "ecr:DescribeRepositories",
2067
+ "ecr:ListImages",
2068
+ "ecr:BatchGetImage",
2069
+ "autoscaling:DescribeAutoScalingGroups",
2070
+ "autoscaling:DescribeAutoScalingInstances",
2071
+ "autoscaling:DescribeLaunchConfigurations",
2072
+ "autoscaling:DescribeTags",
2073
+ "autoscaling:SetDesiredCapacity",
2074
+ "autoscaling:TerminateInstanceInAutoScalingGroup",
2075
+ "elasticloadbalancing:AddTags",
2076
+ "elasticloadbalancing:ApplySecurityGroupsToLoadBalancer",
2077
+ "elasticloadbalancing:AttachLoadBalancerToSubnets",
2078
+ "elasticloadbalancing:ConfigureHealthCheck",
2079
+ "elasticloadbalancing:CreateListener",
2080
+ "elasticloadbalancing:CreateLoadBalancer",
2081
+ "elasticloadbalancing:CreateLoadBalancerListeners",
2082
+ "elasticloadbalancing:CreateLoadBalancerPolicy",
2083
+ "elasticloadbalancing:CreateTargetGroup",
2084
+ "elasticloadbalancing:DeleteListener",
2085
+ "elasticloadbalancing:DeleteLoadBalancer",
2086
+ "elasticloadbalancing:DeleteLoadBalancerListeners",
2087
+ "elasticloadbalancing:DeleteTargetGroup",
2088
+ "elasticloadbalancing:DeregisterInstancesFromLoadBalancer",
2089
+ "elasticloadbalancing:DeregisterTargets",
2090
+ "elasticloadbalancing:DescribeListeners",
2091
+ "elasticloadbalancing:DescribeLoadBalancerAttributes",
2092
+ "elasticloadbalancing:DescribeLoadBalancerPolicies",
2093
+ "elasticloadbalancing:DescribeLoadBalancers",
2094
+ "elasticloadbalancing:DescribeTargetGroups",
2095
+ "elasticloadbalancing:DescribeTargetHealth",
2096
+ "elasticloadbalancing:DetachLoadBalancerFromSubnets",
2097
+ "elasticloadbalancing:ModifyListener",
2098
+ "elasticloadbalancing:ModifyLoadBalancerAttributes",
2099
+ "elasticloadbalancing:ModifyTargetGroup",
2100
+ "elasticloadbalancing:RegisterInstancesWithLoadBalancer",
2101
+ "elasticloadbalancing:RegisterTargets",
2102
+ "elasticloadbalancing:SetLoadBalancerPoliciesForBackendServer",
2103
+ "elasticloadbalancing:SetLoadBalancerPoliciesOfListener",
2104
+ "kms:DescribeKey",
2105
+ ],
2106
+ )
2107
+ ],
2108
+ )
2109
+
2110
+ def _setup_iam_ec2_role(
2111
+ self, local_role_name: str, policies: dict[str, Any]
2112
+ ) -> str:
1752
2113
  """
1753
2114
  Create an IAM role with the given policies, using the given name in
1754
2115
  addition to the cluster name, and return its full name.
1755
2116
  """
1756
-
1757
- # Grab the connection we need to use for this operation.
1758
- boto3_iam: IAMClient = self.aws.client(self._region, 'iam')
1759
-
1760
- # Make sure we can tell our roles apart from roles for other clusters
1761
- aws_role_name = self._namespace_name(local_role_name)
1762
- try:
1763
- # Make the role
1764
- logger.debug('Creating IAM role %s...', aws_role_name)
1765
- assume_role_policy_document = json.dumps({
2117
+ ec2_role_policy_document = json.dumps(
2118
+ {
1766
2119
  "Version": "2012-10-17",
1767
- "Statement": [{
1768
- "Effect": "Allow",
1769
- "Principal": {"Service": ["ec2.amazonaws.com"]},
1770
- "Action": ["sts:AssumeRole"]}
1771
- ]})
1772
- boto3_iam.create_role(RoleName=aws_role_name, AssumeRolePolicyDocument=assume_role_policy_document)
1773
- logger.debug('Created new IAM role')
1774
- except ClientError as e:
1775
- if get_error_status(e) == 409 and get_error_code(e) == 'EntityAlreadyExists':
1776
- logger.debug('IAM role already exists. Reusing.')
1777
- else:
1778
- raise
1779
-
1780
- # Delete superfluous policies
1781
- policy_names = set(boto3_iam.list_role_policies(RoleName=aws_role_name)["PolicyNames"])
1782
- for policy_name in policy_names.difference(set(list(policies.keys()))):
1783
- boto3_iam.delete_role_policy(RoleName=aws_role_name, PolicyName=policy_name)
1784
-
1785
- # Create expected policies
1786
- for policy_name, policy in policies.items():
1787
- current_policy = None
1788
- try:
1789
- current_policy = boto3_iam.get_role_policy(RoleName=aws_role_name, PolicyName=policy_name)["PolicyDocument"]
1790
- except boto3_iam.exceptions.NoSuchEntityException:
1791
- pass
1792
- if current_policy != policy:
1793
- boto3_iam.put_role_policy(RoleName=aws_role_name, PolicyName=policy_name, PolicyDocument=json.dumps(policy))
1794
-
1795
- # Now the role has the right policies so it is ready.
1796
- return aws_role_name
2120
+ "Statement": [
2121
+ {
2122
+ "Effect": "Allow",
2123
+ "Principal": {"Service": ["ec2.amazonaws.com"]},
2124
+ "Action": ["sts:AssumeRole"],
2125
+ }
2126
+ ],
2127
+ }
2128
+ )
2129
+ return create_iam_role(
2130
+ role_name=self._namespace_name(local_role_name),
2131
+ assume_role_policy_document=ec2_role_policy_document,
2132
+ policies=policies,
2133
+ region=self._region,
2134
+ )
1797
2135
 
1798
2136
  @awsRetry
1799
2137
  def _createProfileArn(self) -> str:
@@ -1805,22 +2143,30 @@ class AWSProvisioner(AbstractProvisioner):
1805
2143
  """
1806
2144
 
1807
2145
  # Grab the connection we need to use for this operation.
1808
- boto3_iam: IAMClient = self.aws.client(self._region, 'iam')
1809
-
1810
- policy = dict(iam_full=self.full_policy('iam'), ec2_full=self.full_policy('ec2'),
1811
- s3_full=self.full_policy('s3'), sbd_full=self.full_policy('sdb'))
1812
- if self.clusterType == 'kubernetes':
2146
+ boto3_iam: IAMClient = self.aws.client(self._region, "iam")
2147
+
2148
+ policy = dict(
2149
+ iam_full=self.full_policy("iam"),
2150
+ ec2_full=self.full_policy("ec2"),
2151
+ s3_full=self.full_policy("s3"),
2152
+ sbd_full=self.full_policy("sdb"),
2153
+ )
2154
+ if self.clusterType == "kubernetes":
1813
2155
  # We also need autoscaling groups and some other stuff for AWS-Kubernetes integrations.
1814
2156
  # TODO: We use one merged policy for leader and worker, but we could be more specific.
1815
- policy['kubernetes_merged'] = self.kubernetes_policy()
2157
+ policy["kubernetes_merged"] = self.kubernetes_policy()
1816
2158
  iamRoleName = self._setup_iam_ec2_role(_INSTANCE_PROFILE_ROLE_NAME, policy)
1817
2159
 
1818
2160
  try:
1819
- profile_result = boto3_iam.get_instance_profile(InstanceProfileName=iamRoleName)
2161
+ profile_result = boto3_iam.get_instance_profile(
2162
+ InstanceProfileName=iamRoleName
2163
+ )
1820
2164
  profile: InstanceProfileTypeDef = profile_result["InstanceProfile"]
1821
2165
  logger.debug("Have preexisting instance profile: %s", profile)
1822
2166
  except boto3_iam.exceptions.NoSuchEntityException:
1823
- profile_result = boto3_iam.create_instance_profile(InstanceProfileName=iamRoleName)
2167
+ profile_result = boto3_iam.create_instance_profile(
2168
+ InstanceProfileName=iamRoleName
2169
+ )
1824
2170
  profile = profile_result["InstanceProfile"]
1825
2171
  logger.debug("Created new instance profile: %s", profile)
1826
2172
  else:
@@ -1836,21 +2182,28 @@ class AWSProvisioner(AbstractProvisioner):
1836
2182
  # This is too many roles. We probably grabbed something we should
1837
2183
  # not have by mistake, and this is some important profile for
1838
2184
  # something else.
1839
- raise RuntimeError(f'Did not expect instance profile {profile_arn} to contain '
1840
- f'more than one role; is it really a Toil-managed profile?')
2185
+ raise RuntimeError(
2186
+ f"Did not expect instance profile {profile_arn} to contain "
2187
+ f"more than one role; is it really a Toil-managed profile?"
2188
+ )
1841
2189
  elif len(profile["Roles"]) == 1:
1842
2190
  if profile["Roles"][0]["RoleName"] == iamRoleName:
1843
2191
  return profile_arn
1844
2192
  else:
1845
2193
  # Drop this wrong role and use the fallback code for 0 roles
1846
- boto3_iam.remove_role_from_instance_profile(InstanceProfileName=iamRoleName,
1847
- RoleName=profile["Roles"][0]["RoleName"])
2194
+ boto3_iam.remove_role_from_instance_profile(
2195
+ InstanceProfileName=iamRoleName,
2196
+ RoleName=profile["Roles"][0]["RoleName"],
2197
+ )
1848
2198
 
1849
2199
  # If we get here, we had 0 roles on the profile, or we had 1 but we removed it.
1850
2200
  for attempt in old_retry(predicate=lambda err: get_error_status(err) == 404):
1851
2201
  with attempt:
1852
2202
  # Put the IAM role on the profile
1853
- boto3_iam.add_role_to_instance_profile(InstanceProfileName=profile["InstanceProfileName"], RoleName=iamRoleName)
2203
+ boto3_iam.add_role_to_instance_profile(
2204
+ InstanceProfileName=profile["InstanceProfileName"],
2205
+ RoleName=iamRoleName,
2206
+ )
1854
2207
  logger.debug("Associated role %s with profile", iamRoleName)
1855
2208
 
1856
2209
  return profile_arn