openms-insight 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,6 +8,7 @@ from ..core.base import BaseComponent
8
8
  from ..core.registry import register_component
9
9
  from ..preprocessing.compression import (
10
10
  compute_compression_levels,
11
+ compute_optimal_bins,
11
12
  downsample_2d,
12
13
  downsample_2d_simple,
13
14
  downsample_2d_streaming,
@@ -76,17 +77,23 @@ class Heatmap(BaseComponent):
76
77
  interactivity: Optional[Dict[str, str]] = None,
77
78
  cache_path: str = ".",
78
79
  regenerate_cache: bool = False,
79
- min_points: int = 20000,
80
- x_bins: int = 400,
81
- y_bins: int = 50,
80
+ min_points: int = 10000,
81
+ display_aspect_ratio: float = 16 / 9,
82
+ x_bins: Optional[int] = None,
83
+ y_bins: Optional[int] = None,
82
84
  zoom_identifier: str = "heatmap_zoom",
83
85
  title: Optional[str] = None,
84
86
  x_label: Optional[str] = None,
85
87
  y_label: Optional[str] = None,
86
88
  colorscale: str = "Portland",
89
+ reversescale: bool = False,
87
90
  use_simple_downsample: bool = False,
88
91
  use_streaming: bool = True,
89
92
  categorical_filters: Optional[List[str]] = None,
93
+ category_column: Optional[str] = None,
94
+ category_colors: Optional[Dict[str, str]] = None,
95
+ log_scale: bool = True,
96
+ intensity_label: Optional[str] = None,
90
97
  **kwargs,
91
98
  ):
92
99
  """
@@ -106,10 +113,17 @@ class Heatmap(BaseComponent):
106
113
  point's value in the corresponding column.
107
114
  cache_path: Base path for cache storage. Default "." (current dir).
108
115
  regenerate_cache: If True, regenerate cache even if valid cache exists.
109
- min_points: Target size for smallest compression level and
110
- threshold for level selection (default: 20000)
111
- x_bins: Number of bins along x-axis for downsampling (default: 400)
112
- y_bins: Number of bins along y-axis for downsampling (default: 50)
116
+ min_points: Target number of points to display (default: 10000).
117
+ Cache levels are built at 2× this value; final downsample
118
+ at render time reduces to exactly min_points.
119
+ display_aspect_ratio: Expected display width/height ratio for
120
+ optimal bin computation during caching (default: 16/9).
121
+ At render time, the actual zoom region's aspect ratio is used.
122
+ x_bins: Number of bins along x-axis for downsampling. If None
123
+ (default), auto-computed from display_aspect_ratio such that
124
+ x_bins × y_bins ≈ 2×min_points with even spatial distribution.
125
+ y_bins: Number of bins along y-axis for downsampling. If None
126
+ (default), auto-computed from display_aspect_ratio.
113
127
  zoom_identifier: State key for storing zoom range (default: 'heatmap_zoom')
114
128
  title: Heatmap title displayed above the plot
115
129
  x_label: X-axis label (defaults to x_column)
@@ -124,12 +138,25 @@ class Heatmap(BaseComponent):
124
138
  are sent to the client regardless of filter selection. Should be
125
139
  used for filters with a small number of unique values (<20).
126
140
  Example: ['im_dimension'] for ion mobility filtering.
141
+ category_column: Optional column name for categorical coloring.
142
+ When provided, points are colored by discrete category values
143
+ instead of the continuous intensity colorscale. Useful for
144
+ condition-based heatmaps (e.g., coloring by sample group).
145
+ category_colors: Optional mapping of category values to colors.
146
+ Keys should match values in category_column.
147
+ Values should be CSS color strings (e.g., '#FF0000', 'red').
148
+ If not provided, default Plotly colors will be used.
149
+ log_scale: If True (default), apply log10 transformation to intensity
150
+ values for color mapping. Set to False for linear color mapping.
151
+ intensity_label: Custom label for the colorbar. Default is "Intensity".
152
+ Useful when displaying non-intensity values like scores or counts.
127
153
  **kwargs: Additional configuration options
128
154
  """
129
155
  self._x_column = x_column
130
156
  self._y_column = y_column
131
157
  self._intensity_column = intensity_column
132
158
  self._min_points = min_points
159
+ self._display_aspect_ratio = display_aspect_ratio
133
160
  self._x_bins = x_bins
134
161
  self._y_bins = y_bins
135
162
  self._zoom_identifier = zoom_identifier
@@ -137,7 +164,12 @@ class Heatmap(BaseComponent):
137
164
  self._x_label = x_label or x_column
138
165
  self._y_label = y_label or y_column
139
166
  self._colorscale = colorscale
167
+ self._reversescale = reversescale
140
168
  self._use_simple_downsample = use_simple_downsample
169
+ self._category_column = category_column
170
+ self._category_colors = category_colors or {}
171
+ self._log_scale = log_scale
172
+ self._intensity_label = intensity_label
141
173
  self._use_streaming = use_streaming
142
174
  self._categorical_filters = categorical_filters or []
143
175
 
@@ -155,6 +187,7 @@ class Heatmap(BaseComponent):
155
187
  y_column=y_column,
156
188
  intensity_column=intensity_column,
157
189
  min_points=min_points,
190
+ display_aspect_ratio=display_aspect_ratio,
158
191
  x_bins=x_bins,
159
192
  y_bins=y_bins,
160
193
  zoom_identifier=zoom_identifier,
@@ -165,6 +198,8 @@ class Heatmap(BaseComponent):
165
198
  use_simple_downsample=use_simple_downsample,
166
199
  use_streaming=use_streaming,
167
200
  categorical_filters=categorical_filters,
201
+ category_column=category_column,
202
+ category_colors=category_colors,
168
203
  **kwargs,
169
204
  )
170
205
 
@@ -180,6 +215,7 @@ class Heatmap(BaseComponent):
180
215
  "y_column": self._y_column,
181
216
  "intensity_column": self._intensity_column,
182
217
  "min_points": self._min_points,
218
+ "display_aspect_ratio": self._display_aspect_ratio,
183
219
  "x_bins": self._x_bins,
184
220
  "y_bins": self._y_bins,
185
221
  "use_simple_downsample": self._use_simple_downsample,
@@ -190,6 +226,10 @@ class Heatmap(BaseComponent):
190
226
  "x_label": self._x_label,
191
227
  "y_label": self._y_label,
192
228
  "colorscale": self._colorscale,
229
+ "category_column": self._category_column,
230
+ "log_scale": self._log_scale,
231
+ "intensity_label": self._intensity_label,
232
+ # Note: category_colors is render-time styling, doesn't affect cache
193
233
  }
194
234
 
195
235
  def _restore_cache_config(self, config: Dict[str, Any]) -> None:
@@ -197,7 +237,10 @@ class Heatmap(BaseComponent):
197
237
  self._x_column = config.get("x_column")
198
238
  self._y_column = config.get("y_column")
199
239
  self._intensity_column = config.get("intensity_column", "intensity")
200
- self._min_points = config.get("min_points", 20000)
240
+ self._min_points = config.get("min_points", 10000)
241
+ self._display_aspect_ratio = config.get("display_aspect_ratio", 16 / 9)
242
+ # x_bins/y_bins are computed during preprocessing and stored in cache
243
+ # Fallback to old defaults for backward compatibility with old caches
201
244
  self._x_bins = config.get("x_bins", 400)
202
245
  self._y_bins = config.get("y_bins", 50)
203
246
  self._use_simple_downsample = config.get("use_simple_downsample", False)
@@ -208,6 +251,10 @@ class Heatmap(BaseComponent):
208
251
  self._x_label = config.get("x_label", self._x_column)
209
252
  self._y_label = config.get("y_label", self._y_column)
210
253
  self._colorscale = config.get("colorscale", "Portland")
254
+ self._category_column = config.get("category_column")
255
+ self._log_scale = config.get("log_scale", True)
256
+ self._intensity_label = config.get("intensity_label")
257
+ # category_colors is not stored in cache (render-time styling)
211
258
 
212
259
  def get_state_dependencies(self) -> list:
213
260
  """
@@ -242,14 +289,116 @@ class Heatmap(BaseComponent):
242
289
  else:
243
290
  self._preprocess_eager()
244
291
 
292
+ def _build_cascading_levels(
293
+ self,
294
+ source_data: pl.LazyFrame,
295
+ level_sizes: list,
296
+ x_range: tuple,
297
+ y_range: tuple,
298
+ cache_dir,
299
+ prefix: str = "level",
300
+ ) -> dict:
301
+ """
302
+ Build cascading compression levels from source data.
303
+
304
+ Each level is built from the previous larger level rather than from
305
+ raw data. This is efficient (raw data read once) and produces identical
306
+ results because the downsampling keeps top N highest-intensity points
307
+ per bin - points surviving at larger levels will also be selected at
308
+ smaller levels.
309
+
310
+ Args:
311
+ source_data: LazyFrame with raw/filtered data
312
+ level_sizes: List of target sizes for compressed levels (smallest first)
313
+ x_range: (x_min, x_max) for consistent bin boundaries
314
+ y_range: (y_min, y_max) for consistent bin boundaries
315
+ cache_dir: Path to save parquet files
316
+ prefix: Filename prefix (e.g., "level" or "cat_level_im_0")
317
+
318
+ Returns:
319
+ Dict with level LazyFrames keyed by "{prefix}_{idx}" and "num_levels"
320
+ """
321
+ import sys
322
+
323
+ result = {}
324
+ num_compressed = len(level_sizes)
325
+
326
+ # Get total count
327
+ total = source_data.select(pl.len()).collect().item()
328
+
329
+ # First: save full resolution as the largest level
330
+ full_res_path = cache_dir / f"{prefix}_{num_compressed}.parquet"
331
+ full_res = source_data.sort([self._x_column, self._y_column])
332
+ full_res.sink_parquet(full_res_path, compression="zstd")
333
+ print(
334
+ f"[HEATMAP] Saved {prefix}_{num_compressed} ({total:,} pts)",
335
+ file=sys.stderr,
336
+ )
337
+
338
+ # Start cascading from full resolution
339
+ current_source = pl.scan_parquet(full_res_path)
340
+ current_size = total
341
+
342
+ # Build compressed levels from largest to smallest
343
+ for i, target_size in enumerate(reversed(level_sizes)):
344
+ level_idx = num_compressed - 1 - i
345
+ level_path = cache_dir / f"{prefix}_{level_idx}.parquet"
346
+
347
+ # If target size equals or exceeds current, just copy reference
348
+ if target_size >= current_size:
349
+ level = current_source
350
+ elif self._use_simple_downsample:
351
+ level = downsample_2d_simple(
352
+ current_source,
353
+ max_points=target_size,
354
+ intensity_column=self._intensity_column,
355
+ )
356
+ else:
357
+ level = downsample_2d_streaming(
358
+ current_source,
359
+ max_points=target_size,
360
+ x_column=self._x_column,
361
+ y_column=self._y_column,
362
+ intensity_column=self._intensity_column,
363
+ x_bins=self._x_bins,
364
+ y_bins=self._y_bins,
365
+ x_range=x_range,
366
+ y_range=y_range,
367
+ )
368
+
369
+ # Sort and save immediately
370
+ level = level.sort([self._x_column, self._y_column])
371
+ level.sink_parquet(level_path, compression="zstd")
372
+
373
+ print(
374
+ f"[HEATMAP] Saved {prefix}_{level_idx} (target {target_size:,} pts)",
375
+ file=sys.stderr,
376
+ )
377
+
378
+ # Next iteration uses this level as source (cascading)
379
+ current_source = pl.scan_parquet(level_path)
380
+ current_size = target_size
381
+
382
+ # Load all levels back as LazyFrames
383
+ for i in range(num_compressed + 1):
384
+ level_path = cache_dir / f"{prefix}_{i}.parquet"
385
+ result[f"{prefix}_{i}"] = pl.scan_parquet(level_path)
386
+
387
+ result["num_levels"] = num_compressed + 1
388
+
389
+ return result
390
+
245
391
  def _preprocess_with_categorical_filters(self) -> None:
246
392
  """
247
- Preprocess with per-filter-value compression levels.
393
+ Preprocess with per-filter-value compression levels using cascading.
248
394
 
249
395
  For each unique value of each categorical filter, creates separate
250
- compression levels. This ensures that when a filter is applied at
251
- render time, the resulting data has ~min_points regardless of the
252
- filter value selected.
396
+ compression levels using cascading (building smaller levels from larger).
397
+ This ensures that when a filter is applied at render time, the resulting
398
+ data has ~min_points regardless of the filter value selected.
399
+
400
+ Uses cascading downsampling for efficiency - each level is built from
401
+ the previous larger level rather than from raw data.
253
402
 
254
403
  Data is sorted by x, y columns for efficient range query predicate pushdown.
255
404
 
@@ -261,6 +410,7 @@ class Heatmap(BaseComponent):
261
410
  import sys
262
411
 
263
412
  # Get data ranges (for the full dataset)
413
+ # These ranges are used for ALL levels to ensure consistent binning
264
414
  x_range, y_range = get_data_range(
265
415
  self._raw_data,
266
416
  self._x_column,
@@ -269,10 +419,31 @@ class Heatmap(BaseComponent):
269
419
  self._preprocessed_data["x_range"] = x_range
270
420
  self._preprocessed_data["y_range"] = y_range
271
421
 
422
+ # Compute optimal bins if not provided
423
+ # Cache at 2×min_points, use display_aspect_ratio for bin computation
424
+ cache_target = 2 * self._min_points
425
+ if self._x_bins is None or self._y_bins is None:
426
+ # Use display aspect ratio (not data aspect ratio) for optimal bins
427
+ self._x_bins, self._y_bins = compute_optimal_bins(
428
+ cache_target,
429
+ (0, self._display_aspect_ratio), # Fake x_range matching aspect
430
+ (0, 1.0), # Fake y_range
431
+ )
432
+ print(
433
+ f"[HEATMAP] Auto-computed bins: {self._x_bins}x{self._y_bins} "
434
+ f"= {self._x_bins * self._y_bins:,} (cache target: {cache_target:,}, "
435
+ f"display aspect: {self._display_aspect_ratio:.2f})",
436
+ file=sys.stderr,
437
+ )
438
+
272
439
  # Get total count
273
440
  total = self._raw_data.select(pl.len()).collect().item()
274
441
  self._preprocessed_data["total"] = total
275
442
 
443
+ # Create cache directory for immediate level saving
444
+ cache_dir = self._cache_dir / "preprocessed"
445
+ cache_dir.mkdir(parents=True, exist_ok=True)
446
+
276
447
  # Store metadata about categorical filters
277
448
  self._preprocessed_data["has_categorical_filters"] = True
278
449
  self._preprocessed_data["categorical_filter_values"] = {}
@@ -309,7 +480,7 @@ class Heatmap(BaseComponent):
309
480
  unique_values
310
481
  )
311
482
 
312
- # Create compression levels for each filter value
483
+ # Create compression levels for each filter value using cascading
313
484
  for filter_value in unique_values:
314
485
  # Filter data to this value
315
486
  filtered_data = self._raw_data.filter(
@@ -317,10 +488,8 @@ class Heatmap(BaseComponent):
317
488
  )
318
489
  filtered_total = filtered_data.select(pl.len()).collect().item()
319
490
 
320
- # Compute level sizes for this filtered subset
321
- level_sizes = compute_compression_levels(
322
- self._min_points, filtered_total
323
- )
491
+ # Compute level sizes for this filtered subset (2× for cache buffer)
492
+ level_sizes = compute_compression_levels(cache_target, filtered_total)
324
493
 
325
494
  print(
326
495
  f"[HEATMAP] Value {filter_value}: {filtered_total:,} pts → levels {level_sizes}",
@@ -332,94 +501,71 @@ class Heatmap(BaseComponent):
332
501
  f"cat_level_sizes_{filter_id}_{filter_value}"
333
502
  ] = level_sizes
334
503
 
335
- # Build each compressed level
336
- for level_idx, target_size in enumerate(level_sizes):
337
- # If target size equals total, skip downsampling - use all data
338
- if target_size >= filtered_total:
339
- level = filtered_data
340
- elif self._use_simple_downsample:
341
- level = downsample_2d_simple(
342
- filtered_data,
343
- max_points=target_size,
344
- intensity_column=self._intensity_column,
345
- )
346
- else:
347
- level = downsample_2d_streaming(
348
- filtered_data,
349
- max_points=target_size,
350
- x_column=self._x_column,
351
- y_column=self._y_column,
352
- intensity_column=self._intensity_column,
353
- x_bins=self._x_bins,
354
- y_bins=self._y_bins,
355
- x_range=x_range,
356
- y_range=y_range,
357
- )
358
-
359
- # Sort by x, y for efficient range query predicate pushdown
360
- level = level.sort([self._x_column, self._y_column])
361
- # Store LazyFrame for streaming to disk
362
- level_key = f"cat_level_{filter_id}_{filter_value}_{level_idx}"
363
- self._preprocessed_data[level_key] = level # Keep lazy
364
-
365
- # Add full resolution as final level (for zoom fallback)
366
- # Also sorted for consistent predicate pushdown behavior
367
- num_compressed = len(level_sizes)
368
- full_res_key = f"cat_level_{filter_id}_{filter_value}_{num_compressed}"
369
- self._preprocessed_data[full_res_key] = filtered_data.sort(
370
- [self._x_column, self._y_column]
504
+ # Build cascading levels using helper
505
+ prefix = f"cat_level_{filter_id}_{filter_value}"
506
+ levels_result = self._build_cascading_levels(
507
+ source_data=filtered_data,
508
+ level_sizes=level_sizes,
509
+ x_range=x_range,
510
+ y_range=y_range,
511
+ cache_dir=cache_dir,
512
+ prefix=prefix,
371
513
  )
372
- self._preprocessed_data[
373
- f"cat_num_levels_{filter_id}_{filter_value}"
374
- ] = num_compressed + 1
514
+
515
+ # Copy results to preprocessed_data
516
+ for key, value in levels_result.items():
517
+ if key == "num_levels":
518
+ self._preprocessed_data[
519
+ f"cat_num_levels_{filter_id}_{filter_value}"
520
+ ] = value
521
+ else:
522
+ self._preprocessed_data[key] = value
375
523
 
376
524
  # Also create global levels for when no categorical filter is selected
377
- # (fallback to standard behavior)
378
- level_sizes = compute_compression_levels(self._min_points, total)
525
+ # (fallback to standard behavior) - using cascading with 2× cache buffer
526
+ level_sizes = compute_compression_levels(cache_target, total)
379
527
  self._preprocessed_data["level_sizes"] = level_sizes
380
528
 
381
- for i, size in enumerate(level_sizes):
382
- # If target size equals total, skip downsampling - use all data
383
- if size >= total:
384
- level = self._raw_data
385
- elif self._use_simple_downsample:
386
- level = downsample_2d_simple(
387
- self._raw_data,
388
- max_points=size,
389
- intensity_column=self._intensity_column,
390
- )
529
+ # Build global cascading levels using helper
530
+ levels_result = self._build_cascading_levels(
531
+ source_data=self._raw_data,
532
+ level_sizes=level_sizes,
533
+ x_range=x_range,
534
+ y_range=y_range,
535
+ cache_dir=cache_dir,
536
+ prefix="level",
537
+ )
538
+
539
+ # Copy results to preprocessed_data
540
+ for key, value in levels_result.items():
541
+ if key == "num_levels":
542
+ self._preprocessed_data["num_levels"] = value
391
543
  else:
392
- level = downsample_2d_streaming(
393
- self._raw_data,
394
- max_points=size,
395
- x_column=self._x_column,
396
- y_column=self._y_column,
397
- intensity_column=self._intensity_column,
398
- x_bins=self._x_bins,
399
- y_bins=self._y_bins,
400
- x_range=x_range,
401
- y_range=y_range,
402
- )
403
- # Sort by x, y for efficient range query predicate pushdown
404
- level = level.sort([self._x_column, self._y_column])
405
- self._preprocessed_data[f"level_{i}"] = level # Keep lazy
544
+ self._preprocessed_data[key] = value
406
545
 
407
- # Add full resolution as final level (for zoom fallback)
408
- # Also sorted for consistent predicate pushdown behavior
409
- num_compressed = len(level_sizes)
410
- self._preprocessed_data[f"level_{num_compressed}"] = self._raw_data.sort(
411
- [self._x_column, self._y_column]
412
- )
413
- self._preprocessed_data["num_levels"] = num_compressed + 1
546
+ # Mark that files are already saved
547
+ self._preprocessed_data["_files_already_saved"] = True
414
548
 
415
549
  def _preprocess_streaming(self) -> None:
416
550
  """
417
- Streaming preprocessing - levels stay lazy through caching.
551
+ Streaming preprocessing with cascading - builds smaller levels from larger.
552
+
553
+ Uses cascading downsampling: each level is built from the previous larger
554
+ level rather than from raw data. This is more efficient (raw data read once)
555
+ and produces identical results because the downsampling algorithm keeps
556
+ the TOP N highest-intensity points per bin - points that survive at a larger
557
+ level will also be selected at smaller levels.
558
+
559
+ Levels are saved to disk immediately after creation, then read back as the
560
+ source for the next smaller level. This keeps memory low while enabling
561
+ cascading.
418
562
 
419
- Builds lazy query plans that are streamed to disk via sink_parquet().
420
563
  Data is sorted by x, y columns for efficient range query predicate pushdown.
421
564
  """
565
+ import sys
566
+
422
567
  # Get data ranges (minimal collect - just 4 values)
568
+ # These ranges are used for ALL levels to ensure consistent binning
423
569
  x_range, y_range = get_data_range(
424
570
  self._raw_data,
425
571
  self._x_column,
@@ -428,55 +574,55 @@ class Heatmap(BaseComponent):
428
574
  self._preprocessed_data["x_range"] = x_range
429
575
  self._preprocessed_data["y_range"] = y_range
430
576
 
577
+ # Compute optimal bins if not provided
578
+ # Cache at 2×min_points, use display_aspect_ratio for bin computation
579
+ cache_target = 2 * self._min_points
580
+ if self._x_bins is None or self._y_bins is None:
581
+ # Use display aspect ratio (not data aspect ratio) for optimal bins
582
+ # This ensures even distribution in the expected display dimensions
583
+ self._x_bins, self._y_bins = compute_optimal_bins(
584
+ cache_target,
585
+ (0, self._display_aspect_ratio), # Fake x_range matching aspect
586
+ (0, 1.0), # Fake y_range
587
+ )
588
+ print(
589
+ f"[HEATMAP] Auto-computed bins: {self._x_bins}x{self._y_bins} "
590
+ f"= {self._x_bins * self._y_bins:,} (cache target: {cache_target:,}, "
591
+ f"display aspect: {self._display_aspect_ratio:.2f})",
592
+ file=sys.stderr,
593
+ )
594
+
431
595
  # Get total count
432
596
  total = self._raw_data.select(pl.len()).collect().item()
433
597
  self._preprocessed_data["total"] = total
434
598
 
435
- # Compute target sizes for levels
436
- level_sizes = compute_compression_levels(self._min_points, total)
599
+ # Compute target sizes for levels (use 2×min_points for smallest cache level)
600
+ level_sizes = compute_compression_levels(cache_target, total)
437
601
  self._preprocessed_data["level_sizes"] = level_sizes
438
602
 
439
- # Build and collect each level
440
- self._preprocessed_data["levels"] = []
603
+ # Create cache directory for immediate level saving
604
+ cache_dir = self._cache_dir / "preprocessed"
605
+ cache_dir.mkdir(parents=True, exist_ok=True)
606
+
607
+ # Build cascading levels using helper
608
+ levels_result = self._build_cascading_levels(
609
+ source_data=self._raw_data,
610
+ level_sizes=level_sizes,
611
+ x_range=x_range,
612
+ y_range=y_range,
613
+ cache_dir=cache_dir,
614
+ prefix="level",
615
+ )
441
616
 
442
- for i, size in enumerate(level_sizes):
443
- # If target size equals total, skip downsampling - use all data
444
- if size >= total:
445
- level = self._raw_data
446
- elif self._use_simple_downsample:
447
- level = downsample_2d_simple(
448
- self._raw_data,
449
- max_points=size,
450
- intensity_column=self._intensity_column,
451
- )
617
+ # Copy results to preprocessed_data
618
+ for key, value in levels_result.items():
619
+ if key == "num_levels":
620
+ self._preprocessed_data["num_levels"] = value
452
621
  else:
453
- level = downsample_2d_streaming(
454
- self._raw_data,
455
- max_points=size,
456
- x_column=self._x_column,
457
- y_column=self._y_column,
458
- intensity_column=self._intensity_column,
459
- x_bins=self._x_bins,
460
- y_bins=self._y_bins,
461
- x_range=x_range,
462
- y_range=y_range,
463
- )
464
- # Sort by x, y for efficient range query predicate pushdown
465
- # This clusters spatially close points together in row groups
466
- level = level.sort([self._x_column, self._y_column])
467
- # Store LazyFrame for streaming to disk
468
- # Base class will use sink_parquet() to stream without full materialization
469
- self._preprocessed_data[f"level_{i}"] = level # Keep lazy
622
+ self._preprocessed_data[key] = value
470
623
 
471
- # Add full resolution as final level (for zoom fallback)
472
- # Also sorted for consistent predicate pushdown behavior
473
- num_compressed = len(level_sizes)
474
- self._preprocessed_data[f"level_{num_compressed}"] = self._raw_data.sort(
475
- [self._x_column, self._y_column]
476
- )
477
-
478
- # Store number of levels for reconstruction (includes full resolution)
479
- self._preprocessed_data["num_levels"] = num_compressed + 1
624
+ # Mark that files are already saved (base class should skip saving)
625
+ self._preprocessed_data["_files_already_saved"] = True
480
626
 
481
627
  def _preprocess_eager(self) -> None:
482
628
  """
@@ -486,6 +632,8 @@ class Heatmap(BaseComponent):
486
632
  downsampling for better spatial distribution.
487
633
  Data is sorted by x, y columns for efficient range query predicate pushdown.
488
634
  """
635
+ import sys
636
+
489
637
  # Get data ranges
490
638
  x_range, y_range = get_data_range(
491
639
  self._raw_data,
@@ -495,12 +643,29 @@ class Heatmap(BaseComponent):
495
643
  self._preprocessed_data["x_range"] = x_range
496
644
  self._preprocessed_data["y_range"] = y_range
497
645
 
646
+ # Compute optimal bins if not provided
647
+ # Cache at 2×min_points, use display_aspect_ratio for bin computation
648
+ cache_target = 2 * self._min_points
649
+ if self._x_bins is None or self._y_bins is None:
650
+ # Use display aspect ratio (not data aspect ratio) for optimal bins
651
+ self._x_bins, self._y_bins = compute_optimal_bins(
652
+ cache_target,
653
+ (0, self._display_aspect_ratio), # Fake x_range matching aspect
654
+ (0, 1.0), # Fake y_range
655
+ )
656
+ print(
657
+ f"[HEATMAP] Auto-computed bins: {self._x_bins}x{self._y_bins} "
658
+ f"= {self._x_bins * self._y_bins:,} (cache target: {cache_target:,}, "
659
+ f"display aspect: {self._display_aspect_ratio:.2f})",
660
+ file=sys.stderr,
661
+ )
662
+
498
663
  # Get total count
499
664
  total = self._raw_data.select(pl.len()).collect().item()
500
665
  self._preprocessed_data["total"] = total
501
666
 
502
- # Compute compression level target sizes
503
- level_sizes = compute_compression_levels(self._min_points, total)
667
+ # Compute compression level target sizes (2× for cache buffer)
668
+ level_sizes = compute_compression_levels(cache_target, total)
504
669
  self._preprocessed_data["level_sizes"] = level_sizes
505
670
 
506
671
  # Build levels from largest to smallest
@@ -736,10 +901,18 @@ class Heatmap(BaseComponent):
736
901
  if count >= self._min_points:
737
902
  # This level has enough detail
738
903
  if count > self._min_points:
739
- # Over limit - downsample to stay at/under max
740
- # Use ZOOM range for binning (not global) to avoid sparse bins
904
+ # Over limit - downsample to exactly min_points
905
+ # Compute optimal bins from ACTUAL zoom region aspect ratio
741
906
  zoom_x_range = (x0, x1)
742
907
  zoom_y_range = (y0, y1)
908
+ render_x_bins, render_y_bins = compute_optimal_bins(
909
+ self._min_points, zoom_x_range, zoom_y_range
910
+ )
911
+ print(
912
+ f"[HEATMAP] Render downsample: {count:,} → {self._min_points:,} pts "
913
+ f"(bins: {render_x_bins}x{render_y_bins})",
914
+ file=sys.stderr,
915
+ )
743
916
  if self._use_streaming or self._use_simple_downsample:
744
917
  if self._use_simple_downsample:
745
918
  return downsample_2d_simple(
@@ -754,8 +927,8 @@ class Heatmap(BaseComponent):
754
927
  x_column=self._x_column,
755
928
  y_column=self._y_column,
756
929
  intensity_column=self._intensity_column,
757
- x_bins=self._x_bins,
758
- y_bins=self._y_bins,
930
+ x_bins=render_x_bins,
931
+ y_bins=render_y_bins,
759
932
  x_range=zoom_x_range,
760
933
  y_range=zoom_y_range,
761
934
  ).collect()
@@ -766,8 +939,8 @@ class Heatmap(BaseComponent):
766
939
  x_column=self._x_column,
767
940
  y_column=self._y_column,
768
941
  intensity_column=self._intensity_column,
769
- x_bins=self._x_bins,
770
- y_bins=self._y_bins,
942
+ x_bins=render_x_bins,
943
+ y_bins=render_y_bins,
771
944
  ).collect()
772
945
  return filtered
773
946
 
@@ -794,12 +967,15 @@ class Heatmap(BaseComponent):
794
967
 
795
968
  zoom = state.get(self._zoom_identifier)
796
969
 
797
- # Build columns to select
970
+ # Build columns to select (filter out None values)
798
971
  columns_to_select = [
799
- self._x_column,
800
- self._y_column,
801
- self._intensity_column,
972
+ col
973
+ for col in [self._x_column, self._y_column, self._intensity_column]
974
+ if col is not None
802
975
  ]
976
+ # Include category column if specified
977
+ if self._category_column and self._category_column not in columns_to_select:
978
+ columns_to_select.append(self._category_column)
803
979
  # Include columns needed for interactivity
804
980
  if self._interactivity:
805
981
  for col in self._interactivity.values():
@@ -852,17 +1028,25 @@ class Heatmap(BaseComponent):
852
1028
  columns=columns_to_select,
853
1029
  filter_defaults=self._filter_defaults,
854
1030
  )
855
- # Sort by intensity ascending so high-intensity points are drawn on top
856
- df_pandas = df_pandas.sort_values(self._intensity_column).reset_index(
857
- drop=True
858
- )
1031
+ # Sort by intensity ascending so high-intensity points are drawn on top (scattergl)
1032
+ if (
1033
+ self._intensity_column
1034
+ and self._intensity_column in df_pandas.columns
1035
+ ):
1036
+ df_pandas = df_pandas.sort_values(
1037
+ self._intensity_column, ascending=True
1038
+ ).reset_index(drop=True)
859
1039
  else:
860
1040
  # No filters to apply - levels already filtered by categorical filter
861
1041
  schema_names = data.collect_schema().names()
862
1042
  available_cols = [c for c in columns_to_select if c in schema_names]
863
1043
  df_polars = data.select(available_cols).collect()
864
- # Sort by intensity ascending so high-intensity points are drawn on top
865
- df_polars = df_polars.sort(self._intensity_column)
1044
+ # Sort by intensity ascending so high-intensity points are drawn on top (scattergl)
1045
+ if (
1046
+ self._intensity_column
1047
+ and self._intensity_column in df_polars.columns
1048
+ ):
1049
+ df_polars = df_polars.sort(self._intensity_column)
866
1050
  data_hash = compute_dataframe_hash(df_polars)
867
1051
  df_pandas = df_polars.to_pandas()
868
1052
  else:
@@ -874,8 +1058,9 @@ class Heatmap(BaseComponent):
874
1058
  # Select only needed columns
875
1059
  available_cols = [c for c in columns_to_select if c in df_polars.columns]
876
1060
  df_polars = df_polars.select(available_cols)
877
- # Sort by intensity ascending so high-intensity points are drawn on top
878
- df_polars = df_polars.sort(self._intensity_column)
1061
+ # Sort by intensity ascending so high-intensity points are drawn on top (scattergl)
1062
+ if self._intensity_column and self._intensity_column in df_polars.columns:
1063
+ df_polars = df_polars.sort(self._intensity_column)
879
1064
  print(
880
1065
  f"[HEATMAP] Selected {len(df_polars)} pts for zoom, levels={level_sizes}",
881
1066
  file=sys.stderr,
@@ -903,6 +1088,7 @@ class Heatmap(BaseComponent):
903
1088
  "xLabel": self._x_label,
904
1089
  "yLabel": self._y_label,
905
1090
  "colorscale": self._colorscale,
1091
+ "reversescale": self._reversescale,
906
1092
  "zoomIdentifier": self._zoom_identifier,
907
1093
  "interactivity": self._interactivity,
908
1094
  }
@@ -910,6 +1096,17 @@ class Heatmap(BaseComponent):
910
1096
  if self._title:
911
1097
  args["title"] = self._title
912
1098
 
1099
+ # Add category column configuration for categorical coloring mode
1100
+ if self._category_column:
1101
+ args["categoryColumn"] = self._category_column
1102
+ if self._category_colors:
1103
+ args["categoryColors"] = self._category_colors
1104
+
1105
+ # Add log scale and intensity label configuration
1106
+ args["logScale"] = self._log_scale
1107
+ if self._intensity_label:
1108
+ args["intensityLabel"] = self._intensity_label
1109
+
913
1110
  # Add any extra config options
914
1111
  args.update(self._config)
915
1112