tirex-mirror 2025.11.8__tar.gz → 2025.11.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/PKG-INFO +1 -1
  2. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/pyproject.toml +1 -1
  3. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex/util.py +193 -90
  4. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex_mirror.egg-info/PKG-INFO +1 -1
  5. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/tests/test_util_freq.py +6 -6
  6. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/LICENSE +0 -0
  7. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/LICENSE_MIRROR.txt +0 -0
  8. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/MANIFEST.in +0 -0
  9. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/NOTICE.txt +0 -0
  10. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/README.md +0 -0
  11. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/setup.cfg +0 -0
  12. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex/__init__.py +0 -0
  13. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex/api_adapter/__init__.py +0 -0
  14. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex/api_adapter/forecast.py +0 -0
  15. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex/api_adapter/gluon.py +0 -0
  16. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex/api_adapter/hf_data.py +0 -0
  17. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex/api_adapter/standard_adapter.py +0 -0
  18. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex/base.py +0 -0
  19. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex/models/__init__.py +0 -0
  20. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex/models/patcher.py +0 -0
  21. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex/models/slstm/block.py +0 -0
  22. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex/models/slstm/cell.py +0 -0
  23. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex/models/slstm/layer.py +0 -0
  24. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex/models/tirex.py +0 -0
  25. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex_mirror.egg-info/SOURCES.txt +0 -0
  26. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
  27. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex_mirror.egg-info/requires.txt +0 -0
  28. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/src/tirex_mirror.egg-info/top_level.txt +0 -0
  29. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/tests/test_chronos_zs.py +0 -0
  30. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/tests/test_compile.py +0 -0
  31. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/tests/test_forecast.py +0 -0
  32. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/tests/test_forecast_adapter.py +0 -0
  33. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/tests/test_patcher.py +0 -0
  34. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/tests/test_slstm_torch_vs_cuda.py +0 -0
  35. {tirex_mirror-2025.11.8 → tirex_mirror-2025.11.13}/tests/test_standard_adapter.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.11.8
3
+ Version: 2025.11.13
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "tirex-mirror"
3
- version = "2025.11.08"
3
+ version = "2025.11.13"
4
4
  description = "Unofficial mirror of NX-AI/tirex for packaging"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -4,8 +4,8 @@
4
4
  from collections.abc import Callable, Sequence
5
5
  from dataclasses import fields
6
6
  from functools import partial
7
- from math import ceil
8
- from typing import Literal
7
+ from math import ceil, isfinite
8
+ from typing import Literal, Optional
9
9
 
10
10
  import numpy as np
11
11
  import torch
@@ -77,6 +77,7 @@ def frequency_resample(
77
77
  - For short horizons (prediction_length < 100), resampling is disabled and the factor is set to 1.0.
78
78
  - The factor is clamped to at most 1.0 to avoid upsampling the context.
79
79
  """
80
+
80
81
  sample_factor = frequency_factor(
81
82
  ts,
82
83
  max_period=max_period,
@@ -173,15 +174,17 @@ def frequency_factor(
173
174
  # NOTE: be careful when min_period is not matching patch_size, it can create unexpected scaling factors!
174
175
  min_period = patch_size
175
176
 
176
- # Ensure CPU numpy array for FFT analysis
177
- ts_np = ts.detach().cpu().numpy() if isinstance(ts, torch.Tensor) else np.asarray(ts)
177
+ if isinstance(ts, torch.Tensor):
178
+ ts_tensor = ts.to(torch.float32)
179
+ else:
180
+ ts_tensor = torch.as_tensor(ts, dtype=torch.float32)
178
181
 
179
182
  # NOTE: If the series is shorter than max_period *2, FFT may not be accurate, to avoid detecting these peaks, we don't scale
180
- if ts_np.size < max_period * 2:
183
+ if ts_tensor.numel() < max_period * 2:
181
184
  return 1.0
182
185
 
183
186
  freqs, specs, peak_idc = run_fft_analysis(
184
- ts_np,
187
+ ts_tensor,
185
188
  scaling="amplitude",
186
189
  peak_prominence=peak_prominence,
187
190
  min_period=min_period,
@@ -190,18 +193,18 @@ def frequency_factor(
190
193
  )
191
194
 
192
195
  # No detectable peaks -> keep original sampling
193
- if peak_idc.size == 0:
196
+ if peak_idc.numel() == 0:
194
197
  return 1.0
195
198
 
196
199
  # Choose initial candidate as the highest-amplitude peak
197
- chosen_idx = int(peak_idc[0])
200
+ chosen_idx = int(peak_idc[0].item())
198
201
 
199
202
  # If two peaks exist, check for ~2x harmonic relation and prefer the higher/lower one
200
- if peak_idc.size >= 2:
201
- idx_a = int(peak_idc[0]) # highest amplitude
202
- idx_b = int(peak_idc[1]) # second highest amplitude
203
- f_a = float(freqs[idx_a])
204
- f_b = float(freqs[idx_b])
203
+ if peak_idc.numel() >= 2:
204
+ idx_a = int(peak_idc[0].item()) # highest amplitude
205
+ idx_b = int(peak_idc[1].item()) # second highest amplitude
206
+ f_a = float(freqs[idx_a].item())
207
+ f_b = float(freqs[idx_b].item())
205
208
 
206
209
  # Determine lower/higher frequency
207
210
  low_f = min(f_a, f_b)
@@ -216,10 +219,10 @@ def frequency_factor(
216
219
  elif selection_method == "high_harmonic":
217
220
  chosen_idx = idx_a if f_a > f_b else idx_b
218
221
 
219
- chosen_freq = float(freqs[chosen_idx])
222
+ chosen_freq = float(freqs[chosen_idx].item())
220
223
 
221
224
  # Guard against zero or non-finite frequency
222
- if not np.isfinite(chosen_freq) or chosen_freq <= 0:
225
+ if not isfinite(chosen_freq) or chosen_freq <= 0:
223
226
  return 1.0
224
227
 
225
228
  # Convert to period and compute scaling factor so one period fits one patch
@@ -228,19 +231,33 @@ def frequency_factor(
228
231
  factor = round(factor, 4)
229
232
 
230
233
  # Guard against factor being negative
231
- if not np.isfinite(factor) or factor <= 0:
234
+ if not isfinite(factor) or factor <= 0:
232
235
  return 1.0
233
236
 
234
237
  # nearest interger fraction rounding (nifr)
235
238
  if nifr_enabled:
236
- int_fractions = np.concatenate([[1], 1 / np.arange(nifr_start_integer, nifr_end_integer + 1)])
237
- diff = np.abs(factor - int_fractions)
238
- min_diff_idc = np.argmin(diff)
239
- factor = int_fractions[min_diff_idc]
239
+ device = ts_tensor.device
240
+ dtype = torch.float32
241
+ base = torch.ones(1, device=device, dtype=dtype)
242
+ if nifr_end_integer >= nifr_start_integer:
243
+ denominators = torch.arange(nifr_start_integer, nifr_end_integer + 1, device=device, dtype=dtype)
244
+ candidate_factors = torch.cat([base, 1.0 / denominators])
245
+ else:
246
+ candidate_factors = base
247
+
248
+ factor_tensor = torch.tensor(factor, device=device, dtype=dtype)
249
+ diff = torch.abs(factor_tensor - candidate_factors)
250
+ min_idx = int(torch.argmin(diff).item())
251
+ factor_tensor = candidate_factors[min_idx]
240
252
 
241
253
  if nifr_clamp_large_factors:
242
254
  # Clamp everything between 1 and 1/nifr_start_integer to 1, that is no scaling
243
- factor = factor if factor < int_fractions[1] else 1
255
+ if candidate_factors.numel() > 1:
256
+ clamp_threshold = candidate_factors[1]
257
+ one = torch.tensor(1.0, device=device, dtype=dtype)
258
+ factor_tensor = torch.where(factor_tensor < clamp_threshold, factor_tensor, one)
259
+
260
+ factor = float(factor_tensor.item())
244
261
 
245
262
  return float(factor)
246
263
 
@@ -439,57 +456,69 @@ def run_fft_analysis(
439
456
  peaks_idx : ndarray
440
457
  Indices into f of detected peaks.
441
458
  """
442
- y = np.asarray(y, dtype=float)
443
- if y.ndim != 1:
444
- y = y.reshape(-1)
445
- n = y.size
459
+ if isinstance(y, torch.Tensor):
460
+ y_tensor = y.to(torch.float32)
461
+ else:
462
+ y_tensor = torch.as_tensor(y, dtype=torch.float32)
463
+
464
+ if y_tensor.ndim != 1:
465
+ y_tensor = y_tensor.reshape(-1)
466
+
467
+ n = y_tensor.numel()
468
+ device = y_tensor.device
469
+
446
470
  if n < 2:
447
- return np.array([]), np.array([]), np.array([])
471
+ empty = torch.empty(0, dtype=y_tensor.dtype, device=device)
472
+ return empty, empty, empty
448
473
 
449
474
  # Fill NaNs linearly (handles edge NaNs as well)
450
- y = _nan_linear_interpolate(y)
475
+ y_tensor = _nan_linear_interpolate(y_tensor)
451
476
 
452
477
  if detrend:
453
- y = y - np.mean(y)
478
+ y_tensor = y_tensor - torch.mean(y_tensor)
454
479
 
455
480
  # Windowing
456
481
  if window == "hann":
457
- w = np.hanning(n)
458
- yw = y * w
482
+ w = torch.hann_window(n, device=device, dtype=y_tensor.dtype)
483
+ yw = y_tensor * w
459
484
  # average window power (for proper amplitude/power normalization)
460
- w_power = np.sum(w**2) / n
485
+ w_power = torch.sum(w.square()) / n
461
486
  elif window is None:
462
- yw = y
463
- w_power = 1.0
487
+ yw = y_tensor
488
+ w_power = torch.tensor(1.0, device=device, dtype=y_tensor.dtype)
464
489
  else:
465
490
  raise ValueError("window must be either 'hann' or None")
466
491
 
467
492
  # FFT (one-sided)
468
- Y = np.fft.rfft(yw)
469
- f = np.fft.rfftfreq(n, d=dt) # cycles per unit time
493
+ Y = torch.fft.rfft(yw)
494
+ f = torch.fft.rfftfreq(n, d=dt, device=device, dtype=y_tensor.dtype) # cycles per unit time
470
495
 
471
496
  if scaling == "raw":
472
- spec = np.abs(Y)
497
+ spec = torch.abs(Y)
473
498
  elif scaling == "amplitude":
474
499
  # One-sided amplitude with window power compensation
475
- spec = np.abs(Y) / (n * np.sqrt(w_power))
476
- if n % 2 == 0:
477
- spec[1:-1] *= 2.0
478
- else:
479
- spec[1:] *= 2.0
500
+ spec = torch.abs(Y) / (n * torch.sqrt(w_power))
501
+ if spec.numel() > 1:
502
+ if n % 2 == 0 and spec.numel() > 2:
503
+ spec[1:-1] *= 2.0
504
+ else:
505
+ spec[1:] *= 2.0
480
506
  elif scaling == "power":
481
507
  # One-sided power (not PSD)
482
- spec = (np.abs(Y) ** 2) / (n**2 * w_power)
483
- if n % 2 == 0:
484
- spec[1:-1] *= 2.0
485
- else:
486
- spec[1:] *= 2.0
508
+ spec = (torch.abs(Y) ** 2) / (n**2 * w_power)
509
+ if spec.numel() > 1:
510
+ if n % 2 == 0 and spec.numel() > 2:
511
+ spec[1:-1] *= 2.0
512
+ else:
513
+ spec[1:] *= 2.0
487
514
  else:
488
515
  raise ValueError("scaling must be 'amplitude', 'power', or 'raw'")
489
516
 
490
517
  # Normalize the spectrum by its maximum value
491
- if spec.max() > 0:
492
- spec = spec / spec.max()
518
+ if spec.numel() > 0:
519
+ max_val = torch.max(spec)
520
+ if max_val > 0:
521
+ spec = spec / max_val
493
522
 
494
523
  # Find peaks in the spectrum
495
524
  peaks_idx = custom_find_peaks(
@@ -505,20 +534,75 @@ def run_fft_analysis(
505
534
  return f, spec, peaks_idx
506
535
 
507
536
 
508
- def _nan_linear_interpolate(y: np.ndarray) -> np.ndarray:
509
- y = y.astype(np.float32)
537
+ def _nan_linear_interpolate(y: torch.Tensor) -> torch.Tensor:
538
+ """
539
+ Linearly interpolate NaN values in a 1D torch tensor.
540
+ """
541
+ y = y.to(torch.float32)
510
542
  if y.ndim != 1:
511
543
  y = y.reshape(-1)
512
- n = y.size
513
- mask = np.isfinite(y)
544
+ n = y.numel()
545
+ mask = torch.isfinite(y)
514
546
  if mask.all():
515
547
  return y
516
548
  if (~mask).all():
517
- return np.zeros(n, dtype=np.float32)
518
- idx = np.arange(n)
519
- y_interp = y.copy()
520
- y_interp[~mask] = np.interp(idx[~mask], idx[mask], y[mask])
521
- return y_interp
549
+ return torch.zeros(n, dtype=y.dtype, device=y.device)
550
+
551
+ idx = torch.arange(n, device=y.device)
552
+ valid_idx = idx[mask]
553
+ valid_vals = y[mask]
554
+
555
+ insert_pos = torch.searchsorted(valid_idx, idx)
556
+ prev_pos = torch.clamp(insert_pos - 1, min=0)
557
+ next_pos = torch.clamp(insert_pos, max=valid_idx.numel() - 1)
558
+
559
+ prev_idx = valid_idx[prev_pos]
560
+ next_idx = valid_idx[next_pos]
561
+
562
+ prev_vals = valid_vals[prev_pos]
563
+ next_vals = valid_vals[next_pos]
564
+
565
+ has_prev = insert_pos > 0
566
+ has_next = insert_pos < valid_idx.numel()
567
+
568
+ result = y.clone()
569
+ missing = ~mask
570
+ if missing.any():
571
+ idx_missing = idx[missing]
572
+ prev_idx_missing = prev_idx[missing]
573
+ next_idx_missing = next_idx[missing]
574
+ prev_vals_missing = prev_vals[missing]
575
+ next_vals_missing = next_vals[missing]
576
+ has_prev_missing = has_prev[missing]
577
+ has_next_missing = has_next[missing]
578
+
579
+ interp_vals = torch.empty_like(idx_missing, dtype=y.dtype)
580
+
581
+ both_mask = has_prev_missing & has_next_missing
582
+ if both_mask.any():
583
+ denom = (next_idx_missing[both_mask] - prev_idx_missing[both_mask]).to(y.dtype)
584
+ denom = torch.where(denom == 0, torch.ones_like(denom), denom)
585
+ t = (idx_missing[both_mask].to(y.dtype) - prev_idx_missing[both_mask].to(y.dtype)) / denom
586
+ interp_vals[both_mask] = (
587
+ prev_vals_missing[both_mask] + (next_vals_missing[both_mask] - prev_vals_missing[both_mask]) * t
588
+ )
589
+
590
+ left_only = has_prev_missing & ~has_next_missing
591
+ if left_only.any():
592
+ interp_vals[left_only] = prev_vals_missing[left_only]
593
+
594
+ right_only = ~has_prev_missing & has_next_missing
595
+ if right_only.any():
596
+ interp_vals[right_only] = next_vals_missing[right_only]
597
+
598
+ # Handle corner case where neither prev nor next exists (shouldn't happen due to earlier checks)
599
+ neither = ~(both_mask | left_only | right_only)
600
+ if neither.any():
601
+ interp_vals[neither] = 0.0
602
+
603
+ result[missing] = interp_vals
604
+
605
+ return result
522
606
 
523
607
 
524
608
  def resampling_factor(inverted_freq, path_size):
@@ -532,13 +616,14 @@ def resampling_factor(inverted_freq, path_size):
532
616
 
533
617
 
534
618
  def custom_find_peaks(
535
- f,
536
- spec,
537
- max_peaks=5,
538
- prominence_threshold=0.1,
539
- min_period=64,
540
- max_period=1000,
541
- bandpass_filter=True,
619
+ f: torch.Tensor,
620
+ spec: torch.Tensor,
621
+ *,
622
+ max_peaks: int = 5,
623
+ prominence_threshold: float = 0.1,
624
+ min_period: int = 64,
625
+ max_period: int = 1000,
626
+ bandpass_filter: bool = True,
542
627
  ):
543
628
  """
544
629
  Finds prominent peaks in a spectrum using a simple custom logic.
@@ -556,60 +641,78 @@ def custom_find_peaks(
556
641
  The maximum number of peaks to return.
557
642
  prominence_threshold : float
558
643
  The minimum height for a peak to be considered prominent.
644
+ min_period : int
645
+ Minimum period to consider for peaks.
646
+ max_period : int
647
+ Maximum period to consider for peaks.
648
+ bandpass_filter : bool
649
+ If True, suppress very low frequencies below 1 / max_period before peak search.
559
650
 
560
651
  Returns
561
652
  -------
562
- np.ndarray
563
- An array of indices of the detected peaks in the spectrum. Returns an
564
- empty array if no prominent peaks are found.
653
+ torch.Tensor
654
+ Long tensor of indices of detected peaks in descending order of prominence.
565
655
  """
566
- if len(spec) < 5: # Need at least 5 points to exclude last two bins
567
- return np.array([], dtype=int)
656
+ if spec.numel() < 5: # Need at least 5 points to exclude last two bins
657
+ return spec.new_empty(0, dtype=torch.long)
568
658
 
569
659
  if bandpass_filter: # only truly filter low frequencies, high frequencies are dealt with later
570
660
  min_freq = 1 / max_period
571
- freq_mask = f >= min_freq
661
+ freq_mask = (f >= min_freq).to(spec.dtype)
572
662
  spec = spec * freq_mask
573
663
 
574
664
  # Find all local maxima, excluding the last two bins
575
- local_maxima_indices = []
576
- for i in range(1, len(spec) - 2):
577
- if spec[i] > spec[i - 1] and spec[i] > spec[i + 1]:
578
- local_maxima_indices.append(i)
665
+ candidates = torch.arange(1, spec.size(0) - 2, device=spec.device, dtype=torch.long)
666
+ if candidates.numel() == 0:
667
+ return spec.new_empty(0, dtype=torch.long)
579
668
 
580
- if not local_maxima_indices:
581
- return np.array([], dtype=int)
669
+ center = spec[candidates]
670
+ left = spec[candidates - 1]
671
+ right = spec[candidates + 1]
672
+ local_mask = (center > left) & (center > right)
673
+
674
+ if not local_mask.any():
675
+ return spec.new_empty(0, dtype=torch.long)
676
+
677
+ local_maxima_indices = candidates[local_mask]
582
678
 
583
679
  # Filter by prominence (height)
584
- prominent_peaks = []
585
- for idx in local_maxima_indices:
586
- if spec[idx] > prominence_threshold:
587
- prominent_peaks.append((idx, spec[idx]))
680
+ heights = spec[local_maxima_indices]
681
+ prominence_mask = heights > prominence_threshold
682
+ if not prominence_mask.any():
683
+ return spec.new_empty(0, dtype=torch.long)
684
+
685
+ prominent_indices = local_maxima_indices[prominence_mask]
588
686
 
589
- # If no peaks are above the threshold, return an empty list
590
- if not prominent_peaks:
591
- return np.array([], dtype=int)
687
+ prominent_heights = spec[prominent_indices]
592
688
 
593
689
  # Check for clear peaks below min_period (do lowpass filter)
594
- for idx, _ in prominent_peaks:
595
- period = 1 / f[idx]
690
+ for idx in prominent_indices.tolist():
691
+ freq_val = float(f[idx].item())
692
+ if freq_val <= 0:
693
+ continue
694
+ period = 1.0 / freq_val
596
695
  if period < min_period:
597
- return np.array([], dtype=int)
696
+ return spec.new_empty(0, dtype=torch.long)
598
697
 
599
698
  # Filter by period
600
699
  period_filtered_peaks = []
601
- for idx, prominence in prominent_peaks:
602
- period = 1 / f[idx]
700
+ for idx, prominence in zip(prominent_indices.tolist(), prominent_heights.tolist()):
701
+ freq_val = float(f[idx].item())
702
+ if freq_val <= 0:
703
+ continue
704
+ period = 1.0 / freq_val
603
705
 
604
706
  if min_period <= period <= max_period:
605
707
  period_filtered_peaks.append((idx, prominence))
606
708
 
607
709
  if not period_filtered_peaks:
608
- return np.array([], dtype=int)
710
+ return spec.new_empty(0, dtype=torch.long)
609
711
 
610
712
  # Sort by height and return the top `max_peaks`
611
713
  period_filtered_peaks.sort(key=lambda x: x[1], reverse=True)
612
- peak_indices = np.array([p[0] for p in period_filtered_peaks[:max_peaks]], dtype=int)
714
+ top_indices = [p[0] for p in period_filtered_peaks[:max_peaks]]
715
+ peak_indices = torch.tensor(top_indices, dtype=torch.long, device=spec.device)
613
716
 
614
717
  return peak_indices
615
718
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.11.8
3
+ Version: 2025.11.13
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -87,9 +87,9 @@ def test_frequency_resample_nifr_clamps_large_factors_to_identity():
87
87
 
88
88
  def test_run_fft_analysis_short_series_returns_empty():
89
89
  freqs, spec, peaks = run_fft_analysis(np.array([1.0]))
90
- assert freqs.size == 0
91
- assert spec.size == 0
92
- assert peaks.size == 0
90
+ assert freqs.numel() == 0
91
+ assert spec.numel() == 0
92
+ assert peaks.numel() == 0
93
93
 
94
94
 
95
95
  def test_run_fft_analysis_detects_primary_frequency_with_nans():
@@ -106,7 +106,7 @@ def test_run_fft_analysis_detects_primary_frequency_with_nans():
106
106
  peak_prominence=0.05,
107
107
  )
108
108
 
109
- assert peaks.size > 0
110
- dominant_freq = freqs[peaks[0]]
109
+ assert peaks.numel() > 0
110
+ dominant_freq = freqs[peaks[0]].item()
111
111
  assert math.isclose(dominant_freq, 1 / period, rel_tol=1e-2)
112
- assert math.isclose(float(spec.max()), 1.0, rel_tol=1e-5)
112
+ assert math.isclose(spec.max().item(), 1.0, rel_tol=1e-5)