letta-nightly 0.8.0.dev20250606104326__py3-none-any.whl → 0.8.2.dev20250606215616__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. letta/__init__.py +1 -1
  2. letta/agent.py +1 -1
  3. letta/agents/letta_agent.py +49 -29
  4. letta/agents/letta_agent_batch.py +1 -2
  5. letta/agents/voice_agent.py +19 -13
  6. letta/agents/voice_sleeptime_agent.py +11 -3
  7. letta/constants.py +18 -0
  8. letta/data_sources/__init__.py +0 -0
  9. letta/data_sources/redis_client.py +282 -0
  10. letta/errors.py +0 -4
  11. letta/functions/function_sets/files.py +58 -0
  12. letta/functions/schema_generator.py +18 -1
  13. letta/groups/sleeptime_multi_agent_v2.py +1 -1
  14. letta/helpers/datetime_helpers.py +47 -3
  15. letta/helpers/decorators.py +69 -0
  16. letta/{services/helpers/noop_helper.py → helpers/singleton.py} +5 -0
  17. letta/interfaces/anthropic_streaming_interface.py +43 -24
  18. letta/interfaces/openai_streaming_interface.py +21 -19
  19. letta/llm_api/anthropic.py +1 -1
  20. letta/llm_api/anthropic_client.py +22 -14
  21. letta/llm_api/google_vertex_client.py +1 -1
  22. letta/llm_api/helpers.py +36 -30
  23. letta/llm_api/llm_api_tools.py +1 -1
  24. letta/llm_api/llm_client_base.py +29 -1
  25. letta/llm_api/openai.py +1 -1
  26. letta/llm_api/openai_client.py +6 -8
  27. letta/local_llm/chat_completion_proxy.py +1 -1
  28. letta/memory.py +1 -1
  29. letta/orm/enums.py +1 -0
  30. letta/orm/file.py +80 -3
  31. letta/orm/files_agents.py +13 -0
  32. letta/orm/sqlalchemy_base.py +34 -11
  33. letta/otel/__init__.py +0 -0
  34. letta/otel/context.py +25 -0
  35. letta/otel/events.py +0 -0
  36. letta/otel/metric_registry.py +122 -0
  37. letta/otel/metrics.py +66 -0
  38. letta/otel/resource.py +26 -0
  39. letta/{tracing.py → otel/tracing.py} +55 -78
  40. letta/plugins/README.md +22 -0
  41. letta/plugins/__init__.py +0 -0
  42. letta/plugins/defaults.py +11 -0
  43. letta/plugins/plugins.py +72 -0
  44. letta/schemas/enums.py +8 -0
  45. letta/schemas/file.py +12 -0
  46. letta/schemas/tool.py +4 -0
  47. letta/server/db.py +7 -7
  48. letta/server/rest_api/app.py +8 -6
  49. letta/server/rest_api/routers/v1/agents.py +37 -36
  50. letta/server/rest_api/routers/v1/groups.py +3 -3
  51. letta/server/rest_api/routers/v1/sources.py +26 -3
  52. letta/server/rest_api/utils.py +9 -6
  53. letta/server/server.py +18 -12
  54. letta/services/agent_manager.py +185 -193
  55. letta/services/block_manager.py +1 -1
  56. letta/services/context_window_calculator/token_counter.py +3 -2
  57. letta/services/file_processor/chunker/line_chunker.py +34 -0
  58. letta/services/file_processor/file_processor.py +40 -11
  59. letta/services/file_processor/parser/mistral_parser.py +11 -1
  60. letta/services/files_agents_manager.py +96 -7
  61. letta/services/group_manager.py +6 -6
  62. letta/services/helpers/agent_manager_helper.py +373 -3
  63. letta/services/identity_manager.py +1 -1
  64. letta/services/job_manager.py +1 -1
  65. letta/services/llm_batch_manager.py +1 -1
  66. letta/services/message_manager.py +1 -1
  67. letta/services/organization_manager.py +1 -1
  68. letta/services/passage_manager.py +1 -1
  69. letta/services/per_agent_lock_manager.py +1 -1
  70. letta/services/provider_manager.py +1 -1
  71. letta/services/sandbox_config_manager.py +1 -1
  72. letta/services/source_manager.py +178 -19
  73. letta/services/step_manager.py +2 -2
  74. letta/services/summarizer/summarizer.py +1 -1
  75. letta/services/telemetry_manager.py +1 -1
  76. letta/services/tool_executor/builtin_tool_executor.py +117 -0
  77. letta/services/tool_executor/composio_tool_executor.py +53 -0
  78. letta/services/tool_executor/core_tool_executor.py +474 -0
  79. letta/services/tool_executor/files_tool_executor.py +131 -0
  80. letta/services/tool_executor/mcp_tool_executor.py +45 -0
  81. letta/services/tool_executor/multi_agent_tool_executor.py +123 -0
  82. letta/services/tool_executor/tool_execution_manager.py +34 -14
  83. letta/services/tool_executor/tool_execution_sandbox.py +1 -1
  84. letta/services/tool_executor/tool_executor.py +3 -802
  85. letta/services/tool_executor/tool_executor_base.py +43 -0
  86. letta/services/tool_manager.py +55 -59
  87. letta/services/tool_sandbox/e2b_sandbox.py +1 -1
  88. letta/services/tool_sandbox/local_sandbox.py +6 -3
  89. letta/services/user_manager.py +6 -3
  90. letta/settings.py +21 -1
  91. letta/utils.py +7 -2
  92. {letta_nightly-0.8.0.dev20250606104326.dist-info → letta_nightly-0.8.2.dev20250606215616.dist-info}/METADATA +4 -2
  93. {letta_nightly-0.8.0.dev20250606104326.dist-info → letta_nightly-0.8.2.dev20250606215616.dist-info}/RECORD +96 -74
  94. {letta_nightly-0.8.0.dev20250606104326.dist-info → letta_nightly-0.8.2.dev20250606215616.dist-info}/LICENSE +0 -0
  95. {letta_nightly-0.8.0.dev20250606104326.dist-info → letta_nightly-0.8.2.dev20250606215616.dist-info}/WHEEL +0 -0
  96. {letta_nightly-0.8.0.dev20250606104326.dist-info → letta_nightly-0.8.2.dev20250606215616.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,282 @@
1
+ import asyncio
2
+ from functools import wraps
3
+ from typing import Any, Optional, Set, Union
4
+
5
+ import redis.asyncio as redis
6
+ from redis import RedisError
7
+
8
+ from letta.constants import REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL
9
+ from letta.log import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
13
+ _client_instance = None
14
+
15
+
16
+ class AsyncRedisClient:
17
+ """Async Redis client with connection pooling and error handling"""
18
+
19
+ def __init__(
20
+ self,
21
+ host: str = "localhost",
22
+ port: int = 6379,
23
+ db: int = 0,
24
+ password: Optional[str] = None,
25
+ max_connections: int = 50,
26
+ decode_responses: bool = True,
27
+ socket_timeout: int = 5,
28
+ socket_connect_timeout: int = 5,
29
+ retry_on_timeout: bool = True,
30
+ health_check_interval: int = 30,
31
+ ):
32
+ """
33
+ Initialize Redis client with connection pool.
34
+
35
+ Args:
36
+ host: Redis server hostname
37
+ port: Redis server port
38
+ db: Database number
39
+ password: Redis password if required
40
+ max_connections: Maximum number of connections in pool
41
+ decode_responses: Decode byte responses to strings
42
+ socket_timeout: Socket timeout in seconds
43
+ socket_connect_timeout: Socket connection timeout
44
+ retry_on_timeout: Retry operations on timeout
45
+ health_check_interval: Seconds between health checks
46
+ """
47
+ self.pool = redis.ConnectionPool(
48
+ host=host,
49
+ port=port,
50
+ db=db,
51
+ password=password,
52
+ max_connections=max_connections,
53
+ decode_responses=decode_responses,
54
+ socket_timeout=socket_timeout,
55
+ socket_connect_timeout=socket_connect_timeout,
56
+ retry_on_timeout=retry_on_timeout,
57
+ health_check_interval=health_check_interval,
58
+ )
59
+ self._client = None
60
+ self._lock = asyncio.Lock()
61
+
62
+ async def get_client(self) -> redis.Redis:
63
+ """Get or create Redis client instance."""
64
+ if self._client is None:
65
+ async with self._lock:
66
+ if self._client is None:
67
+ self._client = redis.Redis(connection_pool=self.pool)
68
+ return self._client
69
+
70
+ async def close(self):
71
+ """Close Redis connection and cleanup."""
72
+ if self._client:
73
+ await self._client.close()
74
+ await self.pool.disconnect()
75
+ self._client = None
76
+
77
+ async def __aenter__(self):
78
+ """Async context manager entry."""
79
+ await self.get_client()
80
+ return self
81
+
82
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
83
+ """Async context manager exit."""
84
+ await self.close()
85
+
86
+ # Health check and connection management
87
+ async def ping(self) -> bool:
88
+ """Check if Redis is accessible."""
89
+ try:
90
+ client = await self.get_client()
91
+ await client.ping()
92
+ return True
93
+ except RedisError:
94
+ logger.exception("Redis ping failed")
95
+ return False
96
+
97
+ async def wait_for_ready(self, timeout: int = 30, interval: float = 0.5):
98
+ """Wait for Redis to be ready."""
99
+ start_time = asyncio.get_event_loop().time()
100
+ while (asyncio.get_event_loop().time() - start_time) < timeout:
101
+ if await self.ping():
102
+ return
103
+ await asyncio.sleep(interval)
104
+ raise ConnectionError(f"Redis not ready after {timeout} seconds")
105
+
106
+ # Retry decorator for resilience
107
+ def with_retry(max_attempts: int = 3, delay: float = 0.1):
108
+ """Decorator to retry Redis operations on failure."""
109
+
110
+ def decorator(func):
111
+ @wraps(func)
112
+ async def wrapper(self, *args, **kwargs):
113
+ last_error = None
114
+ for attempt in range(max_attempts):
115
+ try:
116
+ return await func(self, *args, **kwargs)
117
+ except (ConnectionError, TimeoutError) as e:
118
+ last_error = e
119
+ if attempt < max_attempts - 1:
120
+ await asyncio.sleep(delay * (2**attempt))
121
+ logger.warning(f"Retry {attempt + 1}/{max_attempts} for {func.__name__}: {e}")
122
+ raise last_error
123
+
124
+ return wrapper
125
+
126
+ return decorator
127
+
128
+ # Basic operations with error handling
129
+ @with_retry()
130
+ async def get(self, key: str, default: Any = None) -> Any:
131
+ """Get value by key."""
132
+ try:
133
+ client = await self.get_client()
134
+ return await client.get(key)
135
+ except:
136
+ return default
137
+
138
+ @with_retry()
139
+ async def set(
140
+ self,
141
+ key: str,
142
+ value: Union[str, int, float],
143
+ ex: Optional[int] = None,
144
+ px: Optional[int] = None,
145
+ nx: bool = False,
146
+ xx: bool = False,
147
+ ) -> bool:
148
+ """
149
+ Set key-value with options.
150
+
151
+ Args:
152
+ key: Redis key
153
+ value: Value to store
154
+ ex: Expire time in seconds
155
+ px: Expire time in milliseconds
156
+ nx: Only set if key doesn't exist
157
+ xx: Only set if key exists
158
+ """
159
+ client = await self.get_client()
160
+ return await client.set(key, value, ex=ex, px=px, nx=nx, xx=xx)
161
+
162
+ @with_retry()
163
+ async def delete(self, *keys: str) -> int:
164
+ """Delete one or more keys."""
165
+ client = await self.get_client()
166
+ return await client.delete(*keys)
167
+
168
+ @with_retry()
169
+ async def exists(self, *keys: str) -> int:
170
+ """Check if keys exist."""
171
+ client = await self.get_client()
172
+ return await client.exists(*keys)
173
+
174
+ # Set operations
175
+ async def sadd(self, key: str, *members: Union[str, int, float]) -> int:
176
+ """Add members to set."""
177
+ client = await self.get_client()
178
+ return await client.sadd(key, *members)
179
+
180
+ async def smembers(self, key: str) -> Set[str]:
181
+ """Get all set members."""
182
+ client = await self.get_client()
183
+ return await client.smembers(key)
184
+
185
+ @with_retry()
186
+ async def smismember(self, key: str, values: list[Any] | Any) -> list[int] | int:
187
+ """clever!: set member is member"""
188
+ try:
189
+ client = await self.get_client()
190
+ result = await client.smismember(key, values)
191
+ return result if isinstance(values, list) else result[0]
192
+ except:
193
+ return [0] * len(values) if isinstance(values, list) else 0
194
+
195
+ async def srem(self, key: str, *members: Union[str, int, float]) -> int:
196
+ """Remove members from set."""
197
+ client = await self.get_client()
198
+ return await client.srem(key, *members)
199
+
200
+ async def scard(self, key: str) -> int:
201
+ client = await self.get_client()
202
+ return await client.scard(key)
203
+
204
+ # Atomic operations
205
+ async def incr(self, key: str) -> int:
206
+ """Increment key value."""
207
+ client = await self.get_client()
208
+ return await client.incr(key)
209
+
210
+ async def decr(self, key: str) -> int:
211
+ """Decrement key value."""
212
+ client = await self.get_client()
213
+ return await client.decr(key)
214
+
215
+ async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
216
+ exclude_key = f"{group}_{REDIS_EXCLUDE}"
217
+ include_key = f"{group}_{REDIS_INCLUDE}"
218
+ # 1. if the member IS excluded from the group
219
+ if self.exists(exclude_key) and await self.scard(exclude_key) > 1:
220
+ return bool(await self.smismember(exclude_key, member))
221
+ # 2. if the group HAS an include set, is the member in that set?
222
+ if self.exists(include_key) and await self.scard(include_key) > 1:
223
+ return bool(await self.smismember(include_key, member))
224
+ # 3. if the group does NOT HAVE an include set and member NOT excluded
225
+ return True
226
+
227
+ async def create_inclusion_exclusion_keys(self, group: str) -> None:
228
+ redis_client = await self.get_client()
229
+ await redis_client.sadd(self._get_group_inclusion_key(group), REDIS_SET_DEFAULT_VAL)
230
+ await redis_client.sadd(self._get_group_exclusion_key(group), REDIS_SET_DEFAULT_VAL)
231
+
232
+ @staticmethod
233
+ def _get_group_inclusion_key(group: str) -> str:
234
+ return f"{group}_{REDIS_INCLUDE}"
235
+
236
+ @staticmethod
237
+ def _get_group_exclusion_key(group: str) -> str:
238
+ return f"{group}_{REDIS_EXCLUDE}"
239
+
240
+
241
+ class NoopAsyncRedisClient(AsyncRedisClient):
242
+ async def get(self, key: str, default: Any = None) -> Any:
243
+ return default
244
+
245
+ async def exists(self, *keys: str) -> int:
246
+ return 0
247
+
248
+ async def sadd(self, key: str, *members: Union[str, int, float]) -> int:
249
+ return 0
250
+
251
+ async def smismember(self, key: str, values: list[Any] | Any) -> list[int] | int:
252
+ return [0] * len(values) if isinstance(values, list) else 0
253
+
254
+ async def delete(self, *keys: str) -> int:
255
+ return 0
256
+
257
+ async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
258
+ return False
259
+
260
+ async def create_inclusion_exclusion_keys(self, group: str) -> None:
261
+ return None
262
+
263
+ async def scard(self, key: str) -> int:
264
+ return 0
265
+
266
+
267
+ async def get_redis_client() -> AsyncRedisClient:
268
+ global _client_instance
269
+ if _client_instance is None:
270
+ try:
271
+ from letta.settings import settings
272
+
273
+ _client_instance = AsyncRedisClient(
274
+ host=settings.redis_host or "localhost",
275
+ port=settings.redis_port or 6379,
276
+ )
277
+ await _client_instance.wait_for_ready(timeout=5)
278
+ logger.info("Redis client initialized")
279
+ except Exception as e:
280
+ logger.warning(f"Failed to initialize Redis: {e}")
281
+ _client_instance = NoopAsyncRedisClient()
282
+ return _client_instance
letta/errors.py CHANGED
@@ -88,10 +88,6 @@ class LLMPermissionDeniedError(LLMError):
88
88
  """Error when permission is denied by LLM service"""
89
89
 
90
90
 
91
- class LLMContextWindowExceededError(LLMError):
92
- """Error when the context length is exceeded."""
93
-
94
-
95
91
  class LLMNotFoundError(LLMError):
96
92
  """Error when requested resource is not found"""
97
93
 
@@ -0,0 +1,58 @@
1
+ from typing import TYPE_CHECKING, List, Optional, Tuple
2
+
3
+ if TYPE_CHECKING:
4
+ from letta.schemas.agent import AgentState
5
+ from letta.schemas.file import FileMetadata
6
+
7
+
8
+ async def open_file(agent_state: "AgentState", file_name: str, view_range: Optional[Tuple[int, int]]) -> str:
9
+ """
10
+ Open up a file in core memory.
11
+
12
+ Args:
13
+ file_name (str): Name of the file to view.
14
+ view_range (Optional[Tuple[int, int]]): Optional tuple indicating range to view.
15
+
16
+ Returns:
17
+ str: A status message
18
+ """
19
+ raise NotImplementedError("Tool not implemented. Please contact the Letta team.")
20
+
21
+
22
+ async def close_file(agent_state: "AgentState", file_name: str) -> str:
23
+ """
24
+ Close a file in core memory.
25
+
26
+ Args:
27
+ file_name (str): Name of the file to close.
28
+
29
+ Returns:
30
+ str: A status message
31
+ """
32
+ raise NotImplementedError("Tool not implemented. Please contact the Letta team.")
33
+
34
+
35
+ async def grep(agent_state: "AgentState", pattern: str) -> str:
36
+ """
37
+ Grep tool to search files across data sources with keywords.
38
+
39
+ Args:
40
+ pattern (str): Keyword or regex pattern to search.
41
+
42
+ Returns:
43
+ str: Matching lines or summary output.
44
+ """
45
+ raise NotImplementedError("Tool not implemented. Please contact the Letta team.")
46
+
47
+
48
+ async def search_files(agent_state: "AgentState", query: str) -> List["FileMetadata"]:
49
+ """
50
+ Get list of most relevant files across all data sources.
51
+
52
+ Args:
53
+ query (str): The search query.
54
+
55
+ Returns:
56
+ List[FileMetadata]: List of matching files.
57
+ """
58
+ raise NotImplementedError("Tool not implemented. Please contact the Letta team.")
@@ -1,6 +1,6 @@
1
1
  import inspect
2
2
  import warnings
3
- from typing import Any, Dict, List, Optional, Type, Union, get_args, get_origin
3
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin
4
4
 
5
5
  from composio.client.collections import ActionParametersModel
6
6
  from docstring_parser import parse
@@ -76,6 +76,23 @@ def type_to_json_schema_type(py_type) -> dict:
76
76
  if get_origin(py_type) is Literal:
77
77
  return {"type": "string", "enum": get_args(py_type)}
78
78
 
79
+ # Handle tuple types (specifically fixed-length like Tuple[int, int])
80
+ if origin in (tuple, Tuple):
81
+ args = get_args(py_type)
82
+ if len(args) == 0:
83
+ raise ValueError("Tuple type must have at least one element")
84
+
85
+ # Support only fixed-length tuples like Tuple[int, int], not variable-length like Tuple[int, ...]
86
+ if len(args) == 2 and args[1] is Ellipsis:
87
+ raise NotImplementedError("Variable-length tuples (e.g., Tuple[int, ...]) are not supported")
88
+
89
+ return {
90
+ "type": "array",
91
+ "prefixItems": [type_to_json_schema_type(arg) for arg in args],
92
+ "minItems": len(args),
93
+ "maxItems": len(args),
94
+ }
95
+
79
96
  # Handle object types
80
97
  if py_type == dict or origin in (dict, Dict):
81
98
  args = get_args(py_type)
@@ -5,6 +5,7 @@ from typing import AsyncGenerator, List, Optional
5
5
  from letta.agents.base_agent import BaseAgent
6
6
  from letta.agents.letta_agent import LettaAgent
7
7
  from letta.groups.helpers import stringify_message
8
+ from letta.otel.tracing import trace_method
8
9
  from letta.schemas.enums import JobStatus
9
10
  from letta.schemas.group import Group, ManagerType
10
11
  from letta.schemas.job import JobUpdate
@@ -21,7 +22,6 @@ from letta.services.message_manager import MessageManager
21
22
  from letta.services.passage_manager import PassageManager
22
23
  from letta.services.step_manager import NoopStepManager, StepManager
23
24
  from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
24
- from letta.tracing import trace_method
25
25
 
26
26
 
27
27
  class SleeptimeMultiAgentV2(BaseAgent):
@@ -1,7 +1,9 @@
1
1
  import re
2
2
  import time
3
- from datetime import datetime, timedelta, timezone
3
+ from datetime import datetime, timedelta
4
+ from datetime import timezone as dt_timezone
4
5
  from time import strftime
6
+ from typing import Callable
5
7
 
6
8
  import pytz
7
9
 
@@ -66,7 +68,7 @@ def get_local_time(timezone=None):
66
68
  def get_utc_time() -> datetime:
67
69
  """Get the current UTC time"""
68
70
  # return datetime.now(pytz.utc)
69
- return datetime.now(timezone.utc)
71
+ return datetime.now(dt_timezone.utc)
70
72
 
71
73
 
72
74
  def get_utc_time_int() -> int:
@@ -78,9 +80,13 @@ def get_utc_timestamp_ns() -> int:
78
80
  return int(time.time_ns())
79
81
 
80
82
 
83
+ def ns_to_ms(ns: int) -> int:
84
+ return ns // 1_000_000
85
+
86
+
81
87
  def timestamp_to_datetime(timestamp_seconds: int) -> datetime:
82
88
  """Convert Unix timestamp in seconds to UTC datetime object"""
83
- return datetime.fromtimestamp(timestamp_seconds, tz=timezone.utc)
89
+ return datetime.fromtimestamp(timestamp_seconds, tz=dt_timezone.utc)
84
90
 
85
91
 
86
92
  def format_datetime(dt):
@@ -105,3 +111,41 @@ def extract_date_from_timestamp(timestamp):
105
111
 
106
112
  def is_utc_datetime(dt: datetime) -> bool:
107
113
  return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) == timedelta(0)
114
+
115
+
116
+ class AsyncTimer:
117
+ """An async context manager for timing async code execution.
118
+
119
+ Takes in an optional callback_func to call on exit with arguments
120
+ taking in the elapsed_ms and exc if present.
121
+
122
+ Do not use the start and end times outside of this function as they are relative.
123
+ """
124
+
125
+ def __init__(self, callback_func: Callable | None = None):
126
+ self._start_time_ns = None
127
+ self._end_time_ns = None
128
+ self.elapsed_ns = None
129
+ self.callback_func = callback_func
130
+
131
+ async def __aenter__(self):
132
+ self._start_time_ns = time.perf_counter_ns()
133
+ return self
134
+
135
+ async def __aexit__(self, exc_type, exc, tb):
136
+ self._end_time_ns = time.perf_counter_ns()
137
+ self.elapsed_ns = self._end_time_ns - self._start_time_ns
138
+ if self.callback_func:
139
+ from asyncio import iscoroutinefunction
140
+
141
+ if iscoroutinefunction(self.callback_func):
142
+ await self.callback_func(self.elapsed_ms, exc)
143
+ else:
144
+ self.callback_func(self.elapsed_ms, exc)
145
+ return False
146
+
147
+ @property
148
+ def elapsed_ms(self):
149
+ if self.elapsed_ns is not None:
150
+ return ns_to_ms(self.elapsed_ns)
151
+ return None
@@ -0,0 +1,69 @@
1
+ import inspect
2
+ from functools import wraps
3
+ from typing import Callable
4
+
5
+ from letta.log import get_logger
6
+ from letta.plugins.plugins import get_experimental_checker
7
+ from letta.settings import settings
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ def experimental(feature_name: str, fallback_function: Callable, **kwargs):
13
+ """Decorator that runs a fallback function if experimental feature is not enabled.
14
+
15
+ - kwargs from the decorator will be combined with function kwargs and overwritten only for experimental evaluation.
16
+ - if the decorated function, fallback_function, or experimental checker function is async, the whole call will be async
17
+ """
18
+
19
+ def decorator(f):
20
+ experimental_checker = get_experimental_checker()
21
+ is_f_async = inspect.iscoroutinefunction(f)
22
+ is_fallback_async = inspect.iscoroutinefunction(fallback_function)
23
+ is_experimental_checker_async = inspect.iscoroutinefunction(experimental_checker)
24
+
25
+ async def call_function(func, is_async, *args, **_kwargs):
26
+ if is_async:
27
+ return await func(*args, **_kwargs)
28
+ return func(*args, **_kwargs)
29
+
30
+ # asynchronous wrapper if any function is async
31
+ if any((is_f_async, is_fallback_async, is_experimental_checker_async)):
32
+
33
+ @wraps(f)
34
+ async def async_wrapper(*args, **_kwargs):
35
+ result = await call_function(experimental_checker, is_experimental_checker_async, feature_name, **dict(_kwargs, **kwargs))
36
+ if result:
37
+ return await call_function(f, is_f_async, *args, **_kwargs)
38
+ else:
39
+ return await call_function(fallback_function, is_fallback_async, *args, **_kwargs)
40
+
41
+ return async_wrapper
42
+
43
+ else:
44
+
45
+ @wraps(f)
46
+ def wrapper(*args, **_kwargs):
47
+ if experimental_checker(feature_name, **dict(_kwargs, **kwargs)):
48
+ return f(*args, **_kwargs)
49
+ else:
50
+ return fallback_function(*args, **kwargs)
51
+
52
+ return wrapper
53
+
54
+ return decorator
55
+
56
+
57
+ def deprecated(message: str):
58
+ """Simple decorator that marks a method as deprecated."""
59
+
60
+ def decorator(f):
61
+ @wraps(f)
62
+ def wrapper(*args, **kwargs):
63
+ if settings.debug:
64
+ logger.warning(f"Function {f.__name__} is deprecated: {message}.")
65
+ return f(*args, **kwargs)
66
+
67
+ return wrapper
68
+
69
+ return decorator
@@ -1,7 +1,12 @@
1
+ # TODO (cliandy): consolidate with decorators later
2
+ from functools import wraps
3
+
4
+
1
5
  def singleton(cls):
2
6
  """Decorator to make a class a Singleton class."""
3
7
  instances = {}
4
8
 
9
+ @wraps(cls)
5
10
  def get_instance(*args, **kwargs):
6
11
  if cls not in instances:
7
12
  instances[cls] = cls(*args, **kwargs)