snowpark-checkpoints-collectors 0.2.0rc1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. snowflake/snowpark_checkpoints_collector/__init__.py +30 -0
  2. snowflake/snowpark_checkpoints_collector/__version__.py +16 -0
  3. snowflake/snowpark_checkpoints_collector/collection_common.py +160 -0
  4. snowflake/snowpark_checkpoints_collector/collection_result/model/__init__.py +24 -0
  5. snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result.py +91 -0
  6. snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py +76 -0
  7. snowflake/snowpark_checkpoints_collector/column_collection/__init__.py +22 -0
  8. snowflake/snowpark_checkpoints_collector/column_collection/column_collector_manager.py +276 -0
  9. snowflake/snowpark_checkpoints_collector/column_collection/model/__init__.py +75 -0
  10. snowflake/snowpark_checkpoints_collector/column_collection/model/array_column_collector.py +113 -0
  11. snowflake/snowpark_checkpoints_collector/column_collection/model/binary_column_collector.py +87 -0
  12. snowflake/snowpark_checkpoints_collector/column_collection/model/boolean_column_collector.py +71 -0
  13. snowflake/snowpark_checkpoints_collector/column_collection/model/column_collector_base.py +95 -0
  14. snowflake/snowpark_checkpoints_collector/column_collection/model/date_column_collector.py +74 -0
  15. snowflake/snowpark_checkpoints_collector/column_collection/model/day_time_interval_column_collector.py +67 -0
  16. snowflake/snowpark_checkpoints_collector/column_collection/model/decimal_column_collector.py +92 -0
  17. snowflake/snowpark_checkpoints_collector/column_collection/model/empty_column_collector.py +88 -0
  18. snowflake/snowpark_checkpoints_collector/column_collection/model/map_column_collector.py +120 -0
  19. snowflake/snowpark_checkpoints_collector/column_collection/model/null_column_collector.py +49 -0
  20. snowflake/snowpark_checkpoints_collector/column_collection/model/numeric_column_collector.py +108 -0
  21. snowflake/snowpark_checkpoints_collector/column_collection/model/string_column_collector.py +70 -0
  22. snowflake/snowpark_checkpoints_collector/column_collection/model/struct_column_collector.py +102 -0
  23. snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_column_collector.py +75 -0
  24. snowflake/snowpark_checkpoints_collector/column_collection/model/timestamp_ntz_column_collector.py +75 -0
  25. snowflake/snowpark_checkpoints_collector/column_pandera_checks/__init__.py +20 -0
  26. snowflake/snowpark_checkpoints_collector/column_pandera_checks/pandera_column_checks_manager.py +241 -0
  27. snowflake/snowpark_checkpoints_collector/io_utils/__init__.py +26 -0
  28. snowflake/snowpark_checkpoints_collector/io_utils/io_default_strategy.py +61 -0
  29. snowflake/snowpark_checkpoints_collector/io_utils/io_env_strategy.py +142 -0
  30. snowflake/snowpark_checkpoints_collector/io_utils/io_file_manager.py +79 -0
  31. snowflake/snowpark_checkpoints_collector/singleton.py +23 -0
  32. snowflake/snowpark_checkpoints_collector/snow_connection_model/__init__.py +20 -0
  33. snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py +203 -0
  34. snowflake/snowpark_checkpoints_collector/summary_stats_collector.py +409 -0
  35. snowflake/snowpark_checkpoints_collector/utils/checkpoint_name_utils.py +53 -0
  36. snowflake/snowpark_checkpoints_collector/utils/extra_config.py +164 -0
  37. snowflake/snowpark_checkpoints_collector/utils/file_utils.py +137 -0
  38. snowflake/snowpark_checkpoints_collector/utils/logging_utils.py +67 -0
  39. snowflake/snowpark_checkpoints_collector/utils/telemetry.py +928 -0
  40. snowpark_checkpoints_collectors-0.3.0.dist-info/METADATA +159 -0
  41. snowpark_checkpoints_collectors-0.3.0.dist-info/RECORD +43 -0
  42. {snowpark_checkpoints_collectors-0.2.0rc1.dist-info → snowpark_checkpoints_collectors-0.3.0.dist-info}/licenses/LICENSE +0 -25
  43. snowpark_checkpoints_collectors-0.2.0rc1.dist-info/METADATA +0 -347
  44. snowpark_checkpoints_collectors-0.2.0rc1.dist-info/RECORD +0 -4
  45. {snowpark_checkpoints_collectors-0.2.0rc1.dist-info → snowpark_checkpoints_collectors-0.3.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,928 @@
1
+ #
2
+ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
+ #
4
+
5
+ import atexit
6
+ import datetime
7
+ import hashlib
8
+ import inspect
9
+ import json
10
+ import os
11
+ import re
12
+
13
+ from contextlib import suppress
14
+ from enum import IntEnum
15
+ from functools import wraps
16
+ from pathlib import Path
17
+ from platform import python_version
18
+ from sys import platform
19
+ from typing import Any, Callable, Optional, TypeVar
20
+ from uuid import getnode
21
+
22
+ from snowflake.connector.description import PLATFORM as CONNECTOR_PLATFORM
23
+ from snowflake.connector.telemetry import TelemetryClient
24
+ from snowflake.snowpark import VERSION as SNOWPARK_VERSION
25
+ from snowflake.snowpark import dataframe as snowpark_dataframe
26
+ from snowflake.snowpark.session import Session
27
+ from snowflake.snowpark_checkpoints_collector.io_utils.io_file_manager import (
28
+ get_io_file_manager,
29
+ )
30
+
31
+
32
+ try:
33
+ """
34
+ The following imports are used to log telemetry events in the Snowflake Connector.
35
+ """
36
+ from snowflake.connector import (
37
+ SNOWFLAKE_CONNECTOR_VERSION,
38
+ time_util,
39
+ )
40
+ from snowflake.connector.constants import DIRS as SNOWFLAKE_DIRS
41
+ from snowflake.connector.network import SnowflakeRestful
42
+ except Exception:
43
+ """
44
+ Set default import values for the Snowflake Connector when using snowpark-checkpoints in stored procedures.
45
+ """
46
+ SNOWFLAKE_CONNECTOR_VERSION = ""
47
+ time_util = None
48
+ SNOWFLAKE_DIRS = ""
49
+ SnowflakeRestful = None
50
+
51
+
52
+ try:
53
+ from pyspark.sql import dataframe as spark_dataframe
54
+
55
+ def _is_spark_dataframe(df: Any) -> bool:
56
+ return isinstance(df, spark_dataframe.DataFrame)
57
+
58
+ def _get_spark_schema_types(df: spark_dataframe.DataFrame) -> list[str]:
59
+ return [str(schema_type.dataType) for schema_type in df.schema.fields]
60
+
61
+ except Exception:
62
+
63
+ def _is_spark_dataframe(df: Any):
64
+ pass
65
+
66
+ def _get_spark_schema_types(df: Any):
67
+ pass
68
+
69
+
70
+ VERSION_VARIABLE_PATTERN = r"^__version__ = ['\"]([^'\"]*)['\"]"
71
+ VERSION_FILE_NAME = "__version__.py"
72
+
73
+
74
+ class TelemetryManager(TelemetryClient):
75
+ def __init__(
76
+ self, rest: Optional[SnowflakeRestful] = None, is_telemetry_enabled: bool = True
77
+ ):
78
+ """TelemetryManager class to log telemetry events."""
79
+ super().__init__(rest)
80
+ self.sc_folder_path = (
81
+ Path(SNOWFLAKE_DIRS.user_config_path) / "snowpark-checkpoints-telemetry"
82
+ )
83
+ self.sc_sf_path_telemetry = "/telemetry/send"
84
+ self.sc_flush_size = 25
85
+ self.sc_is_enabled = is_telemetry_enabled
86
+ self.sc_is_testing = self._sc_is_telemetry_testing()
87
+ self.sc_memory_limit = 5 * 1024 * 1024
88
+ self._sc_upload_local_telemetry()
89
+ self.sc_log_batch = []
90
+ self.sc_hypothesis_input_events = []
91
+ self.sc_version = _get_version()
92
+ if rest:
93
+ atexit.register(self._sc_close_at_exit)
94
+
95
+ def set_sc_output_path(self, path: Path) -> None:
96
+ """Set the output path for testing.
97
+
98
+ Args:
99
+ path: path to write telemetry.
100
+
101
+ """
102
+ get_io_file_manager().mkdir(str(path), exist_ok=True)
103
+ self.sc_folder_path = path
104
+
105
+ def sc_log_error(
106
+ self, event_name: str, parameters_info: Optional[dict] = None
107
+ ) -> None:
108
+ """Log an error telemetry event.
109
+
110
+ Args:
111
+ event_name (str): The name of the event.
112
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
113
+
114
+ """
115
+ if event_name is not None:
116
+ self._sc_log_telemetry(event_name, "error", parameters_info)
117
+
118
+ def sc_log_info(
119
+ self, event_name: str, parameters_info: Optional[dict] = None
120
+ ) -> None:
121
+ """Log an information telemetry event.
122
+
123
+ Args:
124
+ event_name (str): The name of the event.
125
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
126
+
127
+ """
128
+ if event_name is not None:
129
+ self._sc_log_telemetry(event_name, "info", parameters_info)
130
+
131
+ def _sc_log_telemetry(
132
+ self, event_name: str, event_type: str, parameters_info: Optional[dict] = None
133
+ ) -> dict:
134
+ """Log a telemetry event if is enabled.
135
+
136
+ Args:
137
+ event_name (str): The name of the event.
138
+ event_type (str): The type of the event (e.g., "error", "info").
139
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
140
+
141
+ Returns:
142
+ dict: The logged event.
143
+
144
+ """
145
+ if not self.sc_is_enabled:
146
+ return {}
147
+ event = _generate_event(
148
+ event_name, event_type, parameters_info, self.sc_version
149
+ )
150
+ self._sc_add_log_to_batch(event)
151
+ return event
152
+
153
+ def _sc_add_log_to_batch(self, event: dict) -> None:
154
+ """Add a log event to the batch. If the batch is full, send the events to the API.
155
+
156
+ Args:
157
+ event (dict): The event to add.
158
+
159
+ """
160
+ self.sc_log_batch.append(event)
161
+ if self.sc_is_testing:
162
+ self._sc_write_telemetry(self.sc_log_batch)
163
+ self.sc_log_batch = []
164
+ return
165
+
166
+ if len(self.sc_log_batch) >= self.sc_flush_size:
167
+ self.sc_send_batch(self.sc_log_batch)
168
+ self.sc_log_batch = []
169
+
170
+ def sc_send_batch(self, to_sent: list) -> bool:
171
+ """Send a request to the API to upload the events. If not have connection, write the events to local folder.
172
+
173
+ Args:
174
+ to_sent (list): The batch of events to send.
175
+
176
+ Returns:
177
+ bool: True if the batch was sent successfully, False otherwise.
178
+
179
+ """
180
+ if not self.sc_is_enabled:
181
+ return False
182
+ if self._rest is None:
183
+ self._sc_write_telemetry(to_sent)
184
+ self.sc_log_batch = []
185
+ return False
186
+ if to_sent == []:
187
+ return False
188
+ body = {"logs": to_sent}
189
+ ret = self._rest.request(
190
+ self.sc_sf_path_telemetry,
191
+ body=body,
192
+ method="post",
193
+ client=None,
194
+ timeout=5,
195
+ )
196
+ if not ret.get("success"):
197
+ self._sc_write_telemetry(to_sent)
198
+ self.sc_log_batch = []
199
+ return False
200
+ return True
201
+
202
+ def _sc_write_telemetry(self, batch: list) -> None:
203
+ """Write telemetry events to local folder. If the folder is full, free up space for the new events.
204
+
205
+ Args:
206
+ batch (list): The batch of events to write.
207
+
208
+ """
209
+ try:
210
+ get_io_file_manager().mkdir(str(self.sc_folder_path), exist_ok=True)
211
+ for event in batch:
212
+ message = event.get("message")
213
+ if message is not None:
214
+ file_path = (
215
+ self.sc_folder_path
216
+ / f'{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")}'
217
+ f'_telemetry_{message.get("type")}.json'
218
+ )
219
+ json_content = self._sc_validate_folder_space(event)
220
+ get_io_file_manager().write(str(file_path), json_content)
221
+ except Exception:
222
+ pass
223
+
224
+ def _sc_validate_folder_space(self, event: dict) -> str:
225
+ """Validate and manage folder space for the new telemetry events.
226
+
227
+ Args:
228
+ event (dict): The event to validate.
229
+
230
+ Returns:
231
+ str: The JSON content of the event.
232
+
233
+ """
234
+ json_content = json.dumps(event, indent=4, sort_keys=True)
235
+ new_json_file_size = len(json_content.encode("utf-8"))
236
+ telemetry_folder = self.sc_folder_path
237
+ folder_size = _get_folder_size(telemetry_folder)
238
+ if folder_size + new_json_file_size > self.sc_memory_limit:
239
+ _free_up_space(telemetry_folder, self.sc_memory_limit - new_json_file_size)
240
+ return json_content
241
+
242
+ def _sc_upload_local_telemetry(self) -> None:
243
+ """Send a request to the API to upload the local telemetry events."""
244
+ if not self.sc_is_enabled or self.sc_is_testing or not self._rest:
245
+ return
246
+ batch = []
247
+ for file in get_io_file_manager().ls(f"{self.sc_folder_path}/*.json"):
248
+ json_content = get_io_file_manager().read(file)
249
+ data_dict = json.loads(json_content)
250
+ batch.append(data_dict)
251
+ if batch == []:
252
+ return
253
+ body = {"logs": batch}
254
+ ret = self._rest.request(
255
+ self.sc_sf_path_telemetry,
256
+ body=body,
257
+ method="post",
258
+ client=None,
259
+ timeout=5,
260
+ )
261
+ if ret.get("success"):
262
+ for file_path in get_io_file_manager().ls(f"{self.sc_folder_path}/*.json"):
263
+ file = get_io_file_manager().telemetry_path_files(file_path)
264
+ file.unlink()
265
+
266
+ def _sc_is_telemetry_testing(self) -> bool:
267
+ is_testing = os.getenv("SNOWPARK_CHECKPOINTS_TELEMETRY_TESTING") == "true"
268
+ if is_testing:
269
+ local_telemetry_path = (
270
+ Path(get_io_file_manager().getcwd())
271
+ / "snowpark-checkpoints-output"
272
+ / "telemetry"
273
+ )
274
+ self.set_sc_output_path(local_telemetry_path)
275
+ self.sc_is_enabled = True
276
+ return is_testing
277
+
278
+ def _sc_is_telemetry_manager(self) -> bool:
279
+ """Check if the class is telemetry manager.
280
+
281
+ Returns:
282
+ bool: True if the class is telemetry manager, False otherwise.
283
+
284
+ """
285
+ return True
286
+
287
+ def sc_is_hypothesis_event_logged(self, event_name: tuple[str, int]) -> bool:
288
+ """Check if a Hypothesis event is logged.
289
+
290
+ Args:
291
+ event_name (tuple[str, int]): A tuple containing the name of the event and an integer identifier
292
+ (0 for info, 1 for error).
293
+
294
+ Returns:
295
+ bool: True if the event is logged, False otherwise.
296
+
297
+ """
298
+ return event_name in self.sc_hypothesis_input_events
299
+
300
+ def _sc_close(self) -> None:
301
+ """Close the telemetry manager and upload collected events.
302
+
303
+ This function closes the telemetry manager, uploads any collected events,
304
+ and performs any necessary cleanup to ensure no data is lost.
305
+ """
306
+ atexit.unregister(self._sc_close_at_exit)
307
+ if self.sc_log_batch and self.sc_is_enabled and not self.sc_is_testing:
308
+ self.sc_send_batch(self.sc_log_batch)
309
+
310
+ def _sc_close_at_exit(self) -> None:
311
+ """Close the telemetry manager at exit and upload collected events.
312
+
313
+ This function ensures that the telemetry manager is closed and all collected events
314
+ are uploaded when the program exits, preventing data loss.
315
+ """
316
+ with suppress(Exception):
317
+ self._sc_close()
318
+
319
+
320
+ def _generate_event(
321
+ event_name: str,
322
+ event_type: str,
323
+ parameters_info: Optional[dict] = None,
324
+ sc_version: Optional[str] = None,
325
+ ) -> dict:
326
+ """Generate a telemetry event.
327
+
328
+ Args:
329
+ event_name (str): The name of the event.
330
+ event_type (str): The type of the event (e.g., "error", "info").
331
+ parameters_info (dict, optional): Additional parameters for the event. Defaults to None.
332
+ sc_version (str, optional): The version of the package. Defaults to None.
333
+
334
+ Returns:
335
+ dict: The generated event.
336
+
337
+ """
338
+ metadata = _get_metadata()
339
+ if sc_version is not None:
340
+ metadata["snowpark_checkpoints_version"] = sc_version
341
+ message = {
342
+ "event_type": event_type,
343
+ "type": "snowpark-checkpoints",
344
+ "event_name": event_name,
345
+ "driver_type": "PythonConnector",
346
+ "driver_version": SNOWFLAKE_CONNECTOR_VERSION,
347
+ "metadata": metadata,
348
+ "data": json.dumps(parameters_info or {}),
349
+ }
350
+ timestamp = time_util.get_time_millis()
351
+ event_base = {"message": message, "timestamp": str(timestamp)}
352
+
353
+ return event_base
354
+
355
+
356
+ def _get_metadata() -> dict:
357
+ """Get metadata for telemetry events.
358
+
359
+ Returns:
360
+ dict: The metadata including OS version, Python version, and device ID.
361
+
362
+ """
363
+ return {
364
+ "os_version": platform,
365
+ "python_version": python_version(),
366
+ "snowpark_version": ".".join(str(x) for x in SNOWPARK_VERSION if x is not None),
367
+ "device_id": _get_unique_id(),
368
+ }
369
+
370
+
371
+ def _get_version() -> Optional[str]:
372
+ """Get the version of the package.
373
+
374
+ Returns:
375
+ str: The version of the package.
376
+
377
+ """
378
+ try:
379
+ directory_levels_up = 1
380
+ project_root = Path(__file__).resolve().parents[directory_levels_up]
381
+ version_file_path = project_root / VERSION_FILE_NAME
382
+ content = get_io_file_manager().read(str(version_file_path))
383
+ version_match = re.search(VERSION_VARIABLE_PATTERN, content, re.MULTILINE)
384
+ if version_match:
385
+ return version_match.group(1)
386
+ return None
387
+ except Exception:
388
+ return None
389
+
390
+
391
+ def _get_folder_size(folder_path: Path) -> int:
392
+ """Get the size of a folder. Only considers JSON files.
393
+
394
+ Args:
395
+ folder_path (Path): The path to the folder.
396
+
397
+ Returns:
398
+ int: The size of the folder in bytes.
399
+
400
+ """
401
+ sum_size = 0
402
+ for f in get_io_file_manager().ls(f"{folder_path}/*.json"):
403
+ sum_size += get_io_file_manager().telemetry_path_files(f).stat().st_size
404
+ return sum_size
405
+
406
+
407
+ def _free_up_space(folder_path: Path, max_size: int) -> None:
408
+ """Free up space in a folder by deleting the oldest files. Only considers JSON files.
409
+
410
+ Args:
411
+ folder_path (Path): The path to the folder.
412
+ max_size (int): The maximum allowed size of the folder in bytes.
413
+
414
+ """
415
+ files = sorted(
416
+ get_io_file_manager().ls(f"{folder_path}/*.json"),
417
+ key=lambda f: f.stat().st_mtime,
418
+ )
419
+ current_size = _get_folder_size(folder_path)
420
+ for file_path in files:
421
+ file = get_io_file_manager().telemetry_path_files(file_path)
422
+ if current_size <= max_size:
423
+ break
424
+ current_size -= file.stat().st_size
425
+ file.unlink()
426
+
427
+
428
+ def _get_unique_id() -> str:
429
+ """Get a unique device ID. The ID is generated based on the hashed MAC address.
430
+
431
+ Returns:
432
+ str: The hashed device ID.
433
+
434
+ """
435
+ node_id_str = str(getnode())
436
+ hashed_id = hashlib.sha256(node_id_str.encode()).hexdigest()
437
+ return hashed_id
438
+
439
+
440
+ def get_telemetry_manager() -> TelemetryManager:
441
+ """Get the telemetry manager.
442
+
443
+ Returns:
444
+ TelemetryManager: The telemetry manager.
445
+
446
+ """
447
+ try:
448
+ connection = Session.builder.getOrCreate().connection
449
+ if not hasattr(connection._telemetry, "_sc_is_telemetry_manager"):
450
+ connection._telemetry = TelemetryManager(
451
+ connection._rest, connection.telemetry_enabled
452
+ )
453
+ return connection._telemetry
454
+ except Exception:
455
+ telemetry_manager = TelemetryManager(None, is_telemetry_enabled=True)
456
+ telemetry_manager.sc_flush_size = 1
457
+ return telemetry_manager
458
+
459
+
460
+ def get_snowflake_schema_types(df: snowpark_dataframe.DataFrame) -> list[str]:
461
+ """Extract the data types of the schema fields from a Snowflake DataFrame.
462
+
463
+ Args:
464
+ df (snowpark_dataframe.DataFrame): The Snowflake DataFrame.
465
+
466
+ Returns:
467
+ list[str]: A list of data type names of the schema fields.
468
+
469
+ """
470
+ return [str(schema_type.datatype) for schema_type in df.schema.fields]
471
+
472
+
473
+ def _is_snowpark_dataframe(df: Any) -> bool:
474
+ """Check if the given dataframe is a Snowpark dataframe.
475
+
476
+ Args:
477
+ df: The dataframe to check.
478
+
479
+ Returns:
480
+ bool: True if the dataframe is a Snowpark dataframe, False otherwise.
481
+
482
+ """
483
+ return isinstance(df, snowpark_dataframe.DataFrame)
484
+
485
+
486
+ def get_load_json(json_schema: str) -> dict:
487
+ """Load and parse a JSON schema file.
488
+
489
+ Args:
490
+ json_schema (str): The path to the JSON schema file.
491
+
492
+ Returns:
493
+ dict: The parsed JSON content.
494
+
495
+ Raises:
496
+ ValueError: If there is an error reading or parsing the JSON file.
497
+
498
+ """
499
+ try:
500
+ file_content = get_io_file_manager().read(json_schema, encoding="utf-8")
501
+ return json.loads(file_content)
502
+ except (OSError, json.JSONDecodeError) as e:
503
+ raise ValueError(f"Error reading JSON schema file: {e}") from None
504
+
505
+
506
+ def _is_in_stored_procedure() -> bool:
507
+ """Check if the code is running in a stored procedure.
508
+
509
+ Returns:
510
+ bool: True if the code is running in a stored procedure, False otherwise.
511
+
512
+ """
513
+ return CONNECTOR_PLATFORM == "XP"
514
+
515
+
516
+ def extract_parameters(
517
+ func: Callable, args: tuple, kwargs: dict, params_list: Optional[list[str]]
518
+ ) -> dict:
519
+ """Extract parameters from the function arguments.
520
+
521
+ Args:
522
+ func (Callable): The function being decorated.
523
+ args (tuple): The positional arguments passed to the function.
524
+ kwargs (dict): The keyword arguments passed to the function.
525
+ params_list (list[str]): The list of parameters to extract.
526
+
527
+ Returns:
528
+ dict: A dictionary of extracted parameters.
529
+
530
+ """
531
+ parameters = inspect.signature(func).parameters
532
+ param_data = {}
533
+ if params_list:
534
+ for _, param in enumerate(params_list):
535
+ if len(args) > 0:
536
+ index = list(parameters.keys()).index(param)
537
+ param_data[param] = args[index]
538
+ else:
539
+ if kwargs[param]:
540
+ param_data[param] = kwargs[param]
541
+ return param_data
542
+
543
+
544
+ def check_dataframe_schema_event(
545
+ telemetry_data: dict, param_data: dict
546
+ ) -> tuple[str, dict]:
547
+ """Handle telemetry event for checking dataframe schema.
548
+
549
+ Args:
550
+ telemetry_data (dict): The telemetry data dictionary.
551
+ param_data (dict): The parameter data dictionary.
552
+
553
+ Returns:
554
+ tuple: A tuple containing the event name and telemetry data.
555
+
556
+ """
557
+ telemetry_data[MODE_KEY] = CheckpointMode.SCHEMA.value
558
+ try:
559
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY)
560
+ pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
561
+ schema_types = []
562
+ for schema_type in pandera_schema.columns.values():
563
+ if schema_type.dtype is not None:
564
+ schema_types.append(str(schema_type.dtype))
565
+ if schema_types:
566
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
567
+ return DATAFRAME_VALIDATOR_SCHEMA, telemetry_data
568
+ except Exception:
569
+ if param_data.get(STATUS_KEY):
570
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY)
571
+ pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
572
+ if pandera_schema:
573
+ schema_types = []
574
+ for schema_type in pandera_schema.columns.values():
575
+ if schema_type.dtype is not None:
576
+ schema_types.append(str(schema_type.dtype))
577
+ if schema_types:
578
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
579
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
580
+
581
+
582
+ def check_output_or_input_schema_event(
583
+ telemetry_data: dict, param_data: dict
584
+ ) -> tuple[str, dict]:
585
+ """Handle telemetry event for checking output or input schema.
586
+
587
+ Args:
588
+ telemetry_data (dict): The telemetry data dictionary.
589
+ param_data (dict): The parameter data dictionary.
590
+
591
+ Returns:
592
+ tuple: A tuple containing the event name and telemetry data.
593
+
594
+ """
595
+ try:
596
+ pandera_schema = param_data.get(PANDERA_SCHEMA_PARAM_NAME)
597
+ schema_types = []
598
+ for schema_type in pandera_schema.columns.values():
599
+ if schema_type.dtype is not None:
600
+ schema_types.append(str(schema_type.dtype))
601
+ if schema_types:
602
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
603
+ return DATAFRAME_VALIDATOR_SCHEMA, telemetry_data
604
+ except Exception:
605
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
606
+
607
+
608
+ def collect_dataframe_checkpoint_mode_schema_event(
609
+ telemetry_data: dict, param_data: dict
610
+ ) -> tuple[str, dict]:
611
+ """Handle telemetry event for collecting dataframe checkpoint mode schema.
612
+
613
+ Args:
614
+ telemetry_data (dict): The telemetry data dictionary.
615
+ param_data (dict): The parameter data dictionary.
616
+
617
+ Returns:
618
+ tuple: A tuple containing the event name and telemetry data.
619
+
620
+ """
621
+ telemetry_data[MODE_KEY] = CheckpointMode.SCHEMA.value
622
+ try:
623
+ schema_types = param_data.get("column_type_dict")
624
+ telemetry_data[SCHEMA_TYPES_KEY] = [
625
+ schema_types[schema_type].dataType.typeName()
626
+ for schema_type in schema_types
627
+ ]
628
+ return DATAFRAME_COLLECTION_SCHEMA, telemetry_data
629
+ except Exception:
630
+ return DATAFRAME_COLLECTION_ERROR, telemetry_data
631
+
632
+
633
+ def collect_dataframe_checkpoint_mode_dataframe_event(
634
+ telemetry_data: dict, param_data: dict
635
+ ) -> tuple[str, dict]:
636
+ """Handle telemetry event for collecting dataframe checkpoint mode dataframe.
637
+
638
+ This function processes telemetry data for a dataframe checkpoint mode event. It updates the telemetry data
639
+ with the mode and schema types of the Spark DataFrame being collected.
640
+
641
+ Args:
642
+ telemetry_data (dict): The telemetry data dictionary to be updated.
643
+ param_data (dict): The parameter data dictionary containing the DataFrame information.
644
+
645
+ Returns:
646
+ tuple: A tuple containing the event name and the updated telemetry data dictionary.
647
+
648
+ """
649
+ telemetry_data[MODE_KEY] = CheckpointMode.DATAFRAME.value
650
+ try:
651
+ if _is_spark_dataframe(param_data.get(DF_PARAM_NAME)):
652
+ telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
653
+ param_data.get(DF_PARAM_NAME)
654
+ )
655
+ return DATAFRAME_COLLECTION_DF, telemetry_data
656
+ except Exception:
657
+ return DATAFRAME_COLLECTION_ERROR, telemetry_data
658
+
659
+
660
+ def assert_return_event(telemetry_data: dict, param_data: dict) -> tuple[str, dict]:
661
+ """Handle telemetry event for asserting return values.
662
+
663
+ Args:
664
+ telemetry_data (dict): The telemetry data dictionary.
665
+ param_data (dict): The parameter data dictionary.
666
+
667
+ Returns:
668
+ tuple: A tuple containing the event name and telemetry data.
669
+
670
+ """
671
+ if param_data.get(STATUS_KEY) is not None:
672
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY, None)
673
+ try:
674
+ if _is_snowpark_dataframe(
675
+ param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
676
+ ) and _is_spark_dataframe(param_data.get(SPARK_RESULTS_PARAM_NAME)):
677
+ telemetry_data[SNOWFLAKE_SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
678
+ param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
679
+ )
680
+ telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
681
+ param_data.get(SPARK_RESULTS_PARAM_NAME)
682
+ )
683
+ return DATAFRAME_VALIDATOR_MIRROR, telemetry_data
684
+ else:
685
+ return VALUE_VALIDATOR_MIRROR, telemetry_data
686
+ except Exception:
687
+ if _is_snowpark_dataframe(param_data.get(SNOWPARK_RESULTS_PARAM_NAME)):
688
+ telemetry_data[SNOWFLAKE_SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
689
+ param_data.get(SNOWPARK_RESULTS_PARAM_NAME)
690
+ )
691
+ if _is_spark_dataframe(param_data.get(SPARK_RESULTS_PARAM_NAME)):
692
+ telemetry_data[SPARK_SCHEMA_TYPES_KEY] = _get_spark_schema_types(
693
+ param_data.get(SPARK_RESULTS_PARAM_NAME)
694
+ )
695
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
696
+
697
+
698
+ def dataframe_strategy_event(
699
+ telemetry_data: dict, param_data: dict, telemetry_m: TelemetryManager
700
+ ) -> tuple[Optional[str], Optional[dict]]:
701
+ """Handle telemetry event for dataframe strategy.
702
+
703
+ Args:
704
+ telemetry_data (dict): The telemetry data dictionary.
705
+ param_data (dict): The parameter data dictionary.
706
+ telemetry_m (TelemetryManager): The telemetry manager.
707
+
708
+ Returns:
709
+ tuple: A tuple containing the event name and telemetry data.
710
+
711
+ """
712
+ try:
713
+ test_function_name = inspect.stack()[2].function
714
+ is_logged = telemetry_m.sc_is_hypothesis_event_logged((test_function_name, 0))
715
+ if not is_logged:
716
+ schema_param = param_data.get(DATAFRAME_STRATEGY_SCHEMA_PARAM_NAME)
717
+ if isinstance(schema_param, str):
718
+ json_data = get_load_json(schema_param)["custom_data"]["columns"]
719
+ telemetry_data[SCHEMA_TYPES_KEY] = [
720
+ column["type"] for column in json_data
721
+ ]
722
+ else:
723
+ schema_types = []
724
+ for schema_type in schema_param.columns.values():
725
+ if schema_type.dtype is not None:
726
+ schema_types.append(str(schema_type.dtype))
727
+ if schema_types:
728
+ telemetry_data[SCHEMA_TYPES_KEY] = schema_types
729
+ telemetry_m.sc_hypothesis_input_events.append((test_function_name, 0))
730
+ if None in telemetry_data[SCHEMA_TYPES_KEY]:
731
+ telemetry_m.sc_log_error(HYPOTHESIS_INPUT_SCHEMA_ERROR, telemetry_data)
732
+ else:
733
+ telemetry_m.sc_log_info(HYPOTHESIS_INPUT_SCHEMA, telemetry_data)
734
+ telemetry_m.sc_send_batch(telemetry_m.sc_log_batch)
735
+ return None, None
736
+ except Exception:
737
+ test_function_name = inspect.stack()[2].function
738
+ is_logged = telemetry_m.sc_is_hypothesis_event_logged((test_function_name, 1))
739
+ if not is_logged:
740
+ telemetry_m.sc_hypothesis_input_events.append((test_function_name, 0))
741
+ telemetry_m.sc_log_error(HYPOTHESIS_INPUT_SCHEMA_ERROR, telemetry_data)
742
+ telemetry_m.sc_send_batch(telemetry_m.sc_log_batch)
743
+ return None, None
744
+
745
+
746
+ def compare_data_event(telemetry_data: dict, param_data: dict) -> tuple[str, dict]:
747
+ """Handle telemetry event for comparing data.
748
+
749
+ This function processes telemetry data for a data comparison event. It updates the telemetry data
750
+ with the mode, status, and schema types of the Snowflake DataFrame being compared.
751
+
752
+ Args:
753
+ telemetry_data (dict): The telemetry data dictionary to be updated.
754
+ param_data (dict): The parameter data dictionary containing the DataFrame and status information.
755
+
756
+ Returns:
757
+ tuple: A tuple containing the event name and the updated telemetry data dictionary.
758
+
759
+ """
760
+ telemetry_data[MODE_KEY] = CheckpointMode.DATAFRAME.value
761
+ telemetry_data[STATUS_KEY] = param_data.get(STATUS_KEY, None)
762
+ try:
763
+ telemetry_data[SCHEMA_TYPES_KEY] = get_snowflake_schema_types(
764
+ param_data.get("df")
765
+ )
766
+ return DATAFRAME_VALIDATOR_DF, telemetry_data
767
+ except Exception:
768
+ return DATAFRAME_VALIDATOR_ERROR, telemetry_data
769
+
770
+
771
+ def handle_result(
772
+ func_name: str,
773
+ result: Any,
774
+ param_data: dict,
775
+ multiple_return: bool,
776
+ telemetry_m: TelemetryManager,
777
+ return_indexes: Optional[list[tuple[str, int]]] = None,
778
+ ) -> tuple[Optional[str], Optional[dict]]:
779
+ """Handle the result of the function and collect telemetry data.
780
+
781
+ Args:
782
+ func_name (str): The name of the function.
783
+ result: The result of the function.
784
+ param_data (dict): The extracted parameters.
785
+ multiple_return (bool): Whether the function returns multiple values.
786
+ telemetry_m (TelemetryManager): The telemetry manager.
787
+ return_indexes (list[tuple[str, int]]): The list of return values to report. Defaults to None.
788
+
789
+ Returns:
790
+ tuple: A tuple containing the event name (str) and telemetry data (dict).
791
+
792
+ """
793
+ if result is not None and return_indexes is not None:
794
+ if multiple_return:
795
+ for name, index in return_indexes:
796
+ param_data[name] = result[index]
797
+ else:
798
+ param_data[return_indexes[0][0]] = result[return_indexes[0][1]]
799
+
800
+ telemetry_data = {
801
+ FUNCTION_KEY: func_name,
802
+ }
803
+
804
+ telemetry_event = None
805
+ data = None
806
+ if func_name == "_check_dataframe_schema":
807
+ telemetry_event, data = check_dataframe_schema_event(telemetry_data, param_data)
808
+ elif func_name in ["check_output_schema", "check_input_schema"]:
809
+ telemetry_event, data = check_output_or_input_schema_event(
810
+ telemetry_data, param_data
811
+ )
812
+ if func_name == "_compare_data":
813
+ telemetry_event, data = compare_data_event(telemetry_data, param_data)
814
+ elif func_name == "_collect_dataframe_checkpoint_mode_schema":
815
+ telemetry_event, data = collect_dataframe_checkpoint_mode_schema_event(
816
+ telemetry_data, param_data
817
+ )
818
+ elif func_name == "_collect_dataframe_checkpoint_mode_dataframe":
819
+ telemetry_event, data = collect_dataframe_checkpoint_mode_dataframe_event(
820
+ telemetry_data, param_data
821
+ )
822
+ elif func_name == "_assert_return":
823
+ telemetry_event, data = assert_return_event(telemetry_data, param_data)
824
+ elif func_name == "dataframe_strategy":
825
+ telemetry_event, data = dataframe_strategy_event(
826
+ telemetry_data, param_data, telemetry_m
827
+ )
828
+ return telemetry_event, data
829
+
830
+
831
+ fn = TypeVar("fn", bound=Callable)
832
+
833
+
834
+ def report_telemetry(
835
+ params_list: list[str] = None,
836
+ return_indexes: list[tuple[str, int]] = None,
837
+ multiple_return: bool = False,
838
+ ) -> Callable[[fn], fn]:
839
+ """Report telemetry events for a function.
840
+
841
+ Args:
842
+ params_list (list[str], optional): The list of parameters to report. Defaults to None.
843
+ return_indexes (list[tuple[str, int]], optional): The list of return values to report. Defaults to None.
844
+ multiple_return (bool, optional): Whether the function returns multiple values. Defaults to False.
845
+
846
+ Returns:
847
+ Callable[[fn], fn]: The decorator function.
848
+
849
+ """
850
+
851
+ def report_telemetry_decorator(func):
852
+ func_name = func.__name__
853
+
854
+ @wraps(func)
855
+ def wrapper(*args, **kwargs):
856
+ func_exception = None
857
+ result = None
858
+ try:
859
+ result = func(*args, **kwargs)
860
+ except Exception as err:
861
+ func_exception = err
862
+
863
+ if (
864
+ os.getenv("SNOWPARK_CHECKPOINTS_TELEMETRY_ENABLED") == "false"
865
+ or _is_in_stored_procedure()
866
+ ):
867
+ return result
868
+ telemetry_event = None
869
+ data = None
870
+ telemetry_m = None
871
+ try:
872
+ param_data = extract_parameters(func, args, kwargs, params_list)
873
+ telemetry_m = get_telemetry_manager()
874
+ telemetry_event, data = handle_result(
875
+ func_name,
876
+ result,
877
+ param_data,
878
+ multiple_return,
879
+ telemetry_m,
880
+ return_indexes,
881
+ )
882
+ except Exception:
883
+ pass
884
+ finally:
885
+ if func_exception is not None:
886
+ if telemetry_m is not None:
887
+ telemetry_m.sc_log_error(telemetry_event, data)
888
+ raise func_exception
889
+ if telemetry_m is not None:
890
+ telemetry_m.sc_log_info(telemetry_event, data)
891
+
892
+ return result
893
+
894
+ return wrapper
895
+
896
+ return report_telemetry_decorator
897
+
898
+
899
+ # Constants for telemetry
900
+ DATAFRAME_COLLECTION_SCHEMA = "DataFrame_Collection_Schema"
901
+ DATAFRAME_COLLECTION_DF = "DataFrame_Collection_DF"
902
+ DATAFRAME_VALIDATOR_MIRROR = "DataFrame_Validator_Mirror"
903
+ VALUE_VALIDATOR_MIRROR = "Value_Validator_Mirror"
904
+ DATAFRAME_VALIDATOR_SCHEMA = "DataFrame_Validator_Schema"
905
+ DATAFRAME_VALIDATOR_DF = "DataFrame_Validator_DF"
906
+ HYPOTHESIS_INPUT_SCHEMA = "Hypothesis_Input_Schema"
907
+ DATAFRAME_COLLECTION_ERROR = "DataFrame_Collection_Error"
908
+ DATAFRAME_VALIDATOR_ERROR = "DataFrame_Validator_Error"
909
+ HYPOTHESIS_INPUT_SCHEMA_ERROR = "Hypothesis_Input_Schema_Error"
910
+
911
+ FUNCTION_KEY = "function"
912
+ STATUS_KEY = "status"
913
+ SCHEMA_TYPES_KEY = "schema_types"
914
+ ERROR_KEY = "error"
915
+ MODE_KEY = "mode"
916
+ SNOWFLAKE_SCHEMA_TYPES_KEY = "snowflake_schema_types"
917
+ SPARK_SCHEMA_TYPES_KEY = "spark_schema_types"
918
+
919
+ DATAFRAME_STRATEGY_SCHEMA_PARAM_NAME = "schema"
920
+ PANDERA_SCHEMA_PARAM_NAME = "pandera_schema"
921
+ SNOWPARK_RESULTS_PARAM_NAME = "snowpark_results"
922
+ SPARK_RESULTS_PARAM_NAME = "spark_results"
923
+ DF_PARAM_NAME = "df"
924
+
925
+
926
+ class CheckpointMode(IntEnum):
927
+ SCHEMA = 1
928
+ DATAFRAME = 2