idds-common 2.0.4__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
idds/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ #!/usr/bin/env python
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # You may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ # http://www.apache.org/licenses/LICENSE-2.0OA
7
+ #
8
+ # Authors:
9
+ # - Wen Guan, <wen.guan@cern.ch>, 2019
@@ -0,0 +1,9 @@
1
+ #!/usr/bin/env python
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # You may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ # http://www.apache.org/licenses/LICENSE-2.0OA
7
+ #
8
+ # Authors:
9
+ # - Wen Guan, <wen.guan@cern.ch>, 2019
@@ -0,0 +1,610 @@
1
+ #!/usr/bin/env python
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # You may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ # http://www.apache.org/licenses/LICENSE-2.0OA
7
+ #
8
+ # Authors:
9
+ # - Wen Guan, <wen.guan@@cern.ch>, 2021 - 2023
10
+
11
+ import datetime
12
+ import base64
13
+ import json
14
+ import jwt
15
+ import os
16
+ import re
17
+ import requests
18
+ import time
19
+
20
+
21
+ try:
22
+ import ConfigParser
23
+ except ImportError:
24
+ import configparser as ConfigParser
25
+
26
+ try:
27
+ # Python 2
28
+ from urllib import urlencode
29
+ except ImportError:
30
+ # Python 3
31
+ from urllib.parse import urlencode
32
+ raw_input = input
33
+
34
+ # from cryptography import x509
35
+ from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
36
+ from cryptography.hazmat.backends import default_backend
37
+ from cryptography.hazmat.primitives import serialization
38
+
39
+ # from idds.common import exceptions
40
+ from idds.common.constants import HTTP_STATUS_CODE
41
+
42
+
43
+ def decode_value(val):
44
+ if isinstance(val, str):
45
+ val = val.encode()
46
+ decoded = base64.urlsafe_b64decode(val + b'==')
47
+ return int.from_bytes(decoded, 'big')
48
+
49
+
50
+ def should_verify(no_verify=False, ssl_verify=None):
51
+ if no_verify:
52
+ return False
53
+ if os.environ.get('IDDS_AUTH_NO_VERIFY', None):
54
+ return False
55
+
56
+ if os.environ.get('IDDS_AUTH_SSL_VERIFY', None):
57
+ return os.environ.get('IDDS_AUTH_SSL_VERIFY', None)
58
+
59
+ if ssl_verify:
60
+ return ssl_verify
61
+
62
+ return True
63
+
64
+
65
+ class Singleton(object):
66
+ _instance = None
67
+
68
+ def __new__(class_, *args, **kwargs):
69
+ if not isinstance(class_._instance, class_):
70
+ class_._instance = object.__new__(class_, *args, **kwargs)
71
+ class_._instance._initialized = False
72
+ return class_._instance
73
+
74
+
75
+ class BaseAuthentication(Singleton):
76
+ def __init__(self, timeout=None):
77
+ self.timeout = timeout
78
+ self.config = self.load_auth_server_config()
79
+ self.max_expires_in = 60
80
+
81
+ self.cache = {}
82
+ self.cache_time = 3600 * 6
83
+
84
+ if self.config and self.config.has_section('common'):
85
+ if self.config.has_option('common', 'max_expires_in'):
86
+ self.max_expires_in = self.config.getint('common', 'max_expires_in')
87
+
88
+ if self.config and self.config.has_section('common'):
89
+ if self.config.has_option('common', 'cache_time'):
90
+ self.cache_time = self.config.getint('common', 'cache_time')
91
+
92
+ def get_cache_value(self, key):
93
+ if key in self.cache and self.cache[key]['time'] + self.cache_time > time.time():
94
+ return self.cache[key]['value']
95
+ return None
96
+
97
+ def set_cache_value(self, key, value):
98
+ cache_keys = list(self.cache.keys())
99
+ for k in cache_keys:
100
+ if self.cache[k]['time'] + self.cache_time <= time.time():
101
+ del self.cache[k]
102
+ self.cache[key] = {'time': time.time(), 'value': value}
103
+
104
+ def load_auth_server_config(self):
105
+ config = ConfigParser.ConfigParser()
106
+ if os.environ.get('IDDS_AUTH_CONFIG', None):
107
+ configfile = os.environ['IDDS_AUTH_CONFIG']
108
+ if config.read(configfile) == [configfile]:
109
+ return config
110
+
111
+ configfiles = ['%s/etc/idds/auth/auth.cfg' % os.environ.get('IDDS_HOME', ''),
112
+ '/etc/idds/auth/auth.cfg', '/opt/idds/etc/idds/auth/auth.cfg',
113
+ '%s/etc/idds/auth/auth.cfg' % os.environ.get('VIRTUAL_ENV', '')]
114
+ for configfile in configfiles:
115
+ if config.read(configfile) == [configfile]:
116
+ return config
117
+ return config
118
+
119
+ def get_allow_vos(self):
120
+ section = 'common'
121
+ allow_vos = []
122
+ if self.config and self.config.has_section(section):
123
+ if self.config.has_option(section, 'allow_vos'):
124
+ allow_vos_temp = self.config.get(section, 'allow_vos')
125
+ allow_vos_temp = allow_vos_temp.split(',')
126
+ for t in allow_vos_temp:
127
+ t = t.strip()
128
+ allow_vos.append(t)
129
+ return allow_vos
130
+
131
+ def get_ssl_verify(self):
132
+ section = 'common'
133
+ ssl_verify = None
134
+ if self.config and self.config.has_section(section):
135
+ if self.config.has_option(section, 'ssl_verify'):
136
+ ssl_verify = self.config.get(section, 'ssl_verify')
137
+ return ssl_verify
138
+
139
+
140
+ class OIDCAuthentication(BaseAuthentication):
141
+ def __init__(self, timeout=None):
142
+ super(OIDCAuthentication, self).__init__(timeout=timeout)
143
+
144
+ def get_auth_config(self, vo):
145
+ ret = self.get_cache_value(vo)
146
+ if ret:
147
+ return ret
148
+
149
+ ret = {'vo': vo, 'oidc_config_url': None, 'client_id': None,
150
+ 'client_secret': None, 'audience': None, 'no_verify': False}
151
+
152
+ if self.config and self.config.has_section(vo):
153
+ for name in ['oidc_config_url', 'client_id', 'client_secret', 'vo', 'audience']:
154
+ if self.config.has_option(vo, name):
155
+ ret[name] = self.config.get(vo, name)
156
+ for name in ['no_verify']:
157
+ if self.config.has_option(vo, name):
158
+ ret[name] = self.config.getboolean(vo, name)
159
+ return ret
160
+
161
+ def get_http_content(self, url, no_verify=False):
162
+ try:
163
+ r = requests.get(url, allow_redirects=True, verify=should_verify(no_verify, self.get_ssl_verify()))
164
+ return r.content
165
+ except Exception as error:
166
+ return False, 'Failed to get http content for %s: %s' (str(url), str(error))
167
+
168
+ def get_endpoint_config(self, auth_config):
169
+ content = self.get_http_content(auth_config['oidc_config_url'], no_verify=auth_config['no_verify'])
170
+ endpoint_config = json.loads(content)
171
+ # ret = {'token_endpoint': , 'device_authorization_endpoint': None}
172
+ return endpoint_config
173
+
174
+ def get_auth_endpoint_config(self, vo):
175
+ auth_config = self.get_cache_value(vo)
176
+ endpoint_config_key = vo + "_endpoint_config"
177
+ endpoint_config = self.get_cache_value(endpoint_config_key)
178
+
179
+ if not auth_config or not endpoint_config:
180
+ allow_vos = self.get_allow_vos()
181
+ if vo not in allow_vos:
182
+ return False, "VO %s is not allowed." % vo
183
+
184
+ auth_config = self.get_auth_config(vo)
185
+ endpoint_config = self.get_endpoint_config(auth_config)
186
+
187
+ self.set_cache_value(vo, auth_config)
188
+ self.set_cache_value(endpoint_config_key, endpoint_config)
189
+ return auth_config, endpoint_config
190
+
191
+ def get_oidc_sign_url(self, vo):
192
+ try:
193
+ auth_config, endpoint_config = self.get_auth_endpoint_config(vo)
194
+
195
+ data = {'client_id': auth_config['client_id'],
196
+ 'scope': "openid profile email offline_access",
197
+ 'audience': auth_config['audience']}
198
+
199
+ headers = {'content-type': 'application/x-www-form-urlencoded'}
200
+
201
+ result = requests.session().post(endpoint_config['device_authorization_endpoint'],
202
+ # data=json.dumps(data),
203
+ urlencode(data).encode(),
204
+ timeout=self.timeout,
205
+ verify=should_verify(auth_config['no_verify'], self.get_ssl_verify()),
206
+ headers=headers)
207
+
208
+ if result is not None:
209
+ if result.status_code == HTTP_STATUS_CODE.OK and result.text:
210
+ return True, json.loads(result.text)
211
+ else:
212
+ return False, "Failed to get oidc sign in URL (status: %s, text: %s)" % (result.status_code, result.text)
213
+ else:
214
+ return False, "Failed to get oidc sign in URL. Response is None."
215
+ except requests.exceptions.ConnectionError as error:
216
+ return False, 'Failed to get oidc sign in URL. ConnectionError: ' + str(error)
217
+ except Exception as error:
218
+ return False, 'Failed to get oidc sign in URL: ' + str(error)
219
+
220
+ def get_id_token(self, vo, device_code, interval=5, expires_in=60):
221
+ try:
222
+ auth_config, endpoint_config = self.get_auth_endpoint_config(vo)
223
+
224
+ data = {'client_id': auth_config['client_id'],
225
+ 'client_secret': auth_config['client_secret'],
226
+ 'grant_type': 'urn:ietf:params:oauth:grant-type:device_code',
227
+ 'device_code': device_code}
228
+
229
+ headers = {'content-type': 'application/x-www-form-urlencoded'}
230
+
231
+ if not interval:
232
+ interval = 5
233
+ interval = int(interval)
234
+
235
+ if not expires_in:
236
+ expires_in = 60
237
+ expires_in = int(expires_in)
238
+ if expires_in > self.max_expires_in:
239
+ expires_in = self.max_expires_in
240
+
241
+ result = requests.session().post(endpoint_config['token_endpoint'],
242
+ # data=json.dumps(data),
243
+ urlencode(data).encode(),
244
+ timeout=self.timeout,
245
+ verify=should_verify(auth_config['no_verify'], self.get_ssl_verify()),
246
+ headers=headers)
247
+ if result is not None:
248
+ if result.status_code == HTTP_STATUS_CODE.OK and result.text:
249
+ return True, json.loads(result.text)
250
+ else:
251
+ return False, json.loads(result.text)
252
+ else:
253
+ return False, None
254
+ except Exception as error:
255
+ return False, 'Failed to get oidc token: ' + str(error)
256
+
257
+ def refresh_id_token(self, vo, refresh_token):
258
+ try:
259
+ auth_config, endpoint_config = self.get_auth_endpoint_config(vo)
260
+
261
+ data = {'client_id': auth_config['client_id'],
262
+ 'client_secret': auth_config['client_secret'],
263
+ 'grant_type': 'refresh_token',
264
+ 'refresh_token': refresh_token}
265
+
266
+ headers = {'content-type': 'application/x-www-form-urlencoded'}
267
+
268
+ result = requests.session().post(endpoint_config['token_endpoint'],
269
+ # data=json.dumps(data),
270
+ urlencode(data).encode(),
271
+ timeout=self.timeout,
272
+ verify=should_verify(auth_config['no_verify'], self.get_ssl_verify()),
273
+ headers=headers)
274
+
275
+ if result is not None:
276
+ if result.status_code == HTTP_STATUS_CODE.OK and result.text:
277
+ return True, json.loads(result.text)
278
+ else:
279
+ return False, "Failed to refresh oidc token (status: %s, text: %s)" % (result.status_code, result.text)
280
+ else:
281
+ return False, "Failed to refresh oidc token. Response is None."
282
+ except requests.exceptions.ConnectionError as error:
283
+ return False, 'Failed to refresh oidc token. ConnectionError: ' + str(error)
284
+ except Exception as error:
285
+ return False, 'Failed to refresh oidc token: ' + str(error)
286
+
287
+ def get_public_key(self, token, jwks_uri, no_verify=False):
288
+ headers = jwt.get_unverified_header(token)
289
+ if headers is None or 'kid' not in headers:
290
+ raise jwt.exceptions.InvalidTokenError('cannot extract kid from headers')
291
+ kid = headers['kid']
292
+
293
+ jwks = self.get_cache_value(jwks_uri)
294
+ if not jwks:
295
+ jwks_content = self.get_http_content(jwks_uri, no_verify=no_verify)
296
+ jwks = json.loads(jwks_content)
297
+ self.set_cache_value(jwks_uri, jwks)
298
+
299
+ jwk = None
300
+ for j in jwks.get('keys', []):
301
+ if j.get('kid') == kid:
302
+ jwk = j
303
+ if jwk is None:
304
+ raise jwt.exceptions.InvalidTokenError('JWK not found for kid={0}: {1}'.format(kid, str(jwks)))
305
+
306
+ public_num = RSAPublicNumbers(n=decode_value(jwk['n']), e=decode_value(jwk['e']))
307
+ public_key = public_num.public_key(default_backend())
308
+ pem = public_key.public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo)
309
+ return pem
310
+
311
+ def verify_id_token(self, vo, token):
312
+ try:
313
+ auth_config, endpoint_config = self.get_auth_endpoint_config(vo)
314
+
315
+ # check audience
316
+ decoded_token = jwt.decode(token, verify=False, options={"verify_signature": False})
317
+ audience = decoded_token['aud']
318
+ if audience not in [auth_config['audience'], auth_config['client_id']]:
319
+ # discovery_endpoint = auth_config['oidc_config_url']
320
+ return False, "The audience %s of the token doesn't match vo configuration(client_id: %s)." % (audience, auth_config['client_id']), None
321
+
322
+ public_key = self.get_public_key(token, endpoint_config['jwks_uri'], no_verify=auth_config['no_verify'])
323
+ # decode token only with RS256
324
+ if 'iss' in decoded_token and decoded_token['iss'] and decoded_token['iss'] != endpoint_config['issuer'] and endpoint_config['issuer'].startswith(decoded_token['iss']):
325
+ # iss is missing the last '/' in access tokens
326
+ issuer = decoded_token['iss']
327
+ else:
328
+ issuer = endpoint_config['issuer']
329
+
330
+ decoded = jwt.decode(token, public_key, verify=True, algorithms='RS256',
331
+ audience=audience, issuer=issuer)
332
+ decoded['vo'] = vo
333
+ if 'name' in decoded:
334
+ username = decoded['name']
335
+ else:
336
+ username = None
337
+ return True, decoded, username
338
+ except Exception as error:
339
+ return False, 'Failed to verify oidc token: ' + str(error), None
340
+
341
+ def setup_oidc_client_token(self, issuer, client_id, client_secret, scope, audience):
342
+ try:
343
+ data = {'client_id': client_id,
344
+ 'client_secret': client_secret,
345
+ 'grant_type': 'client_credentials',
346
+ 'scope': scope,
347
+ 'audience': audience}
348
+
349
+ headers = {'content-type': 'application/x-www-form-urlencoded'}
350
+
351
+ endpoint = '{0}/token'.format(issuer)
352
+ result = requests.session().post(endpoint,
353
+ # data=json.dumps(data),
354
+ # data=data,
355
+ urlencode(data).encode(),
356
+ timeout=self.timeout,
357
+ verify=should_verify(ssl_verify=self.get_ssl_verify()),
358
+ headers=headers)
359
+
360
+ if result is not None:
361
+ # print(result)
362
+ # print(result.text)
363
+ # print(result.status_code)
364
+
365
+ if result.status_code == HTTP_STATUS_CODE.OK and result.text:
366
+ return True, json.loads(result.text)
367
+ else:
368
+ return False, "Failed to refresh oidc token (status: %s, text: %s)" % (result.status_code, result.text)
369
+ else:
370
+ return False, "Failed to refresh oidc token. Response is None."
371
+ except requests.exceptions.ConnectionError as error:
372
+ return False, 'Failed to refresh oidc token. ConnectionError: ' + str(error)
373
+ except Exception as error:
374
+ return False, 'Failed to refresh oidc token: ' + str(error)
375
+
376
+
377
+ class OIDCAuthenticationUtils(object):
378
+ def __init__(self):
379
+ pass
380
+
381
+ def save_token(self, path, token):
382
+ try:
383
+ with open(path, 'w') as f:
384
+ f.write(json.dumps(token))
385
+ return True, None
386
+ except Exception as error:
387
+ return False, "Failed to save token: %s" % str(error)
388
+
389
+ def load_token(self, path):
390
+ try:
391
+ with open(path) as f:
392
+ data = json.load(f)
393
+ return True, data
394
+ except Exception as error:
395
+ return False, "Failed to load token: %s" % str(error)
396
+
397
+ def is_token_expired(self, token):
398
+ try:
399
+ enc = token['id_token'].split('.')[1]
400
+ enc += '=' * (-len(enc) % 4)
401
+ dec = json.loads(base64.urlsafe_b64decode(enc.encode()))
402
+ exp_time = datetime.datetime.utcfromtimestamp(dec['exp'])
403
+ # delta = exp_time - datetime.datetime.utcnow()
404
+ if exp_time < datetime.datetime.utcnow():
405
+ return True, None
406
+ else:
407
+ return False, None
408
+ except Exception as error:
409
+ return True, "Failed to parse token: %s" % str(error)
410
+
411
+ def clean_token(self, path):
412
+ try:
413
+ os.remove(path)
414
+ return True, None
415
+ except Exception as error:
416
+ return False, "Failed to clean token: %s" % str(error)
417
+
418
+ def get_token_info(self, token):
419
+ try:
420
+ # enc = token['id_token'].split('.')[1]
421
+ enc = token.split('.')[1]
422
+ enc += '=' * (-len(enc) % 4)
423
+ dec = json.loads(base64.urlsafe_b64decode(enc.encode()))
424
+ exp_time = datetime.datetime.utcfromtimestamp(dec['exp'])
425
+ delta = exp_time - datetime.datetime.utcnow()
426
+ minutes = delta.total_seconds() / 60
427
+
428
+ info = dec
429
+ info['expire'] = exp_time
430
+ info['expire_time'] = 'Token will expire in %s minutes' % minutes
431
+ info['expire_at'] = 'Token will expire at {0} UTC'.format(exp_time.strftime("%Y-%m-%d %H:%M:%S"))
432
+ return True, info
433
+ except Exception as error:
434
+ return True, "Failed to parse token: %s" % str(error)
435
+
436
+
437
+ class X509Authentication(BaseAuthentication):
438
+ def __init__(self, timeout=None):
439
+ super(X509Authentication, self).__init__(timeout=timeout)
440
+
441
+ def get_ban_user_list(self):
442
+ section = "Users"
443
+ option = "ban_users"
444
+ if self.config and self.config.has_section(section):
445
+ if self.config.has_option(section, option):
446
+ users = self.config.get(section, option)
447
+ users = users.split(",")
448
+ return users
449
+ return []
450
+
451
+ def get_allow_user_list(self):
452
+ section = "Users"
453
+ option = "allow_users"
454
+ if self.config and self.config.has_section(section):
455
+ if self.config.has_option(section, option):
456
+ users = self.config.get(section, option)
457
+ users = users.split(",")
458
+ return users
459
+ return []
460
+
461
+ def get_super_user_list(self):
462
+ section = "Users"
463
+ option = "super_users"
464
+ if self.config and self.config.has_section(section):
465
+ if self.config.has_option(section, option):
466
+ users = self.config.get(section, option)
467
+ users = users.split(",")
468
+ return users
469
+ return []
470
+
471
+
472
+ # "/DC=ch/DC=cern/OU=Organic Units/OU=Users/CN=wguan/CN=667815/CN=Wen Guan/CN=1883443395"
473
+ def get_user_name_from_dn1(dn):
474
+ try:
475
+ up = re.compile('/(DC|O|OU|C|L)=[^\/]+') # noqa W605
476
+ username = up.sub('', dn)
477
+ up2 = re.compile('/CN=[0-9]+')
478
+ username = up2.sub('', username)
479
+ up2 = re.compile('/CN=[0-9]+')
480
+ username = up2.sub('', username)
481
+ up3 = re.compile(' [0-9]+')
482
+ username = up3.sub('', username)
483
+ up4 = re.compile('_[0-9]+')
484
+ username = up4.sub('', username)
485
+ username = username.replace('/CN=proxy', '')
486
+ username = username.replace('/CN=limited proxy', '')
487
+ username = username.replace('limited proxy', '')
488
+ username = re.sub('/CN=Robot:[^/]+', '', username)
489
+ username = re.sub('/CN=Robot[^/]+', '', username)
490
+ username = re.sub('/CN=nickname:[^/]+', '', username)
491
+ pat = re.compile('.*/CN=([^\/]+)/CN=([^\/]+)') # noqa W605
492
+ mat = pat.match(username)
493
+ if mat:
494
+ username = mat.group(2)
495
+ else:
496
+ username = username.replace('/CN=', '')
497
+ if username.lower().find('/email') > 0:
498
+ username = username[:username.lower().find('/email')]
499
+ pat = re.compile('.*(limited.*proxy).*')
500
+ mat = pat.match(username)
501
+ if mat:
502
+ username = mat.group(1)
503
+ username = username.replace('(', '')
504
+ username = username.replace(')', '')
505
+ username = username.replace("'", '')
506
+ return username
507
+ except Exception:
508
+ return dn
509
+
510
+
511
+ # 'CN=203633261,CN=Wen Guan,CN=667815,CN=wguan,OU=Users,OU=Organic Units,DC=cern,DC=ch'
512
+ def get_user_name_from_dn2(dn):
513
+ try:
514
+ up = re.compile(',(DC|O|OU|C|L)=[^\,]+') # noqa W605
515
+ username = up.sub('', dn)
516
+ up2 = re.compile(',CN=[0-9]+')
517
+ username = up2.sub('', username)
518
+ up2 = re.compile('CN=[0-9]+,')
519
+ username = up2.sub(',', username)
520
+ up3 = re.compile(' [0-9]+')
521
+ username = up3.sub('', username)
522
+ up4 = re.compile('_[0-9]+')
523
+ username = up4.sub('', username)
524
+ username = username.replace(',CN=proxy', '')
525
+ username = username.replace(',CN=limited proxy', '')
526
+ username = username.replace('limited proxy', '')
527
+ username = re.sub(',CN=Robot:[^/]+,', ',', username)
528
+ username = re.sub(',CN=Robot:[^/]+', '', username)
529
+ username = re.sub(',CN=Robot[^/]+,', ',', username)
530
+ username = re.sub(',CN=Robot[^/]+', '', username)
531
+ username = re.sub(',CN=nickname:[^/]+,', ',', username)
532
+ username = re.sub(',CN=nickname:[^/]+', '', username)
533
+ pat = re.compile('.*,CN=([^\,]+),CN=([^\,]+)') # noqa W605
534
+ mat = pat.match(username)
535
+ if mat:
536
+ username = mat.group(1)
537
+ else:
538
+ username = username.replace(',CN=', '')
539
+ if username.lower().find(',email') > 0:
540
+ username = username[:username.lower().find(',email')]
541
+ pat = re.compile('.*(limited.*proxy).*')
542
+ mat = pat.match(username)
543
+ if mat:
544
+ username = mat.group(1)
545
+ username = username.replace('(', '')
546
+ username = username.replace(')', '')
547
+ username = username.replace("'", '')
548
+ return username
549
+ except Exception:
550
+ return dn
551
+
552
+
553
+ def get_user_name_from_dn(dn):
554
+ dn = get_user_name_from_dn1(dn)
555
+ dn = get_user_name_from_dn2(dn)
556
+ return dn
557
+
558
+
559
+ def authenticate_x509(vo, dn, client_cert):
560
+ if not dn:
561
+ return False, "User DN cannot be found.", None
562
+ if not client_cert:
563
+ return False, "Client certificate proxy cannot be found.", None
564
+
565
+ # certDecoded = x509.load_pem_x509_certificate(str.encode(client_cert), default_backend())
566
+ # print(certDecoded.issuer)
567
+ # for ext in certDecoded.extensions:
568
+ # print(ext)
569
+ allow_user_list = X509Authentication().get_allow_user_list()
570
+ matched = False
571
+ for allow_user in allow_user_list:
572
+ # pat = re.compile(allow_user)
573
+ # mat = pat.match(dn)
574
+ mat = dn.find(allow_user)
575
+ if mat > -1:
576
+ matched = True
577
+ break
578
+
579
+ if not matched:
580
+ return False, "User %s is not allowed" % str(dn), None
581
+
582
+ if matched:
583
+ # username = get_user_name_from_dn(dn)
584
+ ban_user_list = X509Authentication().get_ban_user_list()
585
+ for ban_user in ban_user_list:
586
+ pat = re.compile(ban_user)
587
+ mat = pat.match(dn)
588
+ if mat:
589
+ return False, "User %s is banned" % str(dn), None
590
+ username = get_user_name_from_dn(dn)
591
+ return True, None, username
592
+
593
+
594
+ def authenticate_oidc(vo, token):
595
+ oidc_auth = OIDCAuthentication()
596
+ status, data, username = oidc_auth.verify_id_token(vo, token)
597
+ if status:
598
+ return status, data, username
599
+ else:
600
+ return status, data, username
601
+
602
+
603
+ def authenticate_is_super_user(username, dn=None):
604
+ super_user_list = X509Authentication().get_super_user_list()
605
+ for super_user in super_user_list:
606
+ if username == super_user:
607
+ return True
608
+ if dn and super_user in dn:
609
+ return True
610
+ return False
idds/common/cache.py ADDED
@@ -0,0 +1,60 @@
1
+ #!/usr/bin/env python
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # You may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ # http://www.apache.org/licenses/LICENSE-2.0OA
7
+ #
8
+ # Authors:
9
+ # - Wen Guan, <wen.guan@cern.ch>, 2022
10
+
11
+ import json
12
+ from dogpile.cache import make_region
13
+
14
+ from idds.common.config import config_has_section, config_has_option, config_get
15
+
16
+
17
+ def get_cache_url():
18
+ if config_has_section('cache') and config_has_option('cache', 'url'):
19
+ return config_get('cache', 'url')
20
+ return '127.0.0.1:11211'
21
+
22
+
23
+ def make_region_memcached(expiration_time, function_key_generator=None):
24
+ """
25
+ Make and configure a dogpile.cache.memcached region
26
+ """
27
+ if function_key_generator:
28
+ region = make_region(function_key_generator=function_key_generator)
29
+ else:
30
+ region = make_region()
31
+
32
+ region.configure(
33
+ 'dogpile.cache.memcached',
34
+ expiration_time=expiration_time,
35
+ arguments={
36
+ 'url': get_cache_url,
37
+ 'distributed_lock': True,
38
+ 'memcached_expire_time': expiration_time + 60, # must be bigger than expiration_time
39
+ }
40
+ )
41
+
42
+ return region
43
+
44
+
45
+ REGION = make_region_memcached(expiration_time=3600)
46
+
47
+
48
+ def update_cache(key, data):
49
+ REGION.set(key, json.dumps(data))
50
+
51
+
52
+ def get_cache(key):
53
+ data = REGION.get(key)
54
+ if data:
55
+ return json.loads(data)
56
+ return data
57
+
58
+
59
+ def delete_cache(key):
60
+ REGION.delete(key)