simile 0.3.13__tar.gz → 0.4.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of simile might be problematic. Click here for more details.
- {simile-0.3.13 → simile-0.4.2}/PKG-INFO +1 -1
- {simile-0.3.13 → simile-0.4.2}/pyproject.toml +1 -1
- {simile-0.3.13 → simile-0.4.2}/simile/__init__.py +14 -0
- {simile-0.3.13 → simile-0.4.2}/simile/client.py +274 -19
- simile-0.4.2/simile/models.py +400 -0
- {simile-0.3.13 → simile-0.4.2}/simile/resources.py +5 -0
- {simile-0.3.13 → simile-0.4.2}/simile.egg-info/PKG-INFO +1 -1
- simile-0.3.13/simile/models.py +0 -231
- {simile-0.3.13 → simile-0.4.2}/LICENSE +0 -0
- {simile-0.3.13 → simile-0.4.2}/README.md +0 -0
- {simile-0.3.13 → simile-0.4.2}/setup.cfg +0 -0
- {simile-0.3.13 → simile-0.4.2}/setup.py +0 -0
- {simile-0.3.13 → simile-0.4.2}/simile/auth_client.py +0 -0
- {simile-0.3.13 → simile-0.4.2}/simile/exceptions.py +0 -0
- {simile-0.3.13 → simile-0.4.2}/simile.egg-info/SOURCES.txt +0 -0
- {simile-0.3.13 → simile-0.4.2}/simile.egg-info/dependency_links.txt +0 -0
- {simile-0.3.13 → simile-0.4.2}/simile.egg-info/requires.txt +0 -0
- {simile-0.3.13 → simile-0.4.2}/simile.egg-info/top_level.txt +0 -0
|
@@ -14,6 +14,13 @@ from .models import (
|
|
|
14
14
|
OpenGenerationResponse,
|
|
15
15
|
ClosedGenerationRequest,
|
|
16
16
|
ClosedGenerationResponse,
|
|
17
|
+
MemoryStream,
|
|
18
|
+
MemoryTurn,
|
|
19
|
+
MemoryTurnType,
|
|
20
|
+
ContextMemoryTurn,
|
|
21
|
+
ImageMemoryTurn,
|
|
22
|
+
OpenQuestionMemoryTurn,
|
|
23
|
+
ClosedQuestionMemoryTurn,
|
|
17
24
|
)
|
|
18
25
|
from .exceptions import (
|
|
19
26
|
SimileAPIError,
|
|
@@ -38,6 +45,13 @@ __all__ = [
|
|
|
38
45
|
"OpenGenerationResponse",
|
|
39
46
|
"ClosedGenerationRequest",
|
|
40
47
|
"ClosedGenerationResponse",
|
|
48
|
+
"MemoryStream",
|
|
49
|
+
"MemoryTurn",
|
|
50
|
+
"MemoryTurnType",
|
|
51
|
+
"ContextMemoryTurn",
|
|
52
|
+
"ImageMemoryTurn",
|
|
53
|
+
"OpenQuestionMemoryTurn",
|
|
54
|
+
"ClosedQuestionMemoryTurn",
|
|
41
55
|
"SimileAPIError",
|
|
42
56
|
"SimileAuthenticationError",
|
|
43
57
|
"SimileNotFoundError",
|
|
@@ -21,6 +21,7 @@ from .models import (
|
|
|
21
21
|
InitialDataItemPayload,
|
|
22
22
|
SurveySessionCreateResponse,
|
|
23
23
|
SurveySessionDetailResponse,
|
|
24
|
+
MemoryStream,
|
|
24
25
|
)
|
|
25
26
|
from .resources import Agent, SurveySession
|
|
26
27
|
from .exceptions import (
|
|
@@ -345,6 +346,7 @@ class Simile:
|
|
|
345
346
|
exclude_data_types: Optional[List[str]] = None,
|
|
346
347
|
images: Optional[Dict[str, str]] = None,
|
|
347
348
|
reasoning: bool = False,
|
|
349
|
+
memory_stream: Optional[MemoryStream] = None,
|
|
348
350
|
) -> AsyncGenerator[str, None]:
|
|
349
351
|
"""Streams an open response from an agent."""
|
|
350
352
|
endpoint = f"/generation/open-stream/{str(agent_id)}"
|
|
@@ -410,22 +412,58 @@ class Simile:
|
|
|
410
412
|
exclude_data_types: Optional[List[str]] = None,
|
|
411
413
|
images: Optional[Dict[str, str]] = None,
|
|
412
414
|
reasoning: bool = False,
|
|
415
|
+
memory_stream: Optional[MemoryStream] = None,
|
|
416
|
+
use_memory: Optional[Union[str, uuid.UUID]] = None, # Session ID to load memory from
|
|
417
|
+
exclude_memory_ids: Optional[List[str]] = None, # Study/question IDs to exclude
|
|
418
|
+
save_memory: Optional[Union[str, uuid.UUID]] = None, # Session ID to save memory to
|
|
413
419
|
) -> OpenGenerationResponse:
|
|
414
|
-
"""Generates an open response from an agent based on a question.
|
|
420
|
+
"""Generates an open response from an agent based on a question.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
agent_id: The agent to query
|
|
424
|
+
question: The question to ask
|
|
425
|
+
data_types: Optional data types to include
|
|
426
|
+
exclude_data_types: Optional data types to exclude
|
|
427
|
+
images: Optional images dict
|
|
428
|
+
reasoning: Whether to include reasoning
|
|
429
|
+
memory_stream: Explicit memory stream to use (overrides use_memory)
|
|
430
|
+
use_memory: Session ID to automatically load memory from
|
|
431
|
+
exclude_memory_ids: Study/question IDs to exclude from loaded memory
|
|
432
|
+
save_memory: Session ID to automatically save response to memory
|
|
433
|
+
"""
|
|
415
434
|
endpoint = f"/generation/open/{str(agent_id)}"
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
435
|
+
# Build request payload directly as dict to avoid serialization issues
|
|
436
|
+
request_payload = {
|
|
437
|
+
"question": question,
|
|
438
|
+
"data_types": data_types,
|
|
439
|
+
"exclude_data_types": exclude_data_types,
|
|
440
|
+
"images": images,
|
|
441
|
+
"reasoning": reasoning,
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
# Pass memory parameters to API for server-side handling
|
|
445
|
+
if use_memory:
|
|
446
|
+
request_payload["use_memory"] = str(use_memory)
|
|
447
|
+
if exclude_memory_ids:
|
|
448
|
+
request_payload["exclude_memory_ids"] = exclude_memory_ids
|
|
449
|
+
|
|
450
|
+
if save_memory:
|
|
451
|
+
request_payload["save_memory"] = str(save_memory)
|
|
452
|
+
|
|
453
|
+
# Only include explicit memory_stream if provided directly
|
|
454
|
+
if memory_stream:
|
|
455
|
+
request_payload["memory_stream"] = memory_stream.to_dict()
|
|
456
|
+
|
|
423
457
|
response_data = await self._request(
|
|
424
458
|
"POST",
|
|
425
459
|
endpoint,
|
|
426
|
-
json=request_payload
|
|
460
|
+
json=request_payload,
|
|
427
461
|
response_model=OpenGenerationResponse,
|
|
428
462
|
)
|
|
463
|
+
|
|
464
|
+
# Don't save memory here - API should handle it when save_memory is passed
|
|
465
|
+
# Memory saving is now handled server-side for better performance
|
|
466
|
+
|
|
429
467
|
return response_data
|
|
430
468
|
|
|
431
469
|
async def generate_closed_response(
|
|
@@ -437,25 +475,242 @@ class Simile:
|
|
|
437
475
|
exclude_data_types: Optional[List[str]] = None,
|
|
438
476
|
images: Optional[Dict[str, str]] = None,
|
|
439
477
|
reasoning: bool = False,
|
|
478
|
+
memory_stream: Optional[MemoryStream] = None,
|
|
479
|
+
use_memory: Optional[Union[str, uuid.UUID]] = None, # Session ID to load memory from
|
|
480
|
+
exclude_memory_ids: Optional[List[str]] = None, # Study/question IDs to exclude
|
|
481
|
+
save_memory: Optional[Union[str, uuid.UUID]] = None, # Session ID to save memory to
|
|
440
482
|
) -> ClosedGenerationResponse:
|
|
441
|
-
"""Generates a closed response from an agent.
|
|
483
|
+
"""Generates a closed response from an agent.
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
agent_id: The agent to query
|
|
487
|
+
question: The question to ask
|
|
488
|
+
options: The options to choose from
|
|
489
|
+
data_types: Optional data types to include
|
|
490
|
+
exclude_data_types: Optional data types to exclude
|
|
491
|
+
images: Optional images dict
|
|
492
|
+
reasoning: Whether to include reasoning
|
|
493
|
+
memory_stream: Explicit memory stream to use (overrides use_memory)
|
|
494
|
+
use_memory: Session ID to automatically load memory from
|
|
495
|
+
exclude_memory_ids: Study/question IDs to exclude from loaded memory
|
|
496
|
+
save_memory: Session ID to automatically save response to memory
|
|
497
|
+
"""
|
|
442
498
|
endpoint = f"generation/closed/{str(agent_id)}"
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
499
|
+
# Build request payload directly as dict to avoid serialization issues
|
|
500
|
+
request_payload = {
|
|
501
|
+
"question": question,
|
|
502
|
+
"options": options,
|
|
503
|
+
"data_types": data_types,
|
|
504
|
+
"exclude_data_types": exclude_data_types,
|
|
505
|
+
"images": images,
|
|
506
|
+
"reasoning": reasoning,
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
# Pass memory parameters to API for server-side handling
|
|
510
|
+
if use_memory:
|
|
511
|
+
request_payload["use_memory"] = str(use_memory)
|
|
512
|
+
if exclude_memory_ids:
|
|
513
|
+
request_payload["exclude_memory_ids"] = exclude_memory_ids
|
|
514
|
+
|
|
515
|
+
if save_memory:
|
|
516
|
+
request_payload["save_memory"] = str(save_memory)
|
|
517
|
+
|
|
518
|
+
# Only include explicit memory_stream if provided directly
|
|
519
|
+
if memory_stream:
|
|
520
|
+
request_payload["memory_stream"] = memory_stream.to_dict()
|
|
521
|
+
|
|
451
522
|
response_data = await self._request(
|
|
452
523
|
"POST",
|
|
453
524
|
endpoint,
|
|
454
|
-
json=request_payload
|
|
525
|
+
json=request_payload,
|
|
455
526
|
response_model=ClosedGenerationResponse,
|
|
456
527
|
)
|
|
528
|
+
|
|
529
|
+
# Don't save memory here - API should handle it when save_memory is passed
|
|
530
|
+
# Memory saving is now handled server-side for better performance
|
|
531
|
+
|
|
457
532
|
return response_data
|
|
458
533
|
|
|
534
|
+
# Memory Management Methods
|
|
535
|
+
|
|
536
|
+
async def save_memory(
|
|
537
|
+
self,
|
|
538
|
+
agent_id: Union[str, uuid.UUID],
|
|
539
|
+
response: str,
|
|
540
|
+
session_id: Optional[Union[str, uuid.UUID]] = None,
|
|
541
|
+
question_id: Optional[Union[str, uuid.UUID]] = None,
|
|
542
|
+
study_id: Optional[Union[str, uuid.UUID]] = None,
|
|
543
|
+
memory_turn: Optional[Dict[str, Any]] = None,
|
|
544
|
+
memory_stream_used: Optional[Dict[str, Any]] = None,
|
|
545
|
+
reasoning: Optional[str] = None,
|
|
546
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
547
|
+
) -> str:
|
|
548
|
+
"""
|
|
549
|
+
Save a response with associated memory information.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
agent_id: The agent ID
|
|
553
|
+
response: The agent's response text
|
|
554
|
+
session_id: Session ID for memory continuity
|
|
555
|
+
question_id: The question ID (optional)
|
|
556
|
+
study_id: The study ID (optional)
|
|
557
|
+
memory_turn: The memory turn to save
|
|
558
|
+
memory_stream_used: The memory stream that was used
|
|
559
|
+
reasoning: Optional reasoning
|
|
560
|
+
metadata: Additional metadata
|
|
561
|
+
|
|
562
|
+
Returns:
|
|
563
|
+
Response ID if saved successfully
|
|
564
|
+
"""
|
|
565
|
+
payload = {
|
|
566
|
+
"agent_id": str(agent_id),
|
|
567
|
+
"response": response,
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
if session_id:
|
|
571
|
+
payload["session_id"] = str(session_id)
|
|
572
|
+
if question_id:
|
|
573
|
+
payload["question_id"] = str(question_id)
|
|
574
|
+
if study_id:
|
|
575
|
+
payload["study_id"] = str(study_id)
|
|
576
|
+
if memory_turn:
|
|
577
|
+
payload["memory_turn"] = memory_turn
|
|
578
|
+
if memory_stream_used:
|
|
579
|
+
payload["memory_stream_used"] = memory_stream_used
|
|
580
|
+
if reasoning:
|
|
581
|
+
payload["reasoning"] = reasoning
|
|
582
|
+
if metadata:
|
|
583
|
+
payload["metadata"] = metadata
|
|
584
|
+
|
|
585
|
+
response = await self._request("POST", "memory/save", json=payload)
|
|
586
|
+
data = response.json()
|
|
587
|
+
if data.get("success"):
|
|
588
|
+
return data.get("response_id")
|
|
589
|
+
raise SimileAPIError("Failed to save memory")
|
|
590
|
+
|
|
591
|
+
async def get_memory(
|
|
592
|
+
self,
|
|
593
|
+
session_id: Union[str, uuid.UUID],
|
|
594
|
+
agent_id: Union[str, uuid.UUID],
|
|
595
|
+
exclude_study_ids: Optional[List[Union[str, uuid.UUID]]] = None,
|
|
596
|
+
exclude_question_ids: Optional[List[Union[str, uuid.UUID]]] = None,
|
|
597
|
+
limit: Optional[int] = None,
|
|
598
|
+
use_memory: bool = True,
|
|
599
|
+
) -> Optional[MemoryStream]:
|
|
600
|
+
"""
|
|
601
|
+
Retrieve the memory stream for an agent in a session.
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
session_id: Session ID to filter by
|
|
605
|
+
agent_id: The agent ID
|
|
606
|
+
exclude_study_ids: List of study IDs to exclude
|
|
607
|
+
exclude_question_ids: List of question IDs to exclude
|
|
608
|
+
limit: Maximum number of turns to include
|
|
609
|
+
use_memory: Whether to use memory at all
|
|
610
|
+
|
|
611
|
+
Returns:
|
|
612
|
+
MemoryStream object or None
|
|
613
|
+
"""
|
|
614
|
+
payload = {
|
|
615
|
+
"session_id": str(session_id),
|
|
616
|
+
"agent_id": str(agent_id),
|
|
617
|
+
"use_memory": use_memory,
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
if exclude_study_ids:
|
|
621
|
+
payload["exclude_study_ids"] = [str(id) for id in exclude_study_ids]
|
|
622
|
+
if exclude_question_ids:
|
|
623
|
+
payload["exclude_question_ids"] = [str(id) for id in exclude_question_ids]
|
|
624
|
+
if limit:
|
|
625
|
+
payload["limit"] = limit
|
|
626
|
+
|
|
627
|
+
response = await self._request("POST", "memory/get", json=payload)
|
|
628
|
+
data = response.json()
|
|
629
|
+
|
|
630
|
+
if data.get("success") and data.get("memory_stream"):
|
|
631
|
+
return MemoryStream.from_dict(data["memory_stream"])
|
|
632
|
+
return None
|
|
633
|
+
|
|
634
|
+
async def get_memory_summary(
|
|
635
|
+
self,
|
|
636
|
+
session_id: Union[str, uuid.UUID],
|
|
637
|
+
) -> Dict[str, Any]:
|
|
638
|
+
"""
|
|
639
|
+
Get a summary of memory usage for a session.
|
|
640
|
+
|
|
641
|
+
Args:
|
|
642
|
+
session_id: Session ID to analyze
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
Dictionary with memory statistics
|
|
646
|
+
"""
|
|
647
|
+
response = await self._request("GET", f"memory/summary/{session_id}")
|
|
648
|
+
data = response.json()
|
|
649
|
+
if data.get("success"):
|
|
650
|
+
return data.get("summary", {})
|
|
651
|
+
return {}
|
|
652
|
+
|
|
653
|
+
async def clear_memory(
|
|
654
|
+
self,
|
|
655
|
+
session_id: Union[str, uuid.UUID],
|
|
656
|
+
agent_id: Optional[Union[str, uuid.UUID]] = None,
|
|
657
|
+
study_id: Optional[Union[str, uuid.UUID]] = None,
|
|
658
|
+
) -> bool:
|
|
659
|
+
"""
|
|
660
|
+
Clear memory for a session, optionally filtered by agent or study.
|
|
661
|
+
|
|
662
|
+
Args:
|
|
663
|
+
session_id: Session ID to clear memory for
|
|
664
|
+
agent_id: Optional agent ID to filter by
|
|
665
|
+
study_id: Optional study ID to filter by
|
|
666
|
+
|
|
667
|
+
Returns:
|
|
668
|
+
True if cleared successfully, False otherwise
|
|
669
|
+
"""
|
|
670
|
+
payload = {
|
|
671
|
+
"session_id": str(session_id),
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
if agent_id:
|
|
675
|
+
payload["agent_id"] = str(agent_id)
|
|
676
|
+
if study_id:
|
|
677
|
+
payload["study_id"] = str(study_id)
|
|
678
|
+
|
|
679
|
+
response = await self._request("POST", "memory/clear", json=payload)
|
|
680
|
+
data = response.json()
|
|
681
|
+
return data.get("success", False)
|
|
682
|
+
|
|
683
|
+
async def copy_memory(
|
|
684
|
+
self,
|
|
685
|
+
from_session_id: Union[str, uuid.UUID],
|
|
686
|
+
to_session_id: Union[str, uuid.UUID],
|
|
687
|
+
agent_id: Optional[Union[str, uuid.UUID]] = None,
|
|
688
|
+
) -> int:
|
|
689
|
+
"""
|
|
690
|
+
Copy memory from one session to another.
|
|
691
|
+
|
|
692
|
+
Args:
|
|
693
|
+
from_session_id: Source session ID
|
|
694
|
+
to_session_id: Destination session ID
|
|
695
|
+
agent_id: Optional agent ID to filter by
|
|
696
|
+
|
|
697
|
+
Returns:
|
|
698
|
+
Number of memory turns copied
|
|
699
|
+
"""
|
|
700
|
+
payload = {
|
|
701
|
+
"from_session_id": str(from_session_id),
|
|
702
|
+
"to_session_id": str(to_session_id),
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
if agent_id:
|
|
706
|
+
payload["agent_id"] = str(agent_id)
|
|
707
|
+
|
|
708
|
+
response = await self._request("POST", "memory/copy", json=payload)
|
|
709
|
+
data = response.json()
|
|
710
|
+
if data.get("success"):
|
|
711
|
+
return data.get("copied_turns", 0)
|
|
712
|
+
return 0
|
|
713
|
+
|
|
459
714
|
async def aclose(self):
|
|
460
715
|
await self._client.aclose()
|
|
461
716
|
|
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
from typing import List, Dict, Any, Optional, Union, Literal
|
|
2
|
+
from pydantic import BaseModel, Field, validator
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from enum import Enum
|
|
5
|
+
import uuid
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Population(BaseModel):
|
|
9
|
+
population_id: uuid.UUID
|
|
10
|
+
name: str
|
|
11
|
+
description: Optional[str] = None
|
|
12
|
+
created_at: datetime
|
|
13
|
+
updated_at: datetime
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PopulationInfo(BaseModel):
|
|
17
|
+
population_id: uuid.UUID
|
|
18
|
+
name: str
|
|
19
|
+
description: Optional[str] = None
|
|
20
|
+
agent_count: int
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DataItem(BaseModel):
|
|
24
|
+
id: uuid.UUID
|
|
25
|
+
agent_id: uuid.UUID
|
|
26
|
+
data_type: str
|
|
27
|
+
content: Any
|
|
28
|
+
created_at: datetime
|
|
29
|
+
updated_at: datetime
|
|
30
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Agent(BaseModel):
|
|
34
|
+
agent_id: uuid.UUID
|
|
35
|
+
name: str
|
|
36
|
+
population_id: Optional[uuid.UUID] = None
|
|
37
|
+
created_at: datetime
|
|
38
|
+
updated_at: datetime
|
|
39
|
+
data_items: List[DataItem] = Field(default_factory=list)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class CreatePopulationPayload(BaseModel):
|
|
43
|
+
name: str
|
|
44
|
+
description: Optional[str] = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class InitialDataItemPayload(BaseModel):
|
|
48
|
+
data_type: str
|
|
49
|
+
content: Any
|
|
50
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class CreateAgentPayload(BaseModel):
|
|
54
|
+
name: str
|
|
55
|
+
population_id: Optional[uuid.UUID] = None
|
|
56
|
+
agent_data: Optional[List[InitialDataItemPayload]] = None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class CreateDataItemPayload(BaseModel):
|
|
60
|
+
data_type: str
|
|
61
|
+
content: Any
|
|
62
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class UpdateDataItemPayload(BaseModel):
|
|
66
|
+
content: Any
|
|
67
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class DeletionResponse(BaseModel):
|
|
71
|
+
message: str
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# --- Generation Operation Models ---
|
|
75
|
+
class OpenGenerationRequest(BaseModel):
|
|
76
|
+
question: str
|
|
77
|
+
data_types: Optional[List[str]] = None
|
|
78
|
+
exclude_data_types: Optional[List[str]] = None
|
|
79
|
+
images: Optional[Dict[str, str]] = (
|
|
80
|
+
None # Dict of {description: url} for multiple images
|
|
81
|
+
)
|
|
82
|
+
reasoning: bool = False
|
|
83
|
+
memory_stream: Optional["MemoryStream"] = None
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class OpenGenerationResponse(BaseModel):
|
|
87
|
+
question: str
|
|
88
|
+
answer: str
|
|
89
|
+
reasoning: Optional[str] = ""
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ClosedGenerationRequest(BaseModel):
|
|
93
|
+
question: str
|
|
94
|
+
options: List[str]
|
|
95
|
+
data_types: Optional[List[str]] = None
|
|
96
|
+
exclude_data_types: Optional[List[str]] = None
|
|
97
|
+
images: Optional[Dict[str, str]] = None
|
|
98
|
+
reasoning: bool = False
|
|
99
|
+
memory_stream: Optional["MemoryStream"] = None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class ClosedGenerationResponse(BaseModel):
|
|
103
|
+
question: str
|
|
104
|
+
options: List[str]
|
|
105
|
+
response: str
|
|
106
|
+
reasoning: Optional[str] = ""
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class AddContextRequest(BaseModel):
|
|
110
|
+
context: str
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class AddContextResponse(BaseModel):
|
|
114
|
+
message: str
|
|
115
|
+
session_id: uuid.UUID
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# --- Survey Session Models ---
|
|
119
|
+
class TurnType(str, Enum):
|
|
120
|
+
"""Enum for different types of conversation turns."""
|
|
121
|
+
|
|
122
|
+
CONTEXT = "context"
|
|
123
|
+
IMAGE = "image"
|
|
124
|
+
OPEN_QUESTION = "open_question"
|
|
125
|
+
CLOSED_QUESTION = "closed_question"
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class BaseTurn(BaseModel):
|
|
129
|
+
"""Base model for all conversation turns."""
|
|
130
|
+
|
|
131
|
+
timestamp: datetime = Field(default_factory=lambda: datetime.now())
|
|
132
|
+
type: TurnType
|
|
133
|
+
|
|
134
|
+
class Config:
|
|
135
|
+
use_enum_values = True
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class ContextTurn(BaseTurn):
|
|
139
|
+
"""A context turn that provides background information."""
|
|
140
|
+
|
|
141
|
+
type: Literal[TurnType.CONTEXT] = TurnType.CONTEXT
|
|
142
|
+
user_context: str
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class ImageTurn(BaseTurn):
|
|
146
|
+
"""A standalone image turn (e.g., for context or reference)."""
|
|
147
|
+
|
|
148
|
+
type: Literal[TurnType.IMAGE] = TurnType.IMAGE
|
|
149
|
+
images: Dict[str, str]
|
|
150
|
+
caption: Optional[str] = None
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class OpenQuestionTurn(BaseTurn):
|
|
154
|
+
"""An open question-answer turn."""
|
|
155
|
+
|
|
156
|
+
type: Literal[TurnType.OPEN_QUESTION] = TurnType.OPEN_QUESTION
|
|
157
|
+
user_question: str
|
|
158
|
+
user_images: Optional[Dict[str, str]] = None
|
|
159
|
+
llm_response: Optional[str] = None
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class ClosedQuestionTurn(BaseTurn):
|
|
163
|
+
"""A closed question-answer turn."""
|
|
164
|
+
|
|
165
|
+
type: Literal[TurnType.CLOSED_QUESTION] = TurnType.CLOSED_QUESTION
|
|
166
|
+
user_question: str
|
|
167
|
+
user_options: List[str]
|
|
168
|
+
user_images: Optional[Dict[str, str]] = None
|
|
169
|
+
llm_response: Optional[str] = None
|
|
170
|
+
|
|
171
|
+
@validator("user_options")
|
|
172
|
+
def validate_options(cls, v):
|
|
173
|
+
if not v:
|
|
174
|
+
raise ValueError("Closed questions must have at least one option")
|
|
175
|
+
if len(v) < 2:
|
|
176
|
+
raise ValueError("Closed questions should have at least two options")
|
|
177
|
+
return v
|
|
178
|
+
|
|
179
|
+
@validator("llm_response")
|
|
180
|
+
def validate_response(cls, v, values):
|
|
181
|
+
if (
|
|
182
|
+
v is not None
|
|
183
|
+
and "user_options" in values
|
|
184
|
+
and v not in values["user_options"]
|
|
185
|
+
):
|
|
186
|
+
raise ValueError(f"Response '{v}' must be one of the provided options")
|
|
187
|
+
return v
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
# Union type for all possible turn types
|
|
191
|
+
SurveySessionTurn = Union[ContextTurn, ImageTurn, OpenQuestionTurn, ClosedQuestionTurn]
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class SurveySessionCreateResponse(BaseModel):
|
|
195
|
+
id: uuid.UUID # Session ID
|
|
196
|
+
agent_id: uuid.UUID
|
|
197
|
+
created_at: datetime
|
|
198
|
+
status: str
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class SurveySessionDetailResponse(BaseModel):
|
|
202
|
+
"""Detailed survey session response with typed conversation turns."""
|
|
203
|
+
|
|
204
|
+
id: uuid.UUID
|
|
205
|
+
agent_id: uuid.UUID
|
|
206
|
+
created_at: datetime
|
|
207
|
+
updated_at: datetime
|
|
208
|
+
status: str
|
|
209
|
+
conversation_history: List[SurveySessionTurn] = Field(default_factory=list)
|
|
210
|
+
|
|
211
|
+
class Config:
|
|
212
|
+
json_encoders = {datetime: lambda v: v.isoformat()}
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class SurveySessionListItemResponse(BaseModel):
|
|
216
|
+
"""Summary response for listing survey sessions."""
|
|
217
|
+
|
|
218
|
+
id: uuid.UUID
|
|
219
|
+
agent_id: uuid.UUID
|
|
220
|
+
created_at: datetime
|
|
221
|
+
updated_at: datetime
|
|
222
|
+
status: str
|
|
223
|
+
turn_count: int = Field(description="Number of turns in conversation history")
|
|
224
|
+
|
|
225
|
+
class Config:
|
|
226
|
+
json_encoders = {datetime: lambda v: v.isoformat()}
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class SurveySessionCloseResponse(BaseModel):
|
|
230
|
+
id: uuid.UUID # Session ID
|
|
231
|
+
status: str
|
|
232
|
+
updated_at: datetime
|
|
233
|
+
message: Optional[str] = None
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
# --- Memory Stream Models (to replace Survey Sessions) ---
|
|
237
|
+
class MemoryTurnType(str, Enum):
|
|
238
|
+
"""Enum for different types of memory turns."""
|
|
239
|
+
|
|
240
|
+
CONTEXT = "context"
|
|
241
|
+
IMAGE = "image"
|
|
242
|
+
OPEN_QUESTION = "open_question"
|
|
243
|
+
CLOSED_QUESTION = "closed_question"
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class BaseMemoryTurn(BaseModel):
|
|
247
|
+
"""Base model for all memory turns."""
|
|
248
|
+
|
|
249
|
+
timestamp: datetime = Field(default_factory=lambda: datetime.now())
|
|
250
|
+
type: MemoryTurnType
|
|
251
|
+
|
|
252
|
+
class Config:
|
|
253
|
+
use_enum_values = True
|
|
254
|
+
|
|
255
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
256
|
+
"""Convert to dictionary for serialization."""
|
|
257
|
+
data = self.model_dump()
|
|
258
|
+
# Remove timestamp - let API handle it
|
|
259
|
+
data.pop("timestamp", None)
|
|
260
|
+
# Ensure enum is serialized as string
|
|
261
|
+
if "type" in data:
|
|
262
|
+
if hasattr(data["type"], "value"):
|
|
263
|
+
data["type"] = data["type"].value
|
|
264
|
+
return data
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class ContextMemoryTurn(BaseMemoryTurn):
|
|
268
|
+
"""A context turn that provides background information."""
|
|
269
|
+
|
|
270
|
+
type: MemoryTurnType = Field(default=MemoryTurnType.CONTEXT)
|
|
271
|
+
user_context: str
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class ImageMemoryTurn(BaseMemoryTurn):
|
|
275
|
+
"""A standalone image turn (e.g., for context or reference)."""
|
|
276
|
+
|
|
277
|
+
type: MemoryTurnType = Field(default=MemoryTurnType.IMAGE)
|
|
278
|
+
images: Dict[str, str]
|
|
279
|
+
caption: Optional[str] = None
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class OpenQuestionMemoryTurn(BaseMemoryTurn):
|
|
283
|
+
"""An open question-answer turn."""
|
|
284
|
+
|
|
285
|
+
type: MemoryTurnType = Field(default=MemoryTurnType.OPEN_QUESTION)
|
|
286
|
+
user_question: str
|
|
287
|
+
user_images: Optional[Dict[str, str]] = None
|
|
288
|
+
llm_response: Optional[str] = None
|
|
289
|
+
llm_reasoning: Optional[str] = None
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class ClosedQuestionMemoryTurn(BaseMemoryTurn):
|
|
293
|
+
"""A closed question-answer turn."""
|
|
294
|
+
|
|
295
|
+
type: MemoryTurnType = Field(default=MemoryTurnType.CLOSED_QUESTION)
|
|
296
|
+
user_question: str
|
|
297
|
+
user_options: List[str]
|
|
298
|
+
user_images: Optional[Dict[str, str]] = None
|
|
299
|
+
llm_response: Optional[str] = None
|
|
300
|
+
llm_reasoning: Optional[str] = None
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
# Discriminated union of all memory turn types
|
|
304
|
+
MemoryTurn = Union[
|
|
305
|
+
ContextMemoryTurn, ImageMemoryTurn, OpenQuestionMemoryTurn, ClosedQuestionMemoryTurn
|
|
306
|
+
]
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class MemoryStream(BaseModel):
|
|
310
|
+
"""
|
|
311
|
+
A flexible memory stream that can be passed to generation functions.
|
|
312
|
+
This replaces the session-based approach with a more flexible paradigm.
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
turns: List[MemoryTurn] = Field(default_factory=list)
|
|
316
|
+
|
|
317
|
+
def add_turn(self, turn: MemoryTurn) -> None:
|
|
318
|
+
"""Add a turn to the memory stream."""
|
|
319
|
+
self.turns.append(turn)
|
|
320
|
+
|
|
321
|
+
def remove_turn(self, index: int) -> Optional[MemoryTurn]:
|
|
322
|
+
"""Remove a turn at the specified index."""
|
|
323
|
+
if 0 <= index < len(self.turns):
|
|
324
|
+
return self.turns.pop(index)
|
|
325
|
+
return None
|
|
326
|
+
|
|
327
|
+
def get_turns_by_type(self, turn_type: MemoryTurnType) -> List[MemoryTurn]:
|
|
328
|
+
"""Get all turns of a specific type."""
|
|
329
|
+
return [turn for turn in self.turns if turn.type == turn_type]
|
|
330
|
+
|
|
331
|
+
def get_last_turn(self) -> Optional[MemoryTurn]:
|
|
332
|
+
"""Get the most recent turn."""
|
|
333
|
+
return self.turns[-1] if self.turns else None
|
|
334
|
+
|
|
335
|
+
def clear(self) -> None:
|
|
336
|
+
"""Clear all turns from the memory stream."""
|
|
337
|
+
self.turns = []
|
|
338
|
+
|
|
339
|
+
def __len__(self) -> int:
|
|
340
|
+
"""Return the number of turns in the memory stream."""
|
|
341
|
+
return len(self.turns)
|
|
342
|
+
|
|
343
|
+
def __bool__(self) -> bool:
|
|
344
|
+
"""Return True if the memory stream has any turns."""
|
|
345
|
+
return bool(self.turns)
|
|
346
|
+
|
|
347
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
348
|
+
"""Convert memory stream to a dictionary for serialization."""
|
|
349
|
+
return {
|
|
350
|
+
"turns": [turn.to_dict() for turn in self.turns]
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
@classmethod
|
|
354
|
+
def from_dict(cls, data: Dict[str, Any]) -> "MemoryStream":
|
|
355
|
+
"""Create a MemoryStream from a dictionary."""
|
|
356
|
+
memory = cls()
|
|
357
|
+
for turn_data in data.get("turns", []):
|
|
358
|
+
turn_type = turn_data.get("type")
|
|
359
|
+
if turn_type == MemoryTurnType.CONTEXT:
|
|
360
|
+
memory.add_turn(ContextMemoryTurn(**turn_data))
|
|
361
|
+
elif turn_type == MemoryTurnType.IMAGE:
|
|
362
|
+
memory.add_turn(ImageMemoryTurn(**turn_data))
|
|
363
|
+
elif turn_type == MemoryTurnType.OPEN_QUESTION:
|
|
364
|
+
memory.add_turn(OpenQuestionMemoryTurn(**turn_data))
|
|
365
|
+
elif turn_type == MemoryTurnType.CLOSED_QUESTION:
|
|
366
|
+
memory.add_turn(ClosedQuestionMemoryTurn(**turn_data))
|
|
367
|
+
return memory
|
|
368
|
+
|
|
369
|
+
def fork(self, up_to_index: Optional[int] = None) -> "MemoryStream":
|
|
370
|
+
"""Create a copy of this memory stream, optionally up to a specific index."""
|
|
371
|
+
new_memory = MemoryStream()
|
|
372
|
+
turns_to_copy = self.turns[:up_to_index] if up_to_index is not None else self.turns
|
|
373
|
+
for turn in turns_to_copy:
|
|
374
|
+
new_memory.add_turn(turn.model_copy())
|
|
375
|
+
return new_memory
|
|
376
|
+
|
|
377
|
+
def filter_by_type(self, turn_type: MemoryTurnType) -> "MemoryStream":
|
|
378
|
+
"""Create a new memory stream with only turns of a specific type."""
|
|
379
|
+
new_memory = MemoryStream()
|
|
380
|
+
for turn in self.get_turns_by_type(turn_type):
|
|
381
|
+
new_memory.add_turn(turn.model_copy())
|
|
382
|
+
return new_memory
|
|
383
|
+
|
|
384
|
+
def get_question_answer_pairs(self) -> List[tuple]:
|
|
385
|
+
"""Extract question-answer pairs from the memory."""
|
|
386
|
+
pairs = []
|
|
387
|
+
for turn in self.turns:
|
|
388
|
+
if isinstance(turn, (OpenQuestionMemoryTurn, ClosedQuestionMemoryTurn)):
|
|
389
|
+
if turn.llm_response:
|
|
390
|
+
pairs.append((turn.user_question, turn.llm_response))
|
|
391
|
+
return pairs
|
|
392
|
+
|
|
393
|
+
def truncate(self, max_turns: int) -> None:
|
|
394
|
+
"""Keep only the most recent N turns."""
|
|
395
|
+
if len(self.turns) > max_turns:
|
|
396
|
+
self.turns = self.turns[-max_turns:]
|
|
397
|
+
|
|
398
|
+
def insert_turn(self, index: int, turn: MemoryTurn) -> None:
|
|
399
|
+
"""Insert a turn at a specific position."""
|
|
400
|
+
self.turns.insert(index, turn)
|
|
@@ -11,6 +11,7 @@ from .models import (
|
|
|
11
11
|
AddContextResponse,
|
|
12
12
|
SurveySessionDetailResponse,
|
|
13
13
|
SurveySessionCreateResponse,
|
|
14
|
+
MemoryStream,
|
|
14
15
|
)
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING:
|
|
@@ -34,6 +35,7 @@ class Agent:
|
|
|
34
35
|
data_types: Optional[List[str]] = None,
|
|
35
36
|
exclude_data_types: Optional[List[str]] = None,
|
|
36
37
|
images: Optional[Dict[str, str]] = None,
|
|
38
|
+
memory_stream: Optional[MemoryStream] = None,
|
|
37
39
|
) -> OpenGenerationResponse:
|
|
38
40
|
"""Generates an open response from this agent based on a question."""
|
|
39
41
|
return await self._client.generate_open_response(
|
|
@@ -42,6 +44,7 @@ class Agent:
|
|
|
42
44
|
data_types=data_types,
|
|
43
45
|
exclude_data_types=exclude_data_types,
|
|
44
46
|
images=images,
|
|
47
|
+
memory_stream=memory_stream,
|
|
45
48
|
)
|
|
46
49
|
|
|
47
50
|
async def generate_closed_response(
|
|
@@ -51,6 +54,7 @@ class Agent:
|
|
|
51
54
|
data_types: Optional[List[str]] = None,
|
|
52
55
|
exclude_data_types: Optional[List[str]] = None,
|
|
53
56
|
images: Optional[Dict[str, str]] = None,
|
|
57
|
+
memory_stream: Optional[MemoryStream] = None,
|
|
54
58
|
) -> ClosedGenerationResponse:
|
|
55
59
|
"""Generates a closed response from this agent."""
|
|
56
60
|
return await self._client.generate_closed_response(
|
|
@@ -60,6 +64,7 @@ class Agent:
|
|
|
60
64
|
data_types=data_types,
|
|
61
65
|
exclude_data_types=exclude_data_types,
|
|
62
66
|
images=images,
|
|
67
|
+
memory_stream=memory_stream,
|
|
63
68
|
)
|
|
64
69
|
|
|
65
70
|
|
simile-0.3.13/simile/models.py
DELETED
|
@@ -1,231 +0,0 @@
|
|
|
1
|
-
from typing import List, Dict, Any, Optional, Union, Literal
|
|
2
|
-
from pydantic import BaseModel, Field, validator
|
|
3
|
-
from datetime import datetime
|
|
4
|
-
from enum import Enum
|
|
5
|
-
import uuid
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class Population(BaseModel):
|
|
9
|
-
population_id: uuid.UUID
|
|
10
|
-
name: str
|
|
11
|
-
description: Optional[str] = None
|
|
12
|
-
created_at: datetime
|
|
13
|
-
updated_at: datetime
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class PopulationInfo(BaseModel):
|
|
17
|
-
population_id: uuid.UUID
|
|
18
|
-
name: str
|
|
19
|
-
description: Optional[str] = None
|
|
20
|
-
agent_count: int
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class DataItem(BaseModel):
|
|
24
|
-
id: uuid.UUID
|
|
25
|
-
agent_id: uuid.UUID
|
|
26
|
-
data_type: str
|
|
27
|
-
content: Any
|
|
28
|
-
created_at: datetime
|
|
29
|
-
updated_at: datetime
|
|
30
|
-
metadata: Optional[Dict[str, Any]] = None
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class Agent(BaseModel):
|
|
34
|
-
agent_id: uuid.UUID
|
|
35
|
-
name: str
|
|
36
|
-
population_id: Optional[uuid.UUID] = None
|
|
37
|
-
created_at: datetime
|
|
38
|
-
updated_at: datetime
|
|
39
|
-
data_items: List[DataItem] = Field(default_factory=list)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
class CreatePopulationPayload(BaseModel):
|
|
43
|
-
name: str
|
|
44
|
-
description: Optional[str] = None
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class InitialDataItemPayload(BaseModel):
|
|
48
|
-
data_type: str
|
|
49
|
-
content: Any
|
|
50
|
-
metadata: Optional[Dict[str, Any]] = None
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class CreateAgentPayload(BaseModel):
|
|
54
|
-
name: str
|
|
55
|
-
population_id: Optional[uuid.UUID] = None
|
|
56
|
-
agent_data: Optional[List[InitialDataItemPayload]] = None
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class CreateDataItemPayload(BaseModel):
|
|
60
|
-
data_type: str
|
|
61
|
-
content: Any
|
|
62
|
-
metadata: Optional[Dict[str, Any]] = None
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
class UpdateDataItemPayload(BaseModel):
|
|
66
|
-
content: Any
|
|
67
|
-
metadata: Optional[Dict[str, Any]] = None
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
class DeletionResponse(BaseModel):
|
|
71
|
-
message: str
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
# --- Generation Operation Models ---
|
|
75
|
-
class OpenGenerationRequest(BaseModel):
|
|
76
|
-
question: str
|
|
77
|
-
data_types: Optional[List[str]] = None
|
|
78
|
-
exclude_data_types: Optional[List[str]] = None
|
|
79
|
-
images: Optional[Dict[str, str]] = (
|
|
80
|
-
None # Dict of {description: url} for multiple images
|
|
81
|
-
)
|
|
82
|
-
reasoning: bool = False
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
class OpenGenerationResponse(BaseModel):
|
|
86
|
-
question: str
|
|
87
|
-
answer: str
|
|
88
|
-
reasoning: Optional[str] = ""
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
class ClosedGenerationRequest(BaseModel):
|
|
92
|
-
question: str
|
|
93
|
-
options: List[str]
|
|
94
|
-
data_types: Optional[List[str]] = None
|
|
95
|
-
exclude_data_types: Optional[List[str]] = None
|
|
96
|
-
images: Optional[Dict[str, str]] = None
|
|
97
|
-
reasoning: bool = False
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
class ClosedGenerationResponse(BaseModel):
|
|
101
|
-
question: str
|
|
102
|
-
options: List[str]
|
|
103
|
-
response: str
|
|
104
|
-
reasoning: Optional[str] = ""
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
class AddContextRequest(BaseModel):
|
|
108
|
-
context: str
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
class AddContextResponse(BaseModel):
|
|
112
|
-
message: str
|
|
113
|
-
session_id: uuid.UUID
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
# --- Survey Session Models ---
|
|
117
|
-
class TurnType(str, Enum):
|
|
118
|
-
"""Enum for different types of conversation turns."""
|
|
119
|
-
|
|
120
|
-
CONTEXT = "context"
|
|
121
|
-
IMAGE = "image"
|
|
122
|
-
OPEN_QUESTION = "open_question"
|
|
123
|
-
CLOSED_QUESTION = "closed_question"
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
class BaseTurn(BaseModel):
|
|
127
|
-
"""Base model for all conversation turns."""
|
|
128
|
-
|
|
129
|
-
timestamp: datetime = Field(default_factory=lambda: datetime.now())
|
|
130
|
-
type: TurnType
|
|
131
|
-
|
|
132
|
-
class Config:
|
|
133
|
-
use_enum_values = True
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
class ContextTurn(BaseTurn):
|
|
137
|
-
"""A context turn that provides background information."""
|
|
138
|
-
|
|
139
|
-
type: Literal[TurnType.CONTEXT] = TurnType.CONTEXT
|
|
140
|
-
user_context: str
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
class ImageTurn(BaseTurn):
|
|
144
|
-
"""A standalone image turn (e.g., for context or reference)."""
|
|
145
|
-
|
|
146
|
-
type: Literal[TurnType.IMAGE] = TurnType.IMAGE
|
|
147
|
-
images: Dict[str, str]
|
|
148
|
-
caption: Optional[str] = None
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
class OpenQuestionTurn(BaseTurn):
|
|
152
|
-
"""An open question-answer turn."""
|
|
153
|
-
|
|
154
|
-
type: Literal[TurnType.OPEN_QUESTION] = TurnType.OPEN_QUESTION
|
|
155
|
-
user_question: str
|
|
156
|
-
user_images: Optional[Dict[str, str]] = None
|
|
157
|
-
llm_response: Optional[str] = None
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
class ClosedQuestionTurn(BaseTurn):
|
|
161
|
-
"""A closed question-answer turn."""
|
|
162
|
-
|
|
163
|
-
type: Literal[TurnType.CLOSED_QUESTION] = TurnType.CLOSED_QUESTION
|
|
164
|
-
user_question: str
|
|
165
|
-
user_options: List[str]
|
|
166
|
-
user_images: Optional[Dict[str, str]] = None
|
|
167
|
-
llm_response: Optional[str] = None
|
|
168
|
-
|
|
169
|
-
@validator("user_options")
|
|
170
|
-
def validate_options(cls, v):
|
|
171
|
-
if not v:
|
|
172
|
-
raise ValueError("Closed questions must have at least one option")
|
|
173
|
-
if len(v) < 2:
|
|
174
|
-
raise ValueError("Closed questions should have at least two options")
|
|
175
|
-
return v
|
|
176
|
-
|
|
177
|
-
@validator("llm_response")
|
|
178
|
-
def validate_response(cls, v, values):
|
|
179
|
-
if (
|
|
180
|
-
v is not None
|
|
181
|
-
and "user_options" in values
|
|
182
|
-
and v not in values["user_options"]
|
|
183
|
-
):
|
|
184
|
-
raise ValueError(f"Response '{v}' must be one of the provided options")
|
|
185
|
-
return v
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
# Union type for all possible turn types
|
|
189
|
-
SurveySessionTurn = Union[ContextTurn, ImageTurn, OpenQuestionTurn, ClosedQuestionTurn]
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
class SurveySessionCreateResponse(BaseModel):
|
|
193
|
-
id: uuid.UUID # Session ID
|
|
194
|
-
agent_id: uuid.UUID
|
|
195
|
-
created_at: datetime
|
|
196
|
-
status: str
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
class SurveySessionDetailResponse(BaseModel):
|
|
200
|
-
"""Detailed survey session response with typed conversation turns."""
|
|
201
|
-
|
|
202
|
-
id: uuid.UUID
|
|
203
|
-
agent_id: uuid.UUID
|
|
204
|
-
created_at: datetime
|
|
205
|
-
updated_at: datetime
|
|
206
|
-
status: str
|
|
207
|
-
conversation_history: List[SurveySessionTurn] = Field(default_factory=list)
|
|
208
|
-
|
|
209
|
-
class Config:
|
|
210
|
-
json_encoders = {datetime: lambda v: v.isoformat()}
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
class SurveySessionListItemResponse(BaseModel):
|
|
214
|
-
"""Summary response for listing survey sessions."""
|
|
215
|
-
|
|
216
|
-
id: uuid.UUID
|
|
217
|
-
agent_id: uuid.UUID
|
|
218
|
-
created_at: datetime
|
|
219
|
-
updated_at: datetime
|
|
220
|
-
status: str
|
|
221
|
-
turn_count: int = Field(description="Number of turns in conversation history")
|
|
222
|
-
|
|
223
|
-
class Config:
|
|
224
|
-
json_encoders = {datetime: lambda v: v.isoformat()}
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
class SurveySessionCloseResponse(BaseModel):
|
|
228
|
-
id: uuid.UUID # Session ID
|
|
229
|
-
status: str
|
|
230
|
-
updated_at: datetime
|
|
231
|
-
message: Optional[str] = None
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|