mdbq 4.0.1__py3-none-any.whl → 4.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mdbq/mysql/s_query.py CHANGED
@@ -6,11 +6,14 @@ import pandas as pd
6
6
  from decimal import Decimal
7
7
  from contextlib import closing
8
8
  from mdbq.log import mylogger
9
+ import os
10
+ from mdbq.config import config
11
+ from typing import Optional, Dict, List, Set, Tuple, Union, Any, Literal
12
+ from dbutils.pooled_db import PooledDB
13
+ import time
14
+ from functools import wraps
9
15
 
10
16
  warnings.filterwarnings('ignore')
11
- """
12
- 程序专门用来下载数据库数据, 并返回 df, 不做清洗数据操作;
13
- """
14
17
  logger = mylogger.MyLogger(
15
18
  logging_mode='file',
16
19
  log_level='info',
@@ -26,117 +29,622 @@ logger = mylogger.MyLogger(
26
29
 
27
30
  class QueryDatas:
28
31
  """
29
- 数据库查询工具类。
30
- 用于连接MySQL数据库,支持表结构检查、条件查询、数据导出为DataFrame、列名和类型获取等功能。
32
+ 专门用来查询数据库, 不做清洗数据操作。
33
+ 支持表结构检查、条件查询、数据导出为DataFrame、列名和类型获取等。
34
+ 支持分页查询和上下文管理。
31
35
  """
32
36
 
33
- def __init__(self, username: str, password: str, host: str, port: int, charset: str = 'utf8mb4'):
37
+ def __init__(self, username: str, password: str, host: str, port: int, charset: str = 'utf8mb4',
38
+ maxconnections: int = 20, mincached: int = 2, maxcached: int = 5,
39
+ connect_timeout: int = 10, read_timeout: int = 30, write_timeout: int = 30,
40
+ max_retries: int = 3, retry_waiting_time: int = 5, collation: str = 'utf8mb4_0900_ai_ci') -> None:
34
41
  """
35
- 初始化数据库连接配置。
36
- :param username: 数据库用户名
37
- :param password: 数据库密码
38
- :param host: 数据库主机
39
- :param port: 数据库端口
40
- :param charset: 字符集,默认utf8mb4
42
+ 初始化数据库连接配置和连接池。
43
+
44
+ Args:
45
+ username: 数据库用户名
46
+ password: 数据库密码
47
+ host: 数据库主机
48
+ port: 数据库端口
49
+ charset: 字符集,默认utf8mb4
50
+ maxconnections: 最大连接数,默认20
51
+ mincached: 最小缓存连接数,默认2
52
+ maxcached: 最大缓存连接数,默认5
53
+ connect_timeout: 连接超时时间,默认10秒
54
+ read_timeout: 读取超时时间,默认30秒
55
+ write_timeout: 写入超时时间,默认30秒
56
+ max_retries: 最大重试次数,默认3次
57
+ retry_waiting_time: 重试等待时间,默认5秒
58
+ collation: 排序规则,默认utf8mb4_0900_ai_ci
41
59
  """
42
60
  self.username = username
43
61
  self.password = password
44
62
  self.host = host
45
63
  self.port = port
46
- self.config = {
64
+ self.charset = charset
65
+ self.collation = collation
66
+ self.max_retries = max_retries
67
+ self.retry_waiting_time = retry_waiting_time
68
+ self.connect_timeout = connect_timeout
69
+ self.read_timeout = read_timeout
70
+ self.write_timeout = write_timeout
71
+
72
+ # 连接池状态监控
73
+ self._pool_stats = {
74
+ 'last_health_check': None,
75
+ 'health_check_interval': 300, # 5分钟检查一次
76
+ 'consecutive_failures': 0, # 连续失败次数
77
+ 'max_consecutive_failures': 3 # 最大连续失败次数
78
+ }
79
+
80
+ self.base_config = {
47
81
  'host': self.host,
48
82
  'port': int(self.port),
49
83
  'user': self.username,
50
84
  'password': self.password,
51
- 'charset': charset, # utf8mb4 支持存储四字节的UTF-8字符集
85
+ 'charset': charset,
86
+ 'collation': self.collation,
52
87
  'cursorclass': pymysql.cursors.DictCursor,
88
+ 'connect_timeout': connect_timeout,
89
+ 'read_timeout': read_timeout,
90
+ 'write_timeout': write_timeout,
91
+ 'autocommit': True
53
92
  }
93
+
94
+ # 创建连接池
95
+ self.pool = self._create_connection_pool(maxconnections, mincached, maxcached)
54
96
 
55
- def check_condition(self, db_name, table_name, condition, columns='更新时间'):
97
+ def _create_connection_pool(self, maxconnections: int, mincached: int, maxcached: int) -> PooledDB:
56
98
  """
57
- 按指定条件查询数据库表,返回满足条件的指定字段数据。
58
- :param db_name: 数据库名
59
- :param table_name: 表名
60
- :param condition: SQL条件字符串(不含WHERE)
61
- :param columns: 查询字段字符串或以逗号分隔的字段名,默认'更新时间'
62
- :return: 查询结果列表或None
99
+ 创建数据库连接池
100
+
101
+ Args:
102
+ maxconnections: 最大连接数
103
+ mincached: 最小缓存连接数
104
+ maxcached: 最大缓存连接数
105
+
106
+ Returns:
107
+ PooledDB连接池实例
108
+
109
+ Raises:
110
+ ConnectionError: 当连接池创建失败时抛出
63
111
  """
64
- if not self.check_infos(db_name, table_name):
65
- return None
66
- self.config.update({'database': db_name})
112
+ if hasattr(self, 'pool') and self.pool is not None and self._check_pool_health():
113
+ return self.pool
114
+
115
+ self.pool = None
116
+
117
+ # 连接参数 - 这些参数会传递给底层的连接创建函数
118
+ connection_params = {
119
+ 'host': self.host,
120
+ 'port': int(self.port),
121
+ 'user': self.username,
122
+ 'password': self.password,
123
+ 'charset': self.charset,
124
+ 'collation': self.collation,
125
+ 'cursorclass': pymysql.cursors.DictCursor,
126
+ 'connect_timeout': self.connect_timeout,
127
+ 'read_timeout': self.read_timeout,
128
+ 'write_timeout': self.write_timeout,
129
+ 'autocommit': True
130
+ }
131
+
132
+ # 连接池参数
133
+ pool_params = {
134
+ 'creator': pymysql,
135
+ 'maxconnections': maxconnections,
136
+ 'mincached': mincached,
137
+ 'maxcached': maxcached,
138
+ 'blocking': True,
139
+ 'maxusage': 2000, # 每个连接最多使用次数
140
+ 'setsession': [],
141
+ 'ping': 7
142
+ }
143
+
67
144
  try:
68
- with closing(pymysql.connect(**self.config)) as connection:
69
- with closing(connection.cursor()) as cursor:
70
- sql = f"SELECT {columns} FROM `{table_name}` WHERE {condition}"
71
- logger.debug(f"check_condition SQL: {sql}")
72
- cursor.execute(sql)
73
- result = cursor.fetchall()
145
+ # 创建连接池,将连接参数作为kwargs传递
146
+ pool = PooledDB(**pool_params, **connection_params)
147
+ logger.debug('连接池创建成功', {
148
+ '连接池大小': maxconnections,
149
+ '最小缓存': mincached,
150
+ '最大缓存': maxcached,
151
+ '主机': self.host,
152
+ '端口': self.port
153
+ })
154
+ return pool
155
+ except Exception as e:
156
+ self.pool = None
157
+ logger.error('连接池创建失败', {
158
+ '错误': str(e),
159
+ '主机': self.host,
160
+ '端口': self.port
161
+ })
162
+ raise ConnectionError(f'连接池创建失败: {str(e)}')
163
+
164
+ def _check_pool_health(self) -> bool:
165
+ """
166
+ 检查连接池健康状态
167
+
168
+ Returns:
169
+ bool: 连接池是否健康
170
+ """
171
+ if not self.pool:
172
+ return False
173
+
174
+ current_time = time.time()
175
+ # 检查是否需要执行健康检查
176
+ if (self._pool_stats['last_health_check'] is None or
177
+ current_time - self._pool_stats['last_health_check'] > self._pool_stats['health_check_interval']):
178
+
179
+ try:
180
+ # 更新健康检查时间
181
+ self._pool_stats['last_health_check'] = current_time
182
+
183
+ # 检查连接是否可用
184
+ with self.pool.connection() as conn:
185
+ with conn.cursor() as cursor:
186
+ cursor.execute('SELECT 1')
187
+ result = cursor.fetchone()
188
+ if not result or result.get('1') != 1:
189
+ self._pool_stats['consecutive_failures'] += 1
190
+ if self._pool_stats['consecutive_failures'] >= self._pool_stats['max_consecutive_failures']:
191
+ logger.error('连接池健康检查连续失败', {
192
+ '连续失败次数': self._pool_stats['consecutive_failures']
193
+ })
194
+ return False
195
+
196
+ # 重置连续失败计数
197
+ self._pool_stats['consecutive_failures'] = 0
198
+ logger.debug('连接池健康检查通过')
199
+ return True
200
+
201
+ except Exception as e:
202
+ self._pool_stats['consecutive_failures'] += 1
203
+ if self._pool_stats['consecutive_failures'] >= self._pool_stats['max_consecutive_failures']:
204
+ logger.error('连接池健康检查失败', {
205
+ '错误类型': type(e).__name__,
206
+ '错误信息': str(e),
207
+ '连续失败次数': self._pool_stats['consecutive_failures']
208
+ })
209
+ return False
210
+
211
+ return True
212
+
213
+ @staticmethod
214
+ def _execute_with_retry(func):
215
+ """
216
+ 带重试机制的装饰器,用于数据库操作
217
+
218
+ Args:
219
+ func: 被装饰的函数
220
+
221
+ Returns:
222
+ 装饰后的函数
223
+ """
224
+ @wraps(func)
225
+ def wrapper(self, *args, **kwargs):
226
+ last_exception = None
227
+ operation = func.__name__
228
+
229
+ for attempt in range(self.max_retries):
230
+ try:
231
+ result = func(self, *args, **kwargs)
232
+ if attempt > 0:
233
+ logger.info('操作成功(重试后)', {
234
+ '操作': operation,
235
+ '重试次数': attempt + 1
236
+ })
74
237
  return result
238
+ except (pymysql.OperationalError, pymysql.err.MySQLError) as e:
239
+ last_exception = e
240
+ error_details = {
241
+ '操作': operation,
242
+ '错误代码': e.args[0] if e.args else None,
243
+ '错误信息': e.args[1] if len(e.args) > 1 else None,
244
+ '尝试次数': attempt + 1,
245
+ '最大重试次数': self.max_retries
246
+ }
247
+
248
+ if attempt < self.max_retries - 1:
249
+ wait_time = self.retry_waiting_time * (attempt + 1)
250
+ error_details['等待时间'] = wait_time
251
+ logger.warning('数据库操作失败,准备重试', error_details)
252
+ time.sleep(wait_time)
253
+ try:
254
+ self.pool = self._create_connection_pool(
255
+ maxconnections=self.pool._maxconnections,
256
+ mincached=self.pool._mincached,
257
+ maxcached=self.pool._maxcached
258
+ )
259
+ logger.info('成功重新建立数据库连接')
260
+ except Exception as reconnect_error:
261
+ logger.error('重连失败', {'错误': str(reconnect_error)})
262
+ else:
263
+ logger.error('操作最终失败', error_details)
264
+ except Exception as e:
265
+ last_exception = e
266
+ logger.error('发生意外错误', {
267
+ '操作': operation,
268
+ '错误类型': type(e).__name__,
269
+ '错误信息': str(e)
270
+ })
271
+ break
272
+
273
+ raise last_exception if last_exception else Exception('发生未知错误')
274
+ return wrapper
275
+
276
+ # @_execute_with_retry
277
+ def _get_connection(self, db_name: Optional[str] = None) -> pymysql.connections.Connection:
278
+ """
279
+ 从连接池获取数据库连接
280
+
281
+ Args:
282
+ db_name: 可选的数据库名,如果提供则会在连接后选择该数据库
283
+
284
+ Returns:
285
+ 数据库连接对象
286
+
287
+ Raises:
288
+ ConnectionError: 当获取连接失败时抛出
289
+ """
290
+ try:
291
+ # 只在连续失败次数达到阈值时检查健康状态
292
+ if self._pool_stats['consecutive_failures'] >= self._pool_stats['max_consecutive_failures']:
293
+ if not self._check_pool_health():
294
+ logger.warning('连接池不健康,尝试重新创建')
295
+ # 使用默认值重新创建连接池
296
+ self.pool = self._create_connection_pool(10, 2, 5)
297
+ # 重置连续失败计数
298
+ self._pool_stats['consecutive_failures'] = 0
299
+
300
+ conn = self.pool.connection()
301
+ if db_name:
302
+ # 使用原生pymysql连接来选择数据库
303
+ with conn.cursor() as cursor:
304
+ cursor.execute(f"USE `{db_name}`")
305
+ return conn
306
+ except pymysql.OperationalError as e:
307
+ error_code = e.args[0] if e.args else None
308
+ if error_code in (2003, 2006, 2013): # 连接相关错误
309
+ logger.error('数据库连接错误', {
310
+ '错误代码': error_code,
311
+ '错误信息': str(e),
312
+ '数据库': db_name
313
+ })
314
+ # 使用默认值重新创建连接池
315
+ self.pool = self._create_connection_pool(10, 2, 5)
316
+ # 重置连续失败计数
317
+ self._pool_stats['consecutive_failures'] = 0
318
+ raise ConnectionError(f'数据库连接错误: {str(e)}')
319
+ else:
320
+ raise
75
321
  except Exception as e:
76
- logger.error(f"check_condition error: {e}")
77
- return None
322
+ logger.error('从连接池获取数据库连接失败', {
323
+ '错误': str(e),
324
+ '数据库': db_name
325
+ })
326
+ raise ConnectionError(f'连接数据库失败: {str(e)}')
78
327
 
79
- def data_to_df(self, db_name, table_name, start_date, end_date, projection: dict = None, limit: int = None):
328
+ # @_execute_with_retry
329
+ def _execute_query(self, sql: str, params: tuple = None, db_name: str = None) -> Optional[List[Dict[str, Any]]]:
330
+ """
331
+ 执行SQL查询的通用方法。
332
+
333
+ Args:
334
+ sql: SQL查询语句
335
+ params: 查询参数
336
+ db_name: 数据库名
337
+
338
+ Returns:
339
+ 查询结果列表,如果查询失败返回None
80
340
  """
81
- 从数据库表获取数据到DataFrame,支持列筛选、日期范围过滤和行数限制。
82
- :param db_name: 数据库名
83
- :param table_name: 表名
84
- :param start_date: 起始日期(包含)
85
- :param end_date: 结束日期(包含)
86
- :param projection: 列筛选字典,e.g. {'日期': 1, '场景名字': 1}
87
- :param limit: 限制返回的最大行数
88
- :return: 查询结果的DataFrame
89
- """
90
- projection = projection or {}
91
- df = pd.DataFrame()
92
341
  try:
93
- start_date = pd.to_datetime(start_date or '1970-01-01').strftime('%Y-%m-%d')
94
- end_date = pd.to_datetime(end_date or datetime.datetime.today()).strftime('%Y-%m-%d')
342
+ with closing(self._get_connection(db_name)) as connection:
343
+ with closing(connection.cursor()) as cursor:
344
+ cursor.execute(sql, params)
345
+ return cursor.fetchall()
95
346
  except Exception as e:
96
- logger.error(f"日期格式错误: {e}")
97
- return df
347
+ logger.error('执行SQL查询失败', {
348
+ 'SQL': sql,
349
+ '参数': params,
350
+ '数据库': db_name,
351
+ '错误类型': type(e).__name__,
352
+ '错误信息': str(e)
353
+ })
354
+ return None
355
+
356
+ def check_condition(self, db_name: str, table_name: str, condition: str, columns: str = '更新时间') -> Optional[List[Dict[str, Any]]]:
357
+ """
358
+ 按指定条件查询数据库表,返回满足条件的指定字段数据。
359
+
360
+ Args:
361
+ db_name: 数据库名
362
+ table_name: 表名
363
+ condition: SQL条件字符串(不含WHERE)
364
+ columns: 查询字段字符串或以逗号分隔的字段名,默认'更新时间'
365
+
366
+ Returns:
367
+ 查询结果列表,如果查询失败返回None
368
+ """
98
369
  if not self.check_infos(db_name, table_name):
99
- return df
100
- self.config['database'] = db_name
370
+ return None
371
+
372
+ sql = f"SELECT {columns} FROM `{table_name}` WHERE {condition}"
373
+ logger.debug('执行SQL查询', {'库': db_name, '表': table_name, 'SQL': sql})
374
+ return self._execute_query(sql, db_name=db_name)
375
+
376
+ def validate_and_format_date(self, date_str: Optional[str], default_date: str) -> str:
377
+ """
378
+ 验证并格式化日期字符串。
379
+
380
+ Args:
381
+ date_str: 日期字符串,支持多种格式
382
+ default_date: 默认日期,当date_str无效时使用
383
+
384
+ Returns:
385
+ 格式化后的日期字符串 'YYYY-MM-DD'
386
+
387
+ Raises:
388
+ ValueError: 当日期格式无法解析时
389
+ """
390
+ if not date_str:
391
+ return default_date
392
+
393
+ # 记录尝试的日期格式
394
+ attempted_formats = []
101
395
  try:
102
- with closing(pymysql.connect(**self.config)) as connection:
103
- with closing(connection.cursor()) as cursor:
104
- cursor.execute(
105
- """SELECT COLUMN_NAME FROM information_schema.columns WHERE table_schema = %s AND table_name = %s""",
106
- (db_name, table_name)
107
- )
108
- cols_exist = {col['COLUMN_NAME'] for col in cursor.fetchall()} - {'id'}
109
- if projection:
110
- selected_columns = [k for k, v in projection.items() if v and k in cols_exist]
111
- if not selected_columns:
112
- logger.info("Warning: Projection 参数不匹配任何数据库字段")
113
- return df
114
- else:
115
- selected_columns = list(cols_exist)
116
- if not selected_columns:
117
- logger.info("未找到可用字段")
118
- return df
119
- quoted_columns = [f'`{col}`' for col in selected_columns]
120
- base_sql = f"SELECT {', '.join(quoted_columns)} FROM `{db_name}`.`{table_name}`"
121
- params = []
122
- if '日期' in cols_exist:
123
- base_sql += f" WHERE 日期 BETWEEN %s AND %s"
124
- params.extend([start_date, end_date])
125
- if limit is not None and isinstance(limit, int) and limit > 0:
126
- base_sql += f" LIMIT %s"
127
- params.append(limit)
128
- logger.debug(f"data_to_df SQL: {base_sql}, params: {params}")
129
- cursor.execute(base_sql, tuple(params))
130
- result = cursor.fetchall()
131
- if result:
132
- df = pd.DataFrame(result)
133
- for col in df.columns:
134
- if df[col].apply(lambda x: isinstance(x, Decimal)).any():
135
- df[col] = df[col].astype(float)
396
+ # 尝试多种日期格式
397
+ for fmt in ['%Y-%m-%d', '%Y/%m/%d', '%Y%m%d', '%Y.%m.%d']:
398
+ try:
399
+ attempted_formats.append(fmt)
400
+ return pd.to_datetime(date_str, format=fmt).strftime('%Y-%m-%d')
401
+ except ValueError:
402
+ continue
403
+
404
+ # 如果所有格式都失败,使用pandas的自动解析
405
+ attempted_formats.append('auto')
406
+ return pd.to_datetime(date_str).strftime('%Y-%m-%d')
407
+
408
+ except Exception as e:
409
+ logger.warning('日期格式转换失败', {
410
+ '输入日期': date_str,
411
+ '尝试的格式': attempted_formats,
412
+ '错误信息': str(e),
413
+ '使用默认日期': default_date
414
+ })
415
+ return default_date
416
+
417
+ def _validate_date_range(self, start_date: Optional[str], end_date: Optional[str],
418
+ db_name: str, table_name: str) -> Tuple[Optional[str], Optional[str]]:
419
+ """
420
+ 验证并处理日期范围。
421
+
422
+ Args:
423
+ start_date: 开始日期
424
+ end_date: 结束日期
425
+ db_name: 数据库名
426
+ table_name: 表名
427
+
428
+ Returns:
429
+ 处理后的日期范围元组 (start_date, end_date),如果处理失败返回 (None, None)
430
+ """
431
+ try:
432
+ # 如果两个日期都未提供,返回None表示不进行日期过滤
433
+ if start_date is None and end_date is None:
434
+ return None, None
435
+
436
+ # 如果只提供了开始日期,结束日期设为今天
437
+ if start_date is not None and end_date is None:
438
+ end_date = datetime.datetime.today().strftime('%Y-%m-%d')
439
+ logger.debug('未提供结束日期,使用当前日期', {'库': db_name, '表': table_name, '结束日期': end_date})
440
+
441
+ # 如果只提供了结束日期,开始日期设为1970年
442
+ if start_date is None and end_date is not None:
443
+ start_date = '1970-01-01'
444
+ logger.debug('未提供开始日期,使用默认日期', {'库': db_name, '表': table_name, '开始日期': start_date})
445
+
446
+ # 格式化日期
447
+ original_start = start_date
448
+ original_end = end_date
449
+ start_date = self.validate_and_format_date(start_date, '1970-01-01')
450
+ end_date = self.validate_and_format_date(end_date, datetime.datetime.today().strftime('%Y-%m-%d'))
451
+
452
+ # 如果日期格式被修改,记录日志
453
+ if original_start != start_date:
454
+ logger.debug('开始日期格式已调整', {
455
+ '库': db_name,
456
+ '表': table_name,
457
+ '原始日期': original_start,
458
+ '调整后日期': start_date
459
+ })
460
+ if original_end != end_date:
461
+ logger.debug('结束日期格式已调整', {
462
+ '库': db_name,
463
+ '表': table_name,
464
+ '原始日期': original_end,
465
+ '调整后日期': end_date
466
+ })
467
+
468
+ # 检查日期顺序
469
+ start_dt = pd.to_datetime(start_date)
470
+ end_dt = pd.to_datetime(end_date)
471
+ if start_dt > end_dt:
472
+ logger.debug('日期范围调整', {'库': db_name, '表': table_name, '原开始日期': start_date, '原结束日期': end_date})
473
+ start_date, end_date = end_date, start_date
474
+
475
+ # # 只在两个日期都是用户明确提供的情况下检查日期范围
476
+ # if original_start != '1970-01-01' and original_end != datetime.datetime.today().strftime('%Y-%m-%d'):
477
+ # if (end_dt - start_dt).days > 365 * 10:
478
+ # logger.debug('日期范围过大,已限制为10年', {'库': db_name, '表': table_name, '开始日期': start_date, '结束日期': end_date})
479
+ # end_date = (start_dt + pd.Timedelta(days=365*10)).strftime('%Y-%m-%d')
480
+
481
+ return start_date, end_date
482
+
136
483
  except Exception as e:
137
- logger.error(f"data_to_df error: {e}")
484
+ logger.error('日期处理失败', {
485
+ '库': db_name,
486
+ '表': table_name,
487
+ '开始日期': start_date,
488
+ '结束日期': end_date,
489
+ '错误': str(e)
490
+ })
491
+ return None, None
492
+
493
+ def _detect_date_field(self, cols_exist: Set[str], date_field: Optional[str] = None) -> Optional[str]:
494
+ """
495
+ 检测或验证日期字段。
496
+
497
+ Args:
498
+ cols_exist: 存在的列名集合
499
+ date_field: 用户指定的日期字段名
500
+
501
+ Returns:
502
+ 有效的日期字段名,如果未找到则返回None
503
+ """
504
+ if date_field:
505
+ if date_field not in cols_exist:
506
+ logger.debug('指定的日期字段不存在', {
507
+ '指定的日期字段': date_field,
508
+ '可用的列': list(cols_exist)
509
+ })
510
+ return None
511
+ logger.debug('使用指定的日期字段', {'日期字段': date_field})
512
+ return date_field
513
+
514
+ # 自动检测可能的日期字段
515
+ possible_date_fields = {'日期', 'date', 'create_time', 'update_time', 'created_at', 'updated_at', '更新时间', '创建时间'}
516
+ detected_field = next((field for field in possible_date_fields if field in cols_exist), None)
517
+ if detected_field:
518
+ logger.debug('自动检测到日期字段', {
519
+ '检测到的日期字段': detected_field,
520
+ '可用的列': list(cols_exist),
521
+ '尝试匹配的字段': list(possible_date_fields)
522
+ })
523
+ else:
524
+ logger.debug('未检测到日期字段', {
525
+ '可用的列': list(cols_exist),
526
+ '尝试匹配的字段': list(possible_date_fields)
527
+ })
528
+ return detected_field
529
+
530
+ def _get_selected_columns(self, cols_exist: Set[str], projection: Optional[Dict[str, int]] = None) -> List[str]:
531
+ """
532
+ 获取要查询的列名列表。
533
+
534
+ Args:
535
+ cols_exist: 存在的列名集合
536
+ projection: 列筛选字典,key为列名,value为1表示选中
537
+ - 如果为None、空字典{}或空列表[],则返回所有列
538
+ - 如果为字典,则根据value值筛选列
539
+
540
+ Returns:
541
+ 选中的列名列表
542
+ """
543
+ if not cols_exist:
544
+ logger.warning('表没有可用列')
545
+ return []
546
+
547
+ # 如果 projection 为 None、空字典或空列表,返回所有列
548
+ if projection is None or projection == {} or projection == []:
549
+ return list(cols_exist)
550
+
551
+ # 验证列名是否包含特殊字符
552
+ invalid_chars = set('`\'"\\')
553
+ selected_columns = []
554
+ for col in projection:
555
+ if any(char in col for char in invalid_chars):
556
+ logger.warning('列名包含特殊字符,已跳过', {'列名': col})
557
+ continue
558
+ if col in cols_exist and projection[col]:
559
+ selected_columns.append(col)
560
+
561
+ if not selected_columns:
562
+ logger.info('参数不匹配,返回所有列', {'参数': projection})
563
+ return list(cols_exist)
564
+
565
+ return selected_columns
566
+
567
+ def _build_query_sql(self, db_name: str, table_name: str, selected_columns: List[str],
568
+ date_field: Optional[str] = None, start_date: Optional[str] = None,
569
+ end_date: Optional[str] = None, limit: Optional[int] = None) -> Tuple[str, List[Any]]:
570
+ """
571
+ 构建SQL查询语句和参数。
572
+
573
+ Args:
574
+ db_name: 数据库名
575
+ table_name: 表名
576
+ selected_columns: 选中的列名列表
577
+ date_field: 日期字段名
578
+ start_date: 开始日期
579
+ end_date: 结束日期
580
+ limit: 限制返回行数,None表示不限制
581
+
582
+ Returns:
583
+ SQL语句和参数列表的元组
584
+
585
+ Raises:
586
+ ValueError: 当参数无效时
587
+ """
588
+ if not selected_columns:
589
+ raise ValueError("没有可查询的列")
590
+
591
+ # 验证数据库名和表名
592
+ if not db_name or not table_name:
593
+ raise ValueError("数据库名和表名不能为空")
594
+
595
+ # 验证列名
596
+ invalid_chars = set('`\'"\\')
597
+ for col in selected_columns:
598
+ if any(char in col for char in invalid_chars):
599
+ raise ValueError(f"列名包含特殊字符: {col}")
600
+
601
+ # 使用参数化查询防止SQL注入
602
+ quoted_columns = [f'`{col}`' for col in selected_columns]
603
+ base_sql = f"SELECT {', '.join(quoted_columns)} FROM `{db_name}`.`{table_name}`"
604
+ params = []
605
+ param_names = [] # 用于记录参数名称
606
+
607
+ # 如果有日期字段,添加日期过滤条件
608
+ if date_field:
609
+ conditions = []
610
+ if start_date is not None:
611
+ conditions.append(f"`{date_field}` >= %s")
612
+ params.append(start_date)
613
+ param_names.append('开始日期')
614
+ if end_date is not None:
615
+ conditions.append(f"`{date_field}` <= %s")
616
+ params.append(end_date)
617
+ param_names.append('结束日期')
618
+
619
+ if conditions:
620
+ base_sql += " WHERE " + " AND ".join(conditions)
621
+
622
+ # 只在显式指定limit时添加限制
623
+ if limit is not None:
624
+ if not isinstance(limit, int) or limit <= 0:
625
+ raise ValueError("limit必须是正整数")
626
+ base_sql += f" LIMIT %s"
627
+ params.append(limit)
628
+ param_names.append('限制行数')
629
+
630
+ return base_sql, params, param_names
631
+
632
+ def _convert_decimal_columns(self, df: pd.DataFrame) -> pd.DataFrame:
633
+ """
634
+ 将DataFrame中的Decimal类型列转换为float类型。
635
+
636
+ Args:
637
+ df: 原始DataFrame
638
+
639
+ Returns:
640
+ 转换后的DataFrame
641
+ """
642
+ for col in df.columns:
643
+ if df[col].apply(lambda x: isinstance(x, Decimal)).any():
644
+ df[col] = df[col].astype(float)
138
645
  return df
139
646
 
647
+ # @_execute_with_retry
140
648
  def columns_to_list(self, db_name, table_name, columns_name, where: str = None) -> list:
141
649
  """
142
650
  获取数据表的指定列, 支持where条件筛选, 返回列表字典。
@@ -148,29 +656,30 @@ class QueryDatas:
148
656
  """
149
657
  if not self.check_infos(db_name, table_name):
150
658
  return []
151
- self.config.update({'database': db_name})
659
+
152
660
  try:
153
- with closing(pymysql.connect(**self.config)) as connection:
661
+ with closing(self._get_connection(db_name)) as connection:
154
662
  with closing(connection.cursor()) as cursor:
155
663
  sql = 'SELECT COLUMN_NAME FROM information_schema.columns WHERE table_schema = %s AND table_name = %s'
156
664
  cursor.execute(sql, (db_name, table_name))
157
665
  cols_exist = [col['COLUMN_NAME'] for col in cursor.fetchall()]
158
666
  columns_name = [item for item in columns_name if item in cols_exist]
159
667
  if not columns_name:
160
- logger.info("columns_to_list: 未找到匹配的列名")
668
+ logger.info('未找到匹配的列名', {'库': db_name, '表': table_name, '请求列': columns_name})
161
669
  return []
162
670
  columns_in = ', '.join([f'`{col}`' for col in columns_name])
163
671
  sql = f"SELECT {columns_in} FROM `{db_name}`.`{table_name}`"
164
672
  if where:
165
673
  sql += f" WHERE {where}"
166
- logger.debug(f"columns_to_list SQL: {sql}")
674
+ logger.debug('执行列查询', {'库': db_name, '表': table_name, 'SQL': sql})
167
675
  cursor.execute(sql)
168
676
  column_values = cursor.fetchall()
169
677
  return column_values
170
678
  except Exception as e:
171
- logger.error(f"columns_to_list error: {e}")
679
+ logger.error('列查询失败', {'库': db_name, '表': table_name, '列': columns_name, '错误': str(e)})
172
680
  return []
173
681
 
682
+ # @_execute_with_retry
174
683
  def dtypes_to_list(self, db_name, table_name, columns_name=None) -> list:
175
684
  """
176
685
  获取数据表的列名和类型, 支持只返回部分字段类型。
@@ -181,9 +690,9 @@ class QueryDatas:
181
690
  """
182
691
  if not self.check_infos(db_name, table_name):
183
692
  return []
184
- self.config.update({'database': db_name})
693
+
185
694
  try:
186
- with closing(pymysql.connect(**self.config)) as connection:
695
+ with closing(self._get_connection(db_name)) as connection:
187
696
  with closing(connection.cursor()) as cursor:
188
697
  sql = 'SELECT COLUMN_NAME, COLUMN_TYPE FROM information_schema.columns WHERE table_schema = %s AND table_name = %s'
189
698
  cursor.execute(sql, (db_name, table_name))
@@ -193,9 +702,10 @@ class QueryDatas:
193
702
  column_name_and_type = [row for row in column_name_and_type if row['COLUMN_NAME'] in columns_name]
194
703
  return column_name_and_type
195
704
  except Exception as e:
196
- logger.error(f"dtypes_to_list error: {e}")
705
+ logger.error('获取列类型失败', {'库': db_name, '表': table_name, '列': columns_name, '错误': str(e)})
197
706
  return []
198
707
 
708
+ # @_execute_with_retry
199
709
  def check_infos(self, db_name, table_name) -> bool:
200
710
  """
201
711
  检查数据库和数据表是否存在。
@@ -204,29 +714,239 @@ class QueryDatas:
204
714
  :return: 存在返回True,否则False
205
715
  """
206
716
  try:
207
- with closing(pymysql.connect(**self.config)) as connection:
208
- with closing(connection.cursor()) as cursor:
209
- cursor.execute(f"SHOW DATABASES LIKE %s", (db_name,))
210
- database_exists = cursor.fetchone()
211
- if not database_exists:
212
- logger.info(f"Database <{db_name}>: 数据库不存在")
213
- return False
717
+ # 检查数据库是否存在
718
+ result = self._execute_query("SHOW DATABASES LIKE %s", (db_name,))
719
+ if not result:
720
+ logger.info('数据库不存在', {'库': db_name})
721
+ return False
722
+
723
+ # 检查表是否存在
724
+ result = self._execute_query("SHOW TABLES LIKE %s", (table_name,), db_name=db_name)
725
+ if not result:
726
+ logger.info('表不存在', {'库': db_name, '表': table_name})
727
+ return False
728
+ return True
729
+
214
730
  except Exception as e:
215
- logger.error(f"check_infos-db error: {e}")
731
+ logger.error('检查数据库或表失败', {
732
+ '库': db_name,
733
+ '表': table_name,
734
+ '错误类型': type(e).__name__,
735
+ '错误信息': str(e)
736
+ })
216
737
  return False
217
- self.config.update({'database': db_name})
738
+
739
+ def __enter__(self):
740
+ """上下文管理器入口"""
741
+ return self
742
+
743
+ def __exit__(self, exc_type, exc_val, exc_tb):
744
+ """上下文管理器退出,确保资源被正确释放"""
745
+ self.close()
746
+
747
+ def close(self):
748
+ """显式关闭连接池,释放资源"""
749
+ if hasattr(self, 'pool') and self.pool is not None:
750
+ try:
751
+ self.pool.close()
752
+ logger.info('连接池已关闭', {
753
+ '主机': self.host,
754
+ '端口': self.port
755
+ })
756
+ except Exception as e:
757
+ logger.error('关闭连接池失败', {
758
+ '错误': str(e),
759
+ '主机': self.host,
760
+ '端口': self.port
761
+ })
762
+ finally:
763
+ self.pool = None
764
+
765
+ def data_to_df(
766
+ self,
767
+ db_name: str,
768
+ table_name: str,
769
+ start_date: Optional[str] = None,
770
+ end_date: Optional[str] = None,
771
+ projection: Optional[Dict[str, int]] = None,
772
+ limit: Optional[int] = None,
773
+ page_size: Optional[int] = None,
774
+ date_field: Optional[str] = None,
775
+ return_format: Literal['df', 'list_dict'] = 'df'
776
+ ) -> Union[pd.DataFrame, List[Dict[str, Any]]]:
777
+ """
778
+ 从数据库表获取数据,支持列筛选、日期范围过滤和行数限制。
779
+ 支持两种查询模式:
780
+ 1. 使用limit参数进行简单查询
781
+ 2. 使用page_size参数进行分页查询
782
+
783
+ Args:
784
+ db_name: 数据库名
785
+ table_name: 表名
786
+ start_date: 起始日期(包含),支持多种日期格式,如'YYYY-MM-DD'、'YYYY/MM/DD'等
787
+ end_date: 结束日期(包含),支持多种日期格式,如'YYYY-MM-DD'、'YYYY/MM/DD'等
788
+ projection: 列筛选字典,用于指定要查询的列
789
+ - 键为列名(字符串)
790
+ - 值为1表示选中该列,0表示不选中该列
791
+ - 例如:{'日期': 1, '场景名字': 1} 表示只查询这两列
792
+ - 如果为None、空字典{}或空列表[],则查询所有列
793
+ limit: 限制返回的最大行数,None表示不限制
794
+ page_size: 分页查询时每页的数据量,None表示不使用分页
795
+ date_field: 日期字段名,如果为None则使用默认的"日期"字段
796
+ return_format: 返回数据格式
797
+ - 'df': 返回pandas DataFrame(默认)
798
+ - 'list_dict': 返回列表字典格式 [{列1:值, 列2:值, ...}, ...]
799
+
800
+ Returns:
801
+ 根据return_format参数返回不同格式的数据:
802
+ - 当return_format='df'时,返回DataFrame
803
+ - 当return_format='list_dict'时,返回列表字典
804
+ - 如果查询失败,返回空的DataFrame或空列表
805
+ """
806
+ if not db_name or not table_name:
807
+ logger.error('数据库名和表名不能为空', {'库': db_name, '表': table_name})
808
+ return [] if return_format == 'list_dict' else pd.DataFrame()
809
+
810
+ # 验证return_format参数
811
+ valid_formats = {'df', 'list_dict'}
812
+ if return_format not in valid_formats:
813
+ logger.error('无效的return_format值', {'库': db_name, '表': table_name, '指定返回数据格式, 有效值应为: ': ', '.join(valid_formats)})
814
+ return [] if return_format == 'list_dict' else pd.DataFrame()
815
+
816
+ # 验证日期范围
817
+ start_date, end_date = self._validate_date_range(start_date, end_date, db_name, table_name)
818
+
819
+ # 检查数据库和表是否存在
820
+ if not self.check_infos(db_name, table_name):
821
+ return [] if return_format == 'list_dict' else pd.DataFrame()
218
822
  try:
219
- with closing(pymysql.connect(**self.config)) as connection:
823
+ with closing(self._get_connection(db_name)) as connection:
220
824
  with closing(connection.cursor()) as cursor:
221
- cursor.execute(f"SHOW TABLES LIKE %s", (table_name,))
222
- if not cursor.fetchone():
223
- logger.info(f'{db_name} -> <{table_name}>: 表不存在')
224
- return False
225
- return True
825
+ # 获取表的所有列
826
+ cursor.execute(
827
+ """SELECT COLUMN_NAME FROM information_schema.columns WHERE table_schema = %s AND table_name = %s""",
828
+ (db_name, table_name)
829
+ )
830
+ cols_exist = {col['COLUMN_NAME'] for col in cursor.fetchall()} - {'id'}
831
+
832
+ # 设置日期字段
833
+ if start_date is not None and end_date is not None:
834
+ # 如果未指定日期字段,使用默认的"日期"字段
835
+ if date_field is None:
836
+ date_field = "日期"
837
+
838
+ # 检查指定的日期字段是否存在
839
+ if date_field not in cols_exist:
840
+ logger.warning('指定的日期字段不存在,将返回所有数据', {
841
+ '库': db_name,
842
+ '表': table_name,
843
+ '指定日期字段': date_field
844
+ })
845
+ start_date = None
846
+ end_date = None
847
+ date_field = None
848
+
849
+ # 获取选中的列
850
+ selected_columns = self._get_selected_columns(cols_exist, projection)
851
+ if not selected_columns:
852
+ logger.info('未找到可用字段', {'库': db_name, '表': table_name, '字段': selected_columns})
853
+ return [] if return_format == 'list_dict' else pd.DataFrame()
854
+
855
+ # 构建基础SQL
856
+ base_sql, params, param_names = self._build_query_sql(
857
+ db_name, table_name, selected_columns,
858
+ date_field, start_date, end_date, None
859
+ )
860
+
861
+ # 如果指定了limit且没有指定page_size,使用简单查询
862
+ if limit is not None and page_size is None:
863
+ sql = f"{base_sql} LIMIT %s"
864
+ params = list(params) + [limit]
865
+ cursor.execute(sql, tuple(params))
866
+ result = cursor.fetchall()
867
+
868
+ if result:
869
+ if return_format == 'list_dict':
870
+ return result
871
+ else:
872
+ df = pd.DataFrame(result)
873
+ df = self._convert_decimal_columns(df)
874
+ return df
875
+ return [] if return_format == 'list_dict' else pd.DataFrame()
876
+
877
+ # 使用分页查询
878
+ # 获取总记录数
879
+ count_sql = f"SELECT COUNT(*) as total FROM ({base_sql}) as t"
880
+ cursor.execute(count_sql, tuple(params))
881
+ total_count = cursor.fetchone()['total']
882
+
883
+ if total_count == 0:
884
+ return [] if return_format == 'list_dict' else pd.DataFrame()
885
+
886
+ # 设置默认分页大小
887
+ if page_size is None:
888
+ page_size = 1000
889
+
890
+ # 分页查询
891
+ offset = 0
892
+ all_results = []
893
+
894
+ while offset < total_count:
895
+ # 添加分页参数
896
+ page_sql = f"{base_sql} LIMIT %s OFFSET %s"
897
+ page_params = list(params) + [page_size, offset]
898
+
899
+ cursor.execute(page_sql, tuple(page_params))
900
+ page_results = cursor.fetchall()
901
+
902
+ if not page_results:
903
+ break
904
+
905
+ if return_format == 'list_dict':
906
+ all_results.extend(page_results)
907
+ else:
908
+ if len(all_results) == 0:
909
+ all_results = pd.DataFrame(page_results)
910
+ else:
911
+ all_results = pd.concat([all_results, pd.DataFrame(page_results)], ignore_index=True)
912
+
913
+ offset += page_size
914
+ logger.debug('分页查询进度', {
915
+ '库': db_name,
916
+ '表': table_name,
917
+ '当前偏移量': offset,
918
+ '总记录数': total_count,
919
+ '已获取记录数': len(all_results) if return_format == 'list_dict' else len(all_results.index)
920
+ })
921
+
922
+ if return_format == 'df' and isinstance(all_results, pd.DataFrame) and not all_results.empty:
923
+ all_results = self._convert_decimal_columns(all_results)
924
+ return all_results
925
+
226
926
  except Exception as e:
227
- logger.error(f"check_infos-table error: {e}")
228
- return False
927
+ logger.error('数据查询失败', {
928
+ '库': db_name,
929
+ '表': table_name,
930
+ '错误类型': type(e).__name__,
931
+ '错误信息': str(e)
932
+ })
933
+ return [] if return_format == 'list_dict' else pd.DataFrame()
934
+
935
+
936
+ def main():
937
+ dir_path = os.path.expanduser("~")
938
+ my_cont = config.read_config(file_path=os.path.join(dir_path, 'spd.txt'))
939
+ username, password, host, port = my_cont['username'], my_cont['password'], my_cont['host'], int(my_cont['port'])
940
+ host = 'localhost'
941
+
942
+ # 创建QueryDatas实例
943
+ qd = QueryDatas(username=username, password=password, host=host, port=port)
944
+
945
+ # 执行查询
946
+ df = qd.data_to_df('聚合数据', '店铺流量来源构成', limit=10)
947
+ print(df)
229
948
 
230
949
 
231
950
  if __name__ == '__main__':
951
+ main()
232
952
  pass