prediction-market-agent-tooling 0.65.9__py3-none-any.whl → 0.65.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,268 @@
1
+ import os
2
+ from datetime import datetime, timedelta
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import pandas as pd
7
+ import typer
8
+ from langfuse import Langfuse
9
+ from langfuse.client import TraceWithDetails
10
+ from pydantic import BaseModel
11
+
12
+ from prediction_market_agent_tooling.config import APIKeys
13
+ from prediction_market_agent_tooling.gtypes import DatetimeUTC, OutcomeStr, OutcomeToken
14
+ from prediction_market_agent_tooling.loggers import logger
15
+ from prediction_market_agent_tooling.markets.agent_market import AgentMarket
16
+ from prediction_market_agent_tooling.markets.data_models import Resolution
17
+ from prediction_market_agent_tooling.markets.omen.omen import OmenAgentMarket
18
+ from prediction_market_agent_tooling.markets.seer.seer import SeerAgentMarket
19
+ from prediction_market_agent_tooling.markets.seer.seer_subgraph_handler import (
20
+ SeerSubgraphHandler,
21
+ )
22
+ from prediction_market_agent_tooling.tools.hexbytes_custom import HexBytes
23
+ from prediction_market_agent_tooling.tools.httpx_cached_client import HttpxCachedClient
24
+ from prediction_market_agent_tooling.tools.langfuse_client_utils import (
25
+ get_traces_for_agent,
26
+ )
27
+ from prediction_market_agent_tooling.tools.parallelism import par_map
28
+
29
+ PREDICTION_STATES = [
30
+ "predict_market",
31
+ "_make_prediction_categorical",
32
+ "make_prediction",
33
+ ]
34
+ REPORT_STATES = ["prepare_report"]
35
+
36
+ MARKET_RESOLUTION_PROVIDERS = {
37
+ "omen": lambda market_id: OmenAgentMarket.get_binary_market(market_id),
38
+ "seer": lambda market_id: SeerAgentMarket.from_data_model_with_subgraph(
39
+ model=SeerSubgraphHandler().get_market_by_id(HexBytes(market_id)),
40
+ seer_subgraph=SeerSubgraphHandler(),
41
+ must_have_prices=False,
42
+ ),
43
+ }
44
+
45
+
46
+ class TraceResult(BaseModel):
47
+ agent_name: str
48
+ trace_id: str
49
+ market_id: str
50
+ market_type: str
51
+ market_question: str
52
+ market_outcomes: list[str]
53
+ market_outcome_token_pool: dict[OutcomeStr, OutcomeToken] | None
54
+ market_created_time: DatetimeUTC | None
55
+ market_close_time: DatetimeUTC | None
56
+ analysis: str
57
+ prediction_reasoning: str
58
+ prediction_decision: str
59
+ prediction_p_yes: float
60
+ prediction_info_utility: float
61
+ market_resolution: str | None
62
+ resolution_is_valid: bool | None
63
+
64
+
65
+ def get_langfuse_client() -> Langfuse:
66
+ api_keys = APIKeys()
67
+ return Langfuse(
68
+ secret_key=api_keys.langfuse_secret_key.get_secret_value(),
69
+ public_key=api_keys.langfuse_public_key,
70
+ host=api_keys.langfuse_host,
71
+ httpx_client=HttpxCachedClient().get_client(),
72
+ )
73
+
74
+
75
+ def download_data(
76
+ agent_name: str,
77
+ date_from: DatetimeUTC,
78
+ date_to: DatetimeUTC,
79
+ only_resolved: bool,
80
+ output_folder: str,
81
+ ) -> None:
82
+ Path(output_folder).mkdir(parents=True, exist_ok=True)
83
+ index = 0
84
+ default_file_name = f"{agent_name}_{date_from.date()}_{date_to.date()}"
85
+ output_file = os.path.join(output_folder, f"{default_file_name}.csv")
86
+
87
+ if os.path.exists(output_file):
88
+ while os.path.exists(output_file):
89
+ index += 1
90
+ output_file = os.path.join(
91
+ output_folder, f"{default_file_name}_v{index}.csv"
92
+ )
93
+
94
+ langfuse_client_for_traces = get_langfuse_client()
95
+
96
+ traces = get_traces_for_agent(
97
+ agent_name=agent_name,
98
+ trace_name="process_market",
99
+ from_timestamp=date_from,
100
+ to_timestamp=date_to,
101
+ has_output=True,
102
+ client=langfuse_client_for_traces,
103
+ tags=["answered"],
104
+ )
105
+
106
+ if not traces:
107
+ raise ValueError("No traces found for the specified criteria")
108
+
109
+ trace_args = [
110
+ (
111
+ trace,
112
+ only_resolved,
113
+ )
114
+ for trace in traces
115
+ ]
116
+
117
+ results = par_map(
118
+ items=trace_args,
119
+ func=lambda args: process_trace(*args),
120
+ max_workers=5,
121
+ )
122
+
123
+ successful_results = [r for r in results if r is not None]
124
+ if successful_results:
125
+ results_data = [result.model_dump() for result in successful_results]
126
+ pd.DataFrame(results_data).to_csv(output_file, index=False)
127
+ logger.info(f"Saved {len(successful_results)} records to {output_file}")
128
+ else:
129
+ logger.warning("No results to save")
130
+
131
+
132
+ def process_trace(
133
+ trace: TraceWithDetails,
134
+ only_resolved: bool,
135
+ ) -> TraceResult | None:
136
+ langfuse_client = get_langfuse_client()
137
+ try:
138
+ observations = langfuse_client.fetch_observations(trace_id=trace.id)
139
+
140
+ market_state, market_type = get_agent_market_state(trace.input)
141
+
142
+ prepare_report_obs = [
143
+ obs for obs in observations.data if obs.name in REPORT_STATES
144
+ ]
145
+ predict_market_obs = [
146
+ obs for obs in observations.data if obs.name in PREDICTION_STATES
147
+ ]
148
+
149
+ if not prepare_report_obs or not predict_market_obs:
150
+ raise ValueError(f"Missing required observations for trace {trace.id}")
151
+
152
+ analysis = prepare_report_obs[0].output
153
+ prediction = predict_market_obs[0].output
154
+
155
+ resolution = get_market_resolution(market_state.id, market_type)
156
+
157
+ if only_resolved and not resolution:
158
+ raise ValueError(f"No resolution found for market {market_state.id}")
159
+
160
+ result = TraceResult(
161
+ agent_name=trace.metadata["agent_class"],
162
+ trace_id=trace.id,
163
+ market_id=market_state.id,
164
+ market_type=market_type,
165
+ market_question=market_state.question,
166
+ market_outcomes=list(market_state.outcomes),
167
+ market_outcome_token_pool=market_state.outcome_token_pool,
168
+ market_created_time=market_state.created_time,
169
+ market_close_time=market_state.close_time,
170
+ analysis=analysis,
171
+ prediction_reasoning=prediction["reasoning"],
172
+ prediction_decision="YES" if prediction["decision"] == "y" else "NO",
173
+ prediction_p_yes=prediction["p_yes"],
174
+ prediction_info_utility=prediction["info_utility"],
175
+ market_resolution=resolution.outcome if resolution else None,
176
+ resolution_is_valid=not resolution.invalid if resolution else None,
177
+ )
178
+
179
+ return result
180
+
181
+ except Exception as e:
182
+ logger.exception(f"Error processing trace {trace.id}: {e}")
183
+ return None
184
+
185
+
186
+ def get_agent_market_state(input_data: dict[str, Any]) -> tuple[AgentMarket, str]:
187
+ if not input_data or "args" not in input_data:
188
+ raise ValueError("Invalid input data: missing args")
189
+
190
+ args = input_data["args"]
191
+ if len(args) < 2:
192
+ raise ValueError("Invalid args: expected at least 2 elements")
193
+
194
+ market_type = args[0] # e.g., "omen", "seer"
195
+
196
+ if market_type not in MARKET_RESOLUTION_PROVIDERS:
197
+ raise ValueError(f"Unknown market type: {market_type}")
198
+
199
+ market_data = args[1] # market object data
200
+ market_state = AgentMarket.model_construct(**market_data)
201
+
202
+ return market_state, market_type
203
+
204
+
205
+ def get_market_resolution(market_id: str, market_type: str) -> Resolution:
206
+ market_type_lower = market_type.lower()
207
+
208
+ if market_type_lower not in MARKET_RESOLUTION_PROVIDERS:
209
+ raise ValueError(f"Unknown market type: {market_type}")
210
+
211
+ try:
212
+ market: AgentMarket | None = MARKET_RESOLUTION_PROVIDERS[market_type_lower](
213
+ market_id
214
+ )
215
+
216
+ if not market or not market.resolution:
217
+ raise ValueError(f"No resolution found for market: {market_id}")
218
+
219
+ return market.resolution
220
+
221
+ except Exception as e:
222
+ raise ValueError(
223
+ f"Failed to fetch {market_type} market {market_id} resolution: {e}"
224
+ ) from e
225
+
226
+
227
+ def parse_date(date_str: str, param_name: str) -> DatetimeUTC:
228
+ try:
229
+ return DatetimeUTC.to_datetime_utc(date_str)
230
+ except ValueError as e:
231
+ typer.echo(f"Error: Invalid date format for {param_name}: {date_str}")
232
+ typer.echo("Expected format: YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS")
233
+ raise typer.Exit(1) from e
234
+
235
+
236
+ def main(
237
+ agent_name: str = "DeployablePredictionProphetGPT4oAgent",
238
+ only_resolved: bool = True,
239
+ date_from: str = typer.Option(
240
+ None, help="Start date in ISO format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS)"
241
+ ),
242
+ date_to: str = typer.Option(
243
+ None, help="End date in ISO format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS)"
244
+ ),
245
+ output_folder: str = "./agent_trades_output/",
246
+ ) -> None:
247
+ date_from_dt = (
248
+ parse_date(date_from, "date_from")
249
+ if date_from
250
+ else DatetimeUTC.from_datetime(datetime.now() - timedelta(days=1))
251
+ )
252
+ date_to_dt = (
253
+ parse_date(date_to, "date_to")
254
+ if date_to
255
+ else DatetimeUTC.from_datetime(datetime.now())
256
+ )
257
+
258
+ download_data(
259
+ agent_name=agent_name,
260
+ date_from=date_from_dt,
261
+ date_to=date_to_dt,
262
+ only_resolved=only_resolved,
263
+ output_folder=output_folder,
264
+ )
265
+
266
+
267
+ if __name__ == "__main__":
268
+ typer.run(main)
@@ -10,6 +10,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
10
10
  from pythonjsonlogger import jsonlogger
11
11
  from tenacity import RetryError
12
12
 
13
+ UNPATCHED_PRINT_FN = builtins.print
14
+
13
15
 
14
16
  class LogFormat(str, Enum):
15
17
  DEFAULT = "default"
@@ -153,7 +155,13 @@ def print_using_logger_info(
153
155
  end: str = "\n",
154
156
  **kwargs: t.Any,
155
157
  ) -> None:
156
- logger.info(sep.join(map(str, values)) + end)
158
+ # If `logger.exception` is used, loguru+traceback somehow uses `print` statement to format the error stack,
159
+ # if that happens, without this if condition, it errors out because of deadlock and/or recursion errors.
160
+ # This is hacky, but that's exactly how loguru is checking for it internally..
161
+ if getattr(logger._core.handlers[1]._lock_acquired, "acquired", False): # type: ignore # They use stubs and didn't type this.
162
+ UNPATCHED_PRINT_FN(*values, sep=sep, end=end, **kwargs)
163
+ else:
164
+ logger.info(sep.join(map(str, values)) + end)
157
165
 
158
166
 
159
167
  patch_logger()
@@ -284,6 +284,9 @@ def omen_resolve_market_tx(
284
284
  logger.warning(
285
285
  f"Market {market.url=} not resolved, because `condition not prepared or found`, skipping."
286
286
  )
287
+ elif "payout denominator already set" in str(e):
288
+ # We can just skip, it's been resolved already.
289
+ logger.info(f"Market {market.url=} is already resolved.")
287
290
  else:
288
291
  raise
289
292
 
@@ -1,4 +1,5 @@
1
1
  import typing as t
2
+ from datetime import timedelta
2
3
  from urllib.parse import urljoin
3
4
 
4
5
  from pydantic import BaseModel, ConfigDict, Field
@@ -13,12 +14,14 @@ from prediction_market_agent_tooling.gtypes import (
13
14
  OutcomeStr,
14
15
  OutcomeWei,
15
16
  Web3Wei,
17
+ Wei,
16
18
  )
17
19
  from prediction_market_agent_tooling.markets.seer.subgraph_data_models import (
18
20
  SeerParentMarket,
19
21
  )
20
22
  from prediction_market_agent_tooling.tools.contract import ContractERC20OnGnosisChain
21
23
  from prediction_market_agent_tooling.tools.datetime_utc import DatetimeUTC
24
+ from prediction_market_agent_tooling.tools.utils import utcnow
22
25
 
23
26
 
24
27
  class CreateCategoricalMarketsParams(BaseModel):
@@ -133,3 +136,19 @@ class RedeemParams(BaseModel):
133
136
  market: ChecksumAddress
134
137
  outcome_indices: list[int] = Field(alias="outcomeIndexes")
135
138
  amounts: list[OutcomeWei]
139
+
140
+
141
+ class ExactInputSingleParams(BaseModel):
142
+ # from https://gnosisscan.io/address/0xffb643e73f280b97809a8b41f7232ab401a04ee1#code
143
+ model_config = ConfigDict(populate_by_name=True)
144
+ token_in: ChecksumAddress = Field(alias="tokenIn")
145
+ token_out: ChecksumAddress = Field(alias="tokenOut")
146
+ recipient: ChecksumAddress
147
+ deadline: int = Field(
148
+ default_factory=lambda: int((utcnow() + timedelta(minutes=10)).timestamp())
149
+ )
150
+ amount_in: Wei = Field(alias="amountIn")
151
+ amount_out_minimum: Wei = Field(alias="amountOutMinimum")
152
+ limit_sqrt_price: Wei = Field(
153
+ alias="limitSqrtPrice", default_factory=lambda: Wei(0)
154
+ ) # 0 for convenience, we also don't expect major price shifts
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import typing as t
2
3
  from datetime import timedelta
3
4
 
@@ -15,6 +16,7 @@ from prediction_market_agent_tooling.gtypes import (
15
16
  OutcomeStr,
16
17
  OutcomeToken,
17
18
  OutcomeWei,
19
+ Wei,
18
20
  xDai,
19
21
  )
20
22
  from prediction_market_agent_tooling.loggers import logger
@@ -46,16 +48,21 @@ from prediction_market_agent_tooling.markets.seer.seer_subgraph_handler import (
46
48
  from prediction_market_agent_tooling.markets.seer.subgraph_data_models import (
47
49
  NewMarketEvent,
48
50
  )
51
+ from prediction_market_agent_tooling.markets.seer.swap_pool_handler import (
52
+ SwapPoolHandler,
53
+ )
49
54
  from prediction_market_agent_tooling.tools.contract import (
50
55
  ContractERC20OnGnosisChain,
51
56
  init_collateral_token_contract,
52
57
  to_gnosis_chain_contract,
53
58
  )
54
59
  from prediction_market_agent_tooling.tools.cow.cow_order import (
60
+ cancel_order,
55
61
  get_buy_token_amount_else_raise,
56
62
  get_orders_by_owner,
57
63
  get_trades_by_owner,
58
64
  swap_tokens_waiting,
65
+ wait_for_order_completion,
59
66
  )
60
67
  from prediction_market_agent_tooling.tools.datetime_utc import DatetimeUTC
61
68
  from prediction_market_agent_tooling.tools.tokens.auto_deposit import (
@@ -441,6 +448,60 @@ class SeerAgentMarket(AgentMarket):
441
448
  outcome_idx = self.outcomes.index(outcome)
442
449
  return self.wrapped_tokens[outcome_idx]
443
450
 
451
+ def _swap_tokens_with_fallback(
452
+ self,
453
+ sell_token: ChecksumAddress,
454
+ buy_token: ChecksumAddress,
455
+ amount_wei: Wei,
456
+ api_keys: APIKeys,
457
+ web3: Web3 | None,
458
+ ) -> str:
459
+ """
460
+ Helper method to swap tokens with a fallback to direct pool swapping if the order times out.
461
+
462
+ Args:
463
+ sell_token: Address of the token to sell
464
+ buy_token: Address of the token to buy
465
+ amount_wei: Amount to swap in wei
466
+ api_keys: API keys for the transaction
467
+ web3: Web3 instance
468
+
469
+ Returns:
470
+ Transaction hash of the successful swap
471
+ """
472
+ _, order = swap_tokens_waiting(
473
+ amount_wei=amount_wei,
474
+ sell_token=sell_token,
475
+ buy_token=buy_token,
476
+ api_keys=api_keys,
477
+ web3=web3,
478
+ wait_order_complete=False,
479
+ )
480
+
481
+ try:
482
+ order_metadata = asyncio.run(wait_for_order_completion(order=order))
483
+ logger.debug(
484
+ f"Swapped {sell_token} for {buy_token}. Order details {order_metadata}"
485
+ )
486
+ return order_metadata.uid.root
487
+
488
+ except TimeoutError:
489
+ # Since timeout occurred, we need to cancel the order before trying to swap again.
490
+ asyncio.run(cancel_order(order_uids=[order.uid.root], api_keys=api_keys))
491
+ logger.info("TimeoutError. Trying to swap directly on Swapr pools.")
492
+
493
+ tx_receipt = SwapPoolHandler(
494
+ api_keys=api_keys,
495
+ market_id=self.id,
496
+ collateral_token_address=self.collateral_token_contract_address_checksummed,
497
+ ).buy_or_sell_outcome_token(
498
+ token_in=sell_token,
499
+ token_out=buy_token,
500
+ amount_wei=amount_wei,
501
+ web3=web3,
502
+ )
503
+ return tx_receipt["transactionHash"].hex()
504
+
444
505
  def place_bet(
445
506
  self,
446
507
  outcome: OutcomeStr,
@@ -449,6 +510,7 @@ class SeerAgentMarket(AgentMarket):
449
510
  web3: Web3 | None = None,
450
511
  api_keys: APIKeys | None = None,
451
512
  ) -> str:
513
+ outcome_token = self.get_wrapped_token_for_outcome(outcome)
452
514
  api_keys = api_keys if api_keys is not None else APIKeys()
453
515
  if not self.can_be_traded():
454
516
  raise ValueError(
@@ -464,27 +526,21 @@ class SeerAgentMarket(AgentMarket):
464
526
  collateral_contract, amount_wei, api_keys, web3
465
527
  )
466
528
 
467
- collateral_balance = collateral_contract.balanceOf(api_keys.bet_from_address)
529
+ collateral_balance = collateral_contract.balanceOf(
530
+ api_keys.bet_from_address, web3=web3
531
+ )
468
532
  if collateral_balance < amount_wei:
469
533
  raise ValueError(
470
534
  f"Balance {collateral_balance} not enough for bet size {amount}"
471
535
  )
472
536
 
473
- outcome_token = self.get_wrapped_token_for_outcome(outcome)
474
-
475
- # Sell sDAI using token address
476
- order_metadata = swap_tokens_waiting(
477
- amount_wei=amount_wei,
537
+ return self._swap_tokens_with_fallback(
478
538
  sell_token=collateral_contract.address,
479
539
  buy_token=outcome_token,
540
+ amount_wei=amount_wei,
480
541
  api_keys=api_keys,
481
542
  web3=web3,
482
543
  )
483
- logger.debug(
484
- f"Purchased {outcome_token} in exchange for {collateral_contract.address}. Order details {order_metadata}"
485
- )
486
-
487
- return order_metadata.uid.root
488
544
 
489
545
  def sell_tokens(
490
546
  self,
@@ -506,21 +562,27 @@ class SeerAgentMarket(AgentMarket):
506
562
  else self.get_in_token(amount).as_wei
507
563
  )
508
564
 
509
- order_metadata = swap_tokens_waiting(
510
- amount_wei=token_amount,
565
+ return self._swap_tokens_with_fallback(
511
566
  sell_token=outcome_token,
512
567
  buy_token=Web3.to_checksum_address(
513
568
  self.collateral_token_contract_address_checksummed
514
569
  ),
570
+ amount_wei=token_amount,
515
571
  api_keys=api_keys,
516
572
  web3=web3,
517
573
  )
518
574
 
519
- logger.debug(
520
- f"Sold {outcome_token} in exchange for {self.collateral_token_contract_address_checksummed}. Order details {order_metadata}"
575
+ def get_token_balance(
576
+ self, user_id: str, outcome: OutcomeStr, web3: Web3 | None = None
577
+ ) -> OutcomeToken:
578
+ erc20_token = ContractERC20OnGnosisChain(
579
+ address=self.get_wrapped_token_for_outcome(outcome)
580
+ )
581
+ return OutcomeToken.from_token(
582
+ erc20_token.balance_of_in_tokens(
583
+ for_address=Web3.to_checksum_address(user_id), web3=web3
584
+ )
521
585
  )
522
-
523
- return order_metadata.uid.root
524
586
 
525
587
 
526
588
  def seer_create_market_tx(
@@ -11,11 +11,15 @@ from prediction_market_agent_tooling.gtypes import (
11
11
  TxReceipt,
12
12
  xDai,
13
13
  )
14
- from prediction_market_agent_tooling.markets.seer.data_models import RedeemParams
14
+ from prediction_market_agent_tooling.markets.seer.data_models import (
15
+ ExactInputSingleParams,
16
+ RedeemParams,
17
+ )
15
18
  from prediction_market_agent_tooling.markets.seer.subgraph_data_models import (
16
19
  CreateCategoricalMarketsParams,
17
20
  )
18
21
  from prediction_market_agent_tooling.tools.contract import (
22
+ ContractERC20OnGnosisChain,
19
23
  ContractOnGnosisChain,
20
24
  abi_field_validator,
21
25
  )
@@ -110,3 +114,38 @@ class GnosisRouter(ContractOnGnosisChain):
110
114
  web3=web3,
111
115
  )
112
116
  return receipt_tx
117
+
118
+
119
+ class SwaprRouterContract(ContractOnGnosisChain):
120
+ # File content taken from https://github.com/protofire/omen-exchange/blob/master/app/src/abi/marketMaker.json.
121
+ abi: ABI = abi_field_validator(
122
+ os.path.join(
123
+ os.path.dirname(os.path.realpath(__file__)),
124
+ "../../abis/swapr_router.abi.json",
125
+ )
126
+ )
127
+
128
+ address: ChecksumAddress = Web3.to_checksum_address(
129
+ "0xffb643e73f280b97809a8b41f7232ab401a04ee1"
130
+ )
131
+
132
+ def exact_input_single(
133
+ self,
134
+ api_keys: APIKeys,
135
+ params: ExactInputSingleParams,
136
+ web3: Web3 | None = None,
137
+ ) -> TxReceipt:
138
+ erc20_token = ContractERC20OnGnosisChain(address=params.token_in)
139
+
140
+ if (
141
+ erc20_token.allowance(api_keys.bet_from_address, self.address, web3=web3)
142
+ < params.amount_in
143
+ ):
144
+ erc20_token.approve(api_keys, self.address, params.amount_in, web3=web3)
145
+
146
+ return self.send(
147
+ api_keys=api_keys,
148
+ function_name="exactInputSingle",
149
+ function_params=[tuple(dict(params).values())],
150
+ web3=web3,
151
+ )
@@ -0,0 +1,96 @@
1
+ from web3 import Web3
2
+
3
+ from prediction_market_agent_tooling.config import APIKeys
4
+ from prediction_market_agent_tooling.gtypes import (
5
+ ChecksumAddress,
6
+ CollateralToken,
7
+ HexBytes,
8
+ HexStr,
9
+ TxReceipt,
10
+ Wei,
11
+ )
12
+ from prediction_market_agent_tooling.markets.seer.data_models import (
13
+ ExactInputSingleParams,
14
+ )
15
+ from prediction_market_agent_tooling.markets.seer.price_manager import PriceManager
16
+ from prediction_market_agent_tooling.markets.seer.seer_contracts import (
17
+ SwaprRouterContract,
18
+ )
19
+ from prediction_market_agent_tooling.markets.seer.seer_subgraph_handler import (
20
+ SeerSubgraphHandler,
21
+ )
22
+
23
+
24
+ class SwapPoolHandler:
25
+ def __init__(
26
+ self,
27
+ api_keys: APIKeys,
28
+ market_id: str,
29
+ collateral_token_address: ChecksumAddress,
30
+ seer_subgraph: SeerSubgraphHandler | None = None,
31
+ ):
32
+ self.api_keys = api_keys
33
+ self.market_id = market_id
34
+ self.collateral_token_address = collateral_token_address
35
+ self.seer_subgraph = seer_subgraph or SeerSubgraphHandler()
36
+
37
+ def _calculate_amount_out_minimum(
38
+ self,
39
+ amount_wei: Wei,
40
+ token_in: ChecksumAddress,
41
+ price_outcome_token: CollateralToken,
42
+ buffer_pct: float = 0.05,
43
+ ) -> Wei:
44
+ is_buying_outcome = token_in == self.collateral_token_address
45
+
46
+ if is_buying_outcome:
47
+ value = amount_wei.value * (1.0 - buffer_pct) / price_outcome_token.value
48
+ else:
49
+ value = amount_wei.value * price_outcome_token.value * (1.0 - buffer_pct)
50
+ return Wei(int(value))
51
+
52
+ def buy_or_sell_outcome_token(
53
+ self,
54
+ amount_wei: Wei,
55
+ token_in: ChecksumAddress,
56
+ token_out: ChecksumAddress,
57
+ web3: Web3 | None = None,
58
+ ) -> TxReceipt:
59
+ """Buys/sells outcome_tokens in exchange for collateral tokens"""
60
+ if self.collateral_token_address not in [token_in, token_out]:
61
+ raise ValueError(
62
+ f"trading outcome_token for a token different than collateral_token {self.collateral_token_address} is not supported. {token_in=} {token_out=}"
63
+ )
64
+
65
+ outcome_token = (
66
+ token_in if token_in != self.collateral_token_address else token_out
67
+ )
68
+
69
+ # We could use a quoter contract (https://github.com/SwaprHQ/swapr-sdk/blob/develop/src/entities/trades/swapr-v3/constants.ts#L7), but since there is normally 1 pool per outcome token/collateral pair, it's not necessary.
70
+
71
+ price_outcome_token = PriceManager.build(
72
+ HexBytes(HexStr(self.market_id))
73
+ ).get_token_price_from_pools(token=outcome_token)
74
+ if not price_outcome_token:
75
+ raise ValueError(
76
+ f"Could not find price for {outcome_token=} and {self.collateral_token_address}"
77
+ )
78
+
79
+ amount_out_minimum = self._calculate_amount_out_minimum(
80
+ amount_wei=amount_wei,
81
+ token_in=token_in,
82
+ price_outcome_token=price_outcome_token,
83
+ )
84
+
85
+ p = ExactInputSingleParams(
86
+ token_in=token_in,
87
+ token_out=token_out,
88
+ recipient=self.api_keys.bet_from_address,
89
+ amount_in=amount_wei,
90
+ amount_out_minimum=amount_out_minimum,
91
+ )
92
+
93
+ tx_receipt = SwaprRouterContract().exact_input_single(
94
+ api_keys=self.api_keys, params=p, web3=web3
95
+ )
96
+ return tx_receipt