tristero 0.1.7__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tristero/client.py CHANGED
@@ -1,29 +1,61 @@
1
1
  import asyncio
2
+ from dataclasses import dataclass
3
+ from enum import Enum
2
4
  import json
3
5
  import logging
4
- from typing import Any, Optional, TypeVar
6
+ import os
7
+ import ssl
8
+ from typing import Any, Dict, List, Literal, Optional, TypeVar, Union
5
9
 
6
10
  from eth_account.signers.local import LocalAccount
7
11
  from pydantic import BaseModel
8
12
 
9
- from tristero.api import ChainID, fill_order, poll_updates
10
- from .permit2 import create_order
13
+ from tristero.api import (
14
+ fill_order,
15
+ fill_order_feather,
16
+ get_quote,
17
+ poll_updates,
18
+ poll_updates_feather,
19
+ get_margin_quote,
20
+ submit_margin_order,
21
+ submit_close_position,
22
+ get_margin_positions,
23
+ get_margin_position,
24
+ poll_margin_updates,
25
+ )
26
+ from .permit2 import (
27
+ Permit2Order,
28
+ create_permit2_order,
29
+ sign_margin_order,
30
+ sign_close_position,
31
+ )
32
+ from .data import ChainID
11
33
  from web3 import AsyncBaseProvider, AsyncWeb3
12
- import logging
13
- from web3 import AsyncWeb3
14
34
  from tenacity import (
15
35
  retry,
16
36
  stop_after_attempt,
17
37
  wait_exponential,
18
38
  retry_if_exception_type,
19
39
  )
40
+
41
+ import certifi
42
+
20
43
  logger = logging.getLogger(__name__)
21
44
 
22
45
  P = TypeVar("P", bound=AsyncBaseProvider)
23
46
 
47
+
48
+ class OrderType(str, Enum):
49
+ """Order type enum for unified API."""
50
+ SWAP = "swap"
51
+ FEATHER = "feather"
52
+ MARGIN = "margin"
53
+
54
+
24
55
  class WebSocketClosedError(Exception):
25
56
  pass
26
57
 
58
+
27
59
  class SwapException(Exception):
28
60
  """Base exception for all swap-related errors."""
29
61
  pass
@@ -41,80 +73,283 @@ class OrderFailedException(SwapException):
41
73
  self.order_id = order_id
42
74
  self.details = details
43
75
 
76
+
77
+ class MarginException(Exception):
78
+ """Exception for margin-related errors."""
79
+ pass
80
+
81
+
44
82
  class TokenSpec(BaseModel, frozen=True):
45
83
  chain_id: ChainID
46
84
  token_address: str
47
85
 
48
- async def wait_for_completion(order_id: str):
86
+
87
+ def make_async_w3(rpc_url: str) -> AsyncWeb3:
88
+ insecure_ssl = os.getenv("TRISTERO_INSECURE_SSL", "").strip().lower() in {"1", "true", "yes", "y", "on"}
89
+ if insecure_ssl:
90
+ provider = AsyncWeb3.AsyncHTTPProvider(rpc_url, request_kwargs={"ssl": False})
91
+ return AsyncWeb3(provider)
92
+
93
+ ca_file = os.getenv("TRISTERO_SSL_CA_FILE", "").strip()
94
+ if not ca_file:
95
+ ca_file = certifi.where()
96
+
97
+ if ca_file:
98
+ ssl_context = ssl.create_default_context(cafile=ca_file)
99
+ provider = AsyncWeb3.AsyncHTTPProvider(rpc_url, request_kwargs={"ssl": ssl_context})
100
+ return AsyncWeb3(provider)
101
+
102
+ provider = AsyncWeb3.AsyncHTTPProvider(rpc_url)
103
+ return AsyncWeb3(provider)
104
+
105
+
106
+ @dataclass
107
+ class QuoteResult:
108
+ """Result from get_quote operations."""
109
+ quote_data: Dict[str, Any]
110
+ order_type: str
111
+
112
+ @property
113
+ def is_margin(self) -> bool:
114
+ return self.order_type == "MARGIN"
115
+
116
+ @property
117
+ def is_feather(self) -> bool:
118
+ return self.order_type == "FEATHER"
119
+
120
+
121
+ @dataclass
122
+ class OrderResult:
123
+ """Result from order submission."""
124
+ order_id: str
125
+ order_type: str
126
+ response: Dict[str, Any]
127
+
128
+
129
+ @dataclass
130
+ class MarginPosition:
131
+ """Margin position data."""
132
+ id: str
133
+ status: str
134
+ escrow_address: str
135
+ filler_address: str
136
+ taker_token_id: int
137
+ data: Dict[str, Any]
138
+
139
+
140
+ async def wait_for_feather_completion(order_id: str):
141
+ ws = await poll_updates_feather(order_id)
142
+ try:
143
+ async for msg in ws:
144
+ if not msg:
145
+ continue
146
+ msg = json.loads(msg)
147
+ status = msg['status']
148
+ logger.info(
149
+ {
150
+ "message": f"status={status}",
151
+ "id": "order_update",
152
+ "payload": msg,
153
+ }
154
+ )
155
+ if status in ['Expired']:
156
+ await ws.close()
157
+ raise SwapException(f"Swap failed: {ws.close_reason} {msg}")
158
+ elif status in ['Finalized']:
159
+ await ws.close()
160
+ return msg
161
+ raise WebSocketClosedError("WebSocket closed without completion status")
162
+ except Exception as e:
163
+ if not ws.close_code:
164
+ await ws.close()
165
+ raise
166
+
167
+
168
+ async def wait_for_permit2_completion(order_id: str):
49
169
  ws = await poll_updates(order_id)
50
170
  try:
51
171
  async for msg in ws:
52
172
  msg = json.loads(msg)
173
+ if isinstance(msg, dict) and ("failed" in msg or "completed" in msg):
174
+ logger.info(
175
+ {
176
+ "message": f"failed={msg.get('failed')} completed={msg.get('completed')}",
177
+ "id": "order_update",
178
+ "payload": msg,
179
+ }
180
+ )
181
+ if msg.get("failed"):
182
+ await ws.close()
183
+ raise SwapException(f"Swap failed: {ws.close_reason} {msg}")
184
+ elif msg.get("completed"):
185
+ await ws.close()
186
+ return msg
187
+ continue
188
+
189
+ if isinstance(msg, dict):
190
+ fill_tx = msg.get("fill_tx") or msg.get("fillTx")
191
+ if fill_tx:
192
+ await ws.close()
193
+ return msg
194
+
195
+ if msg.get("amount_repaid") is not None or msg.get("amount_settled") is not None:
196
+ await ws.close()
197
+ return msg
198
+
199
+ status = msg.get("status") if isinstance(msg, dict) else None
53
200
  logger.info(
54
201
  {
55
- "message": f"failed={msg['failed']} completed={msg['completed']}",
202
+ "message": f"status={status}",
56
203
  "id": "order_update",
57
204
  "payload": msg,
58
205
  }
59
206
  )
60
- if msg["failed"]:
207
+ if status in ["Failed", "Expired", "Rejected"]:
61
208
  await ws.close()
62
- raise Exception(f"Swap failed: {ws.close_reason} {msg}")
63
- elif msg["completed"]:
209
+ raise SwapException(f"Swap failed: {ws.close_reason} {msg}")
210
+ elif status in ["Finalized", "Filled", "Completed"]:
64
211
  await ws.close()
65
212
  return msg
66
213
 
67
- # If we exit the loop without completed/failed, raise to retry
68
214
  raise WebSocketClosedError("WebSocket closed without completion status")
69
215
  except Exception:
70
- # Close cleanly if still open
71
216
  if not ws.close_code:
72
217
  await ws.close()
73
218
  raise
74
219
 
220
+
221
+ async def wait_for_margin_completion(order_id: str):
222
+ """Wait for margin order completion via WebSocket."""
223
+ ws = await poll_margin_updates(order_id)
224
+ try:
225
+ async for msg in ws:
226
+ if not msg:
227
+ continue
228
+ msg = json.loads(msg)
229
+ if isinstance(msg, dict) and ("failed" in msg or "completed" in msg):
230
+ logger.info(
231
+ {
232
+ "message": f"failed={msg.get('failed')} completed={msg.get('completed')}",
233
+ "id": "margin_order_update",
234
+ "payload": msg,
235
+ }
236
+ )
237
+ if msg.get("failed"):
238
+ await ws.close()
239
+ raise MarginException(f"Margin order failed: {msg}")
240
+ if msg.get("completed"):
241
+ await ws.close()
242
+ return msg
243
+ continue
244
+
245
+ status = msg.get('status', '') if isinstance(msg, dict) else ''
246
+ logger.info(
247
+ {
248
+ "message": f"margin status={status}",
249
+ "id": "margin_order_update",
250
+ "payload": msg,
251
+ }
252
+ )
253
+ if status in ['Failed', 'Expired', 'Rejected']:
254
+ await ws.close()
255
+ raise MarginException(f"Margin order failed: {msg}")
256
+ elif status in ['Filled', 'Finalized', 'Completed']:
257
+ await ws.close()
258
+ return msg
259
+ raise WebSocketClosedError("WebSocket closed without completion status")
260
+ except Exception as e:
261
+ if not ws.close_code:
262
+ await ws.close()
263
+ raise
264
+
265
+
266
+ async def wait_for_completion(order_id: str, order_type: Union[str, OrderType] = OrderType.SWAP):
267
+ """
268
+ Wait for order completion.
269
+
270
+ Args:
271
+ order_id: Order ID to wait for
272
+ order_type: Type of order (swap, feather, margin)
273
+ """
274
+ if isinstance(order_type, str):
275
+ order_type = order_type.lower()
276
+ else:
277
+ order_type = order_type.value
278
+
279
+ if order_type == "feather":
280
+ return await wait_for_feather_completion(order_id)
281
+ elif order_type == "margin":
282
+ return await wait_for_margin_completion(order_id)
283
+ else:
284
+ return await wait_for_permit2_completion(order_id)
285
+
286
+
75
287
  @retry(
76
288
  stop=stop_after_attempt(3),
77
289
  wait=wait_exponential(multiplier=1, min=4, max=10),
78
290
  retry=retry_if_exception_type(WebSocketClosedError)
79
291
  )
80
- async def wait_for_completion_with_retry(order_id: str):
81
- return await wait_for_completion(order_id)
292
+ async def wait_for_completion_with_retry(
293
+ order_id: str,
294
+ order_type: Union[str, OrderType] = OrderType.SWAP
295
+ ):
296
+ """Wait for order completion with automatic retry on WebSocket failures."""
297
+ return await wait_for_completion(order_id, order_type)
82
298
 
83
- async def start_swap(w3: AsyncWeb3[P], account: LocalAccount, from_t: TokenSpec, to_t: TokenSpec, raw_amount: int) -> str:
84
- """
85
- Execute a token swap operation.
86
299
 
87
- Args:
88
- w3: Web3 provider instance
89
- account: Account to execute swap from
90
- from_t: Source token specification
91
- to_t: Target token specification
92
- raw_amount: Amount in smallest unit (e.g., wei)
300
+ @dataclass
301
+ class FeatherSwapResult:
302
+ deposit_address: str
303
+ data: Any
93
304
 
94
- Returns:
95
- Order ID for tracking the swap
96
305
 
97
- Raises:
98
- Exception: If order creation or submission fails
99
- """
100
- data, sig = await create_order(
306
+ class FeatherException(Exception):
307
+ pass
308
+
309
+
310
+ async def start_feather_swap(
311
+ src_t: TokenSpec,
312
+ dst_t: TokenSpec,
313
+ dst_addr: str,
314
+ raw_amount: int,
315
+ client_id: str = ''
316
+ ):
317
+ resp = await fill_order_feather(str(src_t.chain_id.value), str(dst_t.chain_id.value), dst_addr, raw_amount, client_id)
318
+ if resp['detail']:
319
+ raise FeatherException(resp)
320
+ else:
321
+ return FeatherSwapResult(
322
+ deposit_address=resp['deposit_address'],
323
+ data=resp
324
+ )
325
+
326
+
327
+ async def start_permit2_swap(
328
+ w3: AsyncWeb3[P],
329
+ account: LocalAccount,
330
+ src_t: TokenSpec,
331
+ dst_t: TokenSpec,
332
+ raw_amount: int,
333
+ ):
334
+ order = await create_permit2_order(
101
335
  w3,
102
336
  account,
103
- from_t.chain_id,
104
- from_t.token_address,
105
- to_t.chain_id,
106
- to_t.token_address,
337
+ str(src_t.chain_id.value),
338
+ src_t.token_address,
339
+ str(dst_t.chain_id.value),
340
+ dst_t.token_address,
107
341
  raw_amount,
108
342
  )
109
343
  response = await fill_order(
110
- str(sig.signature.to_0x_hex()),
111
- data.domain.model_dump(by_alias=True, mode="json"),
112
- data.message.model_dump(by_alias=True, mode="json"),
344
+ str(order.sig.signature.to_0x_hex()),
345
+ order.msg.domain.model_dump(by_alias=True, mode="json"),
346
+ order.msg.message.model_dump(by_alias=True, mode="json"),
113
347
  )
348
+ order_id = response['id']
349
+ return order_id
114
350
 
115
- return response['id']
116
351
 
117
- async def execute_swap(
352
+ async def execute_permit2_swap(
118
353
  w3: AsyncWeb3[P],
119
354
  account: LocalAccount,
120
355
  src_t: TokenSpec,
@@ -124,24 +359,310 @@ async def execute_swap(
124
359
  timeout: Optional[float] = None
125
360
  ) -> dict[str, Any]:
126
361
  """Execute and wait for swap completion."""
127
- order_id = await start_swap(
128
- w3,
129
- account,
130
- src_t,
131
- dst_t,
132
- raw_amount
133
- )
362
+ order_id = await start_permit2_swap(w3, account, src_t, dst_t, raw_amount)
134
363
  logger.info(f"Swap order placed: {order_id}")
135
364
 
136
365
  waiter = wait_for_completion_with_retry if retry else wait_for_completion
137
366
 
138
367
  try:
139
368
  if timeout is None:
140
- return await waiter(order_id)
369
+ return await waiter(order_id, OrderType.SWAP)
141
370
 
142
371
  return await asyncio.wait_for(
143
- waiter(order_id),
372
+ waiter(order_id, OrderType.SWAP),
144
373
  timeout=timeout
145
374
  )
146
375
  except asyncio.TimeoutError as exc:
147
376
  raise StuckException(f"Swap {order_id} timed out after {timeout}s") from exc
377
+
378
+
379
+ async def execute_swap(
380
+ w3: AsyncWeb3[P],
381
+ account: LocalAccount,
382
+ src_t: TokenSpec,
383
+ dst_t: TokenSpec,
384
+ raw_amount: int,
385
+ retry: bool = True,
386
+ timeout: Optional[float] = None,
387
+ ) -> dict[str, Any]:
388
+ return await execute_permit2_swap(
389
+ w3=w3,
390
+ account=account,
391
+ src_t=src_t,
392
+ dst_t=dst_t,
393
+ raw_amount=raw_amount,
394
+ retry=retry,
395
+ timeout=timeout,
396
+ )
397
+
398
+
399
+ async def request_margin_quote(
400
+ chain_id: str,
401
+ wallet_address: str,
402
+ quote_currency: str,
403
+ base_currency: str,
404
+ leverage_ratio: int,
405
+ collateral_amount: str,
406
+ ) -> QuoteResult:
407
+ """
408
+ Request a quote for opening a margin position.
409
+
410
+ Args:
411
+ chain_id: Chain ID (e.g., "42161" for Arbitrum)
412
+ wallet_address: User's wallet address
413
+ quote_currency: Quote currency token address (e.g., USDC)
414
+ base_currency: Base currency token address (e.g., WETH)
415
+ leverage_ratio: Leverage ratio (e.g., 2 for 2x leverage)
416
+ collateral_amount: Collateral amount in raw units
417
+
418
+ Returns:
419
+ QuoteResult containing quote data
420
+ """
421
+ quote = await get_margin_quote(
422
+ chain_id=chain_id,
423
+ wallet_address=wallet_address,
424
+ quote_currency=quote_currency,
425
+ base_currency=base_currency,
426
+ leverage_ratio=leverage_ratio,
427
+ collateral_amount=collateral_amount,
428
+ )
429
+ return QuoteResult(
430
+ quote_data=quote,
431
+ order_type="MARGIN",
432
+ )
433
+
434
+
435
+ def sign_order(
436
+ quote: QuoteResult,
437
+ private_key: str,
438
+ ) -> Dict[str, Any]:
439
+ """
440
+ Sign a quote to create a submittable order.
441
+
442
+ Args:
443
+ quote: Quote result from request_margin_quote
444
+ private_key: Private key for signing
445
+
446
+ Returns:
447
+ Signed order payload
448
+ """
449
+ if quote.is_margin:
450
+ return sign_margin_order(quote.quote_data, private_key)
451
+ else:
452
+ raise ValueError(f"Unsupported order type for signing: {quote.order_type}")
453
+
454
+
455
+ async def submit_order(signed_order: Dict[str, Any]) -> OrderResult:
456
+ """
457
+ Submit a signed order.
458
+
459
+ Args:
460
+ signed_order: Signed order from sign_order
461
+
462
+ Returns:
463
+ OrderResult with order_id
464
+ """
465
+ response = await submit_margin_order(signed_order)
466
+ return OrderResult(
467
+ order_id=response.get("id", response.get("order_id", "")),
468
+ order_type="MARGIN",
469
+ response=response,
470
+ )
471
+
472
+
473
+ async def open_margin_position(
474
+ private_key: str,
475
+ chain_id: str,
476
+ wallet_address: str,
477
+ quote_currency: str,
478
+ base_currency: str,
479
+ leverage_ratio: int,
480
+ collateral_amount: str,
481
+ wait_for_result: bool = True,
482
+ timeout: Optional[float] = None,
483
+ ) -> Dict[str, Any]:
484
+ """
485
+ Open a margin position in one call (quote + sign + submit + optionally wait).
486
+
487
+ Args:
488
+ private_key: Private key for signing
489
+ chain_id: Chain ID
490
+ wallet_address: User's wallet address
491
+ quote_currency: Quote currency token address
492
+ base_currency: Base currency token address
493
+ leverage_ratio: Leverage ratio
494
+ collateral_amount: Collateral amount in raw units
495
+ wait_for_result: Whether to wait for order completion
496
+ timeout: Optional timeout in seconds
497
+
498
+ Returns:
499
+ Order result or completion result
500
+ """
501
+ quote = await request_margin_quote(
502
+ chain_id=chain_id,
503
+ wallet_address=wallet_address,
504
+ quote_currency=quote_currency,
505
+ base_currency=base_currency,
506
+ leverage_ratio=leverage_ratio,
507
+ collateral_amount=collateral_amount,
508
+ )
509
+
510
+ signed = sign_order(quote, private_key)
511
+ result = await submit_order(signed)
512
+
513
+ logger.info(f"Margin order submitted: {result.order_id}")
514
+
515
+ if not wait_for_result:
516
+ return {
517
+ "order_id": result.order_id,
518
+ "quote": quote.quote_data,
519
+ "response": result.response,
520
+ }
521
+
522
+ try:
523
+ if timeout:
524
+ completion = await asyncio.wait_for(
525
+ wait_for_completion_with_retry(result.order_id, OrderType.MARGIN),
526
+ timeout=timeout
527
+ )
528
+ else:
529
+ completion = await wait_for_completion_with_retry(result.order_id, OrderType.MARGIN)
530
+
531
+ return {
532
+ "order_id": result.order_id,
533
+ "quote": quote.quote_data,
534
+ "response": result.response,
535
+ "completion": completion,
536
+ }
537
+ except asyncio.TimeoutError as exc:
538
+ raise StuckException(f"Margin order {result.order_id} timed out after {timeout}s") from exc
539
+
540
+
541
+ async def close_margin_position(
542
+ private_key: str,
543
+ chain_id: str,
544
+ position_id: int,
545
+ escrow_contract: str,
546
+ authorized: str,
547
+ cash_settle: bool = False,
548
+ fraction_bps: int = 10_000,
549
+ deadline_seconds: int = 3600,
550
+ wait_for_result: bool = True,
551
+ timeout: Optional[float] = None,
552
+ ) -> Dict[str, Any]:
553
+ """
554
+ Close a margin position.
555
+
556
+ Args:
557
+ private_key: Private key for signing
558
+ chain_id: Chain ID
559
+ position_id: Position ID (NFT token ID / taker_token_id)
560
+ escrow_contract: Escrow contract address
561
+ authorized: Authorized filler address
562
+ cash_settle: Whether to cash settle (True) or swap settle (False)
563
+ fraction_bps: Fraction to close in basis points (10000 = 100%)
564
+ deadline_seconds: Deadline in seconds from now
565
+ wait_for_result: Whether to wait for completion
566
+ timeout: Optional timeout in seconds
567
+
568
+ Returns:
569
+ Close result
570
+ """
571
+ signed = sign_close_position(
572
+ chain_id=int(chain_id),
573
+ position_id=position_id,
574
+ private_key=private_key,
575
+ escrow_contract=escrow_contract,
576
+ authorized=authorized,
577
+ cash_settle=cash_settle,
578
+ fraction_bps=fraction_bps,
579
+ deadline_seconds=deadline_seconds,
580
+ )
581
+
582
+ response = await submit_close_position(signed)
583
+ close_order = response.get("close_order")
584
+ close_order_id = None
585
+ if isinstance(close_order, dict):
586
+ close_order_id = close_order.get("id")
587
+
588
+ order_id = (
589
+ close_order_id
590
+ or response.get("id")
591
+ or response.get("order_id")
592
+ or response.get("orderId")
593
+ or response.get("orderID")
594
+ or ""
595
+ )
596
+ if not isinstance(order_id, str) or not order_id:
597
+ raise MarginException(f"Close position response missing order id: {response}")
598
+
599
+ logger.info(f"Close position order submitted: {order_id}")
600
+
601
+ if not wait_for_result:
602
+ return {
603
+ "order_id": order_id,
604
+ "response": response,
605
+ }
606
+
607
+ try:
608
+ if timeout:
609
+ completion = await asyncio.wait_for(
610
+ wait_for_completion_with_retry(order_id, OrderType.SWAP),
611
+ timeout=timeout
612
+ )
613
+ else:
614
+ completion = await wait_for_completion_with_retry(order_id, OrderType.SWAP)
615
+
616
+ return {
617
+ "order_id": order_id,
618
+ "response": response,
619
+ "completion": completion,
620
+ }
621
+ except asyncio.TimeoutError as exc:
622
+ raise StuckException(f"Close position {order_id} timed out after {timeout}s") from exc
623
+
624
+
625
+ async def list_margin_positions(wallet_address: str) -> List[MarginPosition]:
626
+ """
627
+ Get all margin positions for a wallet.
628
+
629
+ Args:
630
+ wallet_address: User's wallet address
631
+
632
+ Returns:
633
+ List of MarginPosition objects
634
+ """
635
+ positions = await get_margin_positions(wallet_address)
636
+ return [
637
+ MarginPosition(
638
+ id=str(p.get("id", "")),
639
+ status=p.get("status", ""),
640
+ escrow_address=p.get("escrow_address", ""),
641
+ filler_address=p.get("filler_address", ""),
642
+ taker_token_id=p.get("taker_token_id", 0),
643
+ data=p,
644
+ )
645
+ for p in positions
646
+ ]
647
+
648
+
649
+ async def get_position(position_id: str) -> MarginPosition:
650
+ """
651
+ Get a specific margin position.
652
+
653
+ Args:
654
+ position_id: Position ID
655
+
656
+ Returns:
657
+ MarginPosition object
658
+ """
659
+ p = await get_margin_position(position_id)
660
+ return MarginPosition(
661
+ id=str(p.get("id", "")),
662
+ status=p.get("status", ""),
663
+ escrow_address=p.get("escrow_address", ""),
664
+ filler_address=p.get("filler_address", ""),
665
+ taker_token_id=p.get("taker_token_id", 0),
666
+ data=p,
667
+ )
668
+