lexsi-sdk 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lexsi_sdk/__init__.py +5 -0
- lexsi_sdk/client/__init__.py +0 -0
- lexsi_sdk/client/client.py +176 -0
- lexsi_sdk/common/__init__.py +0 -0
- lexsi_sdk/common/config/.env.prod +3 -0
- lexsi_sdk/common/constants.py +143 -0
- lexsi_sdk/common/enums.py +8 -0
- lexsi_sdk/common/environment.py +49 -0
- lexsi_sdk/common/monitoring.py +81 -0
- lexsi_sdk/common/trigger.py +75 -0
- lexsi_sdk/common/types.py +122 -0
- lexsi_sdk/common/utils.py +93 -0
- lexsi_sdk/common/validation.py +110 -0
- lexsi_sdk/common/xai_uris.py +197 -0
- lexsi_sdk/core/__init__.py +0 -0
- lexsi_sdk/core/agent.py +62 -0
- lexsi_sdk/core/alert.py +56 -0
- lexsi_sdk/core/case.py +618 -0
- lexsi_sdk/core/dashboard.py +131 -0
- lexsi_sdk/core/guardrails/__init__.py +0 -0
- lexsi_sdk/core/guardrails/guard_template.py +299 -0
- lexsi_sdk/core/guardrails/guardrail_autogen.py +554 -0
- lexsi_sdk/core/guardrails/guardrails_langgraph.py +525 -0
- lexsi_sdk/core/guardrails/guardrails_openai.py +541 -0
- lexsi_sdk/core/guardrails/openai_runner.py +1328 -0
- lexsi_sdk/core/model_summary.py +110 -0
- lexsi_sdk/core/organization.py +549 -0
- lexsi_sdk/core/project.py +5131 -0
- lexsi_sdk/core/synthetic.py +387 -0
- lexsi_sdk/core/text.py +595 -0
- lexsi_sdk/core/tracer.py +208 -0
- lexsi_sdk/core/utils.py +36 -0
- lexsi_sdk/core/workspace.py +325 -0
- lexsi_sdk/core/wrapper.py +766 -0
- lexsi_sdk/core/xai.py +306 -0
- lexsi_sdk/version.py +34 -0
- lexsi_sdk-0.1.16.dist-info/METADATA +100 -0
- lexsi_sdk-0.1.16.dist-info/RECORD +40 -0
- lexsi_sdk-0.1.16.dist-info/WHEEL +5 -0
- lexsi_sdk-0.1.16.dist-info/top_level.txt +1 -0
lexsi_sdk/core/text.py
ADDED
|
@@ -0,0 +1,595 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
import io
|
|
3
|
+
from typing import Optional, List, Dict, Any, Union
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
from lexsi_sdk.common.types import InferenceCompute, InferenceSettings
|
|
7
|
+
from lexsi_sdk.common.utils import poll_events
|
|
8
|
+
from lexsi_sdk.common.xai_uris import (
|
|
9
|
+
AVAILABLE_GUARDRAILS_URI,
|
|
10
|
+
CONFIGURE_GUARDRAILS_URI,
|
|
11
|
+
DELETE_GUARDRAILS_URI,
|
|
12
|
+
GET_AVAILABLE_TEXT_MODELS_URI,
|
|
13
|
+
GET_GUARDRAILS_URI,
|
|
14
|
+
INITIALIZE_TEXT_MODEL_URI,
|
|
15
|
+
LIST_DATA_CONNECTORS,
|
|
16
|
+
MESSAGES_URI,
|
|
17
|
+
QUANTIZE_MODELS_URI,
|
|
18
|
+
SESSIONS_URI,
|
|
19
|
+
TEXT_MODEL_INFERENCE_SETTINGS_URI,
|
|
20
|
+
TRACES_URI,
|
|
21
|
+
UPDATE_GUARDRAILS_STATUS_URI,
|
|
22
|
+
UPLOAD_DATA_FILE_URI,
|
|
23
|
+
UPLOAD_DATA_URI,
|
|
24
|
+
UPLOAD_FILE_DATA_CONNECTORS,
|
|
25
|
+
RUN_CHAT_COMPLETION,
|
|
26
|
+
RUN_IMAGE_GENERATION,
|
|
27
|
+
RUN_CREATE_EMBEDDING,
|
|
28
|
+
RUN_COMPLETION
|
|
29
|
+
)
|
|
30
|
+
from lexsi_sdk.core.project import Project
|
|
31
|
+
import pandas as pd
|
|
32
|
+
|
|
33
|
+
from lexsi_sdk.core.utils import build_list_data_connector_url
|
|
34
|
+
from lexsi_sdk.core.wrapper import LexsiModels, monitor
|
|
35
|
+
import json
|
|
36
|
+
import aiohttp
|
|
37
|
+
from typing import AsyncIterator, Iterator
|
|
38
|
+
import requests
|
|
39
|
+
from uuid import UUID
|
|
40
|
+
|
|
41
|
+
class TextProject(Project):
|
|
42
|
+
"""Project for text modality
|
|
43
|
+
|
|
44
|
+
:return: TextProject
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def llm_monitor(self, client, session_id=None):
|
|
48
|
+
"""llm monitoring for custom clients
|
|
49
|
+
|
|
50
|
+
:param client: client to monitor like OpenAI
|
|
51
|
+
:param session_id: id of the session
|
|
52
|
+
:return: response
|
|
53
|
+
"""
|
|
54
|
+
return monitor(project=self, client=client, session_id=session_id)
|
|
55
|
+
|
|
56
|
+
def sessions(self) -> pd.DataFrame:
|
|
57
|
+
"""All sessions
|
|
58
|
+
|
|
59
|
+
:return: response
|
|
60
|
+
"""
|
|
61
|
+
res = self.api_client.get(f"{SESSIONS_URI}?project_name={self.project_name}")
|
|
62
|
+
if not res["success"]:
|
|
63
|
+
raise Exception(res.get("details"))
|
|
64
|
+
|
|
65
|
+
return pd.DataFrame(res.get("details"))
|
|
66
|
+
|
|
67
|
+
def messages(self, session_id: str) -> pd.DataFrame:
|
|
68
|
+
"""All messages for a session
|
|
69
|
+
|
|
70
|
+
:param session_id: id of the session
|
|
71
|
+
:return: response
|
|
72
|
+
"""
|
|
73
|
+
res = self.api_client.get(
|
|
74
|
+
f"{MESSAGES_URI}?project_name={self.project_name}&session_id={session_id}"
|
|
75
|
+
)
|
|
76
|
+
if not res["success"]:
|
|
77
|
+
raise Exception(res.get("details"))
|
|
78
|
+
|
|
79
|
+
return pd.DataFrame(res.get("details"))
|
|
80
|
+
|
|
81
|
+
def traces(self, trace_id: str) -> pd.DataFrame:
|
|
82
|
+
"""Traces generated for trace_id
|
|
83
|
+
|
|
84
|
+
:param trace_id: id of the trace
|
|
85
|
+
:return: response
|
|
86
|
+
"""
|
|
87
|
+
res = self.api_client.get(
|
|
88
|
+
f"{TRACES_URI}?project_name={self.project_name}&trace_id={trace_id}"
|
|
89
|
+
)
|
|
90
|
+
if not res["success"]:
|
|
91
|
+
raise Exception(res.get("details"))
|
|
92
|
+
|
|
93
|
+
return pd.DataFrame(res.get("details"))
|
|
94
|
+
|
|
95
|
+
def guardrails(self) -> pd.DataFrame:
|
|
96
|
+
"""Guardrails configured in project
|
|
97
|
+
|
|
98
|
+
:return: response
|
|
99
|
+
"""
|
|
100
|
+
res = self.api_client.get(
|
|
101
|
+
f"{GET_GUARDRAILS_URI}?project_name={self.project_name}"
|
|
102
|
+
)
|
|
103
|
+
if not res["success"]:
|
|
104
|
+
raise Exception(res.get("details"))
|
|
105
|
+
|
|
106
|
+
return pd.DataFrame(res.get("details"))
|
|
107
|
+
|
|
108
|
+
def update_guardrail_status(self, guardrail_id: str, status: bool) -> str:
|
|
109
|
+
"""Update Guardrail Status
|
|
110
|
+
|
|
111
|
+
:param guardrail_id: id of the guardrail
|
|
112
|
+
:param status: status to active/inactive
|
|
113
|
+
:return: response
|
|
114
|
+
"""
|
|
115
|
+
payload = {
|
|
116
|
+
"project_name": self.project_name,
|
|
117
|
+
"guardrail_id": guardrail_id,
|
|
118
|
+
"status": status,
|
|
119
|
+
}
|
|
120
|
+
res = self.api_client.post(UPDATE_GUARDRAILS_STATUS_URI, payload=payload)
|
|
121
|
+
if not res["success"]:
|
|
122
|
+
raise Exception(res.get("details"))
|
|
123
|
+
|
|
124
|
+
return res.get("details")
|
|
125
|
+
|
|
126
|
+
def delete_guardrail(self, guardrail_id: str) -> str:
|
|
127
|
+
"""Deletes Guardrail
|
|
128
|
+
|
|
129
|
+
:param guardrail_id: id of the guardrail
|
|
130
|
+
:return: response
|
|
131
|
+
"""
|
|
132
|
+
payload = {
|
|
133
|
+
"project_name": self.project_name,
|
|
134
|
+
"guardrail_id": guardrail_id,
|
|
135
|
+
}
|
|
136
|
+
res = self.api_client.post(DELETE_GUARDRAILS_URI, payload=payload)
|
|
137
|
+
if not res["success"]:
|
|
138
|
+
raise Exception(res.get("details"))
|
|
139
|
+
|
|
140
|
+
return res.get("details")
|
|
141
|
+
|
|
142
|
+
def available_guardrails(self) -> pd.DataFrame:
|
|
143
|
+
"""Available guardrails to configure
|
|
144
|
+
|
|
145
|
+
:return: response
|
|
146
|
+
"""
|
|
147
|
+
res = self.api_client.get(AVAILABLE_GUARDRAILS_URI)
|
|
148
|
+
if not res["success"]:
|
|
149
|
+
raise Exception(res.get("details"))
|
|
150
|
+
|
|
151
|
+
return pd.DataFrame(res.get("details"))
|
|
152
|
+
|
|
153
|
+
def configure_guardrail(
|
|
154
|
+
self,
|
|
155
|
+
guardrail_name: str,
|
|
156
|
+
guardrail_config: dict,
|
|
157
|
+
model_name: str,
|
|
158
|
+
apply_on: str,
|
|
159
|
+
) -> str:
|
|
160
|
+
"""Configure guardrail for project
|
|
161
|
+
|
|
162
|
+
:param guardrail_name: name of the guardrail
|
|
163
|
+
:param guardrail_config: config for the guardrail
|
|
164
|
+
:param model_name: name of the model
|
|
165
|
+
:param apply_on: when to apply guardrails input/output
|
|
166
|
+
:return: response
|
|
167
|
+
"""
|
|
168
|
+
payload = {
|
|
169
|
+
"name": guardrail_name,
|
|
170
|
+
"config": guardrail_config,
|
|
171
|
+
"model_name": model_name,
|
|
172
|
+
"apply_on": apply_on,
|
|
173
|
+
"project_name": self.project_name,
|
|
174
|
+
}
|
|
175
|
+
res = self.api_client.post(CONFIGURE_GUARDRAILS_URI, payload)
|
|
176
|
+
if not res["success"]:
|
|
177
|
+
raise Exception(res.get("details"))
|
|
178
|
+
|
|
179
|
+
return res.get("details")
|
|
180
|
+
|
|
181
|
+
def initialize_text_model(
|
|
182
|
+
self,
|
|
183
|
+
model_provider: str,
|
|
184
|
+
model_name: str,
|
|
185
|
+
model_task_type:str,
|
|
186
|
+
model_type: str,
|
|
187
|
+
inference_compute: InferenceCompute,
|
|
188
|
+
inference_settings: InferenceSettings,
|
|
189
|
+
assets: Optional[dict] = None,
|
|
190
|
+
) -> str:
|
|
191
|
+
"""Initialize text model
|
|
192
|
+
|
|
193
|
+
:param model_provider: model of provider
|
|
194
|
+
:param model_name: name of the model to be initialized
|
|
195
|
+
:param model_task_type: task type of model
|
|
196
|
+
:return: response
|
|
197
|
+
"""
|
|
198
|
+
payload = {
|
|
199
|
+
"model_provider": model_provider,
|
|
200
|
+
"model_name": model_name,
|
|
201
|
+
"model_task_type": model_task_type,
|
|
202
|
+
"project_name": self.project_name,
|
|
203
|
+
"model_type": model_type,
|
|
204
|
+
"inference_compute": inference_compute,
|
|
205
|
+
"inference_settings": inference_settings
|
|
206
|
+
}
|
|
207
|
+
if assets:
|
|
208
|
+
payload["assets"] = assets
|
|
209
|
+
res = self.api_client.post(f"{INITIALIZE_TEXT_MODEL_URI}", payload)
|
|
210
|
+
if not res["success"]:
|
|
211
|
+
raise Exception(res.get("details", "Model Initialization Failed"))
|
|
212
|
+
poll_events(self.api_client, self.project_name, res["event_id"])
|
|
213
|
+
|
|
214
|
+
def model_inference_settings(
|
|
215
|
+
self,
|
|
216
|
+
model_name: str,
|
|
217
|
+
inference_compute: InferenceCompute,
|
|
218
|
+
inference_settings: InferenceSettings,
|
|
219
|
+
) -> str:
|
|
220
|
+
"""Model Inference Settings
|
|
221
|
+
|
|
222
|
+
:param model_provider: model of provider
|
|
223
|
+
:param model_name: name of the model to be initialized
|
|
224
|
+
:param model_task_type: task type of model
|
|
225
|
+
:return: response
|
|
226
|
+
"""
|
|
227
|
+
payload = {
|
|
228
|
+
"model_name": model_name,
|
|
229
|
+
"project_name": self.project_name,
|
|
230
|
+
"inference_compute": inference_compute,
|
|
231
|
+
"inference_settings": inference_settings
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
res = self.api_client.post(f"{TEXT_MODEL_INFERENCE_SETTINGS_URI}", payload)
|
|
235
|
+
if not res["success"]:
|
|
236
|
+
raise Exception(res.get("details", "Failed to update inference settings"))
|
|
237
|
+
|
|
238
|
+
def generate_text_case(
|
|
239
|
+
self,
|
|
240
|
+
model_name: str,
|
|
241
|
+
prompt: str,
|
|
242
|
+
instance_type: Optional[str] = "gova-2",
|
|
243
|
+
serverless_instance_type: Optional[str] = "xsmall",
|
|
244
|
+
explainability_method: Optional[list] = ["DLB"],
|
|
245
|
+
explain_model: Optional[bool] = False,
|
|
246
|
+
session_id: Optional[str] = None,
|
|
247
|
+
max_tokens: Optional[int] = None,
|
|
248
|
+
min_tokens: Optional[int] = None,
|
|
249
|
+
stream: Optional[bool] = False,
|
|
250
|
+
) -> dict:
|
|
251
|
+
"""Generate Text Case
|
|
252
|
+
|
|
253
|
+
:param model_name: name of the model
|
|
254
|
+
:param model_type: type of the model
|
|
255
|
+
:param input_text: input text for the case
|
|
256
|
+
:param tag: tag for the case
|
|
257
|
+
:param task_type: task type for the case, defaults to None
|
|
258
|
+
:param instance_type: instance type for the case, defaults to None
|
|
259
|
+
:param explainability_method: explainability method for the case, defaults to None
|
|
260
|
+
:param explain_model: explain model for the case, defaults to False
|
|
261
|
+
:return: response
|
|
262
|
+
"""
|
|
263
|
+
llm = monitor(
|
|
264
|
+
project=self, client=LexsiModels(project=self, api_client=self.api_client), session_id=session_id
|
|
265
|
+
)
|
|
266
|
+
res = llm.generate_text_case(
|
|
267
|
+
model_name=model_name,
|
|
268
|
+
prompt=prompt,
|
|
269
|
+
instance_type=instance_type,
|
|
270
|
+
serverless_instance_type=serverless_instance_type,
|
|
271
|
+
explainability_method=explainability_method,
|
|
272
|
+
explain_model=explain_model,
|
|
273
|
+
max_tokens=max_tokens,
|
|
274
|
+
min_tokens=min_tokens,
|
|
275
|
+
stream=stream
|
|
276
|
+
)
|
|
277
|
+
return res
|
|
278
|
+
|
|
279
|
+
def available_text_models(self) -> pd.DataFrame:
|
|
280
|
+
"""Get available text models
|
|
281
|
+
|
|
282
|
+
:return: list of available text models
|
|
283
|
+
"""
|
|
284
|
+
res = self.api_client.get(f"{GET_AVAILABLE_TEXT_MODELS_URI}")
|
|
285
|
+
if not res["success"]:
|
|
286
|
+
raise Exception(res.get("details"," Failed to fetch available text models"))
|
|
287
|
+
return pd.DataFrame(res.get("details"))
|
|
288
|
+
|
|
289
|
+
def upload_data(
|
|
290
|
+
self,
|
|
291
|
+
data: str | pd.DataFrame,
|
|
292
|
+
tag: str,
|
|
293
|
+
) -> str:
|
|
294
|
+
"""Upload text data for the current project.
|
|
295
|
+
|
|
296
|
+
:param data: File path or DataFrame containing rows to upload.
|
|
297
|
+
:param tag: Tag to associate with the uploaded data.
|
|
298
|
+
:return: Server response details.
|
|
299
|
+
"""
|
|
300
|
+
def build_upload_data(data):
|
|
301
|
+
"""Prepare file payload from path or DataFrame.
|
|
302
|
+
|
|
303
|
+
:param data: File path or DataFrame to convert.
|
|
304
|
+
:return: Tuple or file handle suitable for multipart upload.
|
|
305
|
+
"""
|
|
306
|
+
if isinstance(data, str):
|
|
307
|
+
file = open(data, "rb")
|
|
308
|
+
return file
|
|
309
|
+
elif isinstance(data, pd.DataFrame):
|
|
310
|
+
csv_buffer = io.BytesIO()
|
|
311
|
+
data.to_csv(csv_buffer, index=False, encoding="utf-8")
|
|
312
|
+
csv_buffer.seek(0)
|
|
313
|
+
file_name = f"{tag}_sdk_{datetime.now().replace(microsecond=0)}.csv"
|
|
314
|
+
file = (file_name, csv_buffer.getvalue())
|
|
315
|
+
return file
|
|
316
|
+
else:
|
|
317
|
+
raise Exception("Invalid Data Type")
|
|
318
|
+
|
|
319
|
+
def upload_file_and_return_path(data, data_type, tag=None) -> str:
|
|
320
|
+
"""Upload a file and return the stored path.
|
|
321
|
+
|
|
322
|
+
:param data: Data payload (path or DataFrame).
|
|
323
|
+
:param data_type: Type of data being uploaded.
|
|
324
|
+
:param tag: Optional tag.
|
|
325
|
+
:return: File path stored on the server.
|
|
326
|
+
"""
|
|
327
|
+
files = {"in_file": build_upload_data(data)}
|
|
328
|
+
res = self.api_client.file(
|
|
329
|
+
f"{UPLOAD_DATA_FILE_URI}?project_name={self.project_name}&data_type={data_type}&tag={tag}",
|
|
330
|
+
files,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
if not res["success"]:
|
|
334
|
+
raise Exception(res.get("details"))
|
|
335
|
+
uploaded_path = res.get("metadata").get("filepath")
|
|
336
|
+
|
|
337
|
+
return uploaded_path
|
|
338
|
+
|
|
339
|
+
uploaded_path = upload_file_and_return_path(data, "data", tag)
|
|
340
|
+
|
|
341
|
+
payload = {
|
|
342
|
+
"path": uploaded_path,
|
|
343
|
+
"tag": tag,
|
|
344
|
+
"type": "data",
|
|
345
|
+
"project_name": self.project_name,
|
|
346
|
+
}
|
|
347
|
+
res = self.api_client.post(UPLOAD_DATA_URI, payload)
|
|
348
|
+
|
|
349
|
+
if not res["success"]:
|
|
350
|
+
self.delete_file(uploaded_path)
|
|
351
|
+
raise Exception(res.get("details"))
|
|
352
|
+
|
|
353
|
+
return res.get("details")
|
|
354
|
+
|
|
355
|
+
def upload_data_dataconnectors(
|
|
356
|
+
self,
|
|
357
|
+
data_connector_name: str,
|
|
358
|
+
tag: str,
|
|
359
|
+
bucket_name: Optional[str] = None,
|
|
360
|
+
file_path: Optional[str] = None,
|
|
361
|
+
dataset_name: Optional[str] = None
|
|
362
|
+
):
|
|
363
|
+
"""Upload text data stored in a configured data connector.
|
|
364
|
+
|
|
365
|
+
:param data_connector_name: Name of the configured connector.
|
|
366
|
+
:param tag: Tag to associate with uploaded data.
|
|
367
|
+
:param bucket_name: Bucket/location name when required by connector.
|
|
368
|
+
:param file_path: File path within the connector store.
|
|
369
|
+
:param dataset_name: Optional dataset name to persist.
|
|
370
|
+
:return: Server response details.
|
|
371
|
+
"""
|
|
372
|
+
def get_connector() -> str | pd.DataFrame:
|
|
373
|
+
"""Fetch connector metadata for the requested link service.
|
|
374
|
+
|
|
375
|
+
:return: DataFrame of connector info or error string.
|
|
376
|
+
"""
|
|
377
|
+
url = build_list_data_connector_url(
|
|
378
|
+
LIST_DATA_CONNECTORS, self.project_name, self.organization_id
|
|
379
|
+
)
|
|
380
|
+
res = self.api_client.post(url)
|
|
381
|
+
|
|
382
|
+
if res["success"]:
|
|
383
|
+
df = pd.DataFrame(res["details"])
|
|
384
|
+
filtered_df = df.loc[df["link_service_name"] == data_connector_name]
|
|
385
|
+
if filtered_df.empty:
|
|
386
|
+
return "No data connector found"
|
|
387
|
+
return filtered_df
|
|
388
|
+
|
|
389
|
+
return res["details"]
|
|
390
|
+
|
|
391
|
+
connectors = get_connector()
|
|
392
|
+
if isinstance(connectors, pd.DataFrame):
|
|
393
|
+
value = connectors.loc[
|
|
394
|
+
connectors["link_service_name"] == data_connector_name,
|
|
395
|
+
"link_service_type",
|
|
396
|
+
].values[0]
|
|
397
|
+
ds_type = value
|
|
398
|
+
|
|
399
|
+
if ds_type == "s3" or ds_type == "gcs":
|
|
400
|
+
if not bucket_name:
|
|
401
|
+
return "Missing argument bucket_name"
|
|
402
|
+
if not file_path:
|
|
403
|
+
return "Missing argument file_path"
|
|
404
|
+
else:
|
|
405
|
+
return connectors
|
|
406
|
+
|
|
407
|
+
def upload_file_and_return_path(file_path, data_type, tag=None) -> str:
|
|
408
|
+
"""Upload a file from connector storage and return stored path.
|
|
409
|
+
|
|
410
|
+
:param file_path: Path within the connector store.
|
|
411
|
+
:param data_type: Type of data being uploaded.
|
|
412
|
+
:param tag: Optional tag for the upload.
|
|
413
|
+
:return: Stored file path returned by the API.
|
|
414
|
+
"""
|
|
415
|
+
if not self.project_name:
|
|
416
|
+
return "Missing Project Name"
|
|
417
|
+
query_params = f"project_name={self.project_name}&link_service_name={data_connector_name}&data_type={data_type}&tag={tag}&bucket_name={bucket_name}&file_path={file_path}&dataset_name={dataset_name}"
|
|
418
|
+
if self.organization_id:
|
|
419
|
+
query_params += f"&organization_id={self.organization_id}"
|
|
420
|
+
res = self.api_client.post(f"{UPLOAD_FILE_DATA_CONNECTORS}?{query_params}")
|
|
421
|
+
if not res["success"]:
|
|
422
|
+
raise Exception(res.get("details"))
|
|
423
|
+
uploaded_path = res.get("metadata").get("filepath")
|
|
424
|
+
|
|
425
|
+
return uploaded_path
|
|
426
|
+
|
|
427
|
+
uploaded_path = upload_file_and_return_path(file_path, "data", tag)
|
|
428
|
+
|
|
429
|
+
payload = {
|
|
430
|
+
"path": uploaded_path,
|
|
431
|
+
"tag": tag,
|
|
432
|
+
"type": "data",
|
|
433
|
+
"project_name": self.project_name,
|
|
434
|
+
}
|
|
435
|
+
res = self.api_client.post(UPLOAD_DATA_URI, payload)
|
|
436
|
+
|
|
437
|
+
if not res["success"]:
|
|
438
|
+
self.delete_file(uploaded_path)
|
|
439
|
+
raise Exception(res.get("details"))
|
|
440
|
+
|
|
441
|
+
return res.get("details")
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def quantize_model(
|
|
445
|
+
self,
|
|
446
|
+
model_name:str,
|
|
447
|
+
quant_name:str,
|
|
448
|
+
quantization_type:str,
|
|
449
|
+
qbit: int,
|
|
450
|
+
instance_type: str,
|
|
451
|
+
tag: Optional[str] = None,
|
|
452
|
+
input_column: Optional[str] = None,
|
|
453
|
+
no_of_samples: Optional[str] = None,
|
|
454
|
+
):
|
|
455
|
+
"""Quantize Model
|
|
456
|
+
|
|
457
|
+
:param model_name: name of the model
|
|
458
|
+
:param quant_name: quant name of the model
|
|
459
|
+
:param quantization_type: type of quantization
|
|
460
|
+
:param qbit: quantization bit
|
|
461
|
+
:param instance_type: instance type for the quantization
|
|
462
|
+
:param tag: tag name to pass
|
|
463
|
+
:param input_column: input column for the data
|
|
464
|
+
:param no_of_samples: no of samples for quantization to perform
|
|
465
|
+
:return: response
|
|
466
|
+
"""
|
|
467
|
+
payload = {
|
|
468
|
+
"project_name": self.project_name,
|
|
469
|
+
"model_name": model_name,
|
|
470
|
+
"quant_name": quant_name,
|
|
471
|
+
"quantization_type": quantization_type,
|
|
472
|
+
"qbit": qbit,
|
|
473
|
+
"instance_type": instance_type,
|
|
474
|
+
"tag": tag,
|
|
475
|
+
"input_column": input_column,
|
|
476
|
+
"no_of_samples": no_of_samples
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
res = self.api_client.post(QUANTIZE_MODELS_URI, payload)
|
|
480
|
+
if not res["success"]:
|
|
481
|
+
raise Exception(res.get("details"))
|
|
482
|
+
|
|
483
|
+
poll_events(self.api_client, self.project_name, res.get("event_id"))
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def chat_completion(
|
|
487
|
+
self,
|
|
488
|
+
model: str,
|
|
489
|
+
messages: List[Dict[str, Any]],
|
|
490
|
+
provider: str,
|
|
491
|
+
api_key: Optional[str] = None,
|
|
492
|
+
session_id : Optional[UUID] = None,
|
|
493
|
+
max_tokens: Optional[int] = None,
|
|
494
|
+
stream: Optional[bool] = False,
|
|
495
|
+
) -> Union[dict, Iterator[str]]:
|
|
496
|
+
"""Chat completion endpoint wrapper
|
|
497
|
+
|
|
498
|
+
:param model: name of the model
|
|
499
|
+
:param messages: list of chat messages
|
|
500
|
+
:param provider: model provider (e.g., "openai", "anthropic")
|
|
501
|
+
:param api_key: API key for the provider
|
|
502
|
+
:param max_tokens: maximum tokens to generate
|
|
503
|
+
:param stream: whether to stream the response
|
|
504
|
+
:return: chat completion response or stream iterator
|
|
505
|
+
"""
|
|
506
|
+
payload = {
|
|
507
|
+
"model": model,
|
|
508
|
+
"messages": messages,
|
|
509
|
+
"max_tokens": max_tokens,
|
|
510
|
+
"stream": stream,
|
|
511
|
+
"project_name": self.project_name,
|
|
512
|
+
"provider": provider,
|
|
513
|
+
"api_key": api_key,
|
|
514
|
+
"session_id" : session_id
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
if not stream:
|
|
518
|
+
return self.api_client.post(RUN_CHAT_COMPLETION, payload=payload)
|
|
519
|
+
|
|
520
|
+
return self.api_client.stream(uri=RUN_CHAT_COMPLETION, method="POST", payload=payload)
|
|
521
|
+
|
|
522
|
+
def create_embeddings(
|
|
523
|
+
self,
|
|
524
|
+
input : Union[str, List[str]],
|
|
525
|
+
model: str,
|
|
526
|
+
api_key : str,
|
|
527
|
+
provider: str,
|
|
528
|
+
session_id : Optional[UUID] = None,
|
|
529
|
+
) -> dict:
|
|
530
|
+
payload = {
|
|
531
|
+
"model": model,
|
|
532
|
+
"input": input,
|
|
533
|
+
"project_name": self.project_name,
|
|
534
|
+
"provider": provider,
|
|
535
|
+
"api_key": api_key,
|
|
536
|
+
"session_id" : session_id
|
|
537
|
+
}
|
|
538
|
+
|
|
539
|
+
res = self.api_client.post(RUN_CREATE_EMBEDDING, payload=payload)
|
|
540
|
+
return res
|
|
541
|
+
|
|
542
|
+
def completion(
|
|
543
|
+
self,
|
|
544
|
+
model: str,
|
|
545
|
+
prompt: str,
|
|
546
|
+
provider: str,
|
|
547
|
+
api_key: Optional[str] = None,
|
|
548
|
+
session_id : Optional[UUID] = None,
|
|
549
|
+
max_tokens: Optional[int] = None,
|
|
550
|
+
stream: Optional[bool] = False,
|
|
551
|
+
) -> dict:
|
|
552
|
+
|
|
553
|
+
payload = {
|
|
554
|
+
"model": model,
|
|
555
|
+
"prompt": prompt,
|
|
556
|
+
"max_tokens": max_tokens,
|
|
557
|
+
"stream": stream,
|
|
558
|
+
"project_name": self.project_name,
|
|
559
|
+
"provider": provider,
|
|
560
|
+
"api_key": api_key,
|
|
561
|
+
"session_id" : session_id
|
|
562
|
+
}
|
|
563
|
+
if not stream:
|
|
564
|
+
return self.api_client.post(RUN_COMPLETION, payload=payload)
|
|
565
|
+
|
|
566
|
+
return self.api_client.stream(uri=RUN_COMPLETION, method="POST", payload=payload)
|
|
567
|
+
|
|
568
|
+
def image_generation(
|
|
569
|
+
self,
|
|
570
|
+
model: str,
|
|
571
|
+
prompt: str,
|
|
572
|
+
provider: str,
|
|
573
|
+
api_key: str,
|
|
574
|
+
session_id : Optional[UUID] = None,
|
|
575
|
+
) -> dict:
|
|
576
|
+
"""Image generation endpoint wrapper
|
|
577
|
+
|
|
578
|
+
:param model: name of the model
|
|
579
|
+
:param prompt: image generation prompt
|
|
580
|
+
:param provider: model provider (e.g., "openai", "stability")
|
|
581
|
+
:param api_key: API key for the provider
|
|
582
|
+
:return: image generation response
|
|
583
|
+
"""
|
|
584
|
+
payload = {
|
|
585
|
+
"model": model,
|
|
586
|
+
"prompt": prompt,
|
|
587
|
+
"project_name": self.project_name,
|
|
588
|
+
"provider": provider,
|
|
589
|
+
"api_key": api_key,
|
|
590
|
+
"session_id" : session_id
|
|
591
|
+
}
|
|
592
|
+
|
|
593
|
+
res = self.api_client.post(RUN_IMAGE_GENERATION, payload=payload)
|
|
594
|
+
|
|
595
|
+
return res
|