emhass 0.10.6__py3-none-any.whl → 0.15.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
emhass/retrieve_hass.py CHANGED
@@ -1,40 +1,57 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- import json
1
+ import asyncio
5
2
  import copy
3
+ import logging
6
4
  import os
7
5
  import pathlib
8
- import datetime
9
- import logging
10
- from typing import Optional
6
+ import time
7
+ from datetime import datetime, timezone
8
+ from typing import Any
9
+
10
+ import aiofiles
11
+ import aiohttp
11
12
  import numpy as np
13
+ import orjson
12
14
  import pandas as pd
13
- from requests import get, post
14
15
 
16
+ from emhass.connection_manager import get_websocket_client
15
17
  from emhass.utils import set_df_index_freq
16
18
 
19
+ logger = logging.getLogger(__name__)
20
+
21
+ header_accept = "application/json"
22
+ header_auth = "Bearer"
23
+ hass_url = "http://supervisor/core/api"
24
+ sensor_prefix = "sensor."
25
+
17
26
 
18
27
  class RetrieveHass:
19
28
  r"""
20
29
  Retrieve data from Home Assistant using the restful API.
21
-
30
+
22
31
  This class allows the user to retrieve data from a Home Assistant instance \
23
32
  using the provided restful API (https://developers.home-assistant.io/docs/api/rest/)
24
-
33
+
25
34
  This class methods are:
26
-
35
+
27
36
  - get_data: to retrieve the actual data from hass
28
-
37
+
29
38
  - prepare_data: to apply some data treatment in preparation for the optimization task
30
-
39
+
31
40
  - post_data: Post passed data to hass
32
-
41
+
33
42
  """
34
43
 
35
- def __init__(self, hass_url: str, long_lived_token: str, freq: pd.Timedelta,
36
- time_zone: datetime.timezone, params: str, emhass_conf: dict, logger: logging.Logger,
37
- get_data_from_file: Optional[bool] = False) -> None:
44
+ def __init__(
45
+ self,
46
+ hass_url: str,
47
+ long_lived_token: str,
48
+ freq: pd.Timedelta,
49
+ time_zone: timezone,
50
+ params: str,
51
+ emhass_conf: dict,
52
+ logger: logging.Logger,
53
+ get_data_from_file: bool | None = False,
54
+ ) -> None:
38
55
  """
39
56
  Define constructor for RetrieveHass class.
40
57
 
@@ -62,17 +79,135 @@ class RetrieveHass:
62
79
  self.long_lived_token = long_lived_token
63
80
  self.freq = freq
64
81
  self.time_zone = time_zone
65
- self.params = params
82
+ if (params is None) or (params == "null"):
83
+ self.params = {}
84
+ elif type(params) is dict:
85
+ self.params = params
86
+ else:
87
+ self.params = orjson.loads(params)
66
88
  self.emhass_conf = emhass_conf
67
89
  self.logger = logger
68
90
  self.get_data_from_file = get_data_from_file
91
+ self.var_list = []
92
+ self.use_websocket = self.params.get("retrieve_hass_conf", {}).get("use_websocket", False)
93
+ if self.use_websocket:
94
+ self._client = None
95
+ else:
96
+ self.logger.debug("Websocket integration disabled, using Home Assistant API")
97
+ # Initialize InfluxDB configuration
98
+ self.use_influxdb = self.params.get("retrieve_hass_conf", {}).get("use_influxdb", False)
99
+ if self.use_influxdb:
100
+ influx_conf = self.params.get("retrieve_hass_conf", {})
101
+ self.influxdb_host = influx_conf.get("influxdb_host", "localhost")
102
+ self.influxdb_port = influx_conf.get("influxdb_port", 8086)
103
+ self.influxdb_username = influx_conf.get("influxdb_username", "")
104
+ self.influxdb_password = influx_conf.get("influxdb_password", "")
105
+ self.influxdb_database = influx_conf.get("influxdb_database", "homeassistant")
106
+ self.influxdb_measurement = influx_conf.get("influxdb_measurement", "W")
107
+ self.influxdb_retention_policy = influx_conf.get("influxdb_retention_policy", "autogen")
108
+ self.influxdb_use_ssl = influx_conf.get("influxdb_use_ssl", False)
109
+ self.influxdb_verify_ssl = influx_conf.get("influxdb_verify_ssl", False)
110
+ self.logger.info(
111
+ f"InfluxDB integration enabled: {self.influxdb_host}:{self.influxdb_port}/{self.influxdb_database}"
112
+ )
113
+ else:
114
+ self.logger.debug("InfluxDB integration disabled, using Home Assistant API")
115
+
116
+ async def get_ha_config(self):
117
+ """
118
+ Extract some configuration data from HA.
119
+
120
+ :rtype: bool
121
+ """
122
+ # Initialize empty config immediately for safety
123
+ self.ha_config = {}
124
+
125
+ # Check if variables are None, empty strings, or explicitly set to "empty"
126
+ if (
127
+ not self.hass_url
128
+ or self.hass_url == "empty"
129
+ or not self.long_lived_token
130
+ or self.long_lived_token == "empty"
131
+ ):
132
+ self.logger.info(
133
+ "No Home Assistant URL or Long Lived Token found. Using only local configuration file."
134
+ )
135
+ return True
136
+
137
+ # Use WebSocket if configured
138
+ if self.use_websocket:
139
+ return await self.get_ha_config_websocket()
140
+
141
+ self.logger.info("get HA config from rest api.")
142
+
143
+ # Set up headers
144
+ headers = {
145
+ "Authorization": header_auth + " " + self.long_lived_token,
146
+ "content-type": header_accept,
147
+ }
148
+
149
+ # Construct the URL (incorporating the PR's helpful checks)
150
+ # The Supervisor API sometimes uses a different path structure
151
+ if self.hass_url == hass_url:
152
+ url = self.hass_url + "/config"
153
+ else:
154
+ # Helpful check for users who forget the trailing slash
155
+ if not self.hass_url.endswith("/"):
156
+ self.logger.warning(
157
+ "The defined HA URL is missing a trailing slash </>. Appending it, but please fix your configuration."
158
+ )
159
+ self.hass_url = self.hass_url + "/"
160
+ url = self.hass_url + "api/config"
161
+
162
+ # Attempt the connection
163
+ try:
164
+ async with aiohttp.ClientSession() as session:
165
+ async with session.get(url, headers=headers) as response:
166
+ # Check for HTTP errors (404, 401, 500) before trying to parse JSON
167
+ response.raise_for_status()
168
+ data = await response.read()
169
+ self.ha_config = orjson.loads(data)
170
+ return True
171
+
172
+ except Exception as e:
173
+ # Granular Error Logging
174
+ # We log the specific error 'e' so the user knows if it's a Timeout, Connection Refused, or 401 Auth error
175
+ self.logger.error(f"Unable to obtain configuration from Home Assistant at: {url}")
176
+ self.logger.error(f"Error details: {e}")
177
+
178
+ # Helpful hint for Add-on users without confusing Docker users
179
+ if "supervisor" in self.hass_url:
180
+ self.logger.error(
181
+ "If using the add-on, try setting url and token to 'empty' to force local config."
182
+ )
69
183
 
70
- def get_data(self, days_list: pd.date_range, var_list: list,
71
- minimal_response: Optional[bool] = False, significant_changes_only: Optional[bool] = False,
72
- test_url: Optional[str] = "empty") -> None:
184
+ return False
185
+
186
+ async def get_ha_config_websocket(self) -> dict[str, Any]:
187
+ """Get Home Assistant configuration."""
188
+ try:
189
+ self._client = await get_websocket_client(
190
+ self.hass_url, self.long_lived_token, self.logger
191
+ )
192
+ self.ha_config = await self._client.get_config()
193
+ return self.ha_config
194
+ except Exception as e:
195
+ self.logger.error(
196
+ f"EMHASS was unable to obtain configuration data from Home Assistant through websocket: {e}"
197
+ )
198
+ raise
199
+
200
+ async def get_data(
201
+ self,
202
+ days_list: pd.date_range,
203
+ var_list: list,
204
+ minimal_response: bool | None = False,
205
+ significant_changes_only: bool | None = False,
206
+ test_url: str | None = "empty",
207
+ ) -> None:
73
208
  r"""
74
209
  Retrieve the actual data from hass.
75
-
210
+
76
211
  :param days_list: A list of days to retrieve. The ISO format should be used \
77
212
  and the timezone is UTC. The frequency of the data_range should be freq='D'
78
213
  :type days_list: pandas.date_range
@@ -88,129 +223,638 @@ class RetrieveHass:
88
223
  :type significant_changes_only: bool, optional
89
224
  :return: The DataFrame populated with the retrieved data from hass
90
225
  :rtype: pandas.DataFrame
91
-
92
- .. warning:: The minimal_response and significant_changes_only options \
93
- are experimental
94
226
  """
227
+ # Use InfluxDB if configured (Prioritize over WebSocket/REST for history)
228
+ if self.use_influxdb:
229
+ return self.get_data_influxdb(days_list, var_list)
230
+
231
+ # Use WebSockets if configured, otherwise use Home Assistant REST API
232
+ if self.use_websocket:
233
+ success = await self.get_data_websocket(days_list, var_list)
234
+ if not success:
235
+ self.logger.warning("WebSocket data retrieval failed, falling back to REST API")
236
+ # Fall back to REST API if websocket fails
237
+ return await self._get_data_rest_api(
238
+ days_list,
239
+ var_list,
240
+ minimal_response,
241
+ significant_changes_only,
242
+ test_url,
243
+ )
244
+ return success
245
+
246
+ self.logger.info("Using REST API for data retrieval")
247
+ return await self._get_data_rest_api(
248
+ days_list, var_list, minimal_response, significant_changes_only, test_url
249
+ )
250
+
251
+ def _build_history_url(
252
+ self,
253
+ day: pd.Timestamp,
254
+ var: str,
255
+ test_url: str,
256
+ minimal_response: bool,
257
+ significant_changes_only: bool,
258
+ ) -> str:
259
+ """Helper to construct the Home Assistant History URL."""
260
+ if test_url != "empty":
261
+ return test_url
262
+ # Check if using supervisor API or Core API
263
+ if self.hass_url == hass_url:
264
+ base_url = f"{self.hass_url}/history/period/{day.isoformat()}"
265
+ else:
266
+ # Ensure trailing slash for Core API
267
+ if self.hass_url[-1] != "/":
268
+ self.logger.warning(
269
+ "Missing slash </> at the end of the defined URL, appending a slash but please fix your URL"
270
+ )
271
+ self.hass_url = self.hass_url + "/"
272
+ base_url = f"{self.hass_url}api/history/period/{day.isoformat()}"
273
+ url = f"{base_url}?filter_entity_id={var}"
274
+ if minimal_response:
275
+ url += "?minimal_response"
276
+ if significant_changes_only:
277
+ url += "?significant_changes_only"
278
+ return url
279
+
280
+ async def _fetch_history_data(
281
+ self,
282
+ session: aiohttp.ClientSession,
283
+ url: str,
284
+ headers: dict,
285
+ var: str,
286
+ day: pd.Timestamp,
287
+ is_first_day: bool,
288
+ ) -> list | bool:
289
+ """Helper to execute the HTTP request and return the raw JSON list."""
290
+ try:
291
+ async with session.get(url, headers=headers) as response:
292
+ response.raise_for_status()
293
+ data = await response.read()
294
+ data_list = orjson.loads(data)
295
+ except Exception:
296
+ self.logger.error("Unable to access Home Assistant instance, check URL")
297
+ self.logger.error("If using addon, try setting url and token to 'empty'")
298
+ return False
299
+ if response.status == 401:
300
+ self.logger.error("Unable to access Home Assistant instance, TOKEN/KEY")
301
+ return False
302
+ if response.status > 299:
303
+ self.logger.error(f"Home assistant request GET error: {response.status} for var {var}")
304
+ return False
305
+ try:
306
+ return data_list[0]
307
+ except IndexError:
308
+ if is_first_day:
309
+ self.logger.error(
310
+ f"The retrieved JSON is empty, A sensor: {var} may have 0 days of history, "
311
+ "passed sensor may not be correct, or days to retrieve is set too high."
312
+ )
313
+ else:
314
+ self.logger.error(
315
+ f"The retrieved JSON is empty for day: {day}, days_to_retrieve may be larger "
316
+ f"than the recorded history of sensor: {var}"
317
+ )
318
+ return False
319
+
320
+ def _process_history_dataframe(
321
+ self, data: list, var: str, day: pd.Timestamp, is_first_day: bool, is_last_day: bool
322
+ ) -> pd.DataFrame | bool:
323
+ """Helper to convert raw data to a resampled DataFrame."""
324
+ df_raw = pd.DataFrame.from_dict(data)
325
+ # Check for empty DataFrame
326
+ if len(df_raw) == 0:
327
+ if is_first_day:
328
+ self.logger.error(
329
+ f"The retrieved Dataframe is empty, A sensor: {var} may have 0 days of history."
330
+ )
331
+ else:
332
+ self.logger.error(
333
+ f"Retrieved empty Dataframe for day: {day}, check recorder settings."
334
+ )
335
+ return False
336
+ # Check for data sufficiency (frequency consistency)
337
+ expected_count = (60 / (self.freq.seconds / 60)) * 24
338
+ if len(df_raw) < expected_count and not is_last_day:
339
+ self.logger.debug(
340
+ f"sensor: {var} retrieved Dataframe count: {len(df_raw)}, on day: {day}. "
341
+ f"This is less than freq value passed: {self.freq}"
342
+ )
343
+ # Process and Resample
344
+ df_tp = (
345
+ df_raw.copy()[["state"]]
346
+ .replace(["unknown", "unavailable", ""], np.nan)
347
+ .astype(float)
348
+ .rename(columns={"state": var})
349
+ )
350
+ df_tp.set_index(pd.to_datetime(df_raw["last_changed"], format="ISO8601"), inplace=True)
351
+ df_tp = df_tp.resample(self.freq).mean()
352
+ return df_tp
353
+
354
+ async def _get_data_rest_api(
355
+ self,
356
+ days_list: pd.date_range,
357
+ var_list: list,
358
+ minimal_response: bool | None = False,
359
+ significant_changes_only: bool | None = False,
360
+ test_url: str | None = "empty",
361
+ ) -> None:
362
+ """Internal method to handle REST API data retrieval."""
95
363
  self.logger.info("Retrieve hass get data method initiated...")
364
+ headers = {
365
+ "Authorization": header_auth + " " + self.long_lived_token,
366
+ "content-type": header_accept,
367
+ }
368
+ var_list = [var for var in var_list if var != ""]
96
369
  self.df_final = pd.DataFrame()
97
- x = 0 # iterate based on days
98
- # Looping on each day from days list
99
- for day in days_list:
100
- for i, var in enumerate(var_list):
101
- if test_url == "empty":
102
- if (
103
- self.hass_url == "http://supervisor/core/api"
104
- ): # If we are using the supervisor API
105
- url = (
106
- self.hass_url
107
- + "/history/period/"
108
- + day.isoformat()
109
- + "?filter_entity_id="
110
- + var
111
- )
112
- else: # Otherwise the Home Assistant Core API it is
113
- url = (
114
- self.hass_url
115
- + "api/history/period/"
116
- + day.isoformat()
117
- + "?filter_entity_id="
118
- + var
119
- )
120
- if minimal_response: # A support for minimal response
121
- url = url + "?minimal_response"
122
- if (
123
- significant_changes_only
124
- ): # And for signicant changes only (check the HASS restful API for more info)
125
- url = url + "?significant_changes_only"
126
- else:
127
- url = test_url
128
- headers = {
129
- "Authorization": "Bearer " + self.long_lived_token,
130
- "content-type": "application/json",
131
- }
132
- try:
133
- response = get(url, headers=headers)
134
- except Exception:
135
- self.logger.error(
136
- "Unable to access Home Assistance instance, check URL"
370
+
371
+ async with aiohttp.ClientSession() as session:
372
+ for day_idx, day in enumerate(days_list):
373
+ df_day = pd.DataFrame()
374
+ for i, var in enumerate(var_list):
375
+ # Build URL
376
+ url = self._build_history_url(
377
+ day, var, test_url, minimal_response, significant_changes_only
137
378
  )
138
- self.logger.error(
139
- "If using addon, try setting url and token to 'empty'"
379
+ # Fetch Data
380
+ data = await self._fetch_history_data(
381
+ session, url, headers, var, day, is_first_day=(day_idx == 0)
140
382
  )
141
- return False
142
- else:
143
- if response.status_code == 401:
144
- self.logger.error(
145
- "Unable to access Home Assistance instance, TOKEN/KEY"
146
- )
147
- self.logger.error(
148
- "If using addon, try setting url and token to 'empty'"
149
- )
383
+ if data is False:
150
384
  return False
151
- if response.status_code > 299:
152
- return f"Request Get Error: {response.status_code}"
153
- """import bz2 # Uncomment to save a serialized data for tests
154
- import _pickle as cPickle
155
- with bz2.BZ2File("data/test_response_get_data_get_method.pbz2", "w") as f:
156
- cPickle.dump(response, f)"""
157
- try: # Sometimes when there are connection problems we need to catch empty retrieved json
158
- data = response.json()[0]
159
- except IndexError:
160
- if x == 0:
161
- self.logger.error("The retrieved JSON is empty, A sensor:" + var + " may have 0 days of history, passed sensor may not be correct, or days to retrieve is set too heigh")
162
- else:
163
- self.logger.error("The retrieved JSON is empty for day:"+ str(day) +", days_to_retrieve may be larger than the recorded history of sensor:" + var + " (check your recorder settings)")
164
- return False
165
- df_raw = pd.DataFrame.from_dict(data)
166
- # self.logger.info(str(df_raw))
167
- if len(df_raw) == 0:
168
- if x == 0:
169
- self.logger.error(
170
- "The retrieved Dataframe is empty, A sensor:"
171
- + var
172
- + " may have 0 days of history or passed sensor may not be correct"
173
- )
174
- else:
175
- self.logger.error("Retrieved empty Dataframe for day:"+ str(day) +", days_to_retrieve may be larger than the recorded history of sensor:" + var + " (check your recorder settings)")
176
- return False
177
- # self.logger.info(self.freq.seconds)
178
- if len(df_raw) < ((60 / (self.freq.seconds / 60)) * 24) and x != len(days_list) -1: #check if there is enough Dataframes for passed frequency per day (not inc current day)
179
- self.logger.debug("sensor:" + var + " retrieved Dataframe count: " + str(len(df_raw)) + ", on day: " + str(day) + ". This is less than freq value passed: " + str(self.freq))
180
- if i == 0: # Defining the DataFrame container
181
- from_date = pd.to_datetime(df_raw['last_changed'], format="ISO8601").min()
182
- to_date = pd.to_datetime(df_raw['last_changed'], format="ISO8601").max()
183
- ts = pd.to_datetime(pd.date_range(start=from_date, end=to_date, freq=self.freq),
184
- format='%Y-%d-%m %H:%M').round(self.freq, ambiguous='infer', nonexistent='shift_forward')
185
- df_day = pd.DataFrame(index = ts)
186
- # Caution with undefined string data: unknown, unavailable, etc.
187
- df_tp = (
188
- df_raw.copy()[["state"]]
189
- .replace(["unknown", "unavailable", ""], np.nan)
190
- .astype(float)
191
- .rename(columns={"state": var})
192
- )
193
- # Setting index, resampling and concatenation
194
- df_tp.set_index(
195
- pd.to_datetime(df_raw["last_changed"], format="ISO8601"),
196
- inplace=True,
197
- )
198
- df_tp = df_tp.resample(self.freq).mean()
199
- df_day = pd.concat([df_day, df_tp], axis=1)
200
- self.df_final = pd.concat([self.df_final, df_day], axis=0)
201
- x += 1
385
+ # Process Data
386
+ df_resampled = self._process_history_dataframe(
387
+ data,
388
+ var,
389
+ day,
390
+ is_first_day=(day_idx == 0),
391
+ is_last_day=(day_idx == len(days_list) - 1),
392
+ )
393
+ if df_resampled is False:
394
+ return False
395
+ # Merge into daily DataFrame
396
+ # If it's the first variable, we initialize the day's index based on it
397
+ if i == 0:
398
+ df_day = pd.DataFrame(index=df_resampled.index)
399
+ # Ensure the daily index is regularized to the frequency
400
+ # Note: The original logic created a manual range here, but using the
401
+ # resampled index from the first variable is safer and cleaner if
402
+ # _process_history_dataframe handles resampling correctly.
403
+ df_day = pd.concat([df_day, df_resampled], axis=1)
404
+ self.df_final = pd.concat([self.df_final, df_day], axis=0)
405
+
406
+ # Final Cleanup
202
407
  self.df_final = set_df_index_freq(self.df_final)
203
408
  if self.df_final.index.freq != self.freq:
204
- self.logger.error("The inferred freq:" + str(self.df_final.index.freq) + " from data is not equal to the defined freq in passed:" + str(self.freq))
409
+ self.logger.error(
410
+ f"The inferred freq: {self.df_final.index.freq} from data is not equal "
411
+ f"to the defined freq in passed: {self.freq}"
412
+ )
413
+ return False
414
+ self.var_list = var_list
415
+ return True
416
+
417
+ async def get_data_websocket(
418
+ self,
419
+ days_list: pd.date_range,
420
+ var_list: list[str],
421
+ ) -> bool:
422
+ r"""
423
+ Retrieve the actual data from hass.
424
+
425
+ :param days_list: A list of days to retrieve. The ISO format should be used \
426
+ and the timezone is UTC. The frequency of the data_range should be freq='D'
427
+ :type days_list: pandas.date_range
428
+ :param var_list: The list of variables to retrive from hass. These should \
429
+ be the exact name of the sensor in Home Assistant. \
430
+ For example: ['sensor.home_load', 'sensor.home_pv']
431
+ :type var_list: list
432
+ :return: The DataFrame populated with the retrieved data from hass
433
+ :rtype: pandas.DataFrame
434
+ """
435
+ try:
436
+ self._client = await asyncio.wait_for(
437
+ get_websocket_client(self.hass_url, self.long_lived_token, self.logger),
438
+ timeout=20.0,
439
+ )
440
+ except TimeoutError:
441
+ self.logger.error("WebSocket connection timed out")
442
+ return False
443
+ except Exception as e:
444
+ self.logger.error(f"Websocket connection error: {e}")
445
+ return False
446
+
447
+ self.var_list = var_list
448
+
449
+ # Calculate time range
450
+ start_time = min(days_list).to_pydatetime()
451
+ end_time = datetime.now()
452
+
453
+ # Try to get statistics data (which contains the actual historical data)
454
+ try:
455
+ # Get statistics data with 5-minute period for good resolution
456
+ t0 = time.time()
457
+ stats_data = await asyncio.wait_for(
458
+ self._client.get_statistics(
459
+ start_time=start_time,
460
+ end_time=end_time,
461
+ statistic_ids=var_list,
462
+ period="5minute",
463
+ ),
464
+ timeout=30.0,
465
+ )
466
+
467
+ # Convert statistics data to DataFrame
468
+ self.df_final = self._convert_statistics_to_dataframe(stats_data, var_list)
469
+
470
+ t1 = time.time()
471
+ self.logger.info(f"Statistics data retrieval took {t1 - t0:.2f} seconds")
472
+
473
+ return not self.df_final.empty
474
+
475
+ except Exception as e:
476
+ self.logger.error(f"Failed to get data via WebSocket: {e}")
477
+ return False
478
+
479
+ def get_data_influxdb(
480
+ self,
481
+ days_list: pd.date_range,
482
+ var_list: list,
483
+ ) -> bool:
484
+ """
485
+ Retrieve data from InfluxDB database.
486
+
487
+ This method provides an alternative data source to Home Assistant API,
488
+ enabling longer historical data retention for better machine learning model training.
489
+
490
+ :param days_list: A list of days to retrieve data for
491
+ :type days_list: pandas.date_range
492
+ :param var_list: List of sensor entity IDs to retrieve
493
+ :type var_list: list
494
+ :return: Success status of data retrieval
495
+ :rtype: bool
496
+ """
497
+ self.logger.info("Retrieve InfluxDB get data method initiated...")
498
+
499
+ # Check for empty inputs
500
+ if not days_list.size:
501
+ self.logger.error("Empty days_list provided")
502
+ return False
503
+
504
+ client = self._init_influx_client()
505
+ if not client:
506
+ return False
507
+
508
+ # Convert all timestamps to UTC for comparison, then make naive for InfluxDB
509
+ # This ensures we compare actual instants in time, not wall clock times
510
+ # InfluxDB queries expect naive UTC timestamps (with 'Z' suffix)
511
+
512
+ # Normalize start_time to pd.Timestamp in UTC
513
+ start_time = pd.Timestamp(days_list[0])
514
+ if start_time.tz is not None:
515
+ start_time = start_time.tz_convert("UTC").tz_localize(None)
516
+ # If naive, assume it's already UTC
517
+
518
+ # Get current time in UTC
519
+ now = pd.Timestamp.now(tz="UTC").tz_localize(None)
520
+
521
+ # Normalize requested_end to pd.Timestamp in UTC
522
+ requested_end = pd.Timestamp(days_list[-1]) + pd.Timedelta(days=1)
523
+ if requested_end.tz is not None:
524
+ requested_end = requested_end.tz_convert("UTC").tz_localize(None)
525
+ # If naive, assume it's already UTC
526
+
527
+ # Cap end_time at current time to avoid querying future data
528
+ # This prevents FILL(previous) from creating fake future datapoints
529
+ end_time = min(now, requested_end)
530
+ total_days = (end_time - start_time).days
531
+
532
+ self.logger.info(f"Retrieving {len(var_list)} sensors over {total_days} days from InfluxDB")
533
+ self.logger.debug(f"Time range: {start_time} to {end_time}")
534
+ if end_time < requested_end:
535
+ self.logger.debug(f"End time capped at current time (requested: {requested_end})")
536
+
537
+ # Collect sensor dataframes
538
+ sensor_dfs = []
539
+ global_min_time = None
540
+ global_max_time = None
541
+
542
+ for sensor in filter(None, var_list):
543
+ df_sensor = self._fetch_sensor_data(client, sensor, start_time, end_time)
544
+ if df_sensor is not None:
545
+ sensor_dfs.append(df_sensor)
546
+ # Track global time range
547
+ sensor_min = df_sensor.index.min()
548
+ sensor_max = df_sensor.index.max()
549
+ global_min_time = min(global_min_time or sensor_min, sensor_min)
550
+ global_max_time = max(global_max_time or sensor_max, sensor_max)
551
+
552
+ client.close()
553
+
554
+ if not sensor_dfs:
555
+ self.logger.error("No data retrieved from InfluxDB")
556
+ return False
557
+
558
+ # Create complete time index covering all sensors
559
+ if global_min_time is not None and global_max_time is not None:
560
+ complete_index = pd.date_range(
561
+ start=global_min_time, end=global_max_time, freq=self.freq
562
+ )
563
+ self.df_final = pd.DataFrame(index=complete_index)
564
+
565
+ # Merge all sensor dataframes
566
+ for df_sensor in sensor_dfs:
567
+ self.df_final = pd.concat([self.df_final, df_sensor], axis=1)
568
+
569
+ # Set frequency and validate with error handling
570
+ try:
571
+ self.df_final = set_df_index_freq(self.df_final)
572
+ except Exception as e:
573
+ self.logger.error(f"Exception occurred while setting DataFrame index frequency: {e}")
205
574
  return False
575
+
576
+ if self.df_final.index.freq != self.freq:
577
+ self.logger.warning(
578
+ f"InfluxDB data frequency ({self.df_final.index.freq}) differs from expected ({self.freq})"
579
+ )
580
+
581
+ self.var_list = var_list
582
+ self.logger.info(f"InfluxDB data retrieval completed: {self.df_final.shape}")
206
583
  return True
207
-
208
-
209
- def prepare_data(self, var_load: str, load_negative: Optional[bool] = False, set_zero_min: Optional[bool] = True,
210
- var_replace_zero: Optional[list] = None, var_interp: Optional[list] = None) -> None:
584
+
585
+ def _init_influx_client(self):
586
+ """Initialize InfluxDB client connection."""
587
+ try:
588
+ from influxdb import InfluxDBClient
589
+ except ImportError:
590
+ self.logger.error("InfluxDB client not installed. Install with: pip install influxdb")
591
+ return None
592
+
593
+ try:
594
+ client = InfluxDBClient(
595
+ host=self.influxdb_host,
596
+ port=self.influxdb_port,
597
+ username=self.influxdb_username or None,
598
+ password=self.influxdb_password or None,
599
+ database=self.influxdb_database,
600
+ ssl=self.influxdb_use_ssl,
601
+ verify_ssl=self.influxdb_verify_ssl,
602
+ )
603
+ # Test connection
604
+ client.ping()
605
+ self.logger.debug(
606
+ f"Successfully connected to InfluxDB at {self.influxdb_host}:{self.influxdb_port}"
607
+ )
608
+
609
+ # Initialize measurement cache
610
+ if not hasattr(self, "_measurement_cache"):
611
+ self._measurement_cache = {}
612
+
613
+ return client
614
+ except Exception as e:
615
+ self.logger.error(f"Failed to connect to InfluxDB: {e}")
616
+ return None
617
+
618
+ def _discover_entity_measurement(self, client, entity_id: str) -> str:
619
+ """Auto-discover which measurement contains the given entity."""
620
+ # Check cache first
621
+ if entity_id in self._measurement_cache:
622
+ return self._measurement_cache[entity_id]
623
+
624
+ try:
625
+ # Get all available measurements
626
+ measurements_query = "SHOW MEASUREMENTS"
627
+ measurements_result = client.query(measurements_query)
628
+ measurements = [m["name"] for m in measurements_result.get_points()]
629
+
630
+ # Priority order: check common sensor types first
631
+ priority_measurements = ["EUR/kWh", "€/kWh", "W", "EUR", "€", "%", "A", "V"]
632
+ all_measurements = priority_measurements + [
633
+ m for m in measurements if m not in priority_measurements
634
+ ]
635
+
636
+ self.logger.debug(
637
+ f"Searching for entity '{entity_id}' across {len(measurements)} measurements"
638
+ )
639
+
640
+ # Search for entity in each measurement
641
+ for measurement in all_measurements:
642
+ if measurement not in measurements:
643
+ continue # Skip if measurement doesn't exist
644
+
645
+ try:
646
+ # Use SHOW TAG VALUES to get all entity_ids in this measurement
647
+ tag_query = f'SHOW TAG VALUES FROM "{measurement}" WITH KEY = "entity_id"'
648
+ self.logger.debug(
649
+ f"Checking measurement '{measurement}' with tag query: {tag_query}"
650
+ )
651
+ result = client.query(tag_query)
652
+ points = list(result.get_points())
653
+
654
+ # Check if our target entity_id is in the tag values
655
+ for point in points:
656
+ if point.get("value") == entity_id:
657
+ self.logger.debug(
658
+ f"Found entity '{entity_id}' in measurement '{measurement}'"
659
+ )
660
+ # Cache the result
661
+ self._measurement_cache[entity_id] = measurement
662
+ return measurement
663
+
664
+ except Exception as query_error:
665
+ self.logger.debug(
666
+ f"Tag query failed for measurement '{measurement}': {query_error}"
667
+ )
668
+ continue
669
+
670
+ except Exception as e:
671
+ self.logger.error(f"Error discovering measurement for entity {entity_id}: {e}")
672
+
673
+ # Fallback to default measurement if not found
674
+ self.logger.warning(
675
+ f"Entity '{entity_id}' not found in any measurement, using default: {self.influxdb_measurement}"
676
+ )
677
+ return self.influxdb_measurement
678
+
679
+ def _build_influx_query_for_measurement(
680
+ self, entity_id: str, measurement: str, start_time, end_time
681
+ ) -> str:
682
+ """Build InfluxQL query for specific measurement and entity."""
683
+ # Convert frequency to InfluxDB interval
684
+ freq_minutes = int(self.freq.total_seconds() / 60)
685
+ interval = f"{freq_minutes}m"
686
+
687
+ # Format times properly for InfluxDB
688
+ start_time_str = start_time.strftime("%Y-%m-%dT%H:%M:%SZ")
689
+ end_time_str = end_time.strftime("%Y-%m-%dT%H:%M:%SZ")
690
+
691
+ # Use FILL(previous) instead of FILL(linear) for compatibility with open-source InfluxDB
692
+ query = f"""
693
+ SELECT mean("value") AS "mean_value"
694
+ FROM "{self.influxdb_database}"."{self.influxdb_retention_policy}"."{measurement}"
695
+ WHERE time >= '{start_time_str}'
696
+ AND time < '{end_time_str}'
697
+ AND "entity_id"='{entity_id}'
698
+ GROUP BY time({interval}) FILL(previous)
699
+ """
700
+ return query
701
+
702
+ def _build_influx_query(self, sensor: str, start_time, end_time) -> str:
703
+ """Build InfluxQL query for sensor data retrieval (legacy method)."""
704
+ # Convert sensor name: sensor.sec_pac_solar -> sec_pac_solar
705
+ entity_id = (
706
+ sensor.replace(sensor_prefix, "") if sensor.startswith(sensor_prefix) else sensor
707
+ )
708
+
709
+ # Use default measurement (for backward compatibility)
710
+ return self._build_influx_query_for_measurement(
711
+ entity_id, self.influxdb_measurement, start_time, end_time
712
+ )
713
+
714
+ def _fetch_sensor_data(self, client, sensor: str, start_time, end_time):
715
+ """Fetch and process data for a single sensor with auto-discovery."""
716
+ self.logger.debug(f"Retrieving sensor: {sensor}")
717
+
718
+ # Clean sensor name (remove sensor. prefix if present)
719
+ entity_id = (
720
+ sensor.replace(sensor_prefix, "") if sensor.startswith(sensor_prefix) else sensor
721
+ )
722
+
723
+ # Auto-discover which measurement contains this entity
724
+ measurement = self._discover_entity_measurement(client, entity_id)
725
+ if not measurement:
726
+ self.logger.warning(f"Entity '{entity_id}' not found in any InfluxDB measurement")
727
+ return None
728
+
729
+ try:
730
+ query = self._build_influx_query_for_measurement(
731
+ entity_id, measurement, start_time, end_time
732
+ )
733
+ self.logger.debug(f"InfluxDB query: {query}")
734
+
735
+ # Execute query
736
+ result = client.query(query)
737
+ points = list(result.get_points())
738
+
739
+ if not points:
740
+ self.logger.warning(
741
+ f"No data found for entity: {entity_id} in measurement: {measurement}"
742
+ )
743
+ return None
744
+
745
+ self.logger.info(f"Retrieved {len(points)} data points for {sensor}")
746
+
747
+ # Create DataFrame from points
748
+ df_sensor = pd.DataFrame(points)
749
+
750
+ # Convert time column and set as index with timezone awareness
751
+ df_sensor["time"] = pd.to_datetime(df_sensor["time"], utc=True)
752
+ df_sensor.set_index("time", inplace=True)
753
+
754
+ # Rename value column to original sensor name
755
+ if "mean_value" in df_sensor.columns:
756
+ df_sensor = df_sensor[["mean_value"]].rename(columns={"mean_value": sensor})
757
+ else:
758
+ self.logger.error(
759
+ f"Expected 'mean_value' column not found for {sensor} in measurement {measurement}"
760
+ )
761
+ return None
762
+
763
+ # Handle non-numeric data with NaN ratio warning
764
+ df_sensor[sensor] = pd.to_numeric(df_sensor[sensor], errors="coerce")
765
+
766
+ # Check proportion of NaNs and log warning if high
767
+ nan_count = df_sensor[sensor].isna().sum()
768
+ total_count = len(df_sensor[sensor])
769
+ if total_count > 0:
770
+ nan_ratio = nan_count / total_count
771
+ if nan_ratio > 0.2:
772
+ self.logger.warning(
773
+ f"Entity '{entity_id}' has {nan_count}/{total_count} ({nan_ratio:.1%}) non-numeric values coerced to NaN."
774
+ )
775
+
776
+ self.logger.debug(
777
+ f"Successfully retrieved {len(df_sensor)} data points for '{entity_id}' from measurement '{measurement}'"
778
+ )
779
+ return df_sensor
780
+
781
+ except Exception as e:
782
+ self.logger.error(
783
+ f"Failed to query entity {entity_id} from measurement {measurement}: {e}"
784
+ )
785
+ return None
786
+
787
+ def _validate_sensor_list(self, target_list: list, list_name: str) -> list:
788
+ """Helper to validate that config lists only contain known sensors."""
789
+ if not isinstance(target_list, list):
790
+ return []
791
+ valid_items = [item for item in target_list if item in self.var_list]
792
+ removed = set(target_list) - set(valid_items)
793
+ for item in removed:
794
+ self.logger.warning(
795
+ f"Sensor '{item}' in {list_name} not found in self.var_list and has been removed."
796
+ )
797
+ return valid_items
798
+
799
+ def _process_load_column_renaming(
800
+ self, var_load: str, load_negative: bool, skip_renaming: bool
801
+ ) -> bool:
802
+ """Helper to handle the sign flip and renaming of the main load column."""
803
+ if skip_renaming:
804
+ return True
805
+ try:
806
+ # Apply the correct sign to load power
807
+ if load_negative:
808
+ self.df_final[var_load + "_positive"] = -self.df_final[var_load]
809
+ else:
810
+ self.df_final[var_load + "_positive"] = self.df_final[var_load]
811
+ self.df_final.drop([var_load], inplace=True, axis=1)
812
+ # Update var_list to reflect the renamed column
813
+ self.var_list = [var.replace(var_load, var_load + "_positive") for var in self.var_list]
814
+ self.logger.debug(f"prepare_data var_list updated after rename: {self.var_list}")
815
+ return True
816
+ except KeyError as e:
817
+ self.logger.error(
818
+ f"Variable '{var_load}' was not found in DataFrame columns: {list(self.df_final.columns)}. "
819
+ f"This is typically because no data could be retrieved from Home Assistant or InfluxDB. Error: {e}"
820
+ )
821
+ return False
822
+ except ValueError:
823
+ self.logger.error(
824
+ "sensor.power_photovoltaics and sensor.power_load_no_var_loads should not be the same"
825
+ )
826
+ return False
827
+
828
+ def _map_variable_names(
829
+ self, target_list: list, var_load: str, skip_renaming: bool, param_name: str
830
+ ) -> list | None:
831
+ """Helper to map old variable names to new ones (if renaming occurred)."""
832
+ if not target_list:
833
+ self.logger.warning(f"Unable to find all the sensors in {param_name} parameter")
834
+ self.logger.warning(
835
+ f"Confirm sure all sensors in {param_name} are sensor_power_photovoltaics and/or sensor_power_load_no_var_loads"
836
+ )
837
+ return None
838
+ new_list = []
839
+ for string in target_list:
840
+ if not skip_renaming:
841
+ new_list.append(string.replace(var_load, var_load + "_positive"))
842
+ else:
843
+ new_list.append(string)
844
+ return new_list
845
+
846
+ def prepare_data(
847
+ self,
848
+ var_load: str,
849
+ load_negative: bool,
850
+ set_zero_min: bool,
851
+ var_replace_zero: list[str],
852
+ var_interp: list[str],
853
+ skip_renaming: bool = False,
854
+ ) -> bool:
211
855
  r"""
212
856
  Apply some data treatment in preparation for the optimization task.
213
-
857
+
214
858
  :param var_load: The name of the variable for the household load consumption.
215
859
  :type var_load: str
216
860
  :param load_negative: Set to True if the retrived load variable is \
@@ -228,77 +872,67 @@ class RetrieveHass:
228
872
  :return: The DataFrame populated with the retrieved data from hass and \
229
873
  after the data treatment
230
874
  :rtype: pandas.DataFrame
231
-
875
+
232
876
  """
233
- try:
234
- if load_negative: # Apply the correct sign to load power
235
- self.df_final[var_load + "_positive"] = -self.df_final[var_load]
236
- else:
237
- self.df_final[var_load + "_positive"] = self.df_final[var_load]
238
- self.df_final.drop([var_load], inplace=True, axis=1)
239
- except KeyError:
240
- self.logger.error(
241
- "Variable "
242
- + var_load
243
- + " was not found. This is typically because no data could be retrieved from Home Assistant"
244
- )
245
- return False
246
- except ValueError:
247
- self.logger.error(
248
- "sensor.power_photovoltaics and sensor.power_load_no_var_loads should not be the same"
249
- )
877
+ self.logger.debug("prepare_data self.var_list=%s", self.var_list)
878
+ self.logger.debug("prepare_data var_load=%s", var_load)
879
+ # Validate Input Lists
880
+ var_replace_zero = self._validate_sensor_list(var_replace_zero, "var_replace_zero")
881
+ var_interp = self._validate_sensor_list(var_interp, "var_interp")
882
+ # Rename Load Columns (Handle sign change)
883
+ if not self._process_load_column_renaming(var_load, load_negative, skip_renaming):
250
884
  return False
251
- if set_zero_min: # Apply minimum values
885
+ # Apply Zero Saturation (Min value clipping)
886
+ if set_zero_min:
252
887
  self.df_final.clip(lower=0.0, inplace=True, axis=1)
253
888
  self.df_final.replace(to_replace=0.0, value=np.nan, inplace=True)
254
- new_var_replace_zero = []
255
- new_var_interp = []
256
- # Just changing the names of variables to contain the fact that they are considered positive
257
- if var_replace_zero is not None:
258
- for string in var_replace_zero:
259
- new_string = string.replace(var_load, var_load + "_positive")
260
- new_var_replace_zero.append(new_string)
261
- else:
262
- new_var_replace_zero = None
263
- if var_interp is not None:
264
- for string in var_interp:
265
- new_string = string.replace(var_load, var_load + "_positive")
266
- new_var_interp.append(new_string)
267
- else:
268
- new_var_interp = None
269
- # Treating NaN replacement: either by zeros or by linear interpolation
270
- if new_var_replace_zero is not None:
271
- self.df_final[new_var_replace_zero] = self.df_final[
272
- new_var_replace_zero
273
- ].fillna(0.0)
274
- if new_var_interp is not None:
889
+ # Map Variable Names (Update lists to match new column names)
890
+ new_var_replace_zero = self._map_variable_names(
891
+ var_replace_zero, var_load, skip_renaming, "sensor_replace_zero"
892
+ )
893
+ new_var_interp = self._map_variable_names(
894
+ var_interp, var_load, skip_renaming, "sensor_linear_interp"
895
+ )
896
+ # Apply Data Cleaning (FillNA / Interpolate)
897
+ if new_var_replace_zero:
898
+ self.df_final[new_var_replace_zero] = self.df_final[new_var_replace_zero].fillna(0.0)
899
+ if new_var_interp:
275
900
  self.df_final[new_var_interp] = self.df_final[new_var_interp].interpolate(
276
901
  method="linear", axis=0, limit=None
277
902
  )
278
903
  self.df_final[new_var_interp] = self.df_final[new_var_interp].fillna(0.0)
279
- # Setting the correct time zone on DF index
904
+ # Finalize Index (Timezone and Duplicates)
280
905
  if self.time_zone is not None:
281
906
  self.df_final.index = self.df_final.index.tz_convert(self.time_zone)
282
- # Drop datetimeindex duplicates on final DF
283
907
  self.df_final = self.df_final[~self.df_final.index.duplicated(keep="first")]
284
908
  return True
285
909
 
286
910
  @staticmethod
287
- def get_attr_data_dict(data_df: pd.DataFrame, idx: int, entity_id: str, unit_of_measurement: str,
288
- friendly_name: str, list_name: str, state: float) -> dict:
911
+ def get_attr_data_dict(
912
+ data_df: pd.DataFrame,
913
+ idx: int,
914
+ entity_id: str,
915
+ device_class: str,
916
+ unit_of_measurement: str,
917
+ friendly_name: str,
918
+ list_name: str,
919
+ state: float,
920
+ decimals: int = 2,
921
+ ) -> dict:
289
922
  list_df = copy.deepcopy(data_df).loc[data_df.index[idx] :].reset_index()
290
923
  list_df.columns = ["timestamps", entity_id]
291
- ts_list = [str(i) for i in list_df["timestamps"].tolist()]
292
- vals_list = [str(np.round(i, 2)) for i in list_df[entity_id].tolist()]
924
+ ts_list = [i.isoformat() for i in list_df["timestamps"].tolist()]
925
+ vals_list = [str(np.round(i, decimals)) for i in list_df[entity_id].tolist()]
293
926
  forecast_list = []
294
927
  for i, ts in enumerate(ts_list):
295
928
  datum = {}
296
929
  datum["date"] = ts
297
- datum[entity_id.split("sensor.")[1]] = vals_list[i]
930
+ datum[entity_id.split(sensor_prefix)[1]] = vals_list[i]
298
931
  forecast_list.append(datum)
299
932
  data = {
300
- "state": "{:.2f}".format(state),
933
+ "state": f"{state:.{decimals}f}",
301
934
  "attributes": {
935
+ "device_class": device_class,
302
936
  "unit_of_measurement": unit_of_measurement,
303
937
  "friendly_name": friendly_name,
304
938
  list_name: forecast_list,
@@ -306,14 +940,27 @@ class RetrieveHass:
306
940
  }
307
941
  return data
308
942
 
309
-
310
- def post_data(self, data_df: pd.DataFrame, idx: int, entity_id: str, unit_of_measurement: str,
311
- friendly_name: str, type_var: str, from_mlforecaster: Optional[bool] = False,
312
- publish_prefix: Optional[str] = "", save_entities: Optional[bool] = False,
313
- logger_levels: Optional[str] = "info", dont_post: Optional[bool] = False) -> None:
943
+ async def post_data(
944
+ self,
945
+ data_df: pd.DataFrame,
946
+ idx: int,
947
+ entity_id: str,
948
+ device_class: str,
949
+ unit_of_measurement: str,
950
+ friendly_name: str,
951
+ type_var: str,
952
+ publish_prefix: str | None = "",
953
+ save_entities: bool | None = False,
954
+ logger_levels: str | None = "info",
955
+ dont_post: bool | None = False,
956
+ ) -> None:
314
957
  r"""
315
- Post passed data to hass.
316
-
958
+ Post passed data to hass using REST API.
959
+
960
+ .. note:: This method ALWAYS uses the REST API for posting data to Home Assistant,
961
+ regardless of the use_websocket setting. WebSocket is only used for
962
+ data retrieval, not for publishing/posting data.
963
+
317
964
  :param data_df: The DataFrame containing the data that will be posted \
318
965
  to hass. This should be a one columns DF or a series.
319
966
  :type data_df: pd.DataFrame
@@ -322,6 +969,8 @@ class RetrieveHass:
322
969
  :type idx: int
323
970
  :param entity_id: The unique entity_id of the sensor in hass.
324
971
  :type entity_id: str
972
+ :param device_class: The HASS device class for the sensor.
973
+ :type device_class: str
325
974
  :param unit_of_measurement: The units of the sensor.
326
975
  :type unit_of_measurement: str
327
976
  :param friendly_name: The friendly name that will be used in the hass frontend.
@@ -332,28 +981,26 @@ class RetrieveHass:
332
981
  :type publish_prefix: str, optional
333
982
  :param save_entities: if entity data should be saved in data_path/entities
334
983
  :type save_entities: bool, optional
335
- :param logger_levels: set logger level, info or debug, to output
984
+ :param logger_levels: set logger level, info or debug, to output
336
985
  :type logger_levels: str, optional
337
986
  :param dont_post: dont post to HA
338
987
  :type dont_post: bool, optional
339
988
 
340
989
  """
341
990
  # Add a possible prefix to the entity ID
342
- entity_id = entity_id.replace("sensor.", "sensor." + publish_prefix)
991
+ entity_id = entity_id.replace(sensor_prefix, sensor_prefix + publish_prefix)
343
992
  # Set the URL
344
- if (
345
- self.hass_url == "http://supervisor/core/api"
346
- ): # If we are using the supervisor API
993
+ if self.hass_url == hass_url: # If we are using the supervisor API
347
994
  url = self.hass_url + "/states/" + entity_id
348
995
  else: # Otherwise the Home Assistant Core API it is
349
996
  url = self.hass_url + "api/states/" + entity_id
350
997
  headers = {
351
- "Authorization": "Bearer " + self.long_lived_token,
352
- "content-type": "application/json",
353
- }
998
+ "Authorization": header_auth + " " + self.long_lived_token,
999
+ "content-type": header_accept,
1000
+ }
354
1001
  # Preparing the data dict to be published
355
1002
  if type_var == "cost_fun":
356
- if isinstance(data_df.iloc[0],pd.Series): #if Series extract
1003
+ if isinstance(data_df.iloc[0], pd.Series): # if Series extract
357
1004
  data_df = data_df.iloc[:, 0]
358
1005
  state = np.round(data_df.sum(), 2)
359
1006
  elif type_var == "unit_load_cost" or type_var == "unit_prod_price":
@@ -361,38 +1008,114 @@ class RetrieveHass:
361
1008
  elif type_var == "optim_status":
362
1009
  state = data_df.loc[data_df.index[idx]]
363
1010
  elif type_var == "mlregressor":
364
- state = data_df[idx]
1011
+ state = float(data_df[idx])
365
1012
  else:
366
1013
  state = np.round(data_df.loc[data_df.index[idx]], 2)
367
1014
  if type_var == "power":
368
- data = RetrieveHass.get_attr_data_dict(data_df, idx, entity_id, unit_of_measurement,
369
- friendly_name, "forecasts", state)
1015
+ data = RetrieveHass.get_attr_data_dict(
1016
+ data_df,
1017
+ idx,
1018
+ entity_id,
1019
+ device_class,
1020
+ unit_of_measurement,
1021
+ friendly_name,
1022
+ "forecasts",
1023
+ state,
1024
+ )
370
1025
  elif type_var == "deferrable":
371
- data = RetrieveHass.get_attr_data_dict(data_df, idx, entity_id, unit_of_measurement,
372
- friendly_name, "deferrables_schedule", state)
1026
+ data = RetrieveHass.get_attr_data_dict(
1027
+ data_df,
1028
+ idx,
1029
+ entity_id,
1030
+ device_class,
1031
+ unit_of_measurement,
1032
+ friendly_name,
1033
+ "deferrables_schedule",
1034
+ state,
1035
+ )
373
1036
  elif type_var == "temperature":
374
- data = RetrieveHass.get_attr_data_dict(data_df, idx, entity_id, unit_of_measurement,
375
- friendly_name, "predicted_temperatures", state)
1037
+ data = RetrieveHass.get_attr_data_dict(
1038
+ data_df,
1039
+ idx,
1040
+ entity_id,
1041
+ device_class,
1042
+ unit_of_measurement,
1043
+ friendly_name,
1044
+ "predicted_temperatures",
1045
+ state,
1046
+ )
376
1047
  elif type_var == "batt":
377
- data = RetrieveHass.get_attr_data_dict(data_df, idx, entity_id, unit_of_measurement,
378
- friendly_name, "battery_scheduled_power", state)
1048
+ data = RetrieveHass.get_attr_data_dict(
1049
+ data_df,
1050
+ idx,
1051
+ entity_id,
1052
+ device_class,
1053
+ unit_of_measurement,
1054
+ friendly_name,
1055
+ "battery_scheduled_power",
1056
+ state,
1057
+ )
379
1058
  elif type_var == "SOC":
380
- data = RetrieveHass.get_attr_data_dict(data_df, idx, entity_id, unit_of_measurement,
381
- friendly_name, "battery_scheduled_soc", state)
1059
+ data = RetrieveHass.get_attr_data_dict(
1060
+ data_df,
1061
+ idx,
1062
+ entity_id,
1063
+ device_class,
1064
+ unit_of_measurement,
1065
+ friendly_name,
1066
+ "battery_scheduled_soc",
1067
+ state,
1068
+ )
382
1069
  elif type_var == "unit_load_cost":
383
- data = RetrieveHass.get_attr_data_dict(data_df, idx, entity_id, unit_of_measurement,
384
- friendly_name, "unit_load_cost_forecasts", state)
1070
+ data = RetrieveHass.get_attr_data_dict(
1071
+ data_df,
1072
+ idx,
1073
+ entity_id,
1074
+ device_class,
1075
+ unit_of_measurement,
1076
+ friendly_name,
1077
+ "unit_load_cost_forecasts",
1078
+ state,
1079
+ decimals=4,
1080
+ )
385
1081
  elif type_var == "unit_prod_price":
386
- data = RetrieveHass.get_attr_data_dict(data_df, idx, entity_id, unit_of_measurement,
387
- friendly_name, "unit_prod_price_forecasts", state)
1082
+ data = RetrieveHass.get_attr_data_dict(
1083
+ data_df,
1084
+ idx,
1085
+ entity_id,
1086
+ device_class,
1087
+ unit_of_measurement,
1088
+ friendly_name,
1089
+ "unit_prod_price_forecasts",
1090
+ state,
1091
+ decimals=4,
1092
+ )
388
1093
  elif type_var == "mlforecaster":
389
- data = RetrieveHass.get_attr_data_dict(data_df, idx, entity_id, unit_of_measurement,
390
- friendly_name, "scheduled_forecast", state)
1094
+ data = RetrieveHass.get_attr_data_dict(
1095
+ data_df,
1096
+ idx,
1097
+ entity_id,
1098
+ device_class,
1099
+ unit_of_measurement,
1100
+ friendly_name,
1101
+ "scheduled_forecast",
1102
+ state,
1103
+ )
1104
+ elif type_var == "energy":
1105
+ data = RetrieveHass.get_attr_data_dict(
1106
+ data_df,
1107
+ idx,
1108
+ entity_id,
1109
+ device_class,
1110
+ unit_of_measurement,
1111
+ friendly_name,
1112
+ "heating_demand_forecast",
1113
+ state,
1114
+ )
391
1115
  elif type_var == "optim_status":
392
1116
  data = {
393
1117
  "state": state,
394
1118
  "attributes": {
395
- "unit_of_measurement": unit_of_measurement,
396
1119
  "friendly_name": friendly_name,
397
1120
  },
398
1121
  }
@@ -400,68 +1123,201 @@ class RetrieveHass:
400
1123
  data = {
401
1124
  "state": state,
402
1125
  "attributes": {
1126
+ "device_class": device_class,
403
1127
  "unit_of_measurement": unit_of_measurement,
404
1128
  "friendly_name": friendly_name,
405
1129
  },
406
1130
  }
407
1131
  else:
408
1132
  data = {
409
- "state": "{:.2f}".format(state),
1133
+ "state": f"{state:.2f}",
410
1134
  "attributes": {
1135
+ "device_class": device_class,
411
1136
  "unit_of_measurement": unit_of_measurement,
412
1137
  "friendly_name": friendly_name,
413
1138
  },
414
1139
  }
415
1140
  # Actually post the data
416
1141
  if self.get_data_from_file or dont_post:
417
- class response:
418
- pass
419
- response.status_code = 200
420
- response.ok = True
1142
+ # Create mock response for file mode or dont_post mode
1143
+ self.logger.debug(
1144
+ f"Skipping actual POST (get_data_from_file={self.get_data_from_file}, dont_post={dont_post})"
1145
+ )
1146
+ response_ok = True
1147
+ response_status_code = 200
421
1148
  else:
422
- response = post(url, headers=headers, data=json.dumps(data))
1149
+ # Always use REST API for posting data, regardless of use_websocket setting
1150
+ self.logger.debug(f"Posting data to URL: {url}")
1151
+ try:
1152
+ async with aiohttp.ClientSession() as session:
1153
+ async with session.post(
1154
+ url, headers=headers, data=orjson.dumps(data).decode("utf-8")
1155
+ ) as response:
1156
+ # Store response data since we need to access it after the context manager
1157
+ response_ok = response.ok
1158
+ response_status_code = response.status
1159
+ self.logger.debug(
1160
+ f"HTTP POST response: ok={response_ok}, status={response_status_code}"
1161
+ )
1162
+ except Exception as e:
1163
+ self.logger.error(f"Failed to post data to {entity_id}: {e}")
1164
+ response_ok = False
1165
+ response_status_code = 500
423
1166
 
424
1167
  # Treating the response status and posting them on the logger
425
- if response.ok:
426
-
427
- if logger_levels == "DEBUG":
1168
+ if response_ok:
1169
+ if logger_levels == "DEBUG" or dont_post:
428
1170
  self.logger.debug("Successfully posted to " + entity_id + " = " + str(state))
429
1171
  else:
430
1172
  self.logger.info("Successfully posted to " + entity_id + " = " + str(state))
431
1173
 
432
1174
  # If save entities is set, save entity data to /data_path/entities
433
- if (save_entities):
434
- entities_path = self.emhass_conf['data_path'] / "entities"
435
-
1175
+ if save_entities:
1176
+ entities_path = self.emhass_conf["data_path"] / "entities"
1177
+
436
1178
  # Clarify folder exists
437
1179
  pathlib.Path(entities_path).mkdir(parents=True, exist_ok=True)
438
-
1180
+
439
1181
  # Save entity data to json file
440
- result = data_df.to_json(index="timestamp", orient='index', date_unit='s', date_format='iso')
441
- parsed = json.loads(result)
442
- with open(entities_path / (entity_id + ".json"), "w") as file:
443
- json.dump(parsed, file, indent=4)
444
-
1182
+ result = data_df.to_json(
1183
+ index="timestamp", orient="index", date_unit="s", date_format="iso"
1184
+ )
1185
+ parsed = orjson.loads(result)
1186
+ async with aiofiles.open(entities_path / (entity_id + ".json"), "w") as file:
1187
+ await file.write(orjson.dumps(parsed, option=orjson.OPT_INDENT_2).decode())
1188
+
445
1189
  # Save the required metadata to json file
446
- if os.path.isfile(entities_path / "metadata.json"):
447
- with open(entities_path / "metadata.json", "r") as file:
448
- metadata = json.load(file)
1190
+ metadata_path = entities_path / "metadata.json"
1191
+ if os.path.isfile(metadata_path):
1192
+ async with aiofiles.open(metadata_path) as file:
1193
+ content = await file.read()
1194
+ metadata = orjson.loads(content)
449
1195
  else:
450
1196
  metadata = {}
451
- with open(entities_path / "metadata.json", "w") as file:
452
- # Save entity metadata, key = entity_id
453
- metadata[entity_id] = {'name': data_df.name, 'unit_of_measurement': unit_of_measurement,'friendly_name': friendly_name,'type_var': type_var, 'freq': int(self.freq.seconds / 60)}
454
-
1197
+
1198
+ async with aiofiles.open(metadata_path, "w") as file:
1199
+ # Save entity metadata, key = entity_id
1200
+ metadata[entity_id] = {
1201
+ "name": data_df.name,
1202
+ "device_class": device_class,
1203
+ "unit_of_measurement": unit_of_measurement,
1204
+ "friendly_name": friendly_name,
1205
+ "type_var": type_var,
1206
+ "optimization_time_step": int(self.freq.seconds / 60),
1207
+ }
1208
+
455
1209
  # Find lowest frequency to set for continual loop freq
456
- if metadata.get("lowest_freq",None) == None or metadata["lowest_freq"] > int(self.freq.seconds / 60):
457
- metadata["lowest_freq"] = int(self.freq.seconds / 60)
458
- json.dump(metadata,file, indent=4)
1210
+ if metadata.get("lowest_time_step") is None or metadata[
1211
+ "lowest_time_step"
1212
+ ] > int(self.freq.seconds / 60):
1213
+ metadata["lowest_time_step"] = int(self.freq.seconds / 60)
1214
+ await file.write(orjson.dumps(metadata, option=orjson.OPT_INDENT_2).decode())
1215
+
1216
+ self.logger.debug("Saved " + entity_id + " to json file")
459
1217
 
460
- self.logger.debug("Saved " + entity_id + " to json file")
461
-
462
1218
  else:
463
1219
  self.logger.warning(
464
- "The status code for received curl command response is: "
465
- + str(response.status_code)
1220
+ f"Failed to post data to {entity_id}. Status code: {response_status_code}"
466
1221
  )
467
- return response, data
1222
+
1223
+ # Create a response object to maintain compatibility
1224
+ class MockResponse:
1225
+ def __init__(self, ok, status_code):
1226
+ self.ok = ok
1227
+ self.status_code = status_code
1228
+
1229
+ mock_response = MockResponse(response_ok, response_status_code)
1230
+ self.logger.debug(f"Completed post_data for {entity_id}")
1231
+ return mock_response, data
1232
+
1233
+ def _convert_statistics_to_dataframe(
1234
+ self, stats_data: dict[str, Any], var_list: list[str]
1235
+ ) -> pd.DataFrame:
1236
+ """Convert WebSocket statistics data to DataFrame."""
1237
+ import pandas as pd
1238
+
1239
+ # Initialize empty DataFrame
1240
+ df_final = pd.DataFrame()
1241
+
1242
+ # The websocket manager already extracts the 'result' portion
1243
+ # so stats_data should be directly the entity data dictionary
1244
+
1245
+ for entity_id in var_list:
1246
+ if entity_id not in stats_data:
1247
+ self.logger.warning(f"No statistics data for {entity_id}")
1248
+ continue
1249
+
1250
+ entity_stats = stats_data[entity_id]
1251
+
1252
+ if not entity_stats:
1253
+ continue
1254
+
1255
+ # Convert statistics to DataFrame
1256
+ entity_data = []
1257
+ for _i, stat in enumerate(entity_stats):
1258
+ try:
1259
+ # Handle timestamp from start time (milliseconds or ISO string)
1260
+ if isinstance(stat["start"], int | float):
1261
+ # Convert from milliseconds to datetime with UTC timezone
1262
+ timestamp = pd.to_datetime(stat["start"], unit="ms", utc=True)
1263
+ else:
1264
+ # Assume ISO string
1265
+ timestamp = pd.to_datetime(stat["start"], utc=True)
1266
+
1267
+ # Use mean, max, min or sum depending on what's available
1268
+ value = None
1269
+ if "mean" in stat and stat["mean"] is not None:
1270
+ value = stat["mean"]
1271
+ elif "sum" in stat and stat["sum"] is not None:
1272
+ value = stat["sum"]
1273
+ elif "max" in stat and stat["max"] is not None:
1274
+ value = stat["max"]
1275
+ elif "min" in stat and stat["min"] is not None:
1276
+ value = stat["min"]
1277
+
1278
+ if value is not None:
1279
+ try:
1280
+ value = float(value)
1281
+ entity_data.append({"timestamp": timestamp, entity_id: value})
1282
+ except (ValueError, TypeError):
1283
+ self.logger.debug(f"Could not convert value to float: {value}")
1284
+
1285
+ except (KeyError, ValueError, TypeError) as e:
1286
+ self.logger.debug(f"Skipping invalid statistic for {entity_id}: {e}")
1287
+ continue
1288
+
1289
+ if entity_data:
1290
+ entity_df = pd.DataFrame(entity_data)
1291
+ entity_df.set_index("timestamp", inplace=True)
1292
+
1293
+ if df_final.empty:
1294
+ df_final = entity_df
1295
+ else:
1296
+ df_final = df_final.join(entity_df, how="outer")
1297
+
1298
+ # Process the final DataFrame
1299
+ if not df_final.empty:
1300
+ # Ensure timezone awareness - timestamps should already be UTC from conversion above
1301
+ if df_final.index.tz is None:
1302
+ # If somehow still naive, localize as UTC first then convert
1303
+ df_final.index = df_final.index.tz_localize("UTC").tz_convert(self.time_zone)
1304
+ else:
1305
+ # Convert from existing timezone to target timezone
1306
+ df_final.index = df_final.index.tz_convert(self.time_zone)
1307
+
1308
+ # Sort by index
1309
+ df_final = df_final.sort_index()
1310
+
1311
+ # Resample to frequency if needed
1312
+ try:
1313
+ df_final = df_final.resample(self.freq).mean()
1314
+ except Exception as e:
1315
+ self.logger.warning(f"Could not resample data to {self.freq}: {e}")
1316
+
1317
+ # Forward fill missing values
1318
+ df_final = df_final.ffill()
1319
+
1320
+ # Set frequency for the DataFrame index
1321
+ df_final = set_df_index_freq(df_final)
1322
+
1323
+ return df_final