tellaro-query-language 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tql/stats_evaluator.py CHANGED
@@ -147,33 +147,56 @@ class TQLStatsEvaluator:
147
147
  return {"type": "multiple_aggregations", "results": results}
148
148
 
149
149
  def _grouped_aggregation(
150
- self, records: List[Dict[str, Any]], aggregations: List[Dict[str, Any]], group_by_fields: List[str]
150
+ self, records: List[Dict[str, Any]], aggregations: List[Dict[str, Any]], group_by_fields: List[Any]
151
151
  ) -> Dict[str, Any]:
152
152
  """Perform aggregation with grouping.
153
153
 
154
154
  Args:
155
155
  records: Records to aggregate
156
156
  aggregations: Aggregation specifications
157
- group_by_fields: Fields to group by
157
+ group_by_fields: Fields to group by (can be strings or dicts with bucket_size)
158
158
 
159
159
  Returns:
160
160
  Grouped aggregation results
161
161
  """
162
+ # Normalize group_by_fields to handle both old (string) and new (dict) formats
163
+ normalized_fields = []
164
+ for field in group_by_fields:
165
+ if isinstance(field, str):
166
+ # Old format: just field name
167
+ normalized_fields.append({"field": field, "bucket_size": None})
168
+ elif isinstance(field, dict):
169
+ # New format: {"field": "name", "bucket_size": N}
170
+ normalized_fields.append(field)
171
+ else:
172
+ # Shouldn't happen but handle gracefully
173
+ normalized_fields.append({"field": str(field), "bucket_size": None})
174
+
162
175
  # Group records
163
176
  groups = defaultdict(list)
177
+ key_mapping = {} # Maps hashable key to original key
178
+
164
179
  for record in records:
165
180
  # Build group key
166
181
  key_parts = []
167
- for field in group_by_fields:
168
- value = self._get_field_value(record, field)
169
- key_parts.append((field, value))
170
- key = tuple(key_parts)
171
- groups[key].append(record)
182
+ for field_spec in normalized_fields:
183
+ field_name = field_spec["field"]
184
+ value = self._get_field_value(record, field_name)
185
+ key_parts.append((field_name, value))
186
+
187
+ # Create hashable key - convert unhashable values to strings
188
+ hashable_key = self._make_hashable_key(key_parts)
189
+ groups[hashable_key].append(record)
190
+
191
+ # Store mapping from hashable key to original key
192
+ if hashable_key not in key_mapping:
193
+ key_mapping[hashable_key] = key_parts
172
194
 
173
195
  # Calculate aggregations for each group
174
196
  results = []
175
- for key, group_records in groups.items():
176
- group_result: Dict[str, Any] = {"key": dict(key), "doc_count": len(group_records)}
197
+ for hashable_key, group_records in groups.items():
198
+ original_key = key_mapping[hashable_key]
199
+ group_result: Dict[str, Any] = {"key": dict(original_key), "doc_count": len(group_records)}
177
200
 
178
201
  if len(aggregations) == 1:
179
202
  # Single aggregation
@@ -194,7 +217,23 @@ class TQLStatsEvaluator:
194
217
  # Apply modifiers (top/bottom)
195
218
  results = self._apply_modifiers(results, aggregations)
196
219
 
197
- return {"type": "grouped_aggregation", "group_by": group_by_fields, "results": results}
220
+ # Apply per-field bucket limits
221
+ results = self._apply_bucket_limits(results, normalized_fields)
222
+
223
+ # Extract just the field names for the response to ensure compatibility
224
+ # with frontend code that expects strings, not dictionaries
225
+ group_by_field_names = []
226
+ for field in group_by_fields:
227
+ if isinstance(field, str):
228
+ group_by_field_names.append(field)
229
+ elif isinstance(field, dict) and "field" in field:
230
+ group_by_field_names.append(field["field"])
231
+ else:
232
+ # Fallback for unexpected formats
233
+ group_by_field_names.append(str(field))
234
+
235
+ # Return group_by fields as strings for frontend compatibility
236
+ return {"type": "grouped_aggregation", "group_by": group_by_field_names, "results": results}
198
237
 
199
238
  def _calculate_aggregation( # noqa: C901
200
239
  self, records: List[Dict[str, Any]], agg_spec: Dict[str, Any]
@@ -226,7 +265,7 @@ class TQLStatsEvaluator:
226
265
  if func == "count":
227
266
  return len(values)
228
267
  elif func == "unique_count":
229
- return len(set(values))
268
+ return len(self._get_unique_values(values))
230
269
  elif func == "sum":
231
270
  return sum(self._to_numeric(v) for v in values) if values else 0
232
271
  elif func == "min":
@@ -283,7 +322,7 @@ class TQLStatsEvaluator:
283
322
  return result
284
323
  elif func in ["values", "unique", "cardinality"]:
285
324
  # Return unique values from the field
286
- unique_values = list(set(values)) if values else []
325
+ unique_values = self._get_unique_values(values)
287
326
  # Sort the values for consistent output
288
327
  try:
289
328
  # Try to sort if values are comparable
@@ -332,6 +371,81 @@ class TQLStatsEvaluator:
332
371
 
333
372
  return results
334
373
 
374
+ def _apply_bucket_limits(
375
+ self, results: List[Dict[str, Any]], normalized_fields: List[Dict[str, Any]]
376
+ ) -> List[Dict[str, Any]]:
377
+ """Apply per-field bucket size limits to results.
378
+
379
+ Args:
380
+ results: Aggregation results
381
+ normalized_fields: Group by fields with bucket_size specifications
382
+
383
+ Returns:
384
+ Results with bucket limits applied
385
+ """
386
+ # Check if we have any bucket size limits
387
+ has_limits = any(field.get("bucket_size") is not None for field in normalized_fields)
388
+ if not has_limits:
389
+ return results
390
+
391
+ # For single-level grouping, apply the limit directly
392
+ if len(normalized_fields) == 1:
393
+ bucket_size = normalized_fields[0].get("bucket_size")
394
+ if bucket_size:
395
+ # Sort by doc_count (most common pattern) and limit
396
+ results = sorted(results, key=lambda x: x.get("doc_count", 0), reverse=True)
397
+ results = results[:bucket_size]
398
+ return results
399
+
400
+ # For multi-level grouping, we need to apply limits hierarchically
401
+ # First, group results by each level and apply limits
402
+
403
+ # Sort all results by doc_count first
404
+ results = sorted(results, key=lambda x: x.get("doc_count", 0), reverse=True)
405
+
406
+ # Build a hierarchical structure to apply limits at each level
407
+ filtered_results = []
408
+
409
+ # Track unique values at each level
410
+ level_values = {}
411
+ for level, field_spec in enumerate(normalized_fields):
412
+ level_values[level] = {}
413
+
414
+ for result in results:
415
+ # Check if this result should be included based on bucket limits at each level
416
+ should_include = True
417
+
418
+ # Build the key path for this result
419
+ key_path = []
420
+ for level, field_spec in enumerate(normalized_fields):
421
+ field_name = field_spec["field"]
422
+ field_value = result["key"].get(field_name)
423
+ key_path.append(field_value)
424
+
425
+ # For each level, check if we've hit the bucket limit
426
+ bucket_size = field_spec.get("bucket_size")
427
+ if bucket_size is not None:
428
+ # Build parent key (all fields up to but not including current level)
429
+ parent_key = tuple(key_path[:level]) if level > 0 else ()
430
+
431
+ # Initialize tracking for this parent if needed
432
+ if parent_key not in level_values[level]:
433
+ level_values[level][parent_key] = set()
434
+
435
+ # Check if adding this value would exceed the bucket limit
436
+ if field_value not in level_values[level][parent_key]:
437
+ if len(level_values[level][parent_key]) >= bucket_size:
438
+ should_include = False
439
+ break
440
+ else:
441
+ # Reserve this slot
442
+ level_values[level][parent_key].add(field_value)
443
+
444
+ if should_include:
445
+ filtered_results.append(result)
446
+
447
+ return filtered_results
448
+
335
449
  def _get_field_value(self, record: Dict[str, Any], field_path: str) -> Any:
336
450
  """Get a field value from a record, supporting nested fields.
337
451
 
@@ -453,3 +567,47 @@ class TQLStatsEvaluator:
453
567
  rank = count_less / n * 100
454
568
 
455
569
  return round(rank, 2)
570
+
571
+ def _get_unique_values(self, values: List[Any]) -> List[Any]:
572
+ """Get unique values from a list, handling unhashable types like dicts.
573
+
574
+ Args:
575
+ values: List of values that may contain unhashable types
576
+
577
+ Returns:
578
+ List of unique values
579
+ """
580
+ if not values:
581
+ return []
582
+
583
+ # Try the fast path first - use set if all values are hashable
584
+ try:
585
+ return list(set(values))
586
+ except TypeError:
587
+ # Some values are unhashable, use slower but safe approach
588
+ unique_values = []
589
+ for value in values:
590
+ if value not in unique_values:
591
+ unique_values.append(value)
592
+ return unique_values
593
+
594
+ def _make_hashable_key(self, key_parts: List[tuple]) -> tuple:
595
+ """Convert a key with potentially unhashable values to a hashable key.
596
+
597
+ Args:
598
+ key_parts: List of (field_name, value) tuples
599
+
600
+ Returns:
601
+ Hashable tuple that can be used as a dictionary key
602
+ """
603
+ hashable_parts = []
604
+ for field_name, value in key_parts:
605
+ try:
606
+ # Try to hash the value - if it works, use it as-is
607
+ hash(value)
608
+ hashable_parts.append((field_name, value))
609
+ except TypeError:
610
+ # Value is unhashable (like dict), convert to string representation
611
+ hashable_parts.append((field_name, str(value)))
612
+
613
+ return tuple(hashable_parts)