lidb 2.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lidb might be problematic. Click here for more details.

lidb/dataset.py ADDED
@@ -0,0 +1,696 @@
1
+ # Copyright (c) ZhangYundi.
2
+ # Licensed under the MIT License.
3
+ # Created on 2025/10/27 14:13
4
+ # Description:
5
+
6
+ from __future__ import annotations
7
+
8
+ import shutil
9
+ import sys
10
+ import warnings
11
+ from collections import defaultdict
12
+ from enum import Enum
13
+ from functools import partial
14
+ from typing import Callable, Literal
15
+
16
+ import logair
17
+ import pandas as pd
18
+ import polars as pl
19
+ import polars.selectors as cs
20
+ import xcals
21
+ import ygo
22
+ from varname import varname
23
+
24
+ from .database import put, tb_path, scan, DB_PATH
25
+ from .parse import parse_hive_partition_structure
26
+ import inspect
27
+
28
+ DEFAULT_DS_PATH = DB_PATH / "datasets"
29
+
30
+ class InstrumentType(Enum):
31
+ STOCK = "Stock" # 股票
32
+ ETF = "ETF" #
33
+ CB = "ConvertibleBond" # 可转债
34
+
35
+
36
+ def complete_data(fn, date, save_path, partitions):
37
+ logger = logair.get_logger(__name__)
38
+ try:
39
+ data = fn()
40
+ if data is None:
41
+ # 保存数据的逻辑在fn中实现了
42
+ return
43
+ # 剔除以 `_` 开头的列
44
+ data = data.select(~cs.starts_with("_"))
45
+ if not isinstance(data, (pl.DataFrame, pl.LazyFrame)):
46
+ logger.error(f"{save_path}: Result of dataset.fn must be polars.DataFrame or polars.LazyFrame.")
47
+ return
48
+ if isinstance(data, pl.LazyFrame):
49
+ data = data.collect()
50
+ cols = data.columns
51
+ if "date" not in cols:
52
+ data = data.with_columns(pl.lit(date).alias("date")).select("date", *cols)
53
+ else:
54
+ data = data.cast({"date": pl.Utf8})
55
+ data = data.filter(date=date)
56
+ if "time" in data.columns:
57
+ if data["time"].n_unique() < 2:
58
+ data = data.drop("time")
59
+ put(data, save_path, partitions=partitions)
60
+ except Exception as e:
61
+ logger.error(f"{save_path}: Error when complete data for {date}\n", exc_info=e)
62
+
63
+
64
+ class Dataset:
65
+
66
+ def __init__(self,
67
+ *depends: Dataset,
68
+ fn: Callable[..., pl.DataFrame | pl.LazyFrame],
69
+ tb: str = "",
70
+ update_time: str = "",
71
+ window: str = "1d",
72
+ partitions: list[str] = None,
73
+ is_hft: bool = False,
74
+ data_name: str = "",
75
+ frame: int = 1):
76
+ """
77
+
78
+ Parameters
79
+ ----------
80
+ depends: Dataset
81
+ 底层依赖数据集
82
+ fn: str
83
+ 数据集计算函数。如果要用到底层依赖数据集,则必须显示定义形参 `depend`
84
+ tb: str
85
+ 数据集保存表格, 如果没有指定,默认 {lidb.DB_PATH}/datasets/<module>
86
+ update_time: str
87
+ 更新时间: 默认没有-实时更新,也就是可以取到当天值
88
+ 更新时间只允许三种情况:
89
+ - 1. 盘前时间点:比如 08:00:00, 09:00:00, 09:15:00 ...
90
+ - 2. 盘中时间点:归为实时更新,使用空值 ""
91
+ - 3. 盘后时间点:比如 15:00:00, 16:30:00, 20:00:00 ...
92
+ partitions: list[str]
93
+ 分区: 如果指定为 None, 则自动从 fn 参数推断,如果不需要分区,应该将其设定为空列表: []
94
+ is_hft: bool
95
+ 是否是高频数据,如果是,则会按照asset进行分区存储,默认 False
96
+ hft定义为:时间步长 < 1min
97
+ window: str
98
+ 配合depends使用,在取depends时,会回看window周期,最小单位为`d`。不足 `d` 的会往上取整为`1d`
99
+ data_name: str
100
+ 数据名,默认为空,会自动推断,如果指定了,则使用指定名
101
+ frame: int
102
+ 用于自动推断 数据名
103
+ """
104
+ self._depends = list(depends)
105
+ self._name = ""
106
+ self.fn = fn
107
+ self.fn_params_sig = ygo.fn_signature_params(fn)
108
+ self._is_depend = "depend" in self.fn_params_sig and len(self._depends) > 0
109
+ self._is_hft = is_hft
110
+ self._frame = frame
111
+ self.data_name = data_name
112
+ if not self.data_name:
113
+ try:
114
+ self.data_name = varname(frame, strict=False)
115
+ except Exception as e:
116
+ pass
117
+ if self.data_name:
118
+ self.data_name = self.data_name.replace('ds_', '')
119
+ fn_params = ygo.fn_params(self.fn)
120
+ self.fn_params = {k: v for (k, v) in fn_params}
121
+ # 更新底层依赖数据集的同名参数
122
+ self._update_depends()
123
+
124
+ if pd.Timedelta(window).days < 1:
125
+ window = "1d"
126
+ window_td = pd.Timedelta(window)
127
+ self._window = window
128
+ self._days = window_td.days
129
+ if window_td.seconds > 0:
130
+ self._days += 1
131
+ # 检测是否高频数据:如果是高频数据,则按照标的进行分区,高频的定义为时间差 < 60s
132
+ self._append_partitions = ["asset", "date"] if is_hft else ["date", ]
133
+ if partitions is not None:
134
+ partitions = [k for k in partitions if k not in self._append_partitions]
135
+ partitions = [*partitions, *self._append_partitions]
136
+ else:
137
+ # partitions = self._append_partitions
138
+ exclude_partitions = ["this", "depend"]
139
+ partitions = [k for k in self.fn_params_sig if k not in self._append_partitions and k not in exclude_partitions]
140
+ partitions = [*partitions, *self._append_partitions]
141
+ self.partitions = partitions
142
+ self._type_asset = "asset" in self.fn_params_sig
143
+ mod = inspect.getmodule(fn)
144
+ self._tb = tb
145
+ self.tb = tb if tb else DEFAULT_DS_PATH / mod.__name__ /f"{self.data_name}"
146
+ self.save_path = tb_path(self.tb)
147
+ self.constraints = dict()
148
+ for k in self.partitions[:-len(self._append_partitions)]:
149
+ if k in self.fn_params:
150
+ v = self.fn_params[k]
151
+ if isinstance(v, (list, tuple)) and not isinstance(v, str):
152
+ v = sorted(v)
153
+ self.constraints[k] = v
154
+ self.save_path = self.save_path / f"{k}={v}"
155
+
156
+ if "09:30:00" < update_time < "15:00:00":
157
+ update_time = ""
158
+ # self.update_time = update_time
159
+ # 根据底层依赖调整update_time
160
+ if self._depends:
161
+ has_rt = any([not ds.update_time for ds in self._depends]) # 存在实时依赖
162
+ dep_uts = [ds.update_time for ds in self._depends if ds.update_time]
163
+ max_ut = max(dep_uts) if dep_uts else ""
164
+
165
+ if update_time:
166
+ if max_ut:
167
+ if not has_rt:
168
+ update_time = max(max_ut, update_time)
169
+ else:
170
+ # 存在实时依赖
171
+ if max_ut >= "15:00:00":
172
+ update_time = max(max_ut, update_time)
173
+ else:
174
+ # 存在实时依赖并且没有盘后依赖
175
+ if update_time <= "09:30:00":
176
+ # 修复盘前更新时间
177
+ update_time = ""
178
+ else:
179
+ # 依赖都是实时依赖
180
+ if not has_rt:
181
+ warnings.warn(f"{self.data_name}:{self.save_path} 更新时间推断错误", UserWarning)
182
+ sys.exit()
183
+ else:
184
+ if update_time <= "09:30:00":
185
+ # 修复盘前更新时间
186
+ update_time = ""
187
+ else:
188
+ # 最顶层是实时数据:
189
+ # 需要修复的情况: 存在盘后依赖
190
+ if max_ut >= "15:00:00":
191
+ # 盘后依赖:修复
192
+ update_time = max_ut
193
+
194
+ self.update_time = update_time
195
+ self.fn = ygo.delay(self.fn)(this=self)
196
+
197
+ def _update_depends(self):
198
+ new_deps = list()
199
+ for dep in self._depends:
200
+ new_dep = dep(**self.fn_params).alias(dep._name)
201
+ new_deps.append(new_dep)
202
+ self._depends = new_deps
203
+
204
+ def is_empty(self, path) -> bool:
205
+ return not any(path.rglob("*.parquet"))
206
+
207
+ def __call__(self, *fn_args, **fn_kwargs):
208
+ """赋值时也会同步更新底层依赖数据集的同名参数"""
209
+ if "data_name" in fn_kwargs:
210
+ data_name = fn_kwargs.pop("data_name")
211
+ else:
212
+ data_name = self.data_name
213
+ window = fn_kwargs.get("window", self._window)
214
+ fn = ygo.delay(self.fn)(*fn_args, **fn_kwargs)
215
+ ds = Dataset(*self._depends,
216
+ fn=fn,
217
+ tb=self._tb,
218
+ partitions=self.partitions,
219
+ update_time=self.update_time,
220
+ is_hft=self._is_hft,
221
+ window=window,
222
+ data_name=data_name,
223
+ frame=self._frame+1)
224
+ return ds
225
+
226
+ def alias(self, new_name: str):
227
+ self._name = new_name
228
+ return self
229
+
230
+ def get_value(self, date, eager: bool = True, **constraints):
231
+ """
232
+ 取值: 不保证未来数据
233
+ Parameters
234
+ ----------
235
+ date: str
236
+ 取值日期
237
+ eager: bool
238
+ constraints: dict
239
+ 取值的过滤条件
240
+
241
+ Returns
242
+ -------
243
+
244
+ """
245
+ logger = logair.get_logger(f"{__name__}.{self.__class__.__name__}")
246
+ _constraints = {k: v for k, v in constraints.items() if k in self.partitions}
247
+ _limits = {k: v for k, v in constraints.items() if k not in self.partitions}
248
+ search_path = self.save_path
249
+ for k, v in _constraints.items():
250
+ if isinstance(v, (list, tuple)) and not isinstance(v, str):
251
+ v = sorted(v)
252
+ search_path = search_path / f"{k}={v}"
253
+ search_path = search_path / f"date={date}"
254
+
255
+ # 处理空文件
256
+ for file_path in search_path.rglob("*.parquet"):
257
+ if file_path.stat().st_size == 0:
258
+ # 删除
259
+ logger.warning(f"{file_path}: Deleting empty file.")
260
+ file_path.unlink()
261
+
262
+ if not self.is_empty(search_path):
263
+ lf = scan(search_path).cast({"date": pl.Utf8})
264
+ try:
265
+ schema = lf.collect_schema()
266
+ except:
267
+ logger.warning(f"{search_path}: Failed to collect schema.")
268
+ # 删除该文件夹
269
+ shutil.rmtree(search_path)
270
+ return self.get_value(date=date, eager=eager, **constraints)
271
+ _limits = {k: v for k, v in constraints.items() if schema.get(k) is not None}
272
+ lf = lf.filter(date=date, **_limits)
273
+ if not eager:
274
+ return lf
275
+ data = lf.collect()
276
+ if not data.is_empty():
277
+ return data
278
+ fn = self.fn
279
+ save_path = self.save_path
280
+ if self._is_depend:
281
+ fn = partial(fn, depend=self.get_dependsPIT(date, days=self._days))
282
+ else:
283
+ fn = partial(fn, date=date)
284
+ if self._type_asset:
285
+ if "asset" in _constraints:
286
+ fn = ygo.delay(self.fn)(asset=_constraints["asset"])
287
+ if len(self.constraints) < len(self.partitions) - len(self._append_partitions):
288
+ # 如果分区指定的字段没有在Dataset定义中指定,需要在get_value中指定
289
+ params = dict()
290
+ for k in self.partitions[:-len(self._append_partitions)]:
291
+ if k not in self.constraints:
292
+ v = constraints[k]
293
+ params[k] = v
294
+ save_path = save_path / f"{k}={v}"
295
+ fn = ygo.delay(self.fn)(**params)
296
+
297
+ today = xcals.today()
298
+ now = xcals.now()
299
+ if (date > today) or (date == today and now < self.update_time):
300
+ logger.warning(f"{self.tb}: {date} is not ready, waiting for {self.update_time}")
301
+ return
302
+ complete_data(fn, date, save_path, self._append_partitions)
303
+
304
+ lf = scan(search_path).cast({"date": pl.Utf8})
305
+ schema = lf.collect_schema()
306
+ _limits = {k: v for k, v in constraints.items() if schema.get(k) is not None}
307
+ lf = lf.filter(date=date, **_limits)
308
+ if not eager:
309
+ return lf
310
+ return lf.collect()
311
+
312
+ def get_pit(self, date: str, query_time: str, eager: bool = True, **contraints):
313
+ """取值:如果取值时间早于更新时间,则返回上一天的值"""
314
+ if not self.update_time:
315
+ return self.get_value(date, **contraints)
316
+ val_date = date
317
+ if query_time < self.update_time:
318
+ val_date = xcals.shift_tradeday(date, -1)
319
+ return self.get_value(val_date, eager=eager, **contraints).with_columns(date=pl.lit(date), )
320
+
321
+ def get_history(self,
322
+ dateList: list[str],
323
+ n_jobs: int = 11,
324
+ backend: Literal["threading", "multiprocessing", "loky"] = "loky",
325
+ eager: bool = True,
326
+ rep_asset: str = "000001", # 默认 000001
327
+ **constraints):
328
+ """获取历史值: 不保证未来数据"""
329
+ _constraints = {k: v for k, v in constraints.items() if k in self.partitions}
330
+ search_path = self.save_path
331
+ for k, v in _constraints.items():
332
+ if isinstance(v, (list, tuple)) and not isinstance(v, str):
333
+ v = sorted(v)
334
+ search_path = search_path / f"{k}={v}"
335
+ if self.is_empty(search_path):
336
+ # 需要补全全部数据
337
+ missing_dates = dateList
338
+ else:
339
+ if not self._type_asset:
340
+ _search_path = self.save_path
341
+ for k, v in _constraints.items():
342
+ if k != "asset":
343
+ _search_path = _search_path / f"{k}={v}"
344
+ else:
345
+ _search_path = _search_path / f"asset={rep_asset}"
346
+ hive_info = parse_hive_partition_structure(_search_path)
347
+ else:
348
+ hive_info = parse_hive_partition_structure(search_path)
349
+ exist_dates = hive_info["date"].to_list()
350
+ missing_dates = set(dateList).difference(set(exist_dates))
351
+ missing_dates = sorted(list(missing_dates))
352
+ if missing_dates:
353
+ # 先逐个补齐 depends
354
+ _end_date = max(missing_dates)
355
+ _beg_date = min(missing_dates)
356
+ if self._days > 1:
357
+ _beg_date = xcals.shift_tradeday(_beg_date, -(self._days-1))
358
+ _depend_dates = xcals.get_tradingdays(_beg_date, _end_date)
359
+ for depend in self._depends:
360
+ depend.get_history(_depend_dates, eager=False)
361
+ fn = self.fn
362
+ save_path = self.save_path
363
+
364
+ if self._type_asset:
365
+ if "asset" in _constraints:
366
+ fn = ygo.delay(self.fn)(asset=_constraints["asset"])
367
+
368
+ if len(self.constraints) < len(self.partitions) - len(self._append_partitions):
369
+ params = dict()
370
+ for k in self.partitions[:-len(self._append_partitions)]:
371
+ if k not in self.constraints:
372
+ v = constraints[k]
373
+ params[k] = v
374
+ save_path = save_path / f"{k}={v}"
375
+ fn = ygo.delay(self.fn)(**params)
376
+
377
+ with ygo.pool(n_jobs=n_jobs, backend=backend) as go:
378
+ info_path = self.save_path
379
+ try:
380
+ info_path = info_path.relative_to(DB_PATH)
381
+ except:
382
+ pass
383
+ if self._is_depend:
384
+ with ygo.pool(n_jobs=n_jobs, show_progress=False) as _go:
385
+ for date in missing_dates:
386
+ _go.submit(self.get_dependsPIT, job_name="preparing depend")(date=date, days=self._days)
387
+ for (date, depend) in zip(missing_dates, _go.do()):
388
+ fn = partial(fn, depend=depend)
389
+ go.submit(complete_data,
390
+ job_name=f"Completing",
391
+ postfix=info_path,
392
+ leave=False)(fn=fn,
393
+ date=date,
394
+ save_path=save_path,
395
+ partitions=self._append_partitions, )
396
+ else:
397
+ for date in missing_dates:
398
+ fn = partial(fn, date=date)
399
+ go.submit(complete_data,
400
+ job_name=f"Completing",
401
+ postfix=info_path,
402
+ leave=False)(fn=fn,
403
+ date=date,
404
+ save_path=save_path,
405
+ partitions=self._append_partitions, )
406
+ go.do()
407
+ data = scan(search_path, ).cast({"date": pl.Utf8}).filter(pl.col("date").is_in(dateList), **constraints)
408
+ data = data.sort("date")
409
+ if eager:
410
+ return data.collect()
411
+ return data
412
+
413
+ def get_dependsPIT(self, date: str, days: int) -> pl.LazyFrame | None:
414
+ """获取依赖数据集"""
415
+ if not self._depends:
416
+ return None
417
+ end_date = date
418
+ beg_date = date
419
+ if days > 1:
420
+ beg_date = xcals.shift_tradeday(beg_date, -(days-1))
421
+ params = {
422
+ "ds_conf": dict(depend=self._depends),
423
+ "beg_date": beg_date,
424
+ "end_date": end_date,
425
+ "times": [self.update_time, ],
426
+ "show_progress": False,
427
+ "eager": False,
428
+ "n_jobs": 1,
429
+ "process_time": False, # 不处理时间
430
+ }
431
+ res = load_ds(**params)
432
+ return res["depend"]
433
+
434
+
435
+ def loader(data_name: str,
436
+ ds: Dataset,
437
+ date_list: list[str],
438
+ prev_date_list: list[str],
439
+ prev_date_mapping: dict[str, str],
440
+ time: str,
441
+ process_time: bool,
442
+ **constraints) -> pl.LazyFrame:
443
+ """
444
+ Parameters
445
+ ----------
446
+ data_name
447
+ ds
448
+ date_list
449
+ prev_date_list
450
+ prev_date_mapping
451
+ time
452
+ process_time: bool
453
+ 是否处理源数据的时间: 根据实参 time. 用于应对不同场景
454
+ 场景1:依赖因子不处理,底层数据是什么就返回什么
455
+ 场景2:zoo.load 用来加载测试日内不同时间点的数据,就应该处理
456
+ constraints
457
+
458
+ Returns
459
+ -------
460
+
461
+ """
462
+ if time:
463
+ if time < ds.update_time:
464
+ if len(prev_date_list) > 1:
465
+ lf = ds.get_history(prev_date_list, eager=False, **constraints)
466
+ else:
467
+ lf = ds.get_value(prev_date_list[0], eager=False, **constraints)
468
+ else:
469
+ if len(date_list) > 1:
470
+ lf = ds.get_history(date_list, eager=False, **constraints)
471
+ else:
472
+ lf = ds.get_value(date_list[0], eager=False, **constraints)
473
+ else:
474
+ if ds.update_time > "09:30:00":
475
+ # 盘后因子:取上一天的值
476
+ if len(prev_date_list) > 1:
477
+ lf = ds.get_history(prev_date_list, eager=False, **constraints)
478
+ else:
479
+ lf = ds.get_value(prev_date_list[0], eager=False, **constraints)
480
+ else:
481
+ if len(date_list) > 1:
482
+ lf = ds.get_history(date_list, eager=False, **constraints)
483
+ else:
484
+ lf = ds.get_value(date_list[0], eager=False, **constraints)
485
+
486
+ schema = lf.collect_schema()
487
+ include_time = schema.get("time") is not None
488
+ if process_time and time:
489
+ if include_time:
490
+ lf = lf.filter(time=time)
491
+ else:
492
+ lf = lf.with_columns(time=pl.lit(time))
493
+ if time < ds.update_time:
494
+ lf = lf.with_columns(date=pl.col("date").replace(prev_date_mapping))
495
+ keep = {"date", "time", "asset"}
496
+ if ds._name:
497
+ columns = lf.collect_schema().names()
498
+ rename_cols = set(columns).difference(keep)
499
+ if len(rename_cols) > 1:
500
+ lf = lf.rename({k: f"{ds._name}.{k}" for k in rename_cols})
501
+ else:
502
+ lf = lf.rename({k: ds._name for k in rename_cols})
503
+ return data_name, lf
504
+
505
+
506
+ def load_ds(ds_conf: dict[str, list[Dataset]],
507
+ beg_date: str,
508
+ end_date: str,
509
+ times: list[str],
510
+ n_jobs: int = 11,
511
+ backend: Literal["threading", "multiprocessing", "loky"] = "loky",
512
+ show_progress: bool = True,
513
+ eager: bool = False,
514
+ process_time: bool = True,
515
+ **constraints) -> dict[str, pl.DataFrame | pl.LazyFrame]:
516
+ """
517
+ 加载数据集
518
+ Parameters
519
+ ----------
520
+ ds_conf: dict[str, list[Dataset]]
521
+ 数据集配置: key-data_name, value-list[Dataset]
522
+ beg_date: str
523
+ 开始日期
524
+ end_date: str
525
+ 结束日期
526
+ times: list[str]
527
+ 取值时间
528
+ n_jobs: int
529
+ 并发数量
530
+ backend: str
531
+ show_progress: bool
532
+ eager: bool
533
+ 是否返回 DataFrame
534
+ - True: 返回DataFrame
535
+ - False: 返回LazyFrame
536
+ process_time: bool
537
+ 是否处理源数据的时间: 根据实参 time. 用于应对不同场景
538
+ 场景1:依赖因子不处理,底层数据是什么就返回什么
539
+ 场景2:zoo.load 用来加载测试日内不同时间点的数据,就应该处理
540
+ constraints
541
+ 限制条件,比如 asset='000001'
542
+ Returns
543
+ -------
544
+ dict[str, polars.DataFrame | polars.LazyFrame]
545
+ - key: data_name
546
+ - value: polars.DataFrame
547
+
548
+ """
549
+ if beg_date > end_date:
550
+ raise ValueError("beg_date must be less than end_date")
551
+ date_list = xcals.get_tradingdays(beg_date, end_date)
552
+ beg_date, end_date = date_list[0], date_list[-1]
553
+ prev_date_list = xcals.get_tradingdays(xcals.shift_tradeday(beg_date, -1),
554
+ xcals.shift_tradeday(end_date, -1))
555
+ prev_date_mapping = {prev_date: date_list[i] for i, prev_date in enumerate(prev_date_list)}
556
+ results = defaultdict(list)
557
+ index = ("date", "time", "asset")
558
+ _index = ("date", "asset")
559
+ with ygo.pool(n_jobs=n_jobs,
560
+ backend=backend,
561
+ show_progress=show_progress) as go:
562
+ for data_name, ds_list in ds_conf.items():
563
+ for ds in ds_list:
564
+ _data_name = f"{data_name}:{ds.tb}"
565
+ if ds._name:
566
+ _data_name += f".alias({ds._name})"
567
+ for time in times:
568
+ go.submit(loader,
569
+ job_name="Loading",
570
+ postfix=data_name, )(data_name=_data_name,
571
+ ds=ds,
572
+ date_list=date_list,
573
+ prev_date_list=prev_date_list,
574
+ prev_date_mapping=prev_date_mapping,
575
+ time=time,
576
+ process_time=process_time,
577
+ **constraints)
578
+ for name, lf in go.do():
579
+ results[name].append(lf)
580
+ # _LFs = {
581
+ # name: (pl.concat(lfList, )
582
+ # .select(*index,
583
+ # cs.exclude(index))
584
+ # )
585
+ # for name, lfList in results.items()}
586
+ _LFs_with_time = {}
587
+ _LFs_without_time = {}
588
+ for name, lfList in results.items():
589
+ lf = pl.concat(lfList)
590
+ # print(lf)
591
+ if "time" not in lf.collect_schema().names():
592
+ _LFs_without_time[name] = lf
593
+ else:
594
+ _LFs_with_time[name] = lf
595
+ LFs_with_time = defaultdict(list)
596
+ LFs_without_time = defaultdict(list)
597
+ for name, lf in _LFs_with_time.items():
598
+ dn, _ = name.split(":")
599
+ LFs_with_time[dn].append(lf)
600
+ for name, lf in _LFs_without_time.items():
601
+ dn, _ = name.split(":")
602
+ LFs_without_time[dn].append(lf)
603
+ LFs_with_time = {
604
+ name: (pl.concat(lfList, how="align")
605
+ .sort(index)
606
+ .select(*index,
607
+ cs.exclude(index))
608
+ )
609
+ for name, lfList in LFs_with_time.items()}
610
+ LFs_without_time = {
611
+ name: (pl.concat(lfList, how="align")
612
+ .sort(_index)
613
+ .select(*_index,
614
+ cs.exclude(_index))
615
+ )
616
+ for name, lfList in LFs_without_time.items()}
617
+ dns = list(LFs_with_time.keys()) if LFs_with_time else list(LFs_without_time.keys())
618
+ LFs = dict()
619
+ for dn in dns:
620
+ _lf_with_time = LFs_with_time.get(dn)
621
+ _lf_without_time = LFs_without_time.get(dn)
622
+ if _lf_with_time is not None:
623
+ LFs[dn] = _lf_with_time
624
+ if _lf_without_time is not None:
625
+ LFs[dn] = LFs[dn].join(_lf_without_time, on=["date", "asset"], how="left")
626
+ else:
627
+ LFs[dn] = _lf_without_time
628
+ if not eager:
629
+ return LFs
630
+ return {
631
+ name: lf.collect()
632
+ for name, lf in LFs.items()
633
+ }
634
+
635
+ class DataLoader:
636
+
637
+ def __init__(self, name: str):
638
+ self._name = name
639
+ self._index: tuple[str] = ("date", "time", "asset")
640
+ self._df: pl.LazyFrame | pl.DataFrame = None
641
+ # self._db: QDF = None
642
+
643
+ def get(self,
644
+ ds_list: list[Dataset],
645
+ beg_date: str,
646
+ end_date: str,
647
+ times: list[str],
648
+ eager: bool = False,
649
+ n_jobs: int = -1,
650
+ backend: Literal["threading", "multiprocessing", "loky"] = "loky",
651
+ **constraints):
652
+ """
653
+ 添加数据集
654
+ Parameters
655
+ ----------
656
+ ds_list: list[Dataset]
657
+ beg_date: str
658
+ end_date: str
659
+ times: list[str]
660
+ 加载的时间列表
661
+ eager: bool
662
+ n_jobs: int
663
+ backend: str
664
+ constraints
665
+
666
+ Returns
667
+ -------
668
+
669
+ """
670
+ lf = load_ds(ds_conf={self._name: ds_list},
671
+ beg_date=beg_date,
672
+ end_date=end_date,
673
+ n_jobs=n_jobs,
674
+ backend=backend,
675
+ times=times,
676
+ eager=eager,
677
+ process_time=True,
678
+ **constraints)
679
+ self._df = lf[self._name]
680
+
681
+ @property
682
+ def name(self) -> str:
683
+ return self._name
684
+
685
+ @property
686
+ def data(self) -> pl.DataFrame | None:
687
+ """返回全量数据"""
688
+ if isinstance(self._df, pl.LazyFrame):
689
+ self._df = self._df.collect()
690
+ return self._df
691
+
692
+ def add_data(self, df: pl.DataFrame | pl.LazyFrame):
693
+ """添加dataframe, index 保持为原有的 _df.index"""
694
+ if isinstance(df, pl.LazyFrame):
695
+ df = df.collect()
696
+ self._df = pl.concat([self._df, df], how="align").sort(self._index)