prediction-market-agent-tooling 0.56.3.dev135__py3-none-any.whl → 0.57.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. prediction_market_agent_tooling/config.py +31 -2
  2. prediction_market_agent_tooling/deploy/agent.py +71 -48
  3. prediction_market_agent_tooling/markets/data_models.py +21 -0
  4. prediction_market_agent_tooling/markets/manifold/data_models.py +0 -1
  5. prediction_market_agent_tooling/markets/omen/data_models.py +11 -10
  6. prediction_market_agent_tooling/markets/omen/omen.py +2 -1
  7. prediction_market_agent_tooling/markets/omen/omen_contracts.py +1 -1
  8. prediction_market_agent_tooling/markets/omen/omen_subgraph_handler.py +8 -4
  9. prediction_market_agent_tooling/markets/polymarket/utils.py +1 -1
  10. prediction_market_agent_tooling/monitor/financial_metrics/financial_metrics.py +66 -0
  11. prediction_market_agent_tooling/monitor/monitor.py +1 -1
  12. prediction_market_agent_tooling/tools/caches/db_cache.py +9 -75
  13. prediction_market_agent_tooling/tools/caches/serializers.py +61 -0
  14. prediction_market_agent_tooling/tools/contract.py +3 -7
  15. prediction_market_agent_tooling/tools/custom_exceptions.py +6 -0
  16. prediction_market_agent_tooling/tools/db/db_manager.py +76 -0
  17. prediction_market_agent_tooling/tools/{google.py → google_utils.py} +28 -1
  18. prediction_market_agent_tooling/tools/image_gen/market_thumbnail_gen.py +1 -1
  19. prediction_market_agent_tooling/tools/is_invalid.py +1 -1
  20. prediction_market_agent_tooling/tools/is_predictable.py +8 -3
  21. prediction_market_agent_tooling/tools/langfuse_client_utils.py +1 -0
  22. prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_cache.py +7 -14
  23. prediction_market_agent_tooling/tools/transaction_cache.py +48 -0
  24. prediction_market_agent_tooling/tools/utils.py +1 -23
  25. prediction_market_agent_tooling/tools/web3_utils.py +5 -1
  26. {prediction_market_agent_tooling-0.56.3.dev135.dist-info → prediction_market_agent_tooling-0.57.1.dist-info}/METADATA +1 -1
  27. {prediction_market_agent_tooling-0.56.3.dev135.dist-info → prediction_market_agent_tooling-0.57.1.dist-info}/RECORD +30 -27
  28. prediction_market_agent_tooling/tools/gnosis_rpc.py +0 -6
  29. prediction_market_agent_tooling/tools/pickle_utils.py +0 -31
  30. {prediction_market_agent_tooling-0.56.3.dev135.dist-info → prediction_market_agent_tooling-0.57.1.dist-info}/LICENSE +0 -0
  31. {prediction_market_agent_tooling-0.56.3.dev135.dist-info → prediction_market_agent_tooling-0.57.1.dist-info}/WHEEL +0 -0
  32. {prediction_market_agent_tooling-0.56.3.dev135.dist-info → prediction_market_agent_tooling-0.57.1.dist-info}/entry_points.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  import typing as t
2
2
 
3
+ from pydantic import Field
3
4
  from pydantic.types import SecretStr
4
5
  from pydantic.v1.types import SecretStr as SecretStrV1
5
6
  from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -7,6 +8,7 @@ from safe_eth.eth import EthereumClient
7
8
  from safe_eth.safe.safe import SafeV141
8
9
 
9
10
  from prediction_market_agent_tooling.gtypes import (
11
+ ChainID,
10
12
  ChecksumAddress,
11
13
  PrivateKey,
12
14
  secretstr_to_v1_secretstr,
@@ -186,14 +188,14 @@ class APIKeys(BaseSettings):
186
188
  return {
187
189
  k: v
188
190
  for k, v in self.model_dump().items()
189
- if APIKeys.model_fields[k].annotation not in SECRET_TYPES and v is not None
191
+ if self.model_fields[k].annotation not in SECRET_TYPES and v is not None
190
192
  }
191
193
 
192
194
  def model_dump_secrets(self) -> dict[str, t.Any]:
193
195
  return {
194
196
  k: v.get_secret_value() if isinstance(v, SecretStr) else v
195
197
  for k, v in self.model_dump().items()
196
- if APIKeys.model_fields[k].annotation in SECRET_TYPES and v is not None
198
+ if self.model_fields[k].annotation in SECRET_TYPES and v is not None
197
199
  }
198
200
 
199
201
  def check_if_is_safe_owner(self, ethereum_client: EthereumClient) -> bool:
@@ -203,3 +205,30 @@ class APIKeys(BaseSettings):
203
205
  s = SafeV141(self.SAFE_ADDRESS, ethereum_client)
204
206
  public_key_from_signer = private_key_to_public_key(self.bet_from_private_key)
205
207
  return s.retrieve_is_owner(public_key_from_signer)
208
+
209
+
210
+ class RPCConfig(BaseSettings):
211
+ model_config = SettingsConfigDict(
212
+ env_file=".env", env_file_encoding="utf-8", extra="ignore"
213
+ )
214
+
215
+ GNOSIS_RPC_URL: str = Field(default="https://rpc.gnosischain.com")
216
+ CHAIN_ID: ChainID = Field(default=ChainID(100))
217
+
218
+ @property
219
+ def gnosis_rpc_url(self) -> str:
220
+ return check_not_none(
221
+ self.GNOSIS_RPC_URL, "GNOSIS_RPC_URL missing in the environment."
222
+ )
223
+
224
+ @property
225
+ def chain_id(self) -> ChainID:
226
+ return check_not_none(self.CHAIN_ID, "CHAIN_ID missing in the environment.")
227
+
228
+
229
+ class CloudCredentials(BaseSettings):
230
+ model_config = SettingsConfigDict(
231
+ env_file=".env", env_file_encoding="utf-8", extra="ignore"
232
+ )
233
+
234
+ GOOGLE_APPLICATION_CREDENTIALS: t.Optional[str] = None
@@ -8,8 +8,7 @@ from datetime import timedelta
8
8
  from enum import Enum
9
9
  from functools import cached_property
10
10
 
11
- from pydantic import BeforeValidator, computed_field
12
- from typing_extensions import Annotated
11
+ from pydantic import computed_field
13
12
 
14
13
  from prediction_market_agent_tooling.config import APIKeys
15
14
  from prediction_market_agent_tooling.deploy.betting_strategy import (
@@ -59,33 +58,16 @@ from prediction_market_agent_tooling.markets.omen.omen import (
59
58
  from prediction_market_agent_tooling.monitor.monitor_app import (
60
59
  MARKET_TYPE_TO_DEPLOYED_AGENT,
61
60
  )
61
+ from prediction_market_agent_tooling.tools.custom_exceptions import (
62
+ CantPayForGasError,
63
+ OutOfFundsError,
64
+ )
62
65
  from prediction_market_agent_tooling.tools.is_invalid import is_invalid
63
66
  from prediction_market_agent_tooling.tools.is_predictable import is_predictable_binary
64
67
  from prediction_market_agent_tooling.tools.langfuse_ import langfuse_context, observe
65
68
  from prediction_market_agent_tooling.tools.utils import DatetimeUTC, utcnow
66
69
 
67
70
  MAX_AVAILABLE_MARKETS = 20
68
- TRADER_TAG = "trader"
69
-
70
-
71
- def to_boolean_outcome(value: str | bool) -> bool:
72
- if isinstance(value, bool):
73
- return value
74
-
75
- elif isinstance(value, str):
76
- value = value.lower().strip()
77
-
78
- if value in {"true", "yes", "y", "1"}:
79
- return True
80
-
81
- elif value in {"false", "no", "n", "0"}:
82
- return False
83
-
84
- else:
85
- raise ValueError(f"Expected a boolean string, but got {value}")
86
-
87
- else:
88
- raise ValueError(f"Expected a boolean or a string, but got {value}")
89
71
 
90
72
 
91
73
  def initialize_langfuse(enable_langfuse: bool) -> None:
@@ -103,23 +85,21 @@ def initialize_langfuse(enable_langfuse: bool) -> None:
103
85
  langfuse_context.configure(enabled=enable_langfuse)
104
86
 
105
87
 
106
- Decision = Annotated[bool, BeforeValidator(to_boolean_outcome)]
107
-
108
-
109
- class CantPayForGasError(ValueError):
110
- pass
111
-
112
-
113
- class OutOfFundsError(ValueError):
114
- pass
115
-
116
-
117
88
  class AnsweredEnum(str, Enum):
118
89
  ANSWERED = "answered"
119
90
  NOT_ANSWERED = "not_answered"
120
91
 
121
92
 
93
+ class AgentTagEnum(str, Enum):
94
+ PREDICTOR = "predictor"
95
+ TRADER = "trader"
96
+
97
+
122
98
  class DeployableAgent:
99
+ """
100
+ Subclass this class to create agent with standardized interface.
101
+ """
102
+
123
103
  def __init__(
124
104
  self,
125
105
  enable_langfuse: bool = APIKeys().default_enable_langfuse,
@@ -180,20 +160,25 @@ class DeployableAgent:
180
160
  )
181
161
 
182
162
  def load(self) -> None:
183
- pass
163
+ """
164
+ Implement this method to load arbitrary instances needed across the whole run of the agent.
165
+
166
+ Do not customize __init__ method.
167
+ """
184
168
 
185
169
  def deploy_local(
186
170
  self,
187
171
  market_type: MarketType,
188
172
  sleep_time: float,
189
- timeout: float,
173
+ run_time: float | None,
190
174
  ) -> None:
175
+ """
176
+ Run the agent in the forever cycle every `sleep_time` seconds, until the `run_time` is met.
177
+ """
191
178
  start_time = time.time()
192
- while True:
179
+ while run_time is None or time.time() - start_time < run_time:
193
180
  self.run(market_type=market_type)
194
181
  time.sleep(sleep_time)
195
- if time.time() - start_time > timeout:
196
- break
197
182
 
198
183
  def deploy_gcp(
199
184
  self,
@@ -209,6 +194,9 @@ class DeployableAgent:
209
194
  start_time: DatetimeUTC | None = None,
210
195
  timeout: int = 180,
211
196
  ) -> None:
197
+ """
198
+ Deploy the agent as GCP Function.
199
+ """
212
200
  path_to_agent_file = os.path.relpath(inspect.getfile(self.__class__))
213
201
 
214
202
  entrypoint_function_name = "main"
@@ -275,6 +263,9 @@ def {entrypoint_function_name}(request) -> str:
275
263
  schedule_deployed_gcp_function(fname, cron_schedule=cron_schedule)
276
264
 
277
265
  def run(self, market_type: MarketType) -> None:
266
+ """
267
+ Run single iteration of the agent.
268
+ """
278
269
  raise NotImplementedError("This method must be implemented by the subclass.")
279
270
 
280
271
  def get_gcloud_fname(self, market_type: MarketType) -> str:
@@ -282,6 +273,14 @@ def {entrypoint_function_name}(request) -> str:
282
273
 
283
274
 
284
275
  class DeployablePredictionAgent(DeployableAgent):
276
+ """
277
+ Subclass this class to create your own prediction market agent.
278
+
279
+ The agent will process markets and make predictions.
280
+ """
281
+
282
+ AGENT_TAG: AgentTagEnum = AgentTagEnum.PREDICTOR
283
+
285
284
  bet_on_n_markets_per_run: int = 1
286
285
 
287
286
  # Agent behaviour when fetching markets
@@ -301,10 +300,10 @@ class DeployablePredictionAgent(DeployableAgent):
301
300
  def __init__(
302
301
  self,
303
302
  enable_langfuse: bool = APIKeys().default_enable_langfuse,
304
- store_prediction: bool = True,
303
+ store_predictions: bool = True,
305
304
  ) -> None:
306
305
  super().__init__(enable_langfuse=enable_langfuse)
307
- self.store_prediction = store_prediction
306
+ self.store_predictions = store_predictions
308
307
 
309
308
  def initialize_langfuse(self) -> None:
310
309
  super().initialize_langfuse()
@@ -332,7 +331,7 @@ class DeployablePredictionAgent(DeployableAgent):
332
331
  ) -> None:
333
332
  self.langfuse_update_current_trace(
334
333
  tags=[
335
- TRADER_TAG,
334
+ self.AGENT_TAG,
336
335
  (
337
336
  AnsweredEnum.ANSWERED
338
337
  if processed_market is not None
@@ -390,6 +389,9 @@ class DeployablePredictionAgent(DeployableAgent):
390
389
  self,
391
390
  market_type: MarketType,
392
391
  ) -> t.Sequence[AgentMarket]:
392
+ """
393
+ Override this method to customize what markets will fetch for processing.
394
+ """
393
395
  cls = market_type.market_class
394
396
  # Fetch the soonest closing markets to choose from
395
397
  available_markets = cls.get_binary_markets(
@@ -403,6 +405,9 @@ class DeployablePredictionAgent(DeployableAgent):
403
405
  def before_process_market(
404
406
  self, market_type: MarketType, market: AgentMarket
405
407
  ) -> None:
408
+ """
409
+ Executed before processing of each market.
410
+ """
406
411
  api_keys = APIKeys()
407
412
 
408
413
  if market_type.is_blockchain_market:
@@ -446,19 +451,22 @@ class DeployablePredictionAgent(DeployableAgent):
446
451
  market: AgentMarket,
447
452
  processed_market: ProcessedMarket | None,
448
453
  ) -> None:
454
+ """
455
+ Executed after processing of each market.
456
+ """
449
457
  keys = APIKeys()
450
- if self.store_prediction:
458
+ if self.store_predictions:
451
459
  market.store_prediction(
452
460
  processed_market=processed_market, keys=keys, agent_name=self.agent_name
453
461
  )
454
462
  else:
455
463
  logger.info(
456
- f"Prediction {processed_market} not stored because {self.store_prediction=}."
464
+ f"Prediction {processed_market} not stored because {self.store_predictions=}."
457
465
  )
458
466
 
459
467
  def before_process_markets(self, market_type: MarketType) -> None:
460
468
  """
461
- Executes actions that occur before bets are placed.
469
+ Executed before market processing loop starts.
462
470
  """
463
471
  api_keys = APIKeys()
464
472
  self.check_min_required_balance_to_operate(market_type)
@@ -489,7 +497,9 @@ class DeployablePredictionAgent(DeployableAgent):
489
497
  logger.info("All markets processed.")
490
498
 
491
499
  def after_process_markets(self, market_type: MarketType) -> None:
492
- "Executes actions that occur after bets are placed."
500
+ """
501
+ Executed after market processing loop ends.
502
+ """
493
503
 
494
504
  def run(self, market_type: MarketType) -> None:
495
505
  if market_type not in self.supported_markets:
@@ -502,6 +512,14 @@ class DeployablePredictionAgent(DeployableAgent):
502
512
 
503
513
 
504
514
  class DeployableTraderAgent(DeployablePredictionAgent):
515
+ """
516
+ Subclass this class to create your own prediction market trading agent.
517
+
518
+ The agent will process markets, make predictions and place trades (bets) based off these predictions.
519
+ """
520
+
521
+ AGENT_TAG: AgentTagEnum = AgentTagEnum.TRADER
522
+
505
523
  # These markets require place of bet, not just predictions.
506
524
  supported_markets: t.Sequence[MarketType] = [
507
525
  MarketType.OMEN,
@@ -512,12 +530,12 @@ class DeployableTraderAgent(DeployablePredictionAgent):
512
530
  def __init__(
513
531
  self,
514
532
  enable_langfuse: bool = APIKeys().default_enable_langfuse,
515
- store_prediction: bool = True,
533
+ store_predictions: bool = True,
516
534
  store_trades: bool = True,
517
535
  place_trades: bool = True,
518
536
  ) -> None:
519
537
  super().__init__(
520
- enable_langfuse=enable_langfuse, store_prediction=store_prediction
538
+ enable_langfuse=enable_langfuse, store_predictions=store_predictions
521
539
  )
522
540
  self.store_trades = store_trades
523
541
  self.place_trades = place_trades
@@ -541,6 +559,11 @@ class DeployableTraderAgent(DeployablePredictionAgent):
541
559
  )
542
560
 
543
561
  def get_betting_strategy(self, market: AgentMarket) -> BettingStrategy:
562
+ """
563
+ Override this method to customize betting strategy of your agent.
564
+
565
+ Given the market and prediction, agent uses this method to calculate optimal outcome and bet size.
566
+ """
544
567
  user_id = market.get_user_id(api_keys=APIKeys())
545
568
 
546
569
  total_amount = market.get_tiny_bet_amount().amount
@@ -140,3 +140,24 @@ class PlacedTrade(Trade):
140
140
  amount=trade.amount,
141
141
  id=id,
142
142
  )
143
+
144
+
145
+ class SimulationDetail(BaseModel):
146
+ strategy: str
147
+ url: str
148
+ market_p_yes: float
149
+ agent_p_yes: float
150
+ agent_conf: float
151
+ org_bet: float
152
+ sim_bet: float
153
+ org_dir: bool
154
+ sim_dir: bool
155
+ org_profit: float
156
+ sim_profit: float
157
+ timestamp: DatetimeUTC
158
+
159
+
160
+ class SharpeOutput(BaseModel):
161
+ annualized_volatility: float
162
+ mean_daily_return: float
163
+ annualized_sharpe_ratio: float
@@ -200,7 +200,6 @@ class ManifoldBet(BaseModel):
200
200
  if self.get_resolved_boolean_outcome() == market_outcome
201
201
  else -self.amount
202
202
  )
203
- profit -= self.fees.get_total()
204
203
  return ProfitAmount(
205
204
  amount=profit,
206
205
  currency=Currency.Mana,
@@ -531,7 +531,6 @@ class OmenBet(BaseModel):
531
531
  if self.boolean_outcome == self.fpmm.boolean_outcome
532
532
  else -bet_amount_xdai
533
533
  )
534
- profit -= wei_to_xdai(self.feeAmount)
535
534
  return ProfitAmount(
536
535
  amount=profit,
537
536
  currency=Currency.xDai,
@@ -539,9 +538,8 @@ class OmenBet(BaseModel):
539
538
 
540
539
  def to_bet(self) -> Bet:
541
540
  return Bet(
542
- id=str(
543
- self.transactionHash
544
- ), # Use the transaction hash instead of the bet id - both are valid, but we return the transaction hash from the trade functions, so be consistent here.
541
+ id=str(self.transactionHash),
542
+ # Use the transaction hash instead of the bet id - both are valid, but we return the transaction hash from the trade functions, so be consistent here.
545
543
  amount=BetAmount(amount=self.collateralAmountUSD, currency=Currency.xDai),
546
544
  outcome=self.boolean_outcome,
547
545
  created_time=self.creation_datetime,
@@ -556,9 +554,8 @@ class OmenBet(BaseModel):
556
554
  )
557
555
 
558
556
  return ResolvedBet(
559
- id=str(
560
- self.transactionHash
561
- ), # Use the transaction hash instead of the bet id - both are valid, but we return the transaction hash from the trade functions, so be consistent here.
557
+ id=self.transactionHash.hex(),
558
+ # Use the transaction hash instead of the bet id - both are valid, but we return the transaction hash from the trade functions, so be consistent here.
562
559
  amount=BetAmount(amount=self.collateralAmountUSD, currency=Currency.xDai),
563
560
  outcome=self.boolean_outcome,
564
561
  created_time=self.creation_datetime,
@@ -800,9 +797,13 @@ class ContractPrediction(BaseModel):
800
797
  return Web3.to_checksum_address(self.publisher)
801
798
 
802
799
  @staticmethod
803
- def from_tuple(values: tuple[t.Any]) -> "ContractPrediction":
804
- data = {k: v for k, v in zip(ContractPrediction.model_fields.keys(), values)}
805
- return ContractPrediction.model_validate(data)
800
+ def from_tuple(values: tuple[t.Any, ...]) -> "ContractPrediction":
801
+ return ContractPrediction(
802
+ publisher=values[0],
803
+ ipfs_hash=values[1],
804
+ tx_hashes=values[2],
805
+ estimated_probability_bps=values[3],
806
+ )
806
807
 
807
808
 
808
809
  class IPFSAgentResult(BaseModel):
@@ -75,6 +75,7 @@ from prediction_market_agent_tooling.tools.contract import (
75
75
  init_collateral_token_contract,
76
76
  to_gnosis_chain_contract,
77
77
  )
78
+ from prediction_market_agent_tooling.tools.custom_exceptions import OutOfFundsError
78
79
  from prediction_market_agent_tooling.tools.hexbytes_custom import HexBytes
79
80
  from prediction_market_agent_tooling.tools.ipfs.ipfs_handler import IPFSHandler
80
81
  from prediction_market_agent_tooling.tools.utils import (
@@ -1339,7 +1340,7 @@ def withdraw_wxdai_to_xdai_to_keep_balance(
1339
1340
  )
1340
1341
 
1341
1342
  if current_balances.wxdai < need_to_withdraw:
1342
- raise ValueError(
1343
+ raise OutOfFundsError(
1343
1344
  f"Current wxDai balance {current_balances.wxdai} is less than the required minimum wxDai to withdraw {need_to_withdraw}."
1344
1345
  )
1345
1346
 
@@ -538,7 +538,7 @@ class OmenKlerosContract(ContractOnGnosisChain):
538
538
  address = "0xe40DD83a262da3f56976038F1554Fe541Fa75ecd"
539
539
 
540
540
  elif arbitrator == Arbitrator.KLEROS_31_JURORS_WITH_APPEAL:
541
- address = "0x29f39de98d750eb77b5fafb31b2837f079fce222"
541
+ address = "0x5562Ac605764DC4039fb6aB56a74f7321396Cdf2"
542
542
 
543
543
  else:
544
544
  raise ValueError(f"Unsupported arbitrator: {arbitrator=}")
@@ -435,10 +435,14 @@ class OmenSubgraphHandler(BaseSubgraphHandler):
435
435
  omen_markets = self.do_query(fields=fields, pydantic_model=OmenMarket)
436
436
  return omen_markets
437
437
 
438
- def get_omen_market_by_market_id(self, market_id: HexAddress) -> OmenMarket:
439
- markets = self.trades_subgraph.Query.fixedProductMarketMaker(
440
- id=market_id.lower()
441
- )
438
+ def get_omen_market_by_market_id(
439
+ self, market_id: HexAddress, block_number: int | None = None
440
+ ) -> OmenMarket:
441
+ query_filters: dict[str, t.Any] = {"id": market_id.lower()}
442
+ if block_number:
443
+ query_filters["block"] = {"number": block_number}
444
+
445
+ markets = self.trades_subgraph.Query.fixedProductMarketMaker(**query_filters)
442
446
 
443
447
  fields = self._get_fields_for_markets(markets)
444
448
  omen_markets = self.do_query(fields=fields, pydantic_model=OmenMarket)
@@ -3,7 +3,7 @@ from prediction_market_agent_tooling.markets.markets import MarketType
3
3
  from prediction_market_agent_tooling.markets.polymarket.data_models_web import (
4
4
  PolymarketFullMarket,
5
5
  )
6
- from prediction_market_agent_tooling.tools.google import search_google
6
+ from prediction_market_agent_tooling.tools.google_utils import search_google
7
7
 
8
8
 
9
9
  def find_resolution_on_polymarket(question: str) -> Resolution | None:
@@ -0,0 +1,66 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ from prediction_market_agent_tooling.markets.data_models import (
5
+ SharpeOutput,
6
+ SimulationDetail,
7
+ )
8
+
9
+
10
+ class SharpeRatioCalculator:
11
+ def __init__(
12
+ self, details: list[SimulationDetail], risk_free_rate: float = 0.0
13
+ ) -> None:
14
+ self.details = details
15
+ self.df = pd.DataFrame([d.model_dump() for d in self.details])
16
+ self.risk_free_rate = risk_free_rate
17
+
18
+ def __has_df_valid_columns_else_exception(
19
+ self, required_columns: list[str]
20
+ ) -> None:
21
+ if not set(required_columns).issubset(self.df.columns):
22
+ raise ValueError("Dataframe doesn't contain all the required columns.")
23
+
24
+ def prepare_wallet_daily_balance_df(
25
+ self, timestamp_col_name: str, profit_col_name: str
26
+ ) -> pd.DataFrame:
27
+ self.__has_df_valid_columns_else_exception(
28
+ [timestamp_col_name, profit_col_name]
29
+ )
30
+ df = self.df.copy()
31
+ df[timestamp_col_name] = pd.to_datetime(df[timestamp_col_name])
32
+ df.sort_values(timestamp_col_name, ascending=True, inplace=True)
33
+
34
+ df["profit_cumsum"] = df[profit_col_name].cumsum()
35
+ df["profit_cumsum"] = df["profit_cumsum"] + 50
36
+
37
+ df = df.drop_duplicates(subset=timestamp_col_name, keep="last")
38
+ df.set_index(timestamp_col_name, inplace=True)
39
+ # We generate a new Dataframe with daily wallet balances, derived by the final wallet balance
40
+ # from the previous day.
41
+ wallet_balance_daily_df = df[["profit_cumsum"]].resample("D").ffill()
42
+ wallet_balance_daily_df.dropna(inplace=True)
43
+ wallet_balance_daily_df["returns"] = wallet_balance_daily_df[
44
+ "profit_cumsum"
45
+ ].pct_change()
46
+ return wallet_balance_daily_df
47
+
48
+ def calculate_annual_sharpe_ratio(
49
+ self, timestamp_col_name: str = "timestamp", profit_col_name: str = "sim_profit"
50
+ ) -> SharpeOutput:
51
+ wallet_daily_balance_df = self.prepare_wallet_daily_balance_df(
52
+ timestamp_col_name=timestamp_col_name, profit_col_name=profit_col_name
53
+ )
54
+
55
+ daily_volatility = wallet_daily_balance_df["returns"].std()
56
+ annualized_volatility = daily_volatility * np.sqrt(365)
57
+ mean_daily_return = wallet_daily_balance_df["returns"].mean()
58
+ daily_sharpe_ratio = (
59
+ mean_daily_return - self.risk_free_rate
60
+ ) / daily_volatility
61
+ annualized_sharpe_ratio = daily_sharpe_ratio * np.sqrt(365)
62
+ return SharpeOutput(
63
+ annualized_volatility=annualized_volatility,
64
+ mean_daily_return=mean_daily_return,
65
+ annualized_sharpe_ratio=annualized_sharpe_ratio,
66
+ )
@@ -223,7 +223,7 @@ def monitor_agent(agent: DeployedAgent) -> None:
223
223
  )
224
224
  .interactive()
225
225
  )
226
- st.altair_chart( # type: ignore # Doesn't expect `LayerChart`, but `Chart`, yet it works.
226
+ st.altair_chart(
227
227
  per_day_accuracy_chart.mark_line()
228
228
  + per_day_accuracy_chart.transform_loess("x-axis-day", "Is Correct").mark_line(
229
229
  color="red", strokeDash=[5, 5]
@@ -1,7 +1,7 @@
1
1
  import hashlib
2
2
  import inspect
3
3
  import json
4
- from datetime import date, timedelta
4
+ from datetime import timedelta
5
5
  from functools import wraps
6
6
  from typing import (
7
7
  Any,
@@ -17,12 +17,12 @@ from typing import (
17
17
  from pydantic import BaseModel
18
18
  from sqlalchemy import Column
19
19
  from sqlalchemy.dialects.postgresql import JSONB
20
- from sqlmodel import Field, Session, SQLModel, create_engine, desc, select
20
+ from sqlmodel import Field, SQLModel, desc, select
21
21
 
22
22
  from prediction_market_agent_tooling.config import APIKeys
23
23
  from prediction_market_agent_tooling.loggers import logger
24
24
  from prediction_market_agent_tooling.tools.datetime_utc import DatetimeUTC
25
- from prediction_market_agent_tooling.tools.pickle_utils import InitialiseNonPickable
25
+ from prediction_market_agent_tooling.tools.db.db_manager import DBManager
26
26
  from prediction_market_agent_tooling.tools.utils import utcnow
27
27
 
28
28
  FunctionT = TypeVar("FunctionT", bound=Callable[..., Any])
@@ -91,17 +91,6 @@ def db_cache(
91
91
  return decorator
92
92
 
93
93
  api_keys = api_keys if api_keys is not None else APIKeys()
94
- wrapped_engine = InitialiseNonPickable(
95
- lambda: create_engine(
96
- api_keys.sqlalchemy_db_url.get_secret_value(),
97
- # Use custom json serializer and deserializer, because otherwise, for example `datetime` serialization would fail.
98
- json_serializer=json_serializer,
99
- json_deserializer=json_deserializer,
100
- )
101
- )
102
-
103
- if api_keys.ENABLE_CACHE:
104
- SQLModel.metadata.create_all(wrapped_engine.get_value())
105
94
 
106
95
  @wraps(func)
107
96
  def wrapper(*args: Any, **kwargs: Any) -> Any:
@@ -109,7 +98,7 @@ def db_cache(
109
98
  if not api_keys.ENABLE_CACHE:
110
99
  return func(*args, **kwargs)
111
100
 
112
- engine = wrapped_engine.get_value()
101
+ DBManager(api_keys).create_tables([FunctionCache])
113
102
 
114
103
  # Convert *args and **kwargs to a single dictionary, where we have names for arguments passed as args as well.
115
104
  signature = inspect.signature(func)
@@ -150,12 +139,11 @@ def db_cache(
150
139
 
151
140
  # Determine if the function returns or contains Pydantic BaseModel(s)
152
141
  return_type = func.__annotations__.get("return", None)
153
- is_pydantic_model = False
154
-
155
- if return_type is not None and contains_pydantic_model(return_type):
156
- is_pydantic_model = True
142
+ is_pydantic_model = return_type is not None and contains_pydantic_model(
143
+ return_type
144
+ )
157
145
 
158
- with Session(engine) as session:
146
+ with DBManager(api_keys).get_session() as session:
159
147
  # Try to get cached result
160
148
  statement = (
161
149
  select(FunctionCache)
@@ -208,7 +196,7 @@ def db_cache(
208
196
  result=computed_result,
209
197
  created_at=utcnow(),
210
198
  )
211
- with Session(engine) as session:
199
+ with DBManager(api_keys).get_session() as session:
212
200
  logger.info(f"Saving {cache_entry} into database.")
213
201
  session.add(cache_entry)
214
202
  session.commit()
@@ -232,60 +220,6 @@ def contains_pydantic_model(return_type: Any) -> bool:
232
220
  return False
233
221
 
234
222
 
235
- def json_serializer_default_fn(
236
- y: DatetimeUTC | timedelta | date | BaseModel,
237
- ) -> str | dict[str, Any]:
238
- """
239
- Used to serialize objects that don't support it by default into a specific string that can be deserialized out later.
240
- If this function returns a dictionary, it will be called recursivelly.
241
- If you add something here, also add it to `replace_custom_stringified_objects` below.
242
- """
243
- if isinstance(y, DatetimeUTC):
244
- return f"DatetimeUTC::{y.isoformat()}"
245
- elif isinstance(y, timedelta):
246
- return f"timedelta::{y.total_seconds()}"
247
- elif isinstance(y, date):
248
- return f"date::{y.isoformat()}"
249
- elif isinstance(y, BaseModel):
250
- return y.model_dump()
251
- raise TypeError(
252
- f"Unsuported type for the default json serialize function, value is {y}."
253
- )
254
-
255
-
256
- def json_serializer(x: Any) -> str:
257
- return json.dumps(x, default=json_serializer_default_fn)
258
-
259
-
260
- def replace_custom_stringified_objects(obj: Any) -> Any:
261
- """
262
- Used to deserialize objects from `json_serializer_default_fn` into their proper form.
263
- """
264
- if isinstance(obj, str):
265
- if obj.startswith("DatetimeUTC::"):
266
- iso_str = obj[len("DatetimeUTC::") :]
267
- return DatetimeUTC.to_datetime_utc(iso_str)
268
- elif obj.startswith("timedelta::"):
269
- total_seconds_str = obj[len("timedelta::") :]
270
- return timedelta(seconds=float(total_seconds_str))
271
- elif obj.startswith("date::"):
272
- iso_str = obj[len("date::") :]
273
- return date.fromisoformat(iso_str)
274
- else:
275
- return obj
276
- elif isinstance(obj, dict):
277
- return {k: replace_custom_stringified_objects(v) for k, v in obj.items()}
278
- elif isinstance(obj, list):
279
- return [replace_custom_stringified_objects(item) for item in obj]
280
- else:
281
- return obj
282
-
283
-
284
- def json_deserializer(s: str) -> Any:
285
- data = json.loads(s)
286
- return replace_custom_stringified_objects(data)
287
-
288
-
289
223
  def convert_cached_output_to_pydantic(return_type: Any, data: Any) -> Any:
290
224
  """
291
225
  Used to initialize Pydantic models from anything cached that was originally a Pydantic model in the output. Including models in nested structures.