sws-spark-dissemination-helper 0.0.180__tar.gz → 0.0.190__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/PKG-INFO +2 -2
  2. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/pyproject.toml +2 -2
  3. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/src/sws_spark_dissemination_helper/SWSGoldIcebergSparkHelper.py +193 -51
  4. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/src/sws_spark_dissemination_helper/SWSPostgresSparkReader.py +1 -43
  5. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/src/sws_spark_dissemination_helper/SWSSilverIcebergSparkHelper.py +3 -3
  6. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/src/sws_spark_dissemination_helper/constants.py +1 -1
  7. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/.gitignore +0 -0
  8. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/LICENSE +0 -0
  9. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/README.md +0 -0
  10. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/src/sws_spark_dissemination_helper/SWSBronzeIcebergSparkHelper.py +0 -0
  11. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/src/sws_spark_dissemination_helper/SWSDatatablesExportHelper.py +0 -0
  12. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/src/sws_spark_dissemination_helper/SWSEasyIcebergSparkHelper.py +0 -0
  13. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/src/sws_spark_dissemination_helper/__init__.py +0 -0
  14. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/src/sws_spark_dissemination_helper/utils.py +0 -0
  15. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/tests/__init__.py +0 -0
  16. {sws_spark_dissemination_helper-0.0.180 → sws_spark_dissemination_helper-0.0.190}/tests/test.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sws-spark-dissemination-helper
3
- Version: 0.0.180
3
+ Version: 0.0.190
4
4
  Summary: A Python helper package providing streamlined Spark functions for efficient data dissemination processes
5
5
  Project-URL: Repository, https://github.com/un-fao/fao-sws-it-python-spark-dissemination-helper
6
6
  Author-email: Daniele Mansillo <danielemansillo@gmail.com>
@@ -49,7 +49,7 @@ Requires-Dist: pytz==2025.2
49
49
  Requires-Dist: requests==2.32.3
50
50
  Requires-Dist: s3transfer>=0.11.2
51
51
  Requires-Dist: six==1.17.0
52
- Requires-Dist: sws-api-client==2.3.0
52
+ Requires-Dist: sws-api-client==2.7.3
53
53
  Requires-Dist: typing-extensions>=4.12.2
54
54
  Requires-Dist: tzdata==2025.2
55
55
  Requires-Dist: urllib3==1.26.20
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "sws-spark-dissemination-helper"
7
- version = "0.0.180"
7
+ version = "0.0.190"
8
8
  dependencies = [
9
9
  "annotated-types==0.7.0",
10
10
  "boto3>=1.40.0",
@@ -25,7 +25,7 @@ dependencies = [
25
25
  "requests==2.32.3",
26
26
  "s3transfer>=0.11.2",
27
27
  "six==1.17.0",
28
- "sws_api_client==2.3.0",
28
+ "sws_api_client==2.7.3",
29
29
  "typing_extensions>=4.12.2",
30
30
  "tzdata==2025.2",
31
31
  "urllib3==1.26.20"
@@ -5,13 +5,19 @@ from typing import List, Tuple
5
5
  import pyspark.sql.functions as F
6
6
  from pyspark.sql import DataFrame, SparkSession
7
7
  from pyspark.sql.functions import col, lit
8
+ from pyspark.sql.types import DecimalType, FloatType
8
9
  from sws_api_client import Tags
9
10
  from sws_api_client.tags import BaseDisseminatedTagTable, TableLayer, TableType
10
11
 
11
- from .constants import IcebergDatabases, IcebergTables, DatasetDatatables
12
+ from .constants import DatasetDatatables, IcebergDatabases, IcebergTables
12
13
  from .SWSPostgresSparkReader import SWSPostgresSparkReader
13
14
  from .utils import get_or_create_tag, save_cache_csv, upsert_disseminated_table
14
15
 
16
+ SIMPLE_NUMERIC_REGEX = r"^[+-]?\d*(\.\d+)?$"
17
+ NUMERIC_REGEX = r"^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$"
18
+ # Regex to extract decimal places: matches the decimal part and counts its length
19
+ DECIMAL_PLACES_REGEX = r"\.(\d+)$"
20
+
15
21
 
16
22
  class SWSGoldIcebergSparkHelper:
17
23
  def __init__(
@@ -62,10 +68,9 @@ class SWSGoldIcebergSparkHelper:
62
68
  if col_name in self.dim_columns
63
69
  }
64
70
 
65
- self.display_decimals = (
66
- self.sws_postgres_spark_reader.get_display_decimals_datatable(
67
- domain_code=domain_code
68
- )
71
+ self.display_decimals_df = self.sws_postgres_spark_reader.read_pg_table(
72
+ pg_table=DatasetDatatables.DISPLAY_DECIMALS.id,
73
+ custom_schema=DatasetDatatables.DISPLAY_DECIMALS.schema,
69
74
  )
70
75
 
71
76
  def _get_dim_time_flag_columns(self) -> Tuple[List[str], List[str], str, List[str]]:
@@ -88,61 +93,164 @@ class SWSGoldIcebergSparkHelper:
88
93
  def apply_diss_flag_filter(self, df: DataFrame) -> DataFrame:
89
94
  return df.filter(col("diss_flag"))
90
95
 
91
- def keep_dim_val_attr_columns(self, df: DataFrame):
96
+ def keep_dim_val_attr_columns(
97
+ self, df: DataFrame, additional_columns: List[str] = []
98
+ ):
92
99
  cols_to_keep_sws = self.cols_to_keep_sws
93
- if "note" in df.columns:
94
- cols_to_keep_sws = cols_to_keep_sws + ["note"]
100
+ for additional_column in additional_columns:
101
+ if additional_column in df.columns:
102
+ cols_to_keep_sws = cols_to_keep_sws + [additional_column]
95
103
  if "unit_of_measure_symbol" in df.columns:
96
104
  cols_to_keep_sws = cols_to_keep_sws + ["unit_of_measure_symbol"]
97
105
  return df.select(*cols_to_keep_sws)
98
106
 
99
- def round_to_display_decimals(self, df: DataFrame):
100
- col1_name, col2_name = (
101
- self.display_decimals.select("column_1_name", "column_2_name")
102
- .distinct()
103
- .collect()[0]
104
- )
105
- if col1_name.lower() not in [column.lower() for column in df.columns]:
106
- raise ValueError(
107
- f"{col1_name} is not part of the columns available for this dataset ({df.columns})"
107
+ def round_to_display_decimals(
108
+ self,
109
+ df: DataFrame,
110
+ value_column: str = "value",
111
+ ) -> DataFrame:
112
+
113
+ df = df.withColumn("unrounded_value", col(value_column))
114
+
115
+ general_default_decimals = (
116
+ self.display_decimals_df.filter(col("domain") == lit("DEFAULT"))
117
+ .select("display_decimals")
118
+ .collect()[0][0]
119
+ )
120
+ domain_default_decimals = self.display_decimals_df.filter(
121
+ (col("domain") == lit(self.domain_code))
122
+ & col("column_1_name").isNull()
123
+ & col("column_2_name").isNull()
124
+ ).select("display_decimals")
125
+
126
+ default_decimals = int(
127
+ general_default_decimals
128
+ if domain_default_decimals.count() == 0
129
+ else domain_default_decimals.collect()[0][0]
130
+ )
131
+
132
+ domain_specific_rules = self.display_decimals_df.filter(
133
+ (col("domain") == lit(self.domain_code))
134
+ & (col("column_1_name").isNotNull() & col("column_1_value").isNotNull())
135
+ | (col("column_2_name").isNotNull() & col("column_2_value").isNotNull())
136
+ )
137
+
138
+ when_decimals = None
139
+ when_rounded = None
140
+
141
+ for rule in domain_specific_rules.collect():
142
+ condition = lit(True)
143
+ if rule["column_1_name"] != "" and rule["column_1_value"] != "":
144
+ column_1_name = rule["column_1_name"]
145
+ column_1_value_str = rule["column_1_value"]
146
+
147
+ column_1_value_list = [
148
+ v.strip() for v in str(column_1_value_str).split(",")
149
+ ]
150
+ condition &= col(column_1_name).isin(column_1_value_list)
151
+
152
+ if (
153
+ rule["column_2_name"] is not None
154
+ and rule["column_2_name"] != ""
155
+ and rule["column_2_value"] is not None
156
+ and rule["column_2_value"] != ""
157
+ ):
158
+ column_2_name = rule["column_2_name"]
159
+ column_2_value_str = rule["column_2_value"]
160
+
161
+ column_2_value_list = [
162
+ v.strip() for v in str(column_2_value_str).split(",")
163
+ ]
164
+ condition &= col(column_2_name).isin(column_2_value_list)
165
+
166
+ display_decimals = int(rule["display_decimals"])
167
+
168
+ # Count actual decimal places in the current value
169
+ # If the value already has fewer decimals than target, skip rounding
170
+ actual_decimals = F.length(
171
+ F.regexp_extract(
172
+ F.col(value_column).cast("string"), DECIMAL_PLACES_REGEX, 1
173
+ )
108
174
  )
109
- if col2_name.lower() not in [column.lower() for column in df.columns]:
110
- raise ValueError(
111
- f"{col2_name} is not part of the columns available for this dataset ({df.columns})"
175
+
176
+ # Add decimals condition
177
+ when_decimals = (
178
+ F.when(condition, lit(display_decimals))
179
+ if when_decimals is None
180
+ else when_decimals.when(condition, lit(display_decimals))
112
181
  )
113
182
 
114
- df = (
115
- df.alias("d")
116
- .join(
117
- self.display_decimals.alias("dd"),
118
- on=(col(f"d.{col1_name}") == col("dd.column_1_value"))
119
- & (col(f"d.{col2_name}") == col("dd.column_2_value")),
120
- how="left",
183
+ # Add rounding condition based on display_decimals
184
+ # Only apply rounding if current decimals >= target decimals
185
+ if display_decimals > 6:
186
+ # Cast to float and round
187
+ rounded_value = F.round(
188
+ col(value_column).cast(FloatType()), display_decimals
189
+ )
190
+ else:
191
+ # Cast to DECIMAL with precision 38 and decimals as display_decimals + 2
192
+ precision = 38
193
+ decimals = display_decimals + 2
194
+ decimal_value = col(value_column).cast(DecimalType(precision, decimals))
195
+ scale = pow(lit(10), lit(display_decimals)).cast(
196
+ DecimalType(precision, decimals)
197
+ )
198
+ rounded_value = F.round(decimal_value * scale) / scale
199
+
200
+ # Only round if actual decimals >= target decimals, otherwise keep original
201
+ rounded_value = F.when(
202
+ actual_decimals >= lit(display_decimals), rounded_value
203
+ ).otherwise(col(value_column))
204
+
205
+ when_rounded = (
206
+ F.when(condition, rounded_value)
207
+ if when_rounded is None
208
+ else when_rounded.when(condition, rounded_value)
121
209
  )
122
- .select("d.*", "dd.display_decimals")
123
- )
124
210
 
125
- df.filter(col("display_decimals").isNull()).select(
126
- col1_name, col2_name
127
- ).distinct()
128
- logging.warning(
129
- f"The following combinations of {col1_name} and {col2_name} are not available in the table {DatasetDatatables.DISPLAY_DECIMALS.name} and will be assigned to 0"
211
+ # Add otherwise with default value for decimals
212
+ when_decimals = (
213
+ lit(default_decimals)
214
+ if when_decimals is None
215
+ else when_decimals.otherwise(lit(default_decimals))
130
216
  )
131
217
 
132
- df = df.withColumn(
133
- "display_decimals",
134
- F.coalesce(col("display_decimals"), lit("0")).cast("INT"),
135
- ).withColumn(
136
- "value",
137
- F.round(
138
- F.col("value").cast("FLOAT") * F.pow(10, F.col("display_decimals")), 0
218
+ # Add otherwise with default rounding for value
219
+ if default_decimals > 6:
220
+ default_rounded = F.round(
221
+ col(value_column).cast(FloatType()), default_decimals
139
222
  )
140
- / F.pow(10, F.col("display_decimals")).cast("STRING"),
223
+ else:
224
+ precision = 38
225
+ decimals = default_decimals + 2
226
+ default_decimal_value = col(value_column).cast(
227
+ DecimalType(precision, decimals)
228
+ )
229
+ default_scale = pow(lit(10), lit(default_decimals)).cast(
230
+ DecimalType(precision, decimals)
231
+ )
232
+ default_rounded = (
233
+ F.round(default_decimal_value * default_scale) / default_scale
234
+ )
235
+
236
+ # Only round if actual decimals >= target decimals, otherwise keep original
237
+ actual_decimals_default = F.length(
238
+ F.regexp_extract(
239
+ F.col(value_column).cast("string"), DECIMAL_PLACES_REGEX, 1
240
+ )
241
+ )
242
+ default_rounded = F.when(
243
+ actual_decimals_default >= lit(default_decimals), default_rounded
244
+ ).otherwise(col(value_column))
245
+
246
+ when_rounded = (
247
+ default_rounded
248
+ if when_rounded is None
249
+ else when_rounded.otherwise(default_rounded)
141
250
  )
142
251
 
143
- # F.round(
144
- # col("value").cast("FLOAT"), col("display_decimals").cast("INT")
145
- # ).cast("STRING"),
252
+ df = df.withColumn("display_decimals", when_decimals)
253
+ df = df.withColumn(value_column, when_rounded)
146
254
 
147
255
  return df
148
256
 
@@ -156,18 +264,26 @@ class SWSGoldIcebergSparkHelper:
156
264
  self.iceberg_tables.SILVER.iceberg_id
157
265
  )
158
266
 
159
- def gen_gold_sws_disseminated_data(self) -> DataFrame:
267
+ def gen_gold_sws_disseminated_data(
268
+ self, additional_columns: List[str] = []
269
+ ) -> DataFrame:
160
270
  return (
161
271
  self.read_silver_data()
162
272
  .transform(self.apply_diss_flag_filter)
163
- .transform(self.keep_dim_val_attr_columns)
273
+ .transform(self.keep_dim_val_attr_columns, additional_columns)
164
274
  )
165
275
 
166
- def gen_gold_sws_data(self) -> DataFrame:
167
- return self.read_bronze_data().transform(self.keep_dim_val_attr_columns)
276
+ def gen_gold_sws_data(self, additional_columns: List[str] = []) -> DataFrame:
277
+ return self.read_bronze_data().transform(
278
+ self.keep_dim_val_attr_columns, additional_columns
279
+ )
168
280
 
169
- def gen_gold_sws_validated_data(self) -> DataFrame:
170
- return self.read_silver_data().transform(self.keep_dim_val_attr_columns)
281
+ def gen_gold_sws_validated_data(
282
+ self, additional_columns: List[str] = []
283
+ ) -> DataFrame:
284
+ return self.read_silver_data().transform(
285
+ self.keep_dim_val_attr_columns, additional_columns
286
+ )
171
287
 
172
288
  def write_gold_sws_validated_data_to_iceberg_and_csv(
173
289
  self, df: DataFrame
@@ -724,3 +840,29 @@ class SWSGoldIcebergSparkHelper:
724
840
  logging.debug(f"Tag with Added csv Table: {tag}")
725
841
 
726
842
  return df
843
+
844
+
845
+ 1
846
+ frozenset({"1", "2", "6", "7", "5", "8", "0", "4", "3", "9"})
847
+ 1
848
+ 1
849
+ 2
850
+ frozenset({"1", "2", "6", "7", "5", "8", "0", "4", "3", "9"})
851
+ 2
852
+ 1
853
+ 1
854
+ frozenset({"1", "2", "6", "7", "5", "8", "0", "4", "3", "9"})
855
+ 1
856
+ 1
857
+ 2
858
+ frozenset({"1", "2", "6", "7", "5", "8", "0", "4", "3", "9"})
859
+ 2
860
+ 1
861
+ 1
862
+ frozenset({"1", "2", "6", "7", "5", "8", "0", "4", "3", "9"})
863
+ 1
864
+ 1
865
+ 1
866
+ frozenset({"1", "2", "6", "7", "5", "8", "0", "4", "3", "9"})
867
+ 1
868
+ 1
@@ -468,7 +468,7 @@ class SWSPostgresSparkReader:
468
468
  correct_domain_filter, domain=domain_code, unique_columns=["code"]
469
469
  )
470
470
  for col_type in mapping_dim_col_name_type.values()
471
- if col_type != "other"
471
+ if col_type not in ("year", "other")
472
472
  }
473
473
 
474
474
  def import_diss_exceptions_datatable(
@@ -497,45 +497,3 @@ class SWSPostgresSparkReader:
497
497
  "aggregation",
498
498
  ],
499
499
  )
500
-
501
- def get_display_decimals_datatable(
502
- self,
503
- domain_code: str,
504
- ) -> DataFrame:
505
- df = self.read_pg_table(
506
- pg_table=DatasetDatatables.DISPLAY_DECIMALS.id,
507
- custom_schema=DatasetDatatables.DISPLAY_DECIMALS.schema,
508
- ).filter(col("domain") == lit(domain_code))
509
-
510
- pairs = df.select("column_1_name", "column_2_name").distinct().collect()
511
-
512
- # If no config exists for this domain, fail early
513
- if not pairs:
514
- msg = (
515
- f'No display-decimals configuration found for domain "{domain_code}". '
516
- f'Please add an entry in table "{DatasetDatatables.DISPLAY_DECIMALS.id}".'
517
- )
518
- logging.error(msg)
519
- # raise ValueError(msg)
520
-
521
- # If more than one mapping exists, it's invalid
522
- if len(pairs) > 1:
523
- formatted_pairs = [(p["column_1_name"], p["column_2_name"]) for p in pairs]
524
-
525
- msg = (
526
- f'Invalid configuration for domain "{domain_code}". '
527
- f"Expected exactly one (column_1_name, column_2_name) pair, but found {len(pairs)}: "
528
- f"{formatted_pairs}. "
529
- f'Please correct the table "{DatasetDatatables.DISPLAY_DECIMALS.id}".'
530
- )
531
-
532
- logging.error(
533
- "Multiple display-decimals column pairs detected",
534
- extra={
535
- "domain": domain_code,
536
- "pairs_found": formatted_pairs,
537
- },
538
- )
539
- raise ValueError(msg)
540
-
541
- return df
@@ -209,7 +209,7 @@ class SWSSilverIcebergSparkHelper:
209
209
  F.array_append(
210
210
  col("d.diss_note"),
211
211
  F.concat(
212
- col("sy.diss_note"),
212
+ col("sy.note"),
213
213
  lit(" from "),
214
214
  col("sy.old_code"),
215
215
  lit(" to "),
@@ -225,7 +225,7 @@ class SWSSilverIcebergSparkHelper:
225
225
  F.array_append(
226
226
  col("new_diss_note"),
227
227
  F.concat(
228
- col("ey.diss_note"),
228
+ col("ey.note"),
229
229
  lit(" from "),
230
230
  col("ey.old_code"),
231
231
  lit(" to "),
@@ -444,7 +444,7 @@ class SWSSilverIcebergSparkHelper:
444
444
  logging.info("Checking the dissemination flag for each dimension (except year)")
445
445
 
446
446
  for col_name, col_type in self.mapping_dim_col_name_type.items():
447
- if col_type != "other":
447
+ if col_type not in ("other", "year"):
448
448
  df = self._check_diss_dim_list(
449
449
  df,
450
450
  self.dfs_diss_flags[col_type],
@@ -168,7 +168,7 @@ class DatasetTables:
168
168
  self.OBSERVATION = self.__SWSTable(
169
169
  postgres_id=f"{self.__dataset_id}.observation",
170
170
  iceberg_id=f"{IcebergDatabases.STAGING_DATABASE}.{self.__dataset_id}_observation",
171
- schema="id BIGINT, observation_coordinates BIGINT, version INT, value FLOAT, flag_obs_status STRING, flag_method STRING, created_on TIMESTAMP, created_by INT, replaced_on TIMESTAMP",
171
+ schema="id BIGINT, observation_coordinates BIGINT, version INT, value STRING, flag_obs_status STRING, flag_method STRING, created_on TIMESTAMP, created_by INT, replaced_on TIMESTAMP",
172
172
  )
173
173
  self.OBSERVATION_COORDINATE = self.__SWSTable(
174
174
  postgres_id=f"{self.__dataset_id}.observation_coordinate",