sqlframe 3.16.0__py3-none-any.whl → 3.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlframe/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '3.16.0'
16
- __version_tuple__ = version_tuple = (3, 16, 0)
15
+ __version__ = version = '3.17.0'
16
+ __version_tuple__ = version_tuple = (3, 17, 0)
sqlframe/base/column.py CHANGED
@@ -291,6 +291,7 @@ class Column:
291
291
  this=self.column_expression,
292
292
  alias=alias.this if isinstance(alias, exp.Column) else alias,
293
293
  )
294
+ new_expression._meta = {"display_name": name, **(new_expression._meta or {})}
294
295
  return Column(new_expression)
295
296
 
296
297
  def asc(self) -> Column:
@@ -233,6 +233,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
233
233
  last_op: Operation = Operation.INIT,
234
234
  pending_hints: t.Optional[t.List[exp.Expression]] = None,
235
235
  output_expression_container: t.Optional[OutputExpressionContainer] = None,
236
+ display_name_mapping: t.Optional[t.Dict[str, str]] = None,
236
237
  **kwargs,
237
238
  ):
238
239
  self.session = session
@@ -246,6 +247,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
246
247
  self.pending_hints = pending_hints or []
247
248
  self.output_expression_container = output_expression_container or exp.Select()
248
249
  self.temp_views: t.List[exp.Select] = []
250
+ self.display_name_mapping = display_name_mapping or {}
249
251
 
250
252
  def __getattr__(self, column_name: str) -> Column:
251
253
  return self[column_name]
@@ -385,13 +387,14 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
385
387
  return Column.ensure_cols(ensure_list(cols)) # type: ignore
386
388
 
387
389
  def _ensure_and_normalize_cols(
388
- self, cols, expression: t.Optional[exp.Select] = None
390
+ self, cols, expression: t.Optional[exp.Select] = None, skip_star_expansion: bool = False
389
391
  ) -> t.List[Column]:
390
392
  from sqlframe.base.normalize import normalize
391
393
 
392
394
  cols = self._ensure_list_of_columns(cols)
393
395
  normalize(self.session, expression or self.expression, cols)
394
- cols = list(flatten([self._expand_star(col) for col in cols]))
396
+ if not skip_star_expansion:
397
+ cols = list(flatten([self._expand_star(col) for col in cols]))
395
398
  self._resolve_ambiguous_columns(cols)
396
399
  return cols
397
400
 
@@ -592,6 +595,23 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
592
595
  )
593
596
  return [col]
594
597
 
598
+ def _update_display_name_mapping(
599
+ self, normalized_columns: t.List[Column], user_input: t.Iterable[ColumnOrName]
600
+ ) -> None:
601
+ from sqlframe.base.column import Column
602
+
603
+ normalized_aliases = [x.alias_or_name for x in normalized_columns]
604
+ user_display_names = [
605
+ x.expression.meta.get("display_name") if isinstance(x, Column) else x
606
+ for x in user_input
607
+ ]
608
+ zipped = {
609
+ k: v
610
+ for k, v in dict(zip(normalized_aliases, user_display_names)).items()
611
+ if v is not None
612
+ }
613
+ self.display_name_mapping.update(zipped)
614
+
595
615
  def _get_expressions(
596
616
  self,
597
617
  optimize: bool = True,
@@ -611,6 +631,16 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
611
631
  select_expression = select_expression.transform(
612
632
  replace_id_value, replacement_mapping
613
633
  ).assert_is(exp.Select)
634
+ for index, column in enumerate(select_expression.expressions):
635
+ column_name = quote_preserving_alias_or_name(column)
636
+ if column_name in self.display_name_mapping:
637
+ display_name_identifier = exp.to_identifier(
638
+ self.display_name_mapping[column_name], quoted=True
639
+ )
640
+ display_name_identifier._meta = {"case_sensitive": True, **(column._meta or {})}
641
+ select_expression.expressions[index] = exp.alias_(
642
+ column.unalias(), display_name_identifier, quoted=True
643
+ )
614
644
  if optimize:
615
645
  select_expression = t.cast(
616
646
  exp.Select,
@@ -803,6 +833,17 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
803
833
  if isinstance(cols[0], list):
804
834
  cols = cols[0] # type: ignore
805
835
  columns = self._ensure_and_normalize_cols(cols)
836
+ if "skip_update_display_name_mapping" not in kwargs:
837
+ unexpanded_columns = self._ensure_and_normalize_cols(cols, skip_star_expansion=True)
838
+ user_cols = list(cols)
839
+ star_columns = []
840
+ for index, user_col in enumerate(cols):
841
+ if "*" in (user_col if isinstance(user_col, str) else user_col.alias_or_name):
842
+ star_columns.append(index)
843
+ for index in star_columns:
844
+ unexpanded_columns.pop(index)
845
+ user_cols.pop(index)
846
+ self._update_display_name_mapping(unexpanded_columns, user_cols)
806
847
  kwargs["append"] = kwargs.get("append", False)
807
848
  # If an expression is `CAST(x AS DATETYPE)` then we want to alias so that `x` is the result column name
808
849
  columns = [
@@ -852,6 +893,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
852
893
  @operation(Operation.SELECT)
853
894
  def agg(self, *exprs, **kwargs) -> Self:
854
895
  cols = self._ensure_and_normalize_cols(exprs)
896
+ self._update_display_name_mapping(cols, exprs)
855
897
  return self.groupBy().agg(*cols)
856
898
 
857
899
  @operation(Operation.FROM)
@@ -1051,7 +1093,9 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1051
1093
  new_df = self.copy(expression=join_expression)
1052
1094
  new_df.pending_join_hints.extend(self.pending_join_hints)
1053
1095
  new_df.pending_hints.extend(other_df.pending_hints)
1054
- new_df = new_df.select.__wrapped__(new_df, *select_column_names) # type: ignore
1096
+ new_df = new_df.select.__wrapped__( # type: ignore
1097
+ new_df, *select_column_names, skip_update_display_name_mapping=True
1098
+ )
1055
1099
  return new_df
1056
1100
 
1057
1101
  @operation(Operation.ORDER_BY)
@@ -1441,20 +1485,18 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1441
1485
  def withColumnRenamed(self, existing: str, new: str) -> Self:
1442
1486
  expression = self.expression.copy()
1443
1487
  existing = self.session._normalize_string(existing)
1444
- new = self.session._normalize_string(new)
1445
- existing_columns = [
1446
- expression
1447
- for expression in expression.expressions
1448
- if expression.alias_or_name == existing
1449
- ]
1450
- if not existing_columns:
1488
+ columns = self._get_outer_select_columns(expression)
1489
+ results = []
1490
+ found_match = False
1491
+ for column in columns:
1492
+ if column.alias_or_name == existing:
1493
+ column = column.alias(new)
1494
+ self._update_display_name_mapping([column], [new])
1495
+ found_match = True
1496
+ results.append(column)
1497
+ if not found_match:
1451
1498
  raise ValueError("Tried to rename a column that doesn't exist")
1452
- for existing_column in existing_columns:
1453
- if isinstance(existing_column, exp.Column):
1454
- existing_column.replace(exp.alias_(existing_column, new))
1455
- else:
1456
- existing_column.set("alias", exp.to_identifier(new))
1457
- return self.copy(expression=expression)
1499
+ return self.select.__wrapped__(self, *results, skip_update_display_name_mapping=True) # type: ignore
1458
1500
 
1459
1501
  @operation(Operation.SELECT)
1460
1502
  def withColumns(self, *colsMap: t.Dict[str, Column]) -> Self:
@@ -1495,23 +1537,27 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1495
1537
  if len(colsMap) != 1:
1496
1538
  raise ValueError("Only a single map is supported")
1497
1539
  col_map = {
1498
- self._ensure_and_normalize_col(k).alias_or_name: self._ensure_and_normalize_col(v)
1540
+ self._ensure_and_normalize_col(k): (self._ensure_and_normalize_col(v), k)
1499
1541
  for k, v in colsMap[0].items()
1500
1542
  }
1501
1543
  existing_cols = self._get_outer_select_columns(self.expression)
1502
1544
  existing_col_names = [x.alias_or_name for x in existing_cols]
1503
1545
  select_columns = existing_cols
1504
- for column_name, col_value in col_map.items():
1546
+ for col, (col_value, display_name) in col_map.items():
1547
+ column_name = col.alias_or_name
1505
1548
  existing_col_index = (
1506
1549
  existing_col_names.index(column_name) if column_name in existing_col_names else None
1507
1550
  )
1508
1551
  if existing_col_index is not None:
1509
1552
  select_columns[existing_col_index] = col_value.alias( # type: ignore
1510
- column_name
1511
- ).expression
1553
+ display_name
1554
+ )
1512
1555
  else:
1513
- select_columns.append(col_value.alias(column_name))
1514
- return self.select.__wrapped__(self, *select_columns) # type: ignore
1556
+ select_columns.append(col_value.alias(display_name))
1557
+ self._update_display_name_mapping(
1558
+ [col for col in col_map], [name for _, name in col_map.values()]
1559
+ )
1560
+ return self.select.__wrapped__(self, *select_columns, skip_update_display_name_mapping=True) # type: ignore
1515
1561
 
1516
1562
  @operation(Operation.SELECT)
1517
1563
  def drop(self, *cols: t.Union[str, Column]) -> Self:
@@ -39,11 +39,19 @@ def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column:
39
39
 
40
40
  dialect = _BaseSession().input_dialect
41
41
  if isinstance(column_name, str):
42
- return Column(
43
- expression.to_column(column_name, dialect=dialect).transform(
44
- dialect.normalize_identifier
45
- )
42
+ col_expression = expression.to_column(column_name, dialect=dialect).transform(
43
+ dialect.normalize_identifier
46
44
  )
45
+ case_sensitive_expression = expression.to_column(column_name, dialect=dialect)
46
+ if not isinstance(
47
+ case_sensitive_expression, (expression.Star, expression.Literal, expression.Null)
48
+ ):
49
+ col_expression._meta = {
50
+ "display_name": case_sensitive_expression.this.this,
51
+ **(col_expression._meta or {}),
52
+ }
53
+
54
+ return Column(col_expression)
47
55
  return Column(column_name)
48
56
 
49
57
 
sqlframe/base/session.py CHANGED
@@ -507,9 +507,14 @@ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, TABLE, CONN, UDF_REGIS
507
507
  result = self._cur.fetchall()
508
508
  if not self._cur.description:
509
509
  return []
510
+ case_sensitive_cols = []
511
+ for col in self._cur.description:
512
+ col_id = exp.parse_identifier(col[0], dialect=self.execution_dialect)
513
+ col_id._meta = {"case_sensitive": True, **(col_id._meta or {})}
514
+ case_sensitive_cols.append(col_id)
510
515
  columns = [
511
- normalize_string(x[0], from_dialect="execution", to_dialect="output", is_column=True)
512
- for x in self._cur.description
516
+ normalize_string(x, from_dialect="execution", to_dialect="output")
517
+ for x in case_sensitive_cols
513
518
  ]
514
519
  return [self._to_row(columns, row) for row in result]
515
520
 
sqlframe/spark/session.py CHANGED
@@ -79,17 +79,18 @@ class SparkSession(
79
79
  if skip_rows:
80
80
  return []
81
81
  assert self._last_df is not None
82
- return [
83
- Row(
84
- **{
85
- normalize_string(
86
- k, from_dialect="execution", to_dialect="output", is_column=True
87
- ): v
88
- for k, v in row.asDict().items()
89
- }
90
- )
91
- for row in self._last_df.collect()
92
- ]
82
+ results = []
83
+ for row in self._last_df.collect():
84
+ rows_normalized = {}
85
+ for k, v in row.asDict().items():
86
+ col_id = exp.parse_identifier(k, dialect=self.execution_dialect)
87
+ col_id._meta = {"case_sensitive": True, **(col_id._meta or {})}
88
+ col_name = normalize_string(
89
+ col_id, from_dialect="execution", to_dialect="output", is_column=True
90
+ )
91
+ rows_normalized[col_name] = v
92
+ results.append(Row(**rows_normalized))
93
+ return results
93
94
 
94
95
  def _execute(self, sql: str) -> None:
95
96
  self._last_df = self.spark_session.sql(sql)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sqlframe
3
- Version: 3.16.0
3
+ Version: 3.17.0
4
4
  Summary: Turning PySpark Into a Universal DataFrame API
5
5
  Home-page: https://github.com/eakmanrq/sqlframe
6
6
  Author: Ryan Eakman
@@ -1,19 +1,19 @@
1
1
  sqlframe/__init__.py,sha256=wfqm98eLoLid9oV_FzzpG5loKC6LxOhj2lXpfN7SARo,3138
2
- sqlframe/_version.py,sha256=CtTis8a_OeN0EsLFoVgtqX-ARqHjuin2ATomgRROY1Y,413
2
+ sqlframe/_version.py,sha256=KdbrTz1mygb-tPODYZu2E4Sk2KYmeTUCHVpQLRpXAXo,413
3
3
  sqlframe/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  sqlframe/base/_typing.py,sha256=b2clI5HI1zEZKB_3Msx3FeAJQyft44ubUifJwQRVXyQ,1298
5
5
  sqlframe/base/catalog.py,sha256=SzFQalTWdhWzxUY-4ut1f9TfOECp_JmJEgNPfrRKCe0,38457
6
- sqlframe/base/column.py,sha256=wRghgieYAA51aw4WuFQWOvl0TFOToZbBhBuIamEzxx4,18011
7
- sqlframe/base/dataframe.py,sha256=KKBwtn73xNGt2gRwUB8Vri7Ee6_ivP5a_qij4Eq96zE,76622
6
+ sqlframe/base/column.py,sha256=oHVwkSWABO3ZlAbgBShsxSSlgbI06BOup5XJrRhgqJI,18097
7
+ sqlframe/base/dataframe.py,sha256=SQtwoQKpq-12WXuplOPN21fXQPvjF_D9WLcPPFA12Zs,78973
8
8
  sqlframe/base/decorators.py,sha256=ms-CvDOIW3T8IVB9VqDmLwAiaEsqXLYRXEqVQaxktiM,1890
9
9
  sqlframe/base/exceptions.py,sha256=9Uwvqn2eAkDpqm4BrRgbL61qM-GMCbJEMAW8otxO46s,370
10
10
  sqlframe/base/function_alternatives.py,sha256=NV31IaEhVYmfUSWetAEFISAvLzs2DxQ7bp-iMNgj0hQ,53786
11
- sqlframe/base/functions.py,sha256=o8zwbS8zCsyNe5arcb6dbAGBL8a1tH99rGyRimwzzUk,220614
11
+ sqlframe/base/functions.py,sha256=1LHxazgC9tZ_GzyWNsjU945SRnAsQjUH2easMJLU3h4,221012
12
12
  sqlframe/base/group.py,sha256=fsyG5990_Pd7gFPjTFrH9IEoAquL_wEkVpIlBAIkZJU,4091
13
13
  sqlframe/base/normalize.py,sha256=nXAJ5CwxVf4DV0GsH-q1w0p8gmjSMlv96k_ez1eVul8,3880
14
14
  sqlframe/base/operations.py,sha256=xSPw74e59wYvNd6U1AlwziNCTG6Aftrbl4SybN9u9VE,3450
15
15
  sqlframe/base/readerwriter.py,sha256=w8926cqIrXF7NGHiINw5UHzP_3xpjsqbijTBTzycBRM,26605
16
- sqlframe/base/session.py,sha256=s9M9_nbtOQQgLyEBZs-ijkMeHkYkILHfBc8JsU2SLmU,26369
16
+ sqlframe/base/session.py,sha256=0eBE_HYEb3npyyOGM7zS_VR8WgzvfgVI-PFLCK9Hy0M,26628
17
17
  sqlframe/base/table.py,sha256=rCeh1W5SWbtEVfkLAUiexzrZwNgmZeptLEmLcM1ABkE,6961
18
18
  sqlframe/base/transforms.py,sha256=y0j3SGDz3XCmNGrvassk1S-owllUWfkHyMgZlY6SFO4,467
19
19
  sqlframe/base/types.py,sha256=iBNk9bpFtb2NBIogYS8i7OlQZMRvpR6XxqzBebsjQDU,12280
@@ -110,7 +110,7 @@ sqlframe/spark/functions.py,sha256=MYCgHsjRQWylT-rezWRBuLV6BivcaVarbaQtP4T0toQ,3
110
110
  sqlframe/spark/functions.pyi,sha256=GyOdUzv2Z7Qt99JAKEPKgV2t2Rn274OuqwAfcoAXlN0,24259
111
111
  sqlframe/spark/group.py,sha256=MrvV_v-YkBc6T1zz882WrEqtWjlooWIyHBCmTQg3fCA,379
112
112
  sqlframe/spark/readwriter.py,sha256=zXZcCPWpQMMN90wdIx8AD4Y5tWBcpRSL4-yKX2aZyik,874
113
- sqlframe/spark/session.py,sha256=1kgi69uztJxJ6bJpgkpRxllOYgVrizKXA5iT88-jWKA,5421
113
+ sqlframe/spark/session.py,sha256=9qG-J5L8gmiy384GZFSBT2tHF8akqqJNij23Y3pheMs,5651
114
114
  sqlframe/spark/table.py,sha256=puWV8h_CqA64zwpzq0ydY9LoygMAvprkODyxyzZeF9M,186
115
115
  sqlframe/spark/types.py,sha256=KwNyuXIo-2xVVd4bZED3YrQOobKCtemlxGrJL7DrTC8,34
116
116
  sqlframe/spark/udf.py,sha256=owB8NDaGVkUQ0WGm7SZt2t9zfvLFCfi0W48QiPfgjck,1153
@@ -129,8 +129,8 @@ sqlframe/standalone/udf.py,sha256=azmgtUjHNIPs0WMVNId05SHwiYn41MKVBhKXsQJ5dmY,27
129
129
  sqlframe/standalone/window.py,sha256=6GKPzuxeSapJakBaKBeT9VpED1ACdjggDv9JRILDyV0,35
130
130
  sqlframe/testing/__init__.py,sha256=VVCosQhitU74A3NnE52O4mNtGZONapuEXcc20QmSlnQ,132
131
131
  sqlframe/testing/utils.py,sha256=PFsGZpwNUE_4-g_f43_vstTqsK0AQ2lBneb5Eb6NkFo,13008
132
- sqlframe-3.16.0.dist-info/LICENSE,sha256=VZu79YgW780qxaFJMr0t5ZgbOYEh04xWoxaWOaqIGWk,1068
133
- sqlframe-3.16.0.dist-info/METADATA,sha256=SMpgyXmxbVMqeeRuByF19qKm9iLDYubcniTCYBUmyNo,8970
134
- sqlframe-3.16.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
135
- sqlframe-3.16.0.dist-info/top_level.txt,sha256=T0_RpoygaZSF6heeWwIDQgaP0varUdSK1pzjeJZRjM8,9
136
- sqlframe-3.16.0.dist-info/RECORD,,
132
+ sqlframe-3.17.0.dist-info/LICENSE,sha256=VZu79YgW780qxaFJMr0t5ZgbOYEh04xWoxaWOaqIGWk,1068
133
+ sqlframe-3.17.0.dist-info/METADATA,sha256=K8kfOT5t6cEBs4YsIK76QCFBPW2NEcDcsPMkEhWCLUI,8970
134
+ sqlframe-3.17.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
135
+ sqlframe-3.17.0.dist-info/top_level.txt,sha256=T0_RpoygaZSF6heeWwIDQgaP0varUdSK1pzjeJZRjM8,9
136
+ sqlframe-3.17.0.dist-info/RECORD,,