rasa-pro 3.11.0a3__py3-none-any.whl → 3.11.0a4.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +17 -396
- rasa/api.py +4 -0
- rasa/cli/arguments/train.py +14 -0
- rasa/cli/inspect.py +1 -1
- rasa/cli/interactive.py +1 -0
- rasa/cli/project_templates/calm/endpoints.yml +7 -2
- rasa/cli/project_templates/tutorial/endpoints.yml +7 -2
- rasa/cli/train.py +3 -0
- rasa/constants.py +2 -0
- rasa/core/actions/action.py +75 -33
- rasa/core/actions/action_repeat_bot_messages.py +72 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +5 -1
- rasa/core/actions/http_custom_action_executor.py +4 -0
- rasa/core/channels/socketio.py +5 -1
- rasa/core/channels/voice_ready/utils.py +6 -5
- rasa/core/channels/voice_stream/browser_audio.py +1 -1
- rasa/core/channels/voice_stream/twilio_media_streams.py +1 -1
- rasa/core/nlg/contextual_response_rephraser.py +19 -2
- rasa/core/persistor.py +87 -21
- rasa/core/utils.py +53 -22
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +3 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +19 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +5 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/e2e_test/utils/io.py +2 -0
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +18 -0
- rasa/model_manager/model_api.py +469 -0
- rasa/model_manager/runner_service.py +279 -0
- rasa/model_manager/socket_bridge.py +143 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +332 -0
- rasa/model_manager/utils.py +66 -0
- rasa/model_service.py +109 -0
- rasa/model_training.py +25 -7
- rasa/shared/constants.py +6 -0
- rasa/shared/core/constants.py +2 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +15 -3
- rasa/shared/utils/yaml.py +10 -1
- rasa/utils/endpoints.py +27 -1
- rasa/version.py +1 -1
- rasa_pro-3.11.0a4.dev1.dist-info/METADATA +197 -0
- {rasa_pro-3.11.0a3.dist-info → rasa_pro-3.11.0a4.dev1.dist-info}/RECORD +48 -38
- rasa/keys +0 -1
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +0 -407
- rasa_pro-3.11.0a3.dist-info/METADATA +0 -576
- {rasa_pro-3.11.0a3.dist-info → rasa_pro-3.11.0a4.dev1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0a3.dist-info → rasa_pro-3.11.0a4.dev1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0a3.dist-info → rasa_pro-3.11.0a4.dev1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,469 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from functools import wraps
|
|
3
|
+
import os
|
|
4
|
+
from http import HTTPStatus
|
|
5
|
+
from typing import Any, Dict, Optional, Callable
|
|
6
|
+
import dotenv
|
|
7
|
+
from sanic import Blueprint, Sanic, response
|
|
8
|
+
from sanic.response import json
|
|
9
|
+
from sanic.exceptions import NotFound
|
|
10
|
+
from sanic.request import Request
|
|
11
|
+
import structlog
|
|
12
|
+
from socketio import AsyncServer
|
|
13
|
+
|
|
14
|
+
from rasa.exceptions import ModelNotFound
|
|
15
|
+
from rasa.model_manager import config, studio_jwt_auth
|
|
16
|
+
from rasa.model_manager.config import SERVER_BASE_URL
|
|
17
|
+
from rasa.constants import MODEL_ARCHIVE_EXTENSION
|
|
18
|
+
from rasa.model_manager.runner_service import (
|
|
19
|
+
BotSession,
|
|
20
|
+
fetch_remote_model_to_dir,
|
|
21
|
+
run_bot,
|
|
22
|
+
terminate_bot,
|
|
23
|
+
update_bot_status,
|
|
24
|
+
)
|
|
25
|
+
from rasa.model_manager.socket_bridge import create_bridge_server
|
|
26
|
+
from rasa.model_manager.trainer_service import (
|
|
27
|
+
TrainingSession,
|
|
28
|
+
run_training,
|
|
29
|
+
terminate_training,
|
|
30
|
+
update_training_status,
|
|
31
|
+
)
|
|
32
|
+
from rasa.model_manager.utils import (
|
|
33
|
+
get_logs_content,
|
|
34
|
+
logs_base_path,
|
|
35
|
+
models_base_path,
|
|
36
|
+
subpath,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
dotenv.load_dotenv()
|
|
40
|
+
|
|
41
|
+
structlogger = structlog.get_logger()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# A simple in-memory store for training sessions and running bots
|
|
45
|
+
trainings: Dict[str, TrainingSession] = {}
|
|
46
|
+
running_bots: Dict[str, BotSession] = {}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def prepare_working_directories() -> None:
|
|
50
|
+
"""Make sure all required directories exist."""
|
|
51
|
+
os.makedirs(logs_base_path(), exist_ok=True)
|
|
52
|
+
os.makedirs(models_base_path(), exist_ok=True)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def cleanup_training_processes() -> None:
|
|
56
|
+
"""Terminate all running training processes."""
|
|
57
|
+
structlogger.debug("model_trainer.cleanup_processes.started")
|
|
58
|
+
running = list(trainings.values())
|
|
59
|
+
for training in running:
|
|
60
|
+
terminate_training(training)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def cleanup_bot_processes() -> None:
|
|
64
|
+
"""Terminate all running bot processes."""
|
|
65
|
+
structlogger.debug("model_runner.cleanup_processes.started")
|
|
66
|
+
running = list(running_bots.values())
|
|
67
|
+
for bot in running:
|
|
68
|
+
terminate_bot(bot)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def update_status_of_all_trainings() -> None:
|
|
72
|
+
"""Update the status of all training processes."""
|
|
73
|
+
running = list(trainings.values())
|
|
74
|
+
for training in running:
|
|
75
|
+
update_training_status(training)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
async def update_status_of_all_bots() -> None:
|
|
79
|
+
"""Update the status of all bot processes."""
|
|
80
|
+
# we need to get the values first, because (since we are async and waiting
|
|
81
|
+
# within the loop) some other job on the asyncio loop could change the dict
|
|
82
|
+
# (adding or removing). python doesn't like if you change the size of a dict
|
|
83
|
+
# while iterating over it and will raise a RuntimeError. so we get the values
|
|
84
|
+
# first and iterate over them to avoid that.
|
|
85
|
+
running = list(running_bots.values())
|
|
86
|
+
for bot in running:
|
|
87
|
+
await update_bot_status(bot)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def base_server_url(request: Request) -> str:
|
|
91
|
+
"""Return the base URL of the server."""
|
|
92
|
+
if SERVER_BASE_URL:
|
|
93
|
+
return SERVER_BASE_URL.rstrip("/")
|
|
94
|
+
else:
|
|
95
|
+
return f"{request.scheme}://{request.host}"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
async def continuously_update_process_status() -> None:
|
|
99
|
+
"""Regularly Update the status of all training and bot processes."""
|
|
100
|
+
structlogger.debug("model_api.update_process_status.started")
|
|
101
|
+
|
|
102
|
+
while True:
|
|
103
|
+
try:
|
|
104
|
+
update_status_of_all_trainings()
|
|
105
|
+
await update_status_of_all_bots()
|
|
106
|
+
except asyncio.exceptions.CancelledError:
|
|
107
|
+
structlogger.debug("model_api.update_process_status.cancelled")
|
|
108
|
+
break
|
|
109
|
+
except Exception as e:
|
|
110
|
+
structlogger.error("model_api.update_process_status.error", error=str(e))
|
|
111
|
+
finally:
|
|
112
|
+
await asyncio.sleep(1)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def requires_studio_auth() -> Callable:
|
|
116
|
+
"""Wraps a request handler with token authentication."""
|
|
117
|
+
|
|
118
|
+
def decorator(f: Callable) -> Callable:
|
|
119
|
+
@wraps(f)
|
|
120
|
+
async def decorated(
|
|
121
|
+
request: Request, *args: Any, **kwargs: Any
|
|
122
|
+
) -> response.HTTPResponse:
|
|
123
|
+
# get token from bearer in auth header
|
|
124
|
+
provided = request.headers.get("Authorization", "").split("Bearer ")[-1]
|
|
125
|
+
try:
|
|
126
|
+
studio_jwt_auth.authenticate_user_to_service(provided)
|
|
127
|
+
return await f(request, *args, **kwargs)
|
|
128
|
+
except studio_jwt_auth.UserToServiceAuthenticationError:
|
|
129
|
+
return response.json({"message": "User not authenticated."}, status=401)
|
|
130
|
+
|
|
131
|
+
return decorated
|
|
132
|
+
|
|
133
|
+
return decorator
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def internal_blueprint() -> Blueprint:
|
|
137
|
+
"""Create a blueprint for the model manager API."""
|
|
138
|
+
bp = Blueprint("model_api_internal")
|
|
139
|
+
|
|
140
|
+
@bp.before_server_stop
|
|
141
|
+
async def cleanup_processes(app: Sanic, loop: asyncio.AbstractEventLoop) -> None:
|
|
142
|
+
"""Terminate all running processes before the server stops."""
|
|
143
|
+
structlogger.debug("model_api.cleanup_processes.started")
|
|
144
|
+
cleanup_training_processes()
|
|
145
|
+
cleanup_bot_processes()
|
|
146
|
+
|
|
147
|
+
@bp.on_request # type: ignore[misc]
|
|
148
|
+
async def limit_parallel_training_requests(request: Request) -> Any:
|
|
149
|
+
"""Limit the number of parallel training requests."""
|
|
150
|
+
from rasa.model_manager.config import MAX_PARALLEL_TRAININGS
|
|
151
|
+
|
|
152
|
+
if not request.url.endswith("/training"):
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
running_requests = len(
|
|
156
|
+
[
|
|
157
|
+
training
|
|
158
|
+
for training in trainings.values()
|
|
159
|
+
if training.status == "running" and training.process.poll() is None
|
|
160
|
+
]
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if running_requests >= int(MAX_PARALLEL_TRAININGS):
|
|
164
|
+
return response.json(
|
|
165
|
+
{
|
|
166
|
+
"message": f"Too many parallel training requests, above "
|
|
167
|
+
f"the limit of {MAX_PARALLEL_TRAININGS}. "
|
|
168
|
+
f"Retry later or increase your server's "
|
|
169
|
+
f"memory and CPU resources."
|
|
170
|
+
},
|
|
171
|
+
status=HTTPStatus.TOO_MANY_REQUESTS,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
@bp.on_request # type: ignore[misc]
|
|
175
|
+
async def limit_parallel_bot_runs(request: Request) -> Any:
|
|
176
|
+
"""Limit the number of parallel bot runs."""
|
|
177
|
+
from rasa.model_manager.config import MAX_PARALLEL_TRAININGS
|
|
178
|
+
|
|
179
|
+
if not request.url.endswith("/bot"):
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
running_requests = len(
|
|
183
|
+
[
|
|
184
|
+
bot
|
|
185
|
+
for bot in running_bots.values()
|
|
186
|
+
if bot.status in {"running", "queued"}
|
|
187
|
+
]
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if running_requests >= int(MAX_PARALLEL_TRAININGS):
|
|
191
|
+
return response.json(
|
|
192
|
+
{
|
|
193
|
+
"message": f"Too many parallel bot runs, above "
|
|
194
|
+
f"the limit of {MAX_PARALLEL_TRAININGS}. "
|
|
195
|
+
f"Retry later or increase your server's "
|
|
196
|
+
f"memory and CPU resources."
|
|
197
|
+
},
|
|
198
|
+
status=HTTPStatus.TOO_MANY_REQUESTS,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
@bp.get("/")
|
|
202
|
+
async def health(request: Request) -> response.HTTPResponse:
|
|
203
|
+
return json(
|
|
204
|
+
{
|
|
205
|
+
"status": "ok",
|
|
206
|
+
"bots": [
|
|
207
|
+
{
|
|
208
|
+
"deployment_id": bot.deployment_id,
|
|
209
|
+
"status": bot.status,
|
|
210
|
+
"internal_url": bot.internal_url,
|
|
211
|
+
"url": bot.url,
|
|
212
|
+
}
|
|
213
|
+
for bot in running_bots.values()
|
|
214
|
+
],
|
|
215
|
+
"trainings": [
|
|
216
|
+
{
|
|
217
|
+
"training_id": training.training_id,
|
|
218
|
+
"assistant_id": training.assistant_id,
|
|
219
|
+
"client_id": training.client_id,
|
|
220
|
+
"progress": training.progress,
|
|
221
|
+
"status": training.status,
|
|
222
|
+
}
|
|
223
|
+
for training in trainings.values()
|
|
224
|
+
],
|
|
225
|
+
}
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
@bp.get("/training")
|
|
229
|
+
async def get_training_list(request: Request) -> response.HTTPResponse:
|
|
230
|
+
"""Return a list of all training sessions for an assistant."""
|
|
231
|
+
assistant_id = request.args.get("assistant_id")
|
|
232
|
+
sessions = [
|
|
233
|
+
{
|
|
234
|
+
"training_id": session.training_id,
|
|
235
|
+
"assistant_id": session.assistant_id,
|
|
236
|
+
"client_id": session.client_id,
|
|
237
|
+
"progress": session.progress,
|
|
238
|
+
"status": session.status,
|
|
239
|
+
"model_name": session.model_name,
|
|
240
|
+
"runtime_metadata": None,
|
|
241
|
+
}
|
|
242
|
+
for session in trainings.values()
|
|
243
|
+
if session.assistant_id == assistant_id
|
|
244
|
+
]
|
|
245
|
+
return json({"training_sessions": sessions, "total_number": len(sessions)})
|
|
246
|
+
|
|
247
|
+
@bp.post("/training")
|
|
248
|
+
async def start_training(request: Request) -> response.HTTPResponse:
|
|
249
|
+
"""Start a new training session."""
|
|
250
|
+
data = request.json
|
|
251
|
+
training_id: Optional[str] = data.get("id")
|
|
252
|
+
assistant_id: Optional[str] = data.get("assistant_id")
|
|
253
|
+
client_id: Optional[str] = data.get("client_id")
|
|
254
|
+
encoded_training_data: Dict[str, str] = data.get("bot_config", {}).get(
|
|
255
|
+
"data", {}
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
if training_id in trainings:
|
|
259
|
+
# fail, because there apparently is already a training with this id
|
|
260
|
+
return json({"message": "Training with this id already exists"}, status=409)
|
|
261
|
+
|
|
262
|
+
if not assistant_id:
|
|
263
|
+
return json({"message": "Assistant id is required"}, status=400)
|
|
264
|
+
|
|
265
|
+
if not training_id:
|
|
266
|
+
return json({"message": "Training id is required"}, status=400)
|
|
267
|
+
|
|
268
|
+
try:
|
|
269
|
+
training_session = run_training(
|
|
270
|
+
training_id=training_id,
|
|
271
|
+
assistant_id=assistant_id,
|
|
272
|
+
client_id=client_id,
|
|
273
|
+
encoded_training_data=encoded_training_data,
|
|
274
|
+
)
|
|
275
|
+
trainings[training_id] = training_session
|
|
276
|
+
return json(
|
|
277
|
+
{"training_id": training_id, "model_name": training_session.model_name}
|
|
278
|
+
)
|
|
279
|
+
except Exception as e:
|
|
280
|
+
return json({"message": str(e)}, status=500)
|
|
281
|
+
|
|
282
|
+
@bp.get("/training/<training_id>")
|
|
283
|
+
async def get_training(request: Request, training_id: str) -> response.HTTPResponse:
|
|
284
|
+
"""Return the status of a training session."""
|
|
285
|
+
if training := trainings.get(training_id):
|
|
286
|
+
return json(
|
|
287
|
+
{
|
|
288
|
+
"training_id": training_id,
|
|
289
|
+
"assistant_id": training.assistant_id,
|
|
290
|
+
"client_id": training.client_id,
|
|
291
|
+
"progress": training.progress,
|
|
292
|
+
"model_name": training.model_name,
|
|
293
|
+
"status": training.status,
|
|
294
|
+
"logs": get_logs_content(training_id),
|
|
295
|
+
}
|
|
296
|
+
)
|
|
297
|
+
else:
|
|
298
|
+
return json({"message": "Training not found"}, status=404)
|
|
299
|
+
|
|
300
|
+
@bp.delete("/training/<training_id>")
|
|
301
|
+
async def stop_training(
|
|
302
|
+
request: Request, training_id: str
|
|
303
|
+
) -> response.HTTPResponse:
|
|
304
|
+
# this is a no-op if the training is already done
|
|
305
|
+
if not (training := trainings.get(training_id)):
|
|
306
|
+
return json({"message": "Training session not found"}, status=404)
|
|
307
|
+
|
|
308
|
+
terminate_training(training)
|
|
309
|
+
return json({"training_id": training_id})
|
|
310
|
+
|
|
311
|
+
@bp.post("/bot")
|
|
312
|
+
async def start_bot(request: Request) -> response.HTTPResponse:
|
|
313
|
+
data = request.json
|
|
314
|
+
deployment_id: Optional[str] = data.get("deployment_id")
|
|
315
|
+
model_name: Optional[str] = data.get("model_name")
|
|
316
|
+
encoded_configs: Dict[str, str] = data.get("bot_config", {})
|
|
317
|
+
|
|
318
|
+
if deployment_id in running_bots:
|
|
319
|
+
# fail, because there apparently is already a bot running with this id
|
|
320
|
+
return json(
|
|
321
|
+
{"message": "Bot with this deployment id already exists"}, status=409
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
if not deployment_id:
|
|
325
|
+
return json({"message": "Deployment id is required"}, status=400)
|
|
326
|
+
|
|
327
|
+
if not model_name:
|
|
328
|
+
return json({"message": "Model name is required"}, status=400)
|
|
329
|
+
|
|
330
|
+
base_url_path = base_server_url(request)
|
|
331
|
+
try:
|
|
332
|
+
bot_session = run_bot(
|
|
333
|
+
deployment_id,
|
|
334
|
+
model_name,
|
|
335
|
+
base_url_path,
|
|
336
|
+
encoded_configs,
|
|
337
|
+
)
|
|
338
|
+
running_bots[deployment_id] = bot_session
|
|
339
|
+
return json(
|
|
340
|
+
{
|
|
341
|
+
"deployment_id": deployment_id,
|
|
342
|
+
"status": bot_session.status,
|
|
343
|
+
"url": bot_session.url,
|
|
344
|
+
}
|
|
345
|
+
)
|
|
346
|
+
except ModelNotFound:
|
|
347
|
+
return json(
|
|
348
|
+
{"message": f"Model with name '{model_name}' could not be found."},
|
|
349
|
+
status=404,
|
|
350
|
+
)
|
|
351
|
+
except Exception as e:
|
|
352
|
+
return json({"message": str(e)}, status=500)
|
|
353
|
+
|
|
354
|
+
@bp.delete("/bot/<deployment_id>")
|
|
355
|
+
async def stop_bot(request: Request, deployment_id: str) -> response.HTTPResponse:
|
|
356
|
+
bot = running_bots.get(deployment_id)
|
|
357
|
+
if bot is None:
|
|
358
|
+
return json({"message": "Bot not found"}, status=404)
|
|
359
|
+
|
|
360
|
+
terminate_bot(bot)
|
|
361
|
+
|
|
362
|
+
return json(
|
|
363
|
+
{"deployment_id": deployment_id, "status": bot.status, "url": bot.url}
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
@bp.get("/bot/<deployment_id>")
|
|
367
|
+
async def get_bot(request: Request, deployment_id: str) -> response.HTTPResponse:
|
|
368
|
+
bot = running_bots.get(deployment_id)
|
|
369
|
+
if bot is None:
|
|
370
|
+
return json({"message": "Bot not found"}, status=404)
|
|
371
|
+
|
|
372
|
+
return json(
|
|
373
|
+
{
|
|
374
|
+
"deployment_id": deployment_id,
|
|
375
|
+
"status": bot.status,
|
|
376
|
+
"url": bot.url,
|
|
377
|
+
"logs": get_logs_content(deployment_id),
|
|
378
|
+
}
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
@bp.get("/bot")
|
|
382
|
+
async def list_bots(request: Request) -> response.HTTPResponse:
|
|
383
|
+
bots = [
|
|
384
|
+
{
|
|
385
|
+
"deployment_id": bot.deployment_id,
|
|
386
|
+
"status": bot.status,
|
|
387
|
+
"url": bot.url,
|
|
388
|
+
}
|
|
389
|
+
for bot in running_bots.values()
|
|
390
|
+
]
|
|
391
|
+
return json({"deployment_sessions": bots, "total_number": len(bots)})
|
|
392
|
+
|
|
393
|
+
return bp
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def external_blueprint() -> Blueprint:
|
|
397
|
+
"""Create a blueprint for the model manager API."""
|
|
398
|
+
from rasa.core.channels.socketio import SocketBlueprint
|
|
399
|
+
|
|
400
|
+
sio = AsyncServer(async_mode="sanic", cors_allowed_origins=[])
|
|
401
|
+
bp = SocketBlueprint(sio, "", "model_api_external")
|
|
402
|
+
|
|
403
|
+
create_bridge_server(sio, running_bots)
|
|
404
|
+
|
|
405
|
+
@bp.get("/health")
|
|
406
|
+
async def health(request: Request) -> response.HTTPResponse:
|
|
407
|
+
return json(
|
|
408
|
+
{
|
|
409
|
+
"status": "ok",
|
|
410
|
+
"bots": [
|
|
411
|
+
{
|
|
412
|
+
"deployment_id": bot.deployment_id,
|
|
413
|
+
"status": bot.status,
|
|
414
|
+
"internal_url": bot.internal_url,
|
|
415
|
+
"url": bot.url,
|
|
416
|
+
}
|
|
417
|
+
for bot in running_bots.values()
|
|
418
|
+
],
|
|
419
|
+
"trainings": [
|
|
420
|
+
{
|
|
421
|
+
"training_id": training.training_id,
|
|
422
|
+
"assistant_id": training.assistant_id,
|
|
423
|
+
"client_id": training.client_id,
|
|
424
|
+
"progress": training.progress,
|
|
425
|
+
"status": training.status,
|
|
426
|
+
}
|
|
427
|
+
for training in trainings.values()
|
|
428
|
+
],
|
|
429
|
+
}
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
@bp.route("/models/<model_name>")
|
|
433
|
+
@requires_studio_auth()
|
|
434
|
+
async def send_model(request: Request, model_name: str) -> response.HTTPResponse:
|
|
435
|
+
try:
|
|
436
|
+
model_path = path_to_model(model_name)
|
|
437
|
+
|
|
438
|
+
if not model_path:
|
|
439
|
+
return json({"message": "Model not found"}, status=404)
|
|
440
|
+
|
|
441
|
+
return await response.file(model_path)
|
|
442
|
+
except NotFound:
|
|
443
|
+
return json({"message": "Model not found"}, status=404)
|
|
444
|
+
except ModelNotFound:
|
|
445
|
+
return json({"message": "Model not found"}, status=404)
|
|
446
|
+
|
|
447
|
+
return bp
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def path_to_model(model_name: str) -> Optional[str]:
|
|
451
|
+
"""Return the path to a local model."""
|
|
452
|
+
model_file_name = f"{model_name}.{MODEL_ARCHIVE_EXTENSION}"
|
|
453
|
+
model_path = subpath(models_base_path(), model_file_name)
|
|
454
|
+
|
|
455
|
+
if os.path.exists(model_path):
|
|
456
|
+
return model_path
|
|
457
|
+
|
|
458
|
+
if config.SERVER_MODEL_REMOTE_STORAGE:
|
|
459
|
+
structlogger.info(
|
|
460
|
+
"model_api.storage.fetching_remote_model",
|
|
461
|
+
model_name=model_file_name,
|
|
462
|
+
)
|
|
463
|
+
return fetch_remote_model_to_dir(
|
|
464
|
+
model_file_name,
|
|
465
|
+
models_base_path(),
|
|
466
|
+
config.SERVER_MODEL_REMOTE_STORAGE,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
return None
|