simile 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of simile might be problematic. Click here for more details.

simile/client.py ADDED
@@ -0,0 +1,885 @@
1
+ import httpx
2
+ from httpx import AsyncClient, Limits
3
+ from typing import List, Dict, Any, Optional, Union, Type, AsyncGenerator
4
+ import uuid
5
+ from pydantic import BaseModel
6
+
7
+ from .models import (
8
+ Population,
9
+ PopulationInfo,
10
+ UpdatePopulationMetadataPayload,
11
+ UpdatePopulationInfoPayload,
12
+ Agent as AgentModel,
13
+ DataItem,
14
+ DeletionResponse,
15
+ OpenGenerationRequest,
16
+ OpenGenerationResponse,
17
+ ClosedGenerationRequest,
18
+ ClosedGenerationResponse,
19
+ CreatePopulationPayload,
20
+ CreateAgentPayload,
21
+ CreateDataItemPayload,
22
+ UpdateDataItemPayload,
23
+ InitialDataItemPayload,
24
+ SurveySessionCreateResponse,
25
+ SurveySessionDetailResponse,
26
+ MemoryStream,
27
+ UpdateAgentInfoPayload,
28
+ )
29
+ from .resources import Agent, SurveySession
30
+ from .exceptions import (
31
+ SimileAPIError,
32
+ SimileAuthenticationError,
33
+ SimileNotFoundError,
34
+ SimileBadRequestError,
35
+ )
36
+
37
+ DEFAULT_BASE_URL = "https://api.simile.ai/api/v1"
38
+ TIMEOUT_CONFIG = httpx.Timeout(5.0, read=60.0, write=30.0, pool=30.0)
39
+
40
+
41
+ class Simile:
42
+ APIError = SimileAPIError
43
+ AuthenticationError = SimileAuthenticationError
44
+ NotFoundError = SimileNotFoundError
45
+ BadRequestError = SimileBadRequestError
46
+
47
+ def __init__(
48
+ self,
49
+ api_key: str,
50
+ base_url: str = DEFAULT_BASE_URL,
51
+ max_connections: int = 5000,
52
+ max_keepalive_connections: int = 2000,
53
+ keepalive_expiry: float = 300.0,
54
+ ):
55
+ if not api_key:
56
+ raise ValueError("API key is required.")
57
+ self.api_key = api_key
58
+ self.base_url = base_url.rstrip("/")
59
+
60
+ limits = Limits(
61
+ max_connections=max_connections,
62
+ max_keepalive_connections=max_keepalive_connections,
63
+ keepalive_expiry=keepalive_expiry,
64
+ )
65
+ self._client = AsyncClient(
66
+ headers={"X-API-Key": self.api_key}, timeout=TIMEOUT_CONFIG, limits=limits,
67
+ )
68
+
69
+ async def _request(
70
+ self, method: str, endpoint: str, **kwargs
71
+ ) -> Union[httpx.Response, BaseModel]:
72
+ url = f"{self.base_url}/{endpoint.lstrip('/')}"
73
+ response_model_cls: Optional[Type[BaseModel]] = kwargs.pop(
74
+ "response_model", None
75
+ )
76
+
77
+ try:
78
+ response = await self._client.request(method, url, **kwargs)
79
+ response.raise_for_status()
80
+
81
+ if response_model_cls:
82
+ return response_model_cls(**response.json())
83
+ else:
84
+ return response
85
+ except httpx.HTTPStatusError as e:
86
+ status_code = e.response.status_code
87
+ try:
88
+ error_data = e.response.json()
89
+ detail = error_data.get("detail", e.response.text)
90
+ except Exception:
91
+ detail = e.response.text
92
+
93
+ if status_code == 401:
94
+ raise SimileAuthenticationError(detail=detail)
95
+ elif status_code == 404:
96
+ raise SimileNotFoundError(detail=detail)
97
+ elif status_code == 400:
98
+ raise SimileBadRequestError(detail=detail)
99
+ else:
100
+ raise SimileAPIError(
101
+ f"API request failed: {e}", status_code=status_code, detail=detail
102
+ )
103
+ except httpx.ConnectTimeout:
104
+ raise SimileAPIError("Connection timed out while trying to connect.")
105
+ except httpx.ReadTimeout:
106
+ raise SimileAPIError("Timed out waiting for data from the server.")
107
+ except httpx.WriteTimeout:
108
+ raise SimileAPIError("Timed out while sending data to the server.")
109
+ except httpx.PoolTimeout:
110
+ raise SimileAPIError("Timed out waiting for a connection from the pool.")
111
+ except httpx.ConnectError:
112
+ raise SimileAPIError("Failed to connect to the server.")
113
+ except httpx.ProtocolError:
114
+ raise SimileAPIError("A protocol error occurred.")
115
+ except httpx.DecodingError:
116
+ raise SimileAPIError("Failed to decode the response.")
117
+ except httpx.RequestError as e:
118
+ raise SimileAPIError(
119
+ f"An unknown request error occurred: {type(e).__name__}: {e}"
120
+ )
121
+
122
+ def agent(self, agent_id: uuid.UUID) -> Agent:
123
+ """Returns an Agent object to interact with a specific agent."""
124
+ return Agent(agent_id=agent_id, client=self)
125
+
126
+ async def create_survey_session(self, agent_id: uuid.UUID) -> SurveySession:
127
+ """Creates a new survey session for the given agent and returns a SurveySession object."""
128
+ endpoint = "sessions/"
129
+ response_data = await self._request(
130
+ "POST",
131
+ endpoint,
132
+ json={"agent_id": str(agent_id)},
133
+ response_model=SurveySessionCreateResponse,
134
+ )
135
+
136
+ # Create and return a SurveySession object
137
+ return SurveySession(
138
+ id=response_data.id,
139
+ agent_id=response_data.agent_id,
140
+ status=response_data.status,
141
+ client=self,
142
+ )
143
+
144
+ async def get_survey_session_details(
145
+ self, session_id: Union[str, uuid.UUID]
146
+ ) -> SurveySessionDetailResponse:
147
+ """Retrieves detailed information about a survey session including typed conversation history."""
148
+ endpoint = f"sessions/{str(session_id)}"
149
+ response_data = await self._request(
150
+ "GET", endpoint, response_model=SurveySessionDetailResponse
151
+ )
152
+ return response_data
153
+
154
+ async def get_survey_session(
155
+ self, session_id: Union[str, uuid.UUID]
156
+ ) -> SurveySession:
157
+ """Resume an existing survey session by ID and return a SurveySession object."""
158
+ session_details = await self.get_survey_session_details(session_id)
159
+
160
+ if session_details.status == "closed":
161
+ raise ValueError(f"Session {session_id} is already closed")
162
+
163
+ return SurveySession(
164
+ id=session_details.id,
165
+ agent_id=session_details.agent_id,
166
+ status=session_details.status,
167
+ client=self,
168
+ )
169
+
170
+ async def create_population(
171
+ self, name: str, description: Optional[str] = None
172
+ ) -> Population:
173
+ """Creates a new population."""
174
+ payload = CreatePopulationPayload(name=name, description=description)
175
+ response_data = await self._request(
176
+ "POST",
177
+ "populations/create",
178
+ json=payload.model_dump(mode="json", exclude_none=True),
179
+ response_model=Population,
180
+ )
181
+ return response_data
182
+
183
+ async def update_population_metadata(
184
+ self,
185
+ population_id: Union[str, uuid.UUID],
186
+ metadata: Dict[str, Any],
187
+ mode: str = "merge",
188
+ ) -> Population:
189
+ """
190
+ Update a population's metadata (jsonb).
191
+
192
+ Args:
193
+ population_id: The ID of the population
194
+ metadata: A dictionary of metadata to merge or replace
195
+ mode: Either "merge" (default) or "replace"
196
+
197
+ Returns:
198
+ Updated Population object
199
+ """
200
+ payload = UpdatePopulationMetadataPayload(metadata=metadata, mode=mode)
201
+ response_data = await self._request(
202
+ "PATCH",
203
+ f"populations/{str(population_id)}/metadata",
204
+ json=payload.model_dump(mode="json", exclude_none=True),
205
+ response_model=Population,
206
+ )
207
+ return response_data
208
+
209
+ async def get_population(self, population_id: Union[str, uuid.UUID]) -> Population:
210
+ response_data = await self._request(
211
+ "GET", f"populations/get/{str(population_id)}", response_model=Population
212
+ )
213
+ return response_data
214
+
215
+ async def update_population_info(
216
+ self,
217
+ population_id: Union[str, uuid.UUID],
218
+ name: Optional[str] = None,
219
+ description: Optional[str] = None,
220
+ metadata: Optional[Dict[str, Any]] = None,
221
+ ) -> Population:
222
+ """
223
+ Updates population information (name, description, metadata).
224
+ At least one of name, description, or metadata must be provided.
225
+ Requires write access to the population.
226
+ """
227
+ payload = UpdatePopulationInfoPayload(
228
+ name=name, description=description, metadata=metadata
229
+ )
230
+ response_data = await self._request(
231
+ "PUT",
232
+ f"populations/update/{str(population_id)}",
233
+ json=payload.model_dump(mode="json", exclude_none=True),
234
+ response_model=Population,
235
+ )
236
+ return response_data
237
+
238
+ async def get_population_info(
239
+ self, population_id: Union[str, uuid.UUID]
240
+ ) -> PopulationInfo:
241
+ """Gets basic population info (name and agent count) without full population data."""
242
+ response_data = await self._request(
243
+ "GET",
244
+ f"populations/info/{str(population_id)}",
245
+ response_model=PopulationInfo,
246
+ )
247
+ return response_data
248
+
249
+ async def delete_population(
250
+ self, population_id: Union[str, uuid.UUID]
251
+ ) -> DeletionResponse:
252
+ response_data = await self._request(
253
+ "DELETE",
254
+ f"populations/delete/{str(population_id)}",
255
+ response_model=DeletionResponse,
256
+ )
257
+ return response_data
258
+
259
+ async def get_agents_in_population(
260
+ self, population_id: Union[str, uuid.UUID]
261
+ ) -> List[AgentModel]:
262
+ """Retrieves all agents belonging to a specific population."""
263
+ endpoint = f"populations/{str(population_id)}/agents"
264
+ raw_response = await self._request("GET", endpoint)
265
+ agents_data_list = raw_response.json()
266
+ return [AgentModel(**data) for data in agents_data_list]
267
+
268
+ async def get_agent_ids_in_population(
269
+ self, population_id: Union[str, uuid.UUID]
270
+ ) -> List[str]:
271
+ """Retrieves only agent IDs for a population without full agent data.
272
+
273
+ This is a lightweight alternative to get_agents_in_population when
274
+ only agent IDs are needed.
275
+
276
+ Args:
277
+ population_id: The ID of the population
278
+
279
+ Returns:
280
+ List of agent ID strings
281
+ """
282
+ endpoint = f"populations/{str(population_id)}/agents/ids"
283
+ raw_response = await self._request("GET", endpoint)
284
+ return raw_response.json()
285
+
286
+ async def create_agent(
287
+ self,
288
+ name: str,
289
+ source: Optional[str] = None,
290
+ source_id: Optional[str] = None,
291
+ population_id: Optional[Union[str, uuid.UUID]] = None,
292
+ agent_data: Optional[List[Dict[str, Any]]] = None,
293
+ ) -> AgentModel:
294
+ """Creates a new agent, optionally within a population and with initial data items."""
295
+ pop_id_uuid: Optional[uuid.UUID] = None
296
+ if population_id:
297
+ pop_id_uuid = (
298
+ uuid.UUID(str(population_id))
299
+ if not isinstance(population_id, uuid.UUID)
300
+ else population_id
301
+ )
302
+
303
+ payload = CreateAgentPayload(
304
+ name=name,
305
+ population_id=pop_id_uuid,
306
+ agent_data=agent_data,
307
+ source=source,
308
+ source_id=source_id,
309
+ )
310
+ response_data = await self._request(
311
+ "POST",
312
+ "agents/create",
313
+ json=payload.model_dump(mode="json", exclude_none=True),
314
+ response_model=AgentModel,
315
+ )
316
+ return response_data
317
+
318
+ async def get_agent(self, agent_id: Union[str, uuid.UUID]) -> AgentModel:
319
+ response_data = await self._request(
320
+ "GET", f"agents/get/{str(agent_id)}", response_model=AgentModel
321
+ )
322
+ return response_data
323
+
324
+ async def update_agent_info(
325
+ self,
326
+ agent_id: Union[str, uuid.UUID],
327
+ name: str,
328
+ ) -> AgentModel:
329
+ """
330
+ Updates agent information (name).
331
+ Name must be provided.
332
+ Requires write access to the agent.
333
+ """
334
+ payload = UpdateAgentInfoPayload(
335
+ name=name,
336
+ )
337
+ response_data = await self._request(
338
+ "PUT",
339
+ f"agents/update/{str(agent_id)}",
340
+ json=payload.model_dump(mode="json", exclude_none=True),
341
+ response_model=AgentModel,
342
+ )
343
+ return response_data
344
+
345
+ async def delete_agent(self, agent_id: Union[str, uuid.UUID]) -> DeletionResponse:
346
+ response_data = await self._request(
347
+ "DELETE", f"agents/delete/{str(agent_id)}", response_model=DeletionResponse
348
+ )
349
+ return response_data
350
+
351
+ async def add_agent_to_population(
352
+ self, agent_id: Union[str, uuid.UUID], population_id: Union[str, uuid.UUID]
353
+ ) -> Dict[str, str]:
354
+ """Add an agent to an additional population."""
355
+ raw_response = await self._request(
356
+ "POST", f"agents/{str(agent_id)}/populations/{str(population_id)}"
357
+ )
358
+ return raw_response.json()
359
+
360
+ async def remove_agent_from_population(
361
+ self, agent_id: Union[str, uuid.UUID], population_id: Union[str, uuid.UUID]
362
+ ) -> Dict[str, str]:
363
+ """Remove an agent from a population."""
364
+ raw_response = await self._request(
365
+ "DELETE", f"agents/{str(agent_id)}/populations/{str(population_id)}"
366
+ )
367
+ return raw_response.json()
368
+
369
+ async def batch_add_agents_to_population(
370
+ self,
371
+ agent_ids: List[Union[str, uuid.UUID]],
372
+ population_id: Union[str, uuid.UUID],
373
+ ) -> Dict[str, Any]:
374
+ """Add multiple agents to a population in a single batch operation."""
375
+ agent_id_strs = [str(aid) for aid in agent_ids]
376
+ raw_response = await self._request(
377
+ "POST", f"populations/{str(population_id)}/agents/batch", json=agent_id_strs
378
+ )
379
+ return raw_response.json()
380
+
381
+ async def get_populations_for_agent(
382
+ self, agent_id: Union[str, uuid.UUID]
383
+ ) -> Dict[str, Any]:
384
+ """Get all populations an agent belongs to."""
385
+ raw_response = await self._request("GET", f"agents/{str(agent_id)}/populations")
386
+ return raw_response.json()
387
+
388
+ async def create_data_item(
389
+ self,
390
+ agent_id: Union[str, uuid.UUID],
391
+ data_type: str,
392
+ content: Any,
393
+ metadata: Optional[Dict[str, Any]] = None,
394
+ ) -> DataItem:
395
+ """Creates a new data item for a specific agent."""
396
+ payload = CreateDataItemPayload(
397
+ data_type=data_type, content=content, metadata=metadata
398
+ )
399
+ response_data = await self._request(
400
+ "POST",
401
+ f"data_item/create/{str(agent_id)}",
402
+ json=payload.model_dump(mode="json"),
403
+ response_model=DataItem,
404
+ )
405
+ return response_data
406
+
407
+ async def get_data_item(self, data_item_id: Union[str, uuid.UUID]) -> DataItem:
408
+ response_data = await self._request(
409
+ "GET", f"data_item/get/{str(data_item_id)}", response_model=DataItem
410
+ )
411
+ return response_data
412
+
413
+ async def list_data_items(
414
+ self, agent_id: Union[str, uuid.UUID], data_type: Optional[str] = None
415
+ ) -> List[DataItem]:
416
+ params = {}
417
+ if data_type:
418
+ params["data_type"] = data_type
419
+ agent_id_str = str(agent_id)
420
+ raw_response = await self._request(
421
+ "GET", f"data_item/list/{agent_id_str}", params=params
422
+ )
423
+ return [DataItem(**item) for item in raw_response.json()]
424
+
425
+ async def update_data_item(
426
+ self,
427
+ data_item_id: Union[str, uuid.UUID],
428
+ content: Any,
429
+ metadata: Optional[Dict[str, Any]] = None,
430
+ ) -> DataItem:
431
+ """Updates an existing data item."""
432
+ payload = UpdateDataItemPayload(content=content, metadata=metadata)
433
+ response_data = await self._request(
434
+ "POST",
435
+ f"data_item/update/{str(data_item_id)}",
436
+ json=payload.model_dump(),
437
+ response_model=DataItem,
438
+ )
439
+ return response_data
440
+
441
+ async def delete_data_item(
442
+ self, data_item_id: Union[str, uuid.UUID]
443
+ ) -> DeletionResponse:
444
+ response_data = await self._request(
445
+ "DELETE",
446
+ f"data_item/delete/{str(data_item_id)}",
447
+ response_model=DeletionResponse,
448
+ )
449
+ return response_data
450
+
451
+ async def stream_open_response(
452
+ self,
453
+ agent_id: uuid.UUID,
454
+ question: str,
455
+ data_types: Optional[List[str]] = None,
456
+ exclude_data_types: Optional[List[str]] = None,
457
+ images: Optional[Dict[str, str]] = None,
458
+ reasoning: bool = False,
459
+ evidence: bool = False,
460
+ confidence: bool = False,
461
+ memory_stream: Optional[MemoryStream] = None,
462
+ ) -> AsyncGenerator[str, None]:
463
+ """Streams an open response from an agent."""
464
+ endpoint = f"/generation/open-stream/{str(agent_id)}"
465
+ request_payload = OpenGenerationRequest(
466
+ question=question,
467
+ data_types=data_types,
468
+ exclude_data_types=exclude_data_types,
469
+ images=images,
470
+ reasoning=reasoning,
471
+ evidence=evidence,
472
+ confidence=confidence,
473
+ )
474
+
475
+ url = self.base_url + endpoint # assuming self.base_url is defined
476
+
477
+ async with httpx.AsyncClient(timeout=None) as client:
478
+ async with client.stream(
479
+ "POST", url, json=request_payload.model_dump()
480
+ ) as response:
481
+ response.raise_for_status()
482
+ async for line in response.aiter_lines():
483
+ if line.strip(): # skip empty lines
484
+ if line.startswith("data: "): # optional, if using SSE format
485
+ yield line.removeprefix("data: ").strip()
486
+ else:
487
+ yield line.strip()
488
+
489
+ async def stream_closed_response(
490
+ self,
491
+ agent_id: uuid.UUID,
492
+ question: str,
493
+ options: List[str],
494
+ data_types: Optional[List[str]] = None,
495
+ exclude_data_types: Optional[List[str]] = None,
496
+ images: Optional[Dict[str, str]] = None,
497
+ ) -> AsyncGenerator[str, None]:
498
+ """Streams a closed response from an agent."""
499
+ endpoint = f"/generation/closed-stream/{str(agent_id)}"
500
+
501
+ request_payload = {
502
+ "question": question,
503
+ "options": options,
504
+ "data_types": data_types,
505
+ "exclude_data_types": exclude_data_types,
506
+ "images": images,
507
+ }
508
+
509
+ url = self.base_url + endpoint # assuming self.base_url is defined
510
+
511
+ async with httpx.AsyncClient(timeout=None) as client:
512
+ async with client.stream("POST", url, json=request_payload) as response:
513
+ response.raise_for_status()
514
+ async for line in response.aiter_lines():
515
+ if line.strip(): # skip empty lines
516
+ if line.startswith("data: "): # optional, if using SSE format
517
+ yield line.removeprefix("data: ").strip()
518
+ else:
519
+ yield line.strip()
520
+
521
+ async def generate_open_response(
522
+ self,
523
+ agent_id: uuid.UUID,
524
+ question: str,
525
+ question_id: Optional[str] = None,
526
+ study_id: Optional[str] = None,
527
+ data_types: Optional[List[str]] = None,
528
+ exclude_data_types: Optional[List[str]] = None,
529
+ images: Optional[Dict[str, str]] = None,
530
+ reasoning: bool = False,
531
+ evidence: bool = False,
532
+ confidence: bool = False,
533
+ memory_stream: Optional[MemoryStream] = None,
534
+ use_memory: Optional[
535
+ Union[str, uuid.UUID]
536
+ ] = None, # Session ID to load memory from
537
+ exclude_memory_ids: Optional[List[str]] = None, # Study/question IDs to exclude
538
+ save_memory: Optional[
539
+ Union[str, uuid.UUID]
540
+ ] = None, # Session ID to save memory to
541
+ include_data_room: Optional[bool] = False,
542
+ organization_id: Optional[str] = None,
543
+ ) -> OpenGenerationResponse:
544
+ """Generates an open response from an agent based on a question.
545
+
546
+ Args:
547
+ agent_id: The agent to query
548
+ question: The question to ask
549
+ question_id: Optional question ID for tracking
550
+ study_id: Optional study ID for tracking
551
+ data_types: Optional data types to include
552
+ exclude_data_types: Optional data types to exclude
553
+ images: Optional images dict
554
+ reasoning: Whether to include reasoning
555
+ memory_stream: Explicit memory stream to use (overrides use_memory)
556
+ use_memory: Session ID to automatically load memory from
557
+ exclude_memory_ids: Study/question IDs to exclude from loaded memory
558
+ save_memory: Session ID to automatically save response to memory
559
+ include_data_room: Whether to include data room info
560
+ organization_id: Optional organization ID
561
+ """
562
+ endpoint = f"/generation/open/{str(agent_id)}"
563
+ # Build request payload directly as dict to avoid serialization issues
564
+ request_payload = {
565
+ "question": question,
566
+ "question_id": question_id,
567
+ "study_id": study_id,
568
+ "data_types": data_types,
569
+ "exclude_data_types": exclude_data_types,
570
+ "images": images,
571
+ "reasoning": reasoning,
572
+ "evidence": evidence,
573
+ "confidence": confidence,
574
+ }
575
+
576
+ # Conditionally add optional fields
577
+ if include_data_room is not None:
578
+ request_payload["include_data_room"] = include_data_room
579
+ if organization_id is not None:
580
+ request_payload["organization_id"] = organization_id
581
+
582
+ # Pass memory parameters to API for server-side handling
583
+ if use_memory:
584
+ request_payload["use_memory"] = str(use_memory)
585
+ if exclude_memory_ids:
586
+ request_payload["exclude_memory_ids"] = exclude_memory_ids
587
+
588
+ if save_memory:
589
+ request_payload["save_memory"] = str(save_memory)
590
+
591
+ # Only include explicit memory_stream if provided directly
592
+ if memory_stream:
593
+ request_payload["memory_stream"] = memory_stream.to_dict()
594
+
595
+ response_data = await self._request(
596
+ "POST",
597
+ endpoint,
598
+ json=request_payload,
599
+ response_model=OpenGenerationResponse,
600
+ )
601
+
602
+ # Don't save memory here - API should handle it when save_memory is passed
603
+ # Memory saving is now handled server-side for better performance
604
+
605
+ return response_data
606
+
607
+ async def generate_closed_response(
608
+ self,
609
+ agent_id: uuid.UUID,
610
+ question: str,
611
+ options: List[str],
612
+ question_id: Optional[str] = None,
613
+ study_id: Optional[str] = None,
614
+ data_types: Optional[List[str]] = None,
615
+ exclude_data_types: Optional[List[str]] = None,
616
+ images: Optional[Dict[str, str]] = None,
617
+ reasoning: bool = False,
618
+ evidence: bool = False,
619
+ confidence: bool = False,
620
+ memory_stream: Optional[MemoryStream] = None,
621
+ use_memory: Optional[
622
+ Union[str, uuid.UUID]
623
+ ] = None, # Session ID to load memory from
624
+ exclude_memory_ids: Optional[List[str]] = None, # Study/question IDs to exclude
625
+ save_memory: Optional[
626
+ Union[str, uuid.UUID]
627
+ ] = None, # Session ID to save memory to
628
+ summary_mode: bool = True,
629
+ method: Optional[str] = None, # Sampling method; defaults to tool-call-random-sampling
630
+ include_data_room: Optional[bool] = False,
631
+ organization_id: Optional[str] = None,
632
+ ) -> ClosedGenerationResponse:
633
+ """Generates a closed response from an agent.
634
+
635
+ Args:
636
+ agent_id: The agent to query
637
+ question: The question to ask
638
+ options: The options to choose from
639
+ question_id: Optional question ID for tracking
640
+ study_id: Optional study ID for tracking
641
+ data_types: Optional data types to include
642
+ exclude_data_types: Optional data types to exclude
643
+ images: Optional images dict
644
+ reasoning: Whether to include reasoning
645
+ memory_stream: Explicit memory stream to use (overrides use_memory)
646
+ use_memory: Session ID to automatically load memory from
647
+ exclude_memory_ids: Study/question IDs to exclude from loaded memory
648
+ save_memory: Session ID to automatically save response to memory
649
+ include_data_room: Whether to include data room info
650
+ organization_id: Optional organization ID
651
+ """
652
+ endpoint = f"generation/closed/{str(agent_id)}"
653
+ # Build request payload directly as dict to avoid serialization issues
654
+ request_payload = {
655
+ "question": question,
656
+ "options": options,
657
+ "question_id": question_id,
658
+ "study_id": study_id,
659
+ "data_types": data_types,
660
+ "exclude_data_types": exclude_data_types,
661
+ "images": images,
662
+ "reasoning": reasoning,
663
+ "evidence": evidence,
664
+ "confidence": confidence,
665
+ }
666
+
667
+ # Conditionally add optional fields
668
+ if include_data_room is not None:
669
+ request_payload["include_data_room"] = include_data_room
670
+ if organization_id is not None:
671
+ request_payload["organization_id"] = organization_id
672
+
673
+ # Pass memory parameters to API for server-side handling
674
+ if use_memory:
675
+ request_payload["use_memory"] = str(use_memory)
676
+ if exclude_memory_ids:
677
+ request_payload["exclude_memory_ids"] = exclude_memory_ids
678
+
679
+ if save_memory:
680
+ request_payload["save_memory"] = str(save_memory)
681
+
682
+ # Only include explicit memory_stream if provided directly
683
+ if memory_stream:
684
+ request_payload["memory_stream"] = memory_stream.to_dict()
685
+
686
+ response_data = await self._request(
687
+ "POST",
688
+ endpoint,
689
+ json=request_payload,
690
+ response_model=ClosedGenerationResponse,
691
+ )
692
+
693
+ # Don't save memory here - API should handle it when save_memory is passed
694
+ # Memory saving is now handled server-side for better performance
695
+
696
+ return response_data
697
+
698
+ # Memory Management Methods
699
+
700
+ async def save_memory(
701
+ self,
702
+ agent_id: Union[str, uuid.UUID],
703
+ response: str,
704
+ session_id: Optional[Union[str, uuid.UUID]] = None,
705
+ question_id: Optional[Union[str, uuid.UUID]] = None,
706
+ study_id: Optional[Union[str, uuid.UUID]] = None,
707
+ memory_turn: Optional[Dict[str, Any]] = None,
708
+ memory_stream_used: Optional[Dict[str, Any]] = None,
709
+ reasoning: Optional[str] = None,
710
+ metadata: Optional[Dict[str, Any]] = None,
711
+ ) -> str:
712
+ """
713
+ Save a response with associated memory information.
714
+
715
+ Args:
716
+ agent_id: The agent ID
717
+ response: The agent's response text
718
+ session_id: Session ID for memory continuity
719
+ question_id: The question ID (optional)
720
+ study_id: The study ID (optional)
721
+ memory_turn: The memory turn to save
722
+ memory_stream_used: The memory stream that was used
723
+ reasoning: Optional reasoning
724
+ metadata: Additional metadata
725
+
726
+ Returns:
727
+ Response ID if saved successfully
728
+ """
729
+ payload = {
730
+ "agent_id": str(agent_id),
731
+ "response": response,
732
+ }
733
+
734
+ if session_id:
735
+ payload["session_id"] = str(session_id)
736
+ if question_id:
737
+ payload["question_id"] = str(question_id)
738
+ if study_id:
739
+ payload["study_id"] = str(study_id)
740
+ if memory_turn:
741
+ payload["memory_turn"] = memory_turn
742
+ if memory_stream_used:
743
+ payload["memory_stream_used"] = memory_stream_used
744
+ if reasoning:
745
+ payload["reasoning"] = reasoning
746
+ if metadata:
747
+ payload["metadata"] = metadata
748
+
749
+ response = await self._request("POST", "memory/save", json=payload)
750
+ data = response.json()
751
+ if data.get("success"):
752
+ return data.get("response_id")
753
+ raise SimileAPIError("Failed to save memory")
754
+
755
+ async def get_memory(
756
+ self,
757
+ session_id: Union[str, uuid.UUID],
758
+ agent_id: Union[str, uuid.UUID],
759
+ exclude_study_ids: Optional[List[Union[str, uuid.UUID]]] = None,
760
+ exclude_question_ids: Optional[List[Union[str, uuid.UUID]]] = None,
761
+ limit: Optional[int] = None,
762
+ use_memory: bool = True,
763
+ ) -> Optional[MemoryStream]:
764
+ """
765
+ Retrieve the memory stream for an agent in a session.
766
+
767
+ Args:
768
+ session_id: Session ID to filter by
769
+ agent_id: The agent ID
770
+ exclude_study_ids: List of study IDs to exclude
771
+ exclude_question_ids: List of question IDs to exclude
772
+ limit: Maximum number of turns to include
773
+ use_memory: Whether to use memory at all
774
+
775
+ Returns:
776
+ MemoryStream object or None
777
+ """
778
+ payload = {
779
+ "session_id": str(session_id),
780
+ "agent_id": str(agent_id),
781
+ "use_memory": use_memory,
782
+ }
783
+
784
+ if exclude_study_ids:
785
+ payload["exclude_study_ids"] = [str(id) for id in exclude_study_ids]
786
+ if exclude_question_ids:
787
+ payload["exclude_question_ids"] = [str(id) for id in exclude_question_ids]
788
+ if limit:
789
+ payload["limit"] = limit
790
+
791
+ response = await self._request("POST", "memory/get", json=payload)
792
+ data = response.json()
793
+
794
+ if data.get("success") and data.get("memory_stream"):
795
+ return MemoryStream.from_dict(data["memory_stream"])
796
+ return None
797
+
798
+ async def get_memory_summary(
799
+ self,
800
+ session_id: Union[str, uuid.UUID],
801
+ ) -> Dict[str, Any]:
802
+ """
803
+ Get a summary of memory usage for a session.
804
+
805
+ Args:
806
+ session_id: Session ID to analyze
807
+
808
+ Returns:
809
+ Dictionary with memory statistics
810
+ """
811
+ response = await self._request("GET", f"memory/summary/{session_id}")
812
+ data = response.json()
813
+ if data.get("success"):
814
+ return data.get("summary", {})
815
+ return {}
816
+
817
+ async def clear_memory(
818
+ self,
819
+ session_id: Union[str, uuid.UUID],
820
+ agent_id: Optional[Union[str, uuid.UUID]] = None,
821
+ study_id: Optional[Union[str, uuid.UUID]] = None,
822
+ ) -> bool:
823
+ """
824
+ Clear memory for a session, optionally filtered by agent or study.
825
+
826
+ Args:
827
+ session_id: Session ID to clear memory for
828
+ agent_id: Optional agent ID to filter by
829
+ study_id: Optional study ID to filter by
830
+
831
+ Returns:
832
+ True if cleared successfully, False otherwise
833
+ """
834
+ payload = {
835
+ "session_id": str(session_id),
836
+ }
837
+
838
+ if agent_id:
839
+ payload["agent_id"] = str(agent_id)
840
+ if study_id:
841
+ payload["study_id"] = str(study_id)
842
+
843
+ response = await self._request("POST", "memory/clear", json=payload)
844
+ data = response.json()
845
+ return data.get("success", False)
846
+
847
+ async def copy_memory(
848
+ self,
849
+ from_session_id: Union[str, uuid.UUID],
850
+ to_session_id: Union[str, uuid.UUID],
851
+ agent_id: Optional[Union[str, uuid.UUID]] = None,
852
+ ) -> int:
853
+ """
854
+ Copy memory from one session to another.
855
+
856
+ Args:
857
+ from_session_id: Source session ID
858
+ to_session_id: Destination session ID
859
+ agent_id: Optional agent ID to filter by
860
+
861
+ Returns:
862
+ Number of memory turns copied
863
+ """
864
+ payload = {
865
+ "from_session_id": str(from_session_id),
866
+ "to_session_id": str(to_session_id),
867
+ }
868
+
869
+ if agent_id:
870
+ payload["agent_id"] = str(agent_id)
871
+
872
+ response = await self._request("POST", "memory/copy", json=payload)
873
+ data = response.json()
874
+ if data.get("success"):
875
+ return data.get("copied_turns", 0)
876
+ return 0
877
+
878
+ async def aclose(self):
879
+ await self._client.aclose()
880
+
881
+ async def __aenter__(self):
882
+ return self
883
+
884
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
885
+ await self.aclose()