dataknobs-bots 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. dataknobs_bots/__init__.py +42 -0
  2. dataknobs_bots/api/__init__.py +42 -0
  3. dataknobs_bots/api/dependencies.py +140 -0
  4. dataknobs_bots/api/exceptions.py +289 -0
  5. dataknobs_bots/bot/__init__.py +15 -0
  6. dataknobs_bots/bot/base.py +1091 -0
  7. dataknobs_bots/bot/context.py +102 -0
  8. dataknobs_bots/bot/manager.py +430 -0
  9. dataknobs_bots/bot/registry.py +629 -0
  10. dataknobs_bots/config/__init__.py +39 -0
  11. dataknobs_bots/config/resolution.py +353 -0
  12. dataknobs_bots/knowledge/__init__.py +82 -0
  13. dataknobs_bots/knowledge/query/__init__.py +25 -0
  14. dataknobs_bots/knowledge/query/expander.py +262 -0
  15. dataknobs_bots/knowledge/query/transformer.py +288 -0
  16. dataknobs_bots/knowledge/rag.py +738 -0
  17. dataknobs_bots/knowledge/retrieval/__init__.py +23 -0
  18. dataknobs_bots/knowledge/retrieval/formatter.py +249 -0
  19. dataknobs_bots/knowledge/retrieval/merger.py +279 -0
  20. dataknobs_bots/memory/__init__.py +56 -0
  21. dataknobs_bots/memory/base.py +38 -0
  22. dataknobs_bots/memory/buffer.py +58 -0
  23. dataknobs_bots/memory/vector.py +188 -0
  24. dataknobs_bots/middleware/__init__.py +11 -0
  25. dataknobs_bots/middleware/base.py +92 -0
  26. dataknobs_bots/middleware/cost.py +421 -0
  27. dataknobs_bots/middleware/logging.py +184 -0
  28. dataknobs_bots/reasoning/__init__.py +65 -0
  29. dataknobs_bots/reasoning/base.py +50 -0
  30. dataknobs_bots/reasoning/react.py +299 -0
  31. dataknobs_bots/reasoning/simple.py +51 -0
  32. dataknobs_bots/registry/__init__.py +41 -0
  33. dataknobs_bots/registry/backend.py +181 -0
  34. dataknobs_bots/registry/memory.py +244 -0
  35. dataknobs_bots/registry/models.py +102 -0
  36. dataknobs_bots/registry/portability.py +210 -0
  37. dataknobs_bots/tools/__init__.py +5 -0
  38. dataknobs_bots/tools/knowledge_search.py +113 -0
  39. dataknobs_bots/utils/__init__.py +1 -0
  40. dataknobs_bots-0.2.4.dist-info/METADATA +591 -0
  41. dataknobs_bots-0.2.4.dist-info/RECORD +42 -0
  42. dataknobs_bots-0.2.4.dist-info/WHEEL +4 -0
@@ -0,0 +1,421 @@
1
+ """Cost tracking middleware for monitoring LLM usage."""
2
+
3
+ import json
4
+ import logging
5
+ from typing import Any
6
+
7
+ from dataknobs_bots.bot.context import BotContext
8
+
9
+ from .base import Middleware
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class CostTrackingMiddleware(Middleware):
15
+ """Middleware for tracking LLM API costs and usage.
16
+
17
+ Monitors token usage across different providers (Ollama, OpenAI, Anthropic, etc.)
18
+ to help optimize costs and track budgets.
19
+
20
+ Attributes:
21
+ track_tokens: Whether to track token usage
22
+ cost_rates: Token cost rates per provider/model
23
+ usage_stats: Accumulated usage statistics by client_id
24
+
25
+ Example:
26
+ ```python
27
+ # Create middleware with default rates
28
+ middleware = CostTrackingMiddleware()
29
+
30
+ # Or with custom rates
31
+ middleware = CostTrackingMiddleware(
32
+ cost_rates={
33
+ "openai": {
34
+ "gpt-4o": {"input": 0.0025, "output": 0.01},
35
+ },
36
+ }
37
+ )
38
+
39
+ # Get stats
40
+ stats = middleware.get_client_stats("my-client")
41
+ total = middleware.get_total_cost()
42
+
43
+ # Export to JSON
44
+ json_data = middleware.export_stats_json()
45
+ ```
46
+ """
47
+
48
+ # Default cost rates (USD per 1K tokens) - Updated Dec 2024
49
+ DEFAULT_RATES: dict[str, Any] = {
50
+ "ollama": {"input": 0.0, "output": 0.0}, # Free (infrastructure cost only)
51
+ "openai": {
52
+ "gpt-4o": {"input": 0.0025, "output": 0.01},
53
+ "gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
54
+ "gpt-4-turbo": {"input": 0.01, "output": 0.03},
55
+ "gpt-4": {"input": 0.03, "output": 0.06},
56
+ "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
57
+ "o1": {"input": 0.015, "output": 0.06},
58
+ "o1-mini": {"input": 0.003, "output": 0.012},
59
+ },
60
+ "anthropic": {
61
+ "claude-3-5-sonnet": {"input": 0.003, "output": 0.015},
62
+ "claude-3-5-haiku": {"input": 0.0008, "output": 0.004},
63
+ "claude-3-opus": {"input": 0.015, "output": 0.075},
64
+ "claude-3-sonnet": {"input": 0.003, "output": 0.015},
65
+ "claude-3-haiku": {"input": 0.00025, "output": 0.00125},
66
+ },
67
+ "google": {
68
+ "gemini-1.5-pro": {"input": 0.00125, "output": 0.005},
69
+ "gemini-1.5-flash": {"input": 0.000075, "output": 0.0003},
70
+ "gemini-2.0-flash": {"input": 0.0001, "output": 0.0004},
71
+ },
72
+ }
73
+
74
+ def __init__(
75
+ self,
76
+ track_tokens: bool = True,
77
+ cost_rates: dict[str, Any] | None = None,
78
+ ):
79
+ """Initialize cost tracking middleware.
80
+
81
+ Args:
82
+ track_tokens: Enable token tracking
83
+ cost_rates: Optional custom cost rates (merged with defaults)
84
+ """
85
+ self.track_tokens = track_tokens
86
+ # Merge custom rates with defaults
87
+ self.cost_rates = self.DEFAULT_RATES.copy()
88
+ if cost_rates:
89
+ for provider, rates in cost_rates.items():
90
+ if provider in self.cost_rates:
91
+ if isinstance(rates, dict) and isinstance(
92
+ self.cost_rates[provider], dict
93
+ ):
94
+ self.cost_rates[provider].update(rates)
95
+ else:
96
+ self.cost_rates[provider] = rates
97
+ else:
98
+ self.cost_rates[provider] = rates
99
+
100
+ self._usage_stats: dict[str, dict[str, Any]] = {}
101
+ self._logger = logging.getLogger(f"{__name__}.CostTracker")
102
+
103
+ async def before_message(self, message: str, context: BotContext) -> None:
104
+ """Track message before processing (mainly for logging).
105
+
106
+ Args:
107
+ message: User's input message
108
+ context: Bot context
109
+ """
110
+ # Estimate input tokens (rough approximation: ~4 chars per token)
111
+ estimated_tokens = len(message) // 4
112
+ self._logger.debug(f"Estimated input tokens: {estimated_tokens}")
113
+
114
+ async def after_message(
115
+ self, response: str, context: BotContext, **kwargs: Any
116
+ ) -> None:
117
+ """Track costs after bot response.
118
+
119
+ Args:
120
+ response: Bot's generated response
121
+ context: Bot context
122
+ **kwargs: Should contain 'tokens_used', 'provider', 'model' if available
123
+ """
124
+ if not self.track_tokens:
125
+ return
126
+
127
+ client_id = context.client_id
128
+
129
+ # Extract provider and model info
130
+ provider = kwargs.get("provider", "unknown")
131
+ model = kwargs.get("model", "unknown")
132
+
133
+ # Get token counts
134
+ tokens_used = kwargs.get("tokens_used", {})
135
+ if isinstance(tokens_used, int):
136
+ # If single number, assume it's total and estimate split
137
+ input_tokens = len(context.session_metadata.get("last_message", "")) // 4
138
+ output_tokens = tokens_used - input_tokens
139
+ else:
140
+ input_tokens = int(
141
+ tokens_used.get(
142
+ "input",
143
+ tokens_used.get(
144
+ "prompt_tokens",
145
+ len(context.session_metadata.get("last_message", "")) // 4,
146
+ ),
147
+ )
148
+ )
149
+ output_tokens = int(
150
+ tokens_used.get(
151
+ "output",
152
+ tokens_used.get("completion_tokens", len(response) // 4),
153
+ )
154
+ )
155
+
156
+ # Calculate cost
157
+ cost = self._calculate_cost(provider, model, input_tokens, output_tokens)
158
+
159
+ # Update stats
160
+ if client_id not in self._usage_stats:
161
+ self._usage_stats[client_id] = {
162
+ "client_id": client_id,
163
+ "total_requests": 0,
164
+ "total_input_tokens": 0,
165
+ "total_output_tokens": 0,
166
+ "total_cost_usd": 0.0,
167
+ "by_provider": {},
168
+ }
169
+
170
+ stats = self._usage_stats[client_id]
171
+ stats["total_requests"] += 1
172
+ stats["total_input_tokens"] += input_tokens
173
+ stats["total_output_tokens"] += output_tokens
174
+ stats["total_cost_usd"] += cost
175
+
176
+ # Track by provider
177
+ if provider not in stats["by_provider"]:
178
+ stats["by_provider"][provider] = {
179
+ "requests": 0,
180
+ "input_tokens": 0,
181
+ "output_tokens": 0,
182
+ "cost_usd": 0.0,
183
+ "by_model": {},
184
+ }
185
+
186
+ provider_stats = stats["by_provider"][provider]
187
+ provider_stats["requests"] += 1
188
+ provider_stats["input_tokens"] += input_tokens
189
+ provider_stats["output_tokens"] += output_tokens
190
+ provider_stats["cost_usd"] += cost
191
+
192
+ # Track by model within provider
193
+ if model not in provider_stats["by_model"]:
194
+ provider_stats["by_model"][model] = {
195
+ "requests": 0,
196
+ "input_tokens": 0,
197
+ "output_tokens": 0,
198
+ "cost_usd": 0.0,
199
+ }
200
+
201
+ model_stats = provider_stats["by_model"][model]
202
+ model_stats["requests"] += 1
203
+ model_stats["input_tokens"] += input_tokens
204
+ model_stats["output_tokens"] += output_tokens
205
+ model_stats["cost_usd"] += cost
206
+
207
+ self._logger.info(
208
+ f"Client {client_id}: {provider}/{model} - "
209
+ f"{input_tokens} in + {output_tokens} out tokens, "
210
+ f"cost: ${cost:.6f}, total: ${stats['total_cost_usd']:.6f}"
211
+ )
212
+
213
+ async def post_stream(
214
+ self, message: str, response: str, context: BotContext
215
+ ) -> None:
216
+ """Track costs after streaming completes.
217
+
218
+ For streaming responses, token counts are estimated from text length
219
+ since exact counts may not be available until the stream completes.
220
+
221
+ Args:
222
+ message: Original user message
223
+ response: Complete accumulated response from streaming
224
+ context: Bot context
225
+ """
226
+ if not self.track_tokens:
227
+ return
228
+
229
+ client_id = context.client_id
230
+
231
+ # For streaming, we estimate tokens from text length (~4 chars per token)
232
+ input_tokens = len(message) // 4
233
+ output_tokens = len(response) // 4
234
+
235
+ # Get provider/model from context metadata if available
236
+ provider = context.session_metadata.get("provider", "unknown")
237
+ model = context.session_metadata.get("model", "unknown")
238
+
239
+ # Calculate cost
240
+ cost = self._calculate_cost(provider, model, input_tokens, output_tokens)
241
+
242
+ # Update stats
243
+ if client_id not in self._usage_stats:
244
+ self._usage_stats[client_id] = {
245
+ "total_requests": 0,
246
+ "total_input_tokens": 0,
247
+ "total_output_tokens": 0,
248
+ "total_cost_usd": 0.0,
249
+ "by_provider": {},
250
+ }
251
+
252
+ stats = self._usage_stats[client_id]
253
+ stats["total_requests"] += 1
254
+ stats["total_input_tokens"] += input_tokens
255
+ stats["total_output_tokens"] += output_tokens
256
+ stats["total_cost_usd"] += cost
257
+
258
+ # Track by provider
259
+ if provider not in stats["by_provider"]:
260
+ stats["by_provider"][provider] = {
261
+ "requests": 0,
262
+ "input_tokens": 0,
263
+ "output_tokens": 0,
264
+ "cost_usd": 0.0,
265
+ }
266
+
267
+ provider_stats = stats["by_provider"][provider]
268
+ provider_stats["requests"] += 1
269
+ provider_stats["input_tokens"] += input_tokens
270
+ provider_stats["output_tokens"] += output_tokens
271
+ provider_stats["cost_usd"] += cost
272
+
273
+ self._logger.info(
274
+ f"Stream complete - Client {client_id}: {provider}/{model} - "
275
+ f"~{input_tokens} in + ~{output_tokens} out tokens (estimated), "
276
+ f"cost: ${cost:.6f}, total: ${stats['total_cost_usd']:.6f}"
277
+ )
278
+
279
+ async def on_error(
280
+ self, error: Exception, message: str, context: BotContext
281
+ ) -> None:
282
+ """Log errors but don't track costs for failed requests.
283
+
284
+ Args:
285
+ error: The exception that occurred
286
+ message: User message that caused the error
287
+ context: Bot context
288
+ """
289
+ self._logger.warning(
290
+ f"Error during request for client {context.client_id}: {error}"
291
+ )
292
+
293
+ def _calculate_cost(
294
+ self, provider: str, model: str, input_tokens: int, output_tokens: int
295
+ ) -> float:
296
+ """Calculate cost for token usage.
297
+
298
+ Args:
299
+ provider: LLM provider name
300
+ model: Model name
301
+ input_tokens: Number of input tokens
302
+ output_tokens: Number of output tokens
303
+
304
+ Returns:
305
+ Cost in USD
306
+ """
307
+ # Get rates for provider/model
308
+ if provider in self.cost_rates:
309
+ provider_rates = self.cost_rates[provider]
310
+
311
+ if isinstance(provider_rates, dict):
312
+ # Check if model-specific rates exist
313
+ if model in provider_rates:
314
+ rates = provider_rates[model]
315
+ elif "input" in provider_rates:
316
+ # Use generic rates for provider (e.g., ollama)
317
+ rates = provider_rates
318
+ else:
319
+ # Try partial model name match
320
+ for model_key in provider_rates:
321
+ if model_key in model or model in model_key:
322
+ rates = provider_rates[model_key]
323
+ break
324
+ else:
325
+ return 0.0
326
+ else:
327
+ return 0.0
328
+
329
+ # Calculate cost (rates are per 1K tokens)
330
+ input_cost = (input_tokens / 1000) * float(rates.get("input", 0.0))
331
+ output_cost = (output_tokens / 1000) * float(rates.get("output", 0.0))
332
+ return float(input_cost + output_cost)
333
+
334
+ return 0.0
335
+
336
+ def get_client_stats(self, client_id: str) -> dict[str, Any] | None:
337
+ """Get usage statistics for a client.
338
+
339
+ Args:
340
+ client_id: Client identifier
341
+
342
+ Returns:
343
+ Usage statistics or None if not found
344
+ """
345
+ return self._usage_stats.get(client_id)
346
+
347
+ def get_all_stats(self) -> dict[str, dict[str, Any]]:
348
+ """Get all usage statistics.
349
+
350
+ Returns:
351
+ Dictionary mapping client_id to statistics
352
+ """
353
+ return self._usage_stats.copy()
354
+
355
+ def get_total_cost(self) -> float:
356
+ """Get total cost across all clients.
357
+
358
+ Returns:
359
+ Total cost in USD
360
+ """
361
+ return float(
362
+ sum(stats["total_cost_usd"] for stats in self._usage_stats.values())
363
+ )
364
+
365
+ def get_total_tokens(self) -> dict[str, int]:
366
+ """Get total tokens across all clients.
367
+
368
+ Returns:
369
+ Dictionary with 'input', 'output', and 'total' token counts
370
+ """
371
+ input_tokens = sum(
372
+ stats["total_input_tokens"] for stats in self._usage_stats.values()
373
+ )
374
+ output_tokens = sum(
375
+ stats["total_output_tokens"] for stats in self._usage_stats.values()
376
+ )
377
+ return {
378
+ "input": input_tokens,
379
+ "output": output_tokens,
380
+ "total": input_tokens + output_tokens,
381
+ }
382
+
383
+ def clear_stats(self, client_id: str | None = None) -> None:
384
+ """Clear usage statistics.
385
+
386
+ Args:
387
+ client_id: If provided, clear only this client. Otherwise clear all.
388
+ """
389
+ if client_id:
390
+ if client_id in self._usage_stats:
391
+ del self._usage_stats[client_id]
392
+ else:
393
+ self._usage_stats.clear()
394
+
395
+ def export_stats_json(self, indent: int = 2) -> str:
396
+ """Export all statistics as JSON.
397
+
398
+ Args:
399
+ indent: JSON indentation level
400
+
401
+ Returns:
402
+ JSON string of all statistics
403
+ """
404
+ return json.dumps(self._usage_stats, indent=indent)
405
+
406
+ def export_stats_csv(self) -> str:
407
+ """Export statistics as CSV (one row per client).
408
+
409
+ Returns:
410
+ CSV string with headers
411
+ """
412
+ lines = [
413
+ "client_id,total_requests,total_input_tokens,total_output_tokens,total_cost_usd"
414
+ ]
415
+ for client_id, stats in self._usage_stats.items():
416
+ lines.append(
417
+ f"{client_id},{stats['total_requests']},"
418
+ f"{stats['total_input_tokens']},{stats['total_output_tokens']},"
419
+ f"{stats['total_cost_usd']:.6f}"
420
+ )
421
+ return "\n".join(lines)
@@ -0,0 +1,184 @@
1
+ """Logging middleware for conversation tracking."""
2
+
3
+ import json
4
+ import logging
5
+ from datetime import datetime, timezone
6
+ from typing import Any
7
+
8
+ from dataknobs_bots.bot.context import BotContext
9
+
10
+ from .base import Middleware
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class LoggingMiddleware(Middleware):
16
+ """Middleware for tracking conversation interactions.
17
+
18
+ Logs all user messages and bot responses with context
19
+ for monitoring, debugging, and analytics.
20
+
21
+ Attributes:
22
+ log_level: Logging level to use (default: INFO)
23
+ include_metadata: Whether to include full context metadata
24
+ json_format: Whether to output logs in JSON format
25
+
26
+ Example:
27
+ ```python
28
+ # Basic usage
29
+ middleware = LoggingMiddleware()
30
+
31
+ # With JSON format for log aggregation
32
+ middleware = LoggingMiddleware(
33
+ log_level="INFO",
34
+ include_metadata=True,
35
+ json_format=True
36
+ )
37
+ ```
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ log_level: str = "INFO",
43
+ include_metadata: bool = True,
44
+ json_format: bool = False,
45
+ ):
46
+ """Initialize logging middleware.
47
+
48
+ Args:
49
+ log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
50
+ include_metadata: Whether to log full context metadata
51
+ json_format: Whether to output in JSON format
52
+ """
53
+ self.log_level = log_level
54
+ self.include_metadata = include_metadata
55
+ self.json_format = json_format
56
+ self._logger = logging.getLogger(f"{__name__}.ConversationLogger")
57
+ self._logger.setLevel(getattr(logging, log_level.upper()))
58
+
59
+ async def before_message(self, message: str, context: BotContext) -> None:
60
+ """Called before processing user message.
61
+
62
+ Args:
63
+ message: User's input message
64
+ context: Bot context with conversation and user info
65
+ """
66
+ log_data = {
67
+ "timestamp": datetime.now(timezone.utc).isoformat(),
68
+ "event": "user_message",
69
+ "client_id": context.client_id,
70
+ "user_id": context.user_id,
71
+ "conversation_id": context.conversation_id,
72
+ "message_length": len(message),
73
+ }
74
+
75
+ if self.include_metadata:
76
+ log_data["session_metadata"] = context.session_metadata
77
+ log_data["request_metadata"] = context.request_metadata
78
+
79
+ if self.json_format:
80
+ self._logger.info(json.dumps(log_data))
81
+ else:
82
+ self._logger.info(f"User message: {log_data}")
83
+
84
+ # Log content at DEBUG level (first 200 chars)
85
+ self._logger.debug(f"Message content: {message[:200]}...")
86
+
87
+ async def after_message(
88
+ self, response: str, context: BotContext, **kwargs: Any
89
+ ) -> None:
90
+ """Called after generating bot response.
91
+
92
+ Args:
93
+ response: Bot's generated response
94
+ context: Bot context
95
+ **kwargs: Additional data (e.g., tokens_used, response_time_ms)
96
+ """
97
+ log_data = {
98
+ "timestamp": datetime.now(timezone.utc).isoformat(),
99
+ "event": "bot_response",
100
+ "client_id": context.client_id,
101
+ "user_id": context.user_id,
102
+ "conversation_id": context.conversation_id,
103
+ "response_length": len(response),
104
+ }
105
+
106
+ # Add optional metrics
107
+ if "tokens_used" in kwargs:
108
+ log_data["tokens_used"] = kwargs["tokens_used"]
109
+ if "response_time_ms" in kwargs:
110
+ log_data["response_time_ms"] = kwargs["response_time_ms"]
111
+ if "provider" in kwargs:
112
+ log_data["provider"] = kwargs["provider"]
113
+ if "model" in kwargs:
114
+ log_data["model"] = kwargs["model"]
115
+
116
+ if self.include_metadata:
117
+ log_data["session_metadata"] = context.session_metadata
118
+ log_data["request_metadata"] = context.request_metadata
119
+
120
+ if self.json_format:
121
+ self._logger.info(json.dumps(log_data))
122
+ else:
123
+ self._logger.info(f"Bot response: {log_data}")
124
+
125
+ # Log content at DEBUG level (first 200 chars)
126
+ self._logger.debug(f"Response content: {response[:200]}...")
127
+
128
+ async def post_stream(
129
+ self, message: str, response: str, context: BotContext
130
+ ) -> None:
131
+ """Called after streaming response completes.
132
+
133
+ Args:
134
+ message: Original user message
135
+ response: Complete accumulated response from streaming
136
+ context: Bot context
137
+ """
138
+ log_data = {
139
+ "timestamp": datetime.now(timezone.utc).isoformat(),
140
+ "event": "stream_complete",
141
+ "client_id": context.client_id,
142
+ "user_id": context.user_id,
143
+ "conversation_id": context.conversation_id,
144
+ "message_length": len(message),
145
+ "response_length": len(response),
146
+ }
147
+
148
+ if self.include_metadata:
149
+ log_data["session_metadata"] = context.session_metadata
150
+ log_data["request_metadata"] = context.request_metadata
151
+
152
+ if self.json_format:
153
+ self._logger.info(json.dumps(log_data))
154
+ else:
155
+ self._logger.info(f"Stream complete: {log_data}")
156
+
157
+ # Log content at DEBUG level (first 200 chars each)
158
+ self._logger.debug(f"Streamed message: {message[:200]}...")
159
+ self._logger.debug(f"Streamed response: {response[:200]}...")
160
+
161
+ async def on_error(
162
+ self, error: Exception, message: str, context: BotContext
163
+ ) -> None:
164
+ """Called when an error occurs during message processing.
165
+
166
+ Args:
167
+ error: The exception that occurred
168
+ message: User message that caused the error
169
+ context: Bot context
170
+ """
171
+ log_data = {
172
+ "timestamp": datetime.now(timezone.utc).isoformat(),
173
+ "event": "error",
174
+ "client_id": context.client_id,
175
+ "user_id": context.user_id,
176
+ "conversation_id": context.conversation_id,
177
+ "error_type": type(error).__name__,
178
+ "error_message": str(error),
179
+ }
180
+
181
+ if self.json_format:
182
+ self._logger.error(json.dumps(log_data), exc_info=error)
183
+ else:
184
+ self._logger.error(f"Error processing message: {log_data}", exc_info=error)
@@ -0,0 +1,65 @@
1
+ """Reasoning strategies for DynaBot."""
2
+
3
+ from typing import Any
4
+
5
+ from .base import ReasoningStrategy
6
+ from .react import ReActReasoning
7
+ from .simple import SimpleReasoning
8
+
9
+ __all__ = [
10
+ "ReasoningStrategy",
11
+ "SimpleReasoning",
12
+ "ReActReasoning",
13
+ "create_reasoning_from_config",
14
+ ]
15
+
16
+
17
+ def create_reasoning_from_config(config: dict[str, Any]) -> ReasoningStrategy:
18
+ """Create reasoning strategy from configuration.
19
+
20
+ Args:
21
+ config: Reasoning configuration with:
22
+ - strategy: Strategy type ('simple', 'react')
23
+ - max_iterations: For ReAct, max reasoning loops (default: 5)
24
+ - verbose: Enable debug logging for reasoning steps (default: False)
25
+ - store_trace: Store reasoning trace in conversation metadata (default: False)
26
+
27
+ Returns:
28
+ Configured reasoning strategy instance
29
+
30
+ Raises:
31
+ ValueError: If strategy type is not supported
32
+
33
+ Example:
34
+ ```python
35
+ # Simple reasoning
36
+ config = {"strategy": "simple"}
37
+ strategy = create_reasoning_from_config(config)
38
+
39
+ # ReAct reasoning with trace storage
40
+ config = {
41
+ "strategy": "react",
42
+ "max_iterations": 5,
43
+ "verbose": True,
44
+ "store_trace": True
45
+ }
46
+ strategy = create_reasoning_from_config(config)
47
+ ```
48
+ """
49
+ strategy_type = config.get("strategy", "simple").lower()
50
+
51
+ if strategy_type == "simple":
52
+ return SimpleReasoning()
53
+
54
+ elif strategy_type == "react":
55
+ return ReActReasoning(
56
+ max_iterations=config.get("max_iterations", 5),
57
+ verbose=config.get("verbose", False),
58
+ store_trace=config.get("store_trace", False),
59
+ )
60
+
61
+ else:
62
+ raise ValueError(
63
+ f"Unknown reasoning strategy: {strategy_type}. "
64
+ f"Available strategies: simple, react"
65
+ )