ai-data-science-team 0.0.0.9008__py3-none-any.whl → 0.0.0.9010__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (29) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +0 -1
  3. ai_data_science_team/agents/data_cleaning_agent.py +50 -39
  4. ai_data_science_team/agents/data_loader_tools_agent.py +69 -0
  5. ai_data_science_team/agents/data_visualization_agent.py +45 -50
  6. ai_data_science_team/agents/data_wrangling_agent.py +50 -49
  7. ai_data_science_team/agents/feature_engineering_agent.py +48 -67
  8. ai_data_science_team/agents/sql_database_agent.py +130 -76
  9. ai_data_science_team/ml_agents/__init__.py +2 -0
  10. ai_data_science_team/ml_agents/h2o_ml_agent.py +852 -0
  11. ai_data_science_team/ml_agents/mlflow_tools_agent.py +327 -0
  12. ai_data_science_team/multiagents/sql_data_analyst.py +120 -9
  13. ai_data_science_team/parsers/__init__.py +0 -0
  14. ai_data_science_team/{tools → parsers}/parsers.py +0 -1
  15. ai_data_science_team/templates/__init__.py +1 -0
  16. ai_data_science_team/templates/agent_templates.py +78 -7
  17. ai_data_science_team/tools/data_loader.py +378 -0
  18. ai_data_science_team/tools/{metadata.py → dataframe.py} +0 -91
  19. ai_data_science_team/tools/h2o.py +643 -0
  20. ai_data_science_team/tools/mlflow.py +961 -0
  21. ai_data_science_team/tools/sql.py +126 -0
  22. ai_data_science_team/{tools → utils}/regex.py +59 -1
  23. {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/METADATA +56 -24
  24. ai_data_science_team-0.0.0.9010.dist-info/RECORD +35 -0
  25. ai_data_science_team-0.0.0.9008.dist-info/RECORD +0 -26
  26. /ai_data_science_team/{tools → utils}/logging.py +0 -0
  27. {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/LICENSE +0 -0
  28. {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/WHEEL +0 -0
  29. {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,961 @@
1
+
2
+
3
+ from typing import Optional, Dict, Any, Union, List, Annotated
4
+ from langgraph.prebuilt import InjectedState
5
+ from langchain.tools import tool
6
+
7
+
8
+ @tool(response_format='content_and_artifact')
9
+ def mlflow_search_experiments(
10
+ filter_string: Optional[str] = None,
11
+ tracking_uri: str | None = None,
12
+ registry_uri: str | None = None
13
+ ) -> str:
14
+ """
15
+ Search and list existing MLflow experiments.
16
+
17
+ Parameters
18
+ ----------
19
+ filter_string : str, optional
20
+ Filter query string (e.g., "name = 'my_experiment'"), defaults to
21
+ searching for all experiments.
22
+
23
+ tracking_uri: str, optional
24
+ Address of local or remote tracking server.
25
+ If not provided, defaults
26
+ to the service set by mlflow.tracking.set_tracking_uri. See Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>_ for more info.
27
+ registry_uri: str, optional
28
+ Address of local or remote model registry
29
+ server. If not provided,
30
+ defaults to the service set by mlflow.tracking.set_registry_uri. If no such service was set, defaults to the tracking uri of the client.
31
+
32
+ Returns
33
+ -------
34
+ tuple
35
+ - JSON-serialized list of experiment metadata (ID, name, etc.).
36
+ - DataFrame of experiment metadata.
37
+ """
38
+ print(" * Tool: mlflow_search_experiments")
39
+ from mlflow.tracking import MlflowClient
40
+ import pandas as pd
41
+
42
+ client = MlflowClient(tracking_uri=tracking_uri, registry_uri=registry_uri)
43
+ experiments = client.search_experiments(filter_string=filter_string)
44
+ # Convert to a dictionary in a list
45
+ experiments_data = [
46
+ dict(e)
47
+ for e in experiments
48
+ ]
49
+ # Convert to a DataFrame
50
+ experiments_df = pd.DataFrame(experiments_data)
51
+ # Convert timestamps to datetime objects
52
+ experiments_df["last_update_time"] = pd.to_datetime(experiments_df["last_update_time"], unit="ms")
53
+ experiments_df["creation_time"] = pd.to_datetime(experiments_df["creation_time"], unit="ms")
54
+
55
+ return (experiments_df.to_dict(), experiments_df.to_dict())
56
+
57
+
58
+ @tool(response_format='content_and_artifact')
59
+ def mlflow_search_runs(
60
+ experiment_ids: Optional[Union[List[str], List[int], str, int]] = None,
61
+ filter_string: Optional[str] = None,
62
+ tracking_uri: str | None = None,
63
+ registry_uri: str | None = None
64
+ ) -> str:
65
+ """
66
+ Search runs within one or more MLflow experiments, optionally filtering by a filter_string.
67
+
68
+ Parameters
69
+ ----------
70
+ experiment_ids : list or str or int, optional
71
+ One or more Experiment IDs.
72
+ filter_string : str, optional
73
+ MLflow filter expression, e.g. "metrics.rmse < 1.0".
74
+ tracking_uri: str, optional
75
+ Address of local or remote tracking server.
76
+ If not provided, defaults
77
+ to the service set by mlflow.tracking.set_tracking_uri. See Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>_ for more info.
78
+ registry_uri: str, optional
79
+ Address of local or remote model registry
80
+ server. If not provided,
81
+ defaults to the service set by mlflow.tracking.set_registry_uri. If no such service was set, defaults to the tracking uri of the client.
82
+
83
+ Returns
84
+ -------
85
+ str
86
+ JSON-formatted list of runs that match the query.
87
+ """
88
+ print(" * Tool: mlflow_search_runs")
89
+ from mlflow.tracking import MlflowClient
90
+ import pandas as pd
91
+
92
+ client = MlflowClient(
93
+ tracking_uri=tracking_uri,
94
+ registry_uri=registry_uri
95
+ )
96
+
97
+ if experiment_ids is None:
98
+ experiment_ids = []
99
+ if isinstance(experiment_ids, (str, int)):
100
+ experiment_ids = [experiment_ids]
101
+
102
+ runs = client.search_runs(
103
+ experiment_ids=experiment_ids,
104
+ filter_string=filter_string
105
+ )
106
+
107
+ # If no runs are found, return an empty DataFrame
108
+ if not runs:
109
+ return "No runs found.", pd.DataFrame()
110
+
111
+ # Extract relevant information
112
+ data = []
113
+ for run in runs:
114
+ run_info = {
115
+ "run_id": run.info.run_id,
116
+ "run_name": run.info.run_name,
117
+ "status": run.info.status,
118
+ "start_time": pd.to_datetime(run.info.start_time, unit="ms"),
119
+ "end_time": pd.to_datetime(run.info.end_time, unit="ms"),
120
+ "experiment_id": run.info.experiment_id,
121
+ "user_id": run.info.user_id
122
+ }
123
+
124
+ # Flatten metrics, parameters, and tags
125
+ run_info.update(run.data.metrics)
126
+ run_info.update({f"param_{k}": v for k, v in run.data.params.items()})
127
+ run_info.update({f"tag_{k}": v for k, v in run.data.tags.items()})
128
+
129
+ data.append(run_info)
130
+
131
+ # Convert to DataFrame
132
+ df = pd.DataFrame(data)
133
+
134
+ return (df.iloc[:,0:15].to_dict(), df.to_dict())
135
+
136
+
137
+
138
+ @tool(response_format='content')
139
+ def mlflow_create_experiment(experiment_name: str) -> str:
140
+ """
141
+ Create a new MLflow experiment by name.
142
+
143
+ Parameters
144
+ ----------
145
+ experiment_name : str
146
+ The name of the experiment to create.
147
+
148
+ Returns
149
+ -------
150
+ str
151
+ The experiment ID or an error message if creation failed.
152
+ """
153
+ print(" * Tool: mlflow_create_experiment")
154
+ from mlflow.tracking import MlflowClient
155
+
156
+ client = MlflowClient()
157
+ exp_id = client.create_experiment(experiment_name)
158
+ return f"Experiment created with ID: {exp_id}, name: {experiment_name}"
159
+
160
+
161
+
162
+
163
+ @tool(response_format='content_and_artifact')
164
+ def mlflow_predict_from_run_id(
165
+ run_id: str,
166
+ data_raw: Annotated[dict, InjectedState("data_raw")],
167
+ tracking_uri: Optional[str] = None
168
+ ) -> tuple:
169
+ """
170
+ Predict using an MLflow model (PyFunc) directly from a given run ID.
171
+
172
+ Parameters
173
+ ----------
174
+ run_id : str
175
+ The ID of the MLflow run that logged the model.
176
+ data_raw : dict
177
+ The incoming data as a dictionary.
178
+ tracking_uri : str, optional
179
+ Address of local or remote tracking server.
180
+
181
+ Returns
182
+ -------
183
+ tuple
184
+ (user_facing_message, artifact_dict)
185
+ """
186
+ print(" * Tool: mlflow_predict_from_run_id")
187
+ import mlflow
188
+ import mlflow.pyfunc
189
+ import pandas as pd
190
+
191
+ # 1. Check if data is loaded
192
+ if not data_raw:
193
+ return "No data provided for prediction. Please use `data_raw` parameter inside of `invoke_agent()` or `ainvoke_agent()`.", {}
194
+ df = pd.DataFrame(data_raw)
195
+
196
+ # 2. Prepare model URI
197
+ model_uri = f"runs:/{run_id}/model"
198
+
199
+ # 3. Load or cache the MLflow model
200
+ model = mlflow.pyfunc.load_model(model_uri)
201
+
202
+ # 4. Make predictions
203
+ try:
204
+ preds = model.predict(df)
205
+ except Exception as e:
206
+ return f"Error during inference: {str(e)}", {}
207
+
208
+ # 5. Convert predictions to a user-friendly summary + artifact
209
+ if isinstance(preds, pd.DataFrame):
210
+ sample_json = preds.head().to_json(orient='records')
211
+ artifact_dict = preds.to_dict(orient='records') # entire DF
212
+ message = f"Predictions returned. Sample: {sample_json}"
213
+ elif hasattr(preds, "to_json"):
214
+ # e.g., pd.Series
215
+ sample_json = preds[:5].to_json(orient='records')
216
+ artifact_dict = preds.to_dict()
217
+ message = f"Predictions returned. Sample: {sample_json}"
218
+ elif hasattr(preds, "tolist"):
219
+ # e.g., a NumPy array
220
+ preds_list = preds.tolist()
221
+ artifact_dict = {"predictions": preds_list}
222
+ message = f"Predictions returned. First 5: {preds_list[:5]}"
223
+ else:
224
+ # fallback
225
+ preds_str = str(preds)
226
+ artifact_dict = {"predictions": preds_str}
227
+ message = f"Predictions returned (unrecognized type). Example: {preds_str[:100]}..."
228
+
229
+ return (message, artifact_dict)
230
+
231
+
232
+ # MLflow tool to launch gui for mlflow
233
+ @tool(response_format='content')
234
+ def mlflow_launch_ui(
235
+ port: int = 5000,
236
+ host: str = "localhost",
237
+ tracking_uri: Optional[str] = None
238
+ ) -> str:
239
+ """
240
+ Launch the MLflow UI.
241
+
242
+ Parameters
243
+ ----------
244
+ port : int, optional
245
+ The port on which to run the UI.
246
+ host : str, optional
247
+ The host address to bind the UI to.
248
+ tracking_uri : str, optional
249
+ Address of local or remote tracking server.
250
+
251
+ Returns
252
+ -------
253
+ str
254
+ Confirmation message.
255
+ """
256
+ print(" * Tool: mlflow_launch_ui")
257
+ import subprocess
258
+
259
+ # Try binding to the user-specified port first
260
+ allocated_port = _find_free_port(start_port=port, host=host)
261
+
262
+ cmd = ["mlflow", "ui", "--host", host, "--port", str(allocated_port)]
263
+ if tracking_uri:
264
+ cmd.extend(["--backend-store-uri", tracking_uri])
265
+
266
+ process = subprocess.Popen(cmd)
267
+ return (f"MLflow UI launched at http://{host}:{allocated_port}. "
268
+ f"(PID: {process.pid})")
269
+
270
+ def _find_free_port(start_port: int, host: str) -> int:
271
+ """
272
+ Find a free port >= start_port on the specified host.
273
+ If the start_port is free, returns start_port, else tries subsequent ports.
274
+ """
275
+ import socket
276
+ for port_candidate in range(start_port, start_port + 1000):
277
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
278
+ try:
279
+ sock.bind((host, port_candidate))
280
+ except OSError:
281
+ # Port is in use, try the next one
282
+ continue
283
+ # If bind succeeds, it's free
284
+ return port_candidate
285
+
286
+ raise OSError("No available ports found in the range "
287
+ f"{start_port}-{start_port + 999}")
288
+
289
+
290
+ @tool(response_format='content')
291
+ def mlflow_stop_ui(port: int = 5000) -> str:
292
+ """
293
+ Kill any process currently listening on the given MLflow UI port.
294
+ Requires `pip install psutil`.
295
+
296
+ Parameters
297
+ ----------
298
+ port : int, optional
299
+ The port on which the UI is running.
300
+ """
301
+ print(" * Tool: mlflow_stop_ui")
302
+ import psutil
303
+
304
+ # Gather system-wide inet connections
305
+ for conn in psutil.net_connections(kind="inet"):
306
+ # Check if this connection has a local address (laddr) and if
307
+ # the port matches the one we're trying to free
308
+ if conn.laddr and conn.laddr.port == port:
309
+ # Some connections may not have an associated PID
310
+ if conn.pid is not None:
311
+ try:
312
+ p = psutil.Process(conn.pid)
313
+ p_name = p.name() # optional: get process name for clarity
314
+ p.kill() # forcibly terminate the process
315
+ return (
316
+ f"Killed process {conn.pid} ({p_name}) listening on port {port}."
317
+ )
318
+ except psutil.NoSuchProcess:
319
+ return (
320
+ "Process was already terminated before we could kill it."
321
+ )
322
+ return f"No process found listening on port {port}."
323
+
324
+
325
+ @tool(response_format='content_and_artifact')
326
+ def mlflow_list_artifacts(
327
+ run_id: str,
328
+ path: Optional[str] = None,
329
+ tracking_uri: Optional[str] = None
330
+ ) -> tuple:
331
+ """
332
+ List artifacts under a given MLflow run.
333
+
334
+ Parameters
335
+ ----------
336
+ run_id : str
337
+ The ID of the run whose artifacts to list.
338
+ path : str, optional
339
+ Path within the run's artifact directory to list. Defaults to the root.
340
+ tracking_uri : str, optional
341
+ Custom tracking server URI.
342
+
343
+ Returns
344
+ -------
345
+ tuple
346
+ (summary_message, artifact_listing)
347
+ """
348
+ print(" * Tool: mlflow_list_artifacts")
349
+ from mlflow.tracking import MlflowClient
350
+
351
+ client = MlflowClient(tracking_uri=tracking_uri)
352
+ # If path is None, list the root folder
353
+ artifact_list = client.list_artifacts(run_id, path or "")
354
+
355
+ # Convert to a more user-friendly structure
356
+ artifacts_data = []
357
+ for artifact in artifact_list:
358
+ artifacts_data.append({
359
+ "path": artifact.path,
360
+ "is_dir": artifact.is_dir,
361
+ "file_size": artifact.file_size
362
+ })
363
+
364
+ return (
365
+ f"Found {len(artifacts_data)} artifacts.",
366
+ artifacts_data
367
+ )
368
+
369
+
370
+ @tool(response_format='content_and_artifact')
371
+ def mlflow_download_artifacts(
372
+ run_id: str,
373
+ path: Optional[str] = None,
374
+ dst_path: Optional[str] = "./downloaded_artifacts",
375
+ tracking_uri: Optional[str] = None
376
+ ) -> tuple:
377
+ """
378
+ Download artifacts from MLflow to a local directory.
379
+
380
+ Parameters
381
+ ----------
382
+ run_id : str
383
+ The ID of the run whose artifacts to download.
384
+ path : str, optional
385
+ Path within the run's artifact directory to download. Defaults to the root.
386
+ dst_path : str, optional
387
+ Local destination path to store artifacts.
388
+ tracking_uri : str, optional
389
+ MLflow tracking server URI.
390
+
391
+ Returns
392
+ -------
393
+ tuple
394
+ (summary_message, artifact_dict)
395
+ """
396
+ print(" * Tool: mlflow_download_artifacts")
397
+ from mlflow.tracking import MlflowClient
398
+ import os
399
+
400
+ client = MlflowClient(tracking_uri=tracking_uri)
401
+ local_path = client.download_artifacts(run_id, path or "", dst_path)
402
+
403
+ # Build a recursive listing of what was downloaded
404
+ downloaded_files = []
405
+ for root, dirs, files in os.walk(local_path):
406
+ for f in files:
407
+ downloaded_files.append(os.path.join(root, f))
408
+
409
+ message = (
410
+ f"Artifacts for run_id='{run_id}' have been downloaded to: {local_path}. "
411
+ f"Total files: {len(downloaded_files)}."
412
+ )
413
+
414
+ return (
415
+ message,
416
+ {"downloaded_files": downloaded_files}
417
+ )
418
+
419
+
420
+ @tool(response_format='content_and_artifact')
421
+ def mlflow_list_registered_models(
422
+ max_results: int = 100,
423
+ tracking_uri: Optional[str] = None,
424
+ registry_uri: Optional[str] = None
425
+ ) -> tuple:
426
+ """
427
+ List all registered models in MLflow's model registry.
428
+
429
+ Parameters
430
+ ----------
431
+ max_results : int, optional
432
+ Maximum number of models to return.
433
+ tracking_uri : str, optional
434
+ registry_uri : str, optional
435
+
436
+ Returns
437
+ -------
438
+ tuple
439
+ (summary_message, model_list)
440
+ """
441
+ print(" * Tool: mlflow_list_registered_models")
442
+ from mlflow.tracking import MlflowClient
443
+
444
+ client = MlflowClient(tracking_uri=tracking_uri, registry_uri=registry_uri)
445
+ # The list_registered_models() call can be paginated; for simplicity, we just pass max_results
446
+ models = client.list_registered_models(max_results=max_results)
447
+
448
+ models_data = []
449
+ for m in models:
450
+ models_data.append({
451
+ "name": m.name,
452
+ "latest_versions": [
453
+ {
454
+ "version": v.version,
455
+ "run_id": v.run_id,
456
+ "current_stage": v.current_stage,
457
+ }
458
+ for v in m.latest_versions
459
+ ]
460
+ })
461
+
462
+ return (
463
+ f"Found {len(models_data)} registered models.",
464
+ models_data
465
+ )
466
+
467
+
468
+ @tool(response_format='content_and_artifact')
469
+ def mlflow_search_registered_models(
470
+ filter_string: Optional[str] = None,
471
+ order_by: Optional[List[str]] = None,
472
+ max_results: int = 100,
473
+ tracking_uri: Optional[str] = None,
474
+ registry_uri: Optional[str] = None
475
+ ) -> tuple:
476
+ """
477
+ Search registered models in MLflow's registry using optional filters.
478
+
479
+ Parameters
480
+ ----------
481
+ filter_string : str, optional
482
+ e.g. "name LIKE 'my_model%'" or "tags.stage = 'production'".
483
+ order_by : list, optional
484
+ e.g. ["name ASC"] or ["timestamp DESC"].
485
+ max_results : int, optional
486
+ Max number of results.
487
+ tracking_uri : str, optional
488
+ registry_uri : str, optional
489
+
490
+ Returns
491
+ -------
492
+ tuple
493
+ (summary_message, model_dict_list)
494
+ """
495
+ print(" * Tool: mlflow_search_registered_models")
496
+ from mlflow.tracking import MlflowClient
497
+
498
+ client = MlflowClient(tracking_uri=tracking_uri, registry_uri=registry_uri)
499
+ models = client.search_registered_models(
500
+ filter_string=filter_string,
501
+ order_by=order_by,
502
+ max_results=max_results
503
+ )
504
+
505
+ models_data = []
506
+ for m in models:
507
+ models_data.append({
508
+ "name": m.name,
509
+ "description": m.description,
510
+ "creation_timestamp": m.creation_timestamp,
511
+ "last_updated_timestamp": m.last_updated_timestamp,
512
+ "latest_versions": [
513
+ {
514
+ "version": v.version,
515
+ "run_id": v.run_id,
516
+ "current_stage": v.current_stage
517
+ }
518
+ for v in m.latest_versions
519
+ ]
520
+ })
521
+
522
+ return (
523
+ f"Found {len(models_data)} models matching filter={filter_string}.",
524
+ models_data
525
+ )
526
+
527
+
528
+ @tool(response_format='content_and_artifact')
529
+ def mlflow_get_model_version_details(
530
+ name: str,
531
+ version: str,
532
+ tracking_uri: Optional[str] = None,
533
+ registry_uri: Optional[str] = None
534
+ ) -> tuple:
535
+ """
536
+ Retrieve details about a specific model version in the MLflow registry.
537
+
538
+ Parameters
539
+ ----------
540
+ name : str
541
+ Name of the registered model.
542
+ version : str
543
+ Version number of that model.
544
+ tracking_uri : str, optional
545
+ registry_uri : str, optional
546
+
547
+ Returns
548
+ -------
549
+ tuple
550
+ (summary_message, version_data_dict)
551
+ """
552
+ print(" * Tool: mlflow_get_model_version_details")
553
+ from mlflow.tracking import MlflowClient
554
+
555
+ client = MlflowClient(tracking_uri=tracking_uri, registry_uri=registry_uri)
556
+ version_details = client.get_model_version(name, version)
557
+
558
+ data = {
559
+ "name": version_details.name,
560
+ "version": version_details.version,
561
+ "run_id": version_details.run_id,
562
+ "creation_timestamp": version_details.creation_timestamp,
563
+ "current_stage": version_details.current_stage,
564
+ "description": version_details.description,
565
+ "status": version_details.status
566
+ }
567
+
568
+ return (
569
+ f"Model version details retrieved for {name} v{version}",
570
+ data
571
+ )
572
+
573
+
574
+ # @tool
575
+ # def get_or_create_experiment(experiment_name):
576
+ # """
577
+ # Retrieve the ID of an existing MLflow experiment or create a new one if it doesn't exist.
578
+
579
+ # This function checks if an experiment with the given name exists within MLflow.
580
+ # If it does, the function returns its ID. If not, it creates a new experiment
581
+ # with the provided name and returns its ID.
582
+
583
+ # Parameters:
584
+ # - experiment_name (str): Name of the MLflow experiment.
585
+
586
+ # Returns:
587
+ # - str: ID of the existing or newly created MLflow experiment.
588
+ # """
589
+ # import mlflow
590
+ # if experiment := mlflow.get_experiment_by_name(experiment_name):
591
+ # return experiment.experiment_id
592
+ # else:
593
+ # return mlflow.create_experiment(experiment_name)
594
+
595
+
596
+
597
+ # @tool("mlflow_set_tracking_uri", return_direct=True)
598
+ # def mlflow_set_tracking_uri(tracking_uri: str) -> str:
599
+ # """
600
+ # Set or change the MLflow tracking URI.
601
+
602
+ # Parameters
603
+ # ----------
604
+ # tracking_uri : str
605
+ # The URI/path where MLflow logs & metrics are stored.
606
+
607
+ # Returns
608
+ # -------
609
+ # str
610
+ # Confirmation message.
611
+ # """
612
+ # import mlflow
613
+ # mlflow.set_tracking_uri(tracking_uri)
614
+ # return f"MLflow tracking URI set to: {tracking_uri}"
615
+
616
+
617
+ # @tool("mlflow_list_experiments", return_direct=True)
618
+ # def mlflow_list_experiments() -> str:
619
+ # """
620
+ # List existing MLflow experiments.
621
+
622
+ # Returns
623
+ # -------
624
+ # str
625
+ # JSON-serialized list of experiment metadata (ID, name, etc.).
626
+ # """
627
+ # from mlflow.tracking import MlflowClient
628
+ # import json
629
+
630
+ # client = MlflowClient()
631
+ # experiments = client.list_experiments()
632
+ # # Convert to a JSON-like structure
633
+ # experiments_data = [
634
+ # dict(experiment_id=e.experiment_id, name=e.name, artifact_location=e.artifact_location)
635
+ # for e in experiments
636
+ # ]
637
+
638
+ # return json.dumps(experiments_data)
639
+
640
+
641
+ # @tool("mlflow_create_experiment", return_direct=True)
642
+ # def mlflow_create_experiment(experiment_name: str) -> str:
643
+ # """
644
+ # Create a new MLflow experiment by name.
645
+
646
+ # Parameters
647
+ # ----------
648
+ # experiment_name : str
649
+ # The name of the experiment to create.
650
+
651
+ # Returns
652
+ # -------
653
+ # str
654
+ # The experiment ID or an error message if creation failed.
655
+ # """
656
+ # from mlflow.tracking import MlflowClient
657
+
658
+ # client = MlflowClient()
659
+ # exp_id = client.create_experiment(experiment_name)
660
+ # return f"Experiment created with ID: {exp_id}"
661
+
662
+
663
+ # @tool("mlflow_set_experiment", return_direct=True)
664
+ # def mlflow_set_experiment(experiment_name: str) -> str:
665
+ # """
666
+ # Set or create an MLflow experiment for subsequent logging.
667
+
668
+ # Parameters
669
+ # ----------
670
+ # experiment_name : str
671
+ # The name of the experiment to set.
672
+
673
+ # Returns
674
+ # -------
675
+ # str
676
+ # Confirmation of the chosen experiment name.
677
+ # """
678
+ # import mlflow
679
+ # mlflow.set_experiment(experiment_name)
680
+ # return f"Active MLflow experiment set to: {experiment_name}"
681
+
682
+
683
+ # @tool("mlflow_start_run", return_direct=True)
684
+ # def mlflow_start_run(run_name: Optional[str] = None) -> str:
685
+ # """
686
+ # Start a new MLflow run under the current experiment.
687
+
688
+ # Parameters
689
+ # ----------
690
+ # run_name : str, optional
691
+ # Optional run name.
692
+
693
+ # Returns
694
+ # -------
695
+ # str
696
+ # The run_id of the newly started MLflow run.
697
+ # """
698
+ # import mlflow
699
+ # with mlflow.start_run(run_name=run_name) as run:
700
+ # run_id = run.info.run_id
701
+ # return f"MLflow run started with run_id: {run_id}"
702
+
703
+
704
+ # @tool("mlflow_log_params", return_direct=True)
705
+ # def mlflow_log_params(params: Dict[str, Any]) -> str:
706
+ # """
707
+ # Log a batch of parameters to the current MLflow run.
708
+
709
+ # Parameters
710
+ # ----------
711
+ # params : dict
712
+ # A dictionary of parameter name -> parameter value.
713
+
714
+ # Returns
715
+ # -------
716
+ # str
717
+ # Confirmation message.
718
+ # """
719
+ # import mlflow
720
+ # mlflow.log_params(params)
721
+ # return f"Logged parameters: {params}"
722
+
723
+
724
+ # @tool("mlflow_log_metrics", return_direct=True)
725
+ # def mlflow_log_metrics(metrics: Dict[str, float], step: Optional[int] = None) -> str:
726
+ # """
727
+ # Log a dictionary of metrics to the current MLflow run.
728
+
729
+ # Parameters
730
+ # ----------
731
+ # metrics : dict
732
+ # Metric name -> numeric value.
733
+ # step : int, optional
734
+ # The training step or iteration number.
735
+
736
+ # Returns
737
+ # -------
738
+ # str
739
+ # Confirmation message.
740
+ # """
741
+ # import mlflow
742
+ # mlflow.log_metrics(metrics, step=step)
743
+ # return f"Logged metrics: {metrics} at step {step}"
744
+
745
+
746
+ # @tool("mlflow_log_artifact", return_direct=True)
747
+ # def mlflow_log_artifact(artifact_path: str, artifact_folder_name: Optional[str] = None) -> str:
748
+ # """
749
+ # Log a local file or directory as an MLflow artifact.
750
+
751
+ # Parameters
752
+ # ----------
753
+ # artifact_path : str
754
+ # The local path to the file/directory to be logged.
755
+ # artifact_folder_name : str, optional
756
+ # Subfolder within the run's artifact directory.
757
+
758
+ # Returns
759
+ # -------
760
+ # str
761
+ # Confirmation message.
762
+ # """
763
+ # import mlflow
764
+ # if artifact_folder_name:
765
+ # mlflow.log_artifact(artifact_path, artifact_folder_name)
766
+ # return f"Artifact logged from {artifact_path} into folder '{artifact_folder_name}'"
767
+ # else:
768
+ # mlflow.log_artifact(artifact_path)
769
+ # return f"Artifact logged from {artifact_path}"
770
+
771
+
772
+ # @tool("mlflow_log_model", return_direct=True)
773
+ # def mlflow_log_model(model_path: str, registered_model_name: Optional[str] = None) -> str:
774
+ # """
775
+ # Log a model artifact (e.g., an H2O-saved model directory) to MLflow.
776
+
777
+ # Parameters
778
+ # ----------
779
+ # model_path : str
780
+ # The local filesystem path containing the model artifacts.
781
+ # registered_model_name : str, optional
782
+ # If provided, will also attempt to register the model under this name.
783
+
784
+ # Returns
785
+ # -------
786
+ # str
787
+ # Confirmation message with any relevant registration details.
788
+ # """
789
+ # import mlflow
790
+ # if registered_model_name:
791
+ # mlflow.pyfunc.log_model(
792
+ # artifact_path="model",
793
+ # python_model=None, # if you have a pyfunc wrapper, specify it
794
+ # registered_model_name=registered_model_name,
795
+ # code_path=None,
796
+ # conda_env=None,
797
+ # model_path=model_path # for certain model flavors, or use flavors
798
+ # )
799
+ # return f"Model logged and registered under '{registered_model_name}' from path {model_path}"
800
+ # else:
801
+ # # Simple log as generic artifact
802
+ # mlflow.pyfunc.log_model(
803
+ # artifact_path="model",
804
+ # python_model=None,
805
+ # code_path=None,
806
+ # conda_env=None,
807
+ # model_path=model_path
808
+ # )
809
+ # return f"Model logged (no registration) from path {model_path}"
810
+
811
+
812
+ # @tool("mlflow_end_run", return_direct=True)
813
+ # def mlflow_end_run() -> str:
814
+ # """
815
+ # End the current MLflow run (if one is active).
816
+
817
+ # Returns
818
+ # -------
819
+ # str
820
+ # Confirmation message.
821
+ # """
822
+ # import mlflow
823
+ # mlflow.end_run()
824
+ # return "MLflow run ended."
825
+
826
+
827
+ # @tool("mlflow_search_runs", return_direct=True)
828
+ # def mlflow_search_runs(
829
+ # experiment_names_or_ids: Optional[Union[List[str], List[int], str, int]] = None,
830
+ # filter_string: Optional[str] = None
831
+ # ) -> str:
832
+ # """
833
+ # Search runs within one or more MLflow experiments, optionally filtering by a filter_string.
834
+
835
+ # Parameters
836
+ # ----------
837
+ # experiment_names_or_ids : list or str or int, optional
838
+ # Experiment IDs or names.
839
+ # filter_string : str, optional
840
+ # MLflow filter expression, e.g. "metrics.rmse < 1.0".
841
+
842
+ # Returns
843
+ # -------
844
+ # str
845
+ # JSON-formatted list of runs that match the query.
846
+ # """
847
+ # import mlflow
848
+ # import json
849
+ # if experiment_names_or_ids is None:
850
+ # experiment_names_or_ids = []
851
+ # if isinstance(experiment_names_or_ids, (str, int)):
852
+ # experiment_names_or_ids = [experiment_names_or_ids]
853
+
854
+ # df = mlflow.search_runs(
855
+ # experiment_names=experiment_names_or_ids if all(isinstance(e, str) for e in experiment_names_or_ids) else None,
856
+ # experiment_ids=experiment_names_or_ids if all(isinstance(e, int) for e in experiment_names_or_ids) else None,
857
+ # filter_string=filter_string
858
+ # )
859
+ # return df.to_json(orient="records")
860
+
861
+
862
+ # @tool("mlflow_get_run", return_direct=True)
863
+ # def mlflow_get_run(run_id: str) -> str:
864
+ # """
865
+ # Retrieve details (params, metrics, etc.) for a specific MLflow run by ID.
866
+
867
+ # Parameters
868
+ # ----------
869
+ # run_id : str
870
+ # The ID of the MLflow run to retrieve.
871
+
872
+ # Returns
873
+ # -------
874
+ # str
875
+ # JSON-formatted data containing run info, params, and metrics.
876
+ # """
877
+ # from mlflow.tracking import MlflowClient
878
+ # import json
879
+
880
+ # client = MlflowClient()
881
+ # run = client.get_run(run_id)
882
+ # data = {
883
+ # "run_id": run.info.run_id,
884
+ # "experiment_id": run.info.experiment_id,
885
+ # "status": run.info.status,
886
+ # "start_time": run.info.start_time,
887
+ # "end_time": run.info.end_time,
888
+ # "artifact_uri": run.info.artifact_uri,
889
+ # "params": run.data.params,
890
+ # "metrics": run.data.metrics,
891
+ # "tags": run.data.tags
892
+ # }
893
+ # return json.dumps(data)
894
+
895
+
896
+ # @tool("mlflow_load_model", return_direct=True)
897
+ # def mlflow_load_model(model_uri: str) -> str:
898
+ # """
899
+ # Load an MLflow-model (PyFunc flavor or other) into memory, returning a handle reference.
900
+ # For demonstration, we store the loaded model globally in a registry dict.
901
+
902
+ # Parameters
903
+ # ----------
904
+ # model_uri : str
905
+ # The URI of the model to load, e.g. "runs:/<RUN_ID>/model" or "models:/MyModel/Production".
906
+
907
+ # Returns
908
+ # -------
909
+ # str
910
+ # A reference key identifying the loaded model (for subsequent predictions),
911
+ # or a direct message if you prefer to store it differently.
912
+ # """
913
+ # import mlflow.pyfunc
914
+ # from uuid import uuid4
915
+
916
+ # # For demonstration, create a global registry:
917
+ # global _LOADED_MODELS
918
+ # if "_LOADED_MODELS" not in globals():
919
+ # _LOADED_MODELS = {}
920
+
921
+ # loaded_model = mlflow.pyfunc.load_model(model_uri)
922
+ # model_key = f"model_{uuid4().hex}"
923
+ # _LOADED_MODELS[model_key] = loaded_model
924
+
925
+ # return f"Model loaded with reference key: {model_key}"
926
+
927
+
928
+ # @tool("mlflow_predict", return_direct=True)
929
+ # def mlflow_predict(model_key: str, data: List[Dict[str, Any]]) -> str:
930
+ # """
931
+ # Predict using a previously loaded MLflow model (PyFunc), identified by its reference key.
932
+
933
+ # Parameters
934
+ # ----------
935
+ # model_key : str
936
+ # The reference key for the loaded model (returned by mlflow_load_model).
937
+ # data : List[Dict[str, Any]]
938
+ # The data rows for which predictions should be made.
939
+
940
+ # Returns
941
+ # -------
942
+ # str
943
+ # JSON-formatted prediction results.
944
+ # """
945
+ # import pandas as pd
946
+ # import json
947
+
948
+ # global _LOADED_MODELS
949
+ # if model_key not in _LOADED_MODELS:
950
+ # return f"No model found for key: {model_key}"
951
+
952
+ # model = _LOADED_MODELS[model_key]
953
+ # df = pd.DataFrame(data)
954
+ # preds = model.predict(df)
955
+ # # Convert to JSON (DataFrame or Series)
956
+ # if hasattr(preds, "to_json"):
957
+ # return preds.to_json(orient="records")
958
+ # else:
959
+ # # If preds is just a numpy array or list
960
+ # return json.dumps(preds.tolist())
961
+