snowpark-checkpoints-validators 0.1.0rc3__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. snowflake/snowpark_checkpoints/__init__.py +34 -0
  2. snowflake/snowpark_checkpoints/checkpoint.py +482 -0
  3. snowflake/snowpark_checkpoints/errors.py +60 -0
  4. snowflake/snowpark_checkpoints/job_context.py +85 -0
  5. snowflake/snowpark_checkpoints/singleton.py +23 -0
  6. snowflake/snowpark_checkpoints/snowpark_sampler.py +99 -0
  7. snowflake/snowpark_checkpoints/spark_migration.py +222 -0
  8. snowflake/snowpark_checkpoints/utils/__init__.py +14 -0
  9. snowflake/snowpark_checkpoints/utils/checkpoint_logger.py +52 -0
  10. snowflake/snowpark_checkpoints/utils/constants.py +134 -0
  11. snowflake/snowpark_checkpoints/utils/extra_config.py +84 -0
  12. snowflake/snowpark_checkpoints/utils/pandera_check_manager.py +358 -0
  13. snowflake/snowpark_checkpoints/utils/supported_types.py +65 -0
  14. snowflake/snowpark_checkpoints/utils/telemetry.py +900 -0
  15. snowflake/snowpark_checkpoints/utils/utils_checks.py +372 -0
  16. snowflake/snowpark_checkpoints/validation_result_metadata.py +116 -0
  17. snowflake/snowpark_checkpoints/validation_results.py +49 -0
  18. {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.1.dist-info}/METADATA +4 -6
  19. snowpark_checkpoints_validators-0.1.1.dist-info/RECORD +21 -0
  20. snowpark_checkpoints_validators-0.1.0rc3.dist-info/RECORD +0 -4
  21. {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.1.dist-info}/WHEEL +0 -0
  22. {snowpark_checkpoints_validators-0.1.0rc3.dist-info → snowpark_checkpoints_validators-0.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,900 @@
1
+ # Copyright 2025 Snowflake Inc.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import atexit
17
+ import datetime
18
+ import hashlib
19
+ import inspect
20
+ import json
21
+ import os
22
+ import re
23
+
24
+ from contextlib import suppress
25
+ from enum import IntEnum
26
+ from functools import wraps
27
+ from pathlib import Path
28
+ from platform import python_version
29
+ from sys import platform
30
+ from typing import Any, Callable, Optional, TypeVar
31
+ from uuid import getnode
32
+
33
+ from snowflake.connector import (
34
+ SNOWFLAKE_CONNECTOR_VERSION,
35
+ time_util,
36
+ )
37
+ from snowflake.connector.constants import DIRS as SNOWFLAKE_DIRS
38
+ from snowflake.connector.network import SnowflakeRestful
39
+ from snowflake.connector.telemetry import TelemetryClient
40
+ from snowflake.snowpark import VERSION as SNOWPARK_VERSION
41
+ from snowflake.snowpark import dataframe as snowpark_dataframe
42
+ from snowflake.snowpark.session import Session
43
+
44
+
45
+ try:
46
+ from pyspark.sql import dataframe as spark_dataframe
47
+
48
+ def _is_spark_dataframe(df: Any) -> bool:
49
+ return isinstance(df, spark_dataframe.DataFrame)
50
+
51
+ def _get_spark_schema_types(df: spark_dataframe.DataFrame) -> list[str]:
52
+ return [str(schema_type.dataType) for schema_type in df.schema.fields]
53
+
54
+ except Exception:
55
+
56
+ def _is_spark_dataframe(df: Any):
57
+ pass
58
+
59
+ def _get_spark_schema_types(df: Any):
60
+ pass
61
+
62
+
63
+ VERSION_VARIABLE_PATTERN = r"^__version__ = ['\"]([^'\"]*)['\"]"
64
+ VERSION_FILE_NAME = "__version__.py"
65
+
66
+
67
+ class TelemetryManager(TelemetryClient):
68
+ def __init__(
69
+ self, rest: Optional[SnowflakeRestful] = None, is_telemetry_enabled: bool = True
70
+ ):
71
+ """TelemetryManager class to log telemetry events."""
72
+ super().__init__(rest)
73
+ self.sc_folder_path = (
74
+ Path(SNOWFLAKE_DIRS.user_config_path) / "snowpark-checkpoints-telemetry"
75
+ )
76
+ self.sc_sf_path_telemetry = "/telemetry/send"
77
+ self.sc_flush_size = 25
78
+ self.sc_is_enabled = is_telemetry_enabled
79
+ self.sc_is_testing = self._sc_is_telemetry_testing()
80
+ self.sc_memory_limit = 5 * 1024 * 1024
81
+ self._sc_upload_local_telemetry()
82
+ self.sc_log_batch = []
83
+ self.sc_hypothesis_input_events = []
84
+ self.sc_version = _get_version()
85
+ if rest:
86
+ atexit.register(self._sc_close_at_exit)
87
+
88
+ def set_sc_output_path(self, path: Path) -> None:
89
+ """Set the output path for testing.
90
+
91
+ Args:
92
+ path: path to write telemetry.
93
+
94
+ """
95
+ os.makedirs(path, exist_ok=True)
96
+ self.sc_folder_path = path
97
+
98
+ def sc_log_error(
99
+ self, event_name: str, parameters_info: Optional[dict] = None
100
+ ) -> None:
101
+ """Log an error telemetry event.
102
+
103
+ Args:
104
+ event_name (str): The name of the event.
105
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
106
+
107
+ """
108
+ if event_name is not None:
109
+ self._sc_log_telemetry(event_name, "error", parameters_info)
110
+
111
+ def sc_log_info(
112
+ self, event_name: str, parameters_info: Optional[dict] = None
113
+ ) -> None:
114
+ """Log an information telemetry event.
115
+
116
+ Args:
117
+ event_name (str): The name of the event.
118
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
119
+
120
+ """
121
+ if event_name is not None:
122
+ self._sc_log_telemetry(event_name, "info", parameters_info)
123
+
124
+ def _sc_log_telemetry(
125
+ self, event_name: str, event_type: str, parameters_info: Optional[dict] = None
126
+ ) -> dict:
127
+ """Log a telemetry event if is enabled.
128
+
129
+ Args:
130
+ event_name (str): The name of the event.
131
+ event_type (str): The type of the event (e.g., "error", "info").
132
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
133
+
134
+ Returns:
135
+ dict: The logged event.
136
+
137
+ """
138
+ if not self.sc_is_enabled:
139
+ return {}
140
+ event = _generate_event(
141
+ event_name, event_type, parameters_info, self.sc_version
142
+ )
143
+ self._sc_add_log_to_batch(event)
144
+ return event
145
+
146
+ def _sc_add_log_to_batch(self, event: dict) -> None:
147
+ """Add a log event to the batch. If the batch is full, send the events to the API.
148
+
149
+ Args:
150
+ event (dict): The event to add.
151
+
152
+ """
153
+ self.sc_log_batch.append(event)
154
+ if self.sc_is_testing:
155
+ self._sc_write_telemetry(self.sc_log_batch)
156
+ self.sc_log_batch = []
157
+ return
158
+
159
+ if len(self.sc_log_batch) >= self.sc_flush_size:
160
+ self.sc_send_batch(self.sc_log_batch)
161
+ self.sc_log_batch = []
162
+
163
+ def sc_send_batch(self, to_sent: list) -> bool:
164
+ """Send a request to the API to upload the events. If not have connection, write the events to local folder.
165
+
166
+ Args:
167
+ to_sent (list): The batch of events to send.
168
+
169
+ Returns:
170
+ bool: True if the batch was sent successfully, False otherwise.
171
+
172
+ """
173
+ if not self.sc_is_enabled:
174
+ return False
175
+ if self._rest is None:
176
+ self._sc_write_telemetry(to_sent)
177
+ self.sc_log_batch = []
178
+ return False
179
+ if to_sent == []:
180
+ return False
181
+ body = {"logs": to_sent}
182
+ ret = self._rest.request(
183
+ self.sc_sf_path_telemetry,
184
+ body=body,
185
+ method="post",
186
+ client=None,
187
+ timeout=5,
188
+ )
189
+ if not ret.get("success"):
190
+ self._sc_write_telemetry(to_sent)
191
+ self.sc_log_batch = []
192
+ return False
193
+ return True
194
+
195
+ def _sc_write_telemetry(self, batch: list) -> None:
196
+ """Write telemetry events to local folder. If the folder is full, free up space for the new events.
197
+
198
+ Args:
199
+ batch (list): The batch of events to write.
200
+
201
+ """
202
+ try:
203
+ os.makedirs(self.sc_folder_path, exist_ok=True)
204
+ for event in batch:
205
+ message = event.get("message")
206
+ if message is not None:
207
+ file_path = (
208
+ self.sc_folder_path
209
+ / f'{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")}'
210
+ f'_telemetry_{message.get("type")}.json'
211
+ )
212
+ json_content = self._sc_validate_folder_space(event)
213
+ with open(file_path, "w") as json_file:
214
+ json_file.write(json_content)
215
+ except Exception:
216
+ pass
217
+
218
+ def _sc_validate_folder_space(self, event: dict) -> str:
219
+ """Validate and manage folder space for the new telemetry events.
220
+
221
+ Args:
222
+ event (dict): The event to validate.
223
+
224
+ Returns:
225
+ str: The JSON content of the event.
226
+
227
+ """
228
+ json_content = json.dumps(event, indent=4, sort_keys=True)
229
+ new_json_file_size = len(json_content.encode("utf-8"))
230
+ telemetry_folder = self.sc_folder_path
231
+ folder_size = _get_folder_size(telemetry_folder)
232
+ if folder_size + new_json_file_size > self.sc_memory_limit:
233
+ _free_up_space(telemetry_folder, self.sc_memory_limit - new_json_file_size)
234
+ return json_content
235
+
236
+ def _sc_upload_local_telemetry(self) -> None:
237
+ """Send a request to the API to upload the local telemetry events."""
238
+ if not self.sc_is_enabled or self.sc_is_testing or not self._rest:
239
+ return
240
+ batch = []
241
+ for file in self.sc_folder_path.glob("*.json"):
242
+ with open(file) as json_file:
243
+ data_dict = json.load(json_file)
244
+ batch.append(data_dict)
245
+ if batch == []:
246
+ return
247
+ body = {"logs": batch}
248
+ ret = self._rest.request(
249
+ self.sc_sf_path_telemetry,
250
+ body=body,
251
+ method="post",
252
+ client=None,
253
+ timeout=5,
254
+ )
255
+ if ret.get("success"):
256
+ for file in self.sc_folder_path.glob("*.json"):
257
+ file.unlink()
258
+
259
+ def _sc_is_telemetry_testing(self) -> bool:
260
+ is_testing = os.getenv("SNOWPARK_CHECKPOINTS_TELEMETRY_TESTING") == "true"
261
+ if is_testing:
262
+ local_telemetry_path = (
263
+ Path(os.getcwd()) / "snowpark-checkpoints-output" / "telemetry"
264
+ )
265
+ self.set_sc_output_path(local_telemetry_path)
266
+ self.sc_is_enabled = True
267
+ return is_testing
268
+
269
+ def _sc_is_telemetry_manager(self) -> bool:
270
+ """Check if the class is telemetry manager.
271
+
272
+ Returns:
273
+ bool: True if the class is telemetry manager, False otherwise.
274
+
275
+ """
276
+ return True
277
+
278
+ def sc_is_hypothesis_event_logged(self, event_name: tuple[str, int]) -> bool:
279
+ """Check if a Hypothesis event is logged.
280
+
281
+ Args:
282
+ event_name (tuple[str, int]): A tuple containing the name of the event and an integer identifier
283
+ (0 for info, 1 for error).
284
+
285
+ Returns:
286
+ bool: True if the event is logged, False otherwise.
287
+
288
+ """
289
+ return event_name in self.sc_hypothesis_input_events
290
+
291
+ def _sc_close(self) -> None:
292
+ """Close the telemetry manager and upload collected events.
293
+
294
+ This function closes the telemetry manager, uploads any collected events,
295
+ and performs any necessary cleanup to ensure no data is lost.
296
+ """
297
+ atexit.unregister(self._sc_close_at_exit)
298
+ if self.sc_log_batch and self.sc_is_enabled and not self.sc_is_testing:
299
+ self.sc_send_batch(self.sc_log_batch)
300
+
301
+ def _sc_close_at_exit(self) -> None:
302
+ """Close the telemetry manager at exit and upload collected events.
303
+
304
+ This function ensures that the telemetry manager is closed and all collected events
305
+ are uploaded when the program exits, preventing data loss.
306
+ """
307
+ with suppress(Exception):
308
+ self._sc_close()
309
+
310
+
311
+ def _generate_event(
312
+ event_name: str,
313
+ event_type: str,
314
+ parameters_info: Optional[dict] = None,
315
+ sc_version: Optional[str] = None,
316
+ ) -> dict:
317
+ """Generate a telemetry event.
318
+
319
+ Args:
320
+ event_name (str): The name of the event.
321
+ event_type (str): The type of the event (e.g., "error", "info").
322
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
323
+ sc_version (str, optional): The version of the package. Defaults to None.
324
+
325
+ Returns:
326
+ dict: The generated event.
327
+
328
+ """
329
+ metadata = _get_metadata()
330
+ if sc_version is not None:
331
+ metadata["snowpark_checkpoints_version"] = sc_version
332
+ message = {
333
+ "type": event_type,
334
+ "event_name": event_name,
335
+ "driver_type": "PythonConnector",
336
+ "driver_version": SNOWFLAKE_CONNECTOR_VERSION,
337
+ "source": "snowpark-checkpoints",
338
+ "metadata": metadata,
339
+ "data": json.dumps(parameters_info or {}),
340
+ }
341
+ timestamp = time_util.get_time_millis()
342
+ event_base = {"message": message, "timestamp": str(timestamp)}
343
+
344
+ return event_base
345
+
346
+
347
+ def _get_metadata() -> dict:
348
+ """Get metadata for telemetry events.
349
+
350
+ Returns:
351
+ dict: The metadata including OS version, Python version, and device ID.
352
+
353
+ """
354
+ return {
355
+ "os_version": platform,
356
+ "python_version": python_version(),
357
+ "snowpark_version": ".".join(str(x) for x in SNOWPARK_VERSION if x is not None),
358
+ "device_id": _get_unique_id(),
359
+ }
360
+
361
+
362
+ def _get_version() -> str:
363
+ """Get the version of the package.
364
+
365
+ Returns:
366
+ str: The version of the package.
367
+
368
+ """
369
+ try:
370
+ directory_levels_up = 4
371
+ project_root = Path(__file__).resolve().parents[directory_levels_up]
372
+ version_file_path = project_root / VERSION_FILE_NAME
373
+ with open(version_file_path) as file:
374
+ content = file.read()
375
+ version_match = re.search(VERSION_VARIABLE_PATTERN, content, re.MULTILINE)
376
+ if version_match:
377
+ return version_match.group(1)
378
+ return None
379
+ except Exception:
380
+ return None
381
+
382
+
383
+ def _get_folder_size(folder_path: Path) -> int:
384
+ """Get the size of a folder. Only considers JSON files.
385
+
386
+ Args:
387
+ folder_path (Path): The path to the folder.
388
+
389
+ Returns:
390
+ int: The size of the folder in bytes.
391
+
392
+ """
393
+ return sum(f.stat().st_size for f in folder_path.glob("*.json") if f.is_file())
394
+
395
+
396
+ def _free_up_space(folder_path: Path, max_size: int) -> None:
397
+ """Free up space in a folder by deleting the oldest files. Only considers JSON files.
398
+
399
+ Args:
400
+ folder_path (Path): The path to the folder.
401
+ max_size (int): The maximum allowed size of the folder in bytes.
402
+
403
+ """
404
+ files = sorted(folder_path.glob("*.json"), key=lambda f: f.stat().st_mtime)
405
+ current_size = _get_folder_size(folder_path)
406
+ for file in files:
407
+ if current_size <= max_size:
408
+ break
409
+ current_size -= file.stat().st_size
410
+ file.unlink()
411
+
412
+
413
+ def _get_unique_id() -> str:
414
+ """Get a unique device ID. The ID is generated based on the hashed MAC address.
415
+
416
+ Returns:
417
+ str: The hashed device ID.
418
+
419
+ """
420
+ node_id_str = str(getnode())
421
+ hashed_id = hashlib.sha256(node_id_str.encode()).hexdigest()
422
+ return hashed_id
423
+
424
+
425
+ def get_telemetry_manager() -> TelemetryManager:
426
+ """Get the telemetry manager.
427
+
428
+ Returns:
429
+ TelemetryManager: The telemetry manager.
430
+
431
+ """
432
+ try:
433
+ connection = Session.builder.getOrCreate().connection
434
+ if not hasattr(connection._telemetry, "_sc_is_telemetry_manager"):
435
+ connection._telemetry = TelemetryManager(
436
+ connection._rest, connection.telemetry_enabled
437
+ )
438
+ return connection._telemetry
439
+ except Exception:
440
+ telemetry_manager = TelemetryManager(None, is_telemetry_enabled=True)
441
+ telemetry_manager.sc_flush_size = 1
442
+ return telemetry_manager
443
+
444
+
445
+ def get_snowflake_schema_types(df: snowpark_dataframe.DataFrame) -> list[str]:
446
+ """Extract the data types of the schema fields from a Snowflake DataFrame.
447
+
448
+ Args:
449
+ df (snowpark_dataframe.DataFrame): The Snowflake DataFrame.
450
+
451
+ Returns:
452
+ list[str]: A list of data type names of the schema fields.
453
+
454
+ """
455
+ return [str(schema_type.datatype) for schema_type in df.schema.fields]
456
+
457
+
458
+ def _is_snowpark_dataframe(df: Any) -> bool:
459
+ """Check if the given dataframe is a Snowpark dataframe.
460
+
461
+ Args:
462
+ df: The dataframe to check.
463
+
464
+ Returns:
465
+ bool: True if the dataframe is a Snowpark dataframe, False otherwise.
466
+
467
+ """
468
+ return isinstance(df, snowpark_dataframe.DataFrame)
469
+
470
+
471
+ def get_load_json(json_schema: str) -> dict:
472
+ """Load and parse a JSON schema file.
473
+
474
+ Args:
475
+ json_schema (str): The path to the JSON schema file.
476
+
477
+ Returns:
478
+ dict: The parsed JSON content.
479
+
480
+ Raises:
481
+ ValueError: If there is an error reading or parsing the JSON file.
482
+
483
+ """
484
+ try:
485
+ with open(json_schema, encoding="utf-8") as file:
486
+ return json.load(file)
487
+ except (OSError, json.JSONDecodeError) as e:
488
+ raise ValueError(f"Error reading JSON schema file: {e}") from None
489
+
490
+
491
+ def extract_parameters(
492
+ func: Callable, args: tuple, kwargs: dict, params_list: Optional[list[str]]
493
+ ) -> dict:
494
+ """Extract parameters from the function arguments.
495
+
496
+ Args:
497
+ func (Callable): The function being decorated.
498
+ args (tuple): The positional arguments passed to the function.
499
+ kwargs (dict): The keyword arguments passed to the function.
500
+ params_list (list[str]): The list of parameters to extract.
501
+
502
+ Returns:
503
+ dict: A dictionary of extracted parameters.
504
+
505
+ """
506
+ parameters = inspect.signature(func).parameters
507
+ param_data = {}
508
+ if params_list:
509
+ for _, param in enumerate(params_list):
510
+ if len(args) > 0:
511
+ index = list(parameters.keys()).index(param)
512
+ param_data[param] = args[index]
513
+ else:
514
+ if kwargs[param]:
515
+ param_data[param] = kwargs[param]
516
+ return param_data
517
+
518
+
519
+ def check_dataframe_schema_event(
520
+ telemetry_data: dict, param_data: dict
521
+ ) -> tuple[str, dict]:
522
+ """Handle telemetry event for checking dataframe schema.
523
+
524
+ Args:
525
+ telemetry_data (dict): The telemetry data dictionary.
526
+ param_data (dict): The parameter data dictionary.
527
+
528
+ Returns:
529
+ tuple: A tuple containing the event name and telemetry data.
530
+
531
+ """
532
+ telemetry_data[MODE_KEY] = CheckpointMode.SCHEMA.value
533
+ try:
534
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY)
535
+ pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
536
+ schema_types = []
537
+ for schema_type in pandera_schema.columns.values():
538
+ if schema_type.dtype is not None:
539
+ schema_types.append(str(schema_type.dtype))
540
+ if schema_types:
541
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
542
+ return DATAFRAME_VALIDATOR_SCHEMA, telemetry_data
543
+ except Exception:
544
+ if param_data.get(STATUS_KEY):
545
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY)
546
+ pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
547
+ if pandera_schema:
548
+ schema_types = []
549
+ for schema_type in pandera_schema.columns.values():
550
+ if schema_type.dtype is not None:
551
+ schema_types.append(str(schema_type.dtype))
552
+ if schema_types:
553
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
554
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
555
+
556
+
557
+ def check_output_or_input_schema_event(
558
+ telemetry_data: dict, param_data: dict
559
+ ) -> tuple[str, dict]:
560
+ """Handle telemetry event for checking output or input schema.
561
+
562
+ Args:
563
+ telemetry_data (dict): The telemetry data dictionary.
564
+ param_data (dict): The parameter data dictionary.
565
+
566
+ Returns:
567
+ tuple: A tuple containing the event name and telemetry data.
568
+
569
+ """
570
+ try:
571
+ pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
572
+ schema_types = []
573
+ for schema_type in pandera_schema.columns.values():
574
+ if schema_type.dtype is not None:
575
+ schema_types.append(str(schema_type.dtype))
576
+ if schema_types:
577
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
578
+ return DATAFRAME_VALIDATOR_SCHEMA, telemetry_data
579
+ except Exception:
580
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
581
+
582
+
583
+ def collect_dataframe_checkpoint_mode_schema_event(
584
+ telemetry_data: dict, param_data: dict
585
+ ) -> tuple[str, dict]:
586
+ """Handle telemetry event for collecting dataframe checkpoint mode schema.
587
+
588
+ Args:
589
+ telemetry_data (dict): The telemetry data dictionary.
590
+ param_data (dict): The parameter data dictionary.
591
+
592
+ Returns:
593
+ tuple: A tuple containing the event name and telemetry data.
594
+
595
+ """
596
+ telemetry_data[MODE_KEY] = CheckpointMode.SCHEMA.value
597
+ try:
598
+ schema_types = param_data.get("column_type_dict")
599
+ telemetry_data[SCHEMA_TYPES_KEY] = [
600
+ schema_types[schema_type].dataType.typeName()
601
+ for schema_type in schema_types
602
+ ]
603
+ return DATAFRAME_COLLECTION_SCHEMA, telemetry_data
604
+ except Exception:
605
+ return DATAFRAME_COLLECTION_ERROR, telemetry_data
606
+
607
+
608
+ def collect_dataframe_checkpoint_mode_dataframe_event(
609
+ telemetry_data: dict, param_data: dict
610
+ ) -> tuple[str, dict]:
611
+ """Handle telemetry event for collecting dataframe checkpoint mode dataframe.
612
+
613
+ This function processes telemetry data for a dataframe checkpoint mode event. It updates the telemetry data
614
+ with the mode and schema types of the Spark DataFrame being collected.
615
+
616
+ Args:
617
+ telemetry_data (dict): The telemetry data dictionary to be updated.
618
+ param_data (dict): The parameter data dictionary containing the DataFrame information.
619
+
620
+ Returns:
621
+ tuple: A tuple containing the event name and the updated telemetry data dictionary.
622
+
623
+ """
624
+ telemetry_data[MODE_KEY] = CheckpointMode.DATAFRAME.value
625
+ try:
626
+ if _is_spark_dataframe(param_data.get(DF_PARAM_NAME)):
627
+ telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
628
+ param_data.get(DF_PARAM_NAME)
629
+ )
630
+ return DATAFRAME_COLLECTION_DF, telemetry_data
631
+ except Exception:
632
+ return DATAFRAME_COLLECTION_ERROR, telemetry_data
633
+
634
+
635
+ def assert_return_event(telemetry_data: dict, param_data: dict) -> tuple[str, dict]:
636
+ """Handle telemetry event for asserting return values.
637
+
638
+ Args:
639
+ telemetry_data (dict): The telemetry data dictionary.
640
+ param_data (dict): The parameter data dictionary.
641
+
642
+ Returns:
643
+ tuple: A tuple containing the event name and telemetry data.
644
+
645
+ """
646
+ if param_data.get(STATUS_KEY) is not None:
647
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY, None)
648
+ try:
649
+ if _is_snowpark_dataframe(
650
+ param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
651
+ ) and _is_spark_dataframe(param_data.get(SPARK_RESULTS_PARAM_NAME)):
652
+ telemetry_data[SNOWFLAKE_SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
653
+ param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
654
+ )
655
+ telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
656
+ param_data.get(SPARK_RESULTS_PARAM_NAME)
657
+ )
658
+ return DATAFRAME_VALIDATOR_MIRROR, telemetry_data
659
+ else:
660
+ return VALUE_VALIDATOR_MIRROR, telemetry_data
661
+ except Exception:
662
+ if _is_snowpark_dataframe(param_data.get(SNOWPARK_RESULTS_PARAM_NAME)):
663
+ telemetry_data[SNOWFLAKE_SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
664
+ param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
665
+ )
666
+ if _is_spark_dataframe(param_data.get(SPARK_RESULTS_PARAM_NAME)):
667
+ telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
668
+ param_data.get(SPARK_RESULTS_PARAM_NAME)
669
+ )
670
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
671
+
672
+
673
+ def dataframe_strategy_event(
674
+ telemetry_data: dict, param_data: dict, telemetry_m: TelemetryManager
675
+ ) -> tuple[Optional[str], Optional[dict]]:
676
+ """Handle telemetry event for dataframe strategy.
677
+
678
+ Args:
679
+ telemetry_data (dict): The telemetry data dictionary.
680
+ param_data (dict): The parameter data dictionary.
681
+ telemetry_m (TelemetryManager): The telemetry manager.
682
+
683
+ Returns:
684
+ tuple: A tuple containing the event name and telemetry data.
685
+
686
+ """
687
+ try:
688
+ test_function_name = inspect.stack()[2].function
689
+ is_logged = telemetry_m.sc_is_hypothesis_event_logged((test_function_name, 0))
690
+ if not is_logged:
691
+ schema_param = param_data.get(DATAFRAME_STRATEGY_SCHEMA_PARAM_NAME)
692
+ if isinstance(schema_param, str):
693
+ json_data = get_load_json(schema_param)["custom_data"]["columns"]
694
+ telemetry_data[SCHEMA_TYPES_KEY] = [
695
+ column["type"] for column in json_data
696
+ ]
697
+ else:
698
+ schema_types = []
699
+ for schema_type in schema_param.columns.values():
700
+ if schema_type.dtype is not None:
701
+ schema_types.append(str(schema_type.dtype))
702
+ if schema_types:
703
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
704
+ telemetry_m.sc_hypothesis_input_events.append((test_function_name, 0))
705
+ if None in telemetry_data[SCHEMA_TYPES_KEY]:
706
+ telemetry_m.sc_log_error(HYPOTHESIS_INPUT_SCHEMA_ERROR, telemetry_data)
707
+ else:
708
+ telemetry_m.sc_log_info(HYPOTHESIS_INPUT_SCHEMA, telemetry_data)
709
+ telemetry_m.sc_send_batch(telemetry_m.sc_log_batch)
710
+ return None, None
711
+ except Exception:
712
+ test_function_name = inspect.stack()[2].function
713
+ is_logged = telemetry_m.sc_is_hypothesis_event_logged((test_function_name, 1))
714
+ if not is_logged:
715
+ telemetry_m.sc_hypothesis_input_events.append((test_function_name, 0))
716
+ telemetry_m.sc_log_error(HYPOTHESIS_INPUT_SCHEMA_ERROR, telemetry_data)
717
+ telemetry_m.sc_send_batch(telemetry_m.sc_log_batch)
718
+ return None, None
719
+
720
+
721
+ def compare_data_event(telemetry_data: dict, param_data: dict) -> tuple[str, dict]:
722
+ """Handle telemetry event for comparing data.
723
+
724
+ This function processes telemetry data for a data comparison event. It updates the telemetry data
725
+ with the mode, status, and schema types of the Snowflake DataFrame being compared.
726
+
727
+ Args:
728
+ telemetry_data (dict): The telemetry data dictionary to be updated.
729
+ param_data (dict): The parameter data dictionary containing the DataFrame and status information.
730
+
731
+ Returns:
732
+ tuple: A tuple containing the event name and the updated telemetry data dictionary.
733
+
734
+ """
735
+ telemetry_data[MODE_KEY] = CheckpointMode.DATAFRAME.value
736
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY, None)
737
+ try:
738
+ telemetry_data[SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
739
+ param_data.get("df")
740
+ )
741
+ return DATAFRAME_VALIDATOR_DF, telemetry_data
742
+ except Exception:
743
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
744
+
745
+
746
+ def handle_result(
747
+ func_name: str,
748
+ result: Any,
749
+ param_data: dict,
750
+ multiple_return: bool,
751
+ telemetry_m: TelemetryManager,
752
+ return_indexes: Optional[list[tuple[str, int]]] = None,
753
+ ) -> tuple[Optional[str], Optional[dict]]:
754
+ """Handle the result of the function and collect telemetry data.
755
+
756
+ Args:
757
+ func_name (str): The name of the function.
758
+ result: The result of the function.
759
+ param_data (dict): The extracted parameters.
760
+ multiple_return (bool): Whether the function returns multiple values.
761
+ telemetry_m (TelemetryManager): The telemetry manager.
762
+ return_indexes (list[tuple[str, int]]): The list of return values to report. Defaults to None.
763
+
764
+ Returns:
765
+ tuple: A tuple containing the event name (str) and telemetry data (dict).
766
+
767
+ """
768
+ if result is not None and return_indexes is not None:
769
+ if multiple_return:
770
+ for name, index in return_indexes:
771
+ param_data[name] = result[index]
772
+ else:
773
+ param_data[return_indexes[0][0]] = result[return_indexes[0][1]]
774
+
775
+ telemetry_data = {
776
+ FUNCTION_KEY: func_name,
777
+ }
778
+
779
+ telemetry_event = None
780
+ data = None
781
+ if func_name == "_check_dataframe_schema":
782
+ telemetry_event, data = check_dataframe_schema_event(telemetry_data, param_data)
783
+ elif func_name in ["check_output_schema", "check_input_schema"]:
784
+ telemetry_event, data = check_output_or_input_schema_event(
785
+ telemetry_data, param_data
786
+ )
787
+ if func_name == "_compare_data":
788
+ telemetry_event, data = compare_data_event(telemetry_data, param_data)
789
+ elif func_name == "_collect_dataframe_checkpoint_mode_schema":
790
+ telemetry_event, data = collect_dataframe_checkpoint_mode_schema_event(
791
+ telemetry_data, param_data
792
+ )
793
+ elif func_name == "_collect_dataframe_checkpoint_mode_dataframe":
794
+ telemetry_event, data = collect_dataframe_checkpoint_mode_dataframe_event(
795
+ telemetry_data, param_data
796
+ )
797
+ elif func_name == "_assert_return":
798
+ telemetry_event, data = assert_return_event(telemetry_data, param_data)
799
+ elif func_name == "dataframe_strategy":
800
+ telemetry_event, data = dataframe_strategy_event(
801
+ telemetry_data, param_data, telemetry_m
802
+ )
803
+ return telemetry_event, data
804
+
805
+
806
+ fn = TypeVar("fn", bound=Callable)
807
+
808
+
809
+ def report_telemetry(
810
+ params_list: list[str] = None,
811
+ return_indexes: list[tuple[str, int]] = None,
812
+ multiple_return: bool = False,
813
+ ) -> Callable[[fn], fn]:
814
+ """Report telemetry events for a function.
815
+
816
+ Args:
817
+ params_list (list[str], optional): The list of parameters to report. Defaults to None.
818
+ return_indexes (list[tuple[str, int]], optional): The list of return values to report. Defaults to None.
819
+ multiple_return (bool, optional): Whether the function returns multiple values. Defaults to False.
820
+
821
+ Returns:
822
+ Callable[[fn], fn]: The decorator function.
823
+
824
+ """
825
+
826
+ def report_telemetry_decorator(func):
827
+ func_name = func.__name__
828
+
829
+ @wraps(func)
830
+ def wrapper(*args, **kwargs):
831
+ func_exception = None
832
+ result = None
833
+ try:
834
+ result = func(*args, **kwargs)
835
+ except Exception as err:
836
+ func_exception = err
837
+
838
+ if os.getenv("SNOWPARK_CHECKPOINTS_TELEMETRY_ENABLED") == "false":
839
+ return result
840
+ telemetry_event = None
841
+ data = None
842
+ telemetry_m = None
843
+ try:
844
+ param_data = extract_parameters(func, args, kwargs, params_list)
845
+ telemetry_m = get_telemetry_manager()
846
+ telemetry_event, data = handle_result(
847
+ func_name,
848
+ result,
849
+ param_data,
850
+ multiple_return,
851
+ telemetry_m,
852
+ return_indexes,
853
+ )
854
+ except Exception:
855
+ pass
856
+ finally:
857
+ if func_exception is not None:
858
+ if telemetry_m is not None:
859
+ telemetry_m.sc_log_error(telemetry_event, data)
860
+ raise func_exception
861
+ if telemetry_m is not None:
862
+ telemetry_m.sc_log_info(telemetry_event, data)
863
+
864
+ return result
865
+
866
+ return wrapper
867
+
868
+ return report_telemetry_decorator
869
+
870
+
871
+ # Constants for telemetry
872
+ DATAFRAME_COLLECTION_SCHEMA = "DataFrame_Collection_Schema"
873
+ DATAFRAME_COLLECTION_DF = "DataFrame_Collection_DF"
874
+ DATAFRAME_VALIDATOR_MIRROR = "DataFrame_Validator_Mirror"
875
+ VALUE_VALIDATOR_MIRROR = "Value_Validator_Mirror"
876
+ DATAFRAME_VALIDATOR_SCHEMA = "DataFrame_Validator_Schema"
877
+ DATAFRAME_VALIDATOR_DF = "DataFrame_Validator_DF"
878
+ HYPOTHESIS_INPUT_SCHEMA = "Hypothesis_Input_Schema"
879
+ DATAFRAME_COLLECTION_ERROR = "DataFrame_Collection_Error"
880
+ DATAFRAME_VALIDATOR_ERROR = "DataFrame_Validator_Error"
881
+ HYPOTHESIS_INPUT_SCHEMA_ERROR = "Hypothesis_Input_Schema_Error"
882
+
883
+ FUNCTION_KEY = "function"
884
+ STATUS_KEY = "status"
885
+ SCHEMA_TYPES_KEY = "schema_types"
886
+ ERROR_KEY = "error"
887
+ MODE_KEY = "mode"
888
+ SNOWFLAKE_SCHEMA_TYPES_KEY = "snowflake_schema_types"
889
+ SPARK_SCHEMA_TYPES_KEY = "spark_schema_types"
890
+
891
+ DATAFRAME_STRATEGY_SCHEMA_PARAM_NAME = "schema"
892
+ PANDERA_SCHEMA_PARAM_NAME = "pandera_schema"
893
+ SNOWPARK_RESULTS_PARAM_NAME = "snowpark_results"
894
+ SPARK_RESULTS_PARAM_NAME = "spark_results"
895
+ DF_PARAM_NAME = "df"
896
+
897
+
898
+ class CheckpointMode(IntEnum):
899
+ SCHEMA = 1
900
+ DATAFRAME = 2