forex_data_aggregator 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- forex_data/__init__.py +92 -0
- forex_data/config/__init__.py +20 -0
- forex_data/config/config_file.py +89 -0
- forex_data/data_management/__init__.py +84 -0
- forex_data/data_management/common.py +1773 -0
- forex_data/data_management/database.py +1322 -0
- forex_data/data_management/historicaldata.py +1262 -0
- forex_data/data_management/realtimedata.py +993 -0
- forex_data_aggregator-0.1.2.dist-info/LICENSE +21 -0
- forex_data_aggregator-0.1.2.dist-info/METADATA +562 -0
- forex_data_aggregator-0.1.2.dist-info/RECORD +12 -0
- forex_data_aggregator-0.1.2.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,1322 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Created on Sun Feb 23 00:02:36 2025
|
|
4
|
+
|
|
5
|
+
@author: fiora
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
'''
|
|
9
|
+
Module to connect to a database instance
|
|
10
|
+
|
|
11
|
+
Design constraint:
|
|
12
|
+
|
|
13
|
+
start with only support for polars, prefer lazyframe when possibile
|
|
14
|
+
|
|
15
|
+
read and write using polars dataframe or lazyframe
|
|
16
|
+
exec requests using SQL query language
|
|
17
|
+
OSS versions for windows required
|
|
18
|
+
'''
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
from loguru import logger
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from attrs import (
|
|
24
|
+
define,
|
|
25
|
+
field,
|
|
26
|
+
validate,
|
|
27
|
+
validators
|
|
28
|
+
)
|
|
29
|
+
from re import (
|
|
30
|
+
fullmatch,
|
|
31
|
+
search,
|
|
32
|
+
IGNORECASE
|
|
33
|
+
)
|
|
34
|
+
from collections import OrderedDict
|
|
35
|
+
from numpy import array
|
|
36
|
+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
|
37
|
+
from pathlib import Path as PathType
|
|
38
|
+
from datetime import datetime
|
|
39
|
+
|
|
40
|
+
# Import polars types directly
|
|
41
|
+
from polars import (
|
|
42
|
+
DataFrame as polars_dataframe,
|
|
43
|
+
LazyFrame as polars_lazyframe,
|
|
44
|
+
read_database
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Import from adbc_driver_sqlite
|
|
48
|
+
from adbc_driver_sqlite import dbapi as sqlite_dbapi
|
|
49
|
+
|
|
50
|
+
# Import from sqlalchemy for text queries
|
|
51
|
+
try:
|
|
52
|
+
from sqlalchemy import text
|
|
53
|
+
except ImportError:
|
|
54
|
+
# Fallback if sqlalchemy not available
|
|
55
|
+
def text(s: str) -> str:
|
|
56
|
+
return s
|
|
57
|
+
|
|
58
|
+
# Import from common module - explicit imports for items not in __all__
|
|
59
|
+
from .common import (
|
|
60
|
+
TICK_TIMEFRAME,
|
|
61
|
+
BASE_DATA_COLUMN_NAME,
|
|
62
|
+
DATA_KEY,
|
|
63
|
+
DATA_TYPE,
|
|
64
|
+
SUPPORTED_DATA_FILES,
|
|
65
|
+
SUPPORTED_DATA_ENGINES,
|
|
66
|
+
SUPPORTED_BASE_DATA_COLUMN_NAME,
|
|
67
|
+
SUPPORTED_SQL_COMPARISON_OPERATORS,
|
|
68
|
+
get_attrs_names,
|
|
69
|
+
validator_dir_path,
|
|
70
|
+
is_empty_dataframe,
|
|
71
|
+
list_remove_duplicates,
|
|
72
|
+
POLARS_DTYPE_DICT
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Import remaining items via star import
|
|
76
|
+
from .common import *
|
|
77
|
+
|
|
78
|
+
# Import from config module
|
|
79
|
+
from ..config import (
|
|
80
|
+
read_config_file,
|
|
81
|
+
read_config_string,
|
|
82
|
+
read_config_folder
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
'''
|
|
87
|
+
BASE CONNECTOR
|
|
88
|
+
'''
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@define(kw_only=True, slots=True)
|
|
92
|
+
class DatabaseConnector:
|
|
93
|
+
|
|
94
|
+
data_folder: str = field(default='',
|
|
95
|
+
validator=validators.instance_of(str))
|
|
96
|
+
|
|
97
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
98
|
+
|
|
99
|
+
pass
|
|
100
|
+
|
|
101
|
+
def __attrs_post_init__(self) -> None:
|
|
102
|
+
|
|
103
|
+
# create data folder if not exists
|
|
104
|
+
if (
|
|
105
|
+
not Path(self.data_folder).exists()
|
|
106
|
+
or
|
|
107
|
+
not Path(self.data_folder).is_dir()
|
|
108
|
+
):
|
|
109
|
+
|
|
110
|
+
Path(self.data_folder).mkdir(parents=True,
|
|
111
|
+
exist_ok=True)
|
|
112
|
+
|
|
113
|
+
def connect(self) -> Any:
|
|
114
|
+
"""Connect to database - must be implemented by subclasses."""
|
|
115
|
+
raise NotImplementedError("Subclasses must implement connect")
|
|
116
|
+
|
|
117
|
+
def check_connection(self) -> bool:
|
|
118
|
+
"""Check database connection - must be implemented by subclasses."""
|
|
119
|
+
raise NotImplementedError("Subclasses must implement check_connection")
|
|
120
|
+
|
|
121
|
+
def write_data(self, target_table: str, dataframe: Union[polars_dataframe, polars_lazyframe], clean: bool = False) -> None:
|
|
122
|
+
"""Write data to database - must be implemented by subclasses."""
|
|
123
|
+
raise NotImplementedError("Subclasses must implement write_data")
|
|
124
|
+
|
|
125
|
+
def read_data(self, market: str, ticker: str, timeframe: str, start: datetime, end: datetime) -> polars_lazyframe:
|
|
126
|
+
"""Read data from database - must be implemented by subclasses."""
|
|
127
|
+
raise NotImplementedError("Subclasses must implement read_data")
|
|
128
|
+
|
|
129
|
+
def exec_sql(self) -> None:
|
|
130
|
+
"""Execute SQL - must be implemented by subclasses."""
|
|
131
|
+
raise NotImplementedError("Subclasses must implement exec_sql")
|
|
132
|
+
|
|
133
|
+
def _db_key(self, market: str, ticker: str, timeframe: str) -> str:
|
|
134
|
+
"""Generate database key - must be implemented by subclasses."""
|
|
135
|
+
raise NotImplementedError("Subclasses must implement _db_key")
|
|
136
|
+
|
|
137
|
+
def get_tickers_list(self) -> List[str]:
|
|
138
|
+
"""Get list of tickers - must be implemented by subclasses."""
|
|
139
|
+
raise NotImplementedError("Subclasses must implement get_tickers_list")
|
|
140
|
+
|
|
141
|
+
def get_ticker_keys(self, ticker: str, timeframe: Optional[str] = None) -> List[str]:
|
|
142
|
+
"""Get ticker keys - must be implemented by subclasses."""
|
|
143
|
+
raise NotImplementedError("Subclasses must implement get_ticker_keys")
|
|
144
|
+
|
|
145
|
+
def get_ticker_years_list(self, ticker: str, timeframe: str = TICK_TIMEFRAME) -> List[int]:
|
|
146
|
+
"""Get years list for ticker - must be implemented by subclasses."""
|
|
147
|
+
raise NotImplementedError("Subclasses must implement get_ticker_years_list")
|
|
148
|
+
|
|
149
|
+
def clear_database(self, filter: Optional[str] = None) -> None:
|
|
150
|
+
"""Clear database - must be implemented by subclasses."""
|
|
151
|
+
raise NotImplementedError("Subclasses must implement clear_database")
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
'''
|
|
155
|
+
DUCKDB CONNECTOR:
|
|
156
|
+
|
|
157
|
+
TABLE TEMPLATE:
|
|
158
|
+
<trading field (e.g. Forex, Stocks)>.ticker.timeframe
|
|
159
|
+
|
|
160
|
+
'''
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@define(kw_only=True, slots=True)
|
|
164
|
+
class DuckDBConnector(DatabaseConnector):
|
|
165
|
+
|
|
166
|
+
_duckdb_filepath = field(default='',
|
|
167
|
+
validator=validators.instance_of(str))
|
|
168
|
+
|
|
169
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
170
|
+
|
|
171
|
+
_class_attributes_name = get_attrs_names(self, **kwargs)
|
|
172
|
+
_not_assigned_attrs_index_mask = [True] * len(_class_attributes_name)
|
|
173
|
+
|
|
174
|
+
if 'config' in kwargs.keys():
|
|
175
|
+
|
|
176
|
+
if kwargs['config']:
|
|
177
|
+
|
|
178
|
+
config_path = Path(kwargs['config'])
|
|
179
|
+
|
|
180
|
+
if (
|
|
181
|
+
config_path.exists() and
|
|
182
|
+
config_path.is_dir()
|
|
183
|
+
):
|
|
184
|
+
|
|
185
|
+
config_filepath = read_config_folder(config_path,
|
|
186
|
+
file_pattern='_config.yaml')
|
|
187
|
+
|
|
188
|
+
else:
|
|
189
|
+
|
|
190
|
+
config_filepath = Path()
|
|
191
|
+
|
|
192
|
+
config_args = {}
|
|
193
|
+
if config_filepath.exists() \
|
|
194
|
+
and \
|
|
195
|
+
config_filepath.is_file() \
|
|
196
|
+
and \
|
|
197
|
+
config_filepath.suffix == '.yaml':
|
|
198
|
+
|
|
199
|
+
# read parameters from config file
|
|
200
|
+
# and force keys to lower case
|
|
201
|
+
config_args = {key.lower(): val for key, val in
|
|
202
|
+
read_config_file(str(config_filepath)).items()}
|
|
203
|
+
|
|
204
|
+
elif isinstance(kwargs['config'], str):
|
|
205
|
+
|
|
206
|
+
# read parameters from config file
|
|
207
|
+
# and force keys to lower case
|
|
208
|
+
config_args = {key.lower(): val for key, val in
|
|
209
|
+
read_config_string(kwargs['config']).items()}
|
|
210
|
+
|
|
211
|
+
else:
|
|
212
|
+
|
|
213
|
+
logger.critical('invalid config type '
|
|
214
|
+
f'{kwargs["config"]}: '
|
|
215
|
+
'required str or Path, got '
|
|
216
|
+
f'{type(kwargs["config"])}')
|
|
217
|
+
raise TypeError
|
|
218
|
+
|
|
219
|
+
# check consistency of config_args
|
|
220
|
+
if (
|
|
221
|
+
not isinstance(config_args, dict) or
|
|
222
|
+
not bool(config_args)
|
|
223
|
+
):
|
|
224
|
+
|
|
225
|
+
logger.critical(f'config {kwargs["config"]} '
|
|
226
|
+
'has no valid yaml formatted data')
|
|
227
|
+
raise TypeError
|
|
228
|
+
|
|
229
|
+
# set args from config file
|
|
230
|
+
attrs_keys_configfile = \
|
|
231
|
+
set(_class_attributes_name).intersection(config_args.keys())
|
|
232
|
+
|
|
233
|
+
for attr_key in attrs_keys_configfile:
|
|
234
|
+
|
|
235
|
+
self.__setattr__(attr_key,
|
|
236
|
+
config_args[attr_key])
|
|
237
|
+
|
|
238
|
+
_not_assigned_attrs_index_mask[
|
|
239
|
+
_class_attributes_name.index(attr_key)
|
|
240
|
+
] = False
|
|
241
|
+
|
|
242
|
+
# set args from instantiation
|
|
243
|
+
# override if attr already has a value from config
|
|
244
|
+
attrs_keys_input = \
|
|
245
|
+
set(_class_attributes_name).intersection(kwargs.keys())
|
|
246
|
+
|
|
247
|
+
for attr_key in attrs_keys_input:
|
|
248
|
+
|
|
249
|
+
self.__setattr__(attr_key,
|
|
250
|
+
kwargs[attr_key])
|
|
251
|
+
|
|
252
|
+
_not_assigned_attrs_index_mask[
|
|
253
|
+
_class_attributes_name.index(attr_key)
|
|
254
|
+
] = False
|
|
255
|
+
|
|
256
|
+
# attrs not present in config file or instance inputs
|
|
257
|
+
# --> self.attr leads to KeyError
|
|
258
|
+
# are manually assigned to default value derived
|
|
259
|
+
# from __attrs_attrs__
|
|
260
|
+
|
|
261
|
+
for attr_key in array(_class_attributes_name)[
|
|
262
|
+
_not_assigned_attrs_index_mask
|
|
263
|
+
]:
|
|
264
|
+
|
|
265
|
+
try:
|
|
266
|
+
|
|
267
|
+
attr = [attr
|
|
268
|
+
for attr in self.__attrs_attrs__
|
|
269
|
+
if attr.name == attr_key][0]
|
|
270
|
+
|
|
271
|
+
except KeyError:
|
|
272
|
+
|
|
273
|
+
logger.warning('KeyError: initializing object has no '
|
|
274
|
+
f'attribute {attr.name}')
|
|
275
|
+
raise
|
|
276
|
+
|
|
277
|
+
except IndexError:
|
|
278
|
+
|
|
279
|
+
logger.warning('IndexError: initializing object has no '
|
|
280
|
+
f'attribute {attr.name}')
|
|
281
|
+
raise
|
|
282
|
+
|
|
283
|
+
else:
|
|
284
|
+
|
|
285
|
+
# assign default value
|
|
286
|
+
# try default and factory sabsequently
|
|
287
|
+
# if neither are present
|
|
288
|
+
# assign None
|
|
289
|
+
if hasattr(attr, 'default'):
|
|
290
|
+
|
|
291
|
+
if hasattr(attr.default, 'factory'):
|
|
292
|
+
self.__setattr__(attr.name,
|
|
293
|
+
attr.default.factory())
|
|
294
|
+
|
|
295
|
+
else:
|
|
296
|
+
|
|
297
|
+
self.__setattr__(attr.name,
|
|
298
|
+
attr.default)
|
|
299
|
+
|
|
300
|
+
else:
|
|
301
|
+
|
|
302
|
+
self.__setattr__(attr.name,
|
|
303
|
+
None)
|
|
304
|
+
|
|
305
|
+
else:
|
|
306
|
+
logger.trace(f'config {kwargs["config"]} is empty, using default configuration')
|
|
307
|
+
|
|
308
|
+
else:
|
|
309
|
+
|
|
310
|
+
# no config file is defined
|
|
311
|
+
# call generated init
|
|
312
|
+
self.__attrs_init__(**kwargs) # type: ignore[attr-defined]
|
|
313
|
+
|
|
314
|
+
validate(self)
|
|
315
|
+
|
|
316
|
+
self.__attrs_post_init__(**kwargs)
|
|
317
|
+
|
|
318
|
+
def __attrs_post_init__(self, **kwargs: Any) -> None:
|
|
319
|
+
|
|
320
|
+
super().__attrs_post_init__(**kwargs)
|
|
321
|
+
|
|
322
|
+
# set up log sink for DuckDB
|
|
323
|
+
logger.add(Path(self.data_folder) / 'log' / 'duckdb.log',
|
|
324
|
+
level="TRACE",
|
|
325
|
+
rotation="5 MB",
|
|
326
|
+
filter=lambda record: ('duckdb' == record['extra'].get('target') and
|
|
327
|
+
bool(record["extra"].get('target'))))
|
|
328
|
+
|
|
329
|
+
# create duck file path if not exists
|
|
330
|
+
if (
|
|
331
|
+
not Path(self.duckdb_filepath).exists() or
|
|
332
|
+
not Path(self.duckdb_filepath).is_file()
|
|
333
|
+
):
|
|
334
|
+
|
|
335
|
+
Path(self.duckdb_filepath).parent.mkdir(parents=True,
|
|
336
|
+
exist_ok=True)
|
|
337
|
+
else:
|
|
338
|
+
|
|
339
|
+
logger.bind(target='duckdb').trace(f'DuckDB file {self.duckdb_filepath} already exists')
|
|
340
|
+
|
|
341
|
+
# set autovacuum
|
|
342
|
+
conn = self.connect()
|
|
343
|
+
|
|
344
|
+
# check auto vacuum property
|
|
345
|
+
cur = conn.cursor()
|
|
346
|
+
cur.execute('PRAGMA main.auto_vacuum')
|
|
347
|
+
cur.execute('PRAGMA main.auto_vacuum = 2')
|
|
348
|
+
cur.close()
|
|
349
|
+
conn.close()
|
|
350
|
+
|
|
351
|
+
def connect(self) -> Any:
|
|
352
|
+
|
|
353
|
+
try:
|
|
354
|
+
|
|
355
|
+
con = sqlite_dbapi.connect(uri=self.duckdb_filepath)
|
|
356
|
+
|
|
357
|
+
except Exception as e:
|
|
358
|
+
|
|
359
|
+
logger.bind(target='duckdb').error(f'ADBC-SQLITE: connection error: {e}')
|
|
360
|
+
raise
|
|
361
|
+
|
|
362
|
+
else:
|
|
363
|
+
|
|
364
|
+
return con
|
|
365
|
+
|
|
366
|
+
def check_connection(self) -> bool:
|
|
367
|
+
|
|
368
|
+
out_check_connection = False
|
|
369
|
+
|
|
370
|
+
conn = self.connect()
|
|
371
|
+
|
|
372
|
+
out_check_connection = False
|
|
373
|
+
|
|
374
|
+
try:
|
|
375
|
+
|
|
376
|
+
info = read_database(text('SHOW DATABASES'), conn)
|
|
377
|
+
|
|
378
|
+
except Exception as e:
|
|
379
|
+
|
|
380
|
+
logger.bind(target='duckdb').error(f'Error during connection to {self.duckdb_filepath}')
|
|
381
|
+
|
|
382
|
+
else:
|
|
383
|
+
|
|
384
|
+
logger.bind(target='duckdb').trace(f'{info}')
|
|
385
|
+
|
|
386
|
+
out_check_connection = not is_empty_dataframe(info)
|
|
387
|
+
|
|
388
|
+
return out_check_connection
|
|
389
|
+
|
|
390
|
+
def _to_duckdb_column_types(self, columns_dict: Dict[str, Any]) -> Dict[str, str]:
|
|
391
|
+
|
|
392
|
+
duckdb_columns_dict = {}
|
|
393
|
+
|
|
394
|
+
for key, value in columns_dict.items():
|
|
395
|
+
|
|
396
|
+
match key:
|
|
397
|
+
|
|
398
|
+
case BASE_DATA_COLUMN_NAME.TIMESTAMP:
|
|
399
|
+
|
|
400
|
+
duckdb_columns_dict[BASE_DATA_COLUMN_NAME.TIMESTAMP] = 'TIMESTAMP_MS'
|
|
401
|
+
|
|
402
|
+
case BASE_DATA_COLUMN_NAME.ASK \
|
|
403
|
+
| BASE_DATA_COLUMN_NAME.BID \
|
|
404
|
+
| BASE_DATA_COLUMN_NAME.OPEN \
|
|
405
|
+
| BASE_DATA_COLUMN_NAME.HIGH \
|
|
406
|
+
| BASE_DATA_COLUMN_NAME.LOW \
|
|
407
|
+
| BASE_DATA_COLUMN_NAME.CLOSE \
|
|
408
|
+
| BASE_DATA_COLUMN_NAME.VOL \
|
|
409
|
+
| BASE_DATA_COLUMN_NAME.P_VALUE:
|
|
410
|
+
|
|
411
|
+
duckdb_columns_dict[key] = 'FLOAT'
|
|
412
|
+
|
|
413
|
+
case BASE_DATA_COLUMN_NAME.TRANSACTIONS:
|
|
414
|
+
|
|
415
|
+
duckdb_columns_dict[key] = 'UBIGINT'
|
|
416
|
+
|
|
417
|
+
case BASE_DATA_COLUMN_NAME.OTC:
|
|
418
|
+
|
|
419
|
+
duckdb_columns_dict[key] = 'FLOAT'
|
|
420
|
+
|
|
421
|
+
# force timestamp as first key
|
|
422
|
+
if not list(duckdb_columns_dict.keys())[0] == BASE_DATA_COLUMN_NAME.TIMESTAMP:
|
|
423
|
+
|
|
424
|
+
o_dict = OrderedDict(duckdb_columns_dict.items())
|
|
425
|
+
o_dict.move_to_end(BASE_DATA_COLUMN_NAME.TIMESTAMP, last=False)
|
|
426
|
+
|
|
427
|
+
duckdb_columns_dict = dict(o_dict)
|
|
428
|
+
|
|
429
|
+
else:
|
|
430
|
+
logger.bind(target='duckdb').trace(f'Timestamp is already the first column in {duckdb_columns_dict.keys()}')
|
|
431
|
+
|
|
432
|
+
return duckdb_columns_dict
|
|
433
|
+
|
|
434
|
+
def _list_tables(self) -> List[str]:
|
|
435
|
+
|
|
436
|
+
tables_list: List[str] = []
|
|
437
|
+
|
|
438
|
+
conn = self.connect()
|
|
439
|
+
|
|
440
|
+
try:
|
|
441
|
+
|
|
442
|
+
tables = read_database(query='SELECT * FROM sqlite_master',
|
|
443
|
+
connection=conn)
|
|
444
|
+
|
|
445
|
+
except Exception as e:
|
|
446
|
+
|
|
447
|
+
logger.bind(target='duckdb').error(f'Error list tables for {self.duckdb_filepath}: {e}')
|
|
448
|
+
|
|
449
|
+
else:
|
|
450
|
+
|
|
451
|
+
tables_list = list(tables['tbl_name'])
|
|
452
|
+
|
|
453
|
+
conn.close()
|
|
454
|
+
|
|
455
|
+
return tables_list
|
|
456
|
+
|
|
457
|
+
def _db_key(self,
|
|
458
|
+
market: str,
|
|
459
|
+
ticker: str,
|
|
460
|
+
timeframe: str
|
|
461
|
+
) -> str:
|
|
462
|
+
"""
|
|
463
|
+
|
|
464
|
+
get a str key of dotted divided elements
|
|
465
|
+
|
|
466
|
+
key template = ticker.timeframe.data_type
|
|
467
|
+
|
|
468
|
+
Parameters
|
|
469
|
+
----------
|
|
470
|
+
ticker : TYPE
|
|
471
|
+
DESCRIPTION.
|
|
472
|
+
year : TYPE
|
|
473
|
+
DESCRIPTION.
|
|
474
|
+
data_type : TYPE
|
|
475
|
+
DESCRIPTION.
|
|
476
|
+
|
|
477
|
+
Returns
|
|
478
|
+
-------
|
|
479
|
+
None.
|
|
480
|
+
|
|
481
|
+
"""
|
|
482
|
+
|
|
483
|
+
# skip checks cuse they
|
|
484
|
+
# are not meant for polars syntax of timeframes/frequencies
|
|
485
|
+
# tf = check_timeframe_str(timeframe)
|
|
486
|
+
|
|
487
|
+
return '_'.join([market.lower(),
|
|
488
|
+
ticker.lower(),
|
|
489
|
+
timeframe.lower()])
|
|
490
|
+
|
|
491
|
+
def _get_items_from_db_key(self,
|
|
492
|
+
key
|
|
493
|
+
) -> tuple:
|
|
494
|
+
|
|
495
|
+
return tuple(key.split('_'))
|
|
496
|
+
|
|
497
|
+
def get_tickers_list(self) -> List[str]:
|
|
498
|
+
|
|
499
|
+
tickers_list = []
|
|
500
|
+
|
|
501
|
+
for table_name in self._list_tables():
|
|
502
|
+
|
|
503
|
+
items = self._get_items_from_db_key(table_name)
|
|
504
|
+
|
|
505
|
+
tickers_list.append(items[DATA_KEY.TICKER_INDEX])
|
|
506
|
+
|
|
507
|
+
return list_remove_duplicates(tickers_list)
|
|
508
|
+
|
|
509
|
+
def get_ticker_keys(self, ticker: str, timeframe: Optional[str] = None) -> List[str]:
|
|
510
|
+
|
|
511
|
+
ticker_keys_list = []
|
|
512
|
+
|
|
513
|
+
for table_name in self._list_tables():
|
|
514
|
+
|
|
515
|
+
items = self._get_items_from_db_key(table_name)
|
|
516
|
+
|
|
517
|
+
if items[DATA_KEY.TICKER_INDEX] == ticker.lower():
|
|
518
|
+
|
|
519
|
+
if timeframe:
|
|
520
|
+
|
|
521
|
+
if items[DATA_KEY.TF_INDEX] == timeframe.lower():
|
|
522
|
+
|
|
523
|
+
ticker_keys_list.append(table_name)
|
|
524
|
+
|
|
525
|
+
else:
|
|
526
|
+
|
|
527
|
+
ticker_keys_list.append(table_name)
|
|
528
|
+
|
|
529
|
+
return ticker_keys_list
|
|
530
|
+
|
|
531
|
+
def get_ticker_years_list(self, ticker: str, timeframe: str = TICK_TIMEFRAME) -> List[int]:
|
|
532
|
+
|
|
533
|
+
ticker_years_list = []
|
|
534
|
+
table = ''
|
|
535
|
+
key_found = False
|
|
536
|
+
|
|
537
|
+
for table_name in self._list_tables():
|
|
538
|
+
|
|
539
|
+
items = self._get_items_from_db_key(table_name)
|
|
540
|
+
|
|
541
|
+
if (
|
|
542
|
+
items[DATA_KEY.TICKER_INDEX] == ticker.lower() and
|
|
543
|
+
items[DATA_KEY.TF_INDEX] == timeframe.lower()
|
|
544
|
+
):
|
|
545
|
+
|
|
546
|
+
table = table_name
|
|
547
|
+
key_found = True
|
|
548
|
+
|
|
549
|
+
break
|
|
550
|
+
|
|
551
|
+
if key_found:
|
|
552
|
+
|
|
553
|
+
conn = self.connect()
|
|
554
|
+
|
|
555
|
+
try:
|
|
556
|
+
|
|
557
|
+
query = f'''SELECT DISTINCT STRFTIME('%Y', CAST({
|
|
558
|
+
BASE_DATA_COLUMN_NAME.TIMESTAMP} AS TEXT))
|
|
559
|
+
AS YEAR
|
|
560
|
+
FROM {table}'''
|
|
561
|
+
read = read_database(query, conn)
|
|
562
|
+
|
|
563
|
+
except Exception as e:
|
|
564
|
+
|
|
565
|
+
logger.bind(target='duckdb').error(f'Error querying table {table}: {e}')
|
|
566
|
+
raise
|
|
567
|
+
|
|
568
|
+
else:
|
|
569
|
+
|
|
570
|
+
ticker_years_list = [int(row[0]) for row in read.iter_rows()]
|
|
571
|
+
|
|
572
|
+
conn.commit()
|
|
573
|
+
conn.close()
|
|
574
|
+
|
|
575
|
+
return ticker_years_list
|
|
576
|
+
|
|
577
|
+
def write_data(self, target_table: str, dataframe: Union[polars_dataframe, polars_lazyframe], clean: bool = False) -> None:
|
|
578
|
+
|
|
579
|
+
duckdb_cols_dict = {}
|
|
580
|
+
if isinstance(dataframe, polars_lazyframe):
|
|
581
|
+
|
|
582
|
+
duckdb_cols_dict = self._to_duckdb_column_types(
|
|
583
|
+
dict(dataframe.collect_schema()))
|
|
584
|
+
dataframe = dataframe.collect()
|
|
585
|
+
|
|
586
|
+
else:
|
|
587
|
+
|
|
588
|
+
duckdb_cols_dict = self._to_duckdb_column_types(dict(dataframe.schema))
|
|
589
|
+
|
|
590
|
+
duckdb_cols_str = ', '.join([f"{key} {duckdb_cols_dict[key]}"
|
|
591
|
+
for key in duckdb_cols_dict])
|
|
592
|
+
|
|
593
|
+
# open a connection
|
|
594
|
+
conn = self.connect()
|
|
595
|
+
|
|
596
|
+
# exec stable creation
|
|
597
|
+
table_list = self._list_tables()
|
|
598
|
+
|
|
599
|
+
if_table_exists = 'replace'
|
|
600
|
+
if target_table in table_list:
|
|
601
|
+
|
|
602
|
+
# stable_describe = read_database(f'DESCRIBE {target_table}')
|
|
603
|
+
|
|
604
|
+
# get existing stable column structure
|
|
605
|
+
# if they match, append data
|
|
606
|
+
# if no match, replace stable
|
|
607
|
+
|
|
608
|
+
if_table_exists = 'append'
|
|
609
|
+
|
|
610
|
+
target_length = len(dataframe)
|
|
611
|
+
|
|
612
|
+
table_write = dataframe.write_database(
|
|
613
|
+
table_name=target_table,
|
|
614
|
+
connection=conn,
|
|
615
|
+
if_table_exists=if_table_exists,
|
|
616
|
+
engine='adbc'
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
conn.commit()
|
|
620
|
+
conn.close()
|
|
621
|
+
|
|
622
|
+
# clean stage
|
|
623
|
+
if clean:
|
|
624
|
+
|
|
625
|
+
conn = self.connect()
|
|
626
|
+
|
|
627
|
+
# delete duplicates
|
|
628
|
+
query_clean = f'''DELETE FROM {target_table}
|
|
629
|
+
WHERE ROWID NOT IN (
|
|
630
|
+
SELECT MIN(ROWID)
|
|
631
|
+
FROM {target_table}
|
|
632
|
+
GROUP BY {BASE_DATA_COLUMN_NAME.TIMESTAMP}
|
|
633
|
+
);'''
|
|
634
|
+
|
|
635
|
+
cur = conn.cursor()
|
|
636
|
+
res = cur.execute(query_clean)
|
|
637
|
+
|
|
638
|
+
# Close
|
|
639
|
+
cur.close()
|
|
640
|
+
conn.commit()
|
|
641
|
+
conn.close()
|
|
642
|
+
|
|
643
|
+
conn = self.connect()
|
|
644
|
+
cur = conn.cursor()
|
|
645
|
+
vacuum = cur.execute('PRAGMA main.incremental_vacuum')
|
|
646
|
+
|
|
647
|
+
# Close
|
|
648
|
+
cur.close()
|
|
649
|
+
conn.commit()
|
|
650
|
+
conn.close()
|
|
651
|
+
|
|
652
|
+
def read_data(self,
|
|
653
|
+
market: str,
|
|
654
|
+
ticker: str,
|
|
655
|
+
timeframe: str,
|
|
656
|
+
start: datetime,
|
|
657
|
+
end: datetime
|
|
658
|
+
) -> polars_lazyframe:
|
|
659
|
+
|
|
660
|
+
dataframe = polars_lazyframe()
|
|
661
|
+
|
|
662
|
+
table = self._db_key(market, ticker, timeframe)
|
|
663
|
+
# check if database is available
|
|
664
|
+
if table in self._list_tables():
|
|
665
|
+
|
|
666
|
+
# open a connection
|
|
667
|
+
conn = self.connect()
|
|
668
|
+
|
|
669
|
+
try:
|
|
670
|
+
|
|
671
|
+
start_str = start.isoformat()
|
|
672
|
+
end_str = end.isoformat()
|
|
673
|
+
'''
|
|
674
|
+
Here you could use also
|
|
675
|
+
WHERE CAST({BASE_DATA_COLUMN_NAME.TIMESTAMP} AS TEXT)
|
|
676
|
+
'''
|
|
677
|
+
query = f'''SELECT * FROM {table}
|
|
678
|
+
WHERE {BASE_DATA_COLUMN_NAME.TIMESTAMP}
|
|
679
|
+
BETWEEN '{start_str}' AND '{end_str}'
|
|
680
|
+
ORDER BY {BASE_DATA_COLUMN_NAME.TIMESTAMP}'''
|
|
681
|
+
dataframe = read_database(query, conn).lazy()
|
|
682
|
+
|
|
683
|
+
except Exception as e:
|
|
684
|
+
|
|
685
|
+
logger.bind(target='duckdb').error(f'executing query {query} failed: {e}')
|
|
686
|
+
|
|
687
|
+
else:
|
|
688
|
+
|
|
689
|
+
if timeframe == TICK_TIMEFRAME:
|
|
690
|
+
|
|
691
|
+
# final cast to standard dtypes
|
|
692
|
+
dataframe = dataframe.cast(cast(Any, POLARS_DTYPE_DICT.TIME_TICK_DTYPE))
|
|
693
|
+
|
|
694
|
+
else:
|
|
695
|
+
|
|
696
|
+
# final cast to standard dtypes
|
|
697
|
+
dataframe = dataframe.cast(cast(Any, POLARS_DTYPE_DICT.TIME_TF_DTYPE))
|
|
698
|
+
|
|
699
|
+
# close
|
|
700
|
+
conn.commit()
|
|
701
|
+
conn.close()
|
|
702
|
+
|
|
703
|
+
return dataframe
|
|
704
|
+
|
|
705
|
+
def clear_database(self, filter: Optional[str] = None) -> None:
|
|
706
|
+
"""
|
|
707
|
+
Clear database tables
|
|
708
|
+
If filter is provided, delete only tables related to that ticker
|
|
709
|
+
"""
|
|
710
|
+
tables = self._list_tables()
|
|
711
|
+
conn = self.connect()
|
|
712
|
+
cur = conn.cursor()
|
|
713
|
+
|
|
714
|
+
for table in tables:
|
|
715
|
+
if filter:
|
|
716
|
+
if search(filter, table, IGNORECASE):
|
|
717
|
+
cur.execute(f"DROP TABLE {table}")
|
|
718
|
+
else:
|
|
719
|
+
cur.execute(f"DROP TABLE {table}")
|
|
720
|
+
|
|
721
|
+
cur.execute("VACUUM")
|
|
722
|
+
cur.close()
|
|
723
|
+
conn.commit()
|
|
724
|
+
conn.close()
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
'''
|
|
728
|
+
LOCAL DATA FILES MANAGER
|
|
729
|
+
|
|
730
|
+
'''
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
@define(kw_only=True, slots=True)
|
|
734
|
+
class LocalDBConnector(DatabaseConnector):
|
|
735
|
+
|
|
736
|
+
data_folder: str = field(default='',
|
|
737
|
+
validator=validators.instance_of(str))
|
|
738
|
+
data_type: str = field(default='parquet',
|
|
739
|
+
validator=validators.in_(SUPPORTED_DATA_FILES))
|
|
740
|
+
engine: str = field(default='polars_lazy',
|
|
741
|
+
validator=validators.in_(SUPPORTED_DATA_ENGINES))
|
|
742
|
+
|
|
743
|
+
_local_path = field(
|
|
744
|
+
default=Path('.'),
|
|
745
|
+
validator=validator_dir_path(create_if_missing=False))
|
|
746
|
+
|
|
747
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
748
|
+
|
|
749
|
+
_class_attributes_name = get_attrs_names(self, **kwargs)
|
|
750
|
+
_not_assigned_attrs_index_mask = [True] * len(_class_attributes_name)
|
|
751
|
+
|
|
752
|
+
if 'config' in kwargs.keys():
|
|
753
|
+
|
|
754
|
+
if kwargs['config']:
|
|
755
|
+
|
|
756
|
+
config_path = Path(kwargs['config'])
|
|
757
|
+
|
|
758
|
+
if (
|
|
759
|
+
config_path.exists() and
|
|
760
|
+
config_path.is_dir()
|
|
761
|
+
):
|
|
762
|
+
|
|
763
|
+
config_filepath = read_config_folder(config_path,
|
|
764
|
+
file_pattern='_config.yaml')
|
|
765
|
+
|
|
766
|
+
else:
|
|
767
|
+
|
|
768
|
+
config_filepath = Path()
|
|
769
|
+
|
|
770
|
+
config_args = {}
|
|
771
|
+
if config_filepath.exists() \
|
|
772
|
+
and \
|
|
773
|
+
config_filepath.is_file() \
|
|
774
|
+
and \
|
|
775
|
+
config_filepath.suffix == '.yaml':
|
|
776
|
+
|
|
777
|
+
# read parameters from config file
|
|
778
|
+
# and force keys to lower case
|
|
779
|
+
config_args = {key.lower(): val for key, val in
|
|
780
|
+
read_config_file(str(config_filepath)).items()}
|
|
781
|
+
|
|
782
|
+
elif isinstance(kwargs['config'], str):
|
|
783
|
+
|
|
784
|
+
# read parameters from config file
|
|
785
|
+
# and force keys to lower case
|
|
786
|
+
config_args = {key.lower(): val for key, val in
|
|
787
|
+
read_config_string(kwargs['config']).items()}
|
|
788
|
+
|
|
789
|
+
else:
|
|
790
|
+
|
|
791
|
+
logger.bind(target='localdb').critical(
|
|
792
|
+
'invalid config type '
|
|
793
|
+
f'{kwargs["config"]}: '
|
|
794
|
+
'required str or Path, got '
|
|
795
|
+
f'{type(kwargs["config"])}')
|
|
796
|
+
raise TypeError
|
|
797
|
+
|
|
798
|
+
# check consistency of config_args
|
|
799
|
+
if (
|
|
800
|
+
not isinstance(config_args, dict) or
|
|
801
|
+
not bool(config_args)
|
|
802
|
+
):
|
|
803
|
+
|
|
804
|
+
logger.bind(target='localdb').critical(
|
|
805
|
+
f'config {kwargs["config"]} '
|
|
806
|
+
'has no valid yaml formatted data')
|
|
807
|
+
raise TypeError
|
|
808
|
+
|
|
809
|
+
# set args from config file
|
|
810
|
+
attrs_keys_configfile = \
|
|
811
|
+
set(_class_attributes_name).intersection(config_args.keys())
|
|
812
|
+
|
|
813
|
+
for attr_key in attrs_keys_configfile:
|
|
814
|
+
|
|
815
|
+
self.__setattr__(attr_key,
|
|
816
|
+
config_args[attr_key])
|
|
817
|
+
|
|
818
|
+
_not_assigned_attrs_index_mask[
|
|
819
|
+
_class_attributes_name.index(attr_key)
|
|
820
|
+
] = False
|
|
821
|
+
|
|
822
|
+
# set args from instantiation
|
|
823
|
+
# override if attr already has a value from config
|
|
824
|
+
attrs_keys_input = \
|
|
825
|
+
set(_class_attributes_name).intersection(kwargs.keys())
|
|
826
|
+
|
|
827
|
+
for attr_key in attrs_keys_input:
|
|
828
|
+
|
|
829
|
+
self.__setattr__(attr_key,
|
|
830
|
+
kwargs[attr_key])
|
|
831
|
+
|
|
832
|
+
_not_assigned_attrs_index_mask[
|
|
833
|
+
_class_attributes_name.index(attr_key)
|
|
834
|
+
] = False
|
|
835
|
+
|
|
836
|
+
# attrs not present in config file or instance inputs
|
|
837
|
+
# --> self.attr leads to KeyError
|
|
838
|
+
# are manually assigned to default value derived
|
|
839
|
+
# from __attrs_attrs__
|
|
840
|
+
|
|
841
|
+
for attr_key in array(_class_attributes_name)[
|
|
842
|
+
_not_assigned_attrs_index_mask
|
|
843
|
+
]:
|
|
844
|
+
|
|
845
|
+
try:
|
|
846
|
+
|
|
847
|
+
attr = [attr
|
|
848
|
+
for attr in self.__attrs_attrs__
|
|
849
|
+
if attr.name == attr_key][0]
|
|
850
|
+
|
|
851
|
+
except KeyError:
|
|
852
|
+
|
|
853
|
+
logger.error('KeyError: initializing object has no '
|
|
854
|
+
f'attribute {attr.name}')
|
|
855
|
+
raise
|
|
856
|
+
|
|
857
|
+
except IndexError:
|
|
858
|
+
|
|
859
|
+
logger.error('IndexError: initializing object has no '
|
|
860
|
+
f'attribute {attr.name}')
|
|
861
|
+
raise
|
|
862
|
+
|
|
863
|
+
else:
|
|
864
|
+
|
|
865
|
+
# assign default value
|
|
866
|
+
# try default and factory sabsequently
|
|
867
|
+
# if neither are present
|
|
868
|
+
# assign None
|
|
869
|
+
if hasattr(attr, 'default'):
|
|
870
|
+
|
|
871
|
+
if hasattr(attr.default, 'factory'):
|
|
872
|
+
|
|
873
|
+
self.__setattr__(attr.name,
|
|
874
|
+
attr.default.factory())
|
|
875
|
+
|
|
876
|
+
else:
|
|
877
|
+
|
|
878
|
+
self.__setattr__(attr.name,
|
|
879
|
+
attr.default)
|
|
880
|
+
|
|
881
|
+
else:
|
|
882
|
+
|
|
883
|
+
self.__setattr__(attr.name,
|
|
884
|
+
None)
|
|
885
|
+
|
|
886
|
+
else:
|
|
887
|
+
|
|
888
|
+
logger.trace(f'config {kwargs["config"]} is empty, using default configuration')
|
|
889
|
+
|
|
890
|
+
else:
|
|
891
|
+
|
|
892
|
+
# no config file is defined
|
|
893
|
+
# call generated init
|
|
894
|
+
self.__attrs_init__(**kwargs) # type: ignore[attr-defined]
|
|
895
|
+
|
|
896
|
+
validate(self)
|
|
897
|
+
|
|
898
|
+
self.__attrs_post_init__(**kwargs)
|
|
899
|
+
|
|
900
|
+
def __attrs_post_init__(self, **kwargs: Any) -> None:
|
|
901
|
+
|
|
902
|
+
super().__attrs_post_init__()
|
|
903
|
+
|
|
904
|
+
# set up log sink for LocalDB
|
|
905
|
+
logger.add(Path(self.data_folder) / 'log' / 'localdb.log',
|
|
906
|
+
level="TRACE",
|
|
907
|
+
rotation="5 MB",
|
|
908
|
+
filter=lambda record: ('localdb' == record['extra'].get('target') and
|
|
909
|
+
bool(record["extra"].get('target'))))
|
|
910
|
+
|
|
911
|
+
self._local_path = Path(self.data_folder)
|
|
912
|
+
|
|
913
|
+
def _db_key(self,
|
|
914
|
+
market: str,
|
|
915
|
+
ticker: str,
|
|
916
|
+
timeframe: str
|
|
917
|
+
) -> str:
|
|
918
|
+
"""
|
|
919
|
+
|
|
920
|
+
get a str key of dotted divided elements
|
|
921
|
+
|
|
922
|
+
key template = ticker.timeframe.data_type
|
|
923
|
+
|
|
924
|
+
Parameters
|
|
925
|
+
----------
|
|
926
|
+
ticker : TYPE
|
|
927
|
+
DESCRIPTION.
|
|
928
|
+
year : TYPE
|
|
929
|
+
DESCRIPTION.
|
|
930
|
+
data_type : TYPE
|
|
931
|
+
DESCRIPTION.
|
|
932
|
+
|
|
933
|
+
Returns
|
|
934
|
+
-------
|
|
935
|
+
None.
|
|
936
|
+
|
|
937
|
+
"""
|
|
938
|
+
|
|
939
|
+
# skip checks cuse they
|
|
940
|
+
# are not meant for polars syntax of timeframes/frequencies
|
|
941
|
+
# tf = check_timeframe_str(timeframe)
|
|
942
|
+
|
|
943
|
+
return '_'.join([market.lower(),
|
|
944
|
+
ticker.lower(),
|
|
945
|
+
timeframe.lower()])
|
|
946
|
+
|
|
947
|
+
def _get_items_from_db_key(self,
|
|
948
|
+
key
|
|
949
|
+
) -> tuple:
|
|
950
|
+
|
|
951
|
+
return tuple(key.split('_'))
|
|
952
|
+
|
|
953
|
+
def _get_file_details(self, filename: str) -> Tuple[str, str, str]:
|
|
954
|
+
|
|
955
|
+
if not (
|
|
956
|
+
isinstance(filename, str)
|
|
957
|
+
):
|
|
958
|
+
|
|
959
|
+
logger.bind(target='localdb').error('filename {filename} invalid type: required str')
|
|
960
|
+
raise TypeError(f'filename {filename} invalid type: required str')
|
|
961
|
+
|
|
962
|
+
file_items = self._get_items_from_db_key(filename)
|
|
963
|
+
|
|
964
|
+
# return each file details
|
|
965
|
+
return file_items
|
|
966
|
+
|
|
967
|
+
def _get_filename(self, market: str, ticker: str, tf: str) -> str:
|
|
968
|
+
|
|
969
|
+
# based on standard filename template
|
|
970
|
+
return FILENAME_STR.format(market=market.lower(),
|
|
971
|
+
ticker=ticker.lower(),
|
|
972
|
+
tf=tf.lower(),
|
|
973
|
+
file_ext=self.data_type.lower())
|
|
974
|
+
|
|
975
|
+
def _list_local_data(self) -> List[PathType]:
|
|
976
|
+
|
|
977
|
+
local_files = []
|
|
978
|
+
local_files_name = []
|
|
979
|
+
|
|
980
|
+
# list for all data filetypes supported
|
|
981
|
+
local_files = [file for file in list(self._local_path.rglob(f'*'))
|
|
982
|
+
if search(self.data_type + '$', file.suffix)]
|
|
983
|
+
|
|
984
|
+
local_files_name = [file.name for file in local_files]
|
|
985
|
+
|
|
986
|
+
# check compliance of files to convention (see notes)
|
|
987
|
+
# TODO: warning if no compliant and filter out from files found
|
|
988
|
+
|
|
989
|
+
return local_files, local_files_name
|
|
990
|
+
|
|
991
|
+
def _list_tables(self) -> List[str]:
|
|
992
|
+
|
|
993
|
+
local_files, tables_list = self._list_local_data()
|
|
994
|
+
|
|
995
|
+
return tables_list
|
|
996
|
+
|
|
997
|
+
def get_tickers_list(self) -> List[str]:
|
|
998
|
+
|
|
999
|
+
tickers_list = []
|
|
1000
|
+
|
|
1001
|
+
local_files, local_files_name = self._list_local_data()
|
|
1002
|
+
|
|
1003
|
+
for filename in local_files_name:
|
|
1004
|
+
|
|
1005
|
+
items = self._get_file_details(filename)
|
|
1006
|
+
tickers_list.append(items[DATA_KEY.TICKER_INDEX])
|
|
1007
|
+
|
|
1008
|
+
return list_remove_duplicates(tickers_list)
|
|
1009
|
+
|
|
1010
|
+
def clear_database(self, filter: Optional[str] = None) -> None:
|
|
1011
|
+
|
|
1012
|
+
"""
|
|
1013
|
+
Clear database files
|
|
1014
|
+
If filter is provided and is a ticker present in database (files present)
|
|
1015
|
+
delete only files related to that ticker
|
|
1016
|
+
"""
|
|
1017
|
+
|
|
1018
|
+
if filter:
|
|
1019
|
+
|
|
1020
|
+
# in local path search for files having filter in path stem
|
|
1021
|
+
# and delete them
|
|
1022
|
+
# list all files in local path ending with data_type
|
|
1023
|
+
# and use re.search to catch matches
|
|
1024
|
+
if isinstance(filter, str):
|
|
1025
|
+
|
|
1026
|
+
data_files = self._local_path.rglob(f'*.{self.data_type}')
|
|
1027
|
+
if data_files:
|
|
1028
|
+
for file in data_files:
|
|
1029
|
+
if search(filter, file.stem, IGNORECASE):
|
|
1030
|
+
file.unlink(missing_ok=True)
|
|
1031
|
+
else:
|
|
1032
|
+
logger.bind(target='localdb').info(f'No data files found in {self._local_path} with filter {filter}')
|
|
1033
|
+
|
|
1034
|
+
else:
|
|
1035
|
+
logger.bind(target='localdb').error(f'Filter {filter} invalid type: required str')
|
|
1036
|
+
|
|
1037
|
+
else:
|
|
1038
|
+
|
|
1039
|
+
# clear all files in local path at
|
|
1040
|
+
# folder level using shutil
|
|
1041
|
+
shutil.rmtree(self._local_path)
|
|
1042
|
+
|
|
1043
|
+
def get_ticker_keys(self, ticker: str, timeframe: Optional[str] = None) -> List[str]:
|
|
1044
|
+
|
|
1045
|
+
local_files, local_files_name = self._list_local_data()
|
|
1046
|
+
|
|
1047
|
+
if timeframe:
|
|
1048
|
+
|
|
1049
|
+
return [
|
|
1050
|
+
key for key in local_files_name
|
|
1051
|
+
if search(f'{ticker}',
|
|
1052
|
+
key) and
|
|
1053
|
+
self._get_items_from_db_key(key)[DATA_KEY.TF_INDEX] ==
|
|
1054
|
+
timeframe
|
|
1055
|
+
]
|
|
1056
|
+
|
|
1057
|
+
else:
|
|
1058
|
+
|
|
1059
|
+
return [
|
|
1060
|
+
key for key in local_files_name
|
|
1061
|
+
if search(f'{ticker}',
|
|
1062
|
+
key)
|
|
1063
|
+
]
|
|
1064
|
+
|
|
1065
|
+
def get_ticker_years_list(self, ticker: str, timeframe: str = TICK_TIMEFRAME) -> List[int]:
|
|
1066
|
+
|
|
1067
|
+
ticker_years_list = []
|
|
1068
|
+
table = ''
|
|
1069
|
+
key_found = False
|
|
1070
|
+
|
|
1071
|
+
local_files, local_files_name = self._list_local_data()
|
|
1072
|
+
ticker_keys = []
|
|
1073
|
+
|
|
1074
|
+
files = [
|
|
1075
|
+
key for key in local_files
|
|
1076
|
+
if search(f'{ticker.lower()}',
|
|
1077
|
+
str(key.stem)) and
|
|
1078
|
+
self._get_items_from_db_key(str(key.stem))[DATA_KEY.TF_INDEX] ==
|
|
1079
|
+
timeframe.lower()
|
|
1080
|
+
]
|
|
1081
|
+
|
|
1082
|
+
dataframe = None
|
|
1083
|
+
|
|
1084
|
+
if len(files) == 1:
|
|
1085
|
+
|
|
1086
|
+
if self.data_type == DATA_TYPE.CSV_FILETYPE:
|
|
1087
|
+
|
|
1088
|
+
dataframe = read_csv(self.engine, files[0])
|
|
1089
|
+
|
|
1090
|
+
elif self.data_type == DATA_TYPE.PARQUET_FILETYPE:
|
|
1091
|
+
|
|
1092
|
+
dataframe = read_parquet(self.engine, files[0])
|
|
1093
|
+
|
|
1094
|
+
try:
|
|
1095
|
+
|
|
1096
|
+
query = f'''SELECT DISTINCT STRFTIME({
|
|
1097
|
+
BASE_DATA_COLUMN_NAME.TIMESTAMP}, '%Y')
|
|
1098
|
+
AS YEAR
|
|
1099
|
+
FROM self'''
|
|
1100
|
+
read = dataframe.sql(query)
|
|
1101
|
+
|
|
1102
|
+
except Exception as e:
|
|
1103
|
+
|
|
1104
|
+
logger.bind(target='localdb').error(f'Error querying table {table}: {e}')
|
|
1105
|
+
raise
|
|
1106
|
+
|
|
1107
|
+
else:
|
|
1108
|
+
|
|
1109
|
+
ticker_years_list = [int(row[0]) for row in read.collect().iter_rows()]
|
|
1110
|
+
|
|
1111
|
+
else:
|
|
1112
|
+
logger.bind(target='localdb').warning(f'Expected 1 file for {ticker} - {timeframe}, found {len(files)}')
|
|
1113
|
+
|
|
1114
|
+
return ticker_years_list
|
|
1115
|
+
|
|
1116
|
+
def write_data(
|
|
1117
|
+
self,
|
|
1118
|
+
target_table: str,
|
|
1119
|
+
dataframe: Union[polars_dataframe, polars_lazyframe],
|
|
1120
|
+
clean: bool = False
|
|
1121
|
+
) -> None:
|
|
1122
|
+
|
|
1123
|
+
items = self._get_items_from_db_key(target_table)
|
|
1124
|
+
|
|
1125
|
+
filename = self._get_filename(items[DATA_KEY.MARKET],
|
|
1126
|
+
items[DATA_KEY.TICKER_INDEX],
|
|
1127
|
+
items[DATA_KEY.TF_INDEX])
|
|
1128
|
+
|
|
1129
|
+
filepath = (self._local_path /
|
|
1130
|
+
items[DATA_KEY.MARKET] /
|
|
1131
|
+
items[DATA_KEY.TICKER_INDEX] /
|
|
1132
|
+
filename)
|
|
1133
|
+
|
|
1134
|
+
if (
|
|
1135
|
+
not filepath.exists() or
|
|
1136
|
+
not filepath.is_file()
|
|
1137
|
+
):
|
|
1138
|
+
|
|
1139
|
+
filepath.parent.mkdir(parents=True,
|
|
1140
|
+
exist_ok=True)
|
|
1141
|
+
|
|
1142
|
+
else:
|
|
1143
|
+
|
|
1144
|
+
if self.data_type == DATA_TYPE.CSV_FILETYPE:
|
|
1145
|
+
|
|
1146
|
+
dataframe_ex = read_csv(self.engine, filepath)
|
|
1147
|
+
|
|
1148
|
+
elif self.data_type == DATA_TYPE.PARQUET_FILETYPE:
|
|
1149
|
+
|
|
1150
|
+
dataframe_ex = read_parquet(self.engine, filepath)
|
|
1151
|
+
|
|
1152
|
+
dataframe = concat_data([dataframe, dataframe_ex])
|
|
1153
|
+
# clean duplicated timestamps rows, keep first by default
|
|
1154
|
+
dataframe = dataframe.unique(
|
|
1155
|
+
subset=[
|
|
1156
|
+
BASE_DATA_COLUMN_NAME.TIMESTAMP],
|
|
1157
|
+
keep='first').sort(
|
|
1158
|
+
BASE_DATA_COLUMN_NAME.TIMESTAMP)
|
|
1159
|
+
|
|
1160
|
+
if self.data_type == DATA_TYPE.CSV_FILETYPE:
|
|
1161
|
+
|
|
1162
|
+
write_csv(dataframe, filepath)
|
|
1163
|
+
|
|
1164
|
+
elif self.data_type == DATA_TYPE.PARQUET_FILETYPE:
|
|
1165
|
+
|
|
1166
|
+
write_parquet(dataframe, filepath)
|
|
1167
|
+
|
|
1168
|
+
def read_data(self,
|
|
1169
|
+
market: str,
|
|
1170
|
+
ticker: str,
|
|
1171
|
+
timeframe: str,
|
|
1172
|
+
start: datetime,
|
|
1173
|
+
end: datetime,
|
|
1174
|
+
comparison_column_name: List[str] | str | None = None,
|
|
1175
|
+
check_level: List[int | float] | int | float | None = None,
|
|
1176
|
+
comparison_operator: List[SUPPORTED_SQL_COMPARISON_OPERATORS] | SUPPORTED_SQL_COMPARISON_OPERATORS | None = None,
|
|
1177
|
+
comparison_aggregation_mode: SUPPORTED_SQL_CONDITION_AGGREGATION_MODES | None = None
|
|
1178
|
+
) -> polars_lazyframe:
|
|
1179
|
+
|
|
1180
|
+
comparisons_len = 0
|
|
1181
|
+
|
|
1182
|
+
# Validate and normalize condition parameters if provided
|
|
1183
|
+
if comparison_column_name is not None or check_level is not None or comparison_operator is not None:
|
|
1184
|
+
comparisons_len = len(comparison_column_name)
|
|
1185
|
+
|
|
1186
|
+
if isinstance(comparison_column_name, str):
|
|
1187
|
+
comparison_column_name = [comparison_column_name]
|
|
1188
|
+
|
|
1189
|
+
if isinstance(check_level, (int, float)):
|
|
1190
|
+
check_level = [check_level]
|
|
1191
|
+
|
|
1192
|
+
if isinstance(comparison_operator, str):
|
|
1193
|
+
comparison_operator = [comparison_operator]
|
|
1194
|
+
|
|
1195
|
+
if any([col not in list(SUPPORTED_BASE_DATA_COLUMN_NAME.__args__) for col in comparison_column_name]):
|
|
1196
|
+
logger.bind(target='localdb').error(f'comparison_column_name must be a supported column name: {list(SUPPORTED_BASE_DATA_COLUMN_NAME.__args__)}')
|
|
1197
|
+
raise ValueError('comparison_column_name must be a supported column name')
|
|
1198
|
+
|
|
1199
|
+
if any([cond not in list(SUPPORTED_SQL_COMPARISON_OPERATORS.__args__) for cond in comparison_operator]):
|
|
1200
|
+
logger.bind(target='localdb').error(f'comparison_operator must be a supported SQL comparison operator: {list(SUPPORTED_SQL_COMPARISON_OPERATORS.__args__)}')
|
|
1201
|
+
raise ValueError('comparison_operator must be a supported SQL comparison operator')
|
|
1202
|
+
|
|
1203
|
+
if (
|
|
1204
|
+
(
|
|
1205
|
+
comparison_aggregation_mode is not None
|
|
1206
|
+
and
|
|
1207
|
+
comparisons_len > 1
|
|
1208
|
+
)
|
|
1209
|
+
and
|
|
1210
|
+
comparison_aggregation_mode not in list(SUPPORTED_SQL_CONDITION_AGGREGATION_MODES.__args__)
|
|
1211
|
+
):
|
|
1212
|
+
logger.bind(target='localdb').error(f'comparison_aggregation_mode must be a supported SQL condition aggregation mode: {list(SUPPORTED_SQL_CONDITION_AGGREGATION_MODES.__args__)}')
|
|
1213
|
+
raise ValueError('comparison_aggregation_mode must be a supported SQL condition aggregation mode')
|
|
1214
|
+
|
|
1215
|
+
if len(comparison_column_name) != len(check_level) or len(comparison_column_name) != len(comparison_operator):
|
|
1216
|
+
logger.bind(target='localdb').error('comparison_column_name, check_level and comparison_operator must have the same length')
|
|
1217
|
+
raise ValueError('comparison_column_name, check_level and comparison_operator must have the same length')
|
|
1218
|
+
|
|
1219
|
+
comparisons_len = len(comparison_column_name)
|
|
1220
|
+
|
|
1221
|
+
dataframe = polars_lazyframe()
|
|
1222
|
+
|
|
1223
|
+
filename = self._get_filename(market,
|
|
1224
|
+
ticker,
|
|
1225
|
+
timeframe)
|
|
1226
|
+
|
|
1227
|
+
filepath = (self._local_path /
|
|
1228
|
+
market /
|
|
1229
|
+
ticker /
|
|
1230
|
+
filename)
|
|
1231
|
+
|
|
1232
|
+
if self.engine == 'polars':
|
|
1233
|
+
|
|
1234
|
+
dataframe = polars_dataframe()
|
|
1235
|
+
|
|
1236
|
+
elif self.engine == 'polars_lazy':
|
|
1237
|
+
|
|
1238
|
+
dataframe = polars_lazyframe()
|
|
1239
|
+
|
|
1240
|
+
else:
|
|
1241
|
+
|
|
1242
|
+
logger.bind(target='localdb').error(f'Engine {self.engine} or data type {self.data_type} not supported')
|
|
1243
|
+
raise ValueError(f'Engine {self.engine} or data type {self.data_type} not supported')
|
|
1244
|
+
|
|
1245
|
+
if (
|
|
1246
|
+
filepath.exists() and
|
|
1247
|
+
filepath.is_file()
|
|
1248
|
+
):
|
|
1249
|
+
|
|
1250
|
+
if self.data_type == DATA_TYPE.CSV_FILETYPE:
|
|
1251
|
+
|
|
1252
|
+
dataframe = read_csv(self.engine, filepath)
|
|
1253
|
+
|
|
1254
|
+
elif self.data_type == DATA_TYPE.PARQUET_FILETYPE:
|
|
1255
|
+
|
|
1256
|
+
dataframe = read_parquet(self.engine, filepath)
|
|
1257
|
+
|
|
1258
|
+
try:
|
|
1259
|
+
|
|
1260
|
+
start_str = start.isoformat()
|
|
1261
|
+
end_str = end.isoformat()
|
|
1262
|
+
|
|
1263
|
+
# Build base query with timestamp filter
|
|
1264
|
+
query = f'''SELECT * FROM self
|
|
1265
|
+
WHERE
|
|
1266
|
+
{BASE_DATA_COLUMN_NAME.TIMESTAMP} >= '{start_str}'
|
|
1267
|
+
AND
|
|
1268
|
+
{BASE_DATA_COLUMN_NAME.TIMESTAMP} <= '{end_str}'
|
|
1269
|
+
'''
|
|
1270
|
+
# Aggregate conditional filters if provided
|
|
1271
|
+
# with the aggregation mode specified
|
|
1272
|
+
if comparisons_len > 0:
|
|
1273
|
+
|
|
1274
|
+
if comparisons_len == 1:
|
|
1275
|
+
# only one condition
|
|
1276
|
+
query += f'''AND
|
|
1277
|
+
{comparison_column_name[0]} {comparison_operator[0]} {check_level[0]}
|
|
1278
|
+
'''
|
|
1279
|
+
else:
|
|
1280
|
+
# multiple conditions
|
|
1281
|
+
# wrap all conditions in parentheses with aggregation mode between them
|
|
1282
|
+
query += f'''AND
|
|
1283
|
+
({comparison_column_name[0]} {comparison_operator[0]} {check_level[0]}
|
|
1284
|
+
'''
|
|
1285
|
+
for col, level, cond, index in zip(comparison_column_name[1:], check_level[1:], comparison_operator[1:], range(1, comparisons_len)):
|
|
1286
|
+
|
|
1287
|
+
if index == comparisons_len - 1:
|
|
1288
|
+
# closing conditions needs closing bracket
|
|
1289
|
+
query += f'''{comparison_aggregation_mode}
|
|
1290
|
+
{col} {cond} {level})
|
|
1291
|
+
'''
|
|
1292
|
+
else:
|
|
1293
|
+
# intermediate conditions
|
|
1294
|
+
query += f'''{comparison_aggregation_mode}
|
|
1295
|
+
{col} {cond} {level}
|
|
1296
|
+
'''
|
|
1297
|
+
# Close query with timestamp ordering
|
|
1298
|
+
query += f'ORDER BY {BASE_DATA_COLUMN_NAME.TIMESTAMP}'
|
|
1299
|
+
dataframe = dataframe.sql(query)
|
|
1300
|
+
|
|
1301
|
+
except Exception as e:
|
|
1302
|
+
|
|
1303
|
+
logger.error(f'executing query {query} failed: {e}')
|
|
1304
|
+
|
|
1305
|
+
else:
|
|
1306
|
+
|
|
1307
|
+
if timeframe == TICK_TIMEFRAME:
|
|
1308
|
+
|
|
1309
|
+
# final cast to standard dtypes
|
|
1310
|
+
dataframe = dataframe.cast(POLARS_DTYPE_DICT.TIME_TICK_DTYPE)
|
|
1311
|
+
|
|
1312
|
+
else:
|
|
1313
|
+
|
|
1314
|
+
# final cast to standard dtypes
|
|
1315
|
+
dataframe = dataframe.cast(POLARS_DTYPE_DICT.TIME_TF_DTYPE)
|
|
1316
|
+
|
|
1317
|
+
else:
|
|
1318
|
+
|
|
1319
|
+
logger.bind(target='localdb').critical(f'file {filepath} not found')
|
|
1320
|
+
raise FileNotFoundError("file {filepath} not found")
|
|
1321
|
+
|
|
1322
|
+
return dataframe
|