llmstudio 0.3.2__tar.gz → 0.3.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {llmstudio-0.3.2 → llmstudio-0.3.4}/PKG-INFO +3 -2
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/cli.py +16 -40
- llmstudio-0.3.4/llmstudio/config.py +47 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/engine/__init__.py +9 -7
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/engine/providers/azure.py +39 -19
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/engine/providers/provider.py +3 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/llm/__init__.py +8 -7
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/llm/langchain.py +1 -1
- llmstudio-0.3.4/llmstudio/tracking/__init__.py +64 -0
- llmstudio-0.3.4/llmstudio/tracking/database.py +25 -0
- llmstudio-0.3.4/llmstudio/tracking/logs/__init__.py +0 -0
- {llmstudio-0.3.2/llmstudio/tracking → llmstudio-0.3.4/llmstudio/tracking/logs}/crud.py +2 -1
- llmstudio-0.3.4/llmstudio/tracking/logs/endpoints.py +57 -0
- {llmstudio-0.3.2/llmstudio/tracking → llmstudio-0.3.4/llmstudio/tracking/logs}/models.py +1 -2
- llmstudio-0.3.4/llmstudio/tracking/session/__init__.py +0 -0
- llmstudio-0.3.4/llmstudio/tracking/session/crud.py +49 -0
- llmstudio-0.3.4/llmstudio/tracking/session/endpoints.py +46 -0
- llmstudio-0.3.4/llmstudio/tracking/session/models.py +16 -0
- llmstudio-0.3.4/llmstudio/tracking/session/schemas.py +20 -0
- llmstudio-0.3.4/llmstudio/tracking/tracker.py +47 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/__init__.py +6 -6
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/DataTable/DataTable.tsx +6 -1
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/DataTable/columns.tsx +3 -2
- {llmstudio-0.3.2 → llmstudio-0.3.4}/pyproject.toml +3 -2
- llmstudio-0.3.2/llmstudio/.env.template +0 -12
- llmstudio-0.3.2/llmstudio/tracking/__init__.py +0 -96
- llmstudio-0.3.2/llmstudio/tracking/database.py +0 -27
- llmstudio-0.3.2/llmstudio/tracking/tracker.py +0 -24
- {llmstudio-0.3.2 → llmstudio-0.3.4}/LICENSE +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/README.md +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/__init__.py +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/client.py +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/engine/config.yaml +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/engine/providers/__init__.py +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/engine/providers/anthropic.py +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/engine/providers/ollama.py +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/engine/providers/openai.py +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/tests/__init__.py +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/tests/conftest.py +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/tests/engine/test_engine.py +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/tests/engine/test_providers.py +0 -0
- {llmstudio-0.3.2/llmstudio/tracking → llmstudio-0.3.4/llmstudio/tracking/logs}/schemas.py +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/.eslintrc.json +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/.gitignore +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/.prettierrc.json +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/components.json +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/global.d.ts +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/next.config.js +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/package-lock.json +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/package.json +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/postcss.config.js +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/public/logo.json +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/public/svg/ai.svg +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/public/svg/arrow.svg +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/public/svg/compare.svg +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/public/svg/home.svg +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/public/svg/load.svg +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/public/svg/magic.svg +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/public/svg/play.svg +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/public/svg/playground.svg +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/public/svg/plus.svg +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/public/svg/settings.svg +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/public/svg/sparkles.svg +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/compare/page.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/dashboard/hooks/useDashboardFetch.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/dashboard/page.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/DataTable/ColumnHeader.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/DataTable/FacetedFilter.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/DataTable/Pagination.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/DataTable/RowActions.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/DataTable/Toolbar.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/DataTable/UserNav.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/DataTable/ViewOptions.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/Input.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/LogSheet.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/ModelItem.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/ModelSelector.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/Output.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/Parameters.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/components/index.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/hooks/useChat.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/hooks/useExport.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/hooks/useLogsFetch.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/hooks/useModelFetch.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/hooks/useParameterFetch.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/page.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/(llm)/playground/store.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/favicon.ico +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/globals.css +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/layout.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/app/page.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/CodeBlock/index.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/CopyButton/index.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/Header/index.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/Markdown/index.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/Theme/index.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/Toaster/index.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/theme-provider.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/accordion.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/alert-dialog.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/alert.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/aspect-ratio.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/avatar.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/badge.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/button.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/calendar.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/card.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/checkbox.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/collapsible.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/command.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/context-menu.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/dialog.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/dropdown-menu.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/form.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/hover-card.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/input.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/label.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/menubar.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/navigation-menu.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/popover.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/progress.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/radio-group.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/scroll-area.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/select.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/separator.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/sheet.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/skeleton.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/slider.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/switch.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/table.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/tabs.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/textarea.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/toast.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/toaster.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/toggle-group.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/toggle.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/tooltip.tsx +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/components/ui/use-toast.ts +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/src/lib/utils.ts +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/tailwind.config.js +0 -0
- {llmstudio-0.3.2 → llmstudio-0.3.4}/llmstudio/ui/tsconfig.json +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: llmstudio
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.4
|
|
4
4
|
Summary: Prompt Perfection at Your Fingertips
|
|
5
5
|
Home-page: https://llmstudio.ai/
|
|
6
6
|
License: MIT
|
|
@@ -16,7 +16,8 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
16
16
|
Classifier: Programming Language :: Python :: 3.12
|
|
17
17
|
Requires-Dist: aiohttp (>=3.9.1,<4.0.0)
|
|
18
18
|
Requires-Dist: anthropic (>=0.16.0,<0.17.0)
|
|
19
|
-
Requires-Dist: fastapi (>=0.
|
|
19
|
+
Requires-Dist: fastapi (>=0.109.1,<0.110.0)
|
|
20
|
+
Requires-Dist: langchain-experimental (>=0.0.52,<0.0.53)
|
|
20
21
|
Requires-Dist: openai (>=1.6.1,<2.0.0)
|
|
21
22
|
Requires-Dist: pydantic (>=2.5.3,<3.0.0)
|
|
22
23
|
Requires-Dist: python-dotenv (>=1.0.0,<2.0.0)
|
|
@@ -1,35 +1,22 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import signal
|
|
3
|
-
import socket
|
|
4
3
|
from threading import Thread
|
|
5
4
|
|
|
6
5
|
import click
|
|
7
6
|
import requests
|
|
8
|
-
from dotenv import load_dotenv
|
|
9
7
|
|
|
8
|
+
from llmstudio.config import (
|
|
9
|
+
ENGINE_HOST,
|
|
10
|
+
ENGINE_PORT,
|
|
11
|
+
TRACKING_HOST,
|
|
12
|
+
TRACKING_PORT,
|
|
13
|
+
UI_HOST,
|
|
14
|
+
UI_PORT,
|
|
15
|
+
)
|
|
10
16
|
from llmstudio.engine import run_engine_app
|
|
11
17
|
from llmstudio.tracking import run_tracking_app
|
|
12
18
|
from llmstudio.ui import run_ui_app
|
|
13
19
|
|
|
14
|
-
load_dotenv(os.path.join(os.getcwd(), ".env"))
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def assign_port():
|
|
18
|
-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
19
|
-
s.bind(("", 0))
|
|
20
|
-
return s.getsockname()[1]
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
os.environ["LLMSTUDIO_ENGINE_PORT"] = str(assign_port())
|
|
24
|
-
os.environ["NEXT_PUBLIC_LLMSTUDIO_ENGINE_PORT"] = os.environ.get(
|
|
25
|
-
"LLMSTUDIO_ENGINE_PORT"
|
|
26
|
-
)
|
|
27
|
-
os.environ["LLMSTUDIO_TRACKING_PORT"] = str(assign_port())
|
|
28
|
-
os.environ["NEXT_PUBLIC_LLMSTUDIO_TRACKING_PORT"] = os.environ.get(
|
|
29
|
-
"LLMSTUDIO_TRACKING_PORT"
|
|
30
|
-
)
|
|
31
|
-
os.environ["LLMSTUDIO_UI_PORT"] = str(assign_port())
|
|
32
|
-
|
|
33
20
|
|
|
34
21
|
def is_server_running(host, port, path="/health"):
|
|
35
22
|
try:
|
|
@@ -42,16 +29,11 @@ def is_server_running(host, port, path="/health"):
|
|
|
42
29
|
|
|
43
30
|
|
|
44
31
|
def start_server():
|
|
45
|
-
|
|
46
|
-
tracking_port = int(os.environ.get("LLMSTUDIO_TRACKING_PORT"))
|
|
47
|
-
engine_host = os.environ.get("LLMSTUDIO_ENGINE_HOST", "localhost")
|
|
48
|
-
tracking_host = os.environ.get("LLMSTUDIO_TRACKING_HOST", "localhost")
|
|
49
|
-
|
|
50
|
-
if not is_server_running(engine_host, engine_port):
|
|
32
|
+
if not is_server_running(ENGINE_HOST, ENGINE_PORT):
|
|
51
33
|
engine_thread = Thread(target=run_engine_app, daemon=True)
|
|
52
34
|
engine_thread.start()
|
|
53
35
|
|
|
54
|
-
if not is_server_running(
|
|
36
|
+
if not is_server_running(TRACKING_HOST, TRACKING_PORT):
|
|
55
37
|
tracking_thread = Thread(target=run_tracking_app, daemon=True)
|
|
56
38
|
tracking_thread.start()
|
|
57
39
|
|
|
@@ -77,34 +59,28 @@ def server(ui):
|
|
|
77
59
|
# Register the signal handler
|
|
78
60
|
signal.signal(signal.SIGINT, handle_shutdown)
|
|
79
61
|
|
|
80
|
-
engine_host = os.getenv("LLMSTUDIO_ENGINE_HOST", "localhost")
|
|
81
|
-
tracking_host = os.getenv("LLMSTUDIO_TRACKING_HOST", "localhost")
|
|
82
|
-
engine_port = int(os.getenv("LLMSTUDIO_ENGINE_PORT"))
|
|
83
|
-
tracking_port = int(os.getenv("LLMSTUDIO_TRACKING_PORT"))
|
|
84
|
-
|
|
85
62
|
# Start the engine if it's not already running
|
|
86
|
-
if not is_server_running(
|
|
63
|
+
if not is_server_running(ENGINE_HOST, ENGINE_PORT):
|
|
87
64
|
engine_thread = Thread(target=run_engine_app, daemon=True)
|
|
88
65
|
engine_thread.start()
|
|
89
66
|
else:
|
|
90
|
-
print(f"Engine server already running on {
|
|
67
|
+
print(f"Engine server already running on {ENGINE_HOST}:{ENGINE_PORT}")
|
|
91
68
|
|
|
92
69
|
# Start the tracking if it's not already running
|
|
93
|
-
if not is_server_running(
|
|
70
|
+
if not is_server_running(TRACKING_HOST, TRACKING_PORT):
|
|
94
71
|
tracking_thread = Thread(target=run_tracking_app, daemon=True)
|
|
95
72
|
tracking_thread.start()
|
|
96
73
|
else:
|
|
97
|
-
print(f"Tracking server already running on {
|
|
74
|
+
print(f"Tracking server already running on {TRACKING_HOST}:{TRACKING_PORT}")
|
|
98
75
|
|
|
99
76
|
# Start the UI if requested and not already running
|
|
100
77
|
if ui:
|
|
101
|
-
|
|
102
|
-
if not is_server_running("localhost", ui_port):
|
|
78
|
+
if not is_server_running(UI_HOST, UI_PORT):
|
|
103
79
|
ui_thread = Thread(target=run_ui_app, daemon=True)
|
|
104
80
|
ui_thread.start()
|
|
105
81
|
ui_thread.join()
|
|
106
82
|
else:
|
|
107
|
-
print(f"UI server already running on
|
|
83
|
+
print(f"UI server already running on {UI_HOST}:{UI_PORT}")
|
|
108
84
|
|
|
109
85
|
if engine_thread:
|
|
110
86
|
engine_thread.join()
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import socket
|
|
3
|
+
|
|
4
|
+
from dotenv import load_dotenv
|
|
5
|
+
|
|
6
|
+
load_dotenv(os.path.join(os.getcwd(), ".env"))
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def assign_port(default_port=None):
|
|
10
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
11
|
+
try:
|
|
12
|
+
if default_port is not None:
|
|
13
|
+
s.bind(("", default_port))
|
|
14
|
+
return default_port
|
|
15
|
+
else:
|
|
16
|
+
s.bind(("", 0))
|
|
17
|
+
return s.getsockname()[1]
|
|
18
|
+
except OSError:
|
|
19
|
+
s.bind(("", 0))
|
|
20
|
+
return s.getsockname()[1]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
defaults = {
|
|
24
|
+
"LLMSTUDIO_ENGINE_HOST": "localhost",
|
|
25
|
+
"LLMSTUDIO_TRACKING_HOST": "localhost",
|
|
26
|
+
"LLMSTUDIO_UI_HOST": "localhost",
|
|
27
|
+
"LLMSTUDIO_ENGINE_PORT": str(assign_port(50001)),
|
|
28
|
+
"LLMSTUDIO_TRACKING_PORT": str(assign_port(50002)),
|
|
29
|
+
"LLMSTUDIO_UI_PORT": str(assign_port(50003)),
|
|
30
|
+
"LLMSTUDIO_TRACKING_URI": "sqlite:///./llmstudio_mgmt.db",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
for key, default in defaults.items():
|
|
34
|
+
os.environ[key] = os.getenv(key, default)
|
|
35
|
+
|
|
36
|
+
ENGINE_PORT = os.environ["LLMSTUDIO_ENGINE_PORT"]
|
|
37
|
+
TRACKING_PORT = os.environ["LLMSTUDIO_TRACKING_PORT"]
|
|
38
|
+
UI_PORT = os.environ["LLMSTUDIO_UI_PORT"]
|
|
39
|
+
ENGINE_HOST = os.environ["LLMSTUDIO_ENGINE_HOST"]
|
|
40
|
+
TRACKING_HOST = os.environ["LLMSTUDIO_TRACKING_HOST"]
|
|
41
|
+
UI_HOST = os.environ["LLMSTUDIO_UI_HOST"]
|
|
42
|
+
TRACKING_URI = os.environ["LLMSTUDIO_TRACKING_URI"]
|
|
43
|
+
|
|
44
|
+
os.environ["NEXT_PUBLIC_LLMSTUDIO_ENGINE_PORT"] = ENGINE_PORT
|
|
45
|
+
os.environ["NEXT_PUBLIC_LLMSTUDIO_TRACKING_PORT"] = TRACKING_PORT
|
|
46
|
+
os.environ["NEXT_PUBLIC_LLMSTUDIO_ENGINE_HOST"] = ENGINE_HOST
|
|
47
|
+
os.environ["NEXT_PUBLIC_LLMSTUDIO_TRACKING_HOST"] = TRACKING_HOST
|
|
@@ -3,7 +3,6 @@ import os
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Any, Dict, List, Optional
|
|
5
5
|
|
|
6
|
-
import requests
|
|
7
6
|
import uvicorn
|
|
8
7
|
import yaml
|
|
9
8
|
from fastapi import FastAPI, Request
|
|
@@ -11,6 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
|
11
10
|
from fastapi.responses import StreamingResponse
|
|
12
11
|
from pydantic import BaseModel, ValidationError
|
|
13
12
|
|
|
13
|
+
from llmstudio.config import ENGINE_HOST, ENGINE_PORT
|
|
14
14
|
from llmstudio.engine.providers import *
|
|
15
15
|
|
|
16
16
|
ENGINE_BASE_ENDPOINT = "/api/engine"
|
|
@@ -160,19 +160,21 @@ def create_engine_app(config: EngineConfig = _load_engine_config()) -> FastAPI:
|
|
|
160
160
|
iter([csv_content]), media_type="text/csv", headers=headers
|
|
161
161
|
)
|
|
162
162
|
|
|
163
|
+
@app.on_event("startup")
|
|
164
|
+
async def startup_event():
|
|
165
|
+
print(f"Running LLMstudio Engine on http://{ENGINE_HOST}:{ENGINE_PORT} ")
|
|
166
|
+
|
|
163
167
|
return app
|
|
164
168
|
|
|
165
169
|
|
|
166
170
|
def run_engine_app():
|
|
167
|
-
print(
|
|
168
|
-
f"Running Engine on http://{os.getenv('LLMSTUDIO_ENGINE_HOST')}:{os.getenv('LLMSTUDIO_ENGINE_PORT')}"
|
|
169
|
-
)
|
|
170
171
|
try:
|
|
171
172
|
engine = create_engine_app()
|
|
172
173
|
uvicorn.run(
|
|
173
174
|
engine,
|
|
174
|
-
host=
|
|
175
|
-
port=
|
|
175
|
+
host=ENGINE_HOST,
|
|
176
|
+
port=ENGINE_PORT,
|
|
177
|
+
log_level="warning",
|
|
176
178
|
)
|
|
177
179
|
except Exception as e:
|
|
178
|
-
print(f"Error running
|
|
180
|
+
print(f"Error running LLMstudio Engine: {e}")
|
|
@@ -4,9 +4,10 @@ from typing import Any, AsyncGenerator, Coroutine, Dict, Generator, List, Option
|
|
|
4
4
|
|
|
5
5
|
import openai
|
|
6
6
|
from fastapi import HTTPException
|
|
7
|
-
from openai import AzureOpenAI
|
|
7
|
+
from openai import AzureOpenAI, OpenAI
|
|
8
8
|
from pydantic import BaseModel, Field
|
|
9
9
|
|
|
10
|
+
|
|
10
11
|
from llmstudio.engine.providers.provider import ChatRequest, Provider, provider
|
|
11
12
|
|
|
12
13
|
|
|
@@ -21,6 +22,7 @@ class AzureParameters(BaseModel):
|
|
|
21
22
|
class AzureRequest(ChatRequest):
|
|
22
23
|
api_endpoint: Optional[str] = None
|
|
23
24
|
api_version: Optional[str] = None
|
|
25
|
+
base_url: Optional[str] = None
|
|
24
26
|
parameters: Optional[AzureParameters] = AzureParameters()
|
|
25
27
|
functions: Optional[List[Dict[str, Any]]] = None
|
|
26
28
|
chat_input: Any
|
|
@@ -33,6 +35,7 @@ class AzureProvider(Provider):
|
|
|
33
35
|
self.API_KEY = os.getenv("AZURE_API_KEY")
|
|
34
36
|
self.API_ENDPOINT = os.getenv("AZURE_API_ENDPOINT")
|
|
35
37
|
self.API_VERSION = os.getenv("AZURE_API_VERSION")
|
|
38
|
+
self.BASE_URL = os.getenv("AZURE_BASE_URL")
|
|
36
39
|
|
|
37
40
|
def validate_request(self, request: AzureRequest):
|
|
38
41
|
return AzureRequest(**request)
|
|
@@ -42,24 +45,41 @@ class AzureProvider(Provider):
|
|
|
42
45
|
) -> Coroutine[Any, Any, Generator]:
|
|
43
46
|
"""Generate an AzureOpenAI client"""
|
|
44
47
|
try:
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
48
|
+
if request.api_endpoint or self.API_ENDPOINT:
|
|
49
|
+
client = AzureOpenAI(
|
|
50
|
+
api_key=request.api_key or self.API_KEY,
|
|
51
|
+
azure_endpoint=request.api_endpoint or self.API_ENDPOINT,
|
|
52
|
+
api_version=request.api_version or self.API_VERSION,
|
|
53
|
+
)
|
|
54
|
+
return await asyncio.to_thread(
|
|
55
|
+
client.chat.completions.create,
|
|
56
|
+
model=request.model,
|
|
57
|
+
messages=(
|
|
58
|
+
[{"role": "user", "content": request.chat_input}]
|
|
59
|
+
if isinstance(request.chat_input, str)
|
|
60
|
+
else request.chat_input
|
|
61
|
+
),
|
|
62
|
+
functions=request.functions,
|
|
63
|
+
function_call="auto" if request.functions else None,
|
|
64
|
+
stream=True,
|
|
65
|
+
**request.parameters.model_dump(),
|
|
66
|
+
)
|
|
67
|
+
elif request.base_url or self.BASE_URL:
|
|
68
|
+
client = OpenAI(
|
|
69
|
+
api_key=request.api_key or self.API_KEY,
|
|
70
|
+
base_url=request.base_url or self.BASE_URL,
|
|
71
|
+
)
|
|
72
|
+
return await asyncio.to_thread(
|
|
73
|
+
client.chat.completions.create,
|
|
74
|
+
model=request.model,
|
|
75
|
+
messages=(
|
|
76
|
+
[{"role": "user", "content": request.chat_input}]
|
|
77
|
+
if isinstance(request.chat_input, str)
|
|
78
|
+
else request.chat_input
|
|
79
|
+
),
|
|
80
|
+
stream=True,
|
|
81
|
+
)
|
|
82
|
+
|
|
63
83
|
except openai._exceptions.APIError as e:
|
|
64
84
|
raise HTTPException(status_code=e.status_code, detail=e.response.json())
|
|
65
85
|
|
|
@@ -211,6 +211,9 @@ class Provider:
|
|
|
211
211
|
function_calls = [
|
|
212
212
|
chunk.get("choices")[0].get("delta").get("function_call")
|
|
213
213
|
for chunk in chunks[1:-1]
|
|
214
|
+
if chunk.get("choices")
|
|
215
|
+
and chunk.get("choices")[0].get("delta")
|
|
216
|
+
and chunk.get("choices")[0].get("delta").get("function_call")
|
|
214
217
|
]
|
|
215
218
|
|
|
216
219
|
if isinstance(request, AzureRequest):
|
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
import os
|
|
2
|
-
|
|
3
1
|
import aiohttp
|
|
4
2
|
import requests
|
|
5
3
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
|
6
4
|
|
|
7
5
|
from llmstudio.cli import start_server
|
|
6
|
+
from llmstudio.config import ENGINE_HOST, ENGINE_PORT
|
|
8
7
|
|
|
9
8
|
|
|
10
9
|
class LLM:
|
|
@@ -15,6 +14,7 @@ class LLM:
|
|
|
15
14
|
self.api_key = kwargs.get("api_key")
|
|
16
15
|
self.api_endpoint = kwargs.get("api_endpoint")
|
|
17
16
|
self.api_version = kwargs.get("api_version")
|
|
17
|
+
self.base_url = kwargs.get("base_url")
|
|
18
18
|
self.temperature = kwargs.get("temperature")
|
|
19
19
|
self.top_p = kwargs.get("top_p")
|
|
20
20
|
self.top_k = kwargs.get("top_k")
|
|
@@ -22,13 +22,14 @@ class LLM:
|
|
|
22
22
|
|
|
23
23
|
def chat(self, input: str, is_stream: bool = False, **kwargs):
|
|
24
24
|
response = requests.post(
|
|
25
|
-
f"http://{
|
|
25
|
+
f"http://{ENGINE_HOST}:{ENGINE_PORT}/api/engine/chat/{self.provider}",
|
|
26
26
|
json={
|
|
27
27
|
"model": self.model,
|
|
28
28
|
"session_id": self.session_id,
|
|
29
29
|
"api_key": self.api_key,
|
|
30
30
|
"api_endpoint": self.api_endpoint,
|
|
31
31
|
"api_version": self.api_version,
|
|
32
|
+
"base_url": self.base_url,
|
|
32
33
|
"chat_input": input,
|
|
33
34
|
"is_stream": is_stream,
|
|
34
35
|
"parameters": {
|
|
@@ -64,7 +65,7 @@ class LLM:
|
|
|
64
65
|
async def async_non_stream(self, input: str, **kwargs):
|
|
65
66
|
async with aiohttp.ClientSession() as session:
|
|
66
67
|
async with session.post(
|
|
67
|
-
f"http://{
|
|
68
|
+
f"http://{ENGINE_HOST}:{ENGINE_PORT}/api/engine/chat/{self.provider}",
|
|
68
69
|
json={
|
|
69
70
|
"model": self.model,
|
|
70
71
|
"api_key": self.api_key,
|
|
@@ -78,12 +79,12 @@ class LLM:
|
|
|
78
79
|
) as response:
|
|
79
80
|
response.raise_for_status()
|
|
80
81
|
|
|
81
|
-
return
|
|
82
|
+
return ChatCompletion(**await response.json())
|
|
82
83
|
|
|
83
84
|
async def async_stream(self, input: str, **kwargs):
|
|
84
85
|
async with aiohttp.ClientSession() as session:
|
|
85
86
|
async with session.post(
|
|
86
|
-
f"http://{
|
|
87
|
+
f"http://{ENGINE_HOST}:{ENGINE_PORT}/api/engine/chat/{self.provider}",
|
|
87
88
|
json={
|
|
88
89
|
"model": self.model,
|
|
89
90
|
"api_key": self.api_key,
|
|
@@ -99,4 +100,4 @@ class LLM:
|
|
|
99
100
|
|
|
100
101
|
async for chunk in response.content.iter_any():
|
|
101
102
|
if chunk:
|
|
102
|
-
yield ChatCompletionChunk(**chunk.decode("utf-8"))
|
|
103
|
+
yield ChatCompletionChunk(**await chunk.decode("utf-8"))
|
|
@@ -46,7 +46,7 @@ class ChatLLMstudio(BaseChatModel):
|
|
|
46
46
|
token_usage = response.get("usage", {})
|
|
47
47
|
llm_output = {
|
|
48
48
|
"token_usage": token_usage,
|
|
49
|
-
"model_name": "
|
|
49
|
+
"model_name": response["model"],
|
|
50
50
|
"system_fingerprint": response.get("system_fingerprint", ""),
|
|
51
51
|
}
|
|
52
52
|
return ChatResult(generations=generations, llm_output=llm_output)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import uvicorn
|
|
2
|
+
from fastapi import APIRouter, FastAPI
|
|
3
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
4
|
+
|
|
5
|
+
from llmstudio.config import TRACKING_HOST, TRACKING_PORT
|
|
6
|
+
from llmstudio.engine.providers import *
|
|
7
|
+
from llmstudio.tracking.logs.endpoints import LogsRoutes
|
|
8
|
+
from llmstudio.tracking.session.endpoints import SessionsRoutes
|
|
9
|
+
|
|
10
|
+
TRACKING_HEALTH_ENDPOINT = "/health"
|
|
11
|
+
TRACKING_TITLE = "LLMstudio Tracking API"
|
|
12
|
+
TRACKING_DESCRIPTION = "The tracking API for LLM interactions"
|
|
13
|
+
TRACKING_VERSION = "0.0.1"
|
|
14
|
+
TRACKING_BASE_ENDPOINT = "/api/tracking"
|
|
15
|
+
|
|
16
|
+
## Tracking
|
|
17
|
+
def create_tracking_app() -> FastAPI:
|
|
18
|
+
app = FastAPI(
|
|
19
|
+
title=TRACKING_TITLE,
|
|
20
|
+
description=TRACKING_DESCRIPTION,
|
|
21
|
+
version=TRACKING_VERSION,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
app.add_middleware(
|
|
25
|
+
CORSMiddleware,
|
|
26
|
+
allow_origins=["*"],
|
|
27
|
+
allow_credentials=True,
|
|
28
|
+
allow_methods=["*"],
|
|
29
|
+
allow_headers=["*"],
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
@app.get(TRACKING_HEALTH_ENDPOINT)
|
|
33
|
+
def health_check():
|
|
34
|
+
"""Health check endpoint to ensure the API is running."""
|
|
35
|
+
return {"status": "healthy", "message": "Tracking is up and running"}
|
|
36
|
+
|
|
37
|
+
tracking_router = APIRouter(prefix=TRACKING_BASE_ENDPOINT)
|
|
38
|
+
LogsRoutes(tracking_router)
|
|
39
|
+
SessionsRoutes(tracking_router)
|
|
40
|
+
|
|
41
|
+
app.include_router(tracking_router)
|
|
42
|
+
|
|
43
|
+
@app.on_event("startup")
|
|
44
|
+
async def startup_event():
|
|
45
|
+
print(f"Running LLMstudio Tracking on http://{TRACKING_HOST}:{TRACKING_PORT} ")
|
|
46
|
+
|
|
47
|
+
return app
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def run_tracking_app():
|
|
51
|
+
try:
|
|
52
|
+
tracking = create_tracking_app()
|
|
53
|
+
uvicorn.run(
|
|
54
|
+
tracking,
|
|
55
|
+
host=TRACKING_HOST,
|
|
56
|
+
port=TRACKING_PORT,
|
|
57
|
+
log_level="warning",
|
|
58
|
+
)
|
|
59
|
+
except Exception as e:
|
|
60
|
+
print(f"Error running LLMstudio Tracking: {e}")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
if __name__ == "__main__":
|
|
64
|
+
run_tracking_app()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from sqlalchemy import create_engine
|
|
2
|
+
from sqlalchemy.orm import declarative_base, sessionmaker
|
|
3
|
+
|
|
4
|
+
from llmstudio.config import TRACKING_URI
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def create_tracking_engine(uri: str):
|
|
8
|
+
if uri.split("://")[0] == "sqlite":
|
|
9
|
+
return create_engine(uri, connect_args={"check_same_thread": False})
|
|
10
|
+
return create_engine(uri)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
engine = create_tracking_engine(TRACKING_URI)
|
|
14
|
+
|
|
15
|
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
16
|
+
|
|
17
|
+
Base = declarative_base()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_db():
|
|
21
|
+
db = SessionLocal()
|
|
22
|
+
try:
|
|
23
|
+
yield db
|
|
24
|
+
finally:
|
|
25
|
+
db.close()
|
|
File without changes
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from sqlalchemy.orm import Session
|
|
2
2
|
|
|
3
|
-
from llmstudio.tracking import models, schemas
|
|
3
|
+
from llmstudio.tracking.logs import models, schemas
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def get_project_by_name(db: Session, name: str):
|
|
@@ -22,6 +22,7 @@ def add_log(db: Session, log: schemas.LogDefaultCreate):
|
|
|
22
22
|
db.add(db_log)
|
|
23
23
|
db.commit()
|
|
24
24
|
db.refresh(db_log)
|
|
25
|
+
|
|
25
26
|
return db_log
|
|
26
27
|
|
|
27
28
|
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from fastapi import APIRouter, Depends
|
|
4
|
+
from sqlalchemy.orm import Session
|
|
5
|
+
|
|
6
|
+
from llmstudio.tracking.database import engine, get_db
|
|
7
|
+
from llmstudio.tracking.logs import crud, models, schemas
|
|
8
|
+
|
|
9
|
+
models.Base.metadata.create_all(bind=engine)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LogsRoutes:
|
|
13
|
+
def __init__(self, router: APIRouter):
|
|
14
|
+
self.router = router
|
|
15
|
+
|
|
16
|
+
# Define routes
|
|
17
|
+
self.define_routes()
|
|
18
|
+
|
|
19
|
+
def define_routes(self):
|
|
20
|
+
# Add log
|
|
21
|
+
self.router.post(
|
|
22
|
+
"/logs",
|
|
23
|
+
response_model=schemas.LogDefault,
|
|
24
|
+
)(self.add_log)
|
|
25
|
+
|
|
26
|
+
# Read logs
|
|
27
|
+
self.router.get("/logs", response_model=List[schemas.LogDefault])(
|
|
28
|
+
self.read_logs
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Read logs by session
|
|
32
|
+
self.router.get("/logs_by_session", response_model=List[schemas.LogDefault])(
|
|
33
|
+
self.read_logs_by_session
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
async def add_log(
|
|
37
|
+
self, log: schemas.LogDefaultCreate, db: Session = Depends(get_db)
|
|
38
|
+
):
|
|
39
|
+
return crud.add_log(db=db, log=log)
|
|
40
|
+
|
|
41
|
+
async def read_logs(
|
|
42
|
+
self, skip: int = 0, limit: int = 1000, db: Session = Depends(get_db)
|
|
43
|
+
):
|
|
44
|
+
logs = crud.get_logs(db, skip=skip, limit=limit)
|
|
45
|
+
return logs
|
|
46
|
+
|
|
47
|
+
async def read_logs_by_session(
|
|
48
|
+
self,
|
|
49
|
+
session_id: str,
|
|
50
|
+
skip: int = 0,
|
|
51
|
+
limit: int = 1000,
|
|
52
|
+
db: Session = Depends(get_db),
|
|
53
|
+
):
|
|
54
|
+
logs = crud.get_logs_by_session(
|
|
55
|
+
db, session_id=session_id, skip=skip, limit=limit
|
|
56
|
+
)
|
|
57
|
+
return logs
|
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
from sqlalchemy import JSON,
|
|
2
|
-
from sqlalchemy.orm import relationship
|
|
1
|
+
from sqlalchemy import JSON, Column, DateTime, Integer, String
|
|
3
2
|
from sqlalchemy.sql import func
|
|
4
3
|
|
|
5
4
|
from llmstudio.tracking.database import Base
|
|
File without changes
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from sqlalchemy.orm import Session
|
|
2
|
+
|
|
3
|
+
from llmstudio.tracking.session import models, schemas
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_project_by_name(db: Session, name: str):
|
|
7
|
+
return db.query(models.Project).filter(models.Project.name == name).first()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_session_by_session_id(
|
|
11
|
+
db: Session, session_id: str, skip: int = 0, limit: int = 100
|
|
12
|
+
):
|
|
13
|
+
return (
|
|
14
|
+
db.query(models.SessionDefault)
|
|
15
|
+
.filter(models.SessionDefault.session_id == session_id)
|
|
16
|
+
.order_by(models.SessionDefault.created_at.asc())
|
|
17
|
+
.offset(skip)
|
|
18
|
+
.limit(limit)
|
|
19
|
+
.all()
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_session_by_message_id(db: Session, message_id: int):
|
|
24
|
+
return (
|
|
25
|
+
db.query(models.SessionDefault)
|
|
26
|
+
.filter(models.SessionDefault.message_id == message_id)
|
|
27
|
+
.first()
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def add_session(db: Session, session: schemas.SessionDefaultCreate):
|
|
32
|
+
db_session = models.SessionDefault(**session.dict())
|
|
33
|
+
|
|
34
|
+
db.add(db_session)
|
|
35
|
+
db.commit()
|
|
36
|
+
db.refresh(db_session)
|
|
37
|
+
return db_session
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def update_session(db: Session, message_id: int, extras: dict):
|
|
41
|
+
existing_session = get_session_by_message_id(db, message_id)
|
|
42
|
+
existing_session.extras = extras
|
|
43
|
+
db.commit()
|
|
44
|
+
db.refresh(existing_session)
|
|
45
|
+
return existing_session
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def upsert_session(db: Session, session: schemas.SessionDefaultCreate):
|
|
49
|
+
return add_session(db, session)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from fastapi import APIRouter, Depends
|
|
4
|
+
from sqlalchemy.orm import Session
|
|
5
|
+
|
|
6
|
+
from llmstudio.tracking.database import engine, get_db
|
|
7
|
+
from llmstudio.tracking.session import crud, models, schemas
|
|
8
|
+
|
|
9
|
+
models.Base.metadata.create_all(bind=engine)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SessionsRoutes:
|
|
13
|
+
def __init__(self, router: APIRouter):
|
|
14
|
+
self.router = router
|
|
15
|
+
self.define_routes()
|
|
16
|
+
|
|
17
|
+
def define_routes(self):
|
|
18
|
+
# Add session
|
|
19
|
+
self.router.post(
|
|
20
|
+
"/session",
|
|
21
|
+
response_model=schemas.SessionDefault,
|
|
22
|
+
)(self.add_session)
|
|
23
|
+
|
|
24
|
+
# Read session
|
|
25
|
+
self.router.get(
|
|
26
|
+
"/session/{session_id}", response_model=List[schemas.SessionDefault]
|
|
27
|
+
)(self.get_session)
|
|
28
|
+
|
|
29
|
+
self.router.patch(
|
|
30
|
+
"/session/{message_id}", response_model=schemas.SessionDefault
|
|
31
|
+
)(self.update_session)
|
|
32
|
+
|
|
33
|
+
async def add_session(
|
|
34
|
+
self, session: schemas.SessionDefaultCreate, db: Session = Depends(get_db)
|
|
35
|
+
):
|
|
36
|
+
return crud.upsert_session(db=db, session=session)
|
|
37
|
+
|
|
38
|
+
async def update_session(
|
|
39
|
+
self, message_id: int, extras: dict, db: Session = Depends(get_db)
|
|
40
|
+
):
|
|
41
|
+
sessions = crud.update_session(db, message_id=message_id, extras=extras)
|
|
42
|
+
return sessions
|
|
43
|
+
|
|
44
|
+
async def get_session(self, session_id: str, db: Session = Depends(get_db)):
|
|
45
|
+
sessions = crud.get_session_by_session_id(db, session_id=session_id)
|
|
46
|
+
return sessions
|