ionworks-api 0.1.0__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ionworks/validators.py CHANGED
@@ -1,18 +1,527 @@
1
- """
2
- Reusable validator functions and composable pipelines for inbound/outbound value
3
- normalization (e.g., converting between pandas DataFrames and dictionaries).
1
+ """Reusable validator functions and composable pipelines for value normalization.
2
+
3
+ Provides functions for composable inbound/outbound value normalization
4
+ (e.g., converting between pandas DataFrames and dictionaries).
4
5
  """
5
6
 
7
+ from collections.abc import Callable, Iterable
6
8
  import math
9
+ import os
7
10
  import pathlib
8
- from typing import Any, Callable, Iterable
11
+ from typing import Any
9
12
 
13
+ from dotenv import load_dotenv
10
14
  import numpy as np
11
15
  import pandas as pd
12
16
  import polars as pl
13
17
  import pybamm
14
18
  from pybamm.expression_tree.operations.serialise import convert_symbol_to_json
15
19
 
20
+ from .errors import IonworksError
21
+
22
+ # --- DataFrame Backend Configuration ---------------------------------------- #
23
+
24
+ # Load .env file before reading environment variables
25
+ load_dotenv()
26
+
27
+ # Type alias for DataFrame (pandas or polars)
28
+ DataFrame = pd.DataFrame | pl.DataFrame
29
+
30
+
31
+ def _get_default_backend() -> str:
32
+ """Get default backend from environment variable or fall back to 'polars'."""
33
+ env_val = os.getenv("IONWORKS_DATAFRAME_BACKEND", "polars").lower()
34
+ if env_val not in ("polars", "pandas"):
35
+ return "polars"
36
+ return env_val
37
+
38
+
39
+ # Module-level configuration for DataFrame return type
40
+ # Initialized from IONWORKS_DATAFRAME_BACKEND env var, defaults to "polars"
41
+ _dataframe_backend: str = _get_default_backend()
42
+
43
+
44
+ def set_dataframe_backend(backend: str) -> None:
45
+ """Set the default DataFrame backend for data fetching.
46
+
47
+ This overrides the IONWORKS_DATAFRAME_BACKEND environment variable.
48
+
49
+ Parameters
50
+ ----------
51
+ backend : str
52
+ DataFrame backend to use: "polars" or "pandas".
53
+
54
+ Raises
55
+ ------
56
+ ValueError
57
+ If backend is not "polars" or "pandas".
58
+ """
59
+ global _dataframe_backend
60
+ if backend not in ("polars", "pandas"):
61
+ raise ValueError(f"backend must be 'polars' or 'pandas', got '{backend}'")
62
+ _dataframe_backend = backend
63
+
64
+
65
+ def get_dataframe_backend() -> str:
66
+ """Get the current DataFrame backend setting.
67
+
68
+ Returns
69
+ -------
70
+ str
71
+ Current backend: "polars" or "pandas".
72
+ """
73
+ return _dataframe_backend
74
+
75
+
76
+ # --- Measurement Data Validators -------------------------------------------- #
77
+
78
+
79
+ class MeasurementValidationError(IonworksError):
80
+ """Exception raised when measurement data validation fails."""
81
+
82
+ def __init__(self, message: str, errors: list[str] | None = None) -> None:
83
+ super().__init__(message)
84
+ self.errors = errors or []
85
+
86
+
87
+ def _get_column(df: DataFrame, col: str) -> np.ndarray:
88
+ """
89
+ Extract a column as a numpy array from either pandas or polars DataFrame.
90
+
91
+ Parameters
92
+ ----------
93
+ df : DataFrame
94
+ pandas or polars DataFrame.
95
+ col : str
96
+ Column name.
97
+
98
+ Returns
99
+ -------
100
+ np.ndarray
101
+ Column values as numpy array.
102
+ """
103
+ if isinstance(df, pl.DataFrame):
104
+ return df.get_column(col).to_numpy()
105
+ return df[col].to_numpy()
106
+
107
+
108
+ def _has_column(df: DataFrame, col: str) -> bool:
109
+ """Check if a column exists in the DataFrame."""
110
+ return col in df.columns
111
+
112
+
113
+ def _get_step_group_indices(step_data: np.ndarray) -> np.ndarray:
114
+ """Compute step group indices for each row (0-indexed, based on contiguous groups).
115
+
116
+ Parameters
117
+ ----------
118
+ step_data : np.ndarray
119
+ Array of step numbers/identifiers.
120
+
121
+ Returns
122
+ -------
123
+ np.ndarray
124
+ Array where each element is the step group index (0, 1, 2, ...) for
125
+ that row.
126
+ """
127
+ changes = np.concatenate([[True], np.diff(step_data) != 0])
128
+ return np.cumsum(changes) - 1
129
+
130
+
131
+ def validate_positive_current_is_discharge( # noqa: PLR0913
132
+ df: DataFrame,
133
+ current_col: str = "Current [A]",
134
+ voltage_col: str = "Voltage [V]",
135
+ step_col: str | None = None,
136
+ rest_tol: float = 1e-3,
137
+ current_std_tol: float = 0.01,
138
+ ) -> list[str]:
139
+ """
140
+ Validate that positive current corresponds to discharge.
141
+
142
+ Discharge should cause voltage to decrease. This function analyzes the
143
+ relationship between current direction and voltage change to verify the
144
+ sign convention is correct.
145
+
146
+ Parameters
147
+ ----------
148
+ df : DataFrame
149
+ Time series data with current and voltage columns (pandas or polars).
150
+ current_col : str
151
+ Name of the current column.
152
+ voltage_col : str
153
+ Name of the voltage column.
154
+ step_col : str, optional
155
+ Name of the step column. If provided, analyzes per-step. Otherwise,
156
+ infers steps from current sign changes.
157
+ rest_tol : float
158
+ Tolerance for considering current as zero (rest).
159
+ current_std_tol : float
160
+ Tolerance for standard deviation to consider constant current.
161
+
162
+ Returns
163
+ -------
164
+ list[str]
165
+ List of validation error messages. Empty if validation passes.
166
+ """
167
+ if not _has_column(df, current_col) or not _has_column(df, voltage_col):
168
+ return []
169
+
170
+ current = _get_column(df, current_col)
171
+ voltage = _get_column(df, voltage_col)
172
+
173
+ if len(current) == 0:
174
+ return []
175
+
176
+ # Determine step groups
177
+ if step_col and _has_column(df, step_col):
178
+ step_data = _get_column(df, step_col)
179
+ else:
180
+ # Infer steps from current sign changes
181
+ max_abs = np.max(np.abs(current))
182
+ if max_abs == 0:
183
+ return []
184
+ normalized = current / max_abs
185
+ step_data = np.sign(normalized * (np.abs(normalized) > rest_tol))
186
+
187
+ step_groups = _get_step_group_indices(step_data)
188
+ num_steps = step_groups[-1] + 1
189
+
190
+ # Vectorized computation of per-step statistics using bincount
191
+ # Mean current per step
192
+ step_current_sum = np.bincount(step_groups, weights=current, minlength=num_steps)
193
+ step_counts = np.bincount(step_groups, minlength=num_steps).astype(float)
194
+ step_counts[step_counts == 0] = 1 # Avoid division by zero
195
+ mean_current = step_current_sum / step_counts
196
+
197
+ # Std current per step: std = sqrt(E[x^2] - E[x]^2)
198
+ step_current_sq_sum = np.bincount(
199
+ step_groups, weights=current**2, minlength=num_steps
200
+ )
201
+ variance = step_current_sq_sum / step_counts - mean_current**2
202
+ variance = np.maximum(variance, 0) # Numerical stability
203
+ std_current = np.sqrt(variance)
204
+
205
+ # First and last voltage per step
206
+ first_voltage = np.zeros(num_steps)
207
+ last_voltage = np.zeros(num_steps)
208
+ # Use searchsorted on step boundaries for vectorized first/last
209
+ step_boundaries = np.where(np.diff(step_groups, prepend=-1) != 0)[0]
210
+ step_end_boundaries = np.concatenate([step_boundaries[1:], [len(voltage)]])
211
+ first_voltage = voltage[step_boundaries]
212
+ last_voltage = voltage[step_end_boundaries - 1]
213
+
214
+ delta_v = last_voltage - first_voltage
215
+
216
+ # Filter: non-rest steps with constant current
217
+ is_non_rest = np.abs(mean_current) >= rest_tol
218
+ is_constant_current = std_current <= current_std_tol * np.abs(mean_current)
219
+ valid_mask = is_non_rest & is_constant_current
220
+
221
+ if not np.any(valid_mask):
222
+ return []
223
+
224
+ # Compute voltage response for valid steps
225
+ valid_mean_current = mean_current[valid_mask]
226
+ valid_delta_v = delta_v[valid_mask]
227
+ voltage_responses = valid_delta_v / valid_mean_current
228
+
229
+ mean_response = np.mean(voltage_responses)
230
+
231
+ if mean_response > 0:
232
+ return [
233
+ "Current sign convention error: positive current appears to be charge, "
234
+ "not discharge. Voltage increases when current is positive, but for "
235
+ "discharge, voltage should decrease. Please flip the sign of the "
236
+ "current data (multiply by -1)."
237
+ ]
238
+
239
+ return []
240
+
241
+
242
+ def validate_cumulative_values_reset_per_step(
243
+ df: DataFrame,
244
+ step_col: str = "Step count",
245
+ cumulative_cols: list[str] | None = None,
246
+ tolerance: float = 1e-6,
247
+ ) -> list[str]:
248
+ """Validate cumulative values reset to ~0 at each step and only increase.
249
+
250
+ Parameters
251
+ ----------
252
+ df : DataFrame
253
+ Time series data (pandas or polars).
254
+ step_col : str
255
+ Name of the column containing step numbers.
256
+ cumulative_cols : list[str], optional
257
+ List of cumulative column names to validate. If None, checks for common
258
+ capacity and energy columns.
259
+ tolerance : float
260
+ Tolerance for considering a value as "zero" at step start.
261
+
262
+ Returns
263
+ -------
264
+ list[str]
265
+ List of validation error messages. Empty if validation passes.
266
+ """
267
+ errors = []
268
+
269
+ if not _has_column(df, step_col):
270
+ return []
271
+
272
+ if cumulative_cols is None:
273
+ cumulative_cols = [
274
+ "Discharge capacity [A.h]",
275
+ "Charge capacity [A.h]",
276
+ "Discharge energy [W.h]",
277
+ "Charge energy [W.h]",
278
+ ]
279
+
280
+ cols_to_check = [col for col in cumulative_cols if _has_column(df, col)]
281
+ if not cols_to_check:
282
+ return []
283
+
284
+ step_data = _get_column(df, step_col)
285
+ if len(step_data) == 0:
286
+ return []
287
+ step_groups = _get_step_group_indices(step_data)
288
+
289
+ # Find step boundaries (first index of each step)
290
+ step_boundaries = np.where(np.diff(step_groups, prepend=-1) != 0)[0]
291
+
292
+ for col in cols_to_check:
293
+ values = _get_column(df, col)
294
+
295
+ # Check 1: Values at step starts should be ~0
296
+ start_values = values[step_boundaries]
297
+ non_zero_mask = np.abs(start_values) > tolerance
298
+ non_zero_steps = np.where(non_zero_mask)[0]
299
+
300
+ for step_idx in non_zero_steps:
301
+ errors.append(
302
+ f"Column '{col}' does not reset at start of step {step_idx}: "
303
+ f"expected ~0, got {start_values[step_idx]:.6f}. "
304
+ f"Cumulative values should reset to 0 at the start of each step."
305
+ )
306
+
307
+ # Check 2: Values should be monotonically non-decreasing within each step
308
+ # Compute diff and check where it's negative within same step
309
+ value_diffs = np.diff(values, prepend=values[0])
310
+ step_diffs = np.diff(step_groups, prepend=step_groups[0])
311
+
312
+ # Mask: same step (diff == 0) and value decreased
313
+ within_step = step_diffs == 0
314
+ decreased = value_diffs < -tolerance
315
+
316
+ # Find first decrease per step
317
+ problem_indices = np.where(within_step & decreased)[0]
318
+ if len(problem_indices) > 0:
319
+ # Group by step and report first decrease per step
320
+ problem_steps = step_groups[problem_indices]
321
+ unique_problem_steps = np.unique(problem_steps)
322
+
323
+ for step_idx in unique_problem_steps:
324
+ # Find first index in this step with decrease
325
+ step_problem_indices = problem_indices[problem_steps == step_idx]
326
+ first_idx = step_problem_indices[0]
327
+ errors.append(
328
+ f"Column '{col}' decreases within step {step_idx} at "
329
+ f"index {first_idx}: "
330
+ f"value went from {values[first_idx - 1]:.6f} to "
331
+ f"{values[first_idx]:.6f}. "
332
+ f"Cumulative values should only increase within a step."
333
+ )
334
+
335
+ return errors
336
+
337
+
338
+ def validate_minimum_points_per_step(
339
+ df: DataFrame,
340
+ step_col: str = "Step count",
341
+ min_points: int = 2,
342
+ ) -> list[str]:
343
+ """
344
+ Validate that each step has at least a minimum number of data points.
345
+
346
+ Parameters
347
+ ----------
348
+ df : DataFrame
349
+ Time series data (pandas or polars).
350
+ step_col : str
351
+ Name of the column containing step numbers.
352
+ min_points : int
353
+ Minimum number of points required per step.
354
+
355
+ Returns
356
+ -------
357
+ list[str]
358
+ List of validation error messages. Empty if validation passes.
359
+ """
360
+ if not _has_column(df, step_col):
361
+ return []
362
+
363
+ step_data = _get_column(df, step_col)
364
+ if len(step_data) == 0:
365
+ return []
366
+
367
+ step_groups = _get_step_group_indices(step_data)
368
+ num_steps = step_groups[-1] + 1
369
+
370
+ # Vectorized count per step
371
+ step_counts = np.bincount(step_groups, minlength=num_steps)
372
+
373
+ # Find steps with insufficient points
374
+ insufficient_mask = step_counts < min_points
375
+ insufficient_steps = np.where(insufficient_mask)[0]
376
+
377
+ errors = []
378
+ for step_idx in insufficient_steps:
379
+ num_points = step_counts[step_idx]
380
+ errors.append(
381
+ f"Step {step_idx} has only {num_points} data point(s), "
382
+ f"but at least {min_points} are required."
383
+ )
384
+
385
+ return errors
386
+
387
+
388
+ def validate_cycle_constant_within_step(
389
+ df: DataFrame,
390
+ step_col: str = "Step count",
391
+ cycle_col: str | None = None,
392
+ ) -> list[str]:
393
+ """
394
+ Validate that cycle number does not change within a step.
395
+
396
+ Parameters
397
+ ----------
398
+ df : DataFrame
399
+ Time series data (pandas or polars).
400
+ step_col : str
401
+ Name of the column containing step numbers.
402
+ cycle_col : str, optional
403
+ Name of the column containing cycle numbers. If None, tries common names.
404
+
405
+ Returns
406
+ -------
407
+ list[str]
408
+ List of validation error messages. Empty if validation passes.
409
+ """
410
+ if not _has_column(df, step_col):
411
+ return []
412
+
413
+ # Find cycle column
414
+ if cycle_col is None:
415
+ for col in ["Cycle count", "Cycle number", "Cycle from cycler"]:
416
+ if _has_column(df, col):
417
+ cycle_col = col
418
+ break
419
+
420
+ if cycle_col is None or not _has_column(df, cycle_col):
421
+ return []
422
+
423
+ step_data = _get_column(df, step_col)
424
+ if len(step_data) == 0:
425
+ return []
426
+
427
+ cycle_data = _get_column(df, cycle_col)
428
+ step_groups = _get_step_group_indices(step_data)
429
+
430
+ # Detect cycle changes within steps:
431
+ # A cycle change within a step occurs when:
432
+ # - The cycle value differs from the previous row
433
+ # - AND we're in the same step group
434
+ cycle_diffs = np.diff(cycle_data, prepend=cycle_data[0])
435
+ step_diffs = np.diff(step_groups, prepend=step_groups[0])
436
+
437
+ # Within-step cycle change: same step (step_diff == 0) but cycle changed
438
+ within_step_cycle_change = (step_diffs == 0) & (cycle_diffs != 0)
439
+
440
+ problem_indices = np.where(within_step_cycle_change)[0]
441
+ if len(problem_indices) == 0:
442
+ return []
443
+
444
+ # Group by step and report
445
+ problem_steps = step_groups[problem_indices]
446
+ unique_problem_steps = np.unique(problem_steps)
447
+
448
+ errors = []
449
+ for step_idx in unique_problem_steps:
450
+ # Find all unique cycles in this step
451
+ step_mask = step_groups == step_idx
452
+ unique_cycles = np.unique(cycle_data[step_mask])
453
+ errors.append(
454
+ f"Cycle number changes within step {step_idx}: "
455
+ f"found cycles {unique_cycles.tolist()}. "
456
+ f"Each step should belong to a single cycle."
457
+ )
458
+
459
+ return errors
460
+
461
+
462
+ def validate_measurement_data(
463
+ df: DataFrame,
464
+ strict: bool = True,
465
+ ) -> None:
466
+ """Validate measurement time series data before upload.
467
+
468
+ Performs the following checks:
469
+
470
+ 1. Positive current should correspond to discharge (voltage decreases)
471
+ 2. Cumulative values (capacity, energy) should reset at each step start
472
+ and only increase within steps
473
+ 3. Each step has at least 2 data points (strict mode only)
474
+ 4. Cycle number does not change within a step (strict mode only)
475
+
476
+ Parameters
477
+ ----------
478
+ df : DataFrame
479
+ Time series data to validate (pandas or polars DataFrame).
480
+ strict : bool
481
+ If True (default), run additional checks: minimum 2 points per step
482
+ and cycle number constant within each step.
483
+
484
+ Raises
485
+ ------
486
+ MeasurementValidationError
487
+ If any validation checks fail. The exception contains a list of all
488
+ errors found.
489
+ """
490
+ all_errors = []
491
+
492
+ # Try different possible step column names
493
+ step_col = None
494
+ for col in ["Step count", "Step number", "Step from cycler"]:
495
+ if _has_column(df, col):
496
+ step_col = col
497
+ break
498
+
499
+ # Check 1: Positive current should be discharge
500
+ current_errors = validate_positive_current_is_discharge(df, step_col=step_col)
501
+ all_errors.extend(current_errors)
502
+
503
+ if step_col:
504
+ # Check 2: Cumulative values should reset at each step
505
+ cumulative_errors = validate_cumulative_values_reset_per_step(df, step_col)
506
+ all_errors.extend(cumulative_errors)
507
+
508
+ if strict:
509
+ # Check 3: At least 2 points per step
510
+ points_errors = validate_minimum_points_per_step(df, step_col)
511
+ all_errors.extend(points_errors)
512
+
513
+ # Check 4: Cycle constant within step
514
+ cycle_errors = validate_cycle_constant_within_step(df, step_col)
515
+ all_errors.extend(cycle_errors)
516
+
517
+ if all_errors:
518
+ raise MeasurementValidationError(
519
+ f"Measurement data validation failed with {len(all_errors)} error(s):\n"
520
+ + "\n".join(f" - {err}" for err in all_errors),
521
+ errors=all_errors,
522
+ )
523
+
524
+
16
525
  # --- Atomic validators ------------------------------------------------------ #
17
526
 
18
527
 
@@ -21,23 +530,41 @@ def df_to_dict_validator(v: Any) -> Any:
21
530
  if isinstance(v, pd.DataFrame):
22
531
  # Replace NaN with None for JSON compatibility
23
532
  return v.replace(np.nan, None).to_dict(orient="list")
24
- elif isinstance(v, pl.DataFrame):
533
+ if isinstance(v, pl.DataFrame):
25
534
  # Replace NaN with None for JSON compatibility, then convert to dict
26
535
  return v.fill_nan(None).to_dict(as_series=False)
27
536
  return v
28
537
 
29
538
 
30
- def dict_to_df_validator(v: Any) -> Any:
31
- """Convert dict to DataFrame for data processing."""
539
+ def dict_to_df_validator(v: Any, return_type: str | None = None) -> Any:
540
+ """Convert dict to DataFrame for data processing.
541
+
542
+ Parameters
543
+ ----------
544
+ v : Any
545
+ Value to convert. If dict, converts to DataFrame.
546
+ return_type : str | None
547
+ Type of DataFrame to return: "polars" or "pandas".
548
+ If None, uses the global setting from set_dataframe_backend().
549
+
550
+ Returns
551
+ -------
552
+ Any
553
+ DataFrame if input was dict, otherwise unchanged.
554
+ """
32
555
  if isinstance(v, dict):
33
- try:
34
- return pd.DataFrame(v)
35
- except ValueError as e:
36
- if "If using all scalar values, you must pass an index" in str(e):
37
- # Handle case where all values are scalars by providing an index
556
+ backend = return_type if return_type is not None else _dataframe_backend
557
+ # Check if all values are scalars (not lists/arrays)
558
+ all_scalars = all(
559
+ not isinstance(val, list | tuple | np.ndarray) for val in v.values()
560
+ )
561
+ if backend == "pandas":
562
+ if all_scalars:
38
563
  return pd.DataFrame(v, index=[0])
39
- else:
40
- raise
564
+ return pd.DataFrame(v)
565
+ if all_scalars:
566
+ return pl.DataFrame({k: [val] for k, val in v.items()})
567
+ return pl.DataFrame(v)
41
568
  return v
42
569
 
43
570
 
@@ -49,17 +576,19 @@ def parameter_validator(v: Any) -> Any:
49
576
 
50
577
 
51
578
  def float_sanitizer(v: Any) -> Any:
52
- """Sanitize float values to JSON-compatible forms. Currently removes NaN and
53
- infinity values."""
579
+ """Sanitize float values to JSON-compatible forms.
580
+
581
+ Currently removes NaN and infinity values.
582
+ """
54
583
  if isinstance(v, float):
55
584
  if math.isinf(v):
56
585
  return "Infinity" if v > 0 else "-Infinity"
57
- elif np.isnan(v):
586
+ if np.isnan(v):
58
587
  return None
59
588
  elif isinstance(v, np.floating):
60
589
  if np.isinf(v):
61
590
  return "Infinity" if v > 0 else "-Infinity"
62
- elif np.isnan(v):
591
+ if np.isnan(v):
63
592
  return None
64
593
  return v
65
594
 
@@ -67,10 +596,6 @@ def float_sanitizer(v: Any) -> Any:
67
596
  def bounds_tuple_validator(v: Any) -> Any:
68
597
  """Convert bounds 2-tuple to list for JSON serialization.
69
598
 
70
- Converts tuples with exactly 2 elements to lists. This is useful for
71
- bounds parameters that may be provided as tuples (lower, upper) but
72
- need to be serialized as lists.
73
-
74
599
  Parameters
75
600
  ----------
76
601
  v : Any
@@ -87,27 +612,23 @@ def bounds_tuple_validator(v: Any) -> Any:
87
612
 
88
613
 
89
614
  def file_scheme_validator(v: Any) -> Any:
90
- """
91
- Convert file:// and folder:// scheme paths to serialized dicts.
615
+ """Convert file:// and folder:// scheme paths to serialized dicts.
92
616
 
93
- Handles:
94
- - "file:" prefixed paths: loads CSV file as dict (serialized)
95
- - "folder:" prefixed paths: loads time_series.csv and steps.csv as dict
96
- - All other values: returned unchanged
617
+ Handles ``file:`` prefixed paths (loads CSV as dict) and ``folder:``
618
+ prefixed paths (loads time_series.csv and steps.csv as dict).
619
+ All other values are returned unchanged.
97
620
 
98
621
  Raises
99
622
  ------
100
623
  FileNotFoundError
101
- If the file or folder path doesn't exist
102
- Exception
103
- If reading the CSV file fails for any other reason
624
+ If the file or folder path doesn't exist.
104
625
  """
105
626
  if isinstance(v, str) and v.startswith("file:"):
106
627
  path = pathlib.Path(v.split(":")[1]).expanduser().resolve()
107
628
  if not path.exists() or not path.is_file():
108
629
  raise FileNotFoundError(f"CSV file not found: {v}")
109
630
  return df_to_dict_validator(pd.read_csv(path))
110
- elif isinstance(v, str) and v.startswith("folder:"):
631
+ if isinstance(v, str) and v.startswith("folder:"):
111
632
  path = pathlib.Path(v.split(":")[1]).expanduser().resolve()
112
633
  if not path.exists() or not path.is_dir():
113
634
  raise FileNotFoundError(f"Folder not found: {v}")