ziya 0.1.49__py3-none-any.whl → 0.1.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ziya might be problematic. Click here for more details.

Files changed (49) hide show
  1. app/agents/.agent.py.swp +0 -0
  2. app/agents/agent.py +315 -113
  3. app/agents/models.py +439 -0
  4. app/agents/prompts.py +32 -4
  5. app/main.py +70 -7
  6. app/server.py +403 -14
  7. app/utils/code_util.py +641 -215
  8. pyproject.toml +3 -3
  9. templates/asset-manifest.json +18 -20
  10. templates/index.html +1 -1
  11. templates/static/css/{main.87f30840.css → main.2bddf34e.css} +2 -2
  12. templates/static/css/main.2bddf34e.css.map +1 -0
  13. templates/static/js/46907.90c6a4f3.chunk.js +2 -0
  14. templates/static/js/46907.90c6a4f3.chunk.js.map +1 -0
  15. templates/static/js/56122.1d6a5c10.chunk.js +3 -0
  16. templates/static/js/56122.1d6a5c10.chunk.js.LICENSE.txt +9 -0
  17. templates/static/js/56122.1d6a5c10.chunk.js.map +1 -0
  18. templates/static/js/83953.61a908f4.chunk.js +3 -0
  19. templates/static/js/83953.61a908f4.chunk.js.map +1 -0
  20. templates/static/js/88261.1e90079d.chunk.js +3 -0
  21. templates/static/js/88261.1e90079d.chunk.js.map +1 -0
  22. templates/static/js/{96603.863a8f96.chunk.js → 96603.18c5d644.chunk.js} +2 -2
  23. templates/static/js/{96603.863a8f96.chunk.js.map → 96603.18c5d644.chunk.js.map} +1 -1
  24. templates/static/js/{97902.75670155.chunk.js → 97902.d1e262d6.chunk.js} +3 -3
  25. templates/static/js/{97902.75670155.chunk.js.map → 97902.d1e262d6.chunk.js.map} +1 -1
  26. templates/static/js/main.9b2b2b57.js +3 -0
  27. templates/static/js/{main.ee8b3c96.js.LICENSE.txt → main.9b2b2b57.js.LICENSE.txt} +8 -2
  28. templates/static/js/main.9b2b2b57.js.map +1 -0
  29. {ziya-0.1.49.dist-info → ziya-0.1.51.dist-info}/METADATA +5 -5
  30. {ziya-0.1.49.dist-info → ziya-0.1.51.dist-info}/RECORD +36 -35
  31. templates/static/css/main.87f30840.css.map +0 -1
  32. templates/static/js/23416.c33f07ab.chunk.js +0 -3
  33. templates/static/js/23416.c33f07ab.chunk.js.map +0 -1
  34. templates/static/js/3799.fedb612f.chunk.js +0 -2
  35. templates/static/js/3799.fedb612f.chunk.js.map +0 -1
  36. templates/static/js/46907.4a730107.chunk.js +0 -2
  37. templates/static/js/46907.4a730107.chunk.js.map +0 -1
  38. templates/static/js/64754.cf383335.chunk.js +0 -2
  39. templates/static/js/64754.cf383335.chunk.js.map +0 -1
  40. templates/static/js/88261.33450351.chunk.js +0 -3
  41. templates/static/js/88261.33450351.chunk.js.map +0 -1
  42. templates/static/js/main.ee8b3c96.js +0 -3
  43. templates/static/js/main.ee8b3c96.js.map +0 -1
  44. /templates/static/js/{23416.c33f07ab.chunk.js.LICENSE.txt → 83953.61a908f4.chunk.js.LICENSE.txt} +0 -0
  45. /templates/static/js/{88261.33450351.chunk.js.LICENSE.txt → 88261.1e90079d.chunk.js.LICENSE.txt} +0 -0
  46. /templates/static/js/{97902.75670155.chunk.js.LICENSE.txt → 97902.d1e262d6.chunk.js.LICENSE.txt} +0 -0
  47. {ziya-0.1.49.dist-info → ziya-0.1.51.dist-info}/LICENSE +0 -0
  48. {ziya-0.1.49.dist-info → ziya-0.1.51.dist-info}/WHEEL +0 -0
  49. {ziya-0.1.49.dist-info → ziya-0.1.51.dist-info}/entry_points.txt +0 -0
app/server.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import os
2
2
  import time
3
3
  import json
4
- from typing import Dict, Any, List, Tuple, Optional
4
+ from typing import Dict, Any, List, Tuple, Optional, Union
5
5
 
6
6
  import tiktoken
7
7
  from fastapi import FastAPI, Request, HTTPException
@@ -10,15 +10,20 @@ from fastapi.responses import JSONResponse
10
10
  from fastapi.staticfiles import StaticFiles
11
11
  from fastapi.templating import Jinja2Templates
12
12
  from langserve import add_routes
13
- from app.agents.agent import model
13
+ from app.agents.agent import model, RetryingChatBedrock
14
14
  from app.agents.agent import agent_executor
15
- from fastapi.responses import FileResponse
16
- from pydantic import BaseModel
15
+ from app.agents.agent import update_conversation_state, update_and_return
16
+ from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError
17
+ from fastapi.responses import FileResponse, StreamingResponse
18
+ from pydantic import BaseModel, Field
19
+ from app.agents.models import ModelManager
17
20
  from botocore.exceptions import ClientError, BotoCoreError, CredentialRetrievalError
18
21
  from botocore.exceptions import EventStreamError
22
+ import botocore.errorfactory
19
23
  from starlette.responses import StreamingResponse
20
24
 
21
25
  # import pydevd_pycharm
26
+ from google.api_core.exceptions import ResourceExhausted
22
27
  import uvicorn
23
28
 
24
29
  from app.utils.code_util import use_git_to_apply_code_diff, correct_git_diff, PatchApplicationError
@@ -26,6 +31,13 @@ from app.utils.directory_util import get_ignored_patterns
26
31
  from app.utils.logging_utils import logger
27
32
  from app.utils.gitignore_parser import parse_gitignore_patterns
28
33
 
34
+ # Server configuration defaults
35
+ DEFAULT_PORT = 6969
36
+ # For model configurations, see app/agents/model.py
37
+
38
+ class SetModelRequest(BaseModel):
39
+ model_id: str
40
+
29
41
  app = FastAPI()
30
42
 
31
43
  app.add_middleware(
@@ -74,7 +86,6 @@ async def credential_exception_handler(request: Request, exc: CredentialRetrieva
74
86
  headers={"WWW-Authenticate": "Bearer"}
75
87
  )
76
88
 
77
-
78
89
  @app.exception_handler(ClientError)
79
90
  async def boto_client_exception_handler(request: Request, exc: ClientError):
80
91
  error_message = str(exc)
@@ -84,6 +95,13 @@ async def boto_client_exception_handler(request: Request, exc: ClientError):
84
95
  content={"detail": "AWS credentials have expired. Please refresh your credentials."},
85
96
  headers={"WWW-Authenticate": "Bearer"}
86
97
  )
98
+ elif "ValidationException" in error_message:
99
+ logger.error(f"Bedrock validation error: {error_message}")
100
+ return JSONResponse(
101
+ status_code=400,
102
+ content={"error": "validation_error",
103
+ "detail": "Invalid request format for Bedrock service. Please check your input format.",
104
+ "message": error_message})
87
105
  elif "ServiceUnavailableException" in error_message:
88
106
  return JSONResponse(
89
107
  status_code=503,
@@ -94,11 +112,77 @@ async def boto_client_exception_handler(request: Request, exc: ClientError):
94
112
  content={"detail": f"AWS Service Error: {str(exc)}"}
95
113
  )
96
114
 
115
+ @app.exception_handler(ResourceExhausted)
116
+ async def resource_exhausted_handler(request: Request, exc: ResourceExhausted):
117
+ """Handle Google API quota exceeded errors."""
118
+ logger.error(f"Google API quota exceeded: {str(exc)}")
119
+ return JSONResponse(
120
+ status_code=429, # Too Many Requests
121
+ content={
122
+ "error": "quota_exceeded",
123
+ "detail": "API quota has been exceeded. Please try again in a few minutes.",
124
+ "original_error": str(exc)
125
+ }
126
+ )
127
+
128
+ @app.exception_handler(ResourceExhausted)
129
+ async def resource_exhausted_handler(request: Request, exc: ResourceExhausted):
130
+ """Handle Google API quota exceeded errors."""
131
+ logger.error(f"Google API quota exceeded: {str(exc)}")
132
+ return JSONResponse(
133
+ status_code=429, # Too Many Requests
134
+ content={
135
+ "error": "quota_exceeded",
136
+ "detail": "API quota has been exceeded. Please try again in a few minutes.",
137
+ "original_error": str(exc)
138
+ }
139
+ )
140
+
97
141
  @app.exception_handler(Exception)
98
142
  async def general_exception_handler(request: Request, exc: Exception):
99
143
  error_message = str(exc)
100
144
  status_code = 500
101
145
  error_type = "unknown_error"
146
+
147
+ # Check for empty text parameter error from Gemini
148
+ if "Unable to submit request because it has an empty text parameter" in error_message:
149
+ logger.error("Caught empty text parameter error from Gemini")
150
+ return JSONResponse(
151
+ status_code=400,
152
+ content={
153
+ "error": "validation_error",
154
+ "detail": "Empty message content detected. Please provide a question."
155
+ }
156
+ )
157
+
158
+ # Check for Google API quota exceeded error
159
+ if "Resource has been exhausted" in error_message and "check quota" in error_message:
160
+ return JSONResponse(
161
+ status_code=429, # Too Many Requests
162
+ content={
163
+ "error": "quota_exceeded",
164
+ "detail": "API quota has been exceeded. Please try again in a few minutes."
165
+ })
166
+
167
+ # Check for Gemini token limit error
168
+ if isinstance(exc, ChatGoogleGenerativeAIError) and "token count" in error_message:
169
+ return JSONResponse(
170
+ status_code=413,
171
+ content={
172
+ "error": "validation_error",
173
+ "detail": "Selected content is too large for the model. Please reduce the number of files."
174
+ }
175
+ )
176
+
177
+ # Check for Google API quota exceeded error
178
+ if "Resource has been exhausted" in error_message and "check quota" in error_message:
179
+ return JSONResponse(
180
+ status_code=429, # Too Many Requests
181
+ content={
182
+ "error": "quota_exceeded",
183
+ "detail": "API quota has been exceeded. Please try again in a few minutes."
184
+ })
185
+
102
186
  try:
103
187
  # Check if this is a streaming error
104
188
  if isinstance(exc, EventStreamError):
@@ -124,7 +208,21 @@ async def general_exception_handler(request: Request, exc: Exception):
124
208
  logger.error(f"Error in exception handler: {str(e)}", exc_info=True)
125
209
  raise
126
210
 
127
- app.mount("/static", StaticFiles(directory="../templates/static"), name="static")
211
+ # Get the absolute path to the project root directory
212
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
213
+
214
+ # Define paths relative to project root
215
+ static_dir = os.path.join(project_root, "templates", "static")
216
+ testcases_dir = os.path.join(project_root, "tests", "frontend", "testcases")
217
+ templates_dir = os.path.join(project_root, "templates")
218
+
219
+ # Create directories if they don't exist
220
+ os.makedirs(static_dir, exist_ok=True)
221
+ os.makedirs(testcases_dir, exist_ok=True)
222
+ os.makedirs(templates_dir, exist_ok=True)
223
+
224
+ # Mount static files and templates
225
+ app.mount("/static", StaticFiles(directory=static_dir), name="static")
128
226
 
129
227
  # Only mount testcases directory if it exists
130
228
  testcases_dir = "../tests/frontend/testcases"
@@ -133,27 +231,109 @@ if os.path.exists(testcases_dir):
133
231
  else:
134
232
  logger.info(f"Testcases directory '{testcases_dir}' does not exist - skipping mount")
135
233
 
136
- templates = Jinja2Templates(directory="../templates")
234
+ templates = Jinja2Templates(directory=templates_dir)
137
235
 
138
236
  # Add a route for the frontend
139
237
  add_routes(app, agent_executor, disabled_endpoints=["playground"], path="/ziya")
140
238
  # Override the stream endpoint with our error handling
141
239
  @app.post("/ziya/stream")
142
240
  async def stream_endpoint(body: dict):
241
+
242
+ # Debug logging
243
+ logger.info("Stream endpoint request body:")
244
+ logger.info(f"Question: '{body.get('question', 'EMPTY')}'")
245
+ logger.info(f"Chat history length: {len(body.get('chat_history', []))}")
246
+ logger.info(f"Files count: {len(body.get('config', {}).get('files', []))}")
247
+ logger.info(f"Question type: {type(body.get('question', None))}")
248
+
249
+ # Log the first few files
250
+ if 'config' in body and 'files' in body['config']:
251
+ logger.info(f"First few files: {body['config']['files'][:5]}")
252
+
253
+ # Check if the question is empty or missing
254
+ if not body.get("question") or not body.get("question").strip():
255
+ logger.warning("Empty question detected, returning error response")
256
+ error_response = json.dumps({
257
+ "error": "validation_error",
258
+ "detail": "Please provide a question to continue."
259
+ })
260
+
261
+ # Return a properly formatted SSE response with the error
262
+ async def error_stream():
263
+ # Send the error message
264
+ yield f"data: {error_response}\n\n"
265
+ # Wait a moment to ensure the client receives it
266
+ await asyncio.sleep(0.1)
267
+ # Send an end message
268
+ yield "data: [DONE]\n\n"
269
+
270
+ return StreamingResponse(
271
+ error_stream(),
272
+ media_type="text/event-stream",
273
+ headers={"Cache-Control": "no-cache"}
274
+ )
143
275
  try:
276
+ # Check for empty question
277
+ if not body.get("question") or not body.get("question").strip():
278
+ logger.warning("Empty question detected in stream request")
279
+ # Return a friendly error message
280
+ return StreamingResponse(
281
+ iter([f'data: {json.dumps({"error": "validation_error", "detail": "Please enter a question"})}' + '\n\n']),
282
+ media_type="text/event-stream",
283
+ headers={"Cache-Control": "no-cache"}
284
+ )
285
+
286
+ # Check for empty messages in chat history
287
+ if "chat_history" in body:
288
+ cleaned_history = []
289
+ for pair in body["chat_history"]:
290
+ try:
291
+ if not isinstance(pair, (list, tuple)) or len(pair) != 2:
292
+ logger.warning(f"Invalid chat history pair format: {pair}")
293
+ continue
294
+
295
+ human, ai = pair
296
+ if not isinstance(human, str) or not isinstance(ai, str):
297
+ logger.warning(f"Non-string message in pair: human={type(human)}, ai={type(ai)}")
298
+ continue
299
+
300
+ if human.strip() and ai.strip():
301
+ cleaned_history.append((human.strip(), ai.strip()))
302
+ else:
303
+ logger.warning(f"Empty message in pair: {pair}")
304
+ except Exception as e:
305
+ logger.error(f"Error processing chat history pair: {str(e)}")
306
+
307
+ logger.debug(f"Cleaned chat history from {len(body['chat_history'])} to {len(cleaned_history)} pairs")
308
+ body["chat_history"] = cleaned_history
309
+ logger.debug(f"Cleaned chat history: {json.dumps(cleaned_history)}")
310
+
144
311
  logger.info("Starting stream endpoint with body size: %d", len(str(body)))
145
312
  # Define the streaming response with proper error handling
146
313
  async def error_handled_stream():
314
+ response = None
147
315
  try:
316
+ # Convert to ChatPromptValue before streaming
317
+ if isinstance(body, dict) and "messages" in body:
318
+ from langchain_core.prompt_values import ChatPromptValue
319
+ from langchain_core.messages import HumanMessage
320
+ body["messages"] = [HumanMessage(content=msg) for msg in body["messages"]]
321
+ body = ChatPromptValue(messages=body["messages"])
148
322
  # Create the iterator inside the error handling context
149
- iterator = agent_executor.astream_log(body, {})
323
+ iterator = agent_executor.astream_log(body)
150
324
  async for chunk in iterator:
151
325
  logger.info("Processing chunk: %s",
152
326
  chunk if isinstance(chunk, dict) else chunk[:200] + "..." if len(chunk) > 200 else chunk)
153
327
  if isinstance(chunk, dict) and "error" in chunk:
154
328
  # Format error as SSE message
155
329
  yield f"data: {json.dumps(chunk)}\n\n"
156
- logger.info("Sent error message: %s", error_msg)
330
+ # Update file state before returning
331
+ update_and_return(body)
332
+ logger.info(f"Sent error message: {chunk}")
333
+ return
334
+ elif isinstance(chunk, Generation) and hasattr(chunk, 'text') and "quota_exceeded" in chunk.text:
335
+ yield f"data: {chunk.text}\n\n"
336
+ update_and_return(body)
157
337
  return
158
338
  else:
159
339
  try:
@@ -166,9 +346,30 @@ async def stream_endpoint(body: dict):
166
346
  "detail": "Selected content is too large for the model. Please reduce the number of files."
167
347
  }
168
348
  yield f"data: {json.dumps(error_msg)}\n\n"
349
+ update_and_return(body)
169
350
  await response.flush()
170
351
  logger.info("Sent EventStreamError message: %s", error_msg)
171
352
  return
353
+ except ChatGoogleGenerativeAIError as e:
354
+ if "token count" in str(e):
355
+ error_msg = {
356
+ "error": "validation_error",
357
+ "detail": "Selected content is too large for the model. Please reduce the number of files."
358
+ }
359
+ yield f"data: {json.dumps(error_msg)}\n\n"
360
+ update_and_return(body)
361
+ await response.flush()
362
+ logger.info("Sent token limit error message: %s", error_msg)
363
+ return
364
+ except ResourceExhausted as e:
365
+ error_msg = {
366
+ "error": "quota_exceeded",
367
+ "detail": "API quota has been exceeded. Please try again in a few minutes."
368
+ }
369
+ yield f"data: {json.dumps(error_msg)}\n\n"
370
+ update_and_return(body)
371
+ logger.error(f"Caught ResourceExhausted error: {str(e)}")
372
+ return
172
373
  except EventStreamError as e:
173
374
  if "validationException" in str(e):
174
375
  error_msg = {
@@ -176,15 +377,18 @@ async def stream_endpoint(body: dict):
176
377
  "detail": "Selected content is too large for the model. Please reduce the number of files."
177
378
  }
178
379
  yield f"data: {json.dumps(error_msg)}\n\n"
380
+ update_and_return(body)
179
381
  await response.flush()
180
382
  return
181
383
  raise
384
+ finally:
385
+ update_and_return(body)
182
386
  return StreamingResponse(error_handled_stream(), media_type="text/event-stream", headers={"Cache-Control": "no-cache"})
183
387
  except Exception as e:
184
388
  logger.error(f"Error in stream endpoint: {str(e)}")
185
389
  error_msg = {"error": "stream_error", "detail": str(e)}
186
390
  logger.error(f"Sending error response: {error_msg}")
187
- logger.error(f"Sending error response: {error_msg}")
391
+ update_and_return(body)
188
392
  return StreamingResponse(iter([f"data: {json.dumps(error_msg)}\n\n"]), media_type="text/event-stream", headers={"Cache-Control": "no-cache"})
189
393
 
190
394
 
@@ -287,24 +491,145 @@ async def get_folders():
287
491
  return get_cached_folder_structure(user_codebase_dir, ignored_patterns, max_depth)
288
492
 
289
493
  @app.get('/api/default-included-folders')
290
- def get_default_included_folders():
494
+ def get_model_id():
291
495
  return {'defaultIncludedFolders': []}
292
496
 
497
+ @app.get('/api/current-model')
498
+ def get_current_model():
499
+ """Get detailed information about the currently active model."""
500
+ logger.info(
501
+ "Current model info request: %s",
502
+ { 'model_id': model.model_id,
503
+ 'endpoint': os.environ.get("ZIYA_ENDPOINT", "bedrock")
504
+ })
505
+
506
+ # Get actual model settings
507
+ model_kwargs = {}
508
+ if hasattr(model, 'model') and hasattr(model.model, 'model_kwargs'):
509
+ model_kwargs = model.model.model_kwargs
510
+ elif hasattr(model, 'model_kwargs'):
511
+ model_kwargs = model.model_kwargs
512
+
513
+ logger.info("Current model configuration:")
514
+ logger.info(f" Model ID: {model.model_id}")
515
+ logger.info(f" Temperature: {model_kwargs.get('temperature', 'Not set')} (env: {os.environ.get('ZIYA_TEMPERATURE', 'Not set')})")
516
+ logger.info(f" Top K: {model_kwargs.get('top_k', 'Not set')} (env: {os.environ.get('ZIYA_TOP_K', 'Not set')})")
517
+ logger.info(f" Max tokens: {model_kwargs.get('max_tokens', 'Not set')} (env: {os.environ.get('ZIYA_MAX_OUTPUT_TOKENS', 'Not set')})")
518
+ logger.info(f" Thinking mode: {os.environ.get('ZIYA_THINKING_MODE', 'Not set')}")
519
+
520
+
521
+ return {
522
+ 'model_id': model.model_id,
523
+ 'endpoint': os.environ.get("ZIYA_ENDPOINT", "bedrock"),
524
+ 'settings': {
525
+ 'temperature': model_kwargs.get('temperature',
526
+ float(os.environ.get("ZIYA_TEMPERATURE", 0.3))),
527
+ 'max_output_tokens': model_kwargs.get('max_tokens',
528
+ int(os.environ.get("ZIYA_MAX_OUTPUT_TOKENS", 4096))),
529
+ 'top_k': model_kwargs.get('top_k',
530
+ int(os.environ.get("ZIYA_TOP_K", 15))),
531
+ 'thinking_mode': os.environ.get("ZIYA_THINKING_MODE") == "1"
532
+
533
+ }
534
+ }
535
+
293
536
  @app.get('/api/model-id')
294
537
  def get_model_id():
295
- # Get the model ID from the configured Bedrock client
296
- return {'model_id': model.model_id.split(':')[0].split('/')[-1]}
538
+ if os.environ.get("ZIYA_ENDPOINT") == "google":
539
+ model_name = os.environ.get("ZIYA_MODEL", "gemini-pro")
540
+ return {'model_id': model_name}
541
+ elif os.environ.get("ZIYA_MODEL"):
542
+ return {'model_id': os.environ.get("ZIYA_MODEL")}
543
+ else:
544
+ # Bedrock
545
+ return {'model_id': model.model_id.split(':')[0].split('/')[-1]}
546
+
547
+ @app.post('/api/set-model')
548
+ async def set_model(request: SetModelRequest):
549
+ """Set the active model for the current endpoint."""
550
+ try:
551
+ model_id = request.model_id
552
+ if not model_id:
553
+ logger.error("Empty model ID provided")
554
+ raise HTTPException(status_code=400, detail="Model ID is required")
555
+
556
+ # Update environment variable
557
+ os.environ["ZIYA_MODEL"] = model_id
558
+ logger.info(f"Setting model to: {model_id}")
559
+
560
+ # Reinitialize the model
561
+ try:
562
+ logger.info(f"Reinitializing model with ID: {model_id}")
563
+ new_model = ModelManager.initialize_model(force_reinit=True)
564
+ new_model.model_id = model_id # Ensure model ID is set correctly
565
+
566
+ # Update the global model instance
567
+ global model
568
+ model = RetryingChatBedrock(new_model)
569
+
570
+ return {"status": "success", "model": model_id}
571
+ except Exception as e:
572
+ logger.error(f"Failed to initialize model: {str(e)}")
573
+ raise HTTPException(status_code=500, detail=f"Failed to initialize model: {str(e)}")
574
+
575
+ except Exception as e:
576
+ raise HTTPException(status_code=500, detail=str(e))
577
+
578
+ @app.get('/api/available-models')
579
+ def get_available_models():
580
+ """Get list of available models for the current endpoint."""
581
+ endpoint = os.environ.get("ZIYA_ENDPOINT", "bedrock")
582
+ try:
583
+ models = []
584
+ for name, config in ModelManager.MODEL_CONFIGS[endpoint].items():
585
+ models.append({
586
+ "id": config["model_id"],
587
+ "name": name
588
+ })
589
+ return models
590
+ except Exception as e:
591
+ raise HTTPException(status_code=500, detail=str(e))
592
+ return {'model_id': model.model_id.split(':')[0].split('/')[-1]}
593
+
594
+ @app.get('/api/model-capabilities')
595
+ def get_model_capabilities(model: str = None):
596
+ """Get the capabilities of the current model."""
597
+ endpoint = os.environ.get("ZIYA_ENDPOINT", "bedrock")
598
+ # If model parameter is provided, get capabilities for that model
599
+ # Otherwise use current model
600
+ model_name = model if model else os.environ.get("ZIYA_MODEL")
601
+
602
+ try:
603
+ model_config = ModelManager.get_model_config(endpoint, model_name)
604
+ capabilities = {
605
+ "supports_thinking": model_config.get("supports_thinking", False),
606
+ "max_output_tokens": model_config.get("max_output_tokens", 4096),
607
+ "temperature_range": {"min": 0, "max": 1, "default": model_config.get("temperature", 0.3)},
608
+ "top_k_range": {"min": 0, "max": 500, "default": model_config.get("top_k", 15)} if endpoint == "bedrock" else None
609
+ }
610
+ return capabilities
611
+ except Exception as e:
612
+ logger.error(f"Error getting model capabilities: {str(e)}")
613
+ return {"error": str(e)}
297
614
 
298
615
  class ApplyChangesRequest(BaseModel):
299
616
  diff: str
300
617
  filePath: str
301
618
 
619
+ class ModelSettingsRequest(BaseModel):
620
+ temperature: float = Field(default=0.3, ge=0, le=1)
621
+ top_k: int = Field(default=15, ge=0, le=500)
622
+ max_output_tokens: int = Field(default=4096, ge=1, le=128000)
623
+ thinking_mode: bool = Field(default=False)
624
+
625
+
302
626
  class TokenCountRequest(BaseModel):
303
627
  text: str
304
628
 
305
629
  def count_tokens_fallback(text: str) -> int:
306
630
  """Fallback methods for counting tokens when primary method fails."""
307
631
  try:
632
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
308
633
  # First try using tiktoken directly with cl100k_base (used by Claude)
309
634
  encoding = tiktoken.get_encoding("cl100k_base")
310
635
  return len(encoding.encode(text))
@@ -346,6 +671,70 @@ async def count_tokens(request: TokenCountRequest) -> Dict[str, int]:
346
671
  # Return 0 in case of error to avoid breaking the frontend
347
672
  return {"token_count": 0}
348
673
 
674
+ @app.post('/api/model-settings')
675
+ async def update_model_settings(settings: ModelSettingsRequest):
676
+ global model
677
+ try:
678
+ # Log the requested settings
679
+ logger.info(f"Requested model settings update:")
680
+ logger.info(f" Temperature: {settings.temperature}")
681
+ logger.info(f" Top K: {settings.top_k}")
682
+ logger.info(f" Max Output Tokens: {settings.max_output_tokens}")
683
+ logger.info(f" Thinking Mode: {settings.thinking_mode}")
684
+
685
+ # Store settings in environment variables for the agent to use
686
+ os.environ["ZIYA_TEMPERATURE"] = str(settings.temperature)
687
+ os.environ["ZIYA_TOP_K"] = str(settings.top_k)
688
+ os.environ["ZIYA_MAX_OUTPUT_TOKENS"] = str(settings.max_output_tokens)
689
+ os.environ["ZIYA_THINKING_MODE"] = "1" if settings.thinking_mode else "0"
690
+
691
+ # Update the model's kwargs directly
692
+ if hasattr(model, 'model'):
693
+ # For wrapped models (e.g., RetryingChatBedrock)
694
+ if hasattr(model.model, 'model_kwargs'):
695
+ model.model.model_kwargs.update({
696
+ 'temperature': settings.temperature,
697
+ 'top_k': settings.top_k,
698
+ 'max_tokens': settings.max_output_tokens
699
+ })
700
+ elif hasattr(model, 'model_kwargs'):
701
+ # For direct model instances
702
+ model.model_kwargs.update({
703
+ 'temperature': settings.temperature,
704
+ 'top_k': settings.top_k,
705
+ 'max_tokens': settings.max_output_tokens
706
+ })
707
+
708
+ # Force model reinitialization to apply new settings
709
+ from app.agents.models import ModelManager
710
+ model = ModelManager.initialize_model(force_reinit=True)
711
+ model.model_id = os.environ.get("ZIYA_MODEL", model.model_id)
712
+
713
+ # Get the model's current settings for verification
714
+ model_kwargs = {}
715
+ if hasattr(model, 'model') and hasattr(model.model, 'model_kwargs'):
716
+ model_kwargs = model.model.model_kwargs
717
+ elif hasattr(model, 'model_kwargs'):
718
+ model_kwargs = model.model_kwargs
719
+
720
+ logger.info("Current model settings after update:")
721
+ logger.info(f" Model kwargs temperature: {model_kwargs.get('temperature', 'Not set')}")
722
+ logger.info(f" Model kwargs top_k: {model_kwargs.get('top_k', 'Not set')}")
723
+ logger.info(f" Model kwargs max_tokens: {model_kwargs.get('max_tokens', 'Not set')}")
724
+ logger.info(f" Environment ZIYA_THINKING_MODE: {os.environ.get('ZIYA_THINKING_MODE')}")
725
+
726
+ return {
727
+ 'status': 'success',
728
+ 'message': 'Model settings updated',
729
+ 'settings': model_kwargs
730
+ }
731
+ except Exception as e:
732
+ logger.error(f"Error updating model settings: {str(e)}", exc_info=True)
733
+ raise HTTPException(
734
+ status_code=500,
735
+ detail=f"Error updating model settings: {str(e)}"
736
+ )
737
+
349
738
  @app.post('/api/apply-changes')
350
739
  async def apply_changes(request: ApplyChangesRequest):
351
740
  try:
@@ -405,4 +794,4 @@ async def apply_changes(request: ApplyChangesRequest):
405
794
  )
406
795
 
407
796
  if __name__ == "__main__":
408
- uvicorn.run(app, host="0.0.0.0", port=8000)
797
+ uvicorn.run(app, host="0.0.0.0", port=DEFAULT_PORT)