snowpark-checkpoints-validators 0.2.0rc1__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_checkpoints/__init__.py +44 -0
- snowflake/snowpark_checkpoints/__version__.py +16 -0
- snowflake/snowpark_checkpoints/checkpoint.py +580 -0
- snowflake/snowpark_checkpoints/errors.py +60 -0
- snowflake/snowpark_checkpoints/job_context.py +128 -0
- snowflake/snowpark_checkpoints/singleton.py +23 -0
- snowflake/snowpark_checkpoints/snowpark_sampler.py +124 -0
- snowflake/snowpark_checkpoints/spark_migration.py +255 -0
- snowflake/snowpark_checkpoints/utils/__init__.py +14 -0
- snowflake/snowpark_checkpoints/utils/constants.py +134 -0
- snowflake/snowpark_checkpoints/utils/extra_config.py +89 -0
- snowflake/snowpark_checkpoints/utils/logging_utils.py +67 -0
- snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +399 -0
- snowflake/snowpark_checkpoints/utils/supported_types.py +65 -0
- snowflake/snowpark_checkpoints/utils/telemetry.py +900 -0
- snowflake/snowpark_checkpoints/utils/utils_checks.py +395 -0
- snowflake/snowpark_checkpoints/validation_result_metadata.py +155 -0
- snowflake/snowpark_checkpoints/validation_results.py +49 -0
- snowpark_checkpoints_validators-0.2.1.dist-info/METADATA +323 -0
- snowpark_checkpoints_validators-0.2.1.dist-info/RECORD +22 -0
- snowpark_checkpoints_validators-0.2.0rc1.dist-info/METADATA +0 -514
- snowpark_checkpoints_validators-0.2.0rc1.dist-info/RECORD +0 -4
- {snowpark_checkpoints_validators-0.2.0rc1.dist-info → snowpark_checkpoints_validators-0.2.1.dist-info}/WHEEL +0 -0
- {snowpark_checkpoints_validators-0.2.0rc1.dist-info → snowpark_checkpoints_validators-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,900 @@
|
|
1
|
+
# Copyright 2025 Snowflake Inc.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import atexit
|
17
|
+
import datetime
|
18
|
+
import hashlib
|
19
|
+
import inspect
|
20
|
+
import json
|
21
|
+
import os
|
22
|
+
import re
|
23
|
+
|
24
|
+
from contextlib import suppress
|
25
|
+
from enum import IntEnum
|
26
|
+
from functools import wraps
|
27
|
+
from pathlib import Path
|
28
|
+
from platform import python_version
|
29
|
+
from sys import platform
|
30
|
+
from typing import Any, Callable, Optional, TypeVar
|
31
|
+
from uuid import getnode
|
32
|
+
|
33
|
+
from snowflake.connector import (
|
34
|
+
SNOWFLAKE_CONNECTOR_VERSION,
|
35
|
+
time_util,
|
36
|
+
)
|
37
|
+
from snowflake.connector.constants import DIRS as SNOWFLAKE_DIRS
|
38
|
+
from snowflake.connector.network import SnowflakeRestful
|
39
|
+
from snowflake.connector.telemetry import TelemetryClient
|
40
|
+
from snowflake.snowpark import VERSION as SNOWPARK_VERSION
|
41
|
+
from snowflake.snowpark import dataframe as snowpark_dataframe
|
42
|
+
from snowflake.snowpark.session import Session
|
43
|
+
|
44
|
+
|
45
|
+
try:
|
46
|
+
from pyspark.sql import dataframe as spark_dataframe
|
47
|
+
|
48
|
+
def _is_spark_dataframe(df: Any) -> bool:
|
49
|
+
return isinstance(df, spark_dataframe.DataFrame)
|
50
|
+
|
51
|
+
def _get_spark_schema_types(df: spark_dataframe.DataFrame) -> list[str]:
|
52
|
+
return [str(schema_type.dataType) for schema_type in df.schema.fields]
|
53
|
+
|
54
|
+
except Exception:
|
55
|
+
|
56
|
+
def _is_spark_dataframe(df: Any):
|
57
|
+
pass
|
58
|
+
|
59
|
+
def _get_spark_schema_types(df: Any):
|
60
|
+
pass
|
61
|
+
|
62
|
+
|
63
|
+
VERSION_VARIABLE_PATTERN = r"^__version__ = ['\"]([^'\"]*)['\"]"
|
64
|
+
VERSION_FILE_NAME = "__version__.py"
|
65
|
+
|
66
|
+
|
67
|
+
class TelemetryManager(TelemetryClient):
|
68
|
+
def __init__(
|
69
|
+
self, rest: Optional[SnowflakeRestful] = None, is_telemetry_enabled: bool = True
|
70
|
+
):
|
71
|
+
"""TelemetryManager class to log telemetry events."""
|
72
|
+
super().__init__(rest)
|
73
|
+
self.sc_folder_path = (
|
74
|
+
Path(SNOWFLAKE_DIRS.user_config_path) / "snowpark-checkpoints-telemetry"
|
75
|
+
)
|
76
|
+
self.sc_sf_path_telemetry = "/telemetry/send"
|
77
|
+
self.sc_flush_size = 25
|
78
|
+
self.sc_is_enabled = is_telemetry_enabled
|
79
|
+
self.sc_is_testing = self._sc_is_telemetry_testing()
|
80
|
+
self.sc_memory_limit = 5 * 1024 * 1024
|
81
|
+
self._sc_upload_local_telemetry()
|
82
|
+
self.sc_log_batch = []
|
83
|
+
self.sc_hypothesis_input_events = []
|
84
|
+
self.sc_version = _get_version()
|
85
|
+
if rest:
|
86
|
+
atexit.register(self._sc_close_at_exit)
|
87
|
+
|
88
|
+
def set_sc_output_path(self, path: Path) -> None:
|
89
|
+
"""Set the output path for testing.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
path: path to write telemetry.
|
93
|
+
|
94
|
+
"""
|
95
|
+
os.makedirs(path, exist_ok=True)
|
96
|
+
self.sc_folder_path = path
|
97
|
+
|
98
|
+
def sc_log_error(
|
99
|
+
self, event_name: str, parameters_info: Optional[dict] = None
|
100
|
+
) -> None:
|
101
|
+
"""Log an error telemetry event.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
event_name (str): The name of the event.
|
105
|
+
parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
|
106
|
+
|
107
|
+
"""
|
108
|
+
if event_name is not None:
|
109
|
+
self._sc_log_telemetry(event_name, "error", parameters_info)
|
110
|
+
|
111
|
+
def sc_log_info(
|
112
|
+
self, event_name: str, parameters_info: Optional[dict] = None
|
113
|
+
) -> None:
|
114
|
+
"""Log an information telemetry event.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
event_name (str): The name of the event.
|
118
|
+
parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
|
119
|
+
|
120
|
+
"""
|
121
|
+
if event_name is not None:
|
122
|
+
self._sc_log_telemetry(event_name, "info", parameters_info)
|
123
|
+
|
124
|
+
def _sc_log_telemetry(
|
125
|
+
self, event_name: str, event_type: str, parameters_info: Optional[dict] = None
|
126
|
+
) -> dict:
|
127
|
+
"""Log a telemetry event if is enabled.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
event_name (str): The name of the event.
|
131
|
+
event_type (str): The type of the event (e.g., "error", "info").
|
132
|
+
parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
dict: The logged event.
|
136
|
+
|
137
|
+
"""
|
138
|
+
if not self.sc_is_enabled:
|
139
|
+
return {}
|
140
|
+
event = _generate_event(
|
141
|
+
event_name, event_type, parameters_info, self.sc_version
|
142
|
+
)
|
143
|
+
self._sc_add_log_to_batch(event)
|
144
|
+
return event
|
145
|
+
|
146
|
+
def _sc_add_log_to_batch(self, event: dict) -> None:
|
147
|
+
"""Add a log event to the batch. If the batch is full, send the events to the API.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
event (dict): The event to add.
|
151
|
+
|
152
|
+
"""
|
153
|
+
self.sc_log_batch.append(event)
|
154
|
+
if self.sc_is_testing:
|
155
|
+
self._sc_write_telemetry(self.sc_log_batch)
|
156
|
+
self.sc_log_batch = []
|
157
|
+
return
|
158
|
+
|
159
|
+
if len(self.sc_log_batch) >= self.sc_flush_size:
|
160
|
+
self.sc_send_batch(self.sc_log_batch)
|
161
|
+
self.sc_log_batch = []
|
162
|
+
|
163
|
+
def sc_send_batch(self, to_sent: list) -> bool:
|
164
|
+
"""Send a request to the API to upload the events. If not have connection, write the events to local folder.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
to_sent (list): The batch of events to send.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
bool: True if the batch was sent successfully, False otherwise.
|
171
|
+
|
172
|
+
"""
|
173
|
+
if not self.sc_is_enabled:
|
174
|
+
return False
|
175
|
+
if self._rest is None:
|
176
|
+
self._sc_write_telemetry(to_sent)
|
177
|
+
self.sc_log_batch = []
|
178
|
+
return False
|
179
|
+
if to_sent == []:
|
180
|
+
return False
|
181
|
+
body = {"logs": to_sent}
|
182
|
+
ret = self._rest.request(
|
183
|
+
self.sc_sf_path_telemetry,
|
184
|
+
body=body,
|
185
|
+
method="post",
|
186
|
+
client=None,
|
187
|
+
timeout=5,
|
188
|
+
)
|
189
|
+
if not ret.get("success"):
|
190
|
+
self._sc_write_telemetry(to_sent)
|
191
|
+
self.sc_log_batch = []
|
192
|
+
return False
|
193
|
+
return True
|
194
|
+
|
195
|
+
def _sc_write_telemetry(self, batch: list) -> None:
|
196
|
+
"""Write telemetry events to local folder. If the folder is full, free up space for the new events.
|
197
|
+
|
198
|
+
Args:
|
199
|
+
batch (list): The batch of events to write.
|
200
|
+
|
201
|
+
"""
|
202
|
+
try:
|
203
|
+
os.makedirs(self.sc_folder_path, exist_ok=True)
|
204
|
+
for event in batch:
|
205
|
+
message = event.get("message")
|
206
|
+
if message is not None:
|
207
|
+
file_path = (
|
208
|
+
self.sc_folder_path
|
209
|
+
/ f'{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")}'
|
210
|
+
f'_telemetry_{message.get("type")}.json'
|
211
|
+
)
|
212
|
+
json_content = self._sc_validate_folder_space(event)
|
213
|
+
with open(file_path, "w") as json_file:
|
214
|
+
json_file.write(json_content)
|
215
|
+
except Exception:
|
216
|
+
pass
|
217
|
+
|
218
|
+
def _sc_validate_folder_space(self, event: dict) -> str:
|
219
|
+
"""Validate and manage folder space for the new telemetry events.
|
220
|
+
|
221
|
+
Args:
|
222
|
+
event (dict): The event to validate.
|
223
|
+
|
224
|
+
Returns:
|
225
|
+
str: The JSON content of the event.
|
226
|
+
|
227
|
+
"""
|
228
|
+
json_content = json.dumps(event, indent=4, sort_keys=True)
|
229
|
+
new_json_file_size = len(json_content.encode("utf-8"))
|
230
|
+
telemetry_folder = self.sc_folder_path
|
231
|
+
folder_size = _get_folder_size(telemetry_folder)
|
232
|
+
if folder_size + new_json_file_size > self.sc_memory_limit:
|
233
|
+
_free_up_space(telemetry_folder, self.sc_memory_limit - new_json_file_size)
|
234
|
+
return json_content
|
235
|
+
|
236
|
+
def _sc_upload_local_telemetry(self) -> None:
|
237
|
+
"""Send a request to the API to upload the local telemetry events."""
|
238
|
+
if not self.sc_is_enabled or self.sc_is_testing or not self._rest:
|
239
|
+
return
|
240
|
+
batch = []
|
241
|
+
for file in self.sc_folder_path.glob("*.json"):
|
242
|
+
with open(file) as json_file:
|
243
|
+
data_dict = json.load(json_file)
|
244
|
+
batch.append(data_dict)
|
245
|
+
if batch == []:
|
246
|
+
return
|
247
|
+
body = {"logs": batch}
|
248
|
+
ret = self._rest.request(
|
249
|
+
self.sc_sf_path_telemetry,
|
250
|
+
body=body,
|
251
|
+
method="post",
|
252
|
+
client=None,
|
253
|
+
timeout=5,
|
254
|
+
)
|
255
|
+
if ret.get("success"):
|
256
|
+
for file in self.sc_folder_path.glob("*.json"):
|
257
|
+
file.unlink()
|
258
|
+
|
259
|
+
def _sc_is_telemetry_testing(self) -> bool:
|
260
|
+
is_testing = os.getenv("SNOWPARK_CHECKPOINTS_TELEMETRY_TESTING") == "true"
|
261
|
+
if is_testing:
|
262
|
+
local_telemetry_path = (
|
263
|
+
Path(os.getcwd()) / "snowpark-checkpoints-output" / "telemetry"
|
264
|
+
)
|
265
|
+
self.set_sc_output_path(local_telemetry_path)
|
266
|
+
self.sc_is_enabled = True
|
267
|
+
return is_testing
|
268
|
+
|
269
|
+
def _sc_is_telemetry_manager(self) -> bool:
|
270
|
+
"""Check if the class is telemetry manager.
|
271
|
+
|
272
|
+
Returns:
|
273
|
+
bool: True if the class is telemetry manager, False otherwise.
|
274
|
+
|
275
|
+
"""
|
276
|
+
return True
|
277
|
+
|
278
|
+
def sc_is_hypothesis_event_logged(self, event_name: tuple[str, int]) -> bool:
|
279
|
+
"""Check if a Hypothesis event is logged.
|
280
|
+
|
281
|
+
Args:
|
282
|
+
event_name (tuple[str, int]): A tuple containing the name of the event and an integer identifier
|
283
|
+
(0 for info, 1 for error).
|
284
|
+
|
285
|
+
Returns:
|
286
|
+
bool: True if the event is logged, False otherwise.
|
287
|
+
|
288
|
+
"""
|
289
|
+
return event_name in self.sc_hypothesis_input_events
|
290
|
+
|
291
|
+
def _sc_close(self) -> None:
|
292
|
+
"""Close the telemetry manager and upload collected events.
|
293
|
+
|
294
|
+
This function closes the telemetry manager, uploads any collected events,
|
295
|
+
and performs any necessary cleanup to ensure no data is lost.
|
296
|
+
"""
|
297
|
+
atexit.unregister(self._sc_close_at_exit)
|
298
|
+
if self.sc_log_batch and self.sc_is_enabled and not self.sc_is_testing:
|
299
|
+
self.sc_send_batch(self.sc_log_batch)
|
300
|
+
|
301
|
+
def _sc_close_at_exit(self) -> None:
|
302
|
+
"""Close the telemetry manager at exit and upload collected events.
|
303
|
+
|
304
|
+
This function ensures that the telemetry manager is closed and all collected events
|
305
|
+
are uploaded when the program exits, preventing data loss.
|
306
|
+
"""
|
307
|
+
with suppress(Exception):
|
308
|
+
self._sc_close()
|
309
|
+
|
310
|
+
|
311
|
+
def _generate_event(
|
312
|
+
event_name: str,
|
313
|
+
event_type: str,
|
314
|
+
parameters_info: Optional[dict] = None,
|
315
|
+
sc_version: Optional[str] = None,
|
316
|
+
) -> dict:
|
317
|
+
"""Generate a telemetry event.
|
318
|
+
|
319
|
+
Args:
|
320
|
+
event_name (str): The name of the event.
|
321
|
+
event_type (str): The type of the event (e.g., "error", "info").
|
322
|
+
parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
|
323
|
+
sc_version (str, optional): The version of the package. Defaults to None.
|
324
|
+
|
325
|
+
Returns:
|
326
|
+
dict: The generated event.
|
327
|
+
|
328
|
+
"""
|
329
|
+
metadata = _get_metadata()
|
330
|
+
if sc_version is not None:
|
331
|
+
metadata["snowpark_checkpoints_version"] = sc_version
|
332
|
+
message = {
|
333
|
+
"event_type": event_type,
|
334
|
+
"type": "snowpark-checkpoints",
|
335
|
+
"event_name": event_name,
|
336
|
+
"driver_type": "PythonConnector",
|
337
|
+
"driver_version": SNOWFLAKE_CONNECTOR_VERSION,
|
338
|
+
"metadata": metadata,
|
339
|
+
"data": json.dumps(parameters_info or {}),
|
340
|
+
}
|
341
|
+
timestamp = time_util.get_time_millis()
|
342
|
+
event_base = {"message": message, "timestamp": str(timestamp)}
|
343
|
+
|
344
|
+
return event_base
|
345
|
+
|
346
|
+
|
347
|
+
def _get_metadata() -> dict:
|
348
|
+
"""Get metadata for telemetry events.
|
349
|
+
|
350
|
+
Returns:
|
351
|
+
dict: The metadata including OS version, Python version, and device ID.
|
352
|
+
|
353
|
+
"""
|
354
|
+
return {
|
355
|
+
"os_version": platform,
|
356
|
+
"python_version": python_version(),
|
357
|
+
"snowpark_version": ".".join(str(x) for x in SNOWPARK_VERSION if x is not None),
|
358
|
+
"device_id": _get_unique_id(),
|
359
|
+
}
|
360
|
+
|
361
|
+
|
362
|
+
def _get_version() -> str:
|
363
|
+
"""Get the version of the package.
|
364
|
+
|
365
|
+
Returns:
|
366
|
+
str: The version of the package.
|
367
|
+
|
368
|
+
"""
|
369
|
+
try:
|
370
|
+
directory_levels_up = 1
|
371
|
+
project_root = Path(__file__).resolve().parents[directory_levels_up]
|
372
|
+
version_file_path = project_root / VERSION_FILE_NAME
|
373
|
+
with open(version_file_path) as file:
|
374
|
+
content = file.read()
|
375
|
+
version_match = re.search(VERSION_VARIABLE_PATTERN, content, re.MULTILINE)
|
376
|
+
if version_match:
|
377
|
+
return version_match.group(1)
|
378
|
+
return None
|
379
|
+
except Exception:
|
380
|
+
return None
|
381
|
+
|
382
|
+
|
383
|
+
def _get_folder_size(folder_path: Path) -> int:
|
384
|
+
"""Get the size of a folder. Only considers JSON files.
|
385
|
+
|
386
|
+
Args:
|
387
|
+
folder_path (Path): The path to the folder.
|
388
|
+
|
389
|
+
Returns:
|
390
|
+
int: The size of the folder in bytes.
|
391
|
+
|
392
|
+
"""
|
393
|
+
return sum(f.stat().st_size for f in folder_path.glob("*.json") if f.is_file())
|
394
|
+
|
395
|
+
|
396
|
+
def _free_up_space(folder_path: Path, max_size: int) -> None:
|
397
|
+
"""Free up space in a folder by deleting the oldest files. Only considers JSON files.
|
398
|
+
|
399
|
+
Args:
|
400
|
+
folder_path (Path): The path to the folder.
|
401
|
+
max_size (int): The maximum allowed size of the folder in bytes.
|
402
|
+
|
403
|
+
"""
|
404
|
+
files = sorted(folder_path.glob("*.json"), key=lambda f: f.stat().st_mtime)
|
405
|
+
current_size = _get_folder_size(folder_path)
|
406
|
+
for file in files:
|
407
|
+
if current_size <= max_size:
|
408
|
+
break
|
409
|
+
current_size -= file.stat().st_size
|
410
|
+
file.unlink()
|
411
|
+
|
412
|
+
|
413
|
+
def _get_unique_id() -> str:
|
414
|
+
"""Get a unique device ID. The ID is generated based on the hashed MAC address.
|
415
|
+
|
416
|
+
Returns:
|
417
|
+
str: The hashed device ID.
|
418
|
+
|
419
|
+
"""
|
420
|
+
node_id_str = str(getnode())
|
421
|
+
hashed_id = hashlib.sha256(node_id_str.encode()).hexdigest()
|
422
|
+
return hashed_id
|
423
|
+
|
424
|
+
|
425
|
+
def get_telemetry_manager() -> TelemetryManager:
|
426
|
+
"""Get the telemetry manager.
|
427
|
+
|
428
|
+
Returns:
|
429
|
+
TelemetryManager: The telemetry manager.
|
430
|
+
|
431
|
+
"""
|
432
|
+
try:
|
433
|
+
connection = Session.builder.getOrCreate().connection
|
434
|
+
if not hasattr(connection._telemetry, "_sc_is_telemetry_manager"):
|
435
|
+
connection._telemetry = TelemetryManager(
|
436
|
+
connection._rest, connection.telemetry_enabled
|
437
|
+
)
|
438
|
+
return connection._telemetry
|
439
|
+
except Exception:
|
440
|
+
telemetry_manager = TelemetryManager(None, is_telemetry_enabled=True)
|
441
|
+
telemetry_manager.sc_flush_size = 1
|
442
|
+
return telemetry_manager
|
443
|
+
|
444
|
+
|
445
|
+
def get_snowflake_schema_types(df: snowpark_dataframe.DataFrame) -> list[str]:
|
446
|
+
"""Extract the data types of the schema fields from a Snowflake DataFrame.
|
447
|
+
|
448
|
+
Args:
|
449
|
+
df (snowpark_dataframe.DataFrame): The Snowflake DataFrame.
|
450
|
+
|
451
|
+
Returns:
|
452
|
+
list[str]: A list of data type names of the schema fields.
|
453
|
+
|
454
|
+
"""
|
455
|
+
return [str(schema_type.datatype) for schema_type in df.schema.fields]
|
456
|
+
|
457
|
+
|
458
|
+
def _is_snowpark_dataframe(df: Any) -> bool:
|
459
|
+
"""Check if the given dataframe is a Snowpark dataframe.
|
460
|
+
|
461
|
+
Args:
|
462
|
+
df: The dataframe to check.
|
463
|
+
|
464
|
+
Returns:
|
465
|
+
bool: True if the dataframe is a Snowpark dataframe, False otherwise.
|
466
|
+
|
467
|
+
"""
|
468
|
+
return isinstance(df, snowpark_dataframe.DataFrame)
|
469
|
+
|
470
|
+
|
471
|
+
def get_load_json(json_schema: str) -> dict:
|
472
|
+
"""Load and parse a JSON schema file.
|
473
|
+
|
474
|
+
Args:
|
475
|
+
json_schema (str): The path to the JSON schema file.
|
476
|
+
|
477
|
+
Returns:
|
478
|
+
dict: The parsed JSON content.
|
479
|
+
|
480
|
+
Raises:
|
481
|
+
ValueError: If there is an error reading or parsing the JSON file.
|
482
|
+
|
483
|
+
"""
|
484
|
+
try:
|
485
|
+
with open(json_schema, encoding="utf-8") as file:
|
486
|
+
return json.load(file)
|
487
|
+
except (OSError, json.JSONDecodeError) as e:
|
488
|
+
raise ValueError(f"Error reading JSON schema file: {e}") from None
|
489
|
+
|
490
|
+
|
491
|
+
def extract_parameters(
|
492
|
+
func: Callable, args: tuple, kwargs: dict, params_list: Optional[list[str]]
|
493
|
+
) -> dict:
|
494
|
+
"""Extract parameters from the function arguments.
|
495
|
+
|
496
|
+
Args:
|
497
|
+
func (Callable): The function being decorated.
|
498
|
+
args (tuple): The positional arguments passed to the function.
|
499
|
+
kwargs (dict): The keyword arguments passed to the function.
|
500
|
+
params_list (list[str]): The list of parameters to extract.
|
501
|
+
|
502
|
+
Returns:
|
503
|
+
dict: A dictionary of extracted parameters.
|
504
|
+
|
505
|
+
"""
|
506
|
+
parameters = inspect.signature(func).parameters
|
507
|
+
param_data = {}
|
508
|
+
if params_list:
|
509
|
+
for _, param in enumerate(params_list):
|
510
|
+
if len(args) > 0:
|
511
|
+
index = list(parameters.keys()).index(param)
|
512
|
+
param_data[param] = args[index]
|
513
|
+
else:
|
514
|
+
if kwargs[param]:
|
515
|
+
param_data[param] = kwargs[param]
|
516
|
+
return param_data
|
517
|
+
|
518
|
+
|
519
|
+
def check_dataframe_schema_event(
|
520
|
+
telemetry_data: dict, param_data: dict
|
521
|
+
) -> tuple[str, dict]:
|
522
|
+
"""Handle telemetry event for checking dataframe schema.
|
523
|
+
|
524
|
+
Args:
|
525
|
+
telemetry_data (dict): The telemetry data dictionary.
|
526
|
+
param_data (dict): The parameter data dictionary.
|
527
|
+
|
528
|
+
Returns:
|
529
|
+
tuple: A tuple containing the event name and telemetry data.
|
530
|
+
|
531
|
+
"""
|
532
|
+
telemetry_data[MODE_KEY] = CheckpointMode.SCHEMA.value
|
533
|
+
try:
|
534
|
+
telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY)
|
535
|
+
pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
|
536
|
+
schema_types = []
|
537
|
+
for schema_type in pandera_schema.columns.values():
|
538
|
+
if schema_type.dtype is not None:
|
539
|
+
schema_types.append(str(schema_type.dtype))
|
540
|
+
if schema_types:
|
541
|
+
telemetry_data[SCHEMA_TYPES_KEY] = schema_types
|
542
|
+
return DATAFRAME_VALIDATOR_SCHEMA, telemetry_data
|
543
|
+
except Exception:
|
544
|
+
if param_data.get(STATUS_KEY):
|
545
|
+
telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY)
|
546
|
+
pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
|
547
|
+
if pandera_schema:
|
548
|
+
schema_types = []
|
549
|
+
for schema_type in pandera_schema.columns.values():
|
550
|
+
if schema_type.dtype is not None:
|
551
|
+
schema_types.append(str(schema_type.dtype))
|
552
|
+
if schema_types:
|
553
|
+
telemetry_data[SCHEMA_TYPES_KEY] = schema_types
|
554
|
+
return DATAFRAME_VALIDATOR_ERROR, telemetry_data
|
555
|
+
|
556
|
+
|
557
|
+
def check_output_or_input_schema_event(
|
558
|
+
telemetry_data: dict, param_data: dict
|
559
|
+
) -> tuple[str, dict]:
|
560
|
+
"""Handle telemetry event for checking output or input schema.
|
561
|
+
|
562
|
+
Args:
|
563
|
+
telemetry_data (dict): The telemetry data dictionary.
|
564
|
+
param_data (dict): The parameter data dictionary.
|
565
|
+
|
566
|
+
Returns:
|
567
|
+
tuple: A tuple containing the event name and telemetry data.
|
568
|
+
|
569
|
+
"""
|
570
|
+
try:
|
571
|
+
pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
|
572
|
+
schema_types = []
|
573
|
+
for schema_type in pandera_schema.columns.values():
|
574
|
+
if schema_type.dtype is not None:
|
575
|
+
schema_types.append(str(schema_type.dtype))
|
576
|
+
if schema_types:
|
577
|
+
telemetry_data[SCHEMA_TYPES_KEY] = schema_types
|
578
|
+
return DATAFRAME_VALIDATOR_SCHEMA, telemetry_data
|
579
|
+
except Exception:
|
580
|
+
return DATAFRAME_VALIDATOR_ERROR, telemetry_data
|
581
|
+
|
582
|
+
|
583
|
+
def collect_dataframe_checkpoint_mode_schema_event(
|
584
|
+
telemetry_data: dict, param_data: dict
|
585
|
+
) -> tuple[str, dict]:
|
586
|
+
"""Handle telemetry event for collecting dataframe checkpoint mode schema.
|
587
|
+
|
588
|
+
Args:
|
589
|
+
telemetry_data (dict): The telemetry data dictionary.
|
590
|
+
param_data (dict): The parameter data dictionary.
|
591
|
+
|
592
|
+
Returns:
|
593
|
+
tuple: A tuple containing the event name and telemetry data.
|
594
|
+
|
595
|
+
"""
|
596
|
+
telemetry_data[MODE_KEY] = CheckpointMode.SCHEMA.value
|
597
|
+
try:
|
598
|
+
schema_types = param_data.get("column_type_dict")
|
599
|
+
telemetry_data[SCHEMA_TYPES_KEY] = [
|
600
|
+
schema_types[schema_type].dataType.typeName()
|
601
|
+
for schema_type in schema_types
|
602
|
+
]
|
603
|
+
return DATAFRAME_COLLECTION_SCHEMA, telemetry_data
|
604
|
+
except Exception:
|
605
|
+
return DATAFRAME_COLLECTION_ERROR, telemetry_data
|
606
|
+
|
607
|
+
|
608
|
+
def collect_dataframe_checkpoint_mode_dataframe_event(
|
609
|
+
telemetry_data: dict, param_data: dict
|
610
|
+
) -> tuple[str, dict]:
|
611
|
+
"""Handle telemetry event for collecting dataframe checkpoint mode dataframe.
|
612
|
+
|
613
|
+
This function processes telemetry data for a dataframe checkpoint mode event. It updates the telemetry data
|
614
|
+
with the mode and schema types of the Spark DataFrame being collected.
|
615
|
+
|
616
|
+
Args:
|
617
|
+
telemetry_data (dict): The telemetry data dictionary to be updated.
|
618
|
+
param_data (dict): The parameter data dictionary containing the DataFrame information.
|
619
|
+
|
620
|
+
Returns:
|
621
|
+
tuple: A tuple containing the event name and the updated telemetry data dictionary.
|
622
|
+
|
623
|
+
"""
|
624
|
+
telemetry_data[MODE_KEY] = CheckpointMode.DATAFRAME.value
|
625
|
+
try:
|
626
|
+
if _is_spark_dataframe(param_data.get(DF_PARAM_NAME)):
|
627
|
+
telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
|
628
|
+
param_data.get(DF_PARAM_NAME)
|
629
|
+
)
|
630
|
+
return DATAFRAME_COLLECTION_DF, telemetry_data
|
631
|
+
except Exception:
|
632
|
+
return DATAFRAME_COLLECTION_ERROR, telemetry_data
|
633
|
+
|
634
|
+
|
635
|
+
def assert_return_event(telemetry_data: dict, param_data: dict) -> tuple[str, dict]:
|
636
|
+
"""Handle telemetry event for asserting return values.
|
637
|
+
|
638
|
+
Args:
|
639
|
+
telemetry_data (dict): The telemetry data dictionary.
|
640
|
+
param_data (dict): The parameter data dictionary.
|
641
|
+
|
642
|
+
Returns:
|
643
|
+
tuple: A tuple containing the event name and telemetry data.
|
644
|
+
|
645
|
+
"""
|
646
|
+
if param_data.get(STATUS_KEY) is not None:
|
647
|
+
telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY, None)
|
648
|
+
try:
|
649
|
+
if _is_snowpark_dataframe(
|
650
|
+
param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
|
651
|
+
) and _is_spark_dataframe(param_data.get(SPARK_RESULTS_PARAM_NAME)):
|
652
|
+
telemetry_data[SNOWFLAKE_SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
|
653
|
+
param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
|
654
|
+
)
|
655
|
+
telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
|
656
|
+
param_data.get(SPARK_RESULTS_PARAM_NAME)
|
657
|
+
)
|
658
|
+
return DATAFRAME_VALIDATOR_MIRROR, telemetry_data
|
659
|
+
else:
|
660
|
+
return VALUE_VALIDATOR_MIRROR, telemetry_data
|
661
|
+
except Exception:
|
662
|
+
if _is_snowpark_dataframe(param_data.get(SNOWPARK_RESULTS_PARAM_NAME)):
|
663
|
+
telemetry_data[SNOWFLAKE_SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
|
664
|
+
param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
|
665
|
+
)
|
666
|
+
if _is_spark_dataframe(param_data.get(SPARK_RESULTS_PARAM_NAME)):
|
667
|
+
telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
|
668
|
+
param_data.get(SPARK_RESULTS_PARAM_NAME)
|
669
|
+
)
|
670
|
+
return DATAFRAME_VALIDATOR_ERROR, telemetry_data
|
671
|
+
|
672
|
+
|
673
|
+
def dataframe_strategy_event(
|
674
|
+
telemetry_data: dict, param_data: dict, telemetry_m: TelemetryManager
|
675
|
+
) -> tuple[Optional[str], Optional[dict]]:
|
676
|
+
"""Handle telemetry event for dataframe strategy.
|
677
|
+
|
678
|
+
Args:
|
679
|
+
telemetry_data (dict): The telemetry data dictionary.
|
680
|
+
param_data (dict): The parameter data dictionary.
|
681
|
+
telemetry_m (TelemetryManager): The telemetry manager.
|
682
|
+
|
683
|
+
Returns:
|
684
|
+
tuple: A tuple containing the event name and telemetry data.
|
685
|
+
|
686
|
+
"""
|
687
|
+
try:
|
688
|
+
test_function_name = inspect.stack()[2].function
|
689
|
+
is_logged = telemetry_m.sc_is_hypothesis_event_logged((test_function_name, 0))
|
690
|
+
if not is_logged:
|
691
|
+
schema_param = param_data.get(DATAFRAME_STRATEGY_SCHEMA_PARAM_NAME)
|
692
|
+
if isinstance(schema_param, str):
|
693
|
+
json_data = get_load_json(schema_param)["custom_data"]["columns"]
|
694
|
+
telemetry_data[SCHEMA_TYPES_KEY] = [
|
695
|
+
column["type"] for column in json_data
|
696
|
+
]
|
697
|
+
else:
|
698
|
+
schema_types = []
|
699
|
+
for schema_type in schema_param.columns.values():
|
700
|
+
if schema_type.dtype is not None:
|
701
|
+
schema_types.append(str(schema_type.dtype))
|
702
|
+
if schema_types:
|
703
|
+
telemetry_data[SCHEMA_TYPES_KEY] = schema_types
|
704
|
+
telemetry_m.sc_hypothesis_input_events.append((test_function_name, 0))
|
705
|
+
if None in telemetry_data[SCHEMA_TYPES_KEY]:
|
706
|
+
telemetry_m.sc_log_error(HYPOTHESIS_INPUT_SCHEMA_ERROR, telemetry_data)
|
707
|
+
else:
|
708
|
+
telemetry_m.sc_log_info(HYPOTHESIS_INPUT_SCHEMA, telemetry_data)
|
709
|
+
telemetry_m.sc_send_batch(telemetry_m.sc_log_batch)
|
710
|
+
return None, None
|
711
|
+
except Exception:
|
712
|
+
test_function_name = inspect.stack()[2].function
|
713
|
+
is_logged = telemetry_m.sc_is_hypothesis_event_logged((test_function_name, 1))
|
714
|
+
if not is_logged:
|
715
|
+
telemetry_m.sc_hypothesis_input_events.append((test_function_name, 0))
|
716
|
+
telemetry_m.sc_log_error(HYPOTHESIS_INPUT_SCHEMA_ERROR, telemetry_data)
|
717
|
+
telemetry_m.sc_send_batch(telemetry_m.sc_log_batch)
|
718
|
+
return None, None
|
719
|
+
|
720
|
+
|
721
|
+
def compare_data_event(telemetry_data: dict, param_data: dict) -> tuple[str, dict]:
|
722
|
+
"""Handle telemetry event for comparing data.
|
723
|
+
|
724
|
+
This function processes telemetry data for a data comparison event. It updates the telemetry data
|
725
|
+
with the mode, status, and schema types of the Snowflake DataFrame being compared.
|
726
|
+
|
727
|
+
Args:
|
728
|
+
telemetry_data (dict): The telemetry data dictionary to be updated.
|
729
|
+
param_data (dict): The parameter data dictionary containing the DataFrame and status information.
|
730
|
+
|
731
|
+
Returns:
|
732
|
+
tuple: A tuple containing the event name and the updated telemetry data dictionary.
|
733
|
+
|
734
|
+
"""
|
735
|
+
telemetry_data[MODE_KEY] = CheckpointMode.DATAFRAME.value
|
736
|
+
telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY, None)
|
737
|
+
try:
|
738
|
+
telemetry_data[SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
|
739
|
+
param_data.get("df")
|
740
|
+
)
|
741
|
+
return DATAFRAME_VALIDATOR_DF, telemetry_data
|
742
|
+
except Exception:
|
743
|
+
return DATAFRAME_VALIDATOR_ERROR, telemetry_data
|
744
|
+
|
745
|
+
|
746
|
+
def handle_result(
|
747
|
+
func_name: str,
|
748
|
+
result: Any,
|
749
|
+
param_data: dict,
|
750
|
+
multiple_return: bool,
|
751
|
+
telemetry_m: TelemetryManager,
|
752
|
+
return_indexes: Optional[list[tuple[str, int]]] = None,
|
753
|
+
) -> tuple[Optional[str], Optional[dict]]:
|
754
|
+
"""Handle the result of the function and collect telemetry data.
|
755
|
+
|
756
|
+
Args:
|
757
|
+
func_name (str): The name of the function.
|
758
|
+
result: The result of the function.
|
759
|
+
param_data (dict): The extracted parameters.
|
760
|
+
multiple_return (bool): Whether the function returns multiple values.
|
761
|
+
telemetry_m (TelemetryManager): The telemetry manager.
|
762
|
+
return_indexes (list[tuple[str, int]]): The list of return values to report. Defaults to None.
|
763
|
+
|
764
|
+
Returns:
|
765
|
+
tuple: A tuple containing the event name (str) and telemetry data (dict).
|
766
|
+
|
767
|
+
"""
|
768
|
+
if result is not None and return_indexes is not None:
|
769
|
+
if multiple_return:
|
770
|
+
for name, index in return_indexes:
|
771
|
+
param_data[name] = result[index]
|
772
|
+
else:
|
773
|
+
param_data[return_indexes[0][0]] = result[return_indexes[0][1]]
|
774
|
+
|
775
|
+
telemetry_data = {
|
776
|
+
FUNCTION_KEY: func_name,
|
777
|
+
}
|
778
|
+
|
779
|
+
telemetry_event = None
|
780
|
+
data = None
|
781
|
+
if func_name == "_check_dataframe_schema":
|
782
|
+
telemetry_event, data = check_dataframe_schema_event(telemetry_data, param_data)
|
783
|
+
elif func_name in ["check_output_schema", "check_input_schema"]:
|
784
|
+
telemetry_event, data = check_output_or_input_schema_event(
|
785
|
+
telemetry_data, param_data
|
786
|
+
)
|
787
|
+
if func_name == "_compare_data":
|
788
|
+
telemetry_event, data = compare_data_event(telemetry_data, param_data)
|
789
|
+
elif func_name == "_collect_dataframe_checkpoint_mode_schema":
|
790
|
+
telemetry_event, data = collect_dataframe_checkpoint_mode_schema_event(
|
791
|
+
telemetry_data, param_data
|
792
|
+
)
|
793
|
+
elif func_name == "_collect_dataframe_checkpoint_mode_dataframe":
|
794
|
+
telemetry_event, data = collect_dataframe_checkpoint_mode_dataframe_event(
|
795
|
+
telemetry_data, param_data
|
796
|
+
)
|
797
|
+
elif func_name == "_assert_return":
|
798
|
+
telemetry_event, data = assert_return_event(telemetry_data, param_data)
|
799
|
+
elif func_name == "dataframe_strategy":
|
800
|
+
telemetry_event, data = dataframe_strategy_event(
|
801
|
+
telemetry_data, param_data, telemetry_m
|
802
|
+
)
|
803
|
+
return telemetry_event, data
|
804
|
+
|
805
|
+
|
806
|
+
fn = TypeVar("fn", bound=Callable)
|
807
|
+
|
808
|
+
|
809
|
+
def report_telemetry(
|
810
|
+
params_list: list[str] = None,
|
811
|
+
return_indexes: list[tuple[str, int]] = None,
|
812
|
+
multiple_return: bool = False,
|
813
|
+
) -> Callable[[fn], fn]:
|
814
|
+
"""Report telemetry events for a function.
|
815
|
+
|
816
|
+
Args:
|
817
|
+
params_list (list[str], optional): The list of parameters to report. Defaults to None.
|
818
|
+
return_indexes (list[tuple[str, int]], optional): The list of return values to report. Defaults to None.
|
819
|
+
multiple_return (bool, optional): Whether the function returns multiple values. Defaults to False.
|
820
|
+
|
821
|
+
Returns:
|
822
|
+
Callable[[fn], fn]: The decorator function.
|
823
|
+
|
824
|
+
"""
|
825
|
+
|
826
|
+
def report_telemetry_decorator(func):
|
827
|
+
func_name = func.__name__
|
828
|
+
|
829
|
+
@wraps(func)
|
830
|
+
def wrapper(*args, **kwargs):
|
831
|
+
func_exception = None
|
832
|
+
result = None
|
833
|
+
try:
|
834
|
+
result = func(*args, **kwargs)
|
835
|
+
except Exception as err:
|
836
|
+
func_exception = err
|
837
|
+
|
838
|
+
if os.getenv("SNOWPARK_CHECKPOINTS_TELEMETRY_ENABLED") == "false":
|
839
|
+
return result
|
840
|
+
telemetry_event = None
|
841
|
+
data = None
|
842
|
+
telemetry_m = None
|
843
|
+
try:
|
844
|
+
param_data = extract_parameters(func, args, kwargs, params_list)
|
845
|
+
telemetry_m = get_telemetry_manager()
|
846
|
+
telemetry_event, data = handle_result(
|
847
|
+
func_name,
|
848
|
+
result,
|
849
|
+
param_data,
|
850
|
+
multiple_return,
|
851
|
+
telemetry_m,
|
852
|
+
return_indexes,
|
853
|
+
)
|
854
|
+
except Exception:
|
855
|
+
pass
|
856
|
+
finally:
|
857
|
+
if func_exception is not None:
|
858
|
+
if telemetry_m is not None:
|
859
|
+
telemetry_m.sc_log_error(telemetry_event, data)
|
860
|
+
raise func_exception
|
861
|
+
if telemetry_m is not None:
|
862
|
+
telemetry_m.sc_log_info(telemetry_event, data)
|
863
|
+
|
864
|
+
return result
|
865
|
+
|
866
|
+
return wrapper
|
867
|
+
|
868
|
+
return report_telemetry_decorator
|
869
|
+
|
870
|
+
|
871
|
+
# Constants for telemetry
|
872
|
+
DATAFRAME_COLLECTION_SCHEMA = "DataFrame_Collection_Schema"
|
873
|
+
DATAFRAME_COLLECTION_DF = "DataFrame_Collection_DF"
|
874
|
+
DATAFRAME_VALIDATOR_MIRROR = "DataFrame_Validator_Mirror"
|
875
|
+
VALUE_VALIDATOR_MIRROR = "Value_Validator_Mirror"
|
876
|
+
DATAFRAME_VALIDATOR_SCHEMA = "DataFrame_Validator_Schema"
|
877
|
+
DATAFRAME_VALIDATOR_DF = "DataFrame_Validator_DF"
|
878
|
+
HYPOTHESIS_INPUT_SCHEMA = "Hypothesis_Input_Schema"
|
879
|
+
DATAFRAME_COLLECTION_ERROR = "DataFrame_Collection_Error"
|
880
|
+
DATAFRAME_VALIDATOR_ERROR = "DataFrame_Validator_Error"
|
881
|
+
HYPOTHESIS_INPUT_SCHEMA_ERROR = "Hypothesis_Input_Schema_Error"
|
882
|
+
|
883
|
+
FUNCTION_KEY = "function"
|
884
|
+
STATUS_KEY = "status"
|
885
|
+
SCHEMA_TYPES_KEY = "schema_types"
|
886
|
+
ERROR_KEY = "error"
|
887
|
+
MODE_KEY = "mode"
|
888
|
+
SNOWFLAKE_SCHEMA_TYPES_KEY = "snowflake_schema_types"
|
889
|
+
SPARK_SCHEMA_TYPES_KEY = "spark_schema_types"
|
890
|
+
|
891
|
+
DATAFRAME_STRATEGY_SCHEMA_PARAM_NAME = "schema"
|
892
|
+
PANDERA_SCHEMA_PARAM_NAME = "pandera_schema"
|
893
|
+
SNOWPARK_RESULTS_PARAM_NAME = "snowpark_results"
|
894
|
+
SPARK_RESULTS_PARAM_NAME = "spark_results"
|
895
|
+
DF_PARAM_NAME = "df"
|
896
|
+
|
897
|
+
|
898
|
+
class CheckpointMode(IntEnum):
|
899
|
+
SCHEMA = 1
|
900
|
+
DATAFRAME = 2
|