prismiq 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prismiq/transforms.py ADDED
@@ -0,0 +1,471 @@
1
+ """Data transformation utilities for Prismiq analytics.
2
+
3
+ This module provides functions for transforming query results including
4
+ pivot, transpose, null filling, running totals, and percentage
5
+ calculations.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import contextlib
11
+ from collections import defaultdict
12
+ from typing import Any
13
+
14
+ from prismiq.types import QueryResult
15
+
16
+
17
+ def pivot_data(
18
+ result: QueryResult,
19
+ row_column: str,
20
+ pivot_column: str,
21
+ value_column: str,
22
+ aggregation: str = "sum",
23
+ ) -> QueryResult:
24
+ """Pivot data from long to wide format.
25
+
26
+ Args:
27
+ result: Query result to pivot.
28
+ row_column: Column to use as row labels.
29
+ pivot_column: Column whose unique values become new columns.
30
+ value_column: Column containing the values to pivot.
31
+ aggregation: Aggregation for duplicate values ('sum', 'avg', 'min', 'max', 'count').
32
+
33
+ Returns:
34
+ New QueryResult with pivoted data.
35
+
36
+ Example:
37
+ Input:
38
+ region | month | sales
39
+ East | Jan | 100
40
+ East | Feb | 150
41
+ West | Jan | 200
42
+
43
+ Output (pivot on month):
44
+ region | Jan | Feb
45
+ East | 100 | 150
46
+ West | 200 | None
47
+ """
48
+ if not result.rows:
49
+ return QueryResult(
50
+ columns=[row_column],
51
+ column_types=["text"],
52
+ rows=[],
53
+ row_count=0,
54
+ truncated=False,
55
+ execution_time_ms=0,
56
+ )
57
+
58
+ # Find column indices
59
+ try:
60
+ row_idx = result.columns.index(row_column)
61
+ pivot_idx = result.columns.index(pivot_column)
62
+ value_idx = result.columns.index(value_column)
63
+ except ValueError as e:
64
+ raise ValueError(f"Column not found in result: {e}") from e
65
+
66
+ # Get unique pivot values (these become new columns)
67
+ pivot_values: list[Any] = []
68
+ seen_values: set[Any] = set()
69
+ for row in result.rows:
70
+ val = row[pivot_idx]
71
+ if val not in seen_values:
72
+ pivot_values.append(val)
73
+ seen_values.add(val)
74
+
75
+ # Group data by row_column
76
+ data_map: dict[Any, dict[Any, list[float]]] = defaultdict(lambda: defaultdict(list))
77
+
78
+ for row in result.rows:
79
+ row_val = row[row_idx]
80
+ pivot_val = row[pivot_idx]
81
+ value = row[value_idx]
82
+
83
+ if value is not None:
84
+ try:
85
+ data_map[row_val][pivot_val].append(float(value))
86
+ except (ValueError, TypeError):
87
+ # Non-numeric value, skip aggregation
88
+ data_map[row_val][pivot_val].append(0)
89
+
90
+ # Apply aggregation
91
+ def aggregate(values: list[float]) -> float | None:
92
+ if not values:
93
+ return None
94
+ if aggregation == "sum":
95
+ return sum(values)
96
+ if aggregation == "avg":
97
+ return sum(values) / len(values)
98
+ if aggregation == "min":
99
+ return min(values)
100
+ if aggregation == "max":
101
+ return max(values)
102
+ if aggregation == "count":
103
+ return float(len(values))
104
+ # Default to sum
105
+ return sum(values)
106
+
107
+ # Build output rows
108
+ output_rows: list[list[Any]] = []
109
+ for row_val in data_map:
110
+ output_row: list[Any] = [row_val]
111
+ for pivot_val in pivot_values:
112
+ values = data_map[row_val].get(pivot_val, [])
113
+ output_row.append(aggregate(values))
114
+ output_rows.append(output_row)
115
+
116
+ # Build column names and types
117
+ output_columns = [row_column] + [str(v) for v in pivot_values]
118
+ output_types = [result.column_types[row_idx]] + ["numeric"] * len(pivot_values)
119
+
120
+ return QueryResult(
121
+ columns=output_columns,
122
+ column_types=output_types,
123
+ rows=output_rows,
124
+ row_count=len(output_rows),
125
+ truncated=False,
126
+ execution_time_ms=0,
127
+ )
128
+
129
+
130
+ def transpose_data(result: QueryResult) -> QueryResult:
131
+ """Transpose rows and columns.
132
+
133
+ The first row becomes column headers (if present),
134
+ and columns become rows.
135
+
136
+ Args:
137
+ result: Query result to transpose.
138
+
139
+ Returns:
140
+ New QueryResult with transposed data.
141
+ """
142
+ if not result.rows:
143
+ return QueryResult(
144
+ columns=["Column"],
145
+ column_types=["text"],
146
+ rows=[[col] for col in result.columns],
147
+ row_count=len(result.columns),
148
+ truncated=False,
149
+ execution_time_ms=0,
150
+ )
151
+
152
+ # Use original column names as first column
153
+ # Each original column becomes a row
154
+ num_rows = len(result.rows)
155
+
156
+ # New columns: "Column" + row indices (Row 1, Row 2, etc.)
157
+ output_columns = ["Column"] + [f"Row {i + 1}" for i in range(num_rows)]
158
+ output_types = ["text"] * len(output_columns)
159
+
160
+ # Each original column becomes a row
161
+ output_rows: list[list[Any]] = []
162
+ for col_idx, col_name in enumerate(result.columns):
163
+ row: list[Any] = [col_name]
164
+ for result_row in result.rows:
165
+ row.append(result_row[col_idx] if col_idx < len(result_row) else None)
166
+ output_rows.append(row)
167
+
168
+ return QueryResult(
169
+ columns=output_columns,
170
+ column_types=output_types,
171
+ rows=output_rows,
172
+ row_count=len(output_rows),
173
+ truncated=False,
174
+ execution_time_ms=0,
175
+ )
176
+
177
+
178
+ def fill_nulls(
179
+ result: QueryResult,
180
+ column: str | None = None,
181
+ value: Any = 0,
182
+ method: str | None = None,
183
+ ) -> QueryResult:
184
+ """Fill null values in result data.
185
+
186
+ Args:
187
+ result: Query result to process.
188
+ column: Specific column to fill, or None for all columns.
189
+ value: Static fill value (used if method is None).
190
+ method: Fill method ('ffill' for forward fill, 'bfill' for backward fill).
191
+
192
+ Returns:
193
+ New QueryResult with nulls filled.
194
+ """
195
+ if not result.rows:
196
+ return result
197
+
198
+ # Deep copy rows
199
+ output_rows = [list(row) for row in result.rows]
200
+
201
+ # Determine which columns to process
202
+ if column is not None:
203
+ try:
204
+ col_indices = [result.columns.index(column)]
205
+ except ValueError as e:
206
+ raise ValueError(f"Column '{column}' not found in result") from e
207
+ else:
208
+ col_indices = list(range(len(result.columns)))
209
+
210
+ for col_idx in col_indices:
211
+ if method == "ffill":
212
+ # Forward fill - use previous non-null value
213
+ last_value: Any = value
214
+ for row in output_rows:
215
+ if row[col_idx] is None:
216
+ row[col_idx] = last_value
217
+ else:
218
+ last_value = row[col_idx]
219
+ elif method == "bfill":
220
+ # Backward fill - use next non-null value
221
+ last_value = value
222
+ for row in reversed(output_rows):
223
+ if row[col_idx] is None:
224
+ row[col_idx] = last_value
225
+ else:
226
+ last_value = row[col_idx]
227
+ else:
228
+ # Static fill
229
+ for row in output_rows:
230
+ if row[col_idx] is None:
231
+ row[col_idx] = value
232
+
233
+ return QueryResult(
234
+ columns=result.columns,
235
+ column_types=result.column_types,
236
+ rows=output_rows,
237
+ row_count=result.row_count,
238
+ truncated=result.truncated,
239
+ execution_time_ms=0,
240
+ )
241
+
242
+
243
+ def calculate_running_total(
244
+ result: QueryResult,
245
+ value_column: str,
246
+ order_column: str | None = None,
247
+ group_column: str | None = None,
248
+ ) -> QueryResult:
249
+ """Add a running total column.
250
+
251
+ Args:
252
+ result: Query result to process.
253
+ value_column: Column containing values to sum.
254
+ order_column: Column to order by (uses existing order if None).
255
+ group_column: Column to group by (calculates running total within each group).
256
+
257
+ Returns:
258
+ New QueryResult with running total column added.
259
+ """
260
+ if not result.rows:
261
+ return QueryResult(
262
+ columns=[*result.columns, f"{value_column}_running_total"],
263
+ column_types=[*result.column_types, "numeric"],
264
+ rows=[],
265
+ row_count=0,
266
+ truncated=False,
267
+ execution_time_ms=0,
268
+ )
269
+
270
+ try:
271
+ value_idx = result.columns.index(value_column)
272
+ except ValueError as e:
273
+ raise ValueError(f"Column '{value_column}' not found in result") from e
274
+
275
+ group_idx = None
276
+ if group_column is not None:
277
+ try:
278
+ group_idx = result.columns.index(group_column)
279
+ except ValueError as e:
280
+ raise ValueError(f"Group column '{group_column}' not found") from e
281
+
282
+ order_idx = None
283
+ if order_column is not None:
284
+ try:
285
+ order_idx = result.columns.index(order_column)
286
+ except ValueError as e:
287
+ raise ValueError(f"Order column '{order_column}' not found") from e
288
+
289
+ # Create indexed rows for sorting
290
+ indexed_rows = list(enumerate(result.rows))
291
+
292
+ # Sort by order column if specified
293
+ if order_idx is not None:
294
+ indexed_rows.sort(key=lambda x: (x[1][order_idx] or 0))
295
+
296
+ # Calculate running totals
297
+ running_totals: dict[Any, float] = defaultdict(float)
298
+ row_totals: dict[int, float] = {}
299
+
300
+ for original_idx, row in indexed_rows:
301
+ group_key = row[group_idx] if group_idx is not None else "__all__"
302
+ val = row[value_idx]
303
+
304
+ if val is not None:
305
+ with contextlib.suppress(ValueError, TypeError):
306
+ running_totals[group_key] += float(val)
307
+
308
+ row_totals[original_idx] = running_totals[group_key]
309
+
310
+ # Build output with running totals in original order
311
+ output_rows: list[list[Any]] = []
312
+ for i, row in enumerate(result.rows):
313
+ output_rows.append([*row, row_totals.get(i, 0)])
314
+
315
+ return QueryResult(
316
+ columns=[*result.columns, f"{value_column}_running_total"],
317
+ column_types=[*result.column_types, "numeric"],
318
+ rows=output_rows,
319
+ row_count=result.row_count,
320
+ truncated=result.truncated,
321
+ execution_time_ms=0,
322
+ )
323
+
324
+
325
+ def calculate_percent_of_total(
326
+ result: QueryResult,
327
+ value_column: str,
328
+ group_column: str | None = None,
329
+ ) -> QueryResult:
330
+ """Add a percentage of total column.
331
+
332
+ Args:
333
+ result: Query result to process.
334
+ value_column: Column containing values.
335
+ group_column: Column to group by (calculates percentage within each group).
336
+
337
+ Returns:
338
+ New QueryResult with percentage column added.
339
+ """
340
+ if not result.rows:
341
+ return QueryResult(
342
+ columns=[*result.columns, f"{value_column}_pct"],
343
+ column_types=[*result.column_types, "numeric"],
344
+ rows=[],
345
+ row_count=0,
346
+ truncated=False,
347
+ execution_time_ms=0,
348
+ )
349
+
350
+ try:
351
+ value_idx = result.columns.index(value_column)
352
+ except ValueError as e:
353
+ raise ValueError(f"Column '{value_column}' not found in result") from e
354
+
355
+ group_idx = None
356
+ if group_column is not None:
357
+ try:
358
+ group_idx = result.columns.index(group_column)
359
+ except ValueError as e:
360
+ raise ValueError(f"Group column '{group_column}' not found") from e
361
+
362
+ # Calculate totals per group
363
+ group_totals: dict[Any, float] = defaultdict(float)
364
+
365
+ for row in result.rows:
366
+ group_key = row[group_idx] if group_idx is not None else "__all__"
367
+ val = row[value_idx]
368
+
369
+ if val is not None:
370
+ with contextlib.suppress(ValueError, TypeError):
371
+ group_totals[group_key] += float(val)
372
+
373
+ # Calculate percentages
374
+ output_rows: list[list[Any]] = []
375
+ for row in result.rows:
376
+ group_key = row[group_idx] if group_idx is not None else "__all__"
377
+ val = row[value_idx]
378
+ total = group_totals[group_key]
379
+
380
+ if val is not None and total > 0:
381
+ try:
382
+ pct = (float(val) / total) * 100
383
+ except (ValueError, TypeError):
384
+ pct = None
385
+ else:
386
+ pct = None
387
+
388
+ output_rows.append([*row, pct])
389
+
390
+ return QueryResult(
391
+ columns=[*result.columns, f"{value_column}_pct"],
392
+ column_types=[*result.column_types, "numeric"],
393
+ rows=output_rows,
394
+ row_count=result.row_count,
395
+ truncated=result.truncated,
396
+ execution_time_ms=0,
397
+ )
398
+
399
+
400
+ def sort_result(
401
+ result: QueryResult,
402
+ column: str,
403
+ descending: bool = False,
404
+ ) -> QueryResult:
405
+ """Sort query result by a column.
406
+
407
+ Args:
408
+ result: Query result to sort.
409
+ column: Column to sort by.
410
+ descending: Sort in descending order if True.
411
+
412
+ Returns:
413
+ New QueryResult with sorted rows.
414
+ """
415
+ if not result.rows:
416
+ return result
417
+
418
+ try:
419
+ col_idx = result.columns.index(column)
420
+ except ValueError as e:
421
+ raise ValueError(f"Column '{column}' not found in result") from e
422
+
423
+ # Sort with None values at the end
424
+ def sort_key(row: list[Any]) -> tuple[bool, Any]:
425
+ val = row[col_idx]
426
+ # Put None values last
427
+ return (val is None, val or 0)
428
+
429
+ sorted_rows = sorted(result.rows, key=sort_key, reverse=descending)
430
+
431
+ return QueryResult(
432
+ columns=result.columns,
433
+ column_types=result.column_types,
434
+ rows=sorted_rows,
435
+ row_count=result.row_count,
436
+ truncated=result.truncated,
437
+ execution_time_ms=0,
438
+ )
439
+
440
+
441
+ def limit_result(
442
+ result: QueryResult,
443
+ limit: int,
444
+ offset: int = 0,
445
+ ) -> QueryResult:
446
+ """Limit and offset query result rows.
447
+
448
+ Args:
449
+ result: Query result to limit.
450
+ limit: Maximum number of rows to return.
451
+ offset: Number of rows to skip.
452
+
453
+ Returns:
454
+ New QueryResult with limited rows.
455
+ """
456
+ if offset < 0:
457
+ offset = 0
458
+ if limit < 0:
459
+ limit = 0
460
+
461
+ sliced_rows = result.rows[offset : offset + limit]
462
+ truncated = (offset + limit) < len(result.rows) or result.truncated
463
+
464
+ return QueryResult(
465
+ columns=result.columns,
466
+ column_types=result.column_types,
467
+ rows=sliced_rows,
468
+ row_count=len(sliced_rows),
469
+ truncated=truncated,
470
+ execution_time_ms=0,
471
+ )