openfund-core 0.0.4__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- core/Exchange.py +276 -0
- core/main.py +23 -0
- core/smc/SMCBase.py +130 -0
- core/smc/SMCFVG.py +86 -0
- core/smc/SMCLiquidity.py +7 -0
- core/smc/SMCOrderBlock.py +288 -0
- core/smc/SMCPDArray.py +77 -0
- core/smc/SMCStruct.py +290 -0
- core/smc/__init__.py +0 -0
- core/utils/OPTools.py +30 -0
- openfund_core-1.0.1.dist-info/METADATA +48 -0
- openfund_core-1.0.1.dist-info/RECORD +15 -0
- {openfund_core-0.0.4.dist-info → openfund_core-1.0.1.dist-info}/WHEEL +1 -1
- openfund_core-1.0.1.dist-info/entry_points.txt +3 -0
- openfund/core/__init__.py +0 -14
- openfund/core/api_tools/__init__.py +0 -16
- openfund/core/api_tools/binance_futures_tools.py +0 -23
- openfund/core/api_tools/binance_tools.py +0 -26
- openfund/core/api_tools/enums.py +0 -539
- openfund/core/base_collector.py +0 -72
- openfund/core/base_tool.py +0 -58
- openfund/core/factory.py +0 -97
- openfund/core/openfund_old/continuous_klines.py +0 -153
- openfund/core/openfund_old/depth.py +0 -92
- openfund/core/openfund_old/historical_trades.py +0 -123
- openfund/core/openfund_old/index_info.py +0 -67
- openfund/core/openfund_old/index_price_kline.py +0 -118
- openfund/core/openfund_old/klines.py +0 -95
- openfund/core/openfund_old/klines_qrr.py +0 -103
- openfund/core/openfund_old/mark_price.py +0 -121
- openfund/core/openfund_old/mark_price_klines.py +0 -122
- openfund/core/openfund_old/ticker_24hr_price_change.py +0 -99
- openfund/core/pyopenfund.py +0 -85
- openfund/core/services/um_futures_collector.py +0 -142
- openfund/core/sycu_exam/__init__.py +0 -1
- openfund/core/sycu_exam/exam.py +0 -19
- openfund/core/sycu_exam/random_grade_cplus.py +0 -440
- openfund/core/sycu_exam/random_grade_web.py +0 -404
- openfund/core/utils/time_tools.py +0 -25
- openfund_core-0.0.4.dist-info/LICENSE +0 -201
- openfund_core-0.0.4.dist-info/METADATA +0 -67
- openfund_core-0.0.4.dist-info/RECORD +0 -30
- {openfund/core/openfund_old → core}/__init__.py +0 -0
core/Exchange.py
ADDED
@@ -0,0 +1,276 @@
|
|
1
|
+
import logging
|
2
|
+
import time
|
3
|
+
import ccxt
|
4
|
+
import pandas as pd
|
5
|
+
|
6
|
+
|
7
|
+
from decimal import Decimal
|
8
|
+
from core.utils.OPTools import OPTools
|
9
|
+
from ccxt.base.exchange import ConstructorArgs
|
10
|
+
|
11
|
+
|
12
|
+
class Exchange:
|
13
|
+
def __init__(self, config:ConstructorArgs, exchangeKey:str = "okx",) :
|
14
|
+
# 配置交易所
|
15
|
+
self.exchange = getattr(ccxt, exchangeKey)(config)
|
16
|
+
self.logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
|
20
|
+
def getMarket(self, symbol:str):
|
21
|
+
# 配置交易对
|
22
|
+
self.exchange.load_markets()
|
23
|
+
|
24
|
+
return self.exchange.market(symbol)
|
25
|
+
|
26
|
+
def get_tick_size(self,symbol) -> Decimal:
|
27
|
+
|
28
|
+
market = self.getMarket(symbol)
|
29
|
+
if market and 'precision' in market and 'price' in market['precision']:
|
30
|
+
return OPTools.toDecimal(market['precision']['price'])
|
31
|
+
else:
|
32
|
+
raise ValueError(f"{symbol}: 无法从市场数据中获取价格精度")
|
33
|
+
|
34
|
+
def amount_to_precision(self,symbol, contract_size):
|
35
|
+
return self.exchange.amount_to_precision(symbol, contract_size)
|
36
|
+
|
37
|
+
def get_position_mode(self):
|
38
|
+
|
39
|
+
try:
|
40
|
+
# 假设获取账户持仓模式的 API
|
41
|
+
response = self.exchange.private_get_account_config()
|
42
|
+
data = response.get('data', [])
|
43
|
+
if data and isinstance(data, list):
|
44
|
+
# 取列表的第一个元素(假设它是一个字典),然后获取 'posMode'
|
45
|
+
position_mode = data[0].get('posMode', 'single') # 默认值为单向
|
46
|
+
|
47
|
+
return position_mode
|
48
|
+
else:
|
49
|
+
|
50
|
+
return 'single' # 返回默认值
|
51
|
+
except Exception as e:
|
52
|
+
error_message = f"Error fetching position mode: {e}"
|
53
|
+
self.logger.error(error_message)
|
54
|
+
raise Exception(error_message)
|
55
|
+
|
56
|
+
def set_leverage(self,symbol, leverage, mgnMode='isolated',posSide=None):
|
57
|
+
try:
|
58
|
+
# 设置杠杆
|
59
|
+
params = {
|
60
|
+
# 'instId': instId,
|
61
|
+
'leverage': leverage,
|
62
|
+
'marginMode': mgnMode
|
63
|
+
}
|
64
|
+
if posSide:
|
65
|
+
params['side'] = posSide
|
66
|
+
|
67
|
+
self.exchange.set_leverage(leverage, symbol=symbol, params=params)
|
68
|
+
self.logger.info(f"{symbol} Successfully set leverage to {leverage}x")
|
69
|
+
except Exception as e:
|
70
|
+
error_message = f"{symbol} Error setting leverage: {e}"
|
71
|
+
self.logger.error(error_message)
|
72
|
+
raise Exception(error_message)
|
73
|
+
# 获取价格精度
|
74
|
+
def get_precision_length(self,symbol) -> int:
|
75
|
+
tick_size = self.get_tick_size(symbol)
|
76
|
+
return len(f"{tick_size:.15f}".rstrip('0').split('.')[1]) if '.' in f"{tick_size:.15f}" else 0
|
77
|
+
|
78
|
+
def format_price(self, symbol, price:Decimal) -> str:
|
79
|
+
precision = self.get_precision_length(symbol)
|
80
|
+
return f"{price:.{precision}f}"
|
81
|
+
|
82
|
+
def convert_contract(self, symbol, amount, price:Decimal, direction='cost_to_contract'):
|
83
|
+
"""
|
84
|
+
进行合约与币的转换
|
85
|
+
:param symbol: 交易对符号,如 'BTC/USDT:USDT'
|
86
|
+
:param amount: 输入的数量,可以是合约数量或币的数量
|
87
|
+
:param direction: 转换方向,'amount_to_contract' 表示从数量转换为合约,'cost_to_contract' 表示从金额转换为合约
|
88
|
+
:return: 转换后的数量
|
89
|
+
"""
|
90
|
+
|
91
|
+
# 获取合约规模
|
92
|
+
market_contractSize = OPTools.toDecimal(self.getMarket(symbol)['contractSize'])
|
93
|
+
amount = OPTools.toDecimal(amount)
|
94
|
+
if direction == 'amount_to_contract':
|
95
|
+
contract_size = amount / market_contractSize
|
96
|
+
elif direction == 'cost_to_contract':
|
97
|
+
contract_size = amount / price / market_contractSize
|
98
|
+
else:
|
99
|
+
raise Exception(f"{symbol}:{direction} 是无效的转换方向,请输入 'amount_to_contract' 或 'cost_to_contract'。")
|
100
|
+
|
101
|
+
return self.amount_to_precision(symbol, contract_size)
|
102
|
+
|
103
|
+
|
104
|
+
def cancel_all_orders(self, symbol):
|
105
|
+
max_retries = 3
|
106
|
+
retry_count = 0
|
107
|
+
|
108
|
+
while retry_count < max_retries:
|
109
|
+
try:
|
110
|
+
# 获取所有未完成订单
|
111
|
+
params = {
|
112
|
+
# 'instId': instId
|
113
|
+
}
|
114
|
+
open_orders = self.exchange.fetch_open_orders(symbol=symbol, params=params)
|
115
|
+
|
116
|
+
# 批量取消所有订单
|
117
|
+
if open_orders:
|
118
|
+
order_ids = [order['id'] for order in open_orders]
|
119
|
+
self.exchange.cancel_orders(order_ids, symbol, params=params)
|
120
|
+
|
121
|
+
self.logger.debug(f"{symbol}: {order_ids} 挂单取消成功.")
|
122
|
+
else:
|
123
|
+
self.logger.debug(f"{symbol}: 无挂单.")
|
124
|
+
return True
|
125
|
+
|
126
|
+
except Exception as e:
|
127
|
+
retry_count += 1
|
128
|
+
if retry_count == max_retries:
|
129
|
+
error_message = f"{symbol} 取消挂单失败(重试{retry_count}次): {str(e)}"
|
130
|
+
self.logger.error(error_message)
|
131
|
+
raise Exception(error_message)
|
132
|
+
else:
|
133
|
+
self.logger.warning(f"{symbol} 取消挂单失败,正在进行第{retry_count}次重试: {str(e)}")
|
134
|
+
time.sleep(0.1) # 重试前等待0.1秒
|
135
|
+
|
136
|
+
|
137
|
+
def place_order(self, symbol, price: Decimal, amount_usdt, side, leverage=20, order_type='limit'):
|
138
|
+
"""
|
139
|
+
下单
|
140
|
+
Args:
|
141
|
+
symbol: 交易对
|
142
|
+
price: 下单价格
|
143
|
+
amount_usdt: 下单金额
|
144
|
+
side: 下单方向
|
145
|
+
order_type: 订单类型
|
146
|
+
"""
|
147
|
+
# 格式化价格
|
148
|
+
adjusted_price = self.format_price(symbol, price)
|
149
|
+
|
150
|
+
if amount_usdt > 0:
|
151
|
+
if side == 'buy':
|
152
|
+
pos_side = 'long'
|
153
|
+
else:
|
154
|
+
pos_side = 'short'
|
155
|
+
# 设置杠杆
|
156
|
+
self.set_leverage(symbol=symbol, leverage=leverage, mgnMode='isolated',posSide=pos_side)
|
157
|
+
# 20250220 SWAP类型计算合约数量
|
158
|
+
contract_size = self.convert_contract(symbol=symbol, price = OPTools.toDecimal(adjusted_price) ,amount=amount_usdt)
|
159
|
+
|
160
|
+
params = {
|
161
|
+
|
162
|
+
"tdMode": 'isolated',
|
163
|
+
"side": side,
|
164
|
+
"ordType": order_type,
|
165
|
+
"sz": contract_size,
|
166
|
+
"px": adjusted_price
|
167
|
+
}
|
168
|
+
|
169
|
+
# # 模拟盘(demo_trading)需要 posSide
|
170
|
+
# if self.is_demo_trading == 1 :
|
171
|
+
# params["posSide"] = pos_side
|
172
|
+
|
173
|
+
# self.logger.debug(f"---- Order placed params: {params}")
|
174
|
+
try:
|
175
|
+
order = {
|
176
|
+
'symbol': symbol,
|
177
|
+
'side': side,
|
178
|
+
'type': 'limit',
|
179
|
+
'amount': contract_size,
|
180
|
+
'price': adjusted_price,
|
181
|
+
'params': params
|
182
|
+
}
|
183
|
+
# 使用ccxt创建订单
|
184
|
+
self.logger.debug(f"Pre Order placed: {order} ")
|
185
|
+
order_result = self.exchange.create_order(
|
186
|
+
**order
|
187
|
+
# symbol=symbol,
|
188
|
+
# type='limit',
|
189
|
+
# side=side,
|
190
|
+
# amount=amount_usdt,
|
191
|
+
# price=float(adjusted_price),
|
192
|
+
# params=params
|
193
|
+
)
|
194
|
+
# self.logger.debug(f"{symbol} ++ Order placed rs : {order_result}")
|
195
|
+
except Exception as e:
|
196
|
+
error_message = f"{symbol} Failed to place order: {e}"
|
197
|
+
self.logger.error(error_message)
|
198
|
+
raise Exception(error_message)
|
199
|
+
|
200
|
+
self.logger.debug(f"--------- ++ {symbol} Order placed done! --------")
|
201
|
+
|
202
|
+
def fetch_position(self, symbol):
|
203
|
+
"""_summary_
|
204
|
+
|
205
|
+
Args:
|
206
|
+
symbol (_type_): _description_
|
207
|
+
|
208
|
+
Returns:
|
209
|
+
_type_: _description_
|
210
|
+
"""
|
211
|
+
|
212
|
+
max_retries = 3
|
213
|
+
retry_count = 0
|
214
|
+
|
215
|
+
while retry_count < max_retries:
|
216
|
+
try:
|
217
|
+
position = self.exchange.fetch_position(symbol=symbol)
|
218
|
+
if position and position['contracts'] > 0:
|
219
|
+
self.logger.debug(f"{symbol} 有持仓合约数: {position['contracts']}")
|
220
|
+
return position
|
221
|
+
return None
|
222
|
+
except Exception as e:
|
223
|
+
retry_count += 1
|
224
|
+
if retry_count == max_retries:
|
225
|
+
error_message = f"!!{symbol} 获取持仓失败(重试{retry_count}次): {str(e)}"
|
226
|
+
self.logger.error(error_message)
|
227
|
+
raise Exception(error_message)
|
228
|
+
|
229
|
+
self.logger.warning(f"{symbol} 检查持仓失败,正在进行第{retry_count}次重试: {str(e)}")
|
230
|
+
time.sleep(0.1) # 重试前等待0.1秒
|
231
|
+
|
232
|
+
|
233
|
+
def get_historical_klines(self, symbol, bar='15m', limit=300, after:str=None, params={}):
|
234
|
+
"""
|
235
|
+
获取历史K线数据
|
236
|
+
Args:
|
237
|
+
symbol: 交易对
|
238
|
+
bar: K线周期
|
239
|
+
limit: 数据条数
|
240
|
+
after: 之后时间,格式为 "2025-05-21 23:00:00+08:00"
|
241
|
+
"""
|
242
|
+
|
243
|
+
params = {
|
244
|
+
**params,
|
245
|
+
# 'instId': instId,
|
246
|
+
}
|
247
|
+
since = None
|
248
|
+
if after:
|
249
|
+
since = self.exchange.parse8601(after)
|
250
|
+
limit = None
|
251
|
+
if since:
|
252
|
+
params['paginate'] = True
|
253
|
+
|
254
|
+
klines = self.exchange.fetch_ohlcv(symbol, timeframe=bar,since=since, limit=limit, params=params)
|
255
|
+
# if 'data' in response and len(response['data']) > 0:
|
256
|
+
if klines :
|
257
|
+
# return response['data']
|
258
|
+
return klines
|
259
|
+
else:
|
260
|
+
raise Exception(f"{symbol} : Unexpected response structure or missing candlestick data")
|
261
|
+
|
262
|
+
def get_historical_klines_df(self, symbol, bar='15m', limit=300, after:str=None, params={}) -> pd.DataFrame:
|
263
|
+
klines = self.get_historical_klines(symbol, bar=bar, limit=limit, after=after, params=params)
|
264
|
+
return self.format_klines(klines)
|
265
|
+
|
266
|
+
def format_klines(self, klines) -> pd.DataFrame:
|
267
|
+
"""_summary_
|
268
|
+
格式化K线数据
|
269
|
+
Args:
|
270
|
+
klines (_type_): _description_
|
271
|
+
"""
|
272
|
+
klines_df = pd.DataFrame(klines, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
|
273
|
+
# 转换时间戳为日期时间
|
274
|
+
klines_df['timestamp'] = pd.to_datetime(klines_df['timestamp'], unit='ms').dt.tz_localize('UTC').dt.tz_convert('Asia/Shanghai')
|
275
|
+
|
276
|
+
return klines_df
|
core/main.py
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
import logging
|
2
|
+
from pyfiglet import Figlet
|
3
|
+
|
4
|
+
def main():
|
5
|
+
|
6
|
+
# import importlib.metadata
|
7
|
+
# package_name = __package__ or "openfund-core"
|
8
|
+
# version = importlib.metadata.version("openfund-core")
|
9
|
+
|
10
|
+
# 创建日志记录器并设置输出到屏幕
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
console_handler = logging.StreamHandler()
|
13
|
+
logger.addHandler(console_handler)
|
14
|
+
# # 设置日志级别为INFO
|
15
|
+
logger.setLevel(logging.INFO)
|
16
|
+
|
17
|
+
f = Figlet(font="standard") # 字体可选(如 "block", "bubble")
|
18
|
+
logger.info(f"\n{f.renderText("OpenFund Core")}")
|
19
|
+
|
20
|
+
|
21
|
+
|
22
|
+
if __name__ == "__main__":
|
23
|
+
main()
|
core/smc/SMCBase.py
ADDED
@@ -0,0 +1,130 @@
|
|
1
|
+
from decimal import Decimal
|
2
|
+
import logging
|
3
|
+
import pandas as pd
|
4
|
+
import numpy as np
|
5
|
+
from core.utils.OPTools import OPTools
|
6
|
+
|
7
|
+
class SMCBase(object):
|
8
|
+
HIGH_COL = "high"
|
9
|
+
LOW_COL = "low"
|
10
|
+
CLOSE_COL = "close"
|
11
|
+
OPEN_COL = "open"
|
12
|
+
VOLUME_COL = "volume"
|
13
|
+
AMOUNT_COL = "amount"
|
14
|
+
TIMESTAMP_COL = "timestamp"
|
15
|
+
ATR_COL = "atr"
|
16
|
+
|
17
|
+
BUY_SIDE = "buy"
|
18
|
+
SELL_SIDE = "sell"
|
19
|
+
|
20
|
+
|
21
|
+
def __init__(self):
|
22
|
+
self.logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
@staticmethod
|
25
|
+
def check_columns(df: pd.DataFrame, required_columns: list) -> bool:
|
26
|
+
"""
|
27
|
+
检查DataFrame是否包含指定的列
|
28
|
+
参数:
|
29
|
+
df (pd.DataFrame): 要检查的DataFrame
|
30
|
+
columns (list): 要检查的列名列表
|
31
|
+
返回:
|
32
|
+
bool: 如果DataFrame包含所有指定的列,则返回True;否则返回False
|
33
|
+
"""
|
34
|
+
has_pass = all(col in df.columns for col in required_columns)
|
35
|
+
if not has_pass:
|
36
|
+
raise ValueError(f"DataFrame必须包含列: {required_columns}")
|
37
|
+
return has_pass
|
38
|
+
|
39
|
+
@staticmethod
|
40
|
+
def toDecimal(value, precision:int=None) -> Decimal:
|
41
|
+
return OPTools.toDecimal(value, precision)
|
42
|
+
|
43
|
+
@staticmethod
|
44
|
+
def get_precision_length(value) -> int:
|
45
|
+
return len(f"{value:.15f}".rstrip('0').split('.')[1]) if '.' in f"{value:.15f}" else 0
|
46
|
+
|
47
|
+
@staticmethod
|
48
|
+
def calculate_atr(df, period=14, multiplier=2):
|
49
|
+
"""
|
50
|
+
计算增强版ATR指标,等效于Pine Script中的 ta.highest(ta.atr(200),200)*2
|
51
|
+
|
52
|
+
参数:
|
53
|
+
df: 包含OHLCV数据的DataFrame
|
54
|
+
period: ATR计算周期,默认200
|
55
|
+
multiplier: 放大倍数,默认2
|
56
|
+
|
57
|
+
返回:
|
58
|
+
增强版ATR序列
|
59
|
+
"""
|
60
|
+
# df = data.copy()
|
61
|
+
# 计算真实波幅(TR)
|
62
|
+
high = df[SMCBase.HIGH_COL]
|
63
|
+
low = df[SMCBase.LOW_COL]
|
64
|
+
close = df[SMCBase.CLOSE_COL]
|
65
|
+
|
66
|
+
close_prev = close.shift(1)
|
67
|
+
tr = pd.DataFrame({
|
68
|
+
'tr1': high - low,
|
69
|
+
'tr2': abs(high - close_prev),
|
70
|
+
'tr3': abs(low - close_prev)
|
71
|
+
}).max(axis=1)
|
72
|
+
|
73
|
+
# 计算ATR (使用简单移动平均)
|
74
|
+
atr = tr.rolling(window=period, min_periods=1).mean()
|
75
|
+
|
76
|
+
# 计算ATR的N周期最大值
|
77
|
+
max_atr = atr.rolling(window=period, min_periods=1).max()
|
78
|
+
|
79
|
+
|
80
|
+
# 应用放大倍数
|
81
|
+
enhanced_atr = max_atr * multiplier
|
82
|
+
|
83
|
+
return enhanced_atr
|
84
|
+
|
85
|
+
|
86
|
+
|
87
|
+
|
88
|
+
def calculate_atr_with_smoothing(df, length=14, smoothing='RMA'):
|
89
|
+
"""
|
90
|
+
计算ATR (Average True Range) 指标
|
91
|
+
|
92
|
+
参数:
|
93
|
+
df (pd.DataFrame): 包含OHLCV数据的DataFrame,需包含列:['high', 'low', 'close']
|
94
|
+
length (int): 计算周期,默认为14
|
95
|
+
smoothing (str): 平滑方法,支持 'RMA', 'SMA', 'EMA', 'WMA',默认为 'RMA'
|
96
|
+
|
97
|
+
返回:
|
98
|
+
pd.Series: ATR值序列
|
99
|
+
"""
|
100
|
+
# 确保数据包含所需的列
|
101
|
+
required_columns = [SMCBase.HIGH_COL, SMCBase.LOW_COL, SMCBase.CLOSE_COL]
|
102
|
+
SMCBase.check_columns(df, required_columns)
|
103
|
+
|
104
|
+
# 计算真实波幅 (TR)
|
105
|
+
high_low = df[SMCBase.HIGH_COL] - df[SMCBase.LOW_COL]
|
106
|
+
high_close = (df[SMCBase.HIGH_COL] - df[SMCBase.CLOSE_COL].shift()).abs()
|
107
|
+
low_close = (df[SMCBase.LOW_COL] - df[SMCBase.CLOSE_COL].shift()).abs()
|
108
|
+
|
109
|
+
# 计算TR列,取三个值中的最大值
|
110
|
+
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
111
|
+
|
112
|
+
# Apply smoothing
|
113
|
+
if smoothing == 'RMA':
|
114
|
+
# RMA is approximately an EMA with alpha = 1/length
|
115
|
+
atr = tr.ewm(alpha=1/length, adjust=False).mean()
|
116
|
+
elif smoothing == 'SMA':
|
117
|
+
atr = tr.rolling(window=length).mean()
|
118
|
+
elif smoothing == 'EMA':
|
119
|
+
atr = tr.ewm(span=length, adjust=False).mean()
|
120
|
+
elif smoothing == 'WMA':
|
121
|
+
# WMA implementation
|
122
|
+
weights = pd.Series(range(1, length+1))
|
123
|
+
def wma(series):
|
124
|
+
return (series * weights).sum() / weights.sum()
|
125
|
+
atr = tr.rolling(window=length).apply(wma)
|
126
|
+
else:
|
127
|
+
raise ValueError("Invalid smoothing method. Use 'RMA', 'SMA', 'EMA', or 'WMA'")
|
128
|
+
|
129
|
+
return atr
|
130
|
+
|
core/smc/SMCFVG.py
ADDED
@@ -0,0 +1,86 @@
|
|
1
|
+
import logging
|
2
|
+
import pandas as pd
|
3
|
+
|
4
|
+
from core.smc.SMCStruct import SMCStruct
|
5
|
+
|
6
|
+
|
7
|
+
class SMCFVG(SMCStruct):
|
8
|
+
FVG_TOP = "fvg_top"
|
9
|
+
FVG_BOT = "fvg_bot"
|
10
|
+
FVG_MID = "fvg_mid"
|
11
|
+
FVG_SIDE = "fvg_side"
|
12
|
+
FVG_WAS_BALANCED = "fvg_was_balanced"
|
13
|
+
|
14
|
+
def __init__(self):
|
15
|
+
super().__init__()
|
16
|
+
self.logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
def find_FVGs(
|
19
|
+
self, struct: pd.DataFrame, side, check_balanced=True, start_index=-1
|
20
|
+
) -> pd.DataFrame:
|
21
|
+
"""_summary_
|
22
|
+
寻找公允价值缺口
|
23
|
+
Args:
|
24
|
+
data (pd.DataFrame): K线数据
|
25
|
+
side (_type_): 交易方向 'buy'|'sell'
|
26
|
+
threshold (_type_): 阈值价格,通常为溢价和折价区的CE
|
27
|
+
check_balanced (bool): 是否检查FVG是否被平衡过,默认为True
|
28
|
+
start_index (int): 开始查找索引的起点,默认为-1
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
pd.DataFrame: _description_
|
32
|
+
|
33
|
+
"""
|
34
|
+
# bug2.2.5_1,未到折价区,计算FVG需要前一根K线
|
35
|
+
# df = data.copy().iloc[pivot_index:]
|
36
|
+
df = (
|
37
|
+
struct.copy()
|
38
|
+
if start_index == -1
|
39
|
+
else struct.copy().iloc[max(0, start_index - 1) :]
|
40
|
+
)
|
41
|
+
|
42
|
+
# 检查数据中是否包含必要的列
|
43
|
+
check_columns = [self.HIGH_COL, self.LOW_COL]
|
44
|
+
self.check_columns(df, check_columns)
|
45
|
+
|
46
|
+
# 处理公允价值缺口
|
47
|
+
# 使用向量化操作替代apply,提高性能
|
48
|
+
if side == self.BUY_SIDE:
|
49
|
+
condition = df[self.HIGH_COL].shift(1) < df[self.LOW_COL].shift(-1)
|
50
|
+
side_value = "Bullish"
|
51
|
+
price_top = df[self.LOW_COL].shift(-1)
|
52
|
+
price_bot = df[self.HIGH_COL].shift(1)
|
53
|
+
else:
|
54
|
+
condition = df[self.LOW_COL].shift(1) > df[self.HIGH_COL].shift(-1)
|
55
|
+
side_value = "Bearish"
|
56
|
+
price_top = df[self.LOW_COL].shift(1)
|
57
|
+
price_bot = df[self.HIGH_COL].shift(-1)
|
58
|
+
|
59
|
+
df.loc[:, self.FVG_SIDE] = pd.Series(
|
60
|
+
[side_value if x else None for x in condition], index=df.index
|
61
|
+
)
|
62
|
+
df.loc[:, self.FVG_TOP] = price_top.where(condition, 0)
|
63
|
+
df.loc[:, self.FVG_BOT] = price_bot.where(condition, 0)
|
64
|
+
df.loc[:, self.FVG_MID] = (df[self.FVG_TOP] + df[self.FVG_BOT]) / 2
|
65
|
+
|
66
|
+
fvg_df = df[
|
67
|
+
df[self.FVG_SIDE] == "Bullish"
|
68
|
+
if side == self.BUY_SIDE
|
69
|
+
else df[self.FVG_SIDE] == "Bearish"
|
70
|
+
]
|
71
|
+
fvg_df = fvg_df.copy()
|
72
|
+
if check_balanced:
|
73
|
+
# 检查FVG是否被平衡过
|
74
|
+
fvg_df.loc[:, self.FVG_WAS_BALANCED] = fvg_df.apply(
|
75
|
+
lambda row: any(df.loc[row.name + 2 :, self.LOW_COL] <= row[self.FVG_BOT])
|
76
|
+
if side == self.BUY_SIDE
|
77
|
+
else any(
|
78
|
+
df.loc[row.name + 2 :, self.HIGH_COL] >= row[self.FVG_TOP]
|
79
|
+
),
|
80
|
+
axis=1,
|
81
|
+
)
|
82
|
+
|
83
|
+
fvg_df = fvg_df[~fvg_df[self.FVG_WAS_BALANCED]]
|
84
|
+
|
85
|
+
return fvg_df
|
86
|
+
|