spaik-sdk 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spaik_sdk/__init__.py +21 -0
- spaik_sdk/agent/__init__.py +0 -0
- spaik_sdk/agent/base_agent.py +249 -0
- spaik_sdk/attachments/__init__.py +22 -0
- spaik_sdk/attachments/builder.py +61 -0
- spaik_sdk/attachments/file_storage_provider.py +27 -0
- spaik_sdk/attachments/mime_types.py +118 -0
- spaik_sdk/attachments/models.py +63 -0
- spaik_sdk/attachments/provider_support.py +53 -0
- spaik_sdk/attachments/storage/__init__.py +0 -0
- spaik_sdk/attachments/storage/base_file_storage.py +32 -0
- spaik_sdk/attachments/storage/impl/__init__.py +0 -0
- spaik_sdk/attachments/storage/impl/local_file_storage.py +101 -0
- spaik_sdk/audio/__init__.py +12 -0
- spaik_sdk/audio/options.py +53 -0
- spaik_sdk/audio/providers/__init__.py +1 -0
- spaik_sdk/audio/providers/google_tts.py +77 -0
- spaik_sdk/audio/providers/openai_stt.py +71 -0
- spaik_sdk/audio/providers/openai_tts.py +111 -0
- spaik_sdk/audio/stt.py +61 -0
- spaik_sdk/audio/tts.py +124 -0
- spaik_sdk/config/credentials_provider.py +10 -0
- spaik_sdk/config/env.py +59 -0
- spaik_sdk/config/env_credentials_provider.py +7 -0
- spaik_sdk/config/get_credentials_provider.py +14 -0
- spaik_sdk/image_gen/__init__.py +9 -0
- spaik_sdk/image_gen/image_generator.py +83 -0
- spaik_sdk/image_gen/options.py +24 -0
- spaik_sdk/image_gen/providers/__init__.py +0 -0
- spaik_sdk/image_gen/providers/google.py +75 -0
- spaik_sdk/image_gen/providers/openai.py +60 -0
- spaik_sdk/llm/__init__.py +0 -0
- spaik_sdk/llm/cancellation_handle.py +10 -0
- spaik_sdk/llm/consumption/__init__.py +0 -0
- spaik_sdk/llm/consumption/consumption_estimate.py +26 -0
- spaik_sdk/llm/consumption/consumption_estimate_builder.py +113 -0
- spaik_sdk/llm/consumption/consumption_extractor.py +59 -0
- spaik_sdk/llm/consumption/token_usage.py +31 -0
- spaik_sdk/llm/converters.py +146 -0
- spaik_sdk/llm/cost/__init__.py +1 -0
- spaik_sdk/llm/cost/builtin_cost_provider.py +83 -0
- spaik_sdk/llm/cost/cost_estimate.py +8 -0
- spaik_sdk/llm/cost/cost_provider.py +28 -0
- spaik_sdk/llm/extract_error_message.py +37 -0
- spaik_sdk/llm/langchain_loop_manager.py +270 -0
- spaik_sdk/llm/langchain_service.py +196 -0
- spaik_sdk/llm/message_handler.py +188 -0
- spaik_sdk/llm/streaming/__init__.py +1 -0
- spaik_sdk/llm/streaming/block_manager.py +152 -0
- spaik_sdk/llm/streaming/models.py +42 -0
- spaik_sdk/llm/streaming/streaming_content_handler.py +157 -0
- spaik_sdk/llm/streaming/streaming_event_handler.py +215 -0
- spaik_sdk/llm/streaming/streaming_state_manager.py +58 -0
- spaik_sdk/models/__init__.py +0 -0
- spaik_sdk/models/factories/__init__.py +0 -0
- spaik_sdk/models/factories/anthropic_factory.py +33 -0
- spaik_sdk/models/factories/base_model_factory.py +71 -0
- spaik_sdk/models/factories/google_factory.py +30 -0
- spaik_sdk/models/factories/ollama_factory.py +41 -0
- spaik_sdk/models/factories/openai_factory.py +50 -0
- spaik_sdk/models/llm_config.py +46 -0
- spaik_sdk/models/llm_families.py +7 -0
- spaik_sdk/models/llm_model.py +17 -0
- spaik_sdk/models/llm_wrapper.py +25 -0
- spaik_sdk/models/model_registry.py +156 -0
- spaik_sdk/models/providers/__init__.py +0 -0
- spaik_sdk/models/providers/anthropic_provider.py +29 -0
- spaik_sdk/models/providers/azure_provider.py +31 -0
- spaik_sdk/models/providers/base_provider.py +62 -0
- spaik_sdk/models/providers/google_provider.py +26 -0
- spaik_sdk/models/providers/ollama_provider.py +26 -0
- spaik_sdk/models/providers/openai_provider.py +26 -0
- spaik_sdk/models/providers/provider_type.py +90 -0
- spaik_sdk/orchestration/__init__.py +24 -0
- spaik_sdk/orchestration/base_orchestrator.py +238 -0
- spaik_sdk/orchestration/checkpoint.py +80 -0
- spaik_sdk/orchestration/models.py +103 -0
- spaik_sdk/prompt/__init__.py +0 -0
- spaik_sdk/prompt/get_prompt_loader.py +13 -0
- spaik_sdk/prompt/local_prompt_loader.py +21 -0
- spaik_sdk/prompt/prompt_loader.py +48 -0
- spaik_sdk/prompt/prompt_loader_mode.py +14 -0
- spaik_sdk/py.typed +1 -0
- spaik_sdk/recording/__init__.py +1 -0
- spaik_sdk/recording/base_playback.py +90 -0
- spaik_sdk/recording/base_recorder.py +50 -0
- spaik_sdk/recording/conditional_recorder.py +38 -0
- spaik_sdk/recording/impl/__init__.py +1 -0
- spaik_sdk/recording/impl/local_playback.py +76 -0
- spaik_sdk/recording/impl/local_recorder.py +85 -0
- spaik_sdk/recording/langchain_serializer.py +88 -0
- spaik_sdk/server/__init__.py +1 -0
- spaik_sdk/server/api/routers/__init__.py +0 -0
- spaik_sdk/server/api/routers/api_builder.py +149 -0
- spaik_sdk/server/api/routers/audio_router_factory.py +201 -0
- spaik_sdk/server/api/routers/file_router_factory.py +111 -0
- spaik_sdk/server/api/routers/thread_router_factory.py +284 -0
- spaik_sdk/server/api/streaming/__init__.py +0 -0
- spaik_sdk/server/api/streaming/format_sse_event.py +41 -0
- spaik_sdk/server/api/streaming/negotiate_streaming_response.py +8 -0
- spaik_sdk/server/api/streaming/streaming_negotiator.py +10 -0
- spaik_sdk/server/authorization/__init__.py +0 -0
- spaik_sdk/server/authorization/base_authorizer.py +64 -0
- spaik_sdk/server/authorization/base_user.py +13 -0
- spaik_sdk/server/authorization/dummy_authorizer.py +17 -0
- spaik_sdk/server/job_processor/__init__.py +0 -0
- spaik_sdk/server/job_processor/base_job_processor.py +8 -0
- spaik_sdk/server/job_processor/thread_job_processor.py +32 -0
- spaik_sdk/server/pubsub/__init__.py +1 -0
- spaik_sdk/server/pubsub/cancellation_publisher.py +7 -0
- spaik_sdk/server/pubsub/cancellation_subscriber.py +38 -0
- spaik_sdk/server/pubsub/event_publisher.py +13 -0
- spaik_sdk/server/pubsub/impl/__init__.py +1 -0
- spaik_sdk/server/pubsub/impl/local_cancellation_pubsub.py +48 -0
- spaik_sdk/server/pubsub/impl/signalr_publisher.py +36 -0
- spaik_sdk/server/queue/__init__.py +1 -0
- spaik_sdk/server/queue/agent_job_queue.py +27 -0
- spaik_sdk/server/queue/impl/__init__.py +1 -0
- spaik_sdk/server/queue/impl/azure_queue.py +24 -0
- spaik_sdk/server/response/__init__.py +0 -0
- spaik_sdk/server/response/agent_response_generator.py +39 -0
- spaik_sdk/server/response/response_generator.py +13 -0
- spaik_sdk/server/response/simple_agent_response_generator.py +14 -0
- spaik_sdk/server/services/__init__.py +0 -0
- spaik_sdk/server/services/thread_converters.py +113 -0
- spaik_sdk/server/services/thread_models.py +90 -0
- spaik_sdk/server/services/thread_service.py +91 -0
- spaik_sdk/server/storage/__init__.py +1 -0
- spaik_sdk/server/storage/base_thread_repository.py +51 -0
- spaik_sdk/server/storage/impl/__init__.py +0 -0
- spaik_sdk/server/storage/impl/in_memory_thread_repository.py +100 -0
- spaik_sdk/server/storage/impl/local_file_thread_repository.py +217 -0
- spaik_sdk/server/storage/thread_filter.py +166 -0
- spaik_sdk/server/storage/thread_metadata.py +53 -0
- spaik_sdk/thread/__init__.py +0 -0
- spaik_sdk/thread/adapters/__init__.py +0 -0
- spaik_sdk/thread/adapters/cli/__init__.py +0 -0
- spaik_sdk/thread/adapters/cli/block_display.py +92 -0
- spaik_sdk/thread/adapters/cli/display_manager.py +84 -0
- spaik_sdk/thread/adapters/cli/live_cli.py +235 -0
- spaik_sdk/thread/adapters/event_adapter.py +28 -0
- spaik_sdk/thread/adapters/streaming_block_adapter.py +57 -0
- spaik_sdk/thread/adapters/sync_adapter.py +76 -0
- spaik_sdk/thread/models.py +224 -0
- spaik_sdk/thread/thread_container.py +468 -0
- spaik_sdk/tools/__init__.py +0 -0
- spaik_sdk/tools/impl/__init__.py +0 -0
- spaik_sdk/tools/impl/mcp_tool_provider.py +93 -0
- spaik_sdk/tools/impl/search_tool_provider.py +18 -0
- spaik_sdk/tools/tool_provider.py +131 -0
- spaik_sdk/tracing/__init__.py +13 -0
- spaik_sdk/tracing/agent_trace.py +72 -0
- spaik_sdk/tracing/get_trace_sink.py +15 -0
- spaik_sdk/tracing/local_trace_sink.py +23 -0
- spaik_sdk/tracing/trace_sink.py +19 -0
- spaik_sdk/tracing/trace_sink_mode.py +14 -0
- spaik_sdk/utils/__init__.py +0 -0
- spaik_sdk/utils/init_logger.py +24 -0
- spaik_sdk-0.6.2.dist-info/METADATA +379 -0
- spaik_sdk-0.6.2.dist-info/RECORD +161 -0
- spaik_sdk-0.6.2.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
|
4
|
+
from fastapi.responses import Response
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from spaik_sdk.attachments.storage.base_file_storage import BaseFileStorage
|
|
8
|
+
from spaik_sdk.server.authorization.base_authorizer import BaseAuthorizer
|
|
9
|
+
from spaik_sdk.server.authorization.base_user import BaseUser
|
|
10
|
+
from spaik_sdk.utils.init_logger import init_logger
|
|
11
|
+
|
|
12
|
+
logger = init_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FileUploadResponse(BaseModel):
|
|
16
|
+
file_id: str
|
|
17
|
+
mime_type: str
|
|
18
|
+
filename: Optional[str]
|
|
19
|
+
size_bytes: int
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FileDeleteResponse(BaseModel):
|
|
23
|
+
success: bool
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class FileRouterFactory:
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
file_storage: BaseFileStorage,
|
|
30
|
+
authorizer: Optional[BaseAuthorizer[BaseUser]] = None,
|
|
31
|
+
):
|
|
32
|
+
self.file_storage = file_storage
|
|
33
|
+
self.authorizer = authorizer
|
|
34
|
+
|
|
35
|
+
def create_router(self, prefix: str = "/files") -> APIRouter:
|
|
36
|
+
router = APIRouter(prefix=prefix, tags=["files"])
|
|
37
|
+
|
|
38
|
+
async def get_current_user(request: Request) -> BaseUser:
|
|
39
|
+
if self.authorizer is None:
|
|
40
|
+
return BaseUser("anonymous")
|
|
41
|
+
user = await self.authorizer.get_user(request)
|
|
42
|
+
if not user:
|
|
43
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
44
|
+
return user
|
|
45
|
+
|
|
46
|
+
@router.post("", response_model=FileUploadResponse)
|
|
47
|
+
async def upload_file(
|
|
48
|
+
file: UploadFile = File(...),
|
|
49
|
+
user: BaseUser = Depends(get_current_user),
|
|
50
|
+
):
|
|
51
|
+
if self.authorizer and not await self.authorizer.can_upload_file(user):
|
|
52
|
+
raise HTTPException(status_code=403, detail="Not authorized to upload files")
|
|
53
|
+
|
|
54
|
+
content = await file.read()
|
|
55
|
+
mime_type = file.content_type or "application/octet-stream"
|
|
56
|
+
|
|
57
|
+
metadata = await self.file_storage.store(
|
|
58
|
+
data=content,
|
|
59
|
+
mime_type=mime_type,
|
|
60
|
+
owner_id=user.get_id(),
|
|
61
|
+
filename=file.filename,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return FileUploadResponse(
|
|
65
|
+
file_id=metadata.file_id,
|
|
66
|
+
mime_type=metadata.mime_type,
|
|
67
|
+
filename=metadata.filename,
|
|
68
|
+
size_bytes=metadata.size_bytes,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
@router.get("/{file_id}")
|
|
72
|
+
async def download_file(
|
|
73
|
+
file_id: str,
|
|
74
|
+
user: BaseUser = Depends(get_current_user),
|
|
75
|
+
):
|
|
76
|
+
metadata = await self.file_storage.get_metadata(file_id)
|
|
77
|
+
if metadata is None:
|
|
78
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
79
|
+
|
|
80
|
+
if self.authorizer and not await self.authorizer.can_read_file(user, metadata):
|
|
81
|
+
raise HTTPException(status_code=403, detail="Not authorized to access this file")
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
content, _ = await self.file_storage.retrieve(file_id)
|
|
85
|
+
except FileNotFoundError:
|
|
86
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
87
|
+
|
|
88
|
+
return Response(
|
|
89
|
+
content=content,
|
|
90
|
+
media_type=metadata.mime_type,
|
|
91
|
+
headers={
|
|
92
|
+
"Content-Disposition": f'inline; filename="{metadata.filename or file_id}"',
|
|
93
|
+
},
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
@router.delete("/{file_id}", response_model=FileDeleteResponse)
|
|
97
|
+
async def delete_file(
|
|
98
|
+
file_id: str,
|
|
99
|
+
user: BaseUser = Depends(get_current_user),
|
|
100
|
+
):
|
|
101
|
+
metadata = await self.file_storage.get_metadata(file_id)
|
|
102
|
+
if metadata is None:
|
|
103
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
104
|
+
|
|
105
|
+
if self.authorizer and not await self.authorizer.can_delete_file(user, metadata):
|
|
106
|
+
raise HTTPException(status_code=403, detail="Not authorized to delete this file")
|
|
107
|
+
|
|
108
|
+
success = await self.file_storage.delete(file_id)
|
|
109
|
+
return FileDeleteResponse(success=success)
|
|
110
|
+
|
|
111
|
+
return router
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import time
|
|
3
|
+
import uuid
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import Awaitable, List, Optional
|
|
6
|
+
|
|
7
|
+
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
|
8
|
+
from fastapi.responses import StreamingResponse
|
|
9
|
+
|
|
10
|
+
from spaik_sdk.server.api.streaming.streaming_negotiator import StreamingNegotiator
|
|
11
|
+
from spaik_sdk.server.authorization.base_authorizer import BaseAuthorizer
|
|
12
|
+
from spaik_sdk.server.authorization.base_user import BaseUser
|
|
13
|
+
from spaik_sdk.server.authorization.dummy_authorizer import DummyAuthorizer
|
|
14
|
+
from spaik_sdk.server.job_processor.thread_job_processor import ThreadJobProcessor
|
|
15
|
+
from spaik_sdk.server.pubsub.cancellation_publisher import CancellationPublisher
|
|
16
|
+
from spaik_sdk.server.pubsub.cancellation_subscriber import CancellationSubscriber
|
|
17
|
+
from spaik_sdk.server.queue.agent_job_queue import AgentJob, AgentJobQueue, JobType
|
|
18
|
+
from spaik_sdk.server.services.thread_converters import ThreadConverters
|
|
19
|
+
from spaik_sdk.server.services.thread_models import (
|
|
20
|
+
CreateMessageRequest,
|
|
21
|
+
CreateThreadRequest,
|
|
22
|
+
ListThreadsResponse,
|
|
23
|
+
MessageResponse,
|
|
24
|
+
ThreadResponse,
|
|
25
|
+
)
|
|
26
|
+
from spaik_sdk.server.services.thread_service import ThreadService
|
|
27
|
+
from spaik_sdk.server.storage.thread_filter import ThreadFilter
|
|
28
|
+
from spaik_sdk.thread.models import MessageAddedEvent, MessageBlock, MessageBlockType, ThreadMessage
|
|
29
|
+
from spaik_sdk.thread.thread_container import ThreadContainer
|
|
30
|
+
from spaik_sdk.utils.init_logger import init_logger
|
|
31
|
+
|
|
32
|
+
logger = init_logger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ThreadRouterFactory:
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
service: ThreadService,
|
|
39
|
+
authorizer: Optional[BaseAuthorizer[BaseUser]] = None,
|
|
40
|
+
streaming_negotiator: Optional[StreamingNegotiator] = None,
|
|
41
|
+
job_queue: Optional[AgentJobQueue] = None,
|
|
42
|
+
thread_job_processor: Optional[ThreadJobProcessor] = None,
|
|
43
|
+
cancellation_subscriber_provider: Optional[Callable[[str], Awaitable[CancellationSubscriber]]] = None,
|
|
44
|
+
cancellation_publisher: Optional[CancellationPublisher] = None,
|
|
45
|
+
):
|
|
46
|
+
self.authorizer = authorizer or DummyAuthorizer()
|
|
47
|
+
self.service = service
|
|
48
|
+
self.streaming_negotiator = streaming_negotiator
|
|
49
|
+
self.job_queue = job_queue
|
|
50
|
+
self.thread_job_processor = thread_job_processor
|
|
51
|
+
self.cancellation_subscriber_provider = cancellation_subscriber_provider
|
|
52
|
+
self.cancellation_publisher = cancellation_publisher
|
|
53
|
+
|
|
54
|
+
def create_router(self, prefix: str = "/threads") -> APIRouter:
|
|
55
|
+
router = APIRouter(prefix=prefix, tags=["threads"])
|
|
56
|
+
|
|
57
|
+
async def get_current_user(request: Request) -> BaseUser:
|
|
58
|
+
user = await self.authorizer.get_user(request)
|
|
59
|
+
if not user:
|
|
60
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
61
|
+
return user
|
|
62
|
+
|
|
63
|
+
@router.post("", response_model=ThreadResponse)
|
|
64
|
+
async def create_thread(request: CreateThreadRequest, user: BaseUser = Depends(get_current_user)):
|
|
65
|
+
"""Create a new thread"""
|
|
66
|
+
if not await self.authorizer.can_create_thread(user):
|
|
67
|
+
raise HTTPException(status_code=403, detail="Access denied")
|
|
68
|
+
thread_container = ThreadContainer() # TODO giving no args (eg system prompt) might cause issues
|
|
69
|
+
created_thread = await self.service.create_thread(thread_container)
|
|
70
|
+
return ThreadConverters.thread_model_to_response(created_thread)
|
|
71
|
+
|
|
72
|
+
@router.get("/{thread_id}", response_model=ThreadResponse)
|
|
73
|
+
async def get_thread(thread_id: str, user: BaseUser = Depends(get_current_user)):
|
|
74
|
+
"""Get thread by ID"""
|
|
75
|
+
|
|
76
|
+
thread = await self.service.get_thread(thread_id)
|
|
77
|
+
if not thread:
|
|
78
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
79
|
+
if not await self.authorizer.can_read_thread(user, thread):
|
|
80
|
+
raise HTTPException(status_code=403, detail="Access denied")
|
|
81
|
+
|
|
82
|
+
return ThreadConverters.thread_model_to_response(thread)
|
|
83
|
+
|
|
84
|
+
@router.delete("/{thread_id}")
|
|
85
|
+
async def delete_thread(thread_id: str, user: BaseUser = Depends(get_current_user)):
|
|
86
|
+
"""Delete thread"""
|
|
87
|
+
thread = await self.service.get_thread(thread_id)
|
|
88
|
+
if not thread:
|
|
89
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
90
|
+
if not await self.authorizer.can_delete_thread(user, thread):
|
|
91
|
+
raise HTTPException(status_code=403, detail="Access denied")
|
|
92
|
+
|
|
93
|
+
await self.service.delete_thread(thread_id)
|
|
94
|
+
return {"message": "Thread deleted successfully"}
|
|
95
|
+
|
|
96
|
+
@router.get("", response_model=ListThreadsResponse)
|
|
97
|
+
async def list_threads(
|
|
98
|
+
thread_type: Optional[str] = Query(None),
|
|
99
|
+
title_contains: Optional[str] = Query(None),
|
|
100
|
+
min_messages: Optional[int] = Query(None),
|
|
101
|
+
max_messages: Optional[int] = Query(None),
|
|
102
|
+
hours_ago: Optional[int] = Query(None),
|
|
103
|
+
user: BaseUser = Depends(get_current_user),
|
|
104
|
+
):
|
|
105
|
+
"""List threads with filtering"""
|
|
106
|
+
filter_builder = ThreadFilter.builder()
|
|
107
|
+
|
|
108
|
+
filter_builder.with_author_id(user.get_id())
|
|
109
|
+
if thread_type:
|
|
110
|
+
filter_builder.with_type(thread_type)
|
|
111
|
+
if title_contains:
|
|
112
|
+
filter_builder.with_title_containing(title_contains)
|
|
113
|
+
if min_messages:
|
|
114
|
+
filter_builder.with_min_messages(min_messages)
|
|
115
|
+
if max_messages:
|
|
116
|
+
filter_builder.with_max_messages(max_messages)
|
|
117
|
+
if hours_ago:
|
|
118
|
+
cutoff = int((time.time() - hours_ago * 3600) * 1000)
|
|
119
|
+
filter_builder.with_activity_after(cutoff)
|
|
120
|
+
|
|
121
|
+
filter = filter_builder.build()
|
|
122
|
+
threads = await self.service.list_threads(filter)
|
|
123
|
+
|
|
124
|
+
return ListThreadsResponse(threads=ThreadConverters.metadata_list_to_response(threads), total_count=len(threads))
|
|
125
|
+
|
|
126
|
+
@router.get("/{thread_id}/messages", response_model=List[MessageResponse])
|
|
127
|
+
async def get_thread_messages(thread_id: str, user: BaseUser = Depends(get_current_user)):
|
|
128
|
+
"""Get all messages for a thread"""
|
|
129
|
+
thread = await self.service.get_thread(thread_id)
|
|
130
|
+
if not thread:
|
|
131
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
132
|
+
if not await self.authorizer.can_read_thread(user, thread):
|
|
133
|
+
raise HTTPException(status_code=403, detail="Access denied")
|
|
134
|
+
|
|
135
|
+
messages = await self.service.get_thread_messages(thread_id)
|
|
136
|
+
return ThreadConverters.messages_model_to_response(messages)
|
|
137
|
+
|
|
138
|
+
@router.get("/{thread_id}/messages/{message_id}", response_model=MessageResponse)
|
|
139
|
+
async def get_message(thread_id: str, message_id: str, user: BaseUser = Depends(get_current_user)):
|
|
140
|
+
"""Get specific message by ID"""
|
|
141
|
+
thread = await self.service.get_thread(thread_id)
|
|
142
|
+
if not thread:
|
|
143
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
144
|
+
if not await self.authorizer.can_read_thread(user, thread):
|
|
145
|
+
raise HTTPException(status_code=403, detail="Access denied")
|
|
146
|
+
|
|
147
|
+
message = await self.service.get_message(thread_id, message_id)
|
|
148
|
+
if not message:
|
|
149
|
+
raise HTTPException(status_code=404, detail="Message not found")
|
|
150
|
+
return ThreadConverters.message_model_to_response(message)
|
|
151
|
+
|
|
152
|
+
@router.delete("/{thread_id}/messages/{message_id}")
|
|
153
|
+
async def delete_message(thread_id: str, message_id: str, user: BaseUser = Depends(get_current_user)):
|
|
154
|
+
"""Delete message from thread"""
|
|
155
|
+
thread = await self.service.get_thread(thread_id)
|
|
156
|
+
if not thread:
|
|
157
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
158
|
+
if not await self.authorizer.can_post_message(user, thread):
|
|
159
|
+
raise HTTPException(status_code=403, detail="Access denied")
|
|
160
|
+
|
|
161
|
+
success = await self.service.delete_message(thread_id, message_id)
|
|
162
|
+
if not success:
|
|
163
|
+
raise HTTPException(status_code=404, detail="Message not found")
|
|
164
|
+
return {"message": "Message deleted successfully"}
|
|
165
|
+
|
|
166
|
+
@router.get("/{thread_id}/negotiate-streaming", response_model=ThreadResponse)
|
|
167
|
+
async def negotiate_streaming(thread_id: str, user: BaseUser = Depends(get_current_user)):
|
|
168
|
+
"""Negotiate streaming"""
|
|
169
|
+
if not self.streaming_negotiator:
|
|
170
|
+
raise HTTPException(status_code=501, detail="Streaming negotiation not supported")
|
|
171
|
+
|
|
172
|
+
thread = await self.service.get_thread(thread_id)
|
|
173
|
+
if not thread:
|
|
174
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
175
|
+
if not await self.authorizer.can_read_thread(user, thread):
|
|
176
|
+
raise HTTPException(status_code=403, detail="Access denied")
|
|
177
|
+
|
|
178
|
+
return self.streaming_negotiator.negotiate_thread_streaming(thread_id, user)
|
|
179
|
+
|
|
180
|
+
async def _create_message(thread_id: str, request: CreateMessageRequest, user: BaseUser) -> ThreadMessage:
|
|
181
|
+
"""Create a new message"""
|
|
182
|
+
|
|
183
|
+
thread = await self.service.get_thread(thread_id)
|
|
184
|
+
if not thread:
|
|
185
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
186
|
+
|
|
187
|
+
if not await self.authorizer.can_post_message(user, thread):
|
|
188
|
+
raise HTTPException(status_code=403, detail="Access denied")
|
|
189
|
+
|
|
190
|
+
attachments = None
|
|
191
|
+
if request.attachments:
|
|
192
|
+
attachments = [ThreadConverters.attachment_request_to_model(att) for att in request.attachments]
|
|
193
|
+
|
|
194
|
+
message = ThreadMessage(
|
|
195
|
+
id=str(uuid.uuid4()),
|
|
196
|
+
author_id=user.get_id(),
|
|
197
|
+
author_name=user.get_name(),
|
|
198
|
+
timestamp=int(time.time() * 1000),
|
|
199
|
+
ai=False,
|
|
200
|
+
blocks=[MessageBlock(content=request.content, type=MessageBlockType.PLAIN, id=str(uuid.uuid4()), streaming=False)],
|
|
201
|
+
attachments=attachments,
|
|
202
|
+
)
|
|
203
|
+
thread.add_message(message)
|
|
204
|
+
await self.service.update_thread(thread)
|
|
205
|
+
return message
|
|
206
|
+
|
|
207
|
+
@router.post("/{thread_id}/messages/stream")
|
|
208
|
+
async def create_message_stream(thread_id: str, request: CreateMessageRequest, user: BaseUser = Depends(get_current_user)):
|
|
209
|
+
"""Create a new message and stream the response immediately"""
|
|
210
|
+
|
|
211
|
+
if not self.thread_job_processor:
|
|
212
|
+
raise HTTPException(status_code=501, detail="Thread job processor not supported")
|
|
213
|
+
message = await _create_message(thread_id, request, user)
|
|
214
|
+
|
|
215
|
+
job = AgentJob(job_type=JobType.THREAD_MESSAGE, id=thread_id)
|
|
216
|
+
cancellation_subscriber = (
|
|
217
|
+
await self.cancellation_subscriber_provider(thread_id) if self.cancellation_subscriber_provider else None
|
|
218
|
+
)
|
|
219
|
+
cancellation_handle = cancellation_subscriber.get_cancellation_handle() if cancellation_subscriber else None
|
|
220
|
+
|
|
221
|
+
def on_complete():
|
|
222
|
+
if cancellation_subscriber:
|
|
223
|
+
cancellation_subscriber.stop()
|
|
224
|
+
|
|
225
|
+
logger.debug(f"Starting processing for thread {thread_id}")
|
|
226
|
+
|
|
227
|
+
async def generate_stream():
|
|
228
|
+
try:
|
|
229
|
+
logger.debug(f"Starting streaming stream for thread {thread_id}")
|
|
230
|
+
yield MessageAddedEvent(message=message).dump_json(thread_id) + "\n\n"
|
|
231
|
+
|
|
232
|
+
async for event_response in self.thread_job_processor.process_job(
|
|
233
|
+
job=job, cancellation_handle=cancellation_handle, on_complete=on_complete
|
|
234
|
+
):
|
|
235
|
+
logger.debug(f"Received event response: {event_response}")
|
|
236
|
+
yield json.dumps(event_response) + "\n\n"
|
|
237
|
+
|
|
238
|
+
except Exception as e:
|
|
239
|
+
logger.error(f"Error in SSE stream: {e}")
|
|
240
|
+
yield f'data: {{"error": "{str(e)}"}}\n\n'
|
|
241
|
+
|
|
242
|
+
logger.info(f"StreamingResponse for job {thread_id}")
|
|
243
|
+
return StreamingResponse(
|
|
244
|
+
generate_stream(),
|
|
245
|
+
media_type="text/plain; charset=utf-8", # devtools are not happy with the proper mime type
|
|
246
|
+
headers={
|
|
247
|
+
"Cache-Control": "no-cache",
|
|
248
|
+
# "Connection": "keep-alive",
|
|
249
|
+
"Access-Control-Allow-Origin": "*",
|
|
250
|
+
"Access-Control-Allow-Headers": "Cache-Control",
|
|
251
|
+
"X-Accel-Buffering": "no", # Disable nginx buffering
|
|
252
|
+
},
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
@router.post("/{thread_id}/messages", response_model=MessageResponse)
|
|
256
|
+
async def create_message(thread_id: str, request: CreateMessageRequest, user: BaseUser = Depends(get_current_user)):
|
|
257
|
+
"""Create a new message"""
|
|
258
|
+
if not self.job_queue:
|
|
259
|
+
raise HTTPException(status_code=501, detail="Job queue not supported")
|
|
260
|
+
|
|
261
|
+
message = await _create_message(thread_id, request, user)
|
|
262
|
+
|
|
263
|
+
job = AgentJob(job_type=JobType.THREAD_MESSAGE, id=thread_id)
|
|
264
|
+
await self.job_queue.push(job)
|
|
265
|
+
return ThreadConverters.message_model_to_response(message)
|
|
266
|
+
|
|
267
|
+
@router.post("/{thread_id}/cancel")
|
|
268
|
+
async def cancel_generation(thread_id: str, user: BaseUser = Depends(get_current_user)):
|
|
269
|
+
"""Cancel the current generation for a thread"""
|
|
270
|
+
thread = await self.service.get_thread(thread_id)
|
|
271
|
+
if not thread:
|
|
272
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
273
|
+
|
|
274
|
+
if not await self.authorizer.can_post_message(user, thread):
|
|
275
|
+
raise HTTPException(status_code=403, detail="Access denied")
|
|
276
|
+
|
|
277
|
+
if self.cancellation_publisher:
|
|
278
|
+
self.cancellation_publisher.publish_cancellation(thread_id)
|
|
279
|
+
logger.info(f"Cancellation published for thread {thread_id}")
|
|
280
|
+
return {"success": True, "message": "Cancellation signal sent"}
|
|
281
|
+
else:
|
|
282
|
+
raise HTTPException(status_code=501, detail="Cancellation not supported")
|
|
283
|
+
|
|
284
|
+
return router
|
|
File without changes
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any, Dict, Optional, Union
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from spaik_sdk.utils.init_logger import init_logger
|
|
7
|
+
|
|
8
|
+
logger = init_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EventStreamResponse(BaseModel):
|
|
12
|
+
event: Optional[str] = None
|
|
13
|
+
payload: Union[str, dict[str, Any]]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def format_sse_event(event_response: Dict[str, Any]) -> str:
|
|
17
|
+
"""Format EventStreamResponse as SSE data"""
|
|
18
|
+
|
|
19
|
+
# Build SSE message
|
|
20
|
+
sse_lines: list[str] = []
|
|
21
|
+
|
|
22
|
+
# Add event type if present
|
|
23
|
+
if event_response.get("event_type"):
|
|
24
|
+
sse_lines.append(f"event: {event_response['event_type']}")
|
|
25
|
+
|
|
26
|
+
payload_obj = event_response.get("data")
|
|
27
|
+
|
|
28
|
+
# Serialize payload to JSON-safe string
|
|
29
|
+
try:
|
|
30
|
+
if isinstance(payload_obj, str):
|
|
31
|
+
data_content = payload_obj
|
|
32
|
+
else:
|
|
33
|
+
data_content = json.dumps(payload_obj)
|
|
34
|
+
except Exception as err:
|
|
35
|
+
logger.warning(f"Failed to JSON serialize payload ({err}), using repr")
|
|
36
|
+
data_content = repr(payload_obj)
|
|
37
|
+
|
|
38
|
+
sse_lines.append(f"data: {data_content}")
|
|
39
|
+
|
|
40
|
+
# SSE format requires double newline at end
|
|
41
|
+
return "\n".join(sse_lines) + "\n\n"
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
from spaik_sdk.server.api.streaming.negotiate_streaming_response import NegotiateStreamingResponse
|
|
4
|
+
from spaik_sdk.server.authorization.base_user import BaseUser
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class StreamingNegotiator(ABC):
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def negotiate_thread_streaming(self, thread_id: str, user: BaseUser) -> NegotiateStreamingResponse:
|
|
10
|
+
pass
|
|
File without changes
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import TYPE_CHECKING, Generic, Optional, TypeVar
|
|
3
|
+
|
|
4
|
+
from fastapi import Request
|
|
5
|
+
|
|
6
|
+
from spaik_sdk.server.authorization.base_user import BaseUser
|
|
7
|
+
from spaik_sdk.thread.models import ThreadMessage
|
|
8
|
+
from spaik_sdk.thread.thread_container import ThreadContainer
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from spaik_sdk.attachments.models import FileMetadata
|
|
12
|
+
|
|
13
|
+
TUser = TypeVar("TUser", bound=BaseUser)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BaseAuthorizer(ABC, Generic[TUser]):
|
|
17
|
+
"""Abstract base for authorization"""
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
async def get_user(self, request: Request) -> Optional[TUser]:
|
|
21
|
+
"""Get user from request"""
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
async def can_create_thread(
|
|
25
|
+
self,
|
|
26
|
+
user: TUser,
|
|
27
|
+
) -> bool:
|
|
28
|
+
"""Check if user has permission for the thread"""
|
|
29
|
+
return True
|
|
30
|
+
|
|
31
|
+
async def can_read_thread(self, user: TUser, thread_container: ThreadContainer) -> bool:
|
|
32
|
+
"""Check if user has permission to read the thread"""
|
|
33
|
+
return self.is_thread_owner(user, thread_container)
|
|
34
|
+
|
|
35
|
+
async def can_post_message(self, user: TUser, thread_container: ThreadContainer) -> bool:
|
|
36
|
+
"""Check if user has permission to post a message to the thread"""
|
|
37
|
+
return self.is_thread_owner(user, thread_container)
|
|
38
|
+
|
|
39
|
+
async def can_edit_message(self, user: TUser, thread_container: ThreadContainer, message: ThreadMessage) -> bool:
|
|
40
|
+
"""Check if user has permission to edit a message in the thread"""
|
|
41
|
+
return message.author_id == user.get_id()
|
|
42
|
+
|
|
43
|
+
async def can_delete_thread(self, user: TUser, thread_container: ThreadContainer) -> bool:
|
|
44
|
+
"""Check if user has permission to delete the thread"""
|
|
45
|
+
return self.is_thread_owner(user, thread_container)
|
|
46
|
+
|
|
47
|
+
def is_thread_owner(self, user: TUser, thread_container: ThreadContainer) -> bool:
|
|
48
|
+
"""Check if user is the owner of the thread"""
|
|
49
|
+
return thread_container.messages[0].author_id == user.get_id()
|
|
50
|
+
|
|
51
|
+
async def can_upload_file(self, user: TUser) -> bool:
|
|
52
|
+
"""Check if user has permission to upload files"""
|
|
53
|
+
return True
|
|
54
|
+
|
|
55
|
+
async def can_read_file(self, user: TUser, file_metadata: "FileMetadata") -> bool:
|
|
56
|
+
"""Check if user has permission to read a file.
|
|
57
|
+
|
|
58
|
+
By default, users can read files they own, or files owned by 'system' (agent-generated).
|
|
59
|
+
"""
|
|
60
|
+
return file_metadata.owner_id == user.get_id() or file_metadata.owner_id == "system"
|
|
61
|
+
|
|
62
|
+
async def can_delete_file(self, user: TUser, file_metadata: "FileMetadata") -> bool:
|
|
63
|
+
"""Check if user has permission to delete a file"""
|
|
64
|
+
return file_metadata.owner_id == user.get_id()
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from fastapi import Request
|
|
4
|
+
|
|
5
|
+
from spaik_sdk.server.authorization.base_authorizer import BaseAuthorizer
|
|
6
|
+
from spaik_sdk.server.authorization.base_user import BaseUser
|
|
7
|
+
from spaik_sdk.thread.thread_container import ThreadContainer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DummyAuthorizer(BaseAuthorizer[BaseUser]):
|
|
11
|
+
"""Dummy authorizer that always returns True"""
|
|
12
|
+
|
|
13
|
+
async def get_user(self, request: Request) -> Optional[BaseUser]:
|
|
14
|
+
return BaseUser("user")
|
|
15
|
+
|
|
16
|
+
def is_thread_owner(self, user: BaseUser, thread_container: ThreadContainer) -> bool:
|
|
17
|
+
return True
|
|
File without changes
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any, AsyncGenerator, Dict, Optional
|
|
3
|
+
|
|
4
|
+
from fastapi import HTTPException
|
|
5
|
+
|
|
6
|
+
from spaik_sdk.llm.cancellation_handle import CancellationHandle
|
|
7
|
+
from spaik_sdk.server.job_processor.base_job_processor import BaseJobProcessor
|
|
8
|
+
from spaik_sdk.server.queue.agent_job_queue import AgentJob
|
|
9
|
+
from spaik_sdk.server.response.response_generator import ResponseGenerator
|
|
10
|
+
from spaik_sdk.server.services.thread_service import ThreadService
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ThreadJobProcessor(BaseJobProcessor):
|
|
14
|
+
def __init__(self, thread_service: ThreadService, response_generator: ResponseGenerator):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.thread_service = thread_service
|
|
17
|
+
self.response_generator = response_generator
|
|
18
|
+
|
|
19
|
+
async def process_job(
|
|
20
|
+
self,
|
|
21
|
+
job: AgentJob,
|
|
22
|
+
cancellation_handle: Optional[CancellationHandle],
|
|
23
|
+
on_complete: Callable[[], None] = lambda: None,
|
|
24
|
+
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
25
|
+
thread_container = await self.thread_service.get_thread(job.id)
|
|
26
|
+
if thread_container is None:
|
|
27
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
28
|
+
# thread_event_publisher.bind_container(thread_container)
|
|
29
|
+
async for chunk in self.response_generator.stream_response(thread_container, cancellation_handle):
|
|
30
|
+
yield chunk
|
|
31
|
+
await self.thread_service.update_thread(thread_container)
|
|
32
|
+
on_complete()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
from spaik_sdk.llm.cancellation_handle import CancellationHandle
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CancellationSubscriberHandler(CancellationHandle):
|
|
7
|
+
def __init__(self) -> None:
|
|
8
|
+
self.cancelled: bool = False
|
|
9
|
+
|
|
10
|
+
async def is_cancelled(self) -> bool:
|
|
11
|
+
return self.cancelled
|
|
12
|
+
|
|
13
|
+
def cancel(self) -> None:
|
|
14
|
+
self.cancelled = True
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CancellationSubscriber(ABC):
|
|
18
|
+
def __init__(self, id: str):
|
|
19
|
+
self.id = id
|
|
20
|
+
self.cancellation_handle = CancellationSubscriberHandler()
|
|
21
|
+
self._subscribe_to_cancellation()
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def _subscribe_to_cancellation(self) -> None:
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def _unsubscribe_from_cancellation(self) -> None:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
def get_cancellation_handle(self) -> CancellationHandle:
|
|
32
|
+
return self.cancellation_handle
|
|
33
|
+
|
|
34
|
+
def on_cancellation(self) -> None:
|
|
35
|
+
self.cancellation_handle.cancel()
|
|
36
|
+
|
|
37
|
+
def stop(self) -> None:
|
|
38
|
+
self._unsubscribe_from_cancellation()
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
from spaik_sdk.utils.init_logger import init_logger
|
|
5
|
+
|
|
6
|
+
logger = init_logger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class EventPublisher(ABC):
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def publish_event(self, event: Dict[str, Any]) -> None:
|
|
12
|
+
"""Publish event - implemented by subclasses"""
|
|
13
|
+
pass
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|