vizro-mcp 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vizro_mcp/__init__.py CHANGED
@@ -3,7 +3,7 @@ import sys
3
3
 
4
4
  from .server import mcp
5
5
 
6
- __version__ = "0.1.1"
6
+ __version__ = "0.1.3"
7
7
 
8
8
 
9
9
  def main():
@@ -1,29 +1,13 @@
1
1
  from .schemas import (
2
- MODEL_GROUPS,
3
2
  AgGridEnhanced,
4
3
  ChartPlan,
5
- ContainerSimplified,
6
- DashboardSimplified,
7
- FilterSimplified,
4
+ FigureEnhanced,
8
5
  GraphEnhanced,
9
- PageSimplified,
10
- ParameterSimplified,
11
- TabsSimplified,
12
- get_overview_vizro_models,
13
- get_simple_dashboard_config,
14
6
  )
15
7
 
16
8
  __all__ = [
17
- "MODEL_GROUPS",
18
9
  "AgGridEnhanced",
19
10
  "ChartPlan",
20
- "ContainerSimplified",
21
- "DashboardSimplified",
22
- "FilterSimplified",
11
+ "FigureEnhanced",
23
12
  "GraphEnhanced",
24
- "PageSimplified",
25
- "ParameterSimplified",
26
- "TabsSimplified",
27
- "get_overview_vizro_models",
28
- "get_simple_dashboard_config",
29
13
  ]
@@ -1,15 +1,14 @@
1
1
  """Schema defining pydantic models for usage in the MCP server."""
2
2
 
3
- from typing import Annotated, Any, Literal, Optional
3
+ from typing import Annotated, Any, Optional
4
4
 
5
+ import vizro.figures as vf
5
6
  import vizro.models as vm
6
- from pydantic import AfterValidator, BaseModel, Field, PrivateAttr, conlist
7
+ from pydantic import AfterValidator, BaseModel, Field, PrivateAttr, ValidationInfo
7
8
 
8
- from vizro_mcp._utils import SAMPLE_DASHBOARD_CONFIG, DFMetaData
9
+ from vizro_mcp._utils import DFMetaData
9
10
 
10
- # Constants used in chart validation
11
- CUSTOM_CHART_NAME = "custom_chart"
12
- ADDITIONAL_IMPORTS = [
11
+ BASE_IMPORTS = [
13
12
  "import vizro.plotly.express as px",
14
13
  "import plotly.graph_objects as go",
15
14
  "import pandas as pd",
@@ -17,91 +16,6 @@ ADDITIONAL_IMPORTS = [
17
16
  "from vizro.models.types import capture",
18
17
  ]
19
18
 
20
- # These types are used to simplify the schema for the LLM.
21
- SimplifiedComponentType = Literal["Card", "Button", "Text", "Container", "Tabs", "Graph", "AgGrid"]
22
- SimplifiedSelectorType = Literal[
23
- "Dropdown", "RadioItems", "Checklist", "DatePicker", "Slider", "RangeSlider", "DatePicker"
24
- ]
25
- SimplifiedControlType = Literal["Filter", "Parameter"]
26
- SimplifiedLayoutType = Literal["Grid", "Flex"]
27
-
28
- # This dict is used to give the model and overview of what is available in the vizro.models namespace.
29
- # It helps it to narrow down the choices when asking for a model.
30
- MODEL_GROUPS: dict[str, list[type[vm.VizroBaseModel]]] = {
31
- "main": [vm.Dashboard, vm.Page],
32
- "components": [vm.Card, vm.Button, vm.Text, vm.Container, vm.Tabs, vm.Graph, vm.AgGrid], #'Figure', 'Table'
33
- "layouts": [vm.Grid, vm.Flex],
34
- "controls": [vm.Filter, vm.Parameter],
35
- "selectors": [
36
- vm.Dropdown,
37
- vm.RadioItems,
38
- vm.Checklist,
39
- vm.DatePicker,
40
- vm.Slider,
41
- vm.RangeSlider,
42
- vm.DatePicker,
43
- ],
44
- "navigation": [vm.Navigation, vm.NavBar, vm.NavLink],
45
- }
46
-
47
-
48
- # These simplified page, container, tabs and dashboard models are used to return a flatter schema to the LLM in order to
49
- # reduce the context size. Especially the dashboard model schema is huge as it contains all other models.
50
-
51
-
52
- class FilterSimplified(vm.Filter):
53
- """Simplified Filter model for reduced schema. LLM should remember to insert actual components."""
54
-
55
- selector: Optional[SimplifiedSelectorType] = Field(
56
- default=None, description="Selector to be displayed. Only provide if asked for!"
57
- )
58
-
59
-
60
- class ParameterSimplified(vm.Parameter):
61
- """Simplified Parameter model for reduced schema. LLM should remember to insert actual components."""
62
-
63
- selector: SimplifiedSelectorType = Field(description="Selector to be displayed.")
64
-
65
-
66
- class ContainerSimplified(vm.Container):
67
- """Simplified Container model for reduced schema. LLM should remember to insert actual components."""
68
-
69
- components: list[SimplifiedComponentType] = Field(description="List of component names to be displayed.")
70
- layout: Optional[SimplifiedLayoutType] = Field(
71
- default=None, description="Layout to place components in. Only provide if asked for!"
72
- )
73
-
74
-
75
- class TabsSimplified(vm.Tabs):
76
- """Simplified Tabs model for reduced schema. LLM should remember to insert actual components."""
77
-
78
- tabs: conlist(ContainerSimplified, min_length=1)
79
-
80
-
81
- class PageSimplified(BaseModel):
82
- """Simplified Page modes for reduced schema. LLM should remember to insert actual components."""
83
-
84
- components: list[SimplifiedComponentType] = Field(description="List of component names to be displayed.")
85
- title: str = Field(description="Title to be displayed.")
86
- description: str = Field(default="", description="Description for meta tags.")
87
- layout: Optional[SimplifiedLayoutType] = Field(
88
- default=None, description="Layout to place components in. Only provide if asked for!"
89
- )
90
- controls: list[SimplifiedControlType] = Field(default=[], description="Controls to be displayed.")
91
-
92
-
93
- class DashboardSimplified(BaseModel):
94
- """Simplified Dashboard model for reduced schema. LLM should remember to insert actual components."""
95
-
96
- pages: list[Literal["Page"]] = Field(description="List of page names to be included in the dashboard.")
97
- theme: Literal["vizro_dark", "vizro_light"] = Field(
98
- default="vizro_dark", description="Theme to be applied across dashboard. Defaults to `vizro_dark`."
99
- )
100
- navigation: Optional[Literal["Navigation"]] = Field(
101
- default=None, description="Navigation component for the dashboard. Only provide if asked for!"
102
- )
103
- title: str = Field(default="", description="Dashboard title to appear on every page on top left-side.")
104
-
105
19
 
106
20
  # These enhanced models are used to return a more complete schema to the LLM. Although we do not have actual schemas for
107
21
  # the figure fields, we can prompt the model via the description to produce something likely correct.
@@ -110,15 +24,23 @@ class GraphEnhanced(vm.Graph):
110
24
 
111
25
  figure: dict[str, Any] = Field(
112
26
  description="""
27
+ For simpler charts and without need for data manipulation, use this field:
113
28
  This is the plotly express figure to be displayed. Only use valid plotly express functions to create the figure.
114
29
  Only use the arguments that are supported by the function you are using and where no extra modules such as statsmodels
115
- are needed (e.g. trendline).
116
-
30
+ are needed (e.g. trendline):
117
31
  - Configure a dictionary as if this would be added as **kwargs to the function you are using.
118
32
  - You must use the key: "_target_: "<function_name>" to specify the function you are using. Do NOT precede by
119
33
  namespace (like px.line)
120
- - you must refer to the dataframe by name, for now it is one of "gapminder", "iris", "tips".
34
+ - you must refer to the dataframe by name, check file_name in the data_infos field ("data_frame": "<file_name>")
121
35
  - do not use a title if your Graph model already has a title.
36
+
37
+ For more complex charts and those that require data manipulation, use the `custom_charts` field:
38
+ - create the suitable number of custom charts and add them to the `custom_charts` field
39
+ - refer here to the function signature you created
40
+ - you must use the key: "_target_: "<custom_chart_name>"
41
+ - you must refer to the dataframe by name, check file_name in the data_infos field ("data_frame": "<file_name>")
42
+ - in general, DO NOT modify the background (with plot_bgcolor) or color sequences unless explicitly asked for
43
+ - when creating hover templates, EXPLICITLY style them to work on light and dark mode
122
44
  """
123
45
  )
124
46
 
@@ -137,7 +59,22 @@ The only difference to the dash version is that:
137
59
  )
138
60
 
139
61
 
140
- ###### Chart functionality - not sure if I should include this in the MCP server
62
+ FIGURE_NAMESPACE_FUNCTION_DOCS = {func: vf.__dict__[func].__doc__ for func in vf.__all__}
63
+
64
+
65
+ class FigureEnhanced(vm.Figure):
66
+ """Figure model that allows to use dynamic figure functions."""
67
+
68
+ figure: dict[str, Any] = Field(
69
+ description=f"""This is the figure function to be displayed.
70
+
71
+ Only use arguments from the below mapping of _target_ to figure function documentation:
72
+
73
+ {FIGURE_NAMESPACE_FUNCTION_DOCS}"""
74
+ )
75
+
76
+
77
+ ###### Chart functionality ######
141
78
  def _strip_markdown(code_string: str) -> str:
142
79
  """Remove any code block wrappers (markdown or triple quotes)."""
143
80
  wrappers = [("```python\n", "```"), ("```py\n", "```"), ("```\n", "```"), ('"""', '"""'), ("'''", "'''")]
@@ -150,22 +87,19 @@ def _strip_markdown(code_string: str) -> str:
150
87
  return code_string.strip()
151
88
 
152
89
 
153
- def _check_chart_code(v: str) -> str:
90
+ def _check_chart_code(v: str, info: ValidationInfo) -> str:
154
91
  v = _strip_markdown(v)
155
92
 
156
93
  # TODO: add more checks: ends with return, has return, no second function def, only one indented line
157
- func_def = f"def {CUSTOM_CHART_NAME}("
94
+ func_def = f"def {info.data['chart_name']}("
158
95
  if func_def not in v:
159
- raise ValueError(f"The chart code must be wrapped in a function named `{CUSTOM_CHART_NAME}`")
96
+ raise ValueError(f"The chart code must be wrapped in a function named `{info.data['chart_name']}`")
160
97
 
161
98
  v = v[v.index(func_def) :].strip()
162
99
 
163
100
  first_line = v.split("\n")[0].strip()
164
- if "data_frame" not in first_line:
165
- raise ValueError(
166
- """The chart code must accept a single argument `data_frame`,
167
- and it should be the first argument of the chart."""
168
- )
101
+ if "(data_frame" not in first_line:
102
+ raise ValueError("""The chart code must accept as first argument `data_frame` which is a pandas DataFrame.""")
169
103
  return v
170
104
 
171
105
 
@@ -177,6 +111,12 @@ class ChartPlan(BaseModel):
177
111
  Describes the chart type that best reflects the user request.
178
112
  """,
179
113
  )
114
+ chart_name: str = Field(
115
+ description="""
116
+ The name of the chart function. Should be unique, concise and in snake_case.
117
+ """,
118
+ pattern=r"^[a-z][a-z0-9_]*$",
119
+ )
180
120
  imports: list[str] = Field(
181
121
  description="""
182
122
  List of import statements required to render the chart defined by the `chart_code` field. Ensure that every
@@ -192,20 +132,22 @@ class ChartPlan(BaseModel):
192
132
  Field(
193
133
  description="""
194
134
  Python code that generates a generates a plotly go.Figure object. It must fulfill the following criteria:
195
- 1. Must be wrapped in a function name
196
- 2. Must accept a single argument `data_frame` which is a pandas DataFrame
135
+ 1. Must be wrapped in a function that is named `chart_name`
136
+ 2. Must accept as first argument argument `data_frame` which is a pandas DataFrame
197
137
  3. Must return a plotly go.Figure object
198
138
  4. All data used in the chart must be derived from the data_frame argument, all data manipulations
199
139
  must be done within the function.
140
+ 5. DO NOT modify the background (with plot_bgcolor) or color sequences unless explicitly asked for
141
+ 6. When creating hover templates, explicitly ensure that it works on light and dark mode
200
142
  """,
201
143
  ),
202
144
  ]
203
145
 
204
- _additional_vizro_imports: list[str] = PrivateAttr(ADDITIONAL_IMPORTS)
146
+ _base_chart_imports: list[str] = PrivateAttr(BASE_IMPORTS)
205
147
 
206
148
  def get_imports(self, vizro: bool = False):
207
- imports = list(dict.fromkeys(self.imports + self._additional_vizro_imports)) # remove duplicates
208
- if vizro: # TODO: improve code of below
149
+ imports = list(dict.fromkeys(self.imports + self._base_chart_imports)) # remove duplicates
150
+ if vizro:
209
151
  imports = [imp for imp in imports if "import plotly.express as px" not in imp]
210
152
  else:
211
153
  imports = [imp for imp in imports if "vizro" not in imp]
@@ -214,9 +156,9 @@ class ChartPlan(BaseModel):
214
156
  def get_chart_code(self, chart_name: Optional[str] = None, vizro: bool = False):
215
157
  chart_code = self.chart_code
216
158
  if vizro:
217
- chart_code = chart_code.replace(f"def {CUSTOM_CHART_NAME}", f"@capture('graph')\ndef {CUSTOM_CHART_NAME}")
159
+ chart_code = chart_code.replace(f"def {self.chart_name}", f"@capture('graph')\ndef {self.chart_name}")
218
160
  if chart_name is not None:
219
- chart_code = chart_code.replace(f"def {CUSTOM_CHART_NAME}", f"def {chart_name}")
161
+ chart_code = chart_code.replace(f"def {self.chart_name}", f"def {chart_name}")
220
162
  return chart_code
221
163
 
222
164
  def get_dashboard_template(self, data_info: DFMetaData) -> str:
@@ -232,14 +174,14 @@ class ChartPlan(BaseModel):
232
174
  imports = self.get_imports(vizro=True)
233
175
 
234
176
  # Add the Vizro-specific imports if not present
235
- additional_imports = [
177
+ additional_dashboard_imports = [
236
178
  "import vizro.models as vm",
237
179
  "from vizro import Vizro",
238
180
  "from vizro.managers import data_manager",
239
181
  ]
240
182
 
241
183
  # Combine imports without duplicates
242
- all_imports = list(dict.fromkeys(additional_imports + imports.split("\n")))
184
+ all_imports = list(dict.fromkeys(additional_dashboard_imports + imports.split("\n")))
243
185
 
244
186
  dashboard_template = f"""
245
187
  {chr(10).join(imp for imp in all_imports if imp)}
@@ -259,7 +201,7 @@ dashboard = vm.Dashboard(
259
201
  components=[
260
202
  vm.Graph(
261
203
  id="{self.chart_type}_graph",
262
- figure={CUSTOM_CHART_NAME}("{data_info.file_name}"),
204
+ figure={self.chart_name}("{data_info.file_name}"),
263
205
  )
264
206
  ],
265
207
  )
@@ -274,24 +216,15 @@ Vizro().build(dashboard).run()
274
216
  return dashboard_template
275
217
 
276
218
 
277
- def get_overview_vizro_models() -> dict[str, list[dict[str, str]]]:
278
- """Get all available models in the vizro.models namespace.
279
-
280
- Returns:
281
- Dictionary with categories of models and their descriptions
282
- """
283
- result: dict[str, list[dict[str, str]]] = {}
284
- for category, models_list in MODEL_GROUPS.items():
285
- result[category] = [
286
- {
287
- "name": model_class.__name__,
288
- "description": (model_class.__doc__ or "No description available").split("\n")[0],
289
- }
290
- for model_class in models_list
291
- ]
292
- return result
293
-
219
+ if __name__ == "__main__":
220
+ plan = ChartPlan(
221
+ chart_type="scatter",
222
+ chart_name="scatter",
223
+ imports=["import pandas as pd", "import plotly.express as px"],
224
+ chart_code="""
225
+ def scatter(data_frame):
226
+ return px.scatter(data_frame, x="sepal_length", y="sepal_width")
227
+ """,
228
+ )
294
229
 
295
- def get_simple_dashboard_config() -> str:
296
- """Very simple Vizro dashboard configuration. Use this config as a starter when no other config is provided."""
297
- return SAMPLE_DASHBOARD_CONFIG
230
+ # print(plan.get_chart_code(chart_name="poo", vizro=True))
@@ -1,11 +1,22 @@
1
- from .utils import (
1
+ from .configs import (
2
2
  GAPMINDER,
3
3
  IRIS,
4
4
  SAMPLE_DASHBOARD_CONFIG,
5
5
  STOCKS,
6
6
  TIPS,
7
+ )
8
+ from .prompts import (
9
+ CHART_INSTRUCTIONS,
10
+ LAYOUT_INSTRUCTIONS,
11
+ get_chart_prompt,
12
+ get_dashboard_instructions,
13
+ get_dashboard_prompt,
14
+ get_starter_dashboard_prompt,
15
+ )
16
+ from .utils import (
7
17
  DFInfo,
8
18
  DFMetaData,
19
+ NoDefsGenerateJsonSchema,
9
20
  VizroCodeAndPreviewLink,
10
21
  convert_github_url_to_raw,
11
22
  create_pycafe_url,
@@ -15,19 +26,29 @@ from .utils import (
15
26
  path_or_url_check,
16
27
  )
17
28
 
18
- __all__ = [
29
+ __all__ = [ # noqa: RUF022
30
+ # Classes
31
+ "DFInfo",
32
+ "DFMetaData",
33
+ "NoDefsGenerateJsonSchema",
34
+ "VizroCodeAndPreviewLink",
35
+ # Constants
36
+ "CHART_INSTRUCTIONS",
37
+ "LAYOUT_INSTRUCTIONS",
19
38
  "GAPMINDER",
20
39
  "IRIS",
21
40
  "SAMPLE_DASHBOARD_CONFIG",
22
41
  "STOCKS",
23
42
  "TIPS",
24
- "DFInfo",
25
- "DFMetaData",
26
- "VizroCodeAndPreviewLink",
43
+ # Functions
27
44
  "convert_github_url_to_raw",
28
45
  "create_pycafe_url",
29
46
  "get_dataframe_info",
47
+ "get_dashboard_instructions",
48
+ "get_dashboard_prompt",
49
+ "get_chart_prompt",
30
50
  "get_python_code_and_preview_link",
51
+ "get_starter_dashboard_prompt",
31
52
  "load_dataframe_by_format",
32
53
  "path_or_url_check",
33
54
  ]
@@ -0,0 +1,142 @@
1
+ """Pre-set configs for the Vizro MCP."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Literal, Optional
5
+
6
+
7
+ @dataclass
8
+ class DFMetaData:
9
+ file_name: str
10
+ file_path_or_url: str
11
+ file_location_type: Literal["local", "remote"]
12
+ read_function_string: Literal["pd.read_csv", "pd.read_json", "pd.read_html", "pd.read_parquet", "pd.read_excel"]
13
+ column_names_types: Optional[dict[str, str]] = None
14
+
15
+
16
+ @dataclass
17
+ class DFInfo:
18
+ general_info: str
19
+ sample: dict[str, Any]
20
+
21
+
22
+ IRIS = DFMetaData(
23
+ file_name="iris_data",
24
+ file_path_or_url="https://raw.githubusercontent.com/plotly/datasets/master/iris-id.csv",
25
+ file_location_type="remote",
26
+ read_function_string="pd.read_csv",
27
+ column_names_types={
28
+ "sepal_length": "float",
29
+ "sepal_width": "float",
30
+ "petal_length": "float",
31
+ "petal_width": "float",
32
+ "species": "str",
33
+ },
34
+ )
35
+
36
+ TIPS = DFMetaData(
37
+ file_name="tips_data",
38
+ file_path_or_url="https://raw.githubusercontent.com/plotly/datasets/master/tips.csv",
39
+ file_location_type="remote",
40
+ read_function_string="pd.read_csv",
41
+ column_names_types={
42
+ "total_bill": "float",
43
+ "tip": "float",
44
+ "sex": "str",
45
+ "smoker": "str",
46
+ "day": "str",
47
+ "time": "str",
48
+ "size": "int",
49
+ },
50
+ )
51
+
52
+ STOCKS = DFMetaData(
53
+ file_name="stocks_data",
54
+ file_path_or_url="https://raw.githubusercontent.com/plotly/datasets/master/stockdata.csv",
55
+ file_location_type="remote",
56
+ read_function_string="pd.read_csv",
57
+ column_names_types={
58
+ "Date": "str",
59
+ "IBM": "float",
60
+ "MSFT": "float",
61
+ "SBUX": "float",
62
+ "AAPL": "float",
63
+ "GSPC": "float",
64
+ },
65
+ )
66
+
67
+ GAPMINDER = DFMetaData(
68
+ file_name="gapminder_data",
69
+ file_path_or_url="https://raw.githubusercontent.com/plotly/datasets/master/gapminder_unfiltered.csv",
70
+ file_location_type="remote",
71
+ read_function_string="pd.read_csv",
72
+ column_names_types={
73
+ "country": "str",
74
+ "continent": "str",
75
+ "year": "int",
76
+ "lifeExp": "float",
77
+ "pop": "int",
78
+ "gdpPercap": "float",
79
+ },
80
+ )
81
+
82
+ SAMPLE_DASHBOARD_CONFIG = """
83
+ {
84
+ `config`: {
85
+ `pages`: [
86
+ {
87
+ `title`: `Iris Data Analysis`,
88
+ `controls`: [
89
+ {
90
+ `id`: `species_filter`,
91
+ `type`: `filter`,
92
+ `column`: `species`,
93
+ `targets`: [
94
+ `scatter_plot`
95
+ ],
96
+ `selector`: {
97
+ `type`: `dropdown`,
98
+ `multi`: true
99
+ }
100
+ }
101
+ ],
102
+ `components`: [
103
+ {
104
+ `id`: `scatter_plot`,
105
+ `type`: `graph`,
106
+ `title`: `Sepal Dimensions by Species`,
107
+ `figure`: {
108
+ `x`: `sepal_length`,
109
+ `y`: `sepal_width`,
110
+ `color`: `species`,
111
+ `_target_`: `scatter`,
112
+ `data_frame`: `iris_data`,
113
+ `hover_data`: [
114
+ `petal_length`,
115
+ `petal_width`
116
+ ]
117
+ }
118
+ }
119
+ ]
120
+ }
121
+ ],
122
+ `theme`: `vizro_dark`,
123
+ `title`: `Iris Dashboard`
124
+ },
125
+ `data_infos`: `
126
+ [
127
+ {
128
+ \"file_name\": \"iris_data\",
129
+ \"file_path_or_url\": \"https://raw.githubusercontent.com/plotly/datasets/master/iris-id.csv\",
130
+ \"file_location_type\": \"remote\",
131
+ \"read_function_string\": \"pd.read_csv\",
132
+ }
133
+ ]
134
+ `
135
+ }
136
+
137
+ """
138
+
139
+
140
+ def get_simple_dashboard_config() -> str:
141
+ """Very simple Vizro dashboard configuration. Use this config as a starter when no other config is provided."""
142
+ return SAMPLE_DASHBOARD_CONFIG