vizro-mcp 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vizro_mcp/__init__.py CHANGED
@@ -3,7 +3,7 @@ import sys
3
3
 
4
4
  from .server import mcp
5
5
 
6
- __version__ = "0.1.0"
6
+ __version__ = "0.1.2"
7
7
 
8
8
 
9
9
  def main():
@@ -1,29 +1,11 @@
1
1
  from .schemas import (
2
- MODEL_GROUPS,
3
2
  AgGridEnhanced,
4
3
  ChartPlan,
5
- ContainerSimplified,
6
- DashboardSimplified,
7
- FilterSimplified,
8
4
  GraphEnhanced,
9
- PageSimplified,
10
- ParameterSimplified,
11
- TabsSimplified,
12
- get_overview_vizro_models,
13
- get_simple_dashboard_config,
14
5
  )
15
6
 
16
7
  __all__ = [
17
- "MODEL_GROUPS",
18
8
  "AgGridEnhanced",
19
9
  "ChartPlan",
20
- "ContainerSimplified",
21
- "DashboardSimplified",
22
- "FilterSimplified",
23
10
  "GraphEnhanced",
24
- "PageSimplified",
25
- "ParameterSimplified",
26
- "TabsSimplified",
27
- "get_overview_vizro_models",
28
- "get_simple_dashboard_config",
29
11
  ]
@@ -1,15 +1,13 @@
1
1
  """Schema defining pydantic models for usage in the MCP server."""
2
2
 
3
- from typing import Annotated, Any, Literal, Optional
3
+ from typing import Annotated, Any, Optional
4
4
 
5
5
  import vizro.models as vm
6
- from pydantic import AfterValidator, BaseModel, Field, PrivateAttr, conlist
6
+ from pydantic import AfterValidator, BaseModel, Field, PrivateAttr, ValidationInfo
7
7
 
8
- from vizro_mcp._utils import SAMPLE_DASHBOARD_CONFIG, DFMetaData
8
+ from vizro_mcp._utils import DFMetaData
9
9
 
10
- # Constants used in chart validation
11
- CUSTOM_CHART_NAME = "custom_chart"
12
- ADDITIONAL_IMPORTS = [
10
+ BASE_IMPORTS = [
13
11
  "import vizro.plotly.express as px",
14
12
  "import plotly.graph_objects as go",
15
13
  "import pandas as pd",
@@ -17,91 +15,6 @@ ADDITIONAL_IMPORTS = [
17
15
  "from vizro.models.types import capture",
18
16
  ]
19
17
 
20
- # These types are used to simplify the schema for the LLM.
21
- SimplifiedComponentType = Literal["Card", "Button", "Text", "Container", "Tabs", "Graph", "AgGrid"]
22
- SimplifiedSelectorType = Literal[
23
- "Dropdown", "RadioItems", "Checklist", "DatePicker", "Slider", "RangeSlider", "DatePicker"
24
- ]
25
- SimplifiedControlType = Literal["Filter", "Parameter"]
26
- SimplifiedLayoutType = Literal["Grid", "Flex"]
27
-
28
- # This dict is used to give the model and overview of what is available in the vizro.models namespace.
29
- # It helps it to narrow down the choices when asking for a model.
30
- MODEL_GROUPS: dict[str, list[type[vm.VizroBaseModel]]] = {
31
- "main": [vm.Dashboard, vm.Page],
32
- "components": [vm.Card, vm.Button, vm.Text, vm.Container, vm.Tabs, vm.Graph, vm.AgGrid], #'Figure', 'Table'
33
- "layouts": [vm.Grid, vm.Flex],
34
- "controls": [vm.Filter, vm.Parameter],
35
- "selectors": [
36
- vm.Dropdown,
37
- vm.RadioItems,
38
- vm.Checklist,
39
- vm.DatePicker,
40
- vm.Slider,
41
- vm.RangeSlider,
42
- vm.DatePicker,
43
- ],
44
- "navigation": [vm.Navigation, vm.NavBar, vm.NavLink],
45
- }
46
-
47
-
48
- # These simplified page, container, tabs and dashboard models are used to return a flatter schema to the LLM in order to
49
- # reduce the context size. Especially the dashboard model schema is huge as it contains all other models.
50
-
51
-
52
- class FilterSimplified(vm.Filter):
53
- """Simplified Filter model for reduced schema. LLM should remember to insert actual components."""
54
-
55
- selector: Optional[SimplifiedSelectorType] = Field(
56
- default=None, description="Selector to be displayed. Only provide if asked for!"
57
- )
58
-
59
-
60
- class ParameterSimplified(vm.Parameter):
61
- """Simplified Parameter model for reduced schema. LLM should remember to insert actual components."""
62
-
63
- selector: SimplifiedSelectorType = Field(description="Selector to be displayed.")
64
-
65
-
66
- class ContainerSimplified(vm.Container):
67
- """Simplified Container model for reduced schema. LLM should remember to insert actual components."""
68
-
69
- components: list[SimplifiedComponentType] = Field(description="List of component names to be displayed.")
70
- layout: Optional[SimplifiedLayoutType] = Field(
71
- default=None, description="Layout to place components in. Only provide if asked for!"
72
- )
73
-
74
-
75
- class TabsSimplified(vm.Tabs):
76
- """Simplified Tabs model for reduced schema. LLM should remember to insert actual components."""
77
-
78
- tabs: conlist(ContainerSimplified, min_length=1)
79
-
80
-
81
- class PageSimplified(BaseModel):
82
- """Simplified Page modes for reduced schema. LLM should remember to insert actual components."""
83
-
84
- components: list[SimplifiedComponentType] = Field(description="List of component names to be displayed.")
85
- title: str = Field(description="Title to be displayed.")
86
- description: str = Field(default="", description="Description for meta tags.")
87
- layout: Optional[SimplifiedLayoutType] = Field(
88
- default=None, description="Layout to place components in. Only provide if asked for!"
89
- )
90
- controls: list[SimplifiedControlType] = Field(default=[], description="Controls to be displayed.")
91
-
92
-
93
- class DashboardSimplified(BaseModel):
94
- """Simplified Dashboard model for reduced schema. LLM should remember to insert actual components."""
95
-
96
- pages: list[Literal["Page"]] = Field(description="List of page names to be included in the dashboard.")
97
- theme: Literal["vizro_dark", "vizro_light"] = Field(
98
- default="vizro_dark", description="Theme to be applied across dashboard. Defaults to `vizro_dark`."
99
- )
100
- navigation: Optional[Literal["Navigation"]] = Field(
101
- default=None, description="Navigation component for the dashboard. Only provide if asked for!"
102
- )
103
- title: str = Field(default="", description="Dashboard title to appear on every page on top left-side.")
104
-
105
18
 
106
19
  # These enhanced models are used to return a more complete schema to the LLM. Although we do not have actual schemas for
107
20
  # the figure fields, we can prompt the model via the description to produce something likely correct.
@@ -110,15 +23,23 @@ class GraphEnhanced(vm.Graph):
110
23
 
111
24
  figure: dict[str, Any] = Field(
112
25
  description="""
26
+ For simpler charts and without need for data manipulation, use this field:
113
27
  This is the plotly express figure to be displayed. Only use valid plotly express functions to create the figure.
114
28
  Only use the arguments that are supported by the function you are using and where no extra modules such as statsmodels
115
- are needed (e.g. trendline).
116
-
29
+ are needed (e.g. trendline):
117
30
  - Configure a dictionary as if this would be added as **kwargs to the function you are using.
118
31
  - You must use the key: "_target_: "<function_name>" to specify the function you are using. Do NOT precede by
119
32
  namespace (like px.line)
120
- - you must refer to the dataframe by name, for now it is one of "gapminder", "iris", "tips".
33
+ - you must refer to the dataframe by name, check file_name in the data_infos field ("data_frame": "<file_name>")
121
34
  - do not use a title if your Graph model already has a title.
35
+
36
+ For more complex charts and those that require data manipulation, use the `custom_charts` field:
37
+ - create the suitable number of custom charts and add them to the `custom_charts` field
38
+ - refer here to the function signature you created
39
+ - you must use the key: "_target_: "<custom_chart_name>"
40
+ - you must refer to the dataframe by name, check file_name in the data_infos field ("data_frame": "<file_name>")
41
+ - in general, DO NOT modify the background (with plot_bgcolor) or color sequences unless explicitly asked for
42
+ - when creating hover templates, EXPLICITLY style them to work on light and dark mode
122
43
  """
123
44
  )
124
45
 
@@ -137,7 +58,7 @@ The only difference to the dash version is that:
137
58
  )
138
59
 
139
60
 
140
- ###### Chart functionality - not sure if I should include this in the MCP server
61
+ ###### Chart functionality ######
141
62
  def _strip_markdown(code_string: str) -> str:
142
63
  """Remove any code block wrappers (markdown or triple quotes)."""
143
64
  wrappers = [("```python\n", "```"), ("```py\n", "```"), ("```\n", "```"), ('"""', '"""'), ("'''", "'''")]
@@ -150,22 +71,19 @@ def _strip_markdown(code_string: str) -> str:
150
71
  return code_string.strip()
151
72
 
152
73
 
153
- def _check_chart_code(v: str) -> str:
74
+ def _check_chart_code(v: str, info: ValidationInfo) -> str:
154
75
  v = _strip_markdown(v)
155
76
 
156
77
  # TODO: add more checks: ends with return, has return, no second function def, only one indented line
157
- func_def = f"def {CUSTOM_CHART_NAME}("
78
+ func_def = f"def {info.data['chart_name']}("
158
79
  if func_def not in v:
159
- raise ValueError(f"The chart code must be wrapped in a function named `{CUSTOM_CHART_NAME}`")
80
+ raise ValueError(f"The chart code must be wrapped in a function named `{info.data['chart_name']}`")
160
81
 
161
82
  v = v[v.index(func_def) :].strip()
162
83
 
163
84
  first_line = v.split("\n")[0].strip()
164
- if "data_frame" not in first_line:
165
- raise ValueError(
166
- """The chart code must accept a single argument `data_frame`,
167
- and it should be the first argument of the chart."""
168
- )
85
+ if "(data_frame" not in first_line:
86
+ raise ValueError("""The chart code must accept as first argument `data_frame` which is a pandas DataFrame.""")
169
87
  return v
170
88
 
171
89
 
@@ -177,6 +95,12 @@ class ChartPlan(BaseModel):
177
95
  Describes the chart type that best reflects the user request.
178
96
  """,
179
97
  )
98
+ chart_name: str = Field(
99
+ description="""
100
+ The name of the chart function. Should be unique, concise and in snake_case.
101
+ """,
102
+ pattern=r"^[a-z][a-z0-9_]*$",
103
+ )
180
104
  imports: list[str] = Field(
181
105
  description="""
182
106
  List of import statements required to render the chart defined by the `chart_code` field. Ensure that every
@@ -192,20 +116,22 @@ class ChartPlan(BaseModel):
192
116
  Field(
193
117
  description="""
194
118
  Python code that generates a generates a plotly go.Figure object. It must fulfill the following criteria:
195
- 1. Must be wrapped in a function name
196
- 2. Must accept a single argument `data_frame` which is a pandas DataFrame
119
+ 1. Must be wrapped in a function that is named `chart_name`
120
+ 2. Must accept as first argument argument `data_frame` which is a pandas DataFrame
197
121
  3. Must return a plotly go.Figure object
198
122
  4. All data used in the chart must be derived from the data_frame argument, all data manipulations
199
123
  must be done within the function.
124
+ 5. DO NOT modify the background (with plot_bgcolor) or color sequences unless explicitly asked for
125
+ 6. When creating hover templates, explicitly ensure that it works on light and dark mode
200
126
  """,
201
127
  ),
202
128
  ]
203
129
 
204
- _additional_vizro_imports: list[str] = PrivateAttr(ADDITIONAL_IMPORTS)
130
+ _base_chart_imports: list[str] = PrivateAttr(BASE_IMPORTS)
205
131
 
206
132
  def get_imports(self, vizro: bool = False):
207
- imports = list(dict.fromkeys(self.imports + self._additional_vizro_imports)) # remove duplicates
208
- if vizro: # TODO: improve code of below
133
+ imports = list(dict.fromkeys(self.imports + self._base_chart_imports)) # remove duplicates
134
+ if vizro:
209
135
  imports = [imp for imp in imports if "import plotly.express as px" not in imp]
210
136
  else:
211
137
  imports = [imp for imp in imports if "vizro" not in imp]
@@ -214,9 +140,9 @@ class ChartPlan(BaseModel):
214
140
  def get_chart_code(self, chart_name: Optional[str] = None, vizro: bool = False):
215
141
  chart_code = self.chart_code
216
142
  if vizro:
217
- chart_code = chart_code.replace(f"def {CUSTOM_CHART_NAME}", f"@capture('graph')\ndef {CUSTOM_CHART_NAME}")
143
+ chart_code = chart_code.replace(f"def {self.chart_name}", f"@capture('graph')\ndef {self.chart_name}")
218
144
  if chart_name is not None:
219
- chart_code = chart_code.replace(f"def {CUSTOM_CHART_NAME}", f"def {chart_name}")
145
+ chart_code = chart_code.replace(f"def {self.chart_name}", f"def {chart_name}")
220
146
  return chart_code
221
147
 
222
148
  def get_dashboard_template(self, data_info: DFMetaData) -> str:
@@ -232,14 +158,14 @@ class ChartPlan(BaseModel):
232
158
  imports = self.get_imports(vizro=True)
233
159
 
234
160
  # Add the Vizro-specific imports if not present
235
- additional_imports = [
161
+ additional_dashboard_imports = [
236
162
  "import vizro.models as vm",
237
163
  "from vizro import Vizro",
238
164
  "from vizro.managers import data_manager",
239
165
  ]
240
166
 
241
167
  # Combine imports without duplicates
242
- all_imports = list(dict.fromkeys(additional_imports + imports.split("\n")))
168
+ all_imports = list(dict.fromkeys(additional_dashboard_imports + imports.split("\n")))
243
169
 
244
170
  dashboard_template = f"""
245
171
  {chr(10).join(imp for imp in all_imports if imp)}
@@ -259,7 +185,7 @@ dashboard = vm.Dashboard(
259
185
  components=[
260
186
  vm.Graph(
261
187
  id="{self.chart_type}_graph",
262
- figure={CUSTOM_CHART_NAME}("{data_info.file_name}"),
188
+ figure={self.chart_name}("{data_info.file_name}"),
263
189
  )
264
190
  ],
265
191
  )
@@ -274,24 +200,15 @@ Vizro().build(dashboard).run()
274
200
  return dashboard_template
275
201
 
276
202
 
277
- def get_overview_vizro_models() -> dict[str, list[dict[str, str]]]:
278
- """Get all available models in the vizro.models namespace.
279
-
280
- Returns:
281
- Dictionary with categories of models and their descriptions
282
- """
283
- result: dict[str, list[dict[str, str]]] = {}
284
- for category, models_list in MODEL_GROUPS.items():
285
- result[category] = [
286
- {
287
- "name": model_class.__name__,
288
- "description": (model_class.__doc__ or "No description available").split("\n")[0],
289
- }
290
- for model_class in models_list
291
- ]
292
- return result
293
-
203
+ if __name__ == "__main__":
204
+ plan = ChartPlan(
205
+ chart_type="scatter",
206
+ chart_name="scatter",
207
+ imports=["import pandas as pd", "import plotly.express as px"],
208
+ chart_code="""
209
+ def scatter(data_frame):
210
+ return px.scatter(data_frame, x="sepal_length", y="sepal_width")
211
+ """,
212
+ )
294
213
 
295
- def get_simple_dashboard_config() -> str:
296
- """Very simple Vizro dashboard configuration. Use this config as a starter when no other config is provided."""
297
- return SAMPLE_DASHBOARD_CONFIG
214
+ # print(plan.get_chart_code(chart_name="poo", vizro=True))
@@ -1,11 +1,21 @@
1
- from .utils import (
1
+ from .configs import (
2
2
  GAPMINDER,
3
3
  IRIS,
4
4
  SAMPLE_DASHBOARD_CONFIG,
5
5
  STOCKS,
6
6
  TIPS,
7
+ )
8
+ from .prompts import (
9
+ CHART_INSTRUCTIONS,
10
+ get_chart_prompt,
11
+ get_dashboard_instructions,
12
+ get_dashboard_prompt,
13
+ get_starter_dashboard_prompt,
14
+ )
15
+ from .utils import (
7
16
  DFInfo,
8
17
  DFMetaData,
18
+ NoDefsGenerateJsonSchema,
9
19
  VizroCodeAndPreviewLink,
10
20
  convert_github_url_to_raw,
11
21
  create_pycafe_url,
@@ -15,19 +25,28 @@ from .utils import (
15
25
  path_or_url_check,
16
26
  )
17
27
 
18
- __all__ = [
28
+ __all__ = [ # noqa: RUF022
29
+ # Classes
30
+ "DFInfo",
31
+ "DFMetaData",
32
+ "NoDefsGenerateJsonSchema",
33
+ "VizroCodeAndPreviewLink",
34
+ # Constants
35
+ "CHART_INSTRUCTIONS",
19
36
  "GAPMINDER",
20
37
  "IRIS",
21
38
  "SAMPLE_DASHBOARD_CONFIG",
22
39
  "STOCKS",
23
40
  "TIPS",
24
- "DFInfo",
25
- "DFMetaData",
26
- "VizroCodeAndPreviewLink",
41
+ # Functions
27
42
  "convert_github_url_to_raw",
28
43
  "create_pycafe_url",
29
44
  "get_dataframe_info",
45
+ "get_dashboard_instructions",
46
+ "get_dashboard_prompt",
47
+ "get_chart_prompt",
30
48
  "get_python_code_and_preview_link",
49
+ "get_starter_dashboard_prompt",
31
50
  "load_dataframe_by_format",
32
51
  "path_or_url_check",
33
52
  ]
@@ -0,0 +1,142 @@
1
+ """Pre-set configs for the Vizro MCP."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Literal, Optional
5
+
6
+
7
+ @dataclass
8
+ class DFMetaData:
9
+ file_name: str
10
+ file_path_or_url: str
11
+ file_location_type: Literal["local", "remote"]
12
+ read_function_string: Literal["pd.read_csv", "pd.read_json", "pd.read_html", "pd.read_parquet", "pd.read_excel"]
13
+ column_names_types: Optional[dict[str, str]] = None
14
+
15
+
16
+ @dataclass
17
+ class DFInfo:
18
+ general_info: str
19
+ sample: dict[str, Any]
20
+
21
+
22
+ IRIS = DFMetaData(
23
+ file_name="iris_data",
24
+ file_path_or_url="https://raw.githubusercontent.com/plotly/datasets/master/iris-id.csv",
25
+ file_location_type="remote",
26
+ read_function_string="pd.read_csv",
27
+ column_names_types={
28
+ "sepal_length": "float",
29
+ "sepal_width": "float",
30
+ "petal_length": "float",
31
+ "petal_width": "float",
32
+ "species": "str",
33
+ },
34
+ )
35
+
36
+ TIPS = DFMetaData(
37
+ file_name="tips_data",
38
+ file_path_or_url="https://raw.githubusercontent.com/plotly/datasets/master/tips.csv",
39
+ file_location_type="remote",
40
+ read_function_string="pd.read_csv",
41
+ column_names_types={
42
+ "total_bill": "float",
43
+ "tip": "float",
44
+ "sex": "str",
45
+ "smoker": "str",
46
+ "day": "str",
47
+ "time": "str",
48
+ "size": "int",
49
+ },
50
+ )
51
+
52
+ STOCKS = DFMetaData(
53
+ file_name="stocks_data",
54
+ file_path_or_url="https://raw.githubusercontent.com/plotly/datasets/master/stockdata.csv",
55
+ file_location_type="remote",
56
+ read_function_string="pd.read_csv",
57
+ column_names_types={
58
+ "Date": "str",
59
+ "IBM": "float",
60
+ "MSFT": "float",
61
+ "SBUX": "float",
62
+ "AAPL": "float",
63
+ "GSPC": "float",
64
+ },
65
+ )
66
+
67
+ GAPMINDER = DFMetaData(
68
+ file_name="gapminder_data",
69
+ file_path_or_url="https://raw.githubusercontent.com/plotly/datasets/master/gapminder_unfiltered.csv",
70
+ file_location_type="remote",
71
+ read_function_string="pd.read_csv",
72
+ column_names_types={
73
+ "country": "str",
74
+ "continent": "str",
75
+ "year": "int",
76
+ "lifeExp": "float",
77
+ "pop": "int",
78
+ "gdpPercap": "float",
79
+ },
80
+ )
81
+
82
+ SAMPLE_DASHBOARD_CONFIG = """
83
+ {
84
+ `config`: {
85
+ `pages`: [
86
+ {
87
+ `title`: `Iris Data Analysis`,
88
+ `controls`: [
89
+ {
90
+ `id`: `species_filter`,
91
+ `type`: `filter`,
92
+ `column`: `species`,
93
+ `targets`: [
94
+ `scatter_plot`
95
+ ],
96
+ `selector`: {
97
+ `type`: `dropdown`,
98
+ `multi`: true
99
+ }
100
+ }
101
+ ],
102
+ `components`: [
103
+ {
104
+ `id`: `scatter_plot`,
105
+ `type`: `graph`,
106
+ `title`: `Sepal Dimensions by Species`,
107
+ `figure`: {
108
+ `x`: `sepal_length`,
109
+ `y`: `sepal_width`,
110
+ `color`: `species`,
111
+ `_target_`: `scatter`,
112
+ `data_frame`: `iris_data`,
113
+ `hover_data`: [
114
+ `petal_length`,
115
+ `petal_width`
116
+ ]
117
+ }
118
+ }
119
+ ]
120
+ }
121
+ ],
122
+ `theme`: `vizro_dark`,
123
+ `title`: `Iris Dashboard`
124
+ },
125
+ `data_infos`: `
126
+ [
127
+ {
128
+ \"file_name\": \"iris_data\",
129
+ \"file_path_or_url\": \"https://raw.githubusercontent.com/plotly/datasets/master/iris-id.csv\",
130
+ \"file_location_type\": \"remote\",
131
+ \"read_function_string\": \"pd.read_csv\",
132
+ }
133
+ ]
134
+ `
135
+ }
136
+
137
+ """
138
+
139
+
140
+ def get_simple_dashboard_config() -> str:
141
+ """Very simple Vizro dashboard configuration. Use this config as a starter when no other config is provided."""
142
+ return SAMPLE_DASHBOARD_CONFIG