snowpark-checkpoints-collectors 0.1.0rc3__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. snowflake/snowpark_checkpoints_collector/__init__.py +22 -0
  2. snowflake/snowpark_checkpoints_collector/__version__.py +16 -0
  3. snowflake/snowpark_checkpoints_collector/collection_common.py +160 -0
  4. snowflake/snowpark_checkpoints_collector/collection_result/model/__init__.py +24 -0
  5. snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result.py +91 -0
  6. snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py +69 -0
  7. snowflake/snowpark_checkpoints_collector/column_collection/__init__.py +22 -0
  8. snowflake/snowpark_checkpoints_collector/column_collection/column_collector_manager.py +253 -0
  9. snowflake/snowpark_checkpoints_collector/column_collection/model/__init__.py +75 -0
  10. snowflake/snowpark_checkpoints_collector/column_collection/model/array_column_collector.py +113 -0
  11. snowflake/snowpark_checkpoints_collector/column_collection/model/binary_column_collector.py +87 -0
  12. snowflake/snowpark_checkpoints_collector/column_collection/model/boolean_column_collector.py +71 -0
  13. snowflake/snowpark_checkpoints_collector/column_collection/model/column_collector_base.py +95 -0
  14. snowflake/snowpark_checkpoints_collector/column_collection/model/date_column_collector.py +74 -0
  15. snowflake/snowpark_checkpoints_collector/column_collection/model/day_time_interval_column_collector.py +67 -0
  16. snowflake/snowpark_checkpoints_collector/column_collection/model/decimal_column_collector.py +92 -0
  17. snowflake/snowpark_checkpoints_collector/column_collection/model/empty_column_collector.py +88 -0
  18. snowflake/snowpark_checkpoints_collector/column_collection/model/map_column_collector.py +120 -0
  19. snowflake/snowpark_checkpoints_collector/column_collection/model/null_column_collector.py +49 -0
  20. snowflake/snowpark_checkpoints_collector/column_collection/model/numeric_column_collector.py +108 -0
  21. snowflake/snowpark_checkpoints_collector/column_collection/model/string_column_collector.py +70 -0
  22. snowflake/snowpark_checkpoints_collector/column_collection/model/struct_column_collector.py +102 -0
  23. snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_column_collector.py +75 -0
  24. snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_ntz_column_collector.py +75 -0
  25. snowflake/snowpark_checkpoints_collector/column_pandera_checks/__init__.py +20 -0
  26. snowflake/snowpark_checkpoints_collector/column_pandera_checks/pandera_column_checks_manager.py +223 -0
  27. snowflake/snowpark_checkpoints_collector/singleton.py +23 -0
  28. snowflake/snowpark_checkpoints_collector/snow_connection_model/__init__.py +20 -0
  29. snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py +172 -0
  30. snowflake/snowpark_checkpoints_collector/summary_stats_collector.py +366 -0
  31. snowflake/snowpark_checkpoints_collector/utils/checkpoint_name_utils.py +53 -0
  32. snowflake/snowpark_checkpoints_collector/utils/extra_config.py +112 -0
  33. snowflake/snowpark_checkpoints_collector/utils/file_utils.py +132 -0
  34. snowflake/snowpark_checkpoints_collector/utils/telemetry.py +889 -0
  35. {snowpark_checkpoints_collectors-0.1.0rc3.dist-info → snowpark_checkpoints_collectors-0.1.2.dist-info}/METADATA +4 -6
  36. snowpark_checkpoints_collectors-0.1.2.dist-info/RECORD +38 -0
  37. {snowpark_checkpoints_collectors-0.1.0rc3.dist-info → snowpark_checkpoints_collectors-0.1.2.dist-info}/licenses/LICENSE +0 -25
  38. snowpark_checkpoints_collectors-0.1.0rc3.dist-info/RECORD +0 -4
  39. {snowpark_checkpoints_collectors-0.1.0rc3.dist-info → snowpark_checkpoints_collectors-0.1.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,889 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ import atexit
6
+ import datetime
7
+ import hashlib
8
+ import inspect
9
+ import json
10
+ import os
11
+ import re
12
+
13
+ from contextlib import suppress
14
+ from enum import IntEnum
15
+ from functools import wraps
16
+ from pathlib import Path
17
+ from platform import python_version
18
+ from sys import platform
19
+ from typing import Any, Callable, Optional, TypeVar
20
+ from uuid import getnode
21
+
22
+ from snowflake.connector import (
23
+ SNOWFLAKE_CONNECTOR_VERSION,
24
+ time_util,
25
+ )
26
+ from snowflake.connector.constants import DIRS as SNOWFLAKE_DIRS
27
+ from snowflake.connector.network import SnowflakeRestful
28
+ from snowflake.connector.telemetry import TelemetryClient
29
+ from snowflake.snowpark import VERSION as SNOWPARK_VERSION
30
+ from snowflake.snowpark import dataframe as snowpark_dataframe
31
+ from snowflake.snowpark.session import Session
32
+
33
+
34
+ try:
35
+ from pyspark.sql import dataframe as spark_dataframe
36
+
37
+ def _is_spark_dataframe(df: Any) -> bool:
38
+ return isinstance(df, spark_dataframe.DataFrame)
39
+
40
+ def _get_spark_schema_types(df: spark_dataframe.DataFrame) -> list[str]:
41
+ return [str(schema_type.dataType) for schema_type in df.schema.fields]
42
+
43
+ except Exception:
44
+
45
+ def _is_spark_dataframe(df: Any):
46
+ pass
47
+
48
+ def _get_spark_schema_types(df: Any):
49
+ pass
50
+
51
+
52
+ VERSION_VARIABLE_PATTERN = r"^__version__ = ['\"]([^'\"]*)['\"]"
53
+ VERSION_FILE_NAME = "__version__.py"
54
+
55
+
56
+ class TelemetryManager(TelemetryClient):
57
+ def __init__(
58
+ self, rest: Optional[SnowflakeRestful] = None, is_telemetry_enabled: bool = True
59
+ ):
60
+ """TelemetryManager class to log telemetry events."""
61
+ super().__init__(rest)
62
+ self.sc_folder_path = (
63
+ Path(SNOWFLAKE_DIRS.user_config_path) / "snowpark-checkpoints-telemetry"
64
+ )
65
+ self.sc_sf_path_telemetry = "/telemetry/send"
66
+ self.sc_flush_size = 25
67
+ self.sc_is_enabled = is_telemetry_enabled
68
+ self.sc_is_testing = self._sc_is_telemetry_testing()
69
+ self.sc_memory_limit = 5 * 1024 * 1024
70
+ self._sc_upload_local_telemetry()
71
+ self.sc_log_batch = []
72
+ self.sc_hypothesis_input_events = []
73
+ self.sc_version = _get_version()
74
+ if rest:
75
+ atexit.register(self._sc_close_at_exit)
76
+
77
+ def set_sc_output_path(self, path: Path) -> None:
78
+ """Set the output path for testing.
79
+
80
+ Args:
81
+ path: path to write telemetry.
82
+
83
+ """
84
+ os.makedirs(path, exist_ok=True)
85
+ self.sc_folder_path = path
86
+
87
+ def sc_log_error(
88
+ self, event_name: str, parameters_info: Optional[dict] = None
89
+ ) -> None:
90
+ """Log an error telemetry event.
91
+
92
+ Args:
93
+ event_name (str): The name of the event.
94
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
95
+
96
+ """
97
+ if event_name is not None:
98
+ self._sc_log_telemetry(event_name, "error", parameters_info)
99
+
100
+ def sc_log_info(
101
+ self, event_name: str, parameters_info: Optional[dict] = None
102
+ ) -> None:
103
+ """Log an information telemetry event.
104
+
105
+ Args:
106
+ event_name (str): The name of the event.
107
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
108
+
109
+ """
110
+ if event_name is not None:
111
+ self._sc_log_telemetry(event_name, "info", parameters_info)
112
+
113
+ def _sc_log_telemetry(
114
+ self, event_name: str, event_type: str, parameters_info: Optional[dict] = None
115
+ ) -> dict:
116
+ """Log a telemetry event if is enabled.
117
+
118
+ Args:
119
+ event_name (str): The name of the event.
120
+ event_type (str): The type of the event (e.g., "error", "info").
121
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
122
+
123
+ Returns:
124
+ dict: The logged event.
125
+
126
+ """
127
+ if not self.sc_is_enabled:
128
+ return {}
129
+ event = _generate_event(
130
+ event_name, event_type, parameters_info, self.sc_version
131
+ )
132
+ self._sc_add_log_to_batch(event)
133
+ return event
134
+
135
+ def _sc_add_log_to_batch(self, event: dict) -> None:
136
+ """Add a log event to the batch. If the batch is full, send the events to the API.
137
+
138
+ Args:
139
+ event (dict): The event to add.
140
+
141
+ """
142
+ self.sc_log_batch.append(event)
143
+ if self.sc_is_testing:
144
+ self._sc_write_telemetry(self.sc_log_batch)
145
+ self.sc_log_batch = []
146
+ return
147
+
148
+ if len(self.sc_log_batch) >= self.sc_flush_size:
149
+ self.sc_send_batch(self.sc_log_batch)
150
+ self.sc_log_batch = []
151
+
152
+ def sc_send_batch(self, to_sent: list) -> bool:
153
+ """Send a request to the API to upload the events. If not have connection, write the events to local folder.
154
+
155
+ Args:
156
+ to_sent (list): The batch of events to send.
157
+
158
+ Returns:
159
+ bool: True if the batch was sent successfully, False otherwise.
160
+
161
+ """
162
+ if not self.sc_is_enabled:
163
+ return False
164
+ if self._rest is None:
165
+ self._sc_write_telemetry(to_sent)
166
+ self.sc_log_batch = []
167
+ return False
168
+ if to_sent == []:
169
+ return False
170
+ body = {"logs": to_sent}
171
+ ret = self._rest.request(
172
+ self.sc_sf_path_telemetry,
173
+ body=body,
174
+ method="post",
175
+ client=None,
176
+ timeout=5,
177
+ )
178
+ if not ret.get("success"):
179
+ self._sc_write_telemetry(to_sent)
180
+ self.sc_log_batch = []
181
+ return False
182
+ return True
183
+
184
+ def _sc_write_telemetry(self, batch: list) -> None:
185
+ """Write telemetry events to local folder. If the folder is full, free up space for the new events.
186
+
187
+ Args:
188
+ batch (list): The batch of events to write.
189
+
190
+ """
191
+ try:
192
+ os.makedirs(self.sc_folder_path, exist_ok=True)
193
+ for event in batch:
194
+ message = event.get("message")
195
+ if message is not None:
196
+ file_path = (
197
+ self.sc_folder_path
198
+ / f'{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")}'
199
+ f'_telemetry_{message.get("type")}.json'
200
+ )
201
+ json_content = self._sc_validate_folder_space(event)
202
+ with open(file_path, "w") as json_file:
203
+ json_file.write(json_content)
204
+ except Exception:
205
+ pass
206
+
207
+ def _sc_validate_folder_space(self, event: dict) -> str:
208
+ """Validate and manage folder space for the new telemetry events.
209
+
210
+ Args:
211
+ event (dict): The event to validate.
212
+
213
+ Returns:
214
+ str: The JSON content of the event.
215
+
216
+ """
217
+ json_content = json.dumps(event, indent=4, sort_keys=True)
218
+ new_json_file_size = len(json_content.encode("utf-8"))
219
+ telemetry_folder = self.sc_folder_path
220
+ folder_size = _get_folder_size(telemetry_folder)
221
+ if folder_size + new_json_file_size > self.sc_memory_limit:
222
+ _free_up_space(telemetry_folder, self.sc_memory_limit - new_json_file_size)
223
+ return json_content
224
+
225
+ def _sc_upload_local_telemetry(self) -> None:
226
+ """Send a request to the API to upload the local telemetry events."""
227
+ if not self.sc_is_enabled or self.sc_is_testing or not self._rest:
228
+ return
229
+ batch = []
230
+ for file in self.sc_folder_path.glob("*.json"):
231
+ with open(file) as json_file:
232
+ data_dict = json.load(json_file)
233
+ batch.append(data_dict)
234
+ if batch == []:
235
+ return
236
+ body = {"logs": batch}
237
+ ret = self._rest.request(
238
+ self.sc_sf_path_telemetry,
239
+ body=body,
240
+ method="post",
241
+ client=None,
242
+ timeout=5,
243
+ )
244
+ if ret.get("success"):
245
+ for file in self.sc_folder_path.glob("*.json"):
246
+ file.unlink()
247
+
248
+ def _sc_is_telemetry_testing(self) -> bool:
249
+ is_testing = os.getenv("SNOWPARK_CHECKPOINTS_TELEMETRY_TESTING") == "true"
250
+ if is_testing:
251
+ local_telemetry_path = (
252
+ Path(os.getcwd()) / "snowpark-checkpoints-output" / "telemetry"
253
+ )
254
+ self.set_sc_output_path(local_telemetry_path)
255
+ self.sc_is_enabled = True
256
+ return is_testing
257
+
258
+ def _sc_is_telemetry_manager(self) -> bool:
259
+ """Check if the class is telemetry manager.
260
+
261
+ Returns:
262
+ bool: True if the class is telemetry manager, False otherwise.
263
+
264
+ """
265
+ return True
266
+
267
+ def sc_is_hypothesis_event_logged(self, event_name: tuple[str, int]) -> bool:
268
+ """Check if a Hypothesis event is logged.
269
+
270
+ Args:
271
+ event_name (tuple[str, int]): A tuple containing the name of the event and an integer identifier
272
+ (0 for info, 1 for error).
273
+
274
+ Returns:
275
+ bool: True if the event is logged, False otherwise.
276
+
277
+ """
278
+ return event_name in self.sc_hypothesis_input_events
279
+
280
+ def _sc_close(self) -> None:
281
+ """Close the telemetry manager and upload collected events.
282
+
283
+ This function closes the telemetry manager, uploads any collected events,
284
+ and performs any necessary cleanup to ensure no data is lost.
285
+ """
286
+ atexit.unregister(self._sc_close_at_exit)
287
+ if self.sc_log_batch and self.sc_is_enabled and not self.sc_is_testing:
288
+ self.sc_send_batch(self.sc_log_batch)
289
+
290
+ def _sc_close_at_exit(self) -> None:
291
+ """Close the telemetry manager at exit and upload collected events.
292
+
293
+ This function ensures that the telemetry manager is closed and all collected events
294
+ are uploaded when the program exits, preventing data loss.
295
+ """
296
+ with suppress(Exception):
297
+ self._sc_close()
298
+
299
+
300
+ def _generate_event(
301
+ event_name: str,
302
+ event_type: str,
303
+ parameters_info: Optional[dict] = None,
304
+ sc_version: Optional[str] = None,
305
+ ) -> dict:
306
+ """Generate a telemetry event.
307
+
308
+ Args:
309
+ event_name (str): The name of the event.
310
+ event_type (str): The type of the event (e.g., "error", "info").
311
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
312
+ sc_version (str, optional): The version of the package. Defaults to None.
313
+
314
+ Returns:
315
+ dict: The generated event.
316
+
317
+ """
318
+ metadata = _get_metadata()
319
+ if sc_version is not None:
320
+ metadata["snowpark_checkpoints_version"] = sc_version
321
+ message = {
322
+ "type": event_type,
323
+ "event_name": event_name,
324
+ "driver_type": "PythonConnector",
325
+ "driver_version": SNOWFLAKE_CONNECTOR_VERSION,
326
+ "source": "snowpark-checkpoints",
327
+ "metadata": metadata,
328
+ "data": json.dumps(parameters_info or {}),
329
+ }
330
+ timestamp = time_util.get_time_millis()
331
+ event_base = {"message": message, "timestamp": str(timestamp)}
332
+
333
+ return event_base
334
+
335
+
336
+ def _get_metadata() -> dict:
337
+ """Get metadata for telemetry events.
338
+
339
+ Returns:
340
+ dict: The metadata including OS version, Python version, and device ID.
341
+
342
+ """
343
+ return {
344
+ "os_version": platform,
345
+ "python_version": python_version(),
346
+ "snowpark_version": ".".join(str(x) for x in SNOWPARK_VERSION if x is not None),
347
+ "device_id": _get_unique_id(),
348
+ }
349
+
350
+
351
+ def _get_version() -> str:
352
+ """Get the version of the package.
353
+
354
+ Returns:
355
+ str: The version of the package.
356
+
357
+ """
358
+ try:
359
+ directory_levels_up = 1
360
+ project_root = Path(__file__).resolve().parents[directory_levels_up]
361
+ version_file_path = project_root / VERSION_FILE_NAME
362
+ with open(version_file_path) as file:
363
+ content = file.read()
364
+ version_match = re.search(VERSION_VARIABLE_PATTERN, content, re.MULTILINE)
365
+ if version_match:
366
+ return version_match.group(1)
367
+ return None
368
+ except Exception:
369
+ return None
370
+
371
+
372
+ def _get_folder_size(folder_path: Path) -> int:
373
+ """Get the size of a folder. Only considers JSON files.
374
+
375
+ Args:
376
+ folder_path (Path): The path to the folder.
377
+
378
+ Returns:
379
+ int: The size of the folder in bytes.
380
+
381
+ """
382
+ return sum(f.stat().st_size for f in folder_path.glob("*.json") if f.is_file())
383
+
384
+
385
+ def _free_up_space(folder_path: Path, max_size: int) -> None:
386
+ """Free up space in a folder by deleting the oldest files. Only considers JSON files.
387
+
388
+ Args:
389
+ folder_path (Path): The path to the folder.
390
+ max_size (int): The maximum allowed size of the folder in bytes.
391
+
392
+ """
393
+ files = sorted(folder_path.glob("*.json"), key=lambda f: f.stat().st_mtime)
394
+ current_size = _get_folder_size(folder_path)
395
+ for file in files:
396
+ if current_size <= max_size:
397
+ break
398
+ current_size -= file.stat().st_size
399
+ file.unlink()
400
+
401
+
402
+ def _get_unique_id() -> str:
403
+ """Get a unique device ID. The ID is generated based on the hashed MAC address.
404
+
405
+ Returns:
406
+ str: The hashed device ID.
407
+
408
+ """
409
+ node_id_str = str(getnode())
410
+ hashed_id = hashlib.sha256(node_id_str.encode()).hexdigest()
411
+ return hashed_id
412
+
413
+
414
+ def get_telemetry_manager() -> TelemetryManager:
415
+ """Get the telemetry manager.
416
+
417
+ Returns:
418
+ TelemetryManager: The telemetry manager.
419
+
420
+ """
421
+ try:
422
+ connection = Session.builder.getOrCreate().connection
423
+ if not hasattr(connection._telemetry, "_sc_is_telemetry_manager"):
424
+ connection._telemetry = TelemetryManager(
425
+ connection._rest, connection.telemetry_enabled
426
+ )
427
+ return connection._telemetry
428
+ except Exception:
429
+ telemetry_manager = TelemetryManager(None, is_telemetry_enabled=True)
430
+ telemetry_manager.sc_flush_size = 1
431
+ return telemetry_manager
432
+
433
+
434
+ def get_snowflake_schema_types(df: snowpark_dataframe.DataFrame) -> list[str]:
435
+ """Extract the data types of the schema fields from a Snowflake DataFrame.
436
+
437
+ Args:
438
+ df (snowpark_dataframe.DataFrame): The Snowflake DataFrame.
439
+
440
+ Returns:
441
+ list[str]: A list of data type names of the schema fields.
442
+
443
+ """
444
+ return [str(schema_type.datatype) for schema_type in df.schema.fields]
445
+
446
+
447
+ def _is_snowpark_dataframe(df: Any) -> bool:
448
+ """Check if the given dataframe is a Snowpark dataframe.
449
+
450
+ Args:
451
+ df: The dataframe to check.
452
+
453
+ Returns:
454
+ bool: True if the dataframe is a Snowpark dataframe, False otherwise.
455
+
456
+ """
457
+ return isinstance(df, snowpark_dataframe.DataFrame)
458
+
459
+
460
+ def get_load_json(json_schema: str) -> dict:
461
+ """Load and parse a JSON schema file.
462
+
463
+ Args:
464
+ json_schema (str): The path to the JSON schema file.
465
+
466
+ Returns:
467
+ dict: The parsed JSON content.
468
+
469
+ Raises:
470
+ ValueError: If there is an error reading or parsing the JSON file.
471
+
472
+ """
473
+ try:
474
+ with open(json_schema, encoding="utf-8") as file:
475
+ return json.load(file)
476
+ except (OSError, json.JSONDecodeError) as e:
477
+ raise ValueError(f"Error reading JSON schema file: {e}") from None
478
+
479
+
480
+ def extract_parameters(
481
+ func: Callable, args: tuple, kwargs: dict, params_list: Optional[list[str]]
482
+ ) -> dict:
483
+ """Extract parameters from the function arguments.
484
+
485
+ Args:
486
+ func (Callable): The function being decorated.
487
+ args (tuple): The positional arguments passed to the function.
488
+ kwargs (dict): The keyword arguments passed to the function.
489
+ params_list (list[str]): The list of parameters to extract.
490
+
491
+ Returns:
492
+ dict: A dictionary of extracted parameters.
493
+
494
+ """
495
+ parameters = inspect.signature(func).parameters
496
+ param_data = {}
497
+ if params_list:
498
+ for _, param in enumerate(params_list):
499
+ if len(args) > 0:
500
+ index = list(parameters.keys()).index(param)
501
+ param_data[param] = args[index]
502
+ else:
503
+ if kwargs[param]:
504
+ param_data[param] = kwargs[param]
505
+ return param_data
506
+
507
+
508
+ def check_dataframe_schema_event(
509
+ telemetry_data: dict, param_data: dict
510
+ ) -> tuple[str, dict]:
511
+ """Handle telemetry event for checking dataframe schema.
512
+
513
+ Args:
514
+ telemetry_data (dict): The telemetry data dictionary.
515
+ param_data (dict): The parameter data dictionary.
516
+
517
+ Returns:
518
+ tuple: A tuple containing the event name and telemetry data.
519
+
520
+ """
521
+ telemetry_data[MODE_KEY] = CheckpointMode.SCHEMA.value
522
+ try:
523
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY)
524
+ pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
525
+ schema_types = []
526
+ for schema_type in pandera_schema.columns.values():
527
+ if schema_type.dtype is not None:
528
+ schema_types.append(str(schema_type.dtype))
529
+ if schema_types:
530
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
531
+ return DATAFRAME_VALIDATOR_SCHEMA, telemetry_data
532
+ except Exception:
533
+ if param_data.get(STATUS_KEY):
534
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY)
535
+ pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
536
+ if pandera_schema:
537
+ schema_types = []
538
+ for schema_type in pandera_schema.columns.values():
539
+ if schema_type.dtype is not None:
540
+ schema_types.append(str(schema_type.dtype))
541
+ if schema_types:
542
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
543
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
544
+
545
+
546
+ def check_output_or_input_schema_event(
547
+ telemetry_data: dict, param_data: dict
548
+ ) -> tuple[str, dict]:
549
+ """Handle telemetry event for checking output or input schema.
550
+
551
+ Args:
552
+ telemetry_data (dict): The telemetry data dictionary.
553
+ param_data (dict): The parameter data dictionary.
554
+
555
+ Returns:
556
+ tuple: A tuple containing the event name and telemetry data.
557
+
558
+ """
559
+ try:
560
+ pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
561
+ schema_types = []
562
+ for schema_type in pandera_schema.columns.values():
563
+ if schema_type.dtype is not None:
564
+ schema_types.append(str(schema_type.dtype))
565
+ if schema_types:
566
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
567
+ return DATAFRAME_VALIDATOR_SCHEMA, telemetry_data
568
+ except Exception:
569
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
570
+
571
+
572
+ def collect_dataframe_checkpoint_mode_schema_event(
573
+ telemetry_data: dict, param_data: dict
574
+ ) -> tuple[str, dict]:
575
+ """Handle telemetry event for collecting dataframe checkpoint mode schema.
576
+
577
+ Args:
578
+ telemetry_data (dict): The telemetry data dictionary.
579
+ param_data (dict): The parameter data dictionary.
580
+
581
+ Returns:
582
+ tuple: A tuple containing the event name and telemetry data.
583
+
584
+ """
585
+ telemetry_data[MODE_KEY] = CheckpointMode.SCHEMA.value
586
+ try:
587
+ schema_types = param_data.get("column_type_dict")
588
+ telemetry_data[SCHEMA_TYPES_KEY] = [
589
+ schema_types[schema_type].dataType.typeName()
590
+ for schema_type in schema_types
591
+ ]
592
+ return DATAFRAME_COLLECTION_SCHEMA, telemetry_data
593
+ except Exception:
594
+ return DATAFRAME_COLLECTION_ERROR, telemetry_data
595
+
596
+
597
+ def collect_dataframe_checkpoint_mode_dataframe_event(
598
+ telemetry_data: dict, param_data: dict
599
+ ) -> tuple[str, dict]:
600
+ """Handle telemetry event for collecting dataframe checkpoint mode dataframe.
601
+
602
+ This function processes telemetry data for a dataframe checkpoint mode event. It updates the telemetry data
603
+ with the mode and schema types of the Spark DataFrame being collected.
604
+
605
+ Args:
606
+ telemetry_data (dict): The telemetry data dictionary to be updated.
607
+ param_data (dict): The parameter data dictionary containing the DataFrame information.
608
+
609
+ Returns:
610
+ tuple: A tuple containing the event name and the updated telemetry data dictionary.
611
+
612
+ """
613
+ telemetry_data[MODE_KEY] = CheckpointMode.DATAFRAME.value
614
+ try:
615
+ if _is_spark_dataframe(param_data.get(DF_PARAM_NAME)):
616
+ telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
617
+ param_data.get(DF_PARAM_NAME)
618
+ )
619
+ return DATAFRAME_COLLECTION_DF, telemetry_data
620
+ except Exception:
621
+ return DATAFRAME_COLLECTION_ERROR, telemetry_data
622
+
623
+
624
+ def assert_return_event(telemetry_data: dict, param_data: dict) -> tuple[str, dict]:
625
+ """Handle telemetry event for asserting return values.
626
+
627
+ Args:
628
+ telemetry_data (dict): The telemetry data dictionary.
629
+ param_data (dict): The parameter data dictionary.
630
+
631
+ Returns:
632
+ tuple: A tuple containing the event name and telemetry data.
633
+
634
+ """
635
+ if param_data.get(STATUS_KEY) is not None:
636
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY, None)
637
+ try:
638
+ if _is_snowpark_dataframe(
639
+ param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
640
+ ) and _is_spark_dataframe(param_data.get(SPARK_RESULTS_PARAM_NAME)):
641
+ telemetry_data[SNOWFLAKE_SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
642
+ param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
643
+ )
644
+ telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
645
+ param_data.get(SPARK_RESULTS_PARAM_NAME)
646
+ )
647
+ return DATAFRAME_VALIDATOR_MIRROR, telemetry_data
648
+ else:
649
+ return VALUE_VALIDATOR_MIRROR, telemetry_data
650
+ except Exception:
651
+ if _is_snowpark_dataframe(param_data.get(SNOWPARK_RESULTS_PARAM_NAME)):
652
+ telemetry_data[SNOWFLAKE_SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
653
+ param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
654
+ )
655
+ if _is_spark_dataframe(param_data.get(SPARK_RESULTS_PARAM_NAME)):
656
+ telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
657
+ param_data.get(SPARK_RESULTS_PARAM_NAME)
658
+ )
659
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
660
+
661
+
662
+ def dataframe_strategy_event(
663
+ telemetry_data: dict, param_data: dict, telemetry_m: TelemetryManager
664
+ ) -> tuple[Optional[str], Optional[dict]]:
665
+ """Handle telemetry event for dataframe strategy.
666
+
667
+ Args:
668
+ telemetry_data (dict): The telemetry data dictionary.
669
+ param_data (dict): The parameter data dictionary.
670
+ telemetry_m (TelemetryManager): The telemetry manager.
671
+
672
+ Returns:
673
+ tuple: A tuple containing the event name and telemetry data.
674
+
675
+ """
676
+ try:
677
+ test_function_name = inspect.stack()[2].function
678
+ is_logged = telemetry_m.sc_is_hypothesis_event_logged((test_function_name, 0))
679
+ if not is_logged:
680
+ schema_param = param_data.get(DATAFRAME_STRATEGY_SCHEMA_PARAM_NAME)
681
+ if isinstance(schema_param, str):
682
+ json_data = get_load_json(schema_param)["custom_data"]["columns"]
683
+ telemetry_data[SCHEMA_TYPES_KEY] = [
684
+ column["type"] for column in json_data
685
+ ]
686
+ else:
687
+ schema_types = []
688
+ for schema_type in schema_param.columns.values():
689
+ if schema_type.dtype is not None:
690
+ schema_types.append(str(schema_type.dtype))
691
+ if schema_types:
692
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
693
+ telemetry_m.sc_hypothesis_input_events.append((test_function_name, 0))
694
+ if None in telemetry_data[SCHEMA_TYPES_KEY]:
695
+ telemetry_m.sc_log_error(HYPOTHESIS_INPUT_SCHEMA_ERROR, telemetry_data)
696
+ else:
697
+ telemetry_m.sc_log_info(HYPOTHESIS_INPUT_SCHEMA, telemetry_data)
698
+ telemetry_m.sc_send_batch(telemetry_m.sc_log_batch)
699
+ return None, None
700
+ except Exception:
701
+ test_function_name = inspect.stack()[2].function
702
+ is_logged = telemetry_m.sc_is_hypothesis_event_logged((test_function_name, 1))
703
+ if not is_logged:
704
+ telemetry_m.sc_hypothesis_input_events.append((test_function_name, 0))
705
+ telemetry_m.sc_log_error(HYPOTHESIS_INPUT_SCHEMA_ERROR, telemetry_data)
706
+ telemetry_m.sc_send_batch(telemetry_m.sc_log_batch)
707
+ return None, None
708
+
709
+
710
+ def compare_data_event(telemetry_data: dict, param_data: dict) -> tuple[str, dict]:
711
+ """Handle telemetry event for comparing data.
712
+
713
+ This function processes telemetry data for a data comparison event. It updates the telemetry data
714
+ with the mode, status, and schema types of the Snowflake DataFrame being compared.
715
+
716
+ Args:
717
+ telemetry_data (dict): The telemetry data dictionary to be updated.
718
+ param_data (dict): The parameter data dictionary containing the DataFrame and status information.
719
+
720
+ Returns:
721
+ tuple: A tuple containing the event name and the updated telemetry data dictionary.
722
+
723
+ """
724
+ telemetry_data[MODE_KEY] = CheckpointMode.DATAFRAME.value
725
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY, None)
726
+ try:
727
+ telemetry_data[SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
728
+ param_data.get("df")
729
+ )
730
+ return DATAFRAME_VALIDATOR_DF, telemetry_data
731
+ except Exception:
732
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
733
+
734
+
735
+ def handle_result(
736
+ func_name: str,
737
+ result: Any,
738
+ param_data: dict,
739
+ multiple_return: bool,
740
+ telemetry_m: TelemetryManager,
741
+ return_indexes: Optional[list[tuple[str, int]]] = None,
742
+ ) -> tuple[Optional[str], Optional[dict]]:
743
+ """Handle the result of the function and collect telemetry data.
744
+
745
+ Args:
746
+ func_name (str): The name of the function.
747
+ result: The result of the function.
748
+ param_data (dict): The extracted parameters.
749
+ multiple_return (bool): Whether the function returns multiple values.
750
+ telemetry_m (TelemetryManager): The telemetry manager.
751
+ return_indexes (list[tuple[str, int]]): The list of return values to report. Defaults to None.
752
+
753
+ Returns:
754
+ tuple: A tuple containing the event name (str) and telemetry data (dict).
755
+
756
+ """
757
+ if result is not None and return_indexes is not None:
758
+ if multiple_return:
759
+ for name, index in return_indexes:
760
+ param_data[name] = result[index]
761
+ else:
762
+ param_data[return_indexes[0][0]] = result[return_indexes[0][1]]
763
+
764
+ telemetry_data = {
765
+ FUNCTION_KEY: func_name,
766
+ }
767
+
768
+ telemetry_event = None
769
+ data = None
770
+ if func_name == "_check_dataframe_schema":
771
+ telemetry_event, data = check_dataframe_schema_event(telemetry_data, param_data)
772
+ elif func_name in ["check_output_schema", "check_input_schema"]:
773
+ telemetry_event, data = check_output_or_input_schema_event(
774
+ telemetry_data, param_data
775
+ )
776
+ if func_name == "_compare_data":
777
+ telemetry_event, data = compare_data_event(telemetry_data, param_data)
778
+ elif func_name == "_collect_dataframe_checkpoint_mode_schema":
779
+ telemetry_event, data = collect_dataframe_checkpoint_mode_schema_event(
780
+ telemetry_data, param_data
781
+ )
782
+ elif func_name == "_collect_dataframe_checkpoint_mode_dataframe":
783
+ telemetry_event, data = collect_dataframe_checkpoint_mode_dataframe_event(
784
+ telemetry_data, param_data
785
+ )
786
+ elif func_name == "_assert_return":
787
+ telemetry_event, data = assert_return_event(telemetry_data, param_data)
788
+ elif func_name == "dataframe_strategy":
789
+ telemetry_event, data = dataframe_strategy_event(
790
+ telemetry_data, param_data, telemetry_m
791
+ )
792
+ return telemetry_event, data
793
+
794
+
795
+ fn = TypeVar("fn", bound=Callable)
796
+
797
+
798
+ def report_telemetry(
799
+ params_list: list[str] = None,
800
+ return_indexes: list[tuple[str, int]] = None,
801
+ multiple_return: bool = False,
802
+ ) -> Callable[[fn], fn]:
803
+ """Report telemetry events for a function.
804
+
805
+ Args:
806
+ params_list (list[str], optional): The list of parameters to report. Defaults to None.
807
+ return_indexes (list[tuple[str, int]], optional): The list of return values to report. Defaults to None.
808
+ multiple_return (bool, optional): Whether the function returns multiple values. Defaults to False.
809
+
810
+ Returns:
811
+ Callable[[fn], fn]: The decorator function.
812
+
813
+ """
814
+
815
+ def report_telemetry_decorator(func):
816
+ func_name = func.__name__
817
+
818
+ @wraps(func)
819
+ def wrapper(*args, **kwargs):
820
+ func_exception = None
821
+ result = None
822
+ try:
823
+ result = func(*args, **kwargs)
824
+ except Exception as err:
825
+ func_exception = err
826
+
827
+ if os.getenv("SNOWPARK_CHECKPOINTS_TELEMETRY_ENABLED") == "false":
828
+ return result
829
+ telemetry_event = None
830
+ data = None
831
+ telemetry_m = None
832
+ try:
833
+ param_data = extract_parameters(func, args, kwargs, params_list)
834
+ telemetry_m = get_telemetry_manager()
835
+ telemetry_event, data = handle_result(
836
+ func_name,
837
+ result,
838
+ param_data,
839
+ multiple_return,
840
+ telemetry_m,
841
+ return_indexes,
842
+ )
843
+ except Exception:
844
+ pass
845
+ finally:
846
+ if func_exception is not None:
847
+ if telemetry_m is not None:
848
+ telemetry_m.sc_log_error(telemetry_event, data)
849
+ raise func_exception
850
+ if telemetry_m is not None:
851
+ telemetry_m.sc_log_info(telemetry_event, data)
852
+
853
+ return result
854
+
855
+ return wrapper
856
+
857
+ return report_telemetry_decorator
858
+
859
+
860
+ # Constants for telemetry
861
+ DATAFRAME_COLLECTION_SCHEMA = "DataFrame_Collection_Schema"
862
+ DATAFRAME_COLLECTION_DF = "DataFrame_Collection_DF"
863
+ DATAFRAME_VALIDATOR_MIRROR = "DataFrame_Validator_Mirror"
864
+ VALUE_VALIDATOR_MIRROR = "Value_Validator_Mirror"
865
+ DATAFRAME_VALIDATOR_SCHEMA = "DataFrame_Validator_Schema"
866
+ DATAFRAME_VALIDATOR_DF = "DataFrame_Validator_DF"
867
+ HYPOTHESIS_INPUT_SCHEMA = "Hypothesis_Input_Schema"
868
+ DATAFRAME_COLLECTION_ERROR = "DataFrame_Collection_Error"
869
+ DATAFRAME_VALIDATOR_ERROR = "DataFrame_Validator_Error"
870
+ HYPOTHESIS_INPUT_SCHEMA_ERROR = "Hypothesis_Input_Schema_Error"
871
+
872
+ FUNCTION_KEY = "function"
873
+ STATUS_KEY = "status"
874
+ SCHEMA_TYPES_KEY = "schema_types"
875
+ ERROR_KEY = "error"
876
+ MODE_KEY = "mode"
877
+ SNOWFLAKE_SCHEMA_TYPES_KEY = "snowflake_schema_types"
878
+ SPARK_SCHEMA_TYPES_KEY = "spark_schema_types"
879
+
880
+ DATAFRAME_STRATEGY_SCHEMA_PARAM_NAME = "schema"
881
+ PANDERA_SCHEMA_PARAM_NAME = "pandera_schema"
882
+ SNOWPARK_RESULTS_PARAM_NAME = "snowpark_results"
883
+ SPARK_RESULTS_PARAM_NAME = "spark_results"
884
+ DF_PARAM_NAME = "df"
885
+
886
+
887
+ class CheckpointMode(IntEnum):
888
+ SCHEMA = 1
889
+ DATAFRAME = 2