remdb 0.3.127__py3-none-any.whl → 0.3.172__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of remdb might be problematic. Click here for more details.
- rem/agentic/agents/__init__.py +16 -0
- rem/agentic/agents/agent_manager.py +311 -0
- rem/agentic/context.py +81 -3
- rem/agentic/context_builder.py +36 -9
- rem/agentic/mcp/tool_wrapper.py +132 -15
- rem/agentic/providers/phoenix.py +371 -108
- rem/agentic/providers/pydantic_ai.py +163 -45
- rem/agentic/schema.py +8 -4
- rem/api/deps.py +3 -5
- rem/api/main.py +22 -3
- rem/api/mcp_router/resources.py +15 -10
- rem/api/mcp_router/server.py +2 -0
- rem/api/mcp_router/tools.py +94 -2
- rem/api/middleware/tracking.py +5 -5
- rem/api/routers/auth.py +349 -6
- rem/api/routers/chat/completions.py +5 -3
- rem/api/routers/chat/streaming.py +95 -22
- rem/api/routers/messages.py +24 -15
- rem/auth/__init__.py +13 -3
- rem/auth/jwt.py +352 -0
- rem/auth/middleware.py +115 -10
- rem/auth/providers/__init__.py +4 -1
- rem/auth/providers/email.py +215 -0
- rem/cli/commands/configure.py +3 -4
- rem/cli/commands/experiments.py +226 -50
- rem/cli/commands/session.py +336 -0
- rem/cli/dreaming.py +2 -2
- rem/cli/main.py +2 -0
- rem/models/core/experiment.py +58 -14
- rem/models/entities/__init__.py +4 -0
- rem/models/entities/ontology.py +1 -1
- rem/models/entities/ontology_config.py +1 -1
- rem/models/entities/subscriber.py +175 -0
- rem/models/entities/user.py +1 -0
- rem/schemas/agents/core/agent-builder.yaml +235 -0
- rem/schemas/agents/examples/contract-analyzer.yaml +1 -1
- rem/schemas/agents/examples/contract-extractor.yaml +1 -1
- rem/schemas/agents/examples/cv-parser.yaml +1 -1
- rem/services/__init__.py +3 -1
- rem/services/content/service.py +4 -3
- rem/services/email/__init__.py +10 -0
- rem/services/email/service.py +513 -0
- rem/services/email/templates.py +360 -0
- rem/services/postgres/README.md +38 -0
- rem/services/postgres/diff_service.py +19 -3
- rem/services/postgres/pydantic_to_sqlalchemy.py +45 -13
- rem/services/postgres/repository.py +5 -4
- rem/services/session/compression.py +113 -50
- rem/services/session/reload.py +14 -7
- rem/services/user_service.py +41 -9
- rem/settings.py +292 -5
- rem/sql/migrations/001_install.sql +1 -1
- rem/sql/migrations/002_install_models.sql +91 -91
- rem/sql/migrations/005_schema_update.sql +145 -0
- rem/utils/README.md +45 -0
- rem/utils/files.py +157 -1
- rem/utils/schema_loader.py +45 -7
- rem/utils/vision.py +1 -1
- {remdb-0.3.127.dist-info → remdb-0.3.172.dist-info}/METADATA +7 -5
- {remdb-0.3.127.dist-info → remdb-0.3.172.dist-info}/RECORD +62 -52
- {remdb-0.3.127.dist-info → remdb-0.3.172.dist-info}/WHEEL +0 -0
- {remdb-0.3.127.dist-info → remdb-0.3.172.dist-info}/entry_points.txt +0 -0
rem/cli/commands/experiments.py
CHANGED
|
@@ -63,6 +63,7 @@ def experiments():
|
|
|
63
63
|
@experiments.command("create")
|
|
64
64
|
@click.argument("name")
|
|
65
65
|
@click.option("--agent", "-a", required=True, help="Agent schema name (e.g., 'cv-parser')")
|
|
66
|
+
@click.option("--task", "-t", default="general", help="Task name for organizing experiments (e.g., 'risk-assessment')")
|
|
66
67
|
@click.option("--evaluator", "-e", default="default", help="Evaluator schema name (default: 'default')")
|
|
67
68
|
@click.option("--description", "-d", help="Experiment description")
|
|
68
69
|
@click.option("--dataset-location", type=click.Choice(["git", "s3", "hybrid"]), default="git",
|
|
@@ -74,6 +75,7 @@ def experiments():
|
|
|
74
75
|
def create(
|
|
75
76
|
name: str,
|
|
76
77
|
agent: str,
|
|
78
|
+
task: str,
|
|
77
79
|
evaluator: str,
|
|
78
80
|
description: Optional[str],
|
|
79
81
|
dataset_location: str,
|
|
@@ -123,19 +125,17 @@ def create(
|
|
|
123
125
|
# Resolve base path: CLI arg > EXPERIMENTS_HOME env var > default "experiments"
|
|
124
126
|
if base_path is None:
|
|
125
127
|
base_path = os.getenv("EXPERIMENTS_HOME", "experiments")
|
|
126
|
-
# Build dataset reference
|
|
128
|
+
# Build dataset reference (format auto-detected from file extension)
|
|
127
129
|
if dataset_location == "git":
|
|
128
130
|
dataset_ref = DatasetReference(
|
|
129
131
|
location=DatasetLocation.GIT,
|
|
130
132
|
path="ground-truth/dataset.csv",
|
|
131
|
-
format="csv",
|
|
132
133
|
description="Ground truth Q&A dataset for evaluation"
|
|
133
134
|
)
|
|
134
135
|
else: # s3 or hybrid
|
|
135
136
|
dataset_ref = DatasetReference(
|
|
136
137
|
location=DatasetLocation(dataset_location),
|
|
137
138
|
path=f"s3://rem-experiments/{name}/datasets/ground_truth.parquet",
|
|
138
|
-
format="parquet",
|
|
139
139
|
schema_path="datasets/schema.yaml" if dataset_location == "hybrid" else None,
|
|
140
140
|
description="Ground truth dataset for evaluation"
|
|
141
141
|
)
|
|
@@ -170,7 +170,8 @@ def create(
|
|
|
170
170
|
# Create experiment config
|
|
171
171
|
config = ExperimentConfig(
|
|
172
172
|
name=name,
|
|
173
|
-
|
|
173
|
+
task=task,
|
|
174
|
+
description=description or f"Evaluation experiment for {agent} agent ({task} task)",
|
|
174
175
|
agent_schema_ref=SchemaReference(
|
|
175
176
|
name=agent,
|
|
176
177
|
version=None, # Use latest by default
|
|
@@ -912,58 +913,61 @@ def run(
|
|
|
912
913
|
click.echo(f" Last error: {evaluator_load_error}")
|
|
913
914
|
raise click.Abort()
|
|
914
915
|
|
|
915
|
-
#
|
|
916
|
-
|
|
916
|
+
# Validate evaluator credentials before running expensive agent tasks
|
|
917
|
+
if evaluator_fn is not None and not only_vibes:
|
|
918
|
+
from rem.agentic.providers.phoenix import validate_evaluator_credentials
|
|
919
|
+
|
|
920
|
+
click.echo("Validating evaluator credentials...")
|
|
921
|
+
is_valid, error_msg = validate_evaluator_credentials()
|
|
922
|
+
if not is_valid:
|
|
923
|
+
click.echo(click.style(f"\n⚠️ Evaluator validation failed: {error_msg}", fg="yellow"))
|
|
924
|
+
click.echo("\nOptions:")
|
|
925
|
+
click.echo(" 1. Fix the credentials issue and re-run")
|
|
926
|
+
click.echo(" 2. Run with --only-vibes to skip LLM evaluation")
|
|
927
|
+
click.echo(" 3. Use --evaluator-model to specify a different model")
|
|
928
|
+
raise click.Abort()
|
|
929
|
+
click.echo("✓ Evaluator credentials validated")
|
|
930
|
+
|
|
931
|
+
# Load dataset using read_dataframe utility (auto-detects format from extension)
|
|
932
|
+
from rem.utils.files import read_dataframe
|
|
917
933
|
|
|
918
934
|
click.echo(f"Loading dataset: {list(config.datasets.keys())[0]}")
|
|
919
935
|
dataset_ref = list(config.datasets.values())[0]
|
|
920
936
|
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
if dataset_ref.format == "csv":
|
|
929
|
-
dataset_df = pl.read_csv(dataset_path)
|
|
930
|
-
elif dataset_ref.format == "parquet":
|
|
931
|
-
dataset_df = pl.read_parquet(dataset_path)
|
|
932
|
-
elif dataset_ref.format == "jsonl":
|
|
933
|
-
dataset_df = pl.read_ndjson(dataset_path)
|
|
934
|
-
else:
|
|
935
|
-
click.echo(f"Error: Format '{dataset_ref.format}' not yet supported")
|
|
936
|
-
raise click.Abort()
|
|
937
|
-
elif dataset_ref.location.value in ["s3", "hybrid"]:
|
|
938
|
-
# Load from S3 using FS provider
|
|
939
|
-
from rem.services.fs import FS
|
|
940
|
-
from io import BytesIO
|
|
937
|
+
try:
|
|
938
|
+
if dataset_ref.location.value == "git":
|
|
939
|
+
# Load from Git (local filesystem)
|
|
940
|
+
dataset_path = Path(base_path) / name / dataset_ref.path
|
|
941
|
+
if not dataset_path.exists():
|
|
942
|
+
click.echo(f"Error: Dataset not found: {dataset_path}")
|
|
943
|
+
raise click.Abort()
|
|
941
944
|
|
|
942
|
-
|
|
945
|
+
dataset_df = read_dataframe(dataset_path)
|
|
943
946
|
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
dataset_df = pl.read_csv(BytesIO(content.encode() if isinstance(content, str) else content))
|
|
948
|
-
elif dataset_ref.format == "parquet":
|
|
949
|
-
content_bytes = fs.read(dataset_ref.path)
|
|
950
|
-
dataset_df = pl.read_parquet(BytesIO(content_bytes if isinstance(content_bytes, bytes) else content_bytes.encode()))
|
|
951
|
-
elif dataset_ref.format == "jsonl":
|
|
952
|
-
content = fs.read(dataset_ref.path)
|
|
953
|
-
dataset_df = pl.read_ndjson(BytesIO(content.encode() if isinstance(content, str) else content))
|
|
954
|
-
else:
|
|
955
|
-
click.echo(f"Error: Format '{dataset_ref.format}' not yet supported")
|
|
956
|
-
raise click.Abort()
|
|
947
|
+
elif dataset_ref.location.value in ["s3", "hybrid"]:
|
|
948
|
+
# Load from S3 using FS provider
|
|
949
|
+
from rem.services.fs import FS
|
|
957
950
|
|
|
951
|
+
fs = FS()
|
|
952
|
+
content = fs.read(dataset_ref.path)
|
|
953
|
+
# Ensure we have bytes
|
|
954
|
+
if isinstance(content, str):
|
|
955
|
+
content = content.encode()
|
|
956
|
+
dataset_df = read_dataframe(content, filename=dataset_ref.path)
|
|
958
957
|
click.echo(f"✓ Loaded dataset from S3")
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
click.echo(f"Error:
|
|
962
|
-
click.echo(f" Path: {dataset_ref.path}")
|
|
963
|
-
click.echo(f" Format: {dataset_ref.format}")
|
|
958
|
+
|
|
959
|
+
else:
|
|
960
|
+
click.echo(f"Error: Unknown dataset location: {dataset_ref.location.value}")
|
|
964
961
|
raise click.Abort()
|
|
965
|
-
|
|
966
|
-
|
|
962
|
+
|
|
963
|
+
except ValueError as e:
|
|
964
|
+
# Unsupported format error from read_dataframe
|
|
965
|
+
click.echo(f"Error: {e}")
|
|
966
|
+
raise click.Abort()
|
|
967
|
+
except Exception as e:
|
|
968
|
+
logger.error(f"Failed to load dataset: {e}")
|
|
969
|
+
click.echo(f"Error: Could not load dataset")
|
|
970
|
+
click.echo(f" Path: {dataset_ref.path}")
|
|
967
971
|
raise click.Abort()
|
|
968
972
|
|
|
969
973
|
click.echo(f"✓ Loaded dataset: {len(dataset_df)} examples")
|
|
@@ -1283,7 +1287,7 @@ def prompt():
|
|
|
1283
1287
|
@click.option("--system-prompt", "-s", required=True, help="System prompt text")
|
|
1284
1288
|
@click.option("--description", "-d", help="Prompt description")
|
|
1285
1289
|
@click.option("--model-provider", default="OPENAI", help="Model provider (OPENAI, ANTHROPIC)")
|
|
1286
|
-
@click.option("--model-name", "-m", help="Model name (e.g., gpt-
|
|
1290
|
+
@click.option("--model-name", "-m", help="Model name (e.g., gpt-4.1, claude-sonnet-4-5)")
|
|
1287
1291
|
@click.option("--type", "-t", "prompt_type", default="Agent", help="Prompt type (Agent or Evaluator)")
|
|
1288
1292
|
def prompt_create(
|
|
1289
1293
|
name: str,
|
|
@@ -1299,7 +1303,7 @@ def prompt_create(
|
|
|
1299
1303
|
# Create agent prompt
|
|
1300
1304
|
rem experiments prompt create hello-world \\
|
|
1301
1305
|
--system-prompt "You are a helpful assistant." \\
|
|
1302
|
-
--model-name gpt-
|
|
1306
|
+
--model-name gpt-4.1
|
|
1303
1307
|
|
|
1304
1308
|
# Create evaluator prompt
|
|
1305
1309
|
rem experiments prompt create correctness-evaluator \\
|
|
@@ -1317,7 +1321,7 @@ def prompt_create(
|
|
|
1317
1321
|
try:
|
|
1318
1322
|
# Set default model if not specified
|
|
1319
1323
|
if not model_name:
|
|
1320
|
-
model_name = "gpt-
|
|
1324
|
+
model_name = "gpt-4.1" if model_provider == "OPENAI" else "claude-sonnet-4-5-20250929"
|
|
1321
1325
|
|
|
1322
1326
|
# Get config
|
|
1323
1327
|
phoenix_client = PhoenixClient()
|
|
@@ -1520,3 +1524,175 @@ def trace_list(
|
|
|
1520
1524
|
logger.error(f"Failed to list traces: {e}")
|
|
1521
1525
|
click.echo(f"Error: {e}", err=True)
|
|
1522
1526
|
raise click.Abort()
|
|
1527
|
+
|
|
1528
|
+
|
|
1529
|
+
# =============================================================================
|
|
1530
|
+
# EXPORT COMMAND
|
|
1531
|
+
# =============================================================================
|
|
1532
|
+
|
|
1533
|
+
|
|
1534
|
+
@experiments.command("export")
|
|
1535
|
+
@click.argument("name")
|
|
1536
|
+
@click.option("--base-path", help="Base directory for experiments (default: EXPERIMENTS_HOME or 'experiments')")
|
|
1537
|
+
@click.option("--bucket", "-b", help="S3 bucket name (default: DATA_LAKE__BUCKET_NAME)")
|
|
1538
|
+
@click.option("--version", "-v", default="v0", help="Data lake version prefix (default: v0)")
|
|
1539
|
+
@click.option("--plan", is_flag=True, help="Show what would be exported without uploading")
|
|
1540
|
+
@click.option("--include-results", is_flag=True, help="Include results directory in export")
|
|
1541
|
+
def export(
|
|
1542
|
+
name: str,
|
|
1543
|
+
base_path: Optional[str],
|
|
1544
|
+
bucket: Optional[str],
|
|
1545
|
+
version: str,
|
|
1546
|
+
plan: bool,
|
|
1547
|
+
include_results: bool,
|
|
1548
|
+
):
|
|
1549
|
+
"""Export experiment to S3 data lake.
|
|
1550
|
+
|
|
1551
|
+
Exports experiment configuration, ground truth, and optionally results
|
|
1552
|
+
to the S3 data lake following the convention:
|
|
1553
|
+
|
|
1554
|
+
s3://{bucket}/{version}/datasets/calibration/experiments/{agent}/{task}/
|
|
1555
|
+
|
|
1556
|
+
The export includes:
|
|
1557
|
+
- experiment.yaml (configuration)
|
|
1558
|
+
- README.md (documentation)
|
|
1559
|
+
- ground-truth/ (evaluation datasets)
|
|
1560
|
+
- seed-data/ (optional seed data)
|
|
1561
|
+
- results/ (optional, with --include-results)
|
|
1562
|
+
|
|
1563
|
+
Examples:
|
|
1564
|
+
# Preview what would be exported
|
|
1565
|
+
rem experiments export my-experiment --plan
|
|
1566
|
+
|
|
1567
|
+
# Export to configured data lake bucket
|
|
1568
|
+
rem experiments export my-experiment
|
|
1569
|
+
|
|
1570
|
+
# Export to specific bucket
|
|
1571
|
+
rem experiments export my-experiment --bucket siggy-data
|
|
1572
|
+
|
|
1573
|
+
# Include results in export
|
|
1574
|
+
rem experiments export my-experiment --include-results
|
|
1575
|
+
|
|
1576
|
+
# Export with custom version prefix
|
|
1577
|
+
rem experiments export my-experiment --version v1
|
|
1578
|
+
"""
|
|
1579
|
+
from rem.models.core.experiment import ExperimentConfig
|
|
1580
|
+
from rem.settings import settings
|
|
1581
|
+
from rem.services.fs.s3_provider import S3Provider
|
|
1582
|
+
import os
|
|
1583
|
+
import json
|
|
1584
|
+
|
|
1585
|
+
try:
|
|
1586
|
+
# Resolve base path
|
|
1587
|
+
if base_path is None:
|
|
1588
|
+
base_path = os.getenv("EXPERIMENTS_HOME", "experiments")
|
|
1589
|
+
|
|
1590
|
+
# Load experiment configuration
|
|
1591
|
+
config_path = Path(base_path) / name / "experiment.yaml"
|
|
1592
|
+
if not config_path.exists():
|
|
1593
|
+
click.echo(f"Experiment not found: {name}")
|
|
1594
|
+
click.echo(f" Looked in: {config_path}")
|
|
1595
|
+
raise click.Abort()
|
|
1596
|
+
|
|
1597
|
+
config = ExperimentConfig.from_yaml(config_path)
|
|
1598
|
+
click.echo(f"✓ Loaded experiment: {name}")
|
|
1599
|
+
|
|
1600
|
+
# Resolve bucket
|
|
1601
|
+
if bucket is None:
|
|
1602
|
+
bucket = settings.data_lake.bucket_name
|
|
1603
|
+
if bucket is None:
|
|
1604
|
+
click.echo("Error: No S3 bucket configured.")
|
|
1605
|
+
click.echo(" Set DATA_LAKE__BUCKET_NAME environment variable or use --bucket option")
|
|
1606
|
+
raise click.Abort()
|
|
1607
|
+
|
|
1608
|
+
# Build S3 paths
|
|
1609
|
+
s3_base = config.get_s3_export_path(bucket, version)
|
|
1610
|
+
exp_dir = config.get_experiment_dir(base_path)
|
|
1611
|
+
|
|
1612
|
+
# Collect files to export
|
|
1613
|
+
files_to_export = []
|
|
1614
|
+
|
|
1615
|
+
# Always include these files
|
|
1616
|
+
required_files = [
|
|
1617
|
+
("experiment.yaml", exp_dir / "experiment.yaml"),
|
|
1618
|
+
("README.md", exp_dir / "README.md"),
|
|
1619
|
+
]
|
|
1620
|
+
|
|
1621
|
+
for s3_name, local_path in required_files:
|
|
1622
|
+
if local_path.exists():
|
|
1623
|
+
files_to_export.append((s3_name, local_path))
|
|
1624
|
+
|
|
1625
|
+
# Include ground-truth directory
|
|
1626
|
+
ground_truth_dir = exp_dir / "ground-truth"
|
|
1627
|
+
if ground_truth_dir.exists():
|
|
1628
|
+
for f in ground_truth_dir.rglob("*"):
|
|
1629
|
+
if f.is_file():
|
|
1630
|
+
relative = f.relative_to(exp_dir)
|
|
1631
|
+
files_to_export.append((str(relative), f))
|
|
1632
|
+
|
|
1633
|
+
# Include seed-data directory
|
|
1634
|
+
seed_data_dir = exp_dir / "seed-data"
|
|
1635
|
+
if seed_data_dir.exists():
|
|
1636
|
+
for f in seed_data_dir.rglob("*"):
|
|
1637
|
+
if f.is_file():
|
|
1638
|
+
relative = f.relative_to(exp_dir)
|
|
1639
|
+
files_to_export.append((str(relative), f))
|
|
1640
|
+
|
|
1641
|
+
# Optionally include results
|
|
1642
|
+
if include_results:
|
|
1643
|
+
results_dir = exp_dir / "results"
|
|
1644
|
+
if results_dir.exists():
|
|
1645
|
+
for f in results_dir.rglob("*"):
|
|
1646
|
+
if f.is_file():
|
|
1647
|
+
relative = f.relative_to(exp_dir)
|
|
1648
|
+
files_to_export.append((str(relative), f))
|
|
1649
|
+
|
|
1650
|
+
# Display export plan
|
|
1651
|
+
click.echo(f"\n{'=' * 60}")
|
|
1652
|
+
click.echo(f"EXPORT {'PLAN' if plan else 'TO S3'}")
|
|
1653
|
+
click.echo(f"{'=' * 60}")
|
|
1654
|
+
click.echo(f"\nExperiment: {config.name}")
|
|
1655
|
+
click.echo(f"Agent: {config.agent_schema_ref.name}")
|
|
1656
|
+
click.echo(f"Task: {config.task}")
|
|
1657
|
+
click.echo(f"Evaluator file: {config.get_evaluator_filename()}")
|
|
1658
|
+
click.echo(f"\nDestination: {s3_base}/")
|
|
1659
|
+
click.echo(f"\nFiles to export ({len(files_to_export)}):")
|
|
1660
|
+
|
|
1661
|
+
for s3_name, local_path in files_to_export:
|
|
1662
|
+
s3_uri = f"{s3_base}/{s3_name}"
|
|
1663
|
+
if plan:
|
|
1664
|
+
click.echo(f" {local_path}")
|
|
1665
|
+
click.echo(f" → {s3_uri}")
|
|
1666
|
+
else:
|
|
1667
|
+
click.echo(f" {s3_name}")
|
|
1668
|
+
|
|
1669
|
+
if plan:
|
|
1670
|
+
click.echo(f"\n[PLAN MODE] No files were uploaded.")
|
|
1671
|
+
click.echo(f"Run without --plan to execute the export.")
|
|
1672
|
+
return
|
|
1673
|
+
|
|
1674
|
+
# Execute export
|
|
1675
|
+
click.echo(f"\n⏳ Uploading to S3...")
|
|
1676
|
+
s3 = S3Provider()
|
|
1677
|
+
|
|
1678
|
+
uploaded = 0
|
|
1679
|
+
for s3_name, local_path in files_to_export:
|
|
1680
|
+
s3_uri = f"{s3_base}/{s3_name}"
|
|
1681
|
+
try:
|
|
1682
|
+
s3.copy(str(local_path), s3_uri)
|
|
1683
|
+
uploaded += 1
|
|
1684
|
+
click.echo(f" ✓ {s3_name}")
|
|
1685
|
+
except Exception as e:
|
|
1686
|
+
click.echo(f" ✗ {s3_name}: {e}")
|
|
1687
|
+
|
|
1688
|
+
click.echo(f"\n✓ Exported {uploaded}/{len(files_to_export)} files to {s3_base}/")
|
|
1689
|
+
|
|
1690
|
+
# Show next steps
|
|
1691
|
+
click.echo(f"\nNext steps:")
|
|
1692
|
+
click.echo(f" - View in S3: aws s3 ls {s3_base}/ --recursive")
|
|
1693
|
+
click.echo(f" - Download: aws s3 sync {s3_base}/ ./{config.agent_schema_ref.name}/{config.task}/")
|
|
1694
|
+
|
|
1695
|
+
except Exception as e:
|
|
1696
|
+
logger.error(f"Failed to export experiment: {e}")
|
|
1697
|
+
click.echo(f"Error: {e}", err=True)
|
|
1698
|
+
raise click.Abort()
|
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CLI command for viewing and simulating session conversations.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
rem session show <user_id> [--session-id] [--role user|assistant|system]
|
|
6
|
+
rem session show <user_id> --simulate-next [--save] [--custom-sim-prompt "..."]
|
|
7
|
+
|
|
8
|
+
Examples:
|
|
9
|
+
# Show all messages for a user
|
|
10
|
+
rem session show 11111111-1111-1111-1111-111111111001
|
|
11
|
+
|
|
12
|
+
# Show only user messages
|
|
13
|
+
rem session show 11111111-1111-1111-1111-111111111001 --role user
|
|
14
|
+
|
|
15
|
+
# Simulate next user message
|
|
16
|
+
rem session show 11111111-1111-1111-1111-111111111001 --simulate-next
|
|
17
|
+
|
|
18
|
+
# Simulate with custom prompt and save
|
|
19
|
+
rem session show 11111111-1111-1111-1111-111111111001 --simulate-next --save \
|
|
20
|
+
--custom-sim-prompt "Respond as an anxious patient"
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import asyncio
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Literal
|
|
26
|
+
|
|
27
|
+
import click
|
|
28
|
+
import yaml
|
|
29
|
+
from loguru import logger
|
|
30
|
+
|
|
31
|
+
from ...models.entities.user import User
|
|
32
|
+
from ...models.entities.message import Message
|
|
33
|
+
from ...services.postgres import get_postgres_service
|
|
34
|
+
from ...services.postgres.repository import Repository
|
|
35
|
+
from ...settings import settings
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
SIMULATOR_PROMPT = """You are simulating a patient in a mental health conversation.
|
|
39
|
+
|
|
40
|
+
## Context
|
|
41
|
+
You are continuing a conversation with a clinical evaluation agent. Based on the
|
|
42
|
+
user profile and conversation history below, generate the next realistic patient message.
|
|
43
|
+
|
|
44
|
+
## User Profile
|
|
45
|
+
{user_profile}
|
|
46
|
+
|
|
47
|
+
## Conversation History
|
|
48
|
+
{conversation_history}
|
|
49
|
+
|
|
50
|
+
## Instructions
|
|
51
|
+
- Stay in character as the patient described in the profile
|
|
52
|
+
- Your response should be natural, conversational, and consistent with the patient's presentation
|
|
53
|
+
- Consider the patient's risk level, symptoms, and communication style
|
|
54
|
+
- Do NOT include any metadata or role labels - just the raw message content
|
|
55
|
+
- Keep responses concise (1-3 sentences typical for conversation)
|
|
56
|
+
|
|
57
|
+
Generate the next patient message:"""
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
async def _load_user_and_messages(
|
|
61
|
+
user_id: str,
|
|
62
|
+
session_id: str | None = None,
|
|
63
|
+
role_filter: str | None = None,
|
|
64
|
+
limit: int = 100,
|
|
65
|
+
) -> tuple[User | None, list[Message]]:
|
|
66
|
+
"""Load user profile and messages from database."""
|
|
67
|
+
pg = get_postgres_service()
|
|
68
|
+
if not pg:
|
|
69
|
+
logger.error("PostgreSQL not available")
|
|
70
|
+
return None, []
|
|
71
|
+
|
|
72
|
+
await pg.connect()
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
# Load user
|
|
76
|
+
user_repo = Repository(User, "users", db=pg)
|
|
77
|
+
user = await user_repo.get_by_id(user_id, tenant_id="default")
|
|
78
|
+
|
|
79
|
+
# Load messages
|
|
80
|
+
message_repo = Repository(Message, "messages", db=pg)
|
|
81
|
+
filters = {"user_id": user_id}
|
|
82
|
+
if session_id:
|
|
83
|
+
filters["session_id"] = session_id
|
|
84
|
+
|
|
85
|
+
messages = await message_repo.find(
|
|
86
|
+
filters=filters,
|
|
87
|
+
order_by="created_at ASC",
|
|
88
|
+
limit=limit,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Filter by role if specified
|
|
92
|
+
if role_filter:
|
|
93
|
+
messages = [m for m in messages if m.message_type == role_filter]
|
|
94
|
+
|
|
95
|
+
return user, messages
|
|
96
|
+
|
|
97
|
+
finally:
|
|
98
|
+
await pg.disconnect()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _format_user_yaml(user: User | None) -> str:
|
|
102
|
+
"""Format user profile as YAML."""
|
|
103
|
+
if not user:
|
|
104
|
+
return "# No user found"
|
|
105
|
+
|
|
106
|
+
data = {
|
|
107
|
+
"id": str(user.id),
|
|
108
|
+
"name": user.name,
|
|
109
|
+
"summary": user.summary,
|
|
110
|
+
"interests": user.interests,
|
|
111
|
+
"preferred_topics": user.preferred_topics,
|
|
112
|
+
"metadata": user.metadata,
|
|
113
|
+
}
|
|
114
|
+
return yaml.dump(data, default_flow_style=False, allow_unicode=True)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _format_messages_yaml(messages: list[Message]) -> str:
|
|
118
|
+
"""Format messages as YAML."""
|
|
119
|
+
if not messages:
|
|
120
|
+
return "# No messages found"
|
|
121
|
+
|
|
122
|
+
data = []
|
|
123
|
+
for msg in messages:
|
|
124
|
+
data.append({
|
|
125
|
+
"role": msg.message_type or "unknown",
|
|
126
|
+
"content": msg.content,
|
|
127
|
+
"session_id": msg.session_id,
|
|
128
|
+
"created_at": msg.created_at.isoformat() if msg.created_at else None,
|
|
129
|
+
})
|
|
130
|
+
return yaml.dump(data, default_flow_style=False, allow_unicode=True)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _format_conversation_for_llm(messages: list[Message]) -> str:
|
|
134
|
+
"""Format conversation history for LLM context."""
|
|
135
|
+
lines = []
|
|
136
|
+
for msg in messages:
|
|
137
|
+
role = msg.message_type or "unknown"
|
|
138
|
+
lines.append(f"[{role.upper()}]: {msg.content}")
|
|
139
|
+
return "\n\n".join(lines) if lines else "(No previous messages)"
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
async def _simulate_next_message(
|
|
143
|
+
user: User | None,
|
|
144
|
+
messages: list[Message],
|
|
145
|
+
custom_prompt: str | None = None,
|
|
146
|
+
) -> str:
|
|
147
|
+
"""Use LLM to simulate the next patient message."""
|
|
148
|
+
from pydantic_ai import Agent
|
|
149
|
+
|
|
150
|
+
# Build context
|
|
151
|
+
user_profile = _format_user_yaml(user) if user else "Unknown patient"
|
|
152
|
+
conversation_history = _format_conversation_for_llm(messages)
|
|
153
|
+
|
|
154
|
+
# Use custom prompt or default
|
|
155
|
+
if custom_prompt:
|
|
156
|
+
# Check if it's a file path
|
|
157
|
+
if Path(custom_prompt).exists():
|
|
158
|
+
prompt_template = Path(custom_prompt).read_text()
|
|
159
|
+
else:
|
|
160
|
+
prompt_template = custom_prompt
|
|
161
|
+
# Simple variable substitution
|
|
162
|
+
prompt = prompt_template.replace("{user_profile}", user_profile)
|
|
163
|
+
prompt = prompt.replace("{conversation_history}", conversation_history)
|
|
164
|
+
else:
|
|
165
|
+
prompt = SIMULATOR_PROMPT.format(
|
|
166
|
+
user_profile=user_profile,
|
|
167
|
+
conversation_history=conversation_history,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Create simple agent for simulation
|
|
171
|
+
agent = Agent(
|
|
172
|
+
model=settings.llm.default_model,
|
|
173
|
+
system_prompt="You are a patient simulator. Generate realistic patient responses.",
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
result = await agent.run(prompt)
|
|
177
|
+
return result.output
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
async def _save_message(
|
|
181
|
+
user_id: str,
|
|
182
|
+
session_id: str | None,
|
|
183
|
+
content: str,
|
|
184
|
+
role: str = "user",
|
|
185
|
+
) -> Message:
|
|
186
|
+
"""Save a simulated message to the database."""
|
|
187
|
+
from uuid import uuid4
|
|
188
|
+
|
|
189
|
+
pg = get_postgres_service()
|
|
190
|
+
if not pg:
|
|
191
|
+
raise RuntimeError("PostgreSQL not available")
|
|
192
|
+
|
|
193
|
+
await pg.connect()
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
message_repo = Repository(Message, "messages", db=pg)
|
|
197
|
+
|
|
198
|
+
message = Message(
|
|
199
|
+
id=uuid4(),
|
|
200
|
+
user_id=user_id,
|
|
201
|
+
tenant_id="default",
|
|
202
|
+
session_id=session_id or str(uuid4()),
|
|
203
|
+
content=content,
|
|
204
|
+
message_type=role,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
await message_repo.upsert(message)
|
|
208
|
+
return message
|
|
209
|
+
|
|
210
|
+
finally:
|
|
211
|
+
await pg.disconnect()
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@click.group()
|
|
215
|
+
def session():
|
|
216
|
+
"""Session viewing and simulation commands."""
|
|
217
|
+
pass
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@session.command("show")
|
|
221
|
+
@click.argument("user_id")
|
|
222
|
+
@click.option("--session-id", "-s", help="Filter by session ID")
|
|
223
|
+
@click.option(
|
|
224
|
+
"--role", "-r",
|
|
225
|
+
type=click.Choice(["user", "assistant", "system", "tool"]),
|
|
226
|
+
help="Filter messages by role",
|
|
227
|
+
)
|
|
228
|
+
@click.option("--limit", "-l", default=100, help="Max messages to load")
|
|
229
|
+
@click.option("--simulate-next", is_flag=True, help="Simulate the next patient message")
|
|
230
|
+
@click.option("--save", is_flag=True, help="Save simulated message to database")
|
|
231
|
+
@click.option(
|
|
232
|
+
"--custom-sim-prompt", "-p",
|
|
233
|
+
help="Custom simulation prompt (text or file path)",
|
|
234
|
+
)
|
|
235
|
+
def show(
|
|
236
|
+
user_id: str,
|
|
237
|
+
session_id: str | None,
|
|
238
|
+
role: str | None,
|
|
239
|
+
limit: int,
|
|
240
|
+
simulate_next: bool,
|
|
241
|
+
save: bool,
|
|
242
|
+
custom_sim_prompt: str | None,
|
|
243
|
+
):
|
|
244
|
+
"""
|
|
245
|
+
Show user profile and session messages.
|
|
246
|
+
|
|
247
|
+
USER_ID: The user identifier to load.
|
|
248
|
+
|
|
249
|
+
Examples:
|
|
250
|
+
|
|
251
|
+
# Show user and all messages
|
|
252
|
+
rem session show 11111111-1111-1111-1111-111111111001
|
|
253
|
+
|
|
254
|
+
# Show only assistant responses
|
|
255
|
+
rem session show 11111111-1111-1111-1111-111111111001 --role assistant
|
|
256
|
+
|
|
257
|
+
# Simulate next patient message
|
|
258
|
+
rem session show 11111111-1111-1111-1111-111111111001 --simulate-next
|
|
259
|
+
|
|
260
|
+
# Simulate and save to database
|
|
261
|
+
rem session show 11111111-1111-1111-1111-111111111001 --simulate-next --save
|
|
262
|
+
"""
|
|
263
|
+
asyncio.run(_show_async(
|
|
264
|
+
user_id=user_id,
|
|
265
|
+
session_id=session_id,
|
|
266
|
+
role_filter=role,
|
|
267
|
+
limit=limit,
|
|
268
|
+
simulate_next=simulate_next,
|
|
269
|
+
save=save,
|
|
270
|
+
custom_sim_prompt=custom_sim_prompt,
|
|
271
|
+
))
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
async def _show_async(
|
|
275
|
+
user_id: str,
|
|
276
|
+
session_id: str | None,
|
|
277
|
+
role_filter: str | None,
|
|
278
|
+
limit: int,
|
|
279
|
+
simulate_next: bool,
|
|
280
|
+
save: bool,
|
|
281
|
+
custom_sim_prompt: str | None,
|
|
282
|
+
):
|
|
283
|
+
"""Async implementation of show command."""
|
|
284
|
+
# Load data
|
|
285
|
+
user, messages = await _load_user_and_messages(
|
|
286
|
+
user_id=user_id,
|
|
287
|
+
session_id=session_id,
|
|
288
|
+
role_filter=role_filter if not simulate_next else None, # Need all messages for simulation
|
|
289
|
+
limit=limit,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Display user profile
|
|
293
|
+
click.echo("\n# User Profile")
|
|
294
|
+
click.echo("---")
|
|
295
|
+
click.echo(_format_user_yaml(user))
|
|
296
|
+
|
|
297
|
+
# Display messages (apply filter for display if simulating)
|
|
298
|
+
display_messages = messages
|
|
299
|
+
if simulate_next and role_filter:
|
|
300
|
+
display_messages = [m for m in messages if m.message_type == role_filter]
|
|
301
|
+
|
|
302
|
+
click.echo("\n# Messages")
|
|
303
|
+
click.echo("---")
|
|
304
|
+
click.echo(_format_messages_yaml(display_messages))
|
|
305
|
+
|
|
306
|
+
# Simulate next message if requested
|
|
307
|
+
if simulate_next:
|
|
308
|
+
click.echo("\n# Simulated Next Message")
|
|
309
|
+
click.echo("---")
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
simulated = await _simulate_next_message(
|
|
313
|
+
user=user,
|
|
314
|
+
messages=messages,
|
|
315
|
+
custom_prompt=custom_sim_prompt,
|
|
316
|
+
)
|
|
317
|
+
click.echo(f"role: user")
|
|
318
|
+
click.echo(f"content: |\n {simulated}")
|
|
319
|
+
|
|
320
|
+
if save:
|
|
321
|
+
saved_msg = await _save_message(
|
|
322
|
+
user_id=user_id,
|
|
323
|
+
session_id=session_id,
|
|
324
|
+
content=simulated,
|
|
325
|
+
role="user",
|
|
326
|
+
)
|
|
327
|
+
logger.success(f"Saved message: {saved_msg.id}")
|
|
328
|
+
|
|
329
|
+
except Exception as e:
|
|
330
|
+
logger.error(f"Simulation failed: {e}")
|
|
331
|
+
raise
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def register_command(cli_group):
|
|
335
|
+
"""Register the session command group."""
|
|
336
|
+
cli_group.add_command(session)
|