mdbq 3.9.6__py3-none-any.whl → 3.9.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdbq/__version__.py +1 -1
- mdbq/log/mylogger.py +57 -7
- mdbq/mysql/deduplicator.py +595 -0
- mdbq/mysql/mysql.py +146 -431
- mdbq/mysql/uploader.py +1151 -0
- {mdbq-3.9.6.dist-info → mdbq-3.9.7.dist-info}/METADATA +1 -1
- {mdbq-3.9.6.dist-info → mdbq-3.9.7.dist-info}/RECORD +9 -7
- {mdbq-3.9.6.dist-info → mdbq-3.9.7.dist-info}/WHEEL +0 -0
- {mdbq-3.9.6.dist-info → mdbq-3.9.7.dist-info}/top_level.txt +0 -0
mdbq/mysql/uploader.py
ADDED
@@ -0,0 +1,1151 @@
|
|
1
|
+
# -*- coding:utf-8 -*-
|
2
|
+
import datetime
|
3
|
+
import re
|
4
|
+
import time
|
5
|
+
from functools import wraps
|
6
|
+
import warnings
|
7
|
+
import pymysql
|
8
|
+
import pandas as pd
|
9
|
+
import os
|
10
|
+
import logging
|
11
|
+
from mdbq.log import mylogger
|
12
|
+
from typing import Union, List, Dict, Optional, Any, Tuple, Set
|
13
|
+
from dbutils.pooled_db import PooledDB
|
14
|
+
import json
|
15
|
+
from collections import OrderedDict
|
16
|
+
|
17
|
+
warnings.filterwarnings('ignore')
|
18
|
+
logger = mylogger.MyLogger(
|
19
|
+
name='uploader',
|
20
|
+
logging_mode='none',
|
21
|
+
log_level='error',
|
22
|
+
log_file='uploader.log',
|
23
|
+
log_format='json',
|
24
|
+
max_log_size=50,
|
25
|
+
backup_count=5,
|
26
|
+
enable_async=False, # 是否启用异步日志
|
27
|
+
sample_rate=0.5, # 采样50%的DEBUG/INFO日志
|
28
|
+
sensitive_fields=[], # 敏感字段列表
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
def count_decimal_places(num_str):
|
33
|
+
""" 计算小数位数, 允许科学计数法 """
|
34
|
+
match = re.match(r'^[-+]?\d+(\.\d+)?([eE][-+]?\d+)?$', str(num_str))
|
35
|
+
if match:
|
36
|
+
# 如果是科学计数法
|
37
|
+
match = re.findall(r'(\d+)\.(\d+)[eE][-+]?(\d+)$', str(num_str))
|
38
|
+
if match:
|
39
|
+
if len(match[0]) == 3:
|
40
|
+
if int(match[0][2]) < len(match[0][1]):
|
41
|
+
# count_int 清除整数部分开头的 0 并计算整数位数
|
42
|
+
count_int = len(re.sub('^0+', '', str(match[0][0]))) + int(match[0][2])
|
43
|
+
# 计算小数位数
|
44
|
+
count_float = len(match[0][1]) - int(match[0][2])
|
45
|
+
return count_int, count_float
|
46
|
+
# 如果是普通小数
|
47
|
+
match = re.findall(r'(\d+)\.(\d+)$', str(num_str))
|
48
|
+
if match:
|
49
|
+
count_int = len(re.sub('^0+', '', str(match[0][0])))
|
50
|
+
count_float = len(match[0][1])
|
51
|
+
return count_int, count_float # 计算小数位数
|
52
|
+
return 0, 0
|
53
|
+
|
54
|
+
|
55
|
+
class StatementCache(OrderedDict):
|
56
|
+
"""LRU缓存策略"""
|
57
|
+
def __init__(self, maxsize=100):
|
58
|
+
super().__init__()
|
59
|
+
self.maxsize = maxsize
|
60
|
+
|
61
|
+
def __setitem__(self, key, value):
|
62
|
+
super().__setitem__(key, value)
|
63
|
+
if len(self) > self.maxsize:
|
64
|
+
self.popitem(last=False)
|
65
|
+
|
66
|
+
|
67
|
+
class MySQLUploader:
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
username: str,
|
71
|
+
password: str,
|
72
|
+
host: str = 'localhost',
|
73
|
+
port: int = 3306,
|
74
|
+
charset: str = 'utf8mb4',
|
75
|
+
collation: str = 'utf8mb4_0900_ai_ci', # utf8mb4_0900_ai_ci: 该排序规则对大小写不敏感, utf8mb4_0900_as_cs/utf8mb4_bin: 对大小写敏感
|
76
|
+
max_retries: int = 10,
|
77
|
+
retry_interval: int = 10,
|
78
|
+
pool_size: int = 5,
|
79
|
+
connect_timeout: int = 10,
|
80
|
+
read_timeout: int = 30,
|
81
|
+
write_timeout: int = 30,
|
82
|
+
ssl: Optional[Dict] = None
|
83
|
+
):
|
84
|
+
"""
|
85
|
+
:param username: 数据库用户名
|
86
|
+
:param password: 数据库密码
|
87
|
+
:param host: 数据库主机地址,默认为localhost
|
88
|
+
:param port: 数据库端口,默认为3306
|
89
|
+
:param charset: 字符集,默认为utf8mb4
|
90
|
+
:param collation: 排序规则,默认为utf8mb4_0900_ai_ci
|
91
|
+
|
92
|
+
:param max_retries: 最大重试次数,默认为10
|
93
|
+
:param retry_interval: 重试间隔(秒),默认为10
|
94
|
+
:param pool_size: 连接池大小,默认为5
|
95
|
+
:param connect_timeout: 连接超时(秒),默认为10
|
96
|
+
:param read_timeout: 读取超时(秒),默认为30
|
97
|
+
:param write_timeout: 写入超时(秒),默认为30
|
98
|
+
:param ssl: SSL配置字典,默认为None
|
99
|
+
"""
|
100
|
+
self.username = username
|
101
|
+
self.password = password
|
102
|
+
self.host = host
|
103
|
+
self.port = port
|
104
|
+
self.charset = charset
|
105
|
+
self.collation = collation
|
106
|
+
self.max_retries = max(max_retries, 1)
|
107
|
+
self.retry_interval = max(retry_interval, 1)
|
108
|
+
self.pool_size = max(pool_size, 1)
|
109
|
+
self.connect_timeout = connect_timeout
|
110
|
+
self.read_timeout = read_timeout
|
111
|
+
self.write_timeout = write_timeout
|
112
|
+
self.ssl = ssl
|
113
|
+
self._prepared_statements = StatementCache(maxsize=100)
|
114
|
+
self._max_cached_statements = 100
|
115
|
+
self._table_metadata_cache = {}
|
116
|
+
self.metadata_cache_ttl = 300 # 5分钟缓存时间
|
117
|
+
|
118
|
+
# 创建连接池
|
119
|
+
self.pool = self._create_connection_pool()
|
120
|
+
|
121
|
+
def _create_connection_pool(self) -> PooledDB:
|
122
|
+
"""创建数据库连接池"""
|
123
|
+
if hasattr(self, 'pool') and self.pool is not None and self._check_pool_health():
|
124
|
+
return self.pool
|
125
|
+
|
126
|
+
start_time = time.time()
|
127
|
+
self.pool = None
|
128
|
+
|
129
|
+
pool_params = {
|
130
|
+
'creator': pymysql,
|
131
|
+
'host': self.host,
|
132
|
+
'port': self.port,
|
133
|
+
'user': self.username,
|
134
|
+
'password': self.password,
|
135
|
+
'charset': self.charset,
|
136
|
+
'cursorclass': pymysql.cursors.DictCursor,
|
137
|
+
'maxconnections': self.pool_size,
|
138
|
+
'ping': 7,
|
139
|
+
'connect_timeout': self.connect_timeout,
|
140
|
+
'read_timeout': self.read_timeout,
|
141
|
+
'write_timeout': self.write_timeout,
|
142
|
+
'autocommit': False
|
143
|
+
}
|
144
|
+
|
145
|
+
if self.ssl:
|
146
|
+
required_keys = {'ca', 'cert', 'key'}
|
147
|
+
if not all(k in self.ssl for k in required_keys):
|
148
|
+
error_msg = "SSL配置必须包含ca、cert和key"
|
149
|
+
logger.error(error_msg)
|
150
|
+
raise ValueError(error_msg)
|
151
|
+
pool_params['ssl'] = {
|
152
|
+
'ca': self.ssl['ca'],
|
153
|
+
'cert': self.ssl['cert'],
|
154
|
+
'key': self.ssl['key'],
|
155
|
+
'check_hostname': self.ssl.get('check_hostname', False)
|
156
|
+
}
|
157
|
+
|
158
|
+
try:
|
159
|
+
pool = PooledDB(**pool_params)
|
160
|
+
elapsed = time.time() - start_time
|
161
|
+
logger.info("连接池创建成功", {
|
162
|
+
'pool_size': self.pool_size,
|
163
|
+
'time_elapsed': elapsed
|
164
|
+
})
|
165
|
+
return pool
|
166
|
+
except Exception as e:
|
167
|
+
elapsed = time.time() - start_time
|
168
|
+
self.pool = None
|
169
|
+
logger.error("连接池创建失败", {
|
170
|
+
'error': str(e),
|
171
|
+
'time_elapsed': elapsed
|
172
|
+
})
|
173
|
+
raise ConnectionError(f"连接池创建失败: {str(e)}")
|
174
|
+
|
175
|
+
def _execute_with_retry(self, func):
|
176
|
+
@wraps(func)
|
177
|
+
def wrapper(*args, **kwargs):
|
178
|
+
last_exception = None
|
179
|
+
start_time = time.time()
|
180
|
+
operation = func.__name__
|
181
|
+
|
182
|
+
logger.debug(f"开始执行操作: {operation}", {
|
183
|
+
'attempt': 1,
|
184
|
+
'max_retries': self.max_retries
|
185
|
+
})
|
186
|
+
|
187
|
+
for attempt in range(self.max_retries):
|
188
|
+
try:
|
189
|
+
result = func(*args, **kwargs)
|
190
|
+
elapsed = time.time() - start_time
|
191
|
+
|
192
|
+
if attempt > 0:
|
193
|
+
logger.info("操作成功(重试后)", {
|
194
|
+
'operation': operation,
|
195
|
+
'attempts': attempt + 1,
|
196
|
+
'time_elapsed': elapsed
|
197
|
+
})
|
198
|
+
else:
|
199
|
+
logger.debug("操作成功", {
|
200
|
+
'operation': operation,
|
201
|
+
'time_elapsed': elapsed
|
202
|
+
})
|
203
|
+
|
204
|
+
return result
|
205
|
+
|
206
|
+
except (pymysql.OperationalError, pymysql.err.MySQLError) as e:
|
207
|
+
last_exception = e
|
208
|
+
|
209
|
+
# 记录详细的MySQL错误信息
|
210
|
+
error_details = {
|
211
|
+
'operation': operation,
|
212
|
+
'error_code': e.args[0] if e.args else None,
|
213
|
+
'error_message': e.args[1] if len(e.args) > 1 else None,
|
214
|
+
'attempt': attempt + 1,
|
215
|
+
'max_retries': self.max_retries
|
216
|
+
}
|
217
|
+
|
218
|
+
if attempt < self.max_retries - 1:
|
219
|
+
wait_time = self.retry_interval * (attempt + 1)
|
220
|
+
error_details['wait_time'] = wait_time
|
221
|
+
logger.warning(f"数据库操作失败,准备重试 {error_details}", )
|
222
|
+
time.sleep(wait_time)
|
223
|
+
|
224
|
+
# 尝试重新连接
|
225
|
+
try:
|
226
|
+
self.pool = self._create_connection_pool()
|
227
|
+
logger.info("成功重新建立数据库连接")
|
228
|
+
except Exception as reconnect_error:
|
229
|
+
logger.error("重连失败", {
|
230
|
+
'error': str(reconnect_error)
|
231
|
+
})
|
232
|
+
else:
|
233
|
+
elapsed = time.time() - start_time
|
234
|
+
error_details['time_elapsed'] = elapsed
|
235
|
+
logger.error(f"操作最终失败 {error_details}")
|
236
|
+
|
237
|
+
except pymysql.IntegrityError as e:
|
238
|
+
elapsed = time.time() - start_time
|
239
|
+
logger.error("完整性约束错误", {
|
240
|
+
'operation': operation,
|
241
|
+
'time_elapsed': elapsed,
|
242
|
+
'error_code': e.args[0] if e.args else None,
|
243
|
+
'error_message': e.args[1] if len(e.args) > 1 else None
|
244
|
+
})
|
245
|
+
raise e
|
246
|
+
|
247
|
+
except Exception as e:
|
248
|
+
last_exception = e
|
249
|
+
elapsed = time.time() - start_time
|
250
|
+
logger.error("发生意外错误", {
|
251
|
+
'operation': operation,
|
252
|
+
'time_elapsed': elapsed,
|
253
|
+
'error_type': type(e).__name__,
|
254
|
+
'error_message': str(e),
|
255
|
+
'error_args': e.args if hasattr(e, 'args') else None
|
256
|
+
})
|
257
|
+
break
|
258
|
+
|
259
|
+
raise last_exception if last_exception else Exception("发生未知错误")
|
260
|
+
|
261
|
+
return wrapper
|
262
|
+
|
263
|
+
def _get_connection(self):
|
264
|
+
"""从连接池获取连接"""
|
265
|
+
try:
|
266
|
+
conn = self.pool.connection()
|
267
|
+
logger.debug("获取数据库连接")
|
268
|
+
return conn
|
269
|
+
except Exception as e:
|
270
|
+
logger.error(f'{e}')
|
271
|
+
raise ConnectionError(f"连接数据库失败: {str(e)}")
|
272
|
+
|
273
|
+
def _check_database_exists(self, db_name: str) -> bool:
|
274
|
+
"""检查数据库是否存在"""
|
275
|
+
db_name = self._validate_identifier(db_name)
|
276
|
+
sql = "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = %s"
|
277
|
+
|
278
|
+
try:
|
279
|
+
with self._get_connection() as conn:
|
280
|
+
with conn.cursor() as cursor:
|
281
|
+
cursor.execute(sql, (db_name,))
|
282
|
+
exists = bool(cursor.fetchone())
|
283
|
+
logger.debug(f"{db_name} 数据库已存在: {exists}")
|
284
|
+
return exists
|
285
|
+
except Exception as e:
|
286
|
+
logger.error(f"检查数据库是否存在时出错: {str(e)}")
|
287
|
+
raise
|
288
|
+
|
289
|
+
def _create_database(self, db_name: str):
|
290
|
+
"""创建数据库"""
|
291
|
+
db_name = self._validate_identifier(db_name)
|
292
|
+
sql = f"CREATE DATABASE IF NOT EXISTS `{db_name}` CHARACTER SET {self.charset} COLLATE {self.collation}"
|
293
|
+
|
294
|
+
try:
|
295
|
+
with self._get_connection() as conn:
|
296
|
+
with conn.cursor() as cursor:
|
297
|
+
cursor.execute(sql)
|
298
|
+
conn.commit()
|
299
|
+
logger.info(f"{db_name} 数据库已创建")
|
300
|
+
except Exception as e:
|
301
|
+
logger.error(f"{db_name}: 无法创建数据库 {str(e)}")
|
302
|
+
conn.rollback()
|
303
|
+
raise
|
304
|
+
|
305
|
+
def _get_partition_table_name(self, table_name: str, date_value: str, partition_by: str) -> str:
|
306
|
+
"""
|
307
|
+
获取分表名称
|
308
|
+
|
309
|
+
:param table_name: 基础表名
|
310
|
+
:param date_value: 日期值
|
311
|
+
:param partition_by: 分表方式 ('year' 或 'month')
|
312
|
+
:return: 分表名称
|
313
|
+
:raises ValueError: 如果日期格式无效或分表方式无效
|
314
|
+
"""
|
315
|
+
try:
|
316
|
+
# date_obj = datetime.datetime.strptime(date_value, '%Y-%m-%d %H:%M:%S')
|
317
|
+
date_obj = self._validate_datetime(date_value, True)
|
318
|
+
except ValueError:
|
319
|
+
try:
|
320
|
+
# date_obj = datetime.datetime.strptime(date_value, '%Y-%m-%d')
|
321
|
+
date_obj = self._validate_datetime(date_value, True)
|
322
|
+
except ValueError:
|
323
|
+
error_msg = f"无效的日期格式1: {date_value}"
|
324
|
+
logger.error(error_msg)
|
325
|
+
raise ValueError(error_msg)
|
326
|
+
|
327
|
+
if partition_by == 'year':
|
328
|
+
return f"{table_name}_{date_obj.year}"
|
329
|
+
elif partition_by == 'month':
|
330
|
+
return f"{table_name}_{date_obj.year}_{date_obj.month:02d}"
|
331
|
+
else:
|
332
|
+
error_msg = "partition_by must be 'year' or 'month'"
|
333
|
+
logger.error(error_msg)
|
334
|
+
raise ValueError(error_msg)
|
335
|
+
|
336
|
+
def _validate_identifier(self, identifier: str) -> str:
|
337
|
+
"""
|
338
|
+
验证并清理数据库标识符(数据库名、表名、列名)
|
339
|
+
防止SQL注入和非法字符
|
340
|
+
|
341
|
+
:param identifier: 要验证的标识符
|
342
|
+
:return: 清理后的安全标识符
|
343
|
+
:raises ValueError: 如果标识符无效
|
344
|
+
"""
|
345
|
+
if not identifier or not isinstance(identifier, str):
|
346
|
+
error_msg = f"无效的标识符: {identifier}"
|
347
|
+
logger.error(error_msg)
|
348
|
+
raise ValueError(error_msg)
|
349
|
+
|
350
|
+
# 移除非法字符,只保留字母、数字、下划线和美元符号
|
351
|
+
cleaned = re.sub(r'[^\w\u4e00-\u9fff$]', '', identifier)
|
352
|
+
if not cleaned:
|
353
|
+
error_msg = f"无法清理异常标识符: {identifier}"
|
354
|
+
logger.error(error_msg)
|
355
|
+
raise ValueError(error_msg)
|
356
|
+
|
357
|
+
# 检查是否为MySQL保留字
|
358
|
+
mysql_keywords = {
|
359
|
+
'select', 'insert', 'update', 'delete', 'from', 'where', 'and', 'or',
|
360
|
+
'not', 'like', 'in', 'is', 'null', 'true', 'false', 'between'
|
361
|
+
}
|
362
|
+
if cleaned.lower() in mysql_keywords:
|
363
|
+
logger.debug(f"存在MySQL保留字: {cleaned}")
|
364
|
+
return f"`{cleaned}`"
|
365
|
+
|
366
|
+
return cleaned
|
367
|
+
|
368
|
+
def _check_table_exists(self, db_name: str, table_name: str) -> bool:
|
369
|
+
"""检查表是否存在"""
|
370
|
+
cache_key = f"{db_name}.{table_name}"
|
371
|
+
if cache_key in self._table_metadata_cache:
|
372
|
+
cached_time, result = self._table_metadata_cache[cache_key]
|
373
|
+
if time.time() - cached_time < self.metadata_cache_ttl:
|
374
|
+
return result
|
375
|
+
|
376
|
+
db_name = self._validate_identifier(db_name)
|
377
|
+
table_name = self._validate_identifier(table_name)
|
378
|
+
sql = """
|
379
|
+
SELECT TABLE_NAME
|
380
|
+
FROM INFORMATION_SCHEMA.TABLES
|
381
|
+
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
|
382
|
+
"""
|
383
|
+
|
384
|
+
try:
|
385
|
+
with self._get_connection() as conn:
|
386
|
+
with conn.cursor() as cursor:
|
387
|
+
cursor.execute(sql, (db_name, table_name))
|
388
|
+
result = bool(cursor.fetchone())
|
389
|
+
except Exception as e:
|
390
|
+
logger.error(f"检查数据表是否存在时发生未知错误: {e}", )
|
391
|
+
raise
|
392
|
+
|
393
|
+
# 执行查询并缓存结果
|
394
|
+
self._table_metadata_cache[cache_key] = (time.time(), result)
|
395
|
+
return result
|
396
|
+
|
397
|
+
def _create_table(
|
398
|
+
self,
|
399
|
+
db_name: str,
|
400
|
+
table_name: str,
|
401
|
+
set_typ: Dict[str, str],
|
402
|
+
primary_keys: Optional[List[str]] = None,
|
403
|
+
date_column: Optional[str] = None,
|
404
|
+
indexes: Optional[List[str]] = None,
|
405
|
+
allow_null: bool = False
|
406
|
+
):
|
407
|
+
"""
|
408
|
+
创建数据表
|
409
|
+
|
410
|
+
:param db_name: 数据库名
|
411
|
+
:param table_name: 表名
|
412
|
+
:param set_typ: 列名和数据类型字典 {列名: 数据类型}
|
413
|
+
:param primary_keys: 主键列列表
|
414
|
+
:param date_column: 日期列名,如果存在将设置为索引
|
415
|
+
:param indexes: 需要创建索引的列列表
|
416
|
+
"""
|
417
|
+
db_name = self._validate_identifier(db_name)
|
418
|
+
table_name = self._validate_identifier(table_name)
|
419
|
+
|
420
|
+
if not set_typ:
|
421
|
+
error_msg = "No columns specified for table creation"
|
422
|
+
logger.error(error_msg)
|
423
|
+
raise ValueError(error_msg)
|
424
|
+
|
425
|
+
# 构建列定义SQL
|
426
|
+
column_defs = ["`id` INT NOT NULL AUTO_INCREMENT"]
|
427
|
+
|
428
|
+
# 添加其他列定义
|
429
|
+
for col_name, col_type in set_typ.items():
|
430
|
+
# 跳过id列,因为已经在前面添加了
|
431
|
+
if col_name.lower() == 'id':
|
432
|
+
continue
|
433
|
+
safe_col_name = self._validate_identifier(col_name)
|
434
|
+
col_def = f"`{safe_col_name}` {col_type}"
|
435
|
+
|
436
|
+
# 根据allow_null决定是否添加NOT NULL约束
|
437
|
+
if not allow_null and not col_type.lower().startswith('json'):
|
438
|
+
col_def += " NOT NULL"
|
439
|
+
|
440
|
+
column_defs.append(col_def)
|
441
|
+
|
442
|
+
# 添加主键定义
|
443
|
+
if primary_keys:
|
444
|
+
# 确保id在主键中
|
445
|
+
if 'id' not in [pk.lower() for pk in primary_keys]:
|
446
|
+
primary_keys = ['id'] + primary_keys
|
447
|
+
else:
|
448
|
+
# 如果没有指定主键,则使用id作为主键
|
449
|
+
primary_keys = ['id']
|
450
|
+
|
451
|
+
# 添加主键定义
|
452
|
+
safe_primary_keys = [self._validate_identifier(pk) for pk in primary_keys]
|
453
|
+
primary_key_sql = f", PRIMARY KEY (`{'`,`'.join(safe_primary_keys)}`)"
|
454
|
+
|
455
|
+
# 构建完整SQL
|
456
|
+
sql = f"""
|
457
|
+
CREATE TABLE IF NOT EXISTS `{db_name}`.`{table_name}` (
|
458
|
+
{','.join(column_defs)}
|
459
|
+
{primary_key_sql}
|
460
|
+
) ENGINE=InnoDB DEFAULT CHARSET={self.charset} COLLATE={self.collation}
|
461
|
+
"""
|
462
|
+
|
463
|
+
try:
|
464
|
+
with self._get_connection() as conn:
|
465
|
+
with conn.cursor() as cursor:
|
466
|
+
cursor.execute(sql)
|
467
|
+
logger.info(f"{db_name}.{table_name}: 数据表已创建")
|
468
|
+
|
469
|
+
# 添加普通索引
|
470
|
+
index_statements = []
|
471
|
+
|
472
|
+
# 日期列索引
|
473
|
+
if date_column and date_column in set_typ:
|
474
|
+
safe_date_col = self._validate_identifier(date_column)
|
475
|
+
index_statements.append(
|
476
|
+
f"ALTER TABLE `{db_name}`.`{table_name}` ADD INDEX `idx_{safe_date_col}` (`{safe_date_col}`)"
|
477
|
+
)
|
478
|
+
|
479
|
+
# 其他索引
|
480
|
+
if indexes:
|
481
|
+
for idx_col in indexes:
|
482
|
+
if idx_col in set_typ:
|
483
|
+
safe_idx_col = self._validate_identifier(idx_col)
|
484
|
+
index_statements.append(
|
485
|
+
f"ALTER TABLE `{db_name}`.`{table_name}` ADD INDEX `idx_{safe_idx_col}` (`{safe_idx_col}`)"
|
486
|
+
)
|
487
|
+
|
488
|
+
# 执行所有索引创建语句
|
489
|
+
if index_statements:
|
490
|
+
with conn.cursor() as cursor:
|
491
|
+
for stmt in index_statements:
|
492
|
+
cursor.execute(stmt)
|
493
|
+
logger.debug(f"Executed index statement: {stmt}", )
|
494
|
+
|
495
|
+
conn.commit()
|
496
|
+
logger.info(f"{db_name}.{table_name}: 索引已添加")
|
497
|
+
|
498
|
+
except Exception as e:
|
499
|
+
logger.error(f"{db_name}.{table_name}: 建表失败: {str(e)}")
|
500
|
+
conn.rollback()
|
501
|
+
raise
|
502
|
+
|
503
|
+
def _validate_datetime(self, value, date_type=False):
|
504
|
+
"""date_type: 返回字符串类型或者日期类型"""
|
505
|
+
formats = [
|
506
|
+
'%Y-%m-%d %H:%M:%S',
|
507
|
+
'%Y-%m-%d',
|
508
|
+
'%Y/%m/%d %H:%M:%S',
|
509
|
+
'%Y/%m/%d',
|
510
|
+
'%Y%m%d',
|
511
|
+
'%Y-%m-%dT%H:%M:%S',
|
512
|
+
'%Y-%m-%d %H:%M:%S.%f',
|
513
|
+
'%Y/%-m/%-d', # 2023/1/8
|
514
|
+
'%Y-%m-%-d', # 2023-01-8
|
515
|
+
'%Y-%-m-%-d' # 2023-1-8
|
516
|
+
]
|
517
|
+
for fmt in formats:
|
518
|
+
try:
|
519
|
+
if date_type:
|
520
|
+
return pd.to_datetime(datetime.datetime.strptime(value, fmt).strftime('%Y-%m-%d'))
|
521
|
+
else:
|
522
|
+
return datetime.datetime.strptime(value, fmt).strftime('%Y-%m-%d %H:%M:%S')
|
523
|
+
except ValueError:
|
524
|
+
continue
|
525
|
+
raise ValueError(f"无效的日期格式2: {value}")
|
526
|
+
|
527
|
+
def _validate_value(self, value: Any, column_type: str) -> Any:
|
528
|
+
"""
|
529
|
+
验证并清理数据值,根据列类型进行适当转换
|
530
|
+
|
531
|
+
:param value: 要验证的值
|
532
|
+
:param column_type: 列的数据类型
|
533
|
+
:return: 清理后的值
|
534
|
+
:raises ValueError: 如果值转换失败
|
535
|
+
"""
|
536
|
+
if value is None:
|
537
|
+
return None
|
538
|
+
|
539
|
+
try:
|
540
|
+
column_type_lower = column_type.lower()
|
541
|
+
|
542
|
+
if 'int' in column_type_lower:
|
543
|
+
if isinstance(value, (str, bytes)) and not value.strip().isdigit():
|
544
|
+
raise ValueError("非数字字符串无法转换为整数")
|
545
|
+
return int(value)
|
546
|
+
elif any(t in column_type_lower for t in ['float', 'double', 'decimal']):
|
547
|
+
return float(value) if value is not None else None
|
548
|
+
elif '日期' in column_type_lower or 'time' in column_type_lower:
|
549
|
+
if isinstance(value, (datetime.datetime, pd.Timestamp)):
|
550
|
+
return value.strftime('%Y-%m-%d %H:%M:%S')
|
551
|
+
elif isinstance(value, str):
|
552
|
+
try:
|
553
|
+
return self._validate_datetime(value) # 使用专门的日期验证方法
|
554
|
+
except ValueError as e:
|
555
|
+
raise ValueError(f"无效日期格式: {value} - {str(e)}")
|
556
|
+
return str(value)
|
557
|
+
elif 'char' in column_type_lower or 'text' in column_type_lower:
|
558
|
+
# 防止SQL注入
|
559
|
+
if isinstance(value, str):
|
560
|
+
return value.replace('\\', '\\\\').replace("'", "\\'")
|
561
|
+
return str(value)
|
562
|
+
elif 'json' in column_type_lower:
|
563
|
+
import json
|
564
|
+
return json.dumps(value) if value is not None else None
|
565
|
+
else:
|
566
|
+
return value
|
567
|
+
except (ValueError, TypeError) as e:
|
568
|
+
error_msg = f"数据类型转换异常 {value} to type {column_type}: {str(e)}"
|
569
|
+
logger.error(error_msg)
|
570
|
+
raise ValueError(error_msg)
|
571
|
+
|
572
|
+
def _get_table_columns(self, db_name: str, table_name: str) -> Dict[str, str]:
|
573
|
+
"""获取表的列名和数据类型"""
|
574
|
+
db_name = self._validate_identifier(db_name)
|
575
|
+
table_name = self._validate_identifier(table_name)
|
576
|
+
sql = """
|
577
|
+
SELECT COLUMN_NAME, DATA_TYPE
|
578
|
+
FROM INFORMATION_SCHEMA.COLUMNS
|
579
|
+
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
|
580
|
+
ORDER BY ORDINAL_POSITION
|
581
|
+
"""
|
582
|
+
|
583
|
+
try:
|
584
|
+
with self._get_connection() as conn:
|
585
|
+
with conn.cursor() as cursor:
|
586
|
+
cursor.execute(sql, (db_name, table_name))
|
587
|
+
set_typ = {row['COLUMN_NAME']: row['DATA_TYPE'] for row in cursor.fetchall()}
|
588
|
+
logger.debug(f"{db_name}.{table_name}: 获取表的列信息: {set_typ}")
|
589
|
+
return set_typ
|
590
|
+
except Exception as e:
|
591
|
+
logger.error(f"无法获取表列信息: {str(e)}")
|
592
|
+
raise
|
593
|
+
|
594
|
+
def _upload_to_table(
|
595
|
+
self,
|
596
|
+
db_name: str,
|
597
|
+
table_name: str,
|
598
|
+
data: List[Dict],
|
599
|
+
set_typ: Dict[str, str],
|
600
|
+
primary_keys: Optional[List[str]],
|
601
|
+
check_duplicate: bool,
|
602
|
+
duplicate_columns: Optional[List[str]],
|
603
|
+
allow_null: bool,
|
604
|
+
auto_create: bool,
|
605
|
+
date_column: Optional[str],
|
606
|
+
indexes: Optional[List[str]],
|
607
|
+
batch_id: Optional[str] = None
|
608
|
+
):
|
609
|
+
"""实际执行表上传的方法"""
|
610
|
+
# 检查表是否存在
|
611
|
+
if not self._check_table_exists(db_name, table_name):
|
612
|
+
if auto_create:
|
613
|
+
self._create_table(db_name, table_name, set_typ, primary_keys, date_column, indexes,
|
614
|
+
allow_null=allow_null)
|
615
|
+
else:
|
616
|
+
error_msg = f"数据表不存在: '{db_name}.{table_name}'"
|
617
|
+
logger.error(error_msg)
|
618
|
+
raise ValueError(error_msg)
|
619
|
+
|
620
|
+
# 获取表结构并验证
|
621
|
+
table_columns = self._get_table_columns(db_name, table_name)
|
622
|
+
if not table_columns:
|
623
|
+
error_msg = f"获取列失败 '{db_name}.{table_name}'"
|
624
|
+
logger.error(error_msg)
|
625
|
+
raise ValueError(error_msg)
|
626
|
+
|
627
|
+
# 验证数据列与表列匹配
|
628
|
+
for col in set_typ:
|
629
|
+
if col not in table_columns:
|
630
|
+
error_msg = f"列不存在: '{col}' -> '{db_name}.{table_name}'"
|
631
|
+
logger.error(error_msg)
|
632
|
+
raise ValueError(error_msg)
|
633
|
+
|
634
|
+
# 插入数据
|
635
|
+
self._insert_data(
|
636
|
+
db_name, table_name, data, set_typ,
|
637
|
+
check_duplicate, duplicate_columns
|
638
|
+
)
|
639
|
+
|
640
|
+
def _infer_data_type(self, value: Any) -> str:
|
641
|
+
"""
|
642
|
+
根据值推断合适的数据类型
|
643
|
+
|
644
|
+
:param value: 要推断的值
|
645
|
+
:return: MySQL数据类型字符串
|
646
|
+
"""
|
647
|
+
if value is None:
|
648
|
+
return 'VARCHAR(255)' # 默认字符串类型
|
649
|
+
|
650
|
+
if isinstance(value, bool):
|
651
|
+
return 'TINYINT(1)'
|
652
|
+
elif isinstance(value, int):
|
653
|
+
# if -128 <= value <= 127:
|
654
|
+
# return 'TINYINT'
|
655
|
+
# elif -32768 <= value <= 32767:
|
656
|
+
# return 'SMALLINT'
|
657
|
+
# elif -8388608 <= value <= 8388607:
|
658
|
+
# return 'MEDIUMINT'
|
659
|
+
if -2147483648 <= value <= 2147483647:
|
660
|
+
return 'INT'
|
661
|
+
else:
|
662
|
+
return 'BIGINT'
|
663
|
+
elif isinstance(value, float):
|
664
|
+
return 'DECIMAL(10,2)'
|
665
|
+
elif isinstance(value, (datetime.datetime, pd.Timestamp)):
|
666
|
+
return 'DATETIME'
|
667
|
+
elif isinstance(value, datetime.date):
|
668
|
+
return 'DATE'
|
669
|
+
elif isinstance(value, (list, dict)):
|
670
|
+
return 'JSON'
|
671
|
+
elif isinstance(value, str):
|
672
|
+
# 尝试判断是否是日期时间
|
673
|
+
try:
|
674
|
+
self._validate_datetime(value)
|
675
|
+
return 'DATETIME'
|
676
|
+
except ValueError:
|
677
|
+
pass
|
678
|
+
|
679
|
+
# 根据字符串长度选择合适类型
|
680
|
+
length = len(value)
|
681
|
+
if length <= 255:
|
682
|
+
return 'VARCHAR(255)'
|
683
|
+
elif length <= 65535:
|
684
|
+
return 'TEXT'
|
685
|
+
elif length <= 16777215:
|
686
|
+
return 'MEDIUMTEXT'
|
687
|
+
else:
|
688
|
+
return 'LONGTEXT'
|
689
|
+
else:
|
690
|
+
return 'VARCHAR(255)'
|
691
|
+
|
692
|
+
def _prepare_data(
|
693
|
+
self,
|
694
|
+
data: Union[Dict, List[Dict], pd.DataFrame],
|
695
|
+
set_typ: Dict[str, str],
|
696
|
+
allow_null: bool = False
|
697
|
+
) -> List[Dict]:
|
698
|
+
"""
|
699
|
+
准备要上传的数据,验证并转换数据类型
|
700
|
+
|
701
|
+
:param data: 输入数据
|
702
|
+
:param set_typ: 列名和数据类型字典 {列名: 数据类型}
|
703
|
+
:param allow_null: 是否允许空值
|
704
|
+
:return: 待上传的数据列表和对应的数据类型
|
705
|
+
:raises ValueError: 如果数据验证失败
|
706
|
+
"""
|
707
|
+
# 统一数据格式为字典列表
|
708
|
+
if isinstance(data, pd.DataFrame):
|
709
|
+
try:
|
710
|
+
# 将列名转为小写
|
711
|
+
data.columns = [col.lower() for col in data.columns]
|
712
|
+
data = data.replace({pd.NA: None}).to_dict('records')
|
713
|
+
except Exception as e:
|
714
|
+
logger.error(f"数据转字典时发生错误: {e}", )
|
715
|
+
raise ValueError(f"数据转字典时发生错误: {e}")
|
716
|
+
elif isinstance(data, dict):
|
717
|
+
data = [{k.lower(): v for k, v in data.items()}]
|
718
|
+
elif isinstance(data, list) and all(isinstance(item, dict) for item in data):
|
719
|
+
# 将列表中的每个字典键转为小写
|
720
|
+
data = [{k.lower(): v for k, v in item.items()} for item in data]
|
721
|
+
else:
|
722
|
+
error_msg = "数据结构必须是字典、列表、字典列表或dataframe"
|
723
|
+
logger.error(error_msg)
|
724
|
+
raise ValueError(error_msg)
|
725
|
+
|
726
|
+
# 将set_typ的键转为小写
|
727
|
+
set_typ = {k.lower(): v for k, v in set_typ.items()}
|
728
|
+
|
729
|
+
# 获取数据中实际存在的列名
|
730
|
+
data_columns = set()
|
731
|
+
if data:
|
732
|
+
data_columns = set(data[0].keys())
|
733
|
+
|
734
|
+
# 过滤set_typ,只保留数据中存在的列
|
735
|
+
filtered_set_typ = {}
|
736
|
+
for col in data_columns:
|
737
|
+
if col in set_typ:
|
738
|
+
filtered_set_typ[col] = set_typ[col]
|
739
|
+
else:
|
740
|
+
# 如果列不在set_typ中,尝试推断类型
|
741
|
+
sample_values = [row[col] for row in data if col in row and row[col] is not None][:10]
|
742
|
+
if sample_values:
|
743
|
+
inferred_type = self._infer_data_type(sample_values[0])
|
744
|
+
filtered_set_typ[col] = inferred_type
|
745
|
+
logger.debug(f"自动推断列'{col}'的数据类型为: {inferred_type}")
|
746
|
+
else:
|
747
|
+
# 没有样本值,使用默认类型
|
748
|
+
filtered_set_typ[col] = 'VARCHAR(255)'
|
749
|
+
logger.debug(f"为列'{col}'使用默认数据类型: VARCHAR(255)")
|
750
|
+
|
751
|
+
prepared_data = []
|
752
|
+
for row_idx, row in enumerate(data, 1):
|
753
|
+
prepared_row = {}
|
754
|
+
for col_name in filtered_set_typ:
|
755
|
+
# 跳过id列,不允许外部传入id
|
756
|
+
if col_name.lower() == 'id':
|
757
|
+
continue
|
758
|
+
|
759
|
+
if col_name not in row:
|
760
|
+
if not allow_null:
|
761
|
+
error_msg = f"Row {row_idx}: Missing required column '{col_name}' in data"
|
762
|
+
logger.error(error_msg)
|
763
|
+
raise ValueError(error_msg)
|
764
|
+
prepared_row[col_name] = None
|
765
|
+
else:
|
766
|
+
try:
|
767
|
+
prepared_row[col_name] = self._validate_value(row[col_name], filtered_set_typ[col_name])
|
768
|
+
except ValueError as e:
|
769
|
+
error_msg = f"Row {row_idx}, column '{col_name}': {str(e)}"
|
770
|
+
logger.error(error_msg)
|
771
|
+
raise ValueError(error_msg)
|
772
|
+
prepared_data.append(prepared_row)
|
773
|
+
|
774
|
+
logger.debug(f"已准备 {len(prepared_data)} 行数据")
|
775
|
+
return prepared_data, filtered_set_typ
|
776
|
+
|
777
|
+
def upload_data(
|
778
|
+
self,
|
779
|
+
db_name: str,
|
780
|
+
table_name: str,
|
781
|
+
data: Union[Dict, List[Dict], pd.DataFrame],
|
782
|
+
set_typ: Dict[str, str],
|
783
|
+
primary_keys: Optional[List[str]] = None,
|
784
|
+
check_duplicate: bool = False,
|
785
|
+
duplicate_columns: Optional[List[str]] = None,
|
786
|
+
allow_null: bool = False,
|
787
|
+
partition_by: Optional[str] = None,
|
788
|
+
partition_date_column: str = '日期',
|
789
|
+
auto_create: bool = True,
|
790
|
+
indexes: Optional[List[str]] = None
|
791
|
+
):
|
792
|
+
"""
|
793
|
+
上传数据到数据库
|
794
|
+
"""
|
795
|
+
upload_start = time.time()
|
796
|
+
initial_row_count = len(data) if hasattr(data, '__len__') else 1
|
797
|
+
|
798
|
+
batch_id = f"batch_{int(time.time() * 1000)}"
|
799
|
+
success_flag = False
|
800
|
+
|
801
|
+
logger.info("开始上传数据", {
|
802
|
+
'batch_id': batch_id,
|
803
|
+
'database': db_name,
|
804
|
+
'table': table_name,
|
805
|
+
'partition_by': partition_by,
|
806
|
+
'check_duplicate': check_duplicate,
|
807
|
+
'row_count': len(data) if hasattr(data, '__len__') else 1,
|
808
|
+
'auto_create': auto_create
|
809
|
+
})
|
810
|
+
|
811
|
+
try:
|
812
|
+
# 验证参数
|
813
|
+
if not set_typ:
|
814
|
+
error_msg = "列的数据类型缺失"
|
815
|
+
logger.error(error_msg)
|
816
|
+
raise ValueError(error_msg)
|
817
|
+
|
818
|
+
if partition_by and partition_by not in ['year', 'month']:
|
819
|
+
error_msg = "分表方式必须是 'year' 或 'month'"
|
820
|
+
logger.error(error_msg)
|
821
|
+
raise ValueError(error_msg)
|
822
|
+
|
823
|
+
# 准备数据
|
824
|
+
prepared_data, set_typ = self._prepare_data(data, set_typ, allow_null)
|
825
|
+
|
826
|
+
# 检查数据库是否存在
|
827
|
+
if not self._check_database_exists(db_name):
|
828
|
+
if auto_create:
|
829
|
+
self._create_database(db_name)
|
830
|
+
else:
|
831
|
+
error_msg = f"数据库不存在: '{db_name}'"
|
832
|
+
logger.error(error_msg)
|
833
|
+
raise ValueError(error_msg)
|
834
|
+
|
835
|
+
# 处理分表逻辑
|
836
|
+
if partition_by:
|
837
|
+
partitioned_data = {}
|
838
|
+
for row in prepared_data:
|
839
|
+
try:
|
840
|
+
if partition_date_column not in row:
|
841
|
+
error_msg = f"异常缺失列 '{partition_date_column}'"
|
842
|
+
logger.error(error_msg)
|
843
|
+
continue # 跳过当前行
|
844
|
+
|
845
|
+
part_table = self._get_partition_table_name(
|
846
|
+
table_name,
|
847
|
+
str(row[partition_date_column]),
|
848
|
+
partition_by
|
849
|
+
)
|
850
|
+
if part_table not in partitioned_data:
|
851
|
+
partitioned_data[part_table] = []
|
852
|
+
partitioned_data[part_table].append(row)
|
853
|
+
except Exception as e:
|
854
|
+
logger.error("分表处理失败", {
|
855
|
+
'row_data': row,
|
856
|
+
'error': str(e)
|
857
|
+
})
|
858
|
+
continue # 跳过当前行
|
859
|
+
|
860
|
+
# 对每个分表执行上传
|
861
|
+
for part_table, part_data in partitioned_data.items():
|
862
|
+
try:
|
863
|
+
self._upload_to_table(
|
864
|
+
db_name, part_table, part_data, set_typ,
|
865
|
+
primary_keys, check_duplicate, duplicate_columns,
|
866
|
+
allow_null, auto_create, partition_date_column,
|
867
|
+
indexes, batch_id
|
868
|
+
)
|
869
|
+
except Exception as e:
|
870
|
+
logger.error("分表上传失败", {
|
871
|
+
'partition_table': part_table,
|
872
|
+
'error': str(e)
|
873
|
+
})
|
874
|
+
continue # 跳过当前分表,继续处理其他分表
|
875
|
+
else:
|
876
|
+
# 不分表,直接上传
|
877
|
+
self._upload_to_table(
|
878
|
+
db_name, table_name, prepared_data, set_typ,
|
879
|
+
primary_keys, check_duplicate, duplicate_columns,
|
880
|
+
allow_null, auto_create, partition_date_column,
|
881
|
+
indexes, batch_id
|
882
|
+
)
|
883
|
+
|
884
|
+
success_flag = True
|
885
|
+
|
886
|
+
except Exception as e:
|
887
|
+
logger.error("上传过程中发生全局错误", {
|
888
|
+
'error': str(e),
|
889
|
+
'error_type': type(e).__name__
|
890
|
+
})
|
891
|
+
finally:
|
892
|
+
elapsed = time.time() - upload_start
|
893
|
+
logger.info("上传处理完成", {
|
894
|
+
'batch_id': batch_id,
|
895
|
+
'success': success_flag,
|
896
|
+
'time_elapsed': elapsed,
|
897
|
+
'initial_row_count': initial_row_count
|
898
|
+
})
|
899
|
+
|
900
|
+
def _insert_data(
|
901
|
+
self,
|
902
|
+
db_name: str,
|
903
|
+
table_name: str,
|
904
|
+
data: List[Dict],
|
905
|
+
set_typ: Dict[str, str],
|
906
|
+
check_duplicate: bool = False,
|
907
|
+
duplicate_columns: Optional[List[str]] = None,
|
908
|
+
batch_size: int = 1000,
|
909
|
+
batch_id: Optional[str] = None
|
910
|
+
):
|
911
|
+
"""
|
912
|
+
插入数据到表中
|
913
|
+
|
914
|
+
参数:
|
915
|
+
db_name: 数据库名
|
916
|
+
table_name: 表名
|
917
|
+
data: 要插入的数据列表
|
918
|
+
set_typ: 列名和数据类型字典 {列名: 数据类型}
|
919
|
+
check_duplicate: 是否检查重复
|
920
|
+
duplicate_columns: 用于检查重复的列(为空时检查所有列)
|
921
|
+
batch_size: 批量插入大小
|
922
|
+
batch_id: 批次ID用于日志追踪
|
923
|
+
"""
|
924
|
+
if not data:
|
925
|
+
return
|
926
|
+
|
927
|
+
# 获取所有列名(排除id列)
|
928
|
+
all_columns = [col for col in set_typ.keys() if col.lower() != 'id']
|
929
|
+
safe_columns = [self._validate_identifier(col) for col in all_columns]
|
930
|
+
placeholders = ','.join(['%s'] * len(safe_columns))
|
931
|
+
|
932
|
+
# 构建基础SQL语句
|
933
|
+
if check_duplicate:
|
934
|
+
if not duplicate_columns:
|
935
|
+
duplicate_columns = all_columns
|
936
|
+
else:
|
937
|
+
duplicate_columns = [col for col in duplicate_columns if col != 'id']
|
938
|
+
|
939
|
+
conditions = []
|
940
|
+
for col in duplicate_columns:
|
941
|
+
col_type = set_typ.get(col, '').lower()
|
942
|
+
|
943
|
+
# 处理DECIMAL类型,使用ROUND确保精度一致
|
944
|
+
if col_type.startswith('decimal'):
|
945
|
+
# 提取小数位数,如DECIMAL(10,2)提取2
|
946
|
+
scale_match = re.search(r'decimal\(\d+,(\d+)\)', col_type)
|
947
|
+
scale = int(scale_match.group(1)) if scale_match else 2
|
948
|
+
conditions.append(f"ROUND(`{self._validate_identifier(col)}`, {scale}) = ROUND(%s, {scale})")
|
949
|
+
else:
|
950
|
+
conditions.append(f"`{self._validate_identifier(col)}` = %s")
|
951
|
+
|
952
|
+
where_clause = " AND ".join(conditions)
|
953
|
+
|
954
|
+
sql = f"""
|
955
|
+
INSERT INTO `{db_name}`.`{table_name}`
|
956
|
+
(`{'`,`'.join(safe_columns)}`)
|
957
|
+
SELECT {placeholders}
|
958
|
+
FROM DUAL
|
959
|
+
WHERE NOT EXISTS (
|
960
|
+
SELECT 1 FROM `{db_name}`.`{table_name}`
|
961
|
+
WHERE {where_clause}
|
962
|
+
)
|
963
|
+
"""
|
964
|
+
else:
|
965
|
+
sql = f"""
|
966
|
+
INSERT INTO `{db_name}`.`{table_name}`
|
967
|
+
(`{'`,`'.join(safe_columns)}`)
|
968
|
+
VALUES ({placeholders})
|
969
|
+
"""
|
970
|
+
|
971
|
+
total_inserted = 0
|
972
|
+
total_skipped = 0
|
973
|
+
total_failed = 0 # 失败计数器
|
974
|
+
|
975
|
+
# 分批插入数据
|
976
|
+
with self._get_connection() as conn:
|
977
|
+
with conn.cursor() as cursor:
|
978
|
+
for i in range(0, len(data), batch_size):
|
979
|
+
batch_start = time.time()
|
980
|
+
batch = data[i:i + batch_size]
|
981
|
+
successful_rows = 0 # 当前批次成功数
|
982
|
+
|
983
|
+
for row in batch:
|
984
|
+
try:
|
985
|
+
# 准备参数
|
986
|
+
row_values = [row.get(col) for col in all_columns]
|
987
|
+
# 如果是排重检查,添加排重列值
|
988
|
+
if check_duplicate:
|
989
|
+
row_values += [row.get(col) for col in duplicate_columns]
|
990
|
+
|
991
|
+
cursor.execute(sql, row_values)
|
992
|
+
successful_rows += 1
|
993
|
+
conn.commit() # 每次成功插入后提交
|
994
|
+
|
995
|
+
except Exception as e:
|
996
|
+
conn.rollback() # 回滚当前行的事务
|
997
|
+
total_failed += 1
|
998
|
+
|
999
|
+
# 记录失败行详细信息
|
1000
|
+
error_details = {
|
1001
|
+
'batch_id': batch_id,
|
1002
|
+
'database': db_name,
|
1003
|
+
'table': table_name,
|
1004
|
+
'error_type': type(e).__name__,
|
1005
|
+
'error_message': str(e),
|
1006
|
+
'column_types': set_typ,
|
1007
|
+
'duplicate_check': check_duplicate,
|
1008
|
+
'duplicate_columns': duplicate_columns
|
1009
|
+
}
|
1010
|
+
logger.error(f"单行插入失败: {error_details}")
|
1011
|
+
continue # 跳过当前行,继续处理下一行
|
1012
|
+
|
1013
|
+
# 更新统计信息
|
1014
|
+
if check_duplicate:
|
1015
|
+
cursor.execute("SELECT ROW_COUNT()")
|
1016
|
+
affected_rows = cursor.rowcount
|
1017
|
+
total_inserted += affected_rows
|
1018
|
+
total_skipped += len(batch) - affected_rows - (len(batch) - successful_rows)
|
1019
|
+
else:
|
1020
|
+
total_inserted += successful_rows
|
1021
|
+
|
1022
|
+
batch_elapsed = time.time() - batch_start
|
1023
|
+
batch_info = {
|
1024
|
+
'batch_id': batch_id,
|
1025
|
+
'batch_index': i // batch_size + 1,
|
1026
|
+
'total_batches': (len(data) + batch_size - 1) // batch_size,
|
1027
|
+
'batch_size': len(batch),
|
1028
|
+
'successful_rows': successful_rows,
|
1029
|
+
'failed_rows': len(batch) - successful_rows,
|
1030
|
+
'time_elapsed': batch_elapsed,
|
1031
|
+
'rows_per_second': successful_rows / batch_elapsed if batch_elapsed > 0 else 0
|
1032
|
+
}
|
1033
|
+
logger.debug(f"批次处理完成 {batch_info}")
|
1034
|
+
|
1035
|
+
logger.info("数据插入完成", {
|
1036
|
+
'total_rows': len(data),
|
1037
|
+
'inserted_rows': total_inserted,
|
1038
|
+
'skipped_rows': total_skipped,
|
1039
|
+
'failed_rows': total_failed
|
1040
|
+
})
|
1041
|
+
|
1042
|
+
def close(self):
|
1043
|
+
"""关闭连接池并记录最终指标"""
|
1044
|
+
close_start = time.time()
|
1045
|
+
|
1046
|
+
try:
|
1047
|
+
if hasattr(self, 'pool') and self.pool is not None:
|
1048
|
+
# 更安全的关闭方式
|
1049
|
+
try:
|
1050
|
+
self.pool.close()
|
1051
|
+
except Exception as e:
|
1052
|
+
logger.warning("关闭连接池时出错", {
|
1053
|
+
'error': str(e)
|
1054
|
+
})
|
1055
|
+
|
1056
|
+
self.pool = None
|
1057
|
+
|
1058
|
+
elapsed = round(time.time() - close_start, 2)
|
1059
|
+
logger.info("连接池已关闭", {
|
1060
|
+
'close_time_elapsed': elapsed
|
1061
|
+
})
|
1062
|
+
except Exception as e:
|
1063
|
+
elapsed = round(time.time() - close_start, 2)
|
1064
|
+
logger.error("关闭连接池失败", {
|
1065
|
+
'error': str(e),
|
1066
|
+
'close_time_elapsed': elapsed
|
1067
|
+
})
|
1068
|
+
raise
|
1069
|
+
|
1070
|
+
def _check_pool_health(self):
|
1071
|
+
"""定期检查连接池健康状态"""
|
1072
|
+
try:
|
1073
|
+
conn = self.pool.connection()
|
1074
|
+
conn.ping(reconnect=True)
|
1075
|
+
conn.close()
|
1076
|
+
return True
|
1077
|
+
except Exception as e:
|
1078
|
+
logger.warning("连接池健康检查失败", {
|
1079
|
+
'error': str(e)
|
1080
|
+
})
|
1081
|
+
return False
|
1082
|
+
|
1083
|
+
def retry_on_failure(max_retries=3, delay=1):
|
1084
|
+
def decorator(func):
|
1085
|
+
@wraps(func)
|
1086
|
+
def wrapper(*args, **kwargs):
|
1087
|
+
last_exception = None
|
1088
|
+
for attempt in range(max_retries):
|
1089
|
+
try:
|
1090
|
+
return func(*args, **kwargs)
|
1091
|
+
except (pymysql.OperationalError, pymysql.InterfaceError) as e:
|
1092
|
+
last_exception = e
|
1093
|
+
if attempt < max_retries - 1:
|
1094
|
+
time.sleep(delay * (attempt + 1))
|
1095
|
+
continue
|
1096
|
+
raise MySQLUploaderError(f"操作重试{max_retries}次后失败") from e
|
1097
|
+
except Exception as e:
|
1098
|
+
raise MySQLUploaderError(f"操作失败: {str(e)}") from e
|
1099
|
+
raise last_exception if last_exception else MySQLUploaderError("未知错误")
|
1100
|
+
|
1101
|
+
return wrapper
|
1102
|
+
|
1103
|
+
return decorator
|
1104
|
+
|
1105
|
+
|
1106
|
+
def main():
|
1107
|
+
uploader = MySQLUploader(
|
1108
|
+
username='root',
|
1109
|
+
password='pw',
|
1110
|
+
host='localhost',
|
1111
|
+
port=3306,
|
1112
|
+
)
|
1113
|
+
|
1114
|
+
# 定义列和数据类型
|
1115
|
+
set_typ = {
|
1116
|
+
'name': 'VARCHAR(255)',
|
1117
|
+
'age': 'INT',
|
1118
|
+
'salary': 'DECIMAL(10,2)',
|
1119
|
+
'日期': 'DATE',
|
1120
|
+
'shop': None,
|
1121
|
+
}
|
1122
|
+
|
1123
|
+
# 准备数据
|
1124
|
+
data = [
|
1125
|
+
{'日期': '2023-01-8', 'name': 'JACk', 'AGE': '24', 'salary': 555.1545},
|
1126
|
+
{'日期': '2023-01-15', 'name': 'Alice', 'AGE': 35, 'salary': 100},
|
1127
|
+
{'日期': '2023-01-15', 'name': 'Alice', 'AGE': 30, 'salary': 0.0},
|
1128
|
+
{'日期': '2023-02-20', 'name': 'Bob', 'AGE': 25, 'salary': 45000.75}
|
1129
|
+
]
|
1130
|
+
|
1131
|
+
# 上传数据
|
1132
|
+
uploader.upload_data(
|
1133
|
+
db_name='测试库',
|
1134
|
+
table_name='测试表',
|
1135
|
+
data=data,
|
1136
|
+
set_typ=set_typ, # 定义列和数据类型
|
1137
|
+
primary_keys=[], # 创建唯一主键
|
1138
|
+
check_duplicate=False, # 检查重复数据
|
1139
|
+
duplicate_columns=[], # 指定排重的组合键
|
1140
|
+
allow_null=False, # 允许插入空值
|
1141
|
+
partition_by='year', # 按月分表
|
1142
|
+
partition_date_column='日期', # 用于分表的日期列名,默认为'日期'
|
1143
|
+
auto_create=True, # 表不存在时自动创建, 默认参数不要更改
|
1144
|
+
indexes=[], # 指定索引列
|
1145
|
+
)
|
1146
|
+
|
1147
|
+
uploader.close()
|
1148
|
+
|
1149
|
+
|
1150
|
+
if __name__ == '__main__':
|
1151
|
+
main()
|