great-expectations-cloud 20240523.0.dev0__py3-none-any.whl → 20251124.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. great_expectations_cloud/agent/__init__.py +3 -0
  2. great_expectations_cloud/agent/actions/__init__.py +8 -5
  3. great_expectations_cloud/agent/actions/agent_action.py +21 -6
  4. great_expectations_cloud/agent/actions/draft_datasource_config_action.py +45 -24
  5. great_expectations_cloud/agent/actions/generate_data_quality_check_expectations_action.py +557 -0
  6. great_expectations_cloud/agent/actions/list_asset_names.py +65 -0
  7. great_expectations_cloud/agent/actions/run_checkpoint.py +74 -27
  8. great_expectations_cloud/agent/actions/run_metric_list_action.py +11 -5
  9. great_expectations_cloud/agent/actions/run_scheduled_checkpoint.py +67 -0
  10. great_expectations_cloud/agent/actions/run_window_checkpoint.py +66 -0
  11. great_expectations_cloud/agent/actions/utils.py +35 -0
  12. great_expectations_cloud/agent/agent.py +444 -101
  13. great_expectations_cloud/agent/cli.py +2 -2
  14. great_expectations_cloud/agent/config.py +19 -5
  15. great_expectations_cloud/agent/event_handler.py +49 -12
  16. great_expectations_cloud/agent/exceptions.py +9 -0
  17. great_expectations_cloud/agent/message_service/asyncio_rabbit_mq_client.py +80 -14
  18. great_expectations_cloud/agent/message_service/subscriber.py +8 -5
  19. great_expectations_cloud/agent/models.py +197 -20
  20. great_expectations_cloud/agent/utils.py +84 -0
  21. great_expectations_cloud/logging/logging_cfg.py +20 -4
  22. great_expectations_cloud/py.typed +0 -0
  23. {great_expectations_cloud-20240523.0.dev0.dist-info → great_expectations_cloud-20251124.0.dev1.dist-info}/METADATA +54 -46
  24. great_expectations_cloud-20251124.0.dev1.dist-info/RECORD +34 -0
  25. {great_expectations_cloud-20240523.0.dev0.dist-info → great_expectations_cloud-20251124.0.dev1.dist-info}/WHEEL +1 -1
  26. great_expectations_cloud/agent/actions/data_assistants/__init__.py +0 -8
  27. great_expectations_cloud/agent/actions/data_assistants/run_missingness_data_assistant.py +0 -45
  28. great_expectations_cloud/agent/actions/data_assistants/run_onboarding_data_assistant.py +0 -45
  29. great_expectations_cloud/agent/actions/data_assistants/utils.py +0 -123
  30. great_expectations_cloud/agent/actions/list_table_names.py +0 -76
  31. great_expectations_cloud-20240523.0.dev0.dist-info/RECORD +0 -32
  32. {great_expectations_cloud-20240523.0.dev0.dist-info → great_expectations_cloud-20251124.0.dev1.dist-info}/entry_points.txt +0 -0
  33. {great_expectations_cloud-20240523.0.dev0.dist-info → great_expectations_cloud-20251124.0.dev1.dist-info/licenses}/LICENSE +0 -0
@@ -91,12 +91,12 @@ def main() -> None:
91
91
  )
92
92
 
93
93
  if args.version:
94
- from great_expectations_cloud.agent import get_version
94
+ from great_expectations_cloud.agent import get_version # noqa: PLC0415
95
95
 
96
96
  print(f"GX Agent version: {get_version()}")
97
97
  return
98
98
 
99
- from great_expectations_cloud.agent import run_agent
99
+ from great_expectations_cloud.agent import run_agent # noqa: PLC0415
100
100
 
101
101
  run_agent()
102
102
 
@@ -1,19 +1,33 @@
1
1
  from __future__ import annotations
2
2
 
3
- from great_expectations.compatibility.pydantic import AnyUrl, BaseSettings, ValidationError
3
+ from typing import Optional
4
+
4
5
  from great_expectations.data_context.cloud_constants import CLOUD_DEFAULT_BASE_URL
6
+ from pydantic.v1 import AnyUrl, BaseSettings, ValidationError
5
7
 
6
8
 
7
- class GxAgentEnvVars(BaseSettings): # type: ignore[misc] # BaseSettings is has Any type
8
- # pydantic will coerce this string to AnyUrl type
9
- gx_cloud_base_url: AnyUrl = CLOUD_DEFAULT_BASE_URL
9
+ class GxAgentEnvVars(BaseSettings):
10
+ gx_cloud_base_url: AnyUrl = AnyUrl(url=CLOUD_DEFAULT_BASE_URL, scheme="https")
10
11
  gx_cloud_organization_id: str
11
12
  gx_cloud_access_token: str
13
+ enable_progress_bars: bool = True
14
+
15
+ amqp_host_override: Optional[str] = None # noqa: UP045 # pipe not working with 3.9
16
+ amqp_port_override: Optional[int] = None # noqa: UP045 # pipe not working with 3.9
17
+
18
+ def __init__(self, **overrides: str | AnyUrl) -> None:
19
+ """
20
+ Custom __init__ to prevent type error when relying on environment variables.
21
+
22
+ TODO:once mypy fully support annoting **kwargs with a Unpack[TypedDict], we should do that.
23
+ https://peps.python.org/pep-0692/
24
+ """
25
+ super().__init__(**overrides)
12
26
 
13
27
 
14
28
  def generate_config_validation_error_text(validation_error: ValidationError) -> str:
15
29
  missing_variables = ", ".join(
16
- [validation_error["loc"][0] for validation_error in validation_error.errors()]
30
+ [str(validation_error["loc"][0]) for validation_error in validation_error.errors()]
17
31
  )
18
32
  error_text = f"Missing or badly formed environment variable(s). Make sure to set the following environment variable(s): {missing_variables}"
19
33
  return error_text
@@ -2,18 +2,24 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from collections import defaultdict
5
+ from datetime import datetime, timezone
5
6
  from json import JSONDecodeError
6
7
  from typing import TYPE_CHECKING, Any, Final
8
+ from uuid import UUID
7
9
 
8
10
  import great_expectations as gx
9
- from great_expectations.compatibility import pydantic
10
11
  from packaging.version import Version
11
12
  from packaging.version import parse as parse_version
13
+ from pydantic import v1 as pydantic_v1
12
14
 
13
15
  from great_expectations_cloud.agent.actions.unknown import UnknownEventAction
16
+ from great_expectations_cloud.agent.exceptions import GXAgentError
14
17
  from great_expectations_cloud.agent.models import (
18
+ DomainContext,
15
19
  Event,
20
+ EventType,
16
21
  UnknownEvent,
22
+ get_event_union,
17
23
  )
18
24
 
19
25
  if TYPE_CHECKING:
@@ -41,7 +47,7 @@ _EVENT_ACTION_MAP: dict[str, dict[str, type[AgentAction[Any]]]] = defaultdict(di
41
47
 
42
48
 
43
49
  def register_event_action(
44
- version: str, event_type: type[Event], action_class: type[AgentAction[Any]]
50
+ version: str, event_type: EventType, action_class: type[AgentAction[Any]]
45
51
  ) -> None:
46
52
  """Register an event type to an action class."""
47
53
  if version in _EVENT_ACTION_MAP and event_type.__name__ in _EVENT_ACTION_MAP[version]:
@@ -61,34 +67,65 @@ class EventHandler:
61
67
  def __init__(self, context: CloudDataContext) -> None:
62
68
  self._context = context
63
69
 
64
- def get_event_action(self, event: Event) -> AgentAction[Any]:
70
+ def get_event_action(
71
+ self, event: Event, base_url: str, auth_key: str, domain_context: DomainContext
72
+ ) -> AgentAction[Any]:
65
73
  """Get the action that should be run for the given event."""
74
+
75
+ if not self._check_event_organization_id(event, domain_context.organization_id):
76
+ # Making message more generic
77
+ raise GXAgentError("Unable to process job. Invalid input.") # noqa: TRY003
78
+
66
79
  action_map = _EVENT_ACTION_MAP.get(_GX_MAJOR_VERSION)
67
80
  if action_map is None:
68
81
  raise NoVersionImplementationError(version=_GX_MAJOR_VERSION)
69
82
  action_class = action_map.get(_get_event_name(event))
70
83
  if action_class is None:
71
84
  action_class = UnknownEventAction
72
- return action_class(context=self._context)
73
-
74
- def handle_event(self, event: Event, id: str) -> ActionResult:
85
+ return action_class(
86
+ context=self._context,
87
+ base_url=base_url,
88
+ domain_context=domain_context,
89
+ auth_key=auth_key,
90
+ )
91
+
92
+ def handle_event(
93
+ self, event: Event, id: str, base_url: str, auth_key: str, domain_context: DomainContext
94
+ ) -> ActionResult:
75
95
  """Transform an Event into an ActionResult."""
76
- action = self.get_event_action(event=event)
96
+ start_time = datetime.now(tz=timezone.utc)
97
+ action = self.get_event_action(
98
+ event=event, base_url=base_url, auth_key=auth_key, domain_context=domain_context
99
+ )
77
100
  LOGGER.info(f"Handling event: {event.type} -> {action.__class__.__name__}")
78
101
  action_result = action.run(event=event, id=id)
102
+ end_time = datetime.now(tz=timezone.utc)
103
+ action_result.job_duration = end_time - start_time
79
104
  return action_result
80
105
 
81
106
  @classmethod
82
107
  def parse_event_from(cls, msg_body: bytes) -> Event:
108
+ event_union = get_event_union()
83
109
  try:
84
- event: Event = pydantic.parse_raw_as(Event, msg_body)
85
- except (pydantic.ValidationError, JSONDecodeError):
110
+ event: Event = pydantic_v1.parse_raw_as(event_union, msg_body)
111
+ except (pydantic_v1.ValidationError, JSONDecodeError):
86
112
  # Log as bytes
87
113
  LOGGER.exception("Unable to parse event type", extra={"msg_body": f"{msg_body!r}"})
88
114
  return UnknownEvent()
89
115
 
90
116
  return event
91
117
 
118
+ @staticmethod
119
+ def _check_event_organization_id(event: Event, organization_id: UUID) -> bool:
120
+ """Check if the organization_id in the event matches the given organization_id.
121
+
122
+ This prevents processing events that are not intended for the current organization, and potentially
123
+ leaking sensitive information across organizations.
124
+ """
125
+ if hasattr(event, "organization_id") and event.organization_id != organization_id:
126
+ return False
127
+ return True
128
+
92
129
 
93
130
  class EventError(Exception): ...
94
131
 
@@ -121,12 +158,12 @@ def _get_major_version(version: str) -> str:
121
158
  return str(parsed.major)
122
159
 
123
160
 
124
- version = gx.__version__ # type: ignore[attr-defined] # TODO: fix this
161
+ version = gx.__version__
125
162
  _GX_MAJOR_VERSION = _get_major_version(str(version))
126
163
 
127
164
 
128
165
  def _get_event_name(event: Event) -> str:
129
166
  try:
130
- return str(event.__name__)
167
+ return str(event.__name__) # type: ignore[union-attr] # FIXME
131
168
  except AttributeError:
132
- return str(event.__class__.__name__)
169
+ return event.__class__.__name__
@@ -32,3 +32,12 @@ class ErrorCode(str, enum.Enum):
32
32
 
33
33
  GENERIC_UNHANDLED_ERROR = "generic-unhandled-error"
34
34
  WRONG_USERNAME_OR_PASSWORD = "wrong-username-or-password" # noqa: S105 # Not a hardcoded password
35
+
36
+
37
+ class GXAgentError(Exception): ...
38
+
39
+
40
+ class GXAgentConfigError(GXAgentError): ...
41
+
42
+
43
+ class GXAgentUnrecoverableConnectionError(GXAgentError): ...
@@ -1,19 +1,26 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import logging
4
5
  import ssl
5
6
  from asyncio import AbstractEventLoop
6
7
  from dataclasses import dataclass
7
8
  from functools import partial
8
- from typing import TYPE_CHECKING, Callable, Protocol
9
+ from typing import TYPE_CHECKING, Callable, Final, Protocol
9
10
 
10
11
  import pika
11
12
  from pika.adapters.asyncio_connection import AsyncioConnection
13
+ from pika.exceptions import ChannelClosed, ConnectionClosed
14
+
15
+ from great_expectations_cloud.agent.exceptions import GXAgentUnrecoverableConnectionError
12
16
 
13
17
  if TYPE_CHECKING:
14
18
  from pika.channel import Channel
15
19
  from pika.spec import Basic, BasicProperties
16
20
 
21
+ LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
22
+ LOGGER.setLevel(logging.DEBUG)
23
+
17
24
 
18
25
  @dataclass(frozen=True)
19
26
  class OnMessagePayload:
@@ -43,6 +50,7 @@ class AsyncRabbitMQClient:
43
50
  self._closing = False
44
51
  self._consumer_tag = None
45
52
  self._consuming = False
53
+ self._is_unrecoverable = False
46
54
 
47
55
  def run(self, queue: str, on_message: OnMessageFn) -> None:
48
56
  """Run an async connection to RabbitMQ.
@@ -67,21 +75,27 @@ class AsyncRabbitMQClient:
67
75
  )
68
76
  self._connection = connection
69
77
  connection.ioloop.run_forever()
78
+ if self._is_unrecoverable:
79
+ raise GXAgentUnrecoverableConnectionError( # noqa: TRY003
80
+ "AsyncRabbitMQClient has encountered an unrecoverable error."
81
+ )
70
82
 
71
83
  def stop(self) -> None:
72
84
  """Close the connection to RabbitMQ."""
73
85
  if self._connection is None:
74
86
  return
87
+ LOGGER.debug("Shutting down the connection to RabbitMQ.")
75
88
  if not self._closing:
76
89
  self._closing = True
77
- if self._consuming:
78
- self._stop_consuming()
79
- self._connection.ioloop.run_forever()
80
- else:
81
- self._connection.ioloop.stop()
90
+
91
+ if self._consuming:
92
+ self._stop_consuming()
93
+ self._connection.ioloop.stop()
94
+ LOGGER.debug("The connection to RabbitMQ has been shut down.")
82
95
 
83
96
  def reset(self) -> None:
84
97
  """Reset client to allow a restart."""
98
+ LOGGER.debug("Resetting client")
85
99
  self.should_reconnect = False
86
100
  self.was_consuming = False
87
101
 
@@ -148,7 +162,7 @@ class AsyncRabbitMQClient:
148
162
  nack = partial(channel.basic_nack, delivery_tag=delivery_tag, requeue=requeue)
149
163
  loop.call_soon_threadsafe(callback=nack)
150
164
 
151
- def _callback_handler( # noqa: PLR0913
165
+ def _callback_handler(
152
166
  self,
153
167
  channel: Channel,
154
168
  method_frame: Basic.Deliver,
@@ -167,66 +181,118 @@ class AsyncRabbitMQClient:
167
181
 
168
182
  def _start_consuming(self, queue: str, on_message: OnMessageFn, channel: Channel) -> None:
169
183
  """Consume from a channel with the on_message callback."""
184
+ LOGGER.debug("Issuing consumer-related RPC commands")
170
185
  channel.add_on_cancel_callback(self._on_consumer_canceled)
186
+ # set RabbitMQ prefetch count to equal the max_threads value in the GX Agent's ThreadPoolExecutor
187
+ channel.basic_qos(prefetch_count=1)
171
188
  self._consumer_tag = channel.basic_consume(queue=queue, on_message_callback=on_message)
172
189
 
173
190
  def _on_consumer_canceled(self, method_frame: Basic.Cancel) -> None:
174
191
  """Callback invoked when the broker cancels the client's connection."""
175
192
  if self._channel is not None:
193
+ LOGGER.info(
194
+ "Consumer was cancelled remotely, shutting down",
195
+ extra={
196
+ "method_frame": method_frame,
197
+ },
198
+ )
176
199
  self._channel.close()
177
200
 
178
201
  def _reconnect(self) -> None:
179
202
  """Prepare the client to reconnect."""
203
+ LOGGER.debug("Preparing client to reconnect")
180
204
  self.should_reconnect = True
181
205
  self.stop()
182
206
 
183
207
  def _stop_consuming(self) -> None:
184
208
  """Cancel the channel, if it exists."""
185
209
  if self._channel is not None:
210
+ LOGGER.debug("Sending a Basic.Cancel RPC command to RabbitMQ")
186
211
  self._channel.basic_cancel(self._consumer_tag, callback=self._on_cancel_ok)
187
212
 
188
- def _on_cancel_ok(self, method_frame: Basic.CancelOk) -> None:
213
+ def _on_cancel_ok(self, _unused_frame: Basic.CancelOk) -> None:
189
214
  """Callback invoked after broker confirms cancel."""
190
215
  self._consuming = False
191
216
  if self._channel is not None:
217
+ LOGGER.debug("RabbitMQ acknowledged the cancellation of the consumer")
192
218
  self._channel.close()
193
219
 
194
220
  def _on_connection_open(
195
221
  self, connection: AsyncioConnection, queue: str, on_message: OnMessageFn
196
222
  ) -> None:
197
223
  """Callback invoked after the broker opens the connection."""
224
+ LOGGER.debug("Connection to RabbitMQ has been opened")
198
225
  on_channel_open = partial(self._on_channel_open, queue=queue, on_message=on_message)
199
226
  connection.channel(on_open_callback=on_channel_open)
200
227
 
201
- def _on_connection_open_error(self, connection: AsyncioConnection, reason: str) -> None:
228
+ def _on_connection_open_error(
229
+ self, _unused_connection: AsyncioConnection, reason: pika.Exception
230
+ ) -> None:
202
231
  """Callback invoked when there is an error while opening connection."""
203
232
  self._reconnect()
233
+ self._log_pika_exception("Connection open failed", reason)
204
234
 
205
- def _on_connection_closed(self, connection: AsyncioConnection, reason: str) -> None:
235
+ def _on_connection_closed(
236
+ self, connection: AsyncioConnection, _unused_reason: pika.Exception
237
+ ) -> None:
206
238
  """Callback invoked after the broker closes the connection"""
239
+ LOGGER.debug("Connection to RabbitMQ has been closed")
207
240
  self._channel = None
241
+ self._is_unrecoverable = True
208
242
  if self._closing:
209
243
  connection.ioloop.stop()
210
244
  else:
211
245
  self._reconnect()
212
246
 
213
- def _close_connection(self) -> None:
247
+ def _close_connection(self, reason: pika.Exception) -> None:
214
248
  """Close the connection to the broker."""
215
249
  self._consuming = False
216
250
  if self._connection is None or self._connection.is_closing or self._connection.is_closed:
251
+ LOGGER.debug("Connection to RabbitMQ is closing or is already closed")
217
252
  pass
218
253
  else:
219
- self._connection.close()
254
+ LOGGER.debug("Closing connection to RabbitMQ")
255
+
256
+ if isinstance(reason, (ConnectionClosed, ChannelClosed)):
257
+ reply_code = reason.reply_code
258
+ reply_text = reason.reply_text
259
+ else:
260
+ reply_code = 999 # arbitrary value, not in the list of AMQP reply codes: https://www.rabbitmq.com/amqp-0-9-1-reference#constants
261
+ reply_text = str(reason)
262
+ self._connection.close(reply_code=reply_code, reply_text=reply_text)
263
+
264
+ def _log_pika_exception(
265
+ self, message: str, reason: pika.Exception, extra: dict[str, str] | None = None
266
+ ) -> None:
267
+ """Log a pika exception. Extra is key-value pairs to include in the log message."""
268
+ if not extra:
269
+ extra = {}
270
+ if isinstance(reason, (ConnectionClosed, ChannelClosed)):
271
+ default_extra: dict[str, str] = {
272
+ "reply_code": str(reason.reply_code),
273
+ "reply_text": str(reason.reply_text),
274
+ }
275
+ LOGGER.error(
276
+ message,
277
+ # mypy not happy with dict | dict, so we use the dict constructor
278
+ extra={**default_extra, **extra},
279
+ )
280
+ else:
281
+ default_extra = {"reason": str(reason)}
282
+ # mypy not happy with dict | dict, so we use the dict constructor
283
+ LOGGER.error(message, extra={**default_extra, **extra})
220
284
 
221
285
  def _on_channel_open(self, channel: Channel, queue: str, on_message: OnMessageFn) -> None:
222
286
  """Callback invoked after the broker opens the channel."""
287
+ LOGGER.debug("Channel opened")
223
288
  self._channel = channel
224
289
  channel.add_on_close_callback(self._on_channel_closed)
225
290
  self._start_consuming(queue=queue, on_message=on_message, channel=channel)
226
291
 
227
- def _on_channel_closed(self, channel: Channel, reason: str) -> None:
292
+ def _on_channel_closed(self, channel: Channel, reason: ChannelClosed) -> None:
228
293
  """Callback invoked after the broker closes the channel."""
229
- self._close_connection()
294
+ self._log_pika_exception("Channel closed", reason, extra={"channel": channel})
295
+ self._close_connection(reason)
230
296
 
231
297
  def _build_client_parameters(self, url: str) -> pika.URLParameters:
232
298
  """Configure parameters used to connect to the broker."""
@@ -2,9 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  import asyncio
4
4
  import time
5
+ from collections.abc import Coroutine
5
6
  from dataclasses import dataclass
6
7
  from functools import partial
7
- from typing import TYPE_CHECKING, Callable, Coroutine, Protocol
8
+ from typing import TYPE_CHECKING, Callable, Protocol
8
9
 
9
10
  from pika.exceptions import (
10
11
  AMQPError,
@@ -13,6 +14,7 @@ from pika.exceptions import (
13
14
  )
14
15
 
15
16
  from great_expectations_cloud.agent.event_handler import EventHandler
17
+ from great_expectations_cloud.agent.exceptions import GXAgentUnrecoverableConnectionError
16
18
 
17
19
  if TYPE_CHECKING:
18
20
  from great_expectations_cloud.agent.message_service.asyncio_rabbit_mq_client import (
@@ -95,11 +97,13 @@ class Subscriber:
95
97
  self.client.stop()
96
98
  reconnect_delay = self._get_reconnect_delay()
97
99
  time.sleep(reconnect_delay) # todo: update this blocking call to asyncio.sleep
100
+ raise
98
101
  except KeyboardInterrupt as e:
99
102
  self.client.stop()
100
103
  raise KeyboardInterrupt from e
101
- if self.client.should_reconnect:
102
- self.client.reset()
104
+ except GXAgentUnrecoverableConnectionError:
105
+ self.client.stop()
106
+ raise
103
107
  else:
104
108
  break # exit
105
109
 
@@ -159,8 +163,7 @@ class Subscriber:
159
163
  self._reconnect_delay = 0
160
164
  else:
161
165
  self._reconnect_delay += 1
162
- if self._reconnect_delay > 30: # noqa: PLR2004
163
- self._reconnect_delay = 30
166
+ self._reconnect_delay = min(self._reconnect_delay, 30)
164
167
  return self._reconnect_delay
165
168
 
166
169
  def close(self) -> None: