emhass 0.10.6__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emhass/command_line.py +178 -85
- emhass/data/associations.csv +61 -0
- emhass/data/config_defaults.json +117 -0
- emhass/forecast.py +38 -36
- emhass/machine_learning_forecaster.py +2 -1
- emhass/machine_learning_regressor.py +7 -2
- emhass/optimization.py +62 -62
- emhass/retrieve_hass.py +9 -4
- emhass/static/advanced.html +2 -1
- emhass/static/basic.html +4 -2
- emhass/static/configuration_list.html +44 -0
- emhass/static/configuration_script.js +871 -0
- emhass/static/data/param_definitions.json +424 -0
- emhass/static/script.js +345 -322
- emhass/static/style.css +267 -8
- emhass/templates/configuration.html +75 -0
- emhass/templates/index.html +15 -8
- emhass/utils.py +620 -303
- emhass/web_server.py +323 -213
- {emhass-0.10.6.dist-info → emhass-0.11.0.dist-info}/METADATA +207 -169
- emhass-0.11.0.dist-info/RECORD +32 -0
- {emhass-0.10.6.dist-info → emhass-0.11.0.dist-info}/WHEEL +1 -1
- emhass-0.10.6.dist-info/RECORD +0 -26
- {emhass-0.10.6.dist-info → emhass-0.11.0.dist-info}/LICENSE +0 -0
- {emhass-0.10.6.dist-info → emhass-0.11.0.dist-info}/entry_points.txt +0 -0
- {emhass-0.10.6.dist-info → emhass-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,117 @@
|
|
1
|
+
{
|
2
|
+
"logging_level": "INFO",
|
3
|
+
"costfun": "profit",
|
4
|
+
"optimization_time_step": 30,
|
5
|
+
"historic_days_to_retrieve": 2,
|
6
|
+
"method_ts_round": "nearest",
|
7
|
+
"continual_publish": false,
|
8
|
+
"data_path": "default",
|
9
|
+
"set_total_pv_sell": false,
|
10
|
+
"lp_solver": "default",
|
11
|
+
"lp_solver_path": "empty",
|
12
|
+
"set_nocharge_from_grid": false,
|
13
|
+
"set_nodischarge_to_grid": true,
|
14
|
+
"set_battery_dynamic": false,
|
15
|
+
"battery_dynamic_max": 0.9,
|
16
|
+
"battery_dynamic_min": -0.9,
|
17
|
+
"weight_battery_discharge": 1.0,
|
18
|
+
"weight_battery_charge": 1.0,
|
19
|
+
"sensor_power_photovoltaics": "sensor.power_photovoltaics",
|
20
|
+
"sensor_power_load_no_var_loads": "sensor.power_load_no_var_loads",
|
21
|
+
"sensor_replace_zero": [
|
22
|
+
"sensor.power_photovoltaics",
|
23
|
+
"sensor.power_load_no_var_loads"
|
24
|
+
],
|
25
|
+
"sensor_linear_interp": [
|
26
|
+
"sensor.power_photovoltaics",
|
27
|
+
"sensor.power_load_no_var_loads"
|
28
|
+
],
|
29
|
+
"load_negative": false,
|
30
|
+
"set_zero_min": true,
|
31
|
+
"number_of_deferrable_loads": 2,
|
32
|
+
"nominal_power_of_deferrable_loads": [
|
33
|
+
3000.0,
|
34
|
+
750.0
|
35
|
+
],
|
36
|
+
"operating_hours_of_each_deferrable_load": [
|
37
|
+
4,
|
38
|
+
0
|
39
|
+
],
|
40
|
+
"weather_forecast_method": "scrapper",
|
41
|
+
"load_forecast_method": "naive",
|
42
|
+
"delta_forecast_daily": 1,
|
43
|
+
"load_cost_forecast_method": "hp_hc_periods",
|
44
|
+
"start_timesteps_of_each_deferrable_load": [
|
45
|
+
0,
|
46
|
+
0
|
47
|
+
],
|
48
|
+
"end_timesteps_of_each_deferrable_load": [
|
49
|
+
0,
|
50
|
+
0
|
51
|
+
],
|
52
|
+
"load_peak_hour_periods": {
|
53
|
+
"period_hp_1": [
|
54
|
+
{
|
55
|
+
"start": "02:54"
|
56
|
+
},
|
57
|
+
{
|
58
|
+
"end": "15:24"
|
59
|
+
}
|
60
|
+
],
|
61
|
+
"period_hp_2": [
|
62
|
+
{
|
63
|
+
"start": "17:24"
|
64
|
+
},
|
65
|
+
{
|
66
|
+
"end": "20:24"
|
67
|
+
}
|
68
|
+
]
|
69
|
+
},
|
70
|
+
"treat_deferrable_load_as_semi_cont": [
|
71
|
+
true,
|
72
|
+
true
|
73
|
+
],
|
74
|
+
"set_deferrable_load_single_constant": [
|
75
|
+
false,
|
76
|
+
false
|
77
|
+
],
|
78
|
+
"set_deferrable_startup_penalty": [
|
79
|
+
0.0,
|
80
|
+
0.0
|
81
|
+
],
|
82
|
+
"load_peak_hours_cost": 0.1907,
|
83
|
+
"load_offpeak_hours_cost": 0.1419,
|
84
|
+
"production_price_forecast_method": "constant",
|
85
|
+
"photovoltaic_production_sell_price": 0.1419,
|
86
|
+
"maximum_power_from_grid": 9000,
|
87
|
+
"maximum_power_to_grid": 9000,
|
88
|
+
"pv_module_model": [
|
89
|
+
"CSUN_Eurasia_Energy_Systems_Industry_and_Trade_CSUN295_60M"
|
90
|
+
],
|
91
|
+
"pv_inverter_model": [
|
92
|
+
"Fronius_International_GmbH__Fronius_Primo_5_0_1_208_240__240V_"
|
93
|
+
],
|
94
|
+
"surface_tilt": [
|
95
|
+
30
|
96
|
+
],
|
97
|
+
"surface_azimuth": [
|
98
|
+
205
|
99
|
+
],
|
100
|
+
"modules_per_string": [
|
101
|
+
16
|
102
|
+
],
|
103
|
+
"strings_per_inverter": [
|
104
|
+
1
|
105
|
+
],
|
106
|
+
"inverter_is_hybrid": false,
|
107
|
+
"compute_curtailment": false,
|
108
|
+
"set_use_battery": false,
|
109
|
+
"battery_discharge_power_max": 1000,
|
110
|
+
"battery_charge_power_max": 1000,
|
111
|
+
"battery_discharge_efficiency": 0.95,
|
112
|
+
"battery_charge_efficiency": 0.95,
|
113
|
+
"battery_nominal_energy_capacity": 5000,
|
114
|
+
"battery_minimum_state_of_charge": 0.3,
|
115
|
+
"battery_maximum_state_of_charge": 0.9,
|
116
|
+
"battery_target_state_of_charge": 0.6
|
117
|
+
}
|
emhass/forecast.py
CHANGED
@@ -132,22 +132,24 @@ class Forecast(object):
|
|
132
132
|
self.retrieve_hass_conf = retrieve_hass_conf
|
133
133
|
self.optim_conf = optim_conf
|
134
134
|
self.plant_conf = plant_conf
|
135
|
-
self.freq = self.retrieve_hass_conf['
|
135
|
+
self.freq = self.retrieve_hass_conf['optimization_time_step']
|
136
136
|
self.time_zone = self.retrieve_hass_conf['time_zone']
|
137
137
|
self.method_ts_round = self.retrieve_hass_conf['method_ts_round']
|
138
138
|
self.timeStep = self.freq.seconds/3600 # in hours
|
139
139
|
self.time_delta = pd.to_timedelta(opt_time_delta, "hours")
|
140
|
-
self.var_PV = self.retrieve_hass_conf['
|
141
|
-
self.var_load = self.retrieve_hass_conf['
|
140
|
+
self.var_PV = self.retrieve_hass_conf['sensor_power_photovoltaics']
|
141
|
+
self.var_load = self.retrieve_hass_conf['sensor_power_load_no_var_loads']
|
142
142
|
self.var_load_new = self.var_load+'_positive'
|
143
|
-
self.lat = self.retrieve_hass_conf['
|
144
|
-
self.lon = self.retrieve_hass_conf['
|
143
|
+
self.lat = self.retrieve_hass_conf['Latitude']
|
144
|
+
self.lon = self.retrieve_hass_conf['Longitude']
|
145
145
|
self.emhass_conf = emhass_conf
|
146
146
|
self.logger = logger
|
147
147
|
self.get_data_from_file = get_data_from_file
|
148
148
|
self.var_load_cost = 'unit_load_cost'
|
149
149
|
self.var_prod_price = 'unit_prod_price'
|
150
|
-
if params
|
150
|
+
if (params == None) or (params == "null"):
|
151
|
+
self.params = {}
|
152
|
+
elif type(params) is dict:
|
151
153
|
self.params = params
|
152
154
|
else:
|
153
155
|
self.params = json.loads(params)
|
@@ -159,10 +161,10 @@ class Forecast(object):
|
|
159
161
|
self.start_forecast = pd.Timestamp(datetime.now(), tz=self.time_zone).replace(microsecond=0).ceil(freq=self.freq)
|
160
162
|
else:
|
161
163
|
self.logger.error("Wrong method_ts_round passed parameter")
|
162
|
-
self.end_forecast = (self.start_forecast + self.optim_conf['
|
164
|
+
self.end_forecast = (self.start_forecast + self.optim_conf['delta_forecast_daily']).replace(microsecond=0)
|
163
165
|
self.forecast_dates = pd.date_range(start=self.start_forecast,
|
164
166
|
end=self.end_forecast-self.freq,
|
165
|
-
freq=self.freq).round(self.freq, ambiguous='infer', nonexistent='shift_forward')
|
167
|
+
freq=self.freq, tz=self.time_zone).tz_convert('utc').round(self.freq, ambiguous='infer', nonexistent='shift_forward').tz_convert(self.time_zone)
|
166
168
|
if params is not None:
|
167
169
|
if 'prediction_horizon' in list(self.params['passed_data'].keys()):
|
168
170
|
if self.params['passed_data']['prediction_horizon'] is not None:
|
@@ -190,7 +192,7 @@ class Forecast(object):
|
|
190
192
|
freq_scrap = pd.to_timedelta(60, "minutes") # The scrapping time step is 60min on clearoutside
|
191
193
|
forecast_dates_scrap = pd.date_range(start=self.start_forecast,
|
192
194
|
end=self.end_forecast-freq_scrap,
|
193
|
-
freq=freq_scrap).round(freq_scrap, ambiguous='infer', nonexistent='shift_forward')
|
195
|
+
freq=freq_scrap, tz=self.time_zone).tz_convert('utc').round(freq_scrap, ambiguous='infer', nonexistent='shift_forward').tz_convert(self.time_zone)
|
194
196
|
# Using the clearoutside webpage
|
195
197
|
response = get("https://clearoutside.com/forecast/"+str(round(self.lat, 2))+"/"+str(round(self.lon, 2))+"?desktop=true")
|
196
198
|
'''import bz2 # Uncomment to save a serialized data for tests
|
@@ -226,9 +228,9 @@ class Forecast(object):
|
|
226
228
|
data['temp_air'], data['relative_humidity'])
|
227
229
|
elif method == 'solcast': # using Solcast API
|
228
230
|
# Check if weather_forecast_cache is true or if forecast_data file does not exist
|
229
|
-
if
|
231
|
+
if not os.path.isfile(w_forecast_cache_path):
|
230
232
|
# Check if weather_forecast_cache_only is true, if so produce error for not finding cache file
|
231
|
-
if not self.params["passed_data"]
|
233
|
+
if not self.params["passed_data"].get("weather_forecast_cache_only",False):
|
232
234
|
# Retrieve data from the Solcast API
|
233
235
|
if 'solcast_api_key' not in self.retrieve_hass_conf:
|
234
236
|
self.logger.error("The solcast_api_key parameter was not defined")
|
@@ -243,7 +245,7 @@ class Forecast(object):
|
|
243
245
|
}
|
244
246
|
days_solcast = int(len(self.forecast_dates)*self.freq.seconds/3600)
|
245
247
|
# If weather_forecast_cache, set request days as twice as long to avoid length issues (add a buffer)
|
246
|
-
if self.params["passed_data"]
|
248
|
+
if self.params["passed_data"].get("weather_forecast_cache",False):
|
247
249
|
days_solcast = min((days_solcast * 2), 336)
|
248
250
|
url = "https://api.solcast.com.au/rooftop_sites/"+self.retrieve_hass_conf['solcast_rooftop_id']+"/forecasts?hours="+str(days_solcast)
|
249
251
|
response = get(url, headers=headers)
|
@@ -269,7 +271,7 @@ class Forecast(object):
|
|
269
271
|
self.logger.error("Not enough data retried from Solcast service, try increasing the time step or use MPC.")
|
270
272
|
else:
|
271
273
|
# If runtime weather_forecast_cache is true save forecast result to file as cache
|
272
|
-
if self.params["passed_data"]
|
274
|
+
if self.params["passed_data"].get("weather_forecast_cache",False):
|
273
275
|
# Add x2 forecast periods for cached results. This adds a extra delta_forecast amount of days for a buffer
|
274
276
|
cached_forecast_dates = self.forecast_dates.union(pd.date_range(self.forecast_dates[-1], periods=(len(self.forecast_dates) +1), freq=self.freq)[1:])
|
275
277
|
cache_data_list = data_list[0:len(cached_forecast_dates)]
|
@@ -289,11 +291,11 @@ class Forecast(object):
|
|
289
291
|
data = pd.DataFrame.from_dict(data_dict)
|
290
292
|
# Define index
|
291
293
|
data.set_index('ts', inplace=True)
|
292
|
-
|
294
|
+
# Else, notify user to update cache
|
293
295
|
else:
|
294
296
|
self.logger.error("Unable to obtain Solcast cache file.")
|
295
297
|
self.logger.error("Try running optimization again with 'weather_forecast_cache_only': false")
|
296
|
-
self.logger.error("Optionally, obtain new Solcast cache with runtime parameter 'weather_forecast_cache': true in an optimization, or run the `forecast-cache` action, to pull new data from Solcast and cache.")
|
298
|
+
self.logger.error("Optionally, obtain new Solcast cache with runtime parameter 'weather_forecast_cache': true in an optimization, or run the `weather-forecast-cache` action, to pull new data from Solcast and cache.")
|
297
299
|
return False
|
298
300
|
# Else, open stored weather_forecast_data.pkl file for previous forecast data (cached data)
|
299
301
|
else:
|
@@ -301,7 +303,7 @@ class Forecast(object):
|
|
301
303
|
data = cPickle.load(file)
|
302
304
|
if not isinstance(data, pd.DataFrame) or len(data) < len(self.forecast_dates):
|
303
305
|
self.logger.error("There has been a error obtaining cached Solcast forecast data.")
|
304
|
-
self.logger.error("Try running optimization again with 'weather_forecast_cache': true, or run action `forecast-cache`, to pull new data from Solcast and cache.")
|
306
|
+
self.logger.error("Try running optimization again with 'weather_forecast_cache': true, or run action `weather-forecast-cache`, to pull new data from Solcast and cache.")
|
305
307
|
self.logger.warning("Removing old Solcast cache file. Next optimization will pull data from Solcast, unless 'weather_forecast_cache_only': true")
|
306
308
|
os.remove(w_forecast_cache_path)
|
307
309
|
return False
|
@@ -323,17 +325,17 @@ class Forecast(object):
|
|
323
325
|
if self.retrieve_hass_conf['solar_forecast_kwp'] == 0:
|
324
326
|
self.logger.warning("The solar_forecast_kwp parameter is set to zero, setting to default 5")
|
325
327
|
self.retrieve_hass_conf['solar_forecast_kwp'] = 5
|
326
|
-
if self.optim_conf['
|
328
|
+
if self.optim_conf['delta_forecast_daily'].days > 1:
|
327
329
|
self.logger.warning("The free public tier for solar.forecast only provides one day forecasts")
|
328
330
|
self.logger.warning("Continuing with just the first day of data, the other days are filled with 0.0.")
|
329
|
-
self.logger.warning("Use the other available methods for
|
331
|
+
self.logger.warning("Use the other available methods for delta_forecast_daily > 1")
|
330
332
|
headers = {
|
331
333
|
"Accept": "application/json"
|
332
334
|
}
|
333
335
|
data = pd.DataFrame()
|
334
|
-
for i in range(len(self.plant_conf['
|
336
|
+
for i in range(len(self.plant_conf['pv_module_model'])):
|
335
337
|
url = "https://api.forecast.solar/estimate/"+str(round(self.lat, 2))+"/"+str(round(self.lon, 2))+\
|
336
|
-
"/"+str(self.plant_conf[
|
338
|
+
"/"+str(self.plant_conf['surface_tilt'][i])+"/"+str(self.plant_conf['surface_azimuth'][i]-180)+\
|
337
339
|
"/"+str(self.retrieve_hass_conf["solar_forecast_kwp"])
|
338
340
|
response = get(url, headers=headers)
|
339
341
|
'''import bz2 # Uncomment to save a serialized data for tests
|
@@ -485,12 +487,12 @@ class Forecast(object):
|
|
485
487
|
cec_modules = cPickle.load(cec_modules)
|
486
488
|
cec_inverters = bz2.BZ2File(self.emhass_conf['root_path'] / 'data' / 'cec_inverters.pbz2', "rb")
|
487
489
|
cec_inverters = cPickle.load(cec_inverters)
|
488
|
-
if type(self.plant_conf['
|
490
|
+
if type(self.plant_conf['pv_module_model']) == list:
|
489
491
|
P_PV_forecast = pd.Series(0, index=df_weather.index)
|
490
|
-
for i in range(len(self.plant_conf['
|
492
|
+
for i in range(len(self.plant_conf['pv_module_model'])):
|
491
493
|
# Selecting correct module and inverter
|
492
|
-
module = cec_modules[self.plant_conf['
|
493
|
-
inverter = cec_inverters[self.plant_conf['
|
494
|
+
module = cec_modules[self.plant_conf['pv_module_model'][i]]
|
495
|
+
inverter = cec_inverters[self.plant_conf['pv_inverter_model'][i]]
|
494
496
|
# Building the PV system in PVLib
|
495
497
|
system = PVSystem(surface_tilt=self.plant_conf['surface_tilt'][i],
|
496
498
|
surface_azimuth=self.plant_conf['surface_azimuth'][i],
|
@@ -506,8 +508,8 @@ class Forecast(object):
|
|
506
508
|
P_PV_forecast = P_PV_forecast + mc.results.ac
|
507
509
|
else:
|
508
510
|
# Selecting correct module and inverter
|
509
|
-
module = cec_modules[self.plant_conf['
|
510
|
-
inverter = cec_inverters[self.plant_conf['
|
511
|
+
module = cec_modules[self.plant_conf['pv_module_model']]
|
512
|
+
inverter = cec_inverters[self.plant_conf['pv_inverter_model']]
|
511
513
|
# Building the PV system in PVLib
|
512
514
|
system = PVSystem(surface_tilt=self.plant_conf['surface_tilt'],
|
513
515
|
surface_azimuth=self.plant_conf['surface_azimuth'],
|
@@ -544,10 +546,10 @@ class Forecast(object):
|
|
544
546
|
start_forecast_csv = pd.Timestamp(datetime.now(), tz=self.time_zone).replace(microsecond=0).ceil(freq=self.freq)
|
545
547
|
else:
|
546
548
|
self.logger.error("Wrong method_ts_round passed parameter")
|
547
|
-
end_forecast_csv = (start_forecast_csv + self.optim_conf['
|
549
|
+
end_forecast_csv = (start_forecast_csv + self.optim_conf['delta_forecast_daily']).replace(microsecond=0)
|
548
550
|
forecast_dates_csv = pd.date_range(start=start_forecast_csv,
|
549
551
|
end=end_forecast_csv+timedelta(days=timedelta_days)-self.freq,
|
550
|
-
freq=self.freq).round(self.freq, ambiguous='infer', nonexistent='shift_forward')
|
552
|
+
freq=self.freq, tz=self.time_zone).tz_convert('utc').round(self.freq, ambiguous='infer', nonexistent='shift_forward').tz_convert(self.time_zone)
|
551
553
|
if self.params is not None:
|
552
554
|
if 'prediction_horizon' in list(self.params['passed_data'].keys()):
|
553
555
|
if self.params['passed_data']['prediction_horizon'] is not None:
|
@@ -561,7 +563,7 @@ class Forecast(object):
|
|
561
563
|
Get the forecast data as a DataFrame from a CSV file.
|
562
564
|
|
563
565
|
The data contained in the CSV file should be a 24h forecast with the same frequency as
|
564
|
-
the main '
|
566
|
+
the main 'optimization_time_step' parameter in the configuration file. The timestamp will not be used and
|
565
567
|
a new DateTimeIndex is generated to fit the timestamp index of the input data in 'df_final'.
|
566
568
|
|
567
569
|
:param df_final: The DataFrame containing the input data.
|
@@ -695,7 +697,7 @@ class Forecast(object):
|
|
695
697
|
with open(filename_path, 'rb') as inp:
|
696
698
|
rh.df_final, days_list, var_list = pickle.load(inp)
|
697
699
|
self.var_load = var_list[0]
|
698
|
-
self.retrieve_hass_conf['
|
700
|
+
self.retrieve_hass_conf['sensor_power_load_no_var_loads'] = self.var_load
|
699
701
|
var_interp = [var_list[0]]
|
700
702
|
self.var_list = [var_list[0]]
|
701
703
|
self.var_load_new = self.var_load+'_positive'
|
@@ -704,13 +706,13 @@ class Forecast(object):
|
|
704
706
|
if not rh.get_data(days_list, var_list):
|
705
707
|
return False
|
706
708
|
if not rh.prepare_data(
|
707
|
-
self.retrieve_hass_conf['
|
709
|
+
self.retrieve_hass_conf['sensor_power_load_no_var_loads'], load_negative = self.retrieve_hass_conf['load_negative'],
|
708
710
|
set_zero_min = self.retrieve_hass_conf['set_zero_min'],
|
709
711
|
var_replace_zero = var_replace_zero, var_interp = var_interp):
|
710
712
|
return False
|
711
713
|
df = rh.df_final.copy()[[self.var_load_new]]
|
712
714
|
if method == 'naive': # using a naive approach
|
713
|
-
mask_forecast_out = (df.index > days_list[-1] - self.optim_conf['
|
715
|
+
mask_forecast_out = (df.index > days_list[-1] - self.optim_conf['delta_forecast_daily'])
|
714
716
|
forecast_out = df.copy().loc[mask_forecast_out]
|
715
717
|
forecast_out = forecast_out.rename(columns={self.var_load_new: 'yhat'})
|
716
718
|
# Force forecast_out length to avoid mismatches
|
@@ -812,13 +814,13 @@ class Forecast(object):
|
|
812
814
|
"""
|
813
815
|
csv_path = self.emhass_conf['data_path'] / csv_path
|
814
816
|
if method == 'hp_hc_periods':
|
815
|
-
df_final[self.var_load_cost] = self.optim_conf['
|
817
|
+
df_final[self.var_load_cost] = self.optim_conf['load_offpeak_hours_cost']
|
816
818
|
list_df_hp = []
|
817
|
-
for key, period_hp in self.optim_conf['
|
819
|
+
for key, period_hp in self.optim_conf['load_peak_hour_periods'].items():
|
818
820
|
list_df_hp.append(df_final[self.var_load_cost].between_time(
|
819
821
|
period_hp[0]['start'], period_hp[1]['end']))
|
820
822
|
for df_hp in list_df_hp:
|
821
|
-
df_final.loc[df_hp.index, self.var_load_cost] = self.optim_conf['
|
823
|
+
df_final.loc[df_hp.index, self.var_load_cost] = self.optim_conf['load_peak_hours_cost']
|
822
824
|
elif method == 'csv':
|
823
825
|
forecast_dates_csv = self.get_forecast_days_csv(timedelta_days=0)
|
824
826
|
forecast_out = self.get_forecast_out_from_csv_or_list(
|
@@ -871,7 +873,7 @@ class Forecast(object):
|
|
871
873
|
"""
|
872
874
|
csv_path = self.emhass_conf['data_path'] / csv_path
|
873
875
|
if method == 'constant':
|
874
|
-
df_final[self.var_prod_price] = self.optim_conf['
|
876
|
+
df_final[self.var_prod_price] = self.optim_conf['photovoltaic_production_sell_price']
|
875
877
|
elif method == 'csv':
|
876
878
|
forecast_dates_csv = self.get_forecast_days_csv(timedelta_days=0)
|
877
879
|
forecast_out = self.get_forecast_out_from_csv_or_list(
|
@@ -141,7 +141,8 @@ class MLForecaster:
|
|
141
141
|
elif self.sklearn_model == 'KNeighborsRegressor':
|
142
142
|
base_model = KNeighborsRegressor()
|
143
143
|
else:
|
144
|
-
self.logger.error("Passed sklearn model "+self.sklearn_model+" is not valid")
|
144
|
+
self.logger.error("Passed sklearn model "+self.sklearn_model+" is not valid. Defaulting to KNeighborsRegressor")
|
145
|
+
base_model = KNeighborsRegressor()
|
145
146
|
# Define the forecaster object
|
146
147
|
self.forecaster = ForecasterAutoreg(
|
147
148
|
regressor = base_model,
|
@@ -176,15 +176,17 @@ class MLRegressor:
|
|
176
176
|
"Passed model %s is not valid",
|
177
177
|
self.regression_model,
|
178
178
|
)
|
179
|
-
return None
|
179
|
+
return None, None
|
180
180
|
return base_model, param_grid
|
181
181
|
|
182
|
-
def fit(self: MLRegressor, date_features: list | None = None) ->
|
182
|
+
def fit(self: MLRegressor, date_features: list | None = None) -> bool:
|
183
183
|
r"""Fit the model using the provided data.
|
184
184
|
|
185
185
|
:param date_features: A list of 'date_features' to take into account when \
|
186
186
|
fitting the model.
|
187
187
|
:type data: list
|
188
|
+
:return: bool if successful
|
189
|
+
:rtype: bool
|
188
190
|
"""
|
189
191
|
self.logger.info("Performing a MLRegressor fit for %s", self.model_type)
|
190
192
|
self.data_exo = pd.DataFrame(self.data)
|
@@ -217,6 +219,8 @@ class MLRegressor:
|
|
217
219
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
218
220
|
self.steps = len(X_test)
|
219
221
|
base_model, param_grid = self.get_regression_model()
|
222
|
+
if base_model is None:
|
223
|
+
return False
|
220
224
|
self.model = make_pipeline(StandardScaler(), base_model)
|
221
225
|
# Create a grid search object
|
222
226
|
self.grid_search = GridSearchCV(self.model, param_grid, cv=5, scoring="neg_mean_squared_error",
|
@@ -235,6 +239,7 @@ class MLRegressor:
|
|
235
239
|
"Prediction R2 score of fitted model on test data: %s",
|
236
240
|
pred_metric,
|
237
241
|
)
|
242
|
+
return True
|
238
243
|
|
239
244
|
def predict(self: MLRegressor, new_values: list) -> np.ndarray:
|
240
245
|
"""Predict a new value.
|