yearning-cli 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ from yearning_cli.cli import main
2
+
3
+ main()
yearning_cli/cli.py ADDED
@@ -0,0 +1,459 @@
1
+ """CLI entry point for yearning-cli."""
2
+
3
+ import argparse
4
+ import logging
5
+ import os
6
+ import sys
7
+ import tomllib
8
+ from pathlib import Path
9
+
10
+ from yearning_cli.client import TokenExpiredError, YearningClient
11
+ from mysql_proxy.server import serve
12
+
13
+ BASE_URL = os.environ.get("YEARNING_URL", "http://192.168.1.135")
14
+ CONFIG_FILE = os.path.expanduser("~/.sqlrc")
15
+
16
+
17
+ def load_config() -> dict:
18
+ """Load config from ~/.sqlrc (key: value format)."""
19
+ cfg = {}
20
+ if os.path.exists(CONFIG_FILE):
21
+ with open(CONFIG_FILE, encoding="utf-8") as f:
22
+ for line in f:
23
+ line = line.strip()
24
+ if not line or line.startswith("#"):
25
+ continue
26
+ if ":" in line:
27
+ key, val = line.split(":", 1)
28
+ cfg[key.strip()] = val.strip()
29
+ return cfg
30
+
31
+
32
+ CONFIG = load_config()
33
+ DEFAULT_SOURCE = CONFIG.get("source", "")
34
+ DEFAULT_DATABASE = CONFIG.get("database", "")
35
+ USERNAME = (
36
+ os.environ.get("YEARNING_USER", "")
37
+ or CONFIG.get("yearning_user", "")
38
+ or CONFIG.get("username", "")
39
+ or CONFIG.get("user", "")
40
+ )
41
+ PASSWORD = (
42
+ os.environ.get("YEARNING_PASS", "")
43
+ or CONFIG.get("yearning_pass", "")
44
+ or CONFIG.get("yearning_password", "")
45
+ or CONFIG.get("password", "")
46
+ or CONFIG.get("pass", "")
47
+ )
48
+
49
+ # Check for -v/--verbose before argparse (global flag)
50
+ VERBOSE = "-v" in sys.argv or "--verbose" in sys.argv
51
+ if VERBOSE:
52
+ sys.argv = [a for a in sys.argv if a not in ("-v", "--verbose")]
53
+
54
+
55
+ def get_version() -> str:
56
+ """Return package version from pyproject.toml."""
57
+ pyproject = Path(__file__).resolve().parents[1] / "pyproject.toml"
58
+ if not pyproject.exists():
59
+ raise RuntimeError(f"缺少版本配置文件: {pyproject}")
60
+
61
+ with open(pyproject, "rb") as f:
62
+ data = tomllib.load(f)
63
+
64
+ project_version = data.get("project", {}).get("version", "")
65
+ if not project_version:
66
+ raise RuntimeError("pyproject.toml 缺少 project.version")
67
+
68
+ return project_version
69
+
70
+
71
+ def _get_client() -> YearningClient:
72
+ client = YearningClient(
73
+ BASE_URL,
74
+ verbose=VERBOSE,
75
+ username=USERNAME,
76
+ password=PASSWORD,
77
+ )
78
+ try:
79
+ client.ensure_authenticated()
80
+ except TokenExpiredError as e:
81
+ print(str(e))
82
+ sys.exit(1)
83
+ return client
84
+
85
+
86
+ def _format_table(columns: list, rows: list, query_time: int = 0) -> None:
87
+ """Format query results as a table."""
88
+ if not columns:
89
+ if query_time:
90
+ print(f"(空结果集) [{query_time}ms]")
91
+ else:
92
+ print("(空结果集)")
93
+ return
94
+
95
+ # Calculate column widths
96
+ col_widths = [len(str(c)) for c in columns]
97
+ for row in rows:
98
+ for i, val in enumerate(row):
99
+ col_widths[i] = min(max(col_widths[i], len(str(val))), 60)
100
+
101
+ # Print header
102
+ sep = "+" + "+".join("-" * (w + 2) for w in col_widths) + "+"
103
+ print(sep)
104
+ print("|" + "|".join(f" {str(c).ljust(w)} " for c, w in zip(columns, col_widths)) + "|")
105
+ print(sep)
106
+ for row in rows:
107
+ print("|" + "|".join(f" {str(v).ljust(w)} " for v, w in zip(row, col_widths)) + "|")
108
+ print(sep)
109
+ ts = f", {query_time}ms" if query_time else ""
110
+ print(f"({len(rows)} 行{ts})")
111
+
112
+
113
+ def cmd_login(args: argparse.Namespace) -> None:
114
+ client = YearningClient(BASE_URL, verbose=VERBOSE)
115
+ user = USERNAME or input("用户名: ")
116
+ pwd = PASSWORD or input("密码: ")
117
+ try:
118
+ token = client.login(user, pwd)
119
+ print(f"登录成功! Token 已保存")
120
+ except Exception as e:
121
+ print(f"登录失败: {e}")
122
+ sys.exit(1)
123
+
124
+
125
+ def cmd_sources(args: argparse.Namespace) -> None:
126
+ client = _get_client()
127
+ try:
128
+ sources = client.list_sources(idc=args.idc)
129
+ if not sources:
130
+ print("无数据源")
131
+ return
132
+
133
+ print(f"\n{'IDC':<15} {'名称':<25} {'Source ID'}")
134
+ print("-" * 80)
135
+ for src in sources:
136
+ if isinstance(src, dict):
137
+ print(
138
+ f"{src.get('idc', ''):<15} "
139
+ f"{src.get('source', ''):<25} "
140
+ f"{src.get('source_id', '')}"
141
+ )
142
+ else:
143
+ print(f" {src}")
144
+ except Exception as e:
145
+ print(f"错误: {e}")
146
+
147
+
148
+ def cmd_databases(args: argparse.Namespace) -> None:
149
+ client = _get_client()
150
+ try:
151
+ source_id = _resolve_source_id(client, args.source_id)
152
+ dbs = client.list_databases(source_id)
153
+ if not dbs:
154
+ print("(无数据库)")
155
+ return
156
+
157
+ print("\n数据库列表:")
158
+ for db in dbs:
159
+ print(f" {db}")
160
+ except Exception as e:
161
+ print(f"错误: {e}")
162
+
163
+
164
+ def cmd_tables(args: argparse.Namespace) -> None:
165
+ client = _get_client()
166
+ try:
167
+ source_id = _resolve_source_id(client, args.source_id)
168
+ tables = client.list_tables(source_id, args.database)
169
+ if not tables:
170
+ print("(无表)")
171
+ return
172
+
173
+ print(f"\n表列表 ({args.database}):")
174
+ for t in tables:
175
+ print(f" {t}")
176
+ except Exception as e:
177
+ print(f"错误: {e}")
178
+
179
+
180
+ def _export_result(result: dict, database: str, fmt: str) -> None:
181
+ """Export query result to ~/Downloads/{database}_{timestamp}.{fmt}."""
182
+ from datetime import datetime
183
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
184
+ filename = f"{database}_{ts}.{fmt}"
185
+ filepath = os.path.join(os.path.expanduser("~"), "Downloads", filename)
186
+
187
+ if fmt == "json":
188
+ import json
189
+ data = {
190
+ "columns": result["columns"],
191
+ "rows": [dict(zip(result["columns"], row)) for row in result["rows"]],
192
+ "query_time_ms": result["query_time"],
193
+ "row_count": len(result["rows"]),
194
+ }
195
+ with open(filepath, "w", encoding="utf-8") as f:
196
+ json.dump(data, f, ensure_ascii=False, indent=2)
197
+ elif fmt == "csv":
198
+ import csv
199
+ with open(filepath, "w", newline="", encoding="utf-8-sig") as f:
200
+ writer = csv.writer(f)
201
+ writer.writerow(result["columns"])
202
+ writer.writerows(result["rows"])
203
+ elif fmt == "xlsx":
204
+ try:
205
+ import openpyxl
206
+ except ImportError:
207
+ print("错误: 导出 xlsx 需要 openpyxl,请运行: pip install openpyxl")
208
+ sys.exit(1)
209
+ wb = openpyxl.Workbook()
210
+ ws = wb.active
211
+ ws.append(result["columns"])
212
+ for row in result["rows"]:
213
+ ws.append(row)
214
+ wb.save(filepath)
215
+
216
+ print(f"已导出 {len(result['rows'])} 行到 {filepath}")
217
+
218
+
219
+ def cmd_query(args: argparse.Namespace) -> None:
220
+ if not args.source_id:
221
+ print("错误: 未指定数据源。请在 ~/.sqlrc 配置 source 或在命令行指定")
222
+ sys.exit(1)
223
+ if not args.database:
224
+ print("错误: 未指定数据库。请在 ~/.sqlrc 配置 database 或在命令行指定")
225
+ sys.exit(1)
226
+ client = _get_client()
227
+ try:
228
+ source_id = _resolve_source_id(client, args.source_id)
229
+
230
+ # Auto-create query order before executing
231
+ print("正在创建查询工单...", end="", flush=True)
232
+ if not client.create_query_order(source_id):
233
+ print(" 失败")
234
+ print("错误: 无法创建查询工单,请先在 Web 界面确认查询权限")
235
+ return
236
+ print(" 成功")
237
+
238
+ result = client.execute_query(
239
+ source_id, args.database, args.sql, timeout=args.timeout
240
+ )
241
+ if not result["success"]:
242
+ print(f"错误: {result['error']}")
243
+ return
244
+
245
+ # Export to file if requested
246
+ if args.export is not None:
247
+ fmt = args.export.lstrip(".")
248
+ if fmt not in ("csv", "json", "xlsx"):
249
+ print(f"错误: 不支持的格式 '{fmt}',请使用 csv, json 或 xlsx")
250
+ return
251
+ _export_result(result, args.database, fmt)
252
+ return
253
+
254
+ if args.json:
255
+ import json
256
+
257
+ output = {
258
+ "columns": result["columns"],
259
+ "rows": [
260
+ dict(zip(result["columns"], row)) for row in result["rows"]
261
+ ],
262
+ "query_time_ms": result["query_time"],
263
+ "row_count": len(result["rows"]),
264
+ }
265
+ print(json.dumps(output, ensure_ascii=False, indent=2))
266
+ else:
267
+ _format_table(result["columns"], result["rows"], result["query_time"])
268
+ except Exception as e:
269
+ print(f"错误: {e}")
270
+
271
+
272
+ def _resolve_source_id(client: YearningClient, name_or_id: str) -> str:
273
+ """Resolve source name to source_id, or return as-is if already a UUID."""
274
+ if len(name_or_id) > 30 and "-" in name_or_id:
275
+ return name_or_id
276
+ sources = client.list_sources()
277
+ for src in sources:
278
+ if not isinstance(src, dict):
279
+ continue
280
+ if src.get("source") == name_or_id or src.get("source_id") == name_or_id:
281
+ return src["source_id"]
282
+ return name_or_id
283
+
284
+
285
+ def cmd_shell(args: argparse.Namespace) -> None:
286
+ client = _get_client()
287
+ source_id = _resolve_source_id(client, args.source_id)
288
+ current_db = args.database
289
+
290
+ # 创建查询工单
291
+ print("正在创建查询工单...")
292
+ if not client.create_query_order(source_id):
293
+ print("警告: 查询工单创建失败,可能无法执行查询")
294
+
295
+ print(f"\nYearning SQL Shell")
296
+ print(f"数据源: {args.source_id} ({source_id[:12]}...) | 数据库: {current_db}")
297
+ print("输入 SQL 执行 | \\s 切换库 | \\d 查看库 | \\t 查看表 | exit 退出\n")
298
+
299
+ while True:
300
+ try:
301
+ sql = input(f"sql({current_db})> ").strip()
302
+ except (EOFError, KeyboardInterrupt):
303
+ print("\nBye!")
304
+ break
305
+
306
+ if not sql:
307
+ continue
308
+ if sql.lower() in ("exit", "quit", "\\q"):
309
+ print("Bye!")
310
+ break
311
+
312
+ if sql.lower().startswith("\\s "):
313
+ current_db = sql[3:].strip()
314
+ print(f" 切换到: {current_db}")
315
+ continue
316
+
317
+ if sql.lower() in ("\\d", "\\s"):
318
+ try:
319
+ dbs = client.list_databases(source_id)
320
+ for db in dbs:
321
+ print(f" {db}")
322
+ except Exception as e:
323
+ print(f" 错误: {e}")
324
+ continue
325
+
326
+ if sql.lower() == "\\t":
327
+ try:
328
+ tables = client.list_tables(source_id, current_db)
329
+ for t in tables:
330
+ print(f" {t}")
331
+ except Exception as e:
332
+ print(f" 错误: {e}")
333
+ continue
334
+
335
+ # multi-line support
336
+ while not sql.endswith(";"):
337
+ try:
338
+ line = input(" ...> ").strip()
339
+ if line.lower() in ("exit", "quit"):
340
+ break
341
+ sql += " " + line
342
+ except (EOFError, KeyboardInterrupt):
343
+ break
344
+
345
+ sql = sql.rstrip(";").strip()
346
+ if not sql:
347
+ continue
348
+
349
+ try:
350
+ result = client.execute_query(source_id, current_db, sql)
351
+ if not result["success"]:
352
+ print(f"错误: {result['error']}")
353
+ continue
354
+ _format_table(result["columns"], result["rows"], result["query_time"])
355
+ except Exception as e:
356
+ print(f"错误: {e}")
357
+
358
+
359
+ def cmd_proxy(args: argparse.Namespace) -> None:
360
+ """Start MySQL protocol proxy server."""
361
+ level = logging.DEBUG if VERBOSE else logging.INFO
362
+ logging.basicConfig(
363
+ level=level,
364
+ format="%(asctime)s [%(levelname)s] %(message)s",
365
+ datefmt="%H:%M:%S",
366
+ )
367
+
368
+ config = {
369
+ "base_url": args.url or BASE_URL,
370
+ "source_name": args.source or DEFAULT_SOURCE,
371
+ "source_id": "",
372
+ "database": args.database or DEFAULT_DATABASE,
373
+ "user": args.user or "",
374
+ "password": args.password or "",
375
+ }
376
+
377
+ serve(host=args.host, port=args.port, config=config)
378
+
379
+
380
+ def main() -> None:
381
+ parser = argparse.ArgumentParser(
382
+ prog="sql",
383
+ description="Yearning MySQL Audit Platform CLI",
384
+ )
385
+ parser.add_argument(
386
+ "--version",
387
+ action="version",
388
+ version=f"%(prog)s {get_version()}",
389
+ )
390
+ sub = parser.add_subparsers(dest="command", required=True)
391
+
392
+ # login
393
+ p_login = sub.add_parser("login", help="登录 Yearning")
394
+
395
+ # sources
396
+ p_src = sub.add_parser("sources", help="列出数据源")
397
+ p_src.add_argument("--idc", default="", help="按 IDC 筛选")
398
+
399
+ # databases
400
+ p_db = sub.add_parser("databases", help="列出数据库")
401
+ p_db.add_argument("source_id", nargs="?", default=DEFAULT_SOURCE,
402
+ help=f"数据源 ID (默认: {DEFAULT_SOURCE or '未配置'})")
403
+
404
+ # tables
405
+ p_tbl = sub.add_parser("tables", help="列出表")
406
+ p_tbl.add_argument("source_id", nargs="?", default=DEFAULT_SOURCE,
407
+ help=f"数据源 ID (默认: {DEFAULT_SOURCE or '未配置'})")
408
+ p_tbl.add_argument("database", nargs="?", default=DEFAULT_DATABASE,
409
+ help=f"数据库名 (默认: {DEFAULT_DATABASE or '未配置'})")
410
+
411
+ # query
412
+ p_q = sub.add_parser("query", help="执行 SQL 查询")
413
+ p_q.add_argument("sql", help="SQL 语句")
414
+ p_q.add_argument("source_id", nargs="?", default=DEFAULT_SOURCE,
415
+ help=f"数据源 ID 或名称 (默认: {DEFAULT_SOURCE or '未配置'})")
416
+ p_q.add_argument("database", nargs="?", default=DEFAULT_DATABASE,
417
+ help=f"数据库名 (默认: {DEFAULT_DATABASE or '未配置'})")
418
+ p_q.add_argument("--json", action="store_true", help="以 JSON 格式输出")
419
+ p_q.add_argument("--timeout", type=int, default=30, help="查询超时秒数 (默认: 30)")
420
+ p_q.add_argument("-o", "--export", nargs="?", const="xlsx", metavar="FMT",
421
+ help="导出结果到 ~/Downloads/ (格式: xlsx/csv/json, 默认 xlsx)")
422
+
423
+ # shell
424
+ p_sh = sub.add_parser("shell", help="交互式 SQL Shell")
425
+ p_sh.add_argument("source_id", nargs="?", default=DEFAULT_SOURCE,
426
+ help=f"数据源 ID (默认: {DEFAULT_SOURCE or '未配置'})")
427
+ p_sh.add_argument("database", nargs="?", default=DEFAULT_DATABASE,
428
+ help=f"数据库名 (默认: {DEFAULT_DATABASE or '未配置'})")
429
+
430
+ # proxy
431
+ p_pr = sub.add_parser("proxy", help="启动 MySQL 协议代理服务器")
432
+ p_pr.add_argument("--host", default="0.0.0.0", help="绑定地址 (默认: 0.0.0.0;IPv6 可用 ::)")
433
+ p_pr.add_argument("--port", type=int, default=3307, help="监听端口 (默认: 3307)")
434
+ p_pr.add_argument("--source", default="", help="数据源名称或 ID")
435
+ p_pr.add_argument("--database", default="", help="默认数据库名")
436
+ p_pr.add_argument("--url", default="", help="Yearning URL")
437
+ p_pr.add_argument("--user", default="", help="要求认证的用户名")
438
+ p_pr.add_argument("--password", default="", help="要求的认证密码")
439
+
440
+ args = parser.parse_args()
441
+
442
+ dispatch = {
443
+ "login": cmd_login,
444
+ "sources": cmd_sources,
445
+ "databases": cmd_databases,
446
+ "tables": cmd_tables,
447
+ "query": cmd_query,
448
+ "shell": cmd_shell,
449
+ "proxy": cmd_proxy,
450
+ }
451
+ try:
452
+ dispatch[args.command](args)
453
+ except Exception as e:
454
+ print(f"错误: {e}")
455
+ sys.exit(1)
456
+
457
+
458
+ if __name__ == "__main__":
459
+ main()