pgsqldatatool 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pgsqldatatool/__init__.py +8 -0
- pgsqldatatool/data_clean.py +230 -0
- pgsqldatatool/pgsql_connection_async.py +76 -0
- pgsqldatatool/tools.py +24 -0
- pgsqldatatool/until_async.py +678 -0
- pgsqldatatool-1.0.0.dist-info/METADATA +204 -0
- pgsqldatatool-1.0.0.dist-info/RECORD +9 -0
- pgsqldatatool-1.0.0.dist-info/WHEEL +4 -0
- pgsqldatatool-1.0.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
|
|
2
|
+
import re
|
|
3
|
+
import logging
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from zoneinfo import ZoneInfo
|
|
6
|
+
import polars as pl
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# 基础清洗 (去掉列名两端空白字符、去掉数据两段空白字符、删除空行)
|
|
13
|
+
def lf_basic_clean(df:pl.LazyFrame|pl.DataFrame)->pl.LazyFrame:
|
|
14
|
+
"""
|
|
15
|
+
基础清洗
|
|
16
|
+
1、去掉列名中的首尾空格和首尾换行符
|
|
17
|
+
2、删除数据两端的空白字符("-",空格、制表符、回车、换行等)
|
|
18
|
+
3、删除完全为空的行
|
|
19
|
+
Args:
|
|
20
|
+
df (pl.LazyFrame | pl.DataFrame): 要清洗的数据集。
|
|
21
|
+
Returns:
|
|
22
|
+
pl.LazyFrame: 清洗后的数据集。
|
|
23
|
+
"""
|
|
24
|
+
if not isinstance(df, pl.LazyFrame):
|
|
25
|
+
df= df.lazy()
|
|
26
|
+
|
|
27
|
+
# 去掉列名中的首尾空格和首尾换行符-------------------------------------------------
|
|
28
|
+
pattern = r'^[\s\n\r]+|[\s\n\r]+$'
|
|
29
|
+
new_names = (
|
|
30
|
+
pl.Series(df.collect_schema().names())
|
|
31
|
+
.str.replace_all(pattern, "")
|
|
32
|
+
.to_list()
|
|
33
|
+
)
|
|
34
|
+
df = df.rename(dict(zip(df.collect_schema().names(), new_names)))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# 删除数据两端的空白字符("-",空格、制表符、回车、换行等)---------------------------
|
|
38
|
+
df = df.with_columns(
|
|
39
|
+
pl.col(pl.String)
|
|
40
|
+
.str.strip_chars()
|
|
41
|
+
.str.replace("-", "")
|
|
42
|
+
)
|
|
43
|
+
# 删除空白行--------------------------------------------------
|
|
44
|
+
df = df.filter(
|
|
45
|
+
pl.any_horizontal(pl.all().is_not_null())
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
return df
|
|
49
|
+
|
|
50
|
+
# 根据字典重命名列
|
|
51
|
+
def lf_rename_cols(df:pl.LazyFrame|pl.DataFrame, rename_cols: dict)->pl.LazyFrame:
|
|
52
|
+
"""
|
|
53
|
+
根据字典重命名列
|
|
54
|
+
Args:
|
|
55
|
+
df (pl.LazyFrame | pl.DataFrame): 要清洗的数据集。
|
|
56
|
+
rename_cols (dict, optional): 列名映射字典。
|
|
57
|
+
Returns:
|
|
58
|
+
df.LazyFrame: 数据集
|
|
59
|
+
"""
|
|
60
|
+
if not isinstance(df, pl.LazyFrame):
|
|
61
|
+
df= df.lazy()
|
|
62
|
+
|
|
63
|
+
# 根据参数修改存在的列的列名
|
|
64
|
+
if rename_cols is not None:
|
|
65
|
+
df = df.rename({
|
|
66
|
+
k: v for k, v in rename_cols.items() if k in df.columns
|
|
67
|
+
})
|
|
68
|
+
|
|
69
|
+
return df
|
|
70
|
+
|
|
71
|
+
# 移除重复行
|
|
72
|
+
def lf_remove_dup_rows(df:pl.LazyFrame|pl.DataFrame)->pl.LazyFrame:
|
|
73
|
+
"""
|
|
74
|
+
保留第一次出现的行,删除其余完全相同的重复行
|
|
75
|
+
Args:
|
|
76
|
+
df (pl.LazyFrame|pl.DataFrame): 数据集
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
pl.LazyFrame: 数据集
|
|
80
|
+
|
|
81
|
+
"""
|
|
82
|
+
if isinstance(df, pl.DataFrame): # 转换为懒加载
|
|
83
|
+
df = df.lazy()
|
|
84
|
+
df = df.unique(
|
|
85
|
+
subset=None, # 所有列
|
|
86
|
+
keep="first" # 保留第一个
|
|
87
|
+
)
|
|
88
|
+
return df
|
|
89
|
+
|
|
90
|
+
# 删除指定列
|
|
91
|
+
def lf_remove_cols(df, cols:list[str] | None = None)->pl.LazyFrame:
|
|
92
|
+
"""
|
|
93
|
+
删除指定列
|
|
94
|
+
Args:
|
|
95
|
+
df (pl.LazyFrame|pl.DataFrame): 数据集
|
|
96
|
+
cols (list[str] | None): 需要删除的列名。
|
|
97
|
+
Returns:
|
|
98
|
+
pl.LazyFrame|pl.DataFrame: 处理后的数据集
|
|
99
|
+
"""
|
|
100
|
+
if isinstance(df, pl.DataFrame): # 转换为懒加载
|
|
101
|
+
df = df.lazy()
|
|
102
|
+
|
|
103
|
+
# 根据参数删除存在的列
|
|
104
|
+
if cols is not None:
|
|
105
|
+
df = df.drop([c for c in cols if c in df.columns])
|
|
106
|
+
|
|
107
|
+
return df
|
|
108
|
+
|
|
109
|
+
# 移除数字千分号
|
|
110
|
+
def lf_remove_per_mille(df, cols: "all" | list[str] | None = None):
|
|
111
|
+
"""
|
|
112
|
+
移除字符串数字中的千分号
|
|
113
|
+
Args:
|
|
114
|
+
df (pl.LazyFrame|pl.DataFrame): 数据集
|
|
115
|
+
cols (all | list[str] | None): 需要将千分比转换为小数列的列名。\n
|
|
116
|
+
默认为None。当值为all时,将所有列名转换为小数,返回仍是字符串。
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
pl.LazyFrame: 数据集
|
|
120
|
+
"""
|
|
121
|
+
if not isinstance(df, pl.LazyFrame):
|
|
122
|
+
df= df.lazy()
|
|
123
|
+
|
|
124
|
+
# 移除字符串数字中的千分号-------------------------------------------------------
|
|
125
|
+
if cols:
|
|
126
|
+
if isinstance(cols, str) and cols=="all":
|
|
127
|
+
df = df.with_columns(
|
|
128
|
+
pl.when(
|
|
129
|
+
pl.col(pl.String).str.contains(r'^\d{1,3}(,\d{3})*(\.\d+)?$')
|
|
130
|
+
)
|
|
131
|
+
.then(pl.col(pl.String).str.replace_all(",", ""))
|
|
132
|
+
.otherwise(pl.col(pl.String))
|
|
133
|
+
)
|
|
134
|
+
elif isinstance(cols, list):
|
|
135
|
+
df = df.with_columns(
|
|
136
|
+
pl.when(
|
|
137
|
+
pl.col(cols).str.contains(r'^\d{1,3}(,\d{3})*(\.\d+)?$')
|
|
138
|
+
)
|
|
139
|
+
.then(pl.col(cols).str.replace_all(",", ""))
|
|
140
|
+
.otherwise(pl.col(cols))
|
|
141
|
+
)
|
|
142
|
+
else:
|
|
143
|
+
raise ValueError("per_mille_cols参数错误")
|
|
144
|
+
|
|
145
|
+
# 移除数字百分号
|
|
146
|
+
def lf_remove_percent(df, cols: "all" | list[str] | None = None):
|
|
147
|
+
"""
|
|
148
|
+
移除字符串数字中的百分号
|
|
149
|
+
Args:
|
|
150
|
+
df (pl.LazyFrame|pl.DataFrame): 数据集
|
|
151
|
+
cols (all | list[str] | None): 需要将百分比转换为小数列的列名。\n
|
|
152
|
+
默认为None。当值为all时,将所有列名转换为小数,返回仍是字符串。
|
|
153
|
+
Returns:
|
|
154
|
+
pl.LazyFrame: 数据集
|
|
155
|
+
"""
|
|
156
|
+
if not isinstance(df, pl.LazyFrame):
|
|
157
|
+
df= df.lazy()
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
if isinstance(cols, str) and cols == "all":
|
|
161
|
+
target = pl.col(pl.String)
|
|
162
|
+
elif isinstance(cols, list):
|
|
163
|
+
target = pl.col(cols)
|
|
164
|
+
else:
|
|
165
|
+
raise ValueError("percent_cols参数错误")
|
|
166
|
+
|
|
167
|
+
df = df.with_columns(
|
|
168
|
+
target
|
|
169
|
+
.str.strip_chars()
|
|
170
|
+
.str.replace_all(",", "")
|
|
171
|
+
.str.extract(r"^(\d+\.?\d*)%$")
|
|
172
|
+
.cast(pl.Float64)
|
|
173
|
+
.truediv(100)
|
|
174
|
+
.fill_null(target)
|
|
175
|
+
)
|
|
176
|
+
return df
|
|
177
|
+
|
|
178
|
+
# 添加时间列
|
|
179
|
+
def lf_add_time(df, time_zone='Asia/Shanghai') -> pl.LazyFrame:
|
|
180
|
+
"""
|
|
181
|
+
添加时间列
|
|
182
|
+
Args:
|
|
183
|
+
df (pl.LazyFrame|pl.DataFrame): 数据集
|
|
184
|
+
time_zone (str, optional): 时区。默认为'Asia/Shanghai'。
|
|
185
|
+
Returns:
|
|
186
|
+
pl.LazyFrame: 数据集
|
|
187
|
+
"""
|
|
188
|
+
if not isinstance(df, pl.LazyFrame):
|
|
189
|
+
df= df.lazy()
|
|
190
|
+
|
|
191
|
+
return df.with_columns(
|
|
192
|
+
# 使用 pl.lit(datetime.now(...)) 依然是安全的,
|
|
193
|
+
# 因为整个 with_columns 表达式在 LazyFrame 下是惰性的,
|
|
194
|
+
# datetime.now() 会在最终 .collect() 时才被求值!
|
|
195
|
+
pl.lit(datetime.now(ZoneInfo(time_zone))).alias('数据写入时间')
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# 删除完全为空的列
|
|
199
|
+
def df_drop_empty_cols(df) -> pl.LazyFrame:
|
|
200
|
+
"""
|
|
201
|
+
删除完全为空的列 \n
|
|
202
|
+
这个函数不是延迟执行的 \n
|
|
203
|
+
collect() 会触发运行一次计算,获取所有列的空值状态
|
|
204
|
+
Args:
|
|
205
|
+
df (pl.LazyFrame|pl.DataFrame): 数据集
|
|
206
|
+
Returns:
|
|
207
|
+
pl.LazyFrame: 数据集
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
# 1. 构建一个表达式:一次性计算每一列是否全为 null
|
|
211
|
+
# pl.struct 会将所有列打包,pl.all_horizontal 会并行计算
|
|
212
|
+
is_all_null_expr = pl.all_horizontal(
|
|
213
|
+
pl.col(c).is_null().alias(c) for c in df.collect_schema().names()
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# 2. 仅触发一次收集,获取所有列的空值状态
|
|
217
|
+
# 注意:这里只计算聚合值,不扫描全表数据,速度极快
|
|
218
|
+
all_null_flags = df.select(is_all_null_expr).collect().row(0)
|
|
219
|
+
|
|
220
|
+
# 3. 在 Python 端解析出需要删除的列名
|
|
221
|
+
cols_to_drop = [
|
|
222
|
+
col_name for col_name, is_null in zip(df.collect_schema().names(), all_null_flags)
|
|
223
|
+
if is_null
|
|
224
|
+
]
|
|
225
|
+
|
|
226
|
+
# 4. 延迟执行 drop 操作(drop 本身是惰性的,直到最终 collect 才生效)
|
|
227
|
+
if cols_to_drop:
|
|
228
|
+
df = df.drop(cols_to_drop)
|
|
229
|
+
|
|
230
|
+
return df
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import asyncpg
|
|
5
|
+
import asyncio
|
|
6
|
+
from contextlib import asynccontextmanager
|
|
7
|
+
|
|
8
|
+
class PoolSingleton:
|
|
9
|
+
_pool = None
|
|
10
|
+
_lock = asyncio.Lock()
|
|
11
|
+
|
|
12
|
+
@classmethod
|
|
13
|
+
async def get(cls):
|
|
14
|
+
if cls._pool is None:
|
|
15
|
+
async with cls._lock:
|
|
16
|
+
if cls._pool is None or cls._pool._closed: # double check
|
|
17
|
+
|
|
18
|
+
pg_host = os.environ.get("PG_HOST", "localhost")
|
|
19
|
+
pg_port = os.environ.get("PG_PORT", "5432")
|
|
20
|
+
pg_user = os.environ.get("PG_USER", "postgres")
|
|
21
|
+
pg_password = os.environ.get("PG_PASSWORD", "password")
|
|
22
|
+
pg_database = os.environ.get("PG_DB", "testdb")
|
|
23
|
+
min_size = int(os.environ.get("PG_MIN_SIZE", "5"))
|
|
24
|
+
max_size = int(os.environ.get("PG_MAX_SIZE", "20"))
|
|
25
|
+
|
|
26
|
+
# 写法1
|
|
27
|
+
# cls._pool = await asyncpg.create_pool(
|
|
28
|
+
# # dsn="postgresql://postgres:password@localhost:5432/testdb",
|
|
29
|
+
# dsn = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_database}",
|
|
30
|
+
# min_size=min_size,
|
|
31
|
+
# max_size=max_size
|
|
32
|
+
# )
|
|
33
|
+
|
|
34
|
+
# 独立参数写法
|
|
35
|
+
cls._pool = await asyncpg.create_pool(
|
|
36
|
+
user=pg_user,
|
|
37
|
+
password=pg_password,
|
|
38
|
+
host=pg_host,
|
|
39
|
+
port=pg_port,
|
|
40
|
+
database=pg_database,
|
|
41
|
+
min_size=min_size,
|
|
42
|
+
max_size=max_size
|
|
43
|
+
)
|
|
44
|
+
return cls._pool
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
async def close(cls):
|
|
48
|
+
if cls._pool:
|
|
49
|
+
await cls._pool.close()
|
|
50
|
+
cls._pool = None
|
|
51
|
+
|
|
52
|
+
# @classmethod
|
|
53
|
+
# async def acquire(cls):
|
|
54
|
+
# pool = await cls.get()
|
|
55
|
+
# return pool.acquire()
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
@asynccontextmanager
|
|
59
|
+
async def acquire(cls):
|
|
60
|
+
pool = await cls.get()
|
|
61
|
+
async with pool.acquire() as conn:
|
|
62
|
+
yield conn
|
|
63
|
+
@classmethod
|
|
64
|
+
async def fetch(cls, query, *args):
|
|
65
|
+
async with cls.acquire() as conn:
|
|
66
|
+
return await conn.fetch(query, *args)
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
async def fetchrow(cls, query, *args):
|
|
70
|
+
async with cls.acquire() as conn:
|
|
71
|
+
return await conn.fetchrow(query, *args)
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
async def execute(cls, query, *args):
|
|
75
|
+
async with cls.acquire() as conn:
|
|
76
|
+
return await conn.execute(query, *args)
|
pgsqldatatool/tools.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import asyncpg
|
|
2
|
+
import polars as pl
|
|
3
|
+
|
|
4
|
+
def records_to_df(records:list[asyncpg.Record]):
|
|
5
|
+
"""
|
|
6
|
+
将asyncpg
|
|
7
|
+
将asyncpg.Record对象列表转换为polars.DataFrame对象
|
|
8
|
+
理论上这个写法比 df = pl.DataFrame([dict(r) for r in records]) 更高效
|
|
9
|
+
Args:
|
|
10
|
+
records (list[asyncpg.Record]): 要转换的记录列表。
|
|
11
|
+
Returns:
|
|
12
|
+
pl.DataFrame: 转换后的DataFrame对象。
|
|
13
|
+
"""
|
|
14
|
+
if not records:
|
|
15
|
+
return pl.DataFrame()
|
|
16
|
+
|
|
17
|
+
cols = list(records[0].keys()) # ✅ 转成 list
|
|
18
|
+
data = [[] for _ in range(len(cols))]
|
|
19
|
+
|
|
20
|
+
for r in records:
|
|
21
|
+
for i, val in enumerate(r):
|
|
22
|
+
data[i].append(val)
|
|
23
|
+
|
|
24
|
+
return pl.DataFrame({col: data[i] for i, col in enumerate(cols)})
|
|
@@ -0,0 +1,678 @@
|
|
|
1
|
+
"""异步 PostgreSQL 写入工具(Polars 专用)。"""
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
import time
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
|
7
|
+
|
|
8
|
+
import polars as pl
|
|
9
|
+
import asyncpg
|
|
10
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
11
|
+
|
|
12
|
+
from polars.datatypes import (
|
|
13
|
+
Int8, Int16, Int32, Int64,
|
|
14
|
+
UInt8, UInt16, UInt32, UInt64,
|
|
15
|
+
Float32, Float64,
|
|
16
|
+
Boolean,
|
|
17
|
+
String,
|
|
18
|
+
Datetime,
|
|
19
|
+
Duration,
|
|
20
|
+
List,
|
|
21
|
+
Struct,
|
|
22
|
+
Decimal,
|
|
23
|
+
Date,
|
|
24
|
+
Time,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
import logging
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger("pgtool")
|
|
30
|
+
|
|
31
|
+
# =========================================================
|
|
32
|
+
# 公开 API
|
|
33
|
+
# =========================================================
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
async def write_pg(
|
|
37
|
+
df: pl.DataFrame,
|
|
38
|
+
schema: str,
|
|
39
|
+
table: str,
|
|
40
|
+
conn: Union[asyncpg.Connection, AsyncSession],
|
|
41
|
+
*,
|
|
42
|
+
create: bool = False,
|
|
43
|
+
update: bool = False,
|
|
44
|
+
add_columns: bool = False,
|
|
45
|
+
only_column: bool = False,
|
|
46
|
+
write: Literal["auto", "copy", "execute", "executemany"] = "auto",
|
|
47
|
+
primary_key: Optional[Union[str, List[str]]] = None,
|
|
48
|
+
) -> int:
|
|
49
|
+
"""
|
|
50
|
+
将 Polars DataFrame 异步写入 PostgreSQL 表。
|
|
51
|
+
|
|
52
|
+
支持:
|
|
53
|
+
- 自动建表(基于 Polars 原始 dtype 映射)
|
|
54
|
+
- 自动新增字段(基于 Polars 原始 dtype 映射)
|
|
55
|
+
- 字段对齐
|
|
56
|
+
- 主键校验
|
|
57
|
+
- UPSERT(ON CONFLICT)
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
df (pl.DataFrame):
|
|
61
|
+
要写入的数据。
|
|
62
|
+
schema (str):
|
|
63
|
+
数据库模式名。
|
|
64
|
+
table (str):
|
|
65
|
+
目标表名。
|
|
66
|
+
conn (Union[asyncpg.Connection, AsyncSession]):
|
|
67
|
+
异步连接对象。
|
|
68
|
+
create (bool, optional):
|
|
69
|
+
表不存在时是否自动创建,默认 False。
|
|
70
|
+
update (bool, optional):
|
|
71
|
+
是否更新已存在的数据,默认 False。
|
|
72
|
+
add_columns (bool, optional):
|
|
73
|
+
是否自动添加表中不存在的列,默认 False。
|
|
74
|
+
only_column (bool, optional):
|
|
75
|
+
是否只写入表中已存在的列,默认 False。
|
|
76
|
+
write (Literal["auto", "copy", "execute", "executemany"], optional):
|
|
77
|
+
写入模式:
|
|
78
|
+
- auto: 自动选择
|
|
79
|
+
- copy: COPY 模式(大数据量推荐)
|
|
80
|
+
- execute: 单行 INSERT(慢)
|
|
81
|
+
- executemany: 多行 INSERT
|
|
82
|
+
primary_key (Optional[Union[str, List[str]]], optional):
|
|
83
|
+
主键列名,用于自动建表和 UPSERT。
|
|
84
|
+
- None: 不设主键
|
|
85
|
+
- str: 单列主键(如 "id"、"uuid")
|
|
86
|
+
- List[str]: 复合主键(如 ["user_id", "order_id"])
|
|
87
|
+
仅在 create=True 且表不存在时用于建表;
|
|
88
|
+
表已存在时以数据库实际主键为准。
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
int:
|
|
92
|
+
受影响的行数。
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
ValueError:
|
|
96
|
+
表不存在且 create=False,
|
|
97
|
+
或字段不匹配,
|
|
98
|
+
或主键问题。
|
|
99
|
+
TypeError:
|
|
100
|
+
Polars 类型无法映射到 PG 类型。
|
|
101
|
+
RuntimeError:
|
|
102
|
+
无法解析 asyncpg 连接。
|
|
103
|
+
"""
|
|
104
|
+
start_time = time.time()
|
|
105
|
+
|
|
106
|
+
if df.is_empty():
|
|
107
|
+
logger.warning("DataFrame 为空,跳过写入")
|
|
108
|
+
return 0
|
|
109
|
+
|
|
110
|
+
# 保存 Polars 原始 dtype(用于建表 / 新增字段)
|
|
111
|
+
original_schema = df.schema
|
|
112
|
+
|
|
113
|
+
actual_conn = await _get_asyncpg_connection(conn)
|
|
114
|
+
|
|
115
|
+
async with actual_conn.transaction():
|
|
116
|
+
count = await _write_data(
|
|
117
|
+
df=df,
|
|
118
|
+
original_schema=original_schema,
|
|
119
|
+
schema=schema,
|
|
120
|
+
table=table,
|
|
121
|
+
conn=actual_conn,
|
|
122
|
+
create=create,
|
|
123
|
+
update=update,
|
|
124
|
+
add_columns=add_columns,
|
|
125
|
+
only_column=only_column,
|
|
126
|
+
write=write,
|
|
127
|
+
start_time=start_time,
|
|
128
|
+
primary_key=primary_key,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
logger.info("总耗时 %.2fs", time.time() - start_time)
|
|
132
|
+
return count
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# =========================================================
|
|
136
|
+
# 连接适配
|
|
137
|
+
# =========================================================
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
async def _get_asyncpg_connection(conn: Any) -> asyncpg.Connection:
|
|
141
|
+
"""
|
|
142
|
+
从支持的数据库连接类型中提取 asyncpg 连接。
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
conn (Any):
|
|
146
|
+
asyncpg.Connection 或 SQLAlchemy AsyncSession。
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
asyncpg.Connection:
|
|
150
|
+
解析后的 asyncpg 连接。
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
ValueError:
|
|
154
|
+
不支持的连接类型。
|
|
155
|
+
"""
|
|
156
|
+
if isinstance(conn, asyncpg.Connection):
|
|
157
|
+
return conn
|
|
158
|
+
|
|
159
|
+
if isinstance(conn, AsyncSession):
|
|
160
|
+
try:
|
|
161
|
+
raw = await conn.connection()
|
|
162
|
+
return raw.driver_connection
|
|
163
|
+
except Exception as exc:
|
|
164
|
+
raise ValueError("无法从 AsyncSession 获取 asyncpg 连接") from exc
|
|
165
|
+
|
|
166
|
+
raise ValueError(f"不支持的连接类型: {type(conn)}")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# =========================================================
|
|
170
|
+
# Polars dtype → PostgreSQL 类型映射(唯一入口)
|
|
171
|
+
# =========================================================
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def polars_to_pg_type(dtype) -> str:
|
|
175
|
+
"""
|
|
176
|
+
将 Polars DataType 映射为 PostgreSQL 列类型。
|
|
177
|
+
|
|
178
|
+
这是整个模块中唯一的类型映射入口,
|
|
179
|
+
建表和新增字段都会调用此函数。
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
dtype:
|
|
183
|
+
Polars DataType 实例,例如 pl.Int64、pl.String 等。
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
str:
|
|
187
|
+
PostgreSQL 列类型字符串。
|
|
188
|
+
|
|
189
|
+
Raises:
|
|
190
|
+
TypeError:
|
|
191
|
+
当 dtype 无法映射时抛出。
|
|
192
|
+
"""
|
|
193
|
+
mapping = {
|
|
194
|
+
Int8: "SMALLINT",
|
|
195
|
+
Int16: "SMALLINT",
|
|
196
|
+
Int32: "INTEGER",
|
|
197
|
+
Int64: "BIGINT",
|
|
198
|
+
|
|
199
|
+
UInt8: "SMALLINT",
|
|
200
|
+
UInt16: "INTEGER",
|
|
201
|
+
UInt32: "BIGINT",
|
|
202
|
+
UInt64: "NUMERIC(20)",
|
|
203
|
+
|
|
204
|
+
Float32: "REAL",
|
|
205
|
+
Float64: "DOUBLE PRECISION",
|
|
206
|
+
|
|
207
|
+
Boolean: "BOOLEAN",
|
|
208
|
+
|
|
209
|
+
String: "TEXT",
|
|
210
|
+
|
|
211
|
+
Datetime: "TIMESTAMP",
|
|
212
|
+
Date: "DATE",
|
|
213
|
+
Time: "TIME",
|
|
214
|
+
|
|
215
|
+
Duration: "INTERVAL",
|
|
216
|
+
|
|
217
|
+
Decimal: "NUMERIC",
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
# List / Struct 需要特殊处理,不能直接映射为简单列
|
|
221
|
+
if dtype in (List, Struct):
|
|
222
|
+
raise TypeError(
|
|
223
|
+
f"Polars dtype {dtype} 无法映射为简单的 PG 列类型"
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
pg_type = mapping.get(dtype)
|
|
227
|
+
if pg_type is None:
|
|
228
|
+
raise TypeError(
|
|
229
|
+
f"不支持的 Polars dtype: {dtype}"
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
return pg_type
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
# =========================================================
|
|
236
|
+
# 主键解析
|
|
237
|
+
# =========================================================
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _resolve_primary_key(
|
|
241
|
+
df_columns: List[str],
|
|
242
|
+
primary_key: Optional[Union[str, List[str]]],
|
|
243
|
+
) -> List[str]:
|
|
244
|
+
"""
|
|
245
|
+
解析并校验用户指定的主键列名。
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
df_columns: DataFrame 的所有列名。
|
|
249
|
+
primary_key: 用户显式指定的主键,None 表示不设主键。
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
List[str]: 主键列名列表(primary_key 为 None 时返回空列表)。
|
|
253
|
+
|
|
254
|
+
Raises:
|
|
255
|
+
ValueError: 指定的主键列在 DataFrame 中不存在。
|
|
256
|
+
"""
|
|
257
|
+
if primary_key is None:
|
|
258
|
+
return []
|
|
259
|
+
|
|
260
|
+
if isinstance(primary_key, str):
|
|
261
|
+
pk_list = [primary_key]
|
|
262
|
+
else:
|
|
263
|
+
pk_list = list(primary_key)
|
|
264
|
+
|
|
265
|
+
missing = [c for c in pk_list if c not in df_columns]
|
|
266
|
+
if missing:
|
|
267
|
+
raise ValueError(
|
|
268
|
+
f"指定的主键列 {missing} 在 DataFrame 中不存在"
|
|
269
|
+
)
|
|
270
|
+
return pk_list
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
# =========================================================
|
|
274
|
+
# 核心写入逻辑
|
|
275
|
+
# =========================================================
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
async def _write_data(
|
|
279
|
+
df: pl.DataFrame,
|
|
280
|
+
original_schema,
|
|
281
|
+
schema: str,
|
|
282
|
+
table: str,
|
|
283
|
+
conn: asyncpg.Connection,
|
|
284
|
+
create: bool,
|
|
285
|
+
update: bool,
|
|
286
|
+
add_columns: bool,
|
|
287
|
+
only_column: bool,
|
|
288
|
+
write: str,
|
|
289
|
+
start_time: float,
|
|
290
|
+
primary_key: Optional[Union[str, List[str]]] = None,
|
|
291
|
+
) -> int:
|
|
292
|
+
"""
|
|
293
|
+
内部函数,实现主要写入流程。
|
|
294
|
+
"""
|
|
295
|
+
# ---- 查询目标表现在已有的字段 ----
|
|
296
|
+
db_columns_info = await conn.fetch(
|
|
297
|
+
"""
|
|
298
|
+
SELECT column_name, data_type
|
|
299
|
+
FROM information_schema.columns
|
|
300
|
+
WHERE table_schema = $1 AND table_name = $2
|
|
301
|
+
ORDER BY ordinal_position;
|
|
302
|
+
""",
|
|
303
|
+
schema,
|
|
304
|
+
table,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
db_columns = [r["column_name"] for r in db_columns_info]
|
|
308
|
+
db_column_types: Dict[str, str] = {
|
|
309
|
+
r["column_name"]: r["data_type"] for r in db_columns_info
|
|
310
|
+
}
|
|
311
|
+
table_exists = bool(db_columns)
|
|
312
|
+
|
|
313
|
+
# =========================
|
|
314
|
+
# 建表(使用 Polars 原始 dtype)
|
|
315
|
+
# =========================
|
|
316
|
+
if not table_exists:
|
|
317
|
+
if not create:
|
|
318
|
+
raise ValueError(f"表 {schema}.{table} 不存在,且 create=False")
|
|
319
|
+
|
|
320
|
+
cols_def = []
|
|
321
|
+
for col_name, dtype in original_schema.items():
|
|
322
|
+
pg_type = polars_to_pg_type(dtype)
|
|
323
|
+
cols_def.append(f'"{col_name}" {pg_type}')
|
|
324
|
+
db_column_types[col_name] = pg_type
|
|
325
|
+
|
|
326
|
+
# 确定主键列
|
|
327
|
+
resolved_pk = _resolve_primary_key(df.columns, primary_key)
|
|
328
|
+
|
|
329
|
+
pk_def = ""
|
|
330
|
+
if resolved_pk:
|
|
331
|
+
pk_cols_sql = ", ".join(f'"{c}"' for c in resolved_pk)
|
|
332
|
+
pk_def = f',\n PRIMARY KEY ({pk_cols_sql})'
|
|
333
|
+
|
|
334
|
+
create_sql = (
|
|
335
|
+
f'CREATE TABLE "{schema}"."{table}" (\n'
|
|
336
|
+
+ ",\n".join(f" {c}" for c in cols_def)
|
|
337
|
+
+ pk_def
|
|
338
|
+
+ "\n);"
|
|
339
|
+
)
|
|
340
|
+
await conn.execute(create_sql)
|
|
341
|
+
|
|
342
|
+
db_columns = list(df.columns)
|
|
343
|
+
logger.info(
|
|
344
|
+
"表 %s.%s 创建成功,字段数: %d%s",
|
|
345
|
+
schema,
|
|
346
|
+
table,
|
|
347
|
+
len(cols_def),
|
|
348
|
+
f",主键: {resolved_pk}" if resolved_pk else "",
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# =========================
|
|
352
|
+
# 字段对齐
|
|
353
|
+
# =========================
|
|
354
|
+
df_cols = list(df.columns)
|
|
355
|
+
|
|
356
|
+
if only_column:
|
|
357
|
+
common_cols = [c for c in df_cols if c in db_columns]
|
|
358
|
+
df = df.select(common_cols)
|
|
359
|
+
df_cols = common_cols
|
|
360
|
+
|
|
361
|
+
elif add_columns:
|
|
362
|
+
# 新增字段(使用 Polars 原始 dtype)
|
|
363
|
+
new_cols = [c for c in df_cols if c not in db_columns]
|
|
364
|
+
for col in new_cols:
|
|
365
|
+
pg_type = polars_to_pg_type(original_schema[col])
|
|
366
|
+
await conn.execute(
|
|
367
|
+
f'ALTER TABLE "{schema}"."{table}" ADD COLUMN "{col}" {pg_type}'
|
|
368
|
+
)
|
|
369
|
+
db_column_types[col] = pg_type
|
|
370
|
+
db_columns.append(col)
|
|
371
|
+
logger.info(
|
|
372
|
+
"新增字段 %s.%s.%s %s",
|
|
373
|
+
schema,
|
|
374
|
+
table,
|
|
375
|
+
col,
|
|
376
|
+
pg_type,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
else:
|
|
380
|
+
extra = set(df_cols) - set(db_columns)
|
|
381
|
+
if extra:
|
|
382
|
+
raise ValueError(
|
|
383
|
+
f"列 {extra} 不存在于目标表,且 add_columns=False"
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# =========================
|
|
387
|
+
# 主键校验
|
|
388
|
+
# =========================
|
|
389
|
+
pk_rows = await conn.fetch(
|
|
390
|
+
"""
|
|
391
|
+
SELECT kcu.column_name
|
|
392
|
+
FROM information_schema.table_constraints tc
|
|
393
|
+
JOIN information_schema.key_column_usage kcu
|
|
394
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
395
|
+
AND tc.table_schema = kcu.table_schema
|
|
396
|
+
WHERE tc.constraint_type = 'PRIMARY KEY'
|
|
397
|
+
AND tc.table_schema = $1
|
|
398
|
+
AND tc.table_name = $2
|
|
399
|
+
ORDER BY kcu.ordinal_position;
|
|
400
|
+
""",
|
|
401
|
+
schema,
|
|
402
|
+
table,
|
|
403
|
+
)
|
|
404
|
+
pk_cols = [r["column_name"] for r in pk_rows]
|
|
405
|
+
|
|
406
|
+
if pk_cols:
|
|
407
|
+
missing_pk = set(pk_cols) - set(df_cols)
|
|
408
|
+
if missing_pk:
|
|
409
|
+
raise ValueError(f"主键列缺失: {missing_pk}")
|
|
410
|
+
|
|
411
|
+
dup = df.group_by(pk_cols).agg(pl.len()).filter(pl.col("len") > 1)
|
|
412
|
+
if not dup.is_empty():
|
|
413
|
+
raise ValueError(f"主键重复:\n{dup.head(1)}")
|
|
414
|
+
|
|
415
|
+
# =========================
|
|
416
|
+
# 类型对齐:根据数据库实际列类型转换 DataFrame 值
|
|
417
|
+
# =========================
|
|
418
|
+
str_types = {"text", "varchar", "character varying", "char", "character", "name"}
|
|
419
|
+
for col in df_cols:
|
|
420
|
+
db_type = db_column_types.get(col, "").lower()
|
|
421
|
+
if db_type in str_types and not isinstance(df[col].dtype, pl.String):
|
|
422
|
+
df = df.with_columns(pl.col(col).cast(pl.String))
|
|
423
|
+
logger.debug("列 %s 值转换为 String(数据库类型: %s)", col, db_type)
|
|
424
|
+
|
|
425
|
+
# =========================
|
|
426
|
+
# 选择写入模式
|
|
427
|
+
# =========================
|
|
428
|
+
nrows, ncols = df.shape
|
|
429
|
+
if write == "auto":
|
|
430
|
+
# 根据数据量自动选择最优写入方式:
|
|
431
|
+
# - < 100 行:单行 execute(避免 executemany 协议开销)
|
|
432
|
+
# - 100 ~ 65535 单元格:批量 executemany
|
|
433
|
+
# - > 65535 单元格:COPY 协议(大数据量最优)
|
|
434
|
+
if nrows < 100:
|
|
435
|
+
write = "execute"
|
|
436
|
+
elif nrows * ncols > 65535:
|
|
437
|
+
write = "copy"
|
|
438
|
+
else:
|
|
439
|
+
write = "executemany"
|
|
440
|
+
|
|
441
|
+
# =========================
|
|
442
|
+
# 执行写入
|
|
443
|
+
# =========================
|
|
444
|
+
if write == "copy":
|
|
445
|
+
return await _write_copy(
|
|
446
|
+
df, schema, table, conn, df_cols, pk_cols, update
|
|
447
|
+
)
|
|
448
|
+
elif write == "executemany":
|
|
449
|
+
return await _write_executemany(
|
|
450
|
+
df, schema, table, conn, df_cols, pk_cols, update
|
|
451
|
+
)
|
|
452
|
+
elif write == "execute":
|
|
453
|
+
return await _write_execute(
|
|
454
|
+
df, schema, table, conn, df_cols, pk_cols, update
|
|
455
|
+
)
|
|
456
|
+
else:
|
|
457
|
+
raise ValueError(f"不支持的写入模式: {write}")
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
# =========================================================
|
|
461
|
+
# COPY 模式
|
|
462
|
+
# =========================================================
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
async def _write_copy(
|
|
466
|
+
df: pl.DataFrame,
|
|
467
|
+
schema: str,
|
|
468
|
+
table: str,
|
|
469
|
+
conn: asyncpg.Connection,
|
|
470
|
+
df_cols: List[str],
|
|
471
|
+
pk_cols: List[str],
|
|
472
|
+
update: bool,
|
|
473
|
+
) -> int:
|
|
474
|
+
"""
|
|
475
|
+
使用 PostgreSQL COPY 模式写入数据(适合大数据量)。
|
|
476
|
+
|
|
477
|
+
通过临时表中转:
|
|
478
|
+
1. 创建结构与目标表相同的临时表
|
|
479
|
+
2. 使用 COPY 协议批量灌入临时表
|
|
480
|
+
3. 从临时表 INSERT INTO 目标表(支持 ON CONFLICT)
|
|
481
|
+
|
|
482
|
+
注意:临时表在事务提交/回滚时自动清理。
|
|
483
|
+
"""
|
|
484
|
+
temp_table = f"_temp_{table}_{datetime.now():%Y%m%d%H%M%S%f}"
|
|
485
|
+
|
|
486
|
+
await conn.execute(
|
|
487
|
+
f'CREATE TEMP TABLE "{temp_table}" '
|
|
488
|
+
f'AS SELECT * FROM "{schema}"."{table}" WHERE 1=0'
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
# 流式写入 CSV 到 BytesIO,避免内存中同时存在字符串和字节副本
|
|
492
|
+
buf = io.BytesIO()
|
|
493
|
+
df.select(df_cols).write_csv(
|
|
494
|
+
file=buf,
|
|
495
|
+
separator="\t",
|
|
496
|
+
null_value=r"\N",
|
|
497
|
+
include_header=False,
|
|
498
|
+
quote_style="never",
|
|
499
|
+
)
|
|
500
|
+
buf.seek(0)
|
|
501
|
+
|
|
502
|
+
await conn.copy_to_table(
|
|
503
|
+
table_name=temp_table,
|
|
504
|
+
source=buf,
|
|
505
|
+
columns=df_cols,
|
|
506
|
+
format="csv",
|
|
507
|
+
delimiter="\t",
|
|
508
|
+
null=r"\N",
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
quoted_cols = [f'"{c}"' for c in df_cols]
|
|
512
|
+
|
|
513
|
+
if not pk_cols:
|
|
514
|
+
result = await conn.execute(
|
|
515
|
+
f'INSERT INTO "{schema}"."{table}" '
|
|
516
|
+
f'SELECT * FROM "{temp_table}"'
|
|
517
|
+
)
|
|
518
|
+
else:
|
|
519
|
+
if update:
|
|
520
|
+
update_set = ", ".join(
|
|
521
|
+
f'"{c}" = EXCLUDED."{c}"'
|
|
522
|
+
for c in df_cols
|
|
523
|
+
if c not in pk_cols
|
|
524
|
+
)
|
|
525
|
+
conflict_action = f"DO UPDATE SET {update_set}"
|
|
526
|
+
else:
|
|
527
|
+
conflict_action = "DO NOTHING"
|
|
528
|
+
|
|
529
|
+
result = await conn.execute(
|
|
530
|
+
f"""
|
|
531
|
+
INSERT INTO "{schema}"."{table}" ({", ".join(quoted_cols)})
|
|
532
|
+
SELECT {", ".join(quoted_cols)} FROM "{temp_table}"
|
|
533
|
+
ON CONFLICT ({", ".join(f'"{c}"' for c in pk_cols)})
|
|
534
|
+
{conflict_action}
|
|
535
|
+
"""
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
# asyncpg execute 返回格式为 "INSERT 0 N",取最后一个 token 即为行数
|
|
539
|
+
count = int(result.split()[-1])
|
|
540
|
+
logger.info("COPY 写入完成: %d 行", count)
|
|
541
|
+
return count
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
# =========================================================
|
|
545
|
+
# EXECUTEMANY 模式
|
|
546
|
+
# =========================================================
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
async def _write_executemany(
|
|
550
|
+
df: pl.DataFrame,
|
|
551
|
+
schema: str,
|
|
552
|
+
table: str,
|
|
553
|
+
conn: asyncpg.Connection,
|
|
554
|
+
df_cols: List[str],
|
|
555
|
+
pk_cols: List[str],
|
|
556
|
+
update: bool,
|
|
557
|
+
) -> int:
|
|
558
|
+
"""
|
|
559
|
+
使用 executemany 多行 INSERT 写入数据。
|
|
560
|
+
|
|
561
|
+
注意:asyncpg 的 executemany 不返回精确的受影响行数,
|
|
562
|
+
因此返回值是尝试写入的总行数(而非实际 INSERT/UPDATE 行数)。
|
|
563
|
+
"""
|
|
564
|
+
# 生成行元组列表,适配 asyncpg executemany
|
|
565
|
+
records = [tuple(row) for row in df.select(df_cols).iter_rows()]
|
|
566
|
+
|
|
567
|
+
placeholders = ", ".join(f"${i+1}" for i in range(len(df_cols)))
|
|
568
|
+
quoted_cols = [f'"{c}"' for c in df_cols]
|
|
569
|
+
|
|
570
|
+
if not pk_cols:
|
|
571
|
+
if update:
|
|
572
|
+
logger.warning(
|
|
573
|
+
"表 %s.%s 无主键,update=True 无效,数据可能重复插入",
|
|
574
|
+
schema,
|
|
575
|
+
table,
|
|
576
|
+
)
|
|
577
|
+
await conn.executemany(
|
|
578
|
+
f'INSERT INTO "{schema}"."{table}" '
|
|
579
|
+
f'({", ".join(quoted_cols)}) VALUES ({placeholders})',
|
|
580
|
+
records,
|
|
581
|
+
)
|
|
582
|
+
logger.info("executemany 写入完成: %d 行", len(records))
|
|
583
|
+
return len(records)
|
|
584
|
+
|
|
585
|
+
if update:
|
|
586
|
+
update_set = ", ".join(
|
|
587
|
+
f'"{c}" = EXCLUDED."{c}"'
|
|
588
|
+
for c in df_cols
|
|
589
|
+
if c not in pk_cols
|
|
590
|
+
)
|
|
591
|
+
conflict_action = f"DO UPDATE SET {update_set}"
|
|
592
|
+
else:
|
|
593
|
+
conflict_action = "DO NOTHING"
|
|
594
|
+
|
|
595
|
+
await conn.executemany(
|
|
596
|
+
f"""
|
|
597
|
+
INSERT INTO "{schema}"."{table}" ({", ".join(quoted_cols)})
|
|
598
|
+
VALUES ({placeholders})
|
|
599
|
+
ON CONFLICT ({", ".join(f'"{c}"' for c in pk_cols)})
|
|
600
|
+
{conflict_action}
|
|
601
|
+
""",
|
|
602
|
+
records,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
logger.info(
|
|
606
|
+
"executemany upsert 完成: 尝试写入 %d 行", len(records)
|
|
607
|
+
)
|
|
608
|
+
return len(records)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
# =========================================================
|
|
612
|
+
# EXECUTE 模式(单行,仅适合小数据量)
|
|
613
|
+
# =========================================================
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
async def _write_execute(
|
|
617
|
+
df: pl.DataFrame,
|
|
618
|
+
schema: str,
|
|
619
|
+
table: str,
|
|
620
|
+
conn: asyncpg.Connection,
|
|
621
|
+
df_cols: List[str],
|
|
622
|
+
pk_cols: List[str],
|
|
623
|
+
update: bool,
|
|
624
|
+
) -> int:
|
|
625
|
+
"""
|
|
626
|
+
使用单行 INSERT 写入数据(仅适合极少量数据,< 100 行)。
|
|
627
|
+
|
|
628
|
+
每行一个 asyncpg.execute 调用,串行执行,效率最低。
|
|
629
|
+
大批量数据请使用 copy 或 executemany 模式。
|
|
630
|
+
"""
|
|
631
|
+
count = 0
|
|
632
|
+
placeholders = ", ".join(f"${i+1}" for i in range(len(df_cols)))
|
|
633
|
+
quoted_cols = [f'"{c}"' for c in df_cols]
|
|
634
|
+
|
|
635
|
+
if not pk_cols:
|
|
636
|
+
if update:
|
|
637
|
+
logger.warning(
|
|
638
|
+
"表 %s.%s 无主键,update=True 无效,数据可能重复插入",
|
|
639
|
+
schema,
|
|
640
|
+
table,
|
|
641
|
+
)
|
|
642
|
+
sql = (
|
|
643
|
+
f'INSERT INTO "{schema}"."{table}" '
|
|
644
|
+
f'({", ".join(quoted_cols)}) VALUES ({placeholders})'
|
|
645
|
+
)
|
|
646
|
+
for row in df.iter_rows(named=True):
|
|
647
|
+
await conn.execute(sql, *(row[c] for c in df_cols))
|
|
648
|
+
count += 1
|
|
649
|
+
else:
|
|
650
|
+
if update:
|
|
651
|
+
update_set = ", ".join(
|
|
652
|
+
f'"{c}" = EXCLUDED."{c}"'
|
|
653
|
+
for c in df_cols
|
|
654
|
+
if c not in pk_cols
|
|
655
|
+
)
|
|
656
|
+
conflict_action = f"DO UPDATE SET {update_set}"
|
|
657
|
+
else:
|
|
658
|
+
conflict_action = "DO NOTHING"
|
|
659
|
+
|
|
660
|
+
sql = f"""
|
|
661
|
+
INSERT INTO "{schema}"."{table}" ({", ".join(quoted_cols)})
|
|
662
|
+
VALUES ({placeholders})
|
|
663
|
+
ON CONFLICT ({", ".join(f'"{c}"' for c in pk_cols)})
|
|
664
|
+
{conflict_action}
|
|
665
|
+
"""
|
|
666
|
+
|
|
667
|
+
for row in df.iter_rows(named=True):
|
|
668
|
+
result = await conn.execute(
|
|
669
|
+
sql, *(row[c] for c in df_cols)
|
|
670
|
+
)
|
|
671
|
+
# asyncpg execute 返回 "INSERT 0 N" 格式,解析受影响行数
|
|
672
|
+
try:
|
|
673
|
+
count += int(result.split()[-1])
|
|
674
|
+
except (ValueError, IndexError):
|
|
675
|
+
count += 1
|
|
676
|
+
|
|
677
|
+
logger.info("execute 写入完成: %d 行", count)
|
|
678
|
+
return count
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: pgsqldatatool
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Author: manji
|
|
6
|
+
Author-email: manji <pnsm@qq.com>
|
|
7
|
+
Requires-Dist: asyncpg>=0.31.0
|
|
8
|
+
Requires-Dist: polars>=1.41.2
|
|
9
|
+
Requires-Dist: python-dotenv>=1.2.2
|
|
10
|
+
Requires-Dist: sqlalchemy>=2.0.51
|
|
11
|
+
Requires-Dist: tzdata>=2026.2
|
|
12
|
+
Requires-Python: >=3.14
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
|
|
15
|
+
## 使用教程
|
|
16
|
+
|
|
17
|
+
### 安装
|
|
18
|
+
```python
|
|
19
|
+
pip install pgsqldatatool
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
### date_clean(数据清洗)
|
|
23
|
+
|
|
24
|
+
``` python
|
|
25
|
+
# 基础清洗 (去掉列名两端空白字符、去掉数据两段空白字符、删除空行)
|
|
26
|
+
lf_basic_clean(df)
|
|
27
|
+
|
|
28
|
+
# 根据字典重命名列
|
|
29
|
+
lf_rename_cols
|
|
30
|
+
|
|
31
|
+
# 移除重复行
|
|
32
|
+
lf_remove_dup_rows
|
|
33
|
+
|
|
34
|
+
# 删除指定列
|
|
35
|
+
lf_remove_cols
|
|
36
|
+
|
|
37
|
+
# 移除数字千分号
|
|
38
|
+
lf_remove_per_mille
|
|
39
|
+
|
|
40
|
+
# 移除数字百分号
|
|
41
|
+
lf_remove_percent
|
|
42
|
+
|
|
43
|
+
# 添加时间列
|
|
44
|
+
lf_add_time
|
|
45
|
+
|
|
46
|
+
# 删除完全为空的列 -- 这一步不是惰性计算,可能会降低性能
|
|
47
|
+
df_drop_empty_cols
|
|
48
|
+
```
|
|
49
|
+
示例
|
|
50
|
+
|
|
51
|
+
```python
|
|
52
|
+
from pgsqldatatool import data_clean as dc
|
|
53
|
+
df = pl.read_excel(r"D:\manji\Downloads\判断中国域名.xlsx")
|
|
54
|
+
df = dc.lf_basic_clean(df)
|
|
55
|
+
df = dc.lf_remove_cols(df,["域名22"])
|
|
56
|
+
df = dc.lf_remove_dup_rows(df)
|
|
57
|
+
df = dc.df_drop_empty_cols(df)
|
|
58
|
+
df = dc.lf_remove_dup_rows(df)
|
|
59
|
+
df = df.collect()
|
|
60
|
+
print(df)
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
### PoolSingleton(连接池单例模式)
|
|
64
|
+
|
|
65
|
+
归还连接(Release):把连接还给池子,让别的任务继续用。(async with 已经帮你自动做了)
|
|
66
|
+
关闭连接池(Close):彻底断开与数据库的所有连接,销毁这个池子。这通常只在整个程序/服务准备退出时才需要做。
|
|
67
|
+
```python
|
|
68
|
+
from pgsqldatatool import PoolSingleton
|
|
69
|
+
|
|
70
|
+
# 示例1:使用异步连接池
|
|
71
|
+
async def main_test_1():
|
|
72
|
+
|
|
73
|
+
# 调用方式1,
|
|
74
|
+
async with PoolSingleton.acquire() as conn:
|
|
75
|
+
records = await conn.fetch(""" SELECT * FROM public."test20260211" """)
|
|
76
|
+
print(records)
|
|
77
|
+
|
|
78
|
+
# 销毁整个连接池
|
|
79
|
+
await PoolSingleton.close()
|
|
80
|
+
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
|
|
86
|
+
# 示例2:使用静态方法
|
|
87
|
+
async def main_test_2():
|
|
88
|
+
# 调用方式2:使用静态方法
|
|
89
|
+
records2 = await PoolSingleton.fetch(""" SELECT * FROM public."test20260211" """)
|
|
90
|
+
print(records2)
|
|
91
|
+
|
|
92
|
+
# 调用方式3:使用静态方法
|
|
93
|
+
records3 = await PoolSingleton.fetchrow(""" SELECT * FROM public."test20260211" """)
|
|
94
|
+
print(records3)
|
|
95
|
+
|
|
96
|
+
# 调用方式4:使用静态方法
|
|
97
|
+
records4 = await PoolSingleton.execute(""" SELECT * FROM public."test20260211" """)
|
|
98
|
+
print(records4)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
if __name__ == "__main__":
|
|
102
|
+
asyncio.run(main_test_1())
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
## 创建异步数据库链接
|
|
107
|
+
|
|
108
|
+
### 推荐方案:asyncpg(性能最好,原生 asyncio)
|
|
109
|
+
适合:FastAPI / aiohttp / asyncio 项目
|
|
110
|
+
```python
|
|
111
|
+
pip install asyncpg
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
```python
|
|
115
|
+
import asyncpg
|
|
116
|
+
import asyncio
|
|
117
|
+
|
|
118
|
+
async def main():
|
|
119
|
+
# 创建连接
|
|
120
|
+
conn = await asyncpg.connect(
|
|
121
|
+
host="localhost",
|
|
122
|
+
port=5432,
|
|
123
|
+
user="postgres",
|
|
124
|
+
password="password",
|
|
125
|
+
database="testdb"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# 查询
|
|
129
|
+
row = await conn.fetchrow("SELECT NOW()")
|
|
130
|
+
print(row)
|
|
131
|
+
|
|
132
|
+
# 参数化查询
|
|
133
|
+
rows = await conn.fetch(
|
|
134
|
+
"SELECT * FROM users WHERE age > $1",
|
|
135
|
+
18
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
await conn.close()
|
|
139
|
+
|
|
140
|
+
asyncio.run(main())
|
|
141
|
+
```
|
|
142
|
+
|
|
143
|
+
### 使用连接池(生产必选 ✅)
|
|
144
|
+
|
|
145
|
+
```python
|
|
146
|
+
import asyncpg
|
|
147
|
+
import asyncio
|
|
148
|
+
|
|
149
|
+
async def get_pool():
|
|
150
|
+
return await asyncpg.create_pool(
|
|
151
|
+
dsn="postgresql://postgres:password@localhost:5432/testdb",
|
|
152
|
+
min_size=5,
|
|
153
|
+
max_size=20
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
async def main():
|
|
157
|
+
pool = await get_pool()
|
|
158
|
+
|
|
159
|
+
async with pool.acquire() as conn:
|
|
160
|
+
result = await conn.fetch("SELECT * FROM users")
|
|
161
|
+
|
|
162
|
+
await pool.close()
|
|
163
|
+
|
|
164
|
+
asyncio.run(main())
|
|
165
|
+
```
|
|
166
|
+
|
|
167
|
+
事务
|
|
168
|
+
```python
|
|
169
|
+
async with conn.transaction():
|
|
170
|
+
await conn.execute(
|
|
171
|
+
"INSERT INTO users(name) VALUES($1)",
|
|
172
|
+
"Alice"
|
|
173
|
+
)
|
|
174
|
+
```
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
### SQLAlchemy 2.0 + asyncpg(ORM 场景)
|
|
178
|
+
适合:需要 ORM、多数据库兼容
|
|
179
|
+
|
|
180
|
+
```python
|
|
181
|
+
pip install sqlalchemy[asyncio] asyncpg
|
|
182
|
+
```
|
|
183
|
+
|
|
184
|
+
```python
|
|
185
|
+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|
186
|
+
from sqlalchemy.orm import sessionmaker
|
|
187
|
+
from sqlalchemy import select
|
|
188
|
+
|
|
189
|
+
engine = create_async_engine(
|
|
190
|
+
"postgresql+asyncpg://postgres:password@localhost/testdb",
|
|
191
|
+
pool_size=10,
|
|
192
|
+
max_overflow=20
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
AsyncSessionLocal = sessionmaker(
|
|
196
|
+
engine, class_=AsyncSession, expire_on_commit=False
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
async def query_users():
|
|
200
|
+
async with AsyncSessionLocal() as session:
|
|
201
|
+
result = await session.execute(select(User))
|
|
202
|
+
return result.scalars().all()
|
|
203
|
+
```
|
|
204
|
+
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
pgsqldatatool/__init__.py,sha256=Lu4cRHywCzTHUME7b-huEGa23FnCofebSgf9CrJA56Q,217
|
|
2
|
+
pgsqldatatool/data_clean.py,sha256=3ZQoXsoFkZE_GyKcW1CGTDDIscIwQwdeaUHUdtlmJec,7572
|
|
3
|
+
pgsqldatatool/pgsql_connection_async.py,sha256=aQxVR_vPpSFV_ozmnfzWF2-bIF_PB_fKtXK0Dmg-r5A,2609
|
|
4
|
+
pgsqldatatool/tools.py,sha256=e5gqB3wsQoQ910e8JV2yfrLH301sLnh68Q7pz7VIZbk,741
|
|
5
|
+
pgsqldatatool/until_async.py,sha256=pE1gmG1VCVg4jwa7ulu-PvO1OH8sNnNc2pKqrwYS1Rc,20464
|
|
6
|
+
pgsqldatatool-1.0.0.dist-info/WHEEL,sha256=9sjN42GvvIkyGb9JrWAWXnA96E2dxDe0tzHzrLxUlD4,81
|
|
7
|
+
pgsqldatatool-1.0.0.dist-info/entry_points.txt,sha256=58ZLpCEz7tm5xblywFKsSK6i_BeLZOsjQtQLpiOU8rY,54
|
|
8
|
+
pgsqldatatool-1.0.0.dist-info/METADATA,sha256=nj0mKbKmMe_M7uyUTIbFQEZgbqyDOx3YfGTH1zASp38,4657
|
|
9
|
+
pgsqldatatool-1.0.0.dist-info/RECORD,,
|