fairo 25.7.2__tar.gz → 25.12.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. {fairo-25.7.2 → fairo-25.12.1}/PKG-INFO +4 -6
  2. fairo-25.12.1/fairo/__init__.py +1 -0
  3. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/chat/chat.py +50 -12
  4. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/client/client.py +9 -2
  5. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/execution/agent_serializer.py +185 -21
  6. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/execution/executor.py +44 -7
  7. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/execution/model_log_helper.py +106 -25
  8. fairo-25.12.1/fairo/core/tools/__init__.py +2 -0
  9. fairo-25.12.1/fairo/core/tools/plot.py +250 -0
  10. fairo-25.12.1/fairo/core/tools/suggestion.py +43 -0
  11. fairo-25.12.1/fairo/core/utils.py +320 -0
  12. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/workflow/dependency.py +19 -156
  13. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/workflow/utils.py +225 -60
  14. {fairo-25.7.2 → fairo-25.12.1}/fairo/settings.py +1 -1
  15. {fairo-25.7.2 → fairo-25.12.1}/fairo.egg-info/PKG-INFO +4 -6
  16. {fairo-25.7.2 → fairo-25.12.1}/fairo.egg-info/SOURCES.txt +4 -0
  17. fairo-25.12.1/fairo.egg-info/requires.txt +8 -0
  18. {fairo-25.7.2 → fairo-25.12.1}/pyproject.toml +3 -5
  19. fairo-25.7.2/fairo/__init__.py +0 -1
  20. fairo-25.7.2/fairo.egg-info/requires.txt +0 -10
  21. {fairo-25.7.2 → fairo-25.12.1}/README.md +0 -0
  22. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/__init__.py +0 -0
  23. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/__init__.py +0 -0
  24. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/base_agent.py +0 -0
  25. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/code_analysis_agent.py +0 -0
  26. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/output/__init__.py +0 -0
  27. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/output/base_output.py +0 -0
  28. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/output/google_drive.py +0 -0
  29. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/tools/__init__.py +0 -0
  30. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/tools/base_tools.py +0 -0
  31. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/tools/code_analysis.py +0 -0
  32. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/tools/utils.py +0 -0
  33. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/utils.py +0 -0
  34. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/chat/__init__.py +0 -0
  35. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/client/__init__.py +0 -0
  36. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/exceptions.py +0 -0
  37. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/execution/__init__.py +0 -0
  38. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/execution/env_finder.py +0 -0
  39. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/models/__init__.py +0 -0
  40. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/models/custom_field_value.py +0 -0
  41. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/models/resources.py +0 -0
  42. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/runnable/__init__.py +0 -0
  43. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/runnable/runnable.py +0 -0
  44. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/workflow/__init__.py +0 -0
  45. {fairo-25.7.2 → fairo-25.12.1}/fairo/core/workflow/base_workflow.py +0 -0
  46. {fairo-25.7.2 → fairo-25.12.1}/fairo/metrics/__init__.py +0 -0
  47. {fairo-25.7.2 → fairo-25.12.1}/fairo/metrics/fairness_object.py +0 -0
  48. {fairo-25.7.2 → fairo-25.12.1}/fairo/metrics/metrics.py +0 -0
  49. {fairo-25.7.2 → fairo-25.12.1}/fairo/tests/__init__.py +0 -0
  50. {fairo-25.7.2 → fairo-25.12.1}/fairo/tests/test_metrics.py +0 -0
  51. {fairo-25.7.2 → fairo-25.12.1}/fairo.egg-info/dependency_links.txt +0 -0
  52. {fairo-25.7.2 → fairo-25.12.1}/fairo.egg-info/top_level.txt +0 -0
  53. {fairo-25.7.2 → fairo-25.12.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fairo
3
- Version: 25.7.2
3
+ Version: 25.12.1
4
4
  Summary: SDK for interfacing with Fairo SaaS platform.
5
5
  Author-email: "Fairo Systems, Inc." <support@fairo.ai>
6
6
  License: Apache-2.0
@@ -12,13 +12,11 @@ Classifier: License :: OSI Approved :: Apache Software License
12
12
  Classifier: Operating System :: OS Independent
13
13
  Description-Content-Type: text/markdown
14
14
  Requires-Dist: mlflow<=3.1.1,>=3.1.0
15
- Requires-Dist: langchain<0.4.0,>=0.3.20
15
+ Requires-Dist: langchain<0.4.0,>=0.3.27
16
16
  Requires-Dist: langchain-aws<0.3.0,>=0.2.18
17
- Requires-Dist: langchain-community<0.4.0,>=0.3.20
17
+ Requires-Dist: langchain-community<0.4.0,>=0.3.27
18
18
  Requires-Dist: langchain-core<0.4.0,>=0.3.49
19
- Requires-Dist: langchain-text-splitters<0.4.0,>=0.3.7
20
- Requires-Dist: psycopg2-binary<3.0.0,>=2.9.0
21
- Requires-Dist: langchain-postgres<0.1.0,>=0.0.14
19
+ Requires-Dist: langchain-text-splitters<0.4.0,>=0.3.11
22
20
  Requires-Dist: setuptools>=79.0.0
23
21
  Requires-Dist: pandas<3.0.0,>=2.0.0
24
22
 
@@ -0,0 +1 @@
1
+ __version__ = "25.12.1"
@@ -31,10 +31,10 @@ class FairoDeploymentClient(BaseDeploymentClient):
31
31
 
32
32
  # Add authentication if needed
33
33
  auth = None
34
- if os.environ.get('MLFLOW_TRACKING_USERNAME') and os.environ.get('MLFLOW_TRACKING_PASSWORD'):
34
+ if os.environ.get('FAIRO_API_ACCESS_KEY_ID') and os.environ.get('FAIRO_API_SECRET'):
35
35
  auth = HTTPBasicAuth(
36
- os.environ.get('MLFLOW_TRACKING_USERNAME'),
37
- os.environ.get('MLFLOW_TRACKING_PASSWORD')
36
+ os.environ.get('FAIRO_API_ACCESS_KEY_ID'),
37
+ os.environ.get('FAIRO_API_SECRET')
38
38
  )
39
39
 
40
40
  # Make streaming request
@@ -116,12 +116,15 @@ class FairoDeploymentClient(BaseDeploymentClient):
116
116
 
117
117
  # Add authentication if needed
118
118
  auth = None
119
- if os.environ.get('MLFLOW_TRACKING_USERNAME') and os.environ.get('MLFLOW_TRACKING_PASSWORD'):
119
+ if os.environ.get('FAIRO_API_ACCESS_KEY_ID') and os.environ.get('FAIRO_API_SECRET'):
120
120
  auth = HTTPBasicAuth(
121
- os.environ.get('MLFLOW_TRACKING_USERNAME'),
122
- os.environ.get('MLFLOW_TRACKING_PASSWORD')
121
+ os.environ.get('FAIRO_API_ACCESS_KEY_ID'),
122
+ os.environ.get('FAIRO_API_SECRET')
123
123
  )
124
124
 
125
+ if os.environ.get('MLFLOW_TRACKING_TOKEN'):
126
+ headers['Authorization'] = f"Bearer {os.environ.get('MLFLOW_TRACKING_TOKEN')}"
127
+
125
128
  # Make request
126
129
  response = requests.post(
127
130
  gateway_url,
@@ -180,10 +183,6 @@ class ChatFairo(ChatMlflow):
180
183
 
181
184
  def __init__(self, **kwargs):
182
185
 
183
- # # TODO <- see if this can be improved
184
- # os.environ["MLFLOW_TRACKING_USERNAME"] = get_mlflow_user()
185
- # os.environ["MLFLOW_TRACKING_PASSWORD"] = get_mlflow_password()
186
-
187
186
  super().__init__(
188
187
  target_uri=os.environ.get('MLFLOW_GATEWAY_URI', get_mlflow_gateway_uri()),
189
188
  endpoint=os.environ.get('MLFLOW_GATEWAY_ROUTE', get_mlflow_gateway_chat_route()),
@@ -194,17 +193,56 @@ class ChatFairo(ChatMlflow):
194
193
 
195
194
  @property
196
195
  def _target_uri(self):
197
- return os.environ.get("MLFLOW_GATEWAY_URI", None)
196
+ return os.environ.get("MLFLOW_GATEWAY_URI", get_mlflow_gateway_uri())
198
197
 
199
198
  @property
200
199
  def _endpoint(self):
201
- return os.environ.get("MLFLOW_GATEWAY_ROUTE", None)
200
+ return os.environ.get("MLFLOW_GATEWAY_ROUTE", get_mlflow_gateway_chat_route())
202
201
 
203
202
  def invoke(self, *args, **kwargs):
204
203
  # Override invoke to use dynamic target_uri
205
204
  self.target_uri = self._target_uri
206
205
  self._client = FairoDeploymentClient(self.target_uri, self.endpoint)
207
206
  return super().invoke(*args, **kwargs)
207
+
208
+ def stream(self, *args, **kwargs):
209
+ # Override stream to use dynamic target_uri
210
+ self.target_uri = self._target_uri
211
+ self._client = FairoDeploymentClient(self.target_uri, self.endpoint)
212
+ return super().stream(*args, **kwargs)
213
+
214
+ def bind_tools(self, tools, **kwargs):
215
+ result = super().bind_tools(tools, **kwargs)
216
+ result._uses_tools = True
217
+ return result
218
+
219
+ def _stream(self, *args, **kwargs):
220
+ response = self.invoke(*args, **kwargs)
221
+
222
+ from langchain_core.messages import AIMessage, AIMessageChunk
223
+ from langchain_core.outputs import ChatGenerationChunk
224
+
225
+ if isinstance(response, AIMessage):
226
+ initial_chunk = AIMessageChunk(content="", role="assistant")
227
+ yield ChatGenerationChunk(message=initial_chunk)
228
+
229
+ if response.content:
230
+ content_chunk = AIMessageChunk(content=response.content, role="assistant")
231
+ yield ChatGenerationChunk(message=content_chunk)
232
+
233
+ if hasattr(response, 'tool_calls') and response.tool_calls:
234
+ tool_chunk = AIMessageChunk(
235
+ content="",
236
+ role="assistant",
237
+ tool_calls=response.tool_calls
238
+ )
239
+ yield ChatGenerationChunk(message=tool_chunk)
240
+ else:
241
+ final_chunk = AIMessageChunk(content="", role="assistant")
242
+ yield ChatGenerationChunk(message=final_chunk)
243
+ else:
244
+ chunk = AIMessageChunk(content=str(response), role="assistant")
245
+ yield ChatGenerationChunk(message=chunk)
208
246
 
209
247
 
210
248
  class FairoChat(ChatMlflow):
@@ -3,10 +3,17 @@ import requests
3
3
  from requests.auth import HTTPBasicAuth
4
4
 
5
5
  class BaseClient:
6
- def __init__(self, base_url: str, username: str, password: str):
6
+ def __init__(self, base_url: str, username: str = None, password: str = None, fairo_auth_token: str = None):
7
7
  self.base_url = base_url
8
8
  self.session = requests.Session()
9
- self.session.auth = HTTPBasicAuth(username, password)
9
+ if username is not None and password is not None:
10
+ self.session.auth = HTTPBasicAuth(username, password)
11
+ elif fairo_auth_token is not None:
12
+ self.session.headers.update({
13
+ "Authorization": f"Bearer {fairo_auth_token}"
14
+ })
15
+ else:
16
+ raise ValueError("Must provide either username/password or fairo_auth_token")
10
17
  self.session.headers.update({
11
18
  "Content-Type": "application/json",
12
19
  })
@@ -1,17 +1,42 @@
1
- from typing import Any, Dict
1
+ from typing import Any, Dict, Optional
2
2
  import mlflow
3
- import cloudpickle
4
- import os
5
- import sys
3
+ import json
4
+ from langchain.callbacks.base import BaseCallbackHandler
5
+ from langchain_core.callbacks import CallbackManagerForChainRun
6
6
  from pathlib import Path
7
7
  from langchain_core.runnables import RunnableLambda, Runnable
8
8
  from langchain.chains import SimpleSequentialChain
9
9
  import logging
10
10
  import types
11
11
  import threading
12
+ import inspect
12
13
  import pandas as pd
13
14
  logger = logging.getLogger(__name__)
14
15
 
16
+ # Thread-local context for S3 client and bucket path to prevent cross-execution contamination
17
+ import threading
18
+ _agent_context = threading.local()
19
+
20
+ def set_agent_context(s3_client, bucket_path, execution_id=None):
21
+ """Set thread-local context for agent execution"""
22
+ _agent_context.s3_client = s3_client
23
+ _agent_context.bucket_path = bucket_path
24
+ _agent_context.execution_id = execution_id
25
+
26
+ def get_agent_context():
27
+ """Get thread-local context for agent execution"""
28
+ return {
29
+ 's3_client': getattr(_agent_context, 's3_client', None),
30
+ 'bucket_path': getattr(_agent_context, 'bucket_path', None),
31
+ 'execution_id': getattr(_agent_context, 'execution_id', None)
32
+ }
33
+
34
+ def clear_agent_context():
35
+ """Clear thread-local context to prevent memory leaks"""
36
+ for attr in ['s3_client', 'bucket_path', 'execution_id']:
37
+ if hasattr(_agent_context, attr):
38
+ delattr(_agent_context, attr)
39
+
15
40
  class CustomPythonModel(mlflow.pyfunc.PythonModel):
16
41
  def __init__(self):
17
42
  self.agent = None
@@ -53,27 +78,53 @@ class CustomPythonModel(mlflow.pyfunc.PythonModel):
53
78
  except ImportError as e:
54
79
  raise ImportError(f"Failed to import agent_code: {e}")
55
80
 
56
- def predict(self, context, model_input):
81
+ def predict(self, context, model_input: list[str]):
57
82
  if isinstance(model_input, list):
58
- return [self.agent.run(query) for query in model_input]
83
+ parsed_data = json.loads(model_input[0])
84
+ return self.run(parsed_data, callback_enabled=True)
59
85
  else:
60
- return self.agent.run(model_input)
61
-
86
+ return self.run(model_input)
62
87
  class AgentChainWrapper:
63
- def __init__(self, chain_class = SimpleSequentialChain, agent_functions_list = []):
88
+ def __init__(self, chain_class = SimpleSequentialChain, agent_functions_list = [], callback_enabled = False):
64
89
  self.chain_class = chain_class
65
90
  self.agents = [func() for func in agent_functions_list]
66
91
  self.agent_functions = agent_functions_list
92
+ self.callback_enabled = callback_enabled
67
93
 
68
94
  def _wrap_agent_runnable(self, agent) -> RunnableLambda:
69
95
  """
70
96
  Wraps the agent's .run() method into a RunnableLambda with a custom function name.
71
97
  Properly propagates errors instead of continuing to the next agent.
72
98
  """
73
- def base_fn(inputs: Dict[str, Any]) -> Dict[str, Any]:
99
+ def base_fn(
100
+ x: Dict[str, Any],
101
+ *,
102
+ run_manager: CallbackManagerForChainRun = None,
103
+ ):
74
104
  # Run the agent, but don't catch exceptions - let them propagate
75
105
  # This will stop the entire pipeline on agent failure
76
- return agent.invoke(inputs)
106
+ if run_manager:
107
+ run_manager.on_text(f"[{agent.__class__.__name__}] starting…")
108
+
109
+ # If your agent supports .invoke, prefer it; otherwise fall back to .run
110
+ try:
111
+ # Propagate callbacks to the inner agent call too (if it’s a Runnable)
112
+ if hasattr(agent, "invoke"):
113
+ sig = inspect.signature(agent.invoke)
114
+ if "config" in sig.parameters and self.callback_enabled:
115
+ out = agent.invoke(
116
+ x,
117
+ config={"callbacks": [OutputAgentStatus()]}
118
+ )
119
+ else:
120
+ out = agent.invoke(x)
121
+ else:
122
+ out = agent.run(x) # legacy agents
123
+ finally:
124
+ if run_manager:
125
+ run_manager.on_text(f"[{agent.__class__.__name__}] finished.")
126
+
127
+ return out
77
128
 
78
129
  # Check if result starts with "An error occurred" which indicates agent failure
79
130
  # if isinstance(result, str) and result.startswith("An error occurred during execution:"):
@@ -94,7 +145,9 @@ class AgentChainWrapper:
94
145
 
95
146
  return RunnableLambda(runnable_fn)
96
147
 
97
- def run(self, query):
148
+ def run(self, query, callback_enabled: Optional[bool] = False):
149
+ if callback_enabled:
150
+ self.callback_enabled = callback_enabled
98
151
  result = query
99
152
  def is_dataframe(obj) -> bool:
100
153
  try:
@@ -106,7 +159,14 @@ class AgentChainWrapper:
106
159
  runnables = []
107
160
  for agent in self.agents:
108
161
  if isinstance(agent, Runnable):
109
- runnables.append(agent)
162
+ # Check if agent supports with_config (Runnable style)
163
+ if hasattr(agent, "with_config") and self.callback_enabled:
164
+ # Inject default callbacks on the agent itself
165
+ enhanced = agent.with_config({"callbacks": [OutputAgentStatus()]})
166
+ runnables.append(enhanced)
167
+ else:
168
+ # Not a Runnable — wrap with your fallback wrapper
169
+ runnables.append(agent)
110
170
  else:
111
171
  runnables.append(
112
172
  self._wrap_agent_runnable(agent)
@@ -123,8 +183,12 @@ class AgentChainWrapper:
123
183
  )
124
184
  return chain.run(result)
125
185
 
126
- def predict(self, context = "", model_input = ""):
127
- return self.run(model_input)
186
+ def predict(self, context = "", model_input: list[str] = [""]):
187
+ if isinstance(model_input, list):
188
+ parsed_data = json.loads(model_input[0])
189
+ return self.run(parsed_data, callback_enabled=True)
190
+ else:
191
+ return self.run(model_input)
128
192
 
129
193
  class CustomChainModel(mlflow.pyfunc.PythonModel):
130
194
  def __init__(self):
@@ -191,11 +255,12 @@ class CustomChainModel(mlflow.pyfunc.PythonModel):
191
255
  agent_functions.append(agent_function)
192
256
 
193
257
  # Create the agent chain
194
- self.agent_chain = AgentChainWrapper(agent_functions_list=agent_functions)
258
+ self.agent_chain = AgentChainWrapper(agent_functions_list=agent_functions, callback_enabled=True)
195
259
 
196
- def predict(self, context, model_input):
260
+ def predict(self, context, model_input: list[str]):
197
261
  if isinstance(model_input, list):
198
- return [self.agent_chain.run(query) for query in model_input]
262
+ parsed_data = json.loads(model_input[0])
263
+ return self.agent_chain.run(parsed_data)
199
264
  else:
200
265
  return self.agent_chain.run(model_input)
201
266
 
@@ -238,7 +303,7 @@ class CrewAgentWrapper:
238
303
  print(f"Error running CrewAI crew: {e}")
239
304
  return f"Error executing query '{query}': {str(e)}"
240
305
 
241
- def predict(self, context, model_input):
306
+ def predict(self, context, model_input: list[str]):
242
307
  return self.run(model_input)
243
308
 
244
309
  class CustomCrewModel(mlflow.pyfunc.PythonModel):
@@ -281,8 +346,107 @@ class CustomCrewModel(mlflow.pyfunc.PythonModel):
281
346
  except ImportError as e:
282
347
  raise ImportError(f"Failed to import CrewAI agent_code: {e}")
283
348
 
284
- def predict(self, context, model_input):
349
+ def predict(self, context, model_input: list[str]):
285
350
  if isinstance(model_input, list):
286
351
  return [self.agent.run(query) for query in model_input]
287
352
  else:
288
- return self.agent.run(model_input)
353
+ return self.agent.run(model_input)
354
+
355
+
356
+ class OutputAgentStatus(BaseCallbackHandler):
357
+ def __init__(self, s3_client=None, bucket_path=None):
358
+ super().__init__()
359
+ self.s3_client = s3_client
360
+ self.bucket_path = bucket_path
361
+
362
+ # If not provided, try to get from global context
363
+ if not self.s3_client or not self.bucket_path:
364
+ context = get_agent_context()
365
+ self.s3_client = self.s3_client or context.get('s3_client')
366
+ self.bucket_path = self.bucket_path or context.get('bucket_path')
367
+
368
+ def save_to_s3(self, status, message):
369
+ if not self.s3_client or not self.bucket_path:
370
+ return
371
+
372
+ # Validate execution_id is in bucket_path to prevent cross-execution contamination
373
+ import os
374
+ execution_id = os.environ.get('EXECUTION_ID')
375
+ if execution_id and execution_id not in self.bucket_path:
376
+ print(f"Warning: Execution ID {execution_id} not found in bucket path {self.bucket_path}. Skipping S3 write.")
377
+ return
378
+
379
+ try:
380
+ import os
381
+ import json
382
+ from datetime import datetime
383
+
384
+ bucket_name = os.environ.get('DEPLOYMENTS_BUCKET_NAME', 'local-development-deployments')
385
+ status_key = f"{self.bucket_path}/last_output.json"
386
+
387
+ # Try to read existing last_output.json
388
+ existing_data = {}
389
+ try:
390
+ response = self.s3_client.get_object(Bucket=bucket_name, Key=status_key)
391
+ existing_data = json.loads(response['Body'].read().decode('utf-8'))
392
+ except self.s3_client.exceptions.NoSuchKey:
393
+ # File doesn't exist yet, start with empty data
394
+ pass
395
+ except Exception as e:
396
+ print(f"Warning: Could not read existing last_output.json: {e}")
397
+
398
+ # Update the status and output fields
399
+ existing_data.update({
400
+ "status": status,
401
+ "output": message
402
+ })
403
+
404
+ # Save updated last_output.json
405
+ self.s3_client.put_object(
406
+ Bucket=bucket_name,
407
+ Key=status_key,
408
+ Body=json.dumps(existing_data),
409
+ ContentType='application/json'
410
+ )
411
+ except Exception as e:
412
+ print(f"Error saving status to S3: {e}")
413
+
414
+ def on_text(self, text: str, **kwargs):
415
+ self.save_to_s3("text_output", f"Agent generated text: {text}")
416
+
417
+ def on_llm_start(self, serialized, prompts, **kwargs):
418
+ model_name = serialized.get('name', 'Unknown')
419
+ self.save_to_s3("llm_start", f"Thinking")
420
+
421
+ def on_llm_new_token(self, token: str, **kwargs):
422
+ self.save_to_s3("llm_streaming", f"LLM generating response token: {token}")
423
+
424
+ def on_llm_end(self, response, **kwargs):
425
+ token_count = getattr(response, 'llm_output', {}).get('token_usage', {}).get('total_tokens', 'unknown')
426
+ self.save_to_s3("llm_complete", f"LLM completed response generation (tokens: {token_count})")
427
+
428
+ def on_tool_start(self, serialized, input_str: str, **kwargs):
429
+ tool_name = serialized.get('name', 'Unknown Tool')
430
+ self.save_to_s3("tool_start", f"Executing tool: {tool_name} with input: {input_str[:100]}")
431
+
432
+ def on_tool_end(self, output: str, **kwargs):
433
+ output_preview = str(output)[:100] if len(str(output)) > 100 else str(output)
434
+ self.save_to_s3("tool_complete", f"Tool execution completed with output: {output_preview}")
435
+
436
+ def on_chain_start(self, serialized, inputs, **kwargs):
437
+ chain_id = serialized.get('id', 'Unknown Chain')
438
+ self.save_to_s3("chain_start", f"Starting chain execution: {chain_id}")
439
+
440
+ def on_chain_end(self, outputs, **kwargs):
441
+ output_preview = str(outputs)[:100] if len(str(outputs)) > 100 else str(outputs)
442
+ self.save_to_s3("chain_complete", f"Chain execution completed with outputs: {output_preview}")
443
+
444
+ def on_agent_action(self, action, **kwargs):
445
+ action_tool = getattr(action, 'tool', 'Unknown')
446
+ action_input = getattr(action, 'tool_input', '')
447
+ self.save_to_s3("agent_action", f"Agent taking action with tool: {action_tool}, input: {str(action_input)[:100]}")
448
+
449
+ def on_agent_finish(self, finish, **kwargs):
450
+ return_values = getattr(finish, 'return_values', {})
451
+ output_preview = str(return_values)[:100] if len(str(return_values)) > 100 else str(return_values)
452
+ self.save_to_s3("agent_complete", f"Agent execution finished with result: {output_preview}")
@@ -1,10 +1,12 @@
1
1
  import json
2
2
  import os
3
- from typing import List, Any, Callable, Dict, Union
3
+ from typing import List, Any, Callable, Dict, Optional, Type, Union
4
4
  from langchain_core.runnables import RunnableLambda, RunnableSequence
5
5
  from langchain.chains import SimpleSequentialChain
6
6
  import logging
7
7
 
8
+ from pydantic import BaseModel
9
+
8
10
  import mlflow
9
11
 
10
12
  from mlflow.models.signature import ModelSignature
@@ -16,7 +18,7 @@ from fairo.core.execution.model_log_helper import ModelLogHelper
16
18
  from fairo.core.runnable.runnable import Runnable
17
19
  from fairo.core.workflow.utils import output_langchain_process_graph
18
20
  from fairo.settings import get_fairo_api_key, get_fairo_api_secret, get_mlflow_experiment_name, get_mlflow_server, get_fairo_base_url
19
-
21
+ from fairo.core.tools import ChatSuggestions
20
22
 
21
23
  logger = logging.getLogger(__name__)
22
24
 
@@ -32,18 +34,24 @@ class FairoExecutor:
32
34
  experiment_name: str = None,
33
35
  chain_class = SimpleSequentialChain,
34
36
  input_fields: List[str] = [],
37
+ input_schema: Optional[Type[BaseModel]] = None,
38
+ chat_suggestions: Optional[ChatSuggestions] = None,
39
+ debug_mode: bool = False
35
40
  ):
36
41
  if agents and runnable:
37
42
  raise ValueError("FairoExecutor cannot be initialized with both 'agents' and 'runnable'. Please provide only one.")
38
- if not input_fields:
39
- raise ValueError("Missing input_fields")
43
+ if not input_fields and not input_schema:
44
+ raise ValueError("Missing required parameters: please provide at least one of 'input_fields' or 'input_schema'")
45
+ self.input_schema = input_schema
40
46
  self.agents = agents
41
47
  self.agent_type = agent_type
42
48
  self.verbose = verbose
49
+ self.debug_mode = debug_mode
43
50
  self.patch_run_output_json = patch_run_output_json
44
51
  self.workflow_run_id = workflow_run_id
45
52
  self.runnable = runnable
46
53
  self.experiment_name = experiment_name if experiment_name else get_mlflow_experiment_name()
54
+ self._setup_logging()
47
55
  self.setup_mlflow()
48
56
  self.chain_class = chain_class
49
57
  self.client = BaseClient(
@@ -51,6 +59,7 @@ class FairoExecutor:
51
59
  password=get_fairo_api_secret(),
52
60
  username=get_fairo_api_key()
53
61
  )
62
+ self.chat_suggestions = chat_suggestions
54
63
  self.input_fields = input_fields
55
64
  # Inject shared attributes into agents
56
65
  for agent in self.agents:
@@ -59,6 +68,14 @@ class FairoExecutor:
59
68
  if hasattr(agent, 'verbose'):
60
69
  agent.verbose = self.verbose
61
70
 
71
+ def _setup_logging(self):
72
+ """Configure MLflow logging level based on debug_mode."""
73
+ mlflow_logger = logging.getLogger('mlflow')
74
+ if not self.debug_mode:
75
+ import warnings
76
+ warnings.filterwarnings('ignore', category=UserWarning)
77
+ mlflow_logger.setLevel(logging.ERROR)
78
+
62
79
  def _build_pipeline(self) -> RunnableSequence:
63
80
  if not self.agents and not self.runnable:
64
81
  raise ValueError("At least one agent or runnable must be provided.")
@@ -67,8 +84,26 @@ class FairoExecutor:
67
84
  pipeline = mlflow.pyfunc.load_model(self.runnable.artifact_path)
68
85
  else:
69
86
  pipeline = AgentChainWrapper(chain_class=self.chain_class, agent_functions_list=self.agents)
70
- cols = [ColSpec(type="string", name=field) for field in self.input_fields]
71
- input_schema = Schema(cols)
87
+ # Convert Pydantic schema to MLflow Schema
88
+ if hasattr(self.input_schema, 'model_json_schema'):
89
+ # Extract field names from Pydantic schema
90
+ pydantic_schema = self.input_schema.model_json_schema()
91
+ properties = pydantic_schema.get('properties', {})
92
+ cols = []
93
+ for field_name, field_info in properties.items():
94
+ field_type = field_info.get('type', 'string')
95
+ # Map Pydantic types to MLflow types
96
+ mlflow_type = 'string' # Default to string
97
+ if field_type in ['integer', 'number']:
98
+ mlflow_type = 'double'
99
+ elif field_type == 'boolean':
100
+ mlflow_type = 'boolean'
101
+ cols.append(ColSpec(type=mlflow_type, name=field_name))
102
+ input_schema = Schema(cols)
103
+ else:
104
+ # Fallback to input_fields if schema is not Pydantic
105
+ cols = [ColSpec(type="string", name=field) for field in self.input_fields]
106
+ input_schema = Schema(cols)
72
107
 
73
108
  output_schema = Schema([
74
109
  ColSpec(type="string", name="output"),
@@ -97,7 +132,9 @@ class FairoExecutor:
97
132
  fairo_settings = {
98
133
  "type": type,
99
134
  "process_graph": process_graph,
100
- "input_schema": self.input_fields
135
+ "schema": self.input_schema.model_json_schema() if self.input_schema else None,
136
+ "input_fields": list(self.input_schema.model_fields.keys()) if self.input_schema else self.input_fields,
137
+ "chat_suggestions": self.chat_suggestions.model_dump() if self.chat_suggestions else None,
101
138
  }
102
139
  if process_graph:
103
140
  mlflow.log_text(json.dumps(fairo_settings, ensure_ascii=False, indent=2), artifact_file="fairo_settings.txt")