tristero 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tristero/client.py CHANGED
@@ -1,30 +1,61 @@
1
1
  import asyncio
2
2
  from dataclasses import dataclass
3
+ from enum import Enum
3
4
  import json
4
5
  import logging
5
- from typing import Any, Optional, TypeVar
6
+ import os
7
+ import ssl
8
+ from typing import Any, Dict, List, Literal, Optional, TypeVar, Union
6
9
 
7
10
  from eth_account.signers.local import LocalAccount
8
11
  from pydantic import BaseModel
9
12
 
10
- from tristero.api import fill_order, fill_order_feather, get_quote, poll_updates, poll_updates_feather
11
- from .permit2 import Permit2Order, create_permit2_order
13
+ from tristero.api import (
14
+ fill_order,
15
+ fill_order_feather,
16
+ get_quote,
17
+ poll_updates,
18
+ poll_updates_feather,
19
+ get_margin_quote,
20
+ submit_margin_order,
21
+ submit_close_position,
22
+ get_margin_positions,
23
+ get_margin_position,
24
+ poll_margin_updates,
25
+ )
26
+ from .permit2 import (
27
+ Permit2Order,
28
+ create_permit2_order,
29
+ sign_margin_order,
30
+ sign_close_position,
31
+ )
12
32
  from .data import ChainID
13
33
  from web3 import AsyncBaseProvider, AsyncWeb3
14
- from web3 import AsyncWeb3
15
34
  from tenacity import (
16
35
  retry,
17
36
  stop_after_attempt,
18
37
  wait_exponential,
19
38
  retry_if_exception_type,
20
39
  )
40
+
41
+ import certifi
42
+
21
43
  logger = logging.getLogger(__name__)
22
44
 
23
45
  P = TypeVar("P", bound=AsyncBaseProvider)
24
46
 
47
+
48
+ class OrderType(str, Enum):
49
+ """Order type enum for unified API."""
50
+ SWAP = "swap"
51
+ FEATHER = "feather"
52
+ MARGIN = "margin"
53
+
54
+
25
55
  class WebSocketClosedError(Exception):
26
56
  pass
27
57
 
58
+
28
59
  class SwapException(Exception):
29
60
  """Base exception for all swap-related errors."""
30
61
  pass
@@ -42,10 +73,70 @@ class OrderFailedException(SwapException):
42
73
  self.order_id = order_id
43
74
  self.details = details
44
75
 
76
+
77
+ class MarginException(Exception):
78
+ """Exception for margin-related errors."""
79
+ pass
80
+
81
+
45
82
  class TokenSpec(BaseModel, frozen=True):
46
83
  chain_id: ChainID
47
84
  token_address: str
48
85
 
86
+
87
+ def make_async_w3(rpc_url: str) -> AsyncWeb3:
88
+ insecure_ssl = os.getenv("TRISTERO_INSECURE_SSL", "").strip().lower() in {"1", "true", "yes", "y", "on"}
89
+ if insecure_ssl:
90
+ provider = AsyncWeb3.AsyncHTTPProvider(rpc_url, request_kwargs={"ssl": False})
91
+ return AsyncWeb3(provider)
92
+
93
+ ca_file = os.getenv("TRISTERO_SSL_CA_FILE", "").strip()
94
+ if not ca_file:
95
+ ca_file = certifi.where()
96
+
97
+ if ca_file:
98
+ ssl_context = ssl.create_default_context(cafile=ca_file)
99
+ provider = AsyncWeb3.AsyncHTTPProvider(rpc_url, request_kwargs={"ssl": ssl_context})
100
+ return AsyncWeb3(provider)
101
+
102
+ provider = AsyncWeb3.AsyncHTTPProvider(rpc_url)
103
+ return AsyncWeb3(provider)
104
+
105
+
106
+ @dataclass
107
+ class QuoteResult:
108
+ """Result from get_quote operations."""
109
+ quote_data: Dict[str, Any]
110
+ order_type: str
111
+
112
+ @property
113
+ def is_margin(self) -> bool:
114
+ return self.order_type == "MARGIN"
115
+
116
+ @property
117
+ def is_feather(self) -> bool:
118
+ return self.order_type == "FEATHER"
119
+
120
+
121
+ @dataclass
122
+ class OrderResult:
123
+ """Result from order submission."""
124
+ order_id: str
125
+ order_type: str
126
+ response: Dict[str, Any]
127
+
128
+
129
+ @dataclass
130
+ class MarginPosition:
131
+ """Margin position data."""
132
+ id: str
133
+ status: str
134
+ escrow_address: str
135
+ filler_address: str
136
+ taker_token_id: int
137
+ data: Dict[str, Any]
138
+
139
+
49
140
  async def wait_for_feather_completion(order_id: str):
50
141
  ws = await poll_updates_feather(order_id)
51
142
  try:
@@ -63,67 +154,159 @@ async def wait_for_feather_completion(order_id: str):
63
154
  )
64
155
  if status in ['Expired']:
65
156
  await ws.close()
66
- raise Exception(f"Swap failed: {ws.close_reason} {msg}")
157
+ raise SwapException(f"Swap failed: {ws.close_reason} {msg}")
67
158
  elif status in ['Finalized']:
68
159
  await ws.close()
69
160
  return msg
70
- # If we exit the loop without completed/failed, raise to retry
71
161
  raise WebSocketClosedError("WebSocket closed without completion status")
72
162
  except Exception as e:
73
- # Close cleanly if still open
74
163
  if not ws.close_code:
75
164
  await ws.close()
76
165
  raise
77
166
 
167
+
78
168
  async def wait_for_permit2_completion(order_id: str):
79
169
  ws = await poll_updates(order_id)
80
170
  try:
81
171
  async for msg in ws:
82
172
  msg = json.loads(msg)
173
+ if isinstance(msg, dict) and ("failed" in msg or "completed" in msg):
174
+ logger.info(
175
+ {
176
+ "message": f"failed={msg.get('failed')} completed={msg.get('completed')}",
177
+ "id": "order_update",
178
+ "payload": msg,
179
+ }
180
+ )
181
+ if msg.get("failed"):
182
+ await ws.close()
183
+ raise SwapException(f"Swap failed: {ws.close_reason} {msg}")
184
+ elif msg.get("completed"):
185
+ await ws.close()
186
+ return msg
187
+ continue
188
+
189
+ if isinstance(msg, dict):
190
+ fill_tx = msg.get("fill_tx") or msg.get("fillTx")
191
+ if fill_tx:
192
+ await ws.close()
193
+ return msg
194
+
195
+ if msg.get("amount_repaid") is not None or msg.get("amount_settled") is not None:
196
+ await ws.close()
197
+ return msg
198
+
199
+ status = msg.get("status") if isinstance(msg, dict) else None
83
200
  logger.info(
84
201
  {
85
- "message": f"failed={msg['failed']} completed={msg['completed']}",
202
+ "message": f"status={status}",
86
203
  "id": "order_update",
87
204
  "payload": msg,
88
205
  }
89
206
  )
90
- if msg["failed"]:
207
+ if status in ["Failed", "Expired", "Rejected"]:
91
208
  await ws.close()
92
- raise Exception(f"Swap failed: {ws.close_reason} {msg}")
93
- elif msg["completed"]:
209
+ raise SwapException(f"Swap failed: {ws.close_reason} {msg}")
210
+ elif status in ["Finalized", "Filled", "Completed"]:
94
211
  await ws.close()
95
212
  return msg
96
213
 
97
- # If we exit the loop without completed/failed, raise to retry
98
214
  raise WebSocketClosedError("WebSocket closed without completion status")
99
215
  except Exception:
100
- # Close cleanly if still open
101
216
  if not ws.close_code:
102
217
  await ws.close()
103
218
  raise
104
219
 
105
- async def wait_for_completion(order_id: str, feather: bool):
106
- if feather:
220
+
221
+ async def wait_for_margin_completion(order_id: str):
222
+ """Wait for margin order completion via WebSocket."""
223
+ ws = await poll_margin_updates(order_id)
224
+ try:
225
+ async for msg in ws:
226
+ if not msg:
227
+ continue
228
+ msg = json.loads(msg)
229
+ if isinstance(msg, dict) and ("failed" in msg or "completed" in msg):
230
+ logger.info(
231
+ {
232
+ "message": f"failed={msg.get('failed')} completed={msg.get('completed')}",
233
+ "id": "margin_order_update",
234
+ "payload": msg,
235
+ }
236
+ )
237
+ if msg.get("failed"):
238
+ await ws.close()
239
+ raise MarginException(f"Margin order failed: {msg}")
240
+ if msg.get("completed"):
241
+ await ws.close()
242
+ return msg
243
+ continue
244
+
245
+ status = msg.get('status', '') if isinstance(msg, dict) else ''
246
+ logger.info(
247
+ {
248
+ "message": f"margin status={status}",
249
+ "id": "margin_order_update",
250
+ "payload": msg,
251
+ }
252
+ )
253
+ if status in ['Failed', 'Expired', 'Rejected']:
254
+ await ws.close()
255
+ raise MarginException(f"Margin order failed: {msg}")
256
+ elif status in ['Filled', 'Finalized', 'Completed']:
257
+ await ws.close()
258
+ return msg
259
+ raise WebSocketClosedError("WebSocket closed without completion status")
260
+ except Exception as e:
261
+ if not ws.close_code:
262
+ await ws.close()
263
+ raise
264
+
265
+
266
+ async def wait_for_completion(order_id: str, order_type: Union[str, OrderType] = OrderType.SWAP):
267
+ """
268
+ Wait for order completion.
269
+
270
+ Args:
271
+ order_id: Order ID to wait for
272
+ order_type: Type of order (swap, feather, margin)
273
+ """
274
+ if isinstance(order_type, str):
275
+ order_type = order_type.lower()
276
+ else:
277
+ order_type = order_type.value
278
+
279
+ if order_type == "feather":
107
280
  return await wait_for_feather_completion(order_id)
281
+ elif order_type == "margin":
282
+ return await wait_for_margin_completion(order_id)
108
283
  else:
109
284
  return await wait_for_permit2_completion(order_id)
110
285
 
286
+
111
287
  @retry(
112
288
  stop=stop_after_attempt(3),
113
289
  wait=wait_exponential(multiplier=1, min=4, max=10),
114
290
  retry=retry_if_exception_type(WebSocketClosedError)
115
291
  )
116
- async def wait_for_completion_with_retry(order_id: str, feather: bool = False):
117
- return await wait_for_completion(order_id, feather)
292
+ async def wait_for_completion_with_retry(
293
+ order_id: str,
294
+ order_type: Union[str, OrderType] = OrderType.SWAP
295
+ ):
296
+ """Wait for order completion with automatic retry on WebSocket failures."""
297
+ return await wait_for_completion(order_id, order_type)
298
+
118
299
 
119
300
  @dataclass
120
301
  class FeatherSwapResult:
121
302
  deposit_address: str
122
303
  data: Any
123
304
 
305
+
124
306
  class FeatherException(Exception):
125
307
  pass
126
308
 
309
+
127
310
  async def start_feather_swap(
128
311
  src_t: TokenSpec,
129
312
  dst_t: TokenSpec,
@@ -131,15 +314,16 @@ async def start_feather_swap(
131
314
  raw_amount: int,
132
315
  client_id: str = ''
133
316
  ):
134
- resp = await fill_order_feather(client_id, str(src_t.chain_id.value), str(dst_t.chain_id), dst_addr, raw_amount)
317
+ resp = await fill_order_feather(str(src_t.chain_id.value), str(dst_t.chain_id.value), dst_addr, raw_amount, client_id)
135
318
  if resp['detail']:
136
319
  raise FeatherException(resp)
137
320
  else:
138
321
  return FeatherSwapResult(
139
- deposit_address = resp['deposit_address'],
140
- data = resp
322
+ deposit_address=resp['deposit_address'],
323
+ data=resp
141
324
  )
142
325
 
326
+
143
327
  async def start_permit2_swap(
144
328
  w3: AsyncWeb3[P],
145
329
  account: LocalAccount,
@@ -164,6 +348,7 @@ async def start_permit2_swap(
164
348
  order_id = response['id']
165
349
  return order_id
166
350
 
351
+
167
352
  async def execute_permit2_swap(
168
353
  w3: AsyncWeb3[P],
169
354
  account: LocalAccount,
@@ -181,12 +366,303 @@ async def execute_permit2_swap(
181
366
 
182
367
  try:
183
368
  if timeout is None:
184
- return await waiter(order_id, False)
369
+ return await waiter(order_id, OrderType.SWAP)
185
370
 
186
371
  return await asyncio.wait_for(
187
- waiter(order_id),
372
+ waiter(order_id, OrderType.SWAP),
188
373
  timeout=timeout
189
374
  )
190
375
  except asyncio.TimeoutError as exc:
191
376
  raise StuckException(f"Swap {order_id} timed out after {timeout}s") from exc
192
377
 
378
+
379
+ async def execute_swap(
380
+ w3: AsyncWeb3[P],
381
+ account: LocalAccount,
382
+ src_t: TokenSpec,
383
+ dst_t: TokenSpec,
384
+ raw_amount: int,
385
+ retry: bool = True,
386
+ timeout: Optional[float] = None,
387
+ ) -> dict[str, Any]:
388
+ return await execute_permit2_swap(
389
+ w3=w3,
390
+ account=account,
391
+ src_t=src_t,
392
+ dst_t=dst_t,
393
+ raw_amount=raw_amount,
394
+ retry=retry,
395
+ timeout=timeout,
396
+ )
397
+
398
+
399
+ async def request_margin_quote(
400
+ chain_id: str,
401
+ wallet_address: str,
402
+ quote_currency: str,
403
+ base_currency: str,
404
+ leverage_ratio: int,
405
+ collateral_amount: str,
406
+ ) -> QuoteResult:
407
+ """
408
+ Request a quote for opening a margin position.
409
+
410
+ Args:
411
+ chain_id: Chain ID (e.g., "42161" for Arbitrum)
412
+ wallet_address: User's wallet address
413
+ quote_currency: Quote currency token address (e.g., USDC)
414
+ base_currency: Base currency token address (e.g., WETH)
415
+ leverage_ratio: Leverage ratio (e.g., 2 for 2x leverage)
416
+ collateral_amount: Collateral amount in raw units
417
+
418
+ Returns:
419
+ QuoteResult containing quote data
420
+ """
421
+ quote = await get_margin_quote(
422
+ chain_id=chain_id,
423
+ wallet_address=wallet_address,
424
+ quote_currency=quote_currency,
425
+ base_currency=base_currency,
426
+ leverage_ratio=leverage_ratio,
427
+ collateral_amount=collateral_amount,
428
+ )
429
+ return QuoteResult(
430
+ quote_data=quote,
431
+ order_type="MARGIN",
432
+ )
433
+
434
+
435
+ def sign_order(
436
+ quote: QuoteResult,
437
+ private_key: str,
438
+ ) -> Dict[str, Any]:
439
+ """
440
+ Sign a quote to create a submittable order.
441
+
442
+ Args:
443
+ quote: Quote result from request_margin_quote
444
+ private_key: Private key for signing
445
+
446
+ Returns:
447
+ Signed order payload
448
+ """
449
+ if quote.is_margin:
450
+ return sign_margin_order(quote.quote_data, private_key)
451
+ else:
452
+ raise ValueError(f"Unsupported order type for signing: {quote.order_type}")
453
+
454
+
455
+ async def submit_order(signed_order: Dict[str, Any]) -> OrderResult:
456
+ """
457
+ Submit a signed order.
458
+
459
+ Args:
460
+ signed_order: Signed order from sign_order
461
+
462
+ Returns:
463
+ OrderResult with order_id
464
+ """
465
+ response = await submit_margin_order(signed_order)
466
+ return OrderResult(
467
+ order_id=response.get("id", response.get("order_id", "")),
468
+ order_type="MARGIN",
469
+ response=response,
470
+ )
471
+
472
+
473
+ async def open_margin_position(
474
+ private_key: str,
475
+ chain_id: str,
476
+ wallet_address: str,
477
+ quote_currency: str,
478
+ base_currency: str,
479
+ leverage_ratio: int,
480
+ collateral_amount: str,
481
+ wait_for_result: bool = True,
482
+ timeout: Optional[float] = None,
483
+ ) -> Dict[str, Any]:
484
+ """
485
+ Open a margin position in one call (quote + sign + submit + optionally wait).
486
+
487
+ Args:
488
+ private_key: Private key for signing
489
+ chain_id: Chain ID
490
+ wallet_address: User's wallet address
491
+ quote_currency: Quote currency token address
492
+ base_currency: Base currency token address
493
+ leverage_ratio: Leverage ratio
494
+ collateral_amount: Collateral amount in raw units
495
+ wait_for_result: Whether to wait for order completion
496
+ timeout: Optional timeout in seconds
497
+
498
+ Returns:
499
+ Order result or completion result
500
+ """
501
+ quote = await request_margin_quote(
502
+ chain_id=chain_id,
503
+ wallet_address=wallet_address,
504
+ quote_currency=quote_currency,
505
+ base_currency=base_currency,
506
+ leverage_ratio=leverage_ratio,
507
+ collateral_amount=collateral_amount,
508
+ )
509
+
510
+ signed = sign_order(quote, private_key)
511
+ result = await submit_order(signed)
512
+
513
+ logger.info(f"Margin order submitted: {result.order_id}")
514
+
515
+ if not wait_for_result:
516
+ return {
517
+ "order_id": result.order_id,
518
+ "quote": quote.quote_data,
519
+ "response": result.response,
520
+ }
521
+
522
+ try:
523
+ if timeout:
524
+ completion = await asyncio.wait_for(
525
+ wait_for_completion_with_retry(result.order_id, OrderType.MARGIN),
526
+ timeout=timeout
527
+ )
528
+ else:
529
+ completion = await wait_for_completion_with_retry(result.order_id, OrderType.MARGIN)
530
+
531
+ return {
532
+ "order_id": result.order_id,
533
+ "quote": quote.quote_data,
534
+ "response": result.response,
535
+ "completion": completion,
536
+ }
537
+ except asyncio.TimeoutError as exc:
538
+ raise StuckException(f"Margin order {result.order_id} timed out after {timeout}s") from exc
539
+
540
+
541
+ async def close_margin_position(
542
+ private_key: str,
543
+ chain_id: str,
544
+ position_id: int,
545
+ escrow_contract: str,
546
+ authorized: str,
547
+ cash_settle: bool = False,
548
+ fraction_bps: int = 10_000,
549
+ deadline_seconds: int = 3600,
550
+ wait_for_result: bool = True,
551
+ timeout: Optional[float] = None,
552
+ ) -> Dict[str, Any]:
553
+ """
554
+ Close a margin position.
555
+
556
+ Args:
557
+ private_key: Private key for signing
558
+ chain_id: Chain ID
559
+ position_id: Position ID (NFT token ID / taker_token_id)
560
+ escrow_contract: Escrow contract address
561
+ authorized: Authorized filler address
562
+ cash_settle: Whether to cash settle (True) or swap settle (False)
563
+ fraction_bps: Fraction to close in basis points (10000 = 100%)
564
+ deadline_seconds: Deadline in seconds from now
565
+ wait_for_result: Whether to wait for completion
566
+ timeout: Optional timeout in seconds
567
+
568
+ Returns:
569
+ Close result
570
+ """
571
+ signed = sign_close_position(
572
+ chain_id=int(chain_id),
573
+ position_id=position_id,
574
+ private_key=private_key,
575
+ escrow_contract=escrow_contract,
576
+ authorized=authorized,
577
+ cash_settle=cash_settle,
578
+ fraction_bps=fraction_bps,
579
+ deadline_seconds=deadline_seconds,
580
+ )
581
+
582
+ response = await submit_close_position(signed)
583
+ close_order = response.get("close_order")
584
+ close_order_id = None
585
+ if isinstance(close_order, dict):
586
+ close_order_id = close_order.get("id")
587
+
588
+ order_id = (
589
+ close_order_id
590
+ or response.get("id")
591
+ or response.get("order_id")
592
+ or response.get("orderId")
593
+ or response.get("orderID")
594
+ or ""
595
+ )
596
+ if not isinstance(order_id, str) or not order_id:
597
+ raise MarginException(f"Close position response missing order id: {response}")
598
+
599
+ logger.info(f"Close position order submitted: {order_id}")
600
+
601
+ if not wait_for_result:
602
+ return {
603
+ "order_id": order_id,
604
+ "response": response,
605
+ }
606
+
607
+ try:
608
+ if timeout:
609
+ completion = await asyncio.wait_for(
610
+ wait_for_completion_with_retry(order_id, OrderType.SWAP),
611
+ timeout=timeout
612
+ )
613
+ else:
614
+ completion = await wait_for_completion_with_retry(order_id, OrderType.SWAP)
615
+
616
+ return {
617
+ "order_id": order_id,
618
+ "response": response,
619
+ "completion": completion,
620
+ }
621
+ except asyncio.TimeoutError as exc:
622
+ raise StuckException(f"Close position {order_id} timed out after {timeout}s") from exc
623
+
624
+
625
+ async def list_margin_positions(wallet_address: str) -> List[MarginPosition]:
626
+ """
627
+ Get all margin positions for a wallet.
628
+
629
+ Args:
630
+ wallet_address: User's wallet address
631
+
632
+ Returns:
633
+ List of MarginPosition objects
634
+ """
635
+ positions = await get_margin_positions(wallet_address)
636
+ return [
637
+ MarginPosition(
638
+ id=str(p.get("id", "")),
639
+ status=p.get("status", ""),
640
+ escrow_address=p.get("escrow_address", ""),
641
+ filler_address=p.get("filler_address", ""),
642
+ taker_token_id=p.get("taker_token_id", 0),
643
+ data=p,
644
+ )
645
+ for p in positions
646
+ ]
647
+
648
+
649
+ async def get_position(position_id: str) -> MarginPosition:
650
+ """
651
+ Get a specific margin position.
652
+
653
+ Args:
654
+ position_id: Position ID
655
+
656
+ Returns:
657
+ MarginPosition object
658
+ """
659
+ p = await get_margin_position(position_id)
660
+ return MarginPosition(
661
+ id=str(p.get("id", "")),
662
+ status=p.get("status", ""),
663
+ escrow_address=p.get("escrow_address", ""),
664
+ filler_address=p.get("filler_address", ""),
665
+ taker_token_id=p.get("taker_token_id", 0),
666
+ data=p,
667
+ )
668
+
tristero/config.py CHANGED
@@ -1,19 +1,35 @@
1
1
  from contextvars import ContextVar
2
2
 
3
+
3
4
  class Config:
4
5
  filler_url = "https://api.tristero.com/v2/orders"
5
6
  quoter_url = "https://api.tristero.com/v2/quotes"
6
7
  ws_url = "wss://api.tristero.com/v2/orders"
8
+
9
+ margin_quoter_url = "https://api.tristero.com/v2/quotes"
10
+ margin_filler_url = "https://api.tristero.com/v2/orders"
11
+ wallet_server_url = "https://api.tristero.com/v2/wallets"
12
+
13
+ # filler_url = "http://localhost:8070"
14
+ # quoter_url = "http://localhost:8060"
15
+ # ws_url = "ws://localhost:8070"
16
+
17
+ # margin_quoter_url = "http://localhost:8060"
18
+ # margin_filler_url = "http://localhost:8070"
19
+ # wallet_server_url = "http://localhost:8090"
7
20
 
8
21
  headers = {
9
22
  "User-Agent": "Mozilla/5.0 (iPad; CPU OS 12_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148"
10
23
  }
11
24
 
25
+
12
26
  config_var = ContextVar("config", default=Config())
13
27
 
28
+
14
29
  def get_config():
15
30
  return config_var.get()
16
31
 
32
+
17
33
  def set_config(new_config: Config):
18
34
  return config_var.set(new_config)
19
35