locust-cloud 1.12.4__py3-none-any.whl → 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
locust_cloud/cloud.py CHANGED
@@ -1,21 +1,29 @@
1
1
  import base64
2
2
  import gzip
3
+ import importlib.metadata
4
+ import json
3
5
  import logging
4
6
  import os
7
+ import pathlib
5
8
  import sys
6
9
  import threading
10
+ import time
7
11
  import tomllib
8
12
  import urllib.parse
13
+ import webbrowser
9
14
  from argparse import Namespace
10
15
  from collections import OrderedDict
16
+ from dataclasses import dataclass
11
17
  from typing import IO, Any
12
18
 
13
19
  import configargparse
20
+ import jwt
21
+ import platformdirs
14
22
  import requests
15
23
  import socketio
16
24
  import socketio.exceptions
17
- from locust_cloud import __version__
18
- from locust_cloud.credential_manager import CredentialError, CredentialManager
25
+
26
+ __version__ = importlib.metadata.version("locust-cloud")
19
27
 
20
28
 
21
29
  class LocustTomlConfigParser(configargparse.TomlConfigParser):
@@ -49,7 +57,7 @@ parser = configargparse.ArgumentParser(
49
57
  "cloud.conf",
50
58
  ],
51
59
  auto_env_var_prefix="LOCUSTCLOUD_",
52
- formatter_class=configargparse.RawDescriptionHelpFormatter,
60
+ formatter_class=configargparse.RawTextHelpFormatter,
53
61
  config_file_parser_class=configargparse.CompositeConfigParser(
54
62
  [
55
63
  LocustTomlConfigParser(["tool.locust"]),
@@ -108,36 +116,16 @@ advanced.add_argument(
108
116
  help="Optional requirements.txt file that contains your external libraries.",
109
117
  )
110
118
  advanced.add_argument(
111
- "--region",
112
- type=str,
113
- default=os.environ.get("AWS_DEFAULT_REGION"),
114
- help="Sets the AWS region to use for the deployed cluster, e.g. us-east-1. It defaults to use AWS_DEFAULT_REGION env var, like AWS tools.",
115
- )
116
- parser.add_argument(
117
- "--aws-access-key-id",
118
- type=str,
119
- help=configargparse.SUPPRESS,
120
- env_var="AWS_ACCESS_KEY_ID",
121
- default=None,
122
- )
123
- parser.add_argument(
124
- "--aws-secret-access-key",
125
- type=str,
126
- help=configargparse.SUPPRESS,
127
- env_var="AWS_SECRET_ACCESS_KEY",
128
- default=None,
129
- )
130
- parser.add_argument(
131
- "--username",
132
- type=str,
119
+ "--login",
120
+ action="store_true",
121
+ default=False,
133
122
  help=configargparse.SUPPRESS,
134
- default=os.getenv("LOCUST_CLOUD_USERNAME", None), # backwards compatitibility for dmdb
135
123
  )
136
- parser.add_argument(
137
- "--password",
138
- type=str,
139
- help=configargparse.SUPPRESS,
140
- default=os.getenv("LOCUST_CLOUD_PASSWORD", None), # backwards compatitibility for dmdb
124
+ advanced.add_argument(
125
+ "--non-interactive",
126
+ action="store_true",
127
+ default=False,
128
+ help="This can be set when, for example, running in a CI/CD environment to ensure no interactive steps while executing.\nRequires that LOCUSTCLOUD_USERNAME, LOCUSTCLOUD_PASSWORD and LOCUSTCLOUD_REGION environment variables are set.",
141
129
  )
142
130
  parser.add_argument(
143
131
  "--workers",
@@ -153,7 +141,7 @@ parser.add_argument(
153
141
  parser.add_argument(
154
142
  "--image-tag",
155
143
  type=str,
156
- default="latest",
144
+ default=None,
157
145
  help=configargparse.SUPPRESS, # overrides the locust-cloud docker image tag. for internal use
158
146
  )
159
147
  parser.add_argument(
@@ -169,6 +157,7 @@ parser.add_argument(
169
157
  )
170
158
 
171
159
  options, locust_options = parser.parse_known_args()
160
+
172
161
  options: Namespace
173
162
  locust_options: list
174
163
 
@@ -178,58 +167,426 @@ logging.basicConfig(
178
167
  )
179
168
  logger = logging.getLogger(__name__)
180
169
  # Restore log level for other libs. Yes, this can be done more nicely
181
- logging.getLogger("botocore").setLevel(logging.INFO)
182
- logging.getLogger("boto3").setLevel(logging.INFO)
183
170
  logging.getLogger("requests").setLevel(logging.INFO)
184
171
  logging.getLogger("urllib3").setLevel(logging.INFO)
185
172
 
173
+ cloud_conf_file = pathlib.Path(platformdirs.user_config_dir(appname="locust-cloud")) / "config"
174
+ valid_regions = ["us-east-1", "eu-north-1"]
186
175
 
187
- api_url = os.environ.get("LOCUSTCLOUD_DEPLOYER_URL", f"https://api.{options.region}.locust.cloud/1")
188
176
 
177
+ def get_api_url(region):
178
+ return os.environ.get("LOCUSTCLOUD_DEPLOYER_URL", f"https://api.{region}.locust.cloud/1")
189
179
 
190
- def main() -> None:
191
- if options.version:
192
- print(f"locust-cloud version {__version__}")
193
- sys.exit(0)
194
180
 
195
- if not options.region:
196
- logger.error(
197
- "Setting a region is required to use Locust Cloud. Please ensure the AWS_DEFAULT_REGION env variable or the --region flag is set."
198
- )
181
+ @dataclass
182
+ class CloudConfig:
183
+ id_token: str | None = None
184
+ refresh_token: str | None = None
185
+ refresh_token_expires: int = 0
186
+ region: str | None = None
187
+
188
+
189
+ def read_cloud_config() -> CloudConfig:
190
+ if cloud_conf_file.exists():
191
+ with open(cloud_conf_file) as f:
192
+ return CloudConfig(**json.load(f))
193
+
194
+ return CloudConfig()
195
+
196
+
197
+ def write_cloud_config(config: CloudConfig) -> None:
198
+ cloud_conf_file.parent.mkdir(parents=True, exist_ok=True)
199
+
200
+ with open(cloud_conf_file, "w") as f:
201
+ json.dump(config.__dict__, f)
202
+
203
+
204
+ def web_login() -> None:
205
+ print("Enter the number for the region to authenticate against")
206
+ print()
207
+ for i, valid_region in enumerate(valid_regions, start=1):
208
+ print(f" {i}. {valid_region}")
209
+ print()
210
+ choice = input("> ")
211
+ try:
212
+ region_index = int(choice) - 1
213
+ assert 0 <= region_index < len(valid_regions)
214
+ except (ValueError, AssertionError):
215
+ print(f"Not a valid choice: '{choice}'")
199
216
  sys.exit(1)
200
- if options.region:
201
- os.environ["AWS_DEFAULT_REGION"] = options.region
202
217
 
203
- if not ((options.username and options.password) or (options.aws_access_key_id and options.aws_secret_access_key)):
204
- logger.error(
205
- "Authentication is required to use Locust Cloud. Please ensure the LOCUSTCLOUD_USERNAME and LOCUSTCLOUD_PASSWORD environment variables are set."
206
- )
218
+ region = valid_regions[region_index]
219
+
220
+ try:
221
+ response = requests.post(f"{get_api_url(region)}/cli-auth")
222
+ response.raise_for_status()
223
+ response_data = response.json()
224
+ authentication_url = response_data["authentication_url"]
225
+ result_url = response_data["result_url"]
226
+ except Exception as e:
227
+ print("Something went wrong trying to authorize the locust-cloud CLI:", str(e))
207
228
  sys.exit(1)
229
+
230
+ message = f"""
231
+ Attempting to automatically open the SSO authorization page in your default browser.
232
+ If the browser does not open or you wish to use a different device to authorize this request, open the following URL:
233
+
234
+ {authentication_url}
235
+ """.strip()
236
+ print()
237
+ print(message)
238
+
239
+ webbrowser.open_new_tab(authentication_url)
240
+
241
+ while True: # Should there be some kind of timeout?
242
+ response = requests.get(result_url)
243
+
244
+ if not response.ok:
245
+ print("Oh no!")
246
+ print(response.text)
247
+ sys.exit(1)
248
+
249
+ data = response.json()
250
+
251
+ if data["state"] == "pending":
252
+ time.sleep(1)
253
+ continue
254
+ elif data["state"] == "failed":
255
+ print(f"\nFailed to authorize CLI: {data['reason']}")
256
+ sys.exit(1)
257
+ elif data["state"] == "authorized":
258
+ print("\nAuthorization succeded")
259
+ break
260
+ else:
261
+ print("\nGot unexpected response when authorizing CLI")
262
+ sys.exit(1)
263
+
264
+ config = CloudConfig(
265
+ id_token=data["id_token"],
266
+ refresh_token=data["refresh_token"],
267
+ refresh_token_expires=data["refresh_token_expires"],
268
+ region=region,
269
+ )
270
+ write_cloud_config(config)
271
+
272
+
273
+ class ApiSession(requests.Session):
274
+ def __init__(self) -> None:
275
+ super().__init__()
276
+
277
+ if options.non_interactive:
278
+ username = os.getenv("LOCUSTCLOUD_USERNAME")
279
+ password = os.getenv("LOCUSTCLOUD_PASSWORD")
280
+ region = os.getenv("LOCUSTCLOUD_REGION")
281
+
282
+ if not all([username, password, region]):
283
+ print(
284
+ "Running with --non-interaction requires that LOCUSTCLOUD_USERNAME, LOCUSTCLOUD_PASSWORD and LOCUSTCLOUD_REGION environment variables are set."
285
+ )
286
+ sys.exit(1)
287
+
288
+ if region not in valid_regions:
289
+ print("Environment variable LOCUSTCLOUD_REGION needs to be set to one of", ", ".join(valid_regions))
290
+ sys.exit(1)
291
+
292
+ self.__configure_for_region(region)
293
+ response = requests.post(
294
+ self.__login_url,
295
+ json={"username": username, "password": password},
296
+ headers={"X-Client-Version": __version__},
297
+ )
298
+ if not response.ok:
299
+ print(f"Authentication failed: {response.text}")
300
+ sys.exit(1)
301
+
302
+ self.__refresh_token = response.json()["refresh_token"]
303
+ id_token = response.json()["cognito_client_id_token"]
304
+
305
+ else:
306
+ config = read_cloud_config()
307
+
308
+ if config.refresh_token_expires < time.time() + 24 * 60 * 60:
309
+ message = "You need to authenticate before proceeding. Please run:\n locust-cloud --login"
310
+ print(message)
311
+ sys.exit(1)
312
+
313
+ assert config.region
314
+ self.__configure_for_region(config.region)
315
+ self.__refresh_token = config.refresh_token
316
+ id_token = config.id_token
317
+
318
+ assert id_token
319
+
320
+ decoded = jwt.decode(id_token, options={"verify_signature": False})
321
+ self.__expiry_time = decoded["exp"] - 60 # Refresh 1 minute before expiry
322
+ self.headers["Authorization"] = f"Bearer {id_token}"
323
+
324
+ self.__sub = decoded["sub"]
325
+ self.headers["X-Client-Version"] = __version__
326
+
327
+ def __configure_for_region(self, region: str) -> None:
328
+ self.__region = region
329
+ self.api_url = get_api_url(region)
330
+ self.__login_url = f"{self.api_url}/auth/login"
331
+
332
+ logger.debug(f"Lambda url: {self.api_url}")
333
+
334
+ def __ensure_valid_authorization_header(self) -> None:
335
+ if self.__expiry_time > time.time():
336
+ return
337
+
338
+ logger.info(f"Authenticating ({self.__region}, v{__version__})")
339
+
340
+ response = requests.post(
341
+ self.__login_url,
342
+ json={"user_sub_id": self.__sub, "refresh_token": self.__refresh_token},
343
+ headers={"X-Client-Version": __version__},
344
+ )
345
+
346
+ if not response.ok:
347
+ logger.error(f"Authentication failed: {response.text}")
348
+ sys.exit(1)
349
+
350
+ # TODO: Technically the /login endpoint can return a challenge for you
351
+ # to change your password. Don't know how we should handle that
352
+ # in the cli.
353
+
354
+ id_token = response.json()["cognito_client_id_token"]
355
+ decoded = jwt.decode(id_token, options={"verify_signature": False})
356
+ self.__expiry_time = decoded["exp"] - 60 # Refresh 1 minute before expiry
357
+ self.headers["Authorization"] = f"Bearer {id_token}"
358
+
359
+ if not options.non_interactive:
360
+ config = read_cloud_config()
361
+ config.id_token = id_token
362
+ write_cloud_config(config)
363
+
364
+ def request(self, method, url, *args, **kwargs) -> requests.Response:
365
+ self.__ensure_valid_authorization_header()
366
+ return super().request(method, f"{self.api_url}{url}", *args, **kwargs)
367
+
368
+
369
+ class SessionMismatchError(Exception):
370
+ pass
371
+
372
+
373
+ class WebsocketTimeout(Exception):
374
+ pass
375
+
376
+
377
+ class Websocket:
378
+ def __init__(self) -> None:
379
+ """
380
+ This class was created to encapsulate all the logic involved in the websocket implementation.
381
+ The behaviour of the socketio client once a connection has been established
382
+ is to try to reconnect forever if the connection is lost.
383
+ The way this can be canceled is by setting the _reconnect_abort (threading.Event) on the client
384
+ in which case it will simply proceed with shutting down without giving any indication of an error.
385
+ This class handles timeouts for connection attempts as well as some logic around when the
386
+ socket can be shut down. See descriptions on the methods for further details.
387
+ """
388
+ self.__shutdown_allowed = threading.Event()
389
+ self.__timeout_on_disconnect = True
390
+ self.initial_connect_timeout = 120
391
+ self.reconnect_timeout = 10
392
+ self.wait_timeout = 0
393
+ self.exception: None | Exception = None
394
+
395
+ self.sio = socketio.Client(handle_sigint=False)
396
+ self.sio._reconnect_abort = threading.Event()
397
+ # The _reconnect_abort value on the socketio client will be populated with a newly created threading.Event if it's not already set.
398
+ # There is no way to set this by passing it in the constructor.
399
+ # This event is the only way to interupt the retry logic when the connection is attempted.
400
+
401
+ self.sio.on("connect", self.__on_connect)
402
+ self.sio.on("disconnect", self.__on_disconnect)
403
+ self.sio.on("connect_error", self.__on_connect_error)
404
+ self.sio.on("events", self.__on_events)
405
+
406
+ self.__processed_events: set[int] = set()
407
+
408
+ def __set_connection_timeout(self, timeout) -> None:
409
+ """
410
+ Start a threading.Timer that will set the threading.Event on the socketio client
411
+ that aborts any further attempts to reconnect, sets an exception on the websocket
412
+ that will be raised from the wait method and the threading.Event __shutdown_allowed
413
+ on the websocket that tells the wait method that it should stop blocking.
414
+ """
415
+
416
+ def _timeout():
417
+ logger.debug(f"Websocket connection timed out after {timeout} seconds")
418
+ self.sio._reconnect_abort.set()
419
+ self.exception = WebsocketTimeout("Timed out connecting to locust master")
420
+ self.__shutdown_allowed.set()
421
+
422
+ self.__connect_timeout_timer = threading.Timer(timeout, _timeout)
423
+ self.__connect_timeout_timer.daemon = True
424
+ logger.debug(f"Setting websocket connection timeout to {timeout} seconds")
425
+ self.__connect_timeout_timer.start()
426
+
427
+ def connect(self, url, *, auth) -> None:
428
+ """
429
+ Send along retry=True when initiating the socketio client connection
430
+ to make it use it's builtin logic for retrying failed connections that
431
+ is usually used for reconnections. This will retry forever.
432
+ When connecting start a timer to trigger disabling the retry logic and
433
+ raise a WebsocketTimeout exception.
434
+ """
435
+ ws_connection_info = urllib.parse.urlparse(url)
436
+ self.__set_connection_timeout(self.initial_connect_timeout)
437
+ try:
438
+ self.sio.connect(
439
+ f"{ws_connection_info.scheme}://{ws_connection_info.netloc}",
440
+ auth=auth,
441
+ retry=True,
442
+ **{"socketio_path": ws_connection_info.path} if ws_connection_info.path else {},
443
+ )
444
+ except socketio.exceptions.ConnectionError:
445
+ if self.exception:
446
+ raise self.exception
447
+
448
+ raise
449
+
450
+ def shutdown(self) -> None:
451
+ """
452
+ When shutting down the socketio client a disconnect event will fire.
453
+ Before doing so disable the behaviour of starting a threading.Timer
454
+ to handle timeouts on attempts to reconnect since no further such attempts
455
+ will be made.
456
+ If such a timer is already running, cancel it since the client is being shutdown.
457
+ """
458
+ self.__timeout_on_disconnect = False
459
+ if hasattr(self, "__connect_timeout_timer"):
460
+ self.__connect_timeout_timer.cancel()
461
+ self.sio.shutdown()
462
+
463
+ def wait(self, timeout=False) -> bool:
464
+ """
465
+ Block until the threading.Event __shutdown_allowed is set, with a timeout if indicated.
466
+ If an exception has been set on the websocket (from a connection timeout timer or the
467
+ __on_connect_error method), raise it.
468
+ """
469
+ timeout = self.wait_timeout if timeout else None
470
+ logger.debug(f"Waiting for shutdown for {str(timeout)+'s' if timeout else 'ever'}")
471
+ res = self.__shutdown_allowed.wait(timeout)
472
+ if self.exception:
473
+ raise self.exception
474
+ return res
475
+
476
+ def __on_connect(self) -> None:
477
+ """
478
+ This gets events whenever a connection is successfully established.
479
+ When this happens, cancel the running threading.Timer that would
480
+ abort reconnect attempts and raise a WebsocketTimeout exception.
481
+ The wait_timeout is originally set to zero when creating the websocket
482
+ but once a connection has been established this is raised to ensure
483
+ that the server is given the chance to send all the logs and an
484
+ official shutdown event.
485
+ """
486
+ self.__connect_timeout_timer.cancel()
487
+ self.wait_timeout = 90
488
+ logger.debug("Websocket connected")
489
+
490
+ def __on_disconnect(self) -> None:
491
+ """
492
+ This gets events whenever a connection is lost.
493
+ The socketio client will try to reconnect forever so,
494
+ unless the behaviour has been disabled, a threading.Timer
495
+ is started that will abort reconnect attempts and raise a
496
+ WebsocketTimeout exception.
497
+ """
498
+ if self.__timeout_on_disconnect:
499
+ self.__set_connection_timeout(self.reconnect_timeout)
500
+ logger.debug("Websocket disconnected")
501
+
502
+ def __on_events(self, data):
503
+ """
504
+ This gets events explicitly sent by the websocket server.
505
+ This will either be messages to print on stdout/stderr or
506
+ an indication that the CLI can shut down in which case the
507
+ threading.Event __shutdown_allowed gets set on the websocket
508
+ that tells the wait method that it should stop blocking.
509
+ """
510
+ shutdown = False
511
+ shutdown_message = ""
512
+
513
+ if data["id"] in self.__processed_events:
514
+ logger.debug(f"Got duplicate data on websocket, id {data['id']}")
515
+ return
516
+
517
+ self.__processed_events.add(data["id"])
518
+
519
+ for event in data["events"]:
520
+ type = event["type"]
521
+
522
+ if type == "shutdown":
523
+ shutdown = True
524
+ shutdown_message = event["message"]
525
+ elif type == "stdout":
526
+ sys.stdout.write(event["message"])
527
+ elif type == "stderr":
528
+ sys.stderr.write(event["message"])
529
+ else:
530
+ raise Exception("Unexpected event type")
531
+
532
+ if shutdown:
533
+ logger.debug("Got shutdown from locust master")
534
+ if shutdown_message:
535
+ print(shutdown_message)
536
+
537
+ self.__shutdown_allowed.set()
538
+
539
+ def __on_connect_error(self, data) -> None:
540
+ """
541
+ This gets events whenever there's an error during connection attempts.
542
+ The specific case that is handled below is triggered when the connection
543
+ is made with the auth parameter not matching the session ID on the server.
544
+ If this error occurs it's because the connection is attempted towards an
545
+ instance of locust not started by this CLI.
546
+
547
+ In that case:
548
+ Cancel the running threading.Timer that would abort reconnect attempts
549
+ and raise a WebsocketTimeout exception.
550
+ Set an exception on the websocket that will be raised from the wait method.
551
+ Cancel further reconnect attempts.
552
+ Set the threading.Event __shutdown_allowed on the websocket that tells the
553
+ wait method that it should stop blocking.
554
+ """
555
+ # Do nothing if it's not the specific case we know how to deal with
556
+ if not (isinstance(data, dict) and data.get("message") == "Session mismatch"):
557
+ return
558
+
559
+ self.__connect_timeout_timer.cancel()
560
+ self.exception = SessionMismatchError(
561
+ "The session from this run of locust-cloud did not match the one on the server"
562
+ )
563
+ self.sio._reconnect_abort.set()
564
+ self.__shutdown_allowed.set()
565
+
566
+
567
+ def main() -> None:
568
+ if options.version:
569
+ print(f"locust-cloud version {__version__}")
570
+ sys.exit(0)
208
571
  if not options.locustfile:
209
572
  logger.error("A locustfile is required to run a test.")
210
573
  sys.exit(1)
211
574
 
212
- try:
213
- logger.info(f"Authenticating ({options.region}, v{__version__})")
214
- logger.debug(f"Lambda url: {api_url}")
215
- credential_manager = CredentialManager(
216
- lambda_url=api_url,
217
- access_key=options.aws_access_key_id,
218
- secret_key=options.aws_secret_access_key,
219
- username=options.username,
220
- password=options.password,
221
- )
575
+ if options.login:
576
+ try:
577
+ web_login()
578
+ except KeyboardInterrupt:
579
+ pass
580
+ sys.exit()
222
581
 
223
- credentials = credential_manager.get_current_credentials()
224
- cognito_client_id_token = credentials["cognito_client_id_token"]
225
- aws_access_key_id = credentials.get("access_key")
226
- aws_secret_access_key = credentials.get("secret_key")
227
- aws_session_token = credentials.get("token", "")
582
+ session = ApiSession()
583
+ websocket = Websocket()
228
584
 
229
- if options.delete:
230
- delete(credential_manager)
231
- return
585
+ if options.delete:
586
+ delete(session)
587
+ sys.exit()
232
588
 
589
+ try:
233
590
  try:
234
591
  with open(options.locustfile, "rb") as f:
235
592
  locustfile_data = base64.b64encode(gzip.compress(f.read())).decode()
@@ -249,54 +606,46 @@ def main() -> None:
249
606
 
250
607
  logger.info("Deploying load generators")
251
608
  locust_env_variables = [
252
- {"name": env_variable, "value": str(os.environ[env_variable])}
609
+ {"name": env_variable, "value": os.environ[env_variable]}
253
610
  for env_variable in os.environ
254
611
  if env_variable.startswith("LOCUST_")
255
- and not env_variable
256
- in [
612
+ and env_variable
613
+ not in [
257
614
  "LOCUST_LOCUSTFILE",
258
615
  "LOCUST_USERS",
259
616
  "LOCUST_WEB_HOST_DISPLAY_NAME",
260
617
  "LOCUST_SKIP_MONKEY_PATCH",
261
618
  ]
262
- and os.environ[env_variable]
263
619
  ]
264
- deploy_endpoint = f"{api_url}/deploy"
265
620
  payload = {
266
621
  "locust_args": [
267
622
  {"name": "LOCUST_USERS", "value": str(options.users)},
268
623
  {"name": "LOCUST_FLAGS", "value": " ".join(locust_options)},
269
- {"name": "LOCUSTCLOUD_DEPLOYER_URL", "value": api_url},
624
+ {"name": "LOCUSTCLOUD_DEPLOYER_URL", "value": session.api_url},
270
625
  {"name": "LOCUSTCLOUD_PROFILE", "value": options.profile},
271
626
  *locust_env_variables,
272
627
  ],
273
628
  "locustfile": {"filename": options.locustfile, "data": locustfile_data},
274
629
  "user_count": options.users,
275
- "image_tag": options.image_tag,
276
630
  "mock_server": options.mock_server,
277
631
  }
632
+
633
+ if options.image_tag is not None:
634
+ payload["image_tag"] = options.image_tag
635
+
278
636
  if options.workers is not None:
279
637
  payload["worker_count"] = options.workers
638
+
280
639
  if options.requirements:
281
640
  payload["requirements"] = {"filename": options.requirements, "data": requirements_data}
282
- headers = {
283
- "Authorization": f"Bearer {cognito_client_id_token}",
284
- "Content-Type": "application/json",
285
- "AWS_ACCESS_KEY_ID": aws_access_key_id,
286
- "AWS_SECRET_ACCESS_KEY": aws_secret_access_key,
287
- "AWS_SESSION_TOKEN": aws_session_token,
288
- "X-Client-Version": __version__,
289
- }
641
+
290
642
  try:
291
- # logger.info(payload) # might be useful when debugging sometimes
292
- response = requests.post(deploy_endpoint, json=payload, headers=headers)
643
+ response = session.post("/deploy", json=payload)
293
644
  except requests.exceptions.RequestException as e:
294
645
  logger.error(f"Failed to deploy the load generators: {e}")
295
646
  sys.exit(1)
296
647
 
297
- if response.status_code == 200:
298
- log_ws_url = response.json()["log_ws_url"]
299
- else:
648
+ if response.status_code != 200:
300
649
  try:
301
650
  logger.error(f"{response.json()['Message']} (HTTP {response.status_code}/{response.reason})")
302
651
  except Exception:
@@ -304,98 +653,50 @@ def main() -> None:
304
653
  f"HTTP {response.status_code}/{response.reason} - Response: {response.text} - URL: {response.request.url}"
305
654
  )
306
655
  sys.exit(1)
307
- except CredentialError as ce:
308
- logger.error(f"Credential error: {ce}")
309
- sys.exit(1)
310
- except KeyboardInterrupt:
311
- logger.debug("Interrupted by user")
312
- sys.exit(0)
313
-
314
- logger.debug("Load generators deployed successfully!")
315
- logger.info("Waiting for pods to be ready...")
316
-
317
- shutdown_allowed = threading.Event()
318
- shutdown_allowed.set()
319
- reconnect_aborted = threading.Event()
320
- connect_timeout = threading.Timer(2 * 60, reconnect_aborted.set)
321
-
322
- try:
323
- ws_connection_info = urllib.parse.urlparse(log_ws_url)
324
- sio = socketio.Client(handle_sigint=False)
325
-
326
- @sio.event
327
- def connect():
328
- shutdown_allowed.clear()
329
- connect_timeout.cancel()
330
- logger.debug("Websocket connection established, switching to Locust logs")
331
-
332
- @sio.event
333
- def disconnect():
334
- logger.debug("Websocket disconnected")
335
656
 
336
- @sio.event
337
- def stderr(message):
338
- sys.stderr.write(message)
657
+ log_ws_url = response.json()["log_ws_url"]
658
+ session_id = response.json()["session_id"]
659
+ logger.debug(f"Session ID is {session_id}")
339
660
 
340
- @sio.event
341
- def stdout(message):
342
- sys.stdout.write(message)
343
-
344
- @sio.event
345
- def shutdown(message):
346
- logger.debug("Got shutdown from locust master")
347
- if message:
348
- print(message)
349
-
350
- shutdown_allowed.set()
351
-
352
- # The _reconnect_abort value on the socketio client will be populated with a newly created threading.Event if it's not already set.
353
- # There is no way to set this by passing it in the constructor.
354
- # This event is the only way to interupt the retry logic when the connection is attempted.
355
- sio._reconnect_abort = reconnect_aborted
356
- connect_timeout.start()
357
- sio.connect(
358
- f"{ws_connection_info.scheme}://{ws_connection_info.netloc}",
359
- socketio_path=ws_connection_info.path,
360
- retry=True,
661
+ logger.info("Waiting for pods to be ready...")
662
+ websocket.connect(
663
+ log_ws_url,
664
+ auth=session_id,
361
665
  )
362
- logger.debug("Waiting for shutdown")
363
- shutdown_allowed.wait()
666
+ websocket.wait()
364
667
 
365
668
  except KeyboardInterrupt:
366
669
  logger.debug("Interrupted by user")
367
- delete(credential_manager)
368
- shutdown_allowed.wait(timeout=90)
670
+ delete(session)
671
+ try:
672
+ websocket.wait(timeout=True)
673
+ except (WebsocketTimeout, SessionMismatchError) as e:
674
+ logger.error(str(e))
675
+ sys.exit(1)
676
+ except WebsocketTimeout as e:
677
+ logger.error(str(e))
678
+ delete(session)
679
+ sys.exit(1)
680
+ except SessionMismatchError as e:
681
+ # In this case we do not trigger the teardown since the running instance is not ours
682
+ logger.error(str(e))
683
+ sys.exit(1)
369
684
  except Exception as e:
370
685
  logger.exception(e)
371
- delete(credential_manager)
686
+ delete(session)
372
687
  sys.exit(1)
373
688
  else:
374
- delete(credential_manager)
689
+ delete(session)
375
690
  finally:
376
- sio.shutdown()
691
+ logger.debug("Shutting down websocket")
692
+ websocket.shutdown()
377
693
 
378
694
 
379
- def delete(credential_manager):
695
+ def delete(session):
380
696
  try:
381
697
  logger.info("Tearing down Locust cloud...")
382
- credential_manager.refresh_credentials()
383
- refreshed_credentials = credential_manager.get_current_credentials()
384
-
385
- headers = {
386
- "AWS_ACCESS_KEY_ID": refreshed_credentials.get("access_key", ""),
387
- "AWS_SECRET_ACCESS_KEY": refreshed_credentials.get("secret_key", ""),
388
- "Authorization": f"Bearer {refreshed_credentials.get('cognito_client_id_token', '')}",
389
- "X-Client-Version": __version__,
390
- }
391
-
392
- token = refreshed_credentials.get("token")
393
- if token:
394
- headers["AWS_SESSION_TOKEN"] = token
395
-
396
- response = requests.delete(
397
- f"{api_url}/teardown",
398
- headers=headers,
698
+ response = session.delete(
699
+ "/teardown",
399
700
  )
400
701
 
401
702
  if response.status_code == 200:
@@ -407,7 +708,7 @@ def delete(credential_manager):
407
708
  except Exception as e:
408
709
  logger.error(f"Could not automatically tear down Locust Cloud: {e.__class__.__name__}:{e}")
409
710
 
410
- logger.info("Done! ✨")
711
+ logger.info("Done! ✨") # FIXME: Should probably not say it's done since at this point it could still be running
411
712
 
412
713
 
413
714
  if __name__ == "__main__":