quantcli 0.1.7__tar.gz → 0.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {quantcli-0.1.7/quantcli.egg-info → quantcli-0.1.8}/PKG-INFO +1 -1
  2. {quantcli-0.1.7 → quantcli-0.1.8}/pyproject.toml +1 -1
  3. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/cli.py +9 -9
  4. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/datasources/mysql.py +197 -138
  5. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/datasources/sync/gm_fundamental.py +14 -8
  6. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/utils/__init__.py +13 -0
  7. quantcli-0.1.8/quantcli/utils/env.py +77 -0
  8. {quantcli-0.1.7 → quantcli-0.1.8/quantcli.egg-info}/PKG-INFO +1 -1
  9. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli.egg-info/SOURCES.txt +1 -0
  10. {quantcli-0.1.7 → quantcli-0.1.8}/LICENSE +0 -0
  11. {quantcli-0.1.7 → quantcli-0.1.8}/README.md +0 -0
  12. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/core/__init__.py +0 -0
  13. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/core/backtest.py +0 -0
  14. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/core/data.py +0 -0
  15. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/core/factor.py +0 -0
  16. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/datasources/__init__.py +0 -0
  17. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/datasources/akshare.py +0 -0
  18. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/datasources/baostock.py +0 -0
  19. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/datasources/base.py +0 -0
  20. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/datasources/cache.py +0 -0
  21. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/datasources/fundamentals/__init__.py +0 -0
  22. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/datasources/fundamentals/provider.py +0 -0
  23. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/datasources/mixed.py +0 -0
  24. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/datasources/sync/__init__.py +0 -0
  25. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/datasources/sync/akshare.py +0 -0
  26. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/datasources/sync/base.py +0 -0
  27. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/datasources/sync/gm.py +0 -0
  28. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/factors/__init__.py +0 -0
  29. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/factors/base.py +0 -0
  30. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/factors/compute.py +0 -0
  31. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/factors/loader.py +0 -0
  32. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/factors/pipeline.py +0 -0
  33. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/factors/ranking.py +0 -0
  34. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/factors/ranking_executor.py +0 -0
  35. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/factors/screening.py +0 -0
  36. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/factors/screening_executor.py +0 -0
  37. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/models/bar.py +0 -0
  38. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/parser/__init__.py +0 -0
  39. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/parser/constants.py +0 -0
  40. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/parser/formula.py +0 -0
  41. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/utils/logger.py +0 -0
  42. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/utils/path.py +0 -0
  43. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/utils/symbol_utils.py +0 -0
  44. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/utils/time.py +0 -0
  45. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli/utils/validate.py +0 -0
  46. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli.egg-info/dependency_links.txt +0 -0
  47. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli.egg-info/entry_points.txt +0 -0
  48. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli.egg-info/requires.txt +0 -0
  49. {quantcli-0.1.7 → quantcli-0.1.8}/quantcli.egg-info/top_level.txt +0 -0
  50. {quantcli-0.1.7 → quantcli-0.1.8}/setup.cfg +0 -0
  51. {quantcli-0.1.7 → quantcli-0.1.8}/tests/test_akshare_integration.py +0 -0
  52. {quantcli-0.1.7 → quantcli-0.1.8}/tests/test_builtin_factors.py +0 -0
  53. {quantcli-0.1.7 → quantcli-0.1.8}/tests/test_cli.py +0 -0
  54. {quantcli-0.1.7 → quantcli-0.1.8}/tests/test_datasources.py +0 -0
  55. {quantcli-0.1.7 → quantcli-0.1.8}/tests/test_factors.py +0 -0
  56. {quantcli-0.1.7 → quantcli-0.1.8}/tests/test_gm_executors.py +0 -0
  57. {quantcli-0.1.7 → quantcli-0.1.8}/tests/test_mixed_datasource.py +0 -0
  58. {quantcli-0.1.7 → quantcli-0.1.8}/tests/test_multi_factor.py +0 -0
  59. {quantcli-0.1.7 → quantcli-0.1.8}/tests/test_pipeline_integration.py +0 -0
  60. {quantcli-0.1.7 → quantcli-0.1.8}/tests/test_symbol_utils.py +0 -0
  61. {quantcli-0.1.7 → quantcli-0.1.8}/tests/test_time.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quantcli
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: 面向AI的多因子量化选股策略挖掘工具,AI Agent 友好 CLI
5
5
  Author-email: QuantCLI Team <quantcli@example.com>
6
6
  Project-URL: repository, https://github.com/wumu2013/quantcli
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "quantcli"
7
- version = "0.1.7"
7
+ version = "0.1.8"
8
8
  description = "面向AI的多因子量化选股策略挖掘工具,AI Agent 友好 CLI"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.9"
@@ -1059,11 +1059,11 @@ def backtest_run(ctx, strategy, symbol, start, end, as_of, capital, fee, datasou
1059
1059
  "\nBacktest must use MySQL for efficient batch data loading.", err=True
1060
1060
  )
1061
1061
  click.echo("Please configure MySQL environment variables:", err=True)
1062
- click.echo(" export MYSQL_HOST=localhost", err=True)
1063
- click.echo(" export MYSQL_PORT=3306", err=True)
1064
- click.echo(" export MYSQL_USER=root", err=True)
1065
- click.echo(" export MYSQL_PASSWORD=xxx", err=True)
1066
- click.echo(" export MYSQL_DATABASE=quantcli", err=True)
1062
+ click.echo(" export QUANT_MYSQL_HOST=localhost", err=True)
1063
+ click.echo(" export QUANT_MYSQL_PORT=3306", err=True)
1064
+ click.echo(" export QUANT_MYSQL_USER=root", err=True)
1065
+ click.echo(" export QUANT_MYSQL_PASSWORD=xxx", err=True)
1066
+ click.echo(" export QUANT_MYSQL_DATABASE=quantcli", err=True)
1067
1067
  sys.exit(1)
1068
1068
 
1069
1069
  # 设置时间基线
@@ -1119,10 +1119,10 @@ def _run_yaml_backtest(ctx, strategy_path, symbol, start, end, capital, fee):
1119
1119
  "\nHint: Make sure MySQL is running and environment variables are set:",
1120
1120
  err=True,
1121
1121
  )
1122
- click.echo(" export MYSQL_HOST=localhost", err=True)
1123
- click.echo(" export MYSQL_USER=root", err=True)
1124
- click.echo(" export MYSQL_PASSWORD=xxx", err=True)
1125
- click.echo(" export MYSQL_DATABASE=quantcli", err=True)
1122
+ click.echo(" export QUANT_MYSQL_HOST=localhost", err=True)
1123
+ click.echo(" export QUANT_MYSQL_USER=root", err=True)
1124
+ click.echo(" export QUANT_MYSQL_PASSWORD=xxx", err=True)
1125
+ click.echo(" export QUANT_MYSQL_DATABASE=quantcli", err=True)
1126
1126
  sys.exit(1)
1127
1127
 
1128
1128
  # 显示数据源状态
@@ -29,7 +29,7 @@ from datetime import date, datetime
29
29
  from typing import List, Optional, Dict, Any
30
30
  import pandas as pd
31
31
 
32
- from ..utils import get_logger, format_date
32
+ from ..utils import get_logger, format_date, get_env, get_env_int
33
33
  from ..utils.symbol_utils import to_mysql, normalize
34
34
  from .base import DataSource, DataSourceConfig
35
35
 
@@ -43,12 +43,12 @@ class MySQLDataSource(DataSource):
43
43
 
44
44
  # 默认连接配置
45
45
  DEFAULT_CONFIG = {
46
- "host": os.getenv("MYSQL_HOST", "localhost"),
47
- "port": int(os.getenv("MYSQL_PORT", "3306")),
48
- "user": os.getenv("MYSQL_USER", "root"),
49
- "password": os.getenv("MYSQL_PASSWORD", ""),
50
- "database": os.getenv("MYSQL_DATABASE", "quantcli"),
51
- "table_prefix": os.getenv("MYSQL_TABLE_PREFIX", ""),
46
+ "host": get_env("QUANT_MYSQL_HOST", "localhost"),
47
+ "port": get_env_int("QUANT_MYSQL_PORT", 3306),
48
+ "user": get_env("QUANT_MYSQL_USER", "root"),
49
+ "password": get_env("QUANT_MYSQL_PASSWORD", ""),
50
+ "database": get_env("QUANT_MYSQL_DATABASE", "quantcli"),
51
+ "table_prefix": get_env("QUANT_MYSQL_TABLE_PREFIX", ""),
52
52
  }
53
53
 
54
54
  def __init__(
@@ -96,6 +96,7 @@ class MySQLDataSource(DataSource):
96
96
  """获取数据库连接"""
97
97
  if self._conn is None or not self._conn.open:
98
98
  import pymysql
99
+
99
100
  self._conn = pymysql.connect(
100
101
  host=self._config["host"],
101
102
  port=self._config["port"],
@@ -208,11 +209,7 @@ class MySQLDataSource(DataSource):
208
209
  return [to_mysql(s) for s in symbols]
209
210
 
210
211
  def get_daily(
211
- self,
212
- symbol: str,
213
- start_date,
214
- end_date,
215
- fields: Optional[List[str]] = None
212
+ self, symbol: str, start_date, end_date, fields: Optional[List[str]] = None
216
213
  ) -> pd.DataFrame:
217
214
  """获取日线数据
218
215
 
@@ -234,7 +231,7 @@ class MySQLDataSource(DataSource):
234
231
 
235
232
  sql = f"""
236
233
  SELECT symbol, trade_date, open, high, low, close, volume, amount
237
- FROM {self._table('daily_prices')}
234
+ FROM {self._table("daily_prices")}
238
235
  WHERE symbol = %s AND trade_date BETWEEN %s AND %s
239
236
  ORDER BY trade_date
240
237
  """
@@ -245,24 +242,40 @@ class MySQLDataSource(DataSource):
245
242
  rows = cursor.fetchall()
246
243
 
247
244
  if not rows:
248
- return pd.DataFrame(columns=['symbol', 'date', 'open', 'high', 'low', 'close', 'volume', 'amount'])
245
+ return pd.DataFrame(
246
+ columns=[
247
+ "symbol",
248
+ "date",
249
+ "open",
250
+ "high",
251
+ "low",
252
+ "close",
253
+ "volume",
254
+ "amount",
255
+ ]
256
+ )
249
257
 
250
258
  df = pd.DataFrame(rows)
251
- df = df.rename(columns={
252
- 'trade_date': 'date',
253
- 'amount': 'amount'
254
- })
255
- df['date'] = pd.to_datetime(df['date']).dt.date
259
+ df = df.rename(columns={"trade_date": "date", "amount": "amount"})
260
+ df["date"] = pd.to_datetime(df["date"]).dt.date
256
261
  return df
257
262
  except Exception as e:
258
263
  logger.error(f"Failed to get daily data: {e}")
259
- return pd.DataFrame(columns=['symbol', 'date', 'open', 'high', 'low', 'close', 'volume', 'amount'])
264
+ return pd.DataFrame(
265
+ columns=[
266
+ "symbol",
267
+ "date",
268
+ "open",
269
+ "high",
270
+ "low",
271
+ "close",
272
+ "volume",
273
+ "amount",
274
+ ]
275
+ )
260
276
 
261
277
  def get_multi_daily(
262
- self,
263
- symbols: List[str],
264
- start_date,
265
- end_date
278
+ self, symbols: List[str], start_date, end_date
266
279
  ) -> Dict[str, pd.DataFrame]:
267
280
  """批量获取多只股票的日线数据(回测优化)
268
281
 
@@ -287,7 +300,7 @@ class MySQLDataSource(DataSource):
287
300
  placeholders = ",".join(["%s"] * len(mysql_symbols))
288
301
  sql = f"""
289
302
  SELECT symbol, trade_date, open, high, low, close, volume, amount
290
- FROM {self._table('daily_prices')}
303
+ FROM {self._table("daily_prices")}
291
304
  WHERE symbol IN ({placeholders}) AND trade_date BETWEEN %s AND %s
292
305
  ORDER BY symbol, trade_date
293
306
  """
@@ -301,10 +314,10 @@ class MySQLDataSource(DataSource):
301
314
  result = {}
302
315
  df = pd.DataFrame(rows)
303
316
  if not df.empty:
304
- df = df.rename(columns={'trade_date': 'date'})
305
- df['date'] = pd.to_datetime(df['date']).dt.date
317
+ df = df.rename(columns={"trade_date": "date"})
318
+ df["date"] = pd.to_datetime(df["date"]).dt.date
306
319
  for mysql_symbol in mysql_symbols:
307
- symbol_df = df[df['symbol'] == mysql_symbol].copy()
320
+ symbol_df = df[df["symbol"] == mysql_symbol].copy()
308
321
  if not symbol_df.empty:
309
322
  result[mysql_symbol] = symbol_df
310
323
 
@@ -313,12 +326,7 @@ class MySQLDataSource(DataSource):
313
326
  logger.error(f"Failed to get multi daily data: {e}")
314
327
  return {}
315
328
 
316
- def get_index_daily(
317
- self,
318
- symbol: str,
319
- start_date,
320
- end_date
321
- ) -> pd.DataFrame:
329
+ def get_index_daily(self, symbol: str, start_date, end_date) -> pd.DataFrame:
322
330
  """获取指数日线数据"""
323
331
  # 指数也存储在 daily_prices 表中
324
332
  return self.get_daily(symbol, start_date, end_date)
@@ -330,7 +338,7 @@ class MySQLDataSource(DataSource):
330
338
  symbol: str,
331
339
  start_date: date = None,
332
340
  end_date: date = None,
333
- period: str = "5"
341
+ period: str = "5",
334
342
  ) -> pd.DataFrame:
335
343
  """获取分钟级数据
336
344
 
@@ -350,6 +358,7 @@ class MySQLDataSource(DataSource):
350
358
 
351
359
  # 默认范围:最近 5 个交易日
352
360
  from datetime import timedelta
361
+
353
362
  if end_date is None:
354
363
  end_date = date.today() - timedelta(1)
355
364
  if start_date is None:
@@ -361,7 +370,7 @@ class MySQLDataSource(DataSource):
361
370
  sql = f"""
362
371
  SELECT symbol, trade_date, trade_time, period,
363
372
  open, high, low, close, volume, amount
364
- FROM {self._table('intraday_prices')}
373
+ FROM {self._table("intraday_prices")}
365
374
  WHERE symbol = %s AND trade_date BETWEEN %s AND %s AND period = %s
366
375
  ORDER BY trade_date, trade_time
367
376
  """
@@ -372,25 +381,27 @@ class MySQLDataSource(DataSource):
372
381
  rows = cursor.fetchall()
373
382
 
374
383
  if not rows:
375
- return pd.DataFrame(columns=['date', 'open', 'high', 'low', 'close', 'volume', 'amount'])
384
+ return pd.DataFrame(
385
+ columns=["date", "open", "high", "low", "close", "volume", "amount"]
386
+ )
376
387
 
377
388
  df = pd.DataFrame(rows)
378
389
  # 合并日期和时间
379
- df['datetime'] = pd.to_datetime(df['trade_date'].astype(str) + ' ' + df['trade_time'].astype(str))
380
- df = df.rename(columns={'datetime': 'date'})
381
- df = df.drop(columns=['trade_date', 'trade_time', 'period', 'symbol'])
390
+ df["datetime"] = pd.to_datetime(
391
+ df["trade_date"].astype(str) + " " + df["trade_time"].astype(str)
392
+ )
393
+ df = df.rename(columns={"datetime": "date"})
394
+ df = df.drop(columns=["trade_date", "trade_time", "period", "symbol"])
382
395
 
383
- return df[['date', 'open', 'high', 'low', 'close', 'volume', 'amount']]
396
+ return df[["date", "open", "high", "low", "close", "volume", "amount"]]
384
397
  except Exception as e:
385
398
  logger.error(f"Failed to get intraday data: {e}")
386
- return pd.DataFrame(columns=['date', 'open', 'high', 'low', 'close', 'volume', 'amount'])
399
+ return pd.DataFrame(
400
+ columns=["date", "open", "high", "low", "close", "volume", "amount"]
401
+ )
387
402
 
388
403
  def get_multi_intraday(
389
- self,
390
- symbols: List[str],
391
- start_date: date,
392
- end_date: date,
393
- period: str = "5"
404
+ self, symbols: List[str], start_date: date, end_date: date, period: str = "5"
394
405
  ) -> Dict[str, pd.DataFrame]:
395
406
  """批量获取多只股票的分钟级数据(回测优化)
396
407
 
@@ -417,7 +428,7 @@ class MySQLDataSource(DataSource):
417
428
  sql = f"""
418
429
  SELECT symbol, trade_date, trade_time, period,
419
430
  open, high, low, close, volume, amount
420
- FROM {self._table('intraday_prices')}
431
+ FROM {self._table("intraday_prices")}
421
432
  WHERE symbol IN ({placeholders}) AND trade_date BETWEEN %s AND %s AND period = %s
422
433
  ORDER BY symbol, trade_date, trade_time
423
434
  """
@@ -430,13 +441,19 @@ class MySQLDataSource(DataSource):
430
441
  result = {}
431
442
  df = pd.DataFrame(rows)
432
443
  if not df.empty:
433
- df['datetime'] = pd.to_datetime(df['trade_date'].astype(str) + ' ' + df['trade_time'].astype(str))
434
- df = df.rename(columns={'datetime': 'date'})
444
+ df["datetime"] = pd.to_datetime(
445
+ df["trade_date"].astype(str) + " " + df["trade_time"].astype(str)
446
+ )
447
+ df = df.rename(columns={"datetime": "date"})
435
448
  for mysql_symbol in mysql_symbols:
436
- symbol_df = df[df['symbol'] == mysql_symbol].copy()
449
+ symbol_df = df[df["symbol"] == mysql_symbol].copy()
437
450
  if not symbol_df.empty:
438
- symbol_df = symbol_df.drop(columns=['trade_date', 'trade_time', 'period', 'symbol'])
439
- result[mysql_symbol] = symbol_df[['date', 'open', 'high', 'low', 'close', 'volume', 'amount']]
451
+ symbol_df = symbol_df.drop(
452
+ columns=["trade_date", "trade_time", "period", "symbol"]
453
+ )
454
+ result[mysql_symbol] = symbol_df[
455
+ ["date", "open", "high", "low", "close", "volume", "amount"]
456
+ ]
440
457
 
441
458
  return result
442
459
  except Exception as e:
@@ -444,11 +461,7 @@ class MySQLDataSource(DataSource):
444
461
  return {}
445
462
 
446
463
  def get_pool_minute_data(
447
- self,
448
- symbols: List[str],
449
- start_date: date,
450
- end_date: date,
451
- period: str = "5"
464
+ self, symbols: List[str], start_date: date, end_date: date, period: str = "5"
452
465
  ) -> Dict[str, pd.DataFrame]:
453
466
  """获取股票池的分钟数据(用于 on_bar ranking)
454
467
 
@@ -477,7 +490,7 @@ class MySQLDataSource(DataSource):
477
490
  symbol: str,
478
491
  start_date: date = None,
479
492
  end_date: date = None,
480
- period: str = "5"
493
+ period: str = "5",
481
494
  ):
482
495
  """从 akshare 同步分钟级数据到 MySQL
483
496
 
@@ -510,12 +523,13 @@ class MySQLDataSource(DataSource):
510
523
  with conn.cursor() as cursor:
511
524
  for _, row in df.iterrows():
512
525
  # 解析日期时间
513
- dt = pd.to_datetime(row['date'])
526
+ dt = pd.to_datetime(row["date"])
514
527
  trade_date = dt.date()
515
528
  trade_time = dt.time()
516
529
 
517
- cursor.execute(f"""
518
- INSERT INTO {self._table('intraday_prices')}
530
+ cursor.execute(
531
+ f"""
532
+ INSERT INTO {self._table("intraday_prices")}
519
533
  (symbol, trade_date, trade_time, period, open, high, low, close, volume, amount)
520
534
  VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
521
535
  ON DUPLICATE KEY UPDATE
@@ -525,18 +539,20 @@ class MySQLDataSource(DataSource):
525
539
  close = VALUES(close),
526
540
  volume = VALUES(volume),
527
541
  amount = VALUES(amount)
528
- """, (
529
- mysql_symbol,
530
- trade_date,
531
- trade_time,
532
- period,
533
- row['open'],
534
- row['high'],
535
- row['low'],
536
- row['close'],
537
- row.get('volume', 0),
538
- row.get('amount', 0)
539
- ))
542
+ """,
543
+ (
544
+ mysql_symbol,
545
+ trade_date,
546
+ trade_time,
547
+ period,
548
+ row["open"],
549
+ row["high"],
550
+ row["low"],
551
+ row["close"],
552
+ row.get("volume", 0),
553
+ row.get("amount", 0),
554
+ ),
555
+ )
540
556
 
541
557
  logger.info(f"Synced {len(df)} intraday records for {mysql_symbol}")
542
558
 
@@ -561,18 +577,29 @@ class MySQLDataSource(DataSource):
561
577
 
562
578
  df = pd.DataFrame(rows)
563
579
  if df.empty:
564
- return pd.DataFrame(columns=['symbol', 'name', 'exchange', 'market', 'list_date', 'status'])
580
+ return pd.DataFrame(
581
+ columns=[
582
+ "symbol",
583
+ "name",
584
+ "exchange",
585
+ "market",
586
+ "list_date",
587
+ "status",
588
+ ]
589
+ )
565
590
  return df
566
591
  except Exception as e:
567
592
  logger.error(f"Failed to get stock list: {e}")
568
- return pd.DataFrame(columns=['symbol', 'name', 'exchange', 'market', 'list_date', 'status'])
593
+ return pd.DataFrame(
594
+ columns=["symbol", "name", "exchange", "market", "list_date", "status"]
595
+ )
569
596
 
570
597
  def get_trading_calendar(self, exchange: str = "SSE") -> List[date]:
571
598
  """获取交易日历"""
572
599
  conn = self._get_connection()
573
600
 
574
601
  sql = f"""
575
- SELECT trade_date FROM {self._table('trading_calendar')}
602
+ SELECT trade_date FROM {self._table("trading_calendar")}
576
603
  WHERE exchange = %s AND is_trading_day = 1
577
604
  ORDER BY trade_date
578
605
  """
@@ -582,7 +609,7 @@ class MySQLDataSource(DataSource):
582
609
  cursor.execute(sql, (exchange,))
583
610
  rows = cursor.fetchall()
584
611
 
585
- return [row['trade_date'] for row in rows]
612
+ return [row["trade_date"] for row in rows]
586
613
  except Exception as e:
587
614
  logger.error(f"Failed to get trading calendar: {e}")
588
615
  return []
@@ -590,10 +617,7 @@ class MySQLDataSource(DataSource):
590
617
  # ==================== 基本面数据 ====================
591
618
 
592
619
  def get_fundamental(
593
- self,
594
- symbols: List[str],
595
- date,
596
- indicators: Optional[List[str]] = None
620
+ self, symbols: List[str], date, indicators: Optional[List[str]] = None
597
621
  ) -> pd.DataFrame:
598
622
  """获取基本面数据
599
623
 
@@ -611,7 +635,7 @@ class MySQLDataSource(DataSource):
611
635
  placeholders = ",".join(["%s"] * len(mysql_symbols))
612
636
  sql = f"""
613
637
  SELECT symbol, report_date, roe, netprofitmargin, grossprofitmargin, pe_ttm, pb
614
- FROM {self._table('fundamental_data')}
638
+ FROM {self._table("fundamental_data")}
615
639
  WHERE symbol IN ({placeholders}) AND report_date <= %s
616
640
  ORDER BY symbol, report_date DESC
617
641
  """
@@ -622,19 +646,41 @@ class MySQLDataSource(DataSource):
622
646
  rows = cursor.fetchall()
623
647
 
624
648
  if not rows:
625
- return pd.DataFrame(columns=['symbol', 'report_date', 'roe', 'netprofitmargin', 'grossprofitmargin', 'pe_ttm', 'pb'])
649
+ return pd.DataFrame(
650
+ columns=[
651
+ "symbol",
652
+ "report_date",
653
+ "roe",
654
+ "netprofitmargin",
655
+ "grossprofitmargin",
656
+ "pe_ttm",
657
+ "pb",
658
+ ]
659
+ )
626
660
 
627
661
  # 取每个股票的最新数据
628
662
  df = pd.DataFrame(rows)
629
- df = df.groupby('symbol').first().reset_index()
663
+ df = df.groupby("symbol").first().reset_index()
630
664
  return df
631
665
  except Exception as e:
632
666
  logger.error(f"Failed to get fundamental data: {e}")
633
- return pd.DataFrame(columns=['symbol', 'report_date', 'roe', 'netprofitmargin', 'grossprofitmargin', 'pe_ttm', 'pb'])
667
+ return pd.DataFrame(
668
+ columns=[
669
+ "symbol",
670
+ "report_date",
671
+ "roe",
672
+ "netprofitmargin",
673
+ "grossprofitmargin",
674
+ "pe_ttm",
675
+ "pb",
676
+ ]
677
+ )
634
678
 
635
679
  # ==================== 数据同步 ====================
636
680
 
637
- def sync_from_akshare(self, start_date: date = None, end_date: date = None, symbols: List[str] = None):
681
+ def sync_from_akshare(
682
+ self, start_date: date = None, end_date: date = None, symbols: List[str] = None
683
+ ):
638
684
  """从 akshare 同步日线数据到 MySQL
639
685
 
640
686
  Args:
@@ -648,6 +694,7 @@ class MySQLDataSource(DataSource):
648
694
  start_date = date(2020, 1, 1)
649
695
  if end_date is None:
650
696
  from datetime import timedelta
697
+
651
698
  end_date = date.today() - timedelta(1)
652
699
 
653
700
  akshare = AkshareDataSource(use_cache=True)
@@ -655,7 +702,7 @@ class MySQLDataSource(DataSource):
655
702
  # 获取股票列表
656
703
  if symbols is None:
657
704
  stock_list = akshare.get_stock_list()
658
- symbols = stock_list['symbol'].tolist()[:100] # 限制数量避免超时
705
+ symbols = stock_list["symbol"].tolist()[:100] # 限制数量避免超时
659
706
 
660
707
  conn = self._get_connection()
661
708
  total = len(symbols)
@@ -673,8 +720,9 @@ class MySQLDataSource(DataSource):
673
720
  # 插入数据库
674
721
  with conn.cursor() as cursor:
675
722
  for _, row in df.iterrows():
676
- cursor.execute(f"""
677
- INSERT INTO {self._table('daily_prices')}
723
+ cursor.execute(
724
+ f"""
725
+ INSERT INTO {self._table("daily_prices")}
678
726
  (symbol, trade_date, open, high, low, close, volume, amount)
679
727
  VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
680
728
  ON DUPLICATE KEY UPDATE
@@ -684,16 +732,18 @@ class MySQLDataSource(DataSource):
684
732
  close = VALUES(close),
685
733
  volume = VALUES(volume),
686
734
  amount = VALUES(amount)
687
- """, (
688
- mysql_symbol,
689
- row['date'],
690
- row['open'],
691
- row['high'],
692
- row['low'],
693
- row['close'],
694
- row.get('volume', 0),
695
- row.get('amount', 0)
696
- ))
735
+ """,
736
+ (
737
+ mysql_symbol,
738
+ row["date"],
739
+ row["open"],
740
+ row["high"],
741
+ row["low"],
742
+ row["close"],
743
+ row.get("volume", 0),
744
+ row.get("amount", 0),
745
+ ),
746
+ )
697
747
 
698
748
  synced += 1
699
749
  if synced % 10 == 0:
@@ -713,20 +763,19 @@ class MySQLDataSource(DataSource):
713
763
  conn = self._get_connection()
714
764
  with conn.cursor() as cursor:
715
765
  for day in trading_days:
716
- cursor.execute(f"""
717
- INSERT INTO {self._table('trading_calendar')}
766
+ cursor.execute(
767
+ f"""
768
+ INSERT INTO {self._table("trading_calendar")}
718
769
  (trade_date, exchange, is_trading_day)
719
770
  VALUES (%s, %s, 1)
720
771
  ON DUPLICATE KEY UPDATE is_trading_day = 1
721
- """, (day, exchange))
772
+ """,
773
+ (day, exchange),
774
+ )
722
775
 
723
776
  logger.info(f"Synced {len(trading_days)} trading days")
724
777
 
725
- def sync_fundamentals(
726
- self,
727
- start_year: int = 2020,
728
- end_year: int = 2024
729
- ):
778
+ def sync_fundamentals(self, start_year: int = 2020, end_year: int = 2024):
730
779
  """从 akshare 同步基本面数据到 MySQL(批量获取)
731
780
 
732
781
  使用东方财富数据中心接口,一次获取指定日期的所有股票基本面数据。
@@ -740,7 +789,9 @@ class MySQLDataSource(DataSource):
740
789
 
741
790
  # 生成财报日期列表
742
791
  dates = FundamentalsProvider.generate_report_dates(start_year, end_year)
743
- logger.info(f"准备同步 {start_year}-{end_year} 年财报,共 {len(dates)} 个报告期")
792
+ logger.info(
793
+ f"准备同步 {start_year}-{end_year} 年财报,共 {len(dates)} 个报告期"
794
+ )
744
795
 
745
796
  # 使用 Provider 批量获取数据
746
797
  provider = FundamentalsProvider(use_cache=True)
@@ -778,17 +829,19 @@ class MySQLDataSource(DataSource):
778
829
  quarter_end = report_date
779
830
 
780
831
  # 获取季度日均收盘价
781
- avg_prices = self._get_quarter_avg_close(df['symbol'].tolist(), quarter_start, quarter_end)
832
+ avg_prices = self._get_quarter_avg_close(
833
+ df["symbol"].tolist(), quarter_start, quarter_end
834
+ )
782
835
 
783
836
  # 插入数据库
784
837
  with conn.cursor() as cursor:
785
838
  for _, row in df.iterrows():
786
- symbol = row.get('symbol')
839
+ symbol = row.get("symbol")
787
840
  if not symbol:
788
841
  continue
789
842
 
790
- eps = row.get('eps')
791
- bps = row.get('bps')
843
+ eps = row.get("eps")
844
+ bps = row.get("bps")
792
845
  avg_close = avg_prices.get(symbol)
793
846
 
794
847
  # 转换为 float 避免 Decimal 运算问题
@@ -805,8 +858,9 @@ class MySQLDataSource(DataSource):
805
858
  pb = round(avg_close_val / bps_val, 2)
806
859
 
807
860
  try:
808
- cursor.execute(f"""
809
- INSERT INTO {self._table('fundamental_data')}
861
+ cursor.execute(
862
+ f"""
863
+ INSERT INTO {self._table("fundamental_data")}
810
864
  (symbol, report_date, eps, bps, net_profits, revenue, netprofit_yoy, roe, grossprofitmargin, pe_ttm, pb)
811
865
  VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
812
866
  ON DUPLICATE KEY UPDATE
@@ -819,26 +873,30 @@ class MySQLDataSource(DataSource):
819
873
  grossprofitmargin = COALESCE(VALUES(grossprofitmargin), grossprofitmargin),
820
874
  pe_ttm = COALESCE(VALUES(pe_ttm), pe_ttm),
821
875
  pb = COALESCE(VALUES(pb), pb)
822
- """, (
823
- symbol,
824
- report_date,
825
- eps,
826
- bps,
827
- row.get('net_profits'),
828
- row.get('revenue'),
829
- row.get('netprofit_yoy'),
830
- row.get('roe'),
831
- row.get('grossprofitmargin'),
832
- pe_ttm,
833
- pb,
834
- ))
876
+ """,
877
+ (
878
+ symbol,
879
+ report_date,
880
+ eps,
881
+ bps,
882
+ row.get("net_profits"),
883
+ row.get("revenue"),
884
+ row.get("netprofit_yoy"),
885
+ row.get("roe"),
886
+ row.get("grossprofitmargin"),
887
+ pe_ttm,
888
+ pb,
889
+ ),
890
+ )
835
891
  total_inserted += 1
836
892
  except Exception as e:
837
893
  logger.warning(f"插入失败 {symbol}: {e}")
838
894
 
839
895
  logger.info(f"{report_date} 完成")
840
896
 
841
- def _get_quarter_avg_close(self, symbols: List[str], start_date: str, end_date: str) -> Dict[str, float]:
897
+ def _get_quarter_avg_close(
898
+ self, symbols: List[str], start_date: str, end_date: str
899
+ ) -> Dict[str, float]:
842
900
  """计算股票在指定期间的日均收盘价
843
901
 
844
902
  Args:
@@ -857,7 +915,7 @@ class MySQLDataSource(DataSource):
857
915
 
858
916
  sql = f"""
859
917
  SELECT symbol, AVG(close) as avg_close
860
- FROM {self._table('daily_prices')}
918
+ FROM {self._table("daily_prices")}
861
919
  WHERE symbol IN ({placeholders})
862
920
  AND trade_date BETWEEN %s AND %s
863
921
  GROUP BY symbol
@@ -867,7 +925,7 @@ class MySQLDataSource(DataSource):
867
925
  with conn.cursor() as cursor:
868
926
  cursor.execute(sql, tuple(symbols) + (start_date, end_date))
869
927
  rows = cursor.fetchall()
870
- return {row['symbol']: round(row['avg_close'], 2) for row in rows}
928
+ return {row["symbol"]: round(row["avg_close"], 2) for row in rows}
871
929
  except Exception as e:
872
930
  logger.warning(f"计算季度日均收盘价失败: {e}")
873
931
  return {}
@@ -875,10 +933,7 @@ class MySQLDataSource(DataSource):
875
933
  logger.info(f"基本面同步完成: 共插入/更新 {total_inserted} 条记录")
876
934
 
877
935
  def sync_fundamentals_from_baostock(
878
- self,
879
- symbols: List[str] = None,
880
- start_year: int = 2020,
881
- end_year: int = 2024
936
+ self, symbols: List[str] = None, start_year: int = 2020, end_year: int = 2024
882
937
  ):
883
938
  """从 baostock 同步基本面数据到 MySQL(已废弃,使用 sync_fundamentals)
884
939
 
@@ -887,7 +942,9 @@ class MySQLDataSource(DataSource):
887
942
  start_year: 开始年份
888
943
  end_year: 结束年份
889
944
  """
890
- logger.warning("sync_fundamentals_from_baostock 已废弃,请使用 sync_fundamentals")
945
+ logger.warning(
946
+ "sync_fundamentals_from_baostock 已废弃,请使用 sync_fundamentals"
947
+ )
891
948
  self.sync_fundamentals(start_year, end_year)
892
949
 
893
950
  # ==================== 辅助方法 ====================
@@ -897,9 +954,11 @@ class MySQLDataSource(DataSource):
897
954
  try:
898
955
  conn = self._get_connection()
899
956
  with conn.cursor() as cursor:
900
- cursor.execute(f"SELECT COUNT(*) as cnt FROM {self._table('daily_prices')}")
957
+ cursor.execute(
958
+ f"SELECT COUNT(*) as cnt FROM {self._table('daily_prices')}"
959
+ )
901
960
  result = cursor.fetchone()
902
- daily_count = result['cnt']
961
+ daily_count = result["cnt"]
903
962
 
904
963
  return {
905
964
  "status": "ok",
@@ -17,13 +17,19 @@ from ..mysql import MySQLDataSource
17
17
 
18
18
  logger = get_logger(__name__)
19
19
 
20
- try:
21
- import gm
20
+ GM_AVAILABLE = None
22
21
 
23
- GM_AVAILABLE = True
24
- except ImportError:
25
- GM_AVAILABLE = False
26
- logger.warning("掘金 SDK 未安装,基本面同步将使用降级方案")
22
+
23
+ def _check_gm():
24
+ global GM_AVAILABLE
25
+ if GM_AVAILABLE is None:
26
+ try:
27
+ import gm
28
+
29
+ GM_AVAILABLE = True
30
+ except ImportError:
31
+ GM_AVAILABLE = False
32
+ return GM_AVAILABLE
27
33
 
28
34
 
29
35
  class GmFundamentalSync:
@@ -66,8 +72,8 @@ class GmFundamentalSync:
66
72
  Returns:
67
73
  {symbol: 同步记录数} 字典
68
74
  """
69
- if not GM_AVAILABLE:
70
- logger.warning("掘金 SDK 未安装,跳过基本面同步")
75
+ if not _check_gm():
76
+ logger.warning("掘金 SDK 未安装,基本面同步将使用降级方案")
71
77
  return {}
72
78
 
73
79
  # 初始化掘金
@@ -5,10 +5,12 @@
5
5
  - time: 日期时间处理
6
6
  - path: 路径管理
7
7
  - validate: 数据验证
8
+ - env: 环境变量加载
8
9
 
9
10
  Usage:
10
11
  >>> from quantcli.utils import get_logger, parse_date, project_root
11
12
  >>> from quantcli.utils import validate_schema, check_columns
13
+ >>> from quantcli.utils import get_env
12
14
  """
13
15
 
14
16
  # Logger
@@ -92,6 +94,13 @@ from .validate import (
92
94
  assert_range,
93
95
  )
94
96
 
97
+ # Env
98
+ from .env import (
99
+ load_env,
100
+ get_env,
101
+ get_env_int,
102
+ )
103
+
95
104
  __all__ = [
96
105
  # Logger
97
106
  "setup_logger",
@@ -162,4 +171,8 @@ __all__ = [
162
171
  "assert_columns",
163
172
  "assert_no_null",
164
173
  "assert_range",
174
+ # Env
175
+ "load_env",
176
+ "get_env",
177
+ "get_env_int",
165
178
  ]
@@ -0,0 +1,77 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Dict, Optional
4
+
5
+ _env_cache: Optional[Dict[str, str]] = None
6
+
7
+
8
+ def load_env(file_path: str = "~/.env") -> Dict[str, str]:
9
+ """加载 .env 文件
10
+
11
+ Args:
12
+ file_path: .env 文件路径,默认 ~/.env
13
+
14
+ Returns:
15
+ 环境变量字典
16
+ """
17
+ global _env_cache
18
+ if _env_cache is not None:
19
+ return _env_cache
20
+
21
+ _env_cache = {}
22
+ path = Path(file_path).expanduser()
23
+ if path.exists():
24
+ with open(path) as f:
25
+ for line in f:
26
+ line = line.strip()
27
+ if line and not line.startswith("#") and "=" in line:
28
+ k, v = line.split("=", 1)
29
+ _env_cache[k.strip()] = v.strip()
30
+ return _env_cache
31
+
32
+
33
+ def get_env(key: str, default: str = None) -> str:
34
+ """获取环境变量,优先从 ~/.env 读取
35
+
36
+ 优先级(如果 ~/.env 存在):
37
+ 1. ~/.env 文件
38
+ 2. 系统环境变量
39
+ 3. 默认值
40
+
41
+ 如果 ~/.env 不存在:
42
+ 1. 系统环境变量
43
+ 2. 默认值
44
+
45
+ Args:
46
+ key: 环境变量名
47
+ default: 默认值
48
+
49
+ Returns:
50
+ 环境变量值
51
+ """
52
+ # 检查 ~/.env 是否存在
53
+ env_file = Path("~/.env").expanduser()
54
+ if env_file.exists():
55
+ # 优先从 ~/.env 读取
56
+ env = load_env()
57
+ if key in env:
58
+ return env[key]
59
+ # 然后检查系统环境变量
60
+ value = os.getenv(key)
61
+ if value is not None:
62
+ return value
63
+ return default
64
+ else:
65
+ # 没有 ~/.env 时,使用系统环境变量
66
+ value = os.getenv(key)
67
+ if value is not None:
68
+ return value
69
+ return default
70
+
71
+
72
+ def get_env_int(key: str, default: int = None) -> int:
73
+ """获取整数类型的环境变量"""
74
+ value = get_env(key, str(default) if default is not None else None)
75
+ if value is None:
76
+ return default
77
+ return int(value)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quantcli
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: 面向AI的多因子量化选股策略挖掘工具,AI Agent 友好 CLI
5
5
  Author-email: QuantCLI Team <quantcli@example.com>
6
6
  Project-URL: repository, https://github.com/wumu2013/quantcli
@@ -40,6 +40,7 @@ quantcli/parser/__init__.py
40
40
  quantcli/parser/constants.py
41
41
  quantcli/parser/formula.py
42
42
  quantcli/utils/__init__.py
43
+ quantcli/utils/env.py
43
44
  quantcli/utils/logger.py
44
45
  quantcli/utils/path.py
45
46
  quantcli/utils/symbol_utils.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes