snowpark-checkpoints-collectors 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,4 +13,4 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- __version__ = "0.3.2"
16
+ __version__ = "0.4.0"
@@ -48,11 +48,12 @@ STRUCT_COLUMN_TYPE = "struct"
48
48
  TIMESTAMP_COLUMN_TYPE = "timestamp"
49
49
  TIMESTAMP_NTZ_COLUMN_TYPE = "timestamp_ntz"
50
50
 
51
- PANDAS_BOOLEAN_DTYPE = "bool"
51
+ PANDAS_BOOLEAN_DTYPE = "boolean"
52
52
  PANDAS_DATETIME_DTYPE = "datetime64[ns]"
53
53
  PANDAS_FLOAT_DTYPE = "float64"
54
- PANDAS_INTEGER_DTYPE = "int64"
54
+ PANDAS_INTEGER_DTYPE = "Int64"
55
55
  PANDAS_OBJECT_DTYPE = "object"
56
+ PANDAS_STRING_DTYPE = "string[python]"
56
57
  PANDAS_TIMEDELTA_DTYPE = "timedelta64[ns]"
57
58
 
58
59
  NUMERIC_TYPE_COLLECTION = [
@@ -142,6 +143,8 @@ BACKSLASH_TOKEN = "\\"
142
143
  SLASH_TOKEN = "/"
143
144
  PYSPARK_NONE_SIZE_VALUE = -1
144
145
  PANDAS_LONG_TYPE = "Int64"
146
+ PANDAS_STRING_TYPE = "string"
147
+ PANDAS_FLOAT_TYPE = "float64"
145
148
 
146
149
  # ENVIRONMENT VARIABLES
147
150
  SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR = (
@@ -22,6 +22,8 @@ from pathlib import Path
22
22
  from typing import Callable, Optional
23
23
 
24
24
  from snowflake.snowpark import Session
25
+ from snowflake.snowpark.functions import col, expr
26
+ from snowflake.snowpark.types import TimestampType
25
27
  from snowflake.snowpark_checkpoints_collector.collection_common import (
26
28
  DOT_PARQUET_EXTENSION,
27
29
  )
@@ -195,9 +197,28 @@ class SnowConnection:
195
197
  stage_directory_path,
196
198
  )
197
199
  dataframe = self.session.read.parquet(path=stage_directory_path)
200
+ dataframe = convert_timestamps_to_utc_date(dataframe)
198
201
  LOGGER.info("Creating table '%s' from parquet files", table_name)
199
202
  dataframe.write.save_as_table(table_name=table_name, mode="overwrite")
200
203
 
201
204
  def _create_snowpark_session(self) -> Session:
202
205
  LOGGER.info("Creating a Snowpark session using the default connection")
203
206
  return Session.builder.getOrCreate()
207
+
208
+
209
+ def convert_timestamps_to_utc_date(df):
210
+ """Convert all timestamp columns to UTC normalized timestamps.
211
+
212
+ Reading a parquet written by spark from a snowpark session modifies the original timestamps,
213
+ so this function normalizes timestamps for comparison.
214
+ """
215
+ new_cols = []
216
+ for field in df.schema.fields:
217
+ if isinstance(field.datatype, TimestampType):
218
+ utc_normalized_ts = expr(
219
+ f"convert_timezone('UTC', cast(to_date({field.name}) as timestamp_tz))"
220
+ ).alias(field.name)
221
+ new_cols.append(utc_normalized_ts)
222
+ else:
223
+ new_cols.append(col(field.name))
224
+ return df.select(new_cols)
@@ -23,9 +23,15 @@ import pandera as pa
23
23
 
24
24
  from pyspark.sql import DataFrame as SparkDataFrame
25
25
  from pyspark.sql.functions import col
26
+ from pyspark.sql.types import BinaryType as SparkBinaryType
27
+ from pyspark.sql.types import BooleanType as SparkBooleanType
28
+ from pyspark.sql.types import DateType as SparkDateType
26
29
  from pyspark.sql.types import DoubleType as SparkDoubleType
30
+ from pyspark.sql.types import FloatType as SparkFloatType
31
+ from pyspark.sql.types import IntegerType as SparkIntegerType
27
32
  from pyspark.sql.types import StringType as SparkStringType
28
- from pyspark.sql.types import StructField
33
+ from pyspark.sql.types import StructField as SparkStructField
34
+ from pyspark.sql.types import TimestampType as SparkTimestampType
29
35
 
30
36
  from snowflake.snowpark_checkpoints_collector.collection_common import (
31
37
  CHECKPOINT_JSON_OUTPUT_FILE_NAME_FORMAT,
@@ -36,8 +42,10 @@ from snowflake.snowpark_checkpoints_collector.collection_common import (
36
42
  DOT_PARQUET_EXTENSION,
37
43
  INTEGER_TYPE_COLLECTION,
38
44
  NULL_COLUMN_TYPE,
45
+ PANDAS_FLOAT_TYPE,
39
46
  PANDAS_LONG_TYPE,
40
47
  PANDAS_OBJECT_TYPE_COLLECTION,
48
+ PANDAS_STRING_TYPE,
41
49
  CheckpointMode,
42
50
  )
43
51
  from snowflake.snowpark_checkpoints_collector.collection_result.model import (
@@ -72,6 +80,16 @@ from snowflake.snowpark_checkpoints_collector.utils.telemetry import report_tele
72
80
 
73
81
  LOGGER = logging.getLogger(__name__)
74
82
 
83
+ default_null_types = {
84
+ SparkIntegerType(): 0,
85
+ SparkFloatType(): 0.0,
86
+ SparkDoubleType(): 0.0,
87
+ SparkStringType(): "",
88
+ SparkBooleanType(): False,
89
+ SparkTimestampType(): None,
90
+ SparkDateType(): None,
91
+ }
92
+
75
93
 
76
94
  @log
77
95
  def collect_dataframe_checkpoint(
@@ -253,6 +271,7 @@ def _collect_dataframe_checkpoint_mode_schema(
253
271
  column_type_dict: dict[str, any],
254
272
  output_path: Optional[str] = None,
255
273
  ) -> None:
274
+ df = normalize_missing_values(df)
256
275
  sampled_df = df.sample(sample)
257
276
  if sampled_df.isEmpty():
258
277
  LOGGER.warning("Sampled DataFrame is empty. Collecting full DataFrame.")
@@ -327,7 +346,16 @@ def _collect_dataframe_checkpoint_mode_schema(
327
346
  )
328
347
 
329
348
 
330
- def _get_spark_column_types(df: SparkDataFrame) -> dict[str, StructField]:
349
+ def normalize_missing_values(df: SparkDataFrame) -> SparkDataFrame:
350
+ """Normalize missing values in a PySpark DataFrame to ensure consistent handling of NA values."""
351
+ for field in df.schema.fields:
352
+ default_value = default_null_types.get(field.dataType, None)
353
+ if default_value is not None:
354
+ df = df.fillna({field.name: default_value})
355
+ return df
356
+
357
+
358
+ def _get_spark_column_types(df: SparkDataFrame) -> dict[str, SparkStructField]:
331
359
  schema = df.schema
332
360
  column_type_collection = {}
333
361
  for field in schema.fields:
@@ -457,14 +485,83 @@ def _to_pandas(sampled_df: SparkDataFrame) -> pandas.DataFrame:
457
485
  LOGGER.debug("Converting Spark DataFrame to Pandas DataFrame")
458
486
  pandas_df = sampled_df.toPandas()
459
487
  for field in sampled_df.schema.fields:
460
- has_nan = pandas_df[field.name].isna().any()
461
488
  is_integer = field.dataType.typeName() in INTEGER_TYPE_COLLECTION
462
- if has_nan and is_integer:
489
+ is_spark_string = isinstance(field.dataType, SparkStringType)
490
+ is_spark_binary = isinstance(field.dataType, SparkBinaryType)
491
+ is_spark_timestamp = isinstance(field.dataType, SparkTimestampType)
492
+ is_spark_float = isinstance(field.dataType, SparkFloatType)
493
+ is_spark_boolean = isinstance(field.dataType, SparkBooleanType)
494
+ is_spark_date = isinstance(field.dataType, SparkDateType)
495
+ if is_integer:
463
496
  LOGGER.debug(
464
- "Converting column '%s' to '%s' type",
497
+ "Converting Spark integer column '%s' to Pandas nullable '%s' type",
465
498
  field.name,
466
499
  PANDAS_LONG_TYPE,
467
500
  )
468
- pandas_df[field.name] = pandas_df[field.name].astype(PANDAS_LONG_TYPE)
501
+ pandas_df[field.name] = (
502
+ pandas_df[field.name].astype(PANDAS_LONG_TYPE).fillna(0)
503
+ )
504
+ elif is_spark_string or is_spark_binary:
505
+ LOGGER.debug(
506
+ "Converting Spark string column '%s' to Pandas nullable '%s' type",
507
+ field.name,
508
+ PANDAS_STRING_TYPE,
509
+ )
510
+ pandas_df[field.name] = (
511
+ pandas_df[field.name].astype(PANDAS_STRING_TYPE).fillna("")
512
+ )
513
+ elif is_spark_timestamp:
514
+ LOGGER.debug(
515
+ "Converting Spark timestamp column '%s' to UTC naive Pandas datetime",
516
+ field.name,
517
+ )
518
+ pandas_df[field.name] = convert_all_to_utc_naive(
519
+ pandas_df[field.name]
520
+ ).fillna(pandas.NaT)
521
+ elif is_spark_float:
522
+ LOGGER.debug(
523
+ "Converting Spark float column '%s' to Pandas nullable float",
524
+ field.name,
525
+ )
526
+ pandas_df[field.name] = (
527
+ pandas_df[field.name].astype(PANDAS_FLOAT_TYPE).fillna(0.0)
528
+ )
529
+ elif is_spark_boolean:
530
+ LOGGER.debug(
531
+ "Converting Spark boolean column '%s' to Pandas nullable boolean",
532
+ field.name,
533
+ )
534
+ pandas_df[field.name] = (
535
+ pandas_df[field.name].astype("boolean").fillna(False)
536
+ )
537
+ elif is_spark_date:
538
+ LOGGER.debug(
539
+ "Converting Spark date column '%s' to Pandas nullable datetime",
540
+ field.name,
541
+ )
542
+ pandas_df[field.name] = pandas_df[field.name].fillna(pandas.NaT)
469
543
 
470
544
  return pandas_df
545
+
546
+
547
+ def convert_all_to_utc_naive(series: pandas.Series) -> pandas.Series:
548
+ """Convert all timezone-aware or naive timestamps in a series to UTC naive.
549
+
550
+ Naive timestamps are assumed to be in UTC and localized accordingly.
551
+ Timezone-aware timestamps are converted to UTC and then made naive.
552
+
553
+ Args:
554
+ series (pandas.Series): A Pandas Series of `pd.Timestamp` objects,
555
+ either naive or timezone-aware.
556
+
557
+ Returns:
558
+ pandas.Series: A Series of UTC-normalized naive timestamps (`tzinfo=None`).
559
+
560
+ """
561
+
562
+ def convert(ts):
563
+ if ts.tz is None:
564
+ ts = ts.tz_localize("UTC")
565
+ return ts.tz_convert("UTC").tz_localize(None)
566
+
567
+ return series.apply(convert)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: snowpark-checkpoints-collectors
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Summary: Snowpark column and table statistics collection
5
5
  Project-URL: Bug Tracker, https://github.com/snowflakedb/snowpark-checkpoints/issues
6
6
  Project-URL: Source code, https://github.com/snowflakedb/snowpark-checkpoints/
@@ -1,8 +1,8 @@
1
1
  snowflake/snowpark_checkpoints_collector/__init__.py,sha256=g4NemuA6Mj4O2jkK0yLQ8sEV3owHiiJnBEz_OWvlW1I,1179
2
- snowflake/snowpark_checkpoints_collector/__version__.py,sha256=1W0aBeLTL5Svy-qrNkZc6gAKtQLDbncpMyN2SlnJhoU,632
3
- snowflake/snowpark_checkpoints_collector/collection_common.py,sha256=ff5vYffrTRjoJXZQvVQBaOlegAUj_vXBbl1IZidz8Qo,4510
2
+ snowflake/snowpark_checkpoints_collector/__version__.py,sha256=mZG_4eaVJdzo54iJo1tR3khnIA6lKjmN2lUgMoangNY,632
3
+ snowflake/snowpark_checkpoints_collector/collection_common.py,sha256=qHiBWOICEbc1bvpUbfZU_mkmRiy77TB_2eR12mg52oQ,4612
4
4
  snowflake/snowpark_checkpoints_collector/singleton.py,sha256=7AgIHQBXVRvPBBCkmBplzkdrrm-xVWf_N8svzA2vF8E,836
5
- snowflake/snowpark_checkpoints_collector/summary_stats_collector.py,sha256=kRJpVRE9Iy_uqeIPT-__Aan-YLWxQbgSjkJ3w4LpvCc,17214
5
+ snowflake/snowpark_checkpoints_collector/summary_stats_collector.py,sha256=-KhVUcZX9z3_RmFxkcKa-31Ry9PRdcYN_U6O_cPYNhg,20984
6
6
  snowflake/snowpark_checkpoints_collector/collection_result/model/__init__.py,sha256=jZzx29WzrjH7C_6ZsBGoe4PxbW_oM4uIjySS1axIM34,1000
7
7
  snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result.py,sha256=XelL7LughZpKl1B_6bJoKOc_PqQg3UleX6zdgVXqTus,2926
8
8
  snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py,sha256=EY6WIIXRbvkTYC4bQn7jFALHh7D2PirVoiLZ5Kq8dNs,2659
@@ -31,13 +31,13 @@ snowflake/snowpark_checkpoints_collector/io_utils/io_default_strategy.py,sha256=
31
31
  snowflake/snowpark_checkpoints_collector/io_utils/io_env_strategy.py,sha256=kJMbg2VOKNXXdkGCt_tMMLGEZ2aUl1_nie1qYvx5M-c,3770
32
32
  snowflake/snowpark_checkpoints_collector/io_utils/io_file_manager.py,sha256=M17EtANswD5gcgGnmT13OImO_W1uH4K3ewu2CXL9aes,2597
33
33
  snowflake/snowpark_checkpoints_collector/snow_connection_model/__init__.py,sha256=kLjZId-aGCljK7lF6yeEw-syEqeTOJDxdXfpv9YxvZA,755
34
- snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py,sha256=r3IPnmDMb8151PTgE4YojOhWnxWGPLyBWlgFvvhOfRY,7314
34
+ snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py,sha256=lM3oqHUHXShALDVVU5ZSuXGREUVfHYHprB5fy1r5T0I,8154
35
35
  snowflake/snowpark_checkpoints_collector/utils/checkpoint_name_utils.py,sha256=Xc4k3JU6A96-79VFRR8NrNAUPeO3V1DEAhngg-hLlU4,1787
36
36
  snowflake/snowpark_checkpoints_collector/utils/extra_config.py,sha256=3kVf6WVA-EuyMpTO3ycTlXMSCHtytGtT6wkV4U2Hyjw,5195
37
37
  snowflake/snowpark_checkpoints_collector/utils/file_utils.py,sha256=5ztlNCv9GdSktUvtdfydv86cCFcmSXCdD4axZXJrOQQ,5125
38
38
  snowflake/snowpark_checkpoints_collector/utils/logging_utils.py,sha256=yyi6X5DqKeTg0HRhvsH6ymYp2P0wbnyKIzI2RzrQS7k,2278
39
39
  snowflake/snowpark_checkpoints_collector/utils/telemetry.py,sha256=ueN9vM8j5YNax7jMcnEj_UrgGkoeMv_hJHVKjN7hiJE,32161
40
- snowpark_checkpoints_collectors-0.3.2.dist-info/METADATA,sha256=ueYk6-aMlhiKfvH0CZbqjiEjlxUP1VQwKDejX28ju30,6613
41
- snowpark_checkpoints_collectors-0.3.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
42
- snowpark_checkpoints_collectors-0.3.2.dist-info/licenses/LICENSE,sha256=DVQuDIgE45qn836wDaWnYhSdxoLXgpRRKH4RuTjpRZQ,10174
43
- snowpark_checkpoints_collectors-0.3.2.dist-info/RECORD,,
40
+ snowpark_checkpoints_collectors-0.4.0.dist-info/METADATA,sha256=HMpSzXXczuG-5_RuKadoEbA8JjADUvadLj2sQWvu9MY,6613
41
+ snowpark_checkpoints_collectors-0.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
42
+ snowpark_checkpoints_collectors-0.4.0.dist-info/licenses/LICENSE,sha256=DVQuDIgE45qn836wDaWnYhSdxoLXgpRRKH4RuTjpRZQ,10174
43
+ snowpark_checkpoints_collectors-0.4.0.dist-info/RECORD,,