agent-starter-pack 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agent-starter-pack might be problematic. Click here for more details.
- {agent_starter_pack-0.2.3.dist-info → agent_starter_pack-0.3.0.dist-info}/METADATA +8 -4
- {agent_starter_pack-0.2.3.dist-info → agent_starter_pack-0.3.0.dist-info}/RECORD +61 -46
- agents/adk_base/README.md +14 -0
- agents/adk_base/app/agent.py +66 -0
- agents/adk_base/notebooks/adk_app_testing.ipynb +305 -0
- agents/adk_base/template/.templateconfig.yaml +21 -0
- agents/adk_base/tests/integration/test_agent.py +58 -0
- agents/agentic_rag/README.md +1 -0
- agents/agentic_rag/app/agent.py +44 -89
- agents/agentic_rag/app/templates.py +0 -25
- agents/agentic_rag/notebooks/adk_app_testing.ipynb +305 -0
- agents/agentic_rag/template/.templateconfig.yaml +3 -1
- agents/agentic_rag/tests/integration/test_agent.py +34 -27
- agents/langgraph_base_react/README.md +1 -1
- agents/langgraph_base_react/template/.templateconfig.yaml +1 -1
- src/base_template/Makefile +9 -0
- src/base_template/README.md +1 -1
- src/base_template/app/__init__.py +3 -0
- src/base_template/app/utils/tracing.py +11 -1
- src/base_template/app/utils/typing.py +54 -4
- src/base_template/deployment/terraform/dev/variables.tf +4 -0
- src/base_template/deployment/terraform/dev/vars/env.tfvars +0 -3
- src/base_template/deployment/terraform/variables.tf +4 -0
- src/base_template/deployment/terraform/vars/env.tfvars +0 -4
- src/base_template/pyproject.toml +5 -3
- src/{deployment_targets/agent_engine → base_template}/tests/unit/test_dummy.py +2 -1
- src/cli/commands/create.py +10 -2
- src/cli/commands/setup_cicd.py +3 -0
- src/cli/utils/gcp.py +1 -1
- src/cli/utils/template.py +27 -25
- src/data_ingestion/data_ingestion_pipeline/components/ingest_data.py +2 -1
- src/deployment_targets/agent_engine/app/agent_engine_app.py +62 -11
- src/deployment_targets/agent_engine/app/utils/gcs.py +1 -1
- src/deployment_targets/agent_engine/tests/integration/test_agent_engine_app.py +63 -0
- src/deployment_targets/agent_engine/tests/load_test/load_test.py +9 -2
- src/deployment_targets/cloud_run/app/server.py +41 -15
- src/deployment_targets/cloud_run/tests/integration/test_server_e2e.py +60 -3
- src/deployment_targets/cloud_run/tests/load_test/README.md +1 -1
- src/deployment_targets/cloud_run/tests/load_test/load_test.py +57 -24
- src/frontends/live_api_react/frontend/package-lock.json +3 -3
- src/frontends/streamlit_adk/frontend/side_bar.py +214 -0
- src/frontends/streamlit_adk/frontend/streamlit_app.py +314 -0
- src/frontends/streamlit_adk/frontend/style/app_markdown.py +37 -0
- src/frontends/streamlit_adk/frontend/utils/chat_utils.py +84 -0
- src/frontends/streamlit_adk/frontend/utils/local_chat_history.py +110 -0
- src/frontends/streamlit_adk/frontend/utils/message_editing.py +61 -0
- src/frontends/streamlit_adk/frontend/utils/multimodal_utils.py +223 -0
- src/frontends/streamlit_adk/frontend/utils/stream_handler.py +311 -0
- src/frontends/streamlit_adk/frontend/utils/title_summary.py +129 -0
- src/resources/locks/uv-adk_base-agent_engine.lock +5335 -0
- src/resources/locks/uv-adk_base-cloud_run.lock +5927 -0
- src/resources/locks/uv-agentic_rag-agent_engine.lock +882 -676
- src/resources/locks/uv-agentic_rag-cloud_run.lock +1014 -835
- src/resources/locks/uv-crewai_coding_crew-agent_engine.lock +712 -606
- src/resources/locks/uv-crewai_coding_crew-cloud_run.lock +770 -672
- src/resources/locks/uv-langgraph_base_react-agent_engine.lock +602 -529
- src/resources/locks/uv-langgraph_base_react-cloud_run.lock +763 -665
- src/resources/locks/uv-live_api-cloud_run.lock +760 -662
- agents/agentic_rag/notebooks/evaluating_langgraph_agent.ipynb +0 -1561
- src/base_template/tests/unit/test_utils/test_tracing_exporter.py +0 -140
- src/deployment_targets/cloud_run/tests/unit/test_server.py +0 -124
- {agent_starter_pack-0.2.3.dist-info → agent_starter_pack-0.3.0.dist-info}/WHEEL +0 -0
- {agent_starter_pack-0.2.3.dist-info → agent_starter_pack-0.3.0.dist-info}/entry_points.txt +0 -0
- {agent_starter_pack-0.2.3.dist-info → agent_starter_pack-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -54,7 +54,14 @@ class ChatStreamUser(HttpUser):
|
|
|
54
54
|
"""Simulates a chat stream interaction."""
|
|
55
55
|
headers = {"Content-Type": "application/json"}
|
|
56
56
|
headers["Authorization"] = f"Bearer {os.environ['_AUTH_TOKEN']}"
|
|
57
|
-
|
|
57
|
+
{% if "adk" in cookiecutter.tags %}
|
|
58
|
+
data = {
|
|
59
|
+
"input": {
|
|
60
|
+
"message": "What's the weather in San Francisco?",
|
|
61
|
+
"user_id": "test",
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
{% else %}
|
|
58
65
|
data = {
|
|
59
66
|
"input": {
|
|
60
67
|
"input": {
|
|
@@ -69,7 +76,7 @@ class ChatStreamUser(HttpUser):
|
|
|
69
76
|
},
|
|
70
77
|
}
|
|
71
78
|
}
|
|
72
|
-
|
|
79
|
+
{% endif %}
|
|
73
80
|
start_time = time.time()
|
|
74
81
|
with self.client.post(
|
|
75
82
|
url_path,
|
|
@@ -11,7 +11,32 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
{% if "adk" in cookiecutter.tags %}
|
|
15
|
+
import os
|
|
16
|
+
|
|
17
|
+
from fastapi import FastAPI
|
|
18
|
+
from google.adk.cli.fast_api import get_fast_api_app
|
|
19
|
+
from google.cloud import logging as google_cloud_logging
|
|
20
|
+
from opentelemetry import trace
|
|
21
|
+
from opentelemetry.sdk.trace import TracerProvider, export
|
|
22
|
+
|
|
23
|
+
from app.utils.tracing import CloudTraceLoggingSpanExporter
|
|
24
|
+
from app.utils.typing import Feedback
|
|
25
|
+
|
|
26
|
+
logging_client = google_cloud_logging.Client()
|
|
27
|
+
logger = logging_client.logger(__name__)
|
|
28
|
+
|
|
29
|
+
provider = TracerProvider()
|
|
30
|
+
processor = export.BatchSpanProcessor(CloudTraceLoggingSpanExporter())
|
|
31
|
+
provider.add_span_processor(processor)
|
|
32
|
+
trace.set_tracer_provider(provider)
|
|
14
33
|
|
|
34
|
+
AGENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
35
|
+
app: FastAPI = get_fast_api_app(agent_dir=AGENT_DIR, web=False)
|
|
36
|
+
|
|
37
|
+
app.title = "{{cookiecutter.project_name}}"
|
|
38
|
+
app.description = "API for interacting with the Agent {{cookiecutter.project_name}}"
|
|
39
|
+
{%- else %}
|
|
15
40
|
import logging
|
|
16
41
|
import os
|
|
17
42
|
from collections.abc import Generator
|
|
@@ -40,7 +65,7 @@ try:
|
|
|
40
65
|
app_name=app.title,
|
|
41
66
|
disable_batch=False,
|
|
42
67
|
exporter=CloudTraceLoggingSpanExporter(),
|
|
43
|
-
instruments={
|
|
68
|
+
instruments={Instruments.LANGCHAIN, Instruments.CREW},
|
|
44
69
|
)
|
|
45
70
|
except Exception as e:
|
|
46
71
|
logging.error("Failed to initialize Telemetry: %s", str(e))
|
|
@@ -91,20 +116,6 @@ def redirect_root_to_docs() -> RedirectResponse:
|
|
|
91
116
|
return RedirectResponse(url="/docs")
|
|
92
117
|
|
|
93
118
|
|
|
94
|
-
@app.post("/feedback")
|
|
95
|
-
def collect_feedback(feedback: Feedback) -> dict[str, str]:
|
|
96
|
-
"""Collect and log feedback.
|
|
97
|
-
|
|
98
|
-
Args:
|
|
99
|
-
feedback: The feedback data to log
|
|
100
|
-
|
|
101
|
-
Returns:
|
|
102
|
-
Success message
|
|
103
|
-
"""
|
|
104
|
-
logger.log_struct(feedback.model_dump(), severity="INFO")
|
|
105
|
-
return {"status": "success"}
|
|
106
|
-
|
|
107
|
-
|
|
108
119
|
@app.post("/stream_messages")
|
|
109
120
|
def stream_chat_events(request: Request) -> StreamingResponse:
|
|
110
121
|
"""Stream chat events in response to an input request.
|
|
@@ -119,6 +130,21 @@ def stream_chat_events(request: Request) -> StreamingResponse:
|
|
|
119
130
|
stream_messages(input=request.input, config=request.config),
|
|
120
131
|
media_type="text/event-stream",
|
|
121
132
|
)
|
|
133
|
+
{%- endif %}
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@app.post("/feedback")
|
|
137
|
+
def collect_feedback(feedback: Feedback) -> dict[str, str]:
|
|
138
|
+
"""Collect and log feedback.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
feedback: The feedback data to log
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Success message
|
|
145
|
+
"""
|
|
146
|
+
logger.log_struct(feedback.model_dump(), severity="INFO")
|
|
147
|
+
return {"status": "success"}
|
|
122
148
|
|
|
123
149
|
|
|
124
150
|
# Main execution
|
|
@@ -32,7 +32,11 @@ logging.basicConfig(level=logging.INFO)
|
|
|
32
32
|
logger = logging.getLogger(__name__)
|
|
33
33
|
|
|
34
34
|
BASE_URL = "http://127.0.0.1:8000/"
|
|
35
|
+
{%- if "adk" in cookiecutter.tags %}
|
|
36
|
+
STREAM_URL = BASE_URL + "run_sse"
|
|
37
|
+
{%- else %}
|
|
35
38
|
STREAM_URL = BASE_URL + "stream_messages"
|
|
39
|
+
{%- endif %}
|
|
36
40
|
FEEDBACK_URL = BASE_URL + "feedback"
|
|
37
41
|
|
|
38
42
|
HEADERS = {"Content-Type": "application/json"}
|
|
@@ -116,23 +120,72 @@ def server_fixture(request: Any) -> Iterator[subprocess.Popen[str]]:
|
|
|
116
120
|
def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None:
|
|
117
121
|
"""Test the chat stream functionality."""
|
|
118
122
|
logger.info("Starting chat stream test")
|
|
123
|
+
{% if "adk" in cookiecutter.tags %}
|
|
124
|
+
# Create session first
|
|
125
|
+
user_id = "user_123"
|
|
126
|
+
session_id = "session_abc"
|
|
127
|
+
session_data = {"state": {"preferred_language": "English", "visit_count": 5}}
|
|
128
|
+
session_response = requests.post(
|
|
129
|
+
f"{BASE_URL}/apps/app/users/{user_id}/sessions/{session_id}",
|
|
130
|
+
headers=HEADERS,
|
|
131
|
+
json=session_data,
|
|
132
|
+
timeout=10,
|
|
133
|
+
)
|
|
134
|
+
assert session_response.status_code == 200
|
|
119
135
|
|
|
136
|
+
# Then send chat message
|
|
137
|
+
data = {
|
|
138
|
+
"app_name": "app",
|
|
139
|
+
"user_id": user_id,
|
|
140
|
+
"session_id": session_id,
|
|
141
|
+
"new_message": {
|
|
142
|
+
"role": "user",
|
|
143
|
+
"parts": [{"text": "What's the weather in San Francisco?"}],
|
|
144
|
+
},
|
|
145
|
+
"streaming": True,
|
|
146
|
+
}
|
|
147
|
+
{% else %}
|
|
120
148
|
data = {
|
|
121
149
|
"input": {
|
|
122
150
|
"messages": [
|
|
123
151
|
{"type": "human", "content": "Hello, AI!"},
|
|
124
152
|
{"type": "ai", "content": "Hello!"},
|
|
125
|
-
{"type": "human", "content": "
|
|
153
|
+
{"type": "human", "content": "Who are you?"},
|
|
126
154
|
]
|
|
127
155
|
},
|
|
128
156
|
"config": {"metadata": {"user_id": "test-user", "session_id": "test-session"}},
|
|
129
157
|
}
|
|
130
|
-
|
|
158
|
+
{% endif %}
|
|
131
159
|
response = requests.post(
|
|
132
160
|
STREAM_URL, headers=HEADERS, json=data, stream=True, timeout=10
|
|
133
161
|
)
|
|
134
162
|
assert response.status_code == 200
|
|
135
163
|
|
|
164
|
+
{%- if "adk" in cookiecutter.tags %}
|
|
165
|
+
# Parse SSE events from response
|
|
166
|
+
events = []
|
|
167
|
+
for line in response.iter_lines():
|
|
168
|
+
if line:
|
|
169
|
+
# SSE format is "data: {json}"
|
|
170
|
+
line_str = line.decode("utf-8")
|
|
171
|
+
if line_str.startswith("data: "):
|
|
172
|
+
event_json = line_str[6:] # Remove "data: " prefix
|
|
173
|
+
event = json.loads(event_json)
|
|
174
|
+
events.append(event)
|
|
175
|
+
|
|
176
|
+
assert events, "No events received from stream"
|
|
177
|
+
# Check for valid content in the response
|
|
178
|
+
has_text_content = False
|
|
179
|
+
for event in events:
|
|
180
|
+
content = event.get("content")
|
|
181
|
+
if (
|
|
182
|
+
content is not None
|
|
183
|
+
and content.get("parts")
|
|
184
|
+
and any(part.get("text") for part in content["parts"])
|
|
185
|
+
):
|
|
186
|
+
has_text_content = True
|
|
187
|
+
break
|
|
188
|
+
{%- else %}
|
|
136
189
|
events = [json.loads(line) for line in response.iter_lines() if line]
|
|
137
190
|
assert events, "No events received from stream"
|
|
138
191
|
|
|
@@ -155,12 +208,12 @@ def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None:
|
|
|
155
208
|
has_content = True
|
|
156
209
|
break
|
|
157
210
|
assert has_content, "At least one message should have content"
|
|
211
|
+
{%- endif %}
|
|
158
212
|
|
|
159
213
|
|
|
160
214
|
def test_chat_stream_error_handling(server_fixture: subprocess.Popen[str]) -> None:
|
|
161
215
|
"""Test the chat stream error handling."""
|
|
162
216
|
logger.info("Starting chat stream error handling test")
|
|
163
|
-
|
|
164
217
|
data = {
|
|
165
218
|
"input": {"messages": [{"type": "invalid_type", "content": "Cause an error"}]}
|
|
166
219
|
}
|
|
@@ -182,7 +235,11 @@ def test_collect_feedback(server_fixture: subprocess.Popen[str]) -> None:
|
|
|
182
235
|
# Create sample feedback data
|
|
183
236
|
feedback_data = {
|
|
184
237
|
"score": 4,
|
|
238
|
+
{%- if "adk" in cookiecutter.tags %}
|
|
239
|
+
"invocation_id": str(uuid.uuid4()),
|
|
240
|
+
{%- else %}
|
|
185
241
|
"run_id": str(uuid.uuid4()),
|
|
242
|
+
{%- endif %}
|
|
186
243
|
"text": "Great response!",
|
|
187
244
|
}
|
|
188
245
|
|
|
@@ -28,7 +28,7 @@ Trigger the Locust load test with the following command:
|
|
|
28
28
|
locust -f tests/load_test/load_test.py \
|
|
29
29
|
-H http://127.0.0.1:8000 \
|
|
30
30
|
--headless \
|
|
31
|
-
-t 30s -u
|
|
31
|
+
-t 30s -u 10 -r 2 \
|
|
32
32
|
--csv=tests/load_test/.results/results \
|
|
33
33
|
--html=tests/load_test/.results/report.html
|
|
34
34
|
```
|
|
@@ -15,9 +15,20 @@
|
|
|
15
15
|
import json
|
|
16
16
|
import os
|
|
17
17
|
import time
|
|
18
|
+
{%- if "adk" in cookiecutter.tags %}
|
|
19
|
+
import uuid
|
|
18
20
|
|
|
21
|
+
import requests
|
|
19
22
|
from locust import HttpUser, between, task
|
|
23
|
+
{%- else %}
|
|
20
24
|
|
|
25
|
+
from locust import HttpUser, between, task
|
|
26
|
+
{%- endif %}
|
|
27
|
+
{% if "adk" in cookiecutter.tags %}
|
|
28
|
+
ENDPOINT = "/run_sse"
|
|
29
|
+
{% else %}
|
|
30
|
+
ENDPOINT = "/stream_messages"
|
|
31
|
+
{% endif %}
|
|
21
32
|
|
|
22
33
|
class ChatStreamUser(HttpUser):
|
|
23
34
|
"""Simulates a user interacting with the chat stream API."""
|
|
@@ -30,7 +41,30 @@ class ChatStreamUser(HttpUser):
|
|
|
30
41
|
headers = {"Content-Type": "application/json"}
|
|
31
42
|
if os.environ.get("_ID_TOKEN"):
|
|
32
43
|
headers["Authorization"] = f"Bearer {os.environ['_ID_TOKEN']}"
|
|
44
|
+
{%- if "adk" in cookiecutter.tags %}
|
|
45
|
+
# Create session first
|
|
46
|
+
user_id = f"user_{uuid.uuid4()}"
|
|
47
|
+
session_id = f"session_{uuid.uuid4()}"
|
|
48
|
+
session_data = {"state": {"preferred_language": "English", "visit_count": 5}}
|
|
49
|
+
requests.post(
|
|
50
|
+
f"{self.client.base_url}/apps/app/users/{user_id}/sessions/{session_id}",
|
|
51
|
+
headers=headers,
|
|
52
|
+
json=session_data,
|
|
53
|
+
timeout=10,
|
|
54
|
+
)
|
|
33
55
|
|
|
56
|
+
# Send chat message
|
|
57
|
+
data = {
|
|
58
|
+
"app_name": "app",
|
|
59
|
+
"user_id": user_id,
|
|
60
|
+
"session_id": session_id,
|
|
61
|
+
"new_message": {
|
|
62
|
+
"role": "user",
|
|
63
|
+
"parts": [{"text": "What's the weather in San Francisco?"}],
|
|
64
|
+
},
|
|
65
|
+
"streaming": True,
|
|
66
|
+
}
|
|
67
|
+
{%- else %}
|
|
34
68
|
data = {
|
|
35
69
|
"input": {
|
|
36
70
|
"messages": [
|
|
@@ -43,43 +77,42 @@ class ChatStreamUser(HttpUser):
|
|
|
43
77
|
"metadata": {"user_id": "test-user", "session_id": "test-session"}
|
|
44
78
|
},
|
|
45
79
|
}
|
|
46
|
-
|
|
80
|
+
{%- endif %}
|
|
47
81
|
start_time = time.time()
|
|
48
82
|
|
|
49
83
|
with self.client.post(
|
|
50
|
-
|
|
84
|
+
ENDPOINT,
|
|
85
|
+
name=f"{ENDPOINT} message",
|
|
51
86
|
headers=headers,
|
|
52
87
|
json=data,
|
|
53
88
|
catch_response=True,
|
|
54
|
-
name="/stream_messages first message",
|
|
55
89
|
stream=True,
|
|
90
|
+
params={"alt": "sse"},
|
|
56
91
|
) as response:
|
|
57
92
|
if response.status_code == 200:
|
|
58
93
|
events = []
|
|
59
94
|
for line in response.iter_lines():
|
|
60
95
|
if line:
|
|
96
|
+
{%- if "adk" in cookiecutter.tags %}
|
|
97
|
+
# SSE format is "data: {json}"
|
|
98
|
+
line_str = line.decode("utf-8")
|
|
99
|
+
if line_str.startswith("data: "):
|
|
100
|
+
event_json = line_str[6:] # Remove "data: " prefix
|
|
101
|
+
event = json.loads(event_json)
|
|
102
|
+
events.append(event)
|
|
103
|
+
{%- else %}
|
|
61
104
|
event = json.loads(line)
|
|
62
105
|
events.append(event)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
request_type="POST",
|
|
75
|
-
name="/stream_messages end",
|
|
76
|
-
response_time=total_time
|
|
77
|
-
* 1000, # Convert to milliseconds
|
|
78
|
-
response_length=len(json.dumps(events)),
|
|
79
|
-
response=response,
|
|
80
|
-
context={},
|
|
81
|
-
)
|
|
82
|
-
return
|
|
83
|
-
response.failure("No valid response content received")
|
|
106
|
+
{%- endif %}
|
|
107
|
+
end_time = time.time()
|
|
108
|
+
total_time = end_time - start_time
|
|
109
|
+
self.environment.events.request.fire(
|
|
110
|
+
request_type="POST",
|
|
111
|
+
name=f"{ENDPOINT} end",
|
|
112
|
+
response_time=total_time * 1000, # Convert to milliseconds
|
|
113
|
+
response_length=len(json.dumps(events)),
|
|
114
|
+
response=response,
|
|
115
|
+
context={},
|
|
116
|
+
)
|
|
84
117
|
else:
|
|
85
118
|
response.failure(f"Unexpected status code: {response.status_code}")
|
|
@@ -2027,9 +2027,9 @@
|
|
|
2027
2027
|
}
|
|
2028
2028
|
},
|
|
2029
2029
|
"node_modules/@babel/runtime": {
|
|
2030
|
-
"version": "7.
|
|
2031
|
-
"resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.
|
|
2032
|
-
"integrity": "sha512-
|
|
2030
|
+
"version": "7.27.0",
|
|
2031
|
+
"resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.0.tgz",
|
|
2032
|
+
"integrity": "sha512-VtPOkrdPHZsKc/clNqyi9WUA8TINkZ4cGk63UUE3u4pmB2k+ZMQRDuIOagv8UVd6j7k0T3+RRIb7beKTebNbcw==",
|
|
2033
2033
|
"license": "MIT",
|
|
2034
2034
|
"dependencies": {
|
|
2035
2035
|
"regenerator-runtime": "^0.14.0"
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# ruff: noqa: RUF015
|
|
16
|
+
import json
|
|
17
|
+
import os
|
|
18
|
+
import uuid
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from frontend.utils.chat_utils import save_chat
|
|
22
|
+
from frontend.utils.multimodal_utils import (
|
|
23
|
+
HELP_GCS_CHECKBOX,
|
|
24
|
+
HELP_MESSAGE_MULTIMODALITY,
|
|
25
|
+
upload_files_to_gcs,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
EMPTY_CHAT_NAME = "Empty chat"
|
|
29
|
+
NUM_CHAT_IN_RECENT = 3
|
|
30
|
+
DEFAULT_BASE_URL = "http://localhost:8000/"
|
|
31
|
+
|
|
32
|
+
DEFAULT_REMOTE_AGENT_ENGINE_ID = "N/A"
|
|
33
|
+
if os.path.exists("deployment_metadata.json"):
|
|
34
|
+
with open("deployment_metadata.json") as f:
|
|
35
|
+
DEFAULT_REMOTE_AGENT_ENGINE_ID = json.load(f)["remote_agent_engine_id"]
|
|
36
|
+
DEFAULT_AGENT_CALLABLE_PATH = "app.agent_engine_app.AgentEngineApp"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SideBar:
|
|
40
|
+
"""Manages the sidebar components of the Streamlit application."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, st: Any) -> None:
|
|
43
|
+
"""
|
|
44
|
+
Initialize the SideBar.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
st (Any): The Streamlit object for rendering UI components.
|
|
48
|
+
"""
|
|
49
|
+
self.st = st
|
|
50
|
+
|
|
51
|
+
def init_side_bar(self) -> None:
|
|
52
|
+
"""Initialize and render the sidebar components."""
|
|
53
|
+
with self.st.sidebar:
|
|
54
|
+
default_agent_type = (
|
|
55
|
+
"Remote URL" if os.path.exists("Dockerfile") else "Local Agent"
|
|
56
|
+
)
|
|
57
|
+
use_agent_path = self.st.selectbox(
|
|
58
|
+
"Select Agent Type",
|
|
59
|
+
["Local Agent", "Remote Agent Engine ID", "Remote URL"],
|
|
60
|
+
index=["Local Agent", "Remote Agent Engine ID", "Remote URL"].index(
|
|
61
|
+
default_agent_type
|
|
62
|
+
),
|
|
63
|
+
help="'Local Agent' uses a local implementation, 'Remote Agent Engine ID' connects to a deployed Vertex AI agent, and 'Remote URL' connects to a custom endpoint.",
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if use_agent_path == "Local Agent":
|
|
67
|
+
self.agent_callable_path = self.st.text_input(
|
|
68
|
+
label="Agent Callable Path",
|
|
69
|
+
value=os.environ.get(
|
|
70
|
+
"AGENT_CALLABLE_PATH", DEFAULT_AGENT_CALLABLE_PATH
|
|
71
|
+
),
|
|
72
|
+
)
|
|
73
|
+
self.remote_agent_engine_id = None
|
|
74
|
+
self.url_input_field = None
|
|
75
|
+
self.should_authenticate_request = False
|
|
76
|
+
elif use_agent_path == "Remote Agent Engine ID":
|
|
77
|
+
self.remote_agent_engine_id = self.st.text_input(
|
|
78
|
+
label="Remote Agent Engine ID",
|
|
79
|
+
value=os.environ.get(
|
|
80
|
+
"REMOTE_AGENT_ENGINE_ID", DEFAULT_REMOTE_AGENT_ENGINE_ID
|
|
81
|
+
),
|
|
82
|
+
)
|
|
83
|
+
self.agent_callable_path = None
|
|
84
|
+
self.url_input_field = None
|
|
85
|
+
self.should_authenticate_request = False
|
|
86
|
+
else:
|
|
87
|
+
self.url_input_field = self.st.text_input(
|
|
88
|
+
label="Service URL",
|
|
89
|
+
value=os.environ.get("SERVICE_URL", DEFAULT_BASE_URL),
|
|
90
|
+
)
|
|
91
|
+
self.should_authenticate_request = self.st.checkbox(
|
|
92
|
+
label="Authenticate request",
|
|
93
|
+
value=False,
|
|
94
|
+
help="If checked, any request to the server will contain an"
|
|
95
|
+
"Identity token to allow authentication. "
|
|
96
|
+
"See the Cloud Run documentation to know more about authentication:"
|
|
97
|
+
"https://cloud.google.com/run/docs/authenticating/service-to-service",
|
|
98
|
+
)
|
|
99
|
+
self.agent_callable_path = None
|
|
100
|
+
self.remote_agent_engine_id = None
|
|
101
|
+
|
|
102
|
+
col1, col2, col3 = self.st.columns(3)
|
|
103
|
+
with col1:
|
|
104
|
+
if self.st.button("+ New chat"):
|
|
105
|
+
if (
|
|
106
|
+
len(
|
|
107
|
+
self.st.session_state.user_chats[
|
|
108
|
+
self.st.session_state["session_id"]
|
|
109
|
+
]["messages"]
|
|
110
|
+
)
|
|
111
|
+
> 0
|
|
112
|
+
):
|
|
113
|
+
self.st.session_state.invocation_id = None
|
|
114
|
+
|
|
115
|
+
self.st.session_state["session_id"] = str(uuid.uuid4())
|
|
116
|
+
self.st.session_state.session_db.get_session(
|
|
117
|
+
session_id=self.st.session_state["session_id"],
|
|
118
|
+
)
|
|
119
|
+
self.st.session_state.user_chats[
|
|
120
|
+
self.st.session_state["session_id"]
|
|
121
|
+
] = {
|
|
122
|
+
"title": EMPTY_CHAT_NAME,
|
|
123
|
+
"messages": [],
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
with col2:
|
|
127
|
+
if self.st.button("Delete chat"):
|
|
128
|
+
self.st.session_state.invocation_id = None
|
|
129
|
+
self.st.session_state.session_db.clear()
|
|
130
|
+
self.st.session_state.user_chats.pop(
|
|
131
|
+
self.st.session_state["session_id"]
|
|
132
|
+
)
|
|
133
|
+
if len(self.st.session_state.user_chats) > 0:
|
|
134
|
+
chat_id = list(self.st.session_state.user_chats.keys())[0]
|
|
135
|
+
self.st.session_state["session_id"] = chat_id
|
|
136
|
+
self.st.session_state.session_db.get_session(
|
|
137
|
+
session_id=self.st.session_state["session_id"],
|
|
138
|
+
)
|
|
139
|
+
else:
|
|
140
|
+
self.st.session_state["session_id"] = str(uuid.uuid4())
|
|
141
|
+
self.st.session_state.user_chats[
|
|
142
|
+
self.st.session_state["session_id"]
|
|
143
|
+
] = {
|
|
144
|
+
"title": EMPTY_CHAT_NAME,
|
|
145
|
+
"messages": [],
|
|
146
|
+
}
|
|
147
|
+
with col3:
|
|
148
|
+
if self.st.button("Save chat"):
|
|
149
|
+
save_chat(self.st)
|
|
150
|
+
|
|
151
|
+
self.st.subheader("Recent") # Style the heading
|
|
152
|
+
|
|
153
|
+
all_chats = list(reversed(self.st.session_state.user_chats.items()))
|
|
154
|
+
for chat_id, chat in all_chats[:NUM_CHAT_IN_RECENT]:
|
|
155
|
+
if self.st.button(chat["title"], key=chat_id):
|
|
156
|
+
self.st.session_state.invocation_id = None
|
|
157
|
+
self.st.session_state["session_id"] = chat_id
|
|
158
|
+
self.st.session_state.session_db.get_session(
|
|
159
|
+
session_id=self.st.session_state["session_id"],
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
with self.st.expander("Other chats"):
|
|
163
|
+
for chat_id, chat in all_chats[NUM_CHAT_IN_RECENT:]:
|
|
164
|
+
if self.st.button(chat["title"], key=chat_id):
|
|
165
|
+
self.st.session_state.invocation_id = None
|
|
166
|
+
self.st.session_state["session_id"] = chat_id
|
|
167
|
+
self.st.session_state.session_db.get_session(
|
|
168
|
+
session_id=self.st.session_state["session_id"],
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
self.st.divider()
|
|
172
|
+
self.st.header("Upload files from local")
|
|
173
|
+
bucket_name = self.st.text_input(
|
|
174
|
+
label="GCS Bucket for upload",
|
|
175
|
+
value=os.environ.get("BUCKET_NAME", "gs://your-bucket-name"),
|
|
176
|
+
)
|
|
177
|
+
if "checkbox_state" not in self.st.session_state:
|
|
178
|
+
self.st.session_state.checkbox_state = True
|
|
179
|
+
|
|
180
|
+
self.st.session_state.checkbox_state = self.st.checkbox(
|
|
181
|
+
"Upload to GCS first (suggested)", value=False, help=HELP_GCS_CHECKBOX
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
self.uploaded_files = self.st.file_uploader(
|
|
185
|
+
label="Send files from local",
|
|
186
|
+
accept_multiple_files=True,
|
|
187
|
+
key=f"uploader_images_{self.st.session_state.uploader_key}",
|
|
188
|
+
type=[
|
|
189
|
+
"png",
|
|
190
|
+
"jpg",
|
|
191
|
+
"jpeg",
|
|
192
|
+
"txt",
|
|
193
|
+
"docx",
|
|
194
|
+
"pdf",
|
|
195
|
+
"rtf",
|
|
196
|
+
"csv",
|
|
197
|
+
"tsv",
|
|
198
|
+
"xlsx",
|
|
199
|
+
],
|
|
200
|
+
)
|
|
201
|
+
if self.uploaded_files and self.st.session_state.checkbox_state:
|
|
202
|
+
upload_files_to_gcs(self.st, bucket_name, self.uploaded_files)
|
|
203
|
+
|
|
204
|
+
self.st.divider()
|
|
205
|
+
|
|
206
|
+
self.st.header("Upload files from GCS")
|
|
207
|
+
self.gcs_uris = self.st.text_area(
|
|
208
|
+
"GCS uris (comma-separated)",
|
|
209
|
+
value=self.st.session_state["gcs_uris_to_be_sent"],
|
|
210
|
+
key=f"upload_text_area_{self.st.session_state.uploader_key}",
|
|
211
|
+
help=HELP_MESSAGE_MULTIMODALITY,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
self.st.caption(f"Note: {HELP_MESSAGE_MULTIMODALITY}")
|