neuro-simulator 0.4.4__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,102 +0,0 @@
1
- # neuro_simulator/services/audience.py
2
- import logging
3
-
4
- from google import genai
5
- from google.genai import types
6
- from openai import AsyncOpenAI
7
-
8
- from ..core.config import config_manager, AppSettings
9
- from ..utils.state import app_state
10
-
11
- logger = logging.getLogger(__name__.replace("neuro_simulator", "server", 1))
12
-
13
- class AudienceLLMClient:
14
- async def generate_chat_messages(self, prompt: str, max_tokens: int) -> str:
15
- raise NotImplementedError
16
-
17
- class GeminiAudienceLLM(AudienceLLMClient):
18
- def __init__(self, api_key: str, model_name: str):
19
- if not api_key:
20
- raise ValueError("Gemini API Key is not provided for GeminiAudienceLLM.")
21
- self.client = genai.Client(api_key=api_key)
22
- self.model_name = model_name
23
- logger.info(f"Initialized GeminiAudienceLLM (new SDK), model: {self.model_name}")
24
-
25
- async def generate_chat_messages(self, prompt: str, max_tokens: int) -> str:
26
- response = await self.client.aio.models.generate_content(
27
- model=self.model_name,
28
- contents=prompt,
29
- config=types.GenerateContentConfig(
30
- temperature=config_manager.settings.audience_simulation.llm_temperature,
31
- max_output_tokens=max_tokens
32
- )
33
- )
34
- raw_chat_text = ""
35
- if hasattr(response, 'text') and response.text:
36
- raw_chat_text = response.text
37
- elif response.candidates and response.candidates[0].content and response.candidates[0].content.parts:
38
- for part in response.candidates[0].content.parts:
39
- if hasattr(part, 'text') and part.text:
40
- raw_chat_text += part.text
41
- return raw_chat_text
42
-
43
- class OpenAIAudienceLLM(AudienceLLMClient):
44
- def __init__(self, api_key: str, model_name: str, base_url: str | None):
45
- if not api_key:
46
- raise ValueError("OpenAI API Key is not provided for OpenAIAudienceLLM.")
47
- self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
48
- self.model_name = model_name
49
- logger.info(f"Initialized OpenAIAudienceLLM, model: {self.model_name}, API Base: {base_url}")
50
-
51
- async def generate_chat_messages(self, prompt: str, max_tokens: int) -> str:
52
- response = await self.client.chat.completions.create(
53
- model=self.model_name,
54
- messages=[{"role": "user", "content": prompt}],
55
- temperature=config_manager.settings.audience_simulation.llm_temperature,
56
- max_tokens=max_tokens,
57
- )
58
- if response.choices and response.choices[0].message and response.choices[0].message.content:
59
- return response.choices[0].message.content.strip()
60
- return ""
61
-
62
- async def get_dynamic_audience_prompt() -> str:
63
- current_neuro_speech = ""
64
- async with app_state.neuro_last_speech_lock:
65
- current_neuro_speech = app_state.neuro_last_speech
66
-
67
- prompt = config_manager.settings.audience_simulation.prompt_template.format(
68
- neuro_speech=current_neuro_speech,
69
- num_chats_to_generate=config_manager.settings.audience_simulation.chats_per_batch
70
- )
71
- return prompt
72
-
73
- class AudienceChatbotManager:
74
- def __init__(self):
75
- self.client: AudienceLLMClient = self._create_client(config_manager.settings)
76
- self._last_checked_settings: dict = config_manager.settings.audience_simulation.model_dump()
77
- logger.info("AudienceChatbotManager initialized.")
78
-
79
- def _create_client(self, settings: AppSettings) -> AudienceLLMClient:
80
- provider = settings.audience_simulation.llm_provider
81
- logger.info(f"Creating new audience LLM client for provider: {provider}")
82
- if provider.lower() == "gemini":
83
- if not settings.api_keys.gemini_api_key:
84
- raise ValueError("GEMINI_API_KEY not set in config")
85
- return GeminiAudienceLLM(api_key=settings.api_keys.gemini_api_key, model_name=settings.audience_simulation.gemini_model)
86
- elif provider.lower() == "openai":
87
- if not settings.api_keys.openai_api_key:
88
- raise ValueError("OPENAI_API_KEY not set in config")
89
- return OpenAIAudienceLLM(api_key=settings.api_keys.openai_api_key, model_name=settings.audience_simulation.openai_model, base_url=settings.api_keys.openai_api_base_url)
90
- else:
91
- raise ValueError(f"Unsupported AUDIENCE_LLM_PROVIDER: {provider}")
92
-
93
- def handle_config_update(self, new_settings: AppSettings):
94
- new_audience_settings = new_settings.audience_simulation.model_dump()
95
- if new_audience_settings != self._last_checked_settings:
96
- logger.info("Audience simulation settings changed, re-initializing LLM client...")
97
- try:
98
- self.client = self._create_client(new_settings)
99
- self._last_checked_settings = new_audience_settings
100
- logger.info("LLM client hot-reloaded successfully.")
101
- except Exception as e:
102
- logger.error(f"Error hot-reloading LLM client: {e}", exc_info=True)