pylegend 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,23 +24,37 @@ from pylegend._typing import (
24
24
  )
25
25
  from pylegend.core.language.pandas_api.pandas_api_aggregate_specification import (
26
26
  PyLegendAggFunc,
27
- PyLegendAggInput
27
+ PyLegendAggInput,
28
+ PyLegendAggList,
28
29
  )
29
30
  from pylegend.core.language.pandas_api.pandas_api_tds_row import PandasApiTdsRow
30
31
  from pylegend.core.language.shared.helpers import escape_column_name, generate_pure_lambda
31
32
  from pylegend.core.language.shared.literal_expressions import convert_literal_to_literal_expression
32
33
  from pylegend.core.language.shared.primitive_collection import PyLegendPrimitiveCollection, create_primitive_collection
34
+ from pylegend.core.language.shared.primitives.boolean import PyLegendBoolean
35
+ from pylegend.core.language.shared.primitives.date import PyLegendDate
36
+ from pylegend.core.language.shared.primitives.datetime import PyLegendDateTime
37
+ from pylegend.core.language.shared.primitives.float import PyLegendFloat
38
+ from pylegend.core.language.shared.primitives.integer import PyLegendInteger
39
+ from pylegend.core.language.shared.primitives.number import PyLegendNumber
33
40
  from pylegend.core.language.shared.primitives.primitive import PyLegendPrimitive, PyLegendPrimitiveOrPythonPrimitive
34
- from pylegend.core.sql.metamodel import QuerySpecification, SelectItem, SingleColumn
41
+ from pylegend.core.language.shared.primitives.strictdate import PyLegendStrictDate
42
+ from pylegend.core.language.shared.primitives.string import PyLegendString
43
+ from pylegend.core.sql.metamodel import (
44
+ QuerySpecification,
45
+ SelectItem,
46
+ SingleColumn,
47
+ )
35
48
  from pylegend.core.tds.pandas_api.frames.pandas_api_applied_function_tds_frame import PandasApiAppliedFunction
36
49
  from pylegend.core.tds.pandas_api.frames.pandas_api_base_tds_frame import PandasApiBaseTdsFrame
50
+ from pylegend.core.tds.pandas_api.frames.pandas_api_groupby_tds_frame import PandasApiGroupbyTdsFrame
37
51
  from pylegend.core.tds.sql_query_helpers import copy_query, create_sub_query
38
- from pylegend.core.tds.tds_column import TdsColumn
52
+ from pylegend.core.tds.tds_column import PrimitiveTdsColumn, TdsColumn
39
53
  from pylegend.core.tds.tds_frame import FrameToPureConfig, FrameToSqlConfig
40
54
 
41
55
 
42
56
  class AggregateFunction(PandasApiAppliedFunction):
43
- __base_frame: PandasApiBaseTdsFrame
57
+ __base_frame: PyLegendUnion[PandasApiBaseTdsFrame, PandasApiGroupbyTdsFrame]
44
58
  __func: PyLegendAggInput
45
59
  __axis: PyLegendUnion[int, str]
46
60
  __args: PyLegendSequence[PyLegendPrimitiveOrPythonPrimitive]
@@ -51,12 +65,12 @@ class AggregateFunction(PandasApiAppliedFunction):
51
65
  return "aggregate" # pragma: no cover
52
66
 
53
67
  def __init__(
54
- self,
55
- base_frame: PandasApiBaseTdsFrame,
56
- func: PyLegendAggInput,
57
- axis: PyLegendUnion[int, str],
58
- *args: PyLegendPrimitiveOrPythonPrimitive,
59
- **kwargs: PyLegendPrimitiveOrPythonPrimitive
68
+ self,
69
+ base_frame: PyLegendUnion[PandasApiBaseTdsFrame, PandasApiGroupbyTdsFrame],
70
+ func: PyLegendAggInput,
71
+ axis: PyLegendUnion[int, str],
72
+ *args: PyLegendPrimitiveOrPythonPrimitive,
73
+ **kwargs: PyLegendPrimitiveOrPythonPrimitive,
60
74
  ) -> None:
61
75
  self.__base_frame = base_frame
62
76
  self.__func = func
@@ -66,13 +80,14 @@ class AggregateFunction(PandasApiAppliedFunction):
66
80
 
67
81
  def to_sql(self, config: FrameToSqlConfig) -> QuerySpecification:
68
82
  db_extension = config.sql_to_string_generator().get_db_extension()
69
- base_query: QuerySpecification = self.__base_frame.to_sql_query_object(config)
83
+
84
+ base_query: QuerySpecification = self.base_frame().to_sql_query_object(config)
70
85
 
71
86
  should_create_sub_query = (
72
- len(base_query.groupBy) > 0 or
73
- base_query.select.distinct or
74
- base_query.offset is not None or
75
- base_query.limit is not None
87
+ len(base_query.groupBy) > 0
88
+ or base_query.select.distinct
89
+ or base_query.offset is not None
90
+ or base_query.limit is not None
76
91
  )
77
92
 
78
93
  new_query: QuerySpecification
@@ -83,37 +98,118 @@ class AggregateFunction(PandasApiAppliedFunction):
83
98
 
84
99
  new_select_items: PyLegendList[SelectItem] = []
85
100
 
101
+ if isinstance(self.__base_frame, PandasApiGroupbyTdsFrame):
102
+ columns_to_retain: PyLegendList[str] = [
103
+ db_extension.quote_identifier(x) for x in self.__base_frame.grouping_column_name_list()
104
+ ]
105
+ new_cols_with_index: PyLegendList[PyLegendTuple[int, "SelectItem"]] = []
106
+ for col in new_query.select.selectItems:
107
+ if not isinstance(col, SingleColumn):
108
+ raise ValueError(
109
+ "Group By operation not supported for queries " "with columns other than SingleColumn"
110
+ ) # pragma: no cover
111
+ if col.alias is None:
112
+ raise ValueError(
113
+ "Group By operation not supported for queries " "with SingleColumns with missing alias"
114
+ ) # pragma: no cover
115
+ if col.alias in columns_to_retain:
116
+ new_cols_with_index.append((columns_to_retain.index(col.alias), col))
117
+
118
+ new_select_items = [y[1] for y in sorted(new_cols_with_index, key=lambda x: x[0])]
119
+
86
120
  for agg in self.__aggregates_list:
87
121
  agg_sql_expr = agg[2].to_sql_expression({"r": new_query}, config)
88
- new_select_items.append(
89
- SingleColumn(alias=db_extension.quote_identifier(agg[0]), expression=agg_sql_expr)
90
- )
122
+
123
+ new_select_items.append(SingleColumn(alias=db_extension.quote_identifier(agg[0]), expression=agg_sql_expr))
91
124
 
92
125
  new_query.select.selectItems = new_select_items
126
+
127
+ if isinstance(self.__base_frame, PandasApiGroupbyTdsFrame):
128
+ tds_row = PandasApiTdsRow.from_tds_frame("r", self.base_frame())
129
+ new_query.groupBy = [
130
+ (lambda x: x[c])(tds_row).to_sql_expression({"r": new_query}, config)
131
+ for c in self.__base_frame.grouping_column_name_list()
132
+ ]
133
+
93
134
  return new_query
94
135
 
95
136
  def to_pure(self, config: FrameToPureConfig) -> str:
96
137
  agg_strings = []
97
138
  for agg in self.__aggregates_list:
98
- map_expr_string = (agg[1].to_pure_expression(config) if isinstance(agg[1], PyLegendPrimitive)
99
- else convert_literal_to_literal_expression(agg[1]).to_pure_expression(config))
139
+ map_expr_string = (
140
+ agg[1].to_pure_expression(config)
141
+ if isinstance(agg[1], PyLegendPrimitive)
142
+ else convert_literal_to_literal_expression(agg[1]).to_pure_expression(config)
143
+ )
100
144
  agg_expr_string = agg[2].to_pure_expression(config).replace(map_expr_string, "$c")
101
- agg_strings.append(f"{escape_column_name(agg[0])}:{generate_pure_lambda('r', map_expr_string)}:"
102
- f"{generate_pure_lambda('c', agg_expr_string)}")
145
+ agg_strings.append(
146
+ f"{escape_column_name(agg[0])}:{generate_pure_lambda('r', map_expr_string)}:"
147
+ f"{generate_pure_lambda('c', agg_expr_string)}"
148
+ )
149
+
150
+ if isinstance(self.__base_frame, PandasApiGroupbyTdsFrame):
151
+ group_strings = []
152
+ for col_name in self.__base_frame.grouping_column_name_list():
153
+ group_strings.append(escape_column_name(col_name))
154
+
155
+ pure_expression = (
156
+ f"{self.base_frame().to_pure(config)}{config.separator(1)}" + f"->groupBy({config.separator(2)}"
157
+ f"~[{', '.join(group_strings)}],{config.separator(2, True)}"
158
+ f"~[{', '.join(agg_strings)}]{config.separator(1)}"
159
+ f")"
160
+ )
103
161
 
104
- return (f"{self.__base_frame.to_pure(config)}{config.separator(1)}"
162
+ return pure_expression
163
+ else:
164
+ return (
165
+ f"{self.__base_frame.to_pure(config)}{config.separator(1)}"
105
166
  f"->aggregate({config.separator(2)}"
106
167
  f"~[{', '.join(agg_strings)}]{config.separator(1)}"
107
- f")")
168
+ f")"
169
+ )
108
170
 
109
171
  def base_frame(self) -> PandasApiBaseTdsFrame:
110
- return self.__base_frame
172
+ if isinstance(self.__base_frame, PandasApiGroupbyTdsFrame):
173
+ return self.__base_frame.base_frame()
174
+ else:
175
+ return self.__base_frame
111
176
 
112
177
  def tds_frame_parameters(self) -> PyLegendList["PandasApiBaseTdsFrame"]:
113
178
  return []
114
179
 
115
180
  def calculate_columns(self) -> PyLegendSequence["TdsColumn"]:
116
- return [c.copy() for c in self.__base_frame.columns()]
181
+ new_columns = []
182
+
183
+ if isinstance(self.__base_frame, PandasApiGroupbyTdsFrame):
184
+ base_cols_map = {c.get_name(): c for c in self.base_frame().columns()}
185
+ for group_col_name in self.__base_frame.grouping_column_name_list():
186
+ if group_col_name in base_cols_map:
187
+ new_columns.append(base_cols_map[group_col_name].copy())
188
+
189
+ for alias, _, agg_expr in self.__aggregates_list:
190
+ new_columns.append(self.__infer_column_from_expression(alias, agg_expr))
191
+
192
+ return new_columns
193
+
194
+ def __infer_column_from_expression(self, name: str, expr: PyLegendPrimitive) -> TdsColumn:
195
+ if isinstance(expr, PyLegendInteger):
196
+ return PrimitiveTdsColumn.integer_column(name)
197
+ elif isinstance(expr, PyLegendFloat):
198
+ return PrimitiveTdsColumn.float_column(name)
199
+ elif isinstance(expr, PyLegendNumber):
200
+ return PrimitiveTdsColumn.number_column(name)
201
+ elif isinstance(expr, PyLegendString):
202
+ return PrimitiveTdsColumn.string_column(name)
203
+ elif isinstance(expr, PyLegendBoolean):
204
+ return PrimitiveTdsColumn.boolean_column(name) # pragma: no cover
205
+ elif isinstance(expr, PyLegendDate):
206
+ return PrimitiveTdsColumn.date_column(name)
207
+ elif isinstance(expr, PyLegendDateTime):
208
+ return PrimitiveTdsColumn.datetime_column(name)
209
+ elif isinstance(expr, PyLegendStrictDate):
210
+ return PrimitiveTdsColumn.strictdate_column(name)
211
+ else:
212
+ raise TypeError(f"Could not infer TdsColumn type for aggregation result type: {type(expr)}") # pragma: no cover
117
213
 
118
214
  def validate(self) -> bool:
119
215
  if self.__axis not in [0, "index"]:
@@ -127,35 +223,73 @@ class AggregateFunction(PandasApiAppliedFunction):
127
223
  "or keyword arguments. Please remove extra *args/**kwargs."
128
224
  )
129
225
 
130
- self.__aggregates_list: PyLegendList[
131
- PyLegendTuple[str, PyLegendPrimitiveOrPythonPrimitive, PyLegendPrimitive]
132
- ] = []
226
+ self.__aggregates_list: PyLegendList[PyLegendTuple[str, PyLegendPrimitiveOrPythonPrimitive, PyLegendPrimitive]] = []
227
+
228
+ normalized_func: dict[str, PyLegendUnion[PyLegendAggFunc, PyLegendAggList]] = (
229
+ self.__normalize_input_func_to_standard_dict(self.__func)
230
+ )
133
231
 
134
- normalized_func: dict[str, PyLegendAggFunc] = self.__normalize_input_func_to_standard_dict(self.__func)
135
- tds_row = PandasApiTdsRow.from_tds_frame("r", self.__base_frame)
232
+ tds_row = PandasApiTdsRow.from_tds_frame("r", self.base_frame())
136
233
 
137
- for column_name, aggregate_function in normalized_func.items():
234
+ for column_name, agg_input in normalized_func.items():
138
235
  mapper_function: PyLegendCallable[[PandasApiTdsRow], PyLegendPrimitiveOrPythonPrimitive] = eval(
139
- f'lambda r: r["{column_name}"]')
236
+ f'lambda r: r["{column_name}"]'
237
+ )
140
238
  map_result: PyLegendPrimitiveOrPythonPrimitive = mapper_function(tds_row)
141
239
  collection: PyLegendPrimitiveCollection = create_primitive_collection(map_result)
142
240
 
143
- normalized_aggregate_function = self.__normalize_agg_func_to_lambda_function(aggregate_function)
144
- agg_result: PyLegendPrimitive = normalized_aggregate_function(collection)
241
+ if isinstance(agg_input, list):
242
+ lambda_counter = 0
243
+ for func in agg_input:
244
+ is_anonymous_lambda = False
245
+ if not isinstance(func, str):
246
+ if getattr(func, "__name__", "<lambda>") == "<lambda>":
247
+ is_anonymous_lambda = True
145
248
 
146
- self.__aggregates_list.append((column_name, map_result, agg_result))
249
+ if is_anonymous_lambda:
250
+ lambda_counter += 1
251
+
252
+ normalized_agg_func = self.__normalize_agg_func_to_lambda_function(func)
253
+ agg_result = normalized_agg_func(collection)
254
+
255
+ alias = self._generate_column_alias(column_name, func, lambda_counter)
256
+ self.__aggregates_list.append((alias, map_result, agg_result))
257
+
258
+ else:
259
+ normalized_agg_func = self.__normalize_agg_func_to_lambda_function(agg_input)
260
+ agg_result = normalized_agg_func(collection)
261
+
262
+ self.__aggregates_list.append((column_name, map_result, agg_result))
147
263
 
148
264
  return True
149
265
 
150
266
  def __normalize_input_func_to_standard_dict(
151
- self,
152
- func_input: PyLegendAggInput
153
- ) -> dict[str, PyLegendAggFunc]:
267
+ self, func_input: PyLegendAggInput
268
+ ) -> dict[str, PyLegendUnion[PyLegendAggFunc, PyLegendAggList]]:
269
+
270
+ validation_columns: PyLegendList[str]
271
+ default_broadcast_columns: PyLegendList[str]
272
+ group_cols: set[str] = set()
273
+
274
+ all_cols = [col.get_name() for col in self.base_frame().columns()]
154
275
 
155
- column_names = [col.get_name() for col in self.calculate_columns()]
276
+ if isinstance(self.__base_frame, PandasApiGroupbyTdsFrame):
277
+ group_cols = set(self.__base_frame.grouping_column_name_list())
278
+
279
+ selected_cols = self.__base_frame.selected_columns()
280
+
281
+ if selected_cols is not None:
282
+ validation_columns = selected_cols
283
+ default_broadcast_columns = selected_cols
284
+ else:
285
+ validation_columns = all_cols
286
+ default_broadcast_columns = [c for c in all_cols if c not in group_cols]
287
+ else:
288
+ validation_columns = all_cols
289
+ default_broadcast_columns = all_cols
156
290
 
157
291
  if isinstance(func_input, collections.abc.Mapping):
158
- normalized: dict[str, PyLegendAggFunc] = {}
292
+ normalized: dict[str, PyLegendUnion[PyLegendAggFunc, PyLegendAggList]] = {}
159
293
 
160
294
  for key, value in func_input.items():
161
295
  if not isinstance(key, str):
@@ -164,73 +298,54 @@ class AggregateFunction(PandasApiAppliedFunction):
164
298
  f"When a dictionary is provided, all keys must be strings.\n"
165
299
  f"But got key: {key!r} (type: {type(key).__name__})\n"
166
300
  )
167
- if key not in column_names:
301
+
302
+ if key not in validation_columns:
168
303
  raise ValueError(
169
304
  f"Invalid `func` argument for the aggregate function.\n"
170
305
  f"When a dictionary is provided, all keys must be column names.\n"
171
- f"Available columns are: {sorted(column_names)}\n"
306
+ f"Available columns are: {sorted(validation_columns)}\n"
172
307
  f"But got key: {key!r} (type: {type(key).__name__})\n"
173
308
  )
174
309
 
175
310
  if isinstance(value, collections.abc.Sequence) and not isinstance(value, str):
176
- if len(value) != 1:
177
- raise ValueError(
178
- f"Invalid `func` argument for the aggregate function.\n"
179
- f"When providing a list of functions for a specific column, "
180
- f"the list must contain exactly one element (single aggregation only).\n"
181
- f"Column: {key!r}\n"
182
- f"List Length: {len(value)}\n"
183
- f"Value: {value!r}\n"
184
- )
185
-
186
- single_func = value[0]
187
-
188
- if not (callable(single_func) or isinstance(single_func, str) or isinstance(single_func, np.ufunc)):
189
- raise TypeError(
190
- f"Invalid `func` argument for the aggregate function.\n"
191
- f"The single element in the list for key {key!r} must be a callable, str, or np.ufunc.\n"
192
- f"But got element: {single_func!r} (type: {type(single_func).__name__})\n"
193
- )
194
-
195
- normalized[key] = single_func
311
+ for i, f in enumerate(value):
312
+ if not (callable(f) or isinstance(f, str) or isinstance(f, np.ufunc)):
313
+ raise TypeError(
314
+ f"Invalid `func` argument for the aggregate function.\n"
315
+ f"When a list is provided for a column, all elements must be callable, str, or np.ufunc.\n"
316
+ f"But got element at index {i}: {f!r} (type: {type(f).__name__})\n"
317
+ )
318
+ normalized[key] = value
196
319
 
197
320
  else:
198
321
  if not (callable(value) or isinstance(value, str) or isinstance(value, np.ufunc)):
199
322
  raise TypeError(
200
323
  f"Invalid `func` argument for the aggregate function.\n"
201
324
  f"When a dictionary is provided, the value must be a callable, str, or np.ufunc "
202
- f"(or a list containing exactly one of these).\n"
203
- f"But got value for key {key!r}: {value!r} (type: {type(value).__name__})\n"
325
+ f"(or a list containing these).\n"
326
+ f"But got value for key '{key}': {value} (type: {type(value).__name__})\n"
204
327
  )
205
- normalized[key] = value
328
+
329
+ if key in group_cols:
330
+ normalized[key] = [value]
331
+ else:
332
+ normalized[key] = value
206
333
 
207
334
  return normalized
208
335
 
209
336
  elif isinstance(func_input, collections.abc.Sequence) and not isinstance(func_input, str):
337
+ for i, f in enumerate(func_input):
338
+ if not (callable(f) or isinstance(f, str) or isinstance(f, np.ufunc)):
339
+ raise TypeError(
340
+ f"Invalid `func` argument for the aggregate function.\n"
341
+ f"When a list is provided as the main argument, all elements must be callable, str, or np.ufunc.\n"
342
+ f"But got element at index {i}: {f!r} (type: {type(f).__name__})\n"
343
+ )
210
344
 
211
- if len(func_input) != 1:
212
- raise ValueError(
213
- f"Invalid `func` argument for the aggregate function.\n"
214
- f"When providing a list as the func argument, it must contain exactly one element "
215
- f"(which will be applied to all columns).\n"
216
- f"Multiple functions are not supported.\n"
217
- f"List Length: {len(func_input)}\n"
218
- f"Input: {func_input!r}\n"
219
- )
220
-
221
- single_func = func_input[0]
222
-
223
- if not (callable(single_func) or isinstance(single_func, str) or isinstance(single_func, np.ufunc)):
224
- raise TypeError(
225
- f"Invalid `func` argument for the aggregate function.\n"
226
- f"The single element in the top-level list must be a callable, str, or np.ufunc.\n"
227
- f"But got element: {single_func!r} (type: {type(single_func).__name__})\n"
228
- )
229
-
230
- return {col: single_func for col in column_names}
345
+ return {col: func_input for col in default_broadcast_columns}
231
346
 
232
347
  elif callable(func_input) or isinstance(func_input, str) or isinstance(func_input, np.ufunc):
233
- return {col: func_input for col in column_names}
348
+ return {col: func_input for col in default_broadcast_columns}
234
349
 
235
350
  else:
236
351
  raise TypeError(
@@ -241,19 +356,17 @@ class AggregateFunction(PandasApiAppliedFunction):
241
356
  )
242
357
 
243
358
  def __normalize_agg_func_to_lambda_function(
244
- self,
245
- func: PyLegendAggFunc
359
+ self, func: PyLegendAggFunc
246
360
  ) -> PyLegendCallable[[PyLegendPrimitiveCollection], PyLegendPrimitive]:
247
361
 
248
362
  PYTHON_FUNCTION_TO_LEGEND_FUNCTION_MAPPING: PyLegendMapping[str, PyLegendList[str]] = {
249
- "average": ["mean", "average", "nanmean"],
250
- "sum": ["sum", "nansum"],
251
- "min": ["min", "amin", "minimum", "nanmin"],
252
- "max": ["max", "amax", "maximum", "nanmax"],
253
- "std_dev_sample": ["std", "std_dev", "nanstd"],
363
+ "average": ["mean", "average", "nanmean"],
364
+ "sum": ["sum", "nansum"],
365
+ "min": ["min", "amin", "minimum", "nanmin"],
366
+ "max": ["max", "amax", "maximum", "nanmax"],
367
+ "std_dev_sample": ["std", "std_dev", "nanstd"],
254
368
  "variance_sample": ["var", "variance", "nanvar"],
255
- "median": ["median", "nanmedian"],
256
- "count": ["count", "size", "len", "length"],
369
+ "count": ["count", "size", "len", "length"],
257
370
  }
258
371
 
259
372
  FLATTENED_FUNCTION_MAPPING: dict[str, str] = {}
@@ -300,6 +413,7 @@ class AggregateFunction(PandasApiAppliedFunction):
300
413
  final_lambda = eval(lambda_source)
301
414
  return final_lambda
302
415
  else:
416
+
303
417
  def validation_wrapper(x: PyLegendPrimitiveCollection) -> PyLegendPrimitive:
304
418
  result = func(x)
305
419
  if not isinstance(result, PyLegendPrimitive):
@@ -314,3 +428,14 @@ class AggregateFunction(PandasApiAppliedFunction):
314
428
 
315
429
  def _generate_lambda_source(self, internal_method_name: str) -> str:
316
430
  return f"lambda x: x.{internal_method_name}()"
431
+
432
+ def _generate_column_alias(self, col_name: str, func: PyLegendAggFunc, lambda_counter: int) -> str:
433
+ if isinstance(func, str):
434
+ return f"{func}({col_name})"
435
+
436
+ func_name = getattr(func, "__name__", "<lambda>")
437
+
438
+ if func_name != "<lambda>":
439
+ return f"{func_name}({col_name})"
440
+ else:
441
+ return f"lambda_{lambda_counter}({col_name})"