mdbq 4.2.25__py3-none-any.whl → 4.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mdbq might be problematic. Click here for more details.

@@ -0,0 +1,1439 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ MySQL数据写入模块 - 专注于爬虫数据入库
5
+
6
+ 功能特性:
7
+ - 高性能批量插入
8
+ - 智能去重(基于唯一键)
9
+ - UPSERT操作(插入或更新)
10
+ - 自动建表和字段扩展
11
+ - 数据类型自动推断
12
+ - 进度监控和统计
13
+ - 失败重试机制
14
+ - 支持多种数据格式(字典、列表、DataFrame)
15
+ - 事务保证
16
+
17
+ 依赖:
18
+ pip install pymysql DBUtils pandas
19
+
20
+ 适用场景:
21
+ - 爬虫数据批量入库
22
+ - 数据ETL导入
23
+ - 日志数据存储
24
+ - 实时数据流写入
25
+ """
26
+
27
+ import logging
28
+ import time
29
+ from pathlib import Path
30
+ from typing import List, Dict, Any, Optional, Union, Tuple, Set
31
+ from datetime import datetime, date
32
+ from decimal import Decimal
33
+ from contextlib import contextmanager
34
+ import pymysql
35
+ from pymysql.cursors import DictCursor
36
+ from dbutils.pooled_db import PooledDB
37
+
38
+
39
+ # ==================== 异常类 ====================
40
+
41
+
42
+ class MySQLWriterError(Exception):
43
+ """MySQL写入异常基类"""
44
+ pass
45
+
46
+
47
+ class DataValidationError(MySQLWriterError):
48
+ """数据验证错误"""
49
+ pass
50
+
51
+
52
+ class TableCreationError(MySQLWriterError):
53
+ """建表错误"""
54
+ pass
55
+
56
+
57
+ class InsertError(MySQLWriterError):
58
+ """插入错误"""
59
+ pass
60
+
61
+
62
+ # ==================== 工具函数 ====================
63
+
64
+
65
+ def infer_mysql_type(value: Any, max_length: int = 255, for_index: bool = False) -> str:
66
+ """
67
+ 根据Python值推断MySQL数据类型
68
+
69
+ 参数:
70
+ value: Python值
71
+ max_length: VARCHAR最大长度(默认255)
72
+ for_index: 是否用于索引字段(会限制长度以避免超过索引限制)
73
+
74
+ 返回:
75
+ MySQL数据类型字符串
76
+
77
+ 注意:
78
+ - MySQL InnoDB索引长度限制:767字节(默认)或3072字节
79
+ - utf8mb4编码下,VARCHAR(191)最多占用764字节,安全范围内
80
+ - 如果for_index=True,VARCHAR长度会限制为191
81
+ """
82
+ if value is None:
83
+ # 如果用于索引,限制长度为191(767字节/4字节=191.75)
84
+ safe_length = 191 if for_index else max_length
85
+ return f"VARCHAR({safe_length})"
86
+
87
+ if isinstance(value, bool):
88
+ return "TINYINT(1)"
89
+
90
+ if isinstance(value, int):
91
+ if -128 <= value <= 127:
92
+ return "TINYINT"
93
+ elif -32768 <= value <= 32767:
94
+ return "SMALLINT"
95
+ elif -8388608 <= value <= 8388607:
96
+ return "MEDIUMINT"
97
+ elif -2147483648 <= value <= 2147483647:
98
+ return "INT"
99
+ else:
100
+ return "BIGINT"
101
+
102
+ if isinstance(value, float):
103
+ return "DOUBLE"
104
+
105
+ if isinstance(value, Decimal):
106
+ return "DECIMAL(20,6)"
107
+
108
+ if isinstance(value, (datetime, date)):
109
+ return "DATETIME"
110
+
111
+ if isinstance(value, str):
112
+ length = len(value)
113
+ if length == 0:
114
+ # 空字符串,使用默认长度
115
+ safe_length = 191 if for_index else max_length
116
+ return f"VARCHAR({safe_length})"
117
+ elif length <= 191:
118
+ # 短字符串,直接使用
119
+ return f"VARCHAR({length * 2})"
120
+ elif length <= 255 and not for_index:
121
+ # 中等长度,仅在非索引字段使用
122
+ return f"VARCHAR({min(length * 2, max_length)})"
123
+ elif length <= 65535 and not for_index:
124
+ # 长字符串,使用TEXT(TEXT类型不能作为唯一键)
125
+ return "TEXT"
126
+ else:
127
+ # 超长字符串或索引字段的长字符串
128
+ if for_index:
129
+ # 索引字段限制为191
130
+ return "VARCHAR(191)"
131
+ else:
132
+ return "LONGTEXT"
133
+
134
+ if isinstance(value, (list, dict)):
135
+ # JSON类型不能作为唯一键
136
+ return "JSON"
137
+
138
+ if isinstance(value, bytes):
139
+ return "BLOB"
140
+
141
+ # 默认类型
142
+ safe_length = 191 if for_index else max_length
143
+ return f"VARCHAR({safe_length})"
144
+
145
+
146
+ def sanitize_name(name: str, name_type: str = 'field') -> str:
147
+ """
148
+ 清理名称(库名、表名、字段名),使其符合MySQL命名规范
149
+
150
+ 规则:
151
+ - 强制转为小写
152
+ - 只保留字母、数字、下划线
153
+ - 数字开头添加前缀
154
+ - 截断长度限制为64字符
155
+
156
+ 参数:
157
+ name: 原始名称
158
+ name_type: 名称类型 ('database'/'table'/'field')
159
+
160
+ 返回:
161
+ 清理后的名称
162
+ """
163
+ import re
164
+ # 转小写
165
+ name = name.lower()
166
+ # 移除特殊字符,只保留字母、数字、下划线
167
+ name = re.sub(r'[^a-z0-9_]', '_', name)
168
+ # 移除连续的下划线
169
+ name = re.sub(r'_+', '_', name)
170
+ # 移除首尾下划线
171
+ name = name.strip('_')
172
+ # 如果为空或以数字开头,添加前缀
173
+ if not name or name[0].isdigit():
174
+ prefix = {'database': 'db_', 'table': 'tb_', 'field': 'f_'}.get(name_type, 'x_')
175
+ name = prefix + name
176
+ # 截断长度
177
+ return name[:64]
178
+
179
+
180
+ # ==================== 主写入类 ====================
181
+
182
+
183
+ class MYSQLWriter:
184
+ """
185
+ MySQL数据写入器 - 专为爬虫数据入库设计
186
+
187
+ 功能特性:
188
+ - 自动建库(数据库不存在时自动创建)
189
+ - 自动建表(推断字段类型)
190
+ - 高性能批量插入
191
+ - 智能去重(UPSERT)
192
+ - 唯一约束管理
193
+ - 索引管理
194
+ - 自动ID和时间戳
195
+ - 进度监控
196
+
197
+ 数据类型推断(自动 + 手动):
198
+ 自动推断规则:
199
+ - bool → TINYINT(1)
200
+ - int → TINYINT/SMALLINT/INT/BIGINT(根据值范围)
201
+ - float → DOUBLE
202
+ - Decimal → DECIMAL(20,6)
203
+ - datetime/date → DATETIME
204
+ - str → VARCHAR/TEXT(根据长度)
205
+ - list/dict → JSON
206
+ - None → VARCHAR(255)
207
+
208
+ 手动指定类型(优先级更高):
209
+ - 使用 field_types 参数精确控制字段类型
210
+ - 支持所有MySQL数据类型(DECIMAL、ENUM、SET等)
211
+ - 未指定的字段仍使用自动推断
212
+
213
+ 唯一约束长度限制(重要):
214
+ - MySQL InnoDB索引限制:767字节(utf8mb4: 191字符)
215
+ - 唯一约束字段自动限制为VARCHAR(191)
216
+ - TEXT/JSON类型不能作为唯一键
217
+
218
+ 示例:
219
+ # 自动建库建表(推荐)
220
+ writer = MYSQLWriter(
221
+ host='localhost',
222
+ user='root',
223
+ password='pwd',
224
+ database='spider_data', # 默认数据库(不存在会自动创建)
225
+ auto_create=True, # 自动建库建表(默认True)
226
+ auto_add_id=True, # 自动添加自增ID(默认True)
227
+ auto_add_timestamps=True # 自动添加时间戳(默认True)
228
+ )
229
+
230
+ # 方式1:使用默认数据库
231
+ data = [
232
+ {'url': 'http://example.com/1', 'title': '标题1', 'price': 99.9},
233
+ {'url': 'http://example.com/2', 'title': '标题2', 'price': 199.9}
234
+ ]
235
+ writer.insert_many('products', data) # 插入到 spider_data.products
236
+
237
+ # 方式2:指定其他数据库(支持 "库名.表名" 格式)
238
+ writer.insert_many('db2.products', data) # 自动创建db2库和products表
239
+
240
+ # 带唯一约束
241
+ writer.insert_many(
242
+ 'products',
243
+ data,
244
+ unique_key='url', # URL重复时更新
245
+ unique_constraints=['url'] # 创建唯一索引
246
+ )
247
+
248
+ # 手动指定字段类型
249
+ writer.insert_many(
250
+ 'products',
251
+ data,
252
+ field_types={
253
+ 'price': 'DECIMAL(10,2)', # 精确小数
254
+ 'status': 'ENUM("active","sold")' # 枚举类型
255
+ }
256
+ )
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ host: str = 'localhost',
262
+ port: int = 3306,
263
+ user: str = 'root',
264
+ password: str = '',
265
+ database: str = None,
266
+ charset: str = 'utf8mb4',
267
+ pool_size: int = 3,
268
+ auto_create: bool = True,
269
+ auto_add_id: bool = True,
270
+ auto_add_timestamps: bool = True,
271
+ log_config: Optional[Dict[str, Any]] = None
272
+ ):
273
+ """
274
+ 初始化MySQL写入器
275
+
276
+ 参数:
277
+ host: 数据库主机
278
+ port: 端口
279
+ user: 用户名
280
+ password: 密码
281
+ database: 默认数据库名(可选,insert时可用 "库名.表名" 格式)
282
+ charset: 字符集(默认utf8mb4)
283
+ pool_size: 连接池大小(默认3)
284
+ auto_create: 自动建库建表(默认True,库/表不存在时自动创建)
285
+ auto_add_id: 是否自动添加自增ID字段(默认True)
286
+ auto_add_timestamps: 是否自动添加created_at和updated_at字段(默认True)
287
+ log_config: 日志配置字典(键名忽略大小写),例如: {
288
+ 'enable': True, # 是否启用日志(默认True)
289
+ # 设置为False时,不会输出任何日志,忽略其他配置
290
+ 'level': 'INFO', # 日志级别 debug/info/warning/error(默认INFO)
291
+ 'output': 'console', # 输出位置(默认console):
292
+ # - 'console' 或 'terminal': 仅输出到终端
293
+ # - 'file': 仅输出到文件
294
+ # - 'both': 同时输出到终端和文件
295
+ 'file_path': 'mysql_writer.log' # 日志文件路径(可选)
296
+ # 相对路径:存储到用户home目录
297
+ # 绝对路径:存储到指定路径
298
+ }
299
+ """
300
+ self.host = host
301
+ self.port = port
302
+ self.user = user
303
+ self.password = password
304
+ self.charset = charset
305
+ self.auto_create = auto_create
306
+ self.auto_add_id = auto_add_id
307
+ self.auto_add_timestamps = auto_add_timestamps
308
+ self._closed = False
309
+
310
+ # 统计信息
311
+ self._stats = {
312
+ 'total_inserted': 0,
313
+ 'total_updated': 0,
314
+ 'total_failed': 0,
315
+ 'total_time': 0.0
316
+ }
317
+
318
+ self._setup_logger(log_config or {})
319
+
320
+ # 清理并保存数据库名
321
+ self.database = sanitize_name(database, 'database') if database else None
322
+
323
+ # 如果指定了默认数据库且开启自动建库,先确保数据库存在
324
+ if self.database and auto_create:
325
+ self._ensure_database_exists(self.database)
326
+ self._init_pool(pool_size)
327
+
328
+ if self.logger:
329
+ self.logger.debug(f"MySQL写入器初始化: {user}@{host}:{port}/{database}")
330
+
331
+ def _setup_logger(self, log_config: Dict[str, Any]):
332
+ """配置日志"""
333
+ config = {k.lower(): v for k, v in log_config.items()}
334
+
335
+ enable = config.get('enable', True)
336
+
337
+ if not enable:
338
+ self.logger = logging.getLogger(f'MYSQLWriter_{id(self)}')
339
+ self.logger.addHandler(logging.NullHandler())
340
+ self.logger.propagate = False
341
+ return
342
+
343
+ level = str(config.get('level', 'INFO')).upper()
344
+ output = str(config.get('output', 'console')).lower()
345
+ file_path = config.get('file_path')
346
+
347
+ self.logger = logging.getLogger(f'MYSQLWriter_{id(self)}')
348
+ self.logger.setLevel(getattr(logging, level))
349
+ self.logger.propagate = False
350
+ self.logger.handlers.clear()
351
+
352
+ formatter = logging.Formatter(
353
+ '%(asctime)s - %(levelname)s - %(message)s',
354
+ datefmt='%Y-%m-%d %H:%M:%S'
355
+ )
356
+
357
+ # 输出到终端
358
+ if output in ('console', 'terminal', 'both'):
359
+ console_handler = logging.StreamHandler()
360
+ console_handler.setLevel(logging.DEBUG)
361
+ console_handler.setFormatter(formatter)
362
+ self.logger.addHandler(console_handler)
363
+
364
+ # 输出到文件
365
+ if output in ('file', 'both'):
366
+ if not file_path:
367
+ file_path = Path.home() / 'mysql_writer.log'
368
+ else:
369
+ file_path = Path(file_path)
370
+ if not file_path.is_absolute():
371
+ file_path = Path.home() / file_path
372
+
373
+ file_path.parent.mkdir(parents=True, exist_ok=True)
374
+
375
+ file_handler = logging.FileHandler(file_path, encoding='utf-8')
376
+ file_handler.setLevel(logging.DEBUG)
377
+ file_handler.setFormatter(formatter)
378
+ self.logger.addHandler(file_handler)
379
+
380
+ if output == 'both':
381
+ self.logger.debug(f"日志文件: {file_path}")
382
+
383
+ def _ensure_database_exists(self, database: str):
384
+ """
385
+ 确保数据库存在,如果不存在则自动创建
386
+
387
+ 参数:
388
+ database: 数据库名
389
+
390
+ 异常:
391
+ ConnectionError: 连接失败
392
+ TableCreationError: 创建数据库失败
393
+ """
394
+ if not database or not self.auto_create:
395
+ return
396
+
397
+ conn = None
398
+ try:
399
+ # 先连接到MySQL(不指定数据库)
400
+ conn = pymysql.connect(
401
+ host=self.host,
402
+ port=self.port,
403
+ user=self.user,
404
+ password=self.password,
405
+ charset=self.charset,
406
+ cursorclass=DictCursor
407
+ )
408
+
409
+ with conn.cursor() as cursor:
410
+ # 检查数据库是否存在
411
+ cursor.execute(
412
+ "SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = %s",
413
+ (database,)
414
+ )
415
+ result = cursor.fetchone()
416
+
417
+ if not result:
418
+ # 数据库不存在,创建它
419
+ # 使用utf8mb4字符集和utf8mb4_0900_ai_ci排序规则
420
+ safe_db_name = database.replace('`', '``')
421
+ create_sql = f"""
422
+ CREATE DATABASE `{safe_db_name}`
423
+ CHARACTER SET utf8mb4
424
+ COLLATE utf8mb4_0900_ai_ci
425
+ """
426
+ cursor.execute(create_sql)
427
+ conn.commit()
428
+
429
+ if self.logger:
430
+ self.logger.info(f"自动创建数据库: {database}")
431
+ else:
432
+ if self.logger:
433
+ self.logger.debug(f"数据库已存在: {database}")
434
+
435
+ except pymysql.Error as e:
436
+ error_msg = f"数据库检查/创建失败: {str(e)}"
437
+ if self.logger:
438
+ self.logger.error(error_msg)
439
+ raise TableCreationError(error_msg) from e
440
+ finally:
441
+ if conn:
442
+ conn.close()
443
+
444
+ def _init_pool(self, pool_size: int):
445
+ """初始化连接池"""
446
+ try:
447
+ connection_kwargs = {
448
+ 'host': self.host,
449
+ 'port': self.port,
450
+ 'user': self.user,
451
+ 'password': self.password,
452
+ 'charset': self.charset,
453
+ 'cursorclass': DictCursor,
454
+ 'autocommit': False, # 写入时使用事务
455
+ }
456
+
457
+ if self.database:
458
+ connection_kwargs['database'] = self.database
459
+
460
+ self.pool = PooledDB(
461
+ creator=pymysql,
462
+ maxconnections=pool_size,
463
+ mincached=1,
464
+ maxcached=pool_size,
465
+ blocking=True,
466
+ maxusage=10000,
467
+ ping=1,
468
+ **connection_kwargs
469
+ )
470
+
471
+ if self.logger:
472
+ self.logger.debug(f"连接池初始化成功,大小: {pool_size}")
473
+ except Exception as e:
474
+ if self.logger:
475
+ self.logger.error(f"连接池初始化失败: {str(e)}")
476
+ raise
477
+
478
+ @contextmanager
479
+ def _get_connection(self):
480
+ """获取数据库连接"""
481
+ conn = None
482
+ try:
483
+ conn = self.pool.connection()
484
+ yield conn
485
+ except Exception as e:
486
+ if self.logger:
487
+ self.logger.error(f"数据库连接错误: {str(e)}")
488
+ raise
489
+ finally:
490
+ if conn:
491
+ conn.close()
492
+
493
+ def _get_table_indexes(self, table: str) -> List[Dict[str, Any]]:
494
+ """
495
+ 获取表的所有索引信息
496
+
497
+ 参数:
498
+ table: 表名
499
+
500
+ 返回:
501
+ 索引信息列表,每个索引包含:
502
+ - name: 索引名
503
+ - fields: 字段列表
504
+ - is_unique: 是否唯一索引
505
+ """
506
+ try:
507
+ with self._get_connection() as conn:
508
+ with conn.cursor() as cursor:
509
+ cursor.execute(f"SHOW INDEX FROM `{table}`")
510
+ rows = cursor.fetchall()
511
+
512
+ # 按索引名分组
513
+ indexes_dict = {}
514
+ for row in rows:
515
+ idx_name = row['Key_name']
516
+ if idx_name == 'PRIMARY': # 跳过主键
517
+ continue
518
+
519
+ if idx_name not in indexes_dict:
520
+ indexes_dict[idx_name] = {
521
+ 'name': idx_name,
522
+ 'fields': [],
523
+ 'is_unique': row['Non_unique'] == 0
524
+ }
525
+
526
+ indexes_dict[idx_name]['fields'].append(row['Column_name'])
527
+
528
+ return list(indexes_dict.values())
529
+ except Exception as e:
530
+ if self.logger:
531
+ self.logger.debug(f"获取索引信息失败: {str(e)}")
532
+ return []
533
+
534
+ def _partition_data_by_period(
535
+ self,
536
+ data_list: List[Dict[str, Any]],
537
+ date_field: Optional[str] = None,
538
+ mode: str = 'year'
539
+ ) -> Dict[str, List[Dict[str, Any]]]:
540
+ """
541
+ 按时间周期分组数据
542
+
543
+ 参数:
544
+ data_list: 数据列表
545
+ date_field: 日期字段名(None表示自动识别)
546
+ mode: 分表模式 'year'(按年) 或 'month'(按年月)
547
+
548
+ 返回:
549
+ {表名后缀: [数据列表]} 字典
550
+ - 'year' 模式: {'2024': [...], '2025': [...]}
551
+ - 'month' 模式: {'2024_01': [...], '2024_12': [...], '2025_03': [...]}
552
+
553
+ 异常:
554
+ DataValidationError: 找不到日期字段或无法解析日期
555
+ """
556
+ from datetime import datetime, date
557
+
558
+ if not data_list:
559
+ return {}
560
+
561
+ # 确定日期字段
562
+ sample = data_list[0]
563
+ target_date_field = None
564
+
565
+ if date_field:
566
+ # 用户指定了日期字段
567
+ if date_field not in sample:
568
+ raise DataValidationError(f"指定的日期字段 '{date_field}' 不存在")
569
+ target_date_field = date_field
570
+ else:
571
+ # 自动识别日期字段(优先"日期")
572
+ if '日期' in sample:
573
+ target_date_field = '日期'
574
+ else:
575
+ # 查找其他日期类型字段
576
+ for key, value in sample.items():
577
+ if isinstance(value, (datetime, date)):
578
+ target_date_field = key
579
+ break
580
+ elif isinstance(value, str):
581
+ # 尝试解析字符串日期
582
+ try:
583
+ datetime.strptime(value[:10], '%Y-%m-%d')
584
+ target_date_field = key
585
+ break
586
+ except:
587
+ continue
588
+
589
+ if not target_date_field:
590
+ raise DataValidationError(
591
+ "无法自动识别日期字段,请使用 partition_date_field 参数指定日期字段"
592
+ )
593
+
594
+ # 按时间周期分组数据
595
+ partitioned = {}
596
+ for row in data_list:
597
+ date_value = row.get(target_date_field)
598
+ if date_value is None:
599
+ if self.logger:
600
+ self.logger.warning(f"跳过日期字段为空的数据: {row}")
601
+ continue
602
+
603
+ # 提取年份和月份
604
+ year = None
605
+ month = None
606
+ dt_obj = None
607
+
608
+ if isinstance(date_value, datetime):
609
+ dt_obj = date_value
610
+ elif isinstance(date_value, date):
611
+ dt_obj = datetime(date_value.year, date_value.month, date_value.day)
612
+ elif isinstance(date_value, str):
613
+ try:
614
+ # 尝试多种日期格式
615
+ for fmt in ['%Y-%m-%d', '%Y/%m/%d', '%Y%m%d', '%Y-%m-%d %H:%M:%S']:
616
+ try:
617
+ dt_obj = datetime.strptime(date_value[:10] if len(date_value) >= 10 else date_value, fmt)
618
+ break
619
+ except:
620
+ continue
621
+
622
+ if dt_obj is None:
623
+ # 尝试直接提取年份和月份(例如 "2024-01-15")
624
+ year = int(date_value[:4])
625
+ if mode == 'month' and len(date_value) >= 7:
626
+ month = int(date_value[5:7])
627
+ except Exception as e:
628
+ if self.logger:
629
+ self.logger.warning(f"无法解析日期 '{date_value}': {e}")
630
+ continue
631
+
632
+ # 从 datetime 对象提取年月
633
+ if dt_obj:
634
+ year = dt_obj.year
635
+ month = dt_obj.month
636
+
637
+ # 生成分区键(表名后缀)
638
+ if year:
639
+ if mode == 'month' and month:
640
+ # 按年月分表: 2024_01, 2024_12
641
+ partition_key = f"{year}_{month:02d}"
642
+ else:
643
+ # 按年分表: 2024, 2025
644
+ partition_key = str(year)
645
+
646
+ if partition_key not in partitioned:
647
+ partitioned[partition_key] = []
648
+ partitioned[partition_key].append(row)
649
+
650
+ if not partitioned:
651
+ raise DataValidationError("所有数据的日期字段都无法解析")
652
+
653
+ if self.logger:
654
+ partition_keys = sorted(partitioned.keys())
655
+ counts = {k: len(partitioned[k]) for k in partition_keys}
656
+ mode_text = "按年月" if mode == 'month' else "按年份"
657
+ self.logger.info(f"数据{mode_text}分组: {counts} (使用字段: {target_date_field})")
658
+
659
+ return partitioned
660
+
661
+ def _ensure_table_exists(
662
+ self,
663
+ table: str,
664
+ sample_data: Dict[str, Any],
665
+ unique_key: Optional[Union[str, List[str]]] = None,
666
+ field_types: Optional[Dict[str, str]] = None,
667
+ allow_null: bool = False
668
+ ):
669
+ """
670
+ 确保表存在,如果不存在则创建
671
+
672
+ 参数:
673
+ table: 表名(支持 "库名.表名" 格式)
674
+ sample_data: 样本数据(用于推断字段类型)
675
+ unique_key: 唯一键(自动创建唯一索引)
676
+ - 单字段: 'product_id'
677
+ - 组合字段: ['shop_id', 'product_id']
678
+ field_types: 手动指定字段类型(可选),格式: {'字段名': 'MySQL类型'}
679
+ 例如: {'price': 'DECIMAL(10,2)', 'status': 'ENUM("active","inactive")'}
680
+ 指定的类型优先级高于自动推断
681
+ allow_null: 是否允许字段为NULL(默认False)
682
+ """
683
+ if not self.auto_create:
684
+ return
685
+
686
+ # 解析表名(支持 "库名.表名" 格式)并清理名称
687
+ if '.' in table:
688
+ db_part, table_part = table.split('.', 1)
689
+ target_db = sanitize_name(db_part, 'database')
690
+ target_table = sanitize_name(table_part, 'table')
691
+ else:
692
+ target_db = sanitize_name(self.database, 'database') if self.database else None
693
+ target_table = sanitize_name(table, 'table')
694
+
695
+ # 确保目标数据库存在
696
+ if target_db:
697
+ self._ensure_database_exists(target_db)
698
+
699
+ try:
700
+ with self._get_connection() as conn:
701
+ with conn.cursor() as cursor:
702
+ # 检查表是否存在
703
+ cursor.execute(
704
+ "SELECT 1 FROM information_schema.TABLES WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s",
705
+ (target_db, target_table)
706
+ )
707
+
708
+ if cursor.fetchone():
709
+ if self.logger:
710
+ self.logger.debug(f"表 {table} 已存在")
711
+ return
712
+
713
+ # 收集唯一键字段(用于判断是否需要限制长度)
714
+ unique_field_names = set()
715
+ if unique_key:
716
+ if isinstance(unique_key, str):
717
+ unique_field_names.add(sanitize_name(unique_key, 'field'))
718
+ elif isinstance(unique_key, list):
719
+ for field in unique_key:
720
+ unique_field_names.add(sanitize_name(field, 'field'))
721
+
722
+ # 创建表字段
723
+ fields = []
724
+ for key, value in sample_data.items():
725
+ field_name = sanitize_name(key, 'field')
726
+
727
+ # 优先使用用户指定的类型,否则自动推断
728
+ if field_types and key in field_types:
729
+ field_type = field_types[key]
730
+ else:
731
+ # 如果字段用于唯一约束,限制VARCHAR长度为191
732
+ is_for_index = field_name in unique_field_names
733
+ field_type = infer_mysql_type(value, for_index=is_for_index)
734
+
735
+ # 添加 NULL/NOT NULL 约束
736
+ null_constraint = "NULL" if allow_null else "NOT NULL"
737
+ fields.append(f"`{field_name}` {field_type} {null_constraint}")
738
+
739
+ # 构建CREATE TABLE语句
740
+ create_parts = []
741
+
742
+ # 1. 自增ID(可选)
743
+ if self.auto_add_id:
744
+ create_parts.append("`id` BIGINT AUTO_INCREMENT PRIMARY KEY")
745
+
746
+ # 2. 数据字段
747
+ create_parts.extend(fields)
748
+
749
+ # 3. 时间戳字段(可选)
750
+ if self.auto_add_timestamps:
751
+ create_parts.append("`created_at` DATETIME DEFAULT CURRENT_TIMESTAMP")
752
+ create_parts.append("`updated_at` DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")
753
+
754
+ # 4. 唯一索引(从 unique_key 自动创建)
755
+ if unique_key:
756
+ if isinstance(unique_key, str):
757
+ # 单字段唯一索引
758
+ field_name = sanitize_name(unique_key, 'field')
759
+ create_parts.append(f"UNIQUE KEY `uk_{field_name}` (`{field_name}`)")
760
+ elif isinstance(unique_key, list):
761
+ # 多字段组合唯一索引
762
+ field_names = [sanitize_name(f, 'field') for f in unique_key]
763
+ key_name = f"uk_{'_'.join(field_names)}"
764
+ fields_str = ', '.join(f'`{f}`' for f in field_names)
765
+ create_parts.append(f"UNIQUE KEY `{key_name}` ({fields_str})")
766
+
767
+ # 构建完整的表名(支持跨库创建)
768
+ safe_db = target_db.replace('`', '``')
769
+ safe_table = target_table.replace('`', '``')
770
+ full_table_name = f"`{safe_db}`.`{safe_table}`"
771
+
772
+ create_sql = f"""
773
+ CREATE TABLE {full_table_name} (
774
+ {', '.join(create_parts)}
775
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci
776
+ """
777
+
778
+ cursor.execute(create_sql)
779
+ conn.commit()
780
+
781
+ if self.logger:
782
+ unique_key_info = "是" if unique_key else "否"
783
+ self.logger.info(
784
+ f"表 {table} 创建成功 | "
785
+ f"字段数: {len(fields)} | "
786
+ f"自增ID: {self.auto_add_id} | "
787
+ f"时间戳: {self.auto_add_timestamps} | "
788
+ f"唯一索引: {unique_key_info}"
789
+ )
790
+
791
+ except Exception as e:
792
+ if self.logger:
793
+ self.logger.error(f"创建表失败: {str(e)}")
794
+ raise TableCreationError(f"创建表失败: {str(e)}") from e
795
+
796
+ def insert_one(
797
+ self,
798
+ table: str,
799
+ data: Dict[str, Any],
800
+ unique_key: Optional[Union[str, List[str]]] = None,
801
+ on_duplicate: str = 'ignore',
802
+ field_types: Optional[Dict[str, str]] = None,
803
+ allow_null: bool = False
804
+ ) -> bool:
805
+ """
806
+ 插入单条数据
807
+
808
+ 参数:
809
+ table: 表名
810
+ data: 数据字典
811
+ unique_key: 唯一键(用于去重)
812
+ on_duplicate: 重复时的操作 'ignore'/'update'
813
+ field_types: 手动指定字段类型(可选)
814
+ allow_null: 是否允许字段为NULL(默认False)
815
+
816
+ 返回:
817
+ 是否成功
818
+ """
819
+ return self.insert_many(
820
+ table,
821
+ [data],
822
+ unique_key=unique_key,
823
+ on_duplicate=on_duplicate,
824
+ field_types=field_types,
825
+ allow_null=allow_null
826
+ ) > 0
827
+
828
+ def insert_many(
829
+ self,
830
+ table: str,
831
+ data_list: List[Dict[str, Any]],
832
+ unique_key: Optional[Union[str, List[str]]] = None,
833
+ on_duplicate: str = 'ignore',
834
+ batch_size: int = 1000,
835
+ field_types: Optional[Dict[str, str]] = None,
836
+ allow_null: bool = False,
837
+ auto_partition_by_year: Union[bool, str] = False,
838
+ partition_date_field: Optional[str] = None
839
+ ) -> int:
840
+ """
841
+ 批量插入数据
842
+
843
+ 参数:
844
+ table: 表名(支持 "库名.表名" 格式,自动创建库和表)
845
+ data_list: 数据列表
846
+ unique_key: 唯一键(自动创建唯一索引 + 用于去重)
847
+ - 单字段: 'product_id'
848
+ - 组合字段: ['shop_id', 'product_id']
849
+ - None: 不创建唯一索引,允许重复数据
850
+ on_duplicate: 遇到重复数据时的操作(需要 unique_key)
851
+ - 'ignore': 忽略重复数据
852
+ - 'update': 更新重复数据
853
+ batch_size: 批次大小(默认1000)
854
+ field_types: 手动指定字段类型(可选),格式: {'字段名': 'MySQL类型'}
855
+ 例如: {'price': 'DECIMAL(10,2)', 'status': 'ENUM("active","inactive")'}
856
+ 未指定的字段会自动推断类型
857
+ allow_null: 是否允许字段为NULL(默认False)
858
+ - False: 字段设置为 NOT NULL(推荐)
859
+ - True: 字段设置为 NULL
860
+ 注意:自增ID和时间戳字段不受此参数影响
861
+ auto_partition_by_year: 自动分表模式(默认False)
862
+ - False: 不分表,所有数据插入同一张表(默认)
863
+ - True 或 'year': 按年分表,表名格式: table_2024, table_2025
864
+ - 'month': 按年月分表,表名格式: table_2024_01, table_2024_12
865
+ partition_date_field: 指定用于分表的日期字段名(可选)
866
+ - None: 自动识别(优先"日期",其次其他日期类型字段)
867
+ - '字段名': 使用指定字段进行分表
868
+
869
+ 返回:
870
+ 成功插入的行数
871
+
872
+ 数据类型说明:
873
+ - 自动根据第一条数据推断字段类型
874
+ - 唯一键字段自动限制为VARCHAR(191),避免索引长度超限
875
+ - 普通字符串字段:VARCHAR(动态长度) 或 TEXT
876
+ - 数值/时间等类型:自动精确匹配
877
+
878
+ 唯一键注意事项:
879
+ - 唯一键字段必须是VARCHAR/数值/时间类型
880
+ - TEXT/JSON/BLOB类型不能作为唯一键
881
+ - 组合唯一键的所有字段总长度不能超过767字节(utf8mb4)
882
+
883
+ 示例:
884
+ # 基础插入(自动推断类型)
885
+ data = [{'url': 'http://example.com', 'title': '标题', 'price': 99.9}]
886
+ writer.insert_many('products', data)
887
+
888
+ # 跨库插入(自动创建db2库和products表)
889
+ writer.insert_many('db2.products', data)
890
+
891
+ # 唯一键去重(基于url,自动创建唯一索引)
892
+ writer.insert_many('products', data, unique_key='url')
893
+
894
+ # UPSERT(基于url,重复时更新)
895
+ writer.insert_many('products', data, unique_key='url', on_duplicate='update')
896
+
897
+ # 组合唯一键(多字段组合)
898
+ writer.insert_many(
899
+ 'products',
900
+ data,
901
+ unique_key=['shop_id', 'product_id'] # 自动创建组合唯一索引
902
+ )
903
+
904
+ # 手动指定字段类型(精确控制)
905
+ writer.insert_many(
906
+ 'products',
907
+ data,
908
+ field_types={
909
+ 'price': 'DECIMAL(10,2)', # 精确的价格类型
910
+ 'status': 'ENUM("active","inactive")', # 枚举类型
911
+ 'description': 'TEXT' # 文本类型
912
+ }
913
+ # 其他字段仍然自动推断
914
+ )
915
+
916
+ # 允许NULL值(默认不允许)
917
+ writer.insert_many('products', data, allow_null=True)
918
+
919
+ # 按年分表(自动识别日期字段)
920
+ data_with_date = [
921
+ {'日期': '2024-01-15', 'sales': 1000},
922
+ {'日期': '2024-06-20', 'sales': 2000},
923
+ {'日期': '2025-03-10', 'sales': 1500}
924
+ ]
925
+ # 会自动创建 sales_2024 和 sales_2025 两张表
926
+ writer.insert_many(
927
+ 'sales',
928
+ data_with_date,
929
+ auto_partition_by_year=True # 或 'year'
930
+ )
931
+
932
+ # 按年月分表
933
+ # 会自动创建 logs_2024_01, logs_2024_06, logs_2025_03 等表
934
+ writer.insert_many(
935
+ 'logs',
936
+ data_with_date,
937
+ auto_partition_by_year='month' # 按年月分表
938
+ )
939
+
940
+ # 指定分表字段
941
+ writer.insert_many(
942
+ 'orders',
943
+ data,
944
+ auto_partition_by_year='year',
945
+ partition_date_field='order_date' # 使用 order_date 字段分表
946
+ )
947
+ """
948
+ if not data_list:
949
+ if self.logger:
950
+ self.logger.warning("数据列表为空")
951
+ return 0
952
+
953
+ start_time = time.time()
954
+ total_inserted = 0
955
+
956
+ try:
957
+ # 如果启用自动分表,按时间分组数据
958
+ if auto_partition_by_year:
959
+ # 确定分表模式
960
+ partition_mode = auto_partition_by_year
961
+ if partition_mode is True:
962
+ partition_mode = 'year' # True 等同于 'year'
963
+
964
+ # 按时间分组数据
965
+ partitioned_data = self._partition_data_by_period(
966
+ data_list,
967
+ partition_date_field,
968
+ partition_mode
969
+ )
970
+
971
+ # 为每个分区的表插入数据
972
+ for period_suffix, period_data in partitioned_data.items():
973
+ period_table = f"{table}_{period_suffix}"
974
+
975
+ # 确保分区表存在
976
+ self._ensure_table_exists(period_table, period_data[0], unique_key, field_types, allow_null)
977
+
978
+ # 分批处理分区数据
979
+ for i in range(0, len(period_data), batch_size):
980
+ batch = period_data[i:i + batch_size]
981
+ inserted = self._insert_batch(period_table, batch, unique_key, on_duplicate)
982
+ total_inserted += inserted
983
+ else:
984
+ # 不分表,正常插入
985
+ # 确保表存在(传入唯一键、字段类型和allow_null)
986
+ self._ensure_table_exists(table, data_list[0], unique_key, field_types, allow_null)
987
+
988
+ # 分批处理
989
+ for i in range(0, len(data_list), batch_size):
990
+ batch = data_list[i:i + batch_size]
991
+ inserted = self._insert_batch(table, batch, unique_key, on_duplicate)
992
+ total_inserted += inserted
993
+
994
+ elapsed = time.time() - start_time
995
+ self._stats['total_inserted'] += total_inserted
996
+ self._stats['total_time'] += elapsed
997
+
998
+ if self.logger:
999
+ if on_duplicate == 'update' and unique_key:
1000
+ self.logger.info(
1001
+ f"批量处理完成: {total_inserted}条数据(含新增/更新), 耗时{elapsed:.2f}秒"
1002
+ )
1003
+ else:
1004
+ self.logger.info(
1005
+ f"批量插入完成: {total_inserted}/{len(data_list)}行, 耗时{elapsed:.2f}秒"
1006
+ )
1007
+
1008
+ return total_inserted
1009
+
1010
+ except Exception as e:
1011
+ self._stats['total_failed'] += len(data_list) - total_inserted
1012
+ if self.logger:
1013
+ self.logger.error(f"批量插入失败: {str(e)}")
1014
+ raise InsertError(f"批量插入失败: {str(e)}") from e
1015
+
1016
+ def _insert_batch(
1017
+ self,
1018
+ table: str,
1019
+ batch: List[Dict[str, Any]],
1020
+ unique_key: Optional[Union[str, List[str]]],
1021
+ on_duplicate: str
1022
+ ) -> int:
1023
+ """插入单个批次"""
1024
+ if not batch:
1025
+ return 0
1026
+
1027
+ # 统一字段
1028
+ all_fields = set()
1029
+ for item in batch:
1030
+ all_fields.update(item.keys())
1031
+
1032
+ fields = sorted(all_fields)
1033
+ field_names = [sanitize_name(f, 'field') for f in fields]
1034
+
1035
+ # 构建完整的表名(支持 "库名.表名" 格式)并清理名称
1036
+ if '.' in table:
1037
+ db_name, table_name = table.split('.', 1)
1038
+ safe_db = sanitize_name(db_name, 'database')
1039
+ safe_table = sanitize_name(table_name, 'table')
1040
+ full_table_name = f"`{safe_db}`.`{safe_table}`"
1041
+ else:
1042
+ safe_table = sanitize_name(table, 'table')
1043
+ full_table_name = f"`{safe_table}`"
1044
+
1045
+ # 构建SQL
1046
+ placeholders = ', '.join(['%s'] * len(fields))
1047
+ sql = f"INSERT INTO {full_table_name} ({', '.join(f'`{f}`' for f in field_names)}) VALUES ({placeholders})"
1048
+
1049
+ # 处理重复键
1050
+ if unique_key:
1051
+ if on_duplicate == 'ignore':
1052
+ sql = sql.replace('INSERT', 'INSERT IGNORE')
1053
+ elif on_duplicate == 'update':
1054
+ update_fields = [f for f in field_names if f not in (
1055
+ [unique_key] if isinstance(unique_key, str) else unique_key
1056
+ )]
1057
+ if update_fields:
1058
+ updates = ', '.join(f"`{f}`=VALUES(`{f}`)" for f in update_fields)
1059
+ sql += f" ON DUPLICATE KEY UPDATE {updates}"
1060
+
1061
+ # 准备数据
1062
+ values = []
1063
+ for item in batch:
1064
+ row = []
1065
+ for field in fields:
1066
+ value = item.get(field)
1067
+ # 处理特殊类型
1068
+ if isinstance(value, (dict, list)):
1069
+ import json
1070
+ value = json.dumps(value, ensure_ascii=False)
1071
+ elif isinstance(value, datetime):
1072
+ value = value.strftime('%Y-%m-%d %H:%M:%S')
1073
+ elif isinstance(value, date):
1074
+ value = value.strftime('%Y-%m-%d')
1075
+ row.append(value)
1076
+ values.append(tuple(row))
1077
+
1078
+ # 执行插入
1079
+ try:
1080
+ with self._get_connection() as conn:
1081
+ with conn.cursor() as cursor:
1082
+ cursor.executemany(sql, values)
1083
+ affected = cursor.rowcount
1084
+ conn.commit()
1085
+
1086
+ # 修正计数:MySQL的ON DUPLICATE KEY UPDATE
1087
+ # - 新插入: rowcount = 1
1088
+ # - 更新: rowcount = 2 (删除旧行+插入新行)
1089
+ # 为了准确反映实际影响的行数,将更新的2算作1
1090
+ if on_duplicate == 'update' and unique_key:
1091
+ # 计算实际影响的行数(将UPDATE的2折半)
1092
+ # affected = 新插入数 + 更新数*2
1093
+ # 实际行数 = 新插入数 + 更新数 = (affected + 更新数) / 2
1094
+ # 简化:实际行数约等于数据条数
1095
+ return len(values)
1096
+
1097
+ return affected
1098
+ except Exception as e:
1099
+ if self.logger:
1100
+ self.logger.error(f"批次插入失败: {str(e)}")
1101
+ self.logger.debug(f"SQL: {sql}")
1102
+ raise
1103
+
1104
+ def upsert(
1105
+ self,
1106
+ table: str,
1107
+ data_list: List[Dict[str, Any]],
1108
+ unique_key: Union[str, List[str]],
1109
+ batch_size: int = 1000,
1110
+ field_types: Optional[Dict[str, str]] = None
1111
+ ) -> int:
1112
+ """
1113
+ UPSERT操作(插入或更新)
1114
+
1115
+ 参数:
1116
+ table: 表名
1117
+ data_list: 数据列表
1118
+ unique_key: 唯一键
1119
+ batch_size: 批次大小
1120
+ field_types: 手动指定字段类型(可选)
1121
+
1122
+ 返回:
1123
+ 影响的行数
1124
+ """
1125
+ return self.insert_many(
1126
+ table,
1127
+ data_list,
1128
+ unique_key=unique_key,
1129
+ on_duplicate='update',
1130
+ batch_size=batch_size,
1131
+ field_types=field_types
1132
+ )
1133
+
1134
+ def insert_dataframe(
1135
+ self,
1136
+ table: str,
1137
+ df,
1138
+ unique_key: Optional[Union[str, List[str]]] = None,
1139
+ on_duplicate: str = 'ignore',
1140
+ batch_size: int = 1000,
1141
+ field_types: Optional[Dict[str, str]] = None
1142
+ ) -> int:
1143
+ """
1144
+ 从DataFrame插入数据
1145
+
1146
+ 参数:
1147
+ table: 表名
1148
+ df: pandas DataFrame
1149
+ unique_key: 唯一键
1150
+ on_duplicate: 重复处理方式
1151
+ batch_size: 批次大小
1152
+ field_types: 手动指定字段类型(可选)
1153
+
1154
+ 返回:
1155
+ 插入的行数
1156
+ """
1157
+ try:
1158
+ import pandas as pd
1159
+
1160
+ if not isinstance(df, pd.DataFrame):
1161
+ raise DataValidationError("输入必须是pandas DataFrame")
1162
+
1163
+ # 转换为字典列表
1164
+ data_list = df.to_dict('records')
1165
+
1166
+ return self.insert_many(
1167
+ table,
1168
+ data_list,
1169
+ unique_key=unique_key,
1170
+ on_duplicate=on_duplicate,
1171
+ batch_size=batch_size,
1172
+ field_types=field_types
1173
+ )
1174
+
1175
+ except ImportError:
1176
+ raise ImportError("需要安装pandas: pip install pandas")
1177
+
1178
+ def create_index(
1179
+ self,
1180
+ table: str,
1181
+ fields: Union[str, List[str]],
1182
+ unique: bool = False,
1183
+ index_name: Optional[str] = None
1184
+ ) -> bool:
1185
+ """
1186
+ 创建索引(支持单字段和组合索引)
1187
+
1188
+ 参数:
1189
+ table: 表名
1190
+ fields: 字段名或字段列表(支持组合索引)
1191
+ unique: 是否唯一索引
1192
+ index_name: 索引名称(可选,不指定则自动生成)
1193
+
1194
+ 返回:
1195
+ 是否成功创建(如果已存在返回False)
1196
+
1197
+ 示例:
1198
+ # 单字段索引
1199
+ writer.create_index('products', 'category')
1200
+
1201
+ # 唯一索引
1202
+ writer.create_index('products', 'url', unique=True)
1203
+
1204
+ # 组合索引
1205
+ writer.create_index('products', ['shop_id', 'product_id'])
1206
+
1207
+ # 组合唯一索引
1208
+ writer.create_index('products', ['shop_id', 'product_id'], unique=True)
1209
+ """
1210
+ if isinstance(fields, str):
1211
+ fields = [fields]
1212
+
1213
+ field_names = [sanitize_name(f, 'field') for f in fields]
1214
+
1215
+ # 检查是否已存在相同字段组合的索引
1216
+ existing_indexes = self._get_table_indexes(table)
1217
+ fields_set = set(field_names)
1218
+
1219
+ for idx_info in existing_indexes:
1220
+ idx_fields_set = set(idx_info['fields'])
1221
+ if idx_fields_set == fields_set:
1222
+ # 已存在相同字段组合的索引
1223
+ if idx_info['is_unique']:
1224
+ if unique:
1225
+ if self.logger:
1226
+ self.logger.debug(f"唯一索引已存在: {idx_info['name']} on {table}({', '.join(field_names)})")
1227
+ return False
1228
+ else:
1229
+ if self.logger:
1230
+ self.logger.warning(
1231
+ f"跳过创建普通索引:字段 {', '.join(field_names)} 已有唯一索引 {idx_info['name']},"
1232
+ f"唯一索引包含普通索引的全部功能"
1233
+ )
1234
+ return False
1235
+ else:
1236
+ # 已存在普通索引
1237
+ if unique:
1238
+ # 想创建唯一索引,但已有普通索引(需要先删除普通索引)
1239
+ if self.logger:
1240
+ self.logger.warning(f"字段 {', '.join(field_names)} 已有普通索引 {idx_info['name']}")
1241
+ return False
1242
+ else:
1243
+ # 普通索引已存在
1244
+ if self.logger:
1245
+ self.logger.debug(f"普通索引已存在: {idx_info['name']}")
1246
+ return False
1247
+
1248
+ if not index_name:
1249
+ prefix = 'uk' if unique else 'idx'
1250
+ index_name = f"{prefix}_{'_'.join(field_names)}"
1251
+
1252
+ index_type = 'UNIQUE' if unique else ''
1253
+ fields_str = ', '.join(f'`{f}`' for f in field_names)
1254
+
1255
+ sql = f"CREATE {index_type} INDEX `{index_name}` ON `{table}` ({fields_str})"
1256
+
1257
+ try:
1258
+ with self._get_connection() as conn:
1259
+ with conn.cursor() as cursor:
1260
+ cursor.execute(sql)
1261
+ conn.commit()
1262
+
1263
+ if self.logger:
1264
+ index_type_str = "组合唯一索引" if unique and len(fields) > 1 else \
1265
+ "唯一索引" if unique else \
1266
+ "组合索引" if len(fields) > 1 else "索引"
1267
+ self.logger.info(f"{index_type_str}创建成功: {index_name} on {table}({fields_str})")
1268
+ return True
1269
+ except pymysql.err.OperationalError as e:
1270
+ if 'Duplicate key name' in str(e):
1271
+ if self.logger:
1272
+ self.logger.debug(f"索引名 {index_name} 已存在")
1273
+ return False
1274
+ else:
1275
+ raise
1276
+
1277
+ def create_indexes(
1278
+ self,
1279
+ table: str,
1280
+ indexes: List[Dict[str, Any]]
1281
+ ) -> int:
1282
+ """
1283
+ 批量创建索引
1284
+
1285
+ 参数:
1286
+ table: 表名
1287
+ indexes: 索引配置列表,每个元素是字典,包含:
1288
+ - fields: 字段名或字段列表
1289
+ - unique: 是否唯一索引(可选,默认False)
1290
+ - name: 索引名称(可选)
1291
+
1292
+ 返回:
1293
+ 成功创建的索引数量
1294
+
1295
+ 示例:
1296
+ writer.create_indexes('products', [
1297
+ {'fields': 'url', 'unique': True},
1298
+ {'fields': 'category'},
1299
+ {'fields': ['shop_id', 'product_id'], 'unique': True},
1300
+ {'fields': ['created_at', 'status']}
1301
+ ])
1302
+ """
1303
+ count = 0
1304
+ for idx_config in indexes:
1305
+ fields = idx_config.get('fields')
1306
+ unique = idx_config.get('unique', False)
1307
+ name = idx_config.get('name')
1308
+
1309
+ if fields:
1310
+ if self.create_index(table, fields, unique, name):
1311
+ count += 1
1312
+
1313
+ if self.logger:
1314
+ self.logger.info(f"批量创建索引完成: {count}/{len(indexes)} 个索引创建成功")
1315
+
1316
+ return count
1317
+
1318
+ def get_stats(self) -> Dict[str, Any]:
1319
+ """
1320
+ 获取统计信息
1321
+
1322
+ 返回:
1323
+ 统计信息字典
1324
+ """
1325
+ stats = self._stats.copy()
1326
+ if stats['total_inserted'] > 0:
1327
+ stats['avg_speed'] = stats['total_inserted'] / stats['total_time'] if stats['total_time'] > 0 else 0
1328
+ else:
1329
+ stats['avg_speed'] = 0
1330
+ return stats
1331
+
1332
+ def reset_stats(self):
1333
+ """重置统计信息"""
1334
+ self._stats = {
1335
+ 'total_inserted': 0,
1336
+ 'total_updated': 0,
1337
+ 'total_failed': 0,
1338
+ 'total_time': 0.0
1339
+ }
1340
+
1341
+ def __enter__(self):
1342
+ """上下文管理器入口"""
1343
+ return self
1344
+
1345
+ def __exit__(self, exc_type, exc_val, exc_tb):
1346
+ """上下文管理器出口"""
1347
+ self.close()
1348
+ return False
1349
+
1350
+ def close(self):
1351
+ """关闭连接池"""
1352
+ if self._closed:
1353
+ return
1354
+
1355
+ try:
1356
+ if hasattr(self, 'pool'):
1357
+ self.pool.close()
1358
+ if self.logger:
1359
+ self.logger.debug("连接池已关闭")
1360
+ # 输出最终统计
1361
+ stats = self.get_stats()
1362
+ self.logger.info(f"总计处理: {stats['total_inserted']}条数据")
1363
+ except Exception as e:
1364
+ if self.logger:
1365
+ self.logger.error(f"关闭连接失败: {str(e)}")
1366
+ finally:
1367
+ self._closed = True
1368
+
1369
+
1370
+ # ==================== 测试代码 ====================
1371
+
1372
+
1373
+ def test():
1374
+ """演示功能"""
1375
+
1376
+ try:
1377
+ with MYSQLWriter(
1378
+ host='localhost',
1379
+ user='user',
1380
+ password='password',
1381
+ auto_create=True,
1382
+ auto_add_id=True,
1383
+ auto_add_timestamps=True,
1384
+ log_config={'enable': True, 'level': 'INFO', 'output': 'both', 'file_path': 'mysql_writer.log'}
1385
+ ) as writer:
1386
+ custom_type_data = [
1387
+ {
1388
+ 'product_id': 123456,
1389
+ 'product_name': '商品名称1',
1390
+ 'price': 99.99,
1391
+ 'stock': 100,
1392
+ 'status': 'active',
1393
+ 'weight': 1.5,
1394
+ 'rating': 4.4
1395
+ },
1396
+ {
1397
+ 'product_id': 22456,
1398
+ 'product_name': '商品名称2',
1399
+ 'price': 22,
1400
+ 'stock': 2,
1401
+ 'status': 'active',
1402
+ 'weight': 0.2,
1403
+ 'rating': 2.22
1404
+ },
1405
+ {
1406
+ 'product_id': 123456,
1407
+ 'product_name': '商品名称1',
1408
+ 'price': 3,
1409
+ 'stock': 0.33,
1410
+ 'status': 'active',
1411
+ 'weight': 0.333,
1412
+ 'rating': 3.33
1413
+ }
1414
+ ]
1415
+ field_types={
1416
+ 'product_id': 'INT UNSIGNED',
1417
+ 'price': 'DECIMAL(10,2)',
1418
+ 'stock': 'INT UNSIGNED',
1419
+ 'status': 'ENUM("active","inactive","sold")',
1420
+ 'weight': 'FLOAT(5,2)',
1421
+ 'rating': 'DECIMAL(3,1)'
1422
+ }
1423
+
1424
+ count = writer.insert_many(
1425
+ table='test_db.custom_products',
1426
+ data_list=custom_type_data,
1427
+ unique_key=['product_id', 'product_name'], # 自动创建唯一索引
1428
+ on_duplicate='update',
1429
+ field_types=field_types
1430
+ )
1431
+ print(f"✓ 处理成功: {count}条数据(自动创建唯一索引 + UPSERT)")
1432
+
1433
+ except Exception as e:
1434
+ print(f"\n✗ 测试失败: {e}")
1435
+
1436
+
1437
+ if __name__ == '__main__':
1438
+ test()
1439
+