litequant 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- litequant/LiteQuantClient.py +884 -0
- litequant/ParquetManager.py +1379 -0
- litequant/__init__.py +54 -0
- litequant/exceptions.py +168 -0
- litequant/log.py +24 -0
- litequant-3.0.0.dist-info/LICENSE +21 -0
- litequant-3.0.0.dist-info/METADATA +139 -0
- litequant-3.0.0.dist-info/RECORD +10 -0
- litequant-3.0.0.dist-info/WHEEL +5 -0
- litequant-3.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,884 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
LiteQuant 3.0 Client - 独立简洁版
|
|
4
|
+
|
|
5
|
+
完整独立实现,自动管理 Data Ticket 生命周期:
|
|
6
|
+
- 获取/续租/释放 Data Ticket
|
|
7
|
+
- 直连 Redis 下载数据
|
|
8
|
+
- 本地 Parquet 缓存管理
|
|
9
|
+
- 一键更新和读取数据
|
|
10
|
+
|
|
11
|
+
使用示例:
|
|
12
|
+
import litequant as lq
|
|
13
|
+
|
|
14
|
+
client = lq.LiteQuantClient(
|
|
15
|
+
api_token='your_api_token',
|
|
16
|
+
save_path='D:/LiteQuant/'
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
client.UpdateAllCategory(update_method='full')
|
|
20
|
+
df = client.GetCategory("cn_stock_pivot#open")
|
|
21
|
+
|
|
22
|
+
# 使用 with 语句(自动关闭)
|
|
23
|
+
with lq.LiteQuantClient(api_token='xxx', save_path='D:/data/') as client:
|
|
24
|
+
df = client.GetCategory("cn_stock_pivot#open")
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
import hashlib
|
|
28
|
+
import io
|
|
29
|
+
import json
|
|
30
|
+
import logging
|
|
31
|
+
import os
|
|
32
|
+
import time
|
|
33
|
+
import threading
|
|
34
|
+
from typing import Any, Dict, List, Optional, Literal
|
|
35
|
+
from dataclasses import dataclass
|
|
36
|
+
from datetime import datetime, timedelta
|
|
37
|
+
|
|
38
|
+
import pandas as pd
|
|
39
|
+
import redis
|
|
40
|
+
import requests
|
|
41
|
+
from tqdm import tqdm
|
|
42
|
+
|
|
43
|
+
from .ParquetManager import ParquetDataManager
|
|
44
|
+
from .exceptions import (
|
|
45
|
+
LiteQuantError,
|
|
46
|
+
AuthError,
|
|
47
|
+
CategoryNotFoundError,
|
|
48
|
+
InvalidCategoryError,
|
|
49
|
+
RemoteDataError,
|
|
50
|
+
SerializationError,
|
|
51
|
+
MetaError,
|
|
52
|
+
ProtocolError,
|
|
53
|
+
TicketExpiredError,
|
|
54
|
+
QuotaExceededError,
|
|
55
|
+
PermissionDeniedError,
|
|
56
|
+
NetworkError,
|
|
57
|
+
DataConnectionError,
|
|
58
|
+
ServiceUnavailableError,
|
|
59
|
+
SyncError,
|
|
60
|
+
APIError,
|
|
61
|
+
)
|
|
62
|
+
from .log import GetLogger
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ============ 数据模型 ============
|
|
66
|
+
|
|
67
|
+
@dataclass(repr=False)
|
|
68
|
+
class DataTicket:
|
|
69
|
+
"""Data Ticket data returned by the service."""
|
|
70
|
+
ticket_id: str
|
|
71
|
+
session_id: str
|
|
72
|
+
redis_host: str
|
|
73
|
+
redis_port: int
|
|
74
|
+
redis_password: str
|
|
75
|
+
expires_at: datetime
|
|
76
|
+
datasets: List[str]
|
|
77
|
+
key_patterns: List[str]
|
|
78
|
+
redis_username: Optional[str] = None
|
|
79
|
+
redis_ssl: bool = False
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def is_expired(self) -> bool:
|
|
83
|
+
return datetime.now() >= self.expires_at
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def ttl_seconds(self) -> float:
|
|
87
|
+
delta = self.expires_at - datetime.now()
|
|
88
|
+
return max(0, delta.total_seconds())
|
|
89
|
+
|
|
90
|
+
def __repr__(self) -> str:
|
|
91
|
+
return "DataTicket(<hidden>)"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# ============ 核心客户端 ============
|
|
95
|
+
|
|
96
|
+
class LiteQuantClient:
|
|
97
|
+
"""
|
|
98
|
+
LiteQuant 3.0 简洁版客户端
|
|
99
|
+
|
|
100
|
+
独立完整实现,自动管理所有底层连接和数据缓存。
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
api_token: API Token(认证凭据)
|
|
104
|
+
save_path: 本地数据存储路径
|
|
105
|
+
api_url: API 地址(可选,默认从环境变量 LITEQUANT_API_URL 读取)
|
|
106
|
+
log_level: 日志级别
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
# 常量配置
|
|
110
|
+
DEFAULT_API_URL = os.environ.get('LITEQUANT_API_URL', 'https://www.litequant.pro')
|
|
111
|
+
PROTOCOL_VERSION = 1
|
|
112
|
+
DATA_FORMAT = "parquet"
|
|
113
|
+
|
|
114
|
+
# Ticket 配置
|
|
115
|
+
HEARTBEAT_INTERVAL = 30 # 心跳间隔(秒)
|
|
116
|
+
TICKET_REFRESH_THRESHOLD = 30 # 过期前30秒续租
|
|
117
|
+
|
|
118
|
+
# Redis Key 前缀
|
|
119
|
+
META_CATEGORIES_KEY = "litequant:meta:categories"
|
|
120
|
+
META_CATEGORY_KEY_PREFIX = "litequant:meta:category:"
|
|
121
|
+
META_PARTITION_KEY_PREFIX = "litequant:meta:partition:"
|
|
122
|
+
|
|
123
|
+
PUBLIC_ERROR_MESSAGES = {
|
|
124
|
+
"AUTH_INVALID": "API 凭证无效或已过期,请检查后重试",
|
|
125
|
+
"ACCOUNT_UNAVAILABLE": "账号不可用,请联系管理员处理",
|
|
126
|
+
"SUBSCRIPTION_UNAVAILABLE": "套餐不可用或已过期,请续费后重试",
|
|
127
|
+
"PERMISSION_DENIED": "权限不足:当前账号没有该数据权限",
|
|
128
|
+
"CONNECTION_LIMIT": "连接数已达上限,请关闭其他连接后重试",
|
|
129
|
+
"CONNECTION_INTERRUPTED": "连接中断,请重新初始化客户端",
|
|
130
|
+
"REQUEST_INVALID": "请求参数无效,请检查输入后重试",
|
|
131
|
+
"SERVICE_UNAVAILABLE": "服务暂时不可用,请稍后重试",
|
|
132
|
+
"DATA_UNAVAILABLE": "数据暂时不可用,请稍后重试",
|
|
133
|
+
"DATA_VERIFY_FAILED": "数据校验失败,请重新同步",
|
|
134
|
+
}
|
|
135
|
+
LEGACY_ERROR_CODE_MAP = {
|
|
136
|
+
"Unauthorized": "AUTH_INVALID",
|
|
137
|
+
"UserDisabled": "ACCOUNT_UNAVAILABLE",
|
|
138
|
+
"SubscriptionInvalid": "SUBSCRIPTION_UNAVAILABLE",
|
|
139
|
+
"SubscriptionExpired": "SUBSCRIPTION_UNAVAILABLE",
|
|
140
|
+
"NoDatasetAccess": "PERMISSION_DENIED",
|
|
141
|
+
"ConnectionLimitExceeded": "CONNECTION_LIMIT",
|
|
142
|
+
"InvalidRequest": "REQUEST_INVALID",
|
|
143
|
+
"MissingParameters": "REQUEST_INVALID",
|
|
144
|
+
"TicketNotFound": "CONNECTION_INTERRUPTED",
|
|
145
|
+
"TicketExpired": "CONNECTION_INTERRUPTED",
|
|
146
|
+
"SessionNotFound": "CONNECTION_INTERRUPTED",
|
|
147
|
+
"SessionExpired": "CONNECTION_INTERRUPTED",
|
|
148
|
+
"RenewFailed": "CONNECTION_INTERRUPTED",
|
|
149
|
+
"InternalError": "SERVICE_UNAVAILABLE",
|
|
150
|
+
}
|
|
151
|
+
ERROR_CLASS_MAP = {
|
|
152
|
+
"AUTH_INVALID": AuthError,
|
|
153
|
+
"ACCOUNT_UNAVAILABLE": PermissionDeniedError,
|
|
154
|
+
"SUBSCRIPTION_UNAVAILABLE": PermissionDeniedError,
|
|
155
|
+
"PERMISSION_DENIED": PermissionDeniedError,
|
|
156
|
+
"CONNECTION_LIMIT": QuotaExceededError,
|
|
157
|
+
"CONNECTION_INTERRUPTED": DataConnectionError,
|
|
158
|
+
"REQUEST_INVALID": APIError,
|
|
159
|
+
"SERVICE_UNAVAILABLE": ServiceUnavailableError,
|
|
160
|
+
"DATA_UNAVAILABLE": RemoteDataError,
|
|
161
|
+
"DATA_VERIFY_FAILED": SerializationError,
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
api_token: str,
|
|
167
|
+
save_path: str,
|
|
168
|
+
api_url: str = None,
|
|
169
|
+
log_level: int = logging.INFO,
|
|
170
|
+
display_errors: bool = True,
|
|
171
|
+
):
|
|
172
|
+
self.api_token = api_token
|
|
173
|
+
self.api_url = (api_url or os.environ.get('LITEQUANT_API_URL', self.DEFAULT_API_URL)).rstrip("/")
|
|
174
|
+
self.save_path = os.path.abspath(save_path)
|
|
175
|
+
self.display_errors = display_errors
|
|
176
|
+
|
|
177
|
+
# 确保路径存在
|
|
178
|
+
os.makedirs(self.save_path, exist_ok=True)
|
|
179
|
+
|
|
180
|
+
# 日志(使用 token 前8位作为标识)
|
|
181
|
+
self.logger = GetLogger("litequant", level=log_level)
|
|
182
|
+
self.logger.info(f"初始化 LiteQuantClient")
|
|
183
|
+
|
|
184
|
+
# 数据管理器
|
|
185
|
+
self.LQ_db = ParquetDataManager(self.save_path, logger=self.logger)
|
|
186
|
+
|
|
187
|
+
# Ticket 和连接
|
|
188
|
+
self._ticket: Optional[DataTicket] = None
|
|
189
|
+
self._redis_client: Optional[redis.Redis] = None
|
|
190
|
+
|
|
191
|
+
# 心跳线程
|
|
192
|
+
self._stop_heartbeat = threading.Event()
|
|
193
|
+
self._heartbeat_thread: Optional[threading.Thread] = None
|
|
194
|
+
|
|
195
|
+
# 状态
|
|
196
|
+
self._remote_categories: List[str] = []
|
|
197
|
+
self._remote_categories_set = set()
|
|
198
|
+
|
|
199
|
+
# 初始化连接
|
|
200
|
+
try:
|
|
201
|
+
self._init_connection()
|
|
202
|
+
self._start_heartbeat()
|
|
203
|
+
except LiteQuantError as e:
|
|
204
|
+
err = self._public_error_clone(e)
|
|
205
|
+
self._display_error(err)
|
|
206
|
+
raise err from None
|
|
207
|
+
except Exception as e:
|
|
208
|
+
err = ServiceUnavailableError()
|
|
209
|
+
self._display_error(err)
|
|
210
|
+
raise err from None
|
|
211
|
+
|
|
212
|
+
# 启动心跳
|
|
213
|
+
|
|
214
|
+
self.logger.info("LiteQuantClient 初始化完成")
|
|
215
|
+
|
|
216
|
+
# ============ 连接管理 ============
|
|
217
|
+
|
|
218
|
+
def _display_error(self, error: LiteQuantError) -> None:
|
|
219
|
+
"""Print a privacy-safe user message to the configured SDK logger."""
|
|
220
|
+
if self.display_errors:
|
|
221
|
+
self.logger.error(error.user_message)
|
|
222
|
+
|
|
223
|
+
def _public_error_clone(self, error: LiteQuantError) -> LiteQuantError:
|
|
224
|
+
"""Return a fresh privacy-safe exception without traceback/cause state."""
|
|
225
|
+
return error.__class__(
|
|
226
|
+
code=error.code,
|
|
227
|
+
user_message=error.user_message,
|
|
228
|
+
retryable=error.retryable,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def _as_litequant_error(self, error: Exception) -> LiteQuantError:
|
|
232
|
+
if isinstance(error, LiteQuantError):
|
|
233
|
+
return self._public_error_clone(error)
|
|
234
|
+
return ServiceUnavailableError()
|
|
235
|
+
|
|
236
|
+
def _api_post_json(self, endpoint: str, payload: dict, timeout: int) -> dict:
|
|
237
|
+
url = f"{self.api_url}{endpoint}"
|
|
238
|
+
headers = {
|
|
239
|
+
"Authorization": f"Bearer {self.api_token}",
|
|
240
|
+
"Content-Type": "application/json",
|
|
241
|
+
}
|
|
242
|
+
try:
|
|
243
|
+
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
|
|
244
|
+
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError):
|
|
245
|
+
raise NetworkError() from None
|
|
246
|
+
except requests.exceptions.RequestException:
|
|
247
|
+
raise ServiceUnavailableError() from None
|
|
248
|
+
|
|
249
|
+
return self._parse_api_response(response)
|
|
250
|
+
|
|
251
|
+
def _parse_api_response(self, response: Any) -> dict:
|
|
252
|
+
status_code = getattr(response, "status_code", 200)
|
|
253
|
+
ok = getattr(response, "ok", 200 <= status_code < 300)
|
|
254
|
+
try:
|
|
255
|
+
data = response.json()
|
|
256
|
+
except ValueError:
|
|
257
|
+
public_code = self._public_code_from_status(status_code)
|
|
258
|
+
raise self._build_public_error(public_code) from None
|
|
259
|
+
|
|
260
|
+
if not isinstance(data, dict):
|
|
261
|
+
raise ServiceUnavailableError(detail="invalid response shape")
|
|
262
|
+
|
|
263
|
+
if ok and data.get("success", True):
|
|
264
|
+
return data
|
|
265
|
+
|
|
266
|
+
public_code = self._normalize_error_code(data.get("error"), status_code)
|
|
267
|
+
raise self._build_public_error(public_code) from None
|
|
268
|
+
|
|
269
|
+
def _normalize_error_code(self, raw_code: Any, status_code: int = None) -> str:
|
|
270
|
+
if raw_code:
|
|
271
|
+
raw_code = str(raw_code)
|
|
272
|
+
public_code = self.LEGACY_ERROR_CODE_MAP.get(raw_code, raw_code)
|
|
273
|
+
if public_code in self.PUBLIC_ERROR_MESSAGES:
|
|
274
|
+
return public_code
|
|
275
|
+
return self._public_code_from_status(status_code)
|
|
276
|
+
|
|
277
|
+
def _public_code_from_status(self, status_code: int = None) -> str:
|
|
278
|
+
if status_code == 401:
|
|
279
|
+
return "AUTH_INVALID"
|
|
280
|
+
if status_code == 403:
|
|
281
|
+
return "PERMISSION_DENIED"
|
|
282
|
+
if status_code == 429:
|
|
283
|
+
return "CONNECTION_LIMIT"
|
|
284
|
+
if status_code == 400:
|
|
285
|
+
return "REQUEST_INVALID"
|
|
286
|
+
if status_code and status_code >= 500:
|
|
287
|
+
return "SERVICE_UNAVAILABLE"
|
|
288
|
+
return "SERVICE_UNAVAILABLE"
|
|
289
|
+
|
|
290
|
+
def _build_public_error(self, public_code: str, detail: Any = None) -> LiteQuantError:
|
|
291
|
+
error_cls = self.ERROR_CLASS_MAP.get(public_code, ServiceUnavailableError)
|
|
292
|
+
return error_cls(
|
|
293
|
+
code=public_code,
|
|
294
|
+
user_message=self.PUBLIC_ERROR_MESSAGES.get(public_code, self.PUBLIC_ERROR_MESSAGES["SERVICE_UNAVAILABLE"]),
|
|
295
|
+
retryable=public_code in {
|
|
296
|
+
"CONNECTION_LIMIT",
|
|
297
|
+
"CONNECTION_INTERRUPTED",
|
|
298
|
+
"SERVICE_UNAVAILABLE",
|
|
299
|
+
"DATA_UNAVAILABLE",
|
|
300
|
+
"DATA_VERIFY_FAILED",
|
|
301
|
+
},
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
def _init_connection(self):
|
|
305
|
+
"""初始化连接(获取 Ticket + 连接 Redis)"""
|
|
306
|
+
self._fetch_ticket()
|
|
307
|
+
self._connect_redis()
|
|
308
|
+
self._refresh_categories()
|
|
309
|
+
|
|
310
|
+
def _fetch_ticket(self) -> DataTicket:
|
|
311
|
+
"""从 Django API 获取 Data Ticket"""
|
|
312
|
+
try:
|
|
313
|
+
data = self._api_post_json(
|
|
314
|
+
"/api/v1/data/ticket/",
|
|
315
|
+
{"client_version": "3.0.0"},
|
|
316
|
+
timeout=30,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
redis_info = data["redis"]
|
|
320
|
+
scope = data.get("scope", {})
|
|
321
|
+
redis_username = redis_info.get("username")
|
|
322
|
+
redis_ssl = bool(redis_info.get("ssl", False))
|
|
323
|
+
|
|
324
|
+
self._ticket = DataTicket(
|
|
325
|
+
ticket_id=data["ticket"],
|
|
326
|
+
session_id=data["session_id"],
|
|
327
|
+
redis_host=redis_info["host"],
|
|
328
|
+
redis_port=redis_info["port"],
|
|
329
|
+
redis_password=redis_info["password"],
|
|
330
|
+
expires_at=datetime.now() + timedelta(seconds=data["expires_in"]),
|
|
331
|
+
datasets=scope.get("datasets", []),
|
|
332
|
+
key_patterns=scope.get("key_patterns", []),
|
|
333
|
+
redis_username=redis_username,
|
|
334
|
+
redis_ssl=redis_ssl,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
self.logger.debug("连接凭证已获取")
|
|
338
|
+
return self._ticket
|
|
339
|
+
|
|
340
|
+
except LiteQuantError:
|
|
341
|
+
raise
|
|
342
|
+
except (KeyError, TypeError, ValueError):
|
|
343
|
+
raise ServiceUnavailableError() from None
|
|
344
|
+
|
|
345
|
+
def _connect_redis(self):
|
|
346
|
+
"""使用 Ticket 连接 Redis"""
|
|
347
|
+
if not self._ticket:
|
|
348
|
+
raise DataConnectionError()
|
|
349
|
+
|
|
350
|
+
try:
|
|
351
|
+
redis_kwargs = {
|
|
352
|
+
"host": self._ticket.redis_host,
|
|
353
|
+
"port": self._ticket.redis_port,
|
|
354
|
+
"password": self._ticket.redis_password,
|
|
355
|
+
"socket_timeout": 20,
|
|
356
|
+
"socket_connect_timeout": 5,
|
|
357
|
+
"decode_responses": False,
|
|
358
|
+
}
|
|
359
|
+
if self._ticket.redis_username:
|
|
360
|
+
redis_kwargs["username"] = self._ticket.redis_username
|
|
361
|
+
redis_kwargs["ssl"] = self._ticket.redis_ssl
|
|
362
|
+
else:
|
|
363
|
+
# Aliyun Tair can use password formatted as "username:password".
|
|
364
|
+
redis_kwargs["db"] = 0
|
|
365
|
+
|
|
366
|
+
for attempt in range(6):
|
|
367
|
+
self._redis_client = redis.Redis(**redis_kwargs)
|
|
368
|
+
try:
|
|
369
|
+
if not self._redis_client.ping():
|
|
370
|
+
raise DataConnectionError()
|
|
371
|
+
break
|
|
372
|
+
except redis.AuthenticationError:
|
|
373
|
+
try:
|
|
374
|
+
self._redis_client.close()
|
|
375
|
+
except Exception:
|
|
376
|
+
pass
|
|
377
|
+
self._redis_client = None
|
|
378
|
+
if attempt >= 5:
|
|
379
|
+
raise
|
|
380
|
+
self.logger.warning(f"数据连接暂未就绪,正在重试 {attempt + 1}/5")
|
|
381
|
+
time.sleep(3)
|
|
382
|
+
else:
|
|
383
|
+
raise DataConnectionError()
|
|
384
|
+
|
|
385
|
+
self.logger.info("数据连接已建立")
|
|
386
|
+
|
|
387
|
+
except redis.AuthenticationError:
|
|
388
|
+
raise DataConnectionError() from None
|
|
389
|
+
except Exception as e:
|
|
390
|
+
if isinstance(e, LiteQuantError):
|
|
391
|
+
raise self._public_error_clone(e) from None
|
|
392
|
+
raise DataConnectionError() from None
|
|
393
|
+
|
|
394
|
+
def _disconnect_server(self):
|
|
395
|
+
"""通知服务器断开连接"""
|
|
396
|
+
if not self._ticket:
|
|
397
|
+
return
|
|
398
|
+
|
|
399
|
+
try:
|
|
400
|
+
requests.post(
|
|
401
|
+
f"{self.api_url}/api/v1/data/disconnect/",
|
|
402
|
+
headers={
|
|
403
|
+
"Authorization": f"Bearer {self.api_token}",
|
|
404
|
+
"Content-Type": "application/json"
|
|
405
|
+
},
|
|
406
|
+
json={
|
|
407
|
+
"ticket": self._ticket.ticket_id,
|
|
408
|
+
"session_id": self._ticket.session_id
|
|
409
|
+
},
|
|
410
|
+
timeout=10
|
|
411
|
+
)
|
|
412
|
+
except Exception:
|
|
413
|
+
pass
|
|
414
|
+
|
|
415
|
+
def _ensure_connection(self):
|
|
416
|
+
"""确保连接有效(过期时自动刷新)"""
|
|
417
|
+
if not self._ticket or self._ticket.is_expired:
|
|
418
|
+
self.logger.info("连接已过期,正在重新连接...")
|
|
419
|
+
self._fetch_ticket()
|
|
420
|
+
self._connect_redis()
|
|
421
|
+
elif self._ticket.ttl_seconds < self.TICKET_REFRESH_THRESHOLD:
|
|
422
|
+
self._renew_ticket()
|
|
423
|
+
|
|
424
|
+
def _renew_ticket(self):
|
|
425
|
+
"""续租 Ticket"""
|
|
426
|
+
if not self._ticket:
|
|
427
|
+
raise DataConnectionError()
|
|
428
|
+
try:
|
|
429
|
+
data = self._api_post_json(
|
|
430
|
+
"/api/v1/data/heartbeat/",
|
|
431
|
+
{
|
|
432
|
+
"ticket": self._ticket.ticket_id,
|
|
433
|
+
"session_id": self._ticket.session_id,
|
|
434
|
+
},
|
|
435
|
+
timeout=10,
|
|
436
|
+
)
|
|
437
|
+
self._ticket.expires_at = datetime.now() + timedelta(seconds=data.get("expires_in", 60))
|
|
438
|
+
self.logger.debug("连接已保持")
|
|
439
|
+
except DataConnectionError:
|
|
440
|
+
self.logger.warning("连接已中断,正在重新连接...")
|
|
441
|
+
self._fetch_ticket()
|
|
442
|
+
self._connect_redis()
|
|
443
|
+
except LiteQuantError:
|
|
444
|
+
raise
|
|
445
|
+
except Exception:
|
|
446
|
+
raise ServiceUnavailableError() from None
|
|
447
|
+
|
|
448
|
+
def _start_heartbeat(self):
|
|
449
|
+
"""启动后台心跳线程"""
|
|
450
|
+
self._stop_heartbeat.clear()
|
|
451
|
+
self._heartbeat_thread = threading.Thread(target=self._heartbeat_loop, daemon=True)
|
|
452
|
+
self._heartbeat_thread.start()
|
|
453
|
+
self.logger.debug("心跳线程已启动")
|
|
454
|
+
|
|
455
|
+
def _heartbeat_loop(self):
|
|
456
|
+
"""心跳循环"""
|
|
457
|
+
while not self._stop_heartbeat.wait(self.HEARTBEAT_INTERVAL):
|
|
458
|
+
try:
|
|
459
|
+
if self._ticket and self._ticket.ttl_seconds < self.TICKET_REFRESH_THRESHOLD:
|
|
460
|
+
self._renew_ticket()
|
|
461
|
+
except LiteQuantError as e:
|
|
462
|
+
self.logger.warning(e.user_message)
|
|
463
|
+
except Exception:
|
|
464
|
+
self.logger.warning(ServiceUnavailableError().user_message)
|
|
465
|
+
|
|
466
|
+
# ============ 核心用户接口 ============
|
|
467
|
+
|
|
468
|
+
def UpdateAllCategory(self, update_method: Literal['full', 'incremental'] = 'full') -> None:
|
|
469
|
+
"""更新所有已授权的数据类别。"""
|
|
470
|
+
try:
|
|
471
|
+
return self._update_all_category(update_method)
|
|
472
|
+
except LiteQuantError as e:
|
|
473
|
+
err = self._public_error_clone(e)
|
|
474
|
+
self._display_error(err)
|
|
475
|
+
raise err from None
|
|
476
|
+
except Exception as e:
|
|
477
|
+
err = ServiceUnavailableError()
|
|
478
|
+
self._display_error(err)
|
|
479
|
+
raise err from None
|
|
480
|
+
|
|
481
|
+
def _update_all_category(self, update_method: Literal['full', 'incremental'] = 'full') -> None:
|
|
482
|
+
"""
|
|
483
|
+
更新所有授权的数据类别
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
update_method: 更新方式
|
|
487
|
+
- 'full' - 全量更新,扫描所有数据(首次使用或需要完整数据时)
|
|
488
|
+
- 'incremental' - 增量更新,只扫描最近两个月的数据(日常使用)
|
|
489
|
+
"""
|
|
490
|
+
if update_method not in ('full', 'incremental'):
|
|
491
|
+
raise ValueError("update_method must be 'full' or 'incremental'")
|
|
492
|
+
|
|
493
|
+
self.logger.info(f"开始更新数据 [方式: {update_method}]")
|
|
494
|
+
|
|
495
|
+
# 计算增量更新的日期范围(最近两个月)
|
|
496
|
+
incremental_start_date = None
|
|
497
|
+
if update_method == 'incremental':
|
|
498
|
+
incremental_start_date = datetime.now() - timedelta(days=60)
|
|
499
|
+
self.logger.info(f"增量更新起始日期: {incremental_start_date.strftime('%Y-%m-%d')}")
|
|
500
|
+
|
|
501
|
+
# 确保连接有效
|
|
502
|
+
self._ensure_connection()
|
|
503
|
+
|
|
504
|
+
# 刷新类别列表
|
|
505
|
+
self._refresh_categories()
|
|
506
|
+
|
|
507
|
+
if not self._remote_categories:
|
|
508
|
+
self.logger.warning("没有可更新的数据类别")
|
|
509
|
+
return
|
|
510
|
+
|
|
511
|
+
self.logger.info(f"发现 {len(self._remote_categories)} 个类别")
|
|
512
|
+
|
|
513
|
+
# 同步所有类别
|
|
514
|
+
failures = []
|
|
515
|
+
for category in self._remote_categories:
|
|
516
|
+
try:
|
|
517
|
+
self.logger.info(f"正在同步: {category}")
|
|
518
|
+
self._sync_category(category, incremental_start_date=incremental_start_date)
|
|
519
|
+
except Exception as e:
|
|
520
|
+
err = self._as_litequant_error(e)
|
|
521
|
+
failures.append((category, err))
|
|
522
|
+
self.logger.error(f"{category}: {err.user_message}")
|
|
523
|
+
|
|
524
|
+
if failures:
|
|
525
|
+
raise SyncError(user_message=f"部分数据更新失败:{len(failures)} 个数据类别未完成")
|
|
526
|
+
|
|
527
|
+
self.logger.info("数据更新完成")
|
|
528
|
+
|
|
529
|
+
def GetCategory(
|
|
530
|
+
self,
|
|
531
|
+
category: str,
|
|
532
|
+
start_date: Optional[str] = None,
|
|
533
|
+
end_date: Optional[str] = None,
|
|
534
|
+
) -> pd.DataFrame:
|
|
535
|
+
"""读取指定数据类别。"""
|
|
536
|
+
try:
|
|
537
|
+
return self._get_category(category, start_date, end_date)
|
|
538
|
+
except LiteQuantError as e:
|
|
539
|
+
err = self._public_error_clone(e)
|
|
540
|
+
self._display_error(err)
|
|
541
|
+
raise err from None
|
|
542
|
+
except Exception as e:
|
|
543
|
+
err = ServiceUnavailableError()
|
|
544
|
+
self._display_error(err)
|
|
545
|
+
raise err from None
|
|
546
|
+
|
|
547
|
+
def _get_category(
|
|
548
|
+
self,
|
|
549
|
+
category: str,
|
|
550
|
+
start_date: Optional[str] = None,
|
|
551
|
+
end_date: Optional[str] = None,
|
|
552
|
+
) -> pd.DataFrame:
|
|
553
|
+
"""
|
|
554
|
+
读取数据类别
|
|
555
|
+
|
|
556
|
+
Args:
|
|
557
|
+
category: 数据类别名称,如 "cn_stock_pivot#open"
|
|
558
|
+
start_date: 开始日期(可选)
|
|
559
|
+
end_date: 结束日期(可选)
|
|
560
|
+
|
|
561
|
+
Returns:
|
|
562
|
+
DataFrame
|
|
563
|
+
"""
|
|
564
|
+
self.logger.debug(f"读取数据: {category}")
|
|
565
|
+
|
|
566
|
+
# 检查本地是否存在
|
|
567
|
+
local_categories = self.LQ_db._get_all_categories()
|
|
568
|
+
|
|
569
|
+
if category not in local_categories:
|
|
570
|
+
self.logger.info(f"本地不存在 {category},从远程同步...")
|
|
571
|
+
self._ensure_connection()
|
|
572
|
+
|
|
573
|
+
if category not in self._remote_categories_set:
|
|
574
|
+
self._refresh_categories()
|
|
575
|
+
|
|
576
|
+
if category not in self._remote_categories_set:
|
|
577
|
+
raise CategoryNotFoundError(f"数据类别不存在: {category}")
|
|
578
|
+
|
|
579
|
+
self._sync_category(category)
|
|
580
|
+
|
|
581
|
+
# 读取本地数据
|
|
582
|
+
category_type = self._validate_category_name(category)
|
|
583
|
+
|
|
584
|
+
if category_type == "pivot":
|
|
585
|
+
return self.LQ_db.read_pivot_category(category, start_date, end_date)
|
|
586
|
+
else:
|
|
587
|
+
return self.LQ_db.read_unstack_category(category)
|
|
588
|
+
|
|
589
|
+
def ListCategories(self) -> List[str]:
|
|
590
|
+
"""列出所有可用的数据类别(远程)"""
|
|
591
|
+
try:
|
|
592
|
+
return self._list_categories()
|
|
593
|
+
except LiteQuantError as e:
|
|
594
|
+
err = self._public_error_clone(e)
|
|
595
|
+
self._display_error(err)
|
|
596
|
+
raise err from None
|
|
597
|
+
except Exception as e:
|
|
598
|
+
err = ServiceUnavailableError()
|
|
599
|
+
self._display_error(err)
|
|
600
|
+
raise err from None
|
|
601
|
+
|
|
602
|
+
def _list_categories(self) -> List[str]:
|
|
603
|
+
"""列出所有可用的数据类别(远程)"""
|
|
604
|
+
self._ensure_connection()
|
|
605
|
+
self._refresh_categories()
|
|
606
|
+
return list(self._remote_categories)
|
|
607
|
+
|
|
608
|
+
def ListLocalCategories(self) -> List[str]:
|
|
609
|
+
"""列出本地已存储的数据类别"""
|
|
610
|
+
return self.LQ_db._get_all_categories()
|
|
611
|
+
|
|
612
|
+
def close(self):
|
|
613
|
+
"""关闭客户端,释放资源"""
|
|
614
|
+
self.logger.info("关闭 LiteQuantClient...")
|
|
615
|
+
|
|
616
|
+
# 停止心跳
|
|
617
|
+
self._stop_heartbeat.set()
|
|
618
|
+
if self._heartbeat_thread and self._heartbeat_thread.is_alive():
|
|
619
|
+
self._heartbeat_thread.join(timeout=5)
|
|
620
|
+
|
|
621
|
+
# 关闭 Redis 连接
|
|
622
|
+
if self._redis_client:
|
|
623
|
+
try:
|
|
624
|
+
self._redis_client.close()
|
|
625
|
+
except Exception:
|
|
626
|
+
pass
|
|
627
|
+
self._redis_client = None
|
|
628
|
+
|
|
629
|
+
# 通知服务器断开
|
|
630
|
+
if self._ticket:
|
|
631
|
+
self._disconnect_server()
|
|
632
|
+
|
|
633
|
+
self.logger.info("LiteQuantClient 已关闭")
|
|
634
|
+
|
|
635
|
+
def __enter__(self):
|
|
636
|
+
return self
|
|
637
|
+
|
|
638
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
639
|
+
self.close()
|
|
640
|
+
return False
|
|
641
|
+
|
|
642
|
+
# ============ 内部方法 ============
|
|
643
|
+
|
|
644
|
+
def _refresh_categories(self):
|
|
645
|
+
"""刷新类别列表,只保留授权的数据集"""
|
|
646
|
+
categories = self._json_get_required(self.META_CATEGORIES_KEY)
|
|
647
|
+
if not isinstance(categories, list):
|
|
648
|
+
raise MetaError("categories 必须是列表")
|
|
649
|
+
|
|
650
|
+
all_categories = [str(x) for x in categories]
|
|
651
|
+
|
|
652
|
+
# 过滤:只保留授权数据集的类别
|
|
653
|
+
if self._ticket and self._ticket.datasets:
|
|
654
|
+
authorized_datasets = set(self._ticket.datasets)
|
|
655
|
+
filtered_categories = [
|
|
656
|
+
cat for cat in all_categories
|
|
657
|
+
if any(cat.startswith(f"{ds}#") for ds in authorized_datasets)
|
|
658
|
+
]
|
|
659
|
+
self.logger.debug("过滤类别: %s -> %s", len(all_categories), len(filtered_categories))
|
|
660
|
+
else:
|
|
661
|
+
filtered_categories = all_categories
|
|
662
|
+
|
|
663
|
+
self._remote_categories = sorted(filtered_categories)
|
|
664
|
+
self._remote_categories_set = set(self._remote_categories)
|
|
665
|
+
self.logger.info(f"加载了 {len(self._remote_categories)} 个类别")
|
|
666
|
+
|
|
667
|
+
def _sync_category(self, category: str, incremental_start_date: datetime = None):
|
|
668
|
+
"""同步单个类别"""
|
|
669
|
+
category_type = self._validate_category_name(category)
|
|
670
|
+
|
|
671
|
+
if category_type == "pivot":
|
|
672
|
+
self._sync_pivot_category(category, incremental_start_date)
|
|
673
|
+
else:
|
|
674
|
+
self._sync_unstack_category(category)
|
|
675
|
+
|
|
676
|
+
def _sync_pivot_category(self, category: str, incremental_start_date: datetime = None):
|
|
677
|
+
"""同步 Pivot 类别"""
|
|
678
|
+
if isinstance(incremental_start_date, str):
|
|
679
|
+
incremental_start_date = datetime.strptime(incremental_start_date, "%Y-%m-%d")
|
|
680
|
+
|
|
681
|
+
meta = self._get_category_meta(category)
|
|
682
|
+
self.LQ_db.set_category_meta(category, meta, merge=True)
|
|
683
|
+
partitions = meta.get("partition_keys", [])
|
|
684
|
+
partition_mode = meta.get("partition_mode", "monthly")
|
|
685
|
+
if partition_mode not in ("daily", "monthly"):
|
|
686
|
+
raise ProtocolError()
|
|
687
|
+
|
|
688
|
+
if not partitions:
|
|
689
|
+
self.logger.info(f"{category}: 没有分区")
|
|
690
|
+
return
|
|
691
|
+
|
|
692
|
+
# 如果指定了增量更新日期,过滤分区
|
|
693
|
+
filtered_partitions = partitions
|
|
694
|
+
if incremental_start_date is not None:
|
|
695
|
+
filtered_partitions = self._filter_partitions_by_date(partitions, incremental_start_date)
|
|
696
|
+
self.logger.info(f"{category}: 全量分区 {len(partitions)} 个,增量分区 {len(filtered_partitions)} 个")
|
|
697
|
+
|
|
698
|
+
for partition in tqdm(filtered_partitions, desc=f"同步 {category}"):
|
|
699
|
+
partition_meta = self._get_partition_meta(category, partition)
|
|
700
|
+
remote_hash = partition_meta.get("hash")
|
|
701
|
+
|
|
702
|
+
# 下载数据
|
|
703
|
+
df = self._download_partition(category, partition, partition_meta)
|
|
704
|
+
self.LQ_db.update_pivot_category(
|
|
705
|
+
df=df,
|
|
706
|
+
category=category,
|
|
707
|
+
strf="%Y-%m-%d",
|
|
708
|
+
partition_mode=partition_mode,
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
# 保存 hash
|
|
712
|
+
if remote_hash:
|
|
713
|
+
self._write_local_hash(category, partition, remote_hash)
|
|
714
|
+
|
|
715
|
+
def _sync_unstack_category(self, category: str):
|
|
716
|
+
"""同步 Unstack 类别(unstack类型通常不分区,直接全量同步)"""
|
|
717
|
+
partition = "full"
|
|
718
|
+
category_meta = self._get_category_meta(category)
|
|
719
|
+
self.LQ_db.set_category_meta(category, category_meta, merge=True)
|
|
720
|
+
partition_meta = self._get_partition_meta(category, partition)
|
|
721
|
+
remote_hash = partition_meta.get("hash")
|
|
722
|
+
duplicate_keys = (
|
|
723
|
+
partition_meta.get("duplicate_keys")
|
|
724
|
+
or category_meta.get("duplicate_keys")
|
|
725
|
+
or partition_meta.get("key_columns")
|
|
726
|
+
or partition_meta.get("dedupe_keys")
|
|
727
|
+
or category_meta.get("key_columns")
|
|
728
|
+
or category_meta.get("dedupe_keys")
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
df = self._download_partition(category, partition, partition_meta)
|
|
732
|
+
if not isinstance(duplicate_keys, (list, tuple, str)):
|
|
733
|
+
duplicate_keys = None
|
|
734
|
+
self.LQ_db.update_unstack_category(category=category, sub_df=df, duplicate_keys=duplicate_keys)
|
|
735
|
+
|
|
736
|
+
if remote_hash:
|
|
737
|
+
self._write_local_hash(category, partition, remote_hash)
|
|
738
|
+
|
|
739
|
+
def _download_partition(self, category: str, partition: str, partition_meta: dict) -> pd.DataFrame:
|
|
740
|
+
"""下载分区数据"""
|
|
741
|
+
data_key = partition_meta.get("data_key")
|
|
742
|
+
if not data_key:
|
|
743
|
+
raise MetaError()
|
|
744
|
+
|
|
745
|
+
# 获取数据(带重试)
|
|
746
|
+
payload = self._redis_get_with_retry(data_key)
|
|
747
|
+
if payload is None:
|
|
748
|
+
raise RemoteDataError()
|
|
749
|
+
|
|
750
|
+
# 校验 hash
|
|
751
|
+
remote_hash = partition_meta.get("hash")
|
|
752
|
+
if remote_hash:
|
|
753
|
+
local_hash = "sha256:" + hashlib.sha256(payload).hexdigest()
|
|
754
|
+
if local_hash != remote_hash:
|
|
755
|
+
raise SerializationError()
|
|
756
|
+
|
|
757
|
+
# 解析 Parquet
|
|
758
|
+
try:
|
|
759
|
+
return pd.read_parquet(io.BytesIO(payload))
|
|
760
|
+
except Exception:
|
|
761
|
+
raise SerializationError() from None
|
|
762
|
+
|
|
763
|
+
def _redis_get_with_retry(self, key: str, max_retries: int = 2) -> bytes:
|
|
764
|
+
"""带重试的 Redis GET"""
|
|
765
|
+
for attempt in range(max_retries + 1):
|
|
766
|
+
try:
|
|
767
|
+
self._ensure_connection()
|
|
768
|
+
return self._redis_client.get(key)
|
|
769
|
+
except redis.AuthenticationError:
|
|
770
|
+
if attempt < max_retries:
|
|
771
|
+
self._fetch_ticket()
|
|
772
|
+
self._connect_redis()
|
|
773
|
+
else:
|
|
774
|
+
raise DataConnectionError()
|
|
775
|
+
except Exception as e:
|
|
776
|
+
if isinstance(e, LiteQuantError):
|
|
777
|
+
raise self._public_error_clone(e) from None
|
|
778
|
+
if attempt < max_retries:
|
|
779
|
+
time.sleep(0.5)
|
|
780
|
+
else:
|
|
781
|
+
raise RemoteDataError() from None
|
|
782
|
+
|
|
783
|
+
def _json_get(self, key: str):
|
|
784
|
+
"""获取 JSON 数据"""
|
|
785
|
+
data = self._redis_get_with_retry(key)
|
|
786
|
+
if data is None:
|
|
787
|
+
return None
|
|
788
|
+
try:
|
|
789
|
+
if isinstance(data, bytes):
|
|
790
|
+
data = data.decode("utf-8")
|
|
791
|
+
return json.loads(data)
|
|
792
|
+
except Exception:
|
|
793
|
+
raise SerializationError() from None
|
|
794
|
+
|
|
795
|
+
def _json_get_required(self, key: str):
|
|
796
|
+
"""获取必需的 JSON 数据"""
|
|
797
|
+
obj = self._json_get(key)
|
|
798
|
+
if obj is None:
|
|
799
|
+
raise MetaError()
|
|
800
|
+
return obj
|
|
801
|
+
|
|
802
|
+
def _get_category_meta(self, category: str) -> dict:
|
|
803
|
+
"""获取类别元数据"""
|
|
804
|
+
key = f"{self.META_CATEGORY_KEY_PREFIX}{category}"
|
|
805
|
+
meta = self._json_get_required(key)
|
|
806
|
+
self._validate_protocol(meta)
|
|
807
|
+
return meta
|
|
808
|
+
|
|
809
|
+
def _get_partition_meta(self, category: str, partition: str) -> dict:
|
|
810
|
+
"""获取分区元数据"""
|
|
811
|
+
key = f"{self.META_PARTITION_KEY_PREFIX}{category}:{partition}"
|
|
812
|
+
meta = self._json_get_required(key)
|
|
813
|
+
self._validate_protocol(meta)
|
|
814
|
+
return meta
|
|
815
|
+
|
|
816
|
+
def _validate_protocol(self, meta: dict):
|
|
817
|
+
"""验证协议版本"""
|
|
818
|
+
if meta.get("protocol_version") != self.PROTOCOL_VERSION:
|
|
819
|
+
raise ProtocolError()
|
|
820
|
+
if meta.get("data_format") != self.DATA_FORMAT:
|
|
821
|
+
raise ProtocolError()
|
|
822
|
+
|
|
823
|
+
def _validate_category_name(self, category: str) -> str:
|
|
824
|
+
"""验证类别名称"""
|
|
825
|
+
if "pivot#" in category:
|
|
826
|
+
return "pivot"
|
|
827
|
+
if "unstack#" in category:
|
|
828
|
+
return "unstack"
|
|
829
|
+
raise InvalidCategoryError()
|
|
830
|
+
|
|
831
|
+
def _get_hash_path(self, category: str, partition: str) -> str:
|
|
832
|
+
"""获取 hash 文件路径"""
|
|
833
|
+
hash_dir = os.path.join(self.save_path, "_litequant_meta_hash")
|
|
834
|
+
os.makedirs(hash_dir, exist_ok=True)
|
|
835
|
+
safe_name = f"{category}__{partition}".replace("/", "__")
|
|
836
|
+
return os.path.join(hash_dir, f"{safe_name}.txt")
|
|
837
|
+
|
|
838
|
+
def _check_local_hash(self, category: str, partition: str, remote_hash: str) -> bool:
|
|
839
|
+
"""检查本地 hash 是否匹配"""
|
|
840
|
+
if not remote_hash:
|
|
841
|
+
return False
|
|
842
|
+
path = self._get_hash_path(category, partition)
|
|
843
|
+
if not os.path.exists(path):
|
|
844
|
+
return False
|
|
845
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
846
|
+
local_hash = f.read().strip()
|
|
847
|
+
return local_hash == remote_hash
|
|
848
|
+
|
|
849
|
+
def _write_local_hash(self, category: str, partition: str, hash_value: str):
|
|
850
|
+
"""写入本地 hash"""
|
|
851
|
+
path = self._get_hash_path(category, partition)
|
|
852
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
853
|
+
f.write(hash_value)
|
|
854
|
+
|
|
855
|
+
def _filter_partitions_by_date(self, partitions: List[str], start_date: datetime) -> List[str]:
|
|
856
|
+
"""
|
|
857
|
+
根据日期过滤分区
|
|
858
|
+
|
|
859
|
+
支持的分区格式:
|
|
860
|
+
- YYYY-MM-DD (日分区)
|
|
861
|
+
- YYYY-MM (月分区)
|
|
862
|
+
- YYYY (年分区)
|
|
863
|
+
"""
|
|
864
|
+
filtered = []
|
|
865
|
+
start_date_str = start_date.strftime("%Y-%m-%d")
|
|
866
|
+
start_month_str = start_date.strftime("%Y-%m")
|
|
867
|
+
start_year_str = start_date.strftime("%Y")
|
|
868
|
+
|
|
869
|
+
for partition in partitions:
|
|
870
|
+
# 尝试匹配不同格式的分区
|
|
871
|
+
if len(partition) == 10 and partition.count('-') == 2: # YYYY-MM-DD
|
|
872
|
+
if partition >= start_date_str:
|
|
873
|
+
filtered.append(partition)
|
|
874
|
+
elif len(partition) == 7 and partition.count('-') == 1: # YYYY-MM
|
|
875
|
+
if partition >= start_month_str:
|
|
876
|
+
filtered.append(partition)
|
|
877
|
+
elif len(partition) == 4 and partition.isdigit(): # YYYY
|
|
878
|
+
if partition >= start_year_str:
|
|
879
|
+
filtered.append(partition)
|
|
880
|
+
else:
|
|
881
|
+
# 无法识别的格式,默认包含
|
|
882
|
+
filtered.append(partition)
|
|
883
|
+
|
|
884
|
+
return filtered
|