emhass 0.11.4__py3-none-any.whl → 0.15.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
emhass/command_line.py CHANGED
@@ -1,21 +1,19 @@
1
1
  #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
2
 
4
3
  import argparse
4
+ import asyncio
5
5
  import copy
6
- import json
7
6
  import logging
8
7
  import os
9
8
  import pathlib
10
9
  import pickle
11
- import re
12
- import time
13
- from datetime import datetime, timezone
14
- from distutils.util import strtobool
10
+ from dataclasses import dataclass
11
+ from datetime import UTC, datetime, timedelta
15
12
  from importlib.metadata import version
16
- from typing import Optional, Tuple
17
13
 
14
+ import aiofiles
18
15
  import numpy as np
16
+ import orjson
19
17
  import pandas as pd
20
18
 
21
19
  from emhass import utils
@@ -25,15 +23,598 @@ from emhass.machine_learning_regressor import MLRegressor
25
23
  from emhass.optimization import Optimization
26
24
  from emhass.retrieve_hass import RetrieveHass
27
25
 
26
+ default_csv_filename = "opt_res_latest.csv"
27
+ default_pkl_suffix = "_mlf.pkl"
28
+ default_metadata_json = "metadata.json"
29
+ test_df_literal = "test_df_final.pkl"
28
30
 
29
- def set_input_data_dict(
31
+
32
+ @dataclass
33
+ class SetupContext:
34
+ """Context object for optimization preparation helpers."""
35
+
36
+ retrieve_hass_conf: dict
37
+ optim_conf: dict
38
+ plant_conf: dict
39
+ emhass_conf: dict
40
+ params: dict
41
+ logger: logging.Logger
42
+ get_data_from_file: bool
43
+ rh: RetrieveHass
44
+ fcst: Forecast | None = None
45
+
46
+
47
+ @dataclass
48
+ class PublishContext:
49
+ """Context object for data publishing helpers."""
50
+
51
+ input_data_dict: dict
52
+ params: dict
53
+ idx: int
54
+ common_kwargs: dict
55
+ logger: logging.Logger
56
+
57
+ @property
58
+ def rh(self) -> RetrieveHass:
59
+ return self.input_data_dict["rh"]
60
+
61
+ @property
62
+ def opt(self) -> Optimization:
63
+ return self.input_data_dict["opt"]
64
+
65
+ @property
66
+ def fcst(self) -> Forecast:
67
+ return self.input_data_dict["fcst"]
68
+
69
+
70
+ async def _retrieve_from_file(
71
+ emhass_conf: dict,
72
+ test_df_literal: str,
73
+ rh: RetrieveHass,
74
+ retrieve_hass_conf: dict,
75
+ optim_conf: dict,
76
+ ) -> tuple[bool, object]:
77
+ """Helper to retrieve data from a pickle file and configure variables."""
78
+ async with aiofiles.open(emhass_conf["data_path"] / test_df_literal, "rb") as inp:
79
+ content = await inp.read()
80
+ rh.df_final, days_list, var_list, rh.ha_config = pickle.loads(content)
81
+ rh.var_list = var_list
82
+ # Assign variables based on set_type
83
+ retrieve_hass_conf["sensor_power_load_no_var_loads"] = str(var_list[0])
84
+ if optim_conf.get("set_use_pv", True):
85
+ retrieve_hass_conf["sensor_power_photovoltaics"] = str(var_list[1])
86
+ retrieve_hass_conf["sensor_linear_interp"] = [
87
+ retrieve_hass_conf["sensor_power_photovoltaics"],
88
+ retrieve_hass_conf["sensor_power_load_no_var_loads"],
89
+ ]
90
+ retrieve_hass_conf["sensor_replace_zero"] = [
91
+ retrieve_hass_conf["sensor_power_photovoltaics"],
92
+ var_list[2],
93
+ ]
94
+ else:
95
+ retrieve_hass_conf["sensor_linear_interp"] = [
96
+ retrieve_hass_conf["sensor_power_load_no_var_loads"]
97
+ ]
98
+ retrieve_hass_conf["sensor_replace_zero"] = []
99
+ return True, days_list
100
+
101
+
102
+ async def _retrieve_from_hass(
103
+ set_type: str,
104
+ retrieve_hass_conf: dict,
105
+ optim_conf: dict,
106
+ rh: RetrieveHass,
107
+ logger: logging.Logger | None,
108
+ ) -> tuple[bool, object]:
109
+ """Helper to retrieve live data from Home Assistant."""
110
+ # Determine days_list based on set_type
111
+ if set_type == "perfect-optim" or set_type == "adjust_pv":
112
+ days_list = utils.get_days_list(retrieve_hass_conf["historic_days_to_retrieve"])
113
+ elif set_type == "naive-mpc-optim":
114
+ days_list = utils.get_days_list(1)
115
+ else:
116
+ days_list = None # Not needed for dayahead
117
+ var_list = [retrieve_hass_conf["sensor_power_load_no_var_loads"]]
118
+ if optim_conf.get("set_use_pv", True):
119
+ var_list.append(retrieve_hass_conf["sensor_power_photovoltaics"])
120
+ if optim_conf.get("set_use_adjusted_pv", True):
121
+ var_list.append(retrieve_hass_conf["sensor_power_photovoltaics_forecast"])
122
+ if logger:
123
+ logger.debug(f"Variable list for data retrieval: {var_list}")
124
+ success = await rh.get_data(
125
+ days_list, var_list, minimal_response=False, significant_changes_only=False
126
+ )
127
+ return success, days_list
128
+
129
+
130
+ async def retrieve_home_assistant_data(
131
+ set_type: str,
132
+ get_data_from_file: bool,
133
+ retrieve_hass_conf: dict,
134
+ optim_conf: dict,
135
+ rh: RetrieveHass,
136
+ emhass_conf: dict,
137
+ test_df_literal: str,
138
+ logger: logging.Logger | None = None,
139
+ ) -> tuple[bool, pd.DataFrame | None, list | None]:
140
+ """Retrieve data from Home Assistant or file and prepare it for optimization."""
141
+
142
+ if get_data_from_file:
143
+ success, days_list = await _retrieve_from_file(
144
+ emhass_conf, test_df_literal, rh, retrieve_hass_conf, optim_conf
145
+ )
146
+ else:
147
+ success, days_list = await _retrieve_from_hass(
148
+ set_type, retrieve_hass_conf, optim_conf, rh, logger
149
+ )
150
+ if not success:
151
+ return False, None, days_list
152
+ rh.prepare_data(
153
+ retrieve_hass_conf["sensor_power_load_no_var_loads"],
154
+ load_negative=retrieve_hass_conf["load_negative"],
155
+ set_zero_min=retrieve_hass_conf["set_zero_min"],
156
+ var_replace_zero=retrieve_hass_conf["sensor_replace_zero"],
157
+ var_interp=retrieve_hass_conf["sensor_linear_interp"],
158
+ )
159
+ return True, rh.df_final.copy(), days_list
160
+
161
+
162
+ def is_model_outdated(model_path: pathlib.Path, max_age_hours: int, logger: logging.Logger) -> bool:
163
+ """
164
+ Check if the saved model file is outdated based on its modification time.
165
+
166
+ :param model_path: Path to the saved model file.
167
+ :type model_path: pathlib.Path
168
+ :param max_age_hours: Maximum age in hours before model is considered outdated.
169
+ :type max_age_hours: int
170
+ :param logger: Logger object for logging information.
171
+ :type logger: logging.Logger
172
+ :return: True if model is outdated or doesn't exist, False otherwise.
173
+ :rtype: bool
174
+ """
175
+ if not model_path.exists():
176
+ logger.info("Adjusted PV model file does not exist, will train new model")
177
+ return True
178
+
179
+ if max_age_hours <= 0:
180
+ logger.info("adjusted_pv_model_max_age is set to 0, forcing model re-fit")
181
+ return True
182
+
183
+ model_mtime = datetime.fromtimestamp(model_path.stat().st_mtime)
184
+ model_age = datetime.now() - model_mtime
185
+ max_age = timedelta(hours=max_age_hours)
186
+
187
+ if model_age > max_age:
188
+ logger.info(
189
+ f"Adjusted PV model is outdated (age: {model_age.total_seconds() / 3600:.1f}h, "
190
+ f"max: {max_age_hours}h), will train new model"
191
+ )
192
+ return True
193
+ else:
194
+ logger.info(
195
+ f"Using existing adjusted PV model (age: {model_age.total_seconds() / 3600:.1f}h, "
196
+ f"max: {max_age_hours}h)"
197
+ )
198
+ return False
199
+
200
+
201
+ async def _retrieve_and_fit_pv_model(
202
+ fcst: Forecast,
203
+ get_data_from_file: bool,
204
+ retrieve_hass_conf: dict,
205
+ optim_conf: dict,
206
+ rh: RetrieveHass,
207
+ emhass_conf: dict,
208
+ test_df_literal: pd.DataFrame,
209
+ ) -> bool:
210
+ """
211
+ Helper function to retrieve data and fit the PV adjustment model.
212
+
213
+ :param fcst: Forecast object used for PV forecast adjustment.
214
+ :type fcst: Forecast
215
+ :param get_data_from_file: Whether to retrieve data from a file instead of Home Assistant.
216
+ :type get_data_from_file: bool
217
+ :param retrieve_hass_conf: Configuration dictionary for retrieving data from Home Assistant.
218
+ :type retrieve_hass_conf: dict
219
+ :param optim_conf: Configuration dictionary for optimization settings.
220
+ :type optim_conf: dict
221
+ :param rh: RetrieveHass object for interacting with Home Assistant.
222
+ :type rh: RetrieveHass
223
+ :param emhass_conf: Configuration dictionary for emhass paths and settings.
224
+ :type emhass_conf: dict
225
+ :param test_df_literal: DataFrame containing test data for debugging purposes.
226
+ :type test_df_literal: pd.DataFrame
227
+ :return: True if successful, False otherwise.
228
+ :rtype: bool
229
+ """
230
+ # Retrieve data from Home Assistant
231
+ success, df_input_data, _ = await retrieve_home_assistant_data(
232
+ "adjust_pv",
233
+ get_data_from_file,
234
+ retrieve_hass_conf,
235
+ optim_conf,
236
+ rh,
237
+ emhass_conf,
238
+ test_df_literal,
239
+ )
240
+ if not success:
241
+ return False
242
+ # Call data preparation method
243
+ fcst.adjust_pv_forecast_data_prep(df_input_data)
244
+ # Call the fit method
245
+ await fcst.adjust_pv_forecast_fit(
246
+ n_splits=5,
247
+ regression_model=optim_conf["adjusted_pv_regression_model"],
248
+ )
249
+ return True
250
+
251
+
252
+ async def adjust_pv_forecast(
253
+ logger: logging.Logger,
254
+ fcst: Forecast,
255
+ p_pv_forecast: pd.Series,
256
+ get_data_from_file: bool,
257
+ retrieve_hass_conf: dict,
258
+ optim_conf: dict,
259
+ rh: RetrieveHass,
260
+ emhass_conf: dict,
261
+ test_df_literal: pd.DataFrame,
262
+ ) -> pd.Series:
263
+ """
264
+ Adjust the photovoltaic (PV) forecast using historical data and a regression model.
265
+
266
+ This method retrieves historical data, prepares it for model fitting, trains a regression
267
+ model, and adjusts the provided PV forecast based on the trained model.
268
+
269
+ :param logger: Logger object for logging information and errors.
270
+ :type logger: logging.Logger
271
+ :param fcst: Forecast object used for PV forecast adjustment.
272
+ :type fcst: Forecast
273
+ :param p_pv_forecast: The initial PV forecast to be adjusted.
274
+ :type p_pv_forecast: pd.Series
275
+ :param get_data_from_file: Whether to retrieve data from a file instead of Home Assistant.
276
+ :type get_data_from_file: bool
277
+ :param retrieve_hass_conf: Configuration dictionary for retrieving data from Home Assistant.
278
+ :type retrieve_hass_conf: dict
279
+ :param optim_conf: Configuration dictionary for optimization settings.
280
+ :type optim_conf: dict
281
+ :param rh: RetrieveHass object for interacting with Home Assistant.
282
+ :type rh: RetrieveHass
283
+ :param emhass_conf: Configuration dictionary for emhass paths and settings.
284
+ :type emhass_conf: dict
285
+ :param test_df_literal: DataFrame containing test data for debugging purposes.
286
+ :type test_df_literal: pd.DataFrame
287
+ :return: The adjusted PV forecast as a pandas Series.
288
+ :rtype: pd.Series
289
+ """
290
+ # Normalize data_path to Path object for safety (handles both str and Path types)
291
+ data_path = pathlib.Path(emhass_conf["data_path"])
292
+ model_filename = "adjust_pv_regressor.pkl"
293
+ model_path = data_path / model_filename
294
+ max_age_hours = optim_conf.get("adjusted_pv_model_max_age", 24)
295
+ # Check if model needs to be re-fitted
296
+ if is_model_outdated(model_path, max_age_hours, logger):
297
+ logger.info("Adjusting PV forecast, retrieving history data for model fit")
298
+ success = await _retrieve_and_fit_pv_model(
299
+ fcst,
300
+ get_data_from_file,
301
+ retrieve_hass_conf,
302
+ optim_conf,
303
+ rh,
304
+ emhass_conf,
305
+ test_df_literal,
306
+ )
307
+ if not success:
308
+ return False
309
+ else:
310
+ # Load existing model
311
+ logger.info("Loading existing adjusted PV model from file")
312
+ try:
313
+ async with aiofiles.open(model_path, "rb") as inp:
314
+ content = await inp.read()
315
+ fcst.model_adjust_pv = pickle.loads(content)
316
+ except (pickle.UnpicklingError, EOFError, AttributeError, ImportError) as e:
317
+ logger.error(f"Failed to load existing adjusted PV model: {type(e).__name__}: {str(e)}")
318
+ logger.warning(
319
+ "Model file may be corrupted or incompatible. Falling back to re-fitting the model."
320
+ )
321
+ # Use helper function to retrieve data and re-fit model
322
+ success = await _retrieve_and_fit_pv_model(
323
+ fcst,
324
+ get_data_from_file,
325
+ retrieve_hass_conf,
326
+ optim_conf,
327
+ rh,
328
+ emhass_conf,
329
+ test_df_literal,
330
+ )
331
+ if not success:
332
+ logger.error("Failed to retrieve data for model re-fit after load error")
333
+ return False
334
+ logger.info("Successfully re-fitted model after load failure")
335
+ except Exception as e:
336
+ logger.error(
337
+ f"Unexpected error loading adjusted PV model: {type(e).__name__}: {str(e)}"
338
+ )
339
+ logger.error("Cannot recover from this error")
340
+ return False
341
+ # Call the predict method
342
+ p_pv_forecast = p_pv_forecast.rename("forecast").to_frame()
343
+ p_pv_forecast = fcst.adjust_pv_forecast_predict(forecasted_pv=p_pv_forecast)
344
+ # Update the PV forecast
345
+ return p_pv_forecast["adjusted_forecast"].rename(None)
346
+
347
+
348
+ async def _prepare_perfect_optim(ctx: SetupContext):
349
+ """Helper to prepare data for perfect optimization."""
350
+ success, df_input_data, days_list = await retrieve_home_assistant_data(
351
+ "perfect-optim",
352
+ ctx.get_data_from_file,
353
+ ctx.retrieve_hass_conf,
354
+ ctx.optim_conf,
355
+ ctx.rh,
356
+ ctx.emhass_conf,
357
+ test_df_literal,
358
+ ctx.logger,
359
+ )
360
+ if not success:
361
+ return None
362
+ return {
363
+ "df_input_data": df_input_data,
364
+ "days_list": days_list,
365
+ }
366
+
367
+
368
+ async def _get_dayahead_pv_forecast(ctx: SetupContext):
369
+ """Helper to retrieve and optionally adjust PV forecast."""
370
+ # Check if we should calculate PV forecast
371
+ if not (
372
+ ctx.optim_conf["set_use_pv"]
373
+ or ctx.optim_conf.get("weather_forecast_method", None) == "list"
374
+ ):
375
+ return pd.Series(0, index=ctx.fcst.forecast_dates), None
376
+ # Get weather forecast
377
+ df_weather = await ctx.fcst.get_weather_forecast(
378
+ method=ctx.optim_conf["weather_forecast_method"]
379
+ )
380
+ if isinstance(df_weather, bool) and not df_weather:
381
+ return None, None
382
+ p_pv_forecast = ctx.fcst.get_power_from_weather(df_weather)
383
+ # Adjust PV forecast if needed
384
+ if ctx.optim_conf["set_use_adjusted_pv"]:
385
+ p_pv_forecast = await adjust_pv_forecast(
386
+ ctx.logger,
387
+ ctx.fcst,
388
+ p_pv_forecast,
389
+ ctx.get_data_from_file,
390
+ ctx.retrieve_hass_conf,
391
+ ctx.optim_conf,
392
+ ctx.rh,
393
+ ctx.emhass_conf,
394
+ test_df_literal,
395
+ )
396
+ return p_pv_forecast, df_weather
397
+
398
+
399
+ def _apply_df_freq_horizon(
400
+ df: pd.DataFrame, retrieve_hass_conf: dict, prediction_horizon: int | None
401
+ ) -> pd.DataFrame:
402
+ """Helper to apply frequency adjustment and prediction horizon slicing."""
403
+ # Handle Frequency
404
+ if retrieve_hass_conf.get("optimization_time_step"):
405
+ step = retrieve_hass_conf["optimization_time_step"]
406
+ if not isinstance(step, pd._libs.tslibs.timedeltas.Timedelta):
407
+ step = pd.to_timedelta(step, "minute")
408
+ df = df.asfreq(step)
409
+ else:
410
+ df = utils.set_df_index_freq(df)
411
+ # Handle Prediction Horizon
412
+ if prediction_horizon:
413
+ # Slice the dataframe up to the horizon
414
+ df = copy.deepcopy(df)[df.index[0] : df.index[prediction_horizon - 1]]
415
+ return df
416
+
417
+
418
+ async def _prepare_dayahead_optim(ctx: SetupContext):
419
+ """Helper to prepare data for day-ahead optimization."""
420
+ # Get PV Forecast
421
+ p_pv_forecast, df_weather = await _get_dayahead_pv_forecast(ctx)
422
+ if p_pv_forecast is None:
423
+ return None
424
+ # Get Load Forecast
425
+ p_load_forecast = await ctx.fcst.get_load_forecast(
426
+ days_min_load_forecast=ctx.optim_conf["delta_forecast_daily"].days,
427
+ method=ctx.optim_conf["load_forecast_method"],
428
+ )
429
+ if isinstance(p_load_forecast, bool) and not p_load_forecast:
430
+ ctx.logger.error("Unable to get load forecast.")
431
+ return None
432
+ # Build Input DataFrame
433
+ df_input_data_dayahead = pd.DataFrame(
434
+ np.transpose(np.vstack([p_pv_forecast.values, p_load_forecast.values])),
435
+ index=p_pv_forecast.index,
436
+ columns=["p_pv_forecast", "p_load_forecast"],
437
+ )
438
+ # Apply Frequency and Prediction Horizon
439
+ # Use explicitly passed horizon, avoiding JSON re-parsing
440
+ prediction_horizon = ctx.params["passed_data"].get("prediction_horizon")
441
+ df_input_data_dayahead = _apply_df_freq_horizon(
442
+ df_input_data_dayahead, ctx.retrieve_hass_conf, prediction_horizon
443
+ )
444
+ return {
445
+ "df_input_data_dayahead": df_input_data_dayahead,
446
+ "df_weather": df_weather,
447
+ "p_pv_forecast": p_pv_forecast,
448
+ "p_load_forecast": p_load_forecast,
449
+ }
450
+
451
+
452
+ async def _get_naive_mpc_history(ctx: SetupContext):
453
+ """Helper to retrieve historical data for Naive MPC."""
454
+ # Check if we need to skip historical data retrieval
455
+ is_list_forecast = ctx.optim_conf.get("load_forecast_method") == "list"
456
+ is_list_weather = ctx.optim_conf.get("weather_forecast_method") == "list"
457
+ no_pv = not ctx.optim_conf["set_use_pv"]
458
+
459
+ if (is_list_forecast and is_list_weather) or (is_list_forecast and no_pv):
460
+ return True, None, None, False # success, df, days_list, set_mix_forecast
461
+ # Retrieve data from Home Assistant
462
+ success, df_input_data, days_list = await retrieve_home_assistant_data(
463
+ "naive-mpc-optim",
464
+ ctx.get_data_from_file,
465
+ ctx.retrieve_hass_conf,
466
+ ctx.optim_conf,
467
+ ctx.rh,
468
+ ctx.emhass_conf,
469
+ test_df_literal,
470
+ ctx.logger,
471
+ )
472
+ return success, df_input_data, days_list, True
473
+
474
+
475
+ async def _get_naive_mpc_pv_forecast(ctx: SetupContext, set_mix_forecast, df_input_data):
476
+ """Helper to generate PV forecast for Naive MPC."""
477
+ # If PV is disabled and no weather list, return zero series
478
+ if not (
479
+ ctx.optim_conf["set_use_pv"] or ctx.optim_conf.get("weather_forecast_method") == "list"
480
+ ):
481
+ return pd.Series(0, index=ctx.fcst.forecast_dates), None
482
+ # Get weather forecast
483
+ df_weather = await ctx.fcst.get_weather_forecast(
484
+ method=ctx.optim_conf["weather_forecast_method"]
485
+ )
486
+ if isinstance(df_weather, bool) and not df_weather:
487
+ return None, None
488
+ # Calculate PV power
489
+ p_pv_forecast = ctx.fcst.get_power_from_weather(
490
+ df_weather, set_mix_forecast=set_mix_forecast, df_now=df_input_data
491
+ )
492
+ # Adjust PV forecast if needed
493
+ if ctx.optim_conf["set_use_adjusted_pv"]:
494
+ p_pv_forecast = await adjust_pv_forecast(
495
+ ctx.logger,
496
+ ctx.fcst,
497
+ p_pv_forecast,
498
+ ctx.get_data_from_file,
499
+ ctx.retrieve_hass_conf,
500
+ ctx.optim_conf,
501
+ ctx.rh,
502
+ ctx.emhass_conf,
503
+ test_df_literal,
504
+ )
505
+ return p_pv_forecast, df_weather
506
+
507
+
508
+ async def _prepare_naive_mpc_optim(ctx: SetupContext):
509
+ """Helper to prepare data for Naive MPC optimization."""
510
+ # Retrieve Historical Data
511
+ success, df_input_data, days_list, set_mix_forecast = await _get_naive_mpc_history(ctx)
512
+ if not success:
513
+ return None
514
+ # Get PV Forecast
515
+ p_pv_forecast, df_weather = await _get_naive_mpc_pv_forecast(
516
+ ctx, set_mix_forecast, df_input_data
517
+ )
518
+ if p_pv_forecast is None:
519
+ return None
520
+ # Get Load Forecast
521
+ p_load_forecast = await ctx.fcst.get_load_forecast(
522
+ days_min_load_forecast=ctx.optim_conf["delta_forecast_daily"].days,
523
+ method=ctx.optim_conf["load_forecast_method"],
524
+ set_mix_forecast=set_mix_forecast,
525
+ df_now=df_input_data,
526
+ )
527
+ if isinstance(p_load_forecast, bool) and not p_load_forecast:
528
+ return None
529
+ # Build and Format Input DataFrame
530
+ df_input_data_dayahead = pd.concat([p_pv_forecast, p_load_forecast], axis=1)
531
+ df_input_data_dayahead.columns = ["p_pv_forecast", "p_load_forecast"]
532
+ # Reuse freq/horizon helper
533
+ prediction_horizon = ctx.params["passed_data"].get("prediction_horizon")
534
+ df_input_data_dayahead = _apply_df_freq_horizon(
535
+ df_input_data_dayahead, ctx.retrieve_hass_conf, prediction_horizon
536
+ )
537
+ return {
538
+ "df_input_data": df_input_data,
539
+ "days_list": days_list,
540
+ "df_input_data_dayahead": df_input_data_dayahead,
541
+ "df_weather": df_weather,
542
+ "p_pv_forecast": p_pv_forecast,
543
+ "p_load_forecast": p_load_forecast,
544
+ }
545
+
546
+
547
+ async def _prepare_ml_fit_predict(ctx: SetupContext):
548
+ """Helper to prepare data for ML fit/predict/tune."""
549
+ days_to_retrieve = ctx.params["passed_data"]["historic_days_to_retrieve"]
550
+ model_type = ctx.params["passed_data"]["model_type"]
551
+ var_model = ctx.params["passed_data"]["var_model"]
552
+ if ctx.get_data_from_file:
553
+ filename = model_type + ".pkl"
554
+ filename_path = ctx.emhass_conf["data_path"] / filename
555
+ async with aiofiles.open(filename_path, "rb") as inp:
556
+ content = await inp.read()
557
+ df_input_data, _, _, _ = pickle.loads(content)
558
+ df_input_data = df_input_data[df_input_data.index[-1] - pd.offsets.Day(days_to_retrieve) :]
559
+ return {"df_input_data": df_input_data}
560
+ else:
561
+ days_list = utils.get_days_list(days_to_retrieve)
562
+ var_list = [var_model]
563
+ if not await ctx.rh.get_data(days_list, var_list):
564
+ return None
565
+ ctx.rh.prepare_data(
566
+ var_model,
567
+ load_negative=ctx.retrieve_hass_conf.get("load_negative", False),
568
+ set_zero_min=ctx.retrieve_hass_conf.get("set_zero_min", True),
569
+ var_replace_zero=ctx.retrieve_hass_conf.get("sensor_replace_zero", []),
570
+ var_interp=ctx.retrieve_hass_conf.get("sensor_linear_interp", []),
571
+ skip_renaming=True,
572
+ )
573
+ return {"df_input_data": ctx.rh.df_final.copy()}
574
+
575
+
576
+ def _prepare_regressor_fit(ctx: SetupContext):
577
+ """Helper to prepare data for Regressor fit/predict."""
578
+ csv_file = ctx.params["passed_data"].get("csv_file", None)
579
+ if not csv_file:
580
+ ctx.logger.error("csv_file is required for regressor actions but was not provided.")
581
+ return None
582
+ if ctx.get_data_from_file:
583
+ base_path = ctx.emhass_conf["data_path"]
584
+ filename_path = pathlib.Path(base_path) / csv_file
585
+ else:
586
+ filename_path = ctx.emhass_conf["data_path"] / csv_file
587
+ if filename_path.is_file():
588
+ df_input_data = pd.read_csv(filename_path, parse_dates=True)
589
+ else:
590
+ ctx.logger.error(
591
+ f"The CSV file {csv_file} was not found in path: {ctx.emhass_conf['data_path']}"
592
+ )
593
+ return None
594
+ # Validate columns
595
+ required_columns = []
596
+ if "features" in ctx.params["passed_data"]:
597
+ required_columns.extend(ctx.params["passed_data"]["features"])
598
+ if "target" in ctx.params["passed_data"]:
599
+ required_columns.append(ctx.params["passed_data"]["target"])
600
+ if "timestamp" in ctx.params["passed_data"]:
601
+ required_columns.append(ctx.params["passed_data"]["timestamp"])
602
+ if not set(required_columns).issubset(df_input_data.columns):
603
+ ctx.logger.error(
604
+ f"The csv file does not contain the required columns: {', '.join(required_columns)}"
605
+ )
606
+ return None
607
+ return {"df_input_data": df_input_data}
608
+
609
+
610
+ async def set_input_data_dict(
30
611
  emhass_conf: dict,
31
612
  costfun: str,
32
613
  params: str,
33
614
  runtimeparams: str,
34
615
  set_type: str,
35
616
  logger: logging.Logger,
36
- get_data_from_file: Optional[bool] = False,
617
+ get_data_from_file: bool | None = False,
37
618
  ) -> dict:
38
619
  """
39
620
  Set up some of the data needed for the different actions.
@@ -57,21 +638,21 @@ def set_input_data_dict(
57
638
 
58
639
  """
59
640
  logger.info("Setting up needed data")
60
-
61
- # check if passed params is a dict
62
- if (params != None) and (params != "null"):
63
- if type(params) is str:
64
- params = json.loads(params)
641
+ # Parse Parameters
642
+ if (params is not None) and (params != "null"):
643
+ if isinstance(params, str):
644
+ params = dict(orjson.loads(params))
65
645
  else:
66
646
  params = {}
67
-
68
- # Parsing yaml
69
647
  retrieve_hass_conf, optim_conf, plant_conf = utils.get_yaml_parse(params, logger)
70
648
  if type(retrieve_hass_conf) is bool:
71
649
  return False
72
-
73
- # Treat runtimeparams
74
- params, retrieve_hass_conf, optim_conf, plant_conf = utils.treat_runtimeparams(
650
+ (
651
+ params,
652
+ retrieve_hass_conf,
653
+ optim_conf,
654
+ plant_conf,
655
+ ) = await utils.treat_runtimeparams(
75
656
  runtimeparams,
76
657
  params,
77
658
  retrieve_hass_conf,
@@ -81,7 +662,9 @@ def set_input_data_dict(
81
662
  logger,
82
663
  emhass_conf,
83
664
  )
84
- # Define main objects
665
+ if isinstance(params, str):
666
+ params = dict(orjson.loads(params))
667
+ # Initialize Core Objects
85
668
  rh = RetrieveHass(
86
669
  retrieve_hass_conf["hass_url"],
87
670
  retrieve_hass_conf["long_lived_token"],
@@ -92,6 +675,21 @@ def set_input_data_dict(
92
675
  logger,
93
676
  get_data_from_file=get_data_from_file,
94
677
  )
678
+ # Retrieve HA config
679
+ if get_data_from_file:
680
+ async with aiofiles.open(emhass_conf["data_path"] / test_df_literal, "rb") as inp:
681
+ content = await inp.read()
682
+ _, _, _, rh.ha_config = pickle.loads(content)
683
+ elif not await rh.get_ha_config():
684
+ return False
685
+ if isinstance(params, dict):
686
+ params_str = orjson.dumps(params).decode("utf-8")
687
+ params = utils.update_params_with_ha_config(params_str, rh.ha_config)
688
+ else:
689
+ params = utils.update_params_with_ha_config(params, rh.ha_config)
690
+ if isinstance(params, str):
691
+ params = dict(orjson.loads(params))
692
+ costfun = optim_conf.get("costfun", costfun)
95
693
  fcst = Forecast(
96
694
  retrieve_hass_conf,
97
695
  optim_conf,
@@ -111,281 +709,62 @@ def set_input_data_dict(
111
709
  emhass_conf,
112
710
  logger,
113
711
  )
114
- # Perform setup based on type of action
712
+ # Create SetupContext
713
+ ctx = SetupContext(
714
+ retrieve_hass_conf=retrieve_hass_conf,
715
+ optim_conf=optim_conf,
716
+ plant_conf=plant_conf,
717
+ emhass_conf=emhass_conf,
718
+ params=params,
719
+ logger=logger,
720
+ get_data_from_file=get_data_from_file,
721
+ rh=rh,
722
+ fcst=fcst,
723
+ )
724
+ # Initialize Default Return Data
725
+ data_results = {
726
+ "df_input_data": None,
727
+ "df_input_data_dayahead": None,
728
+ "df_weather": None,
729
+ "p_pv_forecast": None,
730
+ "p_load_forecast": None,
731
+ "days_list": None,
732
+ }
733
+ # Delegate to Helpers based on set_type
734
+ result = None
115
735
  if set_type == "perfect-optim":
116
- # Retrieve data from hass
117
- if get_data_from_file:
118
- with open(emhass_conf["data_path"] / "test_df_final.pkl", "rb") as inp:
119
- rh.df_final, days_list, var_list = pickle.load(inp)
120
- retrieve_hass_conf["sensor_power_load_no_var_loads"] = str(var_list[0])
121
- retrieve_hass_conf["sensor_power_photovoltaics"] = str(var_list[1])
122
- retrieve_hass_conf["sensor_linear_interp"] = [
123
- retrieve_hass_conf["sensor_power_photovoltaics"],
124
- retrieve_hass_conf["sensor_power_load_no_var_loads"],
125
- ]
126
- retrieve_hass_conf["sensor_replace_zero"] = [
127
- retrieve_hass_conf["sensor_power_photovoltaics"]
128
- ]
129
- else:
130
- days_list = utils.get_days_list(
131
- retrieve_hass_conf["historic_days_to_retrieve"]
132
- )
133
- var_list = [
134
- retrieve_hass_conf["sensor_power_load_no_var_loads"],
135
- retrieve_hass_conf["sensor_power_photovoltaics"],
136
- ]
137
- if not rh.get_data(
138
- days_list,
139
- var_list,
140
- minimal_response=False,
141
- significant_changes_only=False,
142
- ):
143
- return False
144
- if not rh.prepare_data(
145
- retrieve_hass_conf["sensor_power_load_no_var_loads"],
146
- load_negative=retrieve_hass_conf["load_negative"],
147
- set_zero_min=retrieve_hass_conf["set_zero_min"],
148
- var_replace_zero=retrieve_hass_conf["sensor_replace_zero"],
149
- var_interp=retrieve_hass_conf["sensor_linear_interp"],
150
- ):
151
- return False
152
- df_input_data = rh.df_final.copy()
153
- # What we don't need for this type of action
154
- P_PV_forecast, P_load_forecast, df_input_data_dayahead = None, None, None
736
+ result = await _prepare_perfect_optim(ctx)
155
737
  elif set_type == "dayahead-optim":
156
- # Get PV and load forecasts
157
- df_weather = fcst.get_weather_forecast(
158
- method=optim_conf["weather_forecast_method"]
159
- )
160
- if isinstance(df_weather, bool) and not df_weather:
161
- return False
162
- P_PV_forecast = fcst.get_power_from_weather(df_weather)
163
- P_load_forecast = fcst.get_load_forecast(
164
- method=optim_conf["load_forecast_method"]
165
- )
166
- if isinstance(P_load_forecast, bool) and not P_load_forecast:
167
- logger.error(
168
- "Unable to get sensor power photovoltaics, or sensor power load no var loads. Check HA sensors and their daily data"
169
- )
170
- return False
171
- df_input_data_dayahead = pd.DataFrame(
172
- np.transpose(np.vstack([P_PV_forecast.values, P_load_forecast.values])),
173
- index=P_PV_forecast.index,
174
- columns=["P_PV_forecast", "P_load_forecast"],
175
- )
176
- if (
177
- "optimization_time_step" in retrieve_hass_conf
178
- and retrieve_hass_conf["optimization_time_step"]
179
- ):
180
- if not isinstance(
181
- retrieve_hass_conf["optimization_time_step"],
182
- pd._libs.tslibs.timedeltas.Timedelta,
183
- ):
184
- optimization_time_step = pd.to_timedelta(
185
- retrieve_hass_conf["optimization_time_step"], "minute"
186
- )
187
- else:
188
- optimization_time_step = retrieve_hass_conf["optimization_time_step"]
189
- df_input_data_dayahead = df_input_data_dayahead.asfreq(
190
- optimization_time_step
191
- )
192
- else:
193
- df_input_data_dayahead = utils.set_df_index_freq(df_input_data_dayahead)
194
- params = json.loads(params)
195
- if (
196
- "prediction_horizon" in params["passed_data"]
197
- and params["passed_data"]["prediction_horizon"] is not None
198
- ):
199
- prediction_horizon = params["passed_data"]["prediction_horizon"]
200
- df_input_data_dayahead = copy.deepcopy(df_input_data_dayahead)[
201
- df_input_data_dayahead.index[0] : df_input_data_dayahead.index[
202
- prediction_horizon - 1
203
- ]
204
- ]
205
- # What we don't need for this type of action
206
- df_input_data, days_list = None, None
738
+ result = await _prepare_dayahead_optim(ctx)
207
739
  elif set_type == "naive-mpc-optim":
208
- # Retrieve data from hass
209
- if get_data_from_file:
210
- with open(emhass_conf["data_path"] / "test_df_final.pkl", "rb") as inp:
211
- rh.df_final, days_list, var_list = pickle.load(inp)
212
- retrieve_hass_conf["sensor_power_load_no_var_loads"] = str(var_list[0])
213
- retrieve_hass_conf["sensor_power_photovoltaics"] = str(var_list[1])
214
- retrieve_hass_conf["sensor_linear_interp"] = [
215
- retrieve_hass_conf["sensor_power_photovoltaics"],
216
- retrieve_hass_conf["sensor_power_load_no_var_loads"],
217
- ]
218
- retrieve_hass_conf["sensor_replace_zero"] = [
219
- retrieve_hass_conf["sensor_power_photovoltaics"]
220
- ]
221
- else:
222
- days_list = utils.get_days_list(1)
223
- var_list = [
224
- retrieve_hass_conf["sensor_power_load_no_var_loads"],
225
- retrieve_hass_conf["sensor_power_photovoltaics"],
226
- ]
227
- if not rh.get_data(
228
- days_list,
229
- var_list,
230
- minimal_response=False,
231
- significant_changes_only=False,
232
- ):
233
- return False
234
- if not rh.prepare_data(
235
- retrieve_hass_conf["sensor_power_load_no_var_loads"],
236
- load_negative=retrieve_hass_conf["load_negative"],
237
- set_zero_min=retrieve_hass_conf["set_zero_min"],
238
- var_replace_zero=retrieve_hass_conf["sensor_replace_zero"],
239
- var_interp=retrieve_hass_conf["sensor_linear_interp"],
240
- ):
241
- return False
242
- df_input_data = rh.df_final.copy()
243
- # Get PV and load forecasts
244
- df_weather = fcst.get_weather_forecast(
245
- method=optim_conf["weather_forecast_method"]
246
- )
247
- if isinstance(df_weather, bool) and not df_weather:
248
- return False
249
- P_PV_forecast = fcst.get_power_from_weather(
250
- df_weather, set_mix_forecast=True, df_now=df_input_data
251
- )
252
- P_load_forecast = fcst.get_load_forecast(
253
- method=optim_conf["load_forecast_method"],
254
- set_mix_forecast=True,
255
- df_now=df_input_data,
256
- )
257
- if isinstance(P_load_forecast, bool) and not P_load_forecast:
258
- logger.error(
259
- "Unable to get sensor power photovoltaics, or sensor power load no var loads. Check HA sensors and their daily data"
260
- )
261
- return False
262
- df_input_data_dayahead = pd.concat([P_PV_forecast, P_load_forecast], axis=1)
263
- if (
264
- "optimization_time_step" in retrieve_hass_conf
265
- and retrieve_hass_conf["optimization_time_step"]
266
- ):
267
- if not isinstance(
268
- retrieve_hass_conf["optimization_time_step"],
269
- pd._libs.tslibs.timedeltas.Timedelta,
270
- ):
271
- optimization_time_step = pd.to_timedelta(
272
- retrieve_hass_conf["optimization_time_step"], "minute"
273
- )
274
- else:
275
- optimization_time_step = retrieve_hass_conf["optimization_time_step"]
276
- df_input_data_dayahead = df_input_data_dayahead.asfreq(
277
- optimization_time_step
278
- )
279
- else:
280
- df_input_data_dayahead = utils.set_df_index_freq(df_input_data_dayahead)
281
- df_input_data_dayahead.columns = ["P_PV_forecast", "P_load_forecast"]
282
- params = json.loads(params)
283
- if (
284
- "prediction_horizon" in params["passed_data"]
285
- and params["passed_data"]["prediction_horizon"] is not None
286
- ):
287
- prediction_horizon = params["passed_data"]["prediction_horizon"]
288
- df_input_data_dayahead = copy.deepcopy(df_input_data_dayahead)[
289
- df_input_data_dayahead.index[0] : df_input_data_dayahead.index[
290
- prediction_horizon - 1
291
- ]
292
- ]
293
- elif (
294
- set_type == "forecast-model-fit"
295
- or set_type == "forecast-model-predict"
296
- or set_type == "forecast-model-tune"
297
- ):
298
- df_input_data_dayahead = None
299
- P_PV_forecast, P_load_forecast = None, None
300
- params = json.loads(params)
301
- # Retrieve data from hass
302
- days_to_retrieve = params["passed_data"]["historic_days_to_retrieve"]
303
- model_type = params["passed_data"]["model_type"]
304
- var_model = params["passed_data"]["var_model"]
305
- if get_data_from_file:
306
- days_list = None
307
- filename = "data_train_" + model_type + ".pkl"
308
- filename_path = emhass_conf["data_path"] / filename
309
- with open(filename_path, "rb") as inp:
310
- df_input_data, _ = pickle.load(inp)
311
- df_input_data = df_input_data[
312
- df_input_data.index[-1] - pd.offsets.Day(days_to_retrieve) :
313
- ]
314
- else:
315
- days_list = utils.get_days_list(days_to_retrieve)
316
- var_list = [var_model]
317
- if not rh.get_data(days_list, var_list):
318
- return False
319
- df_input_data = rh.df_final.copy()
320
- elif set_type == "regressor-model-fit" or set_type == "regressor-model-predict":
321
- df_input_data, df_input_data_dayahead = None, None
322
- P_PV_forecast, P_load_forecast = None, None
323
- params = json.loads(params)
324
- days_list = None
325
- csv_file = params["passed_data"].get("csv_file", None)
326
- if "features" in params["passed_data"]:
327
- features = params["passed_data"]["features"]
328
- if "target" in params["passed_data"]:
329
- target = params["passed_data"]["target"]
330
- if "timestamp" in params["passed_data"]:
331
- timestamp = params["passed_data"]["timestamp"]
332
- if csv_file:
333
- if get_data_from_file:
334
- base_path = emhass_conf["data_path"] # + "/data"
335
- filename_path = pathlib.Path(base_path) / csv_file
336
- else:
337
- filename_path = emhass_conf["data_path"] / csv_file
338
- if filename_path.is_file():
339
- df_input_data = pd.read_csv(filename_path, parse_dates=True)
340
- else:
341
- logger.error(
342
- "The CSV file "
343
- + csv_file
344
- + " was not found in path: "
345
- + str(emhass_conf["data_path"])
346
- )
347
- return False
348
- # raise ValueError("The CSV file " + csv_file + " was not found.")
349
- required_columns = []
350
- required_columns.extend(features)
351
- required_columns.append(target)
352
- if timestamp is not None:
353
- required_columns.append(timestamp)
354
- if not set(required_columns).issubset(df_input_data.columns):
355
- logger.error("The cvs file does not contain the required columns.")
356
- msg = f"CSV file should contain the following columns: {', '.join(required_columns)}"
357
- logger.error(msg)
358
- return False
359
- elif set_type == "publish-data":
360
- df_input_data, df_input_data_dayahead = None, None
361
- P_PV_forecast, P_load_forecast = None, None
362
- days_list = None
740
+ result = await _prepare_naive_mpc_optim(ctx)
741
+ elif set_type in ["forecast-model-fit", "forecast-model-predict", "forecast-model-tune"]:
742
+ result = await _prepare_ml_fit_predict(ctx)
743
+ elif set_type in ["regressor-model-fit", "regressor-model-predict"]:
744
+ result = _prepare_regressor_fit(ctx)
745
+ elif set_type == "publish-data" or set_type == "export-influxdb-to-csv":
746
+ result = {}
363
747
  else:
364
- logger.error(
365
- "The passed action argument and hence the set_type parameter for setup is not valid",
366
- )
367
- df_input_data, df_input_data_dayahead = None, None
368
- P_PV_forecast, P_load_forecast = None, None
369
- days_list = None
370
- # The input data dictionary to return
748
+ logger.error(f"The passed action set_type parameter '{set_type}' is not valid")
749
+ result = {}
750
+ if result is None:
751
+ return False
752
+ data_results.update(result)
753
+ # Build Final Dictionary
371
754
  input_data_dict = {
372
755
  "emhass_conf": emhass_conf,
373
756
  "retrieve_hass_conf": retrieve_hass_conf,
374
757
  "rh": rh,
375
758
  "opt": opt,
376
759
  "fcst": fcst,
377
- "df_input_data": df_input_data,
378
- "df_input_data_dayahead": df_input_data_dayahead,
379
- "P_PV_forecast": P_PV_forecast,
380
- "P_load_forecast": P_load_forecast,
381
760
  "costfun": costfun,
382
761
  "params": params,
383
- "days_list": days_list,
762
+ **data_results,
384
763
  }
385
764
  return input_data_dict
386
765
 
387
766
 
388
- def weather_forecast_cache(
767
+ async def weather_forecast_cache(
389
768
  emhass_conf: dict, params: str, runtimeparams: str, logger: logging.Logger
390
769
  ) -> bool:
391
770
  """
@@ -403,12 +782,15 @@ def weather_forecast_cache(
403
782
  :rtype: bool
404
783
 
405
784
  """
406
-
407
785
  # Parsing yaml
408
786
  retrieve_hass_conf, optim_conf, plant_conf = utils.get_yaml_parse(params, logger)
409
-
410
787
  # Treat runtimeparams
411
- params, retrieve_hass_conf, optim_conf, plant_conf = utils.treat_runtimeparams(
788
+ (
789
+ params,
790
+ retrieve_hass_conf,
791
+ optim_conf,
792
+ plant_conf,
793
+ ) = await utils.treat_runtimeparams(
412
794
  runtimeparams,
413
795
  params,
414
796
  retrieve_hass_conf,
@@ -418,32 +800,27 @@ def weather_forecast_cache(
418
800
  logger,
419
801
  emhass_conf,
420
802
  )
421
-
422
803
  # Make sure weather_forecast_cache is true
423
- if (params != None) and (params != "null"):
424
- params = json.loads(params)
804
+ if (params is not None) and (params != "null"):
805
+ params = orjson.loads(params)
425
806
  else:
426
807
  params = {}
427
808
  params["passed_data"]["weather_forecast_cache"] = True
428
- params = json.dumps(params)
429
-
809
+ params = orjson.dumps(params).decode("utf-8")
430
810
  # Create Forecast object
431
- fcst = Forecast(
432
- retrieve_hass_conf, optim_conf, plant_conf, params, emhass_conf, logger
433
- )
434
-
435
- result = fcst.get_weather_forecast(optim_conf["weather_forecast_method"])
811
+ fcst = Forecast(retrieve_hass_conf, optim_conf, plant_conf, params, emhass_conf, logger)
812
+ result = await fcst.get_weather_forecast(optim_conf["weather_forecast_method"])
436
813
  if isinstance(result, bool) and not result:
437
814
  return False
438
815
 
439
816
  return True
440
817
 
441
818
 
442
- def perfect_forecast_optim(
819
+ async def perfect_forecast_optim(
443
820
  input_data_dict: dict,
444
821
  logger: logging.Logger,
445
- save_data_to_file: Optional[bool] = True,
446
- debug: Optional[bool] = False,
822
+ save_data_to_file: bool | None = True,
823
+ debug: bool | None = False,
447
824
  ) -> pd.DataFrame:
448
825
  """
449
826
  Perform a call to the perfect forecast optimization routine.
@@ -483,14 +860,14 @@ def perfect_forecast_optim(
483
860
  if save_data_to_file:
484
861
  filename = "opt_res_perfect_optim_" + input_data_dict["costfun"] + ".csv"
485
862
  else: # Just save the latest optimization results
486
- filename = "opt_res_latest.csv"
863
+ filename = default_csv_filename
487
864
  if not debug:
488
865
  opt_res.to_csv(
489
866
  input_data_dict["emhass_conf"]["data_path"] / filename,
490
867
  index_label="timestamp",
491
868
  )
492
869
  if not isinstance(input_data_dict["params"], dict):
493
- params = json.loads(input_data_dict["params"])
870
+ params = orjson.loads(input_data_dict["params"])
494
871
  else:
495
872
  params = input_data_dict["params"]
496
873
 
@@ -499,16 +876,93 @@ def perfect_forecast_optim(
499
876
  "passed_data"
500
877
  ].get("entity_save", False):
501
878
  # Trigger the publish function, save entity data and not post to HA
502
- publish_data(input_data_dict, logger, entity_save=True, dont_post=True)
879
+ await publish_data(input_data_dict, logger, entity_save=True, dont_post=True)
503
880
 
504
881
  return opt_res
505
882
 
506
883
 
507
- def dayahead_forecast_optim(
884
+ def prepare_forecast_and_weather_data(
508
885
  input_data_dict: dict,
509
886
  logger: logging.Logger,
510
- save_data_to_file: Optional[bool] = False,
511
- debug: Optional[bool] = False,
887
+ warn_on_resolution: bool = False,
888
+ ) -> pd.DataFrame | bool:
889
+ """
890
+ Prepare forecast data with load costs, production prices, outdoor temperature, and GHI.
891
+
892
+ This helper function eliminates duplication between dayahead_forecast_optim and naive_mpc_optim.
893
+
894
+ :param input_data_dict: Dictionary with forecast and input data
895
+ :type input_data_dict: dict
896
+ :param logger: Logger object
897
+ :type logger: logging.Logger
898
+ :param warn_on_resolution: Whether to warn about GHI resolution mismatch
899
+ :type warn_on_resolution: bool
900
+ :return: Prepared DataFrame or False on error
901
+ :rtype: pd.DataFrame | bool
902
+ """
903
+ # Get load cost forecast
904
+ df_input_data_dayahead = input_data_dict["fcst"].get_load_cost_forecast(
905
+ input_data_dict["df_input_data_dayahead"],
906
+ method=input_data_dict["fcst"].optim_conf["load_cost_forecast_method"],
907
+ )
908
+ if isinstance(df_input_data_dayahead, bool) and not df_input_data_dayahead:
909
+ return False
910
+
911
+ # Get production price forecast
912
+ df_input_data_dayahead = input_data_dict["fcst"].get_prod_price_forecast(
913
+ df_input_data_dayahead,
914
+ method=input_data_dict["fcst"].optim_conf["production_price_forecast_method"],
915
+ )
916
+ if isinstance(df_input_data_dayahead, bool) and not df_input_data_dayahead:
917
+ return False
918
+
919
+ # Add outdoor temperature if provided
920
+ if "outdoor_temperature_forecast" in input_data_dict["params"]["passed_data"]:
921
+ df_input_data_dayahead["outdoor_temperature_forecast"] = input_data_dict["params"][
922
+ "passed_data"
923
+ ]["outdoor_temperature_forecast"]
924
+
925
+ # Merge GHI (Global Horizontal Irradiance) from weather forecast if available
926
+ if input_data_dict["df_weather"] is not None and "ghi" in input_data_dict["df_weather"].columns:
927
+ dayahead_index = df_input_data_dayahead.index
928
+
929
+ # Check time resolution if requested
930
+ if (
931
+ warn_on_resolution
932
+ and len(input_data_dict["df_weather"].index) > 1
933
+ and len(dayahead_index) > 1
934
+ ):
935
+ weather_index = input_data_dict["df_weather"].index
936
+ weather_freq = (weather_index[1] - weather_index[0]).total_seconds()
937
+ dayahead_freq = (dayahead_index[1] - dayahead_index[0]).total_seconds()
938
+ if weather_freq > 2 * dayahead_freq:
939
+ logger.warning(
940
+ "Weather data time resolution (%.0fs) is much coarser than dayahead index (%.0fs). "
941
+ "Step changes in GHI may occur.",
942
+ weather_freq,
943
+ dayahead_freq,
944
+ )
945
+
946
+ # Align GHI data to dayahead index using interpolation
947
+ df_input_data_dayahead["ghi"] = (
948
+ input_data_dict["df_weather"]["ghi"]
949
+ .reindex(dayahead_index)
950
+ .interpolate(method="time", limit_direction="both")
951
+ )
952
+ logger.debug(
953
+ "Merged GHI data into optimization input: mean=%.1f W/m², max=%.1f W/m²",
954
+ df_input_data_dayahead["ghi"].mean(),
955
+ df_input_data_dayahead["ghi"].max(),
956
+ )
957
+
958
+ return df_input_data_dayahead
959
+
960
+
961
+ async def dayahead_forecast_optim(
962
+ input_data_dict: dict,
963
+ logger: logging.Logger,
964
+ save_data_to_file: bool | None = False,
965
+ debug: bool | None = False,
512
966
  ) -> pd.DataFrame:
513
967
  """
514
968
  Perform a call to the day-ahead optimization routine.
@@ -526,36 +980,23 @@ def dayahead_forecast_optim(
526
980
 
527
981
  """
528
982
  logger.info("Performing day-ahead forecast optimization")
529
- # Load cost and prod price forecast
530
- df_input_data_dayahead = input_data_dict["fcst"].get_load_cost_forecast(
531
- input_data_dict["df_input_data_dayahead"],
532
- method=input_data_dict["fcst"].optim_conf["load_cost_forecast_method"],
983
+ # Prepare forecast data with costs, prices, outdoor temp, and GHI
984
+ df_input_data_dayahead = prepare_forecast_and_weather_data(
985
+ input_data_dict, logger, warn_on_resolution=False
533
986
  )
534
987
  if isinstance(df_input_data_dayahead, bool) and not df_input_data_dayahead:
535
988
  return False
536
- df_input_data_dayahead = input_data_dict["fcst"].get_prod_price_forecast(
537
- df_input_data_dayahead,
538
- method=input_data_dict["fcst"].optim_conf["production_price_forecast_method"],
539
- )
540
- if isinstance(df_input_data_dayahead, bool) and not df_input_data_dayahead:
541
- return False
542
- if "outdoor_temperature_forecast" in input_data_dict["params"]["passed_data"]:
543
- df_input_data_dayahead["outdoor_temperature_forecast"] = input_data_dict[
544
- "params"
545
- ]["passed_data"]["outdoor_temperature_forecast"]
546
989
  opt_res_dayahead = input_data_dict["opt"].perform_dayahead_forecast_optim(
547
990
  df_input_data_dayahead,
548
- input_data_dict["P_PV_forecast"],
549
- input_data_dict["P_load_forecast"],
991
+ input_data_dict["p_pv_forecast"],
992
+ input_data_dict["p_load_forecast"],
550
993
  )
551
994
  # Save CSV file for publish_data
552
995
  if save_data_to_file:
553
- today = datetime.now(timezone.utc).replace(
554
- hour=0, minute=0, second=0, microsecond=0
555
- )
996
+ today = datetime.now(UTC).replace(hour=0, minute=0, second=0, microsecond=0)
556
997
  filename = "opt_res_dayahead_" + today.strftime("%Y_%m_%d") + ".csv"
557
998
  else: # Just save the latest optimization results
558
- filename = "opt_res_latest.csv"
999
+ filename = default_csv_filename
559
1000
  if not debug:
560
1001
  opt_res_dayahead.to_csv(
561
1002
  input_data_dict["emhass_conf"]["data_path"] / filename,
@@ -563,7 +1004,7 @@ def dayahead_forecast_optim(
563
1004
  )
564
1005
 
565
1006
  if not isinstance(input_data_dict["params"], dict):
566
- params = json.loads(input_data_dict["params"])
1007
+ params = orjson.loads(input_data_dict["params"])
567
1008
  else:
568
1009
  params = input_data_dict["params"]
569
1010
 
@@ -572,16 +1013,16 @@ def dayahead_forecast_optim(
572
1013
  "passed_data"
573
1014
  ].get("entity_save", False):
574
1015
  # Trigger the publish function, save entity data and not post to HA
575
- publish_data(input_data_dict, logger, entity_save=True, dont_post=True)
1016
+ await publish_data(input_data_dict, logger, entity_save=True, dont_post=True)
576
1017
 
577
1018
  return opt_res_dayahead
578
1019
 
579
1020
 
580
- def naive_mpc_optim(
1021
+ async def naive_mpc_optim(
581
1022
  input_data_dict: dict,
582
1023
  logger: logging.Logger,
583
- save_data_to_file: Optional[bool] = False,
584
- debug: Optional[bool] = False,
1024
+ save_data_to_file: bool | None = False,
1025
+ debug: bool | None = False,
585
1026
  ) -> pd.DataFrame:
586
1027
  """
587
1028
  Perform a call to the naive Model Predictive Controller optimization routine.
@@ -599,30 +1040,22 @@ def naive_mpc_optim(
599
1040
 
600
1041
  """
601
1042
  logger.info("Performing naive MPC optimization")
602
- # Load cost and prod price forecast
603
- df_input_data_dayahead = input_data_dict["fcst"].get_load_cost_forecast(
604
- input_data_dict["df_input_data_dayahead"],
605
- method=input_data_dict["fcst"].optim_conf["load_cost_forecast_method"],
1043
+ # Prepare forecast data with costs, prices, outdoor temp, and GHI (with resolution warning)
1044
+ df_input_data_dayahead = prepare_forecast_and_weather_data(
1045
+ input_data_dict, logger, warn_on_resolution=True
606
1046
  )
607
1047
  if isinstance(df_input_data_dayahead, bool) and not df_input_data_dayahead:
608
1048
  return False
609
- df_input_data_dayahead = input_data_dict["fcst"].get_prod_price_forecast(
610
- df_input_data_dayahead,
611
- method=input_data_dict["fcst"].optim_conf["production_price_forecast_method"],
612
- )
613
- if isinstance(df_input_data_dayahead, bool) and not df_input_data_dayahead:
614
- return False
615
- if "outdoor_temperature_forecast" in input_data_dict["params"]["passed_data"]:
616
- df_input_data_dayahead["outdoor_temperature_forecast"] = input_data_dict[
617
- "params"
618
- ]["passed_data"]["outdoor_temperature_forecast"]
619
1049
  # The specifics params for the MPC at runtime
620
1050
  prediction_horizon = input_data_dict["params"]["passed_data"]["prediction_horizon"]
621
1051
  soc_init = input_data_dict["params"]["passed_data"]["soc_init"]
622
1052
  soc_final = input_data_dict["params"]["passed_data"]["soc_final"]
623
- def_total_hours = input_data_dict["params"]["optim_conf"][
624
- "operating_hours_of_each_deferrable_load"
625
- ]
1053
+ def_total_hours = input_data_dict["params"]["optim_conf"].get(
1054
+ "operating_hours_of_each_deferrable_load", None
1055
+ )
1056
+ def_total_timestep = input_data_dict["params"]["optim_conf"].get(
1057
+ "operating_timesteps_of_each_deferrable_load", None
1058
+ )
626
1059
  def_start_timestep = input_data_dict["params"]["optim_conf"][
627
1060
  "start_timesteps_of_each_deferrable_load"
628
1061
  ]
@@ -631,23 +1064,22 @@ def naive_mpc_optim(
631
1064
  ]
632
1065
  opt_res_naive_mpc = input_data_dict["opt"].perform_naive_mpc_optim(
633
1066
  df_input_data_dayahead,
634
- input_data_dict["P_PV_forecast"],
635
- input_data_dict["P_load_forecast"],
1067
+ input_data_dict["p_pv_forecast"],
1068
+ input_data_dict["p_load_forecast"],
636
1069
  prediction_horizon,
637
1070
  soc_init,
638
1071
  soc_final,
639
1072
  def_total_hours,
1073
+ def_total_timestep,
640
1074
  def_start_timestep,
641
1075
  def_end_timestep,
642
1076
  )
643
1077
  # Save CSV file for publish_data
644
1078
  if save_data_to_file:
645
- today = datetime.now(timezone.utc).replace(
646
- hour=0, minute=0, second=0, microsecond=0
647
- )
1079
+ today = datetime.now(UTC).replace(hour=0, minute=0, second=0, microsecond=0)
648
1080
  filename = "opt_res_naive_mpc_" + today.strftime("%Y_%m_%d") + ".csv"
649
1081
  else: # Just save the latest optimization results
650
- filename = "opt_res_latest.csv"
1082
+ filename = default_csv_filename
651
1083
  if not debug:
652
1084
  opt_res_naive_mpc.to_csv(
653
1085
  input_data_dict["emhass_conf"]["data_path"] / filename,
@@ -655,7 +1087,7 @@ def naive_mpc_optim(
655
1087
  )
656
1088
 
657
1089
  if not isinstance(input_data_dict["params"], dict):
658
- params = json.loads(input_data_dict["params"])
1090
+ params = orjson.loads(input_data_dict["params"])
659
1091
  else:
660
1092
  params = input_data_dict["params"]
661
1093
 
@@ -664,14 +1096,14 @@ def naive_mpc_optim(
664
1096
  "passed_data"
665
1097
  ].get("entity_save", False):
666
1098
  # Trigger the publish function, save entity data and not post to HA
667
- publish_data(input_data_dict, logger, entity_save=True, dont_post=True)
1099
+ await publish_data(input_data_dict, logger, entity_save=True, dont_post=True)
668
1100
 
669
1101
  return opt_res_naive_mpc
670
1102
 
671
1103
 
672
- def forecast_model_fit(
673
- input_data_dict: dict, logger: logging.Logger, debug: Optional[bool] = False
674
- ) -> Tuple[pd.DataFrame, pd.DataFrame, MLForecaster]:
1104
+ async def forecast_model_fit(
1105
+ input_data_dict: dict, logger: logging.Logger, debug: bool | None = False
1106
+ ) -> tuple[pd.DataFrame, pd.DataFrame, MLForecaster]:
675
1107
  """Perform a forecast model fit from training data retrieved from Home Assistant.
676
1108
 
677
1109
  :param input_data_dict: A dictionnary with multiple data used by the action functions
@@ -701,24 +1133,25 @@ def forecast_model_fit(
701
1133
  logger,
702
1134
  )
703
1135
  # Fit the ML model
704
- df_pred, df_pred_backtest = mlf.fit(
1136
+ df_pred, df_pred_backtest = await mlf.fit(
705
1137
  split_date_delta=split_date_delta, perform_backtest=perform_backtest
706
1138
  )
707
1139
  # Save model
708
1140
  if not debug:
709
- filename = model_type + "_mlf.pkl"
1141
+ filename = model_type + default_pkl_suffix
710
1142
  filename_path = input_data_dict["emhass_conf"]["data_path"] / filename
711
- with open(filename_path, "wb") as outp:
712
- pickle.dump(mlf, outp, pickle.HIGHEST_PROTOCOL)
1143
+ async with aiofiles.open(filename_path, "wb") as outp:
1144
+ await outp.write(pickle.dumps(mlf, pickle.HIGHEST_PROTOCOL))
1145
+ logger.debug("saved model to " + str(filename_path))
713
1146
  return df_pred, df_pred_backtest, mlf
714
1147
 
715
1148
 
716
- def forecast_model_predict(
1149
+ async def forecast_model_predict(
717
1150
  input_data_dict: dict,
718
1151
  logger: logging.Logger,
719
- use_last_window: Optional[bool] = True,
720
- debug: Optional[bool] = False,
721
- mlf: Optional[MLForecaster] = None,
1152
+ use_last_window: bool | None = True,
1153
+ debug: bool | None = False,
1154
+ mlf: MLForecaster | None = None,
722
1155
  ) -> pd.DataFrame:
723
1156
  r"""Perform a forecast model predict using a previously trained skforecast model.
724
1157
 
@@ -742,15 +1175,19 @@ def forecast_model_predict(
742
1175
  """
743
1176
  # Load model
744
1177
  model_type = input_data_dict["params"]["passed_data"]["model_type"]
745
- filename = model_type + "_mlf.pkl"
1178
+ filename = model_type + default_pkl_suffix
746
1179
  filename_path = input_data_dict["emhass_conf"]["data_path"] / filename
747
1180
  if not debug:
748
1181
  if filename_path.is_file():
749
- with open(filename_path, "rb") as inp:
750
- mlf = pickle.load(inp)
1182
+ async with aiofiles.open(filename_path, "rb") as inp:
1183
+ content = await inp.read()
1184
+ mlf = pickle.loads(content)
1185
+ logger.debug("loaded saved model from " + str(filename_path))
751
1186
  else:
752
1187
  logger.error(
753
- "The ML forecaster file was not found, please run a model fit method before this predict method",
1188
+ "The ML forecaster file ("
1189
+ + str(filename_path)
1190
+ + ") was not found, please run a model fit method before this predict method",
754
1191
  )
755
1192
  return
756
1193
  # Make predictions
@@ -758,13 +1195,12 @@ def forecast_model_predict(
758
1195
  data_last_window = copy.deepcopy(input_data_dict["df_input_data"])
759
1196
  else:
760
1197
  data_last_window = None
761
- predictions = mlf.predict(data_last_window)
1198
+ predictions = await mlf.predict(data_last_window)
762
1199
  # Publish data to a Home Assistant sensor
763
- model_predict_publish = input_data_dict["params"]["passed_data"][
764
- "model_predict_publish"
765
- ]
766
- model_predict_entity_id = input_data_dict["params"]["passed_data"][
767
- "model_predict_entity_id"
1200
+ model_predict_publish = input_data_dict["params"]["passed_data"]["model_predict_publish"]
1201
+ model_predict_entity_id = input_data_dict["params"]["passed_data"]["model_predict_entity_id"]
1202
+ model_predict_device_class = input_data_dict["params"]["passed_data"][
1203
+ "model_predict_device_class"
768
1204
  ]
769
1205
  model_predict_unit_of_measurement = input_data_dict["params"]["passed_data"][
770
1206
  "model_predict_unit_of_measurement"
@@ -775,30 +1211,23 @@ def forecast_model_predict(
775
1211
  publish_prefix = input_data_dict["params"]["passed_data"]["publish_prefix"]
776
1212
  if model_predict_publish is True:
777
1213
  # Estimate the current index
778
- now_precise = datetime.now(
779
- input_data_dict["retrieve_hass_conf"]["time_zone"]
780
- ).replace(second=0, microsecond=0)
1214
+ now_precise = datetime.now(input_data_dict["retrieve_hass_conf"]["time_zone"]).replace(
1215
+ second=0, microsecond=0
1216
+ )
781
1217
  if input_data_dict["retrieve_hass_conf"]["method_ts_round"] == "nearest":
782
- idx_closest = predictions.index.get_indexer(
783
- [now_precise], method="nearest"
784
- )[0]
1218
+ idx_closest = predictions.index.get_indexer([now_precise], method="nearest")[0]
785
1219
  elif input_data_dict["retrieve_hass_conf"]["method_ts_round"] == "first":
786
- idx_closest = predictions.index.get_indexer([now_precise], method="ffill")[
787
- 0
788
- ]
1220
+ idx_closest = predictions.index.get_indexer([now_precise], method="ffill")[0]
789
1221
  elif input_data_dict["retrieve_hass_conf"]["method_ts_round"] == "last":
790
- idx_closest = predictions.index.get_indexer([now_precise], method="bfill")[
791
- 0
792
- ]
1222
+ idx_closest = predictions.index.get_indexer([now_precise], method="bfill")[0]
793
1223
  if idx_closest == -1:
794
- idx_closest = predictions.index.get_indexer(
795
- [now_precise], method="nearest"
796
- )[0]
1224
+ idx_closest = predictions.index.get_indexer([now_precise], method="nearest")[0]
797
1225
  # Publish Load forecast
798
- input_data_dict["rh"].post_data(
1226
+ await input_data_dict["rh"].post_data(
799
1227
  predictions,
800
1228
  idx_closest,
801
1229
  model_predict_entity_id,
1230
+ model_predict_device_class,
802
1231
  model_predict_unit_of_measurement,
803
1232
  model_predict_friendly_name,
804
1233
  type_var="mlforecaster",
@@ -807,12 +1236,12 @@ def forecast_model_predict(
807
1236
  return predictions
808
1237
 
809
1238
 
810
- def forecast_model_tune(
1239
+ async def forecast_model_tune(
811
1240
  input_data_dict: dict,
812
1241
  logger: logging.Logger,
813
- debug: Optional[bool] = False,
814
- mlf: Optional[MLForecaster] = None,
815
- ) -> Tuple[pd.DataFrame, MLForecaster]:
1242
+ debug: bool | None = False,
1243
+ mlf: MLForecaster | None = None,
1244
+ ) -> tuple[pd.DataFrame, MLForecaster]:
816
1245
  """Tune a forecast model hyperparameters using bayesian optimization.
817
1246
 
818
1247
  :param input_data_dict: A dictionnary with multiple data used by the action functions
@@ -829,30 +1258,42 @@ def forecast_model_tune(
829
1258
  """
830
1259
  # Load model
831
1260
  model_type = input_data_dict["params"]["passed_data"]["model_type"]
832
- filename = model_type + "_mlf.pkl"
1261
+ filename = model_type + default_pkl_suffix
833
1262
  filename_path = input_data_dict["emhass_conf"]["data_path"] / filename
834
1263
  if not debug:
835
1264
  if filename_path.is_file():
836
- with open(filename_path, "rb") as inp:
837
- mlf = pickle.load(inp)
1265
+ async with aiofiles.open(filename_path, "rb") as inp:
1266
+ content = await inp.read()
1267
+ mlf = pickle.loads(content)
1268
+ logger.debug("loaded saved model from " + str(filename_path))
838
1269
  else:
839
1270
  logger.error(
840
- "The ML forecaster file was not found, please run a model fit method before this tune method",
1271
+ "The ML forecaster file ("
1272
+ + str(filename_path)
1273
+ + ") was not found, please run a model fit method before this tune method",
841
1274
  )
842
1275
  return None, None
843
1276
  # Tune the model
844
- df_pred_optim = mlf.tune(debug=debug)
1277
+ split_date_delta = input_data_dict["params"]["passed_data"]["split_date_delta"]
1278
+ if debug:
1279
+ n_trials = 5
1280
+ else:
1281
+ n_trials = input_data_dict["params"]["passed_data"]["n_trials"]
1282
+ df_pred_optim = await mlf.tune(
1283
+ split_date_delta=split_date_delta, n_trials=n_trials, debug=debug
1284
+ )
845
1285
  # Save model
846
1286
  if not debug:
847
- filename = model_type + "_mlf.pkl"
1287
+ filename = model_type + default_pkl_suffix
848
1288
  filename_path = input_data_dict["emhass_conf"]["data_path"] / filename
849
- with open(filename_path, "wb") as outp:
850
- pickle.dump(mlf, outp, pickle.HIGHEST_PROTOCOL)
1289
+ async with aiofiles.open(filename_path, "wb") as outp:
1290
+ await outp.write(pickle.dumps(mlf, pickle.HIGHEST_PROTOCOL))
1291
+ logger.debug("Saved model to " + str(filename_path))
851
1292
  return df_pred_optim, mlf
852
1293
 
853
1294
 
854
- def regressor_model_fit(
855
- input_data_dict: dict, logger: logging.Logger, debug: Optional[bool] = False
1295
+ async def regressor_model_fit(
1296
+ input_data_dict: dict, logger: logging.Logger, debug: bool | None = False
856
1297
  ) -> MLRegressor:
857
1298
  """Perform a forecast model fit from training data retrieved from Home Assistant.
858
1299
 
@@ -895,27 +1336,25 @@ def regressor_model_fit(
895
1336
  logger.error("parameter: 'date_features' not passed")
896
1337
  return False
897
1338
  # The MLRegressor object
898
- mlr = MLRegressor(
899
- data, model_type, regression_model, features, target, timestamp, logger
900
- )
1339
+ mlr = MLRegressor(data, model_type, regression_model, features, target, timestamp, logger)
901
1340
  # Fit the ML model
902
- fit = mlr.fit(date_features=date_features)
1341
+ fit = await mlr.fit(date_features=date_features)
903
1342
  if not fit:
904
1343
  return False
905
1344
  # Save model
906
1345
  if not debug:
907
1346
  filename = model_type + "_mlr.pkl"
908
1347
  filename_path = input_data_dict["emhass_conf"]["data_path"] / filename
909
- with open(filename_path, "wb") as outp:
910
- pickle.dump(mlr, outp, pickle.HIGHEST_PROTOCOL)
1348
+ async with aiofiles.open(filename_path, "wb") as outp:
1349
+ await outp.write(pickle.dumps(mlr, pickle.HIGHEST_PROTOCOL))
911
1350
  return mlr
912
1351
 
913
1352
 
914
- def regressor_model_predict(
1353
+ async def regressor_model_predict(
915
1354
  input_data_dict: dict,
916
1355
  logger: logging.Logger,
917
- debug: Optional[bool] = False,
918
- mlr: Optional[MLRegressor] = None,
1356
+ debug: bool | None = False,
1357
+ mlr: MLRegressor | None = None,
919
1358
  ) -> np.ndarray:
920
1359
  """Perform a prediction from csv file.
921
1360
 
@@ -935,8 +1374,9 @@ def regressor_model_predict(
935
1374
  filename_path = input_data_dict["emhass_conf"]["data_path"] / filename
936
1375
  if not debug:
937
1376
  if filename_path.is_file():
938
- with open(filename_path, "rb") as inp:
939
- mlr = pickle.load(inp)
1377
+ async with aiofiles.open(filename_path, "rb") as inp:
1378
+ content = await inp.read()
1379
+ mlr = pickle.loads(content)
940
1380
  else:
941
1381
  logger.error(
942
1382
  "The ML forecaster file was not found, please run a model fit method before this predict method",
@@ -948,12 +1388,15 @@ def regressor_model_predict(
948
1388
  logger.error("parameter: 'new_values' not passed")
949
1389
  return False
950
1390
  # Predict from csv file
951
- prediction = mlr.predict(new_values)
1391
+ prediction = await mlr.predict(new_values)
952
1392
  mlr_predict_entity_id = input_data_dict["params"]["passed_data"].get(
953
1393
  "mlr_predict_entity_id", "sensor.mlr_predict"
954
1394
  )
1395
+ mlr_predict_device_class = input_data_dict["params"]["passed_data"].get(
1396
+ "mlr_predict_device_class", "power"
1397
+ )
955
1398
  mlr_predict_unit_of_measurement = input_data_dict["params"]["passed_data"].get(
956
- "mlr_predict_unit_of_measurement", "h"
1399
+ "mlr_predict_unit_of_measurement", "W"
957
1400
  )
958
1401
  mlr_predict_friendly_name = input_data_dict["params"]["passed_data"].get(
959
1402
  "mlr_predict_friendly_name", "mlr predictor"
@@ -961,10 +1404,11 @@ def regressor_model_predict(
961
1404
  # Publish prediction
962
1405
  idx = 0
963
1406
  if not debug:
964
- input_data_dict["rh"].post_data(
1407
+ await input_data_dict["rh"].post_data(
965
1408
  prediction,
966
1409
  idx,
967
1410
  mlr_predict_entity_id,
1411
+ mlr_predict_device_class,
968
1412
  mlr_predict_unit_of_measurement,
969
1413
  mlr_predict_friendly_name,
970
1414
  type_var="mlregressor",
@@ -972,341 +1416,554 @@ def regressor_model_predict(
972
1416
  return prediction
973
1417
 
974
1418
 
975
- def publish_data(
976
- input_data_dict: dict,
1419
+ async def export_influxdb_to_csv(
1420
+ input_data_dict: dict | None,
977
1421
  logger: logging.Logger,
978
- save_data_to_file: Optional[bool] = False,
979
- opt_res_latest: Optional[pd.DataFrame] = None,
980
- entity_save: Optional[bool] = False,
981
- dont_post: Optional[bool] = False,
982
- ) -> pd.DataFrame:
983
- """
984
- Publish the data obtained from the optimization results.
1422
+ emhass_conf: dict | None = None,
1423
+ params: str | None = None,
1424
+ runtimeparams: str | None = None,
1425
+ ) -> bool:
1426
+ """Export data from InfluxDB to CSV file.
985
1427
 
986
- :param input_data_dict: A dictionnary with multiple data used by the action functions
987
- :type input_data_dict: dict
988
- :param logger: The passed logger object
989
- :type logger: logging object
990
- :param save_data_to_file: If True we will read data from optimization results in dayahead CSV file
991
- :type save_data_to_file: bool, optional
992
- :return: The output data of the optimization readed from a CSV file in the data folder
993
- :rtype: pd.DataFrame
994
- :param entity_save: Save built entities to data_path/entities
995
- :type entity_save: bool, optional
996
- :param dont_post: Do not post to Home Assistant. Works with entity_save
997
- :type dont_post: bool, optional
1428
+ This function can be called in two ways:
1429
+ 1. With input_data_dict (from web_server via set_input_data_dict)
1430
+ 2. Without input_data_dict (direct call from command line or web_server before set_input_data_dict)
998
1431
 
1432
+ :param input_data_dict: Dictionary containing configuration and parameters (optional)
1433
+ :type input_data_dict: dict | None
1434
+ :param logger: Logger object
1435
+ :type logger: logging.Logger
1436
+ :param emhass_conf: Dictionary containing EMHASS configuration paths (used when input_data_dict is None)
1437
+ :type emhass_conf: dict | None
1438
+ :param params: JSON string of params (used when input_data_dict is None)
1439
+ :type params: str | None
1440
+ :param runtimeparams: JSON string of runtime parameters (used when input_data_dict is None)
1441
+ :type runtimeparams: str | None
1442
+ :return: Success status
1443
+ :rtype: bool
999
1444
  """
1000
- logger.info("Publishing data to HASS instance")
1445
+ # Handle two calling modes
1446
+ if input_data_dict is None:
1447
+ # Direct mode: parse params and create RetrieveHass
1448
+ if emhass_conf is None or params is None:
1449
+ logger.error("emhass_conf and params are required when input_data_dict is None")
1450
+ return False
1451
+ # Parse params
1452
+ if isinstance(params, str):
1453
+ params = orjson.loads(params)
1454
+ if isinstance(runtimeparams, str):
1455
+ runtimeparams = orjson.loads(runtimeparams)
1456
+ # Get configuration
1457
+ retrieve_hass_conf, optim_conf, plant_conf = utils.get_yaml_parse(params, logger)
1458
+ if isinstance(retrieve_hass_conf, bool):
1459
+ return False
1460
+ # Treat runtime params
1461
+ (
1462
+ params,
1463
+ retrieve_hass_conf,
1464
+ optim_conf,
1465
+ plant_conf,
1466
+ ) = await utils.treat_runtimeparams(
1467
+ orjson.dumps(runtimeparams).decode("utf-8") if runtimeparams else "{}",
1468
+ params,
1469
+ retrieve_hass_conf,
1470
+ optim_conf,
1471
+ plant_conf,
1472
+ "export-influxdb-to-csv",
1473
+ logger,
1474
+ emhass_conf,
1475
+ )
1476
+ # Parse params again if it's a string
1477
+ if isinstance(params, str):
1478
+ params = orjson.loads(params)
1479
+ # Create RetrieveHass object
1480
+ rh = RetrieveHass(
1481
+ retrieve_hass_conf["hass_url"],
1482
+ retrieve_hass_conf["long_lived_token"],
1483
+ retrieve_hass_conf["optimization_time_step"],
1484
+ retrieve_hass_conf["time_zone"],
1485
+ params,
1486
+ emhass_conf,
1487
+ logger,
1488
+ )
1489
+ time_zone = rh.time_zone
1490
+ data_path = emhass_conf["data_path"]
1491
+ else:
1492
+ # Standard mode: use input_data_dict
1493
+ params = input_data_dict["params"]
1494
+ if isinstance(params, str):
1495
+ params = orjson.loads(params)
1496
+ rh = input_data_dict["rh"]
1497
+ time_zone = rh.time_zone
1498
+ data_path = input_data_dict["emhass_conf"]["data_path"]
1499
+ # Extract parameters from passed_data
1500
+ if "sensor_list" not in params.get("passed_data", {}):
1501
+ logger.error("parameter: 'sensor_list' not passed")
1502
+ return False
1503
+ sensor_list = params["passed_data"]["sensor_list"]
1504
+ if "csv_filename" not in params.get("passed_data", {}):
1505
+ logger.error("parameter: 'csv_filename' not passed")
1506
+ return False
1507
+ csv_filename = params["passed_data"]["csv_filename"]
1508
+ if "start_time" not in params.get("passed_data", {}):
1509
+ logger.error("parameter: 'start_time' not passed")
1510
+ return False
1511
+ start_time = params["passed_data"]["start_time"]
1512
+ # Optional parameters with defaults
1513
+ end_time = params["passed_data"].get("end_time", None)
1514
+ resample_freq = params["passed_data"].get("resample_freq", "1h")
1515
+ timestamp_col = params["passed_data"].get("timestamp_col_name", "timestamp")
1516
+ decimal_places = params["passed_data"].get("decimal_places", 2)
1517
+ handle_nan = params["passed_data"].get("handle_nan", "keep")
1518
+ # Check if InfluxDB is enabled
1519
+ if not rh.use_influxdb:
1520
+ logger.error(
1521
+ "InfluxDB is not enabled in configuration. Set use_influxdb: true in config.json"
1522
+ )
1523
+ return False
1524
+ # Parse time range
1525
+ start_dt, end_dt = utils.parse_export_time_range(start_time, end_time, time_zone, logger)
1526
+ if start_dt is False:
1527
+ return False
1528
+ # Create days list for data retrieval
1529
+ days_list = pd.date_range(start=start_dt.date(), end=end_dt.date(), freq="D", tz=time_zone)
1530
+ if len(days_list) == 0:
1531
+ logger.error("No days to retrieve. Check start_time and end_time.")
1532
+ return False
1533
+ logger.info(
1534
+ f"Retrieving {len(sensor_list)} sensors from {start_dt} to {end_dt} ({len(days_list)} days)"
1535
+ )
1536
+ logger.info(f"Sensors: {sensor_list}")
1537
+ # Retrieve data from InfluxDB
1538
+ success = rh.get_data(days_list, sensor_list)
1539
+ if not success or rh.df_final is None or rh.df_final.empty:
1540
+ logger.error("Failed to retrieve data from InfluxDB")
1541
+ return False
1542
+ # Filter and resample data
1543
+ df_export = utils.resample_and_filter_data(rh.df_final, start_dt, end_dt, resample_freq, logger)
1544
+ if df_export is False:
1545
+ return False
1546
+ # Reset index to make timestamp a column
1547
+ # Handle custom index names by renaming the index first
1548
+ df_export = df_export.rename_axis(timestamp_col).reset_index()
1549
+ # Clean column names
1550
+ df_export = utils.clean_sensor_column_names(df_export, timestamp_col)
1551
+ # Handle NaN values
1552
+ df_export = utils.handle_nan_values(df_export, handle_nan, timestamp_col, logger)
1553
+ # Round numeric columns to specified decimal places
1554
+ numeric_cols = df_export.select_dtypes(include=[np.number]).columns
1555
+ df_export[numeric_cols] = df_export[numeric_cols].round(decimal_places)
1556
+ # Save to CSV
1557
+ csv_path = pathlib.Path(data_path) / csv_filename
1558
+ df_export.to_csv(csv_path, index=False)
1559
+ logger.info(f"✓ Successfully exported to {csv_filename}")
1560
+ logger.info(f" Rows: {df_export.shape[0]}")
1561
+ logger.info(f" Columns: {list(df_export.columns)}")
1562
+ logger.info(
1563
+ f" Time range: {df_export[timestamp_col].min()} to {df_export[timestamp_col].max()}"
1564
+ )
1565
+ logger.info(f" File location: {csv_path}")
1566
+ return True
1567
+
1568
+
1569
+ def _get_params(input_data_dict: dict) -> dict:
1570
+ """Helper to extract params from input_data_dict."""
1001
1571
  if input_data_dict:
1002
1572
  if not isinstance(input_data_dict.get("params", {}), dict):
1003
- params = json.loads(input_data_dict["params"])
1004
- else:
1005
- params = input_data_dict.get("params", {})
1573
+ return orjson.loads(input_data_dict["params"])
1574
+ return input_data_dict.get("params", {})
1575
+ return {}
1576
+
1577
+
1578
+ async def _publish_from_saved_entities(
1579
+ input_data_dict: dict, logger: logging.Logger, params: dict
1580
+ ) -> pd.DataFrame | None:
1581
+ """
1582
+ Helper to publish data from saved entity JSON files if publish_prefix is set.
1583
+ Returns DataFrame if successful, None if fallback to CSV is needed.
1584
+ """
1585
+ publish_prefix = params["passed_data"].get("publish_prefix", "")
1586
+ entity_path = input_data_dict["emhass_conf"]["data_path"] / "entities"
1587
+ if not entity_path.exists() or not os.listdir(entity_path):
1588
+ logger.warning(f"No saved entity json files in path: {entity_path}")
1589
+ logger.warning("Falling back to opt_res_latest")
1590
+ return None
1591
+ entity_path_contents = os.listdir(entity_path)
1592
+ matches_prefix = any(publish_prefix in entity for entity in entity_path_contents)
1593
+ if not (matches_prefix or publish_prefix == "all"):
1594
+ logger.warning(f"No saved entity json files that match prefix: {publish_prefix}")
1595
+ logger.warning("Falling back to opt_res_latest")
1596
+ return None
1597
+ opt_res_list = []
1598
+ opt_res_list_names = []
1599
+ for entity in entity_path_contents:
1600
+ if entity == default_metadata_json:
1601
+ continue
1602
+ if publish_prefix == "all" or publish_prefix in entity:
1603
+ entity_data = await publish_json(entity, input_data_dict, entity_path, logger)
1604
+ if isinstance(entity_data, bool):
1605
+ return None # Error occurred
1606
+ opt_res_list.append(entity_data)
1607
+ opt_res_list_names.append(entity.replace(".json", ""))
1608
+ opt_res = pd.concat(opt_res_list, axis=1)
1609
+ opt_res.columns = opt_res_list_names
1610
+ return opt_res
1006
1611
 
1007
- # Check if a day ahead optimization has been performed (read CSV file)
1612
+
1613
+ def _load_opt_res_latest(
1614
+ input_data_dict: dict, logger: logging.Logger, save_data_to_file: bool
1615
+ ) -> pd.DataFrame | None:
1616
+ """Helper to load the optimization results DataFrame from CSV."""
1008
1617
  if save_data_to_file:
1009
- today = datetime.now(timezone.utc).replace(
1010
- hour=0, minute=0, second=0, microsecond=0
1011
- )
1618
+ today = datetime.now(UTC).replace(hour=0, minute=0, second=0, microsecond=0)
1012
1619
  filename = "opt_res_dayahead_" + today.strftime("%Y_%m_%d") + ".csv"
1013
- # If publish_prefix is passed, check if there is saved entities in data_path/entities with prefix, publish to results
1014
- elif params["passed_data"].get("publish_prefix", "") != "" and not dont_post:
1015
- opt_res_list = []
1016
- opt_res_list_names = []
1017
- publish_prefix = params["passed_data"]["publish_prefix"]
1018
- entity_path = input_data_dict["emhass_conf"]["data_path"] / "entities"
1019
- # Check if items in entity_path
1020
- if os.path.exists(entity_path) and len(os.listdir(entity_path)) > 0:
1021
- # Obtain all files in entity_path
1022
- entity_path_contents = os.listdir(entity_path)
1023
- # Confirm the entity path contains at least one file containing publish prefix or publish_prefix='all'
1024
- if (
1025
- any(publish_prefix in entity for entity in entity_path_contents)
1026
- or publish_prefix == "all"
1027
- ):
1028
- # Loop through all items in entity path
1029
- for entity in entity_path_contents:
1030
- # If publish_prefix is "all" publish all saved entities to Home Assistant
1031
- # If publish_prefix matches the prefix from saved entities, publish to Home Assistant
1032
- if entity != "metadata.json" and (
1033
- publish_prefix in entity or publish_prefix == "all"
1034
- ):
1035
- entity_data = publish_json(
1036
- entity, input_data_dict, entity_path, logger
1037
- )
1038
- if not isinstance(entity_data, bool):
1039
- opt_res_list.append(entity_data)
1040
- opt_res_list_names.append(entity.replace(".json", ""))
1041
- else:
1042
- return False
1043
- # Build a DataFrame with published entities
1044
- opt_res = pd.concat(opt_res_list, axis=1)
1045
- opt_res.columns = opt_res_list_names
1046
- return opt_res
1047
- else:
1048
- logger.warning(
1049
- "No saved entity json files that match prefix: "
1050
- + str(publish_prefix)
1051
- )
1052
- logger.warning("Falling back to opt_res_latest")
1053
- else:
1054
- logger.warning("No saved entity json files in path:" + str(entity_path))
1055
- logger.warning("Falling back to opt_res_latest")
1056
- filename = "opt_res_latest.csv"
1057
1620
  else:
1058
- filename = "opt_res_latest.csv"
1059
- if opt_res_latest is None:
1060
- if not os.path.isfile(input_data_dict["emhass_conf"]["data_path"] / filename):
1061
- logger.error("File not found error, run an optimization task first.")
1062
- return
1063
- else:
1064
- opt_res_latest = pd.read_csv(
1065
- input_data_dict["emhass_conf"]["data_path"] / filename,
1066
- index_col="timestamp",
1067
- )
1068
- opt_res_latest.index = pd.to_datetime(opt_res_latest.index)
1069
- opt_res_latest.index.freq = input_data_dict["retrieve_hass_conf"][
1070
- "optimization_time_step"
1071
- ]
1072
- # Estimate the current index
1073
- now_precise = datetime.now(
1074
- input_data_dict["retrieve_hass_conf"]["time_zone"]
1075
- ).replace(second=0, microsecond=0)
1076
- if input_data_dict["retrieve_hass_conf"]["method_ts_round"] == "nearest":
1077
- idx_closest = opt_res_latest.index.get_indexer([now_precise], method="nearest")[
1078
- 0
1079
- ]
1080
- elif input_data_dict["retrieve_hass_conf"]["method_ts_round"] == "first":
1081
- idx_closest = opt_res_latest.index.get_indexer([now_precise], method="ffill")[0]
1082
- elif input_data_dict["retrieve_hass_conf"]["method_ts_round"] == "last":
1083
- idx_closest = opt_res_latest.index.get_indexer([now_precise], method="bfill")[0]
1084
- if idx_closest == -1:
1085
- idx_closest = opt_res_latest.index.get_indexer([now_precise], method="nearest")[
1086
- 0
1087
- ]
1088
- # Publish the data
1089
- publish_prefix = params["passed_data"]["publish_prefix"]
1090
- # Publish PV forecast
1091
- custom_pv_forecast_id = params["passed_data"]["custom_pv_forecast_id"]
1092
- input_data_dict["rh"].post_data(
1621
+ filename = default_csv_filename
1622
+ file_path = input_data_dict["emhass_conf"]["data_path"] / filename
1623
+ if not file_path.exists():
1624
+ logger.error("File not found error, run an optimization task first.")
1625
+ return None
1626
+ opt_res_latest = pd.read_csv(file_path, index_col="timestamp")
1627
+ opt_res_latest.index = pd.to_datetime(opt_res_latest.index)
1628
+ opt_res_latest.index.freq = input_data_dict["retrieve_hass_conf"]["optimization_time_step"]
1629
+ return opt_res_latest
1630
+
1631
+
1632
+ def _get_closest_index(retrieve_hass_conf: dict, index: pd.DatetimeIndex) -> int:
1633
+ """Helper to find the closest index in the DataFrame to the current time."""
1634
+ now_precise = datetime.now(retrieve_hass_conf["time_zone"]).replace(second=0, microsecond=0)
1635
+ method = retrieve_hass_conf["method_ts_round"]
1636
+ if method == "nearest":
1637
+ return index.get_indexer([now_precise], method="nearest")[0]
1638
+ elif method == "first":
1639
+ return index.get_indexer([now_precise], method="ffill")[0]
1640
+ elif method == "last":
1641
+ return index.get_indexer([now_precise], method="bfill")[0]
1642
+ return index.get_indexer([now_precise], method="nearest")[0]
1643
+
1644
+
1645
+ async def _publish_standard_forecasts(
1646
+ ctx: PublishContext, opt_res_latest: pd.DataFrame
1647
+ ) -> list[str]:
1648
+ """Publish PV, Load, Curtailment, and Hybrid Inverter data."""
1649
+ cols = []
1650
+ # PV Forecast
1651
+ custom_pv = ctx.params["passed_data"]["custom_pv_forecast_id"]
1652
+ await ctx.rh.post_data(
1093
1653
  opt_res_latest["P_PV"],
1094
- idx_closest,
1095
- custom_pv_forecast_id["entity_id"],
1096
- custom_pv_forecast_id["unit_of_measurement"],
1097
- custom_pv_forecast_id["friendly_name"],
1654
+ ctx.idx,
1655
+ custom_pv["entity_id"],
1656
+ "power",
1657
+ custom_pv["unit_of_measurement"],
1658
+ custom_pv["friendly_name"],
1098
1659
  type_var="power",
1099
- publish_prefix=publish_prefix,
1100
- save_entities=entity_save,
1101
- dont_post=dont_post,
1660
+ **ctx.common_kwargs,
1102
1661
  )
1103
- # Publish Load forecast
1104
- custom_load_forecast_id = params["passed_data"]["custom_load_forecast_id"]
1105
- input_data_dict["rh"].post_data(
1662
+ cols.append("P_PV")
1663
+ # Load Forecast
1664
+ custom_load = ctx.params["passed_data"]["custom_load_forecast_id"]
1665
+ await ctx.rh.post_data(
1106
1666
  opt_res_latest["P_Load"],
1107
- idx_closest,
1108
- custom_load_forecast_id["entity_id"],
1109
- custom_load_forecast_id["unit_of_measurement"],
1110
- custom_load_forecast_id["friendly_name"],
1667
+ ctx.idx,
1668
+ custom_load["entity_id"],
1669
+ "power",
1670
+ custom_load["unit_of_measurement"],
1671
+ custom_load["friendly_name"],
1111
1672
  type_var="power",
1112
- publish_prefix=publish_prefix,
1113
- save_entities=entity_save,
1114
- dont_post=dont_post,
1115
- )
1116
- cols_published = ["P_PV", "P_Load"]
1117
- # Publish PV curtailment
1118
- if input_data_dict["fcst"].plant_conf["compute_curtailment"]:
1119
- custom_pv_curtailment_id = params["passed_data"]["custom_pv_curtailment_id"]
1120
- input_data_dict["rh"].post_data(
1673
+ **ctx.common_kwargs,
1674
+ )
1675
+ cols.append("P_Load")
1676
+ # Curtailment
1677
+ if ctx.fcst.plant_conf["compute_curtailment"]:
1678
+ custom_curt = ctx.params["passed_data"]["custom_pv_curtailment_id"]
1679
+ await ctx.rh.post_data(
1121
1680
  opt_res_latest["P_PV_curtailment"],
1122
- idx_closest,
1123
- custom_pv_curtailment_id["entity_id"],
1124
- custom_pv_curtailment_id["unit_of_measurement"],
1125
- custom_pv_curtailment_id["friendly_name"],
1681
+ ctx.idx,
1682
+ custom_curt["entity_id"],
1683
+ "power",
1684
+ custom_curt["unit_of_measurement"],
1685
+ custom_curt["friendly_name"],
1126
1686
  type_var="power",
1127
- publish_prefix=publish_prefix,
1128
- save_entities=entity_save,
1129
- dont_post=dont_post,
1687
+ **ctx.common_kwargs,
1130
1688
  )
1131
- cols_published = cols_published + ["P_PV_curtailment"]
1132
- # Publish P_hybrid_inverter
1133
- if input_data_dict["fcst"].plant_conf["inverter_is_hybrid"]:
1134
- custom_hybrid_inverter_id = params["passed_data"]["custom_hybrid_inverter_id"]
1135
- input_data_dict["rh"].post_data(
1689
+ cols.append("P_PV_curtailment")
1690
+ # Hybrid Inverter
1691
+ if ctx.fcst.plant_conf["inverter_is_hybrid"]:
1692
+ custom_inv = ctx.params["passed_data"]["custom_hybrid_inverter_id"]
1693
+ await ctx.rh.post_data(
1136
1694
  opt_res_latest["P_hybrid_inverter"],
1137
- idx_closest,
1138
- custom_hybrid_inverter_id["entity_id"],
1139
- custom_hybrid_inverter_id["unit_of_measurement"],
1140
- custom_hybrid_inverter_id["friendly_name"],
1695
+ ctx.idx,
1696
+ custom_inv["entity_id"],
1697
+ "power",
1698
+ custom_inv["unit_of_measurement"],
1699
+ custom_inv["friendly_name"],
1141
1700
  type_var="power",
1142
- publish_prefix=publish_prefix,
1143
- save_entities=entity_save,
1144
- dont_post=dont_post,
1701
+ **ctx.common_kwargs,
1145
1702
  )
1146
- cols_published = cols_published + ["P_hybrid_inverter"]
1147
- # Publish deferrable loads
1148
- custom_deferrable_forecast_id = params["passed_data"][
1149
- "custom_deferrable_forecast_id"
1150
- ]
1151
- for k in range(input_data_dict["opt"].optim_conf["number_of_deferrable_loads"]):
1152
- if "P_deferrable{}".format(k) not in opt_res_latest.columns:
1153
- logger.error(
1154
- "P_deferrable{}".format(k)
1155
- + " was not found in results DataFrame. Optimization task may need to be relaunched or it did not converge to a solution.",
1156
- )
1157
- else:
1158
- input_data_dict["rh"].post_data(
1159
- opt_res_latest["P_deferrable{}".format(k)],
1160
- idx_closest,
1161
- custom_deferrable_forecast_id[k]["entity_id"],
1162
- custom_deferrable_forecast_id[k]["unit_of_measurement"],
1163
- custom_deferrable_forecast_id[k]["friendly_name"],
1164
- type_var="deferrable",
1165
- publish_prefix=publish_prefix,
1166
- save_entities=entity_save,
1167
- dont_post=dont_post,
1168
- )
1169
- cols_published = cols_published + ["P_deferrable{}".format(k)]
1170
- # Publish thermal model data (predicted temperature)
1171
- custom_predicted_temperature_id = params["passed_data"][
1172
- "custom_predicted_temperature_id"
1173
- ]
1174
- for k in range(input_data_dict["opt"].optim_conf["number_of_deferrable_loads"]):
1175
- if "def_load_config" in input_data_dict["opt"].optim_conf.keys():
1176
- if (
1177
- "thermal_config"
1178
- in input_data_dict["opt"].optim_conf["def_load_config"][k]
1179
- ):
1180
- input_data_dict["rh"].post_data(
1181
- opt_res_latest["predicted_temp_heater{}".format(k)],
1182
- idx_closest,
1183
- custom_predicted_temperature_id[k]["entity_id"],
1184
- custom_predicted_temperature_id[k]["unit_of_measurement"],
1185
- custom_predicted_temperature_id[k]["friendly_name"],
1186
- type_var="temperature",
1187
- publish_prefix=publish_prefix,
1188
- save_entities=entity_save,
1189
- dont_post=dont_post,
1190
- )
1191
- cols_published = cols_published + ["predicted_temp_heater{}".format(k)]
1192
- # Publish battery power
1193
- if input_data_dict["opt"].optim_conf["set_use_battery"]:
1194
- if "P_batt" not in opt_res_latest.columns:
1195
- logger.error(
1196
- "P_batt was not found in results DataFrame. Optimization task may need to be relaunched or it did not converge to a solution.",
1197
- )
1198
- else:
1199
- custom_batt_forecast_id = params["passed_data"]["custom_batt_forecast_id"]
1200
- input_data_dict["rh"].post_data(
1201
- opt_res_latest["P_batt"],
1202
- idx_closest,
1203
- custom_batt_forecast_id["entity_id"],
1204
- custom_batt_forecast_id["unit_of_measurement"],
1205
- custom_batt_forecast_id["friendly_name"],
1206
- type_var="batt",
1207
- publish_prefix=publish_prefix,
1208
- save_entities=entity_save,
1209
- dont_post=dont_post,
1210
- )
1211
- cols_published = cols_published + ["P_batt"]
1212
- custom_batt_soc_forecast_id = params["passed_data"][
1213
- "custom_batt_soc_forecast_id"
1214
- ]
1215
- input_data_dict["rh"].post_data(
1216
- opt_res_latest["SOC_opt"] * 100,
1217
- idx_closest,
1218
- custom_batt_soc_forecast_id["entity_id"],
1219
- custom_batt_soc_forecast_id["unit_of_measurement"],
1220
- custom_batt_soc_forecast_id["friendly_name"],
1221
- type_var="SOC",
1222
- publish_prefix=publish_prefix,
1223
- save_entities=entity_save,
1224
- dont_post=dont_post,
1703
+ cols.append("P_hybrid_inverter")
1704
+ return cols
1705
+
1706
+
1707
+ async def _publish_deferrable_loads(ctx: PublishContext, opt_res_latest: pd.DataFrame) -> list[str]:
1708
+ """Publish data for all deferrable loads."""
1709
+ cols = []
1710
+ custom_def = ctx.params["passed_data"]["custom_deferrable_forecast_id"]
1711
+ for k in range(ctx.opt.optim_conf["number_of_deferrable_loads"]):
1712
+ col_name = f"P_deferrable{k}"
1713
+ if col_name not in opt_res_latest.columns:
1714
+ ctx.logger.error(f"{col_name} was not found in results DataFrame.")
1715
+ continue
1716
+ await ctx.rh.post_data(
1717
+ opt_res_latest[col_name],
1718
+ ctx.idx,
1719
+ custom_def[k]["entity_id"],
1720
+ "power",
1721
+ custom_def[k]["unit_of_measurement"],
1722
+ custom_def[k]["friendly_name"],
1723
+ type_var="deferrable",
1724
+ **ctx.common_kwargs,
1725
+ )
1726
+ cols.append(col_name)
1727
+ return cols
1728
+
1729
+
1730
+ async def _publish_thermal_variable(
1731
+ rh, opt_res_latest, idx, k, custom_ids, col_prefix, type_var, unit_type, kwargs
1732
+ ) -> str | None:
1733
+ """Helper to publish a single thermal variable if valid."""
1734
+ if custom_ids and k < len(custom_ids):
1735
+ col_name = f"{col_prefix}{k}"
1736
+ if col_name in opt_res_latest.columns:
1737
+ entity_conf = custom_ids[k]
1738
+ await rh.post_data(
1739
+ opt_res_latest[col_name],
1740
+ idx,
1741
+ entity_conf["entity_id"],
1742
+ unit_type,
1743
+ entity_conf["unit_of_measurement"],
1744
+ entity_conf["friendly_name"],
1745
+ type_var=type_var,
1746
+ **kwargs,
1225
1747
  )
1226
- cols_published = cols_published + ["SOC_opt"]
1227
- # Publish grid power
1228
- custom_grid_forecast_id = params["passed_data"]["custom_grid_forecast_id"]
1229
- input_data_dict["rh"].post_data(
1748
+ return col_name
1749
+ return None
1750
+
1751
+
1752
+ async def _publish_thermal_loads(ctx: PublishContext, opt_res_latest: pd.DataFrame) -> list[str]:
1753
+ """Publish predicted temperature and heating demand for thermal loads."""
1754
+ cols = []
1755
+ if "custom_predicted_temperature_id" not in ctx.params["passed_data"]:
1756
+ return cols
1757
+ custom_temp = ctx.params["passed_data"]["custom_predicted_temperature_id"]
1758
+ custom_heat = ctx.params["passed_data"].get("custom_heating_demand_id")
1759
+ def_load_config = ctx.opt.optim_conf.get("def_load_config", [])
1760
+ if not isinstance(def_load_config, list):
1761
+ def_load_config = []
1762
+ for k in range(ctx.opt.optim_conf["number_of_deferrable_loads"]):
1763
+ if k >= len(def_load_config):
1764
+ continue
1765
+ load_cfg = def_load_config[k]
1766
+ if "thermal_config" not in load_cfg and "thermal_battery" not in load_cfg:
1767
+ continue
1768
+ col_t = await _publish_thermal_variable(
1769
+ ctx.rh,
1770
+ opt_res_latest,
1771
+ ctx.idx,
1772
+ k,
1773
+ custom_temp,
1774
+ "predicted_temp_heater",
1775
+ "temperature",
1776
+ "temperature",
1777
+ ctx.common_kwargs,
1778
+ )
1779
+ if col_t:
1780
+ cols.append(col_t)
1781
+ col_h = await _publish_thermal_variable(
1782
+ ctx.rh,
1783
+ opt_res_latest,
1784
+ ctx.idx,
1785
+ k,
1786
+ custom_heat,
1787
+ "heating_demand_heater",
1788
+ "energy",
1789
+ "energy",
1790
+ ctx.common_kwargs,
1791
+ )
1792
+ if col_h:
1793
+ cols.append(col_h)
1794
+ return cols
1795
+
1796
+
1797
+ async def _publish_battery_data(ctx: PublishContext, opt_res_latest: pd.DataFrame) -> list[str]:
1798
+ """Publish Battery Power and SOC."""
1799
+ cols = []
1800
+ if not ctx.opt.optim_conf["set_use_battery"]:
1801
+ return cols
1802
+ if "P_batt" not in opt_res_latest.columns:
1803
+ ctx.logger.error("P_batt was not found in results DataFrame.")
1804
+ return cols
1805
+ # Power
1806
+ custom_batt = ctx.params["passed_data"]["custom_batt_forecast_id"]
1807
+ await ctx.rh.post_data(
1808
+ opt_res_latest["P_batt"],
1809
+ ctx.idx,
1810
+ custom_batt["entity_id"],
1811
+ "power",
1812
+ custom_batt["unit_of_measurement"],
1813
+ custom_batt["friendly_name"],
1814
+ type_var="batt",
1815
+ **ctx.common_kwargs,
1816
+ )
1817
+ cols.append("P_batt")
1818
+ # SOC
1819
+ custom_soc = ctx.params["passed_data"]["custom_batt_soc_forecast_id"]
1820
+ await ctx.rh.post_data(
1821
+ opt_res_latest["SOC_opt"] * 100,
1822
+ ctx.idx,
1823
+ custom_soc["entity_id"],
1824
+ "battery",
1825
+ custom_soc["unit_of_measurement"],
1826
+ custom_soc["friendly_name"],
1827
+ type_var="SOC",
1828
+ **ctx.common_kwargs,
1829
+ )
1830
+ cols.append("SOC_opt")
1831
+ return cols
1832
+
1833
+
1834
+ async def _publish_grid_and_costs(ctx: PublishContext, opt_res_latest: pd.DataFrame) -> list[str]:
1835
+ """Publish Grid Power, Costs, and Optimization Status."""
1836
+ cols = []
1837
+ # Grid
1838
+ custom_grid = ctx.params["passed_data"]["custom_grid_forecast_id"]
1839
+ await ctx.rh.post_data(
1230
1840
  opt_res_latest["P_grid"],
1231
- idx_closest,
1232
- custom_grid_forecast_id["entity_id"],
1233
- custom_grid_forecast_id["unit_of_measurement"],
1234
- custom_grid_forecast_id["friendly_name"],
1841
+ ctx.idx,
1842
+ custom_grid["entity_id"],
1843
+ "power",
1844
+ custom_grid["unit_of_measurement"],
1845
+ custom_grid["friendly_name"],
1235
1846
  type_var="power",
1236
- publish_prefix=publish_prefix,
1237
- save_entities=entity_save,
1238
- dont_post=dont_post,
1847
+ **ctx.common_kwargs,
1239
1848
  )
1240
- cols_published = cols_published + ["P_grid"]
1241
- # Publish total value of cost function
1242
- custom_cost_fun_id = params["passed_data"]["custom_cost_fun_id"]
1849
+ cols.append("P_grid")
1850
+ # Cost Function
1851
+ custom_cost = ctx.params["passed_data"]["custom_cost_fun_id"]
1243
1852
  col_cost_fun = [i for i in opt_res_latest.columns if "cost_fun_" in i]
1244
- input_data_dict["rh"].post_data(
1853
+ await ctx.rh.post_data(
1245
1854
  opt_res_latest[col_cost_fun],
1246
- idx_closest,
1247
- custom_cost_fun_id["entity_id"],
1248
- custom_cost_fun_id["unit_of_measurement"],
1249
- custom_cost_fun_id["friendly_name"],
1855
+ ctx.idx,
1856
+ custom_cost["entity_id"],
1857
+ "monetary",
1858
+ custom_cost["unit_of_measurement"],
1859
+ custom_cost["friendly_name"],
1250
1860
  type_var="cost_fun",
1251
- publish_prefix=publish_prefix,
1252
- save_entities=entity_save,
1253
- dont_post=dont_post,
1861
+ **ctx.common_kwargs,
1254
1862
  )
1255
- # cols_published = cols_published + col_cost_fun
1256
- # Publish the optimization status
1257
- custom_cost_fun_id = params["passed_data"]["custom_optim_status_id"]
1863
+ # Optim Status
1864
+ custom_status = ctx.params["passed_data"]["custom_optim_status_id"]
1258
1865
  if "optim_status" not in opt_res_latest:
1259
1866
  opt_res_latest["optim_status"] = "Optimal"
1260
- logger.warning(
1261
- "no optim_status in opt_res_latest, run an optimization task first",
1262
- )
1263
- else:
1264
- input_data_dict["rh"].post_data(
1265
- opt_res_latest["optim_status"],
1266
- idx_closest,
1267
- custom_cost_fun_id["entity_id"],
1268
- custom_cost_fun_id["unit_of_measurement"],
1269
- custom_cost_fun_id["friendly_name"],
1270
- type_var="optim_status",
1271
- publish_prefix=publish_prefix,
1272
- save_entities=entity_save,
1273
- dont_post=dont_post,
1867
+ ctx.logger.warning("no optim_status in opt_res_latest")
1868
+ status_val = opt_res_latest["optim_status"]
1869
+ await ctx.rh.post_data(
1870
+ status_val,
1871
+ ctx.idx,
1872
+ custom_status["entity_id"],
1873
+ "",
1874
+ "",
1875
+ custom_status["friendly_name"],
1876
+ type_var="optim_status",
1877
+ **ctx.common_kwargs,
1878
+ )
1879
+ cols.append("optim_status")
1880
+ # Unit Costs
1881
+ for key, var_name in [
1882
+ ("custom_unit_load_cost_id", "unit_load_cost"),
1883
+ ("custom_unit_prod_price_id", "unit_prod_price"),
1884
+ ]:
1885
+ custom_id = ctx.params["passed_data"][key]
1886
+ await ctx.rh.post_data(
1887
+ opt_res_latest[var_name],
1888
+ ctx.idx,
1889
+ custom_id["entity_id"],
1890
+ "monetary",
1891
+ custom_id["unit_of_measurement"],
1892
+ custom_id["friendly_name"],
1893
+ type_var=var_name,
1894
+ **ctx.common_kwargs,
1274
1895
  )
1275
- cols_published = cols_published + ["optim_status"]
1276
- # Publish unit_load_cost
1277
- custom_unit_load_cost_id = params["passed_data"]["custom_unit_load_cost_id"]
1278
- input_data_dict["rh"].post_data(
1279
- opt_res_latest["unit_load_cost"],
1280
- idx_closest,
1281
- custom_unit_load_cost_id["entity_id"],
1282
- custom_unit_load_cost_id["unit_of_measurement"],
1283
- custom_unit_load_cost_id["friendly_name"],
1284
- type_var="unit_load_cost",
1285
- publish_prefix=publish_prefix,
1286
- save_entities=entity_save,
1287
- dont_post=dont_post,
1288
- )
1289
- cols_published = cols_published + ["unit_load_cost"]
1290
- # Publish unit_prod_price
1291
- custom_unit_prod_price_id = params["passed_data"]["custom_unit_prod_price_id"]
1292
- input_data_dict["rh"].post_data(
1293
- opt_res_latest["unit_prod_price"],
1294
- idx_closest,
1295
- custom_unit_prod_price_id["entity_id"],
1296
- custom_unit_prod_price_id["unit_of_measurement"],
1297
- custom_unit_prod_price_id["friendly_name"],
1298
- type_var="unit_prod_price",
1299
- publish_prefix=publish_prefix,
1300
- save_entities=entity_save,
1301
- dont_post=dont_post,
1302
- )
1303
- cols_published = cols_published + ["unit_prod_price"]
1304
- # Create a DF resuming what has been published
1896
+ cols.append(var_name)
1897
+ return cols
1898
+
1899
+
1900
+ async def publish_data(
1901
+ input_data_dict: dict,
1902
+ logger: logging.Logger,
1903
+ save_data_to_file: bool | None = False,
1904
+ opt_res_latest: pd.DataFrame | None = None,
1905
+ entity_save: bool | None = False,
1906
+ dont_post: bool | None = False,
1907
+ ) -> pd.DataFrame:
1908
+ """
1909
+ Publish the data obtained from the optimization results.
1910
+
1911
+ :param input_data_dict: A dictionnary with multiple data used by the action functions
1912
+ :type input_data_dict: dict
1913
+ :param logger: The passed logger object
1914
+ :type logger: logging object
1915
+ :param save_data_to_file: If True we will read data from optimization results in dayahead CSV file
1916
+ :type save_data_to_file: bool, optional
1917
+ :return: The output data of the optimization readed from a CSV file in the data folder
1918
+ :rtype: pd.DataFrame
1919
+ :param entity_save: Save built entities to data_path/entities
1920
+ :type entity_save: bool, optional
1921
+ :param dont_post: Do not post to Home Assistant. Works with entity_save
1922
+ :type dont_post: bool, optional
1923
+
1924
+ """
1925
+ logger.info("Publishing data to HASS instance")
1926
+ # Parse Parameters
1927
+ params = _get_params(input_data_dict)
1928
+ # Check for Entity Publishing (Prefix mode)
1929
+ publish_prefix = params["passed_data"].get("publish_prefix", "")
1930
+ if not save_data_to_file and publish_prefix != "" and not dont_post:
1931
+ opt_res = await _publish_from_saved_entities(input_data_dict, logger, params)
1932
+ if opt_res is not None:
1933
+ return opt_res
1934
+ # Load Optimization Results (if not passed)
1935
+ if opt_res_latest is None:
1936
+ opt_res_latest = _load_opt_res_latest(input_data_dict, logger, save_data_to_file)
1937
+ if opt_res_latest is None:
1938
+ return None
1939
+ # Determine Closest Index
1940
+ idx_closest = _get_closest_index(input_data_dict["retrieve_hass_conf"], opt_res_latest.index)
1941
+ # Create Context
1942
+ common_kwargs = {
1943
+ "publish_prefix": publish_prefix,
1944
+ "save_entities": entity_save,
1945
+ "dont_post": dont_post,
1946
+ }
1947
+ ctx = PublishContext(
1948
+ input_data_dict=input_data_dict,
1949
+ params=params,
1950
+ idx=idx_closest,
1951
+ common_kwargs=common_kwargs,
1952
+ logger=logger,
1953
+ )
1954
+ # Publish Data Components
1955
+ cols_published = []
1956
+ cols_published.extend(await _publish_standard_forecasts(ctx, opt_res_latest))
1957
+ cols_published.extend(await _publish_deferrable_loads(ctx, opt_res_latest))
1958
+ cols_published.extend(await _publish_thermal_loads(ctx, opt_res_latest))
1959
+ cols_published.extend(await _publish_battery_data(ctx, opt_res_latest))
1960
+ cols_published.extend(await _publish_grid_and_costs(ctx, opt_res_latest))
1961
+ # Return Summary DataFrame
1305
1962
  opt_res = opt_res_latest[cols_published].loc[[opt_res_latest.index[idx_closest]]]
1306
1963
  return opt_res
1307
1964
 
1308
1965
 
1309
- def continual_publish(
1966
+ async def continual_publish(
1310
1967
  input_data_dict: dict, entity_path: pathlib.Path, logger: logging.Logger
1311
1968
  ):
1312
1969
  """
@@ -1318,58 +1975,62 @@ def continual_publish(
1318
1975
  :type entity_path: Path
1319
1976
  :param logger: The passed logger object
1320
1977
  :type logger: logging.Logger
1321
-
1322
1978
  """
1323
1979
  logger.info("Continual publish thread service started")
1324
1980
  freq = input_data_dict["retrieve_hass_conf"].get(
1325
1981
  "optimization_time_step", pd.to_timedelta(1, "minutes")
1326
1982
  )
1327
- entity_path_contents = []
1328
1983
  while True:
1329
1984
  # Sleep for x seconds (using current time as a reference for time left)
1330
- time.sleep(
1331
- max(
1332
- 0,
1333
- freq.total_seconds()
1334
- - (
1335
- datetime.now(
1336
- input_data_dict["retrieve_hass_conf"]["time_zone"]
1337
- ).timestamp()
1338
- % 60
1339
- ),
1340
- )
1341
- )
1342
- # Loop through all saved entity files
1343
- if os.path.exists(entity_path) and len(os.listdir(entity_path)) > 0:
1344
- entity_path_contents = os.listdir(entity_path)
1345
- for entity in entity_path_contents:
1346
- if entity != "metadata.json":
1347
- # Call publish_json with entity file, build entity, and publish
1348
- publish_json(
1349
- entity,
1350
- input_data_dict,
1351
- entity_path,
1352
- logger,
1353
- "continual_publish",
1354
- )
1355
- # Retrieve entity metadata from file
1356
- if os.path.isfile(entity_path / "metadata.json"):
1357
- with open(entity_path / "metadata.json", "r") as file:
1358
- metadata = json.load(file)
1359
- # Check if freq should be shorter
1360
- if not metadata.get("lowest_time_step", None) == None:
1361
- freq = pd.to_timedelta(metadata["lowest_time_step"], "minutes")
1362
- pass
1363
- # This function should never return
1985
+ time_zone = input_data_dict["retrieve_hass_conf"]["time_zone"]
1986
+ timestamp_diff = freq.total_seconds() - (datetime.now(time_zone).timestamp() % 60)
1987
+ sleep_seconds = max(0.0, min(timestamp_diff, 60.0))
1988
+ await asyncio.sleep(sleep_seconds)
1989
+ # Delegate processing to helper function to reduce complexity
1990
+ freq = await _publish_and_update_freq(input_data_dict, entity_path, logger, freq)
1364
1991
  return False
1365
1992
 
1366
1993
 
1367
- def publish_json(
1994
+ async def _publish_and_update_freq(input_data_dict, entity_path, logger, current_freq):
1995
+ """
1996
+ Helper to process entity publishing and frequency updates.
1997
+ Returns the (potentially updated) frequency.
1998
+ """
1999
+ # Guard clause: if path doesn't exist, do nothing and return current freq
2000
+ if not os.path.exists(entity_path):
2001
+ return current_freq
2002
+ entity_path_contents = os.listdir(entity_path)
2003
+ # Guard clause: if directory is empty, do nothing
2004
+ if not entity_path_contents:
2005
+ return current_freq
2006
+ # Loop through all saved entity files
2007
+ for entity in entity_path_contents:
2008
+ if entity != default_metadata_json:
2009
+ await publish_json(
2010
+ entity,
2011
+ input_data_dict,
2012
+ entity_path,
2013
+ logger,
2014
+ "continual_publish",
2015
+ )
2016
+ # Retrieve entity metadata from file
2017
+ metadata_file = entity_path / default_metadata_json
2018
+ if os.path.isfile(metadata_file):
2019
+ async with aiofiles.open(metadata_file) as file:
2020
+ content = await file.read()
2021
+ metadata = orjson.loads(content)
2022
+ # Check if freq should be shorter
2023
+ if metadata.get("lowest_time_step") is not None:
2024
+ return pd.to_timedelta(metadata["lowest_time_step"], "minutes")
2025
+ return current_freq
2026
+
2027
+
2028
+ async def publish_json(
1368
2029
  entity: dict,
1369
2030
  input_data_dict: dict,
1370
2031
  entity_path: pathlib.Path,
1371
2032
  logger: logging.Logger,
1372
- reference: Optional[str] = "",
2033
+ reference: str | None = "",
1373
2034
  ):
1374
2035
  """
1375
2036
  Extract saved entity data from .json (in data_path/entities), build entity, post results to post_data
@@ -1387,16 +2048,17 @@ def publish_json(
1387
2048
 
1388
2049
  """
1389
2050
  # Retrieve entity metadata from file
1390
- if os.path.isfile(entity_path / "metadata.json"):
1391
- with open(entity_path / "metadata.json", "r") as file:
1392
- metadata = json.load(file)
2051
+ if os.path.isfile(entity_path / default_metadata_json):
2052
+ async with aiofiles.open(entity_path / default_metadata_json) as file:
2053
+ content = await file.read()
2054
+ metadata = orjson.loads(content)
1393
2055
  else:
1394
2056
  logger.error("unable to located metadata.json in:" + entity_path)
1395
2057
  return False
1396
2058
  # Round current timecode (now)
1397
- now_precise = datetime.now(
1398
- input_data_dict["retrieve_hass_conf"]["time_zone"]
1399
- ).replace(second=0, microsecond=0)
2059
+ now_precise = datetime.now(input_data_dict["retrieve_hass_conf"]["time_zone"]).replace(
2060
+ second=0, microsecond=0
2061
+ )
1400
2062
  # Retrieve entity data from file
1401
2063
  entity_data = pd.read_json(entity_path / entity, orient="index")
1402
2064
  # Remove ".json" from string for entity_id
@@ -1426,10 +2088,11 @@ def publish_json(
1426
2088
  else:
1427
2089
  logger_levels = "INFO"
1428
2090
  # post/save entity
1429
- input_data_dict["rh"].post_data(
2091
+ await input_data_dict["rh"].post_data(
1430
2092
  data_df=entity_data[metadata[entity_id]["name"]],
1431
2093
  idx=idx_closest,
1432
2094
  entity_id=entity_id,
2095
+ device_class=dict.get(metadata[entity_id], "device_class"),
1433
2096
  unit_of_measurement=metadata[entity_id]["unit_of_measurement"],
1434
2097
  friendly_name=metadata[entity_id]["friendly_name"],
1435
2098
  type_var=metadata[entity_id].get("type_var", ""),
@@ -1439,7 +2102,7 @@ def publish_json(
1439
2102
  return entity_data[metadata[entity_id]["name"]]
1440
2103
 
1441
2104
 
1442
- def main():
2105
+ async def main():
1443
2106
  r"""Define the main command line entry function.
1444
2107
 
1445
2108
  This function may take several arguments as inputs. You can type `emhass --help` to see the list of options:
@@ -1477,9 +2140,7 @@ def main():
1477
2140
  default=None,
1478
2141
  help="String of configuration parameters passed",
1479
2142
  )
1480
- parser.add_argument(
1481
- "--data", type=str, help="Define path to the Data files (.csv & .pkl)"
1482
- )
2143
+ parser.add_argument("--data", type=str, help="Define path to the Data files (.csv & .pkl)")
1483
2144
  parser.add_argument("--root", type=str, help="Define path emhass root")
1484
2145
  parser.add_argument(
1485
2146
  "--costfun",
@@ -1489,8 +2150,8 @@ def main():
1489
2150
  )
1490
2151
  parser.add_argument(
1491
2152
  "--log2file",
1492
- type=strtobool,
1493
- default="False",
2153
+ type=bool,
2154
+ default=False,
1494
2155
  help="Define if we should log to a file or not",
1495
2156
  )
1496
2157
  parser.add_argument(
@@ -1506,7 +2167,10 @@ def main():
1506
2167
  help="Pass runtime optimization parameters as dictionnary",
1507
2168
  )
1508
2169
  parser.add_argument(
1509
- "--debug", type=strtobool, default="False", help="Use True for testing purposes"
2170
+ "--debug",
2171
+ type=bool,
2172
+ default=False,
2173
+ help="Use True for testing purposes",
1510
2174
  )
1511
2175
  args = parser.parse_args()
1512
2176
 
@@ -1514,9 +2178,7 @@ def main():
1514
2178
  if args.config is not None:
1515
2179
  config_path = pathlib.Path(args.config)
1516
2180
  else:
1517
- config_path = pathlib.Path(
1518
- str(utils.get_root(__file__, num_parent=3) / "config.json")
1519
- )
2181
+ config_path = pathlib.Path(str(utils.get_root(__file__, num_parent=3) / "config.json"))
1520
2182
  if args.data is not None:
1521
2183
  data_path = pathlib.Path(args.data)
1522
2184
  else:
@@ -1540,18 +2202,14 @@ def main():
1540
2202
  emhass_conf["associations_path"] = associations_path
1541
2203
  emhass_conf["defaults_path"] = defaults_path
1542
2204
  # create logger
1543
- logger, ch = utils.get_logger(
1544
- __name__, emhass_conf, save_to_file=bool(args.log2file)
1545
- )
2205
+ logger, ch = utils.get_logger(__name__, emhass_conf, save_to_file=bool(args.log2file))
1546
2206
 
1547
2207
  # Check paths
1548
2208
  logger.debug("config path: " + str(config_path))
1549
2209
  logger.debug("data path: " + str(data_path))
1550
2210
  logger.debug("root path: " + str(root_path))
1551
2211
  if not associations_path.exists():
1552
- logger.error(
1553
- "Could not find associations.csv file in: " + str(associations_path)
1554
- )
2212
+ logger.error("Could not find associations.csv file in: " + str(associations_path))
1555
2213
  logger.error("Try setting config file path with --associations")
1556
2214
  return False
1557
2215
  if not config_path.exists():
@@ -1586,49 +2244,53 @@ def main():
1586
2244
  config = {}
1587
2245
  # Check if passed config file is yaml of json, build config accordingly
1588
2246
  if config_path.exists():
1589
- config_file_ending = re.findall("(?<=\.).*$", str(config_path))
1590
- if len(config_file_ending) > 0:
1591
- match config_file_ending[0]:
2247
+ # Safe: Use pathlib's suffix instead of regex to avoid ReDoS
2248
+ file_extension = config_path.suffix.lstrip(".").lower()
2249
+
2250
+ if file_extension:
2251
+ match file_extension:
1592
2252
  case "json":
1593
- config = utils.build_config(
2253
+ config = await utils.build_config(
1594
2254
  emhass_conf, logger, defaults_path, config_path
1595
2255
  )
1596
- case "yaml":
1597
- config = utils.build_config(
2256
+ case "yaml" | "yml":
2257
+ config = await utils.build_config(
1598
2258
  emhass_conf, logger, defaults_path, config_path=config_path
1599
2259
  )
1600
- case "yml":
1601
- config = utils.build_config(
1602
- emhass_conf, logger, defaults_path, config_path=config_path
2260
+ case _:
2261
+ logger.warning(
2262
+ f"Unsupported config file format: .{file_extension}, building parameters with only defaults"
1603
2263
  )
1604
- # If unable to find config file, use only defaults_config.json
2264
+ config = await utils.build_config(emhass_conf, logger, defaults_path)
2265
+ else:
2266
+ logger.warning("Config file has no extension, building parameters with only defaults")
2267
+ config = await utils.build_config(emhass_conf, logger, defaults_path)
1605
2268
  else:
1606
- logger.warning(
1607
- "Unable to obtain config.json file, building parameters with only defaults"
1608
- )
1609
- config = utils.build_config(emhass_conf, logger, defaults_path)
2269
+ # If unable to find config file, use only defaults_config.json
2270
+ logger.warning("Unable to obtain config.json file, building parameters with only defaults")
2271
+ config = await utils.build_config(emhass_conf, logger, defaults_path)
1610
2272
  if type(config) is bool and not config:
1611
2273
  raise Exception("Failed to find default config")
1612
2274
 
1613
2275
  # Obtain secrets from secrets_emhass.yaml?
1614
2276
  params_secrets = {}
1615
- emhass_conf, built_secrets = utils.build_secrets(
2277
+ emhass_conf, built_secrets = await utils.build_secrets(
1616
2278
  emhass_conf, logger, secrets_path=secrets_path
1617
2279
  )
1618
2280
  params_secrets.update(built_secrets)
1619
2281
 
1620
2282
  # Build params
1621
- params = utils.build_params(emhass_conf, params_secrets, config, logger)
2283
+ params = await utils.build_params(emhass_conf, params_secrets, config, logger)
1622
2284
  if type(params) is bool:
1623
2285
  raise Exception("A error has occurred while building parameters")
1624
2286
  # Add any passed params from args to params
1625
2287
  if args.params:
1626
- params.update(json.loads(args.params))
2288
+ params.update(orjson.loads(args.params))
1627
2289
 
1628
- input_data_dict = set_input_data_dict(
2290
+ input_data_dict = await set_input_data_dict(
1629
2291
  emhass_conf,
1630
2292
  args.costfun,
1631
- json.dumps(params),
2293
+ orjson.dumps(params).decode("utf-8"),
1632
2294
  args.runtimeparams,
1633
2295
  args.action,
1634
2296
  logger,
@@ -1639,52 +2301,53 @@ def main():
1639
2301
 
1640
2302
  # Perform selected action
1641
2303
  if args.action == "perfect-optim":
1642
- opt_res = perfect_forecast_optim(input_data_dict, logger, debug=args.debug)
2304
+ opt_res = await perfect_forecast_optim(input_data_dict, logger, debug=args.debug)
1643
2305
  elif args.action == "dayahead-optim":
1644
- opt_res = dayahead_forecast_optim(input_data_dict, logger, debug=args.debug)
2306
+ opt_res = await dayahead_forecast_optim(input_data_dict, logger, debug=args.debug)
1645
2307
  elif args.action == "naive-mpc-optim":
1646
- opt_res = naive_mpc_optim(input_data_dict, logger, debug=args.debug)
2308
+ opt_res = await naive_mpc_optim(input_data_dict, logger, debug=args.debug)
1647
2309
  elif args.action == "forecast-model-fit":
1648
- df_fit_pred, df_fit_pred_backtest, mlf = forecast_model_fit(
2310
+ df_fit_pred, df_fit_pred_backtest, mlf = await forecast_model_fit(
1649
2311
  input_data_dict, logger, debug=args.debug
1650
2312
  )
1651
2313
  opt_res = None
1652
2314
  elif args.action == "forecast-model-predict":
1653
2315
  if args.debug:
1654
- _, _, mlf = forecast_model_fit(input_data_dict, logger, debug=args.debug)
2316
+ _, _, mlf = await forecast_model_fit(input_data_dict, logger, debug=args.debug)
1655
2317
  else:
1656
2318
  mlf = None
1657
- df_pred = forecast_model_predict(
1658
- input_data_dict, logger, debug=args.debug, mlf=mlf
1659
- )
2319
+ df_pred = await forecast_model_predict(input_data_dict, logger, debug=args.debug, mlf=mlf)
1660
2320
  opt_res = None
1661
2321
  elif args.action == "forecast-model-tune":
1662
2322
  if args.debug:
1663
- _, _, mlf = forecast_model_fit(input_data_dict, logger, debug=args.debug)
2323
+ _, _, mlf = await forecast_model_fit(input_data_dict, logger, debug=args.debug)
1664
2324
  else:
1665
2325
  mlf = None
1666
- df_pred_optim, mlf = forecast_model_tune(
2326
+ df_pred_optim, mlf = await forecast_model_tune(
1667
2327
  input_data_dict, logger, debug=args.debug, mlf=mlf
1668
2328
  )
1669
2329
  opt_res = None
1670
2330
  elif args.action == "regressor-model-fit":
1671
- mlr = regressor_model_fit(input_data_dict, logger, debug=args.debug)
2331
+ mlr = await regressor_model_fit(input_data_dict, logger, debug=args.debug)
1672
2332
  opt_res = None
1673
2333
  elif args.action == "regressor-model-predict":
1674
2334
  if args.debug:
1675
- mlr = regressor_model_fit(input_data_dict, logger, debug=args.debug)
2335
+ mlr = await regressor_model_fit(input_data_dict, logger, debug=args.debug)
1676
2336
  else:
1677
2337
  mlr = None
1678
- prediction = regressor_model_predict(
2338
+ prediction = await regressor_model_predict(
1679
2339
  input_data_dict, logger, debug=args.debug, mlr=mlr
1680
2340
  )
1681
2341
  opt_res = None
2342
+ elif args.action == "export-influxdb-to-csv":
2343
+ success = await export_influxdb_to_csv(input_data_dict, logger)
2344
+ opt_res = None
1682
2345
  elif args.action == "publish-data":
1683
- opt_res = publish_data(input_data_dict, logger)
2346
+ opt_res = await publish_data(input_data_dict, logger)
1684
2347
  else:
1685
2348
  logger.error("The passed action argument is not valid")
1686
2349
  logger.error(
1687
- "Try setting --action: perfect-optim, dayahead-optim, naive-mpc-optim, forecast-model-fit, forecast-model-predict, forecast-model-tune or publish-data"
2350
+ "Try setting --action: perfect-optim, dayahead-optim, naive-mpc-optim, forecast-model-fit, forecast-model-predict, forecast-model-tune, export-influxdb-to-csv or publish-data"
1688
2351
  )
1689
2352
  opt_res = None
1690
2353
  logger.info(opt_res)
@@ -1706,11 +2369,18 @@ def main():
1706
2369
  return mlr
1707
2370
  elif args.action == "regressor-model-predict":
1708
2371
  return prediction
2372
+ elif args.action == "export-influxdb-to-csv":
2373
+ return success
1709
2374
  elif args.action == "forecast-model-tune":
1710
2375
  return df_pred_optim, mlf
1711
2376
  else:
1712
2377
  return opt_res
1713
2378
 
1714
2379
 
2380
+ def main_sync():
2381
+ """Sync wrapper for async main function - used as CLI entry point."""
2382
+ asyncio.run(main())
2383
+
2384
+
1715
2385
  if __name__ == "__main__":
1716
- main()
2386
+ main_sync()