ygo 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ygo might be problematic. Click here for more details.
- ycat/__init__.py +22 -0
- ycat/client.py +157 -0
- ycat/dtype.py +389 -0
- ycat/parse.py +66 -0
- ycat/yck.py +87 -0
- ygo/__init__.py +10 -0
- ygo/exceptions.py +13 -0
- ygo/ygo.py +372 -0
- ygo-1.0.1.dist-info/METADATA +95 -0
- ygo-1.0.1.dist-info/RECORD +15 -0
- ygo-1.0.1.dist-info/WHEEL +5 -0
- ygo-1.0.1.dist-info/licenses/LICENSE +21 -0
- ygo-1.0.1.dist-info/top_level.txt +3 -0
- ylog/__init__.py +20 -0
- ylog/core.py +226 -0
ycat/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
---------------------------------------------
|
|
4
|
+
Created on 2025/5/14 18:29
|
|
5
|
+
@author: ZhangYundi
|
|
6
|
+
@email: yundi.xxii@outlook.com
|
|
7
|
+
---------------------------------------------
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from .client import HOME, CATDB, SETTINGS, sql, put, create_engine_ck, create_engine_mysql, read_mysql, read_ck
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"HOME",
|
|
14
|
+
"CATDB",
|
|
15
|
+
"SETTINGS",
|
|
16
|
+
"sql",
|
|
17
|
+
"put",
|
|
18
|
+
"create_engine_ck",
|
|
19
|
+
"create_engine_mysql",
|
|
20
|
+
"read_mysql",
|
|
21
|
+
"read_ck",
|
|
22
|
+
]
|
ycat/client.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
---------------------------------------------
|
|
4
|
+
Created on 2024/7/1 09:44
|
|
5
|
+
@author: ZhangYundi
|
|
6
|
+
@email: yundi.xxii@outlook.com
|
|
7
|
+
---------------------------------------------
|
|
8
|
+
"""
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
from typing import Optional
|
|
12
|
+
from .yck import connect, query_polars
|
|
13
|
+
|
|
14
|
+
import duckdb
|
|
15
|
+
import polars as pl
|
|
16
|
+
import ylog
|
|
17
|
+
from dynaconf import Dynaconf
|
|
18
|
+
from sqlalchemy import create_engine
|
|
19
|
+
from functools import partial
|
|
20
|
+
|
|
21
|
+
from .parse import extract_table_names_from_sql
|
|
22
|
+
|
|
23
|
+
# 配置文件在 “~/.catdb/setting.toml”
|
|
24
|
+
USERHOME = os.path.expanduser('~') # 用户家目录
|
|
25
|
+
CONFIG_PATH = os.path.join(USERHOME, ".catdb", "settings.toml")
|
|
26
|
+
if not os.path.exists(CONFIG_PATH):
|
|
27
|
+
try:
|
|
28
|
+
os.makedirs(os.path.dirname(CONFIG_PATH))
|
|
29
|
+
except FileExistsError as e:
|
|
30
|
+
...
|
|
31
|
+
except Exception as e:
|
|
32
|
+
ylog.error(f"配置文件生成失败: {e}")
|
|
33
|
+
catdb_path = os.path.join(USERHOME, "catdb")
|
|
34
|
+
template_content = f"""[paths]
|
|
35
|
+
catdb="{catdb_path}" # 本地数据库,默认家目录
|
|
36
|
+
|
|
37
|
+
## 数据库配置:
|
|
38
|
+
[database]
|
|
39
|
+
[database.ck]
|
|
40
|
+
# urls=["<host1>:<port1>", "<host2>:<port2>",]
|
|
41
|
+
# user="xxx"
|
|
42
|
+
# password="xxxxxx"
|
|
43
|
+
[database.jy]
|
|
44
|
+
# url="<host>:<port>"
|
|
45
|
+
# user="xxxx"
|
|
46
|
+
# password="xxxxxx"
|
|
47
|
+
|
|
48
|
+
## 视情况自由增加其他配置
|
|
49
|
+
"""
|
|
50
|
+
with open(CONFIG_PATH, "w") as f:
|
|
51
|
+
f.write(template_content)
|
|
52
|
+
ylog.info(f"生成配置文件: {CONFIG_PATH}")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_settings():
|
|
56
|
+
try:
|
|
57
|
+
return Dynaconf(settings_files=[CONFIG_PATH])
|
|
58
|
+
except:
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
HOME = USERHOME
|
|
63
|
+
CATDB = os.path.join(HOME, "catdb")
|
|
64
|
+
# 读取配置文件覆盖
|
|
65
|
+
SETTINGS = get_settings()
|
|
66
|
+
if SETTINGS is not None:
|
|
67
|
+
CATDB = SETTINGS.paths.catdb
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# ======================== 本地数据库 catdb ========================
|
|
71
|
+
def tb_path(tb_name: str) -> str:
|
|
72
|
+
"""
|
|
73
|
+
返回指定表名 完整的本地路径
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
tb_name: str
|
|
77
|
+
表名,路径写法: a/b/c
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
full_abs_path: str
|
|
81
|
+
完整的本地绝对路径 $HOME/catdb/a/b/c
|
|
82
|
+
"""
|
|
83
|
+
return os.path.join(CATDB, tb_name)
|
|
84
|
+
|
|
85
|
+
def put(df: pl.DataFrame, tb_name: str, partitions: Optional[list[str]] = None):
|
|
86
|
+
"""
|
|
87
|
+
将数据写入duck_db支持的parquet格式文件
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
df: pandas.DataFrame | pandas.Series | polars.DataFrame
|
|
91
|
+
写入的数据
|
|
92
|
+
tb_name: str
|
|
93
|
+
表名,支持路径写法, a/b/c
|
|
94
|
+
partitions: Optional[List[str]]
|
|
95
|
+
根据哪些字段进行分区,默认不分区
|
|
96
|
+
"""
|
|
97
|
+
tbpath = tb_path(tb_name)
|
|
98
|
+
if not os.path.exists(tbpath):
|
|
99
|
+
try:
|
|
100
|
+
os.makedirs(tbpath)
|
|
101
|
+
except FileExistsError as e:
|
|
102
|
+
pass
|
|
103
|
+
if partitions is not None:
|
|
104
|
+
for field in partitions:
|
|
105
|
+
assert field in df.columns, f'dataframe must have Field `{field}`'
|
|
106
|
+
df.write_parquet(tbpath, partition_by=partitions)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def sql(query: str):
|
|
110
|
+
"""
|
|
111
|
+
从duckdb中读取数据, query语法与mysql一致,特殊语法请查duckdb官网: https://duckdb.org/docs/sql/query_syntax/select
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
query: str
|
|
115
|
+
查询语句
|
|
116
|
+
Returns
|
|
117
|
+
-------
|
|
118
|
+
result: duckdb.DuckDBPyRelation
|
|
119
|
+
查询结果
|
|
120
|
+
"""
|
|
121
|
+
tbs = extract_table_names_from_sql(query)
|
|
122
|
+
convertor = dict()
|
|
123
|
+
for tb in tbs:
|
|
124
|
+
db_path = tb_path(tb)
|
|
125
|
+
format_tb = f"read_parquet('{db_path}/**/*.parquet', hive_partitioning = true)"
|
|
126
|
+
convertor[tb] = format_tb
|
|
127
|
+
pattern = re.compile("|".join(re.escape(k) for k in convertor.keys()))
|
|
128
|
+
new_query = pattern.sub(lambda m: convertor[m.group(0)], query)
|
|
129
|
+
conn = duckdb.connect()
|
|
130
|
+
conn.execute("PRAGMA disable_progress_bar;")
|
|
131
|
+
conn.execute("PRAGMA threads=1;")
|
|
132
|
+
res = conn.execute(new_query).fetch_arrow_table()
|
|
133
|
+
res = pl.from_arrow(res)
|
|
134
|
+
conn.close()
|
|
135
|
+
return res
|
|
136
|
+
|
|
137
|
+
def create_engine_ck(urls: list[str], user: str, password: str):
|
|
138
|
+
return partial(connect, urls, user, password)
|
|
139
|
+
|
|
140
|
+
def read_ck(sql, eng)->pl.DataFrame:
|
|
141
|
+
with eng() as conn:
|
|
142
|
+
return query_polars(sql, conn)
|
|
143
|
+
|
|
144
|
+
def create_engine_mysql(url, user, password, database):
|
|
145
|
+
"""
|
|
146
|
+
:param url: <host>:<port>
|
|
147
|
+
:param user:
|
|
148
|
+
:param password:
|
|
149
|
+
:param database:
|
|
150
|
+
:return:
|
|
151
|
+
"""
|
|
152
|
+
engine = create_engine(f"mysql+pymysql://{user}:{password}@{url}/{database}")
|
|
153
|
+
return engine
|
|
154
|
+
|
|
155
|
+
def read_mysql(sql, eng) -> pl.DataFrame:
|
|
156
|
+
with eng.connect() as conn:
|
|
157
|
+
return pl.read_database(sql, conn)
|
ycat/dtype.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
---------------------------------------------
|
|
4
|
+
Created on 2024/11/4 下午1:20
|
|
5
|
+
@author: ZhangYundi
|
|
6
|
+
@email: yundi.xxii@outlook.com
|
|
7
|
+
---------------------------------------------
|
|
8
|
+
"""
|
|
9
|
+
import functools
|
|
10
|
+
import re
|
|
11
|
+
from typing import Any
|
|
12
|
+
import pyarrow as pa
|
|
13
|
+
import re # 正则解析 Decimal 类型
|
|
14
|
+
|
|
15
|
+
from polars._typing import PolarsDataType
|
|
16
|
+
from polars.datatypes import (
|
|
17
|
+
Binary,
|
|
18
|
+
Boolean,
|
|
19
|
+
Date,
|
|
20
|
+
Datetime,
|
|
21
|
+
Decimal,
|
|
22
|
+
Duration,
|
|
23
|
+
Float32,
|
|
24
|
+
Float64,
|
|
25
|
+
Int8,
|
|
26
|
+
Int16,
|
|
27
|
+
Int32,
|
|
28
|
+
Int64,
|
|
29
|
+
List,
|
|
30
|
+
Null,
|
|
31
|
+
String,
|
|
32
|
+
Time,
|
|
33
|
+
UInt8,
|
|
34
|
+
UInt16,
|
|
35
|
+
UInt32,
|
|
36
|
+
UInt64,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@functools.lru_cache(8)
|
|
41
|
+
def integer_dtype_from_nbits(
|
|
42
|
+
bits: int,
|
|
43
|
+
*,
|
|
44
|
+
unsigned: bool,
|
|
45
|
+
default: PolarsDataType | None = None,
|
|
46
|
+
) -> PolarsDataType | None:
|
|
47
|
+
"""
|
|
48
|
+
Return matching Polars integer dtype from num bits and signed/unsigned flag.
|
|
49
|
+
|
|
50
|
+
Examples
|
|
51
|
+
--------
|
|
52
|
+
>>> integer_dtype_from_nbits(8, unsigned=False)
|
|
53
|
+
Int8
|
|
54
|
+
>>> integer_dtype_from_nbits(32, unsigned=True)
|
|
55
|
+
UInt32
|
|
56
|
+
"""
|
|
57
|
+
dtype = {
|
|
58
|
+
(8, False): Int8,
|
|
59
|
+
(8, True): UInt8,
|
|
60
|
+
(16, False): Int16,
|
|
61
|
+
(16, True): UInt16,
|
|
62
|
+
(32, False): Int32,
|
|
63
|
+
(32, True): UInt32,
|
|
64
|
+
(64, False): Int64,
|
|
65
|
+
(64, True): UInt64,
|
|
66
|
+
}.get((bits, unsigned), None)
|
|
67
|
+
|
|
68
|
+
if dtype is None and default is not None:
|
|
69
|
+
return default
|
|
70
|
+
return dtype
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def timeunit_from_precision(precision: int | str | None) -> str | None:
|
|
74
|
+
"""
|
|
75
|
+
Return `time_unit` from integer precision value.
|
|
76
|
+
|
|
77
|
+
Examples
|
|
78
|
+
--------
|
|
79
|
+
>>> timeunit_from_precision(3)
|
|
80
|
+
'ms'
|
|
81
|
+
>>> timeunit_from_precision(5)
|
|
82
|
+
'us'
|
|
83
|
+
>>> timeunit_from_precision(7)
|
|
84
|
+
'ns'
|
|
85
|
+
"""
|
|
86
|
+
from math import ceil
|
|
87
|
+
|
|
88
|
+
if not precision:
|
|
89
|
+
return None
|
|
90
|
+
elif isinstance(precision, str):
|
|
91
|
+
if precision.isdigit():
|
|
92
|
+
precision = int(precision)
|
|
93
|
+
elif (precision := precision.lower()) in ("s", "ms", "us", "ns"):
|
|
94
|
+
return "ms" if precision == "s" else precision
|
|
95
|
+
try:
|
|
96
|
+
n = min(max(3, int(ceil(precision / 3)) * 3), 9) # type: ignore[operator]
|
|
97
|
+
return {3: "ms", 6: "us", 9: "ns"}.get(n)
|
|
98
|
+
except TypeError:
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def infer_dtype_from_database_typename(
|
|
103
|
+
value: str,
|
|
104
|
+
*,
|
|
105
|
+
raise_unmatched: bool = True,
|
|
106
|
+
) -> PolarsDataType | None:
|
|
107
|
+
"""
|
|
108
|
+
Attempt to infer Polars dtype from database cursor `type_code` string value.
|
|
109
|
+
|
|
110
|
+
Examples
|
|
111
|
+
--------
|
|
112
|
+
>>> infer_dtype_from_database_typename("INT2")
|
|
113
|
+
Int16
|
|
114
|
+
>>> infer_dtype_from_database_typename("NVARCHAR")
|
|
115
|
+
String
|
|
116
|
+
>>> infer_dtype_from_database_typename("NUMERIC(10,2)")
|
|
117
|
+
Decimal(precision=10, scale=2)
|
|
118
|
+
>>> infer_dtype_from_database_typename("TIMESTAMP WITHOUT TZ")
|
|
119
|
+
Datetime(time_unit='us', time_zone=None)
|
|
120
|
+
"""
|
|
121
|
+
dtype: PolarsDataType | None = None
|
|
122
|
+
|
|
123
|
+
# normalise string name/case (eg: 'IntegerType' -> 'INTEGER')
|
|
124
|
+
original_value = value
|
|
125
|
+
value = value.upper().replace("TYPE", "")
|
|
126
|
+
|
|
127
|
+
# extract optional type modifier (eg: 'VARCHAR(64)' -> '64')
|
|
128
|
+
if re.search(r"\([\w,: ]+\)$", value):
|
|
129
|
+
modifier = value[value.find("(") + 1: -1]
|
|
130
|
+
value = value.split("(")[0]
|
|
131
|
+
# Nullable type
|
|
132
|
+
if value.upper() == "NULLABLE":
|
|
133
|
+
return infer_dtype_from_database_typename(modifier)
|
|
134
|
+
elif (
|
|
135
|
+
not value.startswith(("<", ">")) and re.search(r"\[[\w,\]\[: ]+]$", value)
|
|
136
|
+
) or value.endswith(("[S]", "[MS]", "[US]", "[NS]")):
|
|
137
|
+
modifier = value[value.find("[") + 1: -1]
|
|
138
|
+
value = value.split("[")[0]
|
|
139
|
+
else:
|
|
140
|
+
modifier = ""
|
|
141
|
+
|
|
142
|
+
# array dtypes
|
|
143
|
+
array_aliases = ("ARRAY", "LIST", "[]")
|
|
144
|
+
if value.endswith(array_aliases) or value.startswith(array_aliases):
|
|
145
|
+
for a in array_aliases:
|
|
146
|
+
value = value.replace(a, "", 1) if value else ""
|
|
147
|
+
|
|
148
|
+
nested: PolarsDataType | None = None
|
|
149
|
+
if not value and modifier:
|
|
150
|
+
nested = infer_dtype_from_database_typename(
|
|
151
|
+
value=modifier,
|
|
152
|
+
raise_unmatched=False,
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
if inner_value := infer_dtype_from_database_typename(
|
|
156
|
+
value[1:-1]
|
|
157
|
+
if (value[0], value[-1]) == ("<", ">")
|
|
158
|
+
else re.sub(r"\W", "", re.sub(r"\WOF\W", "", value)),
|
|
159
|
+
raise_unmatched=False,
|
|
160
|
+
):
|
|
161
|
+
nested = inner_value
|
|
162
|
+
elif modifier:
|
|
163
|
+
nested = infer_dtype_from_database_typename(
|
|
164
|
+
value=modifier,
|
|
165
|
+
raise_unmatched=False,
|
|
166
|
+
)
|
|
167
|
+
if nested:
|
|
168
|
+
dtype = List(nested)
|
|
169
|
+
|
|
170
|
+
# float dtypes
|
|
171
|
+
elif value.startswith("FLOAT") or ("DOUBLE" in value) or (value == "REAL"):
|
|
172
|
+
dtype = (
|
|
173
|
+
Float32
|
|
174
|
+
if value == "FLOAT4"
|
|
175
|
+
or (value.endswith(("16", "32")) or (modifier in ("16", "32")))
|
|
176
|
+
else Float64
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# integer dtypes
|
|
180
|
+
elif ("INTERVAL" not in value) and (
|
|
181
|
+
value.startswith(("INT", "UINT", "UNSIGNED"))
|
|
182
|
+
or value.endswith(("INT", "SERIAL"))
|
|
183
|
+
or ("INTEGER" in value)
|
|
184
|
+
or value == "ROWID"
|
|
185
|
+
):
|
|
186
|
+
sz: Any
|
|
187
|
+
if "LARGE" in value or value.startswith("BIG") or value == "INT8":
|
|
188
|
+
sz = 64
|
|
189
|
+
elif "MEDIUM" in value or value in ("INT4", "SERIAL"):
|
|
190
|
+
sz = 32
|
|
191
|
+
elif "SMALL" in value or value == "INT2":
|
|
192
|
+
sz = 16
|
|
193
|
+
elif "TINY" in value:
|
|
194
|
+
sz = 8
|
|
195
|
+
else:
|
|
196
|
+
sz = None
|
|
197
|
+
|
|
198
|
+
sz = modifier if (not sz and modifier) else sz
|
|
199
|
+
if not isinstance(sz, int):
|
|
200
|
+
sz = int(sz) if isinstance(sz, str) and sz.isdigit() else None
|
|
201
|
+
if (
|
|
202
|
+
("U" in value and "MEDIUM" not in value)
|
|
203
|
+
or ("UNSIGNED" in value)
|
|
204
|
+
or value == "ROWID"
|
|
205
|
+
):
|
|
206
|
+
dtype = integer_dtype_from_nbits(sz, unsigned=True, default=UInt64)
|
|
207
|
+
else:
|
|
208
|
+
dtype = integer_dtype_from_nbits(sz, unsigned=False, default=Int64)
|
|
209
|
+
|
|
210
|
+
# number types (note: 'number' alone is not that helpful and requires refinement)
|
|
211
|
+
elif "NUMBER" in value and "CARDINAL" in value:
|
|
212
|
+
dtype = UInt64
|
|
213
|
+
|
|
214
|
+
# decimal dtypes
|
|
215
|
+
elif (is_dec := ("DECIMAL" in value)) or ("NUMERIC" in value):
|
|
216
|
+
if "," in modifier:
|
|
217
|
+
prec, scale = modifier.split(",")
|
|
218
|
+
dtype = Decimal(int(prec), int(scale))
|
|
219
|
+
else:
|
|
220
|
+
dtype = Decimal if is_dec else Float64
|
|
221
|
+
|
|
222
|
+
# string dtypes
|
|
223
|
+
elif (
|
|
224
|
+
any(tp in value for tp in ("VARCHAR", "STRING", "TEXT", "UNICODE"))
|
|
225
|
+
or value.startswith(("STR", "CHAR", "BPCHAR", "NCHAR", "UTF"))
|
|
226
|
+
or value.endswith(("_UTF8", "_UTF16", "_UTF32"))
|
|
227
|
+
):
|
|
228
|
+
dtype = String
|
|
229
|
+
|
|
230
|
+
# binary dtypes
|
|
231
|
+
elif value in ("BYTEA", "BYTES", "BLOB", "CLOB", "BINARY"):
|
|
232
|
+
dtype = Binary
|
|
233
|
+
|
|
234
|
+
# boolean dtypes
|
|
235
|
+
elif value.startswith("BOOL"):
|
|
236
|
+
dtype = Boolean
|
|
237
|
+
|
|
238
|
+
# null dtype; odd, but valid
|
|
239
|
+
elif value == "NULL":
|
|
240
|
+
dtype = Null
|
|
241
|
+
|
|
242
|
+
# temporal dtypes
|
|
243
|
+
elif value.startswith(("DATETIME", "TIMESTAMP")) and not (value.endswith("[D]")):
|
|
244
|
+
if any((tz in value.replace(" ", "")) for tz in ("TZ", "TIMEZONE")):
|
|
245
|
+
if "WITHOUT" not in value:
|
|
246
|
+
return None # there's a timezone, but we don't know what it is
|
|
247
|
+
unit = timeunit_from_precision(modifier) if modifier else "us"
|
|
248
|
+
dtype = Datetime(time_unit=(unit or "us")) # type: ignore[arg-type]
|
|
249
|
+
else:
|
|
250
|
+
value = re.sub(r"\d", "", value)
|
|
251
|
+
if value in ("INTERVAL", "TIMEDELTA", "DURATION"):
|
|
252
|
+
dtype = Duration
|
|
253
|
+
elif value == "DATE":
|
|
254
|
+
dtype = Date
|
|
255
|
+
elif value == "TIME":
|
|
256
|
+
dtype = Time
|
|
257
|
+
|
|
258
|
+
if not dtype and raise_unmatched:
|
|
259
|
+
msg = f"cannot infer dtype from {original_value!r} string value"
|
|
260
|
+
raise ValueError(msg)
|
|
261
|
+
|
|
262
|
+
return dtype
|
|
263
|
+
|
|
264
|
+
CLICKHOUSE_TO_ARROW_TYPE = {
|
|
265
|
+
# 整数类型
|
|
266
|
+
'Int8': pa.int8(),
|
|
267
|
+
'Int16': pa.int16(),
|
|
268
|
+
'Int32': pa.int32(),
|
|
269
|
+
'Int64': pa.int64(),
|
|
270
|
+
'UInt8': pa.uint8(),
|
|
271
|
+
'UInt16': pa.uint16(),
|
|
272
|
+
'UInt32': pa.uint32(),
|
|
273
|
+
'UInt64': pa.uint64(),
|
|
274
|
+
|
|
275
|
+
# 浮点类型
|
|
276
|
+
'Float32': pa.float32(),
|
|
277
|
+
'Float64': pa.float64(),
|
|
278
|
+
|
|
279
|
+
# 字符串类型
|
|
280
|
+
'String': pa.string(),
|
|
281
|
+
'FixedString': pa.string(), # Arrow 不区分固定长度和动态长度字符串
|
|
282
|
+
|
|
283
|
+
# 日期和时间类型
|
|
284
|
+
'Date': pa.date32(), # ClickHouse 的 Date 是 32 位(天)
|
|
285
|
+
'Date32': pa.date32(),
|
|
286
|
+
'DateTime': pa.timestamp('s'), # ClickHouse DateTime 精度为秒
|
|
287
|
+
'DateTime64': pa.timestamp('ms'), # 默认映射为毫秒精度(可根据需求调整)
|
|
288
|
+
'UUID': pa.binary(16), # UUID 是 16 字节的二进制
|
|
289
|
+
|
|
290
|
+
# 布尔类型
|
|
291
|
+
'Boolean': pa.bool_(),
|
|
292
|
+
|
|
293
|
+
# 数组类型(嵌套类型)
|
|
294
|
+
'Array(Int8)': pa.list_(pa.int8()),
|
|
295
|
+
'Array(Int16)': pa.list_(pa.int16()),
|
|
296
|
+
'Array(Int32)': pa.list_(pa.int32()),
|
|
297
|
+
'Array(Int64)': pa.list_(pa.int64()),
|
|
298
|
+
'Array(UInt8)': pa.list_(pa.uint8()),
|
|
299
|
+
'Array(UInt16)': pa.list_(pa.uint16()),
|
|
300
|
+
'Array(UInt32)': pa.list_(pa.uint32()),
|
|
301
|
+
'Array(UInt64)': pa.list_(pa.uint64()),
|
|
302
|
+
'Array(Float32)': pa.list_(pa.float32()),
|
|
303
|
+
'Array(Float64)': pa.list_(pa.float64()),
|
|
304
|
+
'Array(String)': pa.list_(pa.string()),
|
|
305
|
+
'Array(Date)': pa.list_(pa.date32()),
|
|
306
|
+
'Array(DateTime)': pa.list_(pa.timestamp('s')),
|
|
307
|
+
|
|
308
|
+
# 嵌套类型(元组、枚举等)
|
|
309
|
+
# 注意:Arrow 不直接支持 Tuple,通常需要转换为 Struct
|
|
310
|
+
'Tuple': pa.struct([]), # 需要动态定义每个字段的类型
|
|
311
|
+
# 枚举类型
|
|
312
|
+
'Enum8': pa.string(), # 通常映射为字符串
|
|
313
|
+
'Enum16': pa.string(),
|
|
314
|
+
|
|
315
|
+
# Map 类型
|
|
316
|
+
'Map': pa.map_(pa.string(), pa.string()), # 默认键值对是字符串(可根据需求调整)
|
|
317
|
+
|
|
318
|
+
# Nullable 类型(ClickHouse 的 Nullable 包装类型)
|
|
319
|
+
'Nullable(Int8)': pa.int8(),
|
|
320
|
+
'Nullable(Int16)': pa.int16(),
|
|
321
|
+
'Nullable(Int32)': pa.int32(),
|
|
322
|
+
'Nullable(Int64)': pa.int64(),
|
|
323
|
+
'Nullable(UInt8)': pa.uint8(),
|
|
324
|
+
'Nullable(UInt16)': pa.uint16(),
|
|
325
|
+
'Nullable(UInt32)': pa.uint32(),
|
|
326
|
+
'Nullable(UInt64)': pa.uint64(),
|
|
327
|
+
'Nullable(Float32)': pa.float32(),
|
|
328
|
+
'Nullable(Float64)': pa.float64(),
|
|
329
|
+
'Nullable(String)': pa.string(),
|
|
330
|
+
'Nullable(Date)': pa.date32(),
|
|
331
|
+
'Nullable(DateTime)': pa.timestamp('s'),
|
|
332
|
+
'Nullable(UUID)': pa.binary(16),
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
def map_clickhouse_decimal(ch_type: str) -> pa.DataType:
|
|
336
|
+
"""
|
|
337
|
+
映射 ClickHouse 的 Decimal 类型到 Arrow 的 Decimal 类型
|
|
338
|
+
:param ch_type: ClickHouse 的 Decimal 类型描述,例如 'Decimal(10, 2)' 或 'Decimal128(38)'
|
|
339
|
+
:return: 对应的 Arrow Decimal 类型
|
|
340
|
+
"""
|
|
341
|
+
# 匹配 ClickHouse 的 Decimal(p, s) 格式
|
|
342
|
+
decimal_match = re.match(r"Decimal(?:32|64|128)?\((\d+),\s*(\d+)\)", ch_type)
|
|
343
|
+
if decimal_match:
|
|
344
|
+
precision, scale = map(int, decimal_match.groups())
|
|
345
|
+
return pa.decimal128(precision, scale)
|
|
346
|
+
|
|
347
|
+
# 匹配 ClickHouse 的 Decimal(p) 格式,默认 scale 为 0
|
|
348
|
+
decimal_match_no_scale = re.match(r"Decimal(?:32|64|128)?\((\d+)\)", ch_type)
|
|
349
|
+
if decimal_match_no_scale:
|
|
350
|
+
precision = int(decimal_match_no_scale.group(1))
|
|
351
|
+
return pa.decimal128(precision, 0)
|
|
352
|
+
|
|
353
|
+
# 如果不匹配,抛出异常
|
|
354
|
+
raise ValueError(f"Unsupported ClickHouse Decimal type: {ch_type}")
|
|
355
|
+
|
|
356
|
+
def map_clickhouse_to_arrow(ch_type: str) -> pa.DataType:
|
|
357
|
+
"""
|
|
358
|
+
动态映射 ClickHouse 类型到 Arrow 类型
|
|
359
|
+
"""
|
|
360
|
+
# 基础类型直接映射
|
|
361
|
+
if ch_type in CLICKHOUSE_TO_ARROW_TYPE:
|
|
362
|
+
return CLICKHOUSE_TO_ARROW_TYPE[ch_type]
|
|
363
|
+
|
|
364
|
+
# Decimal 类型处理
|
|
365
|
+
if ch_type.startswith("Decimal"):
|
|
366
|
+
return map_clickhouse_decimal(ch_type)
|
|
367
|
+
|
|
368
|
+
# 动态处理 Array 类型
|
|
369
|
+
if ch_type.startswith('Array('):
|
|
370
|
+
inner_type = ch_type[6:-1] # 提取 Array 内的类型
|
|
371
|
+
return pa.list_(map_clickhouse_to_arrow(inner_type))
|
|
372
|
+
|
|
373
|
+
# 动态处理 Nullable 类型
|
|
374
|
+
if ch_type.startswith('Nullable('):
|
|
375
|
+
inner_type = ch_type[9:-1] # 提取 Nullable 内的类型
|
|
376
|
+
return map_clickhouse_to_arrow(inner_type)
|
|
377
|
+
|
|
378
|
+
# 动态处理 Tuple 类型
|
|
379
|
+
if ch_type.startswith('Tuple('):
|
|
380
|
+
inner_types = ch_type[6:-1].split(',') # 提取 Tuple 内的字段类型
|
|
381
|
+
return pa.struct([('field' + str(i), map_clickhouse_to_arrow(t.strip())) for i, t in enumerate(inner_types)])
|
|
382
|
+
|
|
383
|
+
# 动态处理 Map 类型
|
|
384
|
+
if ch_type.startswith('Map('):
|
|
385
|
+
key_type, value_type = ch_type[4:-1].split(',')
|
|
386
|
+
return pa.map_(map_clickhouse_to_arrow(key_type.strip()), map_clickhouse_to_arrow(value_type.strip()))
|
|
387
|
+
|
|
388
|
+
raise ValueError(f"Unsupported ClickHouse type: {ch_type}")
|
|
389
|
+
|
ycat/parse.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
---------------------------------------------
|
|
4
|
+
Created on 2024/11/6 下午7:25
|
|
5
|
+
@author: ZhangYundi
|
|
6
|
+
@email: yundi.xxii@outlook.com
|
|
7
|
+
---------------------------------------------
|
|
8
|
+
"""
|
|
9
|
+
import sqlparse
|
|
10
|
+
import re
|
|
11
|
+
|
|
12
|
+
def format_sql(sql_content):
|
|
13
|
+
"""将sql语句进行规范化,并去除sql中的注释,输入和输出均为字符串"""
|
|
14
|
+
parse_str = sqlparse.format(sql_content, reindent=True, strip_comments=True)
|
|
15
|
+
return parse_str
|
|
16
|
+
|
|
17
|
+
def extract_temp_tables(with_clause):
|
|
18
|
+
"""从WITH子句中提取临时表名,输出为列表"""
|
|
19
|
+
temp_tables = re.findall(r'\b(\w+)\s*as\s*\(', with_clause, re.IGNORECASE)
|
|
20
|
+
return temp_tables
|
|
21
|
+
|
|
22
|
+
def extract_table_names_from_sql(sql_query):
|
|
23
|
+
"""从sql中提取对应的表名称,输出为列表"""
|
|
24
|
+
table_names = set()
|
|
25
|
+
# 解析SQL语句
|
|
26
|
+
parsed = sqlparse.parse(sql_query)
|
|
27
|
+
# 正则表达式模式,用于匹配表名
|
|
28
|
+
table_name_pattern = r'\bFROM\s+([^\s\(\)\,]+)|\bJOIN\s+([^\s\(\)\,]+)'
|
|
29
|
+
|
|
30
|
+
# 用于存储WITH子句中的临时表名
|
|
31
|
+
remove_with_name = []
|
|
32
|
+
|
|
33
|
+
# 遍历解析后的语句块
|
|
34
|
+
for statement in parsed:
|
|
35
|
+
# 转换为字符串
|
|
36
|
+
statement_str = str(statement).lower()
|
|
37
|
+
|
|
38
|
+
# 将字符串中的特殊语法置空
|
|
39
|
+
statement_str = re.sub(r'(substring|extract)\s*\(((.|\s)*?)\)', '', statement_str)
|
|
40
|
+
|
|
41
|
+
# 查找匹配的表名
|
|
42
|
+
matches = re.findall(table_name_pattern, statement_str, re.IGNORECASE)
|
|
43
|
+
|
|
44
|
+
for match in matches:
|
|
45
|
+
# 提取非空的表名部分
|
|
46
|
+
for name in match:
|
|
47
|
+
if name:
|
|
48
|
+
# 对于可能包含命名空间的情况,只保留最后一部分作为表名
|
|
49
|
+
table_name = name.split('.')[-1]
|
|
50
|
+
# 去除表名中的特殊符号
|
|
51
|
+
table_name = re.sub(r'("|`|\'|;)', '', table_name)
|
|
52
|
+
table_names.add(table_name)
|
|
53
|
+
|
|
54
|
+
# 处理特殊的WITH语句
|
|
55
|
+
if 'with' in statement_str:
|
|
56
|
+
remove_with_name = extract_temp_tables(statement_str)
|
|
57
|
+
# 移除多余的表名
|
|
58
|
+
if remove_with_name:
|
|
59
|
+
table_names = list(set(table_names) - set(remove_with_name))
|
|
60
|
+
|
|
61
|
+
return table_names
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
if __name__ == '__main__':
|
|
65
|
+
print(extract_table_names_from_sql("select * from c.a/b/c/d"))
|
|
66
|
+
|