vnpy_okx 2025.6.17__py3-none-any.whl → 2025.10.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vnpy_okx/okx_gateway.py CHANGED
@@ -9,7 +9,7 @@ from urllib.parse import urlencode
9
9
  from types import TracebackType
10
10
  from collections.abc import Callable
11
11
 
12
- from vnpy.event import EventEngine
12
+ from vnpy.event import EventEngine, Event, EVENT_TIMER
13
13
  from vnpy.trader.constant import (
14
14
  Direction,
15
15
  Exchange,
@@ -55,9 +55,9 @@ AWS_BUSINESS_HOST: str = "wss://wsaws.okx.com:8443/ws/v5/business"
55
55
 
56
56
  # Demo server hosts
57
57
  DEMO_REST_HOST: str = "https://www.okx.com"
58
- DEMO_PUBLIC_HOST: str = "wss://wspap.okx.com:8443/ws/v5/public?brokerId=9999"
59
- DEMO_PRIVATE_HOST: str = "wss://wspap.okx.com:8443/ws/v5/private?brokerId=9999"
60
- DEMO_BUSINESS_HOST: str = "wss://wspap.okx.com:8443/ws/v5/business?brokerId=9999"
58
+ DEMO_PUBLIC_HOST: str = "wss://wspap.okx.com:8443/ws/v5/public"
59
+ DEMO_PRIVATE_HOST: str = "wss://wspap.okx.com:8443/ws/v5/private"
60
+ DEMO_BUSINESS_HOST: str = "wss://wspap.okx.com:8443/ws/v5/business"
61
61
 
62
62
  # Order status map
63
63
  STATUS_OKX2VT: dict[str, Status] = {
@@ -70,6 +70,7 @@ STATUS_OKX2VT: dict[str, Status] = {
70
70
 
71
71
  # Order type map
72
72
  ORDERTYPE_OKX2VT: dict[str, OrderType] = {
73
+ "market": OrderType.MARKET,
73
74
  "limit": OrderType.LIMIT,
74
75
  "fok": OrderType.FOK,
75
76
  "ioc": OrderType.FAK
@@ -115,6 +116,7 @@ class OkxGateway(BaseGateway):
115
116
  "Server": ["REAL", "AWS", "DEMO"],
116
117
  "Proxy Host": "",
117
118
  "Proxy Port": 0,
119
+ "Spread Trading": ["False", "True"],
118
120
  }
119
121
 
120
122
  exchanges: Exchange = [Exchange.GLOBAL]
@@ -134,6 +136,7 @@ class OkxGateway(BaseGateway):
134
136
  self.server: str = ""
135
137
  self.proxy_host: str = ""
136
138
  self.proxy_port: int = 0
139
+ self.spread_trading: bool = False
137
140
 
138
141
  self.orders: dict[str, OrderData] = {}
139
142
  self.local_orderids: set[str] = set()
@@ -144,6 +147,10 @@ class OkxGateway(BaseGateway):
144
147
  self.rest_api: RestApi = RestApi(self)
145
148
  self.public_api: PublicApi = PublicApi(self)
146
149
  self.private_api: PrivateApi = PrivateApi(self)
150
+ self.business_api: BusinessApi = BusinessApi(self)
151
+
152
+ self.ping_count: int = 0
153
+ self.ping_interval: int = 20
147
154
 
148
155
  def connect(self, setting: dict) -> None:
149
156
  """
@@ -154,7 +161,7 @@ class OkxGateway(BaseGateway):
154
161
 
155
162
  Parameters:
156
163
  setting: A dictionary containing connection parameters including
157
- API credentials, server selection, and proxy configuration
164
+ API credentials, server selection, and proxy configuration.
158
165
  """
159
166
  self.key = setting["API Key"]
160
167
  self.secret = setting["Secret Key"]
@@ -162,6 +169,7 @@ class OkxGateway(BaseGateway):
162
169
  self.server = setting["Server"]
163
170
  self.proxy_host = setting["Proxy Host"]
164
171
  self.proxy_port = setting["Proxy Port"]
172
+ self.spread_trading = setting["Spread Trading"] == "True"
165
173
 
166
174
  self.rest_api.connect(
167
175
  self.key,
@@ -169,7 +177,8 @@ class OkxGateway(BaseGateway):
169
177
  self.passphrase,
170
178
  self.server,
171
179
  self.proxy_host,
172
- self.proxy_port
180
+ self.proxy_port,
181
+ self.spread_trading
173
182
  )
174
183
 
175
184
  def connect_ws_api(self) -> None:
@@ -190,6 +199,18 @@ class OkxGateway(BaseGateway):
190
199
  self.proxy_port,
191
200
  )
192
201
 
202
+ if self.spread_trading:
203
+ self.business_api.connect(
204
+ self.key,
205
+ self.secret,
206
+ self.passphrase,
207
+ self.server,
208
+ self.proxy_host,
209
+ self.proxy_port,
210
+ )
211
+
212
+ self.event_engine.register(EVENT_TIMER, self.process_timer_event)
213
+
193
214
  def subscribe(self, req: SubscribeRequest) -> None:
194
215
  """
195
216
  Subscribe to market data.
@@ -197,7 +218,15 @@ class OkxGateway(BaseGateway):
197
218
  Parameters:
198
219
  req: Subscription request object containing symbol information
199
220
  """
200
- self.public_api.subscribe(req)
221
+ contract: ContractData | None = self.symbol_contract_map.get(req.symbol, None)
222
+ if not contract:
223
+ self.write_log(f"Failed to subscribe data, symbol not found: {req.symbol}")
224
+ return
225
+
226
+ if contract.product == Product.SPREAD:
227
+ self.business_api.subscribe(req)
228
+ else:
229
+ self.public_api.subscribe(req)
201
230
 
202
231
  def send_order(self, req: OrderRequest) -> str:
203
232
  """
@@ -212,7 +241,15 @@ class OkxGateway(BaseGateway):
212
241
  Returns:
213
242
  str: The VeighNa order ID if successful, empty string otherwise
214
243
  """
215
- return self.private_api.send_order(req)
244
+ contract: ContractData | None = self.symbol_contract_map.get(req.symbol, None)
245
+ if not contract:
246
+ self.write_log(f"Failed to send order, symbol not found: {req.symbol}")
247
+ return ""
248
+
249
+ if contract.product == Product.SPREAD:
250
+ return self.business_api.send_order(req)
251
+ else:
252
+ return self.private_api.send_order(req)
216
253
 
217
254
  def cancel_order(self, req: CancelRequest) -> None:
218
255
  """
@@ -224,17 +261,31 @@ class OkxGateway(BaseGateway):
224
261
  Parameters:
225
262
  req: Cancel request object containing order details
226
263
  """
227
- self.private_api.cancel_order(req)
264
+ contract: ContractData | None = self.symbol_contract_map.get(req.symbol, None)
265
+ if not contract:
266
+ self.write_log(f"Failed to cancel order, symbol not found: {req.symbol}")
267
+ return
268
+
269
+ if contract.product == Product.SPREAD:
270
+ self.business_api.cancel_order(req)
271
+ else:
272
+ self.private_api.cancel_order(req)
228
273
 
229
274
  def query_account(self) -> None:
230
275
  """
231
- Not required since OKX provides websocket update for account balances.
276
+ Query account balance.
277
+
278
+ This method is not implemented because OKX provides account balance
279
+ updates through the websocket API.
232
280
  """
233
281
  pass
234
282
 
235
283
  def query_position(self) -> None:
236
284
  """
237
- Not required since OKX provides websocket update for positions.
285
+ Query asset positions.
286
+
287
+ This method is not implemented because OKX provides position updates
288
+ through the websocket API.
238
289
  """
239
290
  pass
240
291
 
@@ -248,7 +299,15 @@ class OkxGateway(BaseGateway):
248
299
  Returns:
249
300
  list[BarData]: List of historical kline data bars
250
301
  """
251
- return self.rest_api.query_history(req)
302
+ contract: ContractData | None = self.symbol_contract_map.get(req.symbol, None)
303
+ if not contract:
304
+ self.write_log(f"Failed to query history, symbol not found: {req.symbol}")
305
+ return []
306
+
307
+ if contract.product == Product.SPREAD:
308
+ return self.rest_api.query_spread_history(req)
309
+ else:
310
+ return self.rest_api.query_history(req)
252
311
 
253
312
  def close(self) -> None:
254
313
  """
@@ -259,10 +318,11 @@ class OkxGateway(BaseGateway):
259
318
  self.rest_api.stop()
260
319
  self.public_api.stop()
261
320
  self.private_api.stop()
321
+ self.business_api.stop()
262
322
 
263
323
  def on_order(self, order: OrderData) -> None:
264
324
  """
265
- Save a copy of order and then push to event engine.
325
+ Cache order data and push an order event.
266
326
 
267
327
  Parameters:
268
328
  order: Order data object
@@ -284,7 +344,7 @@ class OkxGateway(BaseGateway):
284
344
 
285
345
  def on_contract(self, contract: ContractData) -> None:
286
346
  """
287
- Save a copy of contract and then push to event engine.
347
+ Cache contract data and push a contract event.
288
348
 
289
349
  Parameters:
290
350
  contract: Contract data object
@@ -356,6 +416,57 @@ class OkxGateway(BaseGateway):
356
416
  )
357
417
  return order
358
418
 
419
+ def parse_spread_order_data(self, data: dict, gateway_name: str) -> OrderData:
420
+ """
421
+ Parse dict to spread order data.
422
+
423
+ This function converts OKX order data into a VeighNa OrderData object.
424
+ It extracts and maps all relevant fields from the exchange response.
425
+
426
+ Parameters:
427
+ data: Order data from OKX
428
+ gateway_name: Gateway name for identification
429
+
430
+ Returns:
431
+ OrderData: VeighNa order object
432
+ """
433
+ contract: ContractData = self.get_contract_by_name(data["sprdId"])
434
+
435
+ order_id: str = data["clOrdId"]
436
+ if order_id:
437
+ self.local_orderids.add(order_id)
438
+ else:
439
+ order_id = data["ordId"]
440
+
441
+ order: OrderData = OrderData(
442
+ symbol=contract.symbol,
443
+ exchange=Exchange.GLOBAL,
444
+ type=ORDERTYPE_OKX2VT[data["ordType"]],
445
+ orderid=order_id,
446
+ direction=DIRECTION_OKX2VT[data["side"]],
447
+ offset=Offset.NONE,
448
+ traded=float(data["accFillSz"]),
449
+ price=float(data["px"]),
450
+ volume=float(data["sz"]),
451
+ datetime=parse_timestamp(data["cTime"]),
452
+ status=STATUS_OKX2VT[data["state"]],
453
+ gateway_name=gateway_name,
454
+ )
455
+ return order
456
+
457
+ def process_timer_event(self, event: Event) -> None:
458
+ """
459
+ Process timer events for sending heartbeat messages.
460
+ """
461
+ self.ping_count += 1
462
+ if self.ping_count < self.ping_interval:
463
+ return
464
+ self.ping_count = 0
465
+
466
+ self.private_api.send_ping()
467
+ self.public_api.send_ping()
468
+ self.business_api.send_ping()
469
+
359
470
 
360
471
  class RestApi(RestClient):
361
472
  """The REST API of OkxGateway"""
@@ -434,6 +545,7 @@ class RestApi(RestClient):
434
545
  server: str,
435
546
  proxy_host: str,
436
547
  proxy_port: int,
548
+ spread_trading: bool
437
549
  ) -> None:
438
550
  """
439
551
  Start server connection.
@@ -448,6 +560,7 @@ class RestApi(RestClient):
448
560
  server: Server type ("REAL", "AWS", or "DEMO")
449
561
  proxy_host: Proxy server hostname or IP
450
562
  proxy_port: Proxy server port
563
+ spread_trading: Whether to enable spread trading
451
564
  """
452
565
  self.key = key
453
566
  self.secret = secret.encode()
@@ -473,6 +586,9 @@ class RestApi(RestClient):
473
586
  self.query_time()
474
587
  self.query_contract()
475
588
 
589
+ if spread_trading:
590
+ self.query_spread()
591
+
476
592
  def query_time(self) -> None:
477
593
  """
478
594
  Query server time.
@@ -514,6 +630,34 @@ class RestApi(RestClient):
514
630
  params={"instType": inst_type}
515
631
  )
516
632
 
633
+ def query_spread(self) -> None:
634
+ """
635
+ Query available spreads.
636
+
637
+ This function sends a request to get all available spread contracts.
638
+ """
639
+ self.add_request(
640
+ "GET",
641
+ "/api/v5/sprd/spreads",
642
+ callback=self.on_query_spread
643
+ )
644
+
645
+ def query_spread_order(self) -> None:
646
+ """
647
+ Query open spread orders.
648
+
649
+ This function sends a request to get all active orders
650
+ that have not been fully filled or cancelled.
651
+ """
652
+ if not self.key:
653
+ return
654
+
655
+ self.add_request(
656
+ "GET",
657
+ "/api/v5/sprd/orders-pending",
658
+ callback=self.on_query_spread_order,
659
+ )
660
+
517
661
  def on_query_time(self, packet: dict, request: Request) -> None:
518
662
  """
519
663
  Callback of server time query.
@@ -565,6 +709,9 @@ class RestApi(RestClient):
565
709
  """
566
710
  data: list = packet["data"]
567
711
 
712
+ if not data:
713
+ return
714
+
568
715
  for d in data:
569
716
  name: str = d["instId"]
570
717
  product: Product = PRODUCT_OKX2VT[d["instType"]]
@@ -610,16 +757,76 @@ class RestApi(RestClient):
610
757
 
611
758
  self.gateway.on_contract(contract)
612
759
 
613
- self.gateway.write_log(f"{d['instType']} contract data received")
760
+ inst_type: str = request.params["instType"]
761
+ self.gateway.write_log(f"{inst_type} contract data received")
614
762
 
615
763
  # Connect to websocket API after all contract data received
616
- self.product_ready.add(contract.product)
764
+ self.product_ready.add(PRODUCT_OKX2VT[inst_type])
617
765
 
618
766
  if len(self.product_ready) == len(PRODUCT_OKX2VT):
619
767
  self.query_order()
620
768
 
621
769
  self.gateway.connect_ws_api()
622
770
 
771
+ def on_query_spread(self, packet: dict, request: Request) -> None:
772
+ """
773
+ Callback of available contracts query.
774
+
775
+ This function processes the exchange info response and
776
+ creates ContractData objects for each spread contract.
777
+
778
+ Parameters:
779
+ packet: Response data from the server
780
+ request: Original request object
781
+ """
782
+ data: list = packet["data"]
783
+
784
+ for d in data:
785
+ leg_symbols: list[str] = []
786
+ for leg in d["legs"]:
787
+ leg_name: str = leg["instId"]
788
+ leg_contract: ContractData = self.gateway.get_contract_by_name(leg_name)
789
+ leg_symbols.append(leg_contract.symbol)
790
+
791
+ contract: ContractData = ContractData(
792
+ symbol="-".join(leg_symbols),
793
+ exchange=Exchange.GLOBAL,
794
+ name=d["sprdId"],
795
+ product=Product.SPREAD,
796
+ size=float(d["lotSz"]),
797
+ pricetick=float(d["tickSz"]),
798
+ min_volume=float(d["minSz"]),
799
+ history_data=True,
800
+ net_position=True,
801
+ gateway_name=self.gateway_name,
802
+ )
803
+
804
+ self.gateway.on_contract(contract)
805
+
806
+ self.gateway.write_log("Spread contract data received")
807
+
808
+ self.query_spread_order()
809
+
810
+ def on_query_spread_order(self, packet: dict, request: Request) -> None:
811
+ """
812
+ Callback of open spread orders query.
813
+
814
+ This function processes the open orders response and
815
+ creates OrderData objects for each active spread order.
816
+
817
+ Parameters:
818
+ packet: Response data from the server
819
+ request: Original request object
820
+ """
821
+ for order_info in packet["data"]:
822
+ order: OrderData = self.gateway.parse_spread_order_data(
823
+ order_info,
824
+ self.gateway_name
825
+ )
826
+ self.gateway.on_order(order)
827
+
828
+ self.gateway.write_log("Spread order data received")
829
+
623
830
  def on_error(
624
831
  self,
625
832
  exc: type,
@@ -644,18 +851,15 @@ class RestApi(RestClient):
644
851
  msg: str = f"Exception catched by REST API: {detail}"
645
852
  self.gateway.write_log(msg)
646
853
 
647
- def query_history(self, req: HistoryRequest) -> list[BarData]:
854
+ def _query_history(
855
+ self,
856
+ req: HistoryRequest,
857
+ path: str,
858
+ id_key: str,
859
+ bar_parser: Callable[[list, HistoryRequest], BarData]
860
+ ) -> list[BarData]:
648
861
  """
649
- Query kline history data.
650
-
651
- This function sends requests to get historical kline data
652
- for a specific trading instrument and time period.
653
-
654
- Parameters:
655
- req: History request object containing query parameters
656
-
657
- Returns:
658
- list[BarData]: List of historical kline data bars
862
+ Generic helper to query kline history data.
659
863
  """
660
864
  # Validate symbol exists in contract map
661
865
  contract: ContractData | None = self.gateway.get_contract_by_symbol(req.symbol)
@@ -665,8 +869,6 @@ class RestApi(RestClient):
665
869
 
666
870
  # Initialize buffer for storing bars
667
871
  buf: dict[datetime, BarData] = {}
668
-
669
- path: str = "/api/v5/market/history-candles"
670
872
  limit: str = "100"
671
873
 
672
874
  if not req.end:
@@ -678,7 +880,7 @@ class RestApi(RestClient):
678
880
  while True:
679
881
  # Create query params
680
882
  params: dict = {
681
- "instId": contract.name,
883
+ id_key: contract.name,
682
884
  "bar": INTERVAL_VT2OKX[req.interval],
683
885
  "limit": limit,
684
886
  "after": after
@@ -701,28 +903,13 @@ class RestApi(RestClient):
701
903
  bar_data: list = data.get("data", None)
702
904
 
703
905
  if not bar_data:
704
- msg: str = data["msg"]
906
+ msg: str = data.get("msg", "No data returned.")
705
907
  log_msg = f"No kline history data received, {msg}"
908
+ self.gateway.write_log(log_msg)
706
909
  break
707
910
 
708
911
  for row in bar_data:
709
- ts, op, hp, lp, cp, volume, turnover, _, _ = row
710
-
711
- dt: datetime = parse_timestamp(ts)
712
-
713
- bar: BarData = BarData(
714
- symbol=req.symbol,
715
- exchange=req.exchange,
716
- datetime=dt,
717
- interval=req.interval,
718
- volume=float(volume),
719
- turnover=float(turnover),
720
- open_price=float(op),
721
- high_price=float(hp),
722
- low_price=float(lp),
723
- close_price=float(cp),
724
- gateway_name=self.gateway_name
725
- )
912
+ bar: BarData = bar_parser(row, req)
726
913
  buf[bar.datetime] = bar
727
914
 
728
915
  begin: str = bar_data[-1][0]
@@ -746,31 +933,100 @@ class RestApi(RestClient):
746
933
  history: list[BarData] = [buf[i] for i in index]
747
934
  return history
748
935
 
936
+ def query_history(self, req: HistoryRequest) -> list[BarData]:
937
+ """
938
+ Query kline history data.
939
+
940
+ This function sends requests to get historical kline data
941
+ for a specific trading instrument and time period. It queries
942
+ data iteratively until the start time is reached.
749
943
 
750
- class PublicApi(WebsocketClient):
751
- """The public websocket API of OkxGateway"""
944
+ Parameters:
945
+ req: History request object containing query parameters
752
946
 
753
- def __init__(self, gateway: OkxGateway) -> None:
947
+ Returns:
948
+ list[BarData]: List of historical kline data bars
754
949
  """
755
- The init method of the api.
950
+ def parse_bar(row: list, req: HistoryRequest) -> BarData:
951
+ ts, op, hp, lp, cp, volume, turnover, _, _ = row
952
+ dt: datetime = parse_timestamp(ts)
953
+ return BarData(
954
+ symbol=req.symbol,
955
+ exchange=req.exchange,
956
+ datetime=dt,
957
+ interval=req.interval,
958
+ volume=float(volume),
959
+ turnover=float(turnover),
960
+ open_price=float(op),
961
+ high_price=float(hp),
962
+ low_price=float(lp),
963
+ close_price=float(cp),
964
+ gateway_name=self.gateway_name
965
+ )
966
+
967
+ return self._query_history(
968
+ req=req,
969
+ path="/api/v5/market/history-candles",
970
+ id_key="instId",
971
+ bar_parser=parse_bar
972
+ )
973
+
974
+ def query_spread_history(self, req: HistoryRequest) -> list[BarData]:
975
+ """
976
+ Query kline history data for spread contracts.
977
+
978
+ This function sends requests to get historical kline data
979
+ for a specific spread contract and time period. It queries
980
+ data iteratively until the start time is reached.
756
981
 
757
982
  Parameters:
758
- gateway: the parent gateway object for pushing callback data.
983
+ req: History request object containing query parameters
984
+
985
+ Returns:
986
+ list[BarData]: List of historical kline data bars
987
+ """
988
+ def parse_spread_bar(row: list, req: HistoryRequest) -> BarData:
989
+ ts, op, hp, lp, cp, volume, _ = row
990
+ dt: datetime = parse_timestamp(ts)
991
+ return BarData(
992
+ symbol=req.symbol,
993
+ exchange=req.exchange,
994
+ datetime=dt,
995
+ interval=req.interval,
996
+ volume=float(volume),
997
+ open_price=float(op),
998
+ high_price=float(hp),
999
+ low_price=float(lp),
1000
+ close_price=float(cp),
1001
+ gateway_name=self.gateway_name
1002
+ )
1003
+
1004
+ return self._query_history(
1005
+ req=req,
1006
+ path="/api/v5/market/sprd-history-candles",
1007
+ id_key="sprdId",
1008
+ bar_parser=parse_spread_bar
1009
+ )
1010
+
1011
+
1012
+ class WebsocketApi(WebsocketClient):
1013
+ """The base websocket API of OkxGateway"""
1014
+
1015
+ def __init__(self, gateway: OkxGateway, name: str) -> None:
1016
+ """
1017
+ The init method of the api.
759
1018
  """
760
1019
  super().__init__()
761
1020
 
1021
+ self.name: str = name
762
1022
  self.gateway: OkxGateway = gateway
763
1023
  self.gateway_name: str = gateway.gateway_name
764
1024
 
765
- self.subscribed: dict[str, SubscribeRequest] = {}
766
- self.ticks: dict[str, TickData] = {}
767
-
768
- self.callbacks: dict[str, Callable] = {
769
- "tickers": self.on_ticker,
770
- "books5": self.on_depth
771
- }
1025
+ self.connected: bool = False
1026
+ self.callbacks: dict[str, Callable] = {}
1027
+ self.server_hosts: dict[str, str] = {}
772
1028
 
773
- def connect(
1029
+ def connect_(
774
1030
  self,
775
1031
  server: str,
776
1032
  proxy_host: str,
@@ -778,24 +1034,94 @@ class PublicApi(WebsocketClient):
778
1034
  ) -> None:
779
1035
  """
780
1036
  Start server connection.
1037
+ """
1038
+ host: str = self.server_hosts[server]
1039
+ self.init(host, proxy_host, proxy_port, 20)
1040
+ self.start()
1041
+
1042
+ def on_connected(self) -> None:
1043
+ """
1044
+ Callback when server is connected.
1045
+ """
1046
+ self.connected = True
1047
+ self.gateway.write_log(f"{self.name} connected")
1048
+
1049
+ def on_disconnected(self) -> None:
1050
+ """
1051
+ Callback when server is disconnected.
1052
+ """
1053
+ self.connected = False
1054
+ self.gateway.write_log(f"{self.name} disconnected")
1055
+
1056
+ def on_message(self, message: str) -> None:
1057
+ """
1058
+ Callback when websocket app receives new message.
1059
+ """
1060
+ if message == "pong":
1061
+ return
1062
+ self.on_packet(json.loads(message))
1063
+
1064
+ def on_packet(self, packet: dict) -> None:
1065
+ """
1066
+ Callback of data update.
1067
+ """
1068
+ if "event" in packet:
1069
+ cb_name: str = packet["event"]
1070
+ elif "op" in packet:
1071
+ cb_name = packet["op"]
1072
+ elif "arg" in packet and "channel" in packet["arg"]:
1073
+ cb_name = packet["arg"]["channel"]
1074
+ else:
1075
+ return
1076
+
1077
+ callback: Callable | None = self.callbacks.get(cb_name, None)
1078
+ if callback:
1079
+ callback(packet)
1080
+
1081
+ def on_error(self, value: Exception) -> None:
1082
+ """
1083
+ General error callback.
1084
+ """
1085
+ self.gateway.write_log(f"Exception catched by {self.name}: {value}")
1086
+
1087
+ def send_ping(self) -> None:
1088
+ """Send heartbeat ping to server"""
1089
+ if self.connected:
1090
+ self.wsapp.send("ping")
781
1091
 
782
- This method establishes a websocket connection to OKX public data stream.
1092
+
1093
+ class PublicApi(WebsocketApi):
1094
+ """The public websocket API of OkxGateway"""
1095
+
1096
+ def __init__(self, gateway: OkxGateway) -> None:
1097
+ """
1098
+ The init method of the api.
783
1099
 
784
1100
  Parameters:
785
- server: Server type ("REAL", "AWS", or "DEMO")
786
- proxy_host: Proxy server hostname or IP
787
- proxy_port: Proxy server port
1101
+ gateway: the parent gateway object for pushing callback data.
788
1102
  """
789
- server_hosts: dict[str, str] = {
1103
+ super().__init__(gateway, "Public API")
1104
+
1105
+ self.subscribed: dict[str, SubscribeRequest] = {}
1106
+ self.ticks: dict[str, TickData] = {}
1107
+
1108
+ self.callbacks: dict[str, Callable] = {
1109
+ "tickers": self.on_ticker,
1110
+ "books5": self.on_depth,
1111
+ "error": self.on_api_error
1112
+ }
1113
+
1114
+ self.server_hosts: dict[str, str] = {
790
1115
  "REAL": REAL_PUBLIC_HOST,
791
1116
  "AWS": AWS_PUBLIC_HOST,
792
1117
  "DEMO": DEMO_PUBLIC_HOST,
793
1118
  }
794
1119
 
795
- host: str = server_hosts[server]
796
- self.init(host, proxy_host, proxy_port, 20)
797
-
798
- self.start()
1120
+ def connect(self, server: str, proxy_host: str, proxy_port: int) -> None:
1121
+ """
1122
+ Start server connection.
1123
+ """
1124
+ self.connect_(server, proxy_host, proxy_port)
799
1125
 
800
1126
  def subscribe(self, req: SubscribeRequest) -> None:
801
1127
  """
@@ -848,76 +1174,31 @@ class PublicApi(WebsocketClient):
848
1174
  is successfully established. It logs the connection status and
849
1175
  resubscribes to previously subscribed market data channels.
850
1176
  """
851
- self.gateway.write_log("Public API connected")
1177
+ super().on_connected()
852
1178
 
853
1179
  for req in list(self.subscribed.values()):
854
1180
  self.subscribe(req)
855
1181
 
856
- def on_disconnected(self) -> None:
1182
+ def on_api_error(self, packet: dict) -> None:
857
1183
  """
858
- Callback when server is disconnected.
859
-
860
- This function is called when the websocket connection is closed.
861
- It logs the disconnection status.
1184
+ Callback of API error.
862
1185
  """
863
- self.gateway.write_log("Public API disconnected")
1186
+ code: str = packet["code"]
1187
+ msg: str = packet["msg"]
1188
+ self.gateway.write_log(f"{self.name} request failed, status code: {code}, message: {msg}")
864
1189
 
865
- def on_packet(self, packet: dict) -> None:
1190
+ def on_ticker(self, packet: dict) -> None:
866
1191
  """
867
- Callback of data update.
1192
+ Callback of ticker update.
868
1193
 
869
- This function processes different types of market data updates,
870
- including ticker and depth data. It routes the data to the
871
- appropriate callback function based on the channel.
1194
+ This function processes the ticker data updates and
1195
+ updates the corresponding TickData objects.
872
1196
 
873
1197
  Parameters:
874
- packet: JSON data received from websocket
1198
+ packet: Ticker data from websocket
875
1199
  """
876
- if "event" in packet:
877
- event: str = packet["event"]
878
- if event == "subscribe":
879
- return
880
- elif event == "error":
881
- code: str = packet["code"]
882
- msg: str = packet["msg"]
883
- self.gateway.write_log(f"Public API request failed, status code: {code}, message: {msg}")
884
- else:
885
- channel: str = packet["arg"]["channel"]
886
- callback: Callable | None = self.callbacks.get(channel, None)
887
-
888
- if callback:
889
- data: list = packet["data"]
890
- callback(data)
891
-
892
- def on_error(self, exc: type, value: Exception, tb: TracebackType) -> None:
893
- """
894
- General error callback.
895
-
896
- This function is called when an exception occurs in the websocket connection.
897
- It logs the exception details for troubleshooting.
898
-
899
- Parameters:
900
- exc: Type of the exception
901
- value: Exception instance
902
- tb: Traceback object
903
- """
904
- detail: str = self.exception_detail(exc, value, tb)
905
-
906
- msg: str = f"Exception catched by Public API: {detail}"
907
- self.gateway.write_log(msg)
908
-
909
- def on_ticker(self, data: list) -> None:
910
- """
911
- Callback of ticker update.
912
-
913
- This function processes the ticker data updates and
914
- updates the corresponding TickData objects.
915
-
916
- Parameters:
917
- data: Ticker data from websocket
918
- """
919
- for d in data:
920
- tick: TickData = self.ticks[d["instId"]]
1200
+ for d in packet["data"]:
1201
+ tick: TickData = self.ticks[d["instId"]]
921
1202
 
922
1203
  tick.last_price = float(d["last"])
923
1204
  tick.open_price = float(d["open24h"])
@@ -929,7 +1210,7 @@ class PublicApi(WebsocketClient):
929
1210
  tick.datetime = parse_timestamp(d["ts"])
930
1211
  self.gateway.on_tick(copy(tick))
931
1212
 
932
- def on_depth(self, data: list) -> None:
1213
+ def on_depth(self, packet: dict) -> None:
933
1214
  """
934
1215
  Callback of depth update.
935
1216
 
@@ -937,9 +1218,9 @@ class PublicApi(WebsocketClient):
937
1218
  and updates the corresponding TickData objects.
938
1219
 
939
1220
  Parameters:
940
- data: Depth data from websocket
1221
+ packet: Depth data from websocket
941
1222
  """
942
- for d in data:
1223
+ for d in packet["data"]:
943
1224
  tick: TickData = self.ticks[d["instId"]]
944
1225
  bids: list = d["bids"]
945
1226
  asks: list = d["asks"]
@@ -958,7 +1239,7 @@ class PublicApi(WebsocketClient):
958
1239
  self.gateway.on_tick(copy(tick))
959
1240
 
960
1241
 
961
- class PrivateApi(WebsocketClient):
1242
+ class PrivateApi(WebsocketApi):
962
1243
  """The private websocket API of OkxGateway"""
963
1244
 
964
1245
  def __init__(self, gateway: OkxGateway) -> None:
@@ -968,10 +1249,8 @@ class PrivateApi(WebsocketClient):
968
1249
  Parameters:
969
1250
  gateway: the parent gateway object for pushing callback data.
970
1251
  """
971
- super().__init__()
1252
+ super().__init__(gateway, "Private API")
972
1253
 
973
- self.gateway: OkxGateway = gateway
974
- self.gateway_name: str = gateway.gateway_name
975
1254
  self.local_orderids: set[str] = gateway.local_orderids
976
1255
 
977
1256
  self.key: str = ""
@@ -992,6 +1271,12 @@ class PrivateApi(WebsocketClient):
992
1271
  "error": self.on_api_error
993
1272
  }
994
1273
 
1274
+ self.server_hosts: dict[str, str] = {
1275
+ "REAL": REAL_PRIVATE_HOST,
1276
+ "AWS": AWS_PRIVATE_HOST,
1277
+ "DEMO": DEMO_PRIVATE_HOST,
1278
+ }
1279
+
995
1280
  self.reqid_order_map: dict[str, OrderData] = {}
996
1281
 
997
1282
  def connect(
@@ -1022,16 +1307,7 @@ class PrivateApi(WebsocketClient):
1022
1307
 
1023
1308
  self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S"))
1024
1309
 
1025
- server_hosts: dict[str, str] = {
1026
- "REAL": REAL_PRIVATE_HOST,
1027
- "AWS": AWS_PRIVATE_HOST,
1028
- "DEMO": DEMO_PRIVATE_HOST,
1029
- }
1030
-
1031
- host: str = server_hosts[server]
1032
- self.init(host, proxy_host, proxy_port, 20)
1033
-
1034
- self.start()
1310
+ self.connect_(server, proxy_host, proxy_port)
1035
1311
 
1036
1312
  def on_connected(self) -> None:
1037
1313
  """
@@ -1041,53 +1317,9 @@ class PrivateApi(WebsocketClient):
1041
1317
  is successfully established. It logs the connection status and
1042
1318
  initiates the login process.
1043
1319
  """
1044
- self.gateway.write_log("Private websocket API connected")
1320
+ super().on_connected()
1045
1321
  self.login()
1046
1322
 
1047
- def on_disconnected(self) -> None:
1048
- """
1049
- Callback when server is disconnected.
1050
-
1051
- This function is called when the websocket connection is closed.
1052
- It logs the disconnection status.
1053
- """
1054
- self.gateway.write_log("Private API disconnected")
1055
-
1056
- def on_packet(self, packet: dict) -> None:
1057
- """
1058
- Callback of data update.
1059
-
1060
- This function processes different types of private data updates,
1061
- including orders, account balance, and positions. It routes the data
1062
- to the appropriate callback function.
1063
-
1064
- Parameters:
1065
- packet: JSON data received from websocket
1066
- """
1067
- if "event" in packet:
1068
- cb_name: str = packet["event"]
1069
- elif "op" in packet:
1070
- cb_name = packet["op"]
1071
- else:
1072
- cb_name = packet["arg"]["channel"]
1073
-
1074
- callback: Callable | None = self.callbacks.get(cb_name, None)
1075
- if callback:
1076
- callback(packet)
1077
-
1078
- def on_error(self, e: Exception) -> None:
1079
- """
1080
- General error callback.
1081
-
1082
- This function is called when an exception occurs in the websocket connection.
1083
- It logs the exception details for troubleshooting.
1084
-
1085
- Parameters:
1086
- e: The exception that was raised
1087
- """
1088
- msg: str = f"Private channel exception triggered: {e}"
1089
- self.gateway.write_log(msg)
1090
-
1091
1323
  def on_api_error(self, packet: dict) -> None:
1092
1324
  """
1093
1325
  Callback of API error.
@@ -1103,7 +1335,7 @@ class PrivateApi(WebsocketClient):
1103
1335
  msg: str = packet["msg"]
1104
1336
 
1105
1337
  # Log the error with details for debugging
1106
- self.gateway.write_log(f"Private API request failed, status code: {code}, message: {msg}")
1338
+ self.gateway.write_log(f"{self.name} request failed, status code: {code}, message: {msg}")
1107
1339
 
1108
1340
  def on_login(self, packet: dict) -> None:
1109
1341
  """
@@ -1116,10 +1348,10 @@ class PrivateApi(WebsocketClient):
1116
1348
  packet: Login response data from websocket
1117
1349
  """
1118
1350
  if packet["code"] == '0':
1119
- self.gateway.write_log("Private API login successful")
1351
+ self.gateway.write_log(f"{self.name} login successful")
1120
1352
  self.subscribe_topic()
1121
1353
  else:
1122
- self.gateway.write_log("Private API login failed")
1354
+ self.gateway.write_log(f"{self.name} login failed")
1123
1355
 
1124
1356
  def on_order(self, packet: dict) -> None:
1125
1357
  """
@@ -1288,6 +1520,9 @@ class PrivateApi(WebsocketClient):
1288
1520
  This function prepares and sends a login request to authenticate
1289
1521
  with the websocket API using API credentials.
1290
1522
  """
1523
+ if not self.key:
1524
+ return
1525
+
1291
1526
  timestamp: str = str(time.time())
1292
1527
  msg: str = timestamp + "GET" + "/users/self/verify"
1293
1528
  signature: bytes = generate_signature(msg, self.secret)
@@ -1361,7 +1596,7 @@ class PrivateApi(WebsocketClient):
1361
1596
  orderid = f"{self.connect_time}{count_str}"
1362
1597
 
1363
1598
  # Prepare order parameters for OKX API
1364
- args: dict = {
1599
+ arg: dict = {
1365
1600
  "instId": contract.name,
1366
1601
  "clOrdId": orderid,
1367
1602
  "side": DIRECTION_VT2OKX[req.direction],
@@ -1373,16 +1608,16 @@ class PrivateApi(WebsocketClient):
1373
1608
  # Set trading mode based on product type
1374
1609
  # "cash" for spot trading, "cross" for futures/swap with cross margin
1375
1610
  if contract.product == Product.SPOT:
1376
- args["tdMode"] = "cash"
1611
+ arg["tdMode"] = "cash"
1377
1612
  else:
1378
- args["tdMode"] = "cross"
1613
+ arg["tdMode"] = "cross"
1379
1614
 
1380
1615
  # Create websocket request with unique request ID
1381
1616
  self.reqid += 1
1382
1617
  packet: dict = {
1383
1618
  "id": str(self.reqid),
1384
1619
  "op": "order",
1385
- "args": [args]
1620
+ "args": [arg]
1386
1621
  }
1387
1622
  self.send_packet(packet)
1388
1623
 
@@ -1409,24 +1644,510 @@ class PrivateApi(WebsocketClient):
1409
1644
  self.gateway.write_log(f"Cancel order failed, symbol not found: {req.symbol}")
1410
1645
  return
1411
1646
 
1412
- # Initialize cancel parameters with instrument ID
1413
- args: dict = {"instId": contract.name}
1647
+ # Initialize cancel parameters
1648
+ arg: dict = {}
1414
1649
 
1415
1650
  # Determine the type of order ID to use for cancellation
1416
1651
  # OKX supports both client order ID and exchange order ID for cancellation
1417
1652
  if req.orderid in self.local_orderids:
1418
1653
  # Use client order ID if it was created by this gateway instance
1419
- args["clOrdId"] = req.orderid
1654
+ arg["clOrdId"] = req.orderid
1420
1655
  else:
1421
1656
  # Use exchange order ID if it came from another source
1422
- args["ordId"] = req.orderid
1657
+ arg["ordId"] = req.orderid
1423
1658
 
1424
1659
  # Create websocket request with unique request ID
1425
1660
  self.reqid += 1
1426
1661
  packet: dict = {
1427
1662
  "id": str(self.reqid),
1428
1663
  "op": "cancel-order",
1429
- "args": [args]
1664
+ "args": [arg]
1665
+ }
1666
+
1667
+ # Send the cancellation request
1668
+ self.send_packet(packet)
1669
+
1670
+
1671
+ class BusinessApi(WebsocketApi):
1672
+ """The business websocket API of OkxGateway"""
1673
+
1674
+ def __init__(self, gateway: OkxGateway) -> None:
1675
+ """
1676
+ The init method of the api.
1677
+
1678
+ Parameters:
1679
+ gateway: the parent gateway object for pushing callback data.
1680
+ """
1681
+ super().__init__(gateway, "Business API")
1682
+
1683
+ self.local_orderids: set[str] = gateway.local_orderids
1684
+ self.subscribed: dict[str, SubscribeRequest] = {}
1685
+ self.ticks: dict[str, TickData] = {}
1686
+
1687
+ self.key: str = ""
1688
+ self.secret: bytes = b""
1689
+ self.passphrase: str = ""
1690
+
1691
+ self.reqid: int = 0
1692
+ self.order_count: int = 0
1693
+ self.connect_time: int = 0
1694
+
1695
+ self.callbacks: dict[str, Callable] = {
1696
+ "login": self.on_login,
1697
+ "sprd-orders": self.on_order,
1698
+ "sprd-trades": self.on_trade,
1699
+ "sprd-tickers": self.on_ticker,
1700
+ "sprd-books5": self.on_depth,
1701
+ "order": self.on_send_order,
1702
+ "cancel-order": self.on_cancel_order,
1703
+ "error": self.on_api_error
1704
+ }
1705
+
1706
+ self.server_hosts: dict[str, str] = {
1707
+ "REAL": REAL_BUSINESS_HOST,
1708
+ "AWS": AWS_BUSINESS_HOST,
1709
+ "DEMO": DEMO_BUSINESS_HOST,
1710
+ }
1711
+
1712
+ self.reqid_order_map: dict[str, OrderData] = {}
1713
+
1714
+ def connect(
1715
+ self,
1716
+ key: str,
1717
+ secret: str,
1718
+ passphrase: str,
1719
+ server: str,
1720
+ proxy_host: str,
1721
+ proxy_port: int,
1722
+ ) -> None:
1723
+ """
1724
+ Start server connection.
1725
+
1726
+ This method establishes a websocket connection to OKX private data stream.
1727
+
1728
+ Parameters:
1729
+ key: API Key for authentication
1730
+ secret: API Secret for request signing
1731
+ passphrase: API Passphrase for authentication
1732
+ server: Server type ("REAL", "AWS", or "DEMO")
1733
+ proxy_host: Proxy server hostname or IP
1734
+ proxy_port: Proxy server port
1735
+ """
1736
+ self.key = key
1737
+ self.secret = secret.encode()
1738
+ self.passphrase = passphrase
1739
+
1740
+ self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S"))
1741
+
1742
+ self.connect_(server, proxy_host, proxy_port)
1743
+
1744
+ def subscribe(self, req: SubscribeRequest) -> None:
1745
+ """
1746
+ Subscribe to market data.
1747
+
1748
+ This function sends subscription requests for ticker and depth data
1749
+ for the specified trading instrument.
1750
+
1751
+ Parameters:
1752
+ req: Subscription request object containing symbol information
1753
+ """
1754
+ # Get contract by VeighNa symbol
1755
+ contract: ContractData | None = self.gateway.get_contract_by_symbol(req.symbol)
1756
+ if not contract:
1757
+ self.gateway.write_log(f"Failed to subscribe data, symbol not found: {req.symbol}")
1758
+ return
1759
+
1760
+ # Add subscribe record
1761
+ self.subscribed[req.vt_symbol] = req
1762
+
1763
+ # Create tick object
1764
+ tick: TickData = TickData(
1765
+ symbol=req.symbol,
1766
+ exchange=req.exchange,
1767
+ name=contract.name,
1768
+ datetime=datetime.now(CHINA_TZ),
1769
+ gateway_name=self.gateway_name,
1770
+ )
1771
+ self.ticks[contract.name] = tick
1772
+
1773
+ # Send request to subscribe
1774
+ args: list = []
1775
+ for channel in ["sprd-tickers", "sprd-books5"]:
1776
+ args.append({
1777
+ "channel": channel,
1778
+ "sprdId": contract.name
1779
+ })
1780
+
1781
+ packet: dict = {
1782
+ "op": "subscribe",
1783
+ "args": args
1784
+ }
1785
+
1786
+ self.send_packet(packet)
1787
+
1788
+ def on_connected(self) -> None:
1789
+ """
1790
+ Callback when server is connected.
1791
+
1792
+ This function is called when the websocket connection to the server
1793
+ is successfully established. It logs the connection status and
1794
+ initiates the login process.
1795
+ """
1796
+ super().on_connected()
1797
+
1798
+ for req in list(self.subscribed.values()):
1799
+ self.subscribe(req)
1800
+
1801
+ self.login()
1802
+
1803
+ def on_api_error(self, packet: dict) -> None:
1804
+ """
1805
+ Callback of API error.
1806
+
1807
+ This function processes error responses from the websocket API.
1808
+ It logs the error details for troubleshooting.
1809
+
1810
+ Parameters:
1811
+ packet: Error data from websocket
1812
+ """
1813
+ # Extract error code and message from the response
1814
+ code: str = packet["code"]
1815
+ msg: str = packet["msg"]
1816
+
1817
+ # Log the error with details for debugging
1818
+ self.gateway.write_log(f"{self.name} request failed, status code: {code}, message: {msg}")
1819
+
1820
+ def on_login(self, packet: dict) -> None:
1821
+ """
1822
+ Callback of user login.
1823
+
1824
+ This function processes the login response and subscribes to
1825
+ private data channels if login is successful.
1826
+
1827
+ Parameters:
1828
+ packet: Login response data from websocket
1829
+ """
1830
+ if packet["code"] == '0':
1831
+ self.gateway.write_log(f"{self.name} login successful")
1832
+ self.subscribe_topic()
1833
+ else:
1834
+ self.gateway.write_log(f"{self.name} login failed")
1835
+
1836
+ def on_order(self, packet: dict) -> None:
1837
+ """
1838
+ Callback of order update.
1839
+
1840
+ This function processes order updates and trade executions.
1841
+ It creates OrderData and TradeData objects and pushes them to the gateway.
1842
+
1843
+ Parameters:
1844
+ packet: Order update data from websocket
1845
+ """
1846
+ # Extract order data from packet
1847
+ data: list = packet["data"]
1848
+ for d in data:
1849
+ # Create order object from data
1850
+ order: OrderData = self.gateway.parse_spread_order_data(d, self.gateway_name)
1851
+ self.gateway.on_order(order)
1852
+
1853
+ def on_trade(self, packet: dict) -> None:
1854
+ """
1855
+ Callback of trade update.
1856
+
1857
+ This function processes trade updates and creates TradeData objects.
1858
+
1859
+ Parameters:
1860
+ packet: Order update data from websocket
1861
+ """
1862
+ # Extract trade data from packet
1863
+ for d in packet["data"]:
1864
+ # Get order id
1865
+ if d["clOrdId"]:
1866
+ order_id: str = d["clOrdId"]
1867
+ else:
1868
+ order_id = d["ordId"]
1869
+
1870
+ dt: datetime = parse_timestamp(d["ts"])
1871
+
1872
+ for leg in d["legs"]:
1873
+ name: str = leg["instId"]
1874
+ contract: ContractData | None = self.gateway.get_contract_by_name(name)
1875
+ if not contract:
1876
+ self.gateway.write_log(f"Failed to parse trade data, contract not found: {name}")
1877
+ continue
1878
+
1879
+ trade: TradeData = TradeData(
1880
+ symbol=contract.symbol,
1881
+ exchange=Exchange.GLOBAL,
1882
+ orderid=order_id,
1883
+ tradeid=leg["tradeId"],
1884
+ direction=DIRECTION_OKX2VT[leg["side"]],
1885
+ price=float(leg["px"]),
1886
+ volume=float(leg["sz"]),
1887
+ datetime=dt,
1888
+ gateway_name=self.gateway_name,
1889
+ )
1890
+ self.gateway.on_trade(trade)
1891
+
1892
+ def on_send_order(self, packet: dict) -> None:
1893
+ """
1894
+ Callback of send_order.
1895
+
1896
+ This function processes the response to an order placement request.
1897
+ It handles errors and rejection cases.
1898
+
1899
+ Parameters:
1900
+ packet: Order response data from websocket
1901
+ """
1902
+ data: list = packet["data"]
1903
+
1904
+ # Wrong parameters
1905
+ if packet["code"] != "0":
1906
+ if not data:
1907
+ order: OrderData | None = self.reqid_order_map.get(packet["id"], None)
1908
+ if order:
1909
+ order.status = Status.REJECTED
1910
+ self.gateway.on_order(order)
1911
+
1912
+ return
1913
+
1914
+ # Failed to process
1915
+ for d in data:
1916
+ code: str = d["sCode"]
1917
+ if code == "0":
1918
+ return
1919
+
1920
+ orderid: str = d["clOrdId"]
1921
+ order = self.gateway.get_order(orderid)
1922
+ if not order:
1923
+ return
1924
+
1925
+ order.status = Status.REJECTED
1926
+ self.gateway.on_order(copy(order))
1927
+
1928
+ msg: str = d["sMsg"]
1929
+ self.gateway.write_log(f"Send order failed, status code: {code}, message: {msg}")
1930
+
1931
+ def on_cancel_order(self, packet: dict) -> None:
1932
+ """
1933
+ Callback of cancel_order.
1934
+
1935
+ This function processes the response to an order cancellation request.
1936
+ It handles errors and logs appropriate messages.
1937
+
1938
+ Parameters:
1939
+ packet: Cancel response data from websocket
1940
+ """
1941
+ # Wrong parameters
1942
+ if packet["code"] != "0":
1943
+ code: str = packet["code"]
1944
+ msg: str = packet["msg"]
1945
+ self.gateway.write_log(f"Cancel order failed, status code: {code}, message: {msg}")
1946
+ return
1947
+
1948
+ # Failed to process
1949
+ data: list = packet["data"]
1950
+ for d in data:
1951
+ code = d["sCode"]
1952
+ if code == "0":
1953
+ return
1954
+
1955
+ msg = d["sMsg"]
1956
+ self.gateway.write_log(f"Cancel order failed, status code: {code}, message: {msg}")
1957
+
1958
+ def on_ticker(self, packet: dict) -> None:
1959
+ """
1960
+ Callback of ticker update.
1961
+
1962
+ This function processes the ticker data updates and
1963
+ updates the corresponding TickData objects.
1964
+
1965
+ Parameters:
1966
+ packet: Ticker data from websocket
1967
+ """
1968
+ for d in packet["data"]:
1969
+ if not d["last"]:
1970
+ return
1971
+
1972
+ tick: TickData = self.ticks[d["sprdId"]]
1973
+
1974
+ tick.last_price = float(d["last"])
1975
+ tick.last_volume = float(d["lastSz"])
1976
+ tick.open_price = float(d["open24h"])
1977
+ tick.high_price = float(d["high24h"])
1978
+ tick.low_price = float(d["low24h"])
1979
+ tick.volume = float(d["vol24h"])
1980
+ tick.datetime = parse_timestamp(d["ts"])
1981
+
1982
+ self.gateway.on_tick(copy(tick))
1983
+
1984
+ def on_depth(self, packet: dict) -> None:
1985
+ """
1986
+ Callback of depth update.
1987
+
1988
+ This function processes the order book depth data updates
1989
+ and updates the corresponding TickData objects.
1990
+
1991
+ Parameters:
1992
+ packet: Depth data from websocket
1993
+ """
1994
+ name: str = packet["arg"]["sprdId"]
1995
+ tick: TickData = self.ticks[name]
1996
+
1997
+ for d in packet["data"]:
1998
+ bids: list = d["bids"]
1999
+ asks: list = d["asks"]
2000
+
2001
+ for n in range(min(5, len(bids))):
2002
+ price, volume, _ = bids[n]
2003
+ tick.__setattr__("bid_price_%s" % (n + 1), float(price))
2004
+ tick.__setattr__("bid_volume_%s" % (n + 1), float(volume))
2005
+
2006
+ for n in range(min(5, len(asks))):
2007
+ price, volume, _ = asks[n]
2008
+ tick.__setattr__("ask_price_%s" % (n + 1), float(price))
2009
+ tick.__setattr__("ask_volume_%s" % (n + 1), float(volume))
2010
+
2011
+ tick.datetime = parse_timestamp(d["ts"])
2012
+ self.gateway.on_tick(copy(tick))
2013
+
2014
+ def login(self) -> None:
2015
+ """
2016
+ User login.
2017
+
2018
+ This function prepares and sends a login request to authenticate
2019
+ with the websocket API using API credentials.
2020
+ """
2021
+ if not self.key:
2022
+ return
2023
+
2024
+ timestamp: str = str(time.time())
2025
+ msg: str = timestamp + "GET" + "/users/self/verify"
2026
+ signature: bytes = generate_signature(msg, self.secret)
2027
+
2028
+ packet: dict = {
2029
+ "op": "login",
2030
+ "args":
2031
+ [
2032
+ {
2033
+ "apiKey": self.key,
2034
+ "passphrase": self.passphrase,
2035
+ "timestamp": timestamp,
2036
+ "sign": signature.decode("utf-8")
2037
+ }
2038
+ ]
2039
+ }
2040
+ self.send_packet(packet)
2041
+
2042
+ def subscribe_topic(self) -> None:
2043
+ """
2044
+ Subscribe to private data channels.
2045
+
2046
+ This function sends subscription requests for order, account, and
2047
+ position updates after successful login.
2048
+ """
2049
+ packet: dict = {
2050
+ "op": "subscribe",
2051
+ "args": [
2052
+ {
2053
+ "channel": "sprd-orders"
2054
+ },
2055
+ {
2056
+ "channel": "sprd-trades"
2057
+ }
2058
+ ]
2059
+ }
2060
+ self.send_packet(packet)
2061
+
2062
+ def send_order(self, req: OrderRequest) -> str:
2063
+ """
2064
+ Send new order to OKX.
2065
+
2066
+ This function creates and sends a new order request to the exchange.
2067
+ It handles different order types and trading modes.
2068
+
2069
+ Parameters:
2070
+ req: Order request object containing order details
2071
+
2072
+ Returns:
2073
+ str: The VeighNa order ID if successful, empty string otherwise
2074
+ """
2075
+ # Validate order type is supported by OKX
2076
+ if req.type not in ORDERTYPE_VT2OKX:
2077
+ self.gateway.write_log(f"Send order failed, order type not supported: {req.type.value}")
2078
+ return ""
2079
+
2080
+ # Validate symbol exists in contract map
2081
+ contract: ContractData | None = self.gateway.get_contract_by_symbol(req.symbol)
2082
+ if not contract:
2083
+ self.gateway.write_log(f"Send order failed, symbol not found: {req.symbol}")
2084
+ return ""
2085
+
2086
+ # Generate unique local order ID
2087
+ self.order_count += 1
2088
+ count_str = str(self.order_count).rjust(6, "0")
2089
+ orderid = f"{self.connect_time}{count_str}"
2090
+
2091
+ # Prepare order parameters for OKX API
2092
+ arg: dict = {
2093
+ "sprdId": contract.name,
2094
+ "clOrdId": orderid,
2095
+ "side": DIRECTION_VT2OKX[req.direction],
2096
+ "ordType": ORDERTYPE_VT2OKX[req.type],
2097
+ "px": str(req.price),
2098
+ "sz": str(req.volume)
2099
+ }
2100
+
2101
+ # Create websocket request with unique request ID
2102
+ self.reqid += 1
2103
+ packet: dict = {
2104
+ "id": str(self.reqid),
2105
+ "op": "sprd-order",
2106
+ "args": [arg]
2107
+ }
2108
+ self.send_packet(packet)
2109
+
2110
+ # Create order data object and push to gateway
2111
+ order: OrderData = req.create_order_data(orderid, self.gateway_name)
2112
+ self.gateway.on_order(order)
2113
+
2114
+ # Return VeighNa order ID (gateway_name.orderid)
2115
+ return str(order.vt_orderid)
2116
+
2117
+ def cancel_order(self, req: CancelRequest) -> None:
2118
+ """
2119
+ Cancel existing order on OKX.
2120
+
2121
+ This function sends a request to cancel an existing order on the exchange.
2122
+ It determines whether to use client order ID or exchange order ID.
2123
+
2124
+ Parameters:
2125
+ req: Cancel request object containing order details
2126
+ """
2127
+ # Validate symbol exists in contract map
2128
+ contract: ContractData | None = self.gateway.get_contract_by_symbol(req.symbol)
2129
+ if not contract:
2130
+ self.gateway.write_log(f"Cancel order failed, symbol not found: {req.symbol}")
2131
+ return
2132
+
2133
+ # Initialize cancel parameters
2134
+ arg: dict = {}
2135
+
2136
+ # Determine the type of order ID to use for cancellation
2137
+ # OKX supports both client order ID and exchange order ID for cancellation
2138
+ if req.orderid in self.local_orderids:
2139
+ # Use client order ID if it was created by this gateway instance
2140
+ arg["clOrdId"] = req.orderid
2141
+ else:
2142
+ # Use exchange order ID if it came from another source
2143
+ arg["ordId"] = req.orderid
2144
+
2145
+ # Create websocket request with unique request ID
2146
+ self.reqid += 1
2147
+ packet: dict = {
2148
+ "id": str(self.reqid),
2149
+ "op": "sprd-cancel-order",
2150
+ "args": [arg]
1430
2151
  }
1431
2152
 
1432
2153
  # Send the cancellation request