great-expectations-cloud 20240523.0.dev0__py3-none-any.whl → 20251124.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- great_expectations_cloud/agent/__init__.py +3 -0
- great_expectations_cloud/agent/actions/__init__.py +8 -5
- great_expectations_cloud/agent/actions/agent_action.py +21 -6
- great_expectations_cloud/agent/actions/draft_datasource_config_action.py +45 -24
- great_expectations_cloud/agent/actions/generate_data_quality_check_expectations_action.py +557 -0
- great_expectations_cloud/agent/actions/list_asset_names.py +65 -0
- great_expectations_cloud/agent/actions/run_checkpoint.py +74 -27
- great_expectations_cloud/agent/actions/run_metric_list_action.py +11 -5
- great_expectations_cloud/agent/actions/run_scheduled_checkpoint.py +67 -0
- great_expectations_cloud/agent/actions/run_window_checkpoint.py +66 -0
- great_expectations_cloud/agent/actions/utils.py +35 -0
- great_expectations_cloud/agent/agent.py +444 -101
- great_expectations_cloud/agent/cli.py +2 -2
- great_expectations_cloud/agent/config.py +19 -5
- great_expectations_cloud/agent/event_handler.py +49 -12
- great_expectations_cloud/agent/exceptions.py +9 -0
- great_expectations_cloud/agent/message_service/asyncio_rabbit_mq_client.py +80 -14
- great_expectations_cloud/agent/message_service/subscriber.py +8 -5
- great_expectations_cloud/agent/models.py +197 -20
- great_expectations_cloud/agent/utils.py +84 -0
- great_expectations_cloud/logging/logging_cfg.py +20 -4
- great_expectations_cloud/py.typed +0 -0
- {great_expectations_cloud-20240523.0.dev0.dist-info → great_expectations_cloud-20251124.0.dev1.dist-info}/METADATA +54 -46
- great_expectations_cloud-20251124.0.dev1.dist-info/RECORD +34 -0
- {great_expectations_cloud-20240523.0.dev0.dist-info → great_expectations_cloud-20251124.0.dev1.dist-info}/WHEEL +1 -1
- great_expectations_cloud/agent/actions/data_assistants/__init__.py +0 -8
- great_expectations_cloud/agent/actions/data_assistants/run_missingness_data_assistant.py +0 -45
- great_expectations_cloud/agent/actions/data_assistants/run_onboarding_data_assistant.py +0 -45
- great_expectations_cloud/agent/actions/data_assistants/utils.py +0 -123
- great_expectations_cloud/agent/actions/list_table_names.py +0 -76
- great_expectations_cloud-20240523.0.dev0.dist-info/RECORD +0 -32
- {great_expectations_cloud-20240523.0.dev0.dist-info → great_expectations_cloud-20251124.0.dev1.dist-info}/entry_points.txt +0 -0
- {great_expectations_cloud-20240523.0.dev0.dist-info → great_expectations_cloud-20251124.0.dev1.dist-info/licenses}/LICENSE +0 -0
|
@@ -91,12 +91,12 @@ def main() -> None:
|
|
|
91
91
|
)
|
|
92
92
|
|
|
93
93
|
if args.version:
|
|
94
|
-
from great_expectations_cloud.agent import get_version
|
|
94
|
+
from great_expectations_cloud.agent import get_version # noqa: PLC0415
|
|
95
95
|
|
|
96
96
|
print(f"GX Agent version: {get_version()}")
|
|
97
97
|
return
|
|
98
98
|
|
|
99
|
-
from great_expectations_cloud.agent import run_agent
|
|
99
|
+
from great_expectations_cloud.agent import run_agent # noqa: PLC0415
|
|
100
100
|
|
|
101
101
|
run_agent()
|
|
102
102
|
|
|
@@ -1,19 +1,33 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
4
5
|
from great_expectations.data_context.cloud_constants import CLOUD_DEFAULT_BASE_URL
|
|
6
|
+
from pydantic.v1 import AnyUrl, BaseSettings, ValidationError
|
|
5
7
|
|
|
6
8
|
|
|
7
|
-
class GxAgentEnvVars(BaseSettings):
|
|
8
|
-
|
|
9
|
-
gx_cloud_base_url: AnyUrl = CLOUD_DEFAULT_BASE_URL
|
|
9
|
+
class GxAgentEnvVars(BaseSettings):
|
|
10
|
+
gx_cloud_base_url: AnyUrl = AnyUrl(url=CLOUD_DEFAULT_BASE_URL, scheme="https")
|
|
10
11
|
gx_cloud_organization_id: str
|
|
11
12
|
gx_cloud_access_token: str
|
|
13
|
+
enable_progress_bars: bool = True
|
|
14
|
+
|
|
15
|
+
amqp_host_override: Optional[str] = None # noqa: UP045 # pipe not working with 3.9
|
|
16
|
+
amqp_port_override: Optional[int] = None # noqa: UP045 # pipe not working with 3.9
|
|
17
|
+
|
|
18
|
+
def __init__(self, **overrides: str | AnyUrl) -> None:
|
|
19
|
+
"""
|
|
20
|
+
Custom __init__ to prevent type error when relying on environment variables.
|
|
21
|
+
|
|
22
|
+
TODO:once mypy fully support annoting **kwargs with a Unpack[TypedDict], we should do that.
|
|
23
|
+
https://peps.python.org/pep-0692/
|
|
24
|
+
"""
|
|
25
|
+
super().__init__(**overrides)
|
|
12
26
|
|
|
13
27
|
|
|
14
28
|
def generate_config_validation_error_text(validation_error: ValidationError) -> str:
|
|
15
29
|
missing_variables = ", ".join(
|
|
16
|
-
[validation_error["loc"][0] for validation_error in validation_error.errors()]
|
|
30
|
+
[str(validation_error["loc"][0]) for validation_error in validation_error.errors()]
|
|
17
31
|
)
|
|
18
32
|
error_text = f"Missing or badly formed environment variable(s). Make sure to set the following environment variable(s): {missing_variables}"
|
|
19
33
|
return error_text
|
|
@@ -2,18 +2,24 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
from collections import defaultdict
|
|
5
|
+
from datetime import datetime, timezone
|
|
5
6
|
from json import JSONDecodeError
|
|
6
7
|
from typing import TYPE_CHECKING, Any, Final
|
|
8
|
+
from uuid import UUID
|
|
7
9
|
|
|
8
10
|
import great_expectations as gx
|
|
9
|
-
from great_expectations.compatibility import pydantic
|
|
10
11
|
from packaging.version import Version
|
|
11
12
|
from packaging.version import parse as parse_version
|
|
13
|
+
from pydantic import v1 as pydantic_v1
|
|
12
14
|
|
|
13
15
|
from great_expectations_cloud.agent.actions.unknown import UnknownEventAction
|
|
16
|
+
from great_expectations_cloud.agent.exceptions import GXAgentError
|
|
14
17
|
from great_expectations_cloud.agent.models import (
|
|
18
|
+
DomainContext,
|
|
15
19
|
Event,
|
|
20
|
+
EventType,
|
|
16
21
|
UnknownEvent,
|
|
22
|
+
get_event_union,
|
|
17
23
|
)
|
|
18
24
|
|
|
19
25
|
if TYPE_CHECKING:
|
|
@@ -41,7 +47,7 @@ _EVENT_ACTION_MAP: dict[str, dict[str, type[AgentAction[Any]]]] = defaultdict(di
|
|
|
41
47
|
|
|
42
48
|
|
|
43
49
|
def register_event_action(
|
|
44
|
-
version: str, event_type:
|
|
50
|
+
version: str, event_type: EventType, action_class: type[AgentAction[Any]]
|
|
45
51
|
) -> None:
|
|
46
52
|
"""Register an event type to an action class."""
|
|
47
53
|
if version in _EVENT_ACTION_MAP and event_type.__name__ in _EVENT_ACTION_MAP[version]:
|
|
@@ -61,34 +67,65 @@ class EventHandler:
|
|
|
61
67
|
def __init__(self, context: CloudDataContext) -> None:
|
|
62
68
|
self._context = context
|
|
63
69
|
|
|
64
|
-
def get_event_action(
|
|
70
|
+
def get_event_action(
|
|
71
|
+
self, event: Event, base_url: str, auth_key: str, domain_context: DomainContext
|
|
72
|
+
) -> AgentAction[Any]:
|
|
65
73
|
"""Get the action that should be run for the given event."""
|
|
74
|
+
|
|
75
|
+
if not self._check_event_organization_id(event, domain_context.organization_id):
|
|
76
|
+
# Making message more generic
|
|
77
|
+
raise GXAgentError("Unable to process job. Invalid input.") # noqa: TRY003
|
|
78
|
+
|
|
66
79
|
action_map = _EVENT_ACTION_MAP.get(_GX_MAJOR_VERSION)
|
|
67
80
|
if action_map is None:
|
|
68
81
|
raise NoVersionImplementationError(version=_GX_MAJOR_VERSION)
|
|
69
82
|
action_class = action_map.get(_get_event_name(event))
|
|
70
83
|
if action_class is None:
|
|
71
84
|
action_class = UnknownEventAction
|
|
72
|
-
return action_class(
|
|
73
|
-
|
|
74
|
-
|
|
85
|
+
return action_class(
|
|
86
|
+
context=self._context,
|
|
87
|
+
base_url=base_url,
|
|
88
|
+
domain_context=domain_context,
|
|
89
|
+
auth_key=auth_key,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def handle_event(
|
|
93
|
+
self, event: Event, id: str, base_url: str, auth_key: str, domain_context: DomainContext
|
|
94
|
+
) -> ActionResult:
|
|
75
95
|
"""Transform an Event into an ActionResult."""
|
|
76
|
-
|
|
96
|
+
start_time = datetime.now(tz=timezone.utc)
|
|
97
|
+
action = self.get_event_action(
|
|
98
|
+
event=event, base_url=base_url, auth_key=auth_key, domain_context=domain_context
|
|
99
|
+
)
|
|
77
100
|
LOGGER.info(f"Handling event: {event.type} -> {action.__class__.__name__}")
|
|
78
101
|
action_result = action.run(event=event, id=id)
|
|
102
|
+
end_time = datetime.now(tz=timezone.utc)
|
|
103
|
+
action_result.job_duration = end_time - start_time
|
|
79
104
|
return action_result
|
|
80
105
|
|
|
81
106
|
@classmethod
|
|
82
107
|
def parse_event_from(cls, msg_body: bytes) -> Event:
|
|
108
|
+
event_union = get_event_union()
|
|
83
109
|
try:
|
|
84
|
-
event: Event =
|
|
85
|
-
except (
|
|
110
|
+
event: Event = pydantic_v1.parse_raw_as(event_union, msg_body)
|
|
111
|
+
except (pydantic_v1.ValidationError, JSONDecodeError):
|
|
86
112
|
# Log as bytes
|
|
87
113
|
LOGGER.exception("Unable to parse event type", extra={"msg_body": f"{msg_body!r}"})
|
|
88
114
|
return UnknownEvent()
|
|
89
115
|
|
|
90
116
|
return event
|
|
91
117
|
|
|
118
|
+
@staticmethod
|
|
119
|
+
def _check_event_organization_id(event: Event, organization_id: UUID) -> bool:
|
|
120
|
+
"""Check if the organization_id in the event matches the given organization_id.
|
|
121
|
+
|
|
122
|
+
This prevents processing events that are not intended for the current organization, and potentially
|
|
123
|
+
leaking sensitive information across organizations.
|
|
124
|
+
"""
|
|
125
|
+
if hasattr(event, "organization_id") and event.organization_id != organization_id:
|
|
126
|
+
return False
|
|
127
|
+
return True
|
|
128
|
+
|
|
92
129
|
|
|
93
130
|
class EventError(Exception): ...
|
|
94
131
|
|
|
@@ -121,12 +158,12 @@ def _get_major_version(version: str) -> str:
|
|
|
121
158
|
return str(parsed.major)
|
|
122
159
|
|
|
123
160
|
|
|
124
|
-
version = gx.__version__
|
|
161
|
+
version = gx.__version__
|
|
125
162
|
_GX_MAJOR_VERSION = _get_major_version(str(version))
|
|
126
163
|
|
|
127
164
|
|
|
128
165
|
def _get_event_name(event: Event) -> str:
|
|
129
166
|
try:
|
|
130
|
-
return str(event.__name__)
|
|
167
|
+
return str(event.__name__) # type: ignore[union-attr] # FIXME
|
|
131
168
|
except AttributeError:
|
|
132
|
-
return
|
|
169
|
+
return event.__class__.__name__
|
|
@@ -32,3 +32,12 @@ class ErrorCode(str, enum.Enum):
|
|
|
32
32
|
|
|
33
33
|
GENERIC_UNHANDLED_ERROR = "generic-unhandled-error"
|
|
34
34
|
WRONG_USERNAME_OR_PASSWORD = "wrong-username-or-password" # noqa: S105 # Not a hardcoded password
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class GXAgentError(Exception): ...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class GXAgentConfigError(GXAgentError): ...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class GXAgentUnrecoverableConnectionError(GXAgentError): ...
|
|
@@ -1,19 +1,26 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import logging
|
|
4
5
|
import ssl
|
|
5
6
|
from asyncio import AbstractEventLoop
|
|
6
7
|
from dataclasses import dataclass
|
|
7
8
|
from functools import partial
|
|
8
|
-
from typing import TYPE_CHECKING, Callable, Protocol
|
|
9
|
+
from typing import TYPE_CHECKING, Callable, Final, Protocol
|
|
9
10
|
|
|
10
11
|
import pika
|
|
11
12
|
from pika.adapters.asyncio_connection import AsyncioConnection
|
|
13
|
+
from pika.exceptions import ChannelClosed, ConnectionClosed
|
|
14
|
+
|
|
15
|
+
from great_expectations_cloud.agent.exceptions import GXAgentUnrecoverableConnectionError
|
|
12
16
|
|
|
13
17
|
if TYPE_CHECKING:
|
|
14
18
|
from pika.channel import Channel
|
|
15
19
|
from pika.spec import Basic, BasicProperties
|
|
16
20
|
|
|
21
|
+
LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
|
22
|
+
LOGGER.setLevel(logging.DEBUG)
|
|
23
|
+
|
|
17
24
|
|
|
18
25
|
@dataclass(frozen=True)
|
|
19
26
|
class OnMessagePayload:
|
|
@@ -43,6 +50,7 @@ class AsyncRabbitMQClient:
|
|
|
43
50
|
self._closing = False
|
|
44
51
|
self._consumer_tag = None
|
|
45
52
|
self._consuming = False
|
|
53
|
+
self._is_unrecoverable = False
|
|
46
54
|
|
|
47
55
|
def run(self, queue: str, on_message: OnMessageFn) -> None:
|
|
48
56
|
"""Run an async connection to RabbitMQ.
|
|
@@ -67,21 +75,27 @@ class AsyncRabbitMQClient:
|
|
|
67
75
|
)
|
|
68
76
|
self._connection = connection
|
|
69
77
|
connection.ioloop.run_forever()
|
|
78
|
+
if self._is_unrecoverable:
|
|
79
|
+
raise GXAgentUnrecoverableConnectionError( # noqa: TRY003
|
|
80
|
+
"AsyncRabbitMQClient has encountered an unrecoverable error."
|
|
81
|
+
)
|
|
70
82
|
|
|
71
83
|
def stop(self) -> None:
|
|
72
84
|
"""Close the connection to RabbitMQ."""
|
|
73
85
|
if self._connection is None:
|
|
74
86
|
return
|
|
87
|
+
LOGGER.debug("Shutting down the connection to RabbitMQ.")
|
|
75
88
|
if not self._closing:
|
|
76
89
|
self._closing = True
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
90
|
+
|
|
91
|
+
if self._consuming:
|
|
92
|
+
self._stop_consuming()
|
|
93
|
+
self._connection.ioloop.stop()
|
|
94
|
+
LOGGER.debug("The connection to RabbitMQ has been shut down.")
|
|
82
95
|
|
|
83
96
|
def reset(self) -> None:
|
|
84
97
|
"""Reset client to allow a restart."""
|
|
98
|
+
LOGGER.debug("Resetting client")
|
|
85
99
|
self.should_reconnect = False
|
|
86
100
|
self.was_consuming = False
|
|
87
101
|
|
|
@@ -148,7 +162,7 @@ class AsyncRabbitMQClient:
|
|
|
148
162
|
nack = partial(channel.basic_nack, delivery_tag=delivery_tag, requeue=requeue)
|
|
149
163
|
loop.call_soon_threadsafe(callback=nack)
|
|
150
164
|
|
|
151
|
-
def _callback_handler(
|
|
165
|
+
def _callback_handler(
|
|
152
166
|
self,
|
|
153
167
|
channel: Channel,
|
|
154
168
|
method_frame: Basic.Deliver,
|
|
@@ -167,66 +181,118 @@ class AsyncRabbitMQClient:
|
|
|
167
181
|
|
|
168
182
|
def _start_consuming(self, queue: str, on_message: OnMessageFn, channel: Channel) -> None:
|
|
169
183
|
"""Consume from a channel with the on_message callback."""
|
|
184
|
+
LOGGER.debug("Issuing consumer-related RPC commands")
|
|
170
185
|
channel.add_on_cancel_callback(self._on_consumer_canceled)
|
|
186
|
+
# set RabbitMQ prefetch count to equal the max_threads value in the GX Agent's ThreadPoolExecutor
|
|
187
|
+
channel.basic_qos(prefetch_count=1)
|
|
171
188
|
self._consumer_tag = channel.basic_consume(queue=queue, on_message_callback=on_message)
|
|
172
189
|
|
|
173
190
|
def _on_consumer_canceled(self, method_frame: Basic.Cancel) -> None:
|
|
174
191
|
"""Callback invoked when the broker cancels the client's connection."""
|
|
175
192
|
if self._channel is not None:
|
|
193
|
+
LOGGER.info(
|
|
194
|
+
"Consumer was cancelled remotely, shutting down",
|
|
195
|
+
extra={
|
|
196
|
+
"method_frame": method_frame,
|
|
197
|
+
},
|
|
198
|
+
)
|
|
176
199
|
self._channel.close()
|
|
177
200
|
|
|
178
201
|
def _reconnect(self) -> None:
|
|
179
202
|
"""Prepare the client to reconnect."""
|
|
203
|
+
LOGGER.debug("Preparing client to reconnect")
|
|
180
204
|
self.should_reconnect = True
|
|
181
205
|
self.stop()
|
|
182
206
|
|
|
183
207
|
def _stop_consuming(self) -> None:
|
|
184
208
|
"""Cancel the channel, if it exists."""
|
|
185
209
|
if self._channel is not None:
|
|
210
|
+
LOGGER.debug("Sending a Basic.Cancel RPC command to RabbitMQ")
|
|
186
211
|
self._channel.basic_cancel(self._consumer_tag, callback=self._on_cancel_ok)
|
|
187
212
|
|
|
188
|
-
def _on_cancel_ok(self,
|
|
213
|
+
def _on_cancel_ok(self, _unused_frame: Basic.CancelOk) -> None:
|
|
189
214
|
"""Callback invoked after broker confirms cancel."""
|
|
190
215
|
self._consuming = False
|
|
191
216
|
if self._channel is not None:
|
|
217
|
+
LOGGER.debug("RabbitMQ acknowledged the cancellation of the consumer")
|
|
192
218
|
self._channel.close()
|
|
193
219
|
|
|
194
220
|
def _on_connection_open(
|
|
195
221
|
self, connection: AsyncioConnection, queue: str, on_message: OnMessageFn
|
|
196
222
|
) -> None:
|
|
197
223
|
"""Callback invoked after the broker opens the connection."""
|
|
224
|
+
LOGGER.debug("Connection to RabbitMQ has been opened")
|
|
198
225
|
on_channel_open = partial(self._on_channel_open, queue=queue, on_message=on_message)
|
|
199
226
|
connection.channel(on_open_callback=on_channel_open)
|
|
200
227
|
|
|
201
|
-
def _on_connection_open_error(
|
|
228
|
+
def _on_connection_open_error(
|
|
229
|
+
self, _unused_connection: AsyncioConnection, reason: pika.Exception
|
|
230
|
+
) -> None:
|
|
202
231
|
"""Callback invoked when there is an error while opening connection."""
|
|
203
232
|
self._reconnect()
|
|
233
|
+
self._log_pika_exception("Connection open failed", reason)
|
|
204
234
|
|
|
205
|
-
def _on_connection_closed(
|
|
235
|
+
def _on_connection_closed(
|
|
236
|
+
self, connection: AsyncioConnection, _unused_reason: pika.Exception
|
|
237
|
+
) -> None:
|
|
206
238
|
"""Callback invoked after the broker closes the connection"""
|
|
239
|
+
LOGGER.debug("Connection to RabbitMQ has been closed")
|
|
207
240
|
self._channel = None
|
|
241
|
+
self._is_unrecoverable = True
|
|
208
242
|
if self._closing:
|
|
209
243
|
connection.ioloop.stop()
|
|
210
244
|
else:
|
|
211
245
|
self._reconnect()
|
|
212
246
|
|
|
213
|
-
def _close_connection(self) -> None:
|
|
247
|
+
def _close_connection(self, reason: pika.Exception) -> None:
|
|
214
248
|
"""Close the connection to the broker."""
|
|
215
249
|
self._consuming = False
|
|
216
250
|
if self._connection is None or self._connection.is_closing or self._connection.is_closed:
|
|
251
|
+
LOGGER.debug("Connection to RabbitMQ is closing or is already closed")
|
|
217
252
|
pass
|
|
218
253
|
else:
|
|
219
|
-
|
|
254
|
+
LOGGER.debug("Closing connection to RabbitMQ")
|
|
255
|
+
|
|
256
|
+
if isinstance(reason, (ConnectionClosed, ChannelClosed)):
|
|
257
|
+
reply_code = reason.reply_code
|
|
258
|
+
reply_text = reason.reply_text
|
|
259
|
+
else:
|
|
260
|
+
reply_code = 999 # arbitrary value, not in the list of AMQP reply codes: https://www.rabbitmq.com/amqp-0-9-1-reference#constants
|
|
261
|
+
reply_text = str(reason)
|
|
262
|
+
self._connection.close(reply_code=reply_code, reply_text=reply_text)
|
|
263
|
+
|
|
264
|
+
def _log_pika_exception(
|
|
265
|
+
self, message: str, reason: pika.Exception, extra: dict[str, str] | None = None
|
|
266
|
+
) -> None:
|
|
267
|
+
"""Log a pika exception. Extra is key-value pairs to include in the log message."""
|
|
268
|
+
if not extra:
|
|
269
|
+
extra = {}
|
|
270
|
+
if isinstance(reason, (ConnectionClosed, ChannelClosed)):
|
|
271
|
+
default_extra: dict[str, str] = {
|
|
272
|
+
"reply_code": str(reason.reply_code),
|
|
273
|
+
"reply_text": str(reason.reply_text),
|
|
274
|
+
}
|
|
275
|
+
LOGGER.error(
|
|
276
|
+
message,
|
|
277
|
+
# mypy not happy with dict | dict, so we use the dict constructor
|
|
278
|
+
extra={**default_extra, **extra},
|
|
279
|
+
)
|
|
280
|
+
else:
|
|
281
|
+
default_extra = {"reason": str(reason)}
|
|
282
|
+
# mypy not happy with dict | dict, so we use the dict constructor
|
|
283
|
+
LOGGER.error(message, extra={**default_extra, **extra})
|
|
220
284
|
|
|
221
285
|
def _on_channel_open(self, channel: Channel, queue: str, on_message: OnMessageFn) -> None:
|
|
222
286
|
"""Callback invoked after the broker opens the channel."""
|
|
287
|
+
LOGGER.debug("Channel opened")
|
|
223
288
|
self._channel = channel
|
|
224
289
|
channel.add_on_close_callback(self._on_channel_closed)
|
|
225
290
|
self._start_consuming(queue=queue, on_message=on_message, channel=channel)
|
|
226
291
|
|
|
227
|
-
def _on_channel_closed(self, channel: Channel, reason:
|
|
292
|
+
def _on_channel_closed(self, channel: Channel, reason: ChannelClosed) -> None:
|
|
228
293
|
"""Callback invoked after the broker closes the channel."""
|
|
229
|
-
self.
|
|
294
|
+
self._log_pika_exception("Channel closed", reason, extra={"channel": channel})
|
|
295
|
+
self._close_connection(reason)
|
|
230
296
|
|
|
231
297
|
def _build_client_parameters(self, url: str) -> pika.URLParameters:
|
|
232
298
|
"""Configure parameters used to connect to the broker."""
|
|
@@ -2,9 +2,10 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import time
|
|
5
|
+
from collections.abc import Coroutine
|
|
5
6
|
from dataclasses import dataclass
|
|
6
7
|
from functools import partial
|
|
7
|
-
from typing import TYPE_CHECKING, Callable,
|
|
8
|
+
from typing import TYPE_CHECKING, Callable, Protocol
|
|
8
9
|
|
|
9
10
|
from pika.exceptions import (
|
|
10
11
|
AMQPError,
|
|
@@ -13,6 +14,7 @@ from pika.exceptions import (
|
|
|
13
14
|
)
|
|
14
15
|
|
|
15
16
|
from great_expectations_cloud.agent.event_handler import EventHandler
|
|
17
|
+
from great_expectations_cloud.agent.exceptions import GXAgentUnrecoverableConnectionError
|
|
16
18
|
|
|
17
19
|
if TYPE_CHECKING:
|
|
18
20
|
from great_expectations_cloud.agent.message_service.asyncio_rabbit_mq_client import (
|
|
@@ -95,11 +97,13 @@ class Subscriber:
|
|
|
95
97
|
self.client.stop()
|
|
96
98
|
reconnect_delay = self._get_reconnect_delay()
|
|
97
99
|
time.sleep(reconnect_delay) # todo: update this blocking call to asyncio.sleep
|
|
100
|
+
raise
|
|
98
101
|
except KeyboardInterrupt as e:
|
|
99
102
|
self.client.stop()
|
|
100
103
|
raise KeyboardInterrupt from e
|
|
101
|
-
|
|
102
|
-
self.client.
|
|
104
|
+
except GXAgentUnrecoverableConnectionError:
|
|
105
|
+
self.client.stop()
|
|
106
|
+
raise
|
|
103
107
|
else:
|
|
104
108
|
break # exit
|
|
105
109
|
|
|
@@ -159,8 +163,7 @@ class Subscriber:
|
|
|
159
163
|
self._reconnect_delay = 0
|
|
160
164
|
else:
|
|
161
165
|
self._reconnect_delay += 1
|
|
162
|
-
|
|
163
|
-
self._reconnect_delay = 30
|
|
166
|
+
self._reconnect_delay = min(self._reconnect_delay, 30)
|
|
164
167
|
return self._reconnect_delay
|
|
165
168
|
|
|
166
169
|
def close(self) -> None:
|