agent-starter-pack 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agent-starter-pack might be problematic. Click here for more details.

Files changed (64) hide show
  1. {agent_starter_pack-0.2.3.dist-info → agent_starter_pack-0.3.1.dist-info}/METADATA +8 -4
  2. {agent_starter_pack-0.2.3.dist-info → agent_starter_pack-0.3.1.dist-info}/RECORD +61 -46
  3. agents/adk_base/README.md +14 -0
  4. agents/adk_base/app/agent.py +66 -0
  5. agents/adk_base/notebooks/adk_app_testing.ipynb +305 -0
  6. agents/adk_base/template/.templateconfig.yaml +21 -0
  7. agents/adk_base/tests/integration/test_agent.py +58 -0
  8. agents/agentic_rag/README.md +1 -0
  9. agents/agentic_rag/app/agent.py +44 -89
  10. agents/agentic_rag/app/templates.py +0 -25
  11. agents/agentic_rag/notebooks/adk_app_testing.ipynb +305 -0
  12. agents/agentic_rag/template/.templateconfig.yaml +3 -1
  13. agents/agentic_rag/tests/integration/test_agent.py +34 -27
  14. agents/langgraph_base_react/README.md +1 -1
  15. agents/langgraph_base_react/template/.templateconfig.yaml +1 -1
  16. src/base_template/Makefile +9 -0
  17. src/base_template/README.md +1 -1
  18. src/base_template/app/__init__.py +3 -0
  19. src/base_template/app/utils/tracing.py +12 -2
  20. src/base_template/app/utils/typing.py +54 -4
  21. src/base_template/deployment/terraform/dev/variables.tf +4 -0
  22. src/base_template/deployment/terraform/dev/vars/env.tfvars +0 -3
  23. src/base_template/deployment/terraform/variables.tf +4 -0
  24. src/base_template/deployment/terraform/vars/env.tfvars +0 -4
  25. src/base_template/pyproject.toml +5 -3
  26. src/{deployment_targets/agent_engine → base_template}/tests/unit/test_dummy.py +2 -1
  27. src/cli/commands/create.py +10 -2
  28. src/cli/commands/setup_cicd.py +3 -0
  29. src/cli/utils/gcp.py +1 -1
  30. src/cli/utils/template.py +32 -25
  31. src/data_ingestion/data_ingestion_pipeline/components/ingest_data.py +2 -1
  32. src/deployment_targets/agent_engine/app/agent_engine_app.py +62 -11
  33. src/deployment_targets/agent_engine/app/utils/gcs.py +1 -1
  34. src/deployment_targets/agent_engine/tests/integration/test_agent_engine_app.py +63 -0
  35. src/deployment_targets/agent_engine/tests/load_test/load_test.py +9 -2
  36. src/deployment_targets/cloud_run/app/server.py +41 -15
  37. src/deployment_targets/cloud_run/tests/integration/test_server_e2e.py +60 -3
  38. src/deployment_targets/cloud_run/tests/load_test/README.md +1 -1
  39. src/deployment_targets/cloud_run/tests/load_test/load_test.py +57 -24
  40. src/frontends/live_api_react/frontend/package-lock.json +3 -3
  41. src/frontends/streamlit_adk/frontend/side_bar.py +214 -0
  42. src/frontends/streamlit_adk/frontend/streamlit_app.py +314 -0
  43. src/frontends/streamlit_adk/frontend/style/app_markdown.py +37 -0
  44. src/frontends/streamlit_adk/frontend/utils/chat_utils.py +84 -0
  45. src/frontends/streamlit_adk/frontend/utils/local_chat_history.py +110 -0
  46. src/frontends/streamlit_adk/frontend/utils/message_editing.py +61 -0
  47. src/frontends/streamlit_adk/frontend/utils/multimodal_utils.py +223 -0
  48. src/frontends/streamlit_adk/frontend/utils/stream_handler.py +311 -0
  49. src/frontends/streamlit_adk/frontend/utils/title_summary.py +129 -0
  50. src/resources/locks/uv-adk_base-agent_engine.lock +5335 -0
  51. src/resources/locks/uv-adk_base-cloud_run.lock +5927 -0
  52. src/resources/locks/uv-agentic_rag-agent_engine.lock +882 -676
  53. src/resources/locks/uv-agentic_rag-cloud_run.lock +1014 -835
  54. src/resources/locks/uv-crewai_coding_crew-agent_engine.lock +712 -606
  55. src/resources/locks/uv-crewai_coding_crew-cloud_run.lock +770 -672
  56. src/resources/locks/uv-langgraph_base_react-agent_engine.lock +602 -529
  57. src/resources/locks/uv-langgraph_base_react-cloud_run.lock +763 -665
  58. src/resources/locks/uv-live_api-cloud_run.lock +760 -662
  59. agents/agentic_rag/notebooks/evaluating_langgraph_agent.ipynb +0 -1561
  60. src/base_template/tests/unit/test_utils/test_tracing_exporter.py +0 -140
  61. src/deployment_targets/cloud_run/tests/unit/test_server.py +0 -124
  62. {agent_starter_pack-0.2.3.dist-info → agent_starter_pack-0.3.1.dist-info}/WHEEL +0 -0
  63. {agent_starter_pack-0.2.3.dist-info → agent_starter_pack-0.3.1.dist-info}/entry_points.txt +0 -0
  64. {agent_starter_pack-0.2.3.dist-info → agent_starter_pack-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -54,7 +54,14 @@ class ChatStreamUser(HttpUser):
54
54
  """Simulates a chat stream interaction."""
55
55
  headers = {"Content-Type": "application/json"}
56
56
  headers["Authorization"] = f"Bearer {os.environ['_AUTH_TOKEN']}"
57
-
57
+ {% if "adk" in cookiecutter.tags %}
58
+ data = {
59
+ "input": {
60
+ "message": "What's the weather in San Francisco?",
61
+ "user_id": "test",
62
+ }
63
+ }
64
+ {% else %}
58
65
  data = {
59
66
  "input": {
60
67
  "input": {
@@ -69,7 +76,7 @@ class ChatStreamUser(HttpUser):
69
76
  },
70
77
  }
71
78
  }
72
-
79
+ {% endif %}
73
80
  start_time = time.time()
74
81
  with self.client.post(
75
82
  url_path,
@@ -11,7 +11,32 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ {% if "adk" in cookiecutter.tags %}
15
+ import os
16
+
17
+ from fastapi import FastAPI
18
+ from google.adk.cli.fast_api import get_fast_api_app
19
+ from google.cloud import logging as google_cloud_logging
20
+ from opentelemetry import trace
21
+ from opentelemetry.sdk.trace import TracerProvider, export
22
+
23
+ from app.utils.tracing import CloudTraceLoggingSpanExporter
24
+ from app.utils.typing import Feedback
25
+
26
+ logging_client = google_cloud_logging.Client()
27
+ logger = logging_client.logger(__name__)
28
+
29
+ provider = TracerProvider()
30
+ processor = export.BatchSpanProcessor(CloudTraceLoggingSpanExporter())
31
+ provider.add_span_processor(processor)
32
+ trace.set_tracer_provider(provider)
14
33
 
34
+ AGENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
35
+ app: FastAPI = get_fast_api_app(agent_dir=AGENT_DIR, web=False)
36
+
37
+ app.title = "{{cookiecutter.project_name}}"
38
+ app.description = "API for interacting with the Agent {{cookiecutter.project_name}}"
39
+ {%- else %}
15
40
  import logging
16
41
  import os
17
42
  from collections.abc import Generator
@@ -40,7 +65,7 @@ try:
40
65
  app_name=app.title,
41
66
  disable_batch=False,
42
67
  exporter=CloudTraceLoggingSpanExporter(),
43
- instruments={% raw %}{{% endraw %}{%- for instrumentation in cookiecutter.otel_instrumentations %}{{ instrumentation }}{% if not loop.last %}, {% endif %}{%- endfor %}{% raw %}}{% endraw %},
68
+ instruments={Instruments.LANGCHAIN, Instruments.CREW},
44
69
  )
45
70
  except Exception as e:
46
71
  logging.error("Failed to initialize Telemetry: %s", str(e))
@@ -91,20 +116,6 @@ def redirect_root_to_docs() -> RedirectResponse:
91
116
  return RedirectResponse(url="/docs")
92
117
 
93
118
 
94
- @app.post("/feedback")
95
- def collect_feedback(feedback: Feedback) -> dict[str, str]:
96
- """Collect and log feedback.
97
-
98
- Args:
99
- feedback: The feedback data to log
100
-
101
- Returns:
102
- Success message
103
- """
104
- logger.log_struct(feedback.model_dump(), severity="INFO")
105
- return {"status": "success"}
106
-
107
-
108
119
  @app.post("/stream_messages")
109
120
  def stream_chat_events(request: Request) -> StreamingResponse:
110
121
  """Stream chat events in response to an input request.
@@ -119,6 +130,21 @@ def stream_chat_events(request: Request) -> StreamingResponse:
119
130
  stream_messages(input=request.input, config=request.config),
120
131
  media_type="text/event-stream",
121
132
  )
133
+ {%- endif %}
134
+
135
+
136
+ @app.post("/feedback")
137
+ def collect_feedback(feedback: Feedback) -> dict[str, str]:
138
+ """Collect and log feedback.
139
+
140
+ Args:
141
+ feedback: The feedback data to log
142
+
143
+ Returns:
144
+ Success message
145
+ """
146
+ logger.log_struct(feedback.model_dump(), severity="INFO")
147
+ return {"status": "success"}
122
148
 
123
149
 
124
150
  # Main execution
@@ -32,7 +32,11 @@ logging.basicConfig(level=logging.INFO)
32
32
  logger = logging.getLogger(__name__)
33
33
 
34
34
  BASE_URL = "http://127.0.0.1:8000/"
35
+ {%- if "adk" in cookiecutter.tags %}
36
+ STREAM_URL = BASE_URL + "run_sse"
37
+ {%- else %}
35
38
  STREAM_URL = BASE_URL + "stream_messages"
39
+ {%- endif %}
36
40
  FEEDBACK_URL = BASE_URL + "feedback"
37
41
 
38
42
  HEADERS = {"Content-Type": "application/json"}
@@ -116,23 +120,72 @@ def server_fixture(request: Any) -> Iterator[subprocess.Popen[str]]:
116
120
  def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None:
117
121
  """Test the chat stream functionality."""
118
122
  logger.info("Starting chat stream test")
123
+ {% if "adk" in cookiecutter.tags %}
124
+ # Create session first
125
+ user_id = "user_123"
126
+ session_id = "session_abc"
127
+ session_data = {"state": {"preferred_language": "English", "visit_count": 5}}
128
+ session_response = requests.post(
129
+ f"{BASE_URL}/apps/app/users/{user_id}/sessions/{session_id}",
130
+ headers=HEADERS,
131
+ json=session_data,
132
+ timeout=10,
133
+ )
134
+ assert session_response.status_code == 200
119
135
 
136
+ # Then send chat message
137
+ data = {
138
+ "app_name": "app",
139
+ "user_id": user_id,
140
+ "session_id": session_id,
141
+ "new_message": {
142
+ "role": "user",
143
+ "parts": [{"text": "What's the weather in San Francisco?"}],
144
+ },
145
+ "streaming": True,
146
+ }
147
+ {% else %}
120
148
  data = {
121
149
  "input": {
122
150
  "messages": [
123
151
  {"type": "human", "content": "Hello, AI!"},
124
152
  {"type": "ai", "content": "Hello!"},
125
- {"type": "human", "content": "What is the weather in NY?"},
153
+ {"type": "human", "content": "Who are you?"},
126
154
  ]
127
155
  },
128
156
  "config": {"metadata": {"user_id": "test-user", "session_id": "test-session"}},
129
157
  }
130
-
158
+ {% endif %}
131
159
  response = requests.post(
132
160
  STREAM_URL, headers=HEADERS, json=data, stream=True, timeout=10
133
161
  )
134
162
  assert response.status_code == 200
135
163
 
164
+ {%- if "adk" in cookiecutter.tags %}
165
+ # Parse SSE events from response
166
+ events = []
167
+ for line in response.iter_lines():
168
+ if line:
169
+ # SSE format is "data: {json}"
170
+ line_str = line.decode("utf-8")
171
+ if line_str.startswith("data: "):
172
+ event_json = line_str[6:] # Remove "data: " prefix
173
+ event = json.loads(event_json)
174
+ events.append(event)
175
+
176
+ assert events, "No events received from stream"
177
+ # Check for valid content in the response
178
+ has_text_content = False
179
+ for event in events:
180
+ content = event.get("content")
181
+ if (
182
+ content is not None
183
+ and content.get("parts")
184
+ and any(part.get("text") for part in content["parts"])
185
+ ):
186
+ has_text_content = True
187
+ break
188
+ {%- else %}
136
189
  events = [json.loads(line) for line in response.iter_lines() if line]
137
190
  assert events, "No events received from stream"
138
191
 
@@ -155,12 +208,12 @@ def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None:
155
208
  has_content = True
156
209
  break
157
210
  assert has_content, "At least one message should have content"
211
+ {%- endif %}
158
212
 
159
213
 
160
214
  def test_chat_stream_error_handling(server_fixture: subprocess.Popen[str]) -> None:
161
215
  """Test the chat stream error handling."""
162
216
  logger.info("Starting chat stream error handling test")
163
-
164
217
  data = {
165
218
  "input": {"messages": [{"type": "invalid_type", "content": "Cause an error"}]}
166
219
  }
@@ -182,7 +235,11 @@ def test_collect_feedback(server_fixture: subprocess.Popen[str]) -> None:
182
235
  # Create sample feedback data
183
236
  feedback_data = {
184
237
  "score": 4,
238
+ {%- if "adk" in cookiecutter.tags %}
239
+ "invocation_id": str(uuid.uuid4()),
240
+ {%- else %}
185
241
  "run_id": str(uuid.uuid4()),
242
+ {%- endif %}
186
243
  "text": "Great response!",
187
244
  }
188
245
 
@@ -28,7 +28,7 @@ Trigger the Locust load test with the following command:
28
28
  locust -f tests/load_test/load_test.py \
29
29
  -H http://127.0.0.1:8000 \
30
30
  --headless \
31
- -t 30s -u 60 -r 2 \
31
+ -t 30s -u 10 -r 2 \
32
32
  --csv=tests/load_test/.results/results \
33
33
  --html=tests/load_test/.results/report.html
34
34
  ```
@@ -15,9 +15,20 @@
15
15
  import json
16
16
  import os
17
17
  import time
18
+ {%- if "adk" in cookiecutter.tags %}
19
+ import uuid
18
20
 
21
+ import requests
19
22
  from locust import HttpUser, between, task
23
+ {%- else %}
20
24
 
25
+ from locust import HttpUser, between, task
26
+ {%- endif %}
27
+ {% if "adk" in cookiecutter.tags %}
28
+ ENDPOINT = "/run_sse"
29
+ {% else %}
30
+ ENDPOINT = "/stream_messages"
31
+ {% endif %}
21
32
 
22
33
  class ChatStreamUser(HttpUser):
23
34
  """Simulates a user interacting with the chat stream API."""
@@ -30,7 +41,30 @@ class ChatStreamUser(HttpUser):
30
41
  headers = {"Content-Type": "application/json"}
31
42
  if os.environ.get("_ID_TOKEN"):
32
43
  headers["Authorization"] = f"Bearer {os.environ['_ID_TOKEN']}"
44
+ {%- if "adk" in cookiecutter.tags %}
45
+ # Create session first
46
+ user_id = f"user_{uuid.uuid4()}"
47
+ session_id = f"session_{uuid.uuid4()}"
48
+ session_data = {"state": {"preferred_language": "English", "visit_count": 5}}
49
+ requests.post(
50
+ f"{self.client.base_url}/apps/app/users/{user_id}/sessions/{session_id}",
51
+ headers=headers,
52
+ json=session_data,
53
+ timeout=10,
54
+ )
33
55
 
56
+ # Send chat message
57
+ data = {
58
+ "app_name": "app",
59
+ "user_id": user_id,
60
+ "session_id": session_id,
61
+ "new_message": {
62
+ "role": "user",
63
+ "parts": [{"text": "What's the weather in San Francisco?"}],
64
+ },
65
+ "streaming": True,
66
+ }
67
+ {%- else %}
34
68
  data = {
35
69
  "input": {
36
70
  "messages": [
@@ -43,43 +77,42 @@ class ChatStreamUser(HttpUser):
43
77
  "metadata": {"user_id": "test-user", "session_id": "test-session"}
44
78
  },
45
79
  }
46
-
80
+ {%- endif %}
47
81
  start_time = time.time()
48
82
 
49
83
  with self.client.post(
50
- "/stream_messages",
84
+ ENDPOINT,
85
+ name=f"{ENDPOINT} message",
51
86
  headers=headers,
52
87
  json=data,
53
88
  catch_response=True,
54
- name="/stream_messages first message",
55
89
  stream=True,
90
+ params={"alt": "sse"},
56
91
  ) as response:
57
92
  if response.status_code == 200:
58
93
  events = []
59
94
  for line in response.iter_lines():
60
95
  if line:
96
+ {%- if "adk" in cookiecutter.tags %}
97
+ # SSE format is "data: {json}"
98
+ line_str = line.decode("utf-8")
99
+ if line_str.startswith("data: "):
100
+ event_json = line_str[6:] # Remove "data: " prefix
101
+ event = json.loads(event_json)
102
+ events.append(event)
103
+ {%- else %}
61
104
  event = json.loads(line)
62
105
  events.append(event)
63
- for chunk in event:
64
- if (
65
- isinstance(chunk, dict)
66
- and chunk.get("type") == "constructor"
67
- ):
68
- if not chunk.get("kwargs", {}).get("content"):
69
- continue
70
- response.success()
71
- end_time = time.time()
72
- total_time = end_time - start_time
73
- self.environment.events.request.fire(
74
- request_type="POST",
75
- name="/stream_messages end",
76
- response_time=total_time
77
- * 1000, # Convert to milliseconds
78
- response_length=len(json.dumps(events)),
79
- response=response,
80
- context={},
81
- )
82
- return
83
- response.failure("No valid response content received")
106
+ {%- endif %}
107
+ end_time = time.time()
108
+ total_time = end_time - start_time
109
+ self.environment.events.request.fire(
110
+ request_type="POST",
111
+ name=f"{ENDPOINT} end",
112
+ response_time=total_time * 1000, # Convert to milliseconds
113
+ response_length=len(json.dumps(events)),
114
+ response=response,
115
+ context={},
116
+ )
84
117
  else:
85
118
  response.failure(f"Unexpected status code: {response.status_code}")
@@ -2027,9 +2027,9 @@
2027
2027
  }
2028
2028
  },
2029
2029
  "node_modules/@babel/runtime": {
2030
- "version": "7.26.0",
2031
- "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.26.0.tgz",
2032
- "integrity": "sha512-FDSOghenHTiToteC/QRlv2q3DhPZ/oOXTBoirfWNx1Cx3TMVcGWQtMMmQcSvb/JjpNeGzx8Pq/b4fKEJuWm1sw==",
2030
+ "version": "7.27.0",
2031
+ "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.0.tgz",
2032
+ "integrity": "sha512-VtPOkrdPHZsKc/clNqyi9WUA8TINkZ4cGk63UUE3u4pmB2k+ZMQRDuIOagv8UVd6j7k0T3+RRIb7beKTebNbcw==",
2033
2033
  "license": "MIT",
2034
2034
  "dependencies": {
2035
2035
  "regenerator-runtime": "^0.14.0"
@@ -0,0 +1,214 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # ruff: noqa: RUF015
16
+ import json
17
+ import os
18
+ import uuid
19
+ from typing import Any
20
+
21
+ from frontend.utils.chat_utils import save_chat
22
+ from frontend.utils.multimodal_utils import (
23
+ HELP_GCS_CHECKBOX,
24
+ HELP_MESSAGE_MULTIMODALITY,
25
+ upload_files_to_gcs,
26
+ )
27
+
28
+ EMPTY_CHAT_NAME = "Empty chat"
29
+ NUM_CHAT_IN_RECENT = 3
30
+ DEFAULT_BASE_URL = "http://localhost:8000/"
31
+
32
+ DEFAULT_REMOTE_AGENT_ENGINE_ID = "N/A"
33
+ if os.path.exists("deployment_metadata.json"):
34
+ with open("deployment_metadata.json") as f:
35
+ DEFAULT_REMOTE_AGENT_ENGINE_ID = json.load(f)["remote_agent_engine_id"]
36
+ DEFAULT_AGENT_CALLABLE_PATH = "app.agent_engine_app.AgentEngineApp"
37
+
38
+
39
+ class SideBar:
40
+ """Manages the sidebar components of the Streamlit application."""
41
+
42
+ def __init__(self, st: Any) -> None:
43
+ """
44
+ Initialize the SideBar.
45
+
46
+ Args:
47
+ st (Any): The Streamlit object for rendering UI components.
48
+ """
49
+ self.st = st
50
+
51
+ def init_side_bar(self) -> None:
52
+ """Initialize and render the sidebar components."""
53
+ with self.st.sidebar:
54
+ default_agent_type = (
55
+ "Remote URL" if os.path.exists("Dockerfile") else "Local Agent"
56
+ )
57
+ use_agent_path = self.st.selectbox(
58
+ "Select Agent Type",
59
+ ["Local Agent", "Remote Agent Engine ID", "Remote URL"],
60
+ index=["Local Agent", "Remote Agent Engine ID", "Remote URL"].index(
61
+ default_agent_type
62
+ ),
63
+ help="'Local Agent' uses a local implementation, 'Remote Agent Engine ID' connects to a deployed Vertex AI agent, and 'Remote URL' connects to a custom endpoint.",
64
+ )
65
+
66
+ if use_agent_path == "Local Agent":
67
+ self.agent_callable_path = self.st.text_input(
68
+ label="Agent Callable Path",
69
+ value=os.environ.get(
70
+ "AGENT_CALLABLE_PATH", DEFAULT_AGENT_CALLABLE_PATH
71
+ ),
72
+ )
73
+ self.remote_agent_engine_id = None
74
+ self.url_input_field = None
75
+ self.should_authenticate_request = False
76
+ elif use_agent_path == "Remote Agent Engine ID":
77
+ self.remote_agent_engine_id = self.st.text_input(
78
+ label="Remote Agent Engine ID",
79
+ value=os.environ.get(
80
+ "REMOTE_AGENT_ENGINE_ID", DEFAULT_REMOTE_AGENT_ENGINE_ID
81
+ ),
82
+ )
83
+ self.agent_callable_path = None
84
+ self.url_input_field = None
85
+ self.should_authenticate_request = False
86
+ else:
87
+ self.url_input_field = self.st.text_input(
88
+ label="Service URL",
89
+ value=os.environ.get("SERVICE_URL", DEFAULT_BASE_URL),
90
+ )
91
+ self.should_authenticate_request = self.st.checkbox(
92
+ label="Authenticate request",
93
+ value=False,
94
+ help="If checked, any request to the server will contain an"
95
+ "Identity token to allow authentication. "
96
+ "See the Cloud Run documentation to know more about authentication:"
97
+ "https://cloud.google.com/run/docs/authenticating/service-to-service",
98
+ )
99
+ self.agent_callable_path = None
100
+ self.remote_agent_engine_id = None
101
+
102
+ col1, col2, col3 = self.st.columns(3)
103
+ with col1:
104
+ if self.st.button("+ New chat"):
105
+ if (
106
+ len(
107
+ self.st.session_state.user_chats[
108
+ self.st.session_state["session_id"]
109
+ ]["messages"]
110
+ )
111
+ > 0
112
+ ):
113
+ self.st.session_state.invocation_id = None
114
+
115
+ self.st.session_state["session_id"] = str(uuid.uuid4())
116
+ self.st.session_state.session_db.get_session(
117
+ session_id=self.st.session_state["session_id"],
118
+ )
119
+ self.st.session_state.user_chats[
120
+ self.st.session_state["session_id"]
121
+ ] = {
122
+ "title": EMPTY_CHAT_NAME,
123
+ "messages": [],
124
+ }
125
+
126
+ with col2:
127
+ if self.st.button("Delete chat"):
128
+ self.st.session_state.invocation_id = None
129
+ self.st.session_state.session_db.clear()
130
+ self.st.session_state.user_chats.pop(
131
+ self.st.session_state["session_id"]
132
+ )
133
+ if len(self.st.session_state.user_chats) > 0:
134
+ chat_id = list(self.st.session_state.user_chats.keys())[0]
135
+ self.st.session_state["session_id"] = chat_id
136
+ self.st.session_state.session_db.get_session(
137
+ session_id=self.st.session_state["session_id"],
138
+ )
139
+ else:
140
+ self.st.session_state["session_id"] = str(uuid.uuid4())
141
+ self.st.session_state.user_chats[
142
+ self.st.session_state["session_id"]
143
+ ] = {
144
+ "title": EMPTY_CHAT_NAME,
145
+ "messages": [],
146
+ }
147
+ with col3:
148
+ if self.st.button("Save chat"):
149
+ save_chat(self.st)
150
+
151
+ self.st.subheader("Recent") # Style the heading
152
+
153
+ all_chats = list(reversed(self.st.session_state.user_chats.items()))
154
+ for chat_id, chat in all_chats[:NUM_CHAT_IN_RECENT]:
155
+ if self.st.button(chat["title"], key=chat_id):
156
+ self.st.session_state.invocation_id = None
157
+ self.st.session_state["session_id"] = chat_id
158
+ self.st.session_state.session_db.get_session(
159
+ session_id=self.st.session_state["session_id"],
160
+ )
161
+
162
+ with self.st.expander("Other chats"):
163
+ for chat_id, chat in all_chats[NUM_CHAT_IN_RECENT:]:
164
+ if self.st.button(chat["title"], key=chat_id):
165
+ self.st.session_state.invocation_id = None
166
+ self.st.session_state["session_id"] = chat_id
167
+ self.st.session_state.session_db.get_session(
168
+ session_id=self.st.session_state["session_id"],
169
+ )
170
+
171
+ self.st.divider()
172
+ self.st.header("Upload files from local")
173
+ bucket_name = self.st.text_input(
174
+ label="GCS Bucket for upload",
175
+ value=os.environ.get("BUCKET_NAME", "gs://your-bucket-name"),
176
+ )
177
+ if "checkbox_state" not in self.st.session_state:
178
+ self.st.session_state.checkbox_state = True
179
+
180
+ self.st.session_state.checkbox_state = self.st.checkbox(
181
+ "Upload to GCS first (suggested)", value=False, help=HELP_GCS_CHECKBOX
182
+ )
183
+
184
+ self.uploaded_files = self.st.file_uploader(
185
+ label="Send files from local",
186
+ accept_multiple_files=True,
187
+ key=f"uploader_images_{self.st.session_state.uploader_key}",
188
+ type=[
189
+ "png",
190
+ "jpg",
191
+ "jpeg",
192
+ "txt",
193
+ "docx",
194
+ "pdf",
195
+ "rtf",
196
+ "csv",
197
+ "tsv",
198
+ "xlsx",
199
+ ],
200
+ )
201
+ if self.uploaded_files and self.st.session_state.checkbox_state:
202
+ upload_files_to_gcs(self.st, bucket_name, self.uploaded_files)
203
+
204
+ self.st.divider()
205
+
206
+ self.st.header("Upload files from GCS")
207
+ self.gcs_uris = self.st.text_area(
208
+ "GCS uris (comma-separated)",
209
+ value=self.st.session_state["gcs_uris_to_be_sent"],
210
+ key=f"upload_text_area_{self.st.session_state.uploader_key}",
211
+ help=HELP_MESSAGE_MULTIMODALITY,
212
+ )
213
+
214
+ self.st.caption(f"Note: {HELP_MESSAGE_MULTIMODALITY}")