aixtools 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aixtools might be problematic. Click here for more details.
- aixtools/.chainlit/config.toml +113 -0
- aixtools/.chainlit/translations/bn.json +214 -0
- aixtools/.chainlit/translations/en-US.json +214 -0
- aixtools/.chainlit/translations/gu.json +214 -0
- aixtools/.chainlit/translations/he-IL.json +214 -0
- aixtools/.chainlit/translations/hi.json +214 -0
- aixtools/.chainlit/translations/ja.json +214 -0
- aixtools/.chainlit/translations/kn.json +214 -0
- aixtools/.chainlit/translations/ml.json +214 -0
- aixtools/.chainlit/translations/mr.json +214 -0
- aixtools/.chainlit/translations/nl.json +214 -0
- aixtools/.chainlit/translations/ta.json +214 -0
- aixtools/.chainlit/translations/te.json +214 -0
- aixtools/.chainlit/translations/zh-CN.json +214 -0
- aixtools/__init__.py +11 -0
- aixtools/_version.py +34 -0
- aixtools/a2a/app.py +126 -0
- aixtools/a2a/google_sdk/__init__.py +0 -0
- aixtools/a2a/google_sdk/card.py +27 -0
- aixtools/a2a/google_sdk/pydantic_ai_adapter/agent_executor.py +199 -0
- aixtools/a2a/google_sdk/pydantic_ai_adapter/storage.py +26 -0
- aixtools/a2a/google_sdk/remote_agent_connection.py +88 -0
- aixtools/a2a/google_sdk/utils.py +59 -0
- aixtools/a2a/utils.py +115 -0
- aixtools/agents/__init__.py +12 -0
- aixtools/agents/agent.py +164 -0
- aixtools/agents/agent_batch.py +71 -0
- aixtools/agents/prompt.py +97 -0
- aixtools/app.py +143 -0
- aixtools/chainlit.md +14 -0
- aixtools/compliance/__init__.py +9 -0
- aixtools/compliance/private_data.py +138 -0
- aixtools/context.py +17 -0
- aixtools/db/__init__.py +17 -0
- aixtools/db/database.py +110 -0
- aixtools/db/vector_db.py +115 -0
- aixtools/google/client.py +25 -0
- aixtools/log_view/__init__.py +17 -0
- aixtools/log_view/app.py +195 -0
- aixtools/log_view/display.py +285 -0
- aixtools/log_view/export.py +51 -0
- aixtools/log_view/filters.py +41 -0
- aixtools/log_view/log_utils.py +26 -0
- aixtools/log_view/node_summary.py +229 -0
- aixtools/logfilters/__init__.py +7 -0
- aixtools/logfilters/context_filter.py +67 -0
- aixtools/logging/__init__.py +30 -0
- aixtools/logging/log_objects.py +227 -0
- aixtools/logging/logging_config.py +161 -0
- aixtools/logging/mcp_log_models.py +102 -0
- aixtools/logging/mcp_logger.py +172 -0
- aixtools/logging/model_patch_logging.py +87 -0
- aixtools/logging/open_telemetry.py +36 -0
- aixtools/mcp/__init__.py +9 -0
- aixtools/mcp/client.py +375 -0
- aixtools/mcp/example_client.py +30 -0
- aixtools/mcp/example_server.py +22 -0
- aixtools/mcp/fast_mcp_log.py +31 -0
- aixtools/mcp/faulty_mcp.py +319 -0
- aixtools/model_patch/model_patch.py +63 -0
- aixtools/server/__init__.py +29 -0
- aixtools/server/app_mounter.py +90 -0
- aixtools/server/path.py +72 -0
- aixtools/server/utils.py +70 -0
- aixtools/server/workspace_privacy.py +65 -0
- aixtools/testing/__init__.py +9 -0
- aixtools/testing/aix_test_model.py +149 -0
- aixtools/testing/mock_tool.py +66 -0
- aixtools/testing/model_patch_cache.py +279 -0
- aixtools/tools/doctor/__init__.py +3 -0
- aixtools/tools/doctor/tool_doctor.py +61 -0
- aixtools/tools/doctor/tool_recommendation.py +44 -0
- aixtools/utils/__init__.py +35 -0
- aixtools/utils/chainlit/cl_agent_show.py +82 -0
- aixtools/utils/chainlit/cl_utils.py +168 -0
- aixtools/utils/config.py +131 -0
- aixtools/utils/config_util.py +69 -0
- aixtools/utils/enum_with_description.py +37 -0
- aixtools/utils/files.py +17 -0
- aixtools/utils/persisted_dict.py +99 -0
- aixtools/utils/utils.py +167 -0
- aixtools/vault/__init__.py +7 -0
- aixtools/vault/vault.py +137 -0
- aixtools-0.0.0.dist-info/METADATA +669 -0
- aixtools-0.0.0.dist-info/RECORD +88 -0
- aixtools-0.0.0.dist-info/WHEEL +5 -0
- aixtools-0.0.0.dist-info/entry_points.txt +2 -0
- aixtools-0.0.0.dist-info/top_level.txt +1 -0
aixtools/db/vector_db.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Vector database implementation for embedding storage and similarity search.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from langchain_chroma import Chroma
|
|
6
|
+
from langchain_core.embeddings import Embeddings
|
|
7
|
+
from langchain_ollama import OllamaEmbeddings
|
|
8
|
+
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
|
9
|
+
|
|
10
|
+
from aixtools.logging.logging_config import get_logger
|
|
11
|
+
from aixtools.utils.config import (
|
|
12
|
+
AZURE_OPENAI_API_KEY,
|
|
13
|
+
AZURE_VDB_EMBEDDINGS_MODEL_NAME,
|
|
14
|
+
OLLAMA_VDB_EMBEDDINGS_MODEL_NAME,
|
|
15
|
+
OPENAI_API_KEY,
|
|
16
|
+
OPENAI_VDB_EMBEDDINGS_MODEL_NAME,
|
|
17
|
+
VDB_CHROMA_PATH,
|
|
18
|
+
VDB_EMBEDDINGS_MODEL_FAMILY,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
CREATE_DB = False
|
|
22
|
+
|
|
23
|
+
_vector_dbs = {}
|
|
24
|
+
|
|
25
|
+
logger = get_logger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_vdb_embedding(model_family=VDB_EMBEDDINGS_MODEL_FAMILY) -> Embeddings:
|
|
29
|
+
"""Get the embedding model for vector storage"""
|
|
30
|
+
match model_family:
|
|
31
|
+
case "openai":
|
|
32
|
+
return OpenAIEmbeddings(model=OPENAI_VDB_EMBEDDINGS_MODEL_NAME, api_key=OPENAI_API_KEY) # type: ignore
|
|
33
|
+
case "azure":
|
|
34
|
+
return AzureOpenAIEmbeddings( # type: ignore
|
|
35
|
+
model=AZURE_VDB_EMBEDDINGS_MODEL_NAME, api_key=AZURE_OPENAI_API_KEY
|
|
36
|
+
)
|
|
37
|
+
case "ollama":
|
|
38
|
+
return OllamaEmbeddings(model=OLLAMA_VDB_EMBEDDINGS_MODEL_NAME) # type: ignore
|
|
39
|
+
case _:
|
|
40
|
+
raise ValueError(f"Model family {model_family} not supported")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_vector_db(collection_name: str) -> Chroma:
|
|
44
|
+
"""Implement singleton pattern for database connections"""
|
|
45
|
+
# _vector_dbs will not be re-assigned, but it will be modified
|
|
46
|
+
global _vector_dbs # noqa: PLW0602, pylint: disable=protected-access,global-variable-not-assigned
|
|
47
|
+
if collection_name not in _vector_dbs:
|
|
48
|
+
print(f"Creating new DB connection: {collection_name=}")
|
|
49
|
+
vdb = Chroma(
|
|
50
|
+
persist_directory=str(VDB_CHROMA_PATH),
|
|
51
|
+
collection_name=collection_name,
|
|
52
|
+
embedding_function=get_vdb_embedding(),
|
|
53
|
+
)
|
|
54
|
+
_vector_dbs[collection_name] = vdb
|
|
55
|
+
return _vector_dbs[collection_name]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def vdb_add(vdb: Chroma, text: str, doc_id: str, meta=list[dict] | dict | None, force=False) -> str | None:
|
|
59
|
+
"""
|
|
60
|
+
Add a document to the database if it's not already there.
|
|
61
|
+
"""
|
|
62
|
+
if not force and vdb_has_id(vdb, doc_id):
|
|
63
|
+
return None # Document already exists, return None
|
|
64
|
+
if isinstance(meta, list):
|
|
65
|
+
metadatas = meta
|
|
66
|
+
elif isinstance(meta, dict):
|
|
67
|
+
metadatas = [meta]
|
|
68
|
+
else:
|
|
69
|
+
metadatas = None
|
|
70
|
+
ids = vdb.add_texts(texts=[text], ids=[doc_id], metadatas=metadatas) # type: ignore
|
|
71
|
+
if not ids:
|
|
72
|
+
return None
|
|
73
|
+
return ids[0] # Return the id of the added document
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def vdb_get_by_id(vdb: Chroma, doc_id: str):
|
|
77
|
+
"""Get document with by id"""
|
|
78
|
+
collection = vdb._collection # pylint: disable=protected-access
|
|
79
|
+
return collection.get(ids=[doc_id]) # query by id
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def vdb_has_id(vdb: Chroma, doc_id: str):
|
|
83
|
+
"""Check if a document with a given id exists in the database"""
|
|
84
|
+
result = vdb_get_by_id(vdb, doc_id)
|
|
85
|
+
return len(result["ids"]) > 0
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# Load database
|
|
89
|
+
def vdb_query( # noqa: PLR0913, pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
90
|
+
vdb: Chroma,
|
|
91
|
+
query: str,
|
|
92
|
+
filter: dict[str, str] | None = None, # pylint: disable=redefined-builtin
|
|
93
|
+
where_document: dict[str, str] | None = None,
|
|
94
|
+
max_items=10,
|
|
95
|
+
similarity_threshold=None,
|
|
96
|
+
):
|
|
97
|
+
"""
|
|
98
|
+
Query vector database with a given query, return top k results.
|
|
99
|
+
Args:
|
|
100
|
+
query: str, query string
|
|
101
|
+
max_items: int, maximum number of items to return
|
|
102
|
+
similarity_threshold: float, similarity threshold to filter the results
|
|
103
|
+
"""
|
|
104
|
+
results = vdb.similarity_search_with_relevance_scores(
|
|
105
|
+
query, k=max_items, filter=filter, where_document=where_document
|
|
106
|
+
)
|
|
107
|
+
logger.debug(
|
|
108
|
+
"Got %s results before filter, first one's similarity score is: %s",
|
|
109
|
+
len(results),
|
|
110
|
+
results[0][1] if results else None,
|
|
111
|
+
)
|
|
112
|
+
if similarity_threshold is not None:
|
|
113
|
+
results = [(doc_id, score) for doc_id, score in results if score > similarity_threshold]
|
|
114
|
+
print(f"Got {len(results)} results after filter")
|
|
115
|
+
return results
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from google import genai
|
|
5
|
+
|
|
6
|
+
from aixtools.logging.logging_config import get_logger
|
|
7
|
+
from aixtools.utils.config import GOOGLE_CLOUD_LOCATION, GOOGLE_CLOUD_PROJECT, GOOGLE_GENAI_USE_VERTEXAI
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_genai_client(service_account_key_path: Path | None = None) -> genai.Client:
|
|
13
|
+
"""Initialize and return a Google GenAI client using Vertex AI / Gemini Developer API."""
|
|
14
|
+
assert GOOGLE_CLOUD_PROJECT, "GOOGLE_CLOUD_PROJECT is not set"
|
|
15
|
+
assert GOOGLE_CLOUD_LOCATION, "GOOGLE_CLOUD_LOCATION is not set"
|
|
16
|
+
if service_account_key_path:
|
|
17
|
+
if not service_account_key_path.exists():
|
|
18
|
+
raise FileNotFoundError(f"Service account key file not found: {service_account_key_path}")
|
|
19
|
+
logger.info(f"✅ GCP Service Account Key File: {service_account_key_path}")
|
|
20
|
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = str(service_account_key_path)
|
|
21
|
+
return genai.Client(
|
|
22
|
+
vertexai=GOOGLE_GENAI_USE_VERTEXAI,
|
|
23
|
+
project=GOOGLE_CLOUD_PROJECT,
|
|
24
|
+
location=GOOGLE_CLOUD_LOCATION,
|
|
25
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Streamlit application to visualize agent nodes from log files.
|
|
3
|
+
|
|
4
|
+
This package provides tools to:
|
|
5
|
+
- View the most recent log file by default
|
|
6
|
+
- Open and analyze other log files
|
|
7
|
+
- Visualize nodes from agent runs with expandable/collapsible sections
|
|
8
|
+
- Filter nodes by various criteria
|
|
9
|
+
- Export visualizations
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from aixtools.log_view.app import main, main_cli
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"main",
|
|
16
|
+
"main_cli",
|
|
17
|
+
]
|
aixtools/log_view/app.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main application module for the Agent Log Viewer.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import os
|
|
7
|
+
import subprocess
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import streamlit as st
|
|
11
|
+
|
|
12
|
+
from aixtools.log_view.display import display_node
|
|
13
|
+
from aixtools.log_view.export import export_nodes_to_json
|
|
14
|
+
from aixtools.log_view.filters import filter_nodes
|
|
15
|
+
from aixtools.log_view.log_utils import format_timestamp_from_filename, get_log_files
|
|
16
|
+
from aixtools.log_view.node_summary import NodeTitle, extract_node_types
|
|
17
|
+
|
|
18
|
+
# Now we can import our modules
|
|
19
|
+
from aixtools.logging.log_objects import load_from_log
|
|
20
|
+
from aixtools.utils.config import LOGS_DIR
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def main(log_dir: Path | None = None): # noqa: PLR0915, pylint: disable=too-many-locals,too-many-statements
|
|
24
|
+
"""Main function to run the Streamlit app."""
|
|
25
|
+
st.set_page_config(
|
|
26
|
+
page_title="Agent Log Viewer",
|
|
27
|
+
layout="wide",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
st.title("Agent Log Viewer")
|
|
31
|
+
|
|
32
|
+
# Use provided log directory or default
|
|
33
|
+
if log_dir is None:
|
|
34
|
+
log_dir = LOGS_DIR
|
|
35
|
+
|
|
36
|
+
# Create the logs directory if it doesn't exist
|
|
37
|
+
os.makedirs(log_dir, exist_ok=True)
|
|
38
|
+
|
|
39
|
+
st.sidebar.header("Settings")
|
|
40
|
+
|
|
41
|
+
# Allow user to select a different log directory
|
|
42
|
+
custom_log_dir = st.sidebar.text_input("Log Directory", value=str(log_dir))
|
|
43
|
+
if custom_log_dir and custom_log_dir != str(log_dir):
|
|
44
|
+
log_dir = Path(custom_log_dir)
|
|
45
|
+
# Create the custom directory if it doesn't exist
|
|
46
|
+
os.makedirs(log_dir, exist_ok=True)
|
|
47
|
+
|
|
48
|
+
# Get log files
|
|
49
|
+
log_files = get_log_files(log_dir)
|
|
50
|
+
|
|
51
|
+
if not log_files:
|
|
52
|
+
st.warning(f"No log files found in {log_dir}")
|
|
53
|
+
st.info("Run an agent with logging enabled to create log files.")
|
|
54
|
+
return
|
|
55
|
+
|
|
56
|
+
# Create a dictionary of log files with formatted timestamps as display names
|
|
57
|
+
log_file_options = {f"{format_timestamp_from_filename(f.name)} - {f.name}": f for f in log_files}
|
|
58
|
+
|
|
59
|
+
# Select log file (default to most recent)
|
|
60
|
+
selected_log_file_name = st.sidebar.selectbox(
|
|
61
|
+
"Select Log File",
|
|
62
|
+
options=list(log_file_options.keys()),
|
|
63
|
+
index=0,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
selected_log_file = log_file_options[selected_log_file_name]
|
|
67
|
+
|
|
68
|
+
st.sidebar.info(f"Selected: {selected_log_file.name}")
|
|
69
|
+
|
|
70
|
+
# Load nodes
|
|
71
|
+
try:
|
|
72
|
+
with st.spinner("Loading log file..."):
|
|
73
|
+
nodes = load_from_log(selected_log_file)
|
|
74
|
+
|
|
75
|
+
st.success(f"Loaded {len(nodes)} nodes from {selected_log_file.name}")
|
|
76
|
+
|
|
77
|
+
# Create filter section in sidebar
|
|
78
|
+
st.sidebar.header("Filters")
|
|
79
|
+
|
|
80
|
+
# Text filter
|
|
81
|
+
filter_text = st.sidebar.text_input("Text Search", help="Filter nodes containing this text")
|
|
82
|
+
|
|
83
|
+
# Extract node types for filtering
|
|
84
|
+
node_types = extract_node_types(nodes)
|
|
85
|
+
|
|
86
|
+
# Type filter
|
|
87
|
+
selected_types = st.sidebar.multiselect(
|
|
88
|
+
"Node Types", options=sorted(node_types), default=[], help="Select node types to display"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Attribute filter
|
|
92
|
+
filter_attribute = st.sidebar.text_input("Has Attribute", help="Filter nodes that have this attribute")
|
|
93
|
+
|
|
94
|
+
# Regex filter
|
|
95
|
+
filter_regex = st.sidebar.text_input("Regex Pattern", help="Filter nodes matching this regex pattern")
|
|
96
|
+
|
|
97
|
+
# Combine all filters
|
|
98
|
+
filters = {"text": filter_text, "types": selected_types, "attribute": filter_attribute, "regex": filter_regex}
|
|
99
|
+
|
|
100
|
+
# Apply filters
|
|
101
|
+
filtered_nodes = filter_nodes(nodes, filters)
|
|
102
|
+
|
|
103
|
+
# Show filter results
|
|
104
|
+
if len(filtered_nodes) != len(nodes):
|
|
105
|
+
st.info(f"Filtered to {len(filtered_nodes)} of {len(nodes)} nodes")
|
|
106
|
+
|
|
107
|
+
# Display options
|
|
108
|
+
st.sidebar.header("Display Options")
|
|
109
|
+
|
|
110
|
+
# Option to expand all nodes by default
|
|
111
|
+
expand_all = st.sidebar.checkbox("Expand All Nodes", value=False)
|
|
112
|
+
|
|
113
|
+
# Option to select output format
|
|
114
|
+
display_format = st.sidebar.radio(
|
|
115
|
+
"Display Format",
|
|
116
|
+
options=["Markdown", "Rich", "JSON"],
|
|
117
|
+
index=0,
|
|
118
|
+
help="Select the format for displaying node content",
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Export options
|
|
122
|
+
st.sidebar.header("Export")
|
|
123
|
+
|
|
124
|
+
# Export to JSON
|
|
125
|
+
if st.sidebar.button("Export to JSON"):
|
|
126
|
+
json_str = export_nodes_to_json(filtered_nodes)
|
|
127
|
+
st.sidebar.download_button(
|
|
128
|
+
label="Download JSON",
|
|
129
|
+
data=json_str,
|
|
130
|
+
file_name=f"agent_nodes_{selected_log_file.stem}.json",
|
|
131
|
+
mime="application/json",
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Main content area - display nodes
|
|
135
|
+
if filtered_nodes:
|
|
136
|
+
node_title = NodeTitle()
|
|
137
|
+
# Display nodes with proper formatting
|
|
138
|
+
for i, node in enumerate(filtered_nodes):
|
|
139
|
+
# Create a header for each node
|
|
140
|
+
node_header = f"{i}: {node_title.summary(node)}"
|
|
141
|
+
|
|
142
|
+
# Display the node with proper formatting
|
|
143
|
+
with st.expander(node_header, expanded=expand_all):
|
|
144
|
+
try:
|
|
145
|
+
display_node(node, display_format=display_format)
|
|
146
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
147
|
+
st.error(f"Error displaying node: {e}")
|
|
148
|
+
st.exception(e)
|
|
149
|
+
else:
|
|
150
|
+
st.warning("No nodes match the current filters")
|
|
151
|
+
|
|
152
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
153
|
+
st.error(f"Error loading or processing log file: {e}")
|
|
154
|
+
st.exception(e)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def parse_args():
|
|
158
|
+
"""Parse command line arguments."""
|
|
159
|
+
parser = argparse.ArgumentParser(description="Agent Log Viewer")
|
|
160
|
+
parser.add_argument("log_dir", nargs="?", type=Path, help="Directory containing log files (default: DATA_DIR/logs)")
|
|
161
|
+
return parser.parse_args()
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def main_cli():
|
|
165
|
+
"""Entry point for the command-line tool."""
|
|
166
|
+
cmd_args = parse_args()
|
|
167
|
+
|
|
168
|
+
# Print a message to indicate the app is starting
|
|
169
|
+
print("Starting Agent Log Viewer...")
|
|
170
|
+
print(f"Log directory: {cmd_args.log_dir or LOGS_DIR}")
|
|
171
|
+
|
|
172
|
+
# Launch the Streamlit app
|
|
173
|
+
|
|
174
|
+
# Get the path to this script
|
|
175
|
+
script_path = Path(__file__).resolve()
|
|
176
|
+
|
|
177
|
+
# Use streamlit run to start the app
|
|
178
|
+
cmd = ["streamlit", "run", str(script_path)]
|
|
179
|
+
|
|
180
|
+
# Add log_dir argument if provided
|
|
181
|
+
if cmd_args.log_dir:
|
|
182
|
+
cmd.extend(["--", str(cmd_args.log_dir)])
|
|
183
|
+
|
|
184
|
+
# Run the command
|
|
185
|
+
try:
|
|
186
|
+
subprocess.run(cmd, check=False)
|
|
187
|
+
except KeyboardInterrupt:
|
|
188
|
+
print("\nShutting down Agent Log Viewer...")
|
|
189
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
190
|
+
print(f"Error running Streamlit app: {e}")
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
if __name__ == "__main__":
|
|
194
|
+
args = parse_args()
|
|
195
|
+
main(args.log_dir)
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functions for displaying nodes in the Streamlit interface.
|
|
3
|
+
Provides enhanced display capabilities for various object types,
|
|
4
|
+
including dataclasses, with proper handling of nested structures.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import inspect
|
|
8
|
+
import json
|
|
9
|
+
from dataclasses import fields as dataclass_fields
|
|
10
|
+
from dataclasses import is_dataclass
|
|
11
|
+
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import streamlit as st
|
|
14
|
+
from rich.console import Console
|
|
15
|
+
|
|
16
|
+
from aixtools.utils.utils import prepend_all_lines
|
|
17
|
+
|
|
18
|
+
# Toggle for using markdown display instead of JSON
|
|
19
|
+
USE_MARKDOWN = True
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def filter_private_fields(data_dict: dict) -> dict:
|
|
23
|
+
"""Filter out private fields from the data dictionary."""
|
|
24
|
+
return {k: v for k, v in data_dict.items() if not k.startswith("_")}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def filter_private_attributes(obj) -> dict:
|
|
28
|
+
"""
|
|
29
|
+
Filter out private attributes and methods from an object.
|
|
30
|
+
Returns a dictionary of public attributes and their values.
|
|
31
|
+
"""
|
|
32
|
+
if not hasattr(obj, "__dict__"):
|
|
33
|
+
return {}
|
|
34
|
+
|
|
35
|
+
result = {}
|
|
36
|
+
for attr, value in vars(obj).items():
|
|
37
|
+
if not attr.startswith("_"):
|
|
38
|
+
result[attr] = value
|
|
39
|
+
|
|
40
|
+
return result
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def is_method(obj, attr_name: str) -> bool:
|
|
44
|
+
"""Check if an attribute is a method."""
|
|
45
|
+
try:
|
|
46
|
+
attr = getattr(obj, attr_name)
|
|
47
|
+
return inspect.ismethod(attr) or inspect.isfunction(attr)
|
|
48
|
+
except (AttributeError, TypeError):
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_object_type_str(obj) -> str: # noqa: PLR0911, pylint: disable=too-many-return-statements
|
|
53
|
+
"""Get a string representation of the object's type."""
|
|
54
|
+
if obj is None:
|
|
55
|
+
return "null"
|
|
56
|
+
if isinstance(obj, bool):
|
|
57
|
+
return "bool"
|
|
58
|
+
if isinstance(obj, int):
|
|
59
|
+
return "int"
|
|
60
|
+
if isinstance(obj, float):
|
|
61
|
+
return "float"
|
|
62
|
+
if isinstance(obj, str):
|
|
63
|
+
return "str"
|
|
64
|
+
if isinstance(obj, list):
|
|
65
|
+
return f"list[{len(obj)}]"
|
|
66
|
+
if isinstance(obj, tuple):
|
|
67
|
+
return f"tuple[{len(obj)}]"
|
|
68
|
+
if isinstance(obj, dict):
|
|
69
|
+
return f"dict[{len(obj)}]"
|
|
70
|
+
if isinstance(obj, set):
|
|
71
|
+
return f"set[{len(obj)}]"
|
|
72
|
+
if is_dataclass(obj):
|
|
73
|
+
return f"dataclass:{type(obj).__name__}"
|
|
74
|
+
if hasattr(obj, "__dict__"):
|
|
75
|
+
return type(obj).__name__
|
|
76
|
+
|
|
77
|
+
return type(obj).__name__
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def object_to_json_with_types(obj, max_depth: int = 5, current_depth: int = 0): # noqa: PLR0911, PLR0912, pylint: disable=too-many-return-statements,too-many-branches
|
|
81
|
+
"""
|
|
82
|
+
Convert an object to a JSON-serializable dictionary with type information.
|
|
83
|
+
Handles nested objects up to max_depth.
|
|
84
|
+
"""
|
|
85
|
+
# Prevent infinite recursion
|
|
86
|
+
if current_depth > max_depth:
|
|
87
|
+
return {"__type": get_object_type_str(obj), "__value": str(obj)}
|
|
88
|
+
|
|
89
|
+
# Handle None
|
|
90
|
+
if obj is None:
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
# Handle basic types
|
|
94
|
+
if isinstance(obj, (bool, int, float, str)):
|
|
95
|
+
return obj
|
|
96
|
+
|
|
97
|
+
# Handle lists and tuples
|
|
98
|
+
if isinstance(obj, (list, tuple)):
|
|
99
|
+
items = []
|
|
100
|
+
for item in obj:
|
|
101
|
+
items.append(object_to_json_with_types(item, max_depth, current_depth + 1))
|
|
102
|
+
return items
|
|
103
|
+
|
|
104
|
+
# Handle dictionaries
|
|
105
|
+
if isinstance(obj, dict):
|
|
106
|
+
result = {}
|
|
107
|
+
for key, value in filter_private_fields(obj).items():
|
|
108
|
+
result[key] = object_to_json_with_types(value, max_depth, current_depth + 1)
|
|
109
|
+
return result
|
|
110
|
+
|
|
111
|
+
# Handle sets
|
|
112
|
+
if isinstance(obj, set):
|
|
113
|
+
items = []
|
|
114
|
+
for item in obj:
|
|
115
|
+
items.append(object_to_json_with_types(item, max_depth, current_depth + 1))
|
|
116
|
+
return {"__type": "set", "__items": items}
|
|
117
|
+
|
|
118
|
+
# Handle dataclasses
|
|
119
|
+
if is_dataclass(obj):
|
|
120
|
+
result = {"__type": f"dataclass:{type(obj).__name__}"}
|
|
121
|
+
for field in dataclass_fields(obj):
|
|
122
|
+
if field.name.startswith("_"): # Skip private fields
|
|
123
|
+
continue
|
|
124
|
+
if not hasattr(obj, field.name): # Skip not found
|
|
125
|
+
continue
|
|
126
|
+
value = getattr(obj, field.name)
|
|
127
|
+
result[field.name] = object_to_json_with_types(value, max_depth, current_depth + 1)
|
|
128
|
+
return result
|
|
129
|
+
|
|
130
|
+
# Handle objects with __dict__
|
|
131
|
+
if hasattr(obj, "__dict__"):
|
|
132
|
+
result = {"__type": type(obj).__name__}
|
|
133
|
+
for attr, value in filter_private_attributes(obj).items():
|
|
134
|
+
if not is_method(obj, attr): # Skip methods
|
|
135
|
+
result[attr] = object_to_json_with_types(value, max_depth, current_depth + 1)
|
|
136
|
+
return result
|
|
137
|
+
|
|
138
|
+
# Handle other types
|
|
139
|
+
return {"__type": get_object_type_str(obj), "__value": str(obj)}
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def object_to_markdown( # noqa: PLR0911, PLR0912, PLR0915, pylint: disable=too-many-locals,too-many-return-statements,too-many-branches,too-many-statements
|
|
143
|
+
obj, max_depth: int = 5, current_depth: int = 0, indent: str = ""
|
|
144
|
+
) -> str:
|
|
145
|
+
"""
|
|
146
|
+
Convert an object to a compact markdown representation.
|
|
147
|
+
Handles nested objects up to max_depth.
|
|
148
|
+
"""
|
|
149
|
+
max_display_items = 10 # Show only first MAX_DISPLAY_ITEMS items for large collections, dicts, and sets
|
|
150
|
+
|
|
151
|
+
# Prevent infinite recursion
|
|
152
|
+
if current_depth > max_depth:
|
|
153
|
+
return f"`{get_object_type_str(obj)}`: {str(obj)}"
|
|
154
|
+
|
|
155
|
+
# Handle None
|
|
156
|
+
if obj is None:
|
|
157
|
+
return "`None`"
|
|
158
|
+
|
|
159
|
+
# Handle basic types
|
|
160
|
+
if isinstance(obj, bool):
|
|
161
|
+
return f"`{str(obj).lower()}`"
|
|
162
|
+
|
|
163
|
+
if isinstance(obj, (int, float)):
|
|
164
|
+
return f"`{obj}`"
|
|
165
|
+
|
|
166
|
+
if isinstance(obj, str):
|
|
167
|
+
lines = str(obj).splitlines()
|
|
168
|
+
if len(lines) > 1:
|
|
169
|
+
return f"\n{indent}```\n{prepend_all_lines(obj, prepend=indent)}\n{indent}```\n"
|
|
170
|
+
return obj
|
|
171
|
+
|
|
172
|
+
# Handle lists and tuples
|
|
173
|
+
if isinstance(obj, (list, tuple)):
|
|
174
|
+
if not obj: # Empty collection
|
|
175
|
+
return f"`{get_object_type_str(obj)}`: empty"
|
|
176
|
+
|
|
177
|
+
max_inline_length = 3 # For small collections, show inline
|
|
178
|
+
if (
|
|
179
|
+
len(obj) <= max_inline_length
|
|
180
|
+
and current_depth > 0
|
|
181
|
+
and all(isinstance(x, (bool, int, float, str, type(None))) for x in obj)
|
|
182
|
+
):
|
|
183
|
+
items = [object_to_markdown(item, max_depth, current_depth + 1) for item in obj]
|
|
184
|
+
return f"`{get_object_type_str(obj)}`: [{', '.join(items)}]"
|
|
185
|
+
|
|
186
|
+
# For larger collections, use bullet points
|
|
187
|
+
result = [f"`{get_object_type_str(obj)}`:"]
|
|
188
|
+
for i, item in enumerate(obj):
|
|
189
|
+
if i >= max_display_items and len(obj) > max_display_items + 2:
|
|
190
|
+
result.append(f"{indent}* ... ({len(obj) - 10} more items)")
|
|
191
|
+
break
|
|
192
|
+
item_md = object_to_markdown(item, max_depth, current_depth + 1, indent + " ")
|
|
193
|
+
result.append(f"{indent}* {item_md}")
|
|
194
|
+
return "\n".join(result)
|
|
195
|
+
|
|
196
|
+
# Handle dictionaries
|
|
197
|
+
if isinstance(obj, dict):
|
|
198
|
+
if not obj: # Empty dict
|
|
199
|
+
return "`dict`: empty"
|
|
200
|
+
|
|
201
|
+
result = [f"`dict[{len(obj)}]`:"]
|
|
202
|
+
for i, (key, value) in enumerate(filter_private_fields(obj).items()):
|
|
203
|
+
if i >= max_display_items and len(obj) > max_display_items + 2:
|
|
204
|
+
result.append(f"{indent}* ... ({len(obj) - 10} more items)")
|
|
205
|
+
break
|
|
206
|
+
value_md = object_to_markdown(value, max_depth, current_depth + 1, indent + " ")
|
|
207
|
+
result.append(f"{indent}* **{key}**: {value_md}")
|
|
208
|
+
return "\n".join(result)
|
|
209
|
+
|
|
210
|
+
# Handle sets
|
|
211
|
+
if isinstance(obj, set):
|
|
212
|
+
if not obj: # Empty set
|
|
213
|
+
return "`set`: empty"
|
|
214
|
+
|
|
215
|
+
result = [f"`set[{len(obj)}]`:"]
|
|
216
|
+
for i, item in enumerate(obj):
|
|
217
|
+
if i >= max_display_items and len(obj) > max_display_items + 2:
|
|
218
|
+
result.append(f"{indent}* ... ({len(obj) - 10} more items)")
|
|
219
|
+
break
|
|
220
|
+
item_md = object_to_markdown(item, max_depth, current_depth + 1, indent + " ")
|
|
221
|
+
result.append(f"{indent}* {item_md}")
|
|
222
|
+
return "\n".join(result)
|
|
223
|
+
|
|
224
|
+
# Handle dataclasses
|
|
225
|
+
if is_dataclass(obj):
|
|
226
|
+
result = [f"`{type(obj).__name__}:`"]
|
|
227
|
+
for field in dataclass_fields(obj):
|
|
228
|
+
if field.name.startswith("_"): # Skip private fields
|
|
229
|
+
continue
|
|
230
|
+
if not hasattr(obj, field.name): # Skip not found
|
|
231
|
+
continue
|
|
232
|
+
value = getattr(obj, field.name)
|
|
233
|
+
value_md = object_to_markdown(value, max_depth, current_depth + 1, indent + " ")
|
|
234
|
+
result.append(f"{indent}* **{field.name}**: {value_md}")
|
|
235
|
+
return "\n".join(result)
|
|
236
|
+
|
|
237
|
+
# Handle objects with __dict__
|
|
238
|
+
if hasattr(obj, "__dict__"):
|
|
239
|
+
attrs = filter_private_attributes(obj)
|
|
240
|
+
if not attrs: # No public attributes
|
|
241
|
+
return f"`{type(obj).__name__}`: (no public attributes)"
|
|
242
|
+
|
|
243
|
+
result = [f"`{type(obj).__name__}`:"]
|
|
244
|
+
for attr, value in attrs.items():
|
|
245
|
+
if not is_method(obj, attr): # Skip methods
|
|
246
|
+
value_md = object_to_markdown(value, max_depth, current_depth + 1, indent + " ")
|
|
247
|
+
result.append(f"{indent}* **{attr}**: {value_md}")
|
|
248
|
+
return "\n".join(result)
|
|
249
|
+
|
|
250
|
+
# Handle other types
|
|
251
|
+
return f"`{get_object_type_str(obj)}`: {str(obj)}"
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def format_json_for_display(json_obj) -> str:
|
|
255
|
+
"""Format a JSON object for display with proper indentation."""
|
|
256
|
+
return json.dumps(json_obj, indent=2, default=str)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def display_node(node, display_format: str) -> None:
|
|
260
|
+
"""
|
|
261
|
+
Display node content based on its type, with enhanced formatting.
|
|
262
|
+
"""
|
|
263
|
+
# Special handling for specific types
|
|
264
|
+
if isinstance(node, pd.DataFrame):
|
|
265
|
+
st.dataframe(node)
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
# Toggle between markdown and JSON display
|
|
269
|
+
match display_format:
|
|
270
|
+
case "Markdown":
|
|
271
|
+
st.markdown(object_to_markdown(node))
|
|
272
|
+
case "JSON":
|
|
273
|
+
st.json(object_to_json_with_types(node))
|
|
274
|
+
case "Rich":
|
|
275
|
+
st.write(rich_print(node))
|
|
276
|
+
case _:
|
|
277
|
+
raise ValueError(f"Unsupported display format: {display_format}")
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def rich_print(node):
|
|
281
|
+
"""Display a node using rich print."""
|
|
282
|
+
console = Console(color_system=None)
|
|
283
|
+
with console.capture() as capture:
|
|
284
|
+
console.print(node)
|
|
285
|
+
return f"```\n{capture.get()}\n```"
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functions for exporting nodes to various formats.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def export_nodes_to_json(nodes: list) -> str:
|
|
9
|
+
"""Export nodes to a JSON string for download."""
|
|
10
|
+
# Convert nodes to a serializable format
|
|
11
|
+
serializable_nodes = []
|
|
12
|
+
|
|
13
|
+
for node in nodes:
|
|
14
|
+
if hasattr(node, "__dict__"):
|
|
15
|
+
# For objects with attributes
|
|
16
|
+
node_dict = {
|
|
17
|
+
"type": type(node).__name__,
|
|
18
|
+
"attributes": {
|
|
19
|
+
attr: str(value) if not isinstance(value, (dict, list, int, float, bool, type(None))) else value
|
|
20
|
+
for attr, value in vars(node).items()
|
|
21
|
+
if not attr.startswith("_")
|
|
22
|
+
},
|
|
23
|
+
}
|
|
24
|
+
serializable_nodes.append(node_dict)
|
|
25
|
+
elif isinstance(node, dict):
|
|
26
|
+
# For dictionaries
|
|
27
|
+
serializable_nodes.append(
|
|
28
|
+
{
|
|
29
|
+
"type": "dict",
|
|
30
|
+
"content": {
|
|
31
|
+
str(k): str(v) if not isinstance(v, (dict, list, int, float, bool, type(None))) else v
|
|
32
|
+
for k, v in node.items()
|
|
33
|
+
},
|
|
34
|
+
}
|
|
35
|
+
)
|
|
36
|
+
elif isinstance(node, (list, tuple)):
|
|
37
|
+
# For lists and tuples
|
|
38
|
+
serializable_nodes.append(
|
|
39
|
+
{
|
|
40
|
+
"type": "list" if isinstance(node, list) else "tuple",
|
|
41
|
+
"content": [
|
|
42
|
+
str(item) if not isinstance(item, (dict, list, int, float, bool, type(None))) else item
|
|
43
|
+
for item in node
|
|
44
|
+
],
|
|
45
|
+
}
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
# For primitive types
|
|
49
|
+
serializable_nodes.append({"type": type(node).__name__, "value": str(node)})
|
|
50
|
+
|
|
51
|
+
return json.dumps(serializable_nodes, indent=2)
|