ygo 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ygo might be problematic. Click here for more details.

ycat/__init__.py ADDED
@@ -0,0 +1,22 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ ---------------------------------------------
4
+ Created on 2025/5/14 18:29
5
+ @author: ZhangYundi
6
+ @email: yundi.xxii@outlook.com
7
+ ---------------------------------------------
8
+ """
9
+
10
+ from .client import HOME, CATDB, SETTINGS, sql, put, create_engine_ck, create_engine_mysql, read_mysql, read_ck
11
+
12
+ __all__ = [
13
+ "HOME",
14
+ "CATDB",
15
+ "SETTINGS",
16
+ "sql",
17
+ "put",
18
+ "create_engine_ck",
19
+ "create_engine_mysql",
20
+ "read_mysql",
21
+ "read_ck",
22
+ ]
ycat/client.py ADDED
@@ -0,0 +1,157 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ ---------------------------------------------
4
+ Created on 2024/7/1 09:44
5
+ @author: ZhangYundi
6
+ @email: yundi.xxii@outlook.com
7
+ ---------------------------------------------
8
+ """
9
+ import os
10
+ import re
11
+ from typing import Optional
12
+ from .yck import connect, query_polars
13
+
14
+ import duckdb
15
+ import polars as pl
16
+ import ylog
17
+ from dynaconf import Dynaconf
18
+ from sqlalchemy import create_engine
19
+ from functools import partial
20
+
21
+ from .parse import extract_table_names_from_sql
22
+
23
+ # 配置文件在 “~/.catdb/setting.toml”
24
+ USERHOME = os.path.expanduser('~') # 用户家目录
25
+ CONFIG_PATH = os.path.join(USERHOME, ".catdb", "settings.toml")
26
+ if not os.path.exists(CONFIG_PATH):
27
+ try:
28
+ os.makedirs(os.path.dirname(CONFIG_PATH))
29
+ except FileExistsError as e:
30
+ ...
31
+ except Exception as e:
32
+ ylog.error(f"配置文件生成失败: {e}")
33
+ catdb_path = os.path.join(USERHOME, "catdb")
34
+ template_content = f"""[paths]
35
+ catdb="{catdb_path}" # 本地数据库,默认家目录
36
+
37
+ ## 数据库配置:
38
+ [database]
39
+ [database.ck]
40
+ # urls=["<host1>:<port1>", "<host2>:<port2>",]
41
+ # user="xxx"
42
+ # password="xxxxxx"
43
+ [database.jy]
44
+ # url="<host>:<port>"
45
+ # user="xxxx"
46
+ # password="xxxxxx"
47
+
48
+ ## 视情况自由增加其他配置
49
+ """
50
+ with open(CONFIG_PATH, "w") as f:
51
+ f.write(template_content)
52
+ ylog.info(f"生成配置文件: {CONFIG_PATH}")
53
+
54
+
55
+ def get_settings():
56
+ try:
57
+ return Dynaconf(settings_files=[CONFIG_PATH])
58
+ except:
59
+ return
60
+
61
+
62
+ HOME = USERHOME
63
+ CATDB = os.path.join(HOME, "catdb")
64
+ # 读取配置文件覆盖
65
+ SETTINGS = get_settings()
66
+ if SETTINGS is not None:
67
+ CATDB = SETTINGS.paths.catdb
68
+
69
+
70
+ # ======================== 本地数据库 catdb ========================
71
+ def tb_path(tb_name: str) -> str:
72
+ """
73
+ 返回指定表名 完整的本地路径
74
+ Parameters
75
+ ----------
76
+ tb_name: str
77
+ 表名,路径写法: a/b/c
78
+ Returns
79
+ -------
80
+ full_abs_path: str
81
+ 完整的本地绝对路径 $HOME/catdb/a/b/c
82
+ """
83
+ return os.path.join(CATDB, tb_name)
84
+
85
+ def put(df: pl.DataFrame, tb_name: str, partitions: Optional[list[str]] = None):
86
+ """
87
+ 将数据写入duck_db支持的parquet格式文件
88
+ Parameters
89
+ ----------
90
+ df: pandas.DataFrame | pandas.Series | polars.DataFrame
91
+ 写入的数据
92
+ tb_name: str
93
+ 表名,支持路径写法, a/b/c
94
+ partitions: Optional[List[str]]
95
+ 根据哪些字段进行分区,默认不分区
96
+ """
97
+ tbpath = tb_path(tb_name)
98
+ if not os.path.exists(tbpath):
99
+ try:
100
+ os.makedirs(tbpath)
101
+ except FileExistsError as e:
102
+ pass
103
+ if partitions is not None:
104
+ for field in partitions:
105
+ assert field in df.columns, f'dataframe must have Field `{field}`'
106
+ df.write_parquet(tbpath, partition_by=partitions)
107
+
108
+
109
+ def sql(query: str):
110
+ """
111
+ 从duckdb中读取数据, query语法与mysql一致,特殊语法请查duckdb官网: https://duckdb.org/docs/sql/query_syntax/select
112
+ Parameters
113
+ ----------
114
+ query: str
115
+ 查询语句
116
+ Returns
117
+ -------
118
+ result: duckdb.DuckDBPyRelation
119
+ 查询结果
120
+ """
121
+ tbs = extract_table_names_from_sql(query)
122
+ convertor = dict()
123
+ for tb in tbs:
124
+ db_path = tb_path(tb)
125
+ format_tb = f"read_parquet('{db_path}/**/*.parquet', hive_partitioning = true)"
126
+ convertor[tb] = format_tb
127
+ pattern = re.compile("|".join(re.escape(k) for k in convertor.keys()))
128
+ new_query = pattern.sub(lambda m: convertor[m.group(0)], query)
129
+ conn = duckdb.connect()
130
+ conn.execute("PRAGMA disable_progress_bar;")
131
+ conn.execute("PRAGMA threads=1;")
132
+ res = conn.execute(new_query).fetch_arrow_table()
133
+ res = pl.from_arrow(res)
134
+ conn.close()
135
+ return res
136
+
137
+ def create_engine_ck(urls: list[str], user: str, password: str):
138
+ return partial(connect, urls, user, password)
139
+
140
+ def read_ck(sql, eng)->pl.DataFrame:
141
+ with eng() as conn:
142
+ return query_polars(sql, conn)
143
+
144
+ def create_engine_mysql(url, user, password, database):
145
+ """
146
+ :param url: <host>:<port>
147
+ :param user:
148
+ :param password:
149
+ :param database:
150
+ :return:
151
+ """
152
+ engine = create_engine(f"mysql+pymysql://{user}:{password}@{url}/{database}")
153
+ return engine
154
+
155
+ def read_mysql(sql, eng) -> pl.DataFrame:
156
+ with eng.connect() as conn:
157
+ return pl.read_database(sql, conn)
ycat/dtype.py ADDED
@@ -0,0 +1,389 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ ---------------------------------------------
4
+ Created on 2024/11/4 下午1:20
5
+ @author: ZhangYundi
6
+ @email: yundi.xxii@outlook.com
7
+ ---------------------------------------------
8
+ """
9
+ import functools
10
+ import re
11
+ from typing import Any
12
+ import pyarrow as pa
13
+ import re # 正则解析 Decimal 类型
14
+
15
+ from polars._typing import PolarsDataType
16
+ from polars.datatypes import (
17
+ Binary,
18
+ Boolean,
19
+ Date,
20
+ Datetime,
21
+ Decimal,
22
+ Duration,
23
+ Float32,
24
+ Float64,
25
+ Int8,
26
+ Int16,
27
+ Int32,
28
+ Int64,
29
+ List,
30
+ Null,
31
+ String,
32
+ Time,
33
+ UInt8,
34
+ UInt16,
35
+ UInt32,
36
+ UInt64,
37
+ )
38
+
39
+
40
+ @functools.lru_cache(8)
41
+ def integer_dtype_from_nbits(
42
+ bits: int,
43
+ *,
44
+ unsigned: bool,
45
+ default: PolarsDataType | None = None,
46
+ ) -> PolarsDataType | None:
47
+ """
48
+ Return matching Polars integer dtype from num bits and signed/unsigned flag.
49
+
50
+ Examples
51
+ --------
52
+ >>> integer_dtype_from_nbits(8, unsigned=False)
53
+ Int8
54
+ >>> integer_dtype_from_nbits(32, unsigned=True)
55
+ UInt32
56
+ """
57
+ dtype = {
58
+ (8, False): Int8,
59
+ (8, True): UInt8,
60
+ (16, False): Int16,
61
+ (16, True): UInt16,
62
+ (32, False): Int32,
63
+ (32, True): UInt32,
64
+ (64, False): Int64,
65
+ (64, True): UInt64,
66
+ }.get((bits, unsigned), None)
67
+
68
+ if dtype is None and default is not None:
69
+ return default
70
+ return dtype
71
+
72
+
73
+ def timeunit_from_precision(precision: int | str | None) -> str | None:
74
+ """
75
+ Return `time_unit` from integer precision value.
76
+
77
+ Examples
78
+ --------
79
+ >>> timeunit_from_precision(3)
80
+ 'ms'
81
+ >>> timeunit_from_precision(5)
82
+ 'us'
83
+ >>> timeunit_from_precision(7)
84
+ 'ns'
85
+ """
86
+ from math import ceil
87
+
88
+ if not precision:
89
+ return None
90
+ elif isinstance(precision, str):
91
+ if precision.isdigit():
92
+ precision = int(precision)
93
+ elif (precision := precision.lower()) in ("s", "ms", "us", "ns"):
94
+ return "ms" if precision == "s" else precision
95
+ try:
96
+ n = min(max(3, int(ceil(precision / 3)) * 3), 9) # type: ignore[operator]
97
+ return {3: "ms", 6: "us", 9: "ns"}.get(n)
98
+ except TypeError:
99
+ return None
100
+
101
+
102
+ def infer_dtype_from_database_typename(
103
+ value: str,
104
+ *,
105
+ raise_unmatched: bool = True,
106
+ ) -> PolarsDataType | None:
107
+ """
108
+ Attempt to infer Polars dtype from database cursor `type_code` string value.
109
+
110
+ Examples
111
+ --------
112
+ >>> infer_dtype_from_database_typename("INT2")
113
+ Int16
114
+ >>> infer_dtype_from_database_typename("NVARCHAR")
115
+ String
116
+ >>> infer_dtype_from_database_typename("NUMERIC(10,2)")
117
+ Decimal(precision=10, scale=2)
118
+ >>> infer_dtype_from_database_typename("TIMESTAMP WITHOUT TZ")
119
+ Datetime(time_unit='us', time_zone=None)
120
+ """
121
+ dtype: PolarsDataType | None = None
122
+
123
+ # normalise string name/case (eg: 'IntegerType' -> 'INTEGER')
124
+ original_value = value
125
+ value = value.upper().replace("TYPE", "")
126
+
127
+ # extract optional type modifier (eg: 'VARCHAR(64)' -> '64')
128
+ if re.search(r"\([\w,: ]+\)$", value):
129
+ modifier = value[value.find("(") + 1: -1]
130
+ value = value.split("(")[0]
131
+ # Nullable type
132
+ if value.upper() == "NULLABLE":
133
+ return infer_dtype_from_database_typename(modifier)
134
+ elif (
135
+ not value.startswith(("<", ">")) and re.search(r"\[[\w,\]\[: ]+]$", value)
136
+ ) or value.endswith(("[S]", "[MS]", "[US]", "[NS]")):
137
+ modifier = value[value.find("[") + 1: -1]
138
+ value = value.split("[")[0]
139
+ else:
140
+ modifier = ""
141
+
142
+ # array dtypes
143
+ array_aliases = ("ARRAY", "LIST", "[]")
144
+ if value.endswith(array_aliases) or value.startswith(array_aliases):
145
+ for a in array_aliases:
146
+ value = value.replace(a, "", 1) if value else ""
147
+
148
+ nested: PolarsDataType | None = None
149
+ if not value and modifier:
150
+ nested = infer_dtype_from_database_typename(
151
+ value=modifier,
152
+ raise_unmatched=False,
153
+ )
154
+ else:
155
+ if inner_value := infer_dtype_from_database_typename(
156
+ value[1:-1]
157
+ if (value[0], value[-1]) == ("<", ">")
158
+ else re.sub(r"\W", "", re.sub(r"\WOF\W", "", value)),
159
+ raise_unmatched=False,
160
+ ):
161
+ nested = inner_value
162
+ elif modifier:
163
+ nested = infer_dtype_from_database_typename(
164
+ value=modifier,
165
+ raise_unmatched=False,
166
+ )
167
+ if nested:
168
+ dtype = List(nested)
169
+
170
+ # float dtypes
171
+ elif value.startswith("FLOAT") or ("DOUBLE" in value) or (value == "REAL"):
172
+ dtype = (
173
+ Float32
174
+ if value == "FLOAT4"
175
+ or (value.endswith(("16", "32")) or (modifier in ("16", "32")))
176
+ else Float64
177
+ )
178
+
179
+ # integer dtypes
180
+ elif ("INTERVAL" not in value) and (
181
+ value.startswith(("INT", "UINT", "UNSIGNED"))
182
+ or value.endswith(("INT", "SERIAL"))
183
+ or ("INTEGER" in value)
184
+ or value == "ROWID"
185
+ ):
186
+ sz: Any
187
+ if "LARGE" in value or value.startswith("BIG") or value == "INT8":
188
+ sz = 64
189
+ elif "MEDIUM" in value or value in ("INT4", "SERIAL"):
190
+ sz = 32
191
+ elif "SMALL" in value or value == "INT2":
192
+ sz = 16
193
+ elif "TINY" in value:
194
+ sz = 8
195
+ else:
196
+ sz = None
197
+
198
+ sz = modifier if (not sz and modifier) else sz
199
+ if not isinstance(sz, int):
200
+ sz = int(sz) if isinstance(sz, str) and sz.isdigit() else None
201
+ if (
202
+ ("U" in value and "MEDIUM" not in value)
203
+ or ("UNSIGNED" in value)
204
+ or value == "ROWID"
205
+ ):
206
+ dtype = integer_dtype_from_nbits(sz, unsigned=True, default=UInt64)
207
+ else:
208
+ dtype = integer_dtype_from_nbits(sz, unsigned=False, default=Int64)
209
+
210
+ # number types (note: 'number' alone is not that helpful and requires refinement)
211
+ elif "NUMBER" in value and "CARDINAL" in value:
212
+ dtype = UInt64
213
+
214
+ # decimal dtypes
215
+ elif (is_dec := ("DECIMAL" in value)) or ("NUMERIC" in value):
216
+ if "," in modifier:
217
+ prec, scale = modifier.split(",")
218
+ dtype = Decimal(int(prec), int(scale))
219
+ else:
220
+ dtype = Decimal if is_dec else Float64
221
+
222
+ # string dtypes
223
+ elif (
224
+ any(tp in value for tp in ("VARCHAR", "STRING", "TEXT", "UNICODE"))
225
+ or value.startswith(("STR", "CHAR", "BPCHAR", "NCHAR", "UTF"))
226
+ or value.endswith(("_UTF8", "_UTF16", "_UTF32"))
227
+ ):
228
+ dtype = String
229
+
230
+ # binary dtypes
231
+ elif value in ("BYTEA", "BYTES", "BLOB", "CLOB", "BINARY"):
232
+ dtype = Binary
233
+
234
+ # boolean dtypes
235
+ elif value.startswith("BOOL"):
236
+ dtype = Boolean
237
+
238
+ # null dtype; odd, but valid
239
+ elif value == "NULL":
240
+ dtype = Null
241
+
242
+ # temporal dtypes
243
+ elif value.startswith(("DATETIME", "TIMESTAMP")) and not (value.endswith("[D]")):
244
+ if any((tz in value.replace(" ", "")) for tz in ("TZ", "TIMEZONE")):
245
+ if "WITHOUT" not in value:
246
+ return None # there's a timezone, but we don't know what it is
247
+ unit = timeunit_from_precision(modifier) if modifier else "us"
248
+ dtype = Datetime(time_unit=(unit or "us")) # type: ignore[arg-type]
249
+ else:
250
+ value = re.sub(r"\d", "", value)
251
+ if value in ("INTERVAL", "TIMEDELTA", "DURATION"):
252
+ dtype = Duration
253
+ elif value == "DATE":
254
+ dtype = Date
255
+ elif value == "TIME":
256
+ dtype = Time
257
+
258
+ if not dtype and raise_unmatched:
259
+ msg = f"cannot infer dtype from {original_value!r} string value"
260
+ raise ValueError(msg)
261
+
262
+ return dtype
263
+
264
+ CLICKHOUSE_TO_ARROW_TYPE = {
265
+ # 整数类型
266
+ 'Int8': pa.int8(),
267
+ 'Int16': pa.int16(),
268
+ 'Int32': pa.int32(),
269
+ 'Int64': pa.int64(),
270
+ 'UInt8': pa.uint8(),
271
+ 'UInt16': pa.uint16(),
272
+ 'UInt32': pa.uint32(),
273
+ 'UInt64': pa.uint64(),
274
+
275
+ # 浮点类型
276
+ 'Float32': pa.float32(),
277
+ 'Float64': pa.float64(),
278
+
279
+ # 字符串类型
280
+ 'String': pa.string(),
281
+ 'FixedString': pa.string(), # Arrow 不区分固定长度和动态长度字符串
282
+
283
+ # 日期和时间类型
284
+ 'Date': pa.date32(), # ClickHouse 的 Date 是 32 位(天)
285
+ 'Date32': pa.date32(),
286
+ 'DateTime': pa.timestamp('s'), # ClickHouse DateTime 精度为秒
287
+ 'DateTime64': pa.timestamp('ms'), # 默认映射为毫秒精度(可根据需求调整)
288
+ 'UUID': pa.binary(16), # UUID 是 16 字节的二进制
289
+
290
+ # 布尔类型
291
+ 'Boolean': pa.bool_(),
292
+
293
+ # 数组类型(嵌套类型)
294
+ 'Array(Int8)': pa.list_(pa.int8()),
295
+ 'Array(Int16)': pa.list_(pa.int16()),
296
+ 'Array(Int32)': pa.list_(pa.int32()),
297
+ 'Array(Int64)': pa.list_(pa.int64()),
298
+ 'Array(UInt8)': pa.list_(pa.uint8()),
299
+ 'Array(UInt16)': pa.list_(pa.uint16()),
300
+ 'Array(UInt32)': pa.list_(pa.uint32()),
301
+ 'Array(UInt64)': pa.list_(pa.uint64()),
302
+ 'Array(Float32)': pa.list_(pa.float32()),
303
+ 'Array(Float64)': pa.list_(pa.float64()),
304
+ 'Array(String)': pa.list_(pa.string()),
305
+ 'Array(Date)': pa.list_(pa.date32()),
306
+ 'Array(DateTime)': pa.list_(pa.timestamp('s')),
307
+
308
+ # 嵌套类型(元组、枚举等)
309
+ # 注意:Arrow 不直接支持 Tuple,通常需要转换为 Struct
310
+ 'Tuple': pa.struct([]), # 需要动态定义每个字段的类型
311
+ # 枚举类型
312
+ 'Enum8': pa.string(), # 通常映射为字符串
313
+ 'Enum16': pa.string(),
314
+
315
+ # Map 类型
316
+ 'Map': pa.map_(pa.string(), pa.string()), # 默认键值对是字符串(可根据需求调整)
317
+
318
+ # Nullable 类型(ClickHouse 的 Nullable 包装类型)
319
+ 'Nullable(Int8)': pa.int8(),
320
+ 'Nullable(Int16)': pa.int16(),
321
+ 'Nullable(Int32)': pa.int32(),
322
+ 'Nullable(Int64)': pa.int64(),
323
+ 'Nullable(UInt8)': pa.uint8(),
324
+ 'Nullable(UInt16)': pa.uint16(),
325
+ 'Nullable(UInt32)': pa.uint32(),
326
+ 'Nullable(UInt64)': pa.uint64(),
327
+ 'Nullable(Float32)': pa.float32(),
328
+ 'Nullable(Float64)': pa.float64(),
329
+ 'Nullable(String)': pa.string(),
330
+ 'Nullable(Date)': pa.date32(),
331
+ 'Nullable(DateTime)': pa.timestamp('s'),
332
+ 'Nullable(UUID)': pa.binary(16),
333
+ }
334
+
335
+ def map_clickhouse_decimal(ch_type: str) -> pa.DataType:
336
+ """
337
+ 映射 ClickHouse 的 Decimal 类型到 Arrow 的 Decimal 类型
338
+ :param ch_type: ClickHouse 的 Decimal 类型描述,例如 'Decimal(10, 2)' 或 'Decimal128(38)'
339
+ :return: 对应的 Arrow Decimal 类型
340
+ """
341
+ # 匹配 ClickHouse 的 Decimal(p, s) 格式
342
+ decimal_match = re.match(r"Decimal(?:32|64|128)?\((\d+),\s*(\d+)\)", ch_type)
343
+ if decimal_match:
344
+ precision, scale = map(int, decimal_match.groups())
345
+ return pa.decimal128(precision, scale)
346
+
347
+ # 匹配 ClickHouse 的 Decimal(p) 格式,默认 scale 为 0
348
+ decimal_match_no_scale = re.match(r"Decimal(?:32|64|128)?\((\d+)\)", ch_type)
349
+ if decimal_match_no_scale:
350
+ precision = int(decimal_match_no_scale.group(1))
351
+ return pa.decimal128(precision, 0)
352
+
353
+ # 如果不匹配,抛出异常
354
+ raise ValueError(f"Unsupported ClickHouse Decimal type: {ch_type}")
355
+
356
+ def map_clickhouse_to_arrow(ch_type: str) -> pa.DataType:
357
+ """
358
+ 动态映射 ClickHouse 类型到 Arrow 类型
359
+ """
360
+ # 基础类型直接映射
361
+ if ch_type in CLICKHOUSE_TO_ARROW_TYPE:
362
+ return CLICKHOUSE_TO_ARROW_TYPE[ch_type]
363
+
364
+ # Decimal 类型处理
365
+ if ch_type.startswith("Decimal"):
366
+ return map_clickhouse_decimal(ch_type)
367
+
368
+ # 动态处理 Array 类型
369
+ if ch_type.startswith('Array('):
370
+ inner_type = ch_type[6:-1] # 提取 Array 内的类型
371
+ return pa.list_(map_clickhouse_to_arrow(inner_type))
372
+
373
+ # 动态处理 Nullable 类型
374
+ if ch_type.startswith('Nullable('):
375
+ inner_type = ch_type[9:-1] # 提取 Nullable 内的类型
376
+ return map_clickhouse_to_arrow(inner_type)
377
+
378
+ # 动态处理 Tuple 类型
379
+ if ch_type.startswith('Tuple('):
380
+ inner_types = ch_type[6:-1].split(',') # 提取 Tuple 内的字段类型
381
+ return pa.struct([('field' + str(i), map_clickhouse_to_arrow(t.strip())) for i, t in enumerate(inner_types)])
382
+
383
+ # 动态处理 Map 类型
384
+ if ch_type.startswith('Map('):
385
+ key_type, value_type = ch_type[4:-1].split(',')
386
+ return pa.map_(map_clickhouse_to_arrow(key_type.strip()), map_clickhouse_to_arrow(value_type.strip()))
387
+
388
+ raise ValueError(f"Unsupported ClickHouse type: {ch_type}")
389
+
ycat/parse.py ADDED
@@ -0,0 +1,66 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ ---------------------------------------------
4
+ Created on 2024/11/6 下午7:25
5
+ @author: ZhangYundi
6
+ @email: yundi.xxii@outlook.com
7
+ ---------------------------------------------
8
+ """
9
+ import sqlparse
10
+ import re
11
+
12
+ def format_sql(sql_content):
13
+ """将sql语句进行规范化,并去除sql中的注释,输入和输出均为字符串"""
14
+ parse_str = sqlparse.format(sql_content, reindent=True, strip_comments=True)
15
+ return parse_str
16
+
17
+ def extract_temp_tables(with_clause):
18
+ """从WITH子句中提取临时表名,输出为列表"""
19
+ temp_tables = re.findall(r'\b(\w+)\s*as\s*\(', with_clause, re.IGNORECASE)
20
+ return temp_tables
21
+
22
+ def extract_table_names_from_sql(sql_query):
23
+ """从sql中提取对应的表名称,输出为列表"""
24
+ table_names = set()
25
+ # 解析SQL语句
26
+ parsed = sqlparse.parse(sql_query)
27
+ # 正则表达式模式,用于匹配表名
28
+ table_name_pattern = r'\bFROM\s+([^\s\(\)\,]+)|\bJOIN\s+([^\s\(\)\,]+)'
29
+
30
+ # 用于存储WITH子句中的临时表名
31
+ remove_with_name = []
32
+
33
+ # 遍历解析后的语句块
34
+ for statement in parsed:
35
+ # 转换为字符串
36
+ statement_str = str(statement).lower()
37
+
38
+ # 将字符串中的特殊语法置空
39
+ statement_str = re.sub(r'(substring|extract)\s*\(((.|\s)*?)\)', '', statement_str)
40
+
41
+ # 查找匹配的表名
42
+ matches = re.findall(table_name_pattern, statement_str, re.IGNORECASE)
43
+
44
+ for match in matches:
45
+ # 提取非空的表名部分
46
+ for name in match:
47
+ if name:
48
+ # 对于可能包含命名空间的情况,只保留最后一部分作为表名
49
+ table_name = name.split('.')[-1]
50
+ # 去除表名中的特殊符号
51
+ table_name = re.sub(r'("|`|\'|;)', '', table_name)
52
+ table_names.add(table_name)
53
+
54
+ # 处理特殊的WITH语句
55
+ if 'with' in statement_str:
56
+ remove_with_name = extract_temp_tables(statement_str)
57
+ # 移除多余的表名
58
+ if remove_with_name:
59
+ table_names = list(set(table_names) - set(remove_with_name))
60
+
61
+ return table_names
62
+
63
+
64
+ if __name__ == '__main__':
65
+ print(extract_table_names_from_sql("select * from c.a/b/c/d"))
66
+