steer-core 0.1.28__tar.gz → 0.1.32__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {steer_core-0.1.28 → steer_core-0.1.32}/PKG-INFO +1 -1
  2. {steer_core-0.1.28/steer_core → steer_core-0.1.32/steer_core/Data}/DataManager.py +76 -1
  3. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Mixins/Coordinates.py +77 -22
  4. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Mixins/Plotter.py +112 -17
  5. steer_core-0.1.32/steer_core/Mixins/Serializer.py +281 -0
  6. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Mixins/TypeChecker.py +2 -1
  7. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/__init__.py +1 -4
  8. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core.egg-info/PKG-INFO +1 -1
  9. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core.egg-info/SOURCES.txt +1 -2
  10. steer_core-0.1.28/steer_core/Data/database.db +0 -0
  11. steer_core-0.1.28/steer_core/Mixins/Serializer.py +0 -45
  12. {steer_core-0.1.28 → steer_core-0.1.32}/README.md +0 -0
  13. {steer_core-0.1.28 → steer_core-0.1.32}/pyproject.toml +0 -0
  14. {steer_core-0.1.28 → steer_core-0.1.32}/setup.cfg +0 -0
  15. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Constants/Units.py +0 -0
  16. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Constants/Universal.py +0 -0
  17. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Constants/__init__.py +0 -0
  18. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/ContextManagers/ContextManagers.py +0 -0
  19. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/ContextManagers/__init__.py +0 -0
  20. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Data/__init__.py +0 -0
  21. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Decorators/Coordinates.py +0 -0
  22. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Decorators/General.py +0 -0
  23. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Decorators/Objects.py +0 -0
  24. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Decorators/__init__.py +0 -0
  25. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Mixins/Colors.py +0 -0
  26. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Mixins/Data.py +0 -0
  27. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Mixins/Dunder.py +0 -0
  28. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Mixins/__init__.py +0 -0
  29. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core.egg-info/dependency_links.txt +0 -0
  30. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core.egg-info/requires.txt +0 -0
  31. {steer_core-0.1.28 → steer_core-0.1.32}/steer_core.egg-info/top_level.txt +0 -0
  32. {steer_core-0.1.28 → steer_core-0.1.32}/test/test_validation_mixin.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: steer-core
3
- Version: 0.1.28
3
+ Version: 0.1.32
4
4
  Summary: Modelling energy storage from cell to site - STEER OpenCell Design
5
5
  Author-email: Nicholas Siemons <nsiemons@stanford.edu>
6
6
  Maintainer-email: Nicholas Siemons <nsiemons@stanford.edu>
@@ -1,15 +1,20 @@
1
1
  import sqlite3 as sql
2
2
  from pathlib import Path
3
+ from typing import TypeVar
3
4
  import pandas as pd
4
5
  import importlib.resources
5
6
 
6
7
  from steer_core.Constants.Units import *
8
+ from steer_core.Mixins.Serializer import SerializerMixin
9
+
10
+
11
+ T = TypeVar('T', bound='SerializerMixin')
7
12
 
8
13
 
9
14
  class DataManager:
10
15
 
11
16
  def __init__(self):
12
- with importlib.resources.path("steer_core.Data", "database.db") as db_path:
17
+ with importlib.resources.path("steer_opencell_data", "database.db") as db_path:
13
18
  self._db_path = db_path
14
19
  self._connection = sql.connect(self._db_path)
15
20
  self._cursor = self._connection.cursor()
@@ -356,5 +361,75 @@ class DataManager:
356
361
  self._cursor.execute(f"DELETE FROM {table_name} WHERE {condition}")
357
362
  self._connection.commit()
358
363
 
364
+
365
+ @classmethod
366
+ def from_database(cls: type[T], name: str, table_name: str = None) -> T:
367
+ """
368
+ Pull object from the database by name.
369
+
370
+ Subclasses must define a '_table_name' class variable (str or list of str)
371
+ unless table_name is explicitly provided.
372
+
373
+ Parameters
374
+ ----------
375
+ name : str
376
+ Name of the object to retrieve.
377
+ table_name : str, optional
378
+ Specific table to search. If provided, '_table_name' is not required.
379
+ If None, uses class's _table_name.
380
+
381
+ Returns
382
+ -------
383
+ T
384
+ Instance of the class.
385
+
386
+ Raises
387
+ ------
388
+ NotImplementedError
389
+ If the subclass doesn't define '_table_name' and table_name is not provided.
390
+ ValueError
391
+ If the object name is not found in any of the tables.
392
+ """
393
+ database = cls()
394
+
395
+ # Get list of tables to search
396
+ if table_name:
397
+ tables_to_search = [table_name]
398
+ else:
399
+ # Only check for _table_name if table_name wasn't provided
400
+ if not hasattr(cls, '_table_name'):
401
+ raise NotImplementedError(
402
+ f"{cls.__name__} must define a '_table_name' class variable "
403
+ "or provide 'table_name' argument"
404
+ )
405
+
406
+ if isinstance(cls._table_name, (list, tuple)):
407
+ tables_to_search = cls._table_name
408
+ else:
409
+ tables_to_search = [cls._table_name]
410
+
411
+ # Try each table until found
412
+ for table in tables_to_search:
413
+ available_materials = database.get_unique_values(table, "name")
414
+
415
+ if name in available_materials:
416
+ data = database.get_data(table, condition=f"name = '{name}'")
417
+ serialized_bytes = data["object"].iloc[0]
418
+ return cls.deserialize(serialized_bytes)
419
+
420
+ # Not found in any table
421
+ all_available = []
422
+ for table in tables_to_search:
423
+ all_available.extend(database.get_unique_values(table, "name"))
424
+
425
+ raise ValueError(
426
+ f"'{name}' not found in tables {tables_to_search}. "
427
+ f"Available: {all_available}"
428
+ )
429
+
430
+
431
+
359
432
  def __del__(self):
360
433
  self._connection.close()
434
+
435
+
@@ -136,25 +136,39 @@ class CoordinateMixin:
136
136
  center: tuple = None
137
137
  ) -> np.ndarray:
138
138
  """
139
- Rotate a (N, 3) NumPy array of 3D coordinates around the specified axis.
139
+ Rotate a NumPy array of coordinates around the specified axis.
140
+ Can handle 2D coordinates (N, 2) for x, y or 3D coordinates (N, 3) for x, y, z.
140
141
  Can handle coordinates with None values (preserves None positions).
141
142
 
142
- :param coords: NumPy array of shape (N, 3), where columns are x, y, z
143
- :param axis: Axis to rotate around ('x', 'y', or 'z')
143
+ :param coords: NumPy array of shape (N, 2) for 2D or (N, 3) for 3D coordinates
144
+ :param axis: Axis to rotate around ('x', 'y', or 'z'). For 2D arrays, only 'z' is valid.
144
145
  :param angle: Angle in degrees
145
- :param center: Point to rotate around as (x, y, z) tuple. If None, rotates around origin.
146
- :return: Rotated NumPy array of shape (N, 3)
146
+ :param center: Point to rotate around. For 2D: (x, y) tuple. For 3D: (x, y, z) tuple.
147
+ If None, rotates around origin.
148
+ :return: Rotated NumPy array with same shape as input
147
149
  """
148
- if coords.shape[1] != 3:
150
+ # Check if 2D or 3D coordinates
151
+ is_2d = coords.shape[1] == 2
152
+ is_3d = coords.shape[1] == 3
153
+
154
+ if not (is_2d or is_3d):
155
+ raise ValueError(
156
+ "Input array must have shape (N, 2) for 2D or (N, 3) for 3D coordinates"
157
+ )
158
+
159
+ # For 2D arrays, only z-axis rotation is valid
160
+ if is_2d and axis != 'z':
149
161
  raise ValueError(
150
- "Input array must have shape (N, 3) for x, y, z coordinates"
162
+ "For 2D coordinates (x, y), only 'z' axis rotation is supported"
151
163
  )
152
164
 
153
165
  # Validate center parameter
154
166
  if center is not None:
155
- if not isinstance(center, (tuple, list)) or len(center) != 3:
167
+ expected_len = 2 if is_2d else 3
168
+ if not isinstance(center, (tuple, list)) or len(center) != expected_len:
169
+ coord_type = "(x, y)" if is_2d else "(x, y, z)"
156
170
  raise ValueError(
157
- "Center must be a tuple or list of 3 coordinates (x, y, z)"
171
+ f"Center must be a tuple or list of {expected_len} coordinates {coord_type}"
158
172
  )
159
173
  if not all(isinstance(coord, (int, float)) for coord in center):
160
174
  raise TypeError("All center coordinates must be numbers")
@@ -204,10 +218,11 @@ class CoordinateMixin:
204
218
  """
205
219
  Rotate coordinates around a specified center point.
206
220
 
207
- :param coords: NumPy array of shape (N, 3) with valid coordinates
221
+ :param coords: NumPy array of shape (N, 2) or (N, 3) with valid coordinates
208
222
  :param axis: Axis to rotate around ('x', 'y', or 'z')
209
223
  :param angle: Angle in degrees
210
- :param center: Center point as np.array of shape (3,). If None, rotates around origin.
224
+ :param center: Center point as np.array. Shape (2,) for 2D or (3,) for 3D.
225
+ If None, rotates around origin.
211
226
  :return: Rotated coordinates
212
227
  """
213
228
  if center is None:
@@ -294,13 +309,37 @@ class CoordinateMixin:
294
309
  @staticmethod
295
310
  def order_coordinates_clockwise(df: pd.DataFrame, plane="xy") -> pd.DataFrame:
296
311
 
312
+ df = df.copy()
313
+
297
314
  axis_1 = plane[0]
298
315
  axis_2 = plane[1]
299
316
 
300
- cx = df[axis_1].mean()
301
- cy = df[axis_2].mean()
302
-
303
- angles = np.arctan2(df[axis_2] - cy, df[axis_1] - cx)
317
+ # Find column names that match the axis (case-insensitive, with or without units)
318
+ def find_column(axis_char: str) -> str:
319
+ axis_char_lower = axis_char.lower()
320
+ # First try exact match
321
+ if axis_char in df.columns:
322
+ return axis_char
323
+ # Try lowercase
324
+ if axis_char_lower in df.columns:
325
+ return axis_char_lower
326
+ # Try uppercase
327
+ if axis_char.upper() in df.columns:
328
+ return axis_char.upper()
329
+ # Try with units pattern like "X (mm)", "x (mm)", etc.
330
+ for col in df.columns:
331
+ col_stripped = col.split()[0].lower() if ' ' in col else col.lower()
332
+ if col_stripped == axis_char_lower:
333
+ return col
334
+ raise KeyError(f"Could not find column for axis '{axis_char}' in dataframe columns: {list(df.columns)}")
335
+
336
+ axis_1_col = find_column(axis_1)
337
+ axis_2_col = find_column(axis_2)
338
+
339
+ cx = df[axis_1_col].mean()
340
+ cy = df[axis_2_col].mean()
341
+
342
+ angles = np.arctan2(df[axis_2_col] - cy, df[axis_1_col] - cx)
304
343
 
305
344
  df["angle"] = angles
306
345
 
@@ -533,19 +572,35 @@ class CoordinateMixin:
533
572
  ) -> np.ndarray:
534
573
  """
535
574
  Rotate coordinates without None values using rotation matrices.
575
+ Handles both 2D (N, 2) and 3D (N, 3) coordinate arrays.
576
+
577
+ :param coords: NumPy array of shape (N, 2) or (N, 3)
578
+ :param axis: Axis to rotate around ('x', 'y', or 'z')
579
+ :param angle: Angle in degrees
580
+ :return: Rotated coordinates with same shape as input
536
581
  """
537
582
  angle_rad = np.radians(angle)
538
583
  cos_a = np.cos(angle_rad)
539
584
  sin_a = np.sin(angle_rad)
585
+
586
+ is_2d = coords.shape[1] == 2
540
587
 
541
- if axis == "x":
542
- R = np.array([[1, 0, 0], [0, cos_a, -sin_a], [0, sin_a, cos_a]])
543
- elif axis == "y":
544
- R = np.array([[cos_a, 0, sin_a], [0, 1, 0], [-sin_a, 0, cos_a]])
545
- elif axis == "z":
546
- R = np.array([[cos_a, -sin_a, 0], [sin_a, cos_a, 0], [0, 0, 1]])
588
+ if is_2d:
589
+ # For 2D coordinates, only z-axis rotation applies (rotation in xy plane)
590
+ if axis != 'z':
591
+ raise ValueError("For 2D coordinates, only 'z' axis rotation is supported")
592
+ # 2D rotation matrix
593
+ R = np.array([[cos_a, -sin_a], [sin_a, cos_a]])
547
594
  else:
548
- raise ValueError("Axis must be 'x', 'y', or 'z'.")
595
+ # 3D rotation matrices
596
+ if axis == "x":
597
+ R = np.array([[1, 0, 0], [0, cos_a, -sin_a], [0, sin_a, cos_a]])
598
+ elif axis == "y":
599
+ R = np.array([[cos_a, 0, sin_a], [0, 1, 0], [-sin_a, 0, cos_a]])
600
+ elif axis == "z":
601
+ R = np.array([[cos_a, -sin_a, 0], [sin_a, cos_a, 0], [0, 0, 1]])
602
+ else:
603
+ raise ValueError("Axis must be 'x', 'y', or 'z'.")
549
604
 
550
605
  return coords @ R.T
551
606
 
@@ -36,6 +36,7 @@ class PlotterMixin:
36
36
 
37
37
  SCHEMATIC_Z_AXIS = dict(
38
38
  zeroline=False,
39
+ scaleanchor="x",
39
40
  title="Z (mm)"
40
41
  )
41
42
 
@@ -55,7 +56,8 @@ class PlotterMixin:
55
56
  line_width,
56
57
  color_func,
57
58
  unit_conversion_factor,
58
- order_clockwise: str = None
59
+ order_clockwise: str = None,
60
+ gl: bool = False
59
61
  ):
60
62
  """
61
63
  Create a single trace for a component or group of components with NaN separators.
@@ -115,17 +117,30 @@ class PlotterMixin:
115
117
  z_coords = combined_coords[:, 2] * unit_conversion_factor
116
118
 
117
119
  # Create trace
118
- return go.Scatter(
119
- x=y_coords,
120
- y=z_coords,
121
- mode="lines",
122
- name=name,
123
- line={'width': line_width, 'color': "black"},
124
- fill="toself",
125
- fillcolor=color_func(components[0]),
126
- legendgroup=name,
127
- showlegend=True,
128
- )
120
+ if gl:
121
+ return go.Scattergl(
122
+ x=y_coords,
123
+ y=z_coords,
124
+ mode="lines",
125
+ name=name,
126
+ line={'width': line_width, 'color': "black"},
127
+ fill="toself",
128
+ fillcolor=color_func(components[0]),
129
+ legendgroup=name,
130
+ showlegend=True,
131
+ )
132
+ else:
133
+ return go.Scatter(
134
+ x=y_coords,
135
+ y=z_coords,
136
+ mode="lines",
137
+ name=name,
138
+ line={'width': line_width, 'color': "black"},
139
+ fill="toself",
140
+ fillcolor=color_func(components[0]),
141
+ legendgroup=name,
142
+ showlegend=True,
143
+ )
129
144
 
130
145
  @staticmethod
131
146
  def plot_breakdown_sunburst(
@@ -133,6 +148,7 @@ class PlotterMixin:
133
148
  title: str = "Breakdown",
134
149
  root_label: str = "Total",
135
150
  unit: str = "",
151
+ colorway: List[str] = None,
136
152
  **kwargs,
137
153
  ) -> go.Figure:
138
154
  """
@@ -149,12 +165,22 @@ class PlotterMixin:
149
165
  Label for the root node. Defaults to "Total".
150
166
  unit : str, optional
151
167
  Unit string to display in hover text (e.g., "g", "kg", "%"). Defaults to "".
168
+ colorway : List[str], optional
169
+ List of colors to use for the inner ring. If None, uses Plotly's default colorway.
170
+ Defaults to None.
152
171
 
153
172
  Returns
154
173
  -------
155
174
  go.Figure
156
175
  Plotly sunburst figure
157
176
  """
177
+
178
+ # Default Plotly colorway if none provided
179
+ if colorway is None:
180
+ colorway = [
181
+ '#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A',
182
+ '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52'
183
+ ]
158
184
 
159
185
  def _flatten_breakdown_values(data: Dict[str, Any]) -> List[float]:
160
186
  """Recursively flatten all numeric values from nested breakdown dictionary"""
@@ -177,13 +203,14 @@ class PlotterMixin:
177
203
  return total
178
204
 
179
205
  def _prepare_sunburst_data(
180
- data: Dict[str, Any], parent_id: str = "", current_path: str = ""
181
- ) -> Tuple[List[str], List[str], List[str], List[float]]:
206
+ data: Dict[str, Any], parent_id: str = "", current_path: str = "", depth: int = 1
207
+ ) -> Tuple[List[str], List[str], List[str], List[float], List[int]]:
182
208
  """Recursively prepare data for sunburst plot with proper hierarchy"""
183
209
  ids = []
184
210
  labels = []
185
211
  parents = []
186
212
  values = []
213
+ depths = []
187
214
 
188
215
  for key, value in data.items():
189
216
  # Create unique ID for this node
@@ -192,6 +219,7 @@ class PlotterMixin:
192
219
  ids.append(node_id)
193
220
  labels.append(key)
194
221
  parents.append(parent_id)
222
+ depths.append(depth)
195
223
 
196
224
  if isinstance(value, dict):
197
225
  # This is a nested dictionary - calculate its total value
@@ -204,8 +232,9 @@ class PlotterMixin:
204
232
  nested_labels,
205
233
  nested_parents,
206
234
  nested_values,
235
+ nested_depths,
207
236
  ) = _prepare_sunburst_data(
208
- value, parent_id=node_id, current_path=node_id
237
+ value, parent_id=node_id, current_path=node_id, depth=depth + 1
209
238
  )
210
239
 
211
240
  # Add nested data to our lists
@@ -213,18 +242,19 @@ class PlotterMixin:
213
242
  labels.extend(nested_labels)
214
243
  parents.extend(nested_parents)
215
244
  values.extend(nested_values)
245
+ depths.extend(nested_depths)
216
246
 
217
247
  elif isinstance(value, (int, float)):
218
248
  # This is a leaf node with a numeric value
219
249
  values.append(float(value))
220
250
 
221
- return ids, labels, parents, values
251
+ return ids, labels, parents, values, depths
222
252
 
223
253
  # Calculate total value for root node
224
254
  total_value = _calculate_subtotal(breakdown_dict)
225
255
 
226
256
  # Prepare hierarchical data starting with root
227
- ids, labels, parents, values = _prepare_sunburst_data(
257
+ ids, labels, parents, values, depths = _prepare_sunburst_data(
228
258
  breakdown_dict, parent_id=""
229
259
  )
230
260
 
@@ -233,12 +263,74 @@ class PlotterMixin:
233
263
  labels.insert(0, root_label)
234
264
  parents.insert(0, "")
235
265
  values.insert(0, total_value)
266
+ depths.insert(0, 0)
236
267
 
237
268
  # Update parent references to point to root
238
269
  for i in range(1, len(parents)):
239
270
  if parents[i] == "":
240
271
  parents[i] = root_label
241
272
 
273
+ # Generate colors based on alphabetical ordering and depth
274
+ def _hex_to_rgb(hex_color: str) -> Tuple[int, int, int]:
275
+ """Convert hex color to RGB tuple"""
276
+ hex_color = hex_color.lstrip('#')
277
+ return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
278
+
279
+ def _rgb_to_hex(rgb: Tuple[int, int, int]) -> str:
280
+ """Convert RGB tuple to hex color"""
281
+ return '#{:02x}{:02x}{:02x}'.format(int(rgb[0]), int(rgb[1]), int(rgb[2]))
282
+
283
+ def _lighten_color(hex_color: str, factor: float) -> str:
284
+ """Lighten a color by blending with white. Factor 0=original, 1=white"""
285
+ r, g, b = _hex_to_rgb(hex_color)
286
+ # Blend with white (255, 255, 255)
287
+ r = r + (255 - r) * factor
288
+ g = g + (255 - g) * factor
289
+ b = b + (255 - b) * factor
290
+ return _rgb_to_hex((r, g, b))
291
+
292
+ # Get first-level keys (children of root) and sort alphabetically
293
+ first_level_keys = sorted([key for key in breakdown_dict.keys()])
294
+
295
+ # Assign base colors to first-level keys
296
+ key_to_base_color = {}
297
+ for i, key in enumerate(first_level_keys):
298
+ key_to_base_color[key] = colorway[i % len(colorway)]
299
+
300
+ # Assign colors to all nodes
301
+ marker_colors = []
302
+ max_depth = max(depths) if depths else 0
303
+
304
+ for i, (node_id, label, parent, depth) in enumerate(zip(ids, labels, parents, depths)):
305
+ if depth == 0:
306
+ # Root node - use neutral color
307
+ marker_colors.append('#CCCCCC')
308
+ elif depth == 1:
309
+ # First level - use assigned base color
310
+ marker_colors.append(key_to_base_color[label])
311
+ else:
312
+ # Deeper levels - find the first-level ancestor and lighten its color
313
+ # Trace back through parents to find first-level ancestor
314
+ current_parent = parent
315
+ ancestor_label = None
316
+
317
+ for j, (check_id, check_label, check_depth) in enumerate(zip(ids, labels, depths)):
318
+ if check_id == current_parent:
319
+ if check_depth == 1:
320
+ ancestor_label = check_label
321
+ break
322
+ current_parent = parents[j]
323
+
324
+ if ancestor_label and ancestor_label in key_to_base_color:
325
+ base_color = key_to_base_color[ancestor_label]
326
+ # Lighten based on depth (depth 2 gets 0.3, depth 3 gets 0.5, depth 4 gets 0.7, etc.)
327
+ lighten_factor = 0.2 + (depth - 1) * 0.25
328
+ lighten_factor = min(lighten_factor, 0.85) # Cap at 0.85 to avoid too pale
329
+ marker_colors.append(_lighten_color(base_color, lighten_factor))
330
+ else:
331
+ # Fallback to neutral color
332
+ marker_colors.append('#DDDDDD')
333
+
242
334
  # Create custom hover text with percentages
243
335
  hover_text = []
244
336
  for i, (label, value) in enumerate(zip(labels, values)):
@@ -262,6 +354,7 @@ class PlotterMixin:
262
354
  branchvalues="total",
263
355
  hovertemplate="%{customdata}<extra></extra>",
264
356
  customdata=hover_text,
357
+ marker=dict(colors=marker_colors),
265
358
  )
266
359
  )
267
360
 
@@ -273,3 +366,5 @@ class PlotterMixin:
273
366
 
274
367
 
275
368
 
369
+
370
+
@@ -0,0 +1,281 @@
1
+ import msgpack
2
+ import msgpack_numpy as m
3
+ import zlib
4
+ from typing import TypeVar, Any
5
+ from datetime import datetime
6
+ from enum import Enum
7
+
8
+ T = TypeVar('T', bound='SerializerMixin')
9
+
10
+
11
+ class SerializerMixin:
12
+
13
+ def serialize(self, compress: bool = True) -> bytes:
14
+ """
15
+ Serialize object using MessagePack with numpy support.
16
+
17
+ Parameters
18
+ ----------
19
+ compress : bool, optional
20
+ Whether to compress the serialized data (default: True)
21
+
22
+ Returns
23
+ -------
24
+ bytes
25
+ The serialized byte representation of the object.
26
+ """
27
+ m.patch() # Enable numpy support
28
+ data = msgpack.packb(self._to_dict(), use_bin_type=True)
29
+
30
+ if compress:
31
+ # Add marker byte to indicate compression
32
+ return b'\x01' + zlib.compress(data, level=6)
33
+ else:
34
+ return b'\x00' + data
35
+
36
+ def _serialize_value(self, value: Any) -> Any:
37
+ """
38
+ Recursively serialize a value, handling nested structures.
39
+
40
+ Parameters
41
+ ----------
42
+ value : Any
43
+ The value to serialize.
44
+
45
+ Returns
46
+ -------
47
+ Any
48
+ The serialized representation.
49
+ """
50
+ if hasattr(value, '_to_dict'):
51
+ # Add marker and class info for object reconstruction
52
+ return {
53
+ '__object__': True,
54
+ '_class': f"{value.__class__.__module__}.{value.__class__.__name__}",
55
+ **value._to_dict()
56
+ }
57
+ elif isinstance(value, datetime):
58
+ return {'__datetime__': value.isoformat()}
59
+ elif isinstance(value, Enum):
60
+ return {
61
+ '__enum__': True,
62
+ 'class': f"{value.__class__.__module__}.{value.__class__.__name__}",
63
+ 'value': value.value
64
+ }
65
+ elif isinstance(value, tuple):
66
+ # Recursively serialize tuple items
67
+ return {
68
+ '__tuple__': True,
69
+ 'items': [self._serialize_value(item) for item in value]
70
+ }
71
+ elif isinstance(value, list):
72
+ # Recursively serialize list items
73
+ return [self._serialize_value(item) for item in value]
74
+ elif isinstance(value, dict):
75
+ # Handle dictionaries with object keys or values
76
+ has_object_keys = value and any(hasattr(k, '_to_dict') for k in value.keys())
77
+ has_object_values = value and any(hasattr(v, '_to_dict') for v in value.values())
78
+
79
+ if has_object_keys or has_object_values:
80
+ return {
81
+ '__object_dict__': True,
82
+ 'items': [
83
+ {
84
+ 'key': self._serialize_value(k),
85
+ 'value': self._serialize_value(v)
86
+ }
87
+ for k, v in value.items()
88
+ ]
89
+ }
90
+ else:
91
+ # Recursively serialize regular dict values
92
+ return {k: self._serialize_value(v) for k, v in value.items()}
93
+ else:
94
+ return value
95
+
96
+ def _to_dict(self) -> dict:
97
+ """
98
+ Convert object to dictionary for serialization.
99
+ Override this in subclasses to customize serialization behavior.
100
+
101
+ Returns
102
+ -------
103
+ dict
104
+ Dictionary representation of the object.
105
+ """
106
+ result = {}
107
+ for key, value in self.__dict__.items():
108
+ result[key] = self._serialize_value(value)
109
+ return result
110
+
111
+ @classmethod
112
+ def deserialize(cls: type[T], data: bytes) -> T:
113
+ """
114
+ Deserialize byte data into an object.
115
+ Automatically detects and decompresses if needed.
116
+
117
+ Parameters
118
+ ----------
119
+ data : bytes
120
+ The byte data to deserialize.
121
+
122
+ Returns
123
+ -------
124
+ T
125
+ Instance of the class.
126
+ """
127
+ m.patch() # Enable numpy support
128
+
129
+ # Check compression marker
130
+ if data[0:1] == b'\x01':
131
+ data = zlib.decompress(data[1:])
132
+ else:
133
+ data = data[1:]
134
+
135
+ obj_dict = msgpack.unpackb(data, raw=False)
136
+ return cls._from_dict(obj_dict)
137
+
138
+ @classmethod
139
+ def _deserialize_value(cls, value: Any) -> Any:
140
+ """
141
+ Recursively deserialize a value, handling nested structures.
142
+
143
+ Parameters
144
+ ----------
145
+ value : Any
146
+ The value to deserialize.
147
+
148
+ Returns
149
+ -------
150
+ Any
151
+ The deserialized object.
152
+ """
153
+ if isinstance(value, dict):
154
+ if '__datetime__' in value:
155
+ return datetime.fromisoformat(value['__datetime__'])
156
+ elif '__enum__' in value:
157
+ # Reconstruct enum
158
+ module_name, class_name = value['class'].rsplit('.', 1)
159
+ import importlib
160
+ module = importlib.import_module(module_name)
161
+ enum_class = getattr(module, class_name)
162
+ return enum_class(value['value'])
163
+ elif '__tuple__' in value:
164
+ # Recursively reconstruct tuple items
165
+ return tuple(cls._deserialize_value(item) for item in value['items'])
166
+ elif '__object__' in value:
167
+ # Reconstruct regular object
168
+ import importlib
169
+ module_name, class_name = value['_class'].rsplit('.', 1)
170
+ module = importlib.import_module(module_name)
171
+ obj_class = getattr(module, class_name)
172
+ # Remove marker fields before passing to _from_dict
173
+ obj_data = {k: v for k, v in value.items() if k not in ('__object__', '_class')}
174
+ return obj_class._from_dict(obj_data)
175
+ elif '__object_dict__' in value:
176
+ # Reconstruct dictionary with object keys or values
177
+ reconstructed_dict = {}
178
+ for item in value['items']:
179
+ key_obj = cls._deserialize_value(item['key'])
180
+ value_obj = cls._deserialize_value(item['value'])
181
+ reconstructed_dict[key_obj] = value_obj
182
+ return reconstructed_dict
183
+ else:
184
+ # Recursively deserialize regular dict values
185
+ return {k: cls._deserialize_value(v) for k, v in value.items()}
186
+ elif isinstance(value, list):
187
+ # Recursively deserialize list items
188
+ return [cls._deserialize_value(item) for item in value]
189
+ else:
190
+ return value
191
+
192
+ @classmethod
193
+ def _from_dict(cls: type[T], data: dict) -> T:
194
+ """
195
+ Reconstruct object from dictionary.
196
+ Override in subclasses for custom deserialization.
197
+
198
+ Parameters
199
+ ----------
200
+ data : dict
201
+ Dictionary representation to reconstruct from.
202
+
203
+ Returns
204
+ -------
205
+ T
206
+ Reconstructed object instance.
207
+ """
208
+ obj = cls.__new__(cls)
209
+ reconstructed = {}
210
+ for key, value in data.items():
211
+ reconstructed[key] = cls._deserialize_value(value)
212
+ obj.__dict__.update(reconstructed)
213
+ return obj
214
+
215
+ @classmethod
216
+ def from_database(cls: type[T], name: str, table_name: str = None) -> T:
217
+ """
218
+ Pull object from the database by name.
219
+
220
+ Subclasses must define a '_table_name' class variable (str or list of str)
221
+ unless table_name is explicitly provided.
222
+
223
+ Parameters
224
+ ----------
225
+ name : str
226
+ Name of the object to retrieve.
227
+ table_name : str, optional
228
+ Specific table to search. If provided, '_table_name' is not required.
229
+ If None, uses class's _table_name.
230
+
231
+ Returns
232
+ -------
233
+ T
234
+ Instance of the class.
235
+
236
+ Raises
237
+ ------
238
+ NotImplementedError
239
+ If the subclass doesn't define '_table_name' and table_name is not provided.
240
+ ValueError
241
+ If the object name is not found in any of the tables.
242
+ """
243
+ from steer_core.Data.DataManager import DataManager
244
+
245
+ database = DataManager()
246
+
247
+ # Get list of tables to search
248
+ if table_name:
249
+ tables_to_search = [table_name]
250
+ else:
251
+ # Only check for _table_name if table_name wasn't provided
252
+ if not hasattr(cls, '_table_name'):
253
+ raise NotImplementedError(
254
+ f"{cls.__name__} must define a '_table_name' class variable "
255
+ "or provide 'table_name' argument"
256
+ )
257
+
258
+ if isinstance(cls._table_name, (list, tuple)):
259
+ tables_to_search = cls._table_name
260
+ else:
261
+ tables_to_search = [cls._table_name]
262
+
263
+ # Try each table until found
264
+ for table in tables_to_search:
265
+ available_materials = database.get_unique_values(table, "name")
266
+
267
+ if name in available_materials:
268
+ data = database.get_data(table, condition=f"name = '{name}'")
269
+ serialized_bytes = data["object"].iloc[0]
270
+ return cls.deserialize(serialized_bytes)
271
+
272
+ # Not found in any table
273
+ all_available = []
274
+ for table in tables_to_search:
275
+ all_available.extend(database.get_unique_values(table, "name"))
276
+
277
+ raise ValueError(
278
+ f"'{name}' not found in tables {tables_to_search}. "
279
+ f"Available: {all_available}"
280
+ )
281
+
@@ -208,8 +208,9 @@ class ValidationMixin:
208
208
  ValueError
209
209
  If the value is not a positive float.
210
210
  """
211
+ value = float(value) # Ensure value is float
211
212
  if not isinstance(value, (int, float)):
212
- raise ValueError(f"{name} must be a positive float. Provided: {value}.")
213
+ raise ValueError(f"{name} must be a positive float. Provided: {value} of type {type(value).__name__}.")
213
214
 
214
215
  @staticmethod
215
216
  def validate_positive_int(value: int, name: str) -> None:
@@ -1,7 +1,4 @@
1
- __version__ = "0.1.28"
2
-
3
- # datamanager import
4
- from .DataManager import DataManager
1
+ __version__ = "0.1.32"
5
2
 
6
3
  from .Mixins.Colors import ColorMixin
7
4
  from .Mixins.Coordinates import CoordinateMixin
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: steer-core
3
- Version: 0.1.28
3
+ Version: 0.1.32
4
4
  Summary: Modelling energy storage from cell to site - STEER OpenCell Design
5
5
  Author-email: Nicholas Siemons <nsiemons@stanford.edu>
6
6
  Maintainer-email: Nicholas Siemons <nsiemons@stanford.edu>
@@ -1,6 +1,5 @@
1
1
  README.md
2
2
  pyproject.toml
3
- steer_core/DataManager.py
4
3
  steer_core/__init__.py
5
4
  steer_core.egg-info/PKG-INFO
6
5
  steer_core.egg-info/SOURCES.txt
@@ -12,8 +11,8 @@ steer_core/Constants/Universal.py
12
11
  steer_core/Constants/__init__.py
13
12
  steer_core/ContextManagers/ContextManagers.py
14
13
  steer_core/ContextManagers/__init__.py
14
+ steer_core/Data/DataManager.py
15
15
  steer_core/Data/__init__.py
16
- steer_core/Data/database.db
17
16
  steer_core/Decorators/Coordinates.py
18
17
  steer_core/Decorators/General.py
19
18
  steer_core/Decorators/Objects.py
@@ -1,45 +0,0 @@
1
- import base64
2
- from pickle import loads, dumps
3
- from typing import Type
4
- from copy import deepcopy
5
-
6
-
7
- class SerializerMixin:
8
- def serialize(self) -> str:
9
- """
10
- Serialize an object to a string representation.
11
-
12
- Parameters
13
- ----------
14
- obj : Type
15
- The object to serialize.
16
-
17
- Returns
18
- -------
19
- str
20
- The serialized string representation of the object.
21
- """
22
- pickled = dumps(self)
23
- based = base64.b64encode(pickled).decode("utf-8")
24
- return based
25
-
26
- @staticmethod
27
- def deserialize(String: str) -> Type:
28
- """
29
- Deserialize a string representation into an object.
30
-
31
- Parameters
32
- ----------
33
- String : str
34
- The string representation to deserialize.
35
-
36
- Returns
37
- -------
38
- SerializerMixin
39
- The deserialized object.
40
- """
41
- decoded = base64.b64decode(String.encode("utf-8"))
42
- obj = deepcopy(loads(decoded))
43
- return obj
44
-
45
-
File without changes
File without changes
File without changes