ai-data-science-team 0.0.0.9009__py3-none-any.whl → 0.0.0.9011__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +1 -0
  3. ai_data_science_team/agents/data_cleaning_agent.py +6 -6
  4. ai_data_science_team/agents/data_loader_tools_agent.py +272 -0
  5. ai_data_science_team/agents/data_visualization_agent.py +6 -7
  6. ai_data_science_team/agents/data_wrangling_agent.py +6 -6
  7. ai_data_science_team/agents/feature_engineering_agent.py +6 -6
  8. ai_data_science_team/agents/sql_database_agent.py +6 -6
  9. ai_data_science_team/ml_agents/__init__.py +1 -0
  10. ai_data_science_team/ml_agents/h2o_ml_agent.py +206 -385
  11. ai_data_science_team/ml_agents/h2o_ml_tools_agent.py +0 -0
  12. ai_data_science_team/ml_agents/mlflow_tools_agent.py +350 -0
  13. ai_data_science_team/multiagents/sql_data_analyst.py +3 -4
  14. ai_data_science_team/parsers/__init__.py +0 -0
  15. ai_data_science_team/{tools → parsers}/parsers.py +0 -1
  16. ai_data_science_team/templates/agent_templates.py +6 -6
  17. ai_data_science_team/tools/data_loader.py +448 -0
  18. ai_data_science_team/tools/dataframe.py +139 -0
  19. ai_data_science_team/tools/h2o.py +643 -0
  20. ai_data_science_team/tools/mlflow.py +961 -0
  21. ai_data_science_team/tools/{metadata.py → sql.py} +1 -137
  22. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9011.dist-info}/METADATA +40 -19
  23. ai_data_science_team-0.0.0.9011.dist-info/RECORD +36 -0
  24. ai_data_science_team-0.0.0.9009.dist-info/RECORD +0 -28
  25. /ai_data_science_team/{tools → utils}/logging.py +0 -0
  26. /ai_data_science_team/{tools → utils}/regex.py +0 -0
  27. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9011.dist-info}/LICENSE +0 -0
  28. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9011.dist-info}/WHEEL +0 -0
  29. {ai_data_science_team-0.0.0.9009.dist-info → ai_data_science_team-0.0.0.9011.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,961 @@
1
+
2
+
3
+ from typing import Optional, Dict, Any, Union, List, Annotated
4
+ from langgraph.prebuilt import InjectedState
5
+ from langchain.tools import tool
6
+
7
+
8
+ @tool(response_format='content_and_artifact')
9
+ def mlflow_search_experiments(
10
+ filter_string: Optional[str] = None,
11
+ tracking_uri: str | None = None,
12
+ registry_uri: str | None = None
13
+ ) -> str:
14
+ """
15
+ Search and list existing MLflow experiments.
16
+
17
+ Parameters
18
+ ----------
19
+ filter_string : str, optional
20
+ Filter query string (e.g., "name = 'my_experiment'"), defaults to
21
+ searching for all experiments.
22
+
23
+ tracking_uri: str, optional
24
+ Address of local or remote tracking server.
25
+ If not provided, defaults
26
+ to the service set by mlflow.tracking.set_tracking_uri. See Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>_ for more info.
27
+ registry_uri: str, optional
28
+ Address of local or remote model registry
29
+ server. If not provided,
30
+ defaults to the service set by mlflow.tracking.set_registry_uri. If no such service was set, defaults to the tracking uri of the client.
31
+
32
+ Returns
33
+ -------
34
+ tuple
35
+ - JSON-serialized list of experiment metadata (ID, name, etc.).
36
+ - DataFrame of experiment metadata.
37
+ """
38
+ print(" * Tool: mlflow_search_experiments")
39
+ from mlflow.tracking import MlflowClient
40
+ import pandas as pd
41
+
42
+ client = MlflowClient(tracking_uri=tracking_uri, registry_uri=registry_uri)
43
+ experiments = client.search_experiments(filter_string=filter_string)
44
+ # Convert to a dictionary in a list
45
+ experiments_data = [
46
+ dict(e)
47
+ for e in experiments
48
+ ]
49
+ # Convert to a DataFrame
50
+ experiments_df = pd.DataFrame(experiments_data)
51
+ # Convert timestamps to datetime objects
52
+ experiments_df["last_update_time"] = pd.to_datetime(experiments_df["last_update_time"], unit="ms")
53
+ experiments_df["creation_time"] = pd.to_datetime(experiments_df["creation_time"], unit="ms")
54
+
55
+ return (experiments_df.to_dict(), experiments_df.to_dict())
56
+
57
+
58
+ @tool(response_format='content_and_artifact')
59
+ def mlflow_search_runs(
60
+ experiment_ids: Optional[Union[List[str], List[int], str, int]] = None,
61
+ filter_string: Optional[str] = None,
62
+ tracking_uri: str | None = None,
63
+ registry_uri: str | None = None
64
+ ) -> str:
65
+ """
66
+ Search runs within one or more MLflow experiments, optionally filtering by a filter_string.
67
+
68
+ Parameters
69
+ ----------
70
+ experiment_ids : list or str or int, optional
71
+ One or more Experiment IDs.
72
+ filter_string : str, optional
73
+ MLflow filter expression, e.g. "metrics.rmse < 1.0".
74
+ tracking_uri: str, optional
75
+ Address of local or remote tracking server.
76
+ If not provided, defaults
77
+ to the service set by mlflow.tracking.set_tracking_uri. See Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>_ for more info.
78
+ registry_uri: str, optional
79
+ Address of local or remote model registry
80
+ server. If not provided,
81
+ defaults to the service set by mlflow.tracking.set_registry_uri. If no such service was set, defaults to the tracking uri of the client.
82
+
83
+ Returns
84
+ -------
85
+ str
86
+ JSON-formatted list of runs that match the query.
87
+ """
88
+ print(" * Tool: mlflow_search_runs")
89
+ from mlflow.tracking import MlflowClient
90
+ import pandas as pd
91
+
92
+ client = MlflowClient(
93
+ tracking_uri=tracking_uri,
94
+ registry_uri=registry_uri
95
+ )
96
+
97
+ if experiment_ids is None:
98
+ experiment_ids = []
99
+ if isinstance(experiment_ids, (str, int)):
100
+ experiment_ids = [experiment_ids]
101
+
102
+ runs = client.search_runs(
103
+ experiment_ids=experiment_ids,
104
+ filter_string=filter_string
105
+ )
106
+
107
+ # If no runs are found, return an empty DataFrame
108
+ if not runs:
109
+ return "No runs found.", pd.DataFrame()
110
+
111
+ # Extract relevant information
112
+ data = []
113
+ for run in runs:
114
+ run_info = {
115
+ "run_id": run.info.run_id,
116
+ "run_name": run.info.run_name,
117
+ "status": run.info.status,
118
+ "start_time": pd.to_datetime(run.info.start_time, unit="ms"),
119
+ "end_time": pd.to_datetime(run.info.end_time, unit="ms"),
120
+ "experiment_id": run.info.experiment_id,
121
+ "user_id": run.info.user_id
122
+ }
123
+
124
+ # Flatten metrics, parameters, and tags
125
+ run_info.update(run.data.metrics)
126
+ run_info.update({f"param_{k}": v for k, v in run.data.params.items()})
127
+ run_info.update({f"tag_{k}": v for k, v in run.data.tags.items()})
128
+
129
+ data.append(run_info)
130
+
131
+ # Convert to DataFrame
132
+ df = pd.DataFrame(data)
133
+
134
+ return (df.iloc[:,0:15].to_dict(), df.to_dict())
135
+
136
+
137
+
138
+ @tool(response_format='content')
139
+ def mlflow_create_experiment(experiment_name: str) -> str:
140
+ """
141
+ Create a new MLflow experiment by name.
142
+
143
+ Parameters
144
+ ----------
145
+ experiment_name : str
146
+ The name of the experiment to create.
147
+
148
+ Returns
149
+ -------
150
+ str
151
+ The experiment ID or an error message if creation failed.
152
+ """
153
+ print(" * Tool: mlflow_create_experiment")
154
+ from mlflow.tracking import MlflowClient
155
+
156
+ client = MlflowClient()
157
+ exp_id = client.create_experiment(experiment_name)
158
+ return f"Experiment created with ID: {exp_id}, name: {experiment_name}"
159
+
160
+
161
+
162
+
163
+ @tool(response_format='content_and_artifact')
164
+ def mlflow_predict_from_run_id(
165
+ run_id: str,
166
+ data_raw: Annotated[dict, InjectedState("data_raw")],
167
+ tracking_uri: Optional[str] = None
168
+ ) -> tuple:
169
+ """
170
+ Predict using an MLflow model (PyFunc) directly from a given run ID.
171
+
172
+ Parameters
173
+ ----------
174
+ run_id : str
175
+ The ID of the MLflow run that logged the model.
176
+ data_raw : dict
177
+ The incoming data as a dictionary.
178
+ tracking_uri : str, optional
179
+ Address of local or remote tracking server.
180
+
181
+ Returns
182
+ -------
183
+ tuple
184
+ (user_facing_message, artifact_dict)
185
+ """
186
+ print(" * Tool: mlflow_predict_from_run_id")
187
+ import mlflow
188
+ import mlflow.pyfunc
189
+ import pandas as pd
190
+
191
+ # 1. Check if data is loaded
192
+ if not data_raw:
193
+ return "No data provided for prediction. Please use `data_raw` parameter inside of `invoke_agent()` or `ainvoke_agent()`.", {}
194
+ df = pd.DataFrame(data_raw)
195
+
196
+ # 2. Prepare model URI
197
+ model_uri = f"runs:/{run_id}/model"
198
+
199
+ # 3. Load or cache the MLflow model
200
+ model = mlflow.pyfunc.load_model(model_uri)
201
+
202
+ # 4. Make predictions
203
+ try:
204
+ preds = model.predict(df)
205
+ except Exception as e:
206
+ return f"Error during inference: {str(e)}", {}
207
+
208
+ # 5. Convert predictions to a user-friendly summary + artifact
209
+ if isinstance(preds, pd.DataFrame):
210
+ sample_json = preds.head().to_json(orient='records')
211
+ artifact_dict = preds.to_dict(orient='records') # entire DF
212
+ message = f"Predictions returned. Sample: {sample_json}"
213
+ elif hasattr(preds, "to_json"):
214
+ # e.g., pd.Series
215
+ sample_json = preds[:5].to_json(orient='records')
216
+ artifact_dict = preds.to_dict()
217
+ message = f"Predictions returned. Sample: {sample_json}"
218
+ elif hasattr(preds, "tolist"):
219
+ # e.g., a NumPy array
220
+ preds_list = preds.tolist()
221
+ artifact_dict = {"predictions": preds_list}
222
+ message = f"Predictions returned. First 5: {preds_list[:5]}"
223
+ else:
224
+ # fallback
225
+ preds_str = str(preds)
226
+ artifact_dict = {"predictions": preds_str}
227
+ message = f"Predictions returned (unrecognized type). Example: {preds_str[:100]}..."
228
+
229
+ return (message, artifact_dict)
230
+
231
+
232
+ # MLflow tool to launch gui for mlflow
233
+ @tool(response_format='content')
234
+ def mlflow_launch_ui(
235
+ port: int = 5000,
236
+ host: str = "localhost",
237
+ tracking_uri: Optional[str] = None
238
+ ) -> str:
239
+ """
240
+ Launch the MLflow UI.
241
+
242
+ Parameters
243
+ ----------
244
+ port : int, optional
245
+ The port on which to run the UI.
246
+ host : str, optional
247
+ The host address to bind the UI to.
248
+ tracking_uri : str, optional
249
+ Address of local or remote tracking server.
250
+
251
+ Returns
252
+ -------
253
+ str
254
+ Confirmation message.
255
+ """
256
+ print(" * Tool: mlflow_launch_ui")
257
+ import subprocess
258
+
259
+ # Try binding to the user-specified port first
260
+ allocated_port = _find_free_port(start_port=port, host=host)
261
+
262
+ cmd = ["mlflow", "ui", "--host", host, "--port", str(allocated_port)]
263
+ if tracking_uri:
264
+ cmd.extend(["--backend-store-uri", tracking_uri])
265
+
266
+ process = subprocess.Popen(cmd)
267
+ return (f"MLflow UI launched at http://{host}:{allocated_port}. "
268
+ f"(PID: {process.pid})")
269
+
270
+ def _find_free_port(start_port: int, host: str) -> int:
271
+ """
272
+ Find a free port >= start_port on the specified host.
273
+ If the start_port is free, returns start_port, else tries subsequent ports.
274
+ """
275
+ import socket
276
+ for port_candidate in range(start_port, start_port + 1000):
277
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
278
+ try:
279
+ sock.bind((host, port_candidate))
280
+ except OSError:
281
+ # Port is in use, try the next one
282
+ continue
283
+ # If bind succeeds, it's free
284
+ return port_candidate
285
+
286
+ raise OSError("No available ports found in the range "
287
+ f"{start_port}-{start_port + 999}")
288
+
289
+
290
+ @tool(response_format='content')
291
+ def mlflow_stop_ui(port: int = 5000) -> str:
292
+ """
293
+ Kill any process currently listening on the given MLflow UI port.
294
+ Requires `pip install psutil`.
295
+
296
+ Parameters
297
+ ----------
298
+ port : int, optional
299
+ The port on which the UI is running.
300
+ """
301
+ print(" * Tool: mlflow_stop_ui")
302
+ import psutil
303
+
304
+ # Gather system-wide inet connections
305
+ for conn in psutil.net_connections(kind="inet"):
306
+ # Check if this connection has a local address (laddr) and if
307
+ # the port matches the one we're trying to free
308
+ if conn.laddr and conn.laddr.port == port:
309
+ # Some connections may not have an associated PID
310
+ if conn.pid is not None:
311
+ try:
312
+ p = psutil.Process(conn.pid)
313
+ p_name = p.name() # optional: get process name for clarity
314
+ p.kill() # forcibly terminate the process
315
+ return (
316
+ f"Killed process {conn.pid} ({p_name}) listening on port {port}."
317
+ )
318
+ except psutil.NoSuchProcess:
319
+ return (
320
+ "Process was already terminated before we could kill it."
321
+ )
322
+ return f"No process found listening on port {port}."
323
+
324
+
325
+ @tool(response_format='content_and_artifact')
326
+ def mlflow_list_artifacts(
327
+ run_id: str,
328
+ path: Optional[str] = None,
329
+ tracking_uri: Optional[str] = None
330
+ ) -> tuple:
331
+ """
332
+ List artifacts under a given MLflow run.
333
+
334
+ Parameters
335
+ ----------
336
+ run_id : str
337
+ The ID of the run whose artifacts to list.
338
+ path : str, optional
339
+ Path within the run's artifact directory to list. Defaults to the root.
340
+ tracking_uri : str, optional
341
+ Custom tracking server URI.
342
+
343
+ Returns
344
+ -------
345
+ tuple
346
+ (summary_message, artifact_listing)
347
+ """
348
+ print(" * Tool: mlflow_list_artifacts")
349
+ from mlflow.tracking import MlflowClient
350
+
351
+ client = MlflowClient(tracking_uri=tracking_uri)
352
+ # If path is None, list the root folder
353
+ artifact_list = client.list_artifacts(run_id, path or "")
354
+
355
+ # Convert to a more user-friendly structure
356
+ artifacts_data = []
357
+ for artifact in artifact_list:
358
+ artifacts_data.append({
359
+ "path": artifact.path,
360
+ "is_dir": artifact.is_dir,
361
+ "file_size": artifact.file_size
362
+ })
363
+
364
+ return (
365
+ f"Found {len(artifacts_data)} artifacts.",
366
+ artifacts_data
367
+ )
368
+
369
+
370
+ @tool(response_format='content_and_artifact')
371
+ def mlflow_download_artifacts(
372
+ run_id: str,
373
+ path: Optional[str] = None,
374
+ dst_path: Optional[str] = "./downloaded_artifacts",
375
+ tracking_uri: Optional[str] = None
376
+ ) -> tuple:
377
+ """
378
+ Download artifacts from MLflow to a local directory.
379
+
380
+ Parameters
381
+ ----------
382
+ run_id : str
383
+ The ID of the run whose artifacts to download.
384
+ path : str, optional
385
+ Path within the run's artifact directory to download. Defaults to the root.
386
+ dst_path : str, optional
387
+ Local destination path to store artifacts.
388
+ tracking_uri : str, optional
389
+ MLflow tracking server URI.
390
+
391
+ Returns
392
+ -------
393
+ tuple
394
+ (summary_message, artifact_dict)
395
+ """
396
+ print(" * Tool: mlflow_download_artifacts")
397
+ from mlflow.tracking import MlflowClient
398
+ import os
399
+
400
+ client = MlflowClient(tracking_uri=tracking_uri)
401
+ local_path = client.download_artifacts(run_id, path or "", dst_path)
402
+
403
+ # Build a recursive listing of what was downloaded
404
+ downloaded_files = []
405
+ for root, dirs, files in os.walk(local_path):
406
+ for f in files:
407
+ downloaded_files.append(os.path.join(root, f))
408
+
409
+ message = (
410
+ f"Artifacts for run_id='{run_id}' have been downloaded to: {local_path}. "
411
+ f"Total files: {len(downloaded_files)}."
412
+ )
413
+
414
+ return (
415
+ message,
416
+ {"downloaded_files": downloaded_files}
417
+ )
418
+
419
+
420
+ @tool(response_format='content_and_artifact')
421
+ def mlflow_list_registered_models(
422
+ max_results: int = 100,
423
+ tracking_uri: Optional[str] = None,
424
+ registry_uri: Optional[str] = None
425
+ ) -> tuple:
426
+ """
427
+ List all registered models in MLflow's model registry.
428
+
429
+ Parameters
430
+ ----------
431
+ max_results : int, optional
432
+ Maximum number of models to return.
433
+ tracking_uri : str, optional
434
+ registry_uri : str, optional
435
+
436
+ Returns
437
+ -------
438
+ tuple
439
+ (summary_message, model_list)
440
+ """
441
+ print(" * Tool: mlflow_list_registered_models")
442
+ from mlflow.tracking import MlflowClient
443
+
444
+ client = MlflowClient(tracking_uri=tracking_uri, registry_uri=registry_uri)
445
+ # The list_registered_models() call can be paginated; for simplicity, we just pass max_results
446
+ models = client.list_registered_models(max_results=max_results)
447
+
448
+ models_data = []
449
+ for m in models:
450
+ models_data.append({
451
+ "name": m.name,
452
+ "latest_versions": [
453
+ {
454
+ "version": v.version,
455
+ "run_id": v.run_id,
456
+ "current_stage": v.current_stage,
457
+ }
458
+ for v in m.latest_versions
459
+ ]
460
+ })
461
+
462
+ return (
463
+ f"Found {len(models_data)} registered models.",
464
+ models_data
465
+ )
466
+
467
+
468
+ @tool(response_format='content_and_artifact')
469
+ def mlflow_search_registered_models(
470
+ filter_string: Optional[str] = None,
471
+ order_by: Optional[List[str]] = None,
472
+ max_results: int = 100,
473
+ tracking_uri: Optional[str] = None,
474
+ registry_uri: Optional[str] = None
475
+ ) -> tuple:
476
+ """
477
+ Search registered models in MLflow's registry using optional filters.
478
+
479
+ Parameters
480
+ ----------
481
+ filter_string : str, optional
482
+ e.g. "name LIKE 'my_model%'" or "tags.stage = 'production'".
483
+ order_by : list, optional
484
+ e.g. ["name ASC"] or ["timestamp DESC"].
485
+ max_results : int, optional
486
+ Max number of results.
487
+ tracking_uri : str, optional
488
+ registry_uri : str, optional
489
+
490
+ Returns
491
+ -------
492
+ tuple
493
+ (summary_message, model_dict_list)
494
+ """
495
+ print(" * Tool: mlflow_search_registered_models")
496
+ from mlflow.tracking import MlflowClient
497
+
498
+ client = MlflowClient(tracking_uri=tracking_uri, registry_uri=registry_uri)
499
+ models = client.search_registered_models(
500
+ filter_string=filter_string,
501
+ order_by=order_by,
502
+ max_results=max_results
503
+ )
504
+
505
+ models_data = []
506
+ for m in models:
507
+ models_data.append({
508
+ "name": m.name,
509
+ "description": m.description,
510
+ "creation_timestamp": m.creation_timestamp,
511
+ "last_updated_timestamp": m.last_updated_timestamp,
512
+ "latest_versions": [
513
+ {
514
+ "version": v.version,
515
+ "run_id": v.run_id,
516
+ "current_stage": v.current_stage
517
+ }
518
+ for v in m.latest_versions
519
+ ]
520
+ })
521
+
522
+ return (
523
+ f"Found {len(models_data)} models matching filter={filter_string}.",
524
+ models_data
525
+ )
526
+
527
+
528
+ @tool(response_format='content_and_artifact')
529
+ def mlflow_get_model_version_details(
530
+ name: str,
531
+ version: str,
532
+ tracking_uri: Optional[str] = None,
533
+ registry_uri: Optional[str] = None
534
+ ) -> tuple:
535
+ """
536
+ Retrieve details about a specific model version in the MLflow registry.
537
+
538
+ Parameters
539
+ ----------
540
+ name : str
541
+ Name of the registered model.
542
+ version : str
543
+ Version number of that model.
544
+ tracking_uri : str, optional
545
+ registry_uri : str, optional
546
+
547
+ Returns
548
+ -------
549
+ tuple
550
+ (summary_message, version_data_dict)
551
+ """
552
+ print(" * Tool: mlflow_get_model_version_details")
553
+ from mlflow.tracking import MlflowClient
554
+
555
+ client = MlflowClient(tracking_uri=tracking_uri, registry_uri=registry_uri)
556
+ version_details = client.get_model_version(name, version)
557
+
558
+ data = {
559
+ "name": version_details.name,
560
+ "version": version_details.version,
561
+ "run_id": version_details.run_id,
562
+ "creation_timestamp": version_details.creation_timestamp,
563
+ "current_stage": version_details.current_stage,
564
+ "description": version_details.description,
565
+ "status": version_details.status
566
+ }
567
+
568
+ return (
569
+ f"Model version details retrieved for {name} v{version}",
570
+ data
571
+ )
572
+
573
+
574
+ # @tool
575
+ # def get_or_create_experiment(experiment_name):
576
+ # """
577
+ # Retrieve the ID of an existing MLflow experiment or create a new one if it doesn't exist.
578
+
579
+ # This function checks if an experiment with the given name exists within MLflow.
580
+ # If it does, the function returns its ID. If not, it creates a new experiment
581
+ # with the provided name and returns its ID.
582
+
583
+ # Parameters:
584
+ # - experiment_name (str): Name of the MLflow experiment.
585
+
586
+ # Returns:
587
+ # - str: ID of the existing or newly created MLflow experiment.
588
+ # """
589
+ # import mlflow
590
+ # if experiment := mlflow.get_experiment_by_name(experiment_name):
591
+ # return experiment.experiment_id
592
+ # else:
593
+ # return mlflow.create_experiment(experiment_name)
594
+
595
+
596
+
597
+ # @tool("mlflow_set_tracking_uri", return_direct=True)
598
+ # def mlflow_set_tracking_uri(tracking_uri: str) -> str:
599
+ # """
600
+ # Set or change the MLflow tracking URI.
601
+
602
+ # Parameters
603
+ # ----------
604
+ # tracking_uri : str
605
+ # The URI/path where MLflow logs & metrics are stored.
606
+
607
+ # Returns
608
+ # -------
609
+ # str
610
+ # Confirmation message.
611
+ # """
612
+ # import mlflow
613
+ # mlflow.set_tracking_uri(tracking_uri)
614
+ # return f"MLflow tracking URI set to: {tracking_uri}"
615
+
616
+
617
+ # @tool("mlflow_list_experiments", return_direct=True)
618
+ # def mlflow_list_experiments() -> str:
619
+ # """
620
+ # List existing MLflow experiments.
621
+
622
+ # Returns
623
+ # -------
624
+ # str
625
+ # JSON-serialized list of experiment metadata (ID, name, etc.).
626
+ # """
627
+ # from mlflow.tracking import MlflowClient
628
+ # import json
629
+
630
+ # client = MlflowClient()
631
+ # experiments = client.list_experiments()
632
+ # # Convert to a JSON-like structure
633
+ # experiments_data = [
634
+ # dict(experiment_id=e.experiment_id, name=e.name, artifact_location=e.artifact_location)
635
+ # for e in experiments
636
+ # ]
637
+
638
+ # return json.dumps(experiments_data)
639
+
640
+
641
+ # @tool("mlflow_create_experiment", return_direct=True)
642
+ # def mlflow_create_experiment(experiment_name: str) -> str:
643
+ # """
644
+ # Create a new MLflow experiment by name.
645
+
646
+ # Parameters
647
+ # ----------
648
+ # experiment_name : str
649
+ # The name of the experiment to create.
650
+
651
+ # Returns
652
+ # -------
653
+ # str
654
+ # The experiment ID or an error message if creation failed.
655
+ # """
656
+ # from mlflow.tracking import MlflowClient
657
+
658
+ # client = MlflowClient()
659
+ # exp_id = client.create_experiment(experiment_name)
660
+ # return f"Experiment created with ID: {exp_id}"
661
+
662
+
663
+ # @tool("mlflow_set_experiment", return_direct=True)
664
+ # def mlflow_set_experiment(experiment_name: str) -> str:
665
+ # """
666
+ # Set or create an MLflow experiment for subsequent logging.
667
+
668
+ # Parameters
669
+ # ----------
670
+ # experiment_name : str
671
+ # The name of the experiment to set.
672
+
673
+ # Returns
674
+ # -------
675
+ # str
676
+ # Confirmation of the chosen experiment name.
677
+ # """
678
+ # import mlflow
679
+ # mlflow.set_experiment(experiment_name)
680
+ # return f"Active MLflow experiment set to: {experiment_name}"
681
+
682
+
683
+ # @tool("mlflow_start_run", return_direct=True)
684
+ # def mlflow_start_run(run_name: Optional[str] = None) -> str:
685
+ # """
686
+ # Start a new MLflow run under the current experiment.
687
+
688
+ # Parameters
689
+ # ----------
690
+ # run_name : str, optional
691
+ # Optional run name.
692
+
693
+ # Returns
694
+ # -------
695
+ # str
696
+ # The run_id of the newly started MLflow run.
697
+ # """
698
+ # import mlflow
699
+ # with mlflow.start_run(run_name=run_name) as run:
700
+ # run_id = run.info.run_id
701
+ # return f"MLflow run started with run_id: {run_id}"
702
+
703
+
704
+ # @tool("mlflow_log_params", return_direct=True)
705
+ # def mlflow_log_params(params: Dict[str, Any]) -> str:
706
+ # """
707
+ # Log a batch of parameters to the current MLflow run.
708
+
709
+ # Parameters
710
+ # ----------
711
+ # params : dict
712
+ # A dictionary of parameter name -> parameter value.
713
+
714
+ # Returns
715
+ # -------
716
+ # str
717
+ # Confirmation message.
718
+ # """
719
+ # import mlflow
720
+ # mlflow.log_params(params)
721
+ # return f"Logged parameters: {params}"
722
+
723
+
724
+ # @tool("mlflow_log_metrics", return_direct=True)
725
+ # def mlflow_log_metrics(metrics: Dict[str, float], step: Optional[int] = None) -> str:
726
+ # """
727
+ # Log a dictionary of metrics to the current MLflow run.
728
+
729
+ # Parameters
730
+ # ----------
731
+ # metrics : dict
732
+ # Metric name -> numeric value.
733
+ # step : int, optional
734
+ # The training step or iteration number.
735
+
736
+ # Returns
737
+ # -------
738
+ # str
739
+ # Confirmation message.
740
+ # """
741
+ # import mlflow
742
+ # mlflow.log_metrics(metrics, step=step)
743
+ # return f"Logged metrics: {metrics} at step {step}"
744
+
745
+
746
+ # @tool("mlflow_log_artifact", return_direct=True)
747
+ # def mlflow_log_artifact(artifact_path: str, artifact_folder_name: Optional[str] = None) -> str:
748
+ # """
749
+ # Log a local file or directory as an MLflow artifact.
750
+
751
+ # Parameters
752
+ # ----------
753
+ # artifact_path : str
754
+ # The local path to the file/directory to be logged.
755
+ # artifact_folder_name : str, optional
756
+ # Subfolder within the run's artifact directory.
757
+
758
+ # Returns
759
+ # -------
760
+ # str
761
+ # Confirmation message.
762
+ # """
763
+ # import mlflow
764
+ # if artifact_folder_name:
765
+ # mlflow.log_artifact(artifact_path, artifact_folder_name)
766
+ # return f"Artifact logged from {artifact_path} into folder '{artifact_folder_name}'"
767
+ # else:
768
+ # mlflow.log_artifact(artifact_path)
769
+ # return f"Artifact logged from {artifact_path}"
770
+
771
+
772
+ # @tool("mlflow_log_model", return_direct=True)
773
+ # def mlflow_log_model(model_path: str, registered_model_name: Optional[str] = None) -> str:
774
+ # """
775
+ # Log a model artifact (e.g., an H2O-saved model directory) to MLflow.
776
+
777
+ # Parameters
778
+ # ----------
779
+ # model_path : str
780
+ # The local filesystem path containing the model artifacts.
781
+ # registered_model_name : str, optional
782
+ # If provided, will also attempt to register the model under this name.
783
+
784
+ # Returns
785
+ # -------
786
+ # str
787
+ # Confirmation message with any relevant registration details.
788
+ # """
789
+ # import mlflow
790
+ # if registered_model_name:
791
+ # mlflow.pyfunc.log_model(
792
+ # artifact_path="model",
793
+ # python_model=None, # if you have a pyfunc wrapper, specify it
794
+ # registered_model_name=registered_model_name,
795
+ # code_path=None,
796
+ # conda_env=None,
797
+ # model_path=model_path # for certain model flavors, or use flavors
798
+ # )
799
+ # return f"Model logged and registered under '{registered_model_name}' from path {model_path}"
800
+ # else:
801
+ # # Simple log as generic artifact
802
+ # mlflow.pyfunc.log_model(
803
+ # artifact_path="model",
804
+ # python_model=None,
805
+ # code_path=None,
806
+ # conda_env=None,
807
+ # model_path=model_path
808
+ # )
809
+ # return f"Model logged (no registration) from path {model_path}"
810
+
811
+
812
+ # @tool("mlflow_end_run", return_direct=True)
813
+ # def mlflow_end_run() -> str:
814
+ # """
815
+ # End the current MLflow run (if one is active).
816
+
817
+ # Returns
818
+ # -------
819
+ # str
820
+ # Confirmation message.
821
+ # """
822
+ # import mlflow
823
+ # mlflow.end_run()
824
+ # return "MLflow run ended."
825
+
826
+
827
+ # @tool("mlflow_search_runs", return_direct=True)
828
+ # def mlflow_search_runs(
829
+ # experiment_names_or_ids: Optional[Union[List[str], List[int], str, int]] = None,
830
+ # filter_string: Optional[str] = None
831
+ # ) -> str:
832
+ # """
833
+ # Search runs within one or more MLflow experiments, optionally filtering by a filter_string.
834
+
835
+ # Parameters
836
+ # ----------
837
+ # experiment_names_or_ids : list or str or int, optional
838
+ # Experiment IDs or names.
839
+ # filter_string : str, optional
840
+ # MLflow filter expression, e.g. "metrics.rmse < 1.0".
841
+
842
+ # Returns
843
+ # -------
844
+ # str
845
+ # JSON-formatted list of runs that match the query.
846
+ # """
847
+ # import mlflow
848
+ # import json
849
+ # if experiment_names_or_ids is None:
850
+ # experiment_names_or_ids = []
851
+ # if isinstance(experiment_names_or_ids, (str, int)):
852
+ # experiment_names_or_ids = [experiment_names_or_ids]
853
+
854
+ # df = mlflow.search_runs(
855
+ # experiment_names=experiment_names_or_ids if all(isinstance(e, str) for e in experiment_names_or_ids) else None,
856
+ # experiment_ids=experiment_names_or_ids if all(isinstance(e, int) for e in experiment_names_or_ids) else None,
857
+ # filter_string=filter_string
858
+ # )
859
+ # return df.to_json(orient="records")
860
+
861
+
862
+ # @tool("mlflow_get_run", return_direct=True)
863
+ # def mlflow_get_run(run_id: str) -> str:
864
+ # """
865
+ # Retrieve details (params, metrics, etc.) for a specific MLflow run by ID.
866
+
867
+ # Parameters
868
+ # ----------
869
+ # run_id : str
870
+ # The ID of the MLflow run to retrieve.
871
+
872
+ # Returns
873
+ # -------
874
+ # str
875
+ # JSON-formatted data containing run info, params, and metrics.
876
+ # """
877
+ # from mlflow.tracking import MlflowClient
878
+ # import json
879
+
880
+ # client = MlflowClient()
881
+ # run = client.get_run(run_id)
882
+ # data = {
883
+ # "run_id": run.info.run_id,
884
+ # "experiment_id": run.info.experiment_id,
885
+ # "status": run.info.status,
886
+ # "start_time": run.info.start_time,
887
+ # "end_time": run.info.end_time,
888
+ # "artifact_uri": run.info.artifact_uri,
889
+ # "params": run.data.params,
890
+ # "metrics": run.data.metrics,
891
+ # "tags": run.data.tags
892
+ # }
893
+ # return json.dumps(data)
894
+
895
+
896
+ # @tool("mlflow_load_model", return_direct=True)
897
+ # def mlflow_load_model(model_uri: str) -> str:
898
+ # """
899
+ # Load an MLflow-model (PyFunc flavor or other) into memory, returning a handle reference.
900
+ # For demonstration, we store the loaded model globally in a registry dict.
901
+
902
+ # Parameters
903
+ # ----------
904
+ # model_uri : str
905
+ # The URI of the model to load, e.g. "runs:/<RUN_ID>/model" or "models:/MyModel/Production".
906
+
907
+ # Returns
908
+ # -------
909
+ # str
910
+ # A reference key identifying the loaded model (for subsequent predictions),
911
+ # or a direct message if you prefer to store it differently.
912
+ # """
913
+ # import mlflow.pyfunc
914
+ # from uuid import uuid4
915
+
916
+ # # For demonstration, create a global registry:
917
+ # global _LOADED_MODELS
918
+ # if "_LOADED_MODELS" not in globals():
919
+ # _LOADED_MODELS = {}
920
+
921
+ # loaded_model = mlflow.pyfunc.load_model(model_uri)
922
+ # model_key = f"model_{uuid4().hex}"
923
+ # _LOADED_MODELS[model_key] = loaded_model
924
+
925
+ # return f"Model loaded with reference key: {model_key}"
926
+
927
+
928
+ # @tool("mlflow_predict", return_direct=True)
929
+ # def mlflow_predict(model_key: str, data: List[Dict[str, Any]]) -> str:
930
+ # """
931
+ # Predict using a previously loaded MLflow model (PyFunc), identified by its reference key.
932
+
933
+ # Parameters
934
+ # ----------
935
+ # model_key : str
936
+ # The reference key for the loaded model (returned by mlflow_load_model).
937
+ # data : List[Dict[str, Any]]
938
+ # The data rows for which predictions should be made.
939
+
940
+ # Returns
941
+ # -------
942
+ # str
943
+ # JSON-formatted prediction results.
944
+ # """
945
+ # import pandas as pd
946
+ # import json
947
+
948
+ # global _LOADED_MODELS
949
+ # if model_key not in _LOADED_MODELS:
950
+ # return f"No model found for key: {model_key}"
951
+
952
+ # model = _LOADED_MODELS[model_key]
953
+ # df = pd.DataFrame(data)
954
+ # preds = model.predict(df)
955
+ # # Convert to JSON (DataFrame or Series)
956
+ # if hasattr(preds, "to_json"):
957
+ # return preds.to_json(orient="records")
958
+ # else:
959
+ # # If preds is just a numpy array or list
960
+ # return json.dumps(preds.tolist())
961
+