prediction-market-agent-tooling 0.55.2.dev117__py3-none-any.whl → 0.56.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. prediction_market_agent_tooling/deploy/agent.py +17 -8
  2. prediction_market_agent_tooling/jobs/jobs_models.py +27 -2
  3. prediction_market_agent_tooling/jobs/omen/omen_jobs.py +67 -41
  4. prediction_market_agent_tooling/markets/agent_market.py +8 -2
  5. prediction_market_agent_tooling/markets/base_subgraph_handler.py +51 -0
  6. prediction_market_agent_tooling/markets/markets.py +12 -0
  7. prediction_market_agent_tooling/markets/metaculus/metaculus.py +1 -1
  8. prediction_market_agent_tooling/markets/omen/data_models.py +11 -2
  9. prediction_market_agent_tooling/markets/omen/omen.py +16 -9
  10. prediction_market_agent_tooling/markets/omen/omen_subgraph_handler.py +29 -51
  11. prediction_market_agent_tooling/markets/seer/data_models.py +27 -0
  12. prediction_market_agent_tooling/markets/seer/seer_subgraph_handler.py +142 -0
  13. prediction_market_agent_tooling/tools/caches/db_cache.py +351 -0
  14. prediction_market_agent_tooling/tools/google.py +3 -2
  15. prediction_market_agent_tooling/tools/is_invalid.py +4 -3
  16. prediction_market_agent_tooling/tools/is_predictable.py +3 -3
  17. prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_analysis.py +6 -10
  18. prediction_market_agent_tooling/tools/tavily/tavily_models.py +0 -66
  19. prediction_market_agent_tooling/tools/tavily/tavily_search.py +12 -44
  20. prediction_market_agent_tooling/tools/utils.py +2 -0
  21. {prediction_market_agent_tooling-0.55.2.dev117.dist-info → prediction_market_agent_tooling-0.56.0.dist-info}/METADATA +2 -1
  22. {prediction_market_agent_tooling-0.55.2.dev117.dist-info → prediction_market_agent_tooling-0.56.0.dist-info}/RECORD +26 -24
  23. prediction_market_agent_tooling/jobs/jobs.py +0 -45
  24. prediction_market_agent_tooling/tools/tavily/tavily_storage.py +0 -105
  25. /prediction_market_agent_tooling/tools/{cache.py → caches/inmemory_cache.py} +0 -0
  26. {prediction_market_agent_tooling-0.55.2.dev117.dist-info → prediction_market_agent_tooling-0.56.0.dist-info}/LICENSE +0 -0
  27. {prediction_market_agent_tooling-0.55.2.dev117.dist-info → prediction_market_agent_tooling-0.56.0.dist-info}/WHEEL +0 -0
  28. {prediction_market_agent_tooling-0.55.2.dev117.dist-info → prediction_market_agent_tooling-0.56.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,27 @@
1
+ from pydantic import BaseModel, ConfigDict, Field
2
+
3
+ from prediction_market_agent_tooling.gtypes import HexBytes
4
+
5
+
6
+ class SeerMarket(BaseModel):
7
+ model_config = ConfigDict(populate_by_name=True)
8
+
9
+ id: HexBytes
10
+ title: str = Field(alias="marketName")
11
+ outcomes: list[str]
12
+ parent_market: HexBytes = Field(alias="parentMarket")
13
+ wrapped_tokens: list[HexBytes] = Field(alias="wrappedTokens")
14
+
15
+
16
+ class SeerToken(BaseModel):
17
+ id: HexBytes
18
+ name: str
19
+ symbol: str
20
+
21
+
22
+ class SeerPool(BaseModel):
23
+ model_config = ConfigDict(populate_by_name=True)
24
+ id: HexBytes
25
+ liquidity: int
26
+ token0: SeerToken
27
+ token1: SeerToken
@@ -0,0 +1,142 @@
1
+ from typing import Any
2
+
3
+ from subgrounds import FieldPath
4
+ from web3.constants import ADDRESS_ZERO
5
+
6
+ from prediction_market_agent_tooling.markets.base_subgraph_handler import (
7
+ BaseSubgraphHandler,
8
+ )
9
+ from prediction_market_agent_tooling.markets.seer.data_models import (
10
+ SeerMarket,
11
+ SeerPool,
12
+ )
13
+ from prediction_market_agent_tooling.tools.hexbytes_custom import HexBytes
14
+
15
+ INVALID_OUTCOME = "Invalid result"
16
+
17
+
18
+ class SeerSubgraphHandler(BaseSubgraphHandler):
19
+ """
20
+ Class responsible for handling interactions with Seer subgraphs.
21
+ """
22
+
23
+ SEER_SUBGRAPH = "https://gateway-arbitrum.network.thegraph.com/api/{graph_api_key}/subgraphs/id/B4vyRqJaSHD8dRDb3BFRoAzuBK18c1QQcXq94JbxDxWH"
24
+
25
+ SWAPR_ALGEBRA_SUBGRAPH = "https://gateway-arbitrum.network.thegraph.com/api/{graph_api_key}/subgraphs/id/AAA1vYjxwFHzbt6qKwLHNcDSASyr1J1xVViDH8gTMFMR"
26
+
27
+ INVALID_ANSWER = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
28
+
29
+ def __init__(self) -> None:
30
+ super().__init__()
31
+
32
+ self.seer_subgraph = self.sg.load_subgraph(
33
+ self.SEER_SUBGRAPH.format(
34
+ graph_api_key=self.keys.graph_api_key.get_secret_value()
35
+ )
36
+ )
37
+ self.swapr_algebra_subgraph = self.sg.load_subgraph(
38
+ self.SWAPR_ALGEBRA_SUBGRAPH.format(
39
+ graph_api_key=self.keys.graph_api_key.get_secret_value()
40
+ )
41
+ )
42
+
43
+ def _get_fields_for_markets(self, markets_field: FieldPath) -> list[FieldPath]:
44
+ fields = [
45
+ markets_field.id,
46
+ markets_field.factory,
47
+ markets_field.creator,
48
+ markets_field.marketName,
49
+ markets_field.outcomes,
50
+ markets_field.parentMarket,
51
+ markets_field.finalizeTs,
52
+ markets_field.wrappedTokens,
53
+ ]
54
+ return fields
55
+
56
+ @staticmethod
57
+ def filter_bicategorical_markets(markets: list[SeerMarket]) -> list[SeerMarket]:
58
+ # We do an extra check for the invalid outcome for safety.
59
+ return [
60
+ m for m in markets if len(m.outcomes) == 3 and INVALID_OUTCOME in m.outcomes
61
+ ]
62
+
63
+ @staticmethod
64
+ def filter_binary_markets(markets: list[SeerMarket]) -> list[SeerMarket]:
65
+ return [
66
+ market
67
+ for market in markets
68
+ if {"yes", "no"}.issubset({o.lower() for o in market.outcomes})
69
+ ]
70
+
71
+ @staticmethod
72
+ def build_filter_for_conditional_markets(
73
+ include_conditional_markets: bool = True,
74
+ ) -> dict[Any, Any]:
75
+ return (
76
+ {}
77
+ if include_conditional_markets
78
+ else {"parentMarket": ADDRESS_ZERO.lower()}
79
+ )
80
+
81
+ def get_bicategorical_markets(
82
+ self, include_conditional_markets: bool = True
83
+ ) -> list[SeerMarket]:
84
+ """Returns markets that contain 2 categories plus an invalid outcome."""
85
+ # Binary markets on Seer contain 3 outcomes: OutcomeA, outcomeB and an Invalid option.
86
+ query_filter = self.build_filter_for_conditional_markets(
87
+ include_conditional_markets
88
+ )
89
+ query_filter["outcomes_contains"] = [INVALID_OUTCOME]
90
+ markets_field = self.seer_subgraph.Query.markets(where=query_filter)
91
+ fields = self._get_fields_for_markets(markets_field)
92
+ markets = self.do_query(fields=fields, pydantic_model=SeerMarket)
93
+ two_category_markets = self.filter_bicategorical_markets(markets)
94
+ return two_category_markets
95
+
96
+ def get_binary_markets(
97
+ self, include_conditional_markets: bool = True
98
+ ) -> list[SeerMarket]:
99
+ two_category_markets = self.get_bicategorical_markets(
100
+ include_conditional_markets=include_conditional_markets
101
+ )
102
+ # Now we additionally filter markets based on YES/NO being the only outcomes.
103
+ binary_markets = self.filter_binary_markets(two_category_markets)
104
+ return binary_markets
105
+
106
+ def get_market_by_id(self, market_id: HexBytes) -> SeerMarket:
107
+ markets_field = self.seer_subgraph.Query.market(id=market_id.hex().lower())
108
+ fields = self._get_fields_for_markets(markets_field)
109
+ markets = self.do_query(fields=fields, pydantic_model=SeerMarket)
110
+ if len(markets) != 1:
111
+ raise ValueError(
112
+ f"Fetched wrong number of markets. Expected 1 but got {len(markets)}"
113
+ )
114
+ return markets[0]
115
+
116
+ def _get_fields_for_pools(self, pools_field: FieldPath) -> list[FieldPath]:
117
+ fields = [
118
+ pools_field.id,
119
+ pools_field.liquidity,
120
+ pools_field.token0.id,
121
+ pools_field.token0.name,
122
+ pools_field.token0.symbol,
123
+ pools_field.token1.id,
124
+ pools_field.token1.name,
125
+ pools_field.token1.symbol,
126
+ ]
127
+ return fields
128
+
129
+ def get_pools_for_market(self, market: SeerMarket) -> list[SeerPool]:
130
+ # We iterate through the wrapped tokens and put them in a where clause so that we hit the subgraph endpoint just once.
131
+ wheres = []
132
+ for wrapped_token in market.wrapped_tokens:
133
+ wheres.extend(
134
+ [
135
+ {"token0": wrapped_token.hex().lower()},
136
+ {"token1": wrapped_token.hex().lower()},
137
+ ]
138
+ )
139
+ pools_field = self.swapr_algebra_subgraph.Query.pools(where={"or": wheres})
140
+ fields = self._get_fields_for_pools(pools_field)
141
+ pools = self.do_query(fields=fields, pydantic_model=SeerPool)
142
+ return pools
@@ -0,0 +1,351 @@
1
+ import hashlib
2
+ import inspect
3
+ import json
4
+ from datetime import date, timedelta
5
+ from functools import wraps
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Sequence,
10
+ TypeVar,
11
+ cast,
12
+ get_args,
13
+ get_origin,
14
+ overload,
15
+ )
16
+
17
+ from pydantic import BaseModel
18
+ from sqlalchemy import Column
19
+ from sqlalchemy.dialects.postgresql import JSONB
20
+ from sqlmodel import Field, Session, SQLModel, create_engine, desc, select
21
+
22
+ from prediction_market_agent_tooling.config import APIKeys
23
+ from prediction_market_agent_tooling.loggers import logger
24
+ from prediction_market_agent_tooling.tools.datetime_utc import DatetimeUTC
25
+ from prediction_market_agent_tooling.tools.utils import utcnow
26
+
27
+ FunctionT = TypeVar("FunctionT", bound=Callable[..., Any])
28
+
29
+
30
+ class FunctionCache(SQLModel, table=True):
31
+ __tablename__ = "function_cache"
32
+ id: int | None = Field(default=None, primary_key=True)
33
+ function_name: str = Field(index=True)
34
+ full_function_name: str = Field(index=True)
35
+ # Args are stored to see what was the function called with.
36
+ args: Any = Field(sa_column=Column(JSONB, nullable=False))
37
+ # Args hash is stored as a fast look-up option when looking for cache hits.
38
+ args_hash: str = Field(index=True)
39
+ result: Any = Field(sa_column=Column(JSONB, nullable=False))
40
+ created_at: DatetimeUTC = Field(default_factory=utcnow, index=True)
41
+
42
+
43
+ @overload
44
+ def db_cache(
45
+ func: None = None,
46
+ *,
47
+ max_age: timedelta | None = None,
48
+ cache_none: bool = True,
49
+ api_keys: APIKeys | None = None,
50
+ ignore_args: Sequence[str] | None = None,
51
+ ignore_arg_types: Sequence[type] | None = None,
52
+ ) -> Callable[[FunctionT], FunctionT]:
53
+ ...
54
+
55
+
56
+ @overload
57
+ def db_cache(
58
+ func: FunctionT,
59
+ *,
60
+ max_age: timedelta | None = None,
61
+ cache_none: bool = True,
62
+ api_keys: APIKeys | None = None,
63
+ ignore_args: Sequence[str] | None = None,
64
+ ignore_arg_types: Sequence[type] | None = None,
65
+ ) -> FunctionT:
66
+ ...
67
+
68
+
69
+ def db_cache(
70
+ func: FunctionT | None = None,
71
+ *,
72
+ max_age: timedelta | None = None,
73
+ cache_none: bool = True,
74
+ api_keys: APIKeys | None = None,
75
+ ignore_args: Sequence[str] | None = None,
76
+ ignore_arg_types: Sequence[type] | None = None,
77
+ ) -> FunctionT | Callable[[FunctionT], FunctionT]:
78
+ if func is None:
79
+ # Ugly Pythonic way to support this decorator as `@postgres_cache` but also `@postgres_cache(max_age=timedelta(days=3))`
80
+ def decorator(func: FunctionT) -> FunctionT:
81
+ return db_cache(
82
+ func,
83
+ max_age=max_age,
84
+ cache_none=cache_none,
85
+ api_keys=api_keys,
86
+ ignore_args=ignore_args,
87
+ ignore_arg_types=ignore_arg_types,
88
+ )
89
+
90
+ return decorator
91
+
92
+ api_keys = api_keys if api_keys is not None else APIKeys()
93
+
94
+ sqlalchemy_db_url = api_keys.SQLALCHEMY_DB_URL
95
+ if sqlalchemy_db_url is None:
96
+ logger.warning(
97
+ f"SQLALCHEMY_DB_URL not provided in the environment, skipping function caching."
98
+ )
99
+
100
+ engine = (
101
+ create_engine(
102
+ sqlalchemy_db_url.get_secret_value(),
103
+ # Use custom json serializer and deserializer, because otherwise, for example `datetime` serialization would fail.
104
+ json_serializer=json_serializer,
105
+ json_deserializer=json_deserializer,
106
+ )
107
+ if sqlalchemy_db_url is not None
108
+ else None
109
+ )
110
+
111
+ # Create table if it doesn't exist
112
+ if engine is not None:
113
+ SQLModel.metadata.create_all(engine)
114
+
115
+ @wraps(func)
116
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
117
+ # If caching is disabled, just call the function and return it
118
+ if not api_keys.ENABLE_CACHE:
119
+ return func(*args, **kwargs)
120
+
121
+ # Convert *args and **kwargs to a single dictionary, where we have names for arguments passed as args as well.
122
+ signature = inspect.signature(func)
123
+ bound_arguments = signature.bind(*args, **kwargs)
124
+ bound_arguments.apply_defaults()
125
+
126
+ # Convert any argument that is Pydantic model into classic dictionary, otherwise it won't be json-serializable.
127
+ args_dict: dict[str, Any] = bound_arguments.arguments
128
+
129
+ # Remove `self` or `cls` if present (in case of class' methods)
130
+ if "self" in args_dict:
131
+ del args_dict["self"]
132
+ if "cls" in args_dict:
133
+ del args_dict["cls"]
134
+
135
+ # Remove ignored arguments
136
+ if ignore_args:
137
+ for arg in ignore_args:
138
+ if arg in args_dict:
139
+ del args_dict[arg]
140
+
141
+ # Remove arguments of ignored types
142
+ if ignore_arg_types:
143
+ args_dict = {
144
+ k: v
145
+ for k, v in args_dict.items()
146
+ if not isinstance(v, tuple(ignore_arg_types))
147
+ }
148
+
149
+ # Compute a hash of the function arguments used for lookup of cached results
150
+ arg_string = json.dumps(args_dict, sort_keys=True, default=str)
151
+ args_hash = hashlib.md5(arg_string.encode()).hexdigest()
152
+
153
+ # Get the full function name as concat of module and qualname, to not accidentally clash
154
+ full_function_name = func.__module__ + "." + func.__qualname__
155
+ # But also get the standard function name to easily search for it in database
156
+ function_name = func.__name__
157
+
158
+ # Determine if the function returns or contains Pydantic BaseModel(s)
159
+ return_type = func.__annotations__.get("return", None)
160
+ is_pydantic_model = False
161
+
162
+ if return_type is not None and contains_pydantic_model(return_type):
163
+ is_pydantic_model = True
164
+
165
+ # If postgres access was specified, try to find a hit
166
+ if engine is not None:
167
+ with Session(engine) as session:
168
+ # Try to get cached result
169
+ statement = (
170
+ select(FunctionCache)
171
+ .where(
172
+ FunctionCache.function_name == function_name,
173
+ FunctionCache.full_function_name == full_function_name,
174
+ FunctionCache.args_hash == args_hash,
175
+ )
176
+ .order_by(desc(FunctionCache.created_at))
177
+ )
178
+ if max_age is not None:
179
+ cutoff_time = utcnow() - max_age
180
+ statement = statement.where(FunctionCache.created_at >= cutoff_time)
181
+ cached_result = session.exec(statement).first()
182
+ else:
183
+ cached_result = None
184
+
185
+ if cached_result:
186
+ logger.info(
187
+ # Keep the special [case-hit] identifier so we can easily track it in GCP.
188
+ f"[cache-hit] Cache hit for {full_function_name} with args {args_dict} and output {cached_result.result}"
189
+ )
190
+ if is_pydantic_model:
191
+ # If the output contains any Pydantic models, we need to initialise them.
192
+ try:
193
+ return convert_cached_output_to_pydantic(
194
+ return_type, cached_result.result
195
+ )
196
+ except ValueError as e:
197
+ # In case of backward-incompatible pydantic model, just treat it as cache miss, to not error out.
198
+ logger.warning(
199
+ f"Can not validate {cached_result=} into {return_type=} because {e=}, treating as cache miss."
200
+ )
201
+ cached_result = None
202
+ else:
203
+ return cached_result.result
204
+
205
+ # On cache miss, compute the result
206
+ computed_result = func(*args, **kwargs)
207
+ # Keep the special [case-miss] identifier so we can easily track it in GCP.
208
+ logger.info(
209
+ f"[cache-miss] Cache miss for {full_function_name} with args {args_dict}, computed the output {computed_result}"
210
+ )
211
+
212
+ # If postgres access was specified, save it.
213
+ if engine is not None and (cache_none or computed_result is not None):
214
+ cache_entry = FunctionCache(
215
+ function_name=function_name,
216
+ full_function_name=full_function_name,
217
+ args_hash=args_hash,
218
+ args=args_dict,
219
+ result=computed_result,
220
+ created_at=utcnow(),
221
+ )
222
+ with Session(engine) as session:
223
+ logger.info(f"Saving {cache_entry} into database.")
224
+ session.add(cache_entry)
225
+ session.commit()
226
+
227
+ return computed_result
228
+
229
+ return cast(FunctionT, wrapper)
230
+
231
+
232
+ def contains_pydantic_model(return_type: Any) -> bool:
233
+ """
234
+ Check if the return type contains anything that's a Pydantic model (including nested structures, like `list[BaseModel]`, `dict[str, list[BaseModel]]`, etc.)
235
+ """
236
+ if return_type is None:
237
+ return False
238
+ origin = get_origin(return_type)
239
+ if origin is not None:
240
+ return any(contains_pydantic_model(arg) for arg in get_args(return_type))
241
+ if inspect.isclass(return_type):
242
+ return issubclass(return_type, BaseModel)
243
+ return False
244
+
245
+
246
+ def json_serializer_default_fn(
247
+ y: DatetimeUTC | timedelta | date | BaseModel,
248
+ ) -> str | dict[str, Any]:
249
+ """
250
+ Used to serialize objects that don't support it by default into a specific string that can be deserialized out later.
251
+ If this function returns a dictionary, it will be called recursivelly.
252
+ If you add something here, also add it to `replace_custom_stringified_objects` below.
253
+ """
254
+ if isinstance(y, DatetimeUTC):
255
+ return f"DatetimeUTC::{y.isoformat()}"
256
+ elif isinstance(y, timedelta):
257
+ return f"timedelta::{y.total_seconds()}"
258
+ elif isinstance(y, date):
259
+ return f"date::{y.isoformat()}"
260
+ elif isinstance(y, BaseModel):
261
+ return y.model_dump()
262
+ raise TypeError(
263
+ f"Unsuported type for the default json serialize function, value is {y}."
264
+ )
265
+
266
+
267
+ def json_serializer(x: Any) -> str:
268
+ return json.dumps(x, default=json_serializer_default_fn)
269
+
270
+
271
+ def replace_custom_stringified_objects(obj: Any) -> Any:
272
+ """
273
+ Used to deserialize objects from `json_serializer_default_fn` into their proper form.
274
+ """
275
+ if isinstance(obj, str):
276
+ if obj.startswith("DatetimeUTC::"):
277
+ iso_str = obj[len("DatetimeUTC::") :]
278
+ return DatetimeUTC.to_datetime_utc(iso_str)
279
+ elif obj.startswith("timedelta::"):
280
+ total_seconds_str = obj[len("timedelta::") :]
281
+ return timedelta(seconds=float(total_seconds_str))
282
+ elif obj.startswith("date::"):
283
+ iso_str = obj[len("date::") :]
284
+ return date.fromisoformat(iso_str)
285
+ else:
286
+ return obj
287
+ elif isinstance(obj, dict):
288
+ return {k: replace_custom_stringified_objects(v) for k, v in obj.items()}
289
+ elif isinstance(obj, list):
290
+ return [replace_custom_stringified_objects(item) for item in obj]
291
+ else:
292
+ return obj
293
+
294
+
295
+ def json_deserializer(s: str) -> Any:
296
+ data = json.loads(s)
297
+ return replace_custom_stringified_objects(data)
298
+
299
+
300
+ def convert_cached_output_to_pydantic(return_type: Any, data: Any) -> Any:
301
+ """
302
+ Used to initialize Pydantic models from anything cached that was originally a Pydantic model in the output. Including models in nested structures.
303
+ """
304
+ # Get the origin and arguments of the model type
305
+ origin = get_origin(return_type)
306
+ args = get_args(return_type)
307
+
308
+ # Check if the data is a dictionary
309
+ if isinstance(data, dict):
310
+ # If the model has no origin, check if it is a subclass of BaseModel
311
+ if origin is None:
312
+ if inspect.isclass(return_type) and issubclass(return_type, BaseModel):
313
+ # Convert the dictionary to a Pydantic model
314
+ return return_type(
315
+ **{
316
+ k: convert_cached_output_to_pydantic(
317
+ getattr(return_type, k, None), v
318
+ )
319
+ for k, v in data.items()
320
+ }
321
+ )
322
+ else:
323
+ # If not a Pydantic model, return the data as is
324
+ return data
325
+ # If the origin is a dictionary, convert keys and values
326
+ elif origin is dict:
327
+ key_type, value_type = args
328
+ return {
329
+ convert_cached_output_to_pydantic(
330
+ key_type, k
331
+ ): convert_cached_output_to_pydantic(value_type, v)
332
+ for k, v in data.items()
333
+ }
334
+ else:
335
+ # If the origin is not a dictionary, return the data as is
336
+ return data
337
+ # Check if the data is a list
338
+ elif isinstance(data, (list, tuple)):
339
+ # If the origin is a list or tuple, convert each item
340
+ if origin in {list, tuple}:
341
+ item_type = args[0]
342
+ converted_items = [
343
+ convert_cached_output_to_pydantic(item_type, item) for item in data
344
+ ]
345
+ return type(data)(converted_items)
346
+ else:
347
+ # If the origin is not a list or tuple, return the data as is
348
+ return data
349
+ else:
350
+ # If the data is neither a dictionary nor a list, return it as is
351
+ return data
@@ -1,11 +1,12 @@
1
1
  import typing as t
2
+ from datetime import timedelta
2
3
 
3
4
  import tenacity
4
5
  from googleapiclient.discovery import build
5
6
 
6
7
  from prediction_market_agent_tooling.config import APIKeys
7
8
  from prediction_market_agent_tooling.loggers import logger
8
- from prediction_market_agent_tooling.tools.cache import persistent_inmemory_cache
9
+ from prediction_market_agent_tooling.tools.caches.db_cache import db_cache
9
10
 
10
11
 
11
12
  @tenacity.retry(
@@ -13,7 +14,7 @@ from prediction_market_agent_tooling.tools.cache import persistent_inmemory_cach
13
14
  stop=tenacity.stop_after_attempt(3),
14
15
  after=lambda x: logger.debug(f"search_google failed, {x.attempt_number=}."),
15
16
  )
16
- @persistent_inmemory_cache
17
+ @db_cache(max_age=timedelta(days=1))
17
18
  def search_google(
18
19
  query: str | None = None,
19
20
  num: int = 3,
@@ -2,7 +2,7 @@ import tenacity
2
2
 
3
3
  from prediction_market_agent_tooling.config import APIKeys
4
4
  from prediction_market_agent_tooling.loggers import logger
5
- from prediction_market_agent_tooling.tools.cache import persistent_inmemory_cache
5
+ from prediction_market_agent_tooling.tools.caches.db_cache import db_cache
6
6
  from prediction_market_agent_tooling.tools.is_predictable import (
7
7
  parse_decision_yes_no_completion,
8
8
  )
@@ -34,9 +34,10 @@ QUESTION_IS_INVALID_PROMPT = """Main signs about an invalid question (sometimes
34
34
  - Which could give an incentive only to specific participants to commit an immoral violent action, but are in practice unlikely.
35
35
  - Valid: Will the US be engaged in a military conflict with a UN member state in 2021? (It’s unlikely for the US to declare war in order to win a bet on this market).
36
36
  - Valid: Will Derek Chauvin go to jail for the murder of George Flyod? (It’s unlikely that the jurors would collude to make a wrong verdict in order to win this market).
37
- - Questions with relative dates will resolve as invalid. Dates must be stated in absolute terms, not relative depending on the current time.
37
+ - Questions with relative dates will resolve as invalid. Dates must be stated in absolute terms, not relative depending on the current time. But they can be relative to the event specified in the question itself.
38
38
  - Invalid: Who will be the president of the United States in 6 months? ("in 6 months depends on the current time").
39
39
  - Invalid: In the next 14 days, will Gnosis Chain gain another 1M users? ("in the next 14 days depends on the current time").
40
+ - Valid: Will GNO price go up 10 days after Gnosis Pay cashback program is annouced? ("10 days after" is relative to the event in the question, so we can determine absolute value).
40
41
  - Questions about moral values and not facts will be resolved as invalid.
41
42
  - Invalid: "Is it ethical to eat meat?".
42
43
 
@@ -54,9 +55,9 @@ Finally, write your final decision, write `decision: ` followed by either "yes i
54
55
  """
55
56
 
56
57
 
57
- @persistent_inmemory_cache
58
58
  @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_fixed(1))
59
59
  @observe()
60
+ @db_cache
60
61
  def is_invalid(
61
62
  question: str,
62
63
  engine: str = "gpt-4o",
@@ -2,7 +2,7 @@ import tenacity
2
2
 
3
3
  from prediction_market_agent_tooling.config import APIKeys
4
4
  from prediction_market_agent_tooling.loggers import logger
5
- from prediction_market_agent_tooling.tools.cache import persistent_inmemory_cache
5
+ from prediction_market_agent_tooling.tools.caches.db_cache import db_cache
6
6
  from prediction_market_agent_tooling.tools.langfuse_ import (
7
7
  get_langfuse_langchain_config,
8
8
  observe,
@@ -76,9 +76,9 @@ Finally, write your final decision, write `decision: ` followed by either "yes i
76
76
  """
77
77
 
78
78
 
79
- @persistent_inmemory_cache
80
79
  @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_fixed(1))
81
80
  @observe()
81
+ @db_cache
82
82
  def is_predictable_binary(
83
83
  question: str,
84
84
  engine: str = "gpt-4-1106-preview",
@@ -112,9 +112,9 @@ def is_predictable_binary(
112
112
  return parse_decision_yes_no_completion(question, completion)
113
113
 
114
114
 
115
- @persistent_inmemory_cache
116
115
  @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_fixed(1))
117
116
  @observe()
117
+ @db_cache
118
118
  def is_predictable_without_description(
119
119
  question: str,
120
120
  description: str,
@@ -1,4 +1,4 @@
1
- from datetime import datetime, timedelta
1
+ from datetime import date, timedelta
2
2
 
3
3
  from langchain_core.output_parsers import PydanticOutputParser
4
4
  from langchain_core.prompts import PromptTemplate
@@ -20,8 +20,7 @@ from prediction_market_agent_tooling.tools.relevant_news_analysis.relevant_news_
20
20
  from prediction_market_agent_tooling.tools.tavily.tavily_search import (
21
21
  get_relevant_news_since,
22
22
  )
23
- from prediction_market_agent_tooling.tools.tavily.tavily_storage import TavilyStorage
24
- from prediction_market_agent_tooling.tools.utils import check_not_none, utcnow
23
+ from prediction_market_agent_tooling.tools.utils import check_not_none
25
24
 
26
25
  SUMMARISE_RELEVANT_NEWS_PROMPT_TEMPLATE = """
27
26
  You are an expert news analyst, tracking stories that may affect your prediction to the outcome of a particular QUESTION.
@@ -55,7 +54,7 @@ For your analysis, you should:
55
54
  def analyse_news_relevance(
56
55
  raw_content: str,
57
56
  question: str,
58
- date_of_interest: datetime,
57
+ date_of_interest: date,
59
58
  model: str,
60
59
  temperature: float,
61
60
  ) -> RelevantNewsAnalysis:
@@ -91,19 +90,18 @@ def analyse_news_relevance(
91
90
  def get_certified_relevant_news_since(
92
91
  question: str,
93
92
  days_ago: int,
94
- tavily_storage: TavilyStorage | None = None,
95
93
  ) -> RelevantNews | None:
96
94
  """
97
95
  Get relevant news since a given date for a given question. Retrieves
98
96
  possibly relevant news from tavily, then checks that it is relevant via
99
97
  an LLM call.
100
98
  """
99
+ news_since = date.today() - timedelta(days=days_ago)
101
100
  results = get_relevant_news_since(
102
101
  question=question,
103
- days_ago=days_ago,
102
+ news_since=news_since,
104
103
  score_threshold=0.0, # Be conservative to avoid missing relevant information
105
104
  max_results=3, # A tradeoff between cost and quality. 3 seems to be a good balance.
106
- tavily_storage=tavily_storage,
107
105
  )
108
106
 
109
107
  # Sort results by descending 'relevance score' to maximise the chance of
@@ -118,7 +116,7 @@ def get_certified_relevant_news_since(
118
116
  relevant_news_analysis = analyse_news_relevance(
119
117
  raw_content=check_not_none(result.raw_content),
120
118
  question=question,
121
- date_of_interest=utcnow() - timedelta(days=days_ago),
119
+ date_of_interest=news_since,
122
120
  model="gpt-4o", # 4o-mini isn't good enough, 1o and 1o-mini are too expensive
123
121
  temperature=0.0,
124
122
  )
@@ -140,7 +138,6 @@ def get_certified_relevant_news_since_cached(
140
138
  question: str,
141
139
  days_ago: int,
142
140
  cache: RelevantNewsResponseCache,
143
- tavily_storage: TavilyStorage | None = None,
144
141
  ) -> RelevantNews | None:
145
142
  cached = cache.find(question=question, days_ago=days_ago)
146
143
 
@@ -150,7 +147,6 @@ def get_certified_relevant_news_since_cached(
150
147
  relevant_news = get_certified_relevant_news_since(
151
148
  question=question,
152
149
  days_ago=days_ago,
153
- tavily_storage=tavily_storage,
154
150
  )
155
151
  cache.save(
156
152
  question=question,