lexsi-sdk 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lexsi_sdk/__init__.py +5 -0
- lexsi_sdk/client/__init__.py +0 -0
- lexsi_sdk/client/client.py +176 -0
- lexsi_sdk/common/__init__.py +0 -0
- lexsi_sdk/common/config/.env.prod +3 -0
- lexsi_sdk/common/constants.py +143 -0
- lexsi_sdk/common/enums.py +8 -0
- lexsi_sdk/common/environment.py +49 -0
- lexsi_sdk/common/monitoring.py +81 -0
- lexsi_sdk/common/trigger.py +75 -0
- lexsi_sdk/common/types.py +122 -0
- lexsi_sdk/common/utils.py +93 -0
- lexsi_sdk/common/validation.py +110 -0
- lexsi_sdk/common/xai_uris.py +197 -0
- lexsi_sdk/core/__init__.py +0 -0
- lexsi_sdk/core/agent.py +62 -0
- lexsi_sdk/core/alert.py +56 -0
- lexsi_sdk/core/case.py +618 -0
- lexsi_sdk/core/dashboard.py +131 -0
- lexsi_sdk/core/guardrails/__init__.py +0 -0
- lexsi_sdk/core/guardrails/guard_template.py +299 -0
- lexsi_sdk/core/guardrails/guardrail_autogen.py +554 -0
- lexsi_sdk/core/guardrails/guardrails_langgraph.py +525 -0
- lexsi_sdk/core/guardrails/guardrails_openai.py +541 -0
- lexsi_sdk/core/guardrails/openai_runner.py +1328 -0
- lexsi_sdk/core/model_summary.py +110 -0
- lexsi_sdk/core/organization.py +549 -0
- lexsi_sdk/core/project.py +5131 -0
- lexsi_sdk/core/synthetic.py +387 -0
- lexsi_sdk/core/text.py +595 -0
- lexsi_sdk/core/tracer.py +208 -0
- lexsi_sdk/core/utils.py +36 -0
- lexsi_sdk/core/workspace.py +325 -0
- lexsi_sdk/core/wrapper.py +766 -0
- lexsi_sdk/core/xai.py +306 -0
- lexsi_sdk/version.py +34 -0
- lexsi_sdk-0.1.16.dist-info/METADATA +100 -0
- lexsi_sdk-0.1.16.dist-info/RECORD +40 -0
- lexsi_sdk-0.1.16.dist-info/WHEEL +5 -0
- lexsi_sdk-0.1.16.dist-info/top_level.txt +1 -0
lexsi_sdk/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
import httpx
|
|
3
|
+
from lexsi_sdk.common.xai_uris import LOGIN_URI
|
|
4
|
+
import jwt
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
import json
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class APIClient(BaseModel):
|
|
10
|
+
"""API client to interact with Lexsi Ai services"""
|
|
11
|
+
|
|
12
|
+
debug: bool = False
|
|
13
|
+
base_url: str = ""
|
|
14
|
+
access_token: str = ""
|
|
15
|
+
auth_token: str = ""
|
|
16
|
+
headers: dict = {}
|
|
17
|
+
|
|
18
|
+
def __init__(self, **kwargs):
|
|
19
|
+
"""Initialize the API client with provided configuration."""
|
|
20
|
+
super().__init__(**kwargs)
|
|
21
|
+
|
|
22
|
+
def get_auth_token(self) -> str:
|
|
23
|
+
"""get jwt auth token value
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
str: jwt auth token
|
|
27
|
+
"""
|
|
28
|
+
return self.auth_token
|
|
29
|
+
|
|
30
|
+
def set_auth_token(self, auth_token):
|
|
31
|
+
"""sets jwt auth token value
|
|
32
|
+
|
|
33
|
+
:param auth_token: jwt auth token
|
|
34
|
+
"""
|
|
35
|
+
self.auth_token = auth_token
|
|
36
|
+
|
|
37
|
+
def set_access_token(self, access_token):
|
|
38
|
+
"""sets access token value
|
|
39
|
+
|
|
40
|
+
:param auth_token: jwt auth token
|
|
41
|
+
"""
|
|
42
|
+
self.access_token = access_token
|
|
43
|
+
|
|
44
|
+
def get_url(self, uri) -> str:
|
|
45
|
+
"""get url by appending uri to base url
|
|
46
|
+
|
|
47
|
+
:param uri: uri of endpoint
|
|
48
|
+
:return: url
|
|
49
|
+
"""
|
|
50
|
+
return f"{self.base_url}/{uri}"
|
|
51
|
+
|
|
52
|
+
def update_headers(self, auth_token):
|
|
53
|
+
"""sets jwt auth token and updates headers for all requests"""
|
|
54
|
+
self.set_auth_token(auth_token)
|
|
55
|
+
self.headers = {
|
|
56
|
+
"Authorization": f"Bearer {self.auth_token}",
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
def refresh_bearer_token(self):
|
|
60
|
+
"""Refresh the bearer token if the current token is expired."""
|
|
61
|
+
try:
|
|
62
|
+
if self.auth_token:
|
|
63
|
+
jwt.decode(
|
|
64
|
+
self.auth_token,
|
|
65
|
+
options={"verify_signature": False, "verify_exp": True},
|
|
66
|
+
)
|
|
67
|
+
except jwt.exceptions.ExpiredSignatureError as e:
|
|
68
|
+
response = self.base_request(
|
|
69
|
+
"POST", LOGIN_URI, {"access_token": self.access_token}
|
|
70
|
+
).json()
|
|
71
|
+
self.update_headers(response["access_token"])
|
|
72
|
+
|
|
73
|
+
def base_request(self, method, uri, payload={}, files=None):
|
|
74
|
+
"""makes request to xai base service
|
|
75
|
+
|
|
76
|
+
:param uri: api uri
|
|
77
|
+
:param method: GET, POST, PUT, DELETE
|
|
78
|
+
:raises Exception: Request exception
|
|
79
|
+
:return: JSON response
|
|
80
|
+
"""
|
|
81
|
+
url = f"{self.base_url}/{uri}"
|
|
82
|
+
try:
|
|
83
|
+
# response = requests.request(
|
|
84
|
+
# method,
|
|
85
|
+
# url,
|
|
86
|
+
# headers=self.headers,
|
|
87
|
+
# json=payload,
|
|
88
|
+
# files=files,
|
|
89
|
+
# stream=stream,
|
|
90
|
+
# )
|
|
91
|
+
|
|
92
|
+
with httpx.Client(http2=True, timeout=None) as client:
|
|
93
|
+
response = client.request(
|
|
94
|
+
method=method,
|
|
95
|
+
url=url,
|
|
96
|
+
headers=self.headers,
|
|
97
|
+
json=payload,
|
|
98
|
+
files=files or None,
|
|
99
|
+
)
|
|
100
|
+
#response.raise_for_status()
|
|
101
|
+
#return response
|
|
102
|
+
|
|
103
|
+
res = None
|
|
104
|
+
try:
|
|
105
|
+
res = response.json().get("details") or response.json()
|
|
106
|
+
except Exception:
|
|
107
|
+
res = response.text
|
|
108
|
+
if 400 <= response.status_code < 500:
|
|
109
|
+
raise Exception(res)
|
|
110
|
+
elif 500 <= response.status_code < 600:
|
|
111
|
+
raise Exception(res)
|
|
112
|
+
else:
|
|
113
|
+
return response
|
|
114
|
+
except Exception as e:
|
|
115
|
+
raise e
|
|
116
|
+
|
|
117
|
+
def request(self, method, uri, payload):
|
|
118
|
+
"""Refresh credentials and dispatch a base request."""
|
|
119
|
+
self.refresh_bearer_token()
|
|
120
|
+
response = self.base_request(method, uri, payload)
|
|
121
|
+
return response
|
|
122
|
+
|
|
123
|
+
def get(self, uri):
|
|
124
|
+
"""makes get request to xai base service
|
|
125
|
+
|
|
126
|
+
:param uri: api uri
|
|
127
|
+
:raises Exception: Request exception
|
|
128
|
+
:return: JSON response
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
self.refresh_bearer_token()
|
|
132
|
+
response = self.base_request("GET", uri)
|
|
133
|
+
return response.json()
|
|
134
|
+
|
|
135
|
+
def post(self, uri, payload={}):
|
|
136
|
+
"""makes post request to xai base service
|
|
137
|
+
|
|
138
|
+
:param uri: api uri
|
|
139
|
+
:param payload: api payload, defaults to {}
|
|
140
|
+
:raises Exception: Request exception
|
|
141
|
+
:return: JSON response
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
self.refresh_bearer_token()
|
|
145
|
+
response = self.base_request("POST", uri, payload)
|
|
146
|
+
|
|
147
|
+
return response.json()
|
|
148
|
+
|
|
149
|
+
def stream(self, uri, method, payload=None):
|
|
150
|
+
"""Server-Sent Events / line-streaming endpoint."""
|
|
151
|
+
self.refresh_bearer_token()
|
|
152
|
+
url = f"{self.base_url}/{uri}"
|
|
153
|
+
# if SSE, this header helps
|
|
154
|
+
headers = {**self.headers, "Accept": "text/event-stream"}
|
|
155
|
+
|
|
156
|
+
with httpx.Client(http2=True, timeout=None) as client:
|
|
157
|
+
# streaming MUST be consumed inside the context
|
|
158
|
+
with client.stream(method, url, headers=headers, json=payload) as response:
|
|
159
|
+
response.raise_for_status()
|
|
160
|
+
for line in response.iter_lines(): # no decode_unicode arg in httpx
|
|
161
|
+
if not line: continue
|
|
162
|
+
if line.startswith("data: "): # typical SSE prefix
|
|
163
|
+
if line.strip() == "data: [DONE]":
|
|
164
|
+
break
|
|
165
|
+
yield json.loads(line[6:])
|
|
166
|
+
|
|
167
|
+
def file(self, uri, files):
|
|
168
|
+
"""makes multipart request to send files
|
|
169
|
+
|
|
170
|
+
:param uri: api uri
|
|
171
|
+
:param file_path: file path
|
|
172
|
+
:return: JSON response
|
|
173
|
+
"""
|
|
174
|
+
self.refresh_bearer_token()
|
|
175
|
+
response = self.base_request("POST", uri, files=files)
|
|
176
|
+
return response.json()
|
|
File without changes
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
MODEL_TYPES = ["classification", "regression"]
|
|
2
|
+
|
|
3
|
+
DATA_DRIFT_DASHBOARD_REQUIRED_FIELDS = [
|
|
4
|
+
"base_line_tag",
|
|
5
|
+
"current_tag",
|
|
6
|
+
"stat_test_name",
|
|
7
|
+
]
|
|
8
|
+
|
|
9
|
+
DATA_DRIFT_STAT_TESTS = [
|
|
10
|
+
"chisquare",
|
|
11
|
+
"jensenshannon",
|
|
12
|
+
"ks",
|
|
13
|
+
"kl_div",
|
|
14
|
+
"psi",
|
|
15
|
+
"wasserstein",
|
|
16
|
+
"z",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
TARGET_DRIFT_DASHBOARD_REQUIRED_FIELDS = [
|
|
21
|
+
"base_line_tag",
|
|
22
|
+
"current_tag",
|
|
23
|
+
"baseline_true_label",
|
|
24
|
+
"current_true_label",
|
|
25
|
+
"model_type",
|
|
26
|
+
"stat_test_name",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
TARGET_DRIFT_STAT_TESTS = ["chisquare", "jensenshannon", "kl_div", "psi", "z"]
|
|
30
|
+
|
|
31
|
+
TARGET_DRIFT_STAT_TESTS_CLASSIFICATION = [
|
|
32
|
+
"chisquare",
|
|
33
|
+
"jensenshannon",
|
|
34
|
+
"kl_div",
|
|
35
|
+
"psi",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
TARGET_DRIFT_STAT_TESTS_REGRESSION = [
|
|
39
|
+
"jensenshannon",
|
|
40
|
+
"kl_div",
|
|
41
|
+
"ks",
|
|
42
|
+
"psi",
|
|
43
|
+
"wasserstein",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
BIAS_MONITORING_DASHBOARD_REQUIRED_FIELDS = [
|
|
48
|
+
"base_line_tag",
|
|
49
|
+
"baseline_true_label",
|
|
50
|
+
"baseline_pred_label",
|
|
51
|
+
"model_type",
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
MODEL_PERF_DASHBOARD_REQUIRED_FIELDS = [
|
|
55
|
+
"base_line_tag",
|
|
56
|
+
"current_tag",
|
|
57
|
+
"baseline_true_label",
|
|
58
|
+
"baseline_pred_label",
|
|
59
|
+
"current_true_label",
|
|
60
|
+
"current_pred_label",
|
|
61
|
+
"model_type",
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
DATA_DRIFT_TRIGGER_REQUIRED_FIELDS = [
|
|
66
|
+
"trigger_name",
|
|
67
|
+
"trigger_type",
|
|
68
|
+
"mail_list",
|
|
69
|
+
"frequency",
|
|
70
|
+
"stat_test_name",
|
|
71
|
+
"datadrift_features_per",
|
|
72
|
+
"base_line_tag",
|
|
73
|
+
"current_tag",
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
TARGET_DRIFT_TRIGGER_REQUIRED_FIELDS = [
|
|
77
|
+
"trigger_name",
|
|
78
|
+
"trigger_type",
|
|
79
|
+
"mail_list",
|
|
80
|
+
"frequency",
|
|
81
|
+
"model_type",
|
|
82
|
+
"stat_test_name",
|
|
83
|
+
"baseline_true_label",
|
|
84
|
+
"current_true_label",
|
|
85
|
+
"base_line_tag",
|
|
86
|
+
"current_tag",
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
MODEL_PERF_TRIGGER_REQUIRED_FIELDS = [
|
|
90
|
+
"trigger_name",
|
|
91
|
+
"trigger_type",
|
|
92
|
+
"mail_list",
|
|
93
|
+
"frequency",
|
|
94
|
+
"model_type",
|
|
95
|
+
"model_performance_metric",
|
|
96
|
+
"model_performance_threshold",
|
|
97
|
+
"baseline_true_label",
|
|
98
|
+
"baseline_pred_label",
|
|
99
|
+
"base_line_tag",
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
MODEL_PERF_METRICS_CLASSIFICATION = [
|
|
103
|
+
"accuracy",
|
|
104
|
+
"f1",
|
|
105
|
+
"false_negative_rate",
|
|
106
|
+
"false_positive_rate",
|
|
107
|
+
"precision",
|
|
108
|
+
"recall",
|
|
109
|
+
"true_negative_rate",
|
|
110
|
+
"true_positive_rate",
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
MODEL_PERF_METRICS_REGRESSION = [
|
|
114
|
+
"mean_abs_perc_error",
|
|
115
|
+
"mean_abs_perc_error",
|
|
116
|
+
"mean_squared_error",
|
|
117
|
+
"r2_score",
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
MAIL_FREQUENCIES = [
|
|
121
|
+
"1 hour",
|
|
122
|
+
"3 hour",
|
|
123
|
+
"6 hour",
|
|
124
|
+
"daily",
|
|
125
|
+
"weekly",
|
|
126
|
+
"monthly",
|
|
127
|
+
"quarterly",
|
|
128
|
+
"yearly",
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
SYNTHETIC_MODELS_DEFAULT_HYPER_PARAMS = {
|
|
132
|
+
"GPT2": {
|
|
133
|
+
"batch_size": 250,
|
|
134
|
+
"early_stopping_patience": 10,
|
|
135
|
+
"early_stopping_threshold": 0.11,
|
|
136
|
+
"epochs": 100,
|
|
137
|
+
"model_type": "tabular",
|
|
138
|
+
"random_state": 1,
|
|
139
|
+
"tabular_config": "GPT2Config",
|
|
140
|
+
"train_size": 0.8,
|
|
141
|
+
},
|
|
142
|
+
"CTGAN": {"epochs": 100, "test_ratio": 0.2},
|
|
143
|
+
}
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pydantic import BaseModel
|
|
3
|
+
from dotenv import load_dotenv
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Environment(BaseModel):
|
|
7
|
+
"""
|
|
8
|
+
Environment class to load current environment
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
debug: bool = False
|
|
12
|
+
XAI_ENV: str = os.getenv("XAI_ENV", "prod")
|
|
13
|
+
|
|
14
|
+
def __init__(self):
|
|
15
|
+
"""Load environment configuration on instantiation."""
|
|
16
|
+
super().__init__()
|
|
17
|
+
|
|
18
|
+
self.load_environment()
|
|
19
|
+
|
|
20
|
+
def load_environment(self):
|
|
21
|
+
"""
|
|
22
|
+
load current environment config
|
|
23
|
+
|
|
24
|
+
:return: None
|
|
25
|
+
"""
|
|
26
|
+
env_file = f".env.{self.XAI_ENV}"
|
|
27
|
+
|
|
28
|
+
BASEDIR = os.path.abspath(os.path.dirname(__file__))
|
|
29
|
+
load_dotenv(os.path.join(BASEDIR, "config", env_file))
|
|
30
|
+
|
|
31
|
+
logger_on = self.get_debug()
|
|
32
|
+
|
|
33
|
+
if logger_on:
|
|
34
|
+
self.debug = logger_on
|
|
35
|
+
print(f"Connected to: {self.XAI_ENV} environment")
|
|
36
|
+
|
|
37
|
+
def get_base_url(self) -> str:
|
|
38
|
+
"""get base url of XAI platform
|
|
39
|
+
|
|
40
|
+
:return: base url
|
|
41
|
+
"""
|
|
42
|
+
return os.getenv("XAI_API_URL", "https://apiv1.lexsi.ai")
|
|
43
|
+
|
|
44
|
+
def get_debug(self) -> bool:
|
|
45
|
+
"""get debug flag
|
|
46
|
+
|
|
47
|
+
:return: debug flag
|
|
48
|
+
"""
|
|
49
|
+
return bool(os.getenv("DEBUG", False))
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from typing import List, Optional, TypedDict
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ImageDashboardPayload(TypedDict):
|
|
5
|
+
"""Payload schema for image monitoring dashboards."""
|
|
6
|
+
|
|
7
|
+
base_line_tag: List[str]
|
|
8
|
+
current_tag: List[str]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DataDriftPayload(TypedDict):
|
|
12
|
+
"""Payload schema for data drift dashboards."""
|
|
13
|
+
|
|
14
|
+
project_name: Optional[str]
|
|
15
|
+
base_line_tag: List[str]
|
|
16
|
+
current_tag: List[str]
|
|
17
|
+
|
|
18
|
+
date_feature: Optional[str]
|
|
19
|
+
baseline_date: Optional[dict]
|
|
20
|
+
current_date: Optional[dict]
|
|
21
|
+
|
|
22
|
+
features_to_use: List[str]
|
|
23
|
+
|
|
24
|
+
stat_test_name: str
|
|
25
|
+
stat_test_threshold: str
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TargetDriftPayload(TypedDict):
|
|
29
|
+
"""Payload schema for target drift dashboards."""
|
|
30
|
+
|
|
31
|
+
project_name: str
|
|
32
|
+
base_line_tag: List[str]
|
|
33
|
+
current_tag: List[str]
|
|
34
|
+
|
|
35
|
+
date_feature: Optional[str]
|
|
36
|
+
baseline_date: Optional[dict]
|
|
37
|
+
current_date: Optional[dict]
|
|
38
|
+
|
|
39
|
+
model_type: str
|
|
40
|
+
|
|
41
|
+
baseline_true_label: str
|
|
42
|
+
current_true_label: str
|
|
43
|
+
|
|
44
|
+
stat_test_name: str
|
|
45
|
+
stat_test_threshold: float
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class BiasMonitoringPayload(TypedDict):
|
|
49
|
+
"""Payload schema for bias monitoring dashboards."""
|
|
50
|
+
|
|
51
|
+
project_name: str
|
|
52
|
+
base_line_tag: List[str]
|
|
53
|
+
|
|
54
|
+
date_feature: Optional[str]
|
|
55
|
+
baseline_date: Optional[dict]
|
|
56
|
+
current_date: Optional[dict]
|
|
57
|
+
|
|
58
|
+
baseline_true_label: str
|
|
59
|
+
baseline_pred_label: str
|
|
60
|
+
|
|
61
|
+
features_to_use: List[str]
|
|
62
|
+
model_type: str
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ModelPerformancePayload(TypedDict):
|
|
66
|
+
"""Payload schema for model performance dashboards."""
|
|
67
|
+
|
|
68
|
+
project_name: str
|
|
69
|
+
base_line_tag: List[str]
|
|
70
|
+
current_tag: List[str]
|
|
71
|
+
|
|
72
|
+
date_feature: Optional[str]
|
|
73
|
+
baseline_date: Optional[dict]
|
|
74
|
+
current_date: Optional[dict]
|
|
75
|
+
|
|
76
|
+
baseline_true_label: str
|
|
77
|
+
baseline_pred_label: str
|
|
78
|
+
current_true_label: str
|
|
79
|
+
current_pred_label: str
|
|
80
|
+
|
|
81
|
+
model_type: str
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from typing import List, Optional, TypedDict
|
|
2
|
+
|
|
3
|
+
class DataDriftTriggerPayload(TypedDict):
|
|
4
|
+
"""Payload schema for creating data drift alert triggers."""
|
|
5
|
+
|
|
6
|
+
project_name: str
|
|
7
|
+
trigger_name: str
|
|
8
|
+
trigger_type: str
|
|
9
|
+
|
|
10
|
+
mail_list: List[str]
|
|
11
|
+
frequency: str
|
|
12
|
+
|
|
13
|
+
stat_test_name: str
|
|
14
|
+
stat_test_threshold: Optional[float]
|
|
15
|
+
|
|
16
|
+
datadrift_features_per: float
|
|
17
|
+
|
|
18
|
+
features_to_use: List[str]
|
|
19
|
+
|
|
20
|
+
date_feature: Optional[str]
|
|
21
|
+
baseline_date: Optional[dict]
|
|
22
|
+
current_date: Optional[dict]
|
|
23
|
+
|
|
24
|
+
base_line_tag: List[str]
|
|
25
|
+
current_tag: List[str]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TargetDriftTriggerPayload(TypedDict):
|
|
29
|
+
"""Payload schema for creating target drift alert triggers."""
|
|
30
|
+
|
|
31
|
+
project_name: str
|
|
32
|
+
trigger_name: str
|
|
33
|
+
trigger_type: str
|
|
34
|
+
|
|
35
|
+
mail_list: List[str]
|
|
36
|
+
frequency: str
|
|
37
|
+
|
|
38
|
+
model_type: str
|
|
39
|
+
|
|
40
|
+
stat_test_name: str
|
|
41
|
+
stat_test_threshold: Optional[float]
|
|
42
|
+
|
|
43
|
+
baseline_true_label: str
|
|
44
|
+
current_true_label: str
|
|
45
|
+
|
|
46
|
+
date_feature: Optional[str]
|
|
47
|
+
baseline_date: Optional[dict]
|
|
48
|
+
current_date: Optional[dict]
|
|
49
|
+
|
|
50
|
+
base_line_tag: List[str]
|
|
51
|
+
current_tag: List[str]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ModelPerfTriggerPayload(TypedDict):
|
|
55
|
+
"""Payload schema for creating model performance alert triggers."""
|
|
56
|
+
|
|
57
|
+
project_name: str
|
|
58
|
+
trigger_name: str
|
|
59
|
+
trigger_type: str
|
|
60
|
+
|
|
61
|
+
mail_list: List[str]
|
|
62
|
+
frequency: str
|
|
63
|
+
|
|
64
|
+
model_type: str
|
|
65
|
+
model_performance_metric: float
|
|
66
|
+
model_performance_threshold: float
|
|
67
|
+
|
|
68
|
+
baseline_true_label: str
|
|
69
|
+
baseline_pred_label: str
|
|
70
|
+
|
|
71
|
+
date_feature: Optional[str]
|
|
72
|
+
baseline_date: Optional[dict]
|
|
73
|
+
current_date: Optional[dict]
|
|
74
|
+
|
|
75
|
+
base_line_tag: List[str]
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import List, Optional, TypedDict, Dict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ProjectConfig(TypedDict):
|
|
6
|
+
"""Configuration keys required to describe a project."""
|
|
7
|
+
|
|
8
|
+
project_type: str
|
|
9
|
+
model_name: Optional[str] = None
|
|
10
|
+
unique_identifier: str
|
|
11
|
+
true_label: str
|
|
12
|
+
tag: str
|
|
13
|
+
pred_label: Optional[str]
|
|
14
|
+
feature_exclude: Optional[List[str]]
|
|
15
|
+
drop_duplicate_uid: Optional[bool]
|
|
16
|
+
handle_errors: Optional[bool]
|
|
17
|
+
feature_encodings: Optional[dict]
|
|
18
|
+
handle_data_imbalance: Optional[bool]
|
|
19
|
+
explainability_method: Optional[List[str]] = None
|
|
20
|
+
|
|
21
|
+
class DataConfig(TypedDict):
|
|
22
|
+
"""Training data configuration for tabular workloads."""
|
|
23
|
+
|
|
24
|
+
tags: List[str]
|
|
25
|
+
test_tags: Optional[List[str]]
|
|
26
|
+
use_optuna: Optional[bool] = False
|
|
27
|
+
feature_exclude: List[str]
|
|
28
|
+
feature_encodings: Dict[str, str]
|
|
29
|
+
drop_duplicate_uid: bool
|
|
30
|
+
sample_percentage: float
|
|
31
|
+
explainability_sample_percentage: float
|
|
32
|
+
lime_explainability_iterations: int
|
|
33
|
+
explainability_method: List[str]
|
|
34
|
+
handle_data_imbalance: Optional[bool]
|
|
35
|
+
|
|
36
|
+
class SyntheticDataConfig(TypedDict):
|
|
37
|
+
"""Configuration required when generating synthetic data."""
|
|
38
|
+
|
|
39
|
+
model_name: str
|
|
40
|
+
tags: List[str]
|
|
41
|
+
feature_exclude: List[str]
|
|
42
|
+
feature_include: List[str]
|
|
43
|
+
feature_actual_used: List[str]
|
|
44
|
+
drop_duplicate_uid: bool
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class SyntheticModelHyperParams(TypedDict):
|
|
48
|
+
"""Common hyperparameter keys for supported synthetic models."""
|
|
49
|
+
|
|
50
|
+
# GPT2 hyper params
|
|
51
|
+
batch_size: Optional[int]
|
|
52
|
+
early_stopping_patience: Optional[int]
|
|
53
|
+
early_stopping_threshold: Optional[float]
|
|
54
|
+
epochs: Optional[int]
|
|
55
|
+
model_type: Optional[str]
|
|
56
|
+
random_state: Optional[int]
|
|
57
|
+
tabular_config: Optional[str]
|
|
58
|
+
train_size: Optional[float]
|
|
59
|
+
|
|
60
|
+
# CTGAN hyper params
|
|
61
|
+
epochs: Optional[int]
|
|
62
|
+
test_ratio: Optional[float]
|
|
63
|
+
|
|
64
|
+
class GCSConfig(TypedDict):
|
|
65
|
+
"""Google Cloud Storage connector configuration."""
|
|
66
|
+
|
|
67
|
+
project_id: str
|
|
68
|
+
gcp_project_name: str
|
|
69
|
+
type: str
|
|
70
|
+
private_key_id: str
|
|
71
|
+
private_key: str
|
|
72
|
+
client_email: str
|
|
73
|
+
client_id: str
|
|
74
|
+
auth_uri: str
|
|
75
|
+
token_uri: str
|
|
76
|
+
|
|
77
|
+
class S3Config(TypedDict):
|
|
78
|
+
"""Amazon S3 connector configuration."""
|
|
79
|
+
|
|
80
|
+
region: Optional[str] = None
|
|
81
|
+
access_key: str
|
|
82
|
+
secret_key: str
|
|
83
|
+
|
|
84
|
+
class GDriveConfig(TypedDict):
|
|
85
|
+
"""Google Drive connector configuration."""
|
|
86
|
+
|
|
87
|
+
project_id: str
|
|
88
|
+
type: str
|
|
89
|
+
private_key_id: str
|
|
90
|
+
private_key: str
|
|
91
|
+
client_email: str
|
|
92
|
+
client_id: str
|
|
93
|
+
auth_uri: str
|
|
94
|
+
token_uri: str
|
|
95
|
+
|
|
96
|
+
class SFTPConfig(TypedDict):
|
|
97
|
+
"""SFTP connector configuration."""
|
|
98
|
+
|
|
99
|
+
hostname: str
|
|
100
|
+
port: str
|
|
101
|
+
username: str
|
|
102
|
+
password: str
|
|
103
|
+
|
|
104
|
+
class CustomServerConfig(TypedDict):
|
|
105
|
+
"""Scheduling options when requesting dedicated inference compute."""
|
|
106
|
+
|
|
107
|
+
start: Optional[datetime] = None
|
|
108
|
+
stop: Optional[datetime] = None
|
|
109
|
+
shutdown_after: Optional[int] = 1
|
|
110
|
+
op_hours: Optional[bool] = None
|
|
111
|
+
auto_start: bool = False
|
|
112
|
+
|
|
113
|
+
class InferenceCompute(TypedDict):
|
|
114
|
+
"""Inference compute selection payload."""
|
|
115
|
+
|
|
116
|
+
instance_type: str
|
|
117
|
+
custom_server_config: Optional[CustomServerConfig] = CustomServerConfig()
|
|
118
|
+
|
|
119
|
+
class InferenceSettings(TypedDict):
|
|
120
|
+
"""Inference settings that can be applied to text models."""
|
|
121
|
+
|
|
122
|
+
inference_engine: str
|