forex_data_aggregator 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1322 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Feb 23 00:02:36 2025
4
+
5
+ @author: fiora
6
+ """
7
+
8
+ '''
9
+ Module to connect to a database instance
10
+
11
+ Design constraint:
12
+
13
+ start with only support for polars, prefer lazyframe when possibile
14
+
15
+ read and write using polars dataframe or lazyframe
16
+ exec requests using SQL query language
17
+ OSS versions for windows required
18
+ '''
19
+
20
+
21
+ from loguru import logger
22
+ from pathlib import Path
23
+ from attrs import (
24
+ define,
25
+ field,
26
+ validate,
27
+ validators
28
+ )
29
+ from re import (
30
+ fullmatch,
31
+ search,
32
+ IGNORECASE
33
+ )
34
+ from collections import OrderedDict
35
+ from numpy import array
36
+ from typing import Any, Dict, List, Optional, Tuple, Union, cast
37
+ from pathlib import Path as PathType
38
+ from datetime import datetime
39
+
40
+ # Import polars types directly
41
+ from polars import (
42
+ DataFrame as polars_dataframe,
43
+ LazyFrame as polars_lazyframe,
44
+ read_database
45
+ )
46
+
47
+ # Import from adbc_driver_sqlite
48
+ from adbc_driver_sqlite import dbapi as sqlite_dbapi
49
+
50
+ # Import from sqlalchemy for text queries
51
+ try:
52
+ from sqlalchemy import text
53
+ except ImportError:
54
+ # Fallback if sqlalchemy not available
55
+ def text(s: str) -> str:
56
+ return s
57
+
58
+ # Import from common module - explicit imports for items not in __all__
59
+ from .common import (
60
+ TICK_TIMEFRAME,
61
+ BASE_DATA_COLUMN_NAME,
62
+ DATA_KEY,
63
+ DATA_TYPE,
64
+ SUPPORTED_DATA_FILES,
65
+ SUPPORTED_DATA_ENGINES,
66
+ SUPPORTED_BASE_DATA_COLUMN_NAME,
67
+ SUPPORTED_SQL_COMPARISON_OPERATORS,
68
+ get_attrs_names,
69
+ validator_dir_path,
70
+ is_empty_dataframe,
71
+ list_remove_duplicates,
72
+ POLARS_DTYPE_DICT
73
+ )
74
+
75
+ # Import remaining items via star import
76
+ from .common import *
77
+
78
+ # Import from config module
79
+ from ..config import (
80
+ read_config_file,
81
+ read_config_string,
82
+ read_config_folder
83
+ )
84
+
85
+
86
+ '''
87
+ BASE CONNECTOR
88
+ '''
89
+
90
+
91
+ @define(kw_only=True, slots=True)
92
+ class DatabaseConnector:
93
+
94
+ data_folder: str = field(default='',
95
+ validator=validators.instance_of(str))
96
+
97
+ def __init__(self, **kwargs: Any) -> None:
98
+
99
+ pass
100
+
101
+ def __attrs_post_init__(self) -> None:
102
+
103
+ # create data folder if not exists
104
+ if (
105
+ not Path(self.data_folder).exists()
106
+ or
107
+ not Path(self.data_folder).is_dir()
108
+ ):
109
+
110
+ Path(self.data_folder).mkdir(parents=True,
111
+ exist_ok=True)
112
+
113
+ def connect(self) -> Any:
114
+ """Connect to database - must be implemented by subclasses."""
115
+ raise NotImplementedError("Subclasses must implement connect")
116
+
117
+ def check_connection(self) -> bool:
118
+ """Check database connection - must be implemented by subclasses."""
119
+ raise NotImplementedError("Subclasses must implement check_connection")
120
+
121
+ def write_data(self, target_table: str, dataframe: Union[polars_dataframe, polars_lazyframe], clean: bool = False) -> None:
122
+ """Write data to database - must be implemented by subclasses."""
123
+ raise NotImplementedError("Subclasses must implement write_data")
124
+
125
+ def read_data(self, market: str, ticker: str, timeframe: str, start: datetime, end: datetime) -> polars_lazyframe:
126
+ """Read data from database - must be implemented by subclasses."""
127
+ raise NotImplementedError("Subclasses must implement read_data")
128
+
129
+ def exec_sql(self) -> None:
130
+ """Execute SQL - must be implemented by subclasses."""
131
+ raise NotImplementedError("Subclasses must implement exec_sql")
132
+
133
+ def _db_key(self, market: str, ticker: str, timeframe: str) -> str:
134
+ """Generate database key - must be implemented by subclasses."""
135
+ raise NotImplementedError("Subclasses must implement _db_key")
136
+
137
+ def get_tickers_list(self) -> List[str]:
138
+ """Get list of tickers - must be implemented by subclasses."""
139
+ raise NotImplementedError("Subclasses must implement get_tickers_list")
140
+
141
+ def get_ticker_keys(self, ticker: str, timeframe: Optional[str] = None) -> List[str]:
142
+ """Get ticker keys - must be implemented by subclasses."""
143
+ raise NotImplementedError("Subclasses must implement get_ticker_keys")
144
+
145
+ def get_ticker_years_list(self, ticker: str, timeframe: str = TICK_TIMEFRAME) -> List[int]:
146
+ """Get years list for ticker - must be implemented by subclasses."""
147
+ raise NotImplementedError("Subclasses must implement get_ticker_years_list")
148
+
149
+ def clear_database(self, filter: Optional[str] = None) -> None:
150
+ """Clear database - must be implemented by subclasses."""
151
+ raise NotImplementedError("Subclasses must implement clear_database")
152
+
153
+
154
+ '''
155
+ DUCKDB CONNECTOR:
156
+
157
+ TABLE TEMPLATE:
158
+ <trading field (e.g. Forex, Stocks)>.ticker.timeframe
159
+
160
+ '''
161
+
162
+
163
+ @define(kw_only=True, slots=True)
164
+ class DuckDBConnector(DatabaseConnector):
165
+
166
+ _duckdb_filepath = field(default='',
167
+ validator=validators.instance_of(str))
168
+
169
+ def __init__(self, **kwargs: Any) -> None:
170
+
171
+ _class_attributes_name = get_attrs_names(self, **kwargs)
172
+ _not_assigned_attrs_index_mask = [True] * len(_class_attributes_name)
173
+
174
+ if 'config' in kwargs.keys():
175
+
176
+ if kwargs['config']:
177
+
178
+ config_path = Path(kwargs['config'])
179
+
180
+ if (
181
+ config_path.exists() and
182
+ config_path.is_dir()
183
+ ):
184
+
185
+ config_filepath = read_config_folder(config_path,
186
+ file_pattern='_config.yaml')
187
+
188
+ else:
189
+
190
+ config_filepath = Path()
191
+
192
+ config_args = {}
193
+ if config_filepath.exists() \
194
+ and \
195
+ config_filepath.is_file() \
196
+ and \
197
+ config_filepath.suffix == '.yaml':
198
+
199
+ # read parameters from config file
200
+ # and force keys to lower case
201
+ config_args = {key.lower(): val for key, val in
202
+ read_config_file(str(config_filepath)).items()}
203
+
204
+ elif isinstance(kwargs['config'], str):
205
+
206
+ # read parameters from config file
207
+ # and force keys to lower case
208
+ config_args = {key.lower(): val for key, val in
209
+ read_config_string(kwargs['config']).items()}
210
+
211
+ else:
212
+
213
+ logger.critical('invalid config type '
214
+ f'{kwargs["config"]}: '
215
+ 'required str or Path, got '
216
+ f'{type(kwargs["config"])}')
217
+ raise TypeError
218
+
219
+ # check consistency of config_args
220
+ if (
221
+ not isinstance(config_args, dict) or
222
+ not bool(config_args)
223
+ ):
224
+
225
+ logger.critical(f'config {kwargs["config"]} '
226
+ 'has no valid yaml formatted data')
227
+ raise TypeError
228
+
229
+ # set args from config file
230
+ attrs_keys_configfile = \
231
+ set(_class_attributes_name).intersection(config_args.keys())
232
+
233
+ for attr_key in attrs_keys_configfile:
234
+
235
+ self.__setattr__(attr_key,
236
+ config_args[attr_key])
237
+
238
+ _not_assigned_attrs_index_mask[
239
+ _class_attributes_name.index(attr_key)
240
+ ] = False
241
+
242
+ # set args from instantiation
243
+ # override if attr already has a value from config
244
+ attrs_keys_input = \
245
+ set(_class_attributes_name).intersection(kwargs.keys())
246
+
247
+ for attr_key in attrs_keys_input:
248
+
249
+ self.__setattr__(attr_key,
250
+ kwargs[attr_key])
251
+
252
+ _not_assigned_attrs_index_mask[
253
+ _class_attributes_name.index(attr_key)
254
+ ] = False
255
+
256
+ # attrs not present in config file or instance inputs
257
+ # --> self.attr leads to KeyError
258
+ # are manually assigned to default value derived
259
+ # from __attrs_attrs__
260
+
261
+ for attr_key in array(_class_attributes_name)[
262
+ _not_assigned_attrs_index_mask
263
+ ]:
264
+
265
+ try:
266
+
267
+ attr = [attr
268
+ for attr in self.__attrs_attrs__
269
+ if attr.name == attr_key][0]
270
+
271
+ except KeyError:
272
+
273
+ logger.warning('KeyError: initializing object has no '
274
+ f'attribute {attr.name}')
275
+ raise
276
+
277
+ except IndexError:
278
+
279
+ logger.warning('IndexError: initializing object has no '
280
+ f'attribute {attr.name}')
281
+ raise
282
+
283
+ else:
284
+
285
+ # assign default value
286
+ # try default and factory sabsequently
287
+ # if neither are present
288
+ # assign None
289
+ if hasattr(attr, 'default'):
290
+
291
+ if hasattr(attr.default, 'factory'):
292
+ self.__setattr__(attr.name,
293
+ attr.default.factory())
294
+
295
+ else:
296
+
297
+ self.__setattr__(attr.name,
298
+ attr.default)
299
+
300
+ else:
301
+
302
+ self.__setattr__(attr.name,
303
+ None)
304
+
305
+ else:
306
+ logger.trace(f'config {kwargs["config"]} is empty, using default configuration')
307
+
308
+ else:
309
+
310
+ # no config file is defined
311
+ # call generated init
312
+ self.__attrs_init__(**kwargs) # type: ignore[attr-defined]
313
+
314
+ validate(self)
315
+
316
+ self.__attrs_post_init__(**kwargs)
317
+
318
+ def __attrs_post_init__(self, **kwargs: Any) -> None:
319
+
320
+ super().__attrs_post_init__(**kwargs)
321
+
322
+ # set up log sink for DuckDB
323
+ logger.add(Path(self.data_folder) / 'log' / 'duckdb.log',
324
+ level="TRACE",
325
+ rotation="5 MB",
326
+ filter=lambda record: ('duckdb' == record['extra'].get('target') and
327
+ bool(record["extra"].get('target'))))
328
+
329
+ # create duck file path if not exists
330
+ if (
331
+ not Path(self.duckdb_filepath).exists() or
332
+ not Path(self.duckdb_filepath).is_file()
333
+ ):
334
+
335
+ Path(self.duckdb_filepath).parent.mkdir(parents=True,
336
+ exist_ok=True)
337
+ else:
338
+
339
+ logger.bind(target='duckdb').trace(f'DuckDB file {self.duckdb_filepath} already exists')
340
+
341
+ # set autovacuum
342
+ conn = self.connect()
343
+
344
+ # check auto vacuum property
345
+ cur = conn.cursor()
346
+ cur.execute('PRAGMA main.auto_vacuum')
347
+ cur.execute('PRAGMA main.auto_vacuum = 2')
348
+ cur.close()
349
+ conn.close()
350
+
351
+ def connect(self) -> Any:
352
+
353
+ try:
354
+
355
+ con = sqlite_dbapi.connect(uri=self.duckdb_filepath)
356
+
357
+ except Exception as e:
358
+
359
+ logger.bind(target='duckdb').error(f'ADBC-SQLITE: connection error: {e}')
360
+ raise
361
+
362
+ else:
363
+
364
+ return con
365
+
366
+ def check_connection(self) -> bool:
367
+
368
+ out_check_connection = False
369
+
370
+ conn = self.connect()
371
+
372
+ out_check_connection = False
373
+
374
+ try:
375
+
376
+ info = read_database(text('SHOW DATABASES'), conn)
377
+
378
+ except Exception as e:
379
+
380
+ logger.bind(target='duckdb').error(f'Error during connection to {self.duckdb_filepath}')
381
+
382
+ else:
383
+
384
+ logger.bind(target='duckdb').trace(f'{info}')
385
+
386
+ out_check_connection = not is_empty_dataframe(info)
387
+
388
+ return out_check_connection
389
+
390
+ def _to_duckdb_column_types(self, columns_dict: Dict[str, Any]) -> Dict[str, str]:
391
+
392
+ duckdb_columns_dict = {}
393
+
394
+ for key, value in columns_dict.items():
395
+
396
+ match key:
397
+
398
+ case BASE_DATA_COLUMN_NAME.TIMESTAMP:
399
+
400
+ duckdb_columns_dict[BASE_DATA_COLUMN_NAME.TIMESTAMP] = 'TIMESTAMP_MS'
401
+
402
+ case BASE_DATA_COLUMN_NAME.ASK \
403
+ | BASE_DATA_COLUMN_NAME.BID \
404
+ | BASE_DATA_COLUMN_NAME.OPEN \
405
+ | BASE_DATA_COLUMN_NAME.HIGH \
406
+ | BASE_DATA_COLUMN_NAME.LOW \
407
+ | BASE_DATA_COLUMN_NAME.CLOSE \
408
+ | BASE_DATA_COLUMN_NAME.VOL \
409
+ | BASE_DATA_COLUMN_NAME.P_VALUE:
410
+
411
+ duckdb_columns_dict[key] = 'FLOAT'
412
+
413
+ case BASE_DATA_COLUMN_NAME.TRANSACTIONS:
414
+
415
+ duckdb_columns_dict[key] = 'UBIGINT'
416
+
417
+ case BASE_DATA_COLUMN_NAME.OTC:
418
+
419
+ duckdb_columns_dict[key] = 'FLOAT'
420
+
421
+ # force timestamp as first key
422
+ if not list(duckdb_columns_dict.keys())[0] == BASE_DATA_COLUMN_NAME.TIMESTAMP:
423
+
424
+ o_dict = OrderedDict(duckdb_columns_dict.items())
425
+ o_dict.move_to_end(BASE_DATA_COLUMN_NAME.TIMESTAMP, last=False)
426
+
427
+ duckdb_columns_dict = dict(o_dict)
428
+
429
+ else:
430
+ logger.bind(target='duckdb').trace(f'Timestamp is already the first column in {duckdb_columns_dict.keys()}')
431
+
432
+ return duckdb_columns_dict
433
+
434
+ def _list_tables(self) -> List[str]:
435
+
436
+ tables_list: List[str] = []
437
+
438
+ conn = self.connect()
439
+
440
+ try:
441
+
442
+ tables = read_database(query='SELECT * FROM sqlite_master',
443
+ connection=conn)
444
+
445
+ except Exception as e:
446
+
447
+ logger.bind(target='duckdb').error(f'Error list tables for {self.duckdb_filepath}: {e}')
448
+
449
+ else:
450
+
451
+ tables_list = list(tables['tbl_name'])
452
+
453
+ conn.close()
454
+
455
+ return tables_list
456
+
457
+ def _db_key(self,
458
+ market: str,
459
+ ticker: str,
460
+ timeframe: str
461
+ ) -> str:
462
+ """
463
+
464
+ get a str key of dotted divided elements
465
+
466
+ key template = ticker.timeframe.data_type
467
+
468
+ Parameters
469
+ ----------
470
+ ticker : TYPE
471
+ DESCRIPTION.
472
+ year : TYPE
473
+ DESCRIPTION.
474
+ data_type : TYPE
475
+ DESCRIPTION.
476
+
477
+ Returns
478
+ -------
479
+ None.
480
+
481
+ """
482
+
483
+ # skip checks cuse they
484
+ # are not meant for polars syntax of timeframes/frequencies
485
+ # tf = check_timeframe_str(timeframe)
486
+
487
+ return '_'.join([market.lower(),
488
+ ticker.lower(),
489
+ timeframe.lower()])
490
+
491
+ def _get_items_from_db_key(self,
492
+ key
493
+ ) -> tuple:
494
+
495
+ return tuple(key.split('_'))
496
+
497
+ def get_tickers_list(self) -> List[str]:
498
+
499
+ tickers_list = []
500
+
501
+ for table_name in self._list_tables():
502
+
503
+ items = self._get_items_from_db_key(table_name)
504
+
505
+ tickers_list.append(items[DATA_KEY.TICKER_INDEX])
506
+
507
+ return list_remove_duplicates(tickers_list)
508
+
509
+ def get_ticker_keys(self, ticker: str, timeframe: Optional[str] = None) -> List[str]:
510
+
511
+ ticker_keys_list = []
512
+
513
+ for table_name in self._list_tables():
514
+
515
+ items = self._get_items_from_db_key(table_name)
516
+
517
+ if items[DATA_KEY.TICKER_INDEX] == ticker.lower():
518
+
519
+ if timeframe:
520
+
521
+ if items[DATA_KEY.TF_INDEX] == timeframe.lower():
522
+
523
+ ticker_keys_list.append(table_name)
524
+
525
+ else:
526
+
527
+ ticker_keys_list.append(table_name)
528
+
529
+ return ticker_keys_list
530
+
531
+ def get_ticker_years_list(self, ticker: str, timeframe: str = TICK_TIMEFRAME) -> List[int]:
532
+
533
+ ticker_years_list = []
534
+ table = ''
535
+ key_found = False
536
+
537
+ for table_name in self._list_tables():
538
+
539
+ items = self._get_items_from_db_key(table_name)
540
+
541
+ if (
542
+ items[DATA_KEY.TICKER_INDEX] == ticker.lower() and
543
+ items[DATA_KEY.TF_INDEX] == timeframe.lower()
544
+ ):
545
+
546
+ table = table_name
547
+ key_found = True
548
+
549
+ break
550
+
551
+ if key_found:
552
+
553
+ conn = self.connect()
554
+
555
+ try:
556
+
557
+ query = f'''SELECT DISTINCT STRFTIME('%Y', CAST({
558
+ BASE_DATA_COLUMN_NAME.TIMESTAMP} AS TEXT))
559
+ AS YEAR
560
+ FROM {table}'''
561
+ read = read_database(query, conn)
562
+
563
+ except Exception as e:
564
+
565
+ logger.bind(target='duckdb').error(f'Error querying table {table}: {e}')
566
+ raise
567
+
568
+ else:
569
+
570
+ ticker_years_list = [int(row[0]) for row in read.iter_rows()]
571
+
572
+ conn.commit()
573
+ conn.close()
574
+
575
+ return ticker_years_list
576
+
577
+ def write_data(self, target_table: str, dataframe: Union[polars_dataframe, polars_lazyframe], clean: bool = False) -> None:
578
+
579
+ duckdb_cols_dict = {}
580
+ if isinstance(dataframe, polars_lazyframe):
581
+
582
+ duckdb_cols_dict = self._to_duckdb_column_types(
583
+ dict(dataframe.collect_schema()))
584
+ dataframe = dataframe.collect()
585
+
586
+ else:
587
+
588
+ duckdb_cols_dict = self._to_duckdb_column_types(dict(dataframe.schema))
589
+
590
+ duckdb_cols_str = ', '.join([f"{key} {duckdb_cols_dict[key]}"
591
+ for key in duckdb_cols_dict])
592
+
593
+ # open a connection
594
+ conn = self.connect()
595
+
596
+ # exec stable creation
597
+ table_list = self._list_tables()
598
+
599
+ if_table_exists = 'replace'
600
+ if target_table in table_list:
601
+
602
+ # stable_describe = read_database(f'DESCRIBE {target_table}')
603
+
604
+ # get existing stable column structure
605
+ # if they match, append data
606
+ # if no match, replace stable
607
+
608
+ if_table_exists = 'append'
609
+
610
+ target_length = len(dataframe)
611
+
612
+ table_write = dataframe.write_database(
613
+ table_name=target_table,
614
+ connection=conn,
615
+ if_table_exists=if_table_exists,
616
+ engine='adbc'
617
+ )
618
+
619
+ conn.commit()
620
+ conn.close()
621
+
622
+ # clean stage
623
+ if clean:
624
+
625
+ conn = self.connect()
626
+
627
+ # delete duplicates
628
+ query_clean = f'''DELETE FROM {target_table}
629
+ WHERE ROWID NOT IN (
630
+ SELECT MIN(ROWID)
631
+ FROM {target_table}
632
+ GROUP BY {BASE_DATA_COLUMN_NAME.TIMESTAMP}
633
+ );'''
634
+
635
+ cur = conn.cursor()
636
+ res = cur.execute(query_clean)
637
+
638
+ # Close
639
+ cur.close()
640
+ conn.commit()
641
+ conn.close()
642
+
643
+ conn = self.connect()
644
+ cur = conn.cursor()
645
+ vacuum = cur.execute('PRAGMA main.incremental_vacuum')
646
+
647
+ # Close
648
+ cur.close()
649
+ conn.commit()
650
+ conn.close()
651
+
652
+ def read_data(self,
653
+ market: str,
654
+ ticker: str,
655
+ timeframe: str,
656
+ start: datetime,
657
+ end: datetime
658
+ ) -> polars_lazyframe:
659
+
660
+ dataframe = polars_lazyframe()
661
+
662
+ table = self._db_key(market, ticker, timeframe)
663
+ # check if database is available
664
+ if table in self._list_tables():
665
+
666
+ # open a connection
667
+ conn = self.connect()
668
+
669
+ try:
670
+
671
+ start_str = start.isoformat()
672
+ end_str = end.isoformat()
673
+ '''
674
+ Here you could use also
675
+ WHERE CAST({BASE_DATA_COLUMN_NAME.TIMESTAMP} AS TEXT)
676
+ '''
677
+ query = f'''SELECT * FROM {table}
678
+ WHERE {BASE_DATA_COLUMN_NAME.TIMESTAMP}
679
+ BETWEEN '{start_str}' AND '{end_str}'
680
+ ORDER BY {BASE_DATA_COLUMN_NAME.TIMESTAMP}'''
681
+ dataframe = read_database(query, conn).lazy()
682
+
683
+ except Exception as e:
684
+
685
+ logger.bind(target='duckdb').error(f'executing query {query} failed: {e}')
686
+
687
+ else:
688
+
689
+ if timeframe == TICK_TIMEFRAME:
690
+
691
+ # final cast to standard dtypes
692
+ dataframe = dataframe.cast(cast(Any, POLARS_DTYPE_DICT.TIME_TICK_DTYPE))
693
+
694
+ else:
695
+
696
+ # final cast to standard dtypes
697
+ dataframe = dataframe.cast(cast(Any, POLARS_DTYPE_DICT.TIME_TF_DTYPE))
698
+
699
+ # close
700
+ conn.commit()
701
+ conn.close()
702
+
703
+ return dataframe
704
+
705
+ def clear_database(self, filter: Optional[str] = None) -> None:
706
+ """
707
+ Clear database tables
708
+ If filter is provided, delete only tables related to that ticker
709
+ """
710
+ tables = self._list_tables()
711
+ conn = self.connect()
712
+ cur = conn.cursor()
713
+
714
+ for table in tables:
715
+ if filter:
716
+ if search(filter, table, IGNORECASE):
717
+ cur.execute(f"DROP TABLE {table}")
718
+ else:
719
+ cur.execute(f"DROP TABLE {table}")
720
+
721
+ cur.execute("VACUUM")
722
+ cur.close()
723
+ conn.commit()
724
+ conn.close()
725
+
726
+
727
+ '''
728
+ LOCAL DATA FILES MANAGER
729
+
730
+ '''
731
+
732
+
733
+ @define(kw_only=True, slots=True)
734
+ class LocalDBConnector(DatabaseConnector):
735
+
736
+ data_folder: str = field(default='',
737
+ validator=validators.instance_of(str))
738
+ data_type: str = field(default='parquet',
739
+ validator=validators.in_(SUPPORTED_DATA_FILES))
740
+ engine: str = field(default='polars_lazy',
741
+ validator=validators.in_(SUPPORTED_DATA_ENGINES))
742
+
743
+ _local_path = field(
744
+ default=Path('.'),
745
+ validator=validator_dir_path(create_if_missing=False))
746
+
747
+ def __init__(self, **kwargs: Any) -> None:
748
+
749
+ _class_attributes_name = get_attrs_names(self, **kwargs)
750
+ _not_assigned_attrs_index_mask = [True] * len(_class_attributes_name)
751
+
752
+ if 'config' in kwargs.keys():
753
+
754
+ if kwargs['config']:
755
+
756
+ config_path = Path(kwargs['config'])
757
+
758
+ if (
759
+ config_path.exists() and
760
+ config_path.is_dir()
761
+ ):
762
+
763
+ config_filepath = read_config_folder(config_path,
764
+ file_pattern='_config.yaml')
765
+
766
+ else:
767
+
768
+ config_filepath = Path()
769
+
770
+ config_args = {}
771
+ if config_filepath.exists() \
772
+ and \
773
+ config_filepath.is_file() \
774
+ and \
775
+ config_filepath.suffix == '.yaml':
776
+
777
+ # read parameters from config file
778
+ # and force keys to lower case
779
+ config_args = {key.lower(): val for key, val in
780
+ read_config_file(str(config_filepath)).items()}
781
+
782
+ elif isinstance(kwargs['config'], str):
783
+
784
+ # read parameters from config file
785
+ # and force keys to lower case
786
+ config_args = {key.lower(): val for key, val in
787
+ read_config_string(kwargs['config']).items()}
788
+
789
+ else:
790
+
791
+ logger.bind(target='localdb').critical(
792
+ 'invalid config type '
793
+ f'{kwargs["config"]}: '
794
+ 'required str or Path, got '
795
+ f'{type(kwargs["config"])}')
796
+ raise TypeError
797
+
798
+ # check consistency of config_args
799
+ if (
800
+ not isinstance(config_args, dict) or
801
+ not bool(config_args)
802
+ ):
803
+
804
+ logger.bind(target='localdb').critical(
805
+ f'config {kwargs["config"]} '
806
+ 'has no valid yaml formatted data')
807
+ raise TypeError
808
+
809
+ # set args from config file
810
+ attrs_keys_configfile = \
811
+ set(_class_attributes_name).intersection(config_args.keys())
812
+
813
+ for attr_key in attrs_keys_configfile:
814
+
815
+ self.__setattr__(attr_key,
816
+ config_args[attr_key])
817
+
818
+ _not_assigned_attrs_index_mask[
819
+ _class_attributes_name.index(attr_key)
820
+ ] = False
821
+
822
+ # set args from instantiation
823
+ # override if attr already has a value from config
824
+ attrs_keys_input = \
825
+ set(_class_attributes_name).intersection(kwargs.keys())
826
+
827
+ for attr_key in attrs_keys_input:
828
+
829
+ self.__setattr__(attr_key,
830
+ kwargs[attr_key])
831
+
832
+ _not_assigned_attrs_index_mask[
833
+ _class_attributes_name.index(attr_key)
834
+ ] = False
835
+
836
+ # attrs not present in config file or instance inputs
837
+ # --> self.attr leads to KeyError
838
+ # are manually assigned to default value derived
839
+ # from __attrs_attrs__
840
+
841
+ for attr_key in array(_class_attributes_name)[
842
+ _not_assigned_attrs_index_mask
843
+ ]:
844
+
845
+ try:
846
+
847
+ attr = [attr
848
+ for attr in self.__attrs_attrs__
849
+ if attr.name == attr_key][0]
850
+
851
+ except KeyError:
852
+
853
+ logger.error('KeyError: initializing object has no '
854
+ f'attribute {attr.name}')
855
+ raise
856
+
857
+ except IndexError:
858
+
859
+ logger.error('IndexError: initializing object has no '
860
+ f'attribute {attr.name}')
861
+ raise
862
+
863
+ else:
864
+
865
+ # assign default value
866
+ # try default and factory sabsequently
867
+ # if neither are present
868
+ # assign None
869
+ if hasattr(attr, 'default'):
870
+
871
+ if hasattr(attr.default, 'factory'):
872
+
873
+ self.__setattr__(attr.name,
874
+ attr.default.factory())
875
+
876
+ else:
877
+
878
+ self.__setattr__(attr.name,
879
+ attr.default)
880
+
881
+ else:
882
+
883
+ self.__setattr__(attr.name,
884
+ None)
885
+
886
+ else:
887
+
888
+ logger.trace(f'config {kwargs["config"]} is empty, using default configuration')
889
+
890
+ else:
891
+
892
+ # no config file is defined
893
+ # call generated init
894
+ self.__attrs_init__(**kwargs) # type: ignore[attr-defined]
895
+
896
+ validate(self)
897
+
898
+ self.__attrs_post_init__(**kwargs)
899
+
900
+ def __attrs_post_init__(self, **kwargs: Any) -> None:
901
+
902
+ super().__attrs_post_init__()
903
+
904
+ # set up log sink for LocalDB
905
+ logger.add(Path(self.data_folder) / 'log' / 'localdb.log',
906
+ level="TRACE",
907
+ rotation="5 MB",
908
+ filter=lambda record: ('localdb' == record['extra'].get('target') and
909
+ bool(record["extra"].get('target'))))
910
+
911
+ self._local_path = Path(self.data_folder)
912
+
913
+ def _db_key(self,
914
+ market: str,
915
+ ticker: str,
916
+ timeframe: str
917
+ ) -> str:
918
+ """
919
+
920
+ get a str key of dotted divided elements
921
+
922
+ key template = ticker.timeframe.data_type
923
+
924
+ Parameters
925
+ ----------
926
+ ticker : TYPE
927
+ DESCRIPTION.
928
+ year : TYPE
929
+ DESCRIPTION.
930
+ data_type : TYPE
931
+ DESCRIPTION.
932
+
933
+ Returns
934
+ -------
935
+ None.
936
+
937
+ """
938
+
939
+ # skip checks cuse they
940
+ # are not meant for polars syntax of timeframes/frequencies
941
+ # tf = check_timeframe_str(timeframe)
942
+
943
+ return '_'.join([market.lower(),
944
+ ticker.lower(),
945
+ timeframe.lower()])
946
+
947
+ def _get_items_from_db_key(self,
948
+ key
949
+ ) -> tuple:
950
+
951
+ return tuple(key.split('_'))
952
+
953
+ def _get_file_details(self, filename: str) -> Tuple[str, str, str]:
954
+
955
+ if not (
956
+ isinstance(filename, str)
957
+ ):
958
+
959
+ logger.bind(target='localdb').error('filename {filename} invalid type: required str')
960
+ raise TypeError(f'filename {filename} invalid type: required str')
961
+
962
+ file_items = self._get_items_from_db_key(filename)
963
+
964
+ # return each file details
965
+ return file_items
966
+
967
+ def _get_filename(self, market: str, ticker: str, tf: str) -> str:
968
+
969
+ # based on standard filename template
970
+ return FILENAME_STR.format(market=market.lower(),
971
+ ticker=ticker.lower(),
972
+ tf=tf.lower(),
973
+ file_ext=self.data_type.lower())
974
+
975
+ def _list_local_data(self) -> List[PathType]:
976
+
977
+ local_files = []
978
+ local_files_name = []
979
+
980
+ # list for all data filetypes supported
981
+ local_files = [file for file in list(self._local_path.rglob(f'*'))
982
+ if search(self.data_type + '$', file.suffix)]
983
+
984
+ local_files_name = [file.name for file in local_files]
985
+
986
+ # check compliance of files to convention (see notes)
987
+ # TODO: warning if no compliant and filter out from files found
988
+
989
+ return local_files, local_files_name
990
+
991
+ def _list_tables(self) -> List[str]:
992
+
993
+ local_files, tables_list = self._list_local_data()
994
+
995
+ return tables_list
996
+
997
+ def get_tickers_list(self) -> List[str]:
998
+
999
+ tickers_list = []
1000
+
1001
+ local_files, local_files_name = self._list_local_data()
1002
+
1003
+ for filename in local_files_name:
1004
+
1005
+ items = self._get_file_details(filename)
1006
+ tickers_list.append(items[DATA_KEY.TICKER_INDEX])
1007
+
1008
+ return list_remove_duplicates(tickers_list)
1009
+
1010
+ def clear_database(self, filter: Optional[str] = None) -> None:
1011
+
1012
+ """
1013
+ Clear database files
1014
+ If filter is provided and is a ticker present in database (files present)
1015
+ delete only files related to that ticker
1016
+ """
1017
+
1018
+ if filter:
1019
+
1020
+ # in local path search for files having filter in path stem
1021
+ # and delete them
1022
+ # list all files in local path ending with data_type
1023
+ # and use re.search to catch matches
1024
+ if isinstance(filter, str):
1025
+
1026
+ data_files = self._local_path.rglob(f'*.{self.data_type}')
1027
+ if data_files:
1028
+ for file in data_files:
1029
+ if search(filter, file.stem, IGNORECASE):
1030
+ file.unlink(missing_ok=True)
1031
+ else:
1032
+ logger.bind(target='localdb').info(f'No data files found in {self._local_path} with filter {filter}')
1033
+
1034
+ else:
1035
+ logger.bind(target='localdb').error(f'Filter {filter} invalid type: required str')
1036
+
1037
+ else:
1038
+
1039
+ # clear all files in local path at
1040
+ # folder level using shutil
1041
+ shutil.rmtree(self._local_path)
1042
+
1043
+ def get_ticker_keys(self, ticker: str, timeframe: Optional[str] = None) -> List[str]:
1044
+
1045
+ local_files, local_files_name = self._list_local_data()
1046
+
1047
+ if timeframe:
1048
+
1049
+ return [
1050
+ key for key in local_files_name
1051
+ if search(f'{ticker}',
1052
+ key) and
1053
+ self._get_items_from_db_key(key)[DATA_KEY.TF_INDEX] ==
1054
+ timeframe
1055
+ ]
1056
+
1057
+ else:
1058
+
1059
+ return [
1060
+ key for key in local_files_name
1061
+ if search(f'{ticker}',
1062
+ key)
1063
+ ]
1064
+
1065
+ def get_ticker_years_list(self, ticker: str, timeframe: str = TICK_TIMEFRAME) -> List[int]:
1066
+
1067
+ ticker_years_list = []
1068
+ table = ''
1069
+ key_found = False
1070
+
1071
+ local_files, local_files_name = self._list_local_data()
1072
+ ticker_keys = []
1073
+
1074
+ files = [
1075
+ key for key in local_files
1076
+ if search(f'{ticker.lower()}',
1077
+ str(key.stem)) and
1078
+ self._get_items_from_db_key(str(key.stem))[DATA_KEY.TF_INDEX] ==
1079
+ timeframe.lower()
1080
+ ]
1081
+
1082
+ dataframe = None
1083
+
1084
+ if len(files) == 1:
1085
+
1086
+ if self.data_type == DATA_TYPE.CSV_FILETYPE:
1087
+
1088
+ dataframe = read_csv(self.engine, files[0])
1089
+
1090
+ elif self.data_type == DATA_TYPE.PARQUET_FILETYPE:
1091
+
1092
+ dataframe = read_parquet(self.engine, files[0])
1093
+
1094
+ try:
1095
+
1096
+ query = f'''SELECT DISTINCT STRFTIME({
1097
+ BASE_DATA_COLUMN_NAME.TIMESTAMP}, '%Y')
1098
+ AS YEAR
1099
+ FROM self'''
1100
+ read = dataframe.sql(query)
1101
+
1102
+ except Exception as e:
1103
+
1104
+ logger.bind(target='localdb').error(f'Error querying table {table}: {e}')
1105
+ raise
1106
+
1107
+ else:
1108
+
1109
+ ticker_years_list = [int(row[0]) for row in read.collect().iter_rows()]
1110
+
1111
+ else:
1112
+ logger.bind(target='localdb').warning(f'Expected 1 file for {ticker} - {timeframe}, found {len(files)}')
1113
+
1114
+ return ticker_years_list
1115
+
1116
+ def write_data(
1117
+ self,
1118
+ target_table: str,
1119
+ dataframe: Union[polars_dataframe, polars_lazyframe],
1120
+ clean: bool = False
1121
+ ) -> None:
1122
+
1123
+ items = self._get_items_from_db_key(target_table)
1124
+
1125
+ filename = self._get_filename(items[DATA_KEY.MARKET],
1126
+ items[DATA_KEY.TICKER_INDEX],
1127
+ items[DATA_KEY.TF_INDEX])
1128
+
1129
+ filepath = (self._local_path /
1130
+ items[DATA_KEY.MARKET] /
1131
+ items[DATA_KEY.TICKER_INDEX] /
1132
+ filename)
1133
+
1134
+ if (
1135
+ not filepath.exists() or
1136
+ not filepath.is_file()
1137
+ ):
1138
+
1139
+ filepath.parent.mkdir(parents=True,
1140
+ exist_ok=True)
1141
+
1142
+ else:
1143
+
1144
+ if self.data_type == DATA_TYPE.CSV_FILETYPE:
1145
+
1146
+ dataframe_ex = read_csv(self.engine, filepath)
1147
+
1148
+ elif self.data_type == DATA_TYPE.PARQUET_FILETYPE:
1149
+
1150
+ dataframe_ex = read_parquet(self.engine, filepath)
1151
+
1152
+ dataframe = concat_data([dataframe, dataframe_ex])
1153
+ # clean duplicated timestamps rows, keep first by default
1154
+ dataframe = dataframe.unique(
1155
+ subset=[
1156
+ BASE_DATA_COLUMN_NAME.TIMESTAMP],
1157
+ keep='first').sort(
1158
+ BASE_DATA_COLUMN_NAME.TIMESTAMP)
1159
+
1160
+ if self.data_type == DATA_TYPE.CSV_FILETYPE:
1161
+
1162
+ write_csv(dataframe, filepath)
1163
+
1164
+ elif self.data_type == DATA_TYPE.PARQUET_FILETYPE:
1165
+
1166
+ write_parquet(dataframe, filepath)
1167
+
1168
+ def read_data(self,
1169
+ market: str,
1170
+ ticker: str,
1171
+ timeframe: str,
1172
+ start: datetime,
1173
+ end: datetime,
1174
+ comparison_column_name: List[str] | str | None = None,
1175
+ check_level: List[int | float] | int | float | None = None,
1176
+ comparison_operator: List[SUPPORTED_SQL_COMPARISON_OPERATORS] | SUPPORTED_SQL_COMPARISON_OPERATORS | None = None,
1177
+ comparison_aggregation_mode: SUPPORTED_SQL_CONDITION_AGGREGATION_MODES | None = None
1178
+ ) -> polars_lazyframe:
1179
+
1180
+ comparisons_len = 0
1181
+
1182
+ # Validate and normalize condition parameters if provided
1183
+ if comparison_column_name is not None or check_level is not None or comparison_operator is not None:
1184
+ comparisons_len = len(comparison_column_name)
1185
+
1186
+ if isinstance(comparison_column_name, str):
1187
+ comparison_column_name = [comparison_column_name]
1188
+
1189
+ if isinstance(check_level, (int, float)):
1190
+ check_level = [check_level]
1191
+
1192
+ if isinstance(comparison_operator, str):
1193
+ comparison_operator = [comparison_operator]
1194
+
1195
+ if any([col not in list(SUPPORTED_BASE_DATA_COLUMN_NAME.__args__) for col in comparison_column_name]):
1196
+ logger.bind(target='localdb').error(f'comparison_column_name must be a supported column name: {list(SUPPORTED_BASE_DATA_COLUMN_NAME.__args__)}')
1197
+ raise ValueError('comparison_column_name must be a supported column name')
1198
+
1199
+ if any([cond not in list(SUPPORTED_SQL_COMPARISON_OPERATORS.__args__) for cond in comparison_operator]):
1200
+ logger.bind(target='localdb').error(f'comparison_operator must be a supported SQL comparison operator: {list(SUPPORTED_SQL_COMPARISON_OPERATORS.__args__)}')
1201
+ raise ValueError('comparison_operator must be a supported SQL comparison operator')
1202
+
1203
+ if (
1204
+ (
1205
+ comparison_aggregation_mode is not None
1206
+ and
1207
+ comparisons_len > 1
1208
+ )
1209
+ and
1210
+ comparison_aggregation_mode not in list(SUPPORTED_SQL_CONDITION_AGGREGATION_MODES.__args__)
1211
+ ):
1212
+ logger.bind(target='localdb').error(f'comparison_aggregation_mode must be a supported SQL condition aggregation mode: {list(SUPPORTED_SQL_CONDITION_AGGREGATION_MODES.__args__)}')
1213
+ raise ValueError('comparison_aggregation_mode must be a supported SQL condition aggregation mode')
1214
+
1215
+ if len(comparison_column_name) != len(check_level) or len(comparison_column_name) != len(comparison_operator):
1216
+ logger.bind(target='localdb').error('comparison_column_name, check_level and comparison_operator must have the same length')
1217
+ raise ValueError('comparison_column_name, check_level and comparison_operator must have the same length')
1218
+
1219
+ comparisons_len = len(comparison_column_name)
1220
+
1221
+ dataframe = polars_lazyframe()
1222
+
1223
+ filename = self._get_filename(market,
1224
+ ticker,
1225
+ timeframe)
1226
+
1227
+ filepath = (self._local_path /
1228
+ market /
1229
+ ticker /
1230
+ filename)
1231
+
1232
+ if self.engine == 'polars':
1233
+
1234
+ dataframe = polars_dataframe()
1235
+
1236
+ elif self.engine == 'polars_lazy':
1237
+
1238
+ dataframe = polars_lazyframe()
1239
+
1240
+ else:
1241
+
1242
+ logger.bind(target='localdb').error(f'Engine {self.engine} or data type {self.data_type} not supported')
1243
+ raise ValueError(f'Engine {self.engine} or data type {self.data_type} not supported')
1244
+
1245
+ if (
1246
+ filepath.exists() and
1247
+ filepath.is_file()
1248
+ ):
1249
+
1250
+ if self.data_type == DATA_TYPE.CSV_FILETYPE:
1251
+
1252
+ dataframe = read_csv(self.engine, filepath)
1253
+
1254
+ elif self.data_type == DATA_TYPE.PARQUET_FILETYPE:
1255
+
1256
+ dataframe = read_parquet(self.engine, filepath)
1257
+
1258
+ try:
1259
+
1260
+ start_str = start.isoformat()
1261
+ end_str = end.isoformat()
1262
+
1263
+ # Build base query with timestamp filter
1264
+ query = f'''SELECT * FROM self
1265
+ WHERE
1266
+ {BASE_DATA_COLUMN_NAME.TIMESTAMP} >= '{start_str}'
1267
+ AND
1268
+ {BASE_DATA_COLUMN_NAME.TIMESTAMP} <= '{end_str}'
1269
+ '''
1270
+ # Aggregate conditional filters if provided
1271
+ # with the aggregation mode specified
1272
+ if comparisons_len > 0:
1273
+
1274
+ if comparisons_len == 1:
1275
+ # only one condition
1276
+ query += f'''AND
1277
+ {comparison_column_name[0]} {comparison_operator[0]} {check_level[0]}
1278
+ '''
1279
+ else:
1280
+ # multiple conditions
1281
+ # wrap all conditions in parentheses with aggregation mode between them
1282
+ query += f'''AND
1283
+ ({comparison_column_name[0]} {comparison_operator[0]} {check_level[0]}
1284
+ '''
1285
+ for col, level, cond, index in zip(comparison_column_name[1:], check_level[1:], comparison_operator[1:], range(1, comparisons_len)):
1286
+
1287
+ if index == comparisons_len - 1:
1288
+ # closing conditions needs closing bracket
1289
+ query += f'''{comparison_aggregation_mode}
1290
+ {col} {cond} {level})
1291
+ '''
1292
+ else:
1293
+ # intermediate conditions
1294
+ query += f'''{comparison_aggregation_mode}
1295
+ {col} {cond} {level}
1296
+ '''
1297
+ # Close query with timestamp ordering
1298
+ query += f'ORDER BY {BASE_DATA_COLUMN_NAME.TIMESTAMP}'
1299
+ dataframe = dataframe.sql(query)
1300
+
1301
+ except Exception as e:
1302
+
1303
+ logger.error(f'executing query {query} failed: {e}')
1304
+
1305
+ else:
1306
+
1307
+ if timeframe == TICK_TIMEFRAME:
1308
+
1309
+ # final cast to standard dtypes
1310
+ dataframe = dataframe.cast(POLARS_DTYPE_DICT.TIME_TICK_DTYPE)
1311
+
1312
+ else:
1313
+
1314
+ # final cast to standard dtypes
1315
+ dataframe = dataframe.cast(POLARS_DTYPE_DICT.TIME_TF_DTYPE)
1316
+
1317
+ else:
1318
+
1319
+ logger.bind(target='localdb').critical(f'file {filepath} not found')
1320
+ raise FileNotFoundError("file {filepath} not found")
1321
+
1322
+ return dataframe