github-ai-scraper 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_scraper/__init__.py +3 -0
- ai_scraper/api/__init__.py +6 -0
- ai_scraper/api/github.py +340 -0
- ai_scraper/api/gitlab.py +418 -0
- ai_scraper/api/rate_limiter.py +120 -0
- ai_scraper/api_server.py +196 -0
- ai_scraper/auth.py +68 -0
- ai_scraper/backup.py +112 -0
- ai_scraper/cache.py +95 -0
- ai_scraper/classifier.py +135 -0
- ai_scraper/cli.py +747 -0
- ai_scraper/config.py +237 -0
- ai_scraper/config_watcher.py +82 -0
- ai_scraper/dedup.py +148 -0
- ai_scraper/filters/__init__.py +5 -0
- ai_scraper/filters/ai_filter.py +93 -0
- ai_scraper/health.py +155 -0
- ai_scraper/i18n.py +141 -0
- ai_scraper/interactive.py +96 -0
- ai_scraper/keywords/__init__.py +5 -0
- ai_scraper/keywords/extractor.py +274 -0
- ai_scraper/logging_config.py +74 -0
- ai_scraper/models/__init__.py +5 -0
- ai_scraper/models/repository.py +72 -0
- ai_scraper/output/__init__.py +6 -0
- ai_scraper/output/excel.py +79 -0
- ai_scraper/output/html.py +152 -0
- ai_scraper/output/markdown.py +338 -0
- ai_scraper/output/rss.py +82 -0
- ai_scraper/output/translator.py +303 -0
- ai_scraper/plugin_system.py +146 -0
- ai_scraper/plugins/__init__.py +5 -0
- ai_scraper/retry.py +134 -0
- ai_scraper/scheduler.py +84 -0
- ai_scraper/scrape_progress.py +99 -0
- ai_scraper/secure_storage.py +127 -0
- ai_scraper/storage/__init__.py +5 -0
- ai_scraper/storage/async_database.py +237 -0
- ai_scraper/storage/database.py +456 -0
- ai_scraper/webhooks.py +95 -0
- github_ai_scraper-0.1.2.dist-info/METADATA +299 -0
- github_ai_scraper-0.1.2.dist-info/RECORD +44 -0
- github_ai_scraper-0.1.2.dist-info/WHEEL +4 -0
- github_ai_scraper-0.1.2.dist-info/entry_points.txt +2 -0
ai_scraper/api/gitlab.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
1
|
+
"""GitLab API client."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import aiohttp
|
|
10
|
+
|
|
11
|
+
from ai_scraper.api.rate_limiter import RateLimitInfo, RateLimiter
|
|
12
|
+
from ai_scraper.cache import RequestCache
|
|
13
|
+
from ai_scraper.models.repository import Repository
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GitLabAPIError(Exception):
|
|
19
|
+
"""GitLab API error."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, status: int, message: str):
|
|
22
|
+
self.status = status
|
|
23
|
+
self.message = message
|
|
24
|
+
super().__init__(f"GitLab API error {status}: {message}")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GitLabClient:
|
|
28
|
+
"""Asynchronous GitLab API client.
|
|
29
|
+
|
|
30
|
+
Supports both gitlab.com and self-hosted GitLab instances.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
DEFAULT_BASE_URL = "https://gitlab.com/api/v4"
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
token: Optional[str] = None,
|
|
38
|
+
base_url: Optional[str] = None,
|
|
39
|
+
cache_dir: Optional[Path] = None,
|
|
40
|
+
cache_ttl: int = 3600,
|
|
41
|
+
connection_pool_size: int = 10,
|
|
42
|
+
):
|
|
43
|
+
"""Initialize GitLab client.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
token: GitLab Personal Access Token (optional).
|
|
47
|
+
base_url: GitLab API base URL (default: gitlab.com).
|
|
48
|
+
For self-hosted: "https://your-gitlab.com/api/v4"
|
|
49
|
+
cache_dir: Directory for cache files (optional).
|
|
50
|
+
cache_ttl: Cache time-to-live in seconds.
|
|
51
|
+
connection_pool_size: Maximum number of connections in pool.
|
|
52
|
+
"""
|
|
53
|
+
self.token = token
|
|
54
|
+
self.base_url = base_url or self.DEFAULT_BASE_URL
|
|
55
|
+
self.session: Optional[aiohttp.ClientSession] = None
|
|
56
|
+
self.connection_pool_size = connection_pool_size
|
|
57
|
+
|
|
58
|
+
# Rate limiter: GitLab has higher limits
|
|
59
|
+
# Authenticated: 2000 requests per minute
|
|
60
|
+
# Unauthenticated: 100 requests per minute
|
|
61
|
+
rate = 2000 if token else 100
|
|
62
|
+
self.rate_limiter = RateLimiter(requests_per_hour=rate * 60)
|
|
63
|
+
|
|
64
|
+
# Request cache
|
|
65
|
+
self.cache: Optional[RequestCache] = None
|
|
66
|
+
if cache_dir:
|
|
67
|
+
self.cache = RequestCache(cache_dir=cache_dir, ttl=cache_ttl)
|
|
68
|
+
|
|
69
|
+
async def _get_session(self) -> aiohttp.ClientSession:
|
|
70
|
+
"""Get or create aiohttp session with connection pooling."""
|
|
71
|
+
if self.session is None or self.session.closed:
|
|
72
|
+
headers = {"Content-Type": "application/json"}
|
|
73
|
+
if self.token:
|
|
74
|
+
headers["PRIVATE-TOKEN"] = self.token
|
|
75
|
+
|
|
76
|
+
# Configure connection pool
|
|
77
|
+
connector = aiohttp.TCPConnector(
|
|
78
|
+
limit=self.connection_pool_size,
|
|
79
|
+
limit_per_host=self.connection_pool_size,
|
|
80
|
+
enable_cleanup_closed=True,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
self.session = aiohttp.ClientSession(
|
|
84
|
+
headers=headers,
|
|
85
|
+
connector=connector,
|
|
86
|
+
timeout=aiohttp.ClientTimeout(total=30),
|
|
87
|
+
)
|
|
88
|
+
return self.session
|
|
89
|
+
|
|
90
|
+
async def close(self) -> None:
|
|
91
|
+
"""Close the HTTP session."""
|
|
92
|
+
if self.session and not self.session.closed:
|
|
93
|
+
await self.session.close()
|
|
94
|
+
|
|
95
|
+
async def _request(
|
|
96
|
+
self,
|
|
97
|
+
endpoint: str,
|
|
98
|
+
params: Optional[dict] = None,
|
|
99
|
+
method: str = "GET",
|
|
100
|
+
) -> dict | list:
|
|
101
|
+
"""Make an API request.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
endpoint: API endpoint (without base URL).
|
|
105
|
+
params: Query parameters.
|
|
106
|
+
method: HTTP method.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
JSON response data.
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
GitLabAPIError: On API errors.
|
|
113
|
+
"""
|
|
114
|
+
url = f"{self.base_url}{endpoint}"
|
|
115
|
+
|
|
116
|
+
# Check cache first (only for GET requests)
|
|
117
|
+
if method == "GET" and self.cache:
|
|
118
|
+
cached = self.cache.get(url, params)
|
|
119
|
+
if cached is not None:
|
|
120
|
+
logger.debug(f"Cache hit for {endpoint}")
|
|
121
|
+
return cached
|
|
122
|
+
|
|
123
|
+
# Wait for rate limiter
|
|
124
|
+
while not self.rate_limiter.try_acquire():
|
|
125
|
+
wait_time = self.rate_limiter.wait_time()
|
|
126
|
+
logger.debug(f"Rate limited, waiting {wait_time:.1f}s")
|
|
127
|
+
await asyncio.sleep(min(wait_time, 1.0))
|
|
128
|
+
|
|
129
|
+
session = await self._get_session()
|
|
130
|
+
|
|
131
|
+
async with session.request(method, url, params=params) as response:
|
|
132
|
+
if response.status == 401:
|
|
133
|
+
raise GitLabAPIError(401, "Unauthorized - check your token")
|
|
134
|
+
elif response.status == 403:
|
|
135
|
+
raise GitLabAPIError(403, "Forbidden - check your permissions")
|
|
136
|
+
elif response.status == 404:
|
|
137
|
+
raise GitLabAPIError(404, "Resource not found")
|
|
138
|
+
elif response.status == 429:
|
|
139
|
+
# Rate limited
|
|
140
|
+
retry_after = response.headers.get("Retry-After", "60")
|
|
141
|
+
raise GitLabAPIError(429, f"Rate limited, retry after {retry_after}s")
|
|
142
|
+
elif response.status >= 400:
|
|
143
|
+
text = await response.text()
|
|
144
|
+
raise GitLabAPIError(response.status, text)
|
|
145
|
+
|
|
146
|
+
data = await response.json()
|
|
147
|
+
|
|
148
|
+
# Cache successful GET response
|
|
149
|
+
if method == "GET" and self.cache:
|
|
150
|
+
self.cache.set(url, params, data)
|
|
151
|
+
logger.debug(f"Cached response for {endpoint}")
|
|
152
|
+
|
|
153
|
+
return data
|
|
154
|
+
|
|
155
|
+
async def search_projects(
|
|
156
|
+
self,
|
|
157
|
+
query: str,
|
|
158
|
+
sort: str = "star_count",
|
|
159
|
+
order: str = "desc",
|
|
160
|
+
page: int = 1,
|
|
161
|
+
per_page: int = 100,
|
|
162
|
+
min_stars: int = 0,
|
|
163
|
+
topics: Optional[list[str]] = None,
|
|
164
|
+
) -> list[Repository]:
|
|
165
|
+
"""Search projects.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
query: Search query.
|
|
169
|
+
sort: Sort field (star_count, name, created_at, updated_at).
|
|
170
|
+
order: Sort order (asc, desc).
|
|
171
|
+
page: Page number (1-indexed).
|
|
172
|
+
per_page: Results per page (max 100).
|
|
173
|
+
min_stars: Minimum star count filter.
|
|
174
|
+
topics: List of topics to filter by.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
List of repositories.
|
|
178
|
+
"""
|
|
179
|
+
params = {
|
|
180
|
+
"search": query,
|
|
181
|
+
"order_by": sort,
|
|
182
|
+
"sort": order,
|
|
183
|
+
"page": page,
|
|
184
|
+
"per_page": min(per_page, 100),
|
|
185
|
+
"star_count": f">={min_stars}" if min_stars > 0 else None,
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
# Remove None values
|
|
189
|
+
params = {k: v for k, v in params.items() if v is not None}
|
|
190
|
+
|
|
191
|
+
# Add topic filters
|
|
192
|
+
if topics:
|
|
193
|
+
params["topic"] = ",".join(topics)
|
|
194
|
+
|
|
195
|
+
data = await self._request("/projects", params)
|
|
196
|
+
|
|
197
|
+
if isinstance(data, list):
|
|
198
|
+
return [self._parse_project(item) for item in data]
|
|
199
|
+
return []
|
|
200
|
+
|
|
201
|
+
async def search_projects_concurrent(
|
|
202
|
+
self,
|
|
203
|
+
query: str,
|
|
204
|
+
max_pages: int = 5,
|
|
205
|
+
per_page: int = 100,
|
|
206
|
+
sort: str = "star_count",
|
|
207
|
+
order: str = "desc",
|
|
208
|
+
max_concurrent: int = 5,
|
|
209
|
+
min_stars: int = 0,
|
|
210
|
+
topics: Optional[list[str]] = None,
|
|
211
|
+
) -> list[Repository]:
|
|
212
|
+
"""Search projects concurrently across multiple pages.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
query: Search query.
|
|
216
|
+
max_pages: Maximum number of pages to fetch.
|
|
217
|
+
per_page: Results per page (max 100).
|
|
218
|
+
sort: Sort field.
|
|
219
|
+
order: Sort order (asc, desc).
|
|
220
|
+
max_concurrent: Maximum concurrent requests.
|
|
221
|
+
min_stars: Minimum star count filter.
|
|
222
|
+
topics: List of topics to filter by.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
List of repositories from all pages.
|
|
226
|
+
"""
|
|
227
|
+
semaphore = asyncio.Semaphore(max_concurrent)
|
|
228
|
+
|
|
229
|
+
async def fetch_page(page: int) -> list[Repository]:
|
|
230
|
+
async with semaphore:
|
|
231
|
+
params = {
|
|
232
|
+
"search": query,
|
|
233
|
+
"order_by": sort,
|
|
234
|
+
"sort": order,
|
|
235
|
+
"page": page,
|
|
236
|
+
"per_page": min(per_page, 100),
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
if min_stars > 0:
|
|
240
|
+
params["star_count"] = f">={min_stars}"
|
|
241
|
+
|
|
242
|
+
if topics:
|
|
243
|
+
params["topic"] = ",".join(topics)
|
|
244
|
+
|
|
245
|
+
try:
|
|
246
|
+
data = await self._request("/projects", params)
|
|
247
|
+
if isinstance(data, list):
|
|
248
|
+
return [self._parse_project(item) for item in data]
|
|
249
|
+
return []
|
|
250
|
+
except GitLabAPIError as e:
|
|
251
|
+
logger.warning(f"Page {page} fetch failed: {e}")
|
|
252
|
+
return []
|
|
253
|
+
|
|
254
|
+
# Create tasks for all pages
|
|
255
|
+
tasks = [fetch_page(page) for page in range(1, max_pages + 1)]
|
|
256
|
+
|
|
257
|
+
# Execute concurrently
|
|
258
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
259
|
+
|
|
260
|
+
# Flatten results, skipping exceptions
|
|
261
|
+
all_repos = []
|
|
262
|
+
for result in results:
|
|
263
|
+
if isinstance(result, Exception):
|
|
264
|
+
logger.warning(f"Page fetch failed: {result}")
|
|
265
|
+
continue
|
|
266
|
+
all_repos.extend(result)
|
|
267
|
+
|
|
268
|
+
return all_repos
|
|
269
|
+
|
|
270
|
+
async def get_project(self, project_id: int | str) -> Repository:
|
|
271
|
+
"""Get a single project.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
project_id: Project ID or URL-encoded path (e.g., "group/project").
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
Repository data.
|
|
278
|
+
"""
|
|
279
|
+
data = await self._request(f"/projects/{project_id}")
|
|
280
|
+
return self._parse_project(data)
|
|
281
|
+
|
|
282
|
+
async def get_project_languages(self, project_id: int | str) -> dict[str, float]:
|
|
283
|
+
"""Get languages used in a project.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
project_id: Project ID.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Dictionary of language -> percentage.
|
|
290
|
+
"""
|
|
291
|
+
data = await self._request(f"/projects/{project_id}/languages")
|
|
292
|
+
return data if isinstance(data, dict) else {}
|
|
293
|
+
|
|
294
|
+
async def get_trending_projects(
|
|
295
|
+
self,
|
|
296
|
+
since: str = "weekly",
|
|
297
|
+
page: int = 1,
|
|
298
|
+
per_page: int = 100,
|
|
299
|
+
) -> list[Repository]:
|
|
300
|
+
"""Get trending projects.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
since: Time period (daily, weekly, monthly).
|
|
304
|
+
page: Page number.
|
|
305
|
+
per_page: Results per page.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
List of trending repositories.
|
|
309
|
+
"""
|
|
310
|
+
params = {
|
|
311
|
+
"order_by": "star_count",
|
|
312
|
+
"sort": "desc",
|
|
313
|
+
"page": page,
|
|
314
|
+
"per_page": min(per_page, 100),
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
# GitLab doesn't have a direct trending endpoint,
|
|
318
|
+
# so we get recently updated popular projects
|
|
319
|
+
data = await self._request("/projects", params)
|
|
320
|
+
|
|
321
|
+
if isinstance(data, list):
|
|
322
|
+
return [self._parse_project(item) for item in data]
|
|
323
|
+
return []
|
|
324
|
+
|
|
325
|
+
async def get_group_projects(
|
|
326
|
+
self,
|
|
327
|
+
group_id: int | str,
|
|
328
|
+
include_subgroups: bool = True,
|
|
329
|
+
page: int = 1,
|
|
330
|
+
per_page: int = 100,
|
|
331
|
+
) -> list[Repository]:
|
|
332
|
+
"""Get projects in a group.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
group_id: Group ID or path.
|
|
336
|
+
include_subgroups: Include projects from subgroups.
|
|
337
|
+
page: Page number.
|
|
338
|
+
per_page: Results per page.
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
List of repositories in the group.
|
|
342
|
+
"""
|
|
343
|
+
params = {
|
|
344
|
+
"page": page,
|
|
345
|
+
"per_page": min(per_page, 100),
|
|
346
|
+
"include_subgroups": str(include_subgroups).lower(),
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
data = await self._request(f"/groups/{group_id}/projects", params)
|
|
350
|
+
|
|
351
|
+
if isinstance(data, list):
|
|
352
|
+
return [self._parse_project(item) for item in data]
|
|
353
|
+
return []
|
|
354
|
+
|
|
355
|
+
def _parse_project(self, data: dict) -> Repository:
|
|
356
|
+
"""Parse project data from API response.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
data: API response data.
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
Repository object.
|
|
363
|
+
"""
|
|
364
|
+
# GitLab uses different field names than GitHub
|
|
365
|
+
return Repository(
|
|
366
|
+
id=data["id"],
|
|
367
|
+
name=data.get("path_with_namespace", data.get("name", "")),
|
|
368
|
+
full_name=data.get("path_with_namespace", data.get("name", "")),
|
|
369
|
+
description=data.get("description"),
|
|
370
|
+
stars=data.get("star_count", 0),
|
|
371
|
+
language=self._get_primary_language(data),
|
|
372
|
+
topics=data.get("topics", []) or data.get("tag_list", []),
|
|
373
|
+
created_at=self._parse_datetime(data.get("created_at")),
|
|
374
|
+
updated_at=self._parse_datetime(data.get("last_activity_at")),
|
|
375
|
+
pushed_at=self._parse_datetime(data.get("last_activity_at")),
|
|
376
|
+
url=data.get("web_url", ""),
|
|
377
|
+
open_issues=data.get("open_issues_count"),
|
|
378
|
+
forks=data.get("forks_count"),
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
def _get_primary_language(self, data: dict) -> Optional[str]:
|
|
382
|
+
"""Get primary language from project data.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
data: Project data.
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
Primary language name or None.
|
|
389
|
+
"""
|
|
390
|
+
# GitLab may provide languages in different places
|
|
391
|
+
languages = data.get("languages", {})
|
|
392
|
+
if languages:
|
|
393
|
+
# Return the language with highest percentage
|
|
394
|
+
return max(languages.keys(), key=lambda k: languages[k]) if languages else None
|
|
395
|
+
|
|
396
|
+
# Fallback to repository_language if available
|
|
397
|
+
return data.get("repository_language")
|
|
398
|
+
|
|
399
|
+
def _parse_datetime(self, value: Optional[str]) -> Optional[datetime]:
|
|
400
|
+
"""Parse ISO datetime string.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
value: ISO datetime string.
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
datetime object or None.
|
|
407
|
+
"""
|
|
408
|
+
if not value:
|
|
409
|
+
return None
|
|
410
|
+
|
|
411
|
+
# Handle ISO format with Z suffix
|
|
412
|
+
if value.endswith("Z"):
|
|
413
|
+
value = value[:-1] + "+00:00"
|
|
414
|
+
|
|
415
|
+
try:
|
|
416
|
+
return datetime.fromisoformat(value.replace("+00:00", ""))
|
|
417
|
+
except ValueError:
|
|
418
|
+
return None
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Rate limiter for GitHub API."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
import threading
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional
|
|
7
|
+
from collections import deque
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class RateLimitInfo:
|
|
12
|
+
"""GitHub API rate limit information."""
|
|
13
|
+
|
|
14
|
+
search_limit: int
|
|
15
|
+
search_remaining: int
|
|
16
|
+
search_reset: int
|
|
17
|
+
core_limit: int = 0
|
|
18
|
+
core_remaining: int = 0
|
|
19
|
+
core_reset: int = 0
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class RateLimiter:
|
|
23
|
+
"""Token bucket rate limiter with adaptive control."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, requests_per_hour: int = 60, safety_margin: float = 0.1):
|
|
26
|
+
"""Initialize rate limiter.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
requests_per_hour: Maximum requests per hour.
|
|
30
|
+
safety_margin: Fraction of requests to reserve.
|
|
31
|
+
"""
|
|
32
|
+
self.requests_per_hour = requests_per_hour
|
|
33
|
+
self.safety_margin = safety_margin
|
|
34
|
+
self.effective_limit = int(requests_per_hour * (1 - safety_margin))
|
|
35
|
+
self.refill_rate = self.effective_limit / 3600.0
|
|
36
|
+
self.tokens = float(self.effective_limit)
|
|
37
|
+
self.last_update = time.time()
|
|
38
|
+
|
|
39
|
+
# Statistics
|
|
40
|
+
self._total_requests = 0
|
|
41
|
+
self._total_wait_time = 0.0
|
|
42
|
+
self._request_history: deque = deque(maxlen=1000)
|
|
43
|
+
self._lock = threading.RLock()
|
|
44
|
+
|
|
45
|
+
def try_acquire(self) -> bool:
|
|
46
|
+
"""Try to acquire a token without blocking."""
|
|
47
|
+
with self._lock:
|
|
48
|
+
self._refill()
|
|
49
|
+
|
|
50
|
+
if self.tokens >= 1.0:
|
|
51
|
+
self.tokens -= 1.0
|
|
52
|
+
self._total_requests += 1
|
|
53
|
+
self._request_history.append(time.time())
|
|
54
|
+
return True
|
|
55
|
+
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
def wait_time(self) -> float:
|
|
59
|
+
"""Get time to wait for next token."""
|
|
60
|
+
with self._lock:
|
|
61
|
+
self._refill()
|
|
62
|
+
if self.tokens >= 1.0:
|
|
63
|
+
return 0.0
|
|
64
|
+
return (1.0 - self.tokens) / self.refill_rate
|
|
65
|
+
|
|
66
|
+
def acquire(self, timeout: Optional[float] = None) -> bool:
|
|
67
|
+
"""Acquire a token, blocking if necessary.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
timeout: Maximum time to wait (None = forever).
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
True if token acquired, False if timeout.
|
|
74
|
+
"""
|
|
75
|
+
start_time = time.time()
|
|
76
|
+
|
|
77
|
+
while True:
|
|
78
|
+
if self.try_acquire():
|
|
79
|
+
wait_duration = time.time() - start_time
|
|
80
|
+
with self._lock:
|
|
81
|
+
self._total_wait_time += wait_duration
|
|
82
|
+
return True
|
|
83
|
+
|
|
84
|
+
wait = min(self.wait_time(), 0.1)
|
|
85
|
+
|
|
86
|
+
if timeout is not None:
|
|
87
|
+
elapsed = time.time() - start_time
|
|
88
|
+
if elapsed + wait > timeout:
|
|
89
|
+
return False
|
|
90
|
+
|
|
91
|
+
time.sleep(wait)
|
|
92
|
+
|
|
93
|
+
def set_rate(self, requests_per_hour: int) -> None:
|
|
94
|
+
"""Update the rate limit dynamically."""
|
|
95
|
+
with self._lock:
|
|
96
|
+
self.requests_per_hour = requests_per_hour
|
|
97
|
+
self.effective_limit = int(requests_per_hour * (1 - self.safety_margin))
|
|
98
|
+
self.refill_rate = self.effective_limit / 3600.0
|
|
99
|
+
self.tokens = min(self.tokens, float(self.effective_limit))
|
|
100
|
+
|
|
101
|
+
def get_stats(self) -> dict:
|
|
102
|
+
"""Get rate limiter statistics."""
|
|
103
|
+
with self._lock:
|
|
104
|
+
return {
|
|
105
|
+
"total_requests": self._total_requests,
|
|
106
|
+
"total_wait_time": self._total_wait_time,
|
|
107
|
+
"current_tokens": self.tokens,
|
|
108
|
+
"effective_limit": self.effective_limit,
|
|
109
|
+
"requests_per_hour": self.requests_per_hour,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
def _refill(self) -> None:
|
|
113
|
+
"""Refill tokens based on elapsed time."""
|
|
114
|
+
now = time.time()
|
|
115
|
+
elapsed = now - self.last_update
|
|
116
|
+
self.tokens = min(
|
|
117
|
+
float(self.effective_limit),
|
|
118
|
+
self.tokens + elapsed * self.refill_rate
|
|
119
|
+
)
|
|
120
|
+
self.last_update = now
|