steer-core 0.1.28__tar.gz → 0.1.32__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {steer_core-0.1.28 → steer_core-0.1.32}/PKG-INFO +1 -1
- {steer_core-0.1.28/steer_core → steer_core-0.1.32/steer_core/Data}/DataManager.py +76 -1
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Mixins/Coordinates.py +77 -22
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Mixins/Plotter.py +112 -17
- steer_core-0.1.32/steer_core/Mixins/Serializer.py +281 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Mixins/TypeChecker.py +2 -1
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/__init__.py +1 -4
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core.egg-info/PKG-INFO +1 -1
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core.egg-info/SOURCES.txt +1 -2
- steer_core-0.1.28/steer_core/Data/database.db +0 -0
- steer_core-0.1.28/steer_core/Mixins/Serializer.py +0 -45
- {steer_core-0.1.28 → steer_core-0.1.32}/README.md +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/pyproject.toml +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/setup.cfg +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Constants/Units.py +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Constants/Universal.py +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Constants/__init__.py +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/ContextManagers/ContextManagers.py +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/ContextManagers/__init__.py +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Data/__init__.py +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Decorators/Coordinates.py +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Decorators/General.py +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Decorators/Objects.py +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Decorators/__init__.py +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Mixins/Colors.py +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Mixins/Data.py +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Mixins/Dunder.py +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core/Mixins/__init__.py +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core.egg-info/dependency_links.txt +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core.egg-info/requires.txt +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/steer_core.egg-info/top_level.txt +0 -0
- {steer_core-0.1.28 → steer_core-0.1.32}/test/test_validation_mixin.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: steer-core
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.32
|
|
4
4
|
Summary: Modelling energy storage from cell to site - STEER OpenCell Design
|
|
5
5
|
Author-email: Nicholas Siemons <nsiemons@stanford.edu>
|
|
6
6
|
Maintainer-email: Nicholas Siemons <nsiemons@stanford.edu>
|
|
@@ -1,15 +1,20 @@
|
|
|
1
1
|
import sqlite3 as sql
|
|
2
2
|
from pathlib import Path
|
|
3
|
+
from typing import TypeVar
|
|
3
4
|
import pandas as pd
|
|
4
5
|
import importlib.resources
|
|
5
6
|
|
|
6
7
|
from steer_core.Constants.Units import *
|
|
8
|
+
from steer_core.Mixins.Serializer import SerializerMixin
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
T = TypeVar('T', bound='SerializerMixin')
|
|
7
12
|
|
|
8
13
|
|
|
9
14
|
class DataManager:
|
|
10
15
|
|
|
11
16
|
def __init__(self):
|
|
12
|
-
with importlib.resources.path("
|
|
17
|
+
with importlib.resources.path("steer_opencell_data", "database.db") as db_path:
|
|
13
18
|
self._db_path = db_path
|
|
14
19
|
self._connection = sql.connect(self._db_path)
|
|
15
20
|
self._cursor = self._connection.cursor()
|
|
@@ -356,5 +361,75 @@ class DataManager:
|
|
|
356
361
|
self._cursor.execute(f"DELETE FROM {table_name} WHERE {condition}")
|
|
357
362
|
self._connection.commit()
|
|
358
363
|
|
|
364
|
+
|
|
365
|
+
@classmethod
|
|
366
|
+
def from_database(cls: type[T], name: str, table_name: str = None) -> T:
|
|
367
|
+
"""
|
|
368
|
+
Pull object from the database by name.
|
|
369
|
+
|
|
370
|
+
Subclasses must define a '_table_name' class variable (str or list of str)
|
|
371
|
+
unless table_name is explicitly provided.
|
|
372
|
+
|
|
373
|
+
Parameters
|
|
374
|
+
----------
|
|
375
|
+
name : str
|
|
376
|
+
Name of the object to retrieve.
|
|
377
|
+
table_name : str, optional
|
|
378
|
+
Specific table to search. If provided, '_table_name' is not required.
|
|
379
|
+
If None, uses class's _table_name.
|
|
380
|
+
|
|
381
|
+
Returns
|
|
382
|
+
-------
|
|
383
|
+
T
|
|
384
|
+
Instance of the class.
|
|
385
|
+
|
|
386
|
+
Raises
|
|
387
|
+
------
|
|
388
|
+
NotImplementedError
|
|
389
|
+
If the subclass doesn't define '_table_name' and table_name is not provided.
|
|
390
|
+
ValueError
|
|
391
|
+
If the object name is not found in any of the tables.
|
|
392
|
+
"""
|
|
393
|
+
database = cls()
|
|
394
|
+
|
|
395
|
+
# Get list of tables to search
|
|
396
|
+
if table_name:
|
|
397
|
+
tables_to_search = [table_name]
|
|
398
|
+
else:
|
|
399
|
+
# Only check for _table_name if table_name wasn't provided
|
|
400
|
+
if not hasattr(cls, '_table_name'):
|
|
401
|
+
raise NotImplementedError(
|
|
402
|
+
f"{cls.__name__} must define a '_table_name' class variable "
|
|
403
|
+
"or provide 'table_name' argument"
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
if isinstance(cls._table_name, (list, tuple)):
|
|
407
|
+
tables_to_search = cls._table_name
|
|
408
|
+
else:
|
|
409
|
+
tables_to_search = [cls._table_name]
|
|
410
|
+
|
|
411
|
+
# Try each table until found
|
|
412
|
+
for table in tables_to_search:
|
|
413
|
+
available_materials = database.get_unique_values(table, "name")
|
|
414
|
+
|
|
415
|
+
if name in available_materials:
|
|
416
|
+
data = database.get_data(table, condition=f"name = '{name}'")
|
|
417
|
+
serialized_bytes = data["object"].iloc[0]
|
|
418
|
+
return cls.deserialize(serialized_bytes)
|
|
419
|
+
|
|
420
|
+
# Not found in any table
|
|
421
|
+
all_available = []
|
|
422
|
+
for table in tables_to_search:
|
|
423
|
+
all_available.extend(database.get_unique_values(table, "name"))
|
|
424
|
+
|
|
425
|
+
raise ValueError(
|
|
426
|
+
f"'{name}' not found in tables {tables_to_search}. "
|
|
427
|
+
f"Available: {all_available}"
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
|
|
359
432
|
def __del__(self):
|
|
360
433
|
self._connection.close()
|
|
434
|
+
|
|
435
|
+
|
|
@@ -136,25 +136,39 @@ class CoordinateMixin:
|
|
|
136
136
|
center: tuple = None
|
|
137
137
|
) -> np.ndarray:
|
|
138
138
|
"""
|
|
139
|
-
Rotate a
|
|
139
|
+
Rotate a NumPy array of coordinates around the specified axis.
|
|
140
|
+
Can handle 2D coordinates (N, 2) for x, y or 3D coordinates (N, 3) for x, y, z.
|
|
140
141
|
Can handle coordinates with None values (preserves None positions).
|
|
141
142
|
|
|
142
|
-
:param coords: NumPy array of shape (N,
|
|
143
|
-
:param axis: Axis to rotate around ('x', 'y', or 'z')
|
|
143
|
+
:param coords: NumPy array of shape (N, 2) for 2D or (N, 3) for 3D coordinates
|
|
144
|
+
:param axis: Axis to rotate around ('x', 'y', or 'z'). For 2D arrays, only 'z' is valid.
|
|
144
145
|
:param angle: Angle in degrees
|
|
145
|
-
:param center: Point to rotate around
|
|
146
|
-
|
|
146
|
+
:param center: Point to rotate around. For 2D: (x, y) tuple. For 3D: (x, y, z) tuple.
|
|
147
|
+
If None, rotates around origin.
|
|
148
|
+
:return: Rotated NumPy array with same shape as input
|
|
147
149
|
"""
|
|
148
|
-
if
|
|
150
|
+
# Check if 2D or 3D coordinates
|
|
151
|
+
is_2d = coords.shape[1] == 2
|
|
152
|
+
is_3d = coords.shape[1] == 3
|
|
153
|
+
|
|
154
|
+
if not (is_2d or is_3d):
|
|
155
|
+
raise ValueError(
|
|
156
|
+
"Input array must have shape (N, 2) for 2D or (N, 3) for 3D coordinates"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# For 2D arrays, only z-axis rotation is valid
|
|
160
|
+
if is_2d and axis != 'z':
|
|
149
161
|
raise ValueError(
|
|
150
|
-
"
|
|
162
|
+
"For 2D coordinates (x, y), only 'z' axis rotation is supported"
|
|
151
163
|
)
|
|
152
164
|
|
|
153
165
|
# Validate center parameter
|
|
154
166
|
if center is not None:
|
|
155
|
-
|
|
167
|
+
expected_len = 2 if is_2d else 3
|
|
168
|
+
if not isinstance(center, (tuple, list)) or len(center) != expected_len:
|
|
169
|
+
coord_type = "(x, y)" if is_2d else "(x, y, z)"
|
|
156
170
|
raise ValueError(
|
|
157
|
-
"Center must be a tuple or list of
|
|
171
|
+
f"Center must be a tuple or list of {expected_len} coordinates {coord_type}"
|
|
158
172
|
)
|
|
159
173
|
if not all(isinstance(coord, (int, float)) for coord in center):
|
|
160
174
|
raise TypeError("All center coordinates must be numbers")
|
|
@@ -204,10 +218,11 @@ class CoordinateMixin:
|
|
|
204
218
|
"""
|
|
205
219
|
Rotate coordinates around a specified center point.
|
|
206
220
|
|
|
207
|
-
:param coords: NumPy array of shape (N, 3) with valid coordinates
|
|
221
|
+
:param coords: NumPy array of shape (N, 2) or (N, 3) with valid coordinates
|
|
208
222
|
:param axis: Axis to rotate around ('x', 'y', or 'z')
|
|
209
223
|
:param angle: Angle in degrees
|
|
210
|
-
:param center: Center point as np.array
|
|
224
|
+
:param center: Center point as np.array. Shape (2,) for 2D or (3,) for 3D.
|
|
225
|
+
If None, rotates around origin.
|
|
211
226
|
:return: Rotated coordinates
|
|
212
227
|
"""
|
|
213
228
|
if center is None:
|
|
@@ -294,13 +309,37 @@ class CoordinateMixin:
|
|
|
294
309
|
@staticmethod
|
|
295
310
|
def order_coordinates_clockwise(df: pd.DataFrame, plane="xy") -> pd.DataFrame:
|
|
296
311
|
|
|
312
|
+
df = df.copy()
|
|
313
|
+
|
|
297
314
|
axis_1 = plane[0]
|
|
298
315
|
axis_2 = plane[1]
|
|
299
316
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
317
|
+
# Find column names that match the axis (case-insensitive, with or without units)
|
|
318
|
+
def find_column(axis_char: str) -> str:
|
|
319
|
+
axis_char_lower = axis_char.lower()
|
|
320
|
+
# First try exact match
|
|
321
|
+
if axis_char in df.columns:
|
|
322
|
+
return axis_char
|
|
323
|
+
# Try lowercase
|
|
324
|
+
if axis_char_lower in df.columns:
|
|
325
|
+
return axis_char_lower
|
|
326
|
+
# Try uppercase
|
|
327
|
+
if axis_char.upper() in df.columns:
|
|
328
|
+
return axis_char.upper()
|
|
329
|
+
# Try with units pattern like "X (mm)", "x (mm)", etc.
|
|
330
|
+
for col in df.columns:
|
|
331
|
+
col_stripped = col.split()[0].lower() if ' ' in col else col.lower()
|
|
332
|
+
if col_stripped == axis_char_lower:
|
|
333
|
+
return col
|
|
334
|
+
raise KeyError(f"Could not find column for axis '{axis_char}' in dataframe columns: {list(df.columns)}")
|
|
335
|
+
|
|
336
|
+
axis_1_col = find_column(axis_1)
|
|
337
|
+
axis_2_col = find_column(axis_2)
|
|
338
|
+
|
|
339
|
+
cx = df[axis_1_col].mean()
|
|
340
|
+
cy = df[axis_2_col].mean()
|
|
341
|
+
|
|
342
|
+
angles = np.arctan2(df[axis_2_col] - cy, df[axis_1_col] - cx)
|
|
304
343
|
|
|
305
344
|
df["angle"] = angles
|
|
306
345
|
|
|
@@ -533,19 +572,35 @@ class CoordinateMixin:
|
|
|
533
572
|
) -> np.ndarray:
|
|
534
573
|
"""
|
|
535
574
|
Rotate coordinates without None values using rotation matrices.
|
|
575
|
+
Handles both 2D (N, 2) and 3D (N, 3) coordinate arrays.
|
|
576
|
+
|
|
577
|
+
:param coords: NumPy array of shape (N, 2) or (N, 3)
|
|
578
|
+
:param axis: Axis to rotate around ('x', 'y', or 'z')
|
|
579
|
+
:param angle: Angle in degrees
|
|
580
|
+
:return: Rotated coordinates with same shape as input
|
|
536
581
|
"""
|
|
537
582
|
angle_rad = np.radians(angle)
|
|
538
583
|
cos_a = np.cos(angle_rad)
|
|
539
584
|
sin_a = np.sin(angle_rad)
|
|
585
|
+
|
|
586
|
+
is_2d = coords.shape[1] == 2
|
|
540
587
|
|
|
541
|
-
if
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
R = np.array([[cos_a, -sin_a
|
|
588
|
+
if is_2d:
|
|
589
|
+
# For 2D coordinates, only z-axis rotation applies (rotation in xy plane)
|
|
590
|
+
if axis != 'z':
|
|
591
|
+
raise ValueError("For 2D coordinates, only 'z' axis rotation is supported")
|
|
592
|
+
# 2D rotation matrix
|
|
593
|
+
R = np.array([[cos_a, -sin_a], [sin_a, cos_a]])
|
|
547
594
|
else:
|
|
548
|
-
|
|
595
|
+
# 3D rotation matrices
|
|
596
|
+
if axis == "x":
|
|
597
|
+
R = np.array([[1, 0, 0], [0, cos_a, -sin_a], [0, sin_a, cos_a]])
|
|
598
|
+
elif axis == "y":
|
|
599
|
+
R = np.array([[cos_a, 0, sin_a], [0, 1, 0], [-sin_a, 0, cos_a]])
|
|
600
|
+
elif axis == "z":
|
|
601
|
+
R = np.array([[cos_a, -sin_a, 0], [sin_a, cos_a, 0], [0, 0, 1]])
|
|
602
|
+
else:
|
|
603
|
+
raise ValueError("Axis must be 'x', 'y', or 'z'.")
|
|
549
604
|
|
|
550
605
|
return coords @ R.T
|
|
551
606
|
|
|
@@ -36,6 +36,7 @@ class PlotterMixin:
|
|
|
36
36
|
|
|
37
37
|
SCHEMATIC_Z_AXIS = dict(
|
|
38
38
|
zeroline=False,
|
|
39
|
+
scaleanchor="x",
|
|
39
40
|
title="Z (mm)"
|
|
40
41
|
)
|
|
41
42
|
|
|
@@ -55,7 +56,8 @@ class PlotterMixin:
|
|
|
55
56
|
line_width,
|
|
56
57
|
color_func,
|
|
57
58
|
unit_conversion_factor,
|
|
58
|
-
order_clockwise: str = None
|
|
59
|
+
order_clockwise: str = None,
|
|
60
|
+
gl: bool = False
|
|
59
61
|
):
|
|
60
62
|
"""
|
|
61
63
|
Create a single trace for a component or group of components with NaN separators.
|
|
@@ -115,17 +117,30 @@ class PlotterMixin:
|
|
|
115
117
|
z_coords = combined_coords[:, 2] * unit_conversion_factor
|
|
116
118
|
|
|
117
119
|
# Create trace
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
120
|
+
if gl:
|
|
121
|
+
return go.Scattergl(
|
|
122
|
+
x=y_coords,
|
|
123
|
+
y=z_coords,
|
|
124
|
+
mode="lines",
|
|
125
|
+
name=name,
|
|
126
|
+
line={'width': line_width, 'color': "black"},
|
|
127
|
+
fill="toself",
|
|
128
|
+
fillcolor=color_func(components[0]),
|
|
129
|
+
legendgroup=name,
|
|
130
|
+
showlegend=True,
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
return go.Scatter(
|
|
134
|
+
x=y_coords,
|
|
135
|
+
y=z_coords,
|
|
136
|
+
mode="lines",
|
|
137
|
+
name=name,
|
|
138
|
+
line={'width': line_width, 'color': "black"},
|
|
139
|
+
fill="toself",
|
|
140
|
+
fillcolor=color_func(components[0]),
|
|
141
|
+
legendgroup=name,
|
|
142
|
+
showlegend=True,
|
|
143
|
+
)
|
|
129
144
|
|
|
130
145
|
@staticmethod
|
|
131
146
|
def plot_breakdown_sunburst(
|
|
@@ -133,6 +148,7 @@ class PlotterMixin:
|
|
|
133
148
|
title: str = "Breakdown",
|
|
134
149
|
root_label: str = "Total",
|
|
135
150
|
unit: str = "",
|
|
151
|
+
colorway: List[str] = None,
|
|
136
152
|
**kwargs,
|
|
137
153
|
) -> go.Figure:
|
|
138
154
|
"""
|
|
@@ -149,12 +165,22 @@ class PlotterMixin:
|
|
|
149
165
|
Label for the root node. Defaults to "Total".
|
|
150
166
|
unit : str, optional
|
|
151
167
|
Unit string to display in hover text (e.g., "g", "kg", "%"). Defaults to "".
|
|
168
|
+
colorway : List[str], optional
|
|
169
|
+
List of colors to use for the inner ring. If None, uses Plotly's default colorway.
|
|
170
|
+
Defaults to None.
|
|
152
171
|
|
|
153
172
|
Returns
|
|
154
173
|
-------
|
|
155
174
|
go.Figure
|
|
156
175
|
Plotly sunburst figure
|
|
157
176
|
"""
|
|
177
|
+
|
|
178
|
+
# Default Plotly colorway if none provided
|
|
179
|
+
if colorway is None:
|
|
180
|
+
colorway = [
|
|
181
|
+
'#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A',
|
|
182
|
+
'#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52'
|
|
183
|
+
]
|
|
158
184
|
|
|
159
185
|
def _flatten_breakdown_values(data: Dict[str, Any]) -> List[float]:
|
|
160
186
|
"""Recursively flatten all numeric values from nested breakdown dictionary"""
|
|
@@ -177,13 +203,14 @@ class PlotterMixin:
|
|
|
177
203
|
return total
|
|
178
204
|
|
|
179
205
|
def _prepare_sunburst_data(
|
|
180
|
-
data: Dict[str, Any], parent_id: str = "", current_path: str = ""
|
|
181
|
-
) -> Tuple[List[str], List[str], List[str], List[float]]:
|
|
206
|
+
data: Dict[str, Any], parent_id: str = "", current_path: str = "", depth: int = 1
|
|
207
|
+
) -> Tuple[List[str], List[str], List[str], List[float], List[int]]:
|
|
182
208
|
"""Recursively prepare data for sunburst plot with proper hierarchy"""
|
|
183
209
|
ids = []
|
|
184
210
|
labels = []
|
|
185
211
|
parents = []
|
|
186
212
|
values = []
|
|
213
|
+
depths = []
|
|
187
214
|
|
|
188
215
|
for key, value in data.items():
|
|
189
216
|
# Create unique ID for this node
|
|
@@ -192,6 +219,7 @@ class PlotterMixin:
|
|
|
192
219
|
ids.append(node_id)
|
|
193
220
|
labels.append(key)
|
|
194
221
|
parents.append(parent_id)
|
|
222
|
+
depths.append(depth)
|
|
195
223
|
|
|
196
224
|
if isinstance(value, dict):
|
|
197
225
|
# This is a nested dictionary - calculate its total value
|
|
@@ -204,8 +232,9 @@ class PlotterMixin:
|
|
|
204
232
|
nested_labels,
|
|
205
233
|
nested_parents,
|
|
206
234
|
nested_values,
|
|
235
|
+
nested_depths,
|
|
207
236
|
) = _prepare_sunburst_data(
|
|
208
|
-
value, parent_id=node_id, current_path=node_id
|
|
237
|
+
value, parent_id=node_id, current_path=node_id, depth=depth + 1
|
|
209
238
|
)
|
|
210
239
|
|
|
211
240
|
# Add nested data to our lists
|
|
@@ -213,18 +242,19 @@ class PlotterMixin:
|
|
|
213
242
|
labels.extend(nested_labels)
|
|
214
243
|
parents.extend(nested_parents)
|
|
215
244
|
values.extend(nested_values)
|
|
245
|
+
depths.extend(nested_depths)
|
|
216
246
|
|
|
217
247
|
elif isinstance(value, (int, float)):
|
|
218
248
|
# This is a leaf node with a numeric value
|
|
219
249
|
values.append(float(value))
|
|
220
250
|
|
|
221
|
-
return ids, labels, parents, values
|
|
251
|
+
return ids, labels, parents, values, depths
|
|
222
252
|
|
|
223
253
|
# Calculate total value for root node
|
|
224
254
|
total_value = _calculate_subtotal(breakdown_dict)
|
|
225
255
|
|
|
226
256
|
# Prepare hierarchical data starting with root
|
|
227
|
-
ids, labels, parents, values = _prepare_sunburst_data(
|
|
257
|
+
ids, labels, parents, values, depths = _prepare_sunburst_data(
|
|
228
258
|
breakdown_dict, parent_id=""
|
|
229
259
|
)
|
|
230
260
|
|
|
@@ -233,12 +263,74 @@ class PlotterMixin:
|
|
|
233
263
|
labels.insert(0, root_label)
|
|
234
264
|
parents.insert(0, "")
|
|
235
265
|
values.insert(0, total_value)
|
|
266
|
+
depths.insert(0, 0)
|
|
236
267
|
|
|
237
268
|
# Update parent references to point to root
|
|
238
269
|
for i in range(1, len(parents)):
|
|
239
270
|
if parents[i] == "":
|
|
240
271
|
parents[i] = root_label
|
|
241
272
|
|
|
273
|
+
# Generate colors based on alphabetical ordering and depth
|
|
274
|
+
def _hex_to_rgb(hex_color: str) -> Tuple[int, int, int]:
|
|
275
|
+
"""Convert hex color to RGB tuple"""
|
|
276
|
+
hex_color = hex_color.lstrip('#')
|
|
277
|
+
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
|
|
278
|
+
|
|
279
|
+
def _rgb_to_hex(rgb: Tuple[int, int, int]) -> str:
|
|
280
|
+
"""Convert RGB tuple to hex color"""
|
|
281
|
+
return '#{:02x}{:02x}{:02x}'.format(int(rgb[0]), int(rgb[1]), int(rgb[2]))
|
|
282
|
+
|
|
283
|
+
def _lighten_color(hex_color: str, factor: float) -> str:
|
|
284
|
+
"""Lighten a color by blending with white. Factor 0=original, 1=white"""
|
|
285
|
+
r, g, b = _hex_to_rgb(hex_color)
|
|
286
|
+
# Blend with white (255, 255, 255)
|
|
287
|
+
r = r + (255 - r) * factor
|
|
288
|
+
g = g + (255 - g) * factor
|
|
289
|
+
b = b + (255 - b) * factor
|
|
290
|
+
return _rgb_to_hex((r, g, b))
|
|
291
|
+
|
|
292
|
+
# Get first-level keys (children of root) and sort alphabetically
|
|
293
|
+
first_level_keys = sorted([key for key in breakdown_dict.keys()])
|
|
294
|
+
|
|
295
|
+
# Assign base colors to first-level keys
|
|
296
|
+
key_to_base_color = {}
|
|
297
|
+
for i, key in enumerate(first_level_keys):
|
|
298
|
+
key_to_base_color[key] = colorway[i % len(colorway)]
|
|
299
|
+
|
|
300
|
+
# Assign colors to all nodes
|
|
301
|
+
marker_colors = []
|
|
302
|
+
max_depth = max(depths) if depths else 0
|
|
303
|
+
|
|
304
|
+
for i, (node_id, label, parent, depth) in enumerate(zip(ids, labels, parents, depths)):
|
|
305
|
+
if depth == 0:
|
|
306
|
+
# Root node - use neutral color
|
|
307
|
+
marker_colors.append('#CCCCCC')
|
|
308
|
+
elif depth == 1:
|
|
309
|
+
# First level - use assigned base color
|
|
310
|
+
marker_colors.append(key_to_base_color[label])
|
|
311
|
+
else:
|
|
312
|
+
# Deeper levels - find the first-level ancestor and lighten its color
|
|
313
|
+
# Trace back through parents to find first-level ancestor
|
|
314
|
+
current_parent = parent
|
|
315
|
+
ancestor_label = None
|
|
316
|
+
|
|
317
|
+
for j, (check_id, check_label, check_depth) in enumerate(zip(ids, labels, depths)):
|
|
318
|
+
if check_id == current_parent:
|
|
319
|
+
if check_depth == 1:
|
|
320
|
+
ancestor_label = check_label
|
|
321
|
+
break
|
|
322
|
+
current_parent = parents[j]
|
|
323
|
+
|
|
324
|
+
if ancestor_label and ancestor_label in key_to_base_color:
|
|
325
|
+
base_color = key_to_base_color[ancestor_label]
|
|
326
|
+
# Lighten based on depth (depth 2 gets 0.3, depth 3 gets 0.5, depth 4 gets 0.7, etc.)
|
|
327
|
+
lighten_factor = 0.2 + (depth - 1) * 0.25
|
|
328
|
+
lighten_factor = min(lighten_factor, 0.85) # Cap at 0.85 to avoid too pale
|
|
329
|
+
marker_colors.append(_lighten_color(base_color, lighten_factor))
|
|
330
|
+
else:
|
|
331
|
+
# Fallback to neutral color
|
|
332
|
+
marker_colors.append('#DDDDDD')
|
|
333
|
+
|
|
242
334
|
# Create custom hover text with percentages
|
|
243
335
|
hover_text = []
|
|
244
336
|
for i, (label, value) in enumerate(zip(labels, values)):
|
|
@@ -262,6 +354,7 @@ class PlotterMixin:
|
|
|
262
354
|
branchvalues="total",
|
|
263
355
|
hovertemplate="%{customdata}<extra></extra>",
|
|
264
356
|
customdata=hover_text,
|
|
357
|
+
marker=dict(colors=marker_colors),
|
|
265
358
|
)
|
|
266
359
|
)
|
|
267
360
|
|
|
@@ -273,3 +366,5 @@ class PlotterMixin:
|
|
|
273
366
|
|
|
274
367
|
|
|
275
368
|
|
|
369
|
+
|
|
370
|
+
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
import msgpack
|
|
2
|
+
import msgpack_numpy as m
|
|
3
|
+
import zlib
|
|
4
|
+
from typing import TypeVar, Any
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from enum import Enum
|
|
7
|
+
|
|
8
|
+
T = TypeVar('T', bound='SerializerMixin')
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SerializerMixin:
|
|
12
|
+
|
|
13
|
+
def serialize(self, compress: bool = True) -> bytes:
|
|
14
|
+
"""
|
|
15
|
+
Serialize object using MessagePack with numpy support.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
compress : bool, optional
|
|
20
|
+
Whether to compress the serialized data (default: True)
|
|
21
|
+
|
|
22
|
+
Returns
|
|
23
|
+
-------
|
|
24
|
+
bytes
|
|
25
|
+
The serialized byte representation of the object.
|
|
26
|
+
"""
|
|
27
|
+
m.patch() # Enable numpy support
|
|
28
|
+
data = msgpack.packb(self._to_dict(), use_bin_type=True)
|
|
29
|
+
|
|
30
|
+
if compress:
|
|
31
|
+
# Add marker byte to indicate compression
|
|
32
|
+
return b'\x01' + zlib.compress(data, level=6)
|
|
33
|
+
else:
|
|
34
|
+
return b'\x00' + data
|
|
35
|
+
|
|
36
|
+
def _serialize_value(self, value: Any) -> Any:
|
|
37
|
+
"""
|
|
38
|
+
Recursively serialize a value, handling nested structures.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
value : Any
|
|
43
|
+
The value to serialize.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
Any
|
|
48
|
+
The serialized representation.
|
|
49
|
+
"""
|
|
50
|
+
if hasattr(value, '_to_dict'):
|
|
51
|
+
# Add marker and class info for object reconstruction
|
|
52
|
+
return {
|
|
53
|
+
'__object__': True,
|
|
54
|
+
'_class': f"{value.__class__.__module__}.{value.__class__.__name__}",
|
|
55
|
+
**value._to_dict()
|
|
56
|
+
}
|
|
57
|
+
elif isinstance(value, datetime):
|
|
58
|
+
return {'__datetime__': value.isoformat()}
|
|
59
|
+
elif isinstance(value, Enum):
|
|
60
|
+
return {
|
|
61
|
+
'__enum__': True,
|
|
62
|
+
'class': f"{value.__class__.__module__}.{value.__class__.__name__}",
|
|
63
|
+
'value': value.value
|
|
64
|
+
}
|
|
65
|
+
elif isinstance(value, tuple):
|
|
66
|
+
# Recursively serialize tuple items
|
|
67
|
+
return {
|
|
68
|
+
'__tuple__': True,
|
|
69
|
+
'items': [self._serialize_value(item) for item in value]
|
|
70
|
+
}
|
|
71
|
+
elif isinstance(value, list):
|
|
72
|
+
# Recursively serialize list items
|
|
73
|
+
return [self._serialize_value(item) for item in value]
|
|
74
|
+
elif isinstance(value, dict):
|
|
75
|
+
# Handle dictionaries with object keys or values
|
|
76
|
+
has_object_keys = value and any(hasattr(k, '_to_dict') for k in value.keys())
|
|
77
|
+
has_object_values = value and any(hasattr(v, '_to_dict') for v in value.values())
|
|
78
|
+
|
|
79
|
+
if has_object_keys or has_object_values:
|
|
80
|
+
return {
|
|
81
|
+
'__object_dict__': True,
|
|
82
|
+
'items': [
|
|
83
|
+
{
|
|
84
|
+
'key': self._serialize_value(k),
|
|
85
|
+
'value': self._serialize_value(v)
|
|
86
|
+
}
|
|
87
|
+
for k, v in value.items()
|
|
88
|
+
]
|
|
89
|
+
}
|
|
90
|
+
else:
|
|
91
|
+
# Recursively serialize regular dict values
|
|
92
|
+
return {k: self._serialize_value(v) for k, v in value.items()}
|
|
93
|
+
else:
|
|
94
|
+
return value
|
|
95
|
+
|
|
96
|
+
def _to_dict(self) -> dict:
|
|
97
|
+
"""
|
|
98
|
+
Convert object to dictionary for serialization.
|
|
99
|
+
Override this in subclasses to customize serialization behavior.
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
dict
|
|
104
|
+
Dictionary representation of the object.
|
|
105
|
+
"""
|
|
106
|
+
result = {}
|
|
107
|
+
for key, value in self.__dict__.items():
|
|
108
|
+
result[key] = self._serialize_value(value)
|
|
109
|
+
return result
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def deserialize(cls: type[T], data: bytes) -> T:
|
|
113
|
+
"""
|
|
114
|
+
Deserialize byte data into an object.
|
|
115
|
+
Automatically detects and decompresses if needed.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
data : bytes
|
|
120
|
+
The byte data to deserialize.
|
|
121
|
+
|
|
122
|
+
Returns
|
|
123
|
+
-------
|
|
124
|
+
T
|
|
125
|
+
Instance of the class.
|
|
126
|
+
"""
|
|
127
|
+
m.patch() # Enable numpy support
|
|
128
|
+
|
|
129
|
+
# Check compression marker
|
|
130
|
+
if data[0:1] == b'\x01':
|
|
131
|
+
data = zlib.decompress(data[1:])
|
|
132
|
+
else:
|
|
133
|
+
data = data[1:]
|
|
134
|
+
|
|
135
|
+
obj_dict = msgpack.unpackb(data, raw=False)
|
|
136
|
+
return cls._from_dict(obj_dict)
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def _deserialize_value(cls, value: Any) -> Any:
|
|
140
|
+
"""
|
|
141
|
+
Recursively deserialize a value, handling nested structures.
|
|
142
|
+
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
value : Any
|
|
146
|
+
The value to deserialize.
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
Any
|
|
151
|
+
The deserialized object.
|
|
152
|
+
"""
|
|
153
|
+
if isinstance(value, dict):
|
|
154
|
+
if '__datetime__' in value:
|
|
155
|
+
return datetime.fromisoformat(value['__datetime__'])
|
|
156
|
+
elif '__enum__' in value:
|
|
157
|
+
# Reconstruct enum
|
|
158
|
+
module_name, class_name = value['class'].rsplit('.', 1)
|
|
159
|
+
import importlib
|
|
160
|
+
module = importlib.import_module(module_name)
|
|
161
|
+
enum_class = getattr(module, class_name)
|
|
162
|
+
return enum_class(value['value'])
|
|
163
|
+
elif '__tuple__' in value:
|
|
164
|
+
# Recursively reconstruct tuple items
|
|
165
|
+
return tuple(cls._deserialize_value(item) for item in value['items'])
|
|
166
|
+
elif '__object__' in value:
|
|
167
|
+
# Reconstruct regular object
|
|
168
|
+
import importlib
|
|
169
|
+
module_name, class_name = value['_class'].rsplit('.', 1)
|
|
170
|
+
module = importlib.import_module(module_name)
|
|
171
|
+
obj_class = getattr(module, class_name)
|
|
172
|
+
# Remove marker fields before passing to _from_dict
|
|
173
|
+
obj_data = {k: v for k, v in value.items() if k not in ('__object__', '_class')}
|
|
174
|
+
return obj_class._from_dict(obj_data)
|
|
175
|
+
elif '__object_dict__' in value:
|
|
176
|
+
# Reconstruct dictionary with object keys or values
|
|
177
|
+
reconstructed_dict = {}
|
|
178
|
+
for item in value['items']:
|
|
179
|
+
key_obj = cls._deserialize_value(item['key'])
|
|
180
|
+
value_obj = cls._deserialize_value(item['value'])
|
|
181
|
+
reconstructed_dict[key_obj] = value_obj
|
|
182
|
+
return reconstructed_dict
|
|
183
|
+
else:
|
|
184
|
+
# Recursively deserialize regular dict values
|
|
185
|
+
return {k: cls._deserialize_value(v) for k, v in value.items()}
|
|
186
|
+
elif isinstance(value, list):
|
|
187
|
+
# Recursively deserialize list items
|
|
188
|
+
return [cls._deserialize_value(item) for item in value]
|
|
189
|
+
else:
|
|
190
|
+
return value
|
|
191
|
+
|
|
192
|
+
@classmethod
|
|
193
|
+
def _from_dict(cls: type[T], data: dict) -> T:
|
|
194
|
+
"""
|
|
195
|
+
Reconstruct object from dictionary.
|
|
196
|
+
Override in subclasses for custom deserialization.
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
data : dict
|
|
201
|
+
Dictionary representation to reconstruct from.
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
T
|
|
206
|
+
Reconstructed object instance.
|
|
207
|
+
"""
|
|
208
|
+
obj = cls.__new__(cls)
|
|
209
|
+
reconstructed = {}
|
|
210
|
+
for key, value in data.items():
|
|
211
|
+
reconstructed[key] = cls._deserialize_value(value)
|
|
212
|
+
obj.__dict__.update(reconstructed)
|
|
213
|
+
return obj
|
|
214
|
+
|
|
215
|
+
@classmethod
|
|
216
|
+
def from_database(cls: type[T], name: str, table_name: str = None) -> T:
|
|
217
|
+
"""
|
|
218
|
+
Pull object from the database by name.
|
|
219
|
+
|
|
220
|
+
Subclasses must define a '_table_name' class variable (str or list of str)
|
|
221
|
+
unless table_name is explicitly provided.
|
|
222
|
+
|
|
223
|
+
Parameters
|
|
224
|
+
----------
|
|
225
|
+
name : str
|
|
226
|
+
Name of the object to retrieve.
|
|
227
|
+
table_name : str, optional
|
|
228
|
+
Specific table to search. If provided, '_table_name' is not required.
|
|
229
|
+
If None, uses class's _table_name.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
T
|
|
234
|
+
Instance of the class.
|
|
235
|
+
|
|
236
|
+
Raises
|
|
237
|
+
------
|
|
238
|
+
NotImplementedError
|
|
239
|
+
If the subclass doesn't define '_table_name' and table_name is not provided.
|
|
240
|
+
ValueError
|
|
241
|
+
If the object name is not found in any of the tables.
|
|
242
|
+
"""
|
|
243
|
+
from steer_core.Data.DataManager import DataManager
|
|
244
|
+
|
|
245
|
+
database = DataManager()
|
|
246
|
+
|
|
247
|
+
# Get list of tables to search
|
|
248
|
+
if table_name:
|
|
249
|
+
tables_to_search = [table_name]
|
|
250
|
+
else:
|
|
251
|
+
# Only check for _table_name if table_name wasn't provided
|
|
252
|
+
if not hasattr(cls, '_table_name'):
|
|
253
|
+
raise NotImplementedError(
|
|
254
|
+
f"{cls.__name__} must define a '_table_name' class variable "
|
|
255
|
+
"or provide 'table_name' argument"
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
if isinstance(cls._table_name, (list, tuple)):
|
|
259
|
+
tables_to_search = cls._table_name
|
|
260
|
+
else:
|
|
261
|
+
tables_to_search = [cls._table_name]
|
|
262
|
+
|
|
263
|
+
# Try each table until found
|
|
264
|
+
for table in tables_to_search:
|
|
265
|
+
available_materials = database.get_unique_values(table, "name")
|
|
266
|
+
|
|
267
|
+
if name in available_materials:
|
|
268
|
+
data = database.get_data(table, condition=f"name = '{name}'")
|
|
269
|
+
serialized_bytes = data["object"].iloc[0]
|
|
270
|
+
return cls.deserialize(serialized_bytes)
|
|
271
|
+
|
|
272
|
+
# Not found in any table
|
|
273
|
+
all_available = []
|
|
274
|
+
for table in tables_to_search:
|
|
275
|
+
all_available.extend(database.get_unique_values(table, "name"))
|
|
276
|
+
|
|
277
|
+
raise ValueError(
|
|
278
|
+
f"'{name}' not found in tables {tables_to_search}. "
|
|
279
|
+
f"Available: {all_available}"
|
|
280
|
+
)
|
|
281
|
+
|
|
@@ -208,8 +208,9 @@ class ValidationMixin:
|
|
|
208
208
|
ValueError
|
|
209
209
|
If the value is not a positive float.
|
|
210
210
|
"""
|
|
211
|
+
value = float(value) # Ensure value is float
|
|
211
212
|
if not isinstance(value, (int, float)):
|
|
212
|
-
raise ValueError(f"{name} must be a positive float. Provided: {value}.")
|
|
213
|
+
raise ValueError(f"{name} must be a positive float. Provided: {value} of type {type(value).__name__}.")
|
|
213
214
|
|
|
214
215
|
@staticmethod
|
|
215
216
|
def validate_positive_int(value: int, name: str) -> None:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: steer-core
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.32
|
|
4
4
|
Summary: Modelling energy storage from cell to site - STEER OpenCell Design
|
|
5
5
|
Author-email: Nicholas Siemons <nsiemons@stanford.edu>
|
|
6
6
|
Maintainer-email: Nicholas Siemons <nsiemons@stanford.edu>
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
README.md
|
|
2
2
|
pyproject.toml
|
|
3
|
-
steer_core/DataManager.py
|
|
4
3
|
steer_core/__init__.py
|
|
5
4
|
steer_core.egg-info/PKG-INFO
|
|
6
5
|
steer_core.egg-info/SOURCES.txt
|
|
@@ -12,8 +11,8 @@ steer_core/Constants/Universal.py
|
|
|
12
11
|
steer_core/Constants/__init__.py
|
|
13
12
|
steer_core/ContextManagers/ContextManagers.py
|
|
14
13
|
steer_core/ContextManagers/__init__.py
|
|
14
|
+
steer_core/Data/DataManager.py
|
|
15
15
|
steer_core/Data/__init__.py
|
|
16
|
-
steer_core/Data/database.db
|
|
17
16
|
steer_core/Decorators/Coordinates.py
|
|
18
17
|
steer_core/Decorators/General.py
|
|
19
18
|
steer_core/Decorators/Objects.py
|
|
Binary file
|
|
@@ -1,45 +0,0 @@
|
|
|
1
|
-
import base64
|
|
2
|
-
from pickle import loads, dumps
|
|
3
|
-
from typing import Type
|
|
4
|
-
from copy import deepcopy
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class SerializerMixin:
|
|
8
|
-
def serialize(self) -> str:
|
|
9
|
-
"""
|
|
10
|
-
Serialize an object to a string representation.
|
|
11
|
-
|
|
12
|
-
Parameters
|
|
13
|
-
----------
|
|
14
|
-
obj : Type
|
|
15
|
-
The object to serialize.
|
|
16
|
-
|
|
17
|
-
Returns
|
|
18
|
-
-------
|
|
19
|
-
str
|
|
20
|
-
The serialized string representation of the object.
|
|
21
|
-
"""
|
|
22
|
-
pickled = dumps(self)
|
|
23
|
-
based = base64.b64encode(pickled).decode("utf-8")
|
|
24
|
-
return based
|
|
25
|
-
|
|
26
|
-
@staticmethod
|
|
27
|
-
def deserialize(String: str) -> Type:
|
|
28
|
-
"""
|
|
29
|
-
Deserialize a string representation into an object.
|
|
30
|
-
|
|
31
|
-
Parameters
|
|
32
|
-
----------
|
|
33
|
-
String : str
|
|
34
|
-
The string representation to deserialize.
|
|
35
|
-
|
|
36
|
-
Returns
|
|
37
|
-
-------
|
|
38
|
-
SerializerMixin
|
|
39
|
-
The deserialized object.
|
|
40
|
-
"""
|
|
41
|
-
decoded = base64.b64decode(String.encode("utf-8"))
|
|
42
|
-
obj = deepcopy(loads(decoded))
|
|
43
|
-
return obj
|
|
44
|
-
|
|
45
|
-
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|