bella-openapi 1.0.2__py3-none-any.whl → 1.0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bella_openapi/__init__.py CHANGED
@@ -1,22 +1,26 @@
1
- from .authorize import validate_token, support_model, account_balance_enough, check_configuration
2
- from .log import operation_log, submit_log
3
- from .openapi_contexvar import trace_id_context, caller_id_context, request_url_context
4
- from .auth_billing import ErrorInfo, async_authenticate_decorator_args, authenticate_user, print_context, \
5
- get_context, set_context, clean_context, report
6
- __all__ = ["validate_token", "operation_log",
7
- "support_model",
8
- "account_balance_enough",
9
- "check_configuration",
10
- "trace_id_context",
11
- "caller_id_context",
12
- "request_url_context",
13
- "submit_log",
14
- "ErrorInfo",
15
- "async_authenticate_decorator_args",
16
- "authenticate_user",
17
- "print_context",
18
- "get_context",
19
- "set_context",
20
- "clean_context",
21
- "report"
22
- ]
1
+ from .authorize import validate_token, support_model, account_balance_enough, check_configuration
2
+ from .log import operation_log, submit_log
3
+ from .openapi_contexvar import trace_id_context, caller_id_context, request_url_context
4
+ from .auth_billing import ErrorInfo, async_authenticate_decorator_args, authenticate_user, print_context, \
5
+ get_context, set_context, clean_context, report
6
+ from .domtree import StandardDomTree, StandardNode
7
+
8
+ __all__ = ["validate_token", "operation_log",
9
+ "support_model",
10
+ "account_balance_enough",
11
+ "check_configuration",
12
+ "trace_id_context",
13
+ "caller_id_context",
14
+ "request_url_context",
15
+ "submit_log",
16
+ "ErrorInfo",
17
+ "async_authenticate_decorator_args",
18
+ "authenticate_user",
19
+ "print_context",
20
+ "get_context",
21
+ "set_context",
22
+ "clean_context",
23
+ "report",
24
+ "StandardDomTree",
25
+ "StandardNode"
26
+ ]
@@ -1,91 +1,91 @@
1
- from .log import operation_log
2
- from .authorize import validate_token, check_configuration, account_balance_enough
3
- from .openapi_contexvar import trace_id_context, caller_id_context, request_url_context
4
- from pydantic import BaseModel
5
- import uuid
6
- from fastapi import Request
7
- import logging
8
- # 创建一个日志记录器
9
- logger = logging.getLogger(__name__)
10
- logger.setLevel(logging.INFO) # 设置日志级别为INFO
11
-
12
- # 创建一个控制台处理器,并设置其级别和格式
13
- console_handler = logging.StreamHandler()
14
- console_handler.setLevel(logging.INFO)
15
- formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
16
- console_handler.setFormatter(formatter)
17
-
18
- # 将处理器添加到日志记录器
19
- logger.addHandler(console_handler)
20
-
21
- class ErrorInfo(BaseModel):
22
- task_id: str = ""
23
- result: str = ""
24
- status: int = 40000001
25
- message: str = ""
26
-
27
- def async_authenticate_decorator_args(end_point):
28
- def async_authenticate_decorator(func):
29
- async def wrapper(*args, **kwargs):
30
- request_arg = None
31
- for arg in args:
32
- if type(arg) == Request:
33
- request_arg = arg
34
- break
35
- if request_arg is not None:
36
- task_id = str(uuid.uuid4())
37
- supported, error_json, caller_id = authenticate_user(request_arg.headers.get("Authorization"), task_id)
38
- if not supported:
39
- return error_json
40
- t_token, c_token, r_token = set_context(task_id, caller_id, end_point)
41
- result = await func(*args, **kwargs)
42
- clean_context(t_token, c_token, r_token)
43
- return result
44
- else:
45
- logger.warn("please check your request param,have not a param's type is Request of fastapi!")
46
- return func(*args, **kwargs)
47
- return wrapper
48
- return async_authenticate_decorator
49
-
50
-
51
- def authenticate_user(token, task_id):
52
- if check_configuration():
53
- if token is None:
54
- return False, ErrorInfo(task_id=task_id, result="", status=40000001, message="token is missing").dict(), ""
55
- try:
56
- caller_id = validate_token(token)
57
- balance = account_balance_enough(token, cost=1.0)
58
-
59
- except Exception as e:
60
- return False, ErrorInfo(task_id=task_id, result="", status=40000001, message=str(e)[:100]).dict(), ""
61
- return balance, ErrorInfo(task_id=task_id, result="", status=40000001, message="Insufficient balance").dict(), caller_id
62
- else:
63
- return True, "", ""
64
-
65
-
66
- @operation_log(op_type='upload_cost_log', is_cost_log=True, ucid_key="ucid")
67
- def upload_cost_log(result_obejct, ucid):
68
- response = result_obejct.dict()
69
- return response
70
-
71
- def print_context(log):
72
- logger.info(f"{log} trace_id:{trace_id_context.get()}, caller_id:{caller_id_context.get()}, end_point:{request_url_context.get()}")
73
-
74
- def get_context():
75
- return trace_id_context.get(), caller_id_context.get(), request_url_context.get()
76
-
77
- def set_context(trace_id, caller_id, end_point):
78
- t_token = trace_id_context.set(trace_id)
79
- c_token = caller_id_context.set(caller_id)
80
- r_token = request_url_context.set(end_point)
81
- return t_token, c_token, r_token
82
-
83
- def clean_context(t_token, c_token, r_token):
84
- trace_id_context.reset(t_token)
85
- caller_id_context.reset(c_token)
86
- request_url_context.reset(r_token)
87
-
88
- def report(result_obejct, ucid = ""):
89
- if not check_configuration():
90
- return
91
- upload_cost_log(result_obejct, ucid)
1
+ from .log import operation_log
2
+ from .authorize import validate_token, check_configuration, account_balance_enough
3
+ from .openapi_contexvar import trace_id_context, caller_id_context, request_url_context
4
+ from pydantic import BaseModel
5
+ import uuid
6
+ from fastapi import Request
7
+ import logging
8
+ # 创建一个日志记录器
9
+ logger = logging.getLogger(__name__)
10
+ logger.setLevel(logging.INFO) # 设置日志级别为INFO
11
+
12
+ # 创建一个控制台处理器,并设置其级别和格式
13
+ console_handler = logging.StreamHandler()
14
+ console_handler.setLevel(logging.INFO)
15
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
16
+ console_handler.setFormatter(formatter)
17
+
18
+ # 将处理器添加到日志记录器
19
+ logger.addHandler(console_handler)
20
+
21
+ class ErrorInfo(BaseModel):
22
+ task_id: str = ""
23
+ result: str = ""
24
+ status: int = 40000001
25
+ message: str = ""
26
+
27
+ def async_authenticate_decorator_args(end_point):
28
+ def async_authenticate_decorator(func):
29
+ async def wrapper(*args, **kwargs):
30
+ request_arg = None
31
+ for arg in args:
32
+ if type(arg) == Request:
33
+ request_arg = arg
34
+ break
35
+ if request_arg is not None:
36
+ task_id = str(uuid.uuid4())
37
+ supported, error_json, caller_id = authenticate_user(request_arg.headers.get("Authorization"), task_id)
38
+ if not supported:
39
+ return error_json
40
+ t_token, c_token, r_token = set_context(task_id, caller_id, end_point)
41
+ result = await func(*args, **kwargs)
42
+ clean_context(t_token, c_token, r_token)
43
+ return result
44
+ else:
45
+ logger.warn("please check your request param,have not a param's type is Request of fastapi!")
46
+ return func(*args, **kwargs)
47
+ return wrapper
48
+ return async_authenticate_decorator
49
+
50
+
51
+ def authenticate_user(token, task_id):
52
+ if check_configuration():
53
+ if token is None:
54
+ return False, ErrorInfo(task_id=task_id, result="", status=40000001, message="token is missing").dict(), ""
55
+ try:
56
+ caller_id = validate_token(token)
57
+ balance = account_balance_enough(token, cost=1.0)
58
+
59
+ except Exception as e:
60
+ return False, ErrorInfo(task_id=task_id, result="", status=40000001, message=str(e)[:100]).dict(), ""
61
+ return balance, ErrorInfo(task_id=task_id, result="", status=40000001, message="Insufficient balance").dict(), caller_id
62
+ else:
63
+ return True, "", ""
64
+
65
+
66
+ @operation_log(op_type='upload_cost_log', is_cost_log=True, ucid_key="ucid")
67
+ def upload_cost_log(result_obejct, ucid):
68
+ response = result_obejct.dict()
69
+ return response
70
+
71
+ def print_context(log):
72
+ logger.info(f"{log} trace_id:{trace_id_context.get()}, caller_id:{caller_id_context.get()}, end_point:{request_url_context.get()}")
73
+
74
+ def get_context():
75
+ return trace_id_context.get(), caller_id_context.get(), request_url_context.get()
76
+
77
+ def set_context(trace_id, caller_id, end_point):
78
+ t_token = trace_id_context.set(trace_id)
79
+ c_token = caller_id_context.set(caller_id)
80
+ r_token = request_url_context.set(end_point)
81
+ return t_token, c_token, r_token
82
+
83
+ def clean_context(t_token, c_token, r_token):
84
+ trace_id_context.reset(t_token)
85
+ caller_id_context.reset(c_token)
86
+ request_url_context.reset(r_token)
87
+
88
+ def report(result_obejct, ucid = ""):
89
+ if not check_configuration():
90
+ return
91
+ upload_cost_log(result_obejct, ucid)
@@ -1,61 +1,61 @@
1
- import httpx
2
- import logging
3
- from .config import openapi_config
4
- from .exception import AuthorizationException
5
-
6
-
7
- def check_configuration() -> bool:
8
- res = openapi_config.OPENAPI_HOST is not None
9
- return res
10
-
11
-
12
- def validate_token(token: str) -> str:
13
- """
14
- 根据传入token, 解析用户身份
15
- :param token:
16
- :return: 用户身份,返回账户id
17
- :raises: AuthorizationException 如果token无效, 抛出异常
18
- """
19
- return _validate_token_request(token)
20
-
21
-
22
- def support_model(token: str, model: str) -> bool:
23
- """
24
- 根据传入token,判断用户是否有对应的模型权限
25
- :param token: 用户token
26
- :param model: 模型名
27
- :return: bool
28
- """
29
- url = openapi_config.OPENAPI_HOST + "/v1/openapi/support/model"
30
- # 使用httpx发送get请求
31
- response = httpx.get(url, headers={"Authorization": token}, params={"model": model})
32
- if response.status_code == 200:
33
- return response.json()['data']
34
- else:
35
- raise AuthorizationException(response.text, response.status_code)
36
-
37
-
38
- def account_balance_enough(token: str, cost: float = 0) -> bool:
39
- """
40
- 账户余额判断,支持传入指定判断阈值,服务方可以在每次用户请求前,根据本次预估花费,对用户余额进行校验,若余额不足,则可拒绝请求
41
- :param token: 用户token
42
- :param cost: 消耗金额, 单位:元/RMB
43
- :return: bool
44
- """
45
- url = openapi_config.OPENAPI_HOST + "/v1/openapi/check/account/balance"
46
- # 使用httpx发送get请求
47
- response = httpx.get(url, headers={"Authorization": token}, params={"cost": cost})
48
- if response.status_code == 200:
49
- return response.json()['data']
50
- else:
51
- raise AuthorizationException(response.text, response.status_code)
52
-
53
-
54
- def _validate_token_request(token: str) -> str:
55
- url = openapi_config.OPENAPI_HOST + "/v1/openapi/validate/tokens"
56
- # 使用httpx发送get请求
57
- response = httpx.get(url, headers={"Authorization": token})
58
- if response.status_code == 200:
59
- return response.json()['data']
60
- else:
61
- raise AuthorizationException(response.text, response.status_code)
1
+ import httpx
2
+ import logging
3
+ from .config import openapi_config
4
+ from .exception import AuthorizationException
5
+
6
+
7
+ def check_configuration() -> bool:
8
+ res = openapi_config.OPENAPI_HOST is not None
9
+ return res
10
+
11
+
12
+ def validate_token(token: str) -> str:
13
+ """
14
+ 根据传入token, 解析用户身份
15
+ :param token:
16
+ :return: 用户身份,返回账户id
17
+ :raises: AuthorizationException 如果token无效, 抛出异常
18
+ """
19
+ return _validate_token_request(token)
20
+
21
+
22
+ def support_model(token: str, model: str) -> bool:
23
+ """
24
+ 根据传入token,判断用户是否有对应的模型权限
25
+ :param token: 用户token
26
+ :param model: 模型名
27
+ :return: bool
28
+ """
29
+ url = openapi_config.OPENAPI_HOST + "/v1/openapi/support/model"
30
+ # 使用httpx发送get请求
31
+ response = httpx.get(url, headers={"Authorization": token}, params={"model": model})
32
+ if response.status_code == 200:
33
+ return response.json()['data']
34
+ else:
35
+ raise AuthorizationException(response.text, response.status_code)
36
+
37
+
38
+ def account_balance_enough(token: str, cost: float = 0) -> bool:
39
+ """
40
+ 账户余额判断,支持传入指定判断阈值,服务方可以在每次用户请求前,根据本次预估花费,对用户余额进行校验,若余额不足,则可拒绝请求
41
+ :param token: 用户token
42
+ :param cost: 消耗金额, 单位:元/RMB
43
+ :return: bool
44
+ """
45
+ url = openapi_config.OPENAPI_HOST + "/v1/openapi/check/account/balance"
46
+ # 使用httpx发送get请求
47
+ response = httpx.get(url, headers={"Authorization": token}, params={"cost": cost})
48
+ if response.status_code == 200:
49
+ return response.json()['data']
50
+ else:
51
+ raise AuthorizationException(response.text, response.status_code)
52
+
53
+
54
+ def _validate_token_request(token: str) -> str:
55
+ url = openapi_config.OPENAPI_HOST + "/v1/openapi/validate/tokens"
56
+ # 使用httpx发送get请求
57
+ response = httpx.get(url, headers={"Authorization": token})
58
+ if response.status_code == 200:
59
+ return response.json()['data']
60
+ else:
61
+ raise AuthorizationException(response.text, response.status_code)
@@ -1,13 +1,13 @@
1
- # -*- coding: utf-8 -*-
2
- # ======================
3
- # Date : 2024/12/30
4
- # Author : Liu Yuchen
5
- # Content :
6
- # 协议规范:https://doc.weixin.qq.com/doc/w3_AagAxwZdAD4dsCIEHU3RL26Knh1x8?scode=AJMA1Qc4AAwYUI6MJrAAEASgZXANE
7
- # ======================
8
- from ._context import TraceContext, TRACE_ID
9
- from .fastapi_interceptor import FastapiBellaTraceMiddleware
10
- import bella_openapi.bella_trace.trace_requests as requests
11
- from .record_log import trace, BellaTraceHandler
12
-
13
- __all__ = ["TraceContext", "TRACE_ID", "FastapiBellaTraceMiddleware", "requests", "trace", "BellaTraceHandler"]
1
+ # -*- coding: utf-8 -*-
2
+ # ======================
3
+ # Date : 2024/12/30
4
+ # Author : Liu Yuchen
5
+ # Content :
6
+ # 协议规范:https://doc.weixin.qq.com/doc/w3_AagAxwZdAD4dsCIEHU3RL26Knh1x8?scode=AJMA1Qc4AAwYUI6MJrAAEASgZXANE
7
+ # ======================
8
+ from ._context import TraceContext, TRACE_ID
9
+ from .fastapi_interceptor import FastapiBellaTraceMiddleware
10
+ import bella_openapi.bella_trace.trace_requests as requests
11
+ from .record_log import trace, BellaTraceHandler
12
+
13
+ __all__ = ["TraceContext", "TRACE_ID", "FastapiBellaTraceMiddleware", "requests", "trace", "BellaTraceHandler"]
@@ -1,61 +1,61 @@
1
- # -*- coding: utf-8 -*-
2
- # ======================
3
- # Date : 2024/12/30
4
- # Author : Liu Yuchen
5
- # Content :
6
- #
7
- # ======================
8
- import os
9
- import uuid
10
- from contextvars import ContextVar
11
-
12
- __all__ = ["TraceContext", "TRACE_ID"]
13
-
14
-
15
- _trace_id = ContextVar("bella_trace_id", default="")
16
- _mock_request = ContextVar("mock_request", default="false")
17
-
18
-
19
- TRACE_ID = "X-BELLA-TRACE-ID"
20
- MOCK_REQUEST = "X-BELLA-MOCK-REQUEST"
21
-
22
-
23
- class _TraceContext(object):
24
- @property
25
- def trace_id(self) -> str:
26
- return _trace_id.get()
27
-
28
- @trace_id.setter
29
- def trace_id(self, value):
30
- _trace_id.set(value)
31
-
32
- @property
33
- def service_id(self) -> str:
34
- return _get_service_id()
35
-
36
- @staticmethod
37
- def generate_trace_id() -> str:
38
- return f"{_get_service_id()}-{uuid.uuid4().hex}"
39
-
40
- @property
41
- def mock_request(self):
42
- return _mock_request.get()
43
-
44
- @mock_request.setter
45
- def mock_request(self, value: str):
46
- _mock_request.set(value)
47
-
48
- @property
49
- def is_mock_request(self) -> bool:
50
- return self.mock_request.lower() == "true"
51
-
52
- @property
53
- def headers(self) -> dict:
54
- return {TRACE_ID: self.trace_id, MOCK_REQUEST: self.mock_request}
55
-
56
-
57
- TraceContext = _TraceContext()
58
-
59
-
60
- def _get_service_id() -> str:
61
- return os.environ.get("SERVICE_ID", "")
1
+ # -*- coding: utf-8 -*-
2
+ # ======================
3
+ # Date : 2024/12/30
4
+ # Author : Liu Yuchen
5
+ # Content :
6
+ #
7
+ # ======================
8
+ import os
9
+ import uuid
10
+ from contextvars import ContextVar
11
+
12
+ __all__ = ["TraceContext", "TRACE_ID"]
13
+
14
+
15
+ _trace_id = ContextVar("bella_trace_id", default="")
16
+ _mock_request = ContextVar("mock_request", default="false")
17
+
18
+
19
+ TRACE_ID = "X-BELLA-TRACE-ID"
20
+ MOCK_REQUEST = "X-BELLA-MOCK-REQUEST"
21
+
22
+
23
+ class _TraceContext(object):
24
+ @property
25
+ def trace_id(self) -> str:
26
+ return _trace_id.get()
27
+
28
+ @trace_id.setter
29
+ def trace_id(self, value):
30
+ _trace_id.set(value)
31
+
32
+ @property
33
+ def service_id(self) -> str:
34
+ return _get_service_id()
35
+
36
+ @staticmethod
37
+ def generate_trace_id() -> str:
38
+ return f"{_get_service_id()}-{uuid.uuid4().hex}"
39
+
40
+ @property
41
+ def mock_request(self):
42
+ return _mock_request.get()
43
+
44
+ @mock_request.setter
45
+ def mock_request(self, value: str):
46
+ _mock_request.set(value)
47
+
48
+ @property
49
+ def is_mock_request(self) -> bool:
50
+ return self.mock_request.lower() == "true"
51
+
52
+ @property
53
+ def headers(self) -> dict:
54
+ return {TRACE_ID: self.trace_id, MOCK_REQUEST: self.mock_request}
55
+
56
+
57
+ TraceContext = _TraceContext()
58
+
59
+
60
+ def _get_service_id() -> str:
61
+ return os.environ.get("SERVICE_ID", "")
@@ -1,28 +1,28 @@
1
- # -*- coding: utf-8 -*-
2
- # ======================
3
- # Date : 2024/12/30
4
- # Author : Liu Yuchen
5
- # Content :
6
- #
7
- # ======================
8
- from starlette.middleware.base import BaseHTTPMiddleware
9
-
10
- from ._context import TraceContext, TRACE_ID, MOCK_REQUEST
11
-
12
- __all__ = ["FastapiBellaTraceMiddleware"]
13
-
14
-
15
- class FastapiBellaTraceMiddleware(BaseHTTPMiddleware):
16
- async def dispatch(self, request, call_next):
17
- # 设置 trace_id
18
- if trace_id_h := request.headers.get(TRACE_ID):
19
- TraceContext.trace_id = trace_id_h
20
- else:
21
- TraceContext.trace_id = TraceContext.generate_trace_id()
22
-
23
- if mock_request_h := request.headers.get(MOCK_REQUEST):
24
- TraceContext.mock_request = mock_request_h
25
-
26
- # 继续处理请求
27
- return await call_next(request)
28
-
1
+ # -*- coding: utf-8 -*-
2
+ # ======================
3
+ # Date : 2024/12/30
4
+ # Author : Liu Yuchen
5
+ # Content :
6
+ #
7
+ # ======================
8
+ from starlette.middleware.base import BaseHTTPMiddleware
9
+
10
+ from ._context import TraceContext, TRACE_ID, MOCK_REQUEST
11
+
12
+ __all__ = ["FastapiBellaTraceMiddleware"]
13
+
14
+
15
+ class FastapiBellaTraceMiddleware(BaseHTTPMiddleware):
16
+ async def dispatch(self, request, call_next):
17
+ # 设置 trace_id
18
+ if trace_id_h := request.headers.get(TRACE_ID):
19
+ TraceContext.trace_id = trace_id_h
20
+ else:
21
+ TraceContext.trace_id = TraceContext.generate_trace_id()
22
+
23
+ if mock_request_h := request.headers.get(MOCK_REQUEST):
24
+ TraceContext.mock_request = mock_request_h
25
+
26
+ # 继续处理请求
27
+ return await call_next(request)
28
+