dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
dsipts/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# src/dispts/__init__.py
|
|
2
|
+
|
|
3
|
+
from .data_management.monash import Monash, get_freq
|
|
4
|
+
from .data_management.public_datasets import read_public_dataset
|
|
5
|
+
from .data_structure.data_structure import TimeSeries, Categorical
|
|
6
|
+
from .data_structure.utils import extend_time_df, beauty_string
|
|
7
|
+
|
|
8
|
+
from .models.RNN import RNN
|
|
9
|
+
from .models.LinearTS import LinearTS
|
|
10
|
+
from .models.Persistent import Persistent
|
|
11
|
+
from .models.D3VAE import D3VAE
|
|
12
|
+
from .models.DilatedConv import DilatedConv
|
|
13
|
+
from .models.TFT import TFT
|
|
14
|
+
from .models.Informer import Informer
|
|
15
|
+
from .models.VVA import VVA
|
|
16
|
+
from .models.VQVAEA import VQVAEA
|
|
17
|
+
from .models.CrossFormer import CrossFormer
|
|
18
|
+
from .models.Autoformer import Autoformer
|
|
19
|
+
from .models.PatchTST import PatchTST
|
|
20
|
+
from .models.Diffusion import Diffusion
|
|
21
|
+
from .models.DilatedConvED import DilatedConvED
|
|
22
|
+
from .models.TIDE import TIDE
|
|
23
|
+
from .models.ITransformer import ITransformer
|
|
24
|
+
from .models.TimeXER import TimeXER
|
|
25
|
+
from .models.TTM import TTM
|
|
26
|
+
from .models.Samformer import Samformer
|
|
27
|
+
from .models.Duet import Duet
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
import lightning.pytorch as pl
|
|
31
|
+
from .models.base_v2 import Base
|
|
32
|
+
OLD_PL = False
|
|
33
|
+
except ImportError:
|
|
34
|
+
import pytorch_lightning as pl
|
|
35
|
+
from .models.base import Base
|
|
36
|
+
OLD_PL = True
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
# Data Management
|
|
40
|
+
"Monash", "get_freq", "read_public_dataset",
|
|
41
|
+
# Data Structure
|
|
42
|
+
"TimeSeries", "Categorical", "extend_time_df", "beauty_string",
|
|
43
|
+
# Models
|
|
44
|
+
"RNN", "LinearTS", "Persistent", "D3VAE", "DilatedConv", "TFT",
|
|
45
|
+
"Informer", "VVA", "VQVAEA", "CrossFormer", "Autoformer", "PatchTST",
|
|
46
|
+
"Diffusion", "DilatedConvED", "TIDE", "ITransformer", "TimeXER",
|
|
47
|
+
"TTM", "Samformer", "Duet", "Base"
|
|
48
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import requests
|
|
3
|
+
from bs4 import BeautifulSoup as bs
|
|
4
|
+
import pickle
|
|
5
|
+
import os
|
|
6
|
+
import shutil
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from distutils.util import strtobool
|
|
9
|
+
from typing import Union
|
|
10
|
+
import logging
|
|
11
|
+
|
|
12
|
+
# Converts the contents in a .tsf file into a dataframe and returns it along with other meta-data of the dataset: frequency, horizon, whether the dataset contains missing values and whether the series have equal lengths
|
|
13
|
+
#
|
|
14
|
+
# Parameters
|
|
15
|
+
# full_file_path_and_name - complete .tsf file path
|
|
16
|
+
# replace_missing_vals_with - a term to indicate the missing values in series in the returning dataframe
|
|
17
|
+
# value_column_name - Any name that is preferred to have as the name of the column containing series values in the returning dataframe
|
|
18
|
+
def convert_tsf_to_dataframe(
|
|
19
|
+
full_file_path_and_name:str,
|
|
20
|
+
replace_missing_vals_with:str="NaN",
|
|
21
|
+
value_column_name:str="series_value",
|
|
22
|
+
)-> pd.DataFrame:
|
|
23
|
+
"""I copied this function from the repo
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
full_file_path_and_name (str): path
|
|
28
|
+
replace_missing_vals_with (str, optional): replace not valid numbers. Defaults to "NaN".
|
|
29
|
+
value_column_name (str, optional):. Defaults to "series_value".
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
Exception: see https://forecastingdata.org/ for more information
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
pd.DataFrame: the selected timseries
|
|
37
|
+
"""
|
|
38
|
+
col_names = []
|
|
39
|
+
col_types = []
|
|
40
|
+
all_data = {}
|
|
41
|
+
line_count = 0
|
|
42
|
+
frequency = None
|
|
43
|
+
forecast_horizon = None
|
|
44
|
+
contain_missing_values = None
|
|
45
|
+
contain_equal_length = None
|
|
46
|
+
found_data_tag = False
|
|
47
|
+
found_data_section = False
|
|
48
|
+
started_reading_data_section = False
|
|
49
|
+
|
|
50
|
+
with open(full_file_path_and_name, "r", encoding="cp1252") as file:
|
|
51
|
+
for line in file:
|
|
52
|
+
# Strip white space from start/end of line
|
|
53
|
+
line = line.strip()
|
|
54
|
+
|
|
55
|
+
if line:
|
|
56
|
+
if line.startswith("@"): # Read meta-data
|
|
57
|
+
if not line.startswith("@data"):
|
|
58
|
+
line_content = line.split(" ")
|
|
59
|
+
if line.startswith("@attribute"):
|
|
60
|
+
if (
|
|
61
|
+
len(line_content) != 3
|
|
62
|
+
): # Attributes have both name and type
|
|
63
|
+
raise Exception("Invalid meta-data specification.")
|
|
64
|
+
|
|
65
|
+
col_names.append(line_content[1])
|
|
66
|
+
col_types.append(line_content[2])
|
|
67
|
+
else:
|
|
68
|
+
if (
|
|
69
|
+
len(line_content) != 2
|
|
70
|
+
): # Other meta-data have only values
|
|
71
|
+
raise Exception("Invalid meta-data specification.")
|
|
72
|
+
|
|
73
|
+
if line.startswith("@frequency"):
|
|
74
|
+
frequency = line_content[1]
|
|
75
|
+
elif line.startswith("@horizon"):
|
|
76
|
+
forecast_horizon = int(line_content[1])
|
|
77
|
+
elif line.startswith("@missing"):
|
|
78
|
+
contain_missing_values = bool(
|
|
79
|
+
strtobool(line_content[1])
|
|
80
|
+
)
|
|
81
|
+
elif line.startswith("@equallength"):
|
|
82
|
+
contain_equal_length = bool(strtobool(line_content[1]))
|
|
83
|
+
|
|
84
|
+
else:
|
|
85
|
+
if len(col_names) == 0:
|
|
86
|
+
raise Exception(
|
|
87
|
+
"Missing attribute section. Attribute section must come before data."
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
found_data_tag = True
|
|
91
|
+
elif not line.startswith("#"):
|
|
92
|
+
if len(col_names) == 0:
|
|
93
|
+
raise Exception(
|
|
94
|
+
"Missing attribute section. Attribute section must come before data."
|
|
95
|
+
)
|
|
96
|
+
elif not found_data_tag:
|
|
97
|
+
raise Exception("Missing @data tag.")
|
|
98
|
+
else:
|
|
99
|
+
if not started_reading_data_section:
|
|
100
|
+
started_reading_data_section = True
|
|
101
|
+
found_data_section = True
|
|
102
|
+
all_series = []
|
|
103
|
+
|
|
104
|
+
for col in col_names:
|
|
105
|
+
all_data[col] = []
|
|
106
|
+
|
|
107
|
+
full_info = line.split(":")
|
|
108
|
+
|
|
109
|
+
if len(full_info) != (len(col_names) + 1):
|
|
110
|
+
raise Exception("Missing attributes/values in series.")
|
|
111
|
+
|
|
112
|
+
series = full_info[len(full_info) - 1]
|
|
113
|
+
series = series.split(",")
|
|
114
|
+
|
|
115
|
+
if len(series) == 0:
|
|
116
|
+
raise Exception(
|
|
117
|
+
"A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series. Missing values should be indicated with ? symbol"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
numeric_series = []
|
|
121
|
+
|
|
122
|
+
for val in series:
|
|
123
|
+
if val == "?":
|
|
124
|
+
numeric_series.append(replace_missing_vals_with)
|
|
125
|
+
else:
|
|
126
|
+
numeric_series.append(float(val))
|
|
127
|
+
|
|
128
|
+
if numeric_series.count(replace_missing_vals_with) == len(
|
|
129
|
+
numeric_series
|
|
130
|
+
):
|
|
131
|
+
raise Exception(
|
|
132
|
+
"All series values are missing. A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series."
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
all_series.append(pd.Series(numeric_series).array)
|
|
136
|
+
|
|
137
|
+
for i in range(len(col_names)):
|
|
138
|
+
att_val = None
|
|
139
|
+
if col_types[i] == "numeric":
|
|
140
|
+
att_val = int(full_info[i])
|
|
141
|
+
elif col_types[i] == "string":
|
|
142
|
+
att_val = str(full_info[i])
|
|
143
|
+
elif col_types[i] == "date":
|
|
144
|
+
att_val = datetime.strptime(
|
|
145
|
+
full_info[i], "%Y-%m-%d %H-%M-%S"
|
|
146
|
+
)
|
|
147
|
+
else:
|
|
148
|
+
raise Exception(
|
|
149
|
+
"Invalid attribute type."
|
|
150
|
+
) # Currently, the code supports only numeric, string and date types. Extend this as required.
|
|
151
|
+
|
|
152
|
+
if att_val is None:
|
|
153
|
+
raise Exception("Invalid attribute value.")
|
|
154
|
+
else:
|
|
155
|
+
all_data[col_names[i]].append(att_val)
|
|
156
|
+
|
|
157
|
+
line_count = line_count + 1
|
|
158
|
+
|
|
159
|
+
if line_count == 0:
|
|
160
|
+
raise Exception("Empty file.")
|
|
161
|
+
if len(col_names) == 0:
|
|
162
|
+
raise Exception("Missing attribute section.")
|
|
163
|
+
if not found_data_section:
|
|
164
|
+
raise Exception("Missing series information under data section.")
|
|
165
|
+
|
|
166
|
+
all_data[value_column_name] = all_series
|
|
167
|
+
loaded_data = pd.DataFrame(all_data)
|
|
168
|
+
|
|
169
|
+
return (
|
|
170
|
+
loaded_data,
|
|
171
|
+
frequency,
|
|
172
|
+
forecast_horizon,
|
|
173
|
+
contain_missing_values,
|
|
174
|
+
contain_equal_length,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def get_freq(freq)->str:
|
|
179
|
+
"""Get the frequency based on the string reported. I don't think there are all the possibilities here
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
freq (str): string coming from
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
str: pandas frequency format
|
|
186
|
+
"""
|
|
187
|
+
if freq =='10_minutes':
|
|
188
|
+
return '600s'
|
|
189
|
+
elif freq == 'hourly':
|
|
190
|
+
return 'H'
|
|
191
|
+
else:
|
|
192
|
+
return 'D'
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class Monash():
|
|
196
|
+
|
|
197
|
+
def __init__(self,filename:str,baseUrl:str ='https://forecastingdata.org/', rebuild:bool =False):
|
|
198
|
+
"""Class for downloading datasets listed here https://forecastingdata.org/
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
filename (str): name of the class, used for saving
|
|
202
|
+
baseUrl (str, optional): url to the source page. Defaults to 'https://forecastingdata.org/'.
|
|
203
|
+
rebuild (bool, optional): if true the table will be loaded from the webpage otherwise it will be loaded from the saved file. Defaults to False.
|
|
204
|
+
"""
|
|
205
|
+
self.baseUrl = baseUrl
|
|
206
|
+
self.downloaded = {}
|
|
207
|
+
if rebuild is False:
|
|
208
|
+
logging.info(filename)
|
|
209
|
+
if os.path.exists(filename+'.pkl'):
|
|
210
|
+
self.load(filename)
|
|
211
|
+
else:
|
|
212
|
+
self.get_table(baseUrl)
|
|
213
|
+
self.save(filename)
|
|
214
|
+
else:
|
|
215
|
+
self.get_table(baseUrl)
|
|
216
|
+
self.save(filename)
|
|
217
|
+
|
|
218
|
+
def get_table(self, baseUrl):
|
|
219
|
+
""" get table
|
|
220
|
+
|
|
221
|
+
:meta private:
|
|
222
|
+
"""
|
|
223
|
+
with requests.Session() as s:
|
|
224
|
+
r = s.get(baseUrl)
|
|
225
|
+
soup = bs(r.content)
|
|
226
|
+
header = []
|
|
227
|
+
for x in soup.find("table", {"class": "responsive-table sortable"}).find('thead').find_all('th'):
|
|
228
|
+
header.append(x.text)
|
|
229
|
+
header
|
|
230
|
+
|
|
231
|
+
tot= []
|
|
232
|
+
for row in soup.find("table", {"class": "responsive-table sortable"}).find('tbody').find_all('tr'):
|
|
233
|
+
row_data = []
|
|
234
|
+
links = {}
|
|
235
|
+
for i,column in enumerate(row.find_all('td')):
|
|
236
|
+
tmp_links = column.find_all('a')
|
|
237
|
+
if len(tmp_links)>0:
|
|
238
|
+
|
|
239
|
+
for x in tmp_links:
|
|
240
|
+
if 'zenodo' in x['href']:
|
|
241
|
+
links[x.text] = x['href']
|
|
242
|
+
i_links = i
|
|
243
|
+
row_data.append(column.text)
|
|
244
|
+
for dataset in links:
|
|
245
|
+
row_to_insert = {}
|
|
246
|
+
for j, head in enumerate(header):
|
|
247
|
+
if j!=i_links:
|
|
248
|
+
row_to_insert[header[j]] = row_data[j]
|
|
249
|
+
else:
|
|
250
|
+
row_to_insert['freq'] = dataset
|
|
251
|
+
row_to_insert[header[j]] = links[dataset]
|
|
252
|
+
tot.append(row_to_insert)
|
|
253
|
+
|
|
254
|
+
tot = pd.DataFrame(tot)
|
|
255
|
+
tot['id'] = tot.Download.apply(lambda x:int(x.split('/')[-1]))
|
|
256
|
+
self.table = tot.copy()
|
|
257
|
+
|
|
258
|
+
def save(self, filename:str)-> None:
|
|
259
|
+
"""Save the monarch structure
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
filename (str): name of the file to generate
|
|
263
|
+
"""
|
|
264
|
+
logging.info('Saving')
|
|
265
|
+
with open(f'{filename}.pkl','wb') as f:
|
|
266
|
+
params = self.__dict__.copy()
|
|
267
|
+
#for k in ['data','data_train','data_test','data_validation']:
|
|
268
|
+
# if k in params.keys():
|
|
269
|
+
# _ = params.pop(k)
|
|
270
|
+
pickle.dump(params,f)
|
|
271
|
+
def load(self, filename:str)-> None:
|
|
272
|
+
"""Load a monarch structure
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
filename (str): filename to load
|
|
276
|
+
"""
|
|
277
|
+
logging.info('Loading')
|
|
278
|
+
with open(filename+'.pkl','rb') as f:
|
|
279
|
+
params = pickle.load(f)
|
|
280
|
+
for p in params:
|
|
281
|
+
setattr(self,p, params[p])
|
|
282
|
+
|
|
283
|
+
def download_dataset(self,path: str,id:int ,rebuild=False)->None:
|
|
284
|
+
"""download a specific dataset
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
path (str): path in which save the data
|
|
288
|
+
id (int): id of the dataset
|
|
289
|
+
rebuild (bool, optional): if true the dataset will be re-downloaded. Defaults to False.
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
if os.path.exists(path):
|
|
293
|
+
pass
|
|
294
|
+
else:
|
|
295
|
+
os.mkdir(path)
|
|
296
|
+
if os.path.exists(os.path.join(path,str(id))):
|
|
297
|
+
if rebuild:
|
|
298
|
+
file = self._download(url = self.table.Download[self.table.id== id].values[0], path = os.path.join(path,str(id)))
|
|
299
|
+
self.downloaded[id] = f'{path}/{id}/{file}'
|
|
300
|
+
else:
|
|
301
|
+
pass
|
|
302
|
+
else:
|
|
303
|
+
file = self._download(url = self.table.Download[self.table.id== id].values[0] , path = os.path.join(path,str(id)))
|
|
304
|
+
self.downloaded[id] = f'{path}/{id}/{file}'
|
|
305
|
+
|
|
306
|
+
def _download(self,url, path)->str:
|
|
307
|
+
""" get data
|
|
308
|
+
|
|
309
|
+
:meta private:
|
|
310
|
+
"""
|
|
311
|
+
with requests.Session() as s:
|
|
312
|
+
r = s.get(url)
|
|
313
|
+
soup = bs(r.content)
|
|
314
|
+
|
|
315
|
+
url = soup.find("link", {"type": "application/zip"})['href']
|
|
316
|
+
logging.info(url)
|
|
317
|
+
with open(path+'.zip', "wb") as f:
|
|
318
|
+
f.write(s.get(url).content)
|
|
319
|
+
|
|
320
|
+
shutil.unpack_archive(path+'.zip', path)
|
|
321
|
+
os.remove(path+'.zip')
|
|
322
|
+
return os.listdir(path)[0]
|
|
323
|
+
|
|
324
|
+
def generate_dataset(self, id:int)-> Union[None, pd.DataFrame]:
|
|
325
|
+
"""Parse the id-th dataset in a convient format and return a pandas dataset
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
id (int): id of the dataset
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
None or pd.DataFrame: dataframe
|
|
332
|
+
"""
|
|
333
|
+
|
|
334
|
+
if id not in self.downloaded.keys():
|
|
335
|
+
logging.error('please call first download dataset')
|
|
336
|
+
return None
|
|
337
|
+
else:
|
|
338
|
+
return convert_tsf_to_dataframe(self.downloaded[id])
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import os
|
|
3
|
+
import numpy as np
|
|
4
|
+
from typing import List, Tuple
|
|
5
|
+
import logging
|
|
6
|
+
import requests
|
|
7
|
+
from bs4 import BeautifulSoup as bs
|
|
8
|
+
|
|
9
|
+
def build_venice(path:str,url='https://www.comune.venezia.it/it/content/archivio-storico-livello-marea-venezia-1')->None:
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
with requests.Session() as s:
|
|
13
|
+
r = s.get(url)
|
|
14
|
+
soup = bs(r.content)
|
|
15
|
+
|
|
16
|
+
print('CARE THE STRUCTURE OF THE SITE CAN BE CHANGED')
|
|
17
|
+
|
|
18
|
+
def cast_string(x):
|
|
19
|
+
if np.isfinite(x) is False:
|
|
20
|
+
return x
|
|
21
|
+
if x<10:
|
|
22
|
+
return f'0{int(x)}:00'
|
|
23
|
+
else:
|
|
24
|
+
return f'{int(x)}:00'
|
|
25
|
+
|
|
26
|
+
def cast_month(x):
|
|
27
|
+
try:
|
|
28
|
+
return x.replace('gen','01').replace('feb','02').replace('mar','03').replace('apr','04').replace('mag','05').replace('giu','06').replace('lug','07').replace('ago','08').replace('set','09').replace('ott','10').replace('nov','11').replace('dic','12')
|
|
29
|
+
except:
|
|
30
|
+
return x
|
|
31
|
+
|
|
32
|
+
def remove_float(table,column):
|
|
33
|
+
if table[column].dtype in [int,float]:
|
|
34
|
+
table[column] = table[column].apply(lambda x:cast_string(x))
|
|
35
|
+
else:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def remove_str(table,column):
|
|
39
|
+
table[column] = table[column].apply(lambda x:cast_month(x))
|
|
40
|
+
|
|
41
|
+
def normalize(table):
|
|
42
|
+
columns = table.columns
|
|
43
|
+
table = table[~table.isna().all(axis=1)]
|
|
44
|
+
if 'Data_ora(solare)' in columns:
|
|
45
|
+
table['time'] = table['Data_ora(solare)']
|
|
46
|
+
|
|
47
|
+
elif 'GIORNO' in columns and 'ORA solare' in columns:
|
|
48
|
+
remove_float(table,'ORA solare')
|
|
49
|
+
table['time'] = table['GIORNO'] +' '+ table['ORA solare']
|
|
50
|
+
|
|
51
|
+
elif 'data' in columns and 'ora solare' in columns:
|
|
52
|
+
remove_float(table,'ora solare')
|
|
53
|
+
table['time'] =table['data'] +' '+ table['ora solare']
|
|
54
|
+
|
|
55
|
+
elif 'Data' in columns and 'Ora solare' in columns:
|
|
56
|
+
remove_str(table,'Data')
|
|
57
|
+
remove_float(table,'Ora solare')
|
|
58
|
+
table['time'] = table['Data'] +' '+ table['Ora solare']
|
|
59
|
+
elif 'GIORNO' in columns and 'ORA' in columns:
|
|
60
|
+
remove_str(table,'GIORNO')
|
|
61
|
+
remove_float(table,'ORA')
|
|
62
|
+
table['time'] = table['GIORNO'] +' '+ table['ORA']
|
|
63
|
+
|
|
64
|
+
else:
|
|
65
|
+
import pdb
|
|
66
|
+
pdb.set_trace()
|
|
67
|
+
|
|
68
|
+
for c in columns:
|
|
69
|
+
if 'Salute' in c:
|
|
70
|
+
table['y'] = table[c].values
|
|
71
|
+
if 'cm' in c:
|
|
72
|
+
table['y']/=100
|
|
73
|
+
res = table[['time','y']].dropna()
|
|
74
|
+
try:
|
|
75
|
+
res['time'] = pd.to_datetime(res['time'],format='mixed')
|
|
76
|
+
except:
|
|
77
|
+
import pdb
|
|
78
|
+
pdb.set_trace()
|
|
79
|
+
return res
|
|
80
|
+
tot= []
|
|
81
|
+
for row in soup.find_all("table")[1].find('tbody').find_all('tr'):
|
|
82
|
+
for i,column in enumerate(row.find_all('td')):
|
|
83
|
+
tmp_links = column.find_all('a')
|
|
84
|
+
if len(tmp_links)>0:
|
|
85
|
+
for x in tmp_links:
|
|
86
|
+
if 'orari' in x['href']:
|
|
87
|
+
tmp = pd.read_csv('https://www.comune.venezia.it/'+x['href'],sep=';', parse_dates=True)
|
|
88
|
+
tot.append(normalize(tmp))
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
res = pd.concat(tot)
|
|
92
|
+
res.sort_values(by='time',inplace=True)
|
|
93
|
+
res.to_csv(f'{path}/venice.csv',index=False)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def read_public_dataset(path:str,dataset:str)->Tuple[pd.DataFrame,List[str]]:
|
|
98
|
+
"""
|
|
99
|
+
Returns the public dataset chosen. Pleas download the dataset from here https://drive.google.com/drive/folders/1ZOYpTUa82_jCcxIdTmyr0LXQfvaM9vIy or ask to agobbi@fbk.eu.
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
path (str): path to data
|
|
104
|
+
dataset (str): dataset (one of 'electricity','etth1','etth2','ettm1','ettm2','exchange_rate','illness','traffic','weather')
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Tuple[pd.DataFrame,List[str]]: The target variable is *y* and the time index is *time* and the list of the covariates
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
if os.path.isdir(path):
|
|
113
|
+
pass
|
|
114
|
+
else:
|
|
115
|
+
logging.info('I will try to create the folder')
|
|
116
|
+
os.mkdir(path)
|
|
117
|
+
|
|
118
|
+
files = os.listdir(path)
|
|
119
|
+
if 'all_six_datasets' in files:
|
|
120
|
+
pass
|
|
121
|
+
else:
|
|
122
|
+
logging.error('Please dowload the zip file form here and unzip it https://drive.google.com/drive/folders/1ZOYpTUa82_jCcxIdTmyr0LXQfvaM9vIy')
|
|
123
|
+
return None,None
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
if dataset not in ['electricity','etth1','etth2','ettm1','ettm2','exchange_rate','illness','traffic','weather','venice']:
|
|
127
|
+
logging.error(f"Dataset {dataset} not available, use one among ['electricity','etth1','etth2','ettm1','ettm2','exchange_rate','illness','traffic','weather','venice']")
|
|
128
|
+
return None,None
|
|
129
|
+
|
|
130
|
+
if dataset=='electricity':
|
|
131
|
+
dataset = pd.read_csv(os.path.join(path,'all_six_datasets/electricity/electricity.csv'),sep=',',na_values=-9999)
|
|
132
|
+
elif dataset=='etth1':
|
|
133
|
+
dataset = pd.read_csv(os.path.join(path,'all_six_datasets/ETT-small/ETTh1.csv'),sep=',',na_values=-9999)
|
|
134
|
+
elif dataset=='etth2':
|
|
135
|
+
dataset = pd.read_csv(os.path.join(path,'all_six_datasets/ETT-small/ETTh2.csv'),sep=',',na_values=-9999)
|
|
136
|
+
elif dataset=='ettm1':
|
|
137
|
+
dataset = pd.read_csv(os.path.join(path,'all_six_datasets/ETT-small/ETTm1.csv'),sep=',',na_values=-9999)
|
|
138
|
+
elif dataset=='ettm2':
|
|
139
|
+
dataset = pd.read_csv(os.path.join(path,'all_six_datasets/ETT-small/ETTm2.csv'),sep=',',na_values=-9999)
|
|
140
|
+
elif dataset=='exchange_rate':
|
|
141
|
+
dataset = pd.read_csv(os.path.join(path,'all_six_datasets/exchange_rate/exchange_rate.csv'),sep=',',na_values=-9999)
|
|
142
|
+
elif dataset=='illness':
|
|
143
|
+
dataset = pd.read_csv(os.path.join(path,'all_six_datasets/illness/national_illness.csv'),sep=',',na_values=-9999)
|
|
144
|
+
elif dataset=='traffic':
|
|
145
|
+
dataset = pd.read_csv(os.path.join(path,'all_six_datasets/traffic/traffic.csv'),sep=',',na_values=-9999)
|
|
146
|
+
elif dataset=='weather':
|
|
147
|
+
dataset = pd.read_csv(os.path.join(path,'all_six_datasets/weather/weather.csv'),sep=',',na_values=-9999)
|
|
148
|
+
elif dataset=='venice':
|
|
149
|
+
if os.path.isfile(os.path.join(path,'venice.csv')):
|
|
150
|
+
dataset = pd.read_csv(os.path.join(path,'venice.csv'))
|
|
151
|
+
else:
|
|
152
|
+
logging.info('I WILL TRY TO DOWNLOAD IT, if there are errors please have a look to `build_venice` function')
|
|
153
|
+
build_venice(path,url='https://www.comune.venezia.it/it/content/archivio-storico-livello-marea-venezia-1')
|
|
154
|
+
dataset = pd.read_csv(os.path.join(path,'venice.csv'))
|
|
155
|
+
else:
|
|
156
|
+
logging.error(f'Dataset {dataset} not found')
|
|
157
|
+
return None, None
|
|
158
|
+
dataset.rename(columns={'date':'time','OT':'y'},inplace=True)
|
|
159
|
+
dataset.time = pd.to_datetime(dataset.time)
|
|
160
|
+
logging.info(f'Dataset loaded with shape {dataset.shape}')
|
|
161
|
+
|
|
162
|
+
return dataset, list(set(dataset.columns).difference(set(['time','y'])))
|
|
File without changes
|