simile 0.3.12__tar.gz → 0.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of simile might be problematic. Click here for more details.
- {simile-0.3.12 → simile-0.4.1}/PKG-INFO +1 -1
- {simile-0.3.12 → simile-0.4.1}/pyproject.toml +1 -1
- {simile-0.3.12 → simile-0.4.1}/simile/__init__.py +14 -0
- {simile-0.3.12 → simile-0.4.1}/simile/client.py +321 -19
- simile-0.4.1/simile/models.py +400 -0
- {simile-0.3.12 → simile-0.4.1}/simile/resources.py +5 -0
- {simile-0.3.12 → simile-0.4.1}/simile.egg-info/PKG-INFO +1 -1
- simile-0.3.12/simile/models.py +0 -231
- {simile-0.3.12 → simile-0.4.1}/LICENSE +0 -0
- {simile-0.3.12 → simile-0.4.1}/README.md +0 -0
- {simile-0.3.12 → simile-0.4.1}/setup.cfg +0 -0
- {simile-0.3.12 → simile-0.4.1}/setup.py +0 -0
- {simile-0.3.12 → simile-0.4.1}/simile/auth_client.py +0 -0
- {simile-0.3.12 → simile-0.4.1}/simile/exceptions.py +0 -0
- {simile-0.3.12 → simile-0.4.1}/simile.egg-info/SOURCES.txt +0 -0
- {simile-0.3.12 → simile-0.4.1}/simile.egg-info/dependency_links.txt +0 -0
- {simile-0.3.12 → simile-0.4.1}/simile.egg-info/requires.txt +0 -0
- {simile-0.3.12 → simile-0.4.1}/simile.egg-info/top_level.txt +0 -0
|
@@ -14,6 +14,13 @@ from .models import (
|
|
|
14
14
|
OpenGenerationResponse,
|
|
15
15
|
ClosedGenerationRequest,
|
|
16
16
|
ClosedGenerationResponse,
|
|
17
|
+
MemoryStream,
|
|
18
|
+
MemoryTurn,
|
|
19
|
+
MemoryTurnType,
|
|
20
|
+
ContextMemoryTurn,
|
|
21
|
+
ImageMemoryTurn,
|
|
22
|
+
OpenQuestionMemoryTurn,
|
|
23
|
+
ClosedQuestionMemoryTurn,
|
|
17
24
|
)
|
|
18
25
|
from .exceptions import (
|
|
19
26
|
SimileAPIError,
|
|
@@ -38,6 +45,13 @@ __all__ = [
|
|
|
38
45
|
"OpenGenerationResponse",
|
|
39
46
|
"ClosedGenerationRequest",
|
|
40
47
|
"ClosedGenerationResponse",
|
|
48
|
+
"MemoryStream",
|
|
49
|
+
"MemoryTurn",
|
|
50
|
+
"MemoryTurnType",
|
|
51
|
+
"ContextMemoryTurn",
|
|
52
|
+
"ImageMemoryTurn",
|
|
53
|
+
"OpenQuestionMemoryTurn",
|
|
54
|
+
"ClosedQuestionMemoryTurn",
|
|
41
55
|
"SimileAPIError",
|
|
42
56
|
"SimileAuthenticationError",
|
|
43
57
|
"SimileNotFoundError",
|
|
@@ -21,6 +21,7 @@ from .models import (
|
|
|
21
21
|
InitialDataItemPayload,
|
|
22
22
|
SurveySessionCreateResponse,
|
|
23
23
|
SurveySessionDetailResponse,
|
|
24
|
+
MemoryStream,
|
|
24
25
|
)
|
|
25
26
|
from .resources import Agent, SurveySession
|
|
26
27
|
from .exceptions import (
|
|
@@ -254,6 +255,18 @@ class Simile:
|
|
|
254
255
|
"DELETE", f"agents/{str(agent_id)}/populations/{str(population_id)}"
|
|
255
256
|
)
|
|
256
257
|
return raw_response.json()
|
|
258
|
+
|
|
259
|
+
async def batch_add_agents_to_population(
|
|
260
|
+
self, agent_ids: List[Union[str, uuid.UUID]], population_id: Union[str, uuid.UUID]
|
|
261
|
+
) -> Dict[str, Any]:
|
|
262
|
+
"""Add multiple agents to a population in a single batch operation."""
|
|
263
|
+
agent_id_strs = [str(aid) for aid in agent_ids]
|
|
264
|
+
raw_response = await self._request(
|
|
265
|
+
"POST",
|
|
266
|
+
f"populations/{str(population_id)}/agents/batch",
|
|
267
|
+
json=agent_id_strs
|
|
268
|
+
)
|
|
269
|
+
return raw_response.json()
|
|
257
270
|
|
|
258
271
|
async def get_populations_for_agent(
|
|
259
272
|
self, agent_id: Union[str, uuid.UUID]
|
|
@@ -333,6 +346,7 @@ class Simile:
|
|
|
333
346
|
exclude_data_types: Optional[List[str]] = None,
|
|
334
347
|
images: Optional[Dict[str, str]] = None,
|
|
335
348
|
reasoning: bool = False,
|
|
349
|
+
memory_stream: Optional[MemoryStream] = None,
|
|
336
350
|
) -> AsyncGenerator[str, None]:
|
|
337
351
|
"""Streams an open response from an agent."""
|
|
338
352
|
endpoint = f"/generation/open-stream/{str(agent_id)}"
|
|
@@ -398,22 +412,75 @@ class Simile:
|
|
|
398
412
|
exclude_data_types: Optional[List[str]] = None,
|
|
399
413
|
images: Optional[Dict[str, str]] = None,
|
|
400
414
|
reasoning: bool = False,
|
|
415
|
+
memory_stream: Optional[MemoryStream] = None,
|
|
416
|
+
use_memory: Optional[Union[str, uuid.UUID]] = None, # Session ID to load memory from
|
|
417
|
+
exclude_memory_ids: Optional[List[str]] = None, # Study/question IDs to exclude
|
|
418
|
+
save_memory: Optional[Union[str, uuid.UUID]] = None, # Session ID to save memory to
|
|
401
419
|
) -> OpenGenerationResponse:
|
|
402
|
-
"""Generates an open response from an agent based on a question.
|
|
420
|
+
"""Generates an open response from an agent based on a question.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
agent_id: The agent to query
|
|
424
|
+
question: The question to ask
|
|
425
|
+
data_types: Optional data types to include
|
|
426
|
+
exclude_data_types: Optional data types to exclude
|
|
427
|
+
images: Optional images dict
|
|
428
|
+
reasoning: Whether to include reasoning
|
|
429
|
+
memory_stream: Explicit memory stream to use (overrides use_memory)
|
|
430
|
+
use_memory: Session ID to automatically load memory from
|
|
431
|
+
exclude_memory_ids: Study/question IDs to exclude from loaded memory
|
|
432
|
+
save_memory: Session ID to automatically save response to memory
|
|
433
|
+
"""
|
|
434
|
+
# If use_memory is provided and no explicit memory_stream, load it
|
|
435
|
+
if use_memory and not memory_stream:
|
|
436
|
+
memory_stream = await self.get_memory(
|
|
437
|
+
session_id=use_memory,
|
|
438
|
+
agent_id=agent_id,
|
|
439
|
+
exclude_study_ids=exclude_memory_ids,
|
|
440
|
+
use_memory=True
|
|
441
|
+
)
|
|
442
|
+
|
|
403
443
|
endpoint = f"/generation/open/{str(agent_id)}"
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
444
|
+
# Build request payload directly as dict to avoid serialization issues
|
|
445
|
+
request_payload = {
|
|
446
|
+
"question": question,
|
|
447
|
+
"data_types": data_types,
|
|
448
|
+
"exclude_data_types": exclude_data_types,
|
|
449
|
+
"images": images,
|
|
450
|
+
"reasoning": reasoning,
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
if memory_stream:
|
|
454
|
+
request_payload["memory_stream"] = memory_stream.to_dict()
|
|
455
|
+
|
|
411
456
|
response_data = await self._request(
|
|
412
457
|
"POST",
|
|
413
458
|
endpoint,
|
|
414
|
-
json=request_payload
|
|
459
|
+
json=request_payload,
|
|
415
460
|
response_model=OpenGenerationResponse,
|
|
416
461
|
)
|
|
462
|
+
|
|
463
|
+
# If save_memory is provided, save the response
|
|
464
|
+
if save_memory and response_data:
|
|
465
|
+
from .models import OpenQuestionMemoryTurn
|
|
466
|
+
|
|
467
|
+
memory_turn = OpenQuestionMemoryTurn(
|
|
468
|
+
user_question=question,
|
|
469
|
+
user_images=images,
|
|
470
|
+
llm_response=response_data.answer,
|
|
471
|
+
llm_reasoning=response_data.reasoning if reasoning else None
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
await self.save_memory(
|
|
475
|
+
agent_id=agent_id,
|
|
476
|
+
response=response_data.answer,
|
|
477
|
+
session_id=save_memory,
|
|
478
|
+
memory_turn=memory_turn.to_dict(),
|
|
479
|
+
memory_stream_used=memory_stream.to_dict() if memory_stream else None,
|
|
480
|
+
reasoning=response_data.reasoning if reasoning else None,
|
|
481
|
+
metadata={"question_type": "open"}
|
|
482
|
+
)
|
|
483
|
+
|
|
417
484
|
return response_data
|
|
418
485
|
|
|
419
486
|
async def generate_closed_response(
|
|
@@ -425,25 +492,260 @@ class Simile:
|
|
|
425
492
|
exclude_data_types: Optional[List[str]] = None,
|
|
426
493
|
images: Optional[Dict[str, str]] = None,
|
|
427
494
|
reasoning: bool = False,
|
|
495
|
+
memory_stream: Optional[MemoryStream] = None,
|
|
496
|
+
use_memory: Optional[Union[str, uuid.UUID]] = None, # Session ID to load memory from
|
|
497
|
+
exclude_memory_ids: Optional[List[str]] = None, # Study/question IDs to exclude
|
|
498
|
+
save_memory: Optional[Union[str, uuid.UUID]] = None, # Session ID to save memory to
|
|
428
499
|
) -> ClosedGenerationResponse:
|
|
429
|
-
"""Generates a closed response from an agent.
|
|
500
|
+
"""Generates a closed response from an agent.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
agent_id: The agent to query
|
|
504
|
+
question: The question to ask
|
|
505
|
+
options: The options to choose from
|
|
506
|
+
data_types: Optional data types to include
|
|
507
|
+
exclude_data_types: Optional data types to exclude
|
|
508
|
+
images: Optional images dict
|
|
509
|
+
reasoning: Whether to include reasoning
|
|
510
|
+
memory_stream: Explicit memory stream to use (overrides use_memory)
|
|
511
|
+
use_memory: Session ID to automatically load memory from
|
|
512
|
+
exclude_memory_ids: Study/question IDs to exclude from loaded memory
|
|
513
|
+
save_memory: Session ID to automatically save response to memory
|
|
514
|
+
"""
|
|
515
|
+
# If use_memory is provided and no explicit memory_stream, load it
|
|
516
|
+
if use_memory and not memory_stream:
|
|
517
|
+
memory_stream = await self.get_memory(
|
|
518
|
+
session_id=use_memory,
|
|
519
|
+
agent_id=agent_id,
|
|
520
|
+
exclude_study_ids=exclude_memory_ids,
|
|
521
|
+
use_memory=True
|
|
522
|
+
)
|
|
523
|
+
|
|
430
524
|
endpoint = f"generation/closed/{str(agent_id)}"
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
525
|
+
# Build request payload directly as dict to avoid serialization issues
|
|
526
|
+
request_payload = {
|
|
527
|
+
"question": question,
|
|
528
|
+
"options": options,
|
|
529
|
+
"data_types": data_types,
|
|
530
|
+
"exclude_data_types": exclude_data_types,
|
|
531
|
+
"images": images,
|
|
532
|
+
"reasoning": reasoning,
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
if memory_stream:
|
|
536
|
+
request_payload["memory_stream"] = memory_stream.to_dict()
|
|
537
|
+
|
|
439
538
|
response_data = await self._request(
|
|
440
539
|
"POST",
|
|
441
540
|
endpoint,
|
|
442
|
-
json=request_payload
|
|
541
|
+
json=request_payload,
|
|
443
542
|
response_model=ClosedGenerationResponse,
|
|
444
543
|
)
|
|
544
|
+
|
|
545
|
+
# If save_memory is provided, save the response
|
|
546
|
+
if save_memory and response_data:
|
|
547
|
+
from .models import ClosedQuestionMemoryTurn
|
|
548
|
+
|
|
549
|
+
memory_turn = ClosedQuestionMemoryTurn(
|
|
550
|
+
user_question=question,
|
|
551
|
+
user_options=options,
|
|
552
|
+
user_images=images,
|
|
553
|
+
llm_response=response_data.response,
|
|
554
|
+
llm_reasoning=response_data.reasoning if reasoning else None
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
await self.save_memory(
|
|
558
|
+
agent_id=agent_id,
|
|
559
|
+
response=response_data.response,
|
|
560
|
+
session_id=save_memory,
|
|
561
|
+
memory_turn=memory_turn.to_dict(),
|
|
562
|
+
memory_stream_used=memory_stream.to_dict() if memory_stream else None,
|
|
563
|
+
reasoning=response_data.reasoning if reasoning else None,
|
|
564
|
+
metadata={"question_type": "closed", "options": options}
|
|
565
|
+
)
|
|
566
|
+
|
|
445
567
|
return response_data
|
|
446
568
|
|
|
569
|
+
# Memory Management Methods
|
|
570
|
+
|
|
571
|
+
async def save_memory(
|
|
572
|
+
self,
|
|
573
|
+
agent_id: Union[str, uuid.UUID],
|
|
574
|
+
response: str,
|
|
575
|
+
session_id: Optional[Union[str, uuid.UUID]] = None,
|
|
576
|
+
question_id: Optional[Union[str, uuid.UUID]] = None,
|
|
577
|
+
study_id: Optional[Union[str, uuid.UUID]] = None,
|
|
578
|
+
memory_turn: Optional[Dict[str, Any]] = None,
|
|
579
|
+
memory_stream_used: Optional[Dict[str, Any]] = None,
|
|
580
|
+
reasoning: Optional[str] = None,
|
|
581
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
582
|
+
) -> str:
|
|
583
|
+
"""
|
|
584
|
+
Save a response with associated memory information.
|
|
585
|
+
|
|
586
|
+
Args:
|
|
587
|
+
agent_id: The agent ID
|
|
588
|
+
response: The agent's response text
|
|
589
|
+
session_id: Session ID for memory continuity
|
|
590
|
+
question_id: The question ID (optional)
|
|
591
|
+
study_id: The study ID (optional)
|
|
592
|
+
memory_turn: The memory turn to save
|
|
593
|
+
memory_stream_used: The memory stream that was used
|
|
594
|
+
reasoning: Optional reasoning
|
|
595
|
+
metadata: Additional metadata
|
|
596
|
+
|
|
597
|
+
Returns:
|
|
598
|
+
Response ID if saved successfully
|
|
599
|
+
"""
|
|
600
|
+
payload = {
|
|
601
|
+
"agent_id": str(agent_id),
|
|
602
|
+
"response": response,
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
if session_id:
|
|
606
|
+
payload["session_id"] = str(session_id)
|
|
607
|
+
if question_id:
|
|
608
|
+
payload["question_id"] = str(question_id)
|
|
609
|
+
if study_id:
|
|
610
|
+
payload["study_id"] = str(study_id)
|
|
611
|
+
if memory_turn:
|
|
612
|
+
payload["memory_turn"] = memory_turn
|
|
613
|
+
if memory_stream_used:
|
|
614
|
+
payload["memory_stream_used"] = memory_stream_used
|
|
615
|
+
if reasoning:
|
|
616
|
+
payload["reasoning"] = reasoning
|
|
617
|
+
if metadata:
|
|
618
|
+
payload["metadata"] = metadata
|
|
619
|
+
|
|
620
|
+
response = await self._request("POST", "memory/save", json=payload)
|
|
621
|
+
data = response.json()
|
|
622
|
+
if data.get("success"):
|
|
623
|
+
return data.get("response_id")
|
|
624
|
+
raise SimileAPIError("Failed to save memory")
|
|
625
|
+
|
|
626
|
+
async def get_memory(
|
|
627
|
+
self,
|
|
628
|
+
session_id: Union[str, uuid.UUID],
|
|
629
|
+
agent_id: Union[str, uuid.UUID],
|
|
630
|
+
exclude_study_ids: Optional[List[Union[str, uuid.UUID]]] = None,
|
|
631
|
+
exclude_question_ids: Optional[List[Union[str, uuid.UUID]]] = None,
|
|
632
|
+
limit: Optional[int] = None,
|
|
633
|
+
use_memory: bool = True,
|
|
634
|
+
) -> Optional[MemoryStream]:
|
|
635
|
+
"""
|
|
636
|
+
Retrieve the memory stream for an agent in a session.
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
session_id: Session ID to filter by
|
|
640
|
+
agent_id: The agent ID
|
|
641
|
+
exclude_study_ids: List of study IDs to exclude
|
|
642
|
+
exclude_question_ids: List of question IDs to exclude
|
|
643
|
+
limit: Maximum number of turns to include
|
|
644
|
+
use_memory: Whether to use memory at all
|
|
645
|
+
|
|
646
|
+
Returns:
|
|
647
|
+
MemoryStream object or None
|
|
648
|
+
"""
|
|
649
|
+
payload = {
|
|
650
|
+
"session_id": str(session_id),
|
|
651
|
+
"agent_id": str(agent_id),
|
|
652
|
+
"use_memory": use_memory,
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
if exclude_study_ids:
|
|
656
|
+
payload["exclude_study_ids"] = [str(id) for id in exclude_study_ids]
|
|
657
|
+
if exclude_question_ids:
|
|
658
|
+
payload["exclude_question_ids"] = [str(id) for id in exclude_question_ids]
|
|
659
|
+
if limit:
|
|
660
|
+
payload["limit"] = limit
|
|
661
|
+
|
|
662
|
+
response = await self._request("POST", "memory/get", json=payload)
|
|
663
|
+
data = response.json()
|
|
664
|
+
|
|
665
|
+
if data.get("success") and data.get("memory_stream"):
|
|
666
|
+
return MemoryStream.from_dict(data["memory_stream"])
|
|
667
|
+
return None
|
|
668
|
+
|
|
669
|
+
async def get_memory_summary(
|
|
670
|
+
self,
|
|
671
|
+
session_id: Union[str, uuid.UUID],
|
|
672
|
+
) -> Dict[str, Any]:
|
|
673
|
+
"""
|
|
674
|
+
Get a summary of memory usage for a session.
|
|
675
|
+
|
|
676
|
+
Args:
|
|
677
|
+
session_id: Session ID to analyze
|
|
678
|
+
|
|
679
|
+
Returns:
|
|
680
|
+
Dictionary with memory statistics
|
|
681
|
+
"""
|
|
682
|
+
response = await self._request("GET", f"memory/summary/{session_id}")
|
|
683
|
+
data = response.json()
|
|
684
|
+
if data.get("success"):
|
|
685
|
+
return data.get("summary", {})
|
|
686
|
+
return {}
|
|
687
|
+
|
|
688
|
+
async def clear_memory(
|
|
689
|
+
self,
|
|
690
|
+
session_id: Union[str, uuid.UUID],
|
|
691
|
+
agent_id: Optional[Union[str, uuid.UUID]] = None,
|
|
692
|
+
study_id: Optional[Union[str, uuid.UUID]] = None,
|
|
693
|
+
) -> bool:
|
|
694
|
+
"""
|
|
695
|
+
Clear memory for a session, optionally filtered by agent or study.
|
|
696
|
+
|
|
697
|
+
Args:
|
|
698
|
+
session_id: Session ID to clear memory for
|
|
699
|
+
agent_id: Optional agent ID to filter by
|
|
700
|
+
study_id: Optional study ID to filter by
|
|
701
|
+
|
|
702
|
+
Returns:
|
|
703
|
+
True if cleared successfully, False otherwise
|
|
704
|
+
"""
|
|
705
|
+
payload = {
|
|
706
|
+
"session_id": str(session_id),
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
if agent_id:
|
|
710
|
+
payload["agent_id"] = str(agent_id)
|
|
711
|
+
if study_id:
|
|
712
|
+
payload["study_id"] = str(study_id)
|
|
713
|
+
|
|
714
|
+
response = await self._request("POST", "memory/clear", json=payload)
|
|
715
|
+
data = response.json()
|
|
716
|
+
return data.get("success", False)
|
|
717
|
+
|
|
718
|
+
async def copy_memory(
|
|
719
|
+
self,
|
|
720
|
+
from_session_id: Union[str, uuid.UUID],
|
|
721
|
+
to_session_id: Union[str, uuid.UUID],
|
|
722
|
+
agent_id: Optional[Union[str, uuid.UUID]] = None,
|
|
723
|
+
) -> int:
|
|
724
|
+
"""
|
|
725
|
+
Copy memory from one session to another.
|
|
726
|
+
|
|
727
|
+
Args:
|
|
728
|
+
from_session_id: Source session ID
|
|
729
|
+
to_session_id: Destination session ID
|
|
730
|
+
agent_id: Optional agent ID to filter by
|
|
731
|
+
|
|
732
|
+
Returns:
|
|
733
|
+
Number of memory turns copied
|
|
734
|
+
"""
|
|
735
|
+
payload = {
|
|
736
|
+
"from_session_id": str(from_session_id),
|
|
737
|
+
"to_session_id": str(to_session_id),
|
|
738
|
+
}
|
|
739
|
+
|
|
740
|
+
if agent_id:
|
|
741
|
+
payload["agent_id"] = str(agent_id)
|
|
742
|
+
|
|
743
|
+
response = await self._request("POST", "memory/copy", json=payload)
|
|
744
|
+
data = response.json()
|
|
745
|
+
if data.get("success"):
|
|
746
|
+
return data.get("copied_turns", 0)
|
|
747
|
+
return 0
|
|
748
|
+
|
|
447
749
|
async def aclose(self):
|
|
448
750
|
await self._client.aclose()
|
|
449
751
|
|
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
from typing import List, Dict, Any, Optional, Union, Literal
|
|
2
|
+
from pydantic import BaseModel, Field, validator
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from enum import Enum
|
|
5
|
+
import uuid
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Population(BaseModel):
|
|
9
|
+
population_id: uuid.UUID
|
|
10
|
+
name: str
|
|
11
|
+
description: Optional[str] = None
|
|
12
|
+
created_at: datetime
|
|
13
|
+
updated_at: datetime
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PopulationInfo(BaseModel):
|
|
17
|
+
population_id: uuid.UUID
|
|
18
|
+
name: str
|
|
19
|
+
description: Optional[str] = None
|
|
20
|
+
agent_count: int
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DataItem(BaseModel):
|
|
24
|
+
id: uuid.UUID
|
|
25
|
+
agent_id: uuid.UUID
|
|
26
|
+
data_type: str
|
|
27
|
+
content: Any
|
|
28
|
+
created_at: datetime
|
|
29
|
+
updated_at: datetime
|
|
30
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Agent(BaseModel):
|
|
34
|
+
agent_id: uuid.UUID
|
|
35
|
+
name: str
|
|
36
|
+
population_id: Optional[uuid.UUID] = None
|
|
37
|
+
created_at: datetime
|
|
38
|
+
updated_at: datetime
|
|
39
|
+
data_items: List[DataItem] = Field(default_factory=list)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class CreatePopulationPayload(BaseModel):
|
|
43
|
+
name: str
|
|
44
|
+
description: Optional[str] = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class InitialDataItemPayload(BaseModel):
|
|
48
|
+
data_type: str
|
|
49
|
+
content: Any
|
|
50
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class CreateAgentPayload(BaseModel):
|
|
54
|
+
name: str
|
|
55
|
+
population_id: Optional[uuid.UUID] = None
|
|
56
|
+
agent_data: Optional[List[InitialDataItemPayload]] = None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class CreateDataItemPayload(BaseModel):
|
|
60
|
+
data_type: str
|
|
61
|
+
content: Any
|
|
62
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class UpdateDataItemPayload(BaseModel):
|
|
66
|
+
content: Any
|
|
67
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class DeletionResponse(BaseModel):
|
|
71
|
+
message: str
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# --- Generation Operation Models ---
|
|
75
|
+
class OpenGenerationRequest(BaseModel):
|
|
76
|
+
question: str
|
|
77
|
+
data_types: Optional[List[str]] = None
|
|
78
|
+
exclude_data_types: Optional[List[str]] = None
|
|
79
|
+
images: Optional[Dict[str, str]] = (
|
|
80
|
+
None # Dict of {description: url} for multiple images
|
|
81
|
+
)
|
|
82
|
+
reasoning: bool = False
|
|
83
|
+
memory_stream: Optional["MemoryStream"] = None
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class OpenGenerationResponse(BaseModel):
|
|
87
|
+
question: str
|
|
88
|
+
answer: str
|
|
89
|
+
reasoning: Optional[str] = ""
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ClosedGenerationRequest(BaseModel):
|
|
93
|
+
question: str
|
|
94
|
+
options: List[str]
|
|
95
|
+
data_types: Optional[List[str]] = None
|
|
96
|
+
exclude_data_types: Optional[List[str]] = None
|
|
97
|
+
images: Optional[Dict[str, str]] = None
|
|
98
|
+
reasoning: bool = False
|
|
99
|
+
memory_stream: Optional["MemoryStream"] = None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class ClosedGenerationResponse(BaseModel):
|
|
103
|
+
question: str
|
|
104
|
+
options: List[str]
|
|
105
|
+
response: str
|
|
106
|
+
reasoning: Optional[str] = ""
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class AddContextRequest(BaseModel):
|
|
110
|
+
context: str
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class AddContextResponse(BaseModel):
|
|
114
|
+
message: str
|
|
115
|
+
session_id: uuid.UUID
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# --- Survey Session Models ---
|
|
119
|
+
class TurnType(str, Enum):
|
|
120
|
+
"""Enum for different types of conversation turns."""
|
|
121
|
+
|
|
122
|
+
CONTEXT = "context"
|
|
123
|
+
IMAGE = "image"
|
|
124
|
+
OPEN_QUESTION = "open_question"
|
|
125
|
+
CLOSED_QUESTION = "closed_question"
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class BaseTurn(BaseModel):
|
|
129
|
+
"""Base model for all conversation turns."""
|
|
130
|
+
|
|
131
|
+
timestamp: datetime = Field(default_factory=lambda: datetime.now())
|
|
132
|
+
type: TurnType
|
|
133
|
+
|
|
134
|
+
class Config:
|
|
135
|
+
use_enum_values = True
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class ContextTurn(BaseTurn):
|
|
139
|
+
"""A context turn that provides background information."""
|
|
140
|
+
|
|
141
|
+
type: Literal[TurnType.CONTEXT] = TurnType.CONTEXT
|
|
142
|
+
user_context: str
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class ImageTurn(BaseTurn):
|
|
146
|
+
"""A standalone image turn (e.g., for context or reference)."""
|
|
147
|
+
|
|
148
|
+
type: Literal[TurnType.IMAGE] = TurnType.IMAGE
|
|
149
|
+
images: Dict[str, str]
|
|
150
|
+
caption: Optional[str] = None
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class OpenQuestionTurn(BaseTurn):
|
|
154
|
+
"""An open question-answer turn."""
|
|
155
|
+
|
|
156
|
+
type: Literal[TurnType.OPEN_QUESTION] = TurnType.OPEN_QUESTION
|
|
157
|
+
user_question: str
|
|
158
|
+
user_images: Optional[Dict[str, str]] = None
|
|
159
|
+
llm_response: Optional[str] = None
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class ClosedQuestionTurn(BaseTurn):
|
|
163
|
+
"""A closed question-answer turn."""
|
|
164
|
+
|
|
165
|
+
type: Literal[TurnType.CLOSED_QUESTION] = TurnType.CLOSED_QUESTION
|
|
166
|
+
user_question: str
|
|
167
|
+
user_options: List[str]
|
|
168
|
+
user_images: Optional[Dict[str, str]] = None
|
|
169
|
+
llm_response: Optional[str] = None
|
|
170
|
+
|
|
171
|
+
@validator("user_options")
|
|
172
|
+
def validate_options(cls, v):
|
|
173
|
+
if not v:
|
|
174
|
+
raise ValueError("Closed questions must have at least one option")
|
|
175
|
+
if len(v) < 2:
|
|
176
|
+
raise ValueError("Closed questions should have at least two options")
|
|
177
|
+
return v
|
|
178
|
+
|
|
179
|
+
@validator("llm_response")
|
|
180
|
+
def validate_response(cls, v, values):
|
|
181
|
+
if (
|
|
182
|
+
v is not None
|
|
183
|
+
and "user_options" in values
|
|
184
|
+
and v not in values["user_options"]
|
|
185
|
+
):
|
|
186
|
+
raise ValueError(f"Response '{v}' must be one of the provided options")
|
|
187
|
+
return v
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
# Union type for all possible turn types
|
|
191
|
+
SurveySessionTurn = Union[ContextTurn, ImageTurn, OpenQuestionTurn, ClosedQuestionTurn]
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class SurveySessionCreateResponse(BaseModel):
|
|
195
|
+
id: uuid.UUID # Session ID
|
|
196
|
+
agent_id: uuid.UUID
|
|
197
|
+
created_at: datetime
|
|
198
|
+
status: str
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class SurveySessionDetailResponse(BaseModel):
|
|
202
|
+
"""Detailed survey session response with typed conversation turns."""
|
|
203
|
+
|
|
204
|
+
id: uuid.UUID
|
|
205
|
+
agent_id: uuid.UUID
|
|
206
|
+
created_at: datetime
|
|
207
|
+
updated_at: datetime
|
|
208
|
+
status: str
|
|
209
|
+
conversation_history: List[SurveySessionTurn] = Field(default_factory=list)
|
|
210
|
+
|
|
211
|
+
class Config:
|
|
212
|
+
json_encoders = {datetime: lambda v: v.isoformat()}
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class SurveySessionListItemResponse(BaseModel):
|
|
216
|
+
"""Summary response for listing survey sessions."""
|
|
217
|
+
|
|
218
|
+
id: uuid.UUID
|
|
219
|
+
agent_id: uuid.UUID
|
|
220
|
+
created_at: datetime
|
|
221
|
+
updated_at: datetime
|
|
222
|
+
status: str
|
|
223
|
+
turn_count: int = Field(description="Number of turns in conversation history")
|
|
224
|
+
|
|
225
|
+
class Config:
|
|
226
|
+
json_encoders = {datetime: lambda v: v.isoformat()}
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class SurveySessionCloseResponse(BaseModel):
|
|
230
|
+
id: uuid.UUID # Session ID
|
|
231
|
+
status: str
|
|
232
|
+
updated_at: datetime
|
|
233
|
+
message: Optional[str] = None
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
# --- Memory Stream Models (to replace Survey Sessions) ---
|
|
237
|
+
class MemoryTurnType(str, Enum):
|
|
238
|
+
"""Enum for different types of memory turns."""
|
|
239
|
+
|
|
240
|
+
CONTEXT = "context"
|
|
241
|
+
IMAGE = "image"
|
|
242
|
+
OPEN_QUESTION = "open_question"
|
|
243
|
+
CLOSED_QUESTION = "closed_question"
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class BaseMemoryTurn(BaseModel):
|
|
247
|
+
"""Base model for all memory turns."""
|
|
248
|
+
|
|
249
|
+
timestamp: datetime = Field(default_factory=lambda: datetime.now())
|
|
250
|
+
type: MemoryTurnType
|
|
251
|
+
|
|
252
|
+
class Config:
|
|
253
|
+
use_enum_values = True
|
|
254
|
+
|
|
255
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
256
|
+
"""Convert to dictionary for serialization."""
|
|
257
|
+
data = self.model_dump()
|
|
258
|
+
# Remove timestamp - let API handle it
|
|
259
|
+
data.pop("timestamp", None)
|
|
260
|
+
# Ensure enum is serialized as string
|
|
261
|
+
if "type" in data:
|
|
262
|
+
if hasattr(data["type"], "value"):
|
|
263
|
+
data["type"] = data["type"].value
|
|
264
|
+
return data
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class ContextMemoryTurn(BaseMemoryTurn):
|
|
268
|
+
"""A context turn that provides background information."""
|
|
269
|
+
|
|
270
|
+
type: MemoryTurnType = Field(default=MemoryTurnType.CONTEXT)
|
|
271
|
+
user_context: str
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class ImageMemoryTurn(BaseMemoryTurn):
|
|
275
|
+
"""A standalone image turn (e.g., for context or reference)."""
|
|
276
|
+
|
|
277
|
+
type: MemoryTurnType = Field(default=MemoryTurnType.IMAGE)
|
|
278
|
+
images: Dict[str, str]
|
|
279
|
+
caption: Optional[str] = None
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class OpenQuestionMemoryTurn(BaseMemoryTurn):
|
|
283
|
+
"""An open question-answer turn."""
|
|
284
|
+
|
|
285
|
+
type: MemoryTurnType = Field(default=MemoryTurnType.OPEN_QUESTION)
|
|
286
|
+
user_question: str
|
|
287
|
+
user_images: Optional[Dict[str, str]] = None
|
|
288
|
+
llm_response: Optional[str] = None
|
|
289
|
+
llm_reasoning: Optional[str] = None
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class ClosedQuestionMemoryTurn(BaseMemoryTurn):
|
|
293
|
+
"""A closed question-answer turn."""
|
|
294
|
+
|
|
295
|
+
type: MemoryTurnType = Field(default=MemoryTurnType.CLOSED_QUESTION)
|
|
296
|
+
user_question: str
|
|
297
|
+
user_options: List[str]
|
|
298
|
+
user_images: Optional[Dict[str, str]] = None
|
|
299
|
+
llm_response: Optional[str] = None
|
|
300
|
+
llm_reasoning: Optional[str] = None
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
# Discriminated union of all memory turn types
|
|
304
|
+
MemoryTurn = Union[
|
|
305
|
+
ContextMemoryTurn, ImageMemoryTurn, OpenQuestionMemoryTurn, ClosedQuestionMemoryTurn
|
|
306
|
+
]
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class MemoryStream(BaseModel):
|
|
310
|
+
"""
|
|
311
|
+
A flexible memory stream that can be passed to generation functions.
|
|
312
|
+
This replaces the session-based approach with a more flexible paradigm.
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
turns: List[MemoryTurn] = Field(default_factory=list)
|
|
316
|
+
|
|
317
|
+
def add_turn(self, turn: MemoryTurn) -> None:
|
|
318
|
+
"""Add a turn to the memory stream."""
|
|
319
|
+
self.turns.append(turn)
|
|
320
|
+
|
|
321
|
+
def remove_turn(self, index: int) -> Optional[MemoryTurn]:
|
|
322
|
+
"""Remove a turn at the specified index."""
|
|
323
|
+
if 0 <= index < len(self.turns):
|
|
324
|
+
return self.turns.pop(index)
|
|
325
|
+
return None
|
|
326
|
+
|
|
327
|
+
def get_turns_by_type(self, turn_type: MemoryTurnType) -> List[MemoryTurn]:
|
|
328
|
+
"""Get all turns of a specific type."""
|
|
329
|
+
return [turn for turn in self.turns if turn.type == turn_type]
|
|
330
|
+
|
|
331
|
+
def get_last_turn(self) -> Optional[MemoryTurn]:
|
|
332
|
+
"""Get the most recent turn."""
|
|
333
|
+
return self.turns[-1] if self.turns else None
|
|
334
|
+
|
|
335
|
+
def clear(self) -> None:
|
|
336
|
+
"""Clear all turns from the memory stream."""
|
|
337
|
+
self.turns = []
|
|
338
|
+
|
|
339
|
+
def __len__(self) -> int:
|
|
340
|
+
"""Return the number of turns in the memory stream."""
|
|
341
|
+
return len(self.turns)
|
|
342
|
+
|
|
343
|
+
def __bool__(self) -> bool:
|
|
344
|
+
"""Return True if the memory stream has any turns."""
|
|
345
|
+
return bool(self.turns)
|
|
346
|
+
|
|
347
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
348
|
+
"""Convert memory stream to a dictionary for serialization."""
|
|
349
|
+
return {
|
|
350
|
+
"turns": [turn.to_dict() for turn in self.turns]
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
@classmethod
|
|
354
|
+
def from_dict(cls, data: Dict[str, Any]) -> "MemoryStream":
|
|
355
|
+
"""Create a MemoryStream from a dictionary."""
|
|
356
|
+
memory = cls()
|
|
357
|
+
for turn_data in data.get("turns", []):
|
|
358
|
+
turn_type = turn_data.get("type")
|
|
359
|
+
if turn_type == MemoryTurnType.CONTEXT:
|
|
360
|
+
memory.add_turn(ContextMemoryTurn(**turn_data))
|
|
361
|
+
elif turn_type == MemoryTurnType.IMAGE:
|
|
362
|
+
memory.add_turn(ImageMemoryTurn(**turn_data))
|
|
363
|
+
elif turn_type == MemoryTurnType.OPEN_QUESTION:
|
|
364
|
+
memory.add_turn(OpenQuestionMemoryTurn(**turn_data))
|
|
365
|
+
elif turn_type == MemoryTurnType.CLOSED_QUESTION:
|
|
366
|
+
memory.add_turn(ClosedQuestionMemoryTurn(**turn_data))
|
|
367
|
+
return memory
|
|
368
|
+
|
|
369
|
+
def fork(self, up_to_index: Optional[int] = None) -> "MemoryStream":
|
|
370
|
+
"""Create a copy of this memory stream, optionally up to a specific index."""
|
|
371
|
+
new_memory = MemoryStream()
|
|
372
|
+
turns_to_copy = self.turns[:up_to_index] if up_to_index is not None else self.turns
|
|
373
|
+
for turn in turns_to_copy:
|
|
374
|
+
new_memory.add_turn(turn.model_copy())
|
|
375
|
+
return new_memory
|
|
376
|
+
|
|
377
|
+
def filter_by_type(self, turn_type: MemoryTurnType) -> "MemoryStream":
|
|
378
|
+
"""Create a new memory stream with only turns of a specific type."""
|
|
379
|
+
new_memory = MemoryStream()
|
|
380
|
+
for turn in self.get_turns_by_type(turn_type):
|
|
381
|
+
new_memory.add_turn(turn.model_copy())
|
|
382
|
+
return new_memory
|
|
383
|
+
|
|
384
|
+
def get_question_answer_pairs(self) -> List[tuple]:
|
|
385
|
+
"""Extract question-answer pairs from the memory."""
|
|
386
|
+
pairs = []
|
|
387
|
+
for turn in self.turns:
|
|
388
|
+
if isinstance(turn, (OpenQuestionMemoryTurn, ClosedQuestionMemoryTurn)):
|
|
389
|
+
if turn.llm_response:
|
|
390
|
+
pairs.append((turn.user_question, turn.llm_response))
|
|
391
|
+
return pairs
|
|
392
|
+
|
|
393
|
+
def truncate(self, max_turns: int) -> None:
|
|
394
|
+
"""Keep only the most recent N turns."""
|
|
395
|
+
if len(self.turns) > max_turns:
|
|
396
|
+
self.turns = self.turns[-max_turns:]
|
|
397
|
+
|
|
398
|
+
def insert_turn(self, index: int, turn: MemoryTurn) -> None:
|
|
399
|
+
"""Insert a turn at a specific position."""
|
|
400
|
+
self.turns.insert(index, turn)
|
|
@@ -11,6 +11,7 @@ from .models import (
|
|
|
11
11
|
AddContextResponse,
|
|
12
12
|
SurveySessionDetailResponse,
|
|
13
13
|
SurveySessionCreateResponse,
|
|
14
|
+
MemoryStream,
|
|
14
15
|
)
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING:
|
|
@@ -34,6 +35,7 @@ class Agent:
|
|
|
34
35
|
data_types: Optional[List[str]] = None,
|
|
35
36
|
exclude_data_types: Optional[List[str]] = None,
|
|
36
37
|
images: Optional[Dict[str, str]] = None,
|
|
38
|
+
memory_stream: Optional[MemoryStream] = None,
|
|
37
39
|
) -> OpenGenerationResponse:
|
|
38
40
|
"""Generates an open response from this agent based on a question."""
|
|
39
41
|
return await self._client.generate_open_response(
|
|
@@ -42,6 +44,7 @@ class Agent:
|
|
|
42
44
|
data_types=data_types,
|
|
43
45
|
exclude_data_types=exclude_data_types,
|
|
44
46
|
images=images,
|
|
47
|
+
memory_stream=memory_stream,
|
|
45
48
|
)
|
|
46
49
|
|
|
47
50
|
async def generate_closed_response(
|
|
@@ -51,6 +54,7 @@ class Agent:
|
|
|
51
54
|
data_types: Optional[List[str]] = None,
|
|
52
55
|
exclude_data_types: Optional[List[str]] = None,
|
|
53
56
|
images: Optional[Dict[str, str]] = None,
|
|
57
|
+
memory_stream: Optional[MemoryStream] = None,
|
|
54
58
|
) -> ClosedGenerationResponse:
|
|
55
59
|
"""Generates a closed response from this agent."""
|
|
56
60
|
return await self._client.generate_closed_response(
|
|
@@ -60,6 +64,7 @@ class Agent:
|
|
|
60
64
|
data_types=data_types,
|
|
61
65
|
exclude_data_types=exclude_data_types,
|
|
62
66
|
images=images,
|
|
67
|
+
memory_stream=memory_stream,
|
|
63
68
|
)
|
|
64
69
|
|
|
65
70
|
|
simile-0.3.12/simile/models.py
DELETED
|
@@ -1,231 +0,0 @@
|
|
|
1
|
-
from typing import List, Dict, Any, Optional, Union, Literal
|
|
2
|
-
from pydantic import BaseModel, Field, validator
|
|
3
|
-
from datetime import datetime
|
|
4
|
-
from enum import Enum
|
|
5
|
-
import uuid
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class Population(BaseModel):
|
|
9
|
-
population_id: uuid.UUID
|
|
10
|
-
name: str
|
|
11
|
-
description: Optional[str] = None
|
|
12
|
-
created_at: datetime
|
|
13
|
-
updated_at: datetime
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class PopulationInfo(BaseModel):
|
|
17
|
-
population_id: uuid.UUID
|
|
18
|
-
name: str
|
|
19
|
-
description: Optional[str] = None
|
|
20
|
-
agent_count: int
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class DataItem(BaseModel):
|
|
24
|
-
id: uuid.UUID
|
|
25
|
-
agent_id: uuid.UUID
|
|
26
|
-
data_type: str
|
|
27
|
-
content: Any
|
|
28
|
-
created_at: datetime
|
|
29
|
-
updated_at: datetime
|
|
30
|
-
metadata: Optional[Dict[str, Any]] = None
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class Agent(BaseModel):
|
|
34
|
-
agent_id: uuid.UUID
|
|
35
|
-
name: str
|
|
36
|
-
population_id: Optional[uuid.UUID] = None
|
|
37
|
-
created_at: datetime
|
|
38
|
-
updated_at: datetime
|
|
39
|
-
data_items: List[DataItem] = Field(default_factory=list)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
class CreatePopulationPayload(BaseModel):
|
|
43
|
-
name: str
|
|
44
|
-
description: Optional[str] = None
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class InitialDataItemPayload(BaseModel):
|
|
48
|
-
data_type: str
|
|
49
|
-
content: Any
|
|
50
|
-
metadata: Optional[Dict[str, Any]] = None
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class CreateAgentPayload(BaseModel):
|
|
54
|
-
name: str
|
|
55
|
-
population_id: Optional[uuid.UUID] = None
|
|
56
|
-
agent_data: Optional[List[InitialDataItemPayload]] = None
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class CreateDataItemPayload(BaseModel):
|
|
60
|
-
data_type: str
|
|
61
|
-
content: Any
|
|
62
|
-
metadata: Optional[Dict[str, Any]] = None
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
class UpdateDataItemPayload(BaseModel):
|
|
66
|
-
content: Any
|
|
67
|
-
metadata: Optional[Dict[str, Any]] = None
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
class DeletionResponse(BaseModel):
|
|
71
|
-
message: str
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
# --- Generation Operation Models ---
|
|
75
|
-
class OpenGenerationRequest(BaseModel):
|
|
76
|
-
question: str
|
|
77
|
-
data_types: Optional[List[str]] = None
|
|
78
|
-
exclude_data_types: Optional[List[str]] = None
|
|
79
|
-
images: Optional[Dict[str, str]] = (
|
|
80
|
-
None # Dict of {description: url} for multiple images
|
|
81
|
-
)
|
|
82
|
-
reasoning: bool = False
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
class OpenGenerationResponse(BaseModel):
|
|
86
|
-
question: str
|
|
87
|
-
answer: str
|
|
88
|
-
reasoning: Optional[str] = ""
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
class ClosedGenerationRequest(BaseModel):
|
|
92
|
-
question: str
|
|
93
|
-
options: List[str]
|
|
94
|
-
data_types: Optional[List[str]] = None
|
|
95
|
-
exclude_data_types: Optional[List[str]] = None
|
|
96
|
-
images: Optional[Dict[str, str]] = None
|
|
97
|
-
reasoning: bool = False
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
class ClosedGenerationResponse(BaseModel):
|
|
101
|
-
question: str
|
|
102
|
-
options: List[str]
|
|
103
|
-
response: str
|
|
104
|
-
reasoning: Optional[str] = ""
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
class AddContextRequest(BaseModel):
|
|
108
|
-
context: str
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
class AddContextResponse(BaseModel):
|
|
112
|
-
message: str
|
|
113
|
-
session_id: uuid.UUID
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
# --- Survey Session Models ---
|
|
117
|
-
class TurnType(str, Enum):
|
|
118
|
-
"""Enum for different types of conversation turns."""
|
|
119
|
-
|
|
120
|
-
CONTEXT = "context"
|
|
121
|
-
IMAGE = "image"
|
|
122
|
-
OPEN_QUESTION = "open_question"
|
|
123
|
-
CLOSED_QUESTION = "closed_question"
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
class BaseTurn(BaseModel):
|
|
127
|
-
"""Base model for all conversation turns."""
|
|
128
|
-
|
|
129
|
-
timestamp: datetime = Field(default_factory=lambda: datetime.now())
|
|
130
|
-
type: TurnType
|
|
131
|
-
|
|
132
|
-
class Config:
|
|
133
|
-
use_enum_values = True
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
class ContextTurn(BaseTurn):
|
|
137
|
-
"""A context turn that provides background information."""
|
|
138
|
-
|
|
139
|
-
type: Literal[TurnType.CONTEXT] = TurnType.CONTEXT
|
|
140
|
-
user_context: str
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
class ImageTurn(BaseTurn):
|
|
144
|
-
"""A standalone image turn (e.g., for context or reference)."""
|
|
145
|
-
|
|
146
|
-
type: Literal[TurnType.IMAGE] = TurnType.IMAGE
|
|
147
|
-
images: Dict[str, str]
|
|
148
|
-
caption: Optional[str] = None
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
class OpenQuestionTurn(BaseTurn):
|
|
152
|
-
"""An open question-answer turn."""
|
|
153
|
-
|
|
154
|
-
type: Literal[TurnType.OPEN_QUESTION] = TurnType.OPEN_QUESTION
|
|
155
|
-
user_question: str
|
|
156
|
-
user_images: Optional[Dict[str, str]] = None
|
|
157
|
-
llm_response: Optional[str] = None
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
class ClosedQuestionTurn(BaseTurn):
|
|
161
|
-
"""A closed question-answer turn."""
|
|
162
|
-
|
|
163
|
-
type: Literal[TurnType.CLOSED_QUESTION] = TurnType.CLOSED_QUESTION
|
|
164
|
-
user_question: str
|
|
165
|
-
user_options: List[str]
|
|
166
|
-
user_images: Optional[Dict[str, str]] = None
|
|
167
|
-
llm_response: Optional[str] = None
|
|
168
|
-
|
|
169
|
-
@validator("user_options")
|
|
170
|
-
def validate_options(cls, v):
|
|
171
|
-
if not v:
|
|
172
|
-
raise ValueError("Closed questions must have at least one option")
|
|
173
|
-
if len(v) < 2:
|
|
174
|
-
raise ValueError("Closed questions should have at least two options")
|
|
175
|
-
return v
|
|
176
|
-
|
|
177
|
-
@validator("llm_response")
|
|
178
|
-
def validate_response(cls, v, values):
|
|
179
|
-
if (
|
|
180
|
-
v is not None
|
|
181
|
-
and "user_options" in values
|
|
182
|
-
and v not in values["user_options"]
|
|
183
|
-
):
|
|
184
|
-
raise ValueError(f"Response '{v}' must be one of the provided options")
|
|
185
|
-
return v
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
# Union type for all possible turn types
|
|
189
|
-
SurveySessionTurn = Union[ContextTurn, ImageTurn, OpenQuestionTurn, ClosedQuestionTurn]
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
class SurveySessionCreateResponse(BaseModel):
|
|
193
|
-
id: uuid.UUID # Session ID
|
|
194
|
-
agent_id: uuid.UUID
|
|
195
|
-
created_at: datetime
|
|
196
|
-
status: str
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
class SurveySessionDetailResponse(BaseModel):
|
|
200
|
-
"""Detailed survey session response with typed conversation turns."""
|
|
201
|
-
|
|
202
|
-
id: uuid.UUID
|
|
203
|
-
agent_id: uuid.UUID
|
|
204
|
-
created_at: datetime
|
|
205
|
-
updated_at: datetime
|
|
206
|
-
status: str
|
|
207
|
-
conversation_history: List[SurveySessionTurn] = Field(default_factory=list)
|
|
208
|
-
|
|
209
|
-
class Config:
|
|
210
|
-
json_encoders = {datetime: lambda v: v.isoformat()}
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
class SurveySessionListItemResponse(BaseModel):
|
|
214
|
-
"""Summary response for listing survey sessions."""
|
|
215
|
-
|
|
216
|
-
id: uuid.UUID
|
|
217
|
-
agent_id: uuid.UUID
|
|
218
|
-
created_at: datetime
|
|
219
|
-
updated_at: datetime
|
|
220
|
-
status: str
|
|
221
|
-
turn_count: int = Field(description="Number of turns in conversation history")
|
|
222
|
-
|
|
223
|
-
class Config:
|
|
224
|
-
json_encoders = {datetime: lambda v: v.isoformat()}
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
class SurveySessionCloseResponse(BaseModel):
|
|
228
|
-
id: uuid.UUID # Session ID
|
|
229
|
-
status: str
|
|
230
|
-
updated_at: datetime
|
|
231
|
-
message: Optional[str] = None
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|