snowpark-checkpoints-collectors 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/snowpark_checkpoints_collector/__version__.py +1 -1
- snowflake/snowpark_checkpoints_collector/collection_common.py +5 -2
- snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py +21 -0
- snowflake/snowpark_checkpoints_collector/summary_stats_collector.py +103 -6
- {snowpark_checkpoints_collectors-0.3.2.dist-info → snowpark_checkpoints_collectors-0.4.0.dist-info}/METADATA +1 -1
- {snowpark_checkpoints_collectors-0.3.2.dist-info → snowpark_checkpoints_collectors-0.4.0.dist-info}/RECORD +8 -8
- {snowpark_checkpoints_collectors-0.3.2.dist-info → snowpark_checkpoints_collectors-0.4.0.dist-info}/WHEEL +0 -0
- {snowpark_checkpoints_collectors-0.3.2.dist-info → snowpark_checkpoints_collectors-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -48,11 +48,12 @@ STRUCT_COLUMN_TYPE = "struct"
|
|
48
48
|
TIMESTAMP_COLUMN_TYPE = "timestamp"
|
49
49
|
TIMESTAMP_NTZ_COLUMN_TYPE = "timestamp_ntz"
|
50
50
|
|
51
|
-
PANDAS_BOOLEAN_DTYPE = "
|
51
|
+
PANDAS_BOOLEAN_DTYPE = "boolean"
|
52
52
|
PANDAS_DATETIME_DTYPE = "datetime64[ns]"
|
53
53
|
PANDAS_FLOAT_DTYPE = "float64"
|
54
|
-
PANDAS_INTEGER_DTYPE = "
|
54
|
+
PANDAS_INTEGER_DTYPE = "Int64"
|
55
55
|
PANDAS_OBJECT_DTYPE = "object"
|
56
|
+
PANDAS_STRING_DTYPE = "string[python]"
|
56
57
|
PANDAS_TIMEDELTA_DTYPE = "timedelta64[ns]"
|
57
58
|
|
58
59
|
NUMERIC_TYPE_COLLECTION = [
|
@@ -142,6 +143,8 @@ BACKSLASH_TOKEN = "\\"
|
|
142
143
|
SLASH_TOKEN = "/"
|
143
144
|
PYSPARK_NONE_SIZE_VALUE = -1
|
144
145
|
PANDAS_LONG_TYPE = "Int64"
|
146
|
+
PANDAS_STRING_TYPE = "string"
|
147
|
+
PANDAS_FLOAT_TYPE = "float64"
|
145
148
|
|
146
149
|
# ENVIRONMENT VARIABLES
|
147
150
|
SNOWFLAKE_CHECKPOINT_CONTRACT_FILE_PATH_ENV_VAR = (
|
@@ -22,6 +22,8 @@ from pathlib import Path
|
|
22
22
|
from typing import Callable, Optional
|
23
23
|
|
24
24
|
from snowflake.snowpark import Session
|
25
|
+
from snowflake.snowpark.functions import col, expr
|
26
|
+
from snowflake.snowpark.types import TimestampType
|
25
27
|
from snowflake.snowpark_checkpoints_collector.collection_common import (
|
26
28
|
DOT_PARQUET_EXTENSION,
|
27
29
|
)
|
@@ -195,9 +197,28 @@ class SnowConnection:
|
|
195
197
|
stage_directory_path,
|
196
198
|
)
|
197
199
|
dataframe = self.session.read.parquet(path=stage_directory_path)
|
200
|
+
dataframe = convert_timestamps_to_utc_date(dataframe)
|
198
201
|
LOGGER.info("Creating table '%s' from parquet files", table_name)
|
199
202
|
dataframe.write.save_as_table(table_name=table_name, mode="overwrite")
|
200
203
|
|
201
204
|
def _create_snowpark_session(self) -> Session:
|
202
205
|
LOGGER.info("Creating a Snowpark session using the default connection")
|
203
206
|
return Session.builder.getOrCreate()
|
207
|
+
|
208
|
+
|
209
|
+
def convert_timestamps_to_utc_date(df):
|
210
|
+
"""Convert all timestamp columns to UTC normalized timestamps.
|
211
|
+
|
212
|
+
Reading a parquet written by spark from a snowpark session modifies the original timestamps,
|
213
|
+
so this function normalizes timestamps for comparison.
|
214
|
+
"""
|
215
|
+
new_cols = []
|
216
|
+
for field in df.schema.fields:
|
217
|
+
if isinstance(field.datatype, TimestampType):
|
218
|
+
utc_normalized_ts = expr(
|
219
|
+
f"convert_timezone('UTC', cast(to_date({field.name}) as timestamp_tz))"
|
220
|
+
).alias(field.name)
|
221
|
+
new_cols.append(utc_normalized_ts)
|
222
|
+
else:
|
223
|
+
new_cols.append(col(field.name))
|
224
|
+
return df.select(new_cols)
|
@@ -23,9 +23,15 @@ import pandera as pa
|
|
23
23
|
|
24
24
|
from pyspark.sql import DataFrame as SparkDataFrame
|
25
25
|
from pyspark.sql.functions import col
|
26
|
+
from pyspark.sql.types import BinaryType as SparkBinaryType
|
27
|
+
from pyspark.sql.types import BooleanType as SparkBooleanType
|
28
|
+
from pyspark.sql.types import DateType as SparkDateType
|
26
29
|
from pyspark.sql.types import DoubleType as SparkDoubleType
|
30
|
+
from pyspark.sql.types import FloatType as SparkFloatType
|
31
|
+
from pyspark.sql.types import IntegerType as SparkIntegerType
|
27
32
|
from pyspark.sql.types import StringType as SparkStringType
|
28
|
-
from pyspark.sql.types import StructField
|
33
|
+
from pyspark.sql.types import StructField as SparkStructField
|
34
|
+
from pyspark.sql.types import TimestampType as SparkTimestampType
|
29
35
|
|
30
36
|
from snowflake.snowpark_checkpoints_collector.collection_common import (
|
31
37
|
CHECKPOINT_JSON_OUTPUT_FILE_NAME_FORMAT,
|
@@ -36,8 +42,10 @@ from snowflake.snowpark_checkpoints_collector.collection_common import (
|
|
36
42
|
DOT_PARQUET_EXTENSION,
|
37
43
|
INTEGER_TYPE_COLLECTION,
|
38
44
|
NULL_COLUMN_TYPE,
|
45
|
+
PANDAS_FLOAT_TYPE,
|
39
46
|
PANDAS_LONG_TYPE,
|
40
47
|
PANDAS_OBJECT_TYPE_COLLECTION,
|
48
|
+
PANDAS_STRING_TYPE,
|
41
49
|
CheckpointMode,
|
42
50
|
)
|
43
51
|
from snowflake.snowpark_checkpoints_collector.collection_result.model import (
|
@@ -72,6 +80,16 @@ from snowflake.snowpark_checkpoints_collector.utils.telemetry import report_tele
|
|
72
80
|
|
73
81
|
LOGGER = logging.getLogger(__name__)
|
74
82
|
|
83
|
+
default_null_types = {
|
84
|
+
SparkIntegerType(): 0,
|
85
|
+
SparkFloatType(): 0.0,
|
86
|
+
SparkDoubleType(): 0.0,
|
87
|
+
SparkStringType(): "",
|
88
|
+
SparkBooleanType(): False,
|
89
|
+
SparkTimestampType(): None,
|
90
|
+
SparkDateType(): None,
|
91
|
+
}
|
92
|
+
|
75
93
|
|
76
94
|
@log
|
77
95
|
def collect_dataframe_checkpoint(
|
@@ -253,6 +271,7 @@ def _collect_dataframe_checkpoint_mode_schema(
|
|
253
271
|
column_type_dict: dict[str, any],
|
254
272
|
output_path: Optional[str] = None,
|
255
273
|
) -> None:
|
274
|
+
df = normalize_missing_values(df)
|
256
275
|
sampled_df = df.sample(sample)
|
257
276
|
if sampled_df.isEmpty():
|
258
277
|
LOGGER.warning("Sampled DataFrame is empty. Collecting full DataFrame.")
|
@@ -327,7 +346,16 @@ def _collect_dataframe_checkpoint_mode_schema(
|
|
327
346
|
)
|
328
347
|
|
329
348
|
|
330
|
-
def
|
349
|
+
def normalize_missing_values(df: SparkDataFrame) -> SparkDataFrame:
|
350
|
+
"""Normalize missing values in a PySpark DataFrame to ensure consistent handling of NA values."""
|
351
|
+
for field in df.schema.fields:
|
352
|
+
default_value = default_null_types.get(field.dataType, None)
|
353
|
+
if default_value is not None:
|
354
|
+
df = df.fillna({field.name: default_value})
|
355
|
+
return df
|
356
|
+
|
357
|
+
|
358
|
+
def _get_spark_column_types(df: SparkDataFrame) -> dict[str, SparkStructField]:
|
331
359
|
schema = df.schema
|
332
360
|
column_type_collection = {}
|
333
361
|
for field in schema.fields:
|
@@ -457,14 +485,83 @@ def _to_pandas(sampled_df: SparkDataFrame) -> pandas.DataFrame:
|
|
457
485
|
LOGGER.debug("Converting Spark DataFrame to Pandas DataFrame")
|
458
486
|
pandas_df = sampled_df.toPandas()
|
459
487
|
for field in sampled_df.schema.fields:
|
460
|
-
has_nan = pandas_df[field.name].isna().any()
|
461
488
|
is_integer = field.dataType.typeName() in INTEGER_TYPE_COLLECTION
|
462
|
-
|
489
|
+
is_spark_string = isinstance(field.dataType, SparkStringType)
|
490
|
+
is_spark_binary = isinstance(field.dataType, SparkBinaryType)
|
491
|
+
is_spark_timestamp = isinstance(field.dataType, SparkTimestampType)
|
492
|
+
is_spark_float = isinstance(field.dataType, SparkFloatType)
|
493
|
+
is_spark_boolean = isinstance(field.dataType, SparkBooleanType)
|
494
|
+
is_spark_date = isinstance(field.dataType, SparkDateType)
|
495
|
+
if is_integer:
|
463
496
|
LOGGER.debug(
|
464
|
-
"Converting column '%s' to '%s' type",
|
497
|
+
"Converting Spark integer column '%s' to Pandas nullable '%s' type",
|
465
498
|
field.name,
|
466
499
|
PANDAS_LONG_TYPE,
|
467
500
|
)
|
468
|
-
pandas_df[field.name] =
|
501
|
+
pandas_df[field.name] = (
|
502
|
+
pandas_df[field.name].astype(PANDAS_LONG_TYPE).fillna(0)
|
503
|
+
)
|
504
|
+
elif is_spark_string or is_spark_binary:
|
505
|
+
LOGGER.debug(
|
506
|
+
"Converting Spark string column '%s' to Pandas nullable '%s' type",
|
507
|
+
field.name,
|
508
|
+
PANDAS_STRING_TYPE,
|
509
|
+
)
|
510
|
+
pandas_df[field.name] = (
|
511
|
+
pandas_df[field.name].astype(PANDAS_STRING_TYPE).fillna("")
|
512
|
+
)
|
513
|
+
elif is_spark_timestamp:
|
514
|
+
LOGGER.debug(
|
515
|
+
"Converting Spark timestamp column '%s' to UTC naive Pandas datetime",
|
516
|
+
field.name,
|
517
|
+
)
|
518
|
+
pandas_df[field.name] = convert_all_to_utc_naive(
|
519
|
+
pandas_df[field.name]
|
520
|
+
).fillna(pandas.NaT)
|
521
|
+
elif is_spark_float:
|
522
|
+
LOGGER.debug(
|
523
|
+
"Converting Spark float column '%s' to Pandas nullable float",
|
524
|
+
field.name,
|
525
|
+
)
|
526
|
+
pandas_df[field.name] = (
|
527
|
+
pandas_df[field.name].astype(PANDAS_FLOAT_TYPE).fillna(0.0)
|
528
|
+
)
|
529
|
+
elif is_spark_boolean:
|
530
|
+
LOGGER.debug(
|
531
|
+
"Converting Spark boolean column '%s' to Pandas nullable boolean",
|
532
|
+
field.name,
|
533
|
+
)
|
534
|
+
pandas_df[field.name] = (
|
535
|
+
pandas_df[field.name].astype("boolean").fillna(False)
|
536
|
+
)
|
537
|
+
elif is_spark_date:
|
538
|
+
LOGGER.debug(
|
539
|
+
"Converting Spark date column '%s' to Pandas nullable datetime",
|
540
|
+
field.name,
|
541
|
+
)
|
542
|
+
pandas_df[field.name] = pandas_df[field.name].fillna(pandas.NaT)
|
469
543
|
|
470
544
|
return pandas_df
|
545
|
+
|
546
|
+
|
547
|
+
def convert_all_to_utc_naive(series: pandas.Series) -> pandas.Series:
|
548
|
+
"""Convert all timezone-aware or naive timestamps in a series to UTC naive.
|
549
|
+
|
550
|
+
Naive timestamps are assumed to be in UTC and localized accordingly.
|
551
|
+
Timezone-aware timestamps are converted to UTC and then made naive.
|
552
|
+
|
553
|
+
Args:
|
554
|
+
series (pandas.Series): A Pandas Series of `pd.Timestamp` objects,
|
555
|
+
either naive or timezone-aware.
|
556
|
+
|
557
|
+
Returns:
|
558
|
+
pandas.Series: A Series of UTC-normalized naive timestamps (`tzinfo=None`).
|
559
|
+
|
560
|
+
"""
|
561
|
+
|
562
|
+
def convert(ts):
|
563
|
+
if ts.tz is None:
|
564
|
+
ts = ts.tz_localize("UTC")
|
565
|
+
return ts.tz_convert("UTC").tz_localize(None)
|
566
|
+
|
567
|
+
return series.apply(convert)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: snowpark-checkpoints-collectors
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.0
|
4
4
|
Summary: Snowpark column and table statistics collection
|
5
5
|
Project-URL: Bug Tracker, https://github.com/snowflakedb/snowpark-checkpoints/issues
|
6
6
|
Project-URL: Source code, https://github.com/snowflakedb/snowpark-checkpoints/
|
@@ -1,8 +1,8 @@
|
|
1
1
|
snowflake/snowpark_checkpoints_collector/__init__.py,sha256=g4NemuA6Mj4O2jkK0yLQ8sEV3owHiiJnBEz_OWvlW1I,1179
|
2
|
-
snowflake/snowpark_checkpoints_collector/__version__.py,sha256=
|
3
|
-
snowflake/snowpark_checkpoints_collector/collection_common.py,sha256=
|
2
|
+
snowflake/snowpark_checkpoints_collector/__version__.py,sha256=mZG_4eaVJdzo54iJo1tR3khnIA6lKjmN2lUgMoangNY,632
|
3
|
+
snowflake/snowpark_checkpoints_collector/collection_common.py,sha256=qHiBWOICEbc1bvpUbfZU_mkmRiy77TB_2eR12mg52oQ,4612
|
4
4
|
snowflake/snowpark_checkpoints_collector/singleton.py,sha256=7AgIHQBXVRvPBBCkmBplzkdrrm-xVWf_N8svzA2vF8E,836
|
5
|
-
snowflake/snowpark_checkpoints_collector/summary_stats_collector.py,sha256
|
5
|
+
snowflake/snowpark_checkpoints_collector/summary_stats_collector.py,sha256=-KhVUcZX9z3_RmFxkcKa-31Ry9PRdcYN_U6O_cPYNhg,20984
|
6
6
|
snowflake/snowpark_checkpoints_collector/collection_result/model/__init__.py,sha256=jZzx29WzrjH7C_6ZsBGoe4PxbW_oM4uIjySS1axIM34,1000
|
7
7
|
snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result.py,sha256=XelL7LughZpKl1B_6bJoKOc_PqQg3UleX6zdgVXqTus,2926
|
8
8
|
snowflake/snowpark_checkpoints_collector/collection_result/model/collection_point_result_manager.py,sha256=EY6WIIXRbvkTYC4bQn7jFALHh7D2PirVoiLZ5Kq8dNs,2659
|
@@ -31,13 +31,13 @@ snowflake/snowpark_checkpoints_collector/io_utils/io_default_strategy.py,sha256=
|
|
31
31
|
snowflake/snowpark_checkpoints_collector/io_utils/io_env_strategy.py,sha256=kJMbg2VOKNXXdkGCt_tMMLGEZ2aUl1_nie1qYvx5M-c,3770
|
32
32
|
snowflake/snowpark_checkpoints_collector/io_utils/io_file_manager.py,sha256=M17EtANswD5gcgGnmT13OImO_W1uH4K3ewu2CXL9aes,2597
|
33
33
|
snowflake/snowpark_checkpoints_collector/snow_connection_model/__init__.py,sha256=kLjZId-aGCljK7lF6yeEw-syEqeTOJDxdXfpv9YxvZA,755
|
34
|
-
snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py,sha256=
|
34
|
+
snowflake/snowpark_checkpoints_collector/snow_connection_model/snow_connection.py,sha256=lM3oqHUHXShALDVVU5ZSuXGREUVfHYHprB5fy1r5T0I,8154
|
35
35
|
snowflake/snowpark_checkpoints_collector/utils/checkpoint_name_utils.py,sha256=Xc4k3JU6A96-79VFRR8NrNAUPeO3V1DEAhngg-hLlU4,1787
|
36
36
|
snowflake/snowpark_checkpoints_collector/utils/extra_config.py,sha256=3kVf6WVA-EuyMpTO3ycTlXMSCHtytGtT6wkV4U2Hyjw,5195
|
37
37
|
snowflake/snowpark_checkpoints_collector/utils/file_utils.py,sha256=5ztlNCv9GdSktUvtdfydv86cCFcmSXCdD4axZXJrOQQ,5125
|
38
38
|
snowflake/snowpark_checkpoints_collector/utils/logging_utils.py,sha256=yyi6X5DqKeTg0HRhvsH6ymYp2P0wbnyKIzI2RzrQS7k,2278
|
39
39
|
snowflake/snowpark_checkpoints_collector/utils/telemetry.py,sha256=ueN9vM8j5YNax7jMcnEj_UrgGkoeMv_hJHVKjN7hiJE,32161
|
40
|
-
snowpark_checkpoints_collectors-0.
|
41
|
-
snowpark_checkpoints_collectors-0.
|
42
|
-
snowpark_checkpoints_collectors-0.
|
43
|
-
snowpark_checkpoints_collectors-0.
|
40
|
+
snowpark_checkpoints_collectors-0.4.0.dist-info/METADATA,sha256=HMpSzXXczuG-5_RuKadoEbA8JjADUvadLj2sQWvu9MY,6613
|
41
|
+
snowpark_checkpoints_collectors-0.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
42
|
+
snowpark_checkpoints_collectors-0.4.0.dist-info/licenses/LICENSE,sha256=DVQuDIgE45qn836wDaWnYhSdxoLXgpRRKH4RuTjpRZQ,10174
|
43
|
+
snowpark_checkpoints_collectors-0.4.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|