emhass 0.11.4__py3-none-any.whl → 0.15.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
emhass/retrieve_hass.py CHANGED
@@ -1,36 +1,44 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
1
+ import asyncio
4
2
  import copy
5
- import datetime
6
- import json
7
3
  import logging
8
4
  import os
9
5
  import pathlib
10
- from typing import Optional
6
+ import time
7
+ from datetime import datetime, timezone
8
+ from typing import Any
11
9
 
10
+ import aiofiles
11
+ import aiohttp
12
12
  import numpy as np
13
+ import orjson
13
14
  import pandas as pd
14
- from requests import get, post
15
15
 
16
+ from emhass.connection_manager import get_websocket_client
16
17
  from emhass.utils import set_df_index_freq
17
18
 
19
+ logger = logging.getLogger(__name__)
20
+
21
+ header_accept = "application/json"
22
+ header_auth = "Bearer"
23
+ hass_url = "http://supervisor/core/api"
24
+ sensor_prefix = "sensor."
25
+
18
26
 
19
27
  class RetrieveHass:
20
28
  r"""
21
29
  Retrieve data from Home Assistant using the restful API.
22
-
30
+
23
31
  This class allows the user to retrieve data from a Home Assistant instance \
24
32
  using the provided restful API (https://developers.home-assistant.io/docs/api/rest/)
25
-
33
+
26
34
  This class methods are:
27
-
35
+
28
36
  - get_data: to retrieve the actual data from hass
29
-
37
+
30
38
  - prepare_data: to apply some data treatment in preparation for the optimization task
31
-
39
+
32
40
  - post_data: Post passed data to hass
33
-
41
+
34
42
  """
35
43
 
36
44
  def __init__(
@@ -38,11 +46,11 @@ class RetrieveHass:
38
46
  hass_url: str,
39
47
  long_lived_token: str,
40
48
  freq: pd.Timedelta,
41
- time_zone: datetime.timezone,
49
+ time_zone: timezone,
42
50
  params: str,
43
51
  emhass_conf: dict,
44
52
  logger: logging.Logger,
45
- get_data_from_file: Optional[bool] = False,
53
+ get_data_from_file: bool | None = False,
46
54
  ) -> None:
47
55
  """
48
56
  Define constructor for RetrieveHass class.
@@ -71,40 +79,135 @@ class RetrieveHass:
71
79
  self.long_lived_token = long_lived_token
72
80
  self.freq = freq
73
81
  self.time_zone = time_zone
74
- if (params == None) or (params == "null"):
82
+ if (params is None) or (params == "null"):
75
83
  self.params = {}
76
84
  elif type(params) is dict:
77
85
  self.params = params
78
86
  else:
79
- self.params = json.loads(params)
87
+ self.params = orjson.loads(params)
80
88
  self.emhass_conf = emhass_conf
81
89
  self.logger = logger
82
90
  self.get_data_from_file = get_data_from_file
91
+ self.var_list = []
92
+ self.use_websocket = self.params.get("retrieve_hass_conf", {}).get("use_websocket", False)
93
+ if self.use_websocket:
94
+ self._client = None
95
+ else:
96
+ self.logger.debug("Websocket integration disabled, using Home Assistant API")
97
+ # Initialize InfluxDB configuration
98
+ self.use_influxdb = self.params.get("retrieve_hass_conf", {}).get("use_influxdb", False)
99
+ if self.use_influxdb:
100
+ influx_conf = self.params.get("retrieve_hass_conf", {})
101
+ self.influxdb_host = influx_conf.get("influxdb_host", "localhost")
102
+ self.influxdb_port = influx_conf.get("influxdb_port", 8086)
103
+ self.influxdb_username = influx_conf.get("influxdb_username", "")
104
+ self.influxdb_password = influx_conf.get("influxdb_password", "")
105
+ self.influxdb_database = influx_conf.get("influxdb_database", "homeassistant")
106
+ self.influxdb_measurement = influx_conf.get("influxdb_measurement", "W")
107
+ self.influxdb_retention_policy = influx_conf.get("influxdb_retention_policy", "autogen")
108
+ self.influxdb_use_ssl = influx_conf.get("influxdb_use_ssl", False)
109
+ self.influxdb_verify_ssl = influx_conf.get("influxdb_verify_ssl", False)
110
+ self.logger.info(
111
+ f"InfluxDB integration enabled: {self.influxdb_host}:{self.influxdb_port}/{self.influxdb_database}"
112
+ )
113
+ else:
114
+ self.logger.debug("InfluxDB integration disabled, using Home Assistant API")
83
115
 
84
- def get_ha_config(self):
116
+ async def get_ha_config(self):
85
117
  """
86
118
  Extract some configuration data from HA.
87
119
 
120
+ :rtype: bool
88
121
  """
122
+ # Initialize empty config immediately for safety
123
+ self.ha_config = {}
124
+
125
+ # Check if variables are None, empty strings, or explicitly set to "empty"
126
+ if (
127
+ not self.hass_url
128
+ or self.hass_url == "empty"
129
+ or not self.long_lived_token
130
+ or self.long_lived_token == "empty"
131
+ ):
132
+ self.logger.info(
133
+ "No Home Assistant URL or Long Lived Token found. Using only local configuration file."
134
+ )
135
+ return True
136
+
137
+ # Use WebSocket if configured
138
+ if self.use_websocket:
139
+ return await self.get_ha_config_websocket()
140
+
141
+ self.logger.info("get HA config from rest api.")
142
+
143
+ # Set up headers
89
144
  headers = {
90
- "Authorization": "Bearer " + self.long_lived_token,
91
- "content-type": "application/json",
145
+ "Authorization": header_auth + " " + self.long_lived_token,
146
+ "content-type": header_accept,
92
147
  }
93
- url = self.hass_url + "api/config"
94
- response_config = get(url, headers=headers)
95
- self.ha_config = response_config.json()
96
148
 
97
- def get_data(
149
+ # Construct the URL (incorporating the PR's helpful checks)
150
+ # The Supervisor API sometimes uses a different path structure
151
+ if self.hass_url == hass_url:
152
+ url = self.hass_url + "/config"
153
+ else:
154
+ # Helpful check for users who forget the trailing slash
155
+ if not self.hass_url.endswith("/"):
156
+ self.logger.warning(
157
+ "The defined HA URL is missing a trailing slash </>. Appending it, but please fix your configuration."
158
+ )
159
+ self.hass_url = self.hass_url + "/"
160
+ url = self.hass_url + "api/config"
161
+
162
+ # Attempt the connection
163
+ try:
164
+ async with aiohttp.ClientSession() as session:
165
+ async with session.get(url, headers=headers) as response:
166
+ # Check for HTTP errors (404, 401, 500) before trying to parse JSON
167
+ response.raise_for_status()
168
+ data = await response.read()
169
+ self.ha_config = orjson.loads(data)
170
+ return True
171
+
172
+ except Exception as e:
173
+ # Granular Error Logging
174
+ # We log the specific error 'e' so the user knows if it's a Timeout, Connection Refused, or 401 Auth error
175
+ self.logger.error(f"Unable to obtain configuration from Home Assistant at: {url}")
176
+ self.logger.error(f"Error details: {e}")
177
+
178
+ # Helpful hint for Add-on users without confusing Docker users
179
+ if "supervisor" in self.hass_url:
180
+ self.logger.error(
181
+ "If using the add-on, try setting url and token to 'empty' to force local config."
182
+ )
183
+
184
+ return False
185
+
186
+ async def get_ha_config_websocket(self) -> dict[str, Any]:
187
+ """Get Home Assistant configuration."""
188
+ try:
189
+ self._client = await get_websocket_client(
190
+ self.hass_url, self.long_lived_token, self.logger
191
+ )
192
+ self.ha_config = await self._client.get_config()
193
+ return self.ha_config
194
+ except Exception as e:
195
+ self.logger.error(
196
+ f"EMHASS was unable to obtain configuration data from Home Assistant through websocket: {e}"
197
+ )
198
+ raise
199
+
200
+ async def get_data(
98
201
  self,
99
202
  days_list: pd.date_range,
100
203
  var_list: list,
101
- minimal_response: Optional[bool] = False,
102
- significant_changes_only: Optional[bool] = False,
103
- test_url: Optional[str] = "empty",
204
+ minimal_response: bool | None = False,
205
+ significant_changes_only: bool | None = False,
206
+ test_url: str | None = "empty",
104
207
  ) -> None:
105
208
  r"""
106
209
  Retrieve the actual data from hass.
107
-
210
+
108
211
  :param days_list: A list of days to retrieve. The ISO format should be used \
109
212
  and the timezone is UTC. The frequency of the data_range should be freq='D'
110
213
  :type days_list: pandas.date_range
@@ -120,171 +223,638 @@ class RetrieveHass:
120
223
  :type significant_changes_only: bool, optional
121
224
  :return: The DataFrame populated with the retrieved data from hass
122
225
  :rtype: pandas.DataFrame
123
-
124
- .. warning:: The minimal_response and significant_changes_only options \
125
- are experimental
126
226
  """
227
+ # Use InfluxDB if configured (Prioritize over WebSocket/REST for history)
228
+ if self.use_influxdb:
229
+ return self.get_data_influxdb(days_list, var_list)
230
+
231
+ # Use WebSockets if configured, otherwise use Home Assistant REST API
232
+ if self.use_websocket:
233
+ success = await self.get_data_websocket(days_list, var_list)
234
+ if not success:
235
+ self.logger.warning("WebSocket data retrieval failed, falling back to REST API")
236
+ # Fall back to REST API if websocket fails
237
+ return await self._get_data_rest_api(
238
+ days_list,
239
+ var_list,
240
+ minimal_response,
241
+ significant_changes_only,
242
+ test_url,
243
+ )
244
+ return success
245
+
246
+ self.logger.info("Using REST API for data retrieval")
247
+ return await self._get_data_rest_api(
248
+ days_list, var_list, minimal_response, significant_changes_only, test_url
249
+ )
250
+
251
+ def _build_history_url(
252
+ self,
253
+ day: pd.Timestamp,
254
+ var: str,
255
+ test_url: str,
256
+ minimal_response: bool,
257
+ significant_changes_only: bool,
258
+ ) -> str:
259
+ """Helper to construct the Home Assistant History URL."""
260
+ if test_url != "empty":
261
+ return test_url
262
+ # Check if using supervisor API or Core API
263
+ if self.hass_url == hass_url:
264
+ base_url = f"{self.hass_url}/history/period/{day.isoformat()}"
265
+ else:
266
+ # Ensure trailing slash for Core API
267
+ if self.hass_url[-1] != "/":
268
+ self.logger.warning(
269
+ "Missing slash </> at the end of the defined URL, appending a slash but please fix your URL"
270
+ )
271
+ self.hass_url = self.hass_url + "/"
272
+ base_url = f"{self.hass_url}api/history/period/{day.isoformat()}"
273
+ url = f"{base_url}?filter_entity_id={var}"
274
+ if minimal_response:
275
+ url += "?minimal_response"
276
+ if significant_changes_only:
277
+ url += "?significant_changes_only"
278
+ return url
279
+
280
+ async def _fetch_history_data(
281
+ self,
282
+ session: aiohttp.ClientSession,
283
+ url: str,
284
+ headers: dict,
285
+ var: str,
286
+ day: pd.Timestamp,
287
+ is_first_day: bool,
288
+ ) -> list | bool:
289
+ """Helper to execute the HTTP request and return the raw JSON list."""
290
+ try:
291
+ async with session.get(url, headers=headers) as response:
292
+ response.raise_for_status()
293
+ data = await response.read()
294
+ data_list = orjson.loads(data)
295
+ except Exception:
296
+ self.logger.error("Unable to access Home Assistant instance, check URL")
297
+ self.logger.error("If using addon, try setting url and token to 'empty'")
298
+ return False
299
+ if response.status == 401:
300
+ self.logger.error("Unable to access Home Assistant instance, TOKEN/KEY")
301
+ return False
302
+ if response.status > 299:
303
+ self.logger.error(f"Home assistant request GET error: {response.status} for var {var}")
304
+ return False
305
+ try:
306
+ return data_list[0]
307
+ except IndexError:
308
+ if is_first_day:
309
+ self.logger.error(
310
+ f"The retrieved JSON is empty, A sensor: {var} may have 0 days of history, "
311
+ "passed sensor may not be correct, or days to retrieve is set too high."
312
+ )
313
+ else:
314
+ self.logger.error(
315
+ f"The retrieved JSON is empty for day: {day}, days_to_retrieve may be larger "
316
+ f"than the recorded history of sensor: {var}"
317
+ )
318
+ return False
319
+
320
+ def _process_history_dataframe(
321
+ self, data: list, var: str, day: pd.Timestamp, is_first_day: bool, is_last_day: bool
322
+ ) -> pd.DataFrame | bool:
323
+ """Helper to convert raw data to a resampled DataFrame."""
324
+ df_raw = pd.DataFrame.from_dict(data)
325
+ # Check for empty DataFrame
326
+ if len(df_raw) == 0:
327
+ if is_first_day:
328
+ self.logger.error(
329
+ f"The retrieved Dataframe is empty, A sensor: {var} may have 0 days of history."
330
+ )
331
+ else:
332
+ self.logger.error(
333
+ f"Retrieved empty Dataframe for day: {day}, check recorder settings."
334
+ )
335
+ return False
336
+ # Check for data sufficiency (frequency consistency)
337
+ expected_count = (60 / (self.freq.seconds / 60)) * 24
338
+ if len(df_raw) < expected_count and not is_last_day:
339
+ self.logger.debug(
340
+ f"sensor: {var} retrieved Dataframe count: {len(df_raw)}, on day: {day}. "
341
+ f"This is less than freq value passed: {self.freq}"
342
+ )
343
+ # Process and Resample
344
+ df_tp = (
345
+ df_raw.copy()[["state"]]
346
+ .replace(["unknown", "unavailable", ""], np.nan)
347
+ .astype(float)
348
+ .rename(columns={"state": var})
349
+ )
350
+ df_tp.set_index(pd.to_datetime(df_raw["last_changed"], format="ISO8601"), inplace=True)
351
+ df_tp = df_tp.resample(self.freq).mean()
352
+ return df_tp
353
+
354
+ async def _get_data_rest_api(
355
+ self,
356
+ days_list: pd.date_range,
357
+ var_list: list,
358
+ minimal_response: bool | None = False,
359
+ significant_changes_only: bool | None = False,
360
+ test_url: str | None = "empty",
361
+ ) -> None:
362
+ """Internal method to handle REST API data retrieval."""
127
363
  self.logger.info("Retrieve hass get data method initiated...")
128
364
  headers = {
129
- "Authorization": "Bearer " + self.long_lived_token,
130
- "content-type": "application/json",
365
+ "Authorization": header_auth + " " + self.long_lived_token,
366
+ "content-type": header_accept,
131
367
  }
132
- # Looping on each day from days list
368
+ var_list = [var for var in var_list if var != ""]
133
369
  self.df_final = pd.DataFrame()
134
- x = 0 # iterate based on days
135
- for day in days_list:
136
- for i, var in enumerate(var_list):
137
- if test_url == "empty":
138
- if (
139
- self.hass_url == "http://supervisor/core/api"
140
- ): # If we are using the supervisor API
141
- url = (
142
- self.hass_url
143
- + "/history/period/"
144
- + day.isoformat()
145
- + "?filter_entity_id="
146
- + var
147
- )
148
- else: # Otherwise the Home Assistant Core API it is
149
- url = (
150
- self.hass_url
151
- + "api/history/period/"
152
- + day.isoformat()
153
- + "?filter_entity_id="
154
- + var
155
- )
156
- if minimal_response: # A support for minimal response
157
- url = url + "?minimal_response"
158
- if significant_changes_only: # And for signicant changes only (check the HASS restful API for more info)
159
- url = url + "?significant_changes_only"
160
- else:
161
- url = test_url
162
- try:
163
- response = get(url, headers=headers)
164
- except Exception:
165
- self.logger.error(
166
- "Unable to access Home Assistance instance, check URL"
370
+
371
+ async with aiohttp.ClientSession() as session:
372
+ for day_idx, day in enumerate(days_list):
373
+ df_day = pd.DataFrame()
374
+ for i, var in enumerate(var_list):
375
+ # Build URL
376
+ url = self._build_history_url(
377
+ day, var, test_url, minimal_response, significant_changes_only
167
378
  )
168
- self.logger.error(
169
- "If using addon, try setting url and token to 'empty'"
379
+ # Fetch Data
380
+ data = await self._fetch_history_data(
381
+ session, url, headers, var, day, is_first_day=(day_idx == 0)
170
382
  )
171
- return False
172
- else:
173
- if response.status_code == 401:
174
- self.logger.error(
175
- "Unable to access Home Assistance instance, TOKEN/KEY"
176
- )
177
- self.logger.error(
178
- "If using addon, try setting url and token to 'empty'"
179
- )
383
+ if data is False:
180
384
  return False
181
- if response.status_code > 299:
182
- return f"Request Get Error: {response.status_code}"
183
- """import bz2 # Uncomment to save a serialized data for tests
184
- import _pickle as cPickle
185
- with bz2.BZ2File("data/test_response_get_data_get_method.pbz2", "w") as f:
186
- cPickle.dump(response, f)"""
187
- try: # Sometimes when there are connection problems we need to catch empty retrieved json
188
- data = response.json()[0]
189
- except IndexError:
190
- if x == 0:
191
- self.logger.error(
192
- "The retrieved JSON is empty, A sensor:"
193
- + var
194
- + " may have 0 days of history, passed sensor may not be correct, or days to retrieve is set too heigh"
195
- )
196
- else:
197
- self.logger.error(
198
- "The retrieved JSON is empty for day:"
199
- + str(day)
200
- + ", days_to_retrieve may be larger than the recorded history of sensor:"
201
- + var
202
- + " (check your recorder settings)"
203
- )
204
- return False
205
- df_raw = pd.DataFrame.from_dict(data)
206
- # self.logger.info(str(df_raw))
207
- if len(df_raw) == 0:
208
- if x == 0:
209
- self.logger.error(
210
- "The retrieved Dataframe is empty, A sensor:"
211
- + var
212
- + " may have 0 days of history or passed sensor may not be correct"
213
- )
214
- else:
215
- self.logger.error(
216
- "Retrieved empty Dataframe for day:"
217
- + str(day)
218
- + ", days_to_retrieve may be larger than the recorded history of sensor:"
219
- + var
220
- + " (check your recorder settings)"
221
- )
222
- return False
223
- # self.logger.info(self.freq.seconds)
224
- if (
225
- len(df_raw) < ((60 / (self.freq.seconds / 60)) * 24)
226
- and x != len(days_list) - 1
227
- ): # check if there is enough Dataframes for passed frequency per day (not inc current day)
228
- self.logger.debug(
229
- "sensor:"
230
- + var
231
- + " retrieved Dataframe count: "
232
- + str(len(df_raw))
233
- + ", on day: "
234
- + str(day)
235
- + ". This is less than freq value passed: "
236
- + str(self.freq)
385
+ # Process Data
386
+ df_resampled = self._process_history_dataframe(
387
+ data,
388
+ var,
389
+ day,
390
+ is_first_day=(day_idx == 0),
391
+ is_last_day=(day_idx == len(days_list) - 1),
237
392
  )
238
- if i == 0: # Defining the DataFrame container
239
- from_date = pd.to_datetime(
240
- df_raw["last_changed"], format="ISO8601"
241
- ).min()
242
- to_date = pd.to_datetime(
243
- df_raw["last_changed"], format="ISO8601"
244
- ).max()
245
- ts = pd.to_datetime(
246
- pd.date_range(start=from_date, end=to_date, freq=self.freq),
247
- format="%Y-%d-%m %H:%M",
248
- ).round(self.freq, ambiguous="infer", nonexistent="shift_forward")
249
- df_day = pd.DataFrame(index=ts)
250
- # Caution with undefined string data: unknown, unavailable, etc.
251
- df_tp = (
252
- df_raw.copy()[["state"]]
253
- .replace(["unknown", "unavailable", ""], np.nan)
254
- .astype(float)
255
- .rename(columns={"state": var})
256
- )
257
- # Setting index, resampling and concatenation
258
- df_tp.set_index(
259
- pd.to_datetime(df_raw["last_changed"], format="ISO8601"),
260
- inplace=True,
261
- )
262
- df_tp = df_tp.resample(self.freq).mean()
263
- df_day = pd.concat([df_day, df_tp], axis=1)
264
- self.df_final = pd.concat([self.df_final, df_day], axis=0)
265
- x += 1
393
+ if df_resampled is False:
394
+ return False
395
+ # Merge into daily DataFrame
396
+ # If it's the first variable, we initialize the day's index based on it
397
+ if i == 0:
398
+ df_day = pd.DataFrame(index=df_resampled.index)
399
+ # Ensure the daily index is regularized to the frequency
400
+ # Note: The original logic created a manual range here, but using the
401
+ # resampled index from the first variable is safer and cleaner if
402
+ # _process_history_dataframe handles resampling correctly.
403
+ df_day = pd.concat([df_day, df_resampled], axis=1)
404
+ self.df_final = pd.concat([self.df_final, df_day], axis=0)
405
+
406
+ # Final Cleanup
266
407
  self.df_final = set_df_index_freq(self.df_final)
267
408
  if self.df_final.index.freq != self.freq:
268
409
  self.logger.error(
269
- "The inferred freq:"
270
- + str(self.df_final.index.freq)
271
- + " from data is not equal to the defined freq in passed:"
272
- + str(self.freq)
410
+ f"The inferred freq: {self.df_final.index.freq} from data is not equal "
411
+ f"to the defined freq in passed: {self.freq}"
412
+ )
413
+ return False
414
+ self.var_list = var_list
415
+ return True
416
+
417
+ async def get_data_websocket(
418
+ self,
419
+ days_list: pd.date_range,
420
+ var_list: list[str],
421
+ ) -> bool:
422
+ r"""
423
+ Retrieve the actual data from hass.
424
+
425
+ :param days_list: A list of days to retrieve. The ISO format should be used \
426
+ and the timezone is UTC. The frequency of the data_range should be freq='D'
427
+ :type days_list: pandas.date_range
428
+ :param var_list: The list of variables to retrive from hass. These should \
429
+ be the exact name of the sensor in Home Assistant. \
430
+ For example: ['sensor.home_load', 'sensor.home_pv']
431
+ :type var_list: list
432
+ :return: The DataFrame populated with the retrieved data from hass
433
+ :rtype: pandas.DataFrame
434
+ """
435
+ try:
436
+ self._client = await asyncio.wait_for(
437
+ get_websocket_client(self.hass_url, self.long_lived_token, self.logger),
438
+ timeout=20.0,
273
439
  )
440
+ except TimeoutError:
441
+ self.logger.error("WebSocket connection timed out")
442
+ return False
443
+ except Exception as e:
444
+ self.logger.error(f"Websocket connection error: {e}")
445
+ return False
446
+
447
+ self.var_list = var_list
448
+
449
+ # Calculate time range
450
+ start_time = min(days_list).to_pydatetime()
451
+ end_time = datetime.now()
452
+
453
+ # Try to get statistics data (which contains the actual historical data)
454
+ try:
455
+ # Get statistics data with 5-minute period for good resolution
456
+ t0 = time.time()
457
+ stats_data = await asyncio.wait_for(
458
+ self._client.get_statistics(
459
+ start_time=start_time,
460
+ end_time=end_time,
461
+ statistic_ids=var_list,
462
+ period="5minute",
463
+ ),
464
+ timeout=30.0,
465
+ )
466
+
467
+ # Convert statistics data to DataFrame
468
+ self.df_final = self._convert_statistics_to_dataframe(stats_data, var_list)
469
+
470
+ t1 = time.time()
471
+ self.logger.info(f"Statistics data retrieval took {t1 - t0:.2f} seconds")
472
+
473
+ return not self.df_final.empty
474
+
475
+ except Exception as e:
476
+ self.logger.error(f"Failed to get data via WebSocket: {e}")
477
+ return False
478
+
479
+ def get_data_influxdb(
480
+ self,
481
+ days_list: pd.date_range,
482
+ var_list: list,
483
+ ) -> bool:
484
+ """
485
+ Retrieve data from InfluxDB database.
486
+
487
+ This method provides an alternative data source to Home Assistant API,
488
+ enabling longer historical data retention for better machine learning model training.
489
+
490
+ :param days_list: A list of days to retrieve data for
491
+ :type days_list: pandas.date_range
492
+ :param var_list: List of sensor entity IDs to retrieve
493
+ :type var_list: list
494
+ :return: Success status of data retrieval
495
+ :rtype: bool
496
+ """
497
+ self.logger.info("Retrieve InfluxDB get data method initiated...")
498
+
499
+ # Check for empty inputs
500
+ if not days_list.size:
501
+ self.logger.error("Empty days_list provided")
502
+ return False
503
+
504
+ client = self._init_influx_client()
505
+ if not client:
506
+ return False
507
+
508
+ # Convert all timestamps to UTC for comparison, then make naive for InfluxDB
509
+ # This ensures we compare actual instants in time, not wall clock times
510
+ # InfluxDB queries expect naive UTC timestamps (with 'Z' suffix)
511
+
512
+ # Normalize start_time to pd.Timestamp in UTC
513
+ start_time = pd.Timestamp(days_list[0])
514
+ if start_time.tz is not None:
515
+ start_time = start_time.tz_convert("UTC").tz_localize(None)
516
+ # If naive, assume it's already UTC
517
+
518
+ # Get current time in UTC
519
+ now = pd.Timestamp.now(tz="UTC").tz_localize(None)
520
+
521
+ # Normalize requested_end to pd.Timestamp in UTC
522
+ requested_end = pd.Timestamp(days_list[-1]) + pd.Timedelta(days=1)
523
+ if requested_end.tz is not None:
524
+ requested_end = requested_end.tz_convert("UTC").tz_localize(None)
525
+ # If naive, assume it's already UTC
526
+
527
+ # Cap end_time at current time to avoid querying future data
528
+ # This prevents FILL(previous) from creating fake future datapoints
529
+ end_time = min(now, requested_end)
530
+ total_days = (end_time - start_time).days
531
+
532
+ self.logger.info(f"Retrieving {len(var_list)} sensors over {total_days} days from InfluxDB")
533
+ self.logger.debug(f"Time range: {start_time} to {end_time}")
534
+ if end_time < requested_end:
535
+ self.logger.debug(f"End time capped at current time (requested: {requested_end})")
536
+
537
+ # Collect sensor dataframes
538
+ sensor_dfs = []
539
+ global_min_time = None
540
+ global_max_time = None
541
+
542
+ for sensor in filter(None, var_list):
543
+ df_sensor = self._fetch_sensor_data(client, sensor, start_time, end_time)
544
+ if df_sensor is not None:
545
+ sensor_dfs.append(df_sensor)
546
+ # Track global time range
547
+ sensor_min = df_sensor.index.min()
548
+ sensor_max = df_sensor.index.max()
549
+ global_min_time = min(global_min_time or sensor_min, sensor_min)
550
+ global_max_time = max(global_max_time or sensor_max, sensor_max)
551
+
552
+ client.close()
553
+
554
+ if not sensor_dfs:
555
+ self.logger.error("No data retrieved from InfluxDB")
556
+ return False
557
+
558
+ # Create complete time index covering all sensors
559
+ if global_min_time is not None and global_max_time is not None:
560
+ complete_index = pd.date_range(
561
+ start=global_min_time, end=global_max_time, freq=self.freq
562
+ )
563
+ self.df_final = pd.DataFrame(index=complete_index)
564
+
565
+ # Merge all sensor dataframes
566
+ for df_sensor in sensor_dfs:
567
+ self.df_final = pd.concat([self.df_final, df_sensor], axis=1)
568
+
569
+ # Set frequency and validate with error handling
570
+ try:
571
+ self.df_final = set_df_index_freq(self.df_final)
572
+ except Exception as e:
573
+ self.logger.error(f"Exception occurred while setting DataFrame index frequency: {e}")
274
574
  return False
575
+
576
+ if self.df_final.index.freq != self.freq:
577
+ self.logger.warning(
578
+ f"InfluxDB data frequency ({self.df_final.index.freq}) differs from expected ({self.freq})"
579
+ )
580
+
581
+ self.var_list = var_list
582
+ self.logger.info(f"InfluxDB data retrieval completed: {self.df_final.shape}")
275
583
  return True
276
584
 
585
+ def _init_influx_client(self):
586
+ """Initialize InfluxDB client connection."""
587
+ try:
588
+ from influxdb import InfluxDBClient
589
+ except ImportError:
590
+ self.logger.error("InfluxDB client not installed. Install with: pip install influxdb")
591
+ return None
592
+
593
+ try:
594
+ client = InfluxDBClient(
595
+ host=self.influxdb_host,
596
+ port=self.influxdb_port,
597
+ username=self.influxdb_username or None,
598
+ password=self.influxdb_password or None,
599
+ database=self.influxdb_database,
600
+ ssl=self.influxdb_use_ssl,
601
+ verify_ssl=self.influxdb_verify_ssl,
602
+ )
603
+ # Test connection
604
+ client.ping()
605
+ self.logger.debug(
606
+ f"Successfully connected to InfluxDB at {self.influxdb_host}:{self.influxdb_port}"
607
+ )
608
+
609
+ # Initialize measurement cache
610
+ if not hasattr(self, "_measurement_cache"):
611
+ self._measurement_cache = {}
612
+
613
+ return client
614
+ except Exception as e:
615
+ self.logger.error(f"Failed to connect to InfluxDB: {e}")
616
+ return None
617
+
618
+ def _discover_entity_measurement(self, client, entity_id: str) -> str:
619
+ """Auto-discover which measurement contains the given entity."""
620
+ # Check cache first
621
+ if entity_id in self._measurement_cache:
622
+ return self._measurement_cache[entity_id]
623
+
624
+ try:
625
+ # Get all available measurements
626
+ measurements_query = "SHOW MEASUREMENTS"
627
+ measurements_result = client.query(measurements_query)
628
+ measurements = [m["name"] for m in measurements_result.get_points()]
629
+
630
+ # Priority order: check common sensor types first
631
+ priority_measurements = ["EUR/kWh", "€/kWh", "W", "EUR", "€", "%", "A", "V"]
632
+ all_measurements = priority_measurements + [
633
+ m for m in measurements if m not in priority_measurements
634
+ ]
635
+
636
+ self.logger.debug(
637
+ f"Searching for entity '{entity_id}' across {len(measurements)} measurements"
638
+ )
639
+
640
+ # Search for entity in each measurement
641
+ for measurement in all_measurements:
642
+ if measurement not in measurements:
643
+ continue # Skip if measurement doesn't exist
644
+
645
+ try:
646
+ # Use SHOW TAG VALUES to get all entity_ids in this measurement
647
+ tag_query = f'SHOW TAG VALUES FROM "{measurement}" WITH KEY = "entity_id"'
648
+ self.logger.debug(
649
+ f"Checking measurement '{measurement}' with tag query: {tag_query}"
650
+ )
651
+ result = client.query(tag_query)
652
+ points = list(result.get_points())
653
+
654
+ # Check if our target entity_id is in the tag values
655
+ for point in points:
656
+ if point.get("value") == entity_id:
657
+ self.logger.debug(
658
+ f"Found entity '{entity_id}' in measurement '{measurement}'"
659
+ )
660
+ # Cache the result
661
+ self._measurement_cache[entity_id] = measurement
662
+ return measurement
663
+
664
+ except Exception as query_error:
665
+ self.logger.debug(
666
+ f"Tag query failed for measurement '{measurement}': {query_error}"
667
+ )
668
+ continue
669
+
670
+ except Exception as e:
671
+ self.logger.error(f"Error discovering measurement for entity {entity_id}: {e}")
672
+
673
+ # Fallback to default measurement if not found
674
+ self.logger.warning(
675
+ f"Entity '{entity_id}' not found in any measurement, using default: {self.influxdb_measurement}"
676
+ )
677
+ return self.influxdb_measurement
678
+
679
+ def _build_influx_query_for_measurement(
680
+ self, entity_id: str, measurement: str, start_time, end_time
681
+ ) -> str:
682
+ """Build InfluxQL query for specific measurement and entity."""
683
+ # Convert frequency to InfluxDB interval
684
+ freq_minutes = int(self.freq.total_seconds() / 60)
685
+ interval = f"{freq_minutes}m"
686
+
687
+ # Format times properly for InfluxDB
688
+ start_time_str = start_time.strftime("%Y-%m-%dT%H:%M:%SZ")
689
+ end_time_str = end_time.strftime("%Y-%m-%dT%H:%M:%SZ")
690
+
691
+ # Use FILL(previous) instead of FILL(linear) for compatibility with open-source InfluxDB
692
+ query = f"""
693
+ SELECT mean("value") AS "mean_value"
694
+ FROM "{self.influxdb_database}"."{self.influxdb_retention_policy}"."{measurement}"
695
+ WHERE time >= '{start_time_str}'
696
+ AND time < '{end_time_str}'
697
+ AND "entity_id"='{entity_id}'
698
+ GROUP BY time({interval}) FILL(previous)
699
+ """
700
+ return query
701
+
702
+ def _build_influx_query(self, sensor: str, start_time, end_time) -> str:
703
+ """Build InfluxQL query for sensor data retrieval (legacy method)."""
704
+ # Convert sensor name: sensor.sec_pac_solar -> sec_pac_solar
705
+ entity_id = (
706
+ sensor.replace(sensor_prefix, "") if sensor.startswith(sensor_prefix) else sensor
707
+ )
708
+
709
+ # Use default measurement (for backward compatibility)
710
+ return self._build_influx_query_for_measurement(
711
+ entity_id, self.influxdb_measurement, start_time, end_time
712
+ )
713
+
714
+ def _fetch_sensor_data(self, client, sensor: str, start_time, end_time):
715
+ """Fetch and process data for a single sensor with auto-discovery."""
716
+ self.logger.debug(f"Retrieving sensor: {sensor}")
717
+
718
+ # Clean sensor name (remove sensor. prefix if present)
719
+ entity_id = (
720
+ sensor.replace(sensor_prefix, "") if sensor.startswith(sensor_prefix) else sensor
721
+ )
722
+
723
+ # Auto-discover which measurement contains this entity
724
+ measurement = self._discover_entity_measurement(client, entity_id)
725
+ if not measurement:
726
+ self.logger.warning(f"Entity '{entity_id}' not found in any InfluxDB measurement")
727
+ return None
728
+
729
+ try:
730
+ query = self._build_influx_query_for_measurement(
731
+ entity_id, measurement, start_time, end_time
732
+ )
733
+ self.logger.debug(f"InfluxDB query: {query}")
734
+
735
+ # Execute query
736
+ result = client.query(query)
737
+ points = list(result.get_points())
738
+
739
+ if not points:
740
+ self.logger.warning(
741
+ f"No data found for entity: {entity_id} in measurement: {measurement}"
742
+ )
743
+ return None
744
+
745
+ self.logger.info(f"Retrieved {len(points)} data points for {sensor}")
746
+
747
+ # Create DataFrame from points
748
+ df_sensor = pd.DataFrame(points)
749
+
750
+ # Convert time column and set as index with timezone awareness
751
+ df_sensor["time"] = pd.to_datetime(df_sensor["time"], utc=True)
752
+ df_sensor.set_index("time", inplace=True)
753
+
754
+ # Rename value column to original sensor name
755
+ if "mean_value" in df_sensor.columns:
756
+ df_sensor = df_sensor[["mean_value"]].rename(columns={"mean_value": sensor})
757
+ else:
758
+ self.logger.error(
759
+ f"Expected 'mean_value' column not found for {sensor} in measurement {measurement}"
760
+ )
761
+ return None
762
+
763
+ # Handle non-numeric data with NaN ratio warning
764
+ df_sensor[sensor] = pd.to_numeric(df_sensor[sensor], errors="coerce")
765
+
766
+ # Check proportion of NaNs and log warning if high
767
+ nan_count = df_sensor[sensor].isna().sum()
768
+ total_count = len(df_sensor[sensor])
769
+ if total_count > 0:
770
+ nan_ratio = nan_count / total_count
771
+ if nan_ratio > 0.2:
772
+ self.logger.warning(
773
+ f"Entity '{entity_id}' has {nan_count}/{total_count} ({nan_ratio:.1%}) non-numeric values coerced to NaN."
774
+ )
775
+
776
+ self.logger.debug(
777
+ f"Successfully retrieved {len(df_sensor)} data points for '{entity_id}' from measurement '{measurement}'"
778
+ )
779
+ return df_sensor
780
+
781
+ except Exception as e:
782
+ self.logger.error(
783
+ f"Failed to query entity {entity_id} from measurement {measurement}: {e}"
784
+ )
785
+ return None
786
+
787
+ def _validate_sensor_list(self, target_list: list, list_name: str) -> list:
788
+ """Helper to validate that config lists only contain known sensors."""
789
+ if not isinstance(target_list, list):
790
+ return []
791
+ valid_items = [item for item in target_list if item in self.var_list]
792
+ removed = set(target_list) - set(valid_items)
793
+ for item in removed:
794
+ self.logger.warning(
795
+ f"Sensor '{item}' in {list_name} not found in self.var_list and has been removed."
796
+ )
797
+ return valid_items
798
+
799
+ def _process_load_column_renaming(
800
+ self, var_load: str, load_negative: bool, skip_renaming: bool
801
+ ) -> bool:
802
+ """Helper to handle the sign flip and renaming of the main load column."""
803
+ if skip_renaming:
804
+ return True
805
+ try:
806
+ # Apply the correct sign to load power
807
+ if load_negative:
808
+ self.df_final[var_load + "_positive"] = -self.df_final[var_load]
809
+ else:
810
+ self.df_final[var_load + "_positive"] = self.df_final[var_load]
811
+ self.df_final.drop([var_load], inplace=True, axis=1)
812
+ # Update var_list to reflect the renamed column
813
+ self.var_list = [var.replace(var_load, var_load + "_positive") for var in self.var_list]
814
+ self.logger.debug(f"prepare_data var_list updated after rename: {self.var_list}")
815
+ return True
816
+ except KeyError as e:
817
+ self.logger.error(
818
+ f"Variable '{var_load}' was not found in DataFrame columns: {list(self.df_final.columns)}. "
819
+ f"This is typically because no data could be retrieved from Home Assistant or InfluxDB. Error: {e}"
820
+ )
821
+ return False
822
+ except ValueError:
823
+ self.logger.error(
824
+ "sensor.power_photovoltaics and sensor.power_load_no_var_loads should not be the same"
825
+ )
826
+ return False
827
+
828
+ def _map_variable_names(
829
+ self, target_list: list, var_load: str, skip_renaming: bool, param_name: str
830
+ ) -> list | None:
831
+ """Helper to map old variable names to new ones (if renaming occurred)."""
832
+ if not target_list:
833
+ self.logger.warning(f"Unable to find all the sensors in {param_name} parameter")
834
+ self.logger.warning(
835
+ f"Confirm sure all sensors in {param_name} are sensor_power_photovoltaics and/or sensor_power_load_no_var_loads"
836
+ )
837
+ return None
838
+ new_list = []
839
+ for string in target_list:
840
+ if not skip_renaming:
841
+ new_list.append(string.replace(var_load, var_load + "_positive"))
842
+ else:
843
+ new_list.append(string)
844
+ return new_list
845
+
277
846
  def prepare_data(
278
847
  self,
279
848
  var_load: str,
280
- load_negative: Optional[bool] = False,
281
- set_zero_min: Optional[bool] = True,
282
- var_replace_zero: Optional[list] = None,
283
- var_interp: Optional[list] = None,
284
- ) -> None:
849
+ load_negative: bool,
850
+ set_zero_min: bool,
851
+ var_replace_zero: list[str],
852
+ var_interp: list[str],
853
+ skip_renaming: bool = False,
854
+ ) -> bool:
285
855
  r"""
286
856
  Apply some data treatment in preparation for the optimization task.
287
-
857
+
288
858
  :param var_load: The name of the variable for the household load consumption.
289
859
  :type var_load: str
290
860
  :param load_negative: Set to True if the retrived load variable is \
@@ -302,58 +872,38 @@ class RetrieveHass:
302
872
  :return: The DataFrame populated with the retrieved data from hass and \
303
873
  after the data treatment
304
874
  :rtype: pandas.DataFrame
305
-
875
+
306
876
  """
307
- try:
308
- if load_negative: # Apply the correct sign to load power
309
- self.df_final[var_load + "_positive"] = -self.df_final[var_load]
310
- else:
311
- self.df_final[var_load + "_positive"] = self.df_final[var_load]
312
- self.df_final.drop([var_load], inplace=True, axis=1)
313
- except KeyError:
314
- self.logger.error(
315
- "Variable "
316
- + var_load
317
- + " was not found. This is typically because no data could be retrieved from Home Assistant"
318
- )
877
+ self.logger.debug("prepare_data self.var_list=%s", self.var_list)
878
+ self.logger.debug("prepare_data var_load=%s", var_load)
879
+ # Validate Input Lists
880
+ var_replace_zero = self._validate_sensor_list(var_replace_zero, "var_replace_zero")
881
+ var_interp = self._validate_sensor_list(var_interp, "var_interp")
882
+ # Rename Load Columns (Handle sign change)
883
+ if not self._process_load_column_renaming(var_load, load_negative, skip_renaming):
319
884
  return False
320
- except ValueError:
321
- self.logger.error(
322
- "sensor.power_photovoltaics and sensor.power_load_no_var_loads should not be the same"
323
- )
324
- return False
325
- if set_zero_min: # Apply minimum values
885
+ # Apply Zero Saturation (Min value clipping)
886
+ if set_zero_min:
326
887
  self.df_final.clip(lower=0.0, inplace=True, axis=1)
327
888
  self.df_final.replace(to_replace=0.0, value=np.nan, inplace=True)
328
- new_var_replace_zero = []
329
- new_var_interp = []
330
- # Just changing the names of variables to contain the fact that they are considered positive
331
- if var_replace_zero is not None:
332
- for string in var_replace_zero:
333
- new_string = string.replace(var_load, var_load + "_positive")
334
- new_var_replace_zero.append(new_string)
335
- else:
336
- new_var_replace_zero = None
337
- if var_interp is not None:
338
- for string in var_interp:
339
- new_string = string.replace(var_load, var_load + "_positive")
340
- new_var_interp.append(new_string)
341
- else:
342
- new_var_interp = None
343
- # Treating NaN replacement: either by zeros or by linear interpolation
344
- if new_var_replace_zero is not None:
345
- self.df_final[new_var_replace_zero] = self.df_final[
346
- new_var_replace_zero
347
- ].fillna(0.0)
348
- if new_var_interp is not None:
889
+ # Map Variable Names (Update lists to match new column names)
890
+ new_var_replace_zero = self._map_variable_names(
891
+ var_replace_zero, var_load, skip_renaming, "sensor_replace_zero"
892
+ )
893
+ new_var_interp = self._map_variable_names(
894
+ var_interp, var_load, skip_renaming, "sensor_linear_interp"
895
+ )
896
+ # Apply Data Cleaning (FillNA / Interpolate)
897
+ if new_var_replace_zero:
898
+ self.df_final[new_var_replace_zero] = self.df_final[new_var_replace_zero].fillna(0.0)
899
+ if new_var_interp:
349
900
  self.df_final[new_var_interp] = self.df_final[new_var_interp].interpolate(
350
901
  method="linear", axis=0, limit=None
351
902
  )
352
903
  self.df_final[new_var_interp] = self.df_final[new_var_interp].fillna(0.0)
353
- # Setting the correct time zone on DF index
904
+ # Finalize Index (Timezone and Duplicates)
354
905
  if self.time_zone is not None:
355
906
  self.df_final.index = self.df_final.index.tz_convert(self.time_zone)
356
- # Drop datetimeindex duplicates on final DF
357
907
  self.df_final = self.df_final[~self.df_final.index.duplicated(keep="first")]
358
908
  return True
359
909
 
@@ -362,24 +912,27 @@ class RetrieveHass:
362
912
  data_df: pd.DataFrame,
363
913
  idx: int,
364
914
  entity_id: str,
915
+ device_class: str,
365
916
  unit_of_measurement: str,
366
917
  friendly_name: str,
367
918
  list_name: str,
368
919
  state: float,
920
+ decimals: int = 2,
369
921
  ) -> dict:
370
922
  list_df = copy.deepcopy(data_df).loc[data_df.index[idx] :].reset_index()
371
923
  list_df.columns = ["timestamps", entity_id]
372
- ts_list = [str(i) for i in list_df["timestamps"].tolist()]
373
- vals_list = [str(np.round(i, 2)) for i in list_df[entity_id].tolist()]
924
+ ts_list = [i.isoformat() for i in list_df["timestamps"].tolist()]
925
+ vals_list = [str(np.round(i, decimals)) for i in list_df[entity_id].tolist()]
374
926
  forecast_list = []
375
927
  for i, ts in enumerate(ts_list):
376
928
  datum = {}
377
929
  datum["date"] = ts
378
- datum[entity_id.split("sensor.")[1]] = vals_list[i]
930
+ datum[entity_id.split(sensor_prefix)[1]] = vals_list[i]
379
931
  forecast_list.append(datum)
380
932
  data = {
381
- "state": "{:.2f}".format(state),
933
+ "state": f"{state:.{decimals}f}",
382
934
  "attributes": {
935
+ "device_class": device_class,
383
936
  "unit_of_measurement": unit_of_measurement,
384
937
  "friendly_name": friendly_name,
385
938
  list_name: forecast_list,
@@ -387,23 +940,27 @@ class RetrieveHass:
387
940
  }
388
941
  return data
389
942
 
390
- def post_data(
943
+ async def post_data(
391
944
  self,
392
945
  data_df: pd.DataFrame,
393
946
  idx: int,
394
947
  entity_id: str,
948
+ device_class: str,
395
949
  unit_of_measurement: str,
396
950
  friendly_name: str,
397
951
  type_var: str,
398
- from_mlforecaster: Optional[bool] = False,
399
- publish_prefix: Optional[str] = "",
400
- save_entities: Optional[bool] = False,
401
- logger_levels: Optional[str] = "info",
402
- dont_post: Optional[bool] = False,
952
+ publish_prefix: str | None = "",
953
+ save_entities: bool | None = False,
954
+ logger_levels: str | None = "info",
955
+ dont_post: bool | None = False,
403
956
  ) -> None:
404
957
  r"""
405
- Post passed data to hass.
406
-
958
+ Post passed data to hass using REST API.
959
+
960
+ .. note:: This method ALWAYS uses the REST API for posting data to Home Assistant,
961
+ regardless of the use_websocket setting. WebSocket is only used for
962
+ data retrieval, not for publishing/posting data.
963
+
407
964
  :param data_df: The DataFrame containing the data that will be posted \
408
965
  to hass. This should be a one columns DF or a series.
409
966
  :type data_df: pd.DataFrame
@@ -412,6 +969,8 @@ class RetrieveHass:
412
969
  :type idx: int
413
970
  :param entity_id: The unique entity_id of the sensor in hass.
414
971
  :type entity_id: str
972
+ :param device_class: The HASS device class for the sensor.
973
+ :type device_class: str
415
974
  :param unit_of_measurement: The units of the sensor.
416
975
  :type unit_of_measurement: str
417
976
  :param friendly_name: The friendly name that will be used in the hass frontend.
@@ -422,24 +981,22 @@ class RetrieveHass:
422
981
  :type publish_prefix: str, optional
423
982
  :param save_entities: if entity data should be saved in data_path/entities
424
983
  :type save_entities: bool, optional
425
- :param logger_levels: set logger level, info or debug, to output
984
+ :param logger_levels: set logger level, info or debug, to output
426
985
  :type logger_levels: str, optional
427
986
  :param dont_post: dont post to HA
428
987
  :type dont_post: bool, optional
429
988
 
430
989
  """
431
990
  # Add a possible prefix to the entity ID
432
- entity_id = entity_id.replace("sensor.", "sensor." + publish_prefix)
991
+ entity_id = entity_id.replace(sensor_prefix, sensor_prefix + publish_prefix)
433
992
  # Set the URL
434
- if (
435
- self.hass_url == "http://supervisor/core/api"
436
- ): # If we are using the supervisor API
993
+ if self.hass_url == hass_url: # If we are using the supervisor API
437
994
  url = self.hass_url + "/states/" + entity_id
438
995
  else: # Otherwise the Home Assistant Core API it is
439
996
  url = self.hass_url + "api/states/" + entity_id
440
997
  headers = {
441
- "Authorization": "Bearer " + self.long_lived_token,
442
- "content-type": "application/json",
998
+ "Authorization": header_auth + " " + self.long_lived_token,
999
+ "content-type": header_accept,
443
1000
  }
444
1001
  # Preparing the data dict to be published
445
1002
  if type_var == "cost_fun":
@@ -451,7 +1008,7 @@ class RetrieveHass:
451
1008
  elif type_var == "optim_status":
452
1009
  state = data_df.loc[data_df.index[idx]]
453
1010
  elif type_var == "mlregressor":
454
- state = data_df[idx]
1011
+ state = float(data_df[idx])
455
1012
  else:
456
1013
  state = np.round(data_df.loc[data_df.index[idx]], 2)
457
1014
  if type_var == "power":
@@ -459,6 +1016,7 @@ class RetrieveHass:
459
1016
  data_df,
460
1017
  idx,
461
1018
  entity_id,
1019
+ device_class,
462
1020
  unit_of_measurement,
463
1021
  friendly_name,
464
1022
  "forecasts",
@@ -469,6 +1027,7 @@ class RetrieveHass:
469
1027
  data_df,
470
1028
  idx,
471
1029
  entity_id,
1030
+ device_class,
472
1031
  unit_of_measurement,
473
1032
  friendly_name,
474
1033
  "deferrables_schedule",
@@ -479,6 +1038,7 @@ class RetrieveHass:
479
1038
  data_df,
480
1039
  idx,
481
1040
  entity_id,
1041
+ device_class,
482
1042
  unit_of_measurement,
483
1043
  friendly_name,
484
1044
  "predicted_temperatures",
@@ -489,6 +1049,7 @@ class RetrieveHass:
489
1049
  data_df,
490
1050
  idx,
491
1051
  entity_id,
1052
+ device_class,
492
1053
  unit_of_measurement,
493
1054
  friendly_name,
494
1055
  "battery_scheduled_power",
@@ -499,6 +1060,7 @@ class RetrieveHass:
499
1060
  data_df,
500
1061
  idx,
501
1062
  entity_id,
1063
+ device_class,
502
1064
  unit_of_measurement,
503
1065
  friendly_name,
504
1066
  "battery_scheduled_soc",
@@ -509,36 +1071,51 @@ class RetrieveHass:
509
1071
  data_df,
510
1072
  idx,
511
1073
  entity_id,
1074
+ device_class,
512
1075
  unit_of_measurement,
513
1076
  friendly_name,
514
1077
  "unit_load_cost_forecasts",
515
1078
  state,
1079
+ decimals=4,
516
1080
  )
517
1081
  elif type_var == "unit_prod_price":
518
1082
  data = RetrieveHass.get_attr_data_dict(
519
1083
  data_df,
520
1084
  idx,
521
1085
  entity_id,
1086
+ device_class,
522
1087
  unit_of_measurement,
523
1088
  friendly_name,
524
1089
  "unit_prod_price_forecasts",
525
1090
  state,
1091
+ decimals=4,
526
1092
  )
527
1093
  elif type_var == "mlforecaster":
528
1094
  data = RetrieveHass.get_attr_data_dict(
529
1095
  data_df,
530
1096
  idx,
531
1097
  entity_id,
1098
+ device_class,
532
1099
  unit_of_measurement,
533
1100
  friendly_name,
534
1101
  "scheduled_forecast",
535
1102
  state,
536
1103
  )
1104
+ elif type_var == "energy":
1105
+ data = RetrieveHass.get_attr_data_dict(
1106
+ data_df,
1107
+ idx,
1108
+ entity_id,
1109
+ device_class,
1110
+ unit_of_measurement,
1111
+ friendly_name,
1112
+ "heating_demand_forecast",
1113
+ state,
1114
+ )
537
1115
  elif type_var == "optim_status":
538
1116
  data = {
539
1117
  "state": state,
540
1118
  "attributes": {
541
- "unit_of_measurement": unit_of_measurement,
542
1119
  "friendly_name": friendly_name,
543
1120
  },
544
1121
  }
@@ -546,39 +1123,53 @@ class RetrieveHass:
546
1123
  data = {
547
1124
  "state": state,
548
1125
  "attributes": {
1126
+ "device_class": device_class,
549
1127
  "unit_of_measurement": unit_of_measurement,
550
1128
  "friendly_name": friendly_name,
551
1129
  },
552
1130
  }
553
1131
  else:
554
1132
  data = {
555
- "state": "{:.2f}".format(state),
1133
+ "state": f"{state:.2f}",
556
1134
  "attributes": {
1135
+ "device_class": device_class,
557
1136
  "unit_of_measurement": unit_of_measurement,
558
1137
  "friendly_name": friendly_name,
559
1138
  },
560
1139
  }
561
1140
  # Actually post the data
562
1141
  if self.get_data_from_file or dont_post:
563
-
564
- class response:
565
- pass
566
-
567
- response.status_code = 200
568
- response.ok = True
1142
+ # Create mock response for file mode or dont_post mode
1143
+ self.logger.debug(
1144
+ f"Skipping actual POST (get_data_from_file={self.get_data_from_file}, dont_post={dont_post})"
1145
+ )
1146
+ response_ok = True
1147
+ response_status_code = 200
569
1148
  else:
570
- response = post(url, headers=headers, data=json.dumps(data))
1149
+ # Always use REST API for posting data, regardless of use_websocket setting
1150
+ self.logger.debug(f"Posting data to URL: {url}")
1151
+ try:
1152
+ async with aiohttp.ClientSession() as session:
1153
+ async with session.post(
1154
+ url, headers=headers, data=orjson.dumps(data).decode("utf-8")
1155
+ ) as response:
1156
+ # Store response data since we need to access it after the context manager
1157
+ response_ok = response.ok
1158
+ response_status_code = response.status
1159
+ self.logger.debug(
1160
+ f"HTTP POST response: ok={response_ok}, status={response_status_code}"
1161
+ )
1162
+ except Exception as e:
1163
+ self.logger.error(f"Failed to post data to {entity_id}: {e}")
1164
+ response_ok = False
1165
+ response_status_code = 500
571
1166
 
572
1167
  # Treating the response status and posting them on the logger
573
- if response.ok:
574
- if logger_levels == "DEBUG":
575
- self.logger.debug(
576
- "Successfully posted to " + entity_id + " = " + str(state)
577
- )
1168
+ if response_ok:
1169
+ if logger_levels == "DEBUG" or dont_post:
1170
+ self.logger.debug("Successfully posted to " + entity_id + " = " + str(state))
578
1171
  else:
579
- self.logger.info(
580
- "Successfully posted to " + entity_id + " = " + str(state)
581
- )
1172
+ self.logger.info("Successfully posted to " + entity_id + " = " + str(state))
582
1173
 
583
1174
  # If save entities is set, save entity data to /data_path/entities
584
1175
  if save_entities:
@@ -591,20 +1182,24 @@ class RetrieveHass:
591
1182
  result = data_df.to_json(
592
1183
  index="timestamp", orient="index", date_unit="s", date_format="iso"
593
1184
  )
594
- parsed = json.loads(result)
595
- with open(entities_path / (entity_id + ".json"), "w") as file:
596
- json.dump(parsed, file, indent=4)
1185
+ parsed = orjson.loads(result)
1186
+ async with aiofiles.open(entities_path / (entity_id + ".json"), "w") as file:
1187
+ await file.write(orjson.dumps(parsed, option=orjson.OPT_INDENT_2).decode())
597
1188
 
598
1189
  # Save the required metadata to json file
599
- if os.path.isfile(entities_path / "metadata.json"):
600
- with open(entities_path / "metadata.json", "r") as file:
601
- metadata = json.load(file)
1190
+ metadata_path = entities_path / "metadata.json"
1191
+ if os.path.isfile(metadata_path):
1192
+ async with aiofiles.open(metadata_path) as file:
1193
+ content = await file.read()
1194
+ metadata = orjson.loads(content)
602
1195
  else:
603
1196
  metadata = {}
604
- with open(entities_path / "metadata.json", "w") as file:
1197
+
1198
+ async with aiofiles.open(metadata_path, "w") as file:
605
1199
  # Save entity metadata, key = entity_id
606
1200
  metadata[entity_id] = {
607
1201
  "name": data_df.name,
1202
+ "device_class": device_class,
608
1203
  "unit_of_measurement": unit_of_measurement,
609
1204
  "friendly_name": friendly_name,
610
1205
  "type_var": type_var,
@@ -612,17 +1207,117 @@ class RetrieveHass:
612
1207
  }
613
1208
 
614
1209
  # Find lowest frequency to set for continual loop freq
615
- if metadata.get("lowest_time_step", None) == None or metadata[
1210
+ if metadata.get("lowest_time_step") is None or metadata[
616
1211
  "lowest_time_step"
617
1212
  ] > int(self.freq.seconds / 60):
618
1213
  metadata["lowest_time_step"] = int(self.freq.seconds / 60)
619
- json.dump(metadata, file, indent=4)
1214
+ await file.write(orjson.dumps(metadata, option=orjson.OPT_INDENT_2).decode())
620
1215
 
621
1216
  self.logger.debug("Saved " + entity_id + " to json file")
622
1217
 
623
1218
  else:
624
1219
  self.logger.warning(
625
- "The status code for received curl command response is: "
626
- + str(response.status_code)
1220
+ f"Failed to post data to {entity_id}. Status code: {response_status_code}"
627
1221
  )
628
- return response, data
1222
+
1223
+ # Create a response object to maintain compatibility
1224
+ class MockResponse:
1225
+ def __init__(self, ok, status_code):
1226
+ self.ok = ok
1227
+ self.status_code = status_code
1228
+
1229
+ mock_response = MockResponse(response_ok, response_status_code)
1230
+ self.logger.debug(f"Completed post_data for {entity_id}")
1231
+ return mock_response, data
1232
+
1233
+ def _convert_statistics_to_dataframe(
1234
+ self, stats_data: dict[str, Any], var_list: list[str]
1235
+ ) -> pd.DataFrame:
1236
+ """Convert WebSocket statistics data to DataFrame."""
1237
+ import pandas as pd
1238
+
1239
+ # Initialize empty DataFrame
1240
+ df_final = pd.DataFrame()
1241
+
1242
+ # The websocket manager already extracts the 'result' portion
1243
+ # so stats_data should be directly the entity data dictionary
1244
+
1245
+ for entity_id in var_list:
1246
+ if entity_id not in stats_data:
1247
+ self.logger.warning(f"No statistics data for {entity_id}")
1248
+ continue
1249
+
1250
+ entity_stats = stats_data[entity_id]
1251
+
1252
+ if not entity_stats:
1253
+ continue
1254
+
1255
+ # Convert statistics to DataFrame
1256
+ entity_data = []
1257
+ for _i, stat in enumerate(entity_stats):
1258
+ try:
1259
+ # Handle timestamp from start time (milliseconds or ISO string)
1260
+ if isinstance(stat["start"], int | float):
1261
+ # Convert from milliseconds to datetime with UTC timezone
1262
+ timestamp = pd.to_datetime(stat["start"], unit="ms", utc=True)
1263
+ else:
1264
+ # Assume ISO string
1265
+ timestamp = pd.to_datetime(stat["start"], utc=True)
1266
+
1267
+ # Use mean, max, min or sum depending on what's available
1268
+ value = None
1269
+ if "mean" in stat and stat["mean"] is not None:
1270
+ value = stat["mean"]
1271
+ elif "sum" in stat and stat["sum"] is not None:
1272
+ value = stat["sum"]
1273
+ elif "max" in stat and stat["max"] is not None:
1274
+ value = stat["max"]
1275
+ elif "min" in stat and stat["min"] is not None:
1276
+ value = stat["min"]
1277
+
1278
+ if value is not None:
1279
+ try:
1280
+ value = float(value)
1281
+ entity_data.append({"timestamp": timestamp, entity_id: value})
1282
+ except (ValueError, TypeError):
1283
+ self.logger.debug(f"Could not convert value to float: {value}")
1284
+
1285
+ except (KeyError, ValueError, TypeError) as e:
1286
+ self.logger.debug(f"Skipping invalid statistic for {entity_id}: {e}")
1287
+ continue
1288
+
1289
+ if entity_data:
1290
+ entity_df = pd.DataFrame(entity_data)
1291
+ entity_df.set_index("timestamp", inplace=True)
1292
+
1293
+ if df_final.empty:
1294
+ df_final = entity_df
1295
+ else:
1296
+ df_final = df_final.join(entity_df, how="outer")
1297
+
1298
+ # Process the final DataFrame
1299
+ if not df_final.empty:
1300
+ # Ensure timezone awareness - timestamps should already be UTC from conversion above
1301
+ if df_final.index.tz is None:
1302
+ # If somehow still naive, localize as UTC first then convert
1303
+ df_final.index = df_final.index.tz_localize("UTC").tz_convert(self.time_zone)
1304
+ else:
1305
+ # Convert from existing timezone to target timezone
1306
+ df_final.index = df_final.index.tz_convert(self.time_zone)
1307
+
1308
+ # Sort by index
1309
+ df_final = df_final.sort_index()
1310
+
1311
+ # Resample to frequency if needed
1312
+ try:
1313
+ df_final = df_final.resample(self.freq).mean()
1314
+ except Exception as e:
1315
+ self.logger.warning(f"Could not resample data to {self.freq}: {e}")
1316
+
1317
+ # Forward fill missing values
1318
+ df_final = df_final.ffill()
1319
+
1320
+ # Set frequency for the DataFrame index
1321
+ df_final = set_df_index_freq(df_final)
1322
+
1323
+ return df_final