snowpark-checkpoints-validators 0.2.0rc1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. snowflake/snowpark_checkpoints/__init__.py +44 -0
  2. snowflake/snowpark_checkpoints/__version__.py +16 -0
  3. snowflake/snowpark_checkpoints/checkpoint.py +580 -0
  4. snowflake/snowpark_checkpoints/errors.py +60 -0
  5. snowflake/snowpark_checkpoints/io_utils/__init__.py +26 -0
  6. snowflake/snowpark_checkpoints/io_utils/io_default_strategy.py +57 -0
  7. snowflake/snowpark_checkpoints/io_utils/io_env_strategy.py +133 -0
  8. snowflake/snowpark_checkpoints/io_utils/io_file_manager.py +76 -0
  9. snowflake/snowpark_checkpoints/job_context.py +128 -0
  10. snowflake/snowpark_checkpoints/singleton.py +23 -0
  11. snowflake/snowpark_checkpoints/snowpark_sampler.py +124 -0
  12. snowflake/snowpark_checkpoints/spark_migration.py +255 -0
  13. snowflake/snowpark_checkpoints/utils/__init__.py +14 -0
  14. snowflake/snowpark_checkpoints/utils/constants.py +134 -0
  15. snowflake/snowpark_checkpoints/utils/extra_config.py +132 -0
  16. snowflake/snowpark_checkpoints/utils/logging_utils.py +67 -0
  17. snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +399 -0
  18. snowflake/snowpark_checkpoints/utils/supported_types.py +65 -0
  19. snowflake/snowpark_checkpoints/utils/telemetry.py +939 -0
  20. snowflake/snowpark_checkpoints/utils/utils_checks.py +398 -0
  21. snowflake/snowpark_checkpoints/validation_result_metadata.py +159 -0
  22. snowflake/snowpark_checkpoints/validation_results.py +49 -0
  23. snowpark_checkpoints_validators-0.3.0.dist-info/METADATA +325 -0
  24. snowpark_checkpoints_validators-0.3.0.dist-info/RECORD +26 -0
  25. snowpark_checkpoints_validators-0.2.0rc1.dist-info/METADATA +0 -514
  26. snowpark_checkpoints_validators-0.2.0rc1.dist-info/RECORD +0 -4
  27. {snowpark_checkpoints_validators-0.2.0rc1.dist-info → snowpark_checkpoints_validators-0.3.0.dist-info}/WHEEL +0 -0
  28. {snowpark_checkpoints_validators-0.2.0rc1.dist-info → snowpark_checkpoints_validators-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,939 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import atexit
17
+ import datetime
18
+ import hashlib
19
+ import inspect
20
+ import json
21
+ import os
22
+ import re
23
+
24
+ from contextlib import suppress
25
+ from enum import IntEnum
26
+ from functools import wraps
27
+ from pathlib import Path
28
+ from platform import python_version
29
+ from sys import platform
30
+ from typing import Any, Callable, Optional, TypeVar
31
+ from uuid import getnode
32
+
33
+ from snowflake.connector.description import PLATFORM as CONNECTOR_PLATFORM
34
+ from snowflake.connector.telemetry import TelemetryClient
35
+ from snowflake.snowpark import VERSION as SNOWPARK_VERSION
36
+ from snowflake.snowpark import dataframe as snowpark_dataframe
37
+ from snowflake.snowpark.session import Session
38
+ from snowflake.snowpark_checkpoints.io_utils.io_file_manager import (
39
+ get_io_file_manager,
40
+ )
41
+
42
+
43
+ try:
44
+ """
45
+ The following imports are used to log telemetry events in the Snowflake Connector.
46
+ """
47
+ from snowflake.connector import (
48
+ SNOWFLAKE_CONNECTOR_VERSION,
49
+ time_util,
50
+ )
51
+ from snowflake.connector.constants import DIRS as SNOWFLAKE_DIRS
52
+ from snowflake.connector.network import SnowflakeRestful
53
+ except Exception:
54
+ """
55
+ Set default import values for the Snowflake Connector when using snowpark-checkpoints in stored procedures.
56
+ """
57
+ SNOWFLAKE_CONNECTOR_VERSION = ""
58
+ time_util = None
59
+ SNOWFLAKE_DIRS = ""
60
+ SnowflakeRestful = None
61
+
62
+
63
+ try:
64
+ from pyspark.sql import dataframe as spark_dataframe
65
+
66
+ def _is_spark_dataframe(df: Any) -> bool:
67
+ return isinstance(df, spark_dataframe.DataFrame)
68
+
69
+ def _get_spark_schema_types(df: spark_dataframe.DataFrame) -> list[str]:
70
+ return [str(schema_type.dataType) for schema_type in df.schema.fields]
71
+
72
+ except Exception:
73
+
74
+ def _is_spark_dataframe(df: Any):
75
+ pass
76
+
77
+ def _get_spark_schema_types(df: Any):
78
+ pass
79
+
80
+
81
+ VERSION_VARIABLE_PATTERN = r"^__version__ = ['\"]([^'\"]*)['\"]"
82
+ VERSION_FILE_NAME = "__version__.py"
83
+
84
+
85
+ class TelemetryManager(TelemetryClient):
86
+ def __init__(
87
+ self, rest: Optional[SnowflakeRestful] = None, is_telemetry_enabled: bool = True
88
+ ):
89
+ """TelemetryManager class to log telemetry events."""
90
+ super().__init__(rest)
91
+ self.sc_folder_path = (
92
+ Path(SNOWFLAKE_DIRS.user_config_path) / "snowpark-checkpoints-telemetry"
93
+ )
94
+ self.sc_sf_path_telemetry = "/telemetry/send"
95
+ self.sc_flush_size = 25
96
+ self.sc_is_enabled = is_telemetry_enabled
97
+ self.sc_is_testing = self._sc_is_telemetry_testing()
98
+ self.sc_memory_limit = 5 * 1024 * 1024
99
+ self._sc_upload_local_telemetry()
100
+ self.sc_log_batch = []
101
+ self.sc_hypothesis_input_events = []
102
+ self.sc_version = _get_version()
103
+ if rest:
104
+ atexit.register(self._sc_close_at_exit)
105
+
106
+ def set_sc_output_path(self, path: Path) -> None:
107
+ """Set the output path for testing.
108
+
109
+ Args:
110
+ path: path to write telemetry.
111
+
112
+ """
113
+ get_io_file_manager().mkdir(str(path), exist_ok=True)
114
+ self.sc_folder_path = path
115
+
116
+ def sc_log_error(
117
+ self, event_name: str, parameters_info: Optional[dict] = None
118
+ ) -> None:
119
+ """Log an error telemetry event.
120
+
121
+ Args:
122
+ event_name (str): The name of the event.
123
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
124
+
125
+ """
126
+ if event_name is not None:
127
+ self._sc_log_telemetry(event_name, "error", parameters_info)
128
+
129
+ def sc_log_info(
130
+ self, event_name: str, parameters_info: Optional[dict] = None
131
+ ) -> None:
132
+ """Log an information telemetry event.
133
+
134
+ Args:
135
+ event_name (str): The name of the event.
136
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
137
+
138
+ """
139
+ if event_name is not None:
140
+ self._sc_log_telemetry(event_name, "info", parameters_info)
141
+
142
+ def _sc_log_telemetry(
143
+ self, event_name: str, event_type: str, parameters_info: Optional[dict] = None
144
+ ) -> dict:
145
+ """Log a telemetry event if is enabled.
146
+
147
+ Args:
148
+ event_name (str): The name of the event.
149
+ event_type (str): The type of the event (e.g., "error", "info").
150
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
151
+
152
+ Returns:
153
+ dict: The logged event.
154
+
155
+ """
156
+ if not self.sc_is_enabled:
157
+ return {}
158
+ event = _generate_event(
159
+ event_name, event_type, parameters_info, self.sc_version
160
+ )
161
+ self._sc_add_log_to_batch(event)
162
+ return event
163
+
164
+ def _sc_add_log_to_batch(self, event: dict) -> None:
165
+ """Add a log event to the batch. If the batch is full, send the events to the API.
166
+
167
+ Args:
168
+ event (dict): The event to add.
169
+
170
+ """
171
+ self.sc_log_batch.append(event)
172
+ if self.sc_is_testing:
173
+ self._sc_write_telemetry(self.sc_log_batch)
174
+ self.sc_log_batch = []
175
+ return
176
+
177
+ if len(self.sc_log_batch) >= self.sc_flush_size:
178
+ self.sc_send_batch(self.sc_log_batch)
179
+ self.sc_log_batch = []
180
+
181
+ def sc_send_batch(self, to_sent: list) -> bool:
182
+ """Send a request to the API to upload the events. If not have connection, write the events to local folder.
183
+
184
+ Args:
185
+ to_sent (list): The batch of events to send.
186
+
187
+ Returns:
188
+ bool: True if the batch was sent successfully, False otherwise.
189
+
190
+ """
191
+ if not self.sc_is_enabled:
192
+ return False
193
+ if self._rest is None:
194
+ self._sc_write_telemetry(to_sent)
195
+ self.sc_log_batch = []
196
+ return False
197
+ if to_sent == []:
198
+ return False
199
+ body = {"logs": to_sent}
200
+ ret = self._rest.request(
201
+ self.sc_sf_path_telemetry,
202
+ body=body,
203
+ method="post",
204
+ client=None,
205
+ timeout=5,
206
+ )
207
+ if not ret.get("success"):
208
+ self._sc_write_telemetry(to_sent)
209
+ self.sc_log_batch = []
210
+ return False
211
+ return True
212
+
213
+ def _sc_write_telemetry(self, batch: list) -> None:
214
+ """Write telemetry events to local folder. If the folder is full, free up space for the new events.
215
+
216
+ Args:
217
+ batch (list): The batch of events to write.
218
+
219
+ """
220
+ try:
221
+ get_io_file_manager().mkdir(str(self.sc_folder_path), exist_ok=True)
222
+ for event in batch:
223
+ message = event.get("message")
224
+ if message is not None:
225
+ file_path = (
226
+ self.sc_folder_path
227
+ / f'{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")}'
228
+ f'_telemetry_{message.get("type")}.json'
229
+ )
230
+ json_content = self._sc_validate_folder_space(event)
231
+ get_io_file_manager().write(str(file_path), json_content)
232
+ except Exception:
233
+ pass
234
+
235
+ def _sc_validate_folder_space(self, event: dict) -> str:
236
+ """Validate and manage folder space for the new telemetry events.
237
+
238
+ Args:
239
+ event (dict): The event to validate.
240
+
241
+ Returns:
242
+ str: The JSON content of the event.
243
+
244
+ """
245
+ json_content = json.dumps(event, indent=4, sort_keys=True)
246
+ new_json_file_size = len(json_content.encode("utf-8"))
247
+ telemetry_folder = self.sc_folder_path
248
+ folder_size = _get_folder_size(telemetry_folder)
249
+ if folder_size + new_json_file_size > self.sc_memory_limit:
250
+ _free_up_space(telemetry_folder, self.sc_memory_limit - new_json_file_size)
251
+ return json_content
252
+
253
+ def _sc_upload_local_telemetry(self) -> None:
254
+ """Send a request to the API to upload the local telemetry events."""
255
+ if not self.sc_is_enabled or self.sc_is_testing or not self._rest:
256
+ return
257
+ batch = []
258
+ for file in get_io_file_manager().ls(f"{self.sc_folder_path}/*.json"):
259
+ json_content = get_io_file_manager().read(file)
260
+ data_dict = json.loads(json_content)
261
+ batch.append(data_dict)
262
+ if batch == []:
263
+ return
264
+ body = {"logs": batch}
265
+ ret = self._rest.request(
266
+ self.sc_sf_path_telemetry,
267
+ body=body,
268
+ method="post",
269
+ client=None,
270
+ timeout=5,
271
+ )
272
+ if ret.get("success"):
273
+ for file_path in get_io_file_manager().ls(f"{self.sc_folder_path}/*.json"):
274
+ file = get_io_file_manager().telemetry_path_files(file_path)
275
+ file.unlink()
276
+
277
+ def _sc_is_telemetry_testing(self) -> bool:
278
+ is_testing = os.getenv("SNOWPARK_CHECKPOINTS_TELEMETRY_TESTING") == "true"
279
+ if is_testing:
280
+ local_telemetry_path = (
281
+ Path(get_io_file_manager().getcwd())
282
+ / "snowpark-checkpoints-output"
283
+ / "telemetry"
284
+ )
285
+ self.set_sc_output_path(local_telemetry_path)
286
+ self.sc_is_enabled = True
287
+ return is_testing
288
+
289
+ def _sc_is_telemetry_manager(self) -> bool:
290
+ """Check if the class is telemetry manager.
291
+
292
+ Returns:
293
+ bool: True if the class is telemetry manager, False otherwise.
294
+
295
+ """
296
+ return True
297
+
298
+ def sc_is_hypothesis_event_logged(self, event_name: tuple[str, int]) -> bool:
299
+ """Check if a Hypothesis event is logged.
300
+
301
+ Args:
302
+ event_name (tuple[str, int]): A tuple containing the name of the event and an integer identifier
303
+ (0 for info, 1 for error).
304
+
305
+ Returns:
306
+ bool: True if the event is logged, False otherwise.
307
+
308
+ """
309
+ return event_name in self.sc_hypothesis_input_events
310
+
311
+ def _sc_close(self) -> None:
312
+ """Close the telemetry manager and upload collected events.
313
+
314
+ This function closes the telemetry manager, uploads any collected events,
315
+ and performs any necessary cleanup to ensure no data is lost.
316
+ """
317
+ atexit.unregister(self._sc_close_at_exit)
318
+ if self.sc_log_batch and self.sc_is_enabled and not self.sc_is_testing:
319
+ self.sc_send_batch(self.sc_log_batch)
320
+
321
+ def _sc_close_at_exit(self) -> None:
322
+ """Close the telemetry manager at exit and upload collected events.
323
+
324
+ This function ensures that the telemetry manager is closed and all collected events
325
+ are uploaded when the program exits, preventing data loss.
326
+ """
327
+ with suppress(Exception):
328
+ self._sc_close()
329
+
330
+
331
+ def _generate_event(
332
+ event_name: str,
333
+ event_type: str,
334
+ parameters_info: Optional[dict] = None,
335
+ sc_version: Optional[str] = None,
336
+ ) -> dict:
337
+ """Generate a telemetry event.
338
+
339
+ Args:
340
+ event_name (str): The name of the event.
341
+ event_type (str): The type of the event (e.g., "error", "info").
342
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
343
+ sc_version (str, optional): The version of the package. Defaults to None.
344
+
345
+ Returns:
346
+ dict: The generated event.
347
+
348
+ """
349
+ metadata = _get_metadata()
350
+ if sc_version is not None:
351
+ metadata["snowpark_checkpoints_version"] = sc_version
352
+ message = {
353
+ "event_type": event_type,
354
+ "type": "snowpark-checkpoints",
355
+ "event_name": event_name,
356
+ "driver_type": "PythonConnector",
357
+ "driver_version": SNOWFLAKE_CONNECTOR_VERSION,
358
+ "metadata": metadata,
359
+ "data": json.dumps(parameters_info or {}),
360
+ }
361
+ timestamp = time_util.get_time_millis()
362
+ event_base = {"message": message, "timestamp": str(timestamp)}
363
+
364
+ return event_base
365
+
366
+
367
+ def _get_metadata() -> dict:
368
+ """Get metadata for telemetry events.
369
+
370
+ Returns:
371
+ dict: The metadata including OS version, Python version, and device ID.
372
+
373
+ """
374
+ return {
375
+ "os_version": platform,
376
+ "python_version": python_version(),
377
+ "snowpark_version": ".".join(str(x) for x in SNOWPARK_VERSION if x is not None),
378
+ "device_id": _get_unique_id(),
379
+ }
380
+
381
+
382
+ def _get_version() -> Optional[str]:
383
+ """Get the version of the package.
384
+
385
+ Returns:
386
+ str: The version of the package.
387
+
388
+ """
389
+ try:
390
+ directory_levels_up = 1
391
+ project_root = Path(__file__).resolve().parents[directory_levels_up]
392
+ version_file_path = project_root / VERSION_FILE_NAME
393
+ content = get_io_file_manager().read(str(version_file_path))
394
+ version_match = re.search(VERSION_VARIABLE_PATTERN, content, re.MULTILINE)
395
+ if version_match:
396
+ return version_match.group(1)
397
+ return None
398
+ except Exception:
399
+ return None
400
+
401
+
402
+ def _get_folder_size(folder_path: Path) -> int:
403
+ """Get the size of a folder. Only considers JSON files.
404
+
405
+ Args:
406
+ folder_path (Path): The path to the folder.
407
+
408
+ Returns:
409
+ int: The size of the folder in bytes.
410
+
411
+ """
412
+ sum_size = 0
413
+ for f in get_io_file_manager().ls(f"{folder_path}/*.json"):
414
+ sum_size += get_io_file_manager().telemetry_path_files(f).stat().st_size
415
+ return sum_size
416
+
417
+
418
+ def _free_up_space(folder_path: Path, max_size: int) -> None:
419
+ """Free up space in a folder by deleting the oldest files. Only considers JSON files.
420
+
421
+ Args:
422
+ folder_path (Path): The path to the folder.
423
+ max_size (int): The maximum allowed size of the folder in bytes.
424
+
425
+ """
426
+ files = sorted(
427
+ get_io_file_manager().ls(f"{folder_path}/*.json"),
428
+ key=lambda f: f.stat().st_mtime,
429
+ )
430
+ current_size = _get_folder_size(folder_path)
431
+ for file_path in files:
432
+ file = get_io_file_manager().telemetry_path_files(file_path)
433
+ if current_size <= max_size:
434
+ break
435
+ current_size -= file.stat().st_size
436
+ file.unlink()
437
+
438
+
439
+ def _get_unique_id() -> str:
440
+ """Get a unique device ID. The ID is generated based on the hashed MAC address.
441
+
442
+ Returns:
443
+ str: The hashed device ID.
444
+
445
+ """
446
+ node_id_str = str(getnode())
447
+ hashed_id = hashlib.sha256(node_id_str.encode()).hexdigest()
448
+ return hashed_id
449
+
450
+
451
+ def get_telemetry_manager() -> TelemetryManager:
452
+ """Get the telemetry manager.
453
+
454
+ Returns:
455
+ TelemetryManager: The telemetry manager.
456
+
457
+ """
458
+ try:
459
+ connection = Session.builder.getOrCreate().connection
460
+ if not hasattr(connection._telemetry, "_sc_is_telemetry_manager"):
461
+ connection._telemetry = TelemetryManager(
462
+ connection._rest, connection.telemetry_enabled
463
+ )
464
+ return connection._telemetry
465
+ except Exception:
466
+ telemetry_manager = TelemetryManager(None, is_telemetry_enabled=True)
467
+ telemetry_manager.sc_flush_size = 1
468
+ return telemetry_manager
469
+
470
+
471
+ def get_snowflake_schema_types(df: snowpark_dataframe.DataFrame) -> list[str]:
472
+ """Extract the data types of the schema fields from a Snowflake DataFrame.
473
+
474
+ Args:
475
+ df (snowpark_dataframe.DataFrame): The Snowflake DataFrame.
476
+
477
+ Returns:
478
+ list[str]: A list of data type names of the schema fields.
479
+
480
+ """
481
+ return [str(schema_type.datatype) for schema_type in df.schema.fields]
482
+
483
+
484
+ def _is_snowpark_dataframe(df: Any) -> bool:
485
+ """Check if the given dataframe is a Snowpark dataframe.
486
+
487
+ Args:
488
+ df: The dataframe to check.
489
+
490
+ Returns:
491
+ bool: True if the dataframe is a Snowpark dataframe, False otherwise.
492
+
493
+ """
494
+ return isinstance(df, snowpark_dataframe.DataFrame)
495
+
496
+
497
+ def get_load_json(json_schema: str) -> dict:
498
+ """Load and parse a JSON schema file.
499
+
500
+ Args:
501
+ json_schema (str): The path to the JSON schema file.
502
+
503
+ Returns:
504
+ dict: The parsed JSON content.
505
+
506
+ Raises:
507
+ ValueError: If there is an error reading or parsing the JSON file.
508
+
509
+ """
510
+ try:
511
+ file_content = get_io_file_manager().read(json_schema, encoding="utf-8")
512
+ return json.loads(file_content)
513
+ except (OSError, json.JSONDecodeError) as e:
514
+ raise ValueError(f"Error reading JSON schema file: {e}") from None
515
+
516
+
517
+ def _is_in_stored_procedure() -> bool:
518
+ """Check if the code is running in a stored procedure.
519
+
520
+ Returns:
521
+ bool: True if the code is running in a stored procedure, False otherwise.
522
+
523
+ """
524
+ return CONNECTOR_PLATFORM == "XP"
525
+
526
+
527
+ def extract_parameters(
528
+ func: Callable, args: tuple, kwargs: dict, params_list: Optional[list[str]]
529
+ ) -> dict:
530
+ """Extract parameters from the function arguments.
531
+
532
+ Args:
533
+ func (Callable): The function being decorated.
534
+ args (tuple): The positional arguments passed to the function.
535
+ kwargs (dict): The keyword arguments passed to the function.
536
+ params_list (list[str]): The list of parameters to extract.
537
+
538
+ Returns:
539
+ dict: A dictionary of extracted parameters.
540
+
541
+ """
542
+ parameters = inspect.signature(func).parameters
543
+ param_data = {}
544
+ if params_list:
545
+ for _, param in enumerate(params_list):
546
+ if len(args) > 0:
547
+ index = list(parameters.keys()).index(param)
548
+ param_data[param] = args[index]
549
+ else:
550
+ if kwargs[param]:
551
+ param_data[param] = kwargs[param]
552
+ return param_data
553
+
554
+
555
+ def check_dataframe_schema_event(
556
+ telemetry_data: dict, param_data: dict
557
+ ) -> tuple[str, dict]:
558
+ """Handle telemetry event for checking dataframe schema.
559
+
560
+ Args:
561
+ telemetry_data (dict): The telemetry data dictionary.
562
+ param_data (dict): The parameter data dictionary.
563
+
564
+ Returns:
565
+ tuple: A tuple containing the event name and telemetry data.
566
+
567
+ """
568
+ telemetry_data[MODE_KEY] = CheckpointMode.SCHEMA.value
569
+ try:
570
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY)
571
+ pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
572
+ schema_types = []
573
+ for schema_type in pandera_schema.columns.values():
574
+ if schema_type.dtype is not None:
575
+ schema_types.append(str(schema_type.dtype))
576
+ if schema_types:
577
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
578
+ return DATAFRAME_VALIDATOR_SCHEMA, telemetry_data
579
+ except Exception:
580
+ if param_data.get(STATUS_KEY):
581
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY)
582
+ pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
583
+ if pandera_schema:
584
+ schema_types = []
585
+ for schema_type in pandera_schema.columns.values():
586
+ if schema_type.dtype is not None:
587
+ schema_types.append(str(schema_type.dtype))
588
+ if schema_types:
589
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
590
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
591
+
592
+
593
+ def check_output_or_input_schema_event(
594
+ telemetry_data: dict, param_data: dict
595
+ ) -> tuple[str, dict]:
596
+ """Handle telemetry event for checking output or input schema.
597
+
598
+ Args:
599
+ telemetry_data (dict): The telemetry data dictionary.
600
+ param_data (dict): The parameter data dictionary.
601
+
602
+ Returns:
603
+ tuple: A tuple containing the event name and telemetry data.
604
+
605
+ """
606
+ try:
607
+ pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
608
+ schema_types = []
609
+ for schema_type in pandera_schema.columns.values():
610
+ if schema_type.dtype is not None:
611
+ schema_types.append(str(schema_type.dtype))
612
+ if schema_types:
613
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
614
+ return DATAFRAME_VALIDATOR_SCHEMA, telemetry_data
615
+ except Exception:
616
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
617
+
618
+
619
+ def collect_dataframe_checkpoint_mode_schema_event(
620
+ telemetry_data: dict, param_data: dict
621
+ ) -> tuple[str, dict]:
622
+ """Handle telemetry event for collecting dataframe checkpoint mode schema.
623
+
624
+ Args:
625
+ telemetry_data (dict): The telemetry data dictionary.
626
+ param_data (dict): The parameter data dictionary.
627
+
628
+ Returns:
629
+ tuple: A tuple containing the event name and telemetry data.
630
+
631
+ """
632
+ telemetry_data[MODE_KEY] = CheckpointMode.SCHEMA.value
633
+ try:
634
+ schema_types = param_data.get("column_type_dict")
635
+ telemetry_data[SCHEMA_TYPES_KEY] = [
636
+ schema_types[schema_type].dataType.typeName()
637
+ for schema_type in schema_types
638
+ ]
639
+ return DATAFRAME_COLLECTION_SCHEMA, telemetry_data
640
+ except Exception:
641
+ return DATAFRAME_COLLECTION_ERROR, telemetry_data
642
+
643
+
644
+ def collect_dataframe_checkpoint_mode_dataframe_event(
645
+ telemetry_data: dict, param_data: dict
646
+ ) -> tuple[str, dict]:
647
+ """Handle telemetry event for collecting dataframe checkpoint mode dataframe.
648
+
649
+ This function processes telemetry data for a dataframe checkpoint mode event. It updates the telemetry data
650
+ with the mode and schema types of the Spark DataFrame being collected.
651
+
652
+ Args:
653
+ telemetry_data (dict): The telemetry data dictionary to be updated.
654
+ param_data (dict): The parameter data dictionary containing the DataFrame information.
655
+
656
+ Returns:
657
+ tuple: A tuple containing the event name and the updated telemetry data dictionary.
658
+
659
+ """
660
+ telemetry_data[MODE_KEY] = CheckpointMode.DATAFRAME.value
661
+ try:
662
+ if _is_spark_dataframe(param_data.get(DF_PARAM_NAME)):
663
+ telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
664
+ param_data.get(DF_PARAM_NAME)
665
+ )
666
+ return DATAFRAME_COLLECTION_DF, telemetry_data
667
+ except Exception:
668
+ return DATAFRAME_COLLECTION_ERROR, telemetry_data
669
+
670
+
671
+ def assert_return_event(telemetry_data: dict, param_data: dict) -> tuple[str, dict]:
672
+ """Handle telemetry event for asserting return values.
673
+
674
+ Args:
675
+ telemetry_data (dict): The telemetry data dictionary.
676
+ param_data (dict): The parameter data dictionary.
677
+
678
+ Returns:
679
+ tuple: A tuple containing the event name and telemetry data.
680
+
681
+ """
682
+ if param_data.get(STATUS_KEY) is not None:
683
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY, None)
684
+ try:
685
+ if _is_snowpark_dataframe(
686
+ param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
687
+ ) and _is_spark_dataframe(param_data.get(SPARK_RESULTS_PARAM_NAME)):
688
+ telemetry_data[SNOWFLAKE_SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
689
+ param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
690
+ )
691
+ telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
692
+ param_data.get(SPARK_RESULTS_PARAM_NAME)
693
+ )
694
+ return DATAFRAME_VALIDATOR_MIRROR, telemetry_data
695
+ else:
696
+ return VALUE_VALIDATOR_MIRROR, telemetry_data
697
+ except Exception:
698
+ if _is_snowpark_dataframe(param_data.get(SNOWPARK_RESULTS_PARAM_NAME)):
699
+ telemetry_data[SNOWFLAKE_SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
700
+ param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
701
+ )
702
+ if _is_spark_dataframe(param_data.get(SPARK_RESULTS_PARAM_NAME)):
703
+ telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
704
+ param_data.get(SPARK_RESULTS_PARAM_NAME)
705
+ )
706
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
707
+
708
+
709
+ def dataframe_strategy_event(
710
+ telemetry_data: dict, param_data: dict, telemetry_m: TelemetryManager
711
+ ) -> tuple[Optional[str], Optional[dict]]:
712
+ """Handle telemetry event for dataframe strategy.
713
+
714
+ Args:
715
+ telemetry_data (dict): The telemetry data dictionary.
716
+ param_data (dict): The parameter data dictionary.
717
+ telemetry_m (TelemetryManager): The telemetry manager.
718
+
719
+ Returns:
720
+ tuple: A tuple containing the event name and telemetry data.
721
+
722
+ """
723
+ try:
724
+ test_function_name = inspect.stack()[2].function
725
+ is_logged = telemetry_m.sc_is_hypothesis_event_logged((test_function_name, 0))
726
+ if not is_logged:
727
+ schema_param = param_data.get(DATAFRAME_STRATEGY_SCHEMA_PARAM_NAME)
728
+ if isinstance(schema_param, str):
729
+ json_data = get_load_json(schema_param)["custom_data"]["columns"]
730
+ telemetry_data[SCHEMA_TYPES_KEY] = [
731
+ column["type"] for column in json_data
732
+ ]
733
+ else:
734
+ schema_types = []
735
+ for schema_type in schema_param.columns.values():
736
+ if schema_type.dtype is not None:
737
+ schema_types.append(str(schema_type.dtype))
738
+ if schema_types:
739
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
740
+ telemetry_m.sc_hypothesis_input_events.append((test_function_name, 0))
741
+ if None in telemetry_data[SCHEMA_TYPES_KEY]:
742
+ telemetry_m.sc_log_error(HYPOTHESIS_INPUT_SCHEMA_ERROR, telemetry_data)
743
+ else:
744
+ telemetry_m.sc_log_info(HYPOTHESIS_INPUT_SCHEMA, telemetry_data)
745
+ telemetry_m.sc_send_batch(telemetry_m.sc_log_batch)
746
+ return None, None
747
+ except Exception:
748
+ test_function_name = inspect.stack()[2].function
749
+ is_logged = telemetry_m.sc_is_hypothesis_event_logged((test_function_name, 1))
750
+ if not is_logged:
751
+ telemetry_m.sc_hypothesis_input_events.append((test_function_name, 0))
752
+ telemetry_m.sc_log_error(HYPOTHESIS_INPUT_SCHEMA_ERROR, telemetry_data)
753
+ telemetry_m.sc_send_batch(telemetry_m.sc_log_batch)
754
+ return None, None
755
+
756
+
757
+ def compare_data_event(telemetry_data: dict, param_data: dict) -> tuple[str, dict]:
758
+ """Handle telemetry event for comparing data.
759
+
760
+ This function processes telemetry data for a data comparison event. It updates the telemetry data
761
+ with the mode, status, and schema types of the Snowflake DataFrame being compared.
762
+
763
+ Args:
764
+ telemetry_data (dict): The telemetry data dictionary to be updated.
765
+ param_data (dict): The parameter data dictionary containing the DataFrame and status information.
766
+
767
+ Returns:
768
+ tuple: A tuple containing the event name and the updated telemetry data dictionary.
769
+
770
+ """
771
+ telemetry_data[MODE_KEY] = CheckpointMode.DATAFRAME.value
772
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY, None)
773
+ try:
774
+ telemetry_data[SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
775
+ param_data.get("df")
776
+ )
777
+ return DATAFRAME_VALIDATOR_DF, telemetry_data
778
+ except Exception:
779
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
780
+
781
+
782
+ def handle_result(
783
+ func_name: str,
784
+ result: Any,
785
+ param_data: dict,
786
+ multiple_return: bool,
787
+ telemetry_m: TelemetryManager,
788
+ return_indexes: Optional[list[tuple[str, int]]] = None,
789
+ ) -> tuple[Optional[str], Optional[dict]]:
790
+ """Handle the result of the function and collect telemetry data.
791
+
792
+ Args:
793
+ func_name (str): The name of the function.
794
+ result: The result of the function.
795
+ param_data (dict): The extracted parameters.
796
+ multiple_return (bool): Whether the function returns multiple values.
797
+ telemetry_m (TelemetryManager): The telemetry manager.
798
+ return_indexes (list[tuple[str, int]]): The list of return values to report. Defaults to None.
799
+
800
+ Returns:
801
+ tuple: A tuple containing the event name (str) and telemetry data (dict).
802
+
803
+ """
804
+ if result is not None and return_indexes is not None:
805
+ if multiple_return:
806
+ for name, index in return_indexes:
807
+ param_data[name] = result[index]
808
+ else:
809
+ param_data[return_indexes[0][0]] = result[return_indexes[0][1]]
810
+
811
+ telemetry_data = {
812
+ FUNCTION_KEY: func_name,
813
+ }
814
+
815
+ telemetry_event = None
816
+ data = None
817
+ if func_name == "_check_dataframe_schema":
818
+ telemetry_event, data = check_dataframe_schema_event(telemetry_data, param_data)
819
+ elif func_name in ["check_output_schema", "check_input_schema"]:
820
+ telemetry_event, data = check_output_or_input_schema_event(
821
+ telemetry_data, param_data
822
+ )
823
+ if func_name == "_compare_data":
824
+ telemetry_event, data = compare_data_event(telemetry_data, param_data)
825
+ elif func_name == "_collect_dataframe_checkpoint_mode_schema":
826
+ telemetry_event, data = collect_dataframe_checkpoint_mode_schema_event(
827
+ telemetry_data, param_data
828
+ )
829
+ elif func_name == "_collect_dataframe_checkpoint_mode_dataframe":
830
+ telemetry_event, data = collect_dataframe_checkpoint_mode_dataframe_event(
831
+ telemetry_data, param_data
832
+ )
833
+ elif func_name == "_assert_return":
834
+ telemetry_event, data = assert_return_event(telemetry_data, param_data)
835
+ elif func_name == "dataframe_strategy":
836
+ telemetry_event, data = dataframe_strategy_event(
837
+ telemetry_data, param_data, telemetry_m
838
+ )
839
+ return telemetry_event, data
840
+
841
+
842
+ fn = TypeVar("fn", bound=Callable)
843
+
844
+
845
+ def report_telemetry(
846
+ params_list: list[str] = None,
847
+ return_indexes: list[tuple[str, int]] = None,
848
+ multiple_return: bool = False,
849
+ ) -> Callable[[fn], fn]:
850
+ """Report telemetry events for a function.
851
+
852
+ Args:
853
+ params_list (list[str], optional): The list of parameters to report. Defaults to None.
854
+ return_indexes (list[tuple[str, int]], optional): The list of return values to report. Defaults to None.
855
+ multiple_return (bool, optional): Whether the function returns multiple values. Defaults to False.
856
+
857
+ Returns:
858
+ Callable[[fn], fn]: The decorator function.
859
+
860
+ """
861
+
862
+ def report_telemetry_decorator(func):
863
+ func_name = func.__name__
864
+
865
+ @wraps(func)
866
+ def wrapper(*args, **kwargs):
867
+ func_exception = None
868
+ result = None
869
+ try:
870
+ result = func(*args, **kwargs)
871
+ except Exception as err:
872
+ func_exception = err
873
+
874
+ if (
875
+ os.getenv("SNOWPARK_CHECKPOINTS_TELEMETRY_ENABLED") == "false"
876
+ or _is_in_stored_procedure()
877
+ ):
878
+ return result
879
+ telemetry_event = None
880
+ data = None
881
+ telemetry_m = None
882
+ try:
883
+ param_data = extract_parameters(func, args, kwargs, params_list)
884
+ telemetry_m = get_telemetry_manager()
885
+ telemetry_event, data = handle_result(
886
+ func_name,
887
+ result,
888
+ param_data,
889
+ multiple_return,
890
+ telemetry_m,
891
+ return_indexes,
892
+ )
893
+ except Exception:
894
+ pass
895
+ finally:
896
+ if func_exception is not None:
897
+ if telemetry_m is not None:
898
+ telemetry_m.sc_log_error(telemetry_event, data)
899
+ raise func_exception
900
+ if telemetry_m is not None:
901
+ telemetry_m.sc_log_info(telemetry_event, data)
902
+
903
+ return result
904
+
905
+ return wrapper
906
+
907
+ return report_telemetry_decorator
908
+
909
+
910
+ # Constants for telemetry
911
+ DATAFRAME_COLLECTION_SCHEMA = "DataFrame_Collection_Schema"
912
+ DATAFRAME_COLLECTION_DF = "DataFrame_Collection_DF"
913
+ DATAFRAME_VALIDATOR_MIRROR = "DataFrame_Validator_Mirror"
914
+ VALUE_VALIDATOR_MIRROR = "Value_Validator_Mirror"
915
+ DATAFRAME_VALIDATOR_SCHEMA = "DataFrame_Validator_Schema"
916
+ DATAFRAME_VALIDATOR_DF = "DataFrame_Validator_DF"
917
+ HYPOTHESIS_INPUT_SCHEMA = "Hypothesis_Input_Schema"
918
+ DATAFRAME_COLLECTION_ERROR = "DataFrame_Collection_Error"
919
+ DATAFRAME_VALIDATOR_ERROR = "DataFrame_Validator_Error"
920
+ HYPOTHESIS_INPUT_SCHEMA_ERROR = "Hypothesis_Input_Schema_Error"
921
+
922
+ FUNCTION_KEY = "function"
923
+ STATUS_KEY = "status"
924
+ SCHEMA_TYPES_KEY = "schema_types"
925
+ ERROR_KEY = "error"
926
+ MODE_KEY = "mode"
927
+ SNOWFLAKE_SCHEMA_TYPES_KEY = "snowflake_schema_types"
928
+ SPARK_SCHEMA_TYPES_KEY = "spark_schema_types"
929
+
930
+ DATAFRAME_STRATEGY_SCHEMA_PARAM_NAME = "schema"
931
+ PANDERA_SCHEMA_PARAM_NAME = "pandera_schema"
932
+ SNOWPARK_RESULTS_PARAM_NAME = "snowpark_results"
933
+ SPARK_RESULTS_PARAM_NAME = "spark_results"
934
+ DF_PARAM_NAME = "df"
935
+
936
+
937
+ class CheckpointMode(IntEnum):
938
+ SCHEMA = 1
939
+ DATAFRAME = 2