pyconvexity 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyconvexity might be problematic. Click here for more details.

@@ -0,0 +1,383 @@
1
+ """
2
+ Attribute management operations for PyConvexity.
3
+
4
+ Provides operations for setting, getting, and managing component attributes
5
+ with support for both static values and timeseries data.
6
+ """
7
+
8
+ import sqlite3
9
+ import json
10
+ import logging
11
+ from typing import Dict, Any, Optional, List
12
+ import pandas as pd
13
+ from io import BytesIO
14
+ import pyarrow as pa
15
+ import pyarrow.parquet as pq
16
+
17
+ from pyconvexity.core.types import (
18
+ StaticValue, TimeseriesPoint, AttributeValue, TimePeriod
19
+ )
20
+ from pyconvexity.core.errors import (
21
+ ComponentNotFound, AttributeNotFound, ValidationError, TimeseriesError
22
+ )
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def set_static_attribute(
28
+ conn: sqlite3.Connection,
29
+ component_id: int,
30
+ attribute_name: str,
31
+ value: StaticValue,
32
+ scenario_id: Optional[int] = None
33
+ ) -> None:
34
+ """
35
+ Set a static attribute value for a component in a specific scenario.
36
+
37
+ Args:
38
+ conn: Database connection
39
+ component_id: Component ID
40
+ attribute_name: Name of the attribute
41
+ value: Static value to set
42
+ scenario_id: Scenario ID (uses master scenario if None)
43
+
44
+ Raises:
45
+ ComponentNotFound: If component doesn't exist
46
+ ValidationError: If attribute doesn't allow static values or validation fails
47
+ """
48
+ # 1. Get component type
49
+ from pyconvexity.models.components import get_component_type
50
+ component_type = get_component_type(conn, component_id)
51
+
52
+ # 2. Get validation rule
53
+ from pyconvexity.validation.rules import get_validation_rule, validate_static_value
54
+ rule = get_validation_rule(conn, component_type, attribute_name)
55
+
56
+ # 3. Check if static values are allowed
57
+ if not rule.allows_static:
58
+ raise ValidationError(f"Attribute '{attribute_name}' for {component_type} does not allow static values")
59
+
60
+ # 4. Validate data type
61
+ validate_static_value(value, rule)
62
+
63
+ # 5. Resolve scenario ID (get master scenario if None)
64
+ resolved_scenario_id = resolve_scenario_id(conn, component_id, scenario_id)
65
+
66
+ # 6. Remove any existing attribute for this scenario
67
+ cursor = conn.cursor()
68
+ cursor.execute(
69
+ "DELETE FROM component_attributes WHERE component_id = ? AND attribute_name = ? AND scenario_id = ?",
70
+ (component_id, attribute_name, resolved_scenario_id)
71
+ )
72
+
73
+ # 7. Insert new static attribute (store as JSON in static_value TEXT column)
74
+ json_value = value.to_json()
75
+
76
+ cursor.execute(
77
+ """INSERT INTO component_attributes
78
+ (component_id, attribute_name, scenario_id, storage_type, static_value, data_type, unit, is_input)
79
+ VALUES (?, ?, ?, 'static', ?, ?, ?, ?)""",
80
+ (component_id, attribute_name, resolved_scenario_id, json_value,
81
+ rule.data_type, rule.unit, rule.is_input)
82
+ )
83
+
84
+
85
+ def set_timeseries_attribute(
86
+ conn: sqlite3.Connection,
87
+ component_id: int,
88
+ attribute_name: str,
89
+ timeseries: List[TimeseriesPoint],
90
+ scenario_id: Optional[int] = None
91
+ ) -> None:
92
+ """
93
+ Set a timeseries attribute value for a component in a specific scenario.
94
+
95
+ Args:
96
+ conn: Database connection
97
+ component_id: Component ID
98
+ attribute_name: Name of the attribute
99
+ timeseries: List of timeseries points
100
+ scenario_id: Scenario ID (uses master scenario if None)
101
+
102
+ Raises:
103
+ ComponentNotFound: If component doesn't exist
104
+ ValidationError: If attribute doesn't allow timeseries values
105
+ TimeseriesError: If timeseries serialization fails
106
+ """
107
+ # 1. Get component type
108
+ from pyconvexity.models.components import get_component_type
109
+ component_type = get_component_type(conn, component_id)
110
+
111
+ # 2. Get validation rule
112
+ from pyconvexity.validation.rules import get_validation_rule
113
+ rule = get_validation_rule(conn, component_type, attribute_name)
114
+
115
+ # 3. Check if timeseries values are allowed
116
+ if not rule.allows_timeseries:
117
+ raise ValidationError(f"Attribute '{attribute_name}' for {component_type} does not allow timeseries values")
118
+
119
+ # 4. Serialize timeseries to Parquet
120
+ parquet_data = serialize_timeseries_to_parquet(timeseries)
121
+
122
+ # 5. Resolve scenario ID (get master scenario if None)
123
+ resolved_scenario_id = resolve_scenario_id(conn, component_id, scenario_id)
124
+
125
+ # 6. Remove any existing attribute for this scenario
126
+ cursor = conn.cursor()
127
+ cursor.execute(
128
+ "DELETE FROM component_attributes WHERE component_id = ? AND attribute_name = ? AND scenario_id = ?",
129
+ (component_id, attribute_name, resolved_scenario_id)
130
+ )
131
+
132
+ # 7. Insert new timeseries attribute
133
+ cursor.execute(
134
+ """INSERT INTO component_attributes
135
+ (component_id, attribute_name, scenario_id, storage_type, timeseries_data, data_type, unit, is_input)
136
+ VALUES (?, ?, ?, 'timeseries', ?, ?, ?, ?)""",
137
+ (component_id, attribute_name, resolved_scenario_id, parquet_data,
138
+ rule.data_type, rule.unit, rule.is_input)
139
+ )
140
+
141
+
142
+ def get_attribute(
143
+ conn: sqlite3.Connection,
144
+ component_id: int,
145
+ attribute_name: str,
146
+ scenario_id: Optional[int] = None
147
+ ) -> AttributeValue:
148
+ """
149
+ Get an attribute value with scenario fallback logic.
150
+
151
+ Args:
152
+ conn: Database connection
153
+ component_id: Component ID
154
+ attribute_name: Name of the attribute
155
+ scenario_id: Scenario ID (uses master scenario if None)
156
+
157
+ Returns:
158
+ AttributeValue containing either static or timeseries data
159
+
160
+ Raises:
161
+ ComponentNotFound: If component doesn't exist
162
+ AttributeNotFound: If attribute doesn't exist
163
+ """
164
+
165
+ # Get network_id from component to find master scenario
166
+ cursor = conn.cursor()
167
+ cursor.execute("SELECT network_id FROM components WHERE id = ?", (component_id,))
168
+ result = cursor.fetchone()
169
+ if not result:
170
+ raise ComponentNotFound(component_id)
171
+
172
+ network_id = result[0]
173
+
174
+ # Get master scenario ID
175
+ master_scenario_id = get_master_scenario_id(conn, network_id)
176
+
177
+ # Determine which scenario to check first
178
+ current_scenario_id = scenario_id if scenario_id is not None else master_scenario_id
179
+
180
+ # First try to get the attribute from the current scenario
181
+ cursor.execute(
182
+ """SELECT storage_type, static_value, timeseries_data, data_type, unit
183
+ FROM component_attributes
184
+ WHERE component_id = ? AND attribute_name = ? AND scenario_id = ?""",
185
+ (component_id, attribute_name, current_scenario_id)
186
+ )
187
+ result = cursor.fetchone()
188
+
189
+ # If not found in current scenario and current scenario is not master, try master scenario
190
+ if not result and current_scenario_id != master_scenario_id:
191
+ cursor.execute(
192
+ """SELECT storage_type, static_value, timeseries_data, data_type, unit
193
+ FROM component_attributes
194
+ WHERE component_id = ? AND attribute_name = ? AND scenario_id = ?""",
195
+ (component_id, attribute_name, master_scenario_id)
196
+ )
197
+ result = cursor.fetchone()
198
+
199
+ if not result:
200
+ raise AttributeNotFound(component_id, attribute_name)
201
+
202
+ storage_type, static_value_json, timeseries_data, data_type, unit = result
203
+
204
+ # Handle the deserialization based on storage type
205
+ if storage_type == "static":
206
+ if not static_value_json:
207
+ raise ValidationError("Static attribute missing value")
208
+
209
+ # Parse JSON value
210
+ json_value = json.loads(static_value_json)
211
+
212
+ # Convert based on data type
213
+ if data_type == "float":
214
+ if isinstance(json_value, (int, float)):
215
+ static_value = StaticValue(float(json_value))
216
+ else:
217
+ raise ValidationError("Expected float value")
218
+ elif data_type == "int":
219
+ if isinstance(json_value, (int, float)):
220
+ static_value = StaticValue(int(json_value))
221
+ else:
222
+ raise ValidationError("Expected integer value")
223
+ elif data_type == "boolean":
224
+ if isinstance(json_value, bool):
225
+ static_value = StaticValue(json_value)
226
+ else:
227
+ raise ValidationError("Expected boolean value")
228
+ elif data_type == "string":
229
+ if isinstance(json_value, str):
230
+ static_value = StaticValue(json_value)
231
+ else:
232
+ raise ValidationError("Expected string value")
233
+ else:
234
+ raise ValidationError(f"Unknown data type: {data_type}")
235
+
236
+ return AttributeValue.static(static_value)
237
+
238
+ elif storage_type == "timeseries":
239
+ if not timeseries_data:
240
+ raise ValidationError("Timeseries attribute missing data")
241
+
242
+ # Get network_id from component to load time periods
243
+ cursor = conn.execute("SELECT network_id FROM components WHERE id = ?", (component_id,))
244
+ network_row = cursor.fetchone()
245
+
246
+ network_time_periods = None
247
+ if network_row:
248
+ network_id = network_row[0]
249
+ try:
250
+ from pyconvexity.models.network import get_network_time_periods
251
+ network_time_periods = get_network_time_periods(conn, network_id)
252
+ except Exception as e:
253
+ logger.warning(f"Failed to load network time periods for timestamp computation: {e}")
254
+
255
+ # Deserialize from Parquet with proper timestamp computation
256
+ timeseries_points = deserialize_timeseries_from_parquet(timeseries_data, network_time_periods)
257
+ return AttributeValue.timeseries(timeseries_points)
258
+
259
+ else:
260
+ raise ValidationError(f"Unknown storage type: {storage_type}")
261
+
262
+
263
+ def delete_attribute(
264
+ conn: sqlite3.Connection,
265
+ component_id: int,
266
+ attribute_name: str,
267
+ scenario_id: Optional[int] = None
268
+ ) -> None:
269
+ """
270
+ Delete an attribute from a specific scenario.
271
+
272
+ Args:
273
+ conn: Database connection
274
+ component_id: Component ID
275
+ attribute_name: Name of the attribute
276
+ scenario_id: Scenario ID (uses master scenario if None)
277
+
278
+ Raises:
279
+ AttributeNotFound: If attribute doesn't exist
280
+ """
281
+ # Resolve scenario ID (get master scenario if None)
282
+ resolved_scenario_id = resolve_scenario_id(conn, component_id, scenario_id)
283
+
284
+ cursor = conn.cursor()
285
+ cursor.execute(
286
+ "DELETE FROM component_attributes WHERE component_id = ? AND attribute_name = ? AND scenario_id = ?",
287
+ (component_id, attribute_name, resolved_scenario_id)
288
+ )
289
+
290
+ if cursor.rowcount == 0:
291
+ raise AttributeNotFound(component_id, attribute_name)
292
+
293
+
294
+ # Helper functions
295
+
296
+ def resolve_scenario_id(conn: sqlite3.Connection, component_id: int, scenario_id: Optional[int]) -> int:
297
+ """Resolve scenario ID - if None, get master scenario ID."""
298
+ if scenario_id is not None:
299
+ return scenario_id
300
+
301
+ # Get network_id from component, then get master scenario
302
+ cursor = conn.cursor()
303
+ cursor.execute("SELECT network_id FROM components WHERE id = ?", (component_id,))
304
+ result = cursor.fetchone()
305
+ if not result:
306
+ raise ComponentNotFound(component_id)
307
+
308
+ network_id = result[0]
309
+ return get_master_scenario_id(conn, network_id)
310
+
311
+
312
+ def get_master_scenario_id(conn: sqlite3.Connection, network_id: int) -> int:
313
+ """Get the master scenario ID for a network."""
314
+ cursor = conn.cursor()
315
+ cursor.execute(
316
+ "SELECT id FROM scenarios WHERE network_id = ? AND is_master = TRUE",
317
+ (network_id,)
318
+ )
319
+ result = cursor.fetchone()
320
+ if not result:
321
+ raise ValidationError(f"No master scenario found for network {network_id}")
322
+ return result[0]
323
+
324
+
325
+ # Timeseries serialization functions
326
+
327
+ def serialize_timeseries_to_parquet(timeseries: List[TimeseriesPoint]) -> bytes:
328
+ """Serialize timeseries to Parquet format - EXACT MATCH WITH RUST SCHEMA."""
329
+ # Define the exact schema to match Rust expectations
330
+ schema = pa.schema([
331
+ ('period_index', pa.int32()),
332
+ ('value', pa.float64())
333
+ ])
334
+
335
+ if not timeseries:
336
+ # Return empty parquet file with correct schema
337
+ table = pa.table([], schema=schema)
338
+ else:
339
+ # Create PyArrow table with EXPLICIT schema to ensure data types match Rust
340
+ period_indices = [p.period_index for p in timeseries]
341
+ values = [p.value for p in timeseries]
342
+
343
+ # Create arrays with explicit types to ensure Int32 for period_index
344
+ period_array = pa.array(period_indices, type=pa.int32())
345
+ value_array = pa.array(values, type=pa.float64())
346
+
347
+ table = pa.table([period_array, value_array], schema=schema)
348
+
349
+ # Serialize to Parquet bytes with SNAPPY compression (match Rust)
350
+ buffer = BytesIO()
351
+ pq.write_table(table, buffer, compression='snappy')
352
+ return buffer.getvalue()
353
+
354
+
355
+ def deserialize_timeseries_from_parquet(data: bytes, network_time_periods: Optional[List[TimePeriod]] = None) -> List[TimeseriesPoint]:
356
+ """Deserialize timeseries from Parquet format - EXACT MATCH WITH RUST."""
357
+ if not data:
358
+ return []
359
+
360
+ buffer = BytesIO(data)
361
+ table = pq.read_table(buffer)
362
+
363
+ # Convert to pandas for easier handling
364
+ df = table.to_pandas()
365
+
366
+ points = []
367
+ for _, row in df.iterrows():
368
+ period_index = int(row['period_index'])
369
+
370
+ # Compute timestamp from period_index using network time periods if available
371
+ if network_time_periods and 0 <= period_index < len(network_time_periods):
372
+ timestamp = network_time_periods[period_index].timestamp
373
+ else:
374
+ # Fallback: use period_index as timestamp (matching previous behavior for compatibility)
375
+ timestamp = period_index
376
+
377
+ points.append(TimeseriesPoint(
378
+ timestamp=timestamp,
379
+ value=float(row['value']),
380
+ period_index=period_index
381
+ ))
382
+
383
+ return points