great-expectations-cloud 20240523.0.dev0__py3-none-any.whl → 20251124.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. great_expectations_cloud/agent/__init__.py +3 -0
  2. great_expectations_cloud/agent/actions/__init__.py +8 -5
  3. great_expectations_cloud/agent/actions/agent_action.py +21 -6
  4. great_expectations_cloud/agent/actions/draft_datasource_config_action.py +45 -24
  5. great_expectations_cloud/agent/actions/generate_data_quality_check_expectations_action.py +557 -0
  6. great_expectations_cloud/agent/actions/list_asset_names.py +65 -0
  7. great_expectations_cloud/agent/actions/run_checkpoint.py +74 -27
  8. great_expectations_cloud/agent/actions/run_metric_list_action.py +11 -5
  9. great_expectations_cloud/agent/actions/run_scheduled_checkpoint.py +67 -0
  10. great_expectations_cloud/agent/actions/run_window_checkpoint.py +66 -0
  11. great_expectations_cloud/agent/actions/utils.py +35 -0
  12. great_expectations_cloud/agent/agent.py +444 -101
  13. great_expectations_cloud/agent/cli.py +2 -2
  14. great_expectations_cloud/agent/config.py +19 -5
  15. great_expectations_cloud/agent/event_handler.py +49 -12
  16. great_expectations_cloud/agent/exceptions.py +9 -0
  17. great_expectations_cloud/agent/message_service/asyncio_rabbit_mq_client.py +80 -14
  18. great_expectations_cloud/agent/message_service/subscriber.py +8 -5
  19. great_expectations_cloud/agent/models.py +197 -20
  20. great_expectations_cloud/agent/utils.py +84 -0
  21. great_expectations_cloud/logging/logging_cfg.py +20 -4
  22. great_expectations_cloud/py.typed +0 -0
  23. {great_expectations_cloud-20240523.0.dev0.dist-info → great_expectations_cloud-20251124.0.dev1.dist-info}/METADATA +54 -46
  24. great_expectations_cloud-20251124.0.dev1.dist-info/RECORD +34 -0
  25. {great_expectations_cloud-20240523.0.dev0.dist-info → great_expectations_cloud-20251124.0.dev1.dist-info}/WHEEL +1 -1
  26. great_expectations_cloud/agent/actions/data_assistants/__init__.py +0 -8
  27. great_expectations_cloud/agent/actions/data_assistants/run_missingness_data_assistant.py +0 -45
  28. great_expectations_cloud/agent/actions/data_assistants/run_onboarding_data_assistant.py +0 -45
  29. great_expectations_cloud/agent/actions/data_assistants/utils.py +0 -123
  30. great_expectations_cloud/agent/actions/list_table_names.py +0 -76
  31. great_expectations_cloud-20240523.0.dev0.dist-info/RECORD +0 -32
  32. {great_expectations_cloud-20240523.0.dev0.dist-info → great_expectations_cloud-20251124.0.dev1.dist-info}/entry_points.txt +0 -0
  33. {great_expectations_cloud-20240523.0.dev0.dist-info → great_expectations_cloud-20251124.0.dev1.dist-info/licenses}/LICENSE +0 -0
@@ -2,6 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  import asyncio
4
4
  import logging
5
+ import os
6
+ import signal
5
7
  import traceback
6
8
  import warnings
7
9
  from collections import defaultdict
@@ -9,24 +11,45 @@ from concurrent.futures import Future
9
11
  from concurrent.futures.thread import ThreadPoolExecutor
10
12
  from functools import partial
11
13
  from importlib.metadata import version as metadata_version
12
- from typing import TYPE_CHECKING, Any, Dict, Final
13
-
14
- from great_expectations import get_context # type: ignore[attr-defined] # TODO: fix this
15
- from great_expectations.compatibility import pydantic
16
- from great_expectations.compatibility.pydantic import AmqpDsn, AnyUrl
14
+ from typing import TYPE_CHECKING, Any, Callable, Final, Literal
15
+ from urllib.parse import urljoin, urlparse
16
+ from uuid import UUID
17
+
18
+ import orjson
19
+ import requests
20
+ from great_expectations import __version__, get_context
21
+ from great_expectations.core import http
17
22
  from great_expectations.core.http import create_session
18
23
  from great_expectations.data_context.cloud_constants import CLOUD_DEFAULT_BASE_URL
19
- from packaging.version import Version
20
- from pika.exceptions import AuthenticationError, ProbableAuthenticationError
21
- from tenacity import after_log, retry, retry_if_exception_type, stop_after_attempt, wait_exponential
24
+ from great_expectations.data_context.types.base import ProgressBarsConfig
25
+ from pika.adapters.utils.connection_workflow import AMQPConnectorException
26
+ from pika.exceptions import (
27
+ AMQPConnectionError,
28
+ AMQPError,
29
+ AuthenticationError,
30
+ ChannelError,
31
+ ProbableAuthenticationError,
32
+ )
33
+ from pydantic import v1 as pydantic_v1
34
+ from pydantic.v1 import AmqpDsn, AnyUrl
35
+ from tenacity import (
36
+ after_log,
37
+ retry,
38
+ retry_if_exception_type,
39
+ stop_after_attempt,
40
+ wait_random_exponential,
41
+ )
22
42
 
23
43
  from great_expectations_cloud.agent.config import (
24
44
  GxAgentEnvVars,
25
45
  generate_config_validation_error_text,
26
46
  )
27
47
  from great_expectations_cloud.agent.constants import USER_AGENT_HEADER, HeaderName
28
- from great_expectations_cloud.agent.event_handler import (
29
- EventHandler,
48
+ from great_expectations_cloud.agent.event_handler import EventHandler
49
+ from great_expectations_cloud.agent.exceptions import (
50
+ GXAgentConfigError,
51
+ GXAgentError,
52
+ GXAgentUnrecoverableConnectionError,
30
53
  )
31
54
  from great_expectations_cloud.agent.message_service.asyncio_rabbit_mq_client import (
32
55
  AsyncRabbitMQClient,
@@ -39,16 +62,20 @@ from great_expectations_cloud.agent.message_service.subscriber import (
39
62
  SubscriberError,
40
63
  )
41
64
  from great_expectations_cloud.agent.models import (
42
- AgentBaseModel,
65
+ AgentBaseExtraForbid,
66
+ CreateScheduledJobAndSetJobStarted,
67
+ CreateScheduledJobAndSetJobStartedRequest,
68
+ DomainContext,
43
69
  JobCompleted,
44
70
  JobStarted,
45
71
  JobStatus,
72
+ ScheduledEventBase,
46
73
  UnknownEvent,
74
+ UpdateJobStatusRequest,
47
75
  build_failed_job_completed_status,
48
76
  )
49
77
 
50
78
  if TYPE_CHECKING:
51
- import requests
52
79
  from great_expectations.data_context import CloudDataContext
53
80
  from typing_extensions import Self
54
81
 
@@ -56,11 +83,11 @@ if TYPE_CHECKING:
56
83
 
57
84
  LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
58
85
  # TODO Set in log dict
59
- LOGGER.setLevel(logging.INFO)
60
- HandlerMap = Dict[str, OnMessageCallback]
86
+ LOGGER.setLevel(logging.DEBUG)
87
+ HandlerMap = dict[str, OnMessageCallback]
61
88
 
62
89
 
63
- class GXAgentConfig(AgentBaseModel):
90
+ class GXAgentConfig(AgentBaseExtraForbid):
64
91
  """GXAgent configuration.
65
92
  Attributes:
66
93
  queue: name of queue
@@ -69,10 +96,33 @@ class GXAgentConfig(AgentBaseModel):
69
96
 
70
97
  queue: str
71
98
  connection_string: AmqpDsn
72
- # pydantic will coerce this string to AnyUrl type
73
- gx_cloud_base_url: AnyUrl = CLOUD_DEFAULT_BASE_URL
99
+ gx_cloud_base_url: AnyUrl = AnyUrl(url=CLOUD_DEFAULT_BASE_URL, scheme="https")
74
100
  gx_cloud_organization_id: str
75
101
  gx_cloud_access_token: str
102
+ enable_progress_bars: bool = True
103
+
104
+
105
+ def orjson_dumps(v: Any, *, default: Callable[[Any], Any] | None) -> str:
106
+ # orjson.dumps returns bytes, to match standard json.dumps we need to decode
107
+ # Typing using example from https://github.com/ijl/orjson?tab=readme-ov-file#serialize
108
+ return orjson.dumps(
109
+ v,
110
+ default=default,
111
+ ).decode()
112
+
113
+
114
+ def orjson_loads(v: bytes | bytearray | memoryview | str) -> Any:
115
+ # Typing using example from https://github.com/ijl/orjson?tab=readme-ov-file#deserialize
116
+ return orjson.loads(v)
117
+
118
+
119
+ class Payload(AgentBaseExtraForbid):
120
+ data: dict[str, Any]
121
+
122
+ class Config:
123
+ extra = "forbid"
124
+ json_dumps = orjson_dumps
125
+ json_loads = orjson_loads
76
126
 
77
127
 
78
128
  class GXAgent:
@@ -88,19 +138,17 @@ class GXAgent:
88
138
  _PYPI_GREAT_EXPECTATIONS_PACKAGE_NAME = "great_expectations"
89
139
 
90
140
  def __init__(self: Self):
91
- agent_version: str = self.get_current_gx_agent_version()
92
- print(f"GX Agent version: {agent_version}")
93
- print("Initializing the GX Agent.")
94
- self._set_http_session_headers()
95
- self._config = self._get_config()
96
- print("Loading a DataContext - this might take a moment.")
97
-
98
- with warnings.catch_warnings():
99
- # suppress warnings about GX version
100
- warnings.filterwarnings("ignore", message="You are using great_expectations version")
101
- self._context: CloudDataContext = get_context(cloud_mode=True)
141
+ self._config = self._create_config()
102
142
 
103
- print("DataContext is ready.")
143
+ agent_version: str = self.get_current_gx_agent_version()
144
+ great_expectations_version: str = self._get_current_great_expectations_version()
145
+ LOGGER.info(
146
+ "Initializing GX Agent.",
147
+ extra={
148
+ "agent_version": agent_version,
149
+ "great_expectations_version": great_expectations_version,
150
+ },
151
+ )
104
152
 
105
153
  # Create a thread pool with a single worker, so we can run long-lived
106
154
  # GX processes and maintain our connection to the broker. Note that
@@ -110,45 +158,63 @@ class GXAgent:
110
158
  self._current_task: Future[Any] | None = None
111
159
  self._redeliver_msg_task: asyncio.Task[Any] | None = None
112
160
  self._correlation_ids: defaultdict[str, int] = defaultdict(lambda: 0)
161
+ self._listen_tries = 0
113
162
 
114
163
  def run(self) -> None:
115
164
  """Open a connection to GX Cloud."""
116
165
 
117
- print("Opening connection to GX Cloud.")
166
+ LOGGER.debug("Opening connection to GX Cloud.")
167
+ self._listen_tries = 0
118
168
  self._listen()
119
- print("The connection to GX Cloud has been closed.")
169
+ LOGGER.debug("The connection to GX Cloud has been closed.")
120
170
 
121
171
  # ZEL-505: A race condition can occur if two or more agents are started at the same time
122
172
  # due to the generation of passwords for rabbitMQ queues. This can be mitigated
123
173
  # by adding a delay and retrying the connection. Retrying with new credentials
124
174
  # requires calling get_config again, which handles the password generation.
125
175
  @retry(
126
- retry=retry_if_exception_type((AuthenticationError, ProbableAuthenticationError)),
127
- wait=wait_exponential(multiplier=1, min=1, max=10),
176
+ retry=retry_if_exception_type(
177
+ (AuthenticationError, ProbableAuthenticationError, AMQPError, ChannelError)
178
+ ),
179
+ wait=wait_random_exponential(multiplier=1, min=1, max=10),
128
180
  stop=stop_after_attempt(3),
129
181
  after=after_log(LOGGER, logging.DEBUG),
130
182
  )
131
183
  def _listen(self) -> None:
132
184
  """Manage connection lifecycle."""
133
185
  subscriber = None
186
+ # force refresh if we're retrying
187
+ force_creds_refresh = self._listen_tries > 0
188
+ self._listen_tries += 1
189
+
190
+ config = self._get_config(force_refresh=force_creds_refresh)
191
+
134
192
  try:
135
- client = AsyncRabbitMQClient(url=str(self._config.connection_string))
193
+ client = AsyncRabbitMQClient(url=str(config.connection_string))
136
194
  subscriber = Subscriber(client=client)
137
- print("The GX Agent is ready.")
195
+ LOGGER.info("The GX Agent is ready.")
138
196
  # Open a connection until encountering a shutdown event
139
197
  subscriber.consume(
140
- queue=self._config.queue,
198
+ queue=config.queue,
141
199
  on_message=self._handle_event_as_thread_enter,
142
200
  )
143
201
  except KeyboardInterrupt:
144
- print("Received request to shut down.")
202
+ LOGGER.debug("Received request to shut down.")
145
203
  except (SubscriberError, ClientError):
146
- print("The connection to GX Cloud has encountered an error.")
147
- except (AuthenticationError, ProbableAuthenticationError):
148
- # Retry with new credentials
149
- self._config = self._get_config()
204
+ LOGGER.exception("The connection to GX Cloud has encountered an error.")
205
+ except GXAgentUnrecoverableConnectionError:
206
+ LOGGER.exception("The connection to GX Cloud has encountered an unrecoverable error.")
207
+ os.kill(os.getpid(), signal.SIGTERM)
208
+ except (
209
+ AuthenticationError,
210
+ ProbableAuthenticationError,
211
+ AMQPConnectorException,
212
+ AMQPConnectionError,
213
+ ):
150
214
  # Raise to use the retry decorator to handle the retry logic
215
+ LOGGER.exception("Failed authentication to MQ.")
151
216
  raise
217
+
152
218
  finally:
153
219
  if subscriber is not None:
154
220
  subscriber.close()
@@ -158,6 +224,11 @@ class GXAgent:
158
224
  version: str = metadata_version(cls._PYPI_GX_AGENT_PACKAGE_NAME)
159
225
  return version
160
226
 
227
+ @classmethod
228
+ def _get_current_great_expectations_version(cls) -> str:
229
+ version: str = metadata_version(cls._PYPI_GREAT_EXPECTATIONS_PACKAGE_NAME)
230
+ return version
231
+
161
232
  def _handle_event_as_thread_enter(self, event_context: EventContext) -> None:
162
233
  """Schedule _handle_event to run in a thread.
163
234
 
@@ -172,17 +243,28 @@ class GXAgent:
172
243
  event_context.processed_with_failures()
173
244
  return
174
245
  elif self._can_accept_new_task() is not True:
246
+ LOGGER.warning(
247
+ "Cannot accept new task, redelivering.",
248
+ extra={
249
+ "event_type": event_context.event.type,
250
+ "correlation_id": event_context.correlation_id,
251
+ "organization_id": self.get_organization_id(event_context),
252
+ "workspace_id": str(self.get_workspace_id(event_context)),
253
+ "schedule_id": event_context.event.schedule_id
254
+ if isinstance(event_context.event, ScheduledEventBase)
255
+ else None,
256
+ },
257
+ )
175
258
  # request that this message is redelivered later
176
259
  loop = asyncio.get_event_loop()
177
260
  # store a reference the task to ensure it isn't garbage collected
178
261
  self._redeliver_msg_task = loop.create_task(event_context.redeliver_message())
179
262
  return
180
263
 
181
- # ensure that great_expectations.http requests to GX Cloud include the job_id/correlation_id
182
- self._set_http_session_headers(correlation_id=event_context.correlation_id)
183
-
184
- # send this message to a thread for processing
185
- self._current_task = self._executor.submit(self._handle_event, event_context=event_context)
264
+ self._current_task = self._executor.submit(
265
+ self._handle_event,
266
+ event_context=event_context,
267
+ )
186
268
 
187
269
  if self._current_task is not None:
188
270
  # add a callback for when the thread exits and pass it the event context
@@ -191,6 +273,44 @@ class GXAgent:
191
273
  )
192
274
  self._current_task.add_done_callback(on_exit_callback)
193
275
 
276
+ def get_data_context(self, event_context: EventContext) -> CloudDataContext:
277
+ """Create a new CloudDataContext for each job using the event's workspace_id."""
278
+ with warnings.catch_warnings():
279
+ warnings.filterwarnings("ignore", message="You are using great_expectations version")
280
+ workspace_id = self.get_workspace_id(event_context)
281
+
282
+ LOGGER.debug("Loading a DataContext - this might take a moment.")
283
+
284
+ context: CloudDataContext = get_context(
285
+ cloud_mode=True,
286
+ user_agent_str=self.user_agent_str,
287
+ cloud_workspace_id=str(workspace_id),
288
+ )
289
+ self._configure_progress_bars(data_context=context)
290
+
291
+ LOGGER.debug("DataContext is ready.")
292
+
293
+ return context
294
+
295
+ def get_organization_id(self, event_context: EventContext) -> UUID:
296
+ """Helper method to get the organization ID. Overridden in GX-Runner."""
297
+ return UUID(self._get_config().gx_cloud_organization_id)
298
+
299
+ def get_auth_key(self) -> str:
300
+ """Helper method to get the auth key. Overridden in GX-Runner."""
301
+ return self._get_config().gx_cloud_access_token
302
+
303
+ def get_workspace_id(self, event_context: EventContext) -> UUID:
304
+ """Helper method to get the workspace ID from the event."""
305
+ workspace_id: UUID | None = getattr(event_context.event, "workspace_id", None)
306
+ if workspace_id is None:
307
+ raise GXAgentError()
308
+ return workspace_id
309
+
310
+ def _set_sentry_tags(self, even_context: EventContext) -> None:
311
+ """Used by GX-Runner to set tags for Sentry logging. No-op in the Agent."""
312
+ pass
313
+
194
314
  def _handle_event(self, event_context: EventContext) -> ActionResult:
195
315
  """Pass events to EventHandler.
196
316
 
@@ -201,18 +321,51 @@ class GXAgent:
201
321
  event_context: event with related properties and actions.
202
322
  """
203
323
  # warning: this method will not be executed in the main thread
204
- self._update_status(job_id=event_context.correlation_id, status=JobStarted())
205
- print(f"Starting job {event_context.event.type} ({event_context.correlation_id}) ")
324
+
325
+ data_context = self.get_data_context(event_context=event_context)
326
+ # ensure that great_expectations.http requests to GX Cloud include the job_id/correlation_id
327
+ self._set_http_session_headers(
328
+ correlation_id=event_context.correlation_id, data_context=data_context
329
+ )
330
+
331
+ org_id = self.get_organization_id(event_context)
332
+ workspace_id = self.get_workspace_id(event_context)
333
+ base_url = self._get_config().gx_cloud_base_url
334
+ auth_key = self.get_auth_key()
335
+
336
+ if isinstance(event_context.event, ScheduledEventBase):
337
+ self._create_scheduled_job_and_set_started(event_context, org_id, workspace_id)
338
+ else:
339
+ self._update_status(
340
+ correlation_id=event_context.correlation_id,
341
+ status=JobStarted(),
342
+ org_id=org_id,
343
+ workspace_id=workspace_id,
344
+ )
206
345
  LOGGER.info(
207
346
  "Starting job",
208
347
  extra={
209
348
  "event_type": event_context.event.type,
210
349
  "correlation_id": event_context.correlation_id,
350
+ "organization_id": str(org_id),
351
+ "workspace_id": str(workspace_id),
352
+ "schedule_id": event_context.event.schedule_id
353
+ if isinstance(event_context.event, ScheduledEventBase)
354
+ else None,
211
355
  },
212
356
  )
213
- handler = EventHandler(context=self._context)
357
+
358
+ self._set_sentry_tags(event_context)
359
+
360
+ handler = EventHandler(context=data_context)
214
361
  # This method might raise an exception. Allow it and handle in _handle_event_as_thread_exit
215
- result = handler.handle_event(event=event_context.event, id=event_context.correlation_id)
362
+ result = handler.handle_event(
363
+ event=event_context.event,
364
+ id=event_context.correlation_id,
365
+ base_url=base_url,
366
+ auth_key=auth_key,
367
+ domain_context=DomainContext(organization_id=org_id, workspace_id=workspace_id),
368
+ )
216
369
  return result
217
370
 
218
371
  def _handle_event_as_thread_exit(
@@ -226,6 +379,9 @@ class GXAgent:
226
379
  """
227
380
  # warning: this method will not be executed in the main thread
228
381
 
382
+ org_id = self.get_organization_id(event_context)
383
+ workspace_id = self.get_workspace_id(event_context)
384
+
229
385
  # get results or errors from the thread
230
386
  error = future.exception()
231
387
  if error is None:
@@ -236,24 +392,39 @@ class GXAgent:
236
392
  success=False,
237
393
  created_resources=[],
238
394
  error_stack_trace="The version of the GX Agent you are using does not support this functionality. Please upgrade to the most recent image tagged with `stable`.",
395
+ processed_by=self._get_processed_by(),
239
396
  )
240
397
  LOGGER.error(
241
398
  "Job completed with error. Ensure agent is up-to-date.",
242
399
  extra={
243
400
  "event_type": event_context.event.type,
244
401
  "id": event_context.correlation_id,
402
+ "organization_id": str(org_id),
403
+ "workspace_id": str(workspace_id),
404
+ "schedule_id": event_context.event.schedule_id
405
+ if isinstance(event_context.event, ScheduledEventBase)
406
+ else None,
245
407
  },
246
408
  )
247
409
  else:
248
410
  status = JobCompleted(
249
411
  success=True,
250
412
  created_resources=result.created_resources,
413
+ processed_by=self._get_processed_by(),
251
414
  )
252
415
  LOGGER.info(
253
416
  "Completed job",
254
417
  extra={
255
418
  "event_type": event_context.event.type,
256
419
  "correlation_id": event_context.correlation_id,
420
+ "job_duration": (
421
+ result.job_duration.total_seconds() if result.job_duration else None
422
+ ),
423
+ "organization_id": str(org_id),
424
+ "workspace_id": str(workspace_id),
425
+ "schedule_id": event_context.event.schedule_id
426
+ if isinstance(event_context.event, ScheduledEventBase)
427
+ else None,
257
428
  },
258
429
  )
259
430
  else:
@@ -264,15 +435,44 @@ class GXAgent:
264
435
  extra={
265
436
  "event_type": event_context.event.type,
266
437
  "correlation_id": event_context.correlation_id,
438
+ "organization_id": str(org_id),
439
+ "workspace_id": str(workspace_id),
267
440
  },
268
441
  )
269
442
 
270
- self._update_status(job_id=event_context.correlation_id, status=status)
443
+ try:
444
+ self._update_status(
445
+ correlation_id=event_context.correlation_id,
446
+ status=status,
447
+ org_id=org_id,
448
+ workspace_id=workspace_id,
449
+ )
450
+ except Exception:
451
+ LOGGER.exception(
452
+ "Error updating status, removing message from queue",
453
+ extra={
454
+ "correlation_id": event_context.correlation_id,
455
+ "status": str(status),
456
+ "organization_id": str(org_id),
457
+ "workspace_id": str(workspace_id),
458
+ },
459
+ )
460
+ # We do not want to cause an infinite loop of errors
461
+ # If the status update fails, remove the message from the queue
462
+ # Otherwise, it would attempt to handle the error again via this done callback
463
+ event_context.processed_with_failures()
464
+ self._current_task = None
465
+ # Return so we don't also ack as processed successfully
466
+ return
271
467
 
272
468
  # ack message and cleanup resources
273
469
  event_context.processed_successfully()
274
470
  self._current_task = None
275
471
 
472
+ def _get_processed_by(self) -> Literal["agent", "runner"]:
473
+ """Return the name of the service that processed the event."""
474
+ return "runner" if self._get_config().queue == "gx-runner" else "agent"
475
+
276
476
  def _can_accept_new_task(self) -> bool:
277
477
  """Are we currently processing a task or are we free to take a new one?"""
278
478
  return self._current_task is None or self._current_task.done()
@@ -292,28 +492,33 @@ class GXAgent:
292
492
  self._correlation_ids.clear()
293
493
  return should_reject
294
494
 
495
+ def _get_config(self, force_refresh: bool = False) -> GXAgentConfig:
496
+ if force_refresh:
497
+ self._config = self._create_config()
498
+ return self._config
499
+
295
500
  @classmethod
296
- def _get_config(cls) -> GXAgentConfig:
501
+ def _create_config(cls) -> GXAgentConfig:
297
502
  """Construct GXAgentConfig."""
298
503
 
299
504
  # ensure we have all required env variables, and provide a useful error if not
300
505
 
301
506
  try:
302
507
  env_vars = GxAgentEnvVars()
303
- except pydantic.ValidationError as validation_err:
508
+ except pydantic_v1.ValidationError as validation_err:
304
509
  raise GXAgentConfigError(
305
510
  generate_config_validation_error_text(validation_err)
306
511
  ) from validation_err
307
512
 
308
513
  # obtain the broker url and queue name from Cloud
309
- agent_sessions_url = (
310
- f"{env_vars.gx_cloud_base_url}/organizations/"
311
- f"{env_vars.gx_cloud_organization_id}/agent-sessions"
514
+ agent_sessions_url = urljoin(
515
+ env_vars.gx_cloud_base_url,
516
+ f"/api/v1/organizations/{env_vars.gx_cloud_organization_id}/agent-sessions",
312
517
  )
313
518
 
314
519
  session = create_session(access_token=env_vars.gx_cloud_access_token)
315
-
316
520
  response = session.post(agent_sessions_url)
521
+ session.close()
317
522
  if response.ok is not True:
318
523
  raise GXAgentError( # noqa: TRY003 # TODO: use AuthenticationError
319
524
  "Unable to authenticate to GX Cloud. Please check your credentials."
@@ -323,6 +528,19 @@ class GXAgent:
323
528
  queue = json_response["queue"]
324
529
  connection_string = json_response["connection_string"]
325
530
 
531
+ # if overrides are set, we update the connection string. This is useful for local development to set the host
532
+ # to localhost, for example.
533
+ parsed = urlparse(connection_string)
534
+ if env_vars.amqp_host_override:
535
+ netloc = (
536
+ f"{parsed.username}:{parsed.password}@{env_vars.amqp_host_override}:{parsed.port}"
537
+ )
538
+ parsed = parsed._replace(netloc=netloc) # documented in urllib docs
539
+ if env_vars.amqp_port_override:
540
+ netloc = f"{parsed.username}:{parsed.password}@{parsed.hostname}:{env_vars.amqp_port_override}"
541
+ parsed = parsed._replace(netloc=netloc) # documented in urllib docs
542
+ connection_string = parsed.geturl()
543
+
326
544
  try:
327
545
  # pydantic will coerce the url to the correct type
328
546
  return GXAgentConfig(
@@ -331,71 +549,192 @@ class GXAgent:
331
549
  gx_cloud_base_url=env_vars.gx_cloud_base_url,
332
550
  gx_cloud_organization_id=env_vars.gx_cloud_organization_id,
333
551
  gx_cloud_access_token=env_vars.gx_cloud_access_token,
552
+ enable_progress_bars=env_vars.enable_progress_bars,
334
553
  )
335
- except pydantic.ValidationError as validation_err:
554
+ except pydantic_v1.ValidationError as validation_err:
336
555
  raise GXAgentConfigError(
337
556
  generate_config_validation_error_text(validation_err)
338
557
  ) from validation_err
339
558
 
340
- def _update_status(self, job_id: str, status: JobStatus) -> None:
559
+ def _configure_progress_bars(self, data_context: CloudDataContext) -> None:
560
+ progress_bars_enabled = self._get_config().enable_progress_bars
561
+
562
+ try:
563
+ data_context.variables.progress_bars = ProgressBarsConfig(
564
+ globally=progress_bars_enabled,
565
+ metric_calculations=progress_bars_enabled,
566
+ )
567
+ data_context.variables.save()
568
+ except Exception:
569
+ # Progress bars are not critical, so log and continue
570
+ # This is a known issue with FastAPI mercury V1 API for data-context-variables
571
+ LOGGER.warning(
572
+ "Failed to {set} progress bars".format(
573
+ set="enable" if progress_bars_enabled else "disable"
574
+ )
575
+ )
576
+
577
+ def _update_status(
578
+ self, correlation_id: str, status: JobStatus, org_id: UUID, workspace_id: UUID
579
+ ) -> None:
341
580
  """Update GX Cloud on the status of a job.
342
581
 
343
582
  Args:
344
- job_id: job identifier, also known as correlation_id
345
- status: pydantic model encapsulating the current status
583
+ correlation_id: job identifier
584
+ status: pydantic model encapsulating the current status.
346
585
  """
347
- LOGGER.info("Updating status", extra={"job_id": job_id, "status": str(status)})
348
- agent_sessions_url = (
349
- f"{self._config.gx_cloud_base_url}/organizations/{self._config.gx_cloud_organization_id}"
350
- + f"/agent-jobs/{job_id}"
586
+ LOGGER.info(
587
+ "Updating status",
588
+ extra={
589
+ "correlation_id": correlation_id,
590
+ "status": str(status),
591
+ "organization_id": str(org_id),
592
+ "workspace_id": str(workspace_id),
593
+ },
594
+ )
595
+ agent_sessions_url = urljoin(
596
+ self._get_config().gx_cloud_base_url,
597
+ f"/api/v1/organizations/{org_id}/workspaces/{workspace_id}/agent-jobs/{correlation_id}",
351
598
  )
352
- session = create_session(access_token=self._config.gx_cloud_access_token)
353
- data = status.json()
354
- session.patch(agent_sessions_url, data=data)
599
+ with create_session(access_token=self.get_auth_key()) as session:
600
+ data = UpdateJobStatusRequest(data=status).json()
601
+ response = session.patch(agent_sessions_url, data=data)
602
+ LOGGER.info(
603
+ "Status updated",
604
+ extra={
605
+ "correlation_id": correlation_id,
606
+ "status": str(status),
607
+ "organization_id": str(org_id),
608
+ "workspace_id": str(workspace_id),
609
+ },
610
+ )
611
+ GXAgent._log_http_error(
612
+ response, message="Status Update action had an error while connecting to GX Cloud."
613
+ )
614
+
615
+ def _create_scheduled_job_and_set_started(
616
+ self, event_context: EventContext, org_id: UUID, workspace_id: UUID
617
+ ) -> None:
618
+ """Create a job in GX Cloud for scheduled events.
619
+
620
+ This is because the scheduler + lambda create the event in the queue, and the agent consumes it. The agent then
621
+ sends a request to the agent-jobs endpoint to create the job in mercury to keep track of the job status.
622
+ Non-scheduled events by contrast create the job in mercury and the event in the queue at the same time.
355
623
 
356
- def _set_http_session_headers(self, correlation_id: str | None = None) -> None:
624
+ Args:
625
+ event_context: event with related properties and actions.
357
626
  """
358
- Set the the session headers for requests to GX Cloud.
627
+ if not isinstance(event_context.event, ScheduledEventBase):
628
+ raise GXAgentError( # noqa: TRY003
629
+ "Unable to create a scheduled job for a non-scheduled event."
630
+ )
631
+
632
+ LOGGER.info(
633
+ "Creating scheduled job and setting started",
634
+ extra={
635
+ "correlation_id": str(event_context.correlation_id),
636
+ "event_type": str(event_context.event.type),
637
+ "organization_id": str(org_id),
638
+ "workspace_id": str(workspace_id),
639
+ "schedule_id": str(event_context.event.schedule_id),
640
+ },
641
+ )
642
+
643
+ agent_sessions_url = urljoin(
644
+ self._get_config().gx_cloud_base_url,
645
+ f"/api/v1/organizations/{org_id}/workspaces/{workspace_id}/agent-jobs",
646
+ )
647
+ data = CreateScheduledJobAndSetJobStarted(
648
+ type="run_scheduled_checkpoint.received",
649
+ correlation_id=UUID(event_context.correlation_id),
650
+ schedule_id=event_context.event.schedule_id,
651
+ checkpoint_id=event_context.event.checkpoint_id,
652
+ datasource_names_to_asset_names=event_context.event.datasource_names_to_asset_names,
653
+ splitter_options=event_context.event.splitter_options,
654
+ checkpoint_name=event_context.event.checkpoint_name,
655
+ )
656
+ with create_session(access_token=self.get_auth_key()) as session:
657
+ payload = CreateScheduledJobAndSetJobStartedRequest(data=data).json()
658
+ response = session.post(agent_sessions_url, data=payload)
659
+ LOGGER.info(
660
+ "Created scheduled job and set started",
661
+ extra={
662
+ "correlation_id": str(event_context.correlation_id),
663
+ "event_type": str(event_context.event.type),
664
+ "organization_id": str(org_id),
665
+ "schedule_id": str(event_context.event.schedule_id),
666
+ "workspace_id": str(workspace_id),
667
+ },
668
+ )
669
+ GXAgent._log_http_error(
670
+ response,
671
+ message="Create schedule job action had an error while connecting to GX Cloud.",
672
+ )
673
+
674
+ def get_header_name(self) -> type[HeaderName]:
675
+ return HeaderName
676
+
677
+ def get_user_agent_header(self) -> str:
678
+ return USER_AGENT_HEADER
679
+
680
+ def _get_version(self) -> str:
681
+ return self.get_current_gx_agent_version()
682
+
683
+ def _set_data_context_store_headers(
684
+ self, data_context: CloudDataContext, headers: dict[HeaderName, str]
685
+ ) -> None:
686
+ """
687
+ Sets headers on all stores in the data context.
688
+ """
689
+ from great_expectations.data_context.store.gx_cloud_store_backend import ( # noqa: PLC0415
690
+ GXCloudStoreBackend,
691
+ )
692
+
693
+ # OSS doesn't use the same session for all requests, so we need to set the header for each store
694
+ stores = list(data_context.stores.values())
695
+ # some stores are treated differently
696
+ stores.extend([data_context._datasource_store, data_context._data_asset_store])
697
+ for store in stores:
698
+ backend = store._store_backend
699
+ if isinstance(backend, GXCloudStoreBackend):
700
+ backend._session.headers.update({str(key): value for key, value in headers.items()})
701
+
702
+ @property
703
+ def user_agent_str(self) -> str:
704
+ user_agent_header_prefix = self.get_user_agent_header()
705
+ agent_version = self._get_version()
706
+ return f"{user_agent_header_prefix}/{agent_version}"
707
+
708
+ def _set_http_session_headers(
709
+ self, data_context: CloudDataContext, correlation_id: str | None = None
710
+ ) -> None:
711
+ """
712
+ Set the session headers for requests to GX Cloud.
359
713
  In particular, set the User-Agent header to identify the GX Agent and the correlation_id as
360
714
  Agent-Job-Id if provided.
361
715
 
362
716
  Note: the Agent-Job-Id header value will be set for all GX Cloud request until this method is
363
717
  called again.
364
718
  """
365
- from great_expectations import __version__ # type: ignore[attr-defined] # TODO: fix this
366
- from great_expectations.core import http
367
- from great_expectations.data_context.store.gx_cloud_store_backend import GXCloudStoreBackend
368
719
 
369
- if Version(__version__) > Version(
370
- "0.19" # using 0.19 instead of 1.0 to account for pre-releases
371
- ):
372
- # TODO: public API should be available in v1
373
- LOGGER.info(
374
- "Unable to set header for requests to GX Cloud",
375
- extra={
376
- "user_agent": HeaderName.USER_AGENT,
377
- "agent_job_id": HeaderName.AGENT_JOB_ID,
378
- },
379
- )
380
- return
720
+ header_name = self.get_header_name()
721
+ user_agent_header_value = self.user_agent_str
381
722
 
382
- agent_version = self.get_current_gx_agent_version()
383
723
  LOGGER.debug(
384
724
  "Setting session headers for GX Cloud",
385
725
  extra={
386
- "user_agent": HeaderName.USER_AGENT,
387
- "agent_version": agent_version,
388
- "job_id": HeaderName.AGENT_JOB_ID,
726
+ "user_agent_header_name": header_name.USER_AGENT,
727
+ "user_agent_header_value": user_agent_header_value,
728
+ "correlation_id_header_name": header_name.AGENT_JOB_ID,
729
+ "correlation_id_header_value": correlation_id,
389
730
  "correlation_id": correlation_id,
390
731
  },
391
732
  )
392
733
 
734
+ core_headers = {header_name.USER_AGENT: user_agent_header_value}
393
735
  if correlation_id:
394
- # OSS doesn't use the same session for all requests, so we need to set the header for each store
395
- for store in self._context.stores.values():
396
- backend = store._store_backend
397
- if isinstance(backend, GXCloudStoreBackend):
398
- backend._session.headers[HeaderName.AGENT_JOB_ID] = correlation_id
736
+ core_headers.update({header_name.AGENT_JOB_ID: correlation_id})
737
+ self._set_data_context_store_headers(data_context=data_context, headers=core_headers)
399
738
 
400
739
  def _update_headers_agent_patch(
401
740
  session: requests.Session, access_token: str
@@ -407,10 +746,10 @@ class GXAgent:
407
746
  "Content-Type": "application/vnd.api+json",
408
747
  "Authorization": f"Bearer {access_token}",
409
748
  "Gx-Version": __version__,
410
- HeaderName.USER_AGENT: f"{USER_AGENT_HEADER}/{agent_version}",
749
+ header_name.USER_AGENT: user_agent_header_value,
411
750
  }
412
751
  if correlation_id:
413
- headers[HeaderName.AGENT_JOB_ID] = correlation_id
752
+ headers[header_name.AGENT_JOB_ID] = correlation_id
414
753
  session.headers.update(headers)
415
754
  return session
416
755
 
@@ -418,8 +757,12 @@ class GXAgent:
418
757
  # use a public API once it is available
419
758
  http._update_headers = _update_headers_agent_patch
420
759
 
421
-
422
- class GXAgentError(Exception): ...
423
-
424
-
425
- class GXAgentConfigError(GXAgentError): ...
760
+ @staticmethod
761
+ def _log_http_error(response: requests.Response, message: str) -> None:
762
+ """
763
+ Log the http error if the response is not successful.
764
+ """
765
+ try:
766
+ response.raise_for_status()
767
+ except requests.HTTPError:
768
+ LOGGER.exception(message, extra={"response": response})