sws-spark-dissemination-helper 0.0.185__tar.gz → 0.0.194__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/PKG-INFO +1 -1
  2. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/pyproject.toml +1 -1
  3. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/src/sws_spark_dissemination_helper/SWSGoldIcebergSparkHelper.py +159 -42
  4. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/src/sws_spark_dissemination_helper/SWSPostgresSparkReader.py +1 -43
  5. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/src/sws_spark_dissemination_helper/SWSSilverIcebergSparkHelper.py +1 -1
  6. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/src/sws_spark_dissemination_helper/constants.py +1 -1
  7. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/.gitignore +0 -0
  8. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/LICENSE +0 -0
  9. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/README.md +0 -0
  10. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/src/sws_spark_dissemination_helper/SWSBronzeIcebergSparkHelper.py +0 -0
  11. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/src/sws_spark_dissemination_helper/SWSDatatablesExportHelper.py +0 -0
  12. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/src/sws_spark_dissemination_helper/SWSEasyIcebergSparkHelper.py +0 -0
  13. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/src/sws_spark_dissemination_helper/__init__.py +0 -0
  14. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/src/sws_spark_dissemination_helper/utils.py +0 -0
  15. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/tests/__init__.py +0 -0
  16. {sws_spark_dissemination_helper-0.0.185 → sws_spark_dissemination_helper-0.0.194}/tests/test.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sws-spark-dissemination-helper
3
- Version: 0.0.185
3
+ Version: 0.0.194
4
4
  Summary: A Python helper package providing streamlined Spark functions for efficient data dissemination processes
5
5
  Project-URL: Repository, https://github.com/un-fao/fao-sws-it-python-spark-dissemination-helper
6
6
  Author-email: Daniele Mansillo <danielemansillo@gmail.com>
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "sws-spark-dissemination-helper"
7
- version = "0.0.185"
7
+ version = "0.0.194"
8
8
  dependencies = [
9
9
  "annotated-types==0.7.0",
10
10
  "boto3>=1.40.0",
@@ -5,13 +5,19 @@ from typing import List, Tuple
5
5
  import pyspark.sql.functions as F
6
6
  from pyspark.sql import DataFrame, SparkSession
7
7
  from pyspark.sql.functions import col, lit
8
+ from pyspark.sql.types import DecimalType, FloatType
8
9
  from sws_api_client import Tags
9
10
  from sws_api_client.tags import BaseDisseminatedTagTable, TableLayer, TableType
10
11
 
11
- from .constants import IcebergDatabases, IcebergTables, DatasetDatatables
12
+ from .constants import DatasetDatatables, IcebergDatabases, IcebergTables
12
13
  from .SWSPostgresSparkReader import SWSPostgresSparkReader
13
14
  from .utils import get_or_create_tag, save_cache_csv, upsert_disseminated_table
14
15
 
16
+ SIMPLE_NUMERIC_REGEX = r"^[+-]?\d*(\.\d+)?$"
17
+ NUMERIC_REGEX = r"^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$"
18
+ # Regex to extract decimal places: matches the decimal part and counts its length
19
+ DECIMAL_PLACES_REGEX = r"\.(\d+)$"
20
+
15
21
 
16
22
  class SWSGoldIcebergSparkHelper:
17
23
  def __init__(
@@ -62,10 +68,12 @@ class SWSGoldIcebergSparkHelper:
62
68
  if col_name in self.dim_columns
63
69
  }
64
70
 
65
- self.display_decimals = (
66
- self.sws_postgres_spark_reader.get_display_decimals_datatable(
67
- domain_code=domain_code
68
- )
71
+ self.display_decimals_df = self.sws_postgres_spark_reader.read_pg_table(
72
+ pg_table=DatasetDatatables.DISPLAY_DECIMALS.id,
73
+ custom_schema=DatasetDatatables.DISPLAY_DECIMALS.schema,
74
+ ).filter(
75
+ (col("domain") == lit(self.domain_code))
76
+ | ((col("domain") == lit("DEFAULT")))
69
77
  )
70
78
 
71
79
  def _get_dim_time_flag_columns(self) -> Tuple[List[str], List[str], str, List[str]]:
@@ -99,53 +107,162 @@ class SWSGoldIcebergSparkHelper:
99
107
  cols_to_keep_sws = cols_to_keep_sws + ["unit_of_measure_symbol"]
100
108
  return df.select(*cols_to_keep_sws)
101
109
 
102
- def round_to_display_decimals(self, df: DataFrame):
103
- col1_name, col2_name = (
104
- self.display_decimals.select("column_1_name", "column_2_name")
105
- .distinct()
106
- .collect()[0]
107
- )
108
- if col1_name.lower() not in [column.lower() for column in df.columns]:
109
- raise ValueError(
110
- f"{col1_name} is not part of the columns available for this dataset ({df.columns})"
110
+ def round_to_display_decimals(
111
+ self,
112
+ df: DataFrame,
113
+ value_column: str = "value",
114
+ ) -> DataFrame:
115
+
116
+ df = df.withColumn("unrounded_value", col(value_column).cast("string"))
117
+
118
+ general_default_decimals = (
119
+ self.display_decimals_df.filter(col("domain") == lit("DEFAULT"))
120
+ .select("display_decimals")
121
+ .collect()[0][0]
122
+ )
123
+ domain_default_decimals = self.display_decimals_df.filter(
124
+ (col("domain") == lit(self.domain_code))
125
+ & col("column_1_name").isNull()
126
+ & col("column_2_name").isNull()
127
+ ).select("display_decimals")
128
+
129
+ default_decimals = int(
130
+ general_default_decimals
131
+ if domain_default_decimals.count() == 0
132
+ else domain_default_decimals.collect()[0][0]
133
+ )
134
+
135
+ domain_specific_rules = self.display_decimals_df.filter(
136
+ (col("domain") == lit(self.domain_code))
137
+ & (col("column_1_name").isNotNull() & col("column_1_value").isNotNull())
138
+ | (col("column_2_name").isNotNull() & col("column_2_value").isNotNull())
139
+ )
140
+
141
+ when_decimals = None
142
+ when_rounded = None
143
+
144
+ for rule in domain_specific_rules.collect():
145
+ condition = lit(True)
146
+ if rule["column_1_name"] != "" and rule["column_1_value"] != "":
147
+ column_1_name = rule["column_1_name"]
148
+ column_1_value_str = rule["column_1_value"]
149
+
150
+ column_1_value_list = [
151
+ v.strip() for v in str(column_1_value_str).split(",")
152
+ ]
153
+ condition &= col(column_1_name).isin(column_1_value_list)
154
+
155
+ if (
156
+ rule["column_2_name"] is not None
157
+ and rule["column_2_name"] != ""
158
+ and rule["column_2_value"] is not None
159
+ and rule["column_2_value"] != ""
160
+ ):
161
+ column_2_name = rule["column_2_name"]
162
+ column_2_value_str = rule["column_2_value"]
163
+
164
+ column_2_value_list = [
165
+ v.strip() for v in str(column_2_value_str).split(",")
166
+ ]
167
+ condition &= col(column_2_name).isin(column_2_value_list)
168
+
169
+ display_decimals = int(rule["display_decimals"])
170
+
171
+ # Count actual decimal places in the current value
172
+ # Handle both regular decimals and scientific notation
173
+ # Convert scientific notation to decimal format first
174
+ value_str_normalized = F.when(
175
+ F.col(value_column).cast("string").rlike("[eE]"),
176
+ F.format_number(F.col(value_column).cast("double"), 20),
177
+ ).otherwise(F.col(value_column).cast("string"))
178
+
179
+ actual_decimals = F.length(
180
+ F.regexp_extract(value_str_normalized, DECIMAL_PLACES_REGEX, 1)
111
181
  )
112
- if col2_name.lower() not in [column.lower() for column in df.columns]:
113
- raise ValueError(
114
- f"{col2_name} is not part of the columns available for this dataset ({df.columns})"
182
+
183
+ # Add decimals condition
184
+ when_decimals = (
185
+ F.when(condition, lit(display_decimals))
186
+ if when_decimals is None
187
+ else when_decimals.when(condition, lit(display_decimals))
115
188
  )
116
189
 
117
- df = (
118
- df.alias("d")
119
- .join(
120
- self.display_decimals.alias("dd"),
121
- on=(col(f"d.{col1_name}") == col("dd.column_1_value"))
122
- & (col(f"d.{col2_name}") == col("dd.column_2_value")),
123
- how="left",
190
+ # Add rounding condition based on display_decimals
191
+ # Only apply rounding if current decimals >= target decimals
192
+ if display_decimals > 6:
193
+ # Cast to float and round
194
+ # Cast to DECIMAL with precision 38 and decimals as display_decimals + 2
195
+ precision = 38
196
+ decimals = display_decimals
197
+ rounded_value = col(value_column).cast(DecimalType(precision, decimals))
198
+ else:
199
+ # Cast to DECIMAL with precision 38 and decimals as display_decimals + 2
200
+ precision = 38
201
+ decimals = display_decimals + 2
202
+ decimal_value = col(value_column).cast(DecimalType(precision, decimals))
203
+ scale = pow(lit(10), lit(display_decimals)).cast(
204
+ DecimalType(precision, decimals)
205
+ )
206
+ rounded_value = F.round(decimal_value * scale) / scale
207
+
208
+ # Only round if actual decimals >= target decimals, otherwise keep original
209
+ rounded_value = F.when(
210
+ actual_decimals >= lit(display_decimals), rounded_value
211
+ ).otherwise(col(value_column))
212
+
213
+ when_rounded = (
214
+ F.when(condition, rounded_value)
215
+ if when_rounded is None
216
+ else when_rounded.when(condition, rounded_value)
124
217
  )
125
- .select("d.*", "dd.display_decimals")
126
- )
127
218
 
128
- df.filter(col("display_decimals").isNull()).select(
129
- col1_name, col2_name
130
- ).distinct()
131
- logging.warning(
132
- f"The following combinations of {col1_name} and {col2_name} are not available in the table {DatasetDatatables.DISPLAY_DECIMALS.name} and will be assigned to 0"
219
+ # Add otherwise with default value for decimals
220
+ when_decimals = (
221
+ lit(default_decimals)
222
+ if when_decimals is None
223
+ else when_decimals.otherwise(lit(default_decimals))
133
224
  )
134
225
 
135
- df = df.withColumn(
136
- "display_decimals",
137
- F.coalesce(col("display_decimals"), lit("0")).cast("INT"),
138
- ).withColumn(
139
- "value",
140
- F.round(
141
- F.col("value").cast("FLOAT") * F.pow(10, F.col("display_decimals")), 0
226
+ # Add otherwise with default rounding for value
227
+ if default_decimals > 6:
228
+ default_rounded = F.round(
229
+ col(value_column).cast(FloatType()), default_decimals
230
+ )
231
+ else:
232
+ precision = 38
233
+ decimals = default_decimals + 2
234
+ default_decimal_value = col(value_column).cast(
235
+ DecimalType(precision, decimals)
236
+ )
237
+ default_scale = pow(lit(10), lit(default_decimals)).cast(
238
+ DecimalType(precision, decimals)
239
+ )
240
+ default_rounded = (
241
+ F.round(default_decimal_value * default_scale) / default_scale
142
242
  )
143
- / F.pow(10, F.col("display_decimals")).cast("STRING"),
243
+
244
+ # Only round if actual decimals >= target decimals, otherwise keep original
245
+ # Handle both regular decimals and scientific notation for default case
246
+ value_str_normalized_default = F.when(
247
+ F.col(value_column).cast("string").rlike("[eE]"),
248
+ F.format_number(F.col(value_column).cast("double"), 20),
249
+ ).otherwise(F.col(value_column).cast("string"))
250
+
251
+ actual_decimals_default = F.length(
252
+ F.regexp_extract(value_str_normalized_default, DECIMAL_PLACES_REGEX, 1)
253
+ )
254
+ default_rounded = F.when(
255
+ actual_decimals_default >= lit(default_decimals), default_rounded
256
+ ).otherwise(col(value_column))
257
+
258
+ when_rounded = (
259
+ default_rounded
260
+ if when_rounded is None
261
+ else when_rounded.otherwise(default_rounded)
144
262
  )
145
263
 
146
- # F.round(
147
- # col("value").cast("FLOAT"), col("display_decimals").cast("INT")
148
- # ).cast("STRING"),
264
+ df = df.withColumn("display_decimals", when_decimals)
265
+ df = df.withColumn(value_column, when_rounded)
149
266
 
150
267
  return df
151
268
 
@@ -468,7 +468,7 @@ class SWSPostgresSparkReader:
468
468
  correct_domain_filter, domain=domain_code, unique_columns=["code"]
469
469
  )
470
470
  for col_type in mapping_dim_col_name_type.values()
471
- if col_type != "other"
471
+ if col_type not in ("year", "other")
472
472
  }
473
473
 
474
474
  def import_diss_exceptions_datatable(
@@ -497,45 +497,3 @@ class SWSPostgresSparkReader:
497
497
  "aggregation",
498
498
  ],
499
499
  )
500
-
501
- def get_display_decimals_datatable(
502
- self,
503
- domain_code: str,
504
- ) -> DataFrame:
505
- df = self.read_pg_table(
506
- pg_table=DatasetDatatables.DISPLAY_DECIMALS.id,
507
- custom_schema=DatasetDatatables.DISPLAY_DECIMALS.schema,
508
- ).filter(col("domain") == lit(domain_code))
509
-
510
- pairs = df.select("column_1_name", "column_2_name").distinct().collect()
511
-
512
- # If no config exists for this domain, fail early
513
- if not pairs:
514
- msg = (
515
- f'No display-decimals configuration found for domain "{domain_code}". '
516
- f'Please add an entry in table "{DatasetDatatables.DISPLAY_DECIMALS.id}".'
517
- )
518
- logging.error(msg)
519
- # raise ValueError(msg)
520
-
521
- # If more than one mapping exists, it's invalid
522
- if len(pairs) > 1:
523
- formatted_pairs = [(p["column_1_name"], p["column_2_name"]) for p in pairs]
524
-
525
- msg = (
526
- f'Invalid configuration for domain "{domain_code}". '
527
- f"Expected exactly one (column_1_name, column_2_name) pair, but found {len(pairs)}: "
528
- f"{formatted_pairs}. "
529
- f'Please correct the table "{DatasetDatatables.DISPLAY_DECIMALS.id}".'
530
- )
531
-
532
- logging.error(
533
- "Multiple display-decimals column pairs detected",
534
- extra={
535
- "domain": domain_code,
536
- "pairs_found": formatted_pairs,
537
- },
538
- )
539
- raise ValueError(msg)
540
-
541
- return df
@@ -444,7 +444,7 @@ class SWSSilverIcebergSparkHelper:
444
444
  logging.info("Checking the dissemination flag for each dimension (except year)")
445
445
 
446
446
  for col_name, col_type in self.mapping_dim_col_name_type.items():
447
- if col_type != "other":
447
+ if col_type not in ("other", "year"):
448
448
  df = self._check_diss_dim_list(
449
449
  df,
450
450
  self.dfs_diss_flags[col_type],
@@ -168,7 +168,7 @@ class DatasetTables:
168
168
  self.OBSERVATION = self.__SWSTable(
169
169
  postgres_id=f"{self.__dataset_id}.observation",
170
170
  iceberg_id=f"{IcebergDatabases.STAGING_DATABASE}.{self.__dataset_id}_observation",
171
- schema="id BIGINT, observation_coordinates BIGINT, version INT, value FLOAT, flag_obs_status STRING, flag_method STRING, created_on TIMESTAMP, created_by INT, replaced_on TIMESTAMP",
171
+ schema="id BIGINT, observation_coordinates BIGINT, version INT, value STRING, flag_obs_status STRING, flag_method STRING, created_on TIMESTAMP, created_by INT, replaced_on TIMESTAMP",
172
172
  )
173
173
  self.OBSERVATION_COORDINATE = self.__SWSTable(
174
174
  postgres_id=f"{self.__dataset_id}.observation_coordinate",