ccproxy-api 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ccproxy/__init__.py +4 -0
- ccproxy/__main__.py +7 -0
- ccproxy/_version.py +21 -0
- ccproxy/adapters/__init__.py +11 -0
- ccproxy/adapters/base.py +80 -0
- ccproxy/adapters/openai/__init__.py +43 -0
- ccproxy/adapters/openai/adapter.py +915 -0
- ccproxy/adapters/openai/models.py +412 -0
- ccproxy/adapters/openai/streaming.py +449 -0
- ccproxy/api/__init__.py +28 -0
- ccproxy/api/app.py +225 -0
- ccproxy/api/dependencies.py +140 -0
- ccproxy/api/middleware/__init__.py +11 -0
- ccproxy/api/middleware/auth.py +0 -0
- ccproxy/api/middleware/cors.py +55 -0
- ccproxy/api/middleware/errors.py +703 -0
- ccproxy/api/middleware/headers.py +51 -0
- ccproxy/api/middleware/logging.py +175 -0
- ccproxy/api/middleware/request_id.py +69 -0
- ccproxy/api/middleware/server_header.py +62 -0
- ccproxy/api/responses.py +84 -0
- ccproxy/api/routes/__init__.py +16 -0
- ccproxy/api/routes/claude.py +181 -0
- ccproxy/api/routes/health.py +489 -0
- ccproxy/api/routes/metrics.py +1033 -0
- ccproxy/api/routes/proxy.py +238 -0
- ccproxy/auth/__init__.py +75 -0
- ccproxy/auth/bearer.py +68 -0
- ccproxy/auth/credentials_adapter.py +93 -0
- ccproxy/auth/dependencies.py +229 -0
- ccproxy/auth/exceptions.py +79 -0
- ccproxy/auth/manager.py +102 -0
- ccproxy/auth/models.py +118 -0
- ccproxy/auth/oauth/__init__.py +26 -0
- ccproxy/auth/oauth/models.py +49 -0
- ccproxy/auth/oauth/routes.py +396 -0
- ccproxy/auth/oauth/storage.py +0 -0
- ccproxy/auth/storage/__init__.py +12 -0
- ccproxy/auth/storage/base.py +57 -0
- ccproxy/auth/storage/json_file.py +159 -0
- ccproxy/auth/storage/keyring.py +192 -0
- ccproxy/claude_sdk/__init__.py +20 -0
- ccproxy/claude_sdk/client.py +169 -0
- ccproxy/claude_sdk/converter.py +331 -0
- ccproxy/claude_sdk/options.py +120 -0
- ccproxy/cli/__init__.py +14 -0
- ccproxy/cli/commands/__init__.py +8 -0
- ccproxy/cli/commands/auth.py +553 -0
- ccproxy/cli/commands/config/__init__.py +14 -0
- ccproxy/cli/commands/config/commands.py +766 -0
- ccproxy/cli/commands/config/schema_commands.py +119 -0
- ccproxy/cli/commands/serve.py +630 -0
- ccproxy/cli/docker/__init__.py +34 -0
- ccproxy/cli/docker/adapter_factory.py +157 -0
- ccproxy/cli/docker/params.py +278 -0
- ccproxy/cli/helpers.py +144 -0
- ccproxy/cli/main.py +193 -0
- ccproxy/cli/options/__init__.py +14 -0
- ccproxy/cli/options/claude_options.py +216 -0
- ccproxy/cli/options/core_options.py +40 -0
- ccproxy/cli/options/security_options.py +48 -0
- ccproxy/cli/options/server_options.py +117 -0
- ccproxy/config/__init__.py +40 -0
- ccproxy/config/auth.py +154 -0
- ccproxy/config/claude.py +124 -0
- ccproxy/config/cors.py +79 -0
- ccproxy/config/discovery.py +87 -0
- ccproxy/config/docker_settings.py +265 -0
- ccproxy/config/loader.py +108 -0
- ccproxy/config/observability.py +158 -0
- ccproxy/config/pricing.py +88 -0
- ccproxy/config/reverse_proxy.py +31 -0
- ccproxy/config/scheduler.py +89 -0
- ccproxy/config/security.py +14 -0
- ccproxy/config/server.py +81 -0
- ccproxy/config/settings.py +534 -0
- ccproxy/config/validators.py +231 -0
- ccproxy/core/__init__.py +274 -0
- ccproxy/core/async_utils.py +675 -0
- ccproxy/core/constants.py +97 -0
- ccproxy/core/errors.py +256 -0
- ccproxy/core/http.py +328 -0
- ccproxy/core/http_transformers.py +428 -0
- ccproxy/core/interfaces.py +247 -0
- ccproxy/core/logging.py +189 -0
- ccproxy/core/middleware.py +114 -0
- ccproxy/core/proxy.py +143 -0
- ccproxy/core/system.py +38 -0
- ccproxy/core/transformers.py +259 -0
- ccproxy/core/types.py +129 -0
- ccproxy/core/validators.py +288 -0
- ccproxy/docker/__init__.py +67 -0
- ccproxy/docker/adapter.py +588 -0
- ccproxy/docker/docker_path.py +207 -0
- ccproxy/docker/middleware.py +103 -0
- ccproxy/docker/models.py +228 -0
- ccproxy/docker/protocol.py +192 -0
- ccproxy/docker/stream_process.py +264 -0
- ccproxy/docker/validators.py +173 -0
- ccproxy/models/__init__.py +123 -0
- ccproxy/models/errors.py +42 -0
- ccproxy/models/messages.py +243 -0
- ccproxy/models/requests.py +85 -0
- ccproxy/models/responses.py +227 -0
- ccproxy/models/types.py +102 -0
- ccproxy/observability/__init__.py +51 -0
- ccproxy/observability/access_logger.py +400 -0
- ccproxy/observability/context.py +447 -0
- ccproxy/observability/metrics.py +539 -0
- ccproxy/observability/pushgateway.py +366 -0
- ccproxy/observability/sse_events.py +303 -0
- ccproxy/observability/stats_printer.py +755 -0
- ccproxy/observability/storage/__init__.py +1 -0
- ccproxy/observability/storage/duckdb_simple.py +665 -0
- ccproxy/observability/storage/models.py +55 -0
- ccproxy/pricing/__init__.py +19 -0
- ccproxy/pricing/cache.py +212 -0
- ccproxy/pricing/loader.py +267 -0
- ccproxy/pricing/models.py +106 -0
- ccproxy/pricing/updater.py +309 -0
- ccproxy/scheduler/__init__.py +39 -0
- ccproxy/scheduler/core.py +335 -0
- ccproxy/scheduler/exceptions.py +34 -0
- ccproxy/scheduler/manager.py +186 -0
- ccproxy/scheduler/registry.py +150 -0
- ccproxy/scheduler/tasks.py +484 -0
- ccproxy/services/__init__.py +10 -0
- ccproxy/services/claude_sdk_service.py +614 -0
- ccproxy/services/credentials/__init__.py +55 -0
- ccproxy/services/credentials/config.py +105 -0
- ccproxy/services/credentials/manager.py +562 -0
- ccproxy/services/credentials/oauth_client.py +482 -0
- ccproxy/services/proxy_service.py +1536 -0
- ccproxy/static/.keep +0 -0
- ccproxy/testing/__init__.py +34 -0
- ccproxy/testing/config.py +148 -0
- ccproxy/testing/content_generation.py +197 -0
- ccproxy/testing/mock_responses.py +262 -0
- ccproxy/testing/response_handlers.py +161 -0
- ccproxy/testing/scenarios.py +241 -0
- ccproxy/utils/__init__.py +6 -0
- ccproxy/utils/cost_calculator.py +210 -0
- ccproxy/utils/streaming_metrics.py +199 -0
- ccproxy_api-0.1.0.dist-info/METADATA +253 -0
- ccproxy_api-0.1.0.dist-info/RECORD +148 -0
- ccproxy_api-0.1.0.dist-info/WHEEL +4 -0
- ccproxy_api-0.1.0.dist-info/entry_points.txt +2 -0
- ccproxy_api-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
"""Pricing updater for managing periodic refresh of pricing data."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from decimal import Decimal
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from structlog import get_logger
|
|
8
|
+
|
|
9
|
+
from ccproxy.config.pricing import PricingSettings
|
|
10
|
+
|
|
11
|
+
from .cache import PricingCache
|
|
12
|
+
from .loader import PricingLoader
|
|
13
|
+
from .models import PricingData
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PricingUpdater:
|
|
20
|
+
"""Manages periodic updates of pricing data."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
cache: PricingCache,
|
|
25
|
+
settings: PricingSettings,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initialize pricing updater.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
cache: Pricing cache instance
|
|
31
|
+
settings: Pricing configuration settings
|
|
32
|
+
"""
|
|
33
|
+
self.cache = cache
|
|
34
|
+
self.settings = settings
|
|
35
|
+
self._cached_pricing: PricingData | None = None
|
|
36
|
+
self._last_load_time: float = 0
|
|
37
|
+
self._last_file_check_time: float = 0
|
|
38
|
+
self._cached_file_mtime: float = 0
|
|
39
|
+
|
|
40
|
+
async def get_current_pricing(
|
|
41
|
+
self, force_refresh: bool = False
|
|
42
|
+
) -> PricingData | None:
|
|
43
|
+
"""Get current pricing data with automatic updates.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
force_refresh: Force refresh even if cache is valid
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Current pricing data as PricingData model
|
|
50
|
+
"""
|
|
51
|
+
import time
|
|
52
|
+
|
|
53
|
+
current_time = time.time()
|
|
54
|
+
|
|
55
|
+
# Return cached pricing if recent and not forced
|
|
56
|
+
if (
|
|
57
|
+
not force_refresh
|
|
58
|
+
and self._cached_pricing is not None
|
|
59
|
+
and (current_time - self._last_load_time) < self.settings.memory_cache_ttl
|
|
60
|
+
):
|
|
61
|
+
# Only check file changes every 30 seconds to reduce I/O
|
|
62
|
+
if (current_time - self._last_file_check_time) > 30:
|
|
63
|
+
if self._has_cache_file_changed():
|
|
64
|
+
logger.info("cache_file_changed")
|
|
65
|
+
# File changed, need to reload
|
|
66
|
+
pricing_data = await self._load_pricing_data()
|
|
67
|
+
self._cached_pricing = pricing_data
|
|
68
|
+
self._last_load_time = current_time
|
|
69
|
+
return pricing_data
|
|
70
|
+
self._last_file_check_time = current_time
|
|
71
|
+
|
|
72
|
+
return self._cached_pricing
|
|
73
|
+
|
|
74
|
+
# Check if we need to refresh
|
|
75
|
+
should_refresh = force_refresh or (
|
|
76
|
+
self.settings.auto_update and not self.cache.is_cache_valid()
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if should_refresh:
|
|
80
|
+
logger.info("pricing_refresh_start")
|
|
81
|
+
await self._refresh_pricing()
|
|
82
|
+
|
|
83
|
+
# Load pricing data
|
|
84
|
+
pricing_data = await self._load_pricing_data()
|
|
85
|
+
|
|
86
|
+
# Cache the result
|
|
87
|
+
self._cached_pricing = pricing_data
|
|
88
|
+
self._last_load_time = current_time
|
|
89
|
+
self._last_file_check_time = current_time
|
|
90
|
+
|
|
91
|
+
return pricing_data
|
|
92
|
+
|
|
93
|
+
def _has_cache_file_changed(self) -> bool:
|
|
94
|
+
"""Check if the cache file has changed since last load.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
True if file has changed or doesn't exist
|
|
98
|
+
"""
|
|
99
|
+
try:
|
|
100
|
+
if not self.cache.cache_file.exists():
|
|
101
|
+
return self._cached_file_mtime != 0 # File was deleted
|
|
102
|
+
|
|
103
|
+
current_mtime = self.cache.cache_file.stat().st_mtime
|
|
104
|
+
if current_mtime != self._cached_file_mtime:
|
|
105
|
+
self._cached_file_mtime = current_mtime
|
|
106
|
+
return True
|
|
107
|
+
return False
|
|
108
|
+
except OSError:
|
|
109
|
+
# If we can't check, assume it changed
|
|
110
|
+
return True
|
|
111
|
+
|
|
112
|
+
async def _refresh_pricing(self) -> bool:
|
|
113
|
+
"""Refresh pricing data from external source.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
True if refresh was successful
|
|
117
|
+
"""
|
|
118
|
+
try:
|
|
119
|
+
logger.info("pricing_refresh_start")
|
|
120
|
+
|
|
121
|
+
# Download fresh data
|
|
122
|
+
raw_data = await self.cache.download_pricing_data()
|
|
123
|
+
if raw_data is None:
|
|
124
|
+
logger.error("pricing_download_failed")
|
|
125
|
+
return False
|
|
126
|
+
|
|
127
|
+
# Save to cache
|
|
128
|
+
if not self.cache.save_to_cache(raw_data):
|
|
129
|
+
logger.error("cache_save_failed")
|
|
130
|
+
return False
|
|
131
|
+
|
|
132
|
+
logger.info("pricing_refresh_completed")
|
|
133
|
+
return True
|
|
134
|
+
|
|
135
|
+
except Exception as e:
|
|
136
|
+
logger.error("pricing_refresh_failed", error=str(e))
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
async def _load_pricing_data(self) -> PricingData | None:
|
|
140
|
+
"""Load pricing data from available sources.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Pricing data as PricingData model
|
|
144
|
+
"""
|
|
145
|
+
# Try to get data from cache or download
|
|
146
|
+
raw_data = await self.cache.get_pricing_data()
|
|
147
|
+
|
|
148
|
+
if raw_data is not None:
|
|
149
|
+
# Load and validate pricing data using Pydantic
|
|
150
|
+
pricing_data = PricingLoader.load_pricing_from_data(raw_data, verbose=False)
|
|
151
|
+
|
|
152
|
+
if pricing_data:
|
|
153
|
+
# Get cache info to display age
|
|
154
|
+
cache_info = self.cache.get_cache_info()
|
|
155
|
+
age_hours = cache_info.get("age_hours")
|
|
156
|
+
|
|
157
|
+
if age_hours is not None:
|
|
158
|
+
logger.info(
|
|
159
|
+
"pricing_loaded_from_external",
|
|
160
|
+
model_count=len(pricing_data),
|
|
161
|
+
cache_age_hours=round(age_hours, 2),
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
logger.info(
|
|
165
|
+
"pricing_loaded_from_external", model_count=len(pricing_data)
|
|
166
|
+
)
|
|
167
|
+
return pricing_data
|
|
168
|
+
else:
|
|
169
|
+
logger.warning("external_pricing_validation_failed")
|
|
170
|
+
|
|
171
|
+
# Fallback to embedded pricing
|
|
172
|
+
if self.settings.fallback_to_embedded:
|
|
173
|
+
logger.info("using_embedded_pricing_fallback")
|
|
174
|
+
return self._get_embedded_pricing()
|
|
175
|
+
else:
|
|
176
|
+
logger.error("pricing_unavailable_no_fallback")
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
def _get_embedded_pricing(self) -> PricingData:
|
|
180
|
+
"""Get embedded (hardcoded) pricing as fallback.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Embedded pricing data as PricingData model
|
|
184
|
+
"""
|
|
185
|
+
# This is the current hardcoded pricing from CostCalculator
|
|
186
|
+
embedded_data = {
|
|
187
|
+
"claude-3-5-sonnet-20241022": {
|
|
188
|
+
"input": Decimal("3.00"),
|
|
189
|
+
"output": Decimal("15.00"),
|
|
190
|
+
"cache_read": Decimal("0.30"),
|
|
191
|
+
"cache_write": Decimal("3.75"),
|
|
192
|
+
},
|
|
193
|
+
"claude-3-5-haiku-20241022": {
|
|
194
|
+
"input": Decimal("0.25"),
|
|
195
|
+
"output": Decimal("1.25"),
|
|
196
|
+
"cache_read": Decimal("0.03"),
|
|
197
|
+
"cache_write": Decimal("0.30"),
|
|
198
|
+
},
|
|
199
|
+
"claude-3-opus-20240229": {
|
|
200
|
+
"input": Decimal("15.00"),
|
|
201
|
+
"output": Decimal("75.00"),
|
|
202
|
+
"cache_read": Decimal("1.50"),
|
|
203
|
+
"cache_write": Decimal("18.75"),
|
|
204
|
+
},
|
|
205
|
+
"claude-3-sonnet-20240229": {
|
|
206
|
+
"input": Decimal("3.00"),
|
|
207
|
+
"output": Decimal("15.00"),
|
|
208
|
+
"cache_read": Decimal("0.30"),
|
|
209
|
+
"cache_write": Decimal("3.75"),
|
|
210
|
+
},
|
|
211
|
+
"claude-3-haiku-20240307": {
|
|
212
|
+
"input": Decimal("0.25"),
|
|
213
|
+
"output": Decimal("1.25"),
|
|
214
|
+
"cache_read": Decimal("0.03"),
|
|
215
|
+
"cache_write": Decimal("0.30"),
|
|
216
|
+
},
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
# Create PricingData from embedded data
|
|
220
|
+
return PricingData.from_dict(embedded_data)
|
|
221
|
+
|
|
222
|
+
async def force_refresh(self) -> bool:
|
|
223
|
+
"""Force a refresh of pricing data.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
True if refresh was successful
|
|
227
|
+
"""
|
|
228
|
+
logger.info("pricing_force_refresh_start")
|
|
229
|
+
|
|
230
|
+
# Clear cached pricing
|
|
231
|
+
self._cached_pricing = None
|
|
232
|
+
self._last_load_time = 0
|
|
233
|
+
|
|
234
|
+
# Refresh from external source
|
|
235
|
+
success = await self._refresh_pricing()
|
|
236
|
+
|
|
237
|
+
if success:
|
|
238
|
+
# Reload pricing data
|
|
239
|
+
await self.get_current_pricing(force_refresh=True)
|
|
240
|
+
|
|
241
|
+
return success
|
|
242
|
+
|
|
243
|
+
def clear_cache(self) -> bool:
|
|
244
|
+
"""Clear all cached pricing data.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
True if cache was cleared successfully
|
|
248
|
+
"""
|
|
249
|
+
logger.info("pricing_cache_clear_start")
|
|
250
|
+
|
|
251
|
+
# Clear in-memory cache
|
|
252
|
+
self._cached_pricing = None
|
|
253
|
+
self._last_load_time = 0
|
|
254
|
+
|
|
255
|
+
# Clear file cache
|
|
256
|
+
return self.cache.clear_cache()
|
|
257
|
+
|
|
258
|
+
async def get_pricing_info(self) -> dict[str, Any]:
|
|
259
|
+
"""Get information about current pricing state.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
Dictionary with pricing information
|
|
263
|
+
"""
|
|
264
|
+
cache_info = self.cache.get_cache_info()
|
|
265
|
+
|
|
266
|
+
pricing_data = await self.get_current_pricing()
|
|
267
|
+
|
|
268
|
+
return {
|
|
269
|
+
"models_loaded": len(pricing_data) if pricing_data else 0,
|
|
270
|
+
"model_names": pricing_data.model_names() if pricing_data else [],
|
|
271
|
+
"auto_update": self.settings.auto_update,
|
|
272
|
+
"fallback_to_embedded": self.settings.fallback_to_embedded,
|
|
273
|
+
"has_cached_pricing": self._cached_pricing is not None,
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
async def validate_external_source(self) -> bool:
|
|
277
|
+
"""Validate that external pricing source is accessible.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
True if external source is accessible and has valid data
|
|
281
|
+
"""
|
|
282
|
+
try:
|
|
283
|
+
logger.debug("external_pricing_validation_start")
|
|
284
|
+
|
|
285
|
+
# Try to download data
|
|
286
|
+
raw_data = await self.cache.download_pricing_data(timeout=10)
|
|
287
|
+
if raw_data is None:
|
|
288
|
+
return False
|
|
289
|
+
|
|
290
|
+
# Try to parse Claude models
|
|
291
|
+
claude_models = PricingLoader.extract_claude_models(raw_data)
|
|
292
|
+
if not claude_models:
|
|
293
|
+
logger.warning("claude_models_not_found_in_external")
|
|
294
|
+
return False
|
|
295
|
+
|
|
296
|
+
# Try to load and validate using Pydantic
|
|
297
|
+
pricing_data = PricingLoader.load_pricing_from_data(raw_data, verbose=False)
|
|
298
|
+
if not pricing_data:
|
|
299
|
+
logger.warning("external_pricing_load_failed")
|
|
300
|
+
return False
|
|
301
|
+
|
|
302
|
+
logger.info(
|
|
303
|
+
"external_pricing_validation_completed", model_count=len(pricing_data)
|
|
304
|
+
)
|
|
305
|
+
return True
|
|
306
|
+
|
|
307
|
+
except Exception as e:
|
|
308
|
+
logger.error("external_pricing_validation_failed", error=str(e))
|
|
309
|
+
return False
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Scheduler system for periodic tasks.
|
|
3
|
+
|
|
4
|
+
This module provides a generic, extensible scheduler for managing periodic tasks
|
|
5
|
+
in the CCProxy API. It provides a centralized system that supports:
|
|
6
|
+
|
|
7
|
+
- Generic task scheduling with configurable intervals
|
|
8
|
+
- Task registration and discovery via registry pattern
|
|
9
|
+
- Graceful startup and shutdown with FastAPI integration
|
|
10
|
+
- Error handling with exponential backoff
|
|
11
|
+
- Structured logging and monitoring
|
|
12
|
+
|
|
13
|
+
Key components:
|
|
14
|
+
- Scheduler: Core scheduler engine for task management
|
|
15
|
+
- BaseScheduledTask: Abstract base class for all scheduled tasks
|
|
16
|
+
- TaskRegistry: Dynamic task registration and discovery system
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from .core import Scheduler
|
|
20
|
+
from .registry import TaskRegistry, register_task
|
|
21
|
+
from .tasks import (
|
|
22
|
+
BaseScheduledTask,
|
|
23
|
+
PricingCacheUpdateTask,
|
|
24
|
+
PushgatewayTask,
|
|
25
|
+
StatsPrintingTask,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Task registration is now handled in manager.py during scheduler startup
|
|
30
|
+
# to avoid side effects during module imports (e.g., CLI help display)
|
|
31
|
+
|
|
32
|
+
__all__ = [
|
|
33
|
+
"Scheduler",
|
|
34
|
+
"TaskRegistry",
|
|
35
|
+
"BaseScheduledTask",
|
|
36
|
+
"PushgatewayTask",
|
|
37
|
+
"StatsPrintingTask",
|
|
38
|
+
"PricingCacheUpdateTask",
|
|
39
|
+
]
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""Core scheduler for managing periodic tasks."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import structlog
|
|
7
|
+
|
|
8
|
+
from .exceptions import (
|
|
9
|
+
SchedulerError,
|
|
10
|
+
SchedulerShutdownError,
|
|
11
|
+
TaskNotFoundError,
|
|
12
|
+
TaskRegistrationError,
|
|
13
|
+
)
|
|
14
|
+
from .registry import TaskRegistry, get_task_registry
|
|
15
|
+
from .tasks import BaseScheduledTask
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
logger = structlog.get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Scheduler:
|
|
22
|
+
"""
|
|
23
|
+
Scheduler for managing multiple periodic tasks.
|
|
24
|
+
|
|
25
|
+
Provides centralized management of scheduled tasks with:
|
|
26
|
+
- Dynamic task registration and configuration
|
|
27
|
+
- Graceful startup and shutdown
|
|
28
|
+
- Task monitoring and status reporting
|
|
29
|
+
- Error handling and recovery
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
max_concurrent_tasks: int = 10,
|
|
35
|
+
graceful_shutdown_timeout: float = 30.0,
|
|
36
|
+
task_registry: TaskRegistry | None = None,
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Initialize the scheduler.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
max_concurrent_tasks: Maximum number of tasks to run concurrently
|
|
43
|
+
graceful_shutdown_timeout: Timeout for graceful shutdown in seconds
|
|
44
|
+
task_registry: Task registry instance (uses global if None)
|
|
45
|
+
"""
|
|
46
|
+
self.max_concurrent_tasks = max_concurrent_tasks
|
|
47
|
+
self.graceful_shutdown_timeout = graceful_shutdown_timeout
|
|
48
|
+
self.task_registry = task_registry or get_task_registry()
|
|
49
|
+
|
|
50
|
+
self._running = False
|
|
51
|
+
self._tasks: dict[str, BaseScheduledTask] = {}
|
|
52
|
+
self._semaphore: asyncio.Semaphore | None = None
|
|
53
|
+
|
|
54
|
+
async def start(self) -> None:
|
|
55
|
+
"""Start the scheduler and all enabled tasks."""
|
|
56
|
+
if self._running:
|
|
57
|
+
logger.warning("scheduler_already_running")
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
self._running = True
|
|
61
|
+
self._semaphore = asyncio.Semaphore(self.max_concurrent_tasks)
|
|
62
|
+
|
|
63
|
+
logger.info(
|
|
64
|
+
"scheduler_starting",
|
|
65
|
+
max_concurrent_tasks=self.max_concurrent_tasks,
|
|
66
|
+
registered_tasks=self.task_registry.list_tasks(),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
# No automatic task creation - tasks must be explicitly added
|
|
71
|
+
logger.info(
|
|
72
|
+
"scheduler_started",
|
|
73
|
+
active_tasks=len(self._tasks),
|
|
74
|
+
running_tasks=[
|
|
75
|
+
name for name, task in self._tasks.items() if task.is_running
|
|
76
|
+
],
|
|
77
|
+
)
|
|
78
|
+
except Exception as e:
|
|
79
|
+
self._running = False
|
|
80
|
+
logger.error(
|
|
81
|
+
"scheduler_start_failed",
|
|
82
|
+
error=str(e),
|
|
83
|
+
error_type=type(e).__name__,
|
|
84
|
+
)
|
|
85
|
+
raise SchedulerError(f"Failed to start scheduler: {e}") from e
|
|
86
|
+
|
|
87
|
+
async def stop(self) -> None:
|
|
88
|
+
"""Stop the scheduler and all running tasks."""
|
|
89
|
+
if not self._running:
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
self._running = False
|
|
93
|
+
logger.info("scheduler_stopping", active_tasks=len(self._tasks))
|
|
94
|
+
|
|
95
|
+
# Stop all tasks
|
|
96
|
+
stop_tasks = []
|
|
97
|
+
for task_name, task in self._tasks.items():
|
|
98
|
+
if task.is_running:
|
|
99
|
+
logger.debug("stopping_task", task_name=task_name)
|
|
100
|
+
stop_tasks.append(task.stop())
|
|
101
|
+
|
|
102
|
+
if stop_tasks:
|
|
103
|
+
try:
|
|
104
|
+
# Wait for all tasks to stop gracefully
|
|
105
|
+
await asyncio.wait_for(
|
|
106
|
+
asyncio.gather(*stop_tasks, return_exceptions=True),
|
|
107
|
+
timeout=self.graceful_shutdown_timeout,
|
|
108
|
+
)
|
|
109
|
+
logger.info("scheduler_stopped_gracefully")
|
|
110
|
+
except TimeoutError:
|
|
111
|
+
logger.warning(
|
|
112
|
+
"scheduler_shutdown_timeout",
|
|
113
|
+
timeout=self.graceful_shutdown_timeout,
|
|
114
|
+
)
|
|
115
|
+
# Tasks should have cancelled themselves, but log the issue
|
|
116
|
+
for task_name, task in self._tasks.items():
|
|
117
|
+
if task.is_running:
|
|
118
|
+
logger.warning(
|
|
119
|
+
"task_still_running_after_shutdown", task_name=task_name
|
|
120
|
+
)
|
|
121
|
+
except Exception as e:
|
|
122
|
+
logger.error(
|
|
123
|
+
"scheduler_shutdown_error",
|
|
124
|
+
error=str(e),
|
|
125
|
+
error_type=type(e).__name__,
|
|
126
|
+
)
|
|
127
|
+
raise SchedulerShutdownError(
|
|
128
|
+
f"Error during scheduler shutdown: {e}"
|
|
129
|
+
) from e
|
|
130
|
+
|
|
131
|
+
self._tasks.clear()
|
|
132
|
+
logger.info("scheduler_stopped")
|
|
133
|
+
|
|
134
|
+
async def add_task(
|
|
135
|
+
self,
|
|
136
|
+
task_name: str,
|
|
137
|
+
task_type: str,
|
|
138
|
+
**task_kwargs: Any,
|
|
139
|
+
) -> None:
|
|
140
|
+
"""
|
|
141
|
+
Add and start a task.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
task_name: Unique name for this task instance
|
|
145
|
+
task_type: Type of task (must be registered in task registry)
|
|
146
|
+
**task_kwargs: Additional arguments to pass to task constructor
|
|
147
|
+
|
|
148
|
+
Raises:
|
|
149
|
+
TaskRegistrationError: If task type is not registered
|
|
150
|
+
SchedulerError: If task name already exists or task creation fails
|
|
151
|
+
"""
|
|
152
|
+
if task_name in self._tasks:
|
|
153
|
+
raise SchedulerError(f"Task '{task_name}' already exists")
|
|
154
|
+
|
|
155
|
+
if not self.task_registry.is_registered(task_type):
|
|
156
|
+
raise TaskRegistrationError(f"Task type '{task_type}' is not registered")
|
|
157
|
+
|
|
158
|
+
try:
|
|
159
|
+
# Get task class and create instance
|
|
160
|
+
task_class = self.task_registry.get(task_type)
|
|
161
|
+
task_instance = task_class(name=task_name, **task_kwargs)
|
|
162
|
+
|
|
163
|
+
# Add to our tasks dict
|
|
164
|
+
self._tasks[task_name] = task_instance
|
|
165
|
+
|
|
166
|
+
# Start the task if scheduler is running and task is enabled
|
|
167
|
+
if self._running and task_instance.enabled:
|
|
168
|
+
await task_instance.start()
|
|
169
|
+
logger.info(
|
|
170
|
+
"task_added_and_started",
|
|
171
|
+
task_name=task_name,
|
|
172
|
+
task_type=task_type,
|
|
173
|
+
)
|
|
174
|
+
else:
|
|
175
|
+
logger.info(
|
|
176
|
+
"task_added_not_started",
|
|
177
|
+
task_name=task_name,
|
|
178
|
+
task_type=task_type,
|
|
179
|
+
scheduler_running=self._running,
|
|
180
|
+
task_enabled=task_instance.enabled,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
except Exception as e:
|
|
184
|
+
# Clean up if task was partially added
|
|
185
|
+
if task_name in self._tasks:
|
|
186
|
+
del self._tasks[task_name]
|
|
187
|
+
|
|
188
|
+
logger.error(
|
|
189
|
+
"task_add_failed",
|
|
190
|
+
task_name=task_name,
|
|
191
|
+
task_type=task_type,
|
|
192
|
+
error=str(e),
|
|
193
|
+
error_type=type(e).__name__,
|
|
194
|
+
)
|
|
195
|
+
raise SchedulerError(f"Failed to add task '{task_name}': {e}") from e
|
|
196
|
+
|
|
197
|
+
async def remove_task(self, task_name: str) -> None:
|
|
198
|
+
"""
|
|
199
|
+
Remove and stop a task.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
task_name: Name of task to remove
|
|
203
|
+
|
|
204
|
+
Raises:
|
|
205
|
+
TaskNotFoundError: If task does not exist
|
|
206
|
+
"""
|
|
207
|
+
if task_name not in self._tasks:
|
|
208
|
+
raise TaskNotFoundError(f"Task '{task_name}' does not exist")
|
|
209
|
+
|
|
210
|
+
task = self._tasks[task_name]
|
|
211
|
+
|
|
212
|
+
try:
|
|
213
|
+
if task.is_running:
|
|
214
|
+
await task.stop()
|
|
215
|
+
|
|
216
|
+
del self._tasks[task_name]
|
|
217
|
+
logger.info("task_removed", task_name=task_name)
|
|
218
|
+
|
|
219
|
+
except Exception as e:
|
|
220
|
+
logger.error(
|
|
221
|
+
"task_remove_failed",
|
|
222
|
+
task_name=task_name,
|
|
223
|
+
error=str(e),
|
|
224
|
+
error_type=type(e).__name__,
|
|
225
|
+
)
|
|
226
|
+
raise SchedulerError(f"Failed to remove task '{task_name}': {e}") from e
|
|
227
|
+
|
|
228
|
+
def get_task(self, task_name: str) -> BaseScheduledTask:
|
|
229
|
+
"""
|
|
230
|
+
Get a task instance by name.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
task_name: Name of task to retrieve
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Task instance
|
|
237
|
+
|
|
238
|
+
Raises:
|
|
239
|
+
TaskNotFoundError: If task does not exist
|
|
240
|
+
"""
|
|
241
|
+
if task_name not in self._tasks:
|
|
242
|
+
raise TaskNotFoundError(f"Task '{task_name}' does not exist")
|
|
243
|
+
|
|
244
|
+
return self._tasks[task_name]
|
|
245
|
+
|
|
246
|
+
def list_tasks(self) -> list[str]:
|
|
247
|
+
"""
|
|
248
|
+
Get list of all task names.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
List of task names
|
|
252
|
+
"""
|
|
253
|
+
return list(self._tasks.keys())
|
|
254
|
+
|
|
255
|
+
def get_task_status(self, task_name: str) -> dict[str, Any]:
|
|
256
|
+
"""
|
|
257
|
+
Get status information for a specific task.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
task_name: Name of task
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Task status dictionary
|
|
264
|
+
|
|
265
|
+
Raises:
|
|
266
|
+
TaskNotFoundError: If task does not exist
|
|
267
|
+
"""
|
|
268
|
+
if task_name not in self._tasks:
|
|
269
|
+
raise TaskNotFoundError(f"Task '{task_name}' does not exist")
|
|
270
|
+
|
|
271
|
+
return self._tasks[task_name].get_status()
|
|
272
|
+
|
|
273
|
+
def get_scheduler_status(self) -> dict[str, Any]:
|
|
274
|
+
"""
|
|
275
|
+
Get overall scheduler status information.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
Scheduler status dictionary
|
|
279
|
+
"""
|
|
280
|
+
running_tasks = [name for name, task in self._tasks.items() if task.is_running]
|
|
281
|
+
|
|
282
|
+
return {
|
|
283
|
+
"running": self._running,
|
|
284
|
+
"total_tasks": len(self._tasks),
|
|
285
|
+
"running_tasks": len(running_tasks),
|
|
286
|
+
"max_concurrent_tasks": self.max_concurrent_tasks,
|
|
287
|
+
"graceful_shutdown_timeout": self.graceful_shutdown_timeout,
|
|
288
|
+
"task_names": list(self._tasks.keys()),
|
|
289
|
+
"running_task_names": running_tasks,
|
|
290
|
+
"registered_task_types": self.task_registry.list_tasks(),
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
@property
|
|
294
|
+
def is_running(self) -> bool:
|
|
295
|
+
"""Check if the scheduler is running."""
|
|
296
|
+
return self._running
|
|
297
|
+
|
|
298
|
+
@property
|
|
299
|
+
def task_count(self) -> int:
|
|
300
|
+
"""Get the number of managed tasks."""
|
|
301
|
+
return len(self._tasks)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
# Global scheduler instance
|
|
305
|
+
_global_scheduler: Scheduler | None = None
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
async def get_scheduler() -> Scheduler:
|
|
309
|
+
"""
|
|
310
|
+
Get or create the global scheduler instance.
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
Global Scheduler instance
|
|
314
|
+
"""
|
|
315
|
+
global _global_scheduler
|
|
316
|
+
|
|
317
|
+
if _global_scheduler is None:
|
|
318
|
+
_global_scheduler = Scheduler()
|
|
319
|
+
|
|
320
|
+
return _global_scheduler
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
async def start_scheduler() -> None:
|
|
324
|
+
"""Start the global scheduler."""
|
|
325
|
+
scheduler = await get_scheduler()
|
|
326
|
+
await scheduler.start()
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
async def stop_scheduler() -> None:
|
|
330
|
+
"""Stop the global scheduler."""
|
|
331
|
+
global _global_scheduler
|
|
332
|
+
|
|
333
|
+
if _global_scheduler:
|
|
334
|
+
await _global_scheduler.stop()
|
|
335
|
+
_global_scheduler = None
|