nui-python-shared-utils 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nui_lambda_shared_utils/__init__.py +252 -0
- nui_lambda_shared_utils/base_client.py +323 -0
- nui_lambda_shared_utils/cli.py +225 -0
- nui_lambda_shared_utils/cloudwatch_metrics.py +367 -0
- nui_lambda_shared_utils/config.py +136 -0
- nui_lambda_shared_utils/db_client.py +623 -0
- nui_lambda_shared_utils/error_handler.py +372 -0
- nui_lambda_shared_utils/es_client.py +460 -0
- nui_lambda_shared_utils/es_query_builder.py +315 -0
- nui_lambda_shared_utils/jwt_auth.py +277 -0
- nui_lambda_shared_utils/lambda_helpers.py +84 -0
- nui_lambda_shared_utils/log_processors.py +172 -0
- nui_lambda_shared_utils/powertools_helpers.py +263 -0
- nui_lambda_shared_utils/secrets_helper.py +187 -0
- nui_lambda_shared_utils/slack_client.py +675 -0
- nui_lambda_shared_utils/slack_formatter.py +307 -0
- nui_lambda_shared_utils/slack_setup/__init__.py +14 -0
- nui_lambda_shared_utils/slack_setup/channel_creator.py +295 -0
- nui_lambda_shared_utils/slack_setup/channel_definitions.py +187 -0
- nui_lambda_shared_utils/slack_setup/setup_helpers.py +211 -0
- nui_lambda_shared_utils/timezone.py +117 -0
- nui_lambda_shared_utils/utils.py +291 -0
- nui_python_shared_utils-1.3.0.dist-info/METADATA +470 -0
- nui_python_shared_utils-1.3.0.dist-info/RECORD +28 -0
- nui_python_shared_utils-1.3.0.dist-info/WHEEL +5 -0
- nui_python_shared_utils-1.3.0.dist-info/entry_points.txt +2 -0
- nui_python_shared_utils-1.3.0.dist-info/licenses/LICENSE +21 -0
- nui_python_shared_utils-1.3.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Elasticsearch query builder utilities for consistent query patterns across Lambda services.
|
|
3
|
+
Provides helper functions for building common ES queries used in application monitoring.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Dict, List, Optional, Union
|
|
7
|
+
from datetime import datetime, timedelta
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ESQueryBuilder:
|
|
11
|
+
"""
|
|
12
|
+
Builder for creating Elasticsearch queries with common patterns.
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
builder = ESQueryBuilder()
|
|
16
|
+
query = (builder
|
|
17
|
+
.with_time_range(start_time, end_time)
|
|
18
|
+
.with_term('environment', 'prod')
|
|
19
|
+
.with_service('order-service')
|
|
20
|
+
.add_aggregation('error_rate', 'avg', 'response.status')
|
|
21
|
+
.build()
|
|
22
|
+
)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self):
|
|
26
|
+
self.query = {"bool": {"must": [], "must_not": [], "should": [], "filter": []}}
|
|
27
|
+
self.aggregations = {}
|
|
28
|
+
self.size = 0 # Default to aggregation-only queries
|
|
29
|
+
self.sort = []
|
|
30
|
+
|
|
31
|
+
def with_time_range(self, start: datetime, end: datetime, field: str = "@timestamp") -> "ESQueryBuilder":
|
|
32
|
+
"""Add time range filter."""
|
|
33
|
+
self.query["bool"]["must"].append(
|
|
34
|
+
{
|
|
35
|
+
"range": {
|
|
36
|
+
field: {
|
|
37
|
+
"gte": start.isoformat() if hasattr(start, "isoformat") else start,
|
|
38
|
+
"lte": end.isoformat() if hasattr(end, "isoformat") else end,
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
)
|
|
43
|
+
return self
|
|
44
|
+
|
|
45
|
+
def with_term(self, field: str, value: Union[str, int, bool]) -> "ESQueryBuilder":
|
|
46
|
+
"""Add exact term match."""
|
|
47
|
+
self.query["bool"]["must"].append({"term": {field: value}})
|
|
48
|
+
return self
|
|
49
|
+
|
|
50
|
+
def with_terms(self, field: str, values: List[Union[str, int]]) -> "ESQueryBuilder":
|
|
51
|
+
"""Add terms match (OR condition for multiple values)."""
|
|
52
|
+
self.query["bool"]["must"].append({"terms": {field: values}})
|
|
53
|
+
return self
|
|
54
|
+
|
|
55
|
+
def with_service(self, service: str) -> "ESQueryBuilder":
|
|
56
|
+
"""Add service name filter."""
|
|
57
|
+
return self.with_term("service_name", service)
|
|
58
|
+
|
|
59
|
+
def with_environment(self, env: str = "prod") -> "ESQueryBuilder":
|
|
60
|
+
"""Add environment filter."""
|
|
61
|
+
return self.with_term("environment", env)
|
|
62
|
+
|
|
63
|
+
def with_error_filter(self, min_status: int = 400) -> "ESQueryBuilder":
|
|
64
|
+
"""Add filter for error responses."""
|
|
65
|
+
self.query["bool"]["must"].append({"range": {"response.status": {"gte": min_status}}})
|
|
66
|
+
return self
|
|
67
|
+
|
|
68
|
+
def exclude_pattern(self, field: str, pattern: str) -> "ESQueryBuilder":
|
|
69
|
+
"""Exclude documents matching a pattern."""
|
|
70
|
+
self.query["bool"]["must_not"].append({"wildcard": {field: pattern}})
|
|
71
|
+
return self
|
|
72
|
+
|
|
73
|
+
def with_prefix(self, field: str, prefix: str) -> "ESQueryBuilder":
|
|
74
|
+
"""Add prefix match."""
|
|
75
|
+
self.query["bool"]["must"].append({"prefix": {field: prefix}})
|
|
76
|
+
return self
|
|
77
|
+
|
|
78
|
+
def add_aggregation(self, name: str, agg_type: str, field: str, **kwargs) -> "ESQueryBuilder":
|
|
79
|
+
"""Add a simple aggregation."""
|
|
80
|
+
self.aggregations[name] = {agg_type: {"field": field, **kwargs}}
|
|
81
|
+
return self
|
|
82
|
+
|
|
83
|
+
def add_date_histogram(
|
|
84
|
+
self, name: str, field: str = "@timestamp", interval: str = "5m", **kwargs
|
|
85
|
+
) -> "ESQueryBuilder":
|
|
86
|
+
"""Add date histogram aggregation."""
|
|
87
|
+
self.aggregations[name] = {
|
|
88
|
+
"date_histogram": {"field": field, "fixed_interval": interval, "min_doc_count": 0, **kwargs}
|
|
89
|
+
}
|
|
90
|
+
return self
|
|
91
|
+
|
|
92
|
+
def add_terms_aggregation(self, name: str, field: str, size: int = 10, **kwargs) -> "ESQueryBuilder":
|
|
93
|
+
"""Add terms aggregation for top values."""
|
|
94
|
+
self.aggregations[name] = {"terms": {"field": field, "size": size, **kwargs}}
|
|
95
|
+
return self
|
|
96
|
+
|
|
97
|
+
def add_percentiles(self, name: str, field: str, percents: List[float] = None) -> "ESQueryBuilder":
|
|
98
|
+
"""Add percentiles aggregation."""
|
|
99
|
+
if percents is None:
|
|
100
|
+
percents = [50, 95, 99]
|
|
101
|
+
|
|
102
|
+
self.aggregations[name] = {"percentiles": {"field": field, "percents": percents}}
|
|
103
|
+
return self
|
|
104
|
+
|
|
105
|
+
def add_nested_aggregation(self, parent_name: str, child_aggs: Dict) -> "ESQueryBuilder":
|
|
106
|
+
"""Add nested aggregations."""
|
|
107
|
+
if parent_name not in self.aggregations:
|
|
108
|
+
raise ValueError(f"Parent aggregation '{parent_name}' not found")
|
|
109
|
+
|
|
110
|
+
self.aggregations[parent_name]["aggs"] = child_aggs
|
|
111
|
+
return self
|
|
112
|
+
|
|
113
|
+
def with_size(self, size: int) -> "ESQueryBuilder":
|
|
114
|
+
"""Set number of documents to return."""
|
|
115
|
+
self.size = size
|
|
116
|
+
return self
|
|
117
|
+
|
|
118
|
+
def add_sort(self, field: str, order: str = "desc") -> "ESQueryBuilder":
|
|
119
|
+
"""Add sort criteria."""
|
|
120
|
+
self.sort.append({field: {"order": order}})
|
|
121
|
+
return self
|
|
122
|
+
|
|
123
|
+
def build(self) -> Dict:
|
|
124
|
+
"""Build the final query."""
|
|
125
|
+
query = {"query": self.query, "size": self.size}
|
|
126
|
+
|
|
127
|
+
if self.aggregations:
|
|
128
|
+
query["aggs"] = self.aggregations
|
|
129
|
+
|
|
130
|
+
if self.sort:
|
|
131
|
+
query["sort"] = self.sort
|
|
132
|
+
|
|
133
|
+
return query
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# Pre-built query templates for common patterns
|
|
137
|
+
def build_error_rate_query(service: str, start_time: datetime, end_time: datetime, interval: str = "5m") -> Dict:
|
|
138
|
+
"""Build query for service error rate over time."""
|
|
139
|
+
builder = ESQueryBuilder()
|
|
140
|
+
return (
|
|
141
|
+
builder.with_time_range(start_time, end_time)
|
|
142
|
+
.with_service(service)
|
|
143
|
+
.with_environment("prod")
|
|
144
|
+
.add_date_histogram("error_timeline", interval=interval)
|
|
145
|
+
.add_nested_aggregation(
|
|
146
|
+
"error_timeline",
|
|
147
|
+
{
|
|
148
|
+
"total_requests": {"value_count": {"field": "request.id"}},
|
|
149
|
+
"error_requests": {
|
|
150
|
+
"filter": {"range": {"response.status": {"gte": 400}}},
|
|
151
|
+
"aggs": {"count": {"value_count": {"field": "response.status"}}},
|
|
152
|
+
},
|
|
153
|
+
"error_rate": {
|
|
154
|
+
"bucket_script": {
|
|
155
|
+
"buckets_path": {"errors": "error_requests>count", "total": "total_requests"},
|
|
156
|
+
"script": "params.errors / params.total * 100",
|
|
157
|
+
}
|
|
158
|
+
},
|
|
159
|
+
},
|
|
160
|
+
)
|
|
161
|
+
.build()
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def build_top_errors_query(service: str, start_time: datetime, end_time: datetime, top_n: int = 10) -> Dict:
|
|
166
|
+
"""Build query for top error messages."""
|
|
167
|
+
builder = ESQueryBuilder()
|
|
168
|
+
return (
|
|
169
|
+
builder.with_time_range(start_time, end_time)
|
|
170
|
+
.with_service(service)
|
|
171
|
+
.with_environment("prod")
|
|
172
|
+
.with_error_filter()
|
|
173
|
+
.add_terms_aggregation("top_errors", "error.keyword", size=top_n)
|
|
174
|
+
.add_nested_aggregation(
|
|
175
|
+
"top_errors",
|
|
176
|
+
{
|
|
177
|
+
"status_codes": {"terms": {"field": "response.status", "size": 5}},
|
|
178
|
+
"sample_error": {"top_hits": {"size": 1, "_source": ["error", "request.path", "@timestamp"]}},
|
|
179
|
+
},
|
|
180
|
+
)
|
|
181
|
+
.build()
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def build_response_time_query(service: str, start_time: datetime, end_time: datetime) -> Dict:
|
|
186
|
+
"""Build query for response time metrics."""
|
|
187
|
+
builder = ESQueryBuilder()
|
|
188
|
+
return (
|
|
189
|
+
builder.with_time_range(start_time, end_time)
|
|
190
|
+
.with_service(service)
|
|
191
|
+
.with_environment("prod")
|
|
192
|
+
.add_percentiles("response_times", "response.time", [50, 90, 95, 99])
|
|
193
|
+
.add_aggregation("avg_response_time", "avg", "response.time")
|
|
194
|
+
.add_aggregation("max_response_time", "max", "response.time")
|
|
195
|
+
.add_date_histogram("response_timeline", interval="5m")
|
|
196
|
+
.add_nested_aggregation(
|
|
197
|
+
"response_timeline",
|
|
198
|
+
{
|
|
199
|
+
"avg_time": {"avg": {"field": "response.time"}},
|
|
200
|
+
"p95_time": {"percentiles": {"field": "response.time", "percents": [95]}},
|
|
201
|
+
},
|
|
202
|
+
)
|
|
203
|
+
.build()
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def build_service_volume_query(services: List[str], start_time: datetime, end_time: datetime) -> Dict:
|
|
208
|
+
"""Build query for request volume across multiple services."""
|
|
209
|
+
builder = ESQueryBuilder()
|
|
210
|
+
return (
|
|
211
|
+
builder.with_time_range(start_time, end_time)
|
|
212
|
+
.with_terms("service_name", services)
|
|
213
|
+
.with_environment("prod")
|
|
214
|
+
.add_terms_aggregation("service_breakdown", "service_name", size=20)
|
|
215
|
+
.add_nested_aggregation(
|
|
216
|
+
"service_breakdown",
|
|
217
|
+
{
|
|
218
|
+
"request_count": {"value_count": {"field": "request.id"}},
|
|
219
|
+
"error_count": {
|
|
220
|
+
"filter": {"range": {"response.status": {"gte": 400}}},
|
|
221
|
+
"aggs": {"count": {"value_count": {"field": "response.status"}}},
|
|
222
|
+
},
|
|
223
|
+
"avg_response_time": {"avg": {"field": "response.time"}},
|
|
224
|
+
},
|
|
225
|
+
)
|
|
226
|
+
.build()
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def build_user_activity_query(start_time: datetime, end_time: datetime, user_field: str = "user.id") -> Dict:
|
|
231
|
+
"""Build query for user activity metrics."""
|
|
232
|
+
builder = ESQueryBuilder()
|
|
233
|
+
return (
|
|
234
|
+
builder.with_time_range(start_time, end_time)
|
|
235
|
+
.with_environment("prod")
|
|
236
|
+
.add_aggregation("unique_users", "cardinality", user_field)
|
|
237
|
+
.add_terms_aggregation("top_users", user_field, size=20)
|
|
238
|
+
.add_nested_aggregation(
|
|
239
|
+
"top_users",
|
|
240
|
+
{
|
|
241
|
+
"request_count": {"value_count": {"field": "request.id"}},
|
|
242
|
+
"services_used": {"cardinality": {"field": "service_name"}},
|
|
243
|
+
"last_activity": {"max": {"field": "@timestamp"}},
|
|
244
|
+
},
|
|
245
|
+
)
|
|
246
|
+
.add_date_histogram("user_timeline", interval="1h")
|
|
247
|
+
.add_nested_aggregation("user_timeline", {"active_users": {"cardinality": {"field": user_field}}})
|
|
248
|
+
.build()
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def build_pattern_detection_query(
|
|
253
|
+
pattern: str, field: str = "message", start_time: Optional[datetime] = None, hours_back: int = 24
|
|
254
|
+
) -> Dict:
|
|
255
|
+
"""Build query to detect specific patterns in logs."""
|
|
256
|
+
if not start_time:
|
|
257
|
+
start_time = datetime.utcnow() - timedelta(hours=hours_back)
|
|
258
|
+
end_time = datetime.utcnow()
|
|
259
|
+
|
|
260
|
+
return {
|
|
261
|
+
"query": {
|
|
262
|
+
"bool": {
|
|
263
|
+
"must": [
|
|
264
|
+
{"wildcard": {field: f"*{pattern}*"}},
|
|
265
|
+
{"range": {"@timestamp": {"gte": start_time.isoformat(), "lte": end_time.isoformat()}}},
|
|
266
|
+
]
|
|
267
|
+
}
|
|
268
|
+
},
|
|
269
|
+
"size": 100,
|
|
270
|
+
"sort": [{"@timestamp": {"order": "desc"}}],
|
|
271
|
+
"aggs": {
|
|
272
|
+
"service_breakdown": {"terms": {"field": "service_name", "size": 10}},
|
|
273
|
+
"timeline": {"date_histogram": {"field": "@timestamp", "fixed_interval": "1h", "min_doc_count": 1}},
|
|
274
|
+
},
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def build_tender_participant_query(tender_id: str, start_time: datetime, end_time: datetime) -> Dict:
|
|
279
|
+
"""Build query for tender participant analysis."""
|
|
280
|
+
# Expand time range for pre-tender activity
|
|
281
|
+
expanded_start = start_time - timedelta(hours=2)
|
|
282
|
+
expanded_end = end_time + timedelta(minutes=30)
|
|
283
|
+
|
|
284
|
+
return {
|
|
285
|
+
"query": {
|
|
286
|
+
"bool": {
|
|
287
|
+
"must": [
|
|
288
|
+
{"term": {"tender.id": tender_id}},
|
|
289
|
+
{"range": {"created": {"gte": expanded_start.isoformat(), "lte": expanded_end.isoformat()}}},
|
|
290
|
+
]
|
|
291
|
+
}
|
|
292
|
+
},
|
|
293
|
+
"size": 0,
|
|
294
|
+
"aggs": {
|
|
295
|
+
"participants": {
|
|
296
|
+
"terms": {"field": "division.id", "size": 100},
|
|
297
|
+
"aggs": {
|
|
298
|
+
"participant_details": {"top_hits": {"size": 1, "_source": ["division.name", "company.name"]}},
|
|
299
|
+
"bid_count": {"value_count": {"field": "created"}},
|
|
300
|
+
"first_bid": {"min": {"field": "created"}},
|
|
301
|
+
"last_bid": {"max": {"field": "created"}},
|
|
302
|
+
"bid_values": {"stats": {"field": "price"}},
|
|
303
|
+
},
|
|
304
|
+
},
|
|
305
|
+
"total_bids": {"value_count": {"field": "created"}},
|
|
306
|
+
"bid_timeline": {
|
|
307
|
+
"date_histogram": {
|
|
308
|
+
"field": "created",
|
|
309
|
+
"fixed_interval": "1m",
|
|
310
|
+
"min_doc_count": 0,
|
|
311
|
+
"extended_bounds": {"min": start_time.isoformat(), "max": end_time.isoformat()},
|
|
312
|
+
}
|
|
313
|
+
},
|
|
314
|
+
},
|
|
315
|
+
}
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JWT validation utilities for AWS Lambda functions behind API Gateway.
|
|
3
|
+
|
|
4
|
+
Uses RS256 signature verification with public keys stored in AWS Secrets Manager.
|
|
5
|
+
Requires the `rsa` package (pure Python, ~100KB) — no PyJWT or cryptography needed at runtime.
|
|
6
|
+
|
|
7
|
+
Install: pip install nui-python-shared-utils[jwt]
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import json
|
|
12
|
+
import re
|
|
13
|
+
import base64
|
|
14
|
+
import time
|
|
15
|
+
import logging
|
|
16
|
+
from typing import TYPE_CHECKING, AbstractSet, Any, Dict, Optional, Tuple
|
|
17
|
+
from urllib.parse import unquote
|
|
18
|
+
|
|
19
|
+
from .secrets_helper import get_secret
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
import rsa
|
|
23
|
+
|
|
24
|
+
log = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
JWT_CLOCK_SKEW_SECONDS = 30
|
|
27
|
+
"""Tolerance in seconds for clock differences between token issuer and validator."""
|
|
28
|
+
|
|
29
|
+
_rsa = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _require_rsa():
|
|
33
|
+
"""Lazy-import the rsa package, raising a clear error if not installed."""
|
|
34
|
+
global _rsa
|
|
35
|
+
if _rsa is None:
|
|
36
|
+
try:
|
|
37
|
+
import rsa
|
|
38
|
+
|
|
39
|
+
_rsa = rsa
|
|
40
|
+
except ImportError:
|
|
41
|
+
raise ImportError("The 'rsa' package is required for JWT support. Install with: pip install nui-python-shared-utils[jwt]") from None
|
|
42
|
+
return _rsa
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class JWTValidationError(Exception):
|
|
46
|
+
"""Base exception for JWT validation failures."""
|
|
47
|
+
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class AuthenticationError(JWTValidationError):
|
|
52
|
+
"""Authentication failed — missing/invalid token or header."""
|
|
53
|
+
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_jwt_public_key(secret_name: Optional[str] = None, key_field: str = "TOKEN_PUBLIC_KEY"):
|
|
58
|
+
"""
|
|
59
|
+
Fetch PEM public key from AWS Secrets Manager and return as rsa.PublicKey.
|
|
60
|
+
|
|
61
|
+
Relies on secrets_helper cache — repeated calls with the same secret_name
|
|
62
|
+
do not make additional Secrets Manager API calls.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
secret_name: Secrets Manager secret name. Falls back to JWT_PUBLIC_KEY_SECRET env var.
|
|
66
|
+
key_field: JSON field containing the PEM-encoded public key.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
rsa.PublicKey ready for signature verification.
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
JWTValidationError: If secret or key field is missing/invalid.
|
|
73
|
+
"""
|
|
74
|
+
rsa = _require_rsa()
|
|
75
|
+
|
|
76
|
+
secret_name = secret_name or os.environ.get("JWT_PUBLIC_KEY_SECRET")
|
|
77
|
+
if not secret_name:
|
|
78
|
+
raise JWTValidationError("No JWT public key secret name provided (set JWT_PUBLIC_KEY_SECRET or pass secret_name)")
|
|
79
|
+
|
|
80
|
+
secret = get_secret(secret_name)
|
|
81
|
+
|
|
82
|
+
pem_str = secret.get(key_field)
|
|
83
|
+
if not pem_str:
|
|
84
|
+
raise JWTValidationError(f"Field '{key_field}' not found in secret '{secret_name}'")
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
return rsa.PublicKey.load_pkcs1_openssl_pem(pem_str.encode("utf-8"))
|
|
88
|
+
except Exception as e:
|
|
89
|
+
raise JWTValidationError(
|
|
90
|
+
f"Failed to load public key from '{key_field}' — expected PKCS#8 PEM format (BEGIN PUBLIC KEY): {e}"
|
|
91
|
+
) from e
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _base64url_decode(data: str) -> bytes:
|
|
95
|
+
"""Decode base64url-encoded string (no padding required)."""
|
|
96
|
+
padding = 4 - len(data) % 4
|
|
97
|
+
if padding != 4:
|
|
98
|
+
data += "=" * padding
|
|
99
|
+
return base64.urlsafe_b64decode(data)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def validate_jwt(token: str, public_key: "rsa.PublicKey") -> dict:
|
|
103
|
+
"""
|
|
104
|
+
Decode and verify an RS256-signed JWT.
|
|
105
|
+
|
|
106
|
+
Verifies:
|
|
107
|
+
- Token structure (3 dot-separated segments)
|
|
108
|
+
- Algorithm is RS256
|
|
109
|
+
- RSA signature using the provided public key
|
|
110
|
+
- Expiration (exp claim, if present) with clock skew tolerance
|
|
111
|
+
- Not-before (nbf claim, if present) with clock skew tolerance
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
token: Raw JWT string (without "Bearer " prefix).
|
|
115
|
+
public_key: RSA public key for signature verification.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Decoded claims dict (the JWT payload).
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
JWTValidationError: On any structural, signature, or expiration failure.
|
|
122
|
+
"""
|
|
123
|
+
rsa = _require_rsa()
|
|
124
|
+
|
|
125
|
+
# Split into segments
|
|
126
|
+
parts = token.split(".")
|
|
127
|
+
if len(parts) != 3:
|
|
128
|
+
raise JWTValidationError(f"Malformed JWT: expected 3 segments, got {len(parts)}")
|
|
129
|
+
|
|
130
|
+
header_b64, payload_b64, signature_b64 = parts
|
|
131
|
+
|
|
132
|
+
# Decode header and check algorithm
|
|
133
|
+
try:
|
|
134
|
+
header = json.loads(_base64url_decode(header_b64))
|
|
135
|
+
except Exception as e:
|
|
136
|
+
raise JWTValidationError(f"Invalid JWT header: {e}") from e
|
|
137
|
+
|
|
138
|
+
alg = header.get("alg")
|
|
139
|
+
if alg != "RS256":
|
|
140
|
+
raise JWTValidationError(f"Unsupported algorithm '{alg}' — only RS256 is accepted")
|
|
141
|
+
|
|
142
|
+
# Decode signature
|
|
143
|
+
try:
|
|
144
|
+
signature = _base64url_decode(signature_b64)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
raise JWTValidationError(f"Invalid JWT signature encoding: {e}") from e
|
|
147
|
+
|
|
148
|
+
# Verify RS256 signature over "<header_b64>.<payload_b64>"
|
|
149
|
+
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
|
|
150
|
+
try:
|
|
151
|
+
hash_method = rsa.verify(signing_input, signature, public_key)
|
|
152
|
+
except rsa.VerificationError:
|
|
153
|
+
raise JWTValidationError("JWT signature verification failed") from None
|
|
154
|
+
|
|
155
|
+
if hash_method != "SHA-256":
|
|
156
|
+
raise JWTValidationError(f"Unexpected hash method '{hash_method}' — expected SHA-256 for RS256")
|
|
157
|
+
|
|
158
|
+
# Decode payload
|
|
159
|
+
try:
|
|
160
|
+
claims = json.loads(_base64url_decode(payload_b64))
|
|
161
|
+
except Exception as e:
|
|
162
|
+
raise JWTValidationError(f"Invalid JWT payload: {e}") from e
|
|
163
|
+
|
|
164
|
+
# Check expiration (with clock skew tolerance for distributed systems)
|
|
165
|
+
now = time.time()
|
|
166
|
+
exp = claims.get("exp")
|
|
167
|
+
if exp is not None and now > exp + JWT_CLOCK_SKEW_SECONDS:
|
|
168
|
+
raise JWTValidationError("JWT has expired")
|
|
169
|
+
|
|
170
|
+
# Check not-before
|
|
171
|
+
nbf = claims.get("nbf")
|
|
172
|
+
if nbf is not None and now < nbf - JWT_CLOCK_SKEW_SECONDS:
|
|
173
|
+
raise JWTValidationError("JWT is not yet valid (nbf)")
|
|
174
|
+
|
|
175
|
+
return claims
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _extract_header(event: dict, header_name: str) -> Optional[str]:
|
|
179
|
+
"""Extract a header value from an API Gateway event (v1 or v2), case-insensitive."""
|
|
180
|
+
headers = event.get("headers") or {}
|
|
181
|
+
# API Gateway v2 (HTTP API) lowercases all header names
|
|
182
|
+
# API Gateway v1 (REST) preserves original case
|
|
183
|
+
for key, value in headers.items():
|
|
184
|
+
if key.lower() == header_name.lower():
|
|
185
|
+
return value
|
|
186
|
+
return None
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def require_auth(event: dict, secret_name: Optional[str] = None) -> dict:
|
|
190
|
+
"""
|
|
191
|
+
Extract and validate a Bearer token from an API Gateway event.
|
|
192
|
+
|
|
193
|
+
Handles both API Gateway v1 (REST) and v2 (HTTP API) event formats.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
event: API Gateway Lambda proxy integration event.
|
|
197
|
+
secret_name: Optional Secrets Manager secret name for the public key.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Decoded JWT claims dict.
|
|
201
|
+
|
|
202
|
+
Raises:
|
|
203
|
+
AuthenticationError: If the Authorization header is missing, malformed, or token is invalid.
|
|
204
|
+
"""
|
|
205
|
+
auth_header = _extract_header(event, "Authorization")
|
|
206
|
+
if not auth_header:
|
|
207
|
+
raise AuthenticationError("Missing Authorization header")
|
|
208
|
+
|
|
209
|
+
if not auth_header.startswith("Bearer "):
|
|
210
|
+
raise AuthenticationError("Authorization header must use Bearer scheme")
|
|
211
|
+
|
|
212
|
+
token = auth_header[7:] # len("Bearer ") == 7
|
|
213
|
+
if not token:
|
|
214
|
+
raise AuthenticationError("Empty Bearer token")
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
public_key = get_jwt_public_key(secret_name=secret_name)
|
|
218
|
+
return validate_jwt(token, public_key)
|
|
219
|
+
except JWTValidationError as e:
|
|
220
|
+
raise AuthenticationError(f"Authentication failed: {e}") from e
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _normalize_path(path: str) -> str:
|
|
224
|
+
"""Normalize a URL path for safe comparison.
|
|
225
|
+
|
|
226
|
+
URL-decodes, collapses duplicate slashes, ensures a single leading slash,
|
|
227
|
+
and strips any trailing slash (except for root "/").
|
|
228
|
+
"""
|
|
229
|
+
path = unquote(path)
|
|
230
|
+
path = re.sub(r"/+", "/", path)
|
|
231
|
+
return "/" + path.strip("/") if path.strip("/") else "/"
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def check_auth(
|
|
235
|
+
event: dict,
|
|
236
|
+
public_paths: AbstractSet[str] = frozenset(),
|
|
237
|
+
secret_name: Optional[str] = None,
|
|
238
|
+
) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
|
239
|
+
"""Check JWT authentication on an API Gateway event, skipping public paths.
|
|
240
|
+
|
|
241
|
+
Combines path normalization, public-path bypass, JWT validation,
|
|
242
|
+
and a standard JSON:API 401 error response in one call.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
event: API Gateway Lambda proxy integration event.
|
|
246
|
+
public_paths: Set of normalized paths that skip auth (e.g. {"/health"}).
|
|
247
|
+
secret_name: Optional Secrets Manager secret name for the public key.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
(claims, None) on success — claims is the decoded JWT dict,
|
|
251
|
+
or None if the path is public.
|
|
252
|
+
(None, response) on auth failure — response is a 401 dict
|
|
253
|
+
ready to return from your Lambda handler.
|
|
254
|
+
"""
|
|
255
|
+
raw_path = event.get("path") or event.get("rawPath") or ""
|
|
256
|
+
if _normalize_path(raw_path) in public_paths:
|
|
257
|
+
return None, None
|
|
258
|
+
|
|
259
|
+
try:
|
|
260
|
+
claims = require_auth(event, secret_name=secret_name)
|
|
261
|
+
return claims, None
|
|
262
|
+
except AuthenticationError as e:
|
|
263
|
+
log.warning("Authentication failed: %s", e)
|
|
264
|
+
return None, {
|
|
265
|
+
"statusCode": 401,
|
|
266
|
+
"headers": {
|
|
267
|
+
"Content-Type": "application/json",
|
|
268
|
+
"Access-Control-Allow-Origin": "*",
|
|
269
|
+
},
|
|
270
|
+
"body": json.dumps({
|
|
271
|
+
"errors": [{
|
|
272
|
+
"status": "401",
|
|
273
|
+
"title": "Unauthorized",
|
|
274
|
+
"detail": "Authentication required",
|
|
275
|
+
}]
|
|
276
|
+
}),
|
|
277
|
+
}
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Lambda context helpers for extracting environment information.
|
|
3
|
+
|
|
4
|
+
Provides standardized environment info extraction for logging and metrics context.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from typing import Dict, Union
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = ["get_lambda_environment_info"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_lambda_environment_info() -> Dict[str, Union[str, bool]]:
|
|
15
|
+
"""
|
|
16
|
+
Extract standard Lambda environment info.
|
|
17
|
+
|
|
18
|
+
Returns a dict with Lambda runtime information useful for logging context,
|
|
19
|
+
metric dimensions, and conditional behavior based on environment.
|
|
20
|
+
|
|
21
|
+
Detection logic:
|
|
22
|
+
- `is_local`: True if AWS_LAMBDA_RUNTIME_API is not set (local dev or tests)
|
|
23
|
+
- `environment`: Derived from ENVIRONMENT env var, falls back to "unknown"
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Dict with keys:
|
|
27
|
+
- environment: "prod" | "dev" | "staging" | "unknown"
|
|
28
|
+
- aws_region: AWS region (e.g., "ap-southeast-2")
|
|
29
|
+
- function_name: Lambda function name
|
|
30
|
+
- function_version: Lambda function version (e.g., "$LATEST")
|
|
31
|
+
- memory_limit: Memory limit in MB (e.g., "512")
|
|
32
|
+
- is_local: True if running outside Lambda environment
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
>>> from nui_lambda_shared_utils import get_lambda_environment_info
|
|
36
|
+
>>> env_info = get_lambda_environment_info()
|
|
37
|
+
>>> env_info
|
|
38
|
+
{
|
|
39
|
+
"environment": "prod",
|
|
40
|
+
"aws_region": "ap-southeast-2",
|
|
41
|
+
"function_name": "my-lambda",
|
|
42
|
+
"function_version": "$LATEST",
|
|
43
|
+
"memory_limit": "512",
|
|
44
|
+
"is_local": False
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
Usage with Powertools logger:
|
|
48
|
+
>>> logger = get_powertools_logger("my-service")
|
|
49
|
+
>>> env_info = get_lambda_environment_info()
|
|
50
|
+
>>> logger.info("Starting handler", extra=env_info)
|
|
51
|
+
|
|
52
|
+
Usage with CloudWatch metrics dimensions:
|
|
53
|
+
>>> metrics = MetricsPublisher(namespace="MyService")
|
|
54
|
+
>>> env_info = get_lambda_environment_info()
|
|
55
|
+
>>> metrics.add_dimension("Environment", env_info["environment"])
|
|
56
|
+
>>> metrics.add_dimension("FunctionName", env_info["function_name"])
|
|
57
|
+
"""
|
|
58
|
+
# Detect if running in Lambda environment
|
|
59
|
+
# AWS_LAMBDA_RUNTIME_API is set by the Lambda runtime, not available locally
|
|
60
|
+
is_local = os.getenv("AWS_LAMBDA_RUNTIME_API") is None
|
|
61
|
+
|
|
62
|
+
# Environment detection
|
|
63
|
+
# Support common env var patterns: ENVIRONMENT, ENV, STAGE
|
|
64
|
+
environment = (
|
|
65
|
+
os.getenv("ENVIRONMENT")
|
|
66
|
+
or os.getenv("ENV")
|
|
67
|
+
or os.getenv("STAGE")
|
|
68
|
+
or "unknown"
|
|
69
|
+
).lower()
|
|
70
|
+
|
|
71
|
+
# Normalize common environment names
|
|
72
|
+
if environment in ("production", "prd"):
|
|
73
|
+
environment = "prod"
|
|
74
|
+
elif environment in ("development",):
|
|
75
|
+
environment = "dev"
|
|
76
|
+
|
|
77
|
+
return {
|
|
78
|
+
"environment": environment,
|
|
79
|
+
"aws_region": os.getenv("AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "")),
|
|
80
|
+
"function_name": os.getenv("AWS_LAMBDA_FUNCTION_NAME", ""),
|
|
81
|
+
"function_version": os.getenv("AWS_LAMBDA_FUNCTION_VERSION", ""),
|
|
82
|
+
"memory_limit": os.getenv("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", ""),
|
|
83
|
+
"is_local": is_local,
|
|
84
|
+
}
|