prediction-market-agent-tooling 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. prediction_market_agent_tooling/abis/erc20.abi.json +315 -0
  2. prediction_market_agent_tooling/benchmark/agents.py +7 -1
  3. prediction_market_agent_tooling/benchmark/benchmark.py +22 -24
  4. prediction_market_agent_tooling/config.py +27 -4
  5. prediction_market_agent_tooling/deploy/agent.py +3 -3
  6. prediction_market_agent_tooling/markets/agent_market.py +22 -10
  7. prediction_market_agent_tooling/markets/manifold/manifold.py +9 -1
  8. prediction_market_agent_tooling/markets/omen/data_models.py +42 -11
  9. prediction_market_agent_tooling/markets/omen/omen.py +135 -52
  10. prediction_market_agent_tooling/markets/omen/omen_contracts.py +36 -34
  11. prediction_market_agent_tooling/markets/omen/omen_replicate.py +11 -16
  12. prediction_market_agent_tooling/markets/omen/omen_resolve_replicated.py +32 -25
  13. prediction_market_agent_tooling/markets/omen/omen_subgraph_handler.py +46 -13
  14. prediction_market_agent_tooling/markets/polymarket/polymarket.py +1 -1
  15. prediction_market_agent_tooling/monitor/markets/omen.py +5 -3
  16. prediction_market_agent_tooling/monitor/markets/polymarket.py +3 -2
  17. prediction_market_agent_tooling/monitor/monitor.py +26 -20
  18. prediction_market_agent_tooling/tools/betting_strategies/minimum_bet_to_win.py +1 -1
  19. prediction_market_agent_tooling/tools/contract.py +32 -17
  20. prediction_market_agent_tooling/tools/costs.py +31 -0
  21. prediction_market_agent_tooling/tools/parallelism.py +16 -1
  22. prediction_market_agent_tooling/tools/safe.py +130 -0
  23. prediction_market_agent_tooling/tools/web3_utils.py +100 -15
  24. {prediction_market_agent_tooling-0.14.0.dist-info → prediction_market_agent_tooling-0.15.0.dist-info}/METADATA +13 -1
  25. {prediction_market_agent_tooling-0.14.0.dist-info → prediction_market_agent_tooling-0.15.0.dist-info}/RECORD +28 -25
  26. {prediction_market_agent_tooling-0.14.0.dist-info → prediction_market_agent_tooling-0.15.0.dist-info}/LICENSE +0 -0
  27. {prediction_market_agent_tooling-0.14.0.dist-info → prediction_market_agent_tooling-0.15.0.dist-info}/WHEEL +0 -0
  28. {prediction_market_agent_tooling-0.14.0.dist-info → prediction_market_agent_tooling-0.15.0.dist-info}/entry_points.txt +0 -0
@@ -2,7 +2,9 @@ import sys
2
2
  import typing as t
3
3
  from datetime import datetime
4
4
 
5
+ import tenacity
5
6
  from eth_typing import ChecksumAddress
7
+ from loguru import logger
6
8
  from subgrounds import FieldPath, Subgrounds
7
9
 
8
10
  from prediction_market_agent_tooling.gtypes import HexAddress, HexBytes, Wei, wei_type
@@ -39,6 +41,14 @@ class OmenSubgraphHandler(metaclass=SingletonMeta):
39
41
 
40
42
  def __init__(self) -> None:
41
43
  self.sg = Subgrounds()
44
+
45
+ # Patch the query_json method to retry on failure.
46
+ self.sg.query_json = tenacity.retry(
47
+ stop=tenacity.stop_after_attempt(3),
48
+ wait=tenacity.wait_fixed(1),
49
+ after=lambda x: logger.debug(f"query_json failed, {x.attempt_number=}."),
50
+ )(self.sg.query_json)
51
+
42
52
  # Load the subgraph
43
53
  self.trades_subgraph = self.sg.load_subgraph(self.OMEN_TRADES_SUBGRAPH)
44
54
  self.conditional_tokens_subgraph = self.sg.load_subgraph(
@@ -121,6 +131,8 @@ class OmenSubgraphHandler(metaclass=SingletonMeta):
121
131
  markets_field.outcomes,
122
132
  markets_field.outcomeTokenAmounts,
123
133
  markets_field.outcomeTokenMarginalPrices,
134
+ markets_field.lastActiveDay,
135
+ markets_field.lastActiveHour,
124
136
  markets_field.fee,
125
137
  markets_field.answerFinalizedTimestamp,
126
138
  markets_field.resolutionTimestamp,
@@ -419,26 +431,30 @@ class OmenSubgraphHandler(metaclass=SingletonMeta):
419
431
  items = self._parse_items_from_json(result)
420
432
  return [OmenUserPosition.model_validate(i) for i in items]
421
433
 
422
- def get_bets(
434
+ def get_trades(
423
435
  self,
424
- better_address: ChecksumAddress,
425
- start_time: datetime,
436
+ better_address: ChecksumAddress | None = None,
437
+ start_time: datetime | None = None,
426
438
  end_time: t.Optional[datetime] = None,
427
- market_id: t.Optional[str] = None,
439
+ market_id: t.Optional[ChecksumAddress] = None,
428
440
  filter_by_answer_finalized_not_null: bool = False,
441
+ type_: t.Literal["Buy", "Sell"] | None = None,
429
442
  ) -> list[OmenBet]:
430
443
  if not end_time:
431
444
  end_time = utcnow()
432
445
 
433
446
  trade = self.trades_subgraph.FpmmTrade
434
- where_stms = [
435
- trade.type == "Buy",
436
- trade.creator == better_address.lower(),
437
- trade.creationTimestamp >= to_int_timestamp(start_time),
438
- trade.creationTimestamp <= to_int_timestamp(end_time),
439
- ]
447
+ where_stms = []
448
+ if start_time:
449
+ where_stms.append(trade.creationTimestamp >= to_int_timestamp(start_time))
450
+ if end_time:
451
+ where_stms.append(trade.creationTimestamp <= to_int_timestamp(end_time))
452
+ if type_:
453
+ where_stms.append(trade.type == type_)
454
+ if better_address:
455
+ where_stms.append(trade.creator == better_address.lower())
440
456
  if market_id:
441
- where_stms.append(trade.fpmm == market_id)
457
+ where_stms.append(trade.fpmm == market_id.lower())
442
458
  if filter_by_answer_finalized_not_null:
443
459
  where_stms.append(trade.fpmm.answerFinalizedTimestamp != None)
444
460
 
@@ -450,12 +466,29 @@ class OmenSubgraphHandler(metaclass=SingletonMeta):
450
466
  items = self._parse_items_from_json(result)
451
467
  return [OmenBet.model_validate(i) for i in items]
452
468
 
469
+ def get_bets(
470
+ self,
471
+ better_address: ChecksumAddress | None = None,
472
+ start_time: datetime | None = None,
473
+ end_time: t.Optional[datetime] = None,
474
+ market_id: t.Optional[ChecksumAddress] = None,
475
+ filter_by_answer_finalized_not_null: bool = False,
476
+ ) -> list[OmenBet]:
477
+ return self.get_trades(
478
+ better_address=better_address,
479
+ start_time=start_time,
480
+ end_time=end_time,
481
+ market_id=market_id,
482
+ filter_by_answer_finalized_not_null=filter_by_answer_finalized_not_null,
483
+ type_="Buy", # We consider `bet` to be only the `Buy` trade types.
484
+ )
485
+
453
486
  def get_resolved_bets(
454
487
  self,
455
488
  better_address: ChecksumAddress,
456
489
  start_time: datetime,
457
490
  end_time: t.Optional[datetime] = None,
458
- market_id: t.Optional[str] = None,
491
+ market_id: t.Optional[ChecksumAddress] = None,
459
492
  ) -> list[OmenBet]:
460
493
  omen_bets = self.get_bets(
461
494
  better_address=better_address,
@@ -471,7 +504,7 @@ class OmenSubgraphHandler(metaclass=SingletonMeta):
471
504
  better_address: ChecksumAddress,
472
505
  start_time: datetime,
473
506
  end_time: t.Optional[datetime] = None,
474
- market_id: t.Optional[str] = None,
507
+ market_id: t.Optional[ChecksumAddress] = None,
475
508
  ) -> list[OmenBet]:
476
509
  bets = self.get_resolved_bets(
477
510
  better_address=better_address,
@@ -33,7 +33,7 @@ class PolymarketAgentMarket(AgentMarket):
33
33
  question=model.question,
34
34
  outcomes=[x.outcome for x in model.tokens],
35
35
  resolution=model.resolution,
36
- p_yes=model.p_yes,
36
+ current_p_yes=model.p_yes,
37
37
  created_time=None,
38
38
  close_time=model.end_date_iso,
39
39
  url=model.url,
@@ -2,7 +2,7 @@ import typing as t
2
2
 
3
3
  from google.cloud.functions_v2.types.functions import Function
4
4
 
5
- from prediction_market_agent_tooling.config import APIKeys
5
+ from prediction_market_agent_tooling.config import APIKeys, PrivateCredentials
6
6
  from prediction_market_agent_tooling.deploy.constants import MARKET_TYPE_KEY
7
7
  from prediction_market_agent_tooling.gtypes import ChecksumAddress, DatetimeWithTimezone
8
8
  from prediction_market_agent_tooling.markets.data_models import ResolvedBet
@@ -49,7 +49,8 @@ class DeployedOmenAgent(DeployedAgent):
49
49
  and api_keys.BET_FROM_PRIVATE_KEY
50
50
  != APIKeys().BET_FROM_PRIVATE_KEY # Check that it didn't get if from the default env.
51
51
  ):
52
- env_vars["omen_public_key"] = api_keys.bet_from_address
52
+ private_credentials = PrivateCredentials.from_api_keys(api_keys)
53
+ env_vars["omen_public_key"] = private_credentials.public_key
53
54
  return super().from_env_vars_without_prefix(
54
55
  env_vars=env_vars, extra_vars=extra_vars
55
56
  )
@@ -60,10 +61,11 @@ class DeployedOmenAgent(DeployedAgent):
60
61
  start_time: DatetimeWithTimezone,
61
62
  api_keys: APIKeys,
62
63
  ) -> "DeployedOmenAgent":
64
+ private_credentials = PrivateCredentials.from_api_keys(api_keys)
63
65
  return DeployedOmenAgent(
64
66
  name=name,
65
67
  start_time=start_time,
66
- omen_public_key=api_keys.bet_from_address,
68
+ omen_public_key=private_credentials.public_key,
67
69
  )
68
70
 
69
71
  @classmethod
@@ -2,7 +2,7 @@ import typing as t
2
2
 
3
3
  from google.cloud.functions_v2.types.functions import Function
4
4
 
5
- from prediction_market_agent_tooling.config import APIKeys
5
+ from prediction_market_agent_tooling.config import APIKeys, PrivateCredentials
6
6
  from prediction_market_agent_tooling.deploy.constants import MARKET_TYPE_KEY
7
7
  from prediction_market_agent_tooling.gtypes import ChecksumAddress, DatetimeWithTimezone
8
8
  from prediction_market_agent_tooling.markets.data_models import ResolvedBet
@@ -28,10 +28,11 @@ class DeployedPolymarketAgent(DeployedAgent):
28
28
  start_time: DatetimeWithTimezone,
29
29
  api_keys: APIKeys,
30
30
  ) -> "DeployedPolymarketAgent":
31
+ private_credentials = PrivateCredentials.from_api_keys(api_keys)
31
32
  return DeployedPolymarketAgent(
32
33
  name=name,
33
34
  start_time=start_time,
34
- polymarket_public_key=api_keys.bet_from_address,
35
+ polymarket_public_key=private_credentials.public_key,
35
36
  )
36
37
 
37
38
  @classmethod
@@ -24,6 +24,7 @@ from prediction_market_agent_tooling.deploy.gcp.utils import (
24
24
  )
25
25
  from prediction_market_agent_tooling.markets.agent_market import AgentMarket
26
26
  from prediction_market_agent_tooling.markets.data_models import Resolution, ResolvedBet
27
+ from prediction_market_agent_tooling.tools.parallelism import par_map
27
28
  from prediction_market_agent_tooling.tools.utils import (
28
29
  DatetimeWithTimezone,
29
30
  add_utc_timezone_validator,
@@ -259,16 +260,32 @@ def monitor_brier_score(resolved_markets: t.Sequence[AgentMarket]) -> None:
259
260
  """
260
261
  st.subheader("Brier Score (0-1, lower is better)")
261
262
 
262
- markets_to_squared_error = {
263
- m.created_time: m.get_squared_error() for m in resolved_markets
264
- }
263
+ # We need to use `get_last_trade_p_yes` instead of `current_p_yes` because, for resolved markets, the probabilities can be fixed to 0 and 1 (for example, on Omen).
264
+ # And for the brier score, we need the true market prediction, not its resolution after the outcome is known.
265
+ # If no trades were made, take it as 0.5 because the platform didn't provide any valuable information.
266
+ created_time_and_squared_errors = par_map(
267
+ list(resolved_markets),
268
+ lambda m: (
269
+ m.created_time,
270
+ (
271
+ (p_yes - m.boolean_outcome) ** 2
272
+ if (p_yes := m.get_last_trade_p_yes()) is not None
273
+ else None
274
+ ),
275
+ ),
276
+ )
277
+ created_time_and_squared_errors_with_trades = [
278
+ x for x in created_time_and_squared_errors if x[1] is not None
279
+ ]
265
280
  df = pd.DataFrame(
266
- markets_to_squared_error.items(), columns=["Date", "Squared Error"]
281
+ created_time_and_squared_errors_with_trades, columns=["Date", "Squared Error"]
267
282
  ).sort_values(by="Date")
268
283
 
269
284
  # Compute rolling mean squared error for last 30 markets
270
285
  df["Rolling Mean Squared Error"] = df["Squared Error"].rolling(window=30).mean()
271
286
 
287
+ st.write(f"Based on {len(df)} markets with at least one trade.")
288
+
272
289
  col1, col2 = st.columns(2)
273
290
  col1.metric(label="Overall", value=f"{df['Squared Error'].mean():.3f}")
274
291
  col2.metric(
@@ -293,27 +310,14 @@ def monitor_market_outcome_bias(
293
310
  st.subheader("Market Outcome Bias")
294
311
 
295
312
  date_to_open_yes_proportion = {
296
- d: np.mean([int(m.p_yes > 0.5) for m in markets])
313
+ d: np.mean([int(m.current_p_yes > 0.5) for m in markets])
297
314
  for d, markets in groupby(
298
315
  open_markets,
299
316
  lambda x: check_not_none(x.created_time, "Only markets with created time can be used here.").date(), # type: ignore # Bug, it says `Never has no attribute "date" [attr-defined]` with Mypy, but in VSCode it works correctly.
300
317
  )
301
318
  }
302
319
  date_to_resolved_yes_proportion = {
303
- d: np.mean(
304
- [
305
- (
306
- 1
307
- if m.resolution == Resolution.YES
308
- else (
309
- 0
310
- if m.resolution == Resolution.NO
311
- else should_not_happen(f"Unexpected resolution: {m.resolution}")
312
- )
313
- )
314
- for m in markets
315
- ]
316
- )
320
+ d: np.mean([int(m.boolean_outcome) for m in markets])
317
321
  for d, markets in groupby(
318
322
  resolved_markets,
319
323
  lambda x: check_not_none(x.created_time, "Only markets with created time can be used here.").date(), # type: ignore # Bug, it says `Never has no attribute "date" [attr-defined]` with Mypy, but in VSCode it works correctly.
@@ -349,7 +353,9 @@ def monitor_market_outcome_bias(
349
353
  use_container_width=True,
350
354
  )
351
355
 
352
- all_open_markets_yes_mean = np.mean([int(m.p_yes > 0.5) for m in open_markets])
356
+ all_open_markets_yes_mean = np.mean(
357
+ [int(m.current_p_yes > 0.5) for m in open_markets]
358
+ )
353
359
  all_resolved_markets_yes_mean = np.mean(
354
360
  [
355
361
  (
@@ -7,6 +7,6 @@ def minimum_bet_to_win(
7
7
  """
8
8
  Estimates the minimum bet amount to win the given amount based on the current market price.
9
9
  """
10
- share_price = market.p_yes if answer else market.p_no
10
+ share_price = market.current_p_yes if answer else market.current_p_no
11
11
  bet_amount = amount_to_win / (1 / share_price - 1)
12
12
  return bet_amount
@@ -6,12 +6,12 @@ from contextlib import contextmanager
6
6
  from pydantic import BaseModel, field_validator
7
7
  from web3 import Web3
8
8
 
9
+ from prediction_market_agent_tooling.config import PrivateCredentials
9
10
  from prediction_market_agent_tooling.gtypes import (
10
11
  ABI,
11
12
  ChainID,
12
13
  ChecksumAddress,
13
14
  Nonce,
14
- PrivateKey,
15
15
  TxParams,
16
16
  TxReceipt,
17
17
  Wei,
@@ -22,8 +22,8 @@ from prediction_market_agent_tooling.tools.gnosis_rpc import (
22
22
  )
23
23
  from prediction_market_agent_tooling.tools.web3_utils import (
24
24
  call_function_on_contract,
25
- private_key_to_public_key,
26
25
  send_function_on_contract_tx,
26
+ send_function_on_contract_tx_using_safe,
27
27
  )
28
28
 
29
29
 
@@ -86,7 +86,7 @@ class ContractBaseClass(BaseModel):
86
86
 
87
87
  def send(
88
88
  self,
89
- from_private_key: PrivateKey,
89
+ private_credentials: PrivateCredentials,
90
90
  function_name: str,
91
91
  function_params: t.Optional[list[t.Any] | dict[str, t.Any]] = None,
92
92
  tx_params: t.Optional[TxParams] = None,
@@ -96,22 +96,33 @@ class ContractBaseClass(BaseModel):
96
96
  """
97
97
  Used for changing a state (writing) to the contract.
98
98
  """
99
- with wait_until_nonce_changed(private_key_to_public_key(from_private_key)):
100
- receipt = send_function_on_contract_tx(
99
+
100
+ if private_credentials.safe_address is not None:
101
+ return send_function_on_contract_tx_using_safe(
101
102
  web3=web3 or self.get_web3(),
102
103
  contract_address=self.address,
103
104
  contract_abi=self.abi,
104
- from_private_key=from_private_key,
105
+ from_private_key=private_credentials.private_key,
106
+ safe_address=private_credentials.safe_address,
105
107
  function_name=function_name,
106
108
  function_params=function_params,
107
109
  tx_params=tx_params,
108
110
  timeout=timeout,
109
111
  )
110
- return receipt
112
+ return send_function_on_contract_tx(
113
+ web3=web3 or self.get_web3(),
114
+ contract_address=self.address,
115
+ contract_abi=self.abi,
116
+ from_private_key=private_credentials.private_key,
117
+ function_name=function_name,
118
+ function_params=function_params,
119
+ tx_params=tx_params,
120
+ timeout=timeout,
121
+ )
111
122
 
112
123
  def send_with_value(
113
124
  self,
114
- from_private_key: PrivateKey,
125
+ private_credentials: PrivateCredentials,
115
126
  function_name: str,
116
127
  amount_wei: Wei,
117
128
  function_params: t.Optional[list[t.Any] | dict[str, t.Any]] = None,
@@ -123,7 +134,7 @@ class ContractBaseClass(BaseModel):
123
134
  Used for changing a state (writing) to the contract, including sending chain's native currency.
124
135
  """
125
136
  return self.send(
126
- from_private_key=from_private_key,
137
+ private_credentials=private_credentials,
127
138
  function_name=function_name,
128
139
  function_params=function_params,
129
140
  tx_params={"value": amount_wei, **(tx_params or {})},
@@ -147,51 +158,55 @@ class ContractERC20BaseClass(ContractBaseClass):
147
158
 
148
159
  def approve(
149
160
  self,
161
+ private_credentials: PrivateCredentials,
150
162
  for_address: ChecksumAddress,
151
163
  amount_wei: Wei,
152
- from_private_key: PrivateKey,
153
164
  tx_params: t.Optional[TxParams] = None,
165
+ web3: Web3 | None = None,
154
166
  ) -> TxReceipt:
155
167
  return self.send(
156
- from_private_key=from_private_key,
168
+ private_credentials=private_credentials,
157
169
  function_name="approve",
158
170
  function_params=[
159
171
  for_address,
160
172
  amount_wei,
161
173
  ],
162
174
  tx_params=tx_params,
175
+ web3=web3,
163
176
  )
164
177
 
165
178
  def deposit(
166
179
  self,
180
+ private_credentials: PrivateCredentials,
167
181
  amount_wei: Wei,
168
- from_private_key: PrivateKey,
169
182
  tx_params: t.Optional[TxParams] = None,
183
+ web3: Web3 | None = None,
170
184
  ) -> TxReceipt:
171
185
  return self.send_with_value(
172
- from_private_key=from_private_key,
186
+ private_credentials=private_credentials,
173
187
  function_name="deposit",
174
188
  amount_wei=amount_wei,
175
189
  tx_params=tx_params,
190
+ web3=web3,
176
191
  )
177
192
 
178
193
  def withdraw(
179
194
  self,
195
+ private_credentials: PrivateCredentials,
180
196
  amount_wei: Wei,
181
- from_private_key: PrivateKey,
182
197
  tx_params: t.Optional[TxParams] = None,
183
198
  web3: Web3 | None = None,
184
199
  ) -> TxReceipt:
185
200
  return self.send(
186
- from_private_key=from_private_key,
201
+ private_credentials=private_credentials,
187
202
  function_name="withdraw",
188
203
  function_params=[amount_wei],
189
204
  tx_params=tx_params,
190
205
  web3=web3,
191
206
  )
192
207
 
193
- def balanceOf(self, for_address: ChecksumAddress) -> Wei:
194
- balance: Wei = self.call("balanceOf", [for_address])
208
+ def balanceOf(self, for_address: ChecksumAddress, web3: Web3 | None = None) -> Wei:
209
+ balance: Wei = self.call("balanceOf", [for_address], web3=web3)
195
210
  return balance
196
211
 
197
212
 
@@ -0,0 +1,31 @@
1
+ import typing as t
2
+ from contextlib import contextmanager
3
+ from time import time
4
+
5
+ from langchain_community.callbacks import get_openai_callback
6
+ from pydantic import BaseModel
7
+
8
+ from prediction_market_agent_tooling.benchmark.utils import get_llm_api_call_cost
9
+
10
+
11
+ class Costs(BaseModel):
12
+ time: float
13
+ cost: float
14
+
15
+
16
+ @contextmanager
17
+ def openai_costs(model: str | None = None) -> t.Generator[Costs, None, None]:
18
+ costs = Costs(time=0, cost=0)
19
+ start_time = time()
20
+
21
+ with get_openai_callback() as cb:
22
+ yield costs
23
+ if cb.total_tokens > 0 and cb.total_cost == 0 and model is not None:
24
+ # TODO: this is a hack to get the cost for an unsupported model
25
+ cb.total_cost = get_llm_api_call_cost(
26
+ model=model,
27
+ prompt_tokens=cb.prompt_tokens,
28
+ completion_tokens=cb.completion_tokens,
29
+ )
30
+ costs.time = time() - start_time
31
+ costs.cost = cb.total_cost
@@ -1,7 +1,7 @@
1
1
  import concurrent
2
2
  from concurrent.futures import Executor
3
3
  from concurrent.futures.thread import ThreadPoolExecutor
4
- from typing import Callable, TypeVar
4
+ from typing import Callable, Generator, TypeVar
5
5
 
6
6
  # Max workers to 5 to avoid rate limiting on some APIs, create a custom executor if you need more workers.
7
7
  DEFAULT_THREADPOOL_EXECUTOR = ThreadPoolExecutor(max_workers=5)
@@ -25,3 +25,18 @@ def par_map(
25
25
  for fut in futures:
26
26
  results.append(fut.result())
27
27
  return results
28
+
29
+
30
+ def par_generator(
31
+ items: list[A],
32
+ func: Callable[[A], B],
33
+ executor: Executor = DEFAULT_THREADPOOL_EXECUTOR,
34
+ ) -> Generator[B, None, None]:
35
+ """Applies the function to each element using the specified executor. Yields results as they come.
36
+ If executor is ProcessPoolExecutor, make sure the function passed is pickable, e.g. no lambda functions.
37
+ """
38
+ futures: list[concurrent.futures._base.Future[B]] = [
39
+ executor.submit(func, item) for item in items
40
+ ]
41
+ for fut in concurrent.futures.as_completed(futures):
42
+ yield fut.result()
@@ -0,0 +1,130 @@
1
+ from eth_account.signers.local import LocalAccount
2
+ from eth_typing import ChecksumAddress
3
+ from gnosis.eth import EthereumClient
4
+ from gnosis.eth.constants import NULL_ADDRESS
5
+ from gnosis.eth.contracts import get_safe_V1_4_1_contract
6
+ from gnosis.safe.proxy_factory import ProxyFactoryV141
7
+ from gnosis.safe.safe import Safe
8
+ from loguru import logger
9
+ from safe_cli.safe_addresses import (
10
+ get_default_fallback_handler_address,
11
+ get_proxy_factory_address,
12
+ get_safe_contract_address,
13
+ get_safe_l2_contract_address,
14
+ )
15
+ from web3.types import Wei
16
+
17
+ from prediction_market_agent_tooling.tools.hexbytes_custom import HexBytes
18
+ from prediction_market_agent_tooling.tools.web3_utils import wei_to_xdai
19
+
20
+
21
+ def create_safe(
22
+ ethereum_client: EthereumClient,
23
+ account: LocalAccount,
24
+ owners: list[str],
25
+ salt_nonce: int,
26
+ threshold: int = 1,
27
+ without_events: bool = False, # following safe-cli convention
28
+ ) -> ChecksumAddress | None:
29
+ to = NULL_ADDRESS
30
+ data = b""
31
+ payment_token = NULL_ADDRESS
32
+ payment = 0
33
+ payment_receiver = NULL_ADDRESS
34
+
35
+ if len(owners) < threshold:
36
+ raise ValueError("Threshold cannot be bigger than the number of unique owners")
37
+
38
+ ethereum_network = ethereum_client.get_network()
39
+
40
+ safe_contract_address = (
41
+ get_safe_contract_address(ethereum_client)
42
+ if without_events
43
+ else get_safe_l2_contract_address(ethereum_client)
44
+ )
45
+ proxy_factory_address = get_proxy_factory_address(ethereum_client)
46
+ fallback_handler = get_default_fallback_handler_address(ethereum_client)
47
+
48
+ if not ethereum_client.is_contract(safe_contract_address):
49
+ raise EnvironmentError(
50
+ f"Safe contract address {safe_contract_address} "
51
+ f"does not exist on network {ethereum_network.name}"
52
+ )
53
+ elif not ethereum_client.is_contract(proxy_factory_address):
54
+ raise EnvironmentError(
55
+ f"Proxy contract address {proxy_factory_address} "
56
+ f"does not exist on network {ethereum_network.name}"
57
+ )
58
+ elif fallback_handler != NULL_ADDRESS and not ethereum_client.is_contract(
59
+ fallback_handler
60
+ ):
61
+ raise EnvironmentError(
62
+ f"Fallback handler address {fallback_handler} "
63
+ f"does not exist on network {ethereum_network.name}"
64
+ )
65
+
66
+ account_balance = ethereum_client.get_balance(account.address)
67
+ account_balance_xdai = wei_to_xdai(account_balance)
68
+ # We set a reasonable expected balance below for Safe deployment not to fail.
69
+ if account_balance_xdai < 0.01:
70
+ raise ValueError(
71
+ f"Client's balance is {account_balance_xdai} xDAI, too low for deploying a Safe."
72
+ )
73
+
74
+ logger.info(
75
+ f"Network {ethereum_client.get_network().name} - Sender {account.address} - "
76
+ f"Balance: {account_balance_xdai} xDAI"
77
+ )
78
+
79
+ if not ethereum_client.w3.eth.get_code(
80
+ safe_contract_address
81
+ ) or not ethereum_client.w3.eth.get_code(proxy_factory_address):
82
+ raise EnvironmentError("Network not supported")
83
+
84
+ logger.info(
85
+ f"Creating new Safe with owners={owners} threshold={threshold} salt-nonce={salt_nonce}"
86
+ )
87
+
88
+ # We ignore mypy below because using the proper class SafeV141 yields an error and mypy
89
+ # doesn't understand that there is a hacky factory method (__new__) on this abstract class.
90
+ safe_version = Safe(safe_contract_address, ethereum_client).retrieve_version() # type: ignore
91
+ logger.info(
92
+ f"Safe-master-copy={safe_contract_address} version={safe_version}\n"
93
+ f"Fallback-handler={fallback_handler}\n"
94
+ f"Proxy factory={proxy_factory_address}"
95
+ )
96
+
97
+ # Note that by default a safe contract instance of version 1.4.1 will be fetched (determined
98
+ # by safe_contract_address), but if a different address (corresponding to an older version, e.g. 1.3.0)
99
+ # is passed, it will also work, since older versions also have the setup method.
100
+ safe_contract = get_safe_V1_4_1_contract(ethereum_client.w3, safe_contract_address)
101
+ safe_creation_tx_data = HexBytes(
102
+ safe_contract.functions.setup(
103
+ owners,
104
+ threshold,
105
+ to,
106
+ data,
107
+ fallback_handler,
108
+ payment_token,
109
+ payment,
110
+ payment_receiver,
111
+ ).build_transaction({"gas": 1, "gasPrice": Wei(1)})["data"]
112
+ )
113
+
114
+ proxy_factory = ProxyFactoryV141(proxy_factory_address, ethereum_client)
115
+ expected_safe_address = proxy_factory.calculate_proxy_address(
116
+ safe_contract_address, safe_creation_tx_data, salt_nonce
117
+ )
118
+ if ethereum_client.is_contract(expected_safe_address):
119
+ logger.info(f"Safe on {expected_safe_address} is already deployed")
120
+ return expected_safe_address
121
+
122
+ ethereum_tx_sent = proxy_factory.deploy_proxy_contract_with_nonce(
123
+ account, safe_contract_address, safe_creation_tx_data, salt_nonce
124
+ )
125
+ logger.info(
126
+ f"Sent tx with tx-hash={ethereum_tx_sent.tx_hash.hex()} "
127
+ f"Safe={ethereum_tx_sent.contract_address} is being created"
128
+ )
129
+ logger.info(f"Tx parameters={ethereum_tx_sent.tx}")
130
+ return ethereum_tx_sent.contract_address