fairo 25.7.2__tar.gz → 25.12.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fairo-25.7.2 → fairo-25.12.1}/PKG-INFO +4 -6
- fairo-25.12.1/fairo/__init__.py +1 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/chat/chat.py +50 -12
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/client/client.py +9 -2
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/execution/agent_serializer.py +185 -21
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/execution/executor.py +44 -7
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/execution/model_log_helper.py +106 -25
- fairo-25.12.1/fairo/core/tools/__init__.py +2 -0
- fairo-25.12.1/fairo/core/tools/plot.py +250 -0
- fairo-25.12.1/fairo/core/tools/suggestion.py +43 -0
- fairo-25.12.1/fairo/core/utils.py +320 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/workflow/dependency.py +19 -156
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/workflow/utils.py +225 -60
- {fairo-25.7.2 → fairo-25.12.1}/fairo/settings.py +1 -1
- {fairo-25.7.2 → fairo-25.12.1}/fairo.egg-info/PKG-INFO +4 -6
- {fairo-25.7.2 → fairo-25.12.1}/fairo.egg-info/SOURCES.txt +4 -0
- fairo-25.12.1/fairo.egg-info/requires.txt +8 -0
- {fairo-25.7.2 → fairo-25.12.1}/pyproject.toml +3 -5
- fairo-25.7.2/fairo/__init__.py +0 -1
- fairo-25.7.2/fairo.egg-info/requires.txt +0 -10
- {fairo-25.7.2 → fairo-25.12.1}/README.md +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/__init__.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/__init__.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/base_agent.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/code_analysis_agent.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/output/__init__.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/output/base_output.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/output/google_drive.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/tools/__init__.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/tools/base_tools.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/tools/code_analysis.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/tools/utils.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/agent/utils.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/chat/__init__.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/client/__init__.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/exceptions.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/execution/__init__.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/execution/env_finder.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/models/__init__.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/models/custom_field_value.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/models/resources.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/runnable/__init__.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/runnable/runnable.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/workflow/__init__.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/core/workflow/base_workflow.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/metrics/__init__.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/metrics/fairness_object.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/metrics/metrics.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/tests/__init__.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo/tests/test_metrics.py +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo.egg-info/dependency_links.txt +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/fairo.egg-info/top_level.txt +0 -0
- {fairo-25.7.2 → fairo-25.12.1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: fairo
|
|
3
|
-
Version: 25.
|
|
3
|
+
Version: 25.12.1
|
|
4
4
|
Summary: SDK for interfacing with Fairo SaaS platform.
|
|
5
5
|
Author-email: "Fairo Systems, Inc." <support@fairo.ai>
|
|
6
6
|
License: Apache-2.0
|
|
@@ -12,13 +12,11 @@ Classifier: License :: OSI Approved :: Apache Software License
|
|
|
12
12
|
Classifier: Operating System :: OS Independent
|
|
13
13
|
Description-Content-Type: text/markdown
|
|
14
14
|
Requires-Dist: mlflow<=3.1.1,>=3.1.0
|
|
15
|
-
Requires-Dist: langchain<0.4.0,>=0.3.
|
|
15
|
+
Requires-Dist: langchain<0.4.0,>=0.3.27
|
|
16
16
|
Requires-Dist: langchain-aws<0.3.0,>=0.2.18
|
|
17
|
-
Requires-Dist: langchain-community<0.4.0,>=0.3.
|
|
17
|
+
Requires-Dist: langchain-community<0.4.0,>=0.3.27
|
|
18
18
|
Requires-Dist: langchain-core<0.4.0,>=0.3.49
|
|
19
|
-
Requires-Dist: langchain-text-splitters<0.4.0,>=0.3.
|
|
20
|
-
Requires-Dist: psycopg2-binary<3.0.0,>=2.9.0
|
|
21
|
-
Requires-Dist: langchain-postgres<0.1.0,>=0.0.14
|
|
19
|
+
Requires-Dist: langchain-text-splitters<0.4.0,>=0.3.11
|
|
22
20
|
Requires-Dist: setuptools>=79.0.0
|
|
23
21
|
Requires-Dist: pandas<3.0.0,>=2.0.0
|
|
24
22
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "25.12.1"
|
|
@@ -31,10 +31,10 @@ class FairoDeploymentClient(BaseDeploymentClient):
|
|
|
31
31
|
|
|
32
32
|
# Add authentication if needed
|
|
33
33
|
auth = None
|
|
34
|
-
if os.environ.get('
|
|
34
|
+
if os.environ.get('FAIRO_API_ACCESS_KEY_ID') and os.environ.get('FAIRO_API_SECRET'):
|
|
35
35
|
auth = HTTPBasicAuth(
|
|
36
|
-
os.environ.get('
|
|
37
|
-
os.environ.get('
|
|
36
|
+
os.environ.get('FAIRO_API_ACCESS_KEY_ID'),
|
|
37
|
+
os.environ.get('FAIRO_API_SECRET')
|
|
38
38
|
)
|
|
39
39
|
|
|
40
40
|
# Make streaming request
|
|
@@ -116,12 +116,15 @@ class FairoDeploymentClient(BaseDeploymentClient):
|
|
|
116
116
|
|
|
117
117
|
# Add authentication if needed
|
|
118
118
|
auth = None
|
|
119
|
-
if os.environ.get('
|
|
119
|
+
if os.environ.get('FAIRO_API_ACCESS_KEY_ID') and os.environ.get('FAIRO_API_SECRET'):
|
|
120
120
|
auth = HTTPBasicAuth(
|
|
121
|
-
os.environ.get('
|
|
122
|
-
os.environ.get('
|
|
121
|
+
os.environ.get('FAIRO_API_ACCESS_KEY_ID'),
|
|
122
|
+
os.environ.get('FAIRO_API_SECRET')
|
|
123
123
|
)
|
|
124
124
|
|
|
125
|
+
if os.environ.get('MLFLOW_TRACKING_TOKEN'):
|
|
126
|
+
headers['Authorization'] = f"Bearer {os.environ.get('MLFLOW_TRACKING_TOKEN')}"
|
|
127
|
+
|
|
125
128
|
# Make request
|
|
126
129
|
response = requests.post(
|
|
127
130
|
gateway_url,
|
|
@@ -180,10 +183,6 @@ class ChatFairo(ChatMlflow):
|
|
|
180
183
|
|
|
181
184
|
def __init__(self, **kwargs):
|
|
182
185
|
|
|
183
|
-
# # TODO <- see if this can be improved
|
|
184
|
-
# os.environ["MLFLOW_TRACKING_USERNAME"] = get_mlflow_user()
|
|
185
|
-
# os.environ["MLFLOW_TRACKING_PASSWORD"] = get_mlflow_password()
|
|
186
|
-
|
|
187
186
|
super().__init__(
|
|
188
187
|
target_uri=os.environ.get('MLFLOW_GATEWAY_URI', get_mlflow_gateway_uri()),
|
|
189
188
|
endpoint=os.environ.get('MLFLOW_GATEWAY_ROUTE', get_mlflow_gateway_chat_route()),
|
|
@@ -194,17 +193,56 @@ class ChatFairo(ChatMlflow):
|
|
|
194
193
|
|
|
195
194
|
@property
|
|
196
195
|
def _target_uri(self):
|
|
197
|
-
return os.environ.get("MLFLOW_GATEWAY_URI",
|
|
196
|
+
return os.environ.get("MLFLOW_GATEWAY_URI", get_mlflow_gateway_uri())
|
|
198
197
|
|
|
199
198
|
@property
|
|
200
199
|
def _endpoint(self):
|
|
201
|
-
return os.environ.get("MLFLOW_GATEWAY_ROUTE",
|
|
200
|
+
return os.environ.get("MLFLOW_GATEWAY_ROUTE", get_mlflow_gateway_chat_route())
|
|
202
201
|
|
|
203
202
|
def invoke(self, *args, **kwargs):
|
|
204
203
|
# Override invoke to use dynamic target_uri
|
|
205
204
|
self.target_uri = self._target_uri
|
|
206
205
|
self._client = FairoDeploymentClient(self.target_uri, self.endpoint)
|
|
207
206
|
return super().invoke(*args, **kwargs)
|
|
207
|
+
|
|
208
|
+
def stream(self, *args, **kwargs):
|
|
209
|
+
# Override stream to use dynamic target_uri
|
|
210
|
+
self.target_uri = self._target_uri
|
|
211
|
+
self._client = FairoDeploymentClient(self.target_uri, self.endpoint)
|
|
212
|
+
return super().stream(*args, **kwargs)
|
|
213
|
+
|
|
214
|
+
def bind_tools(self, tools, **kwargs):
|
|
215
|
+
result = super().bind_tools(tools, **kwargs)
|
|
216
|
+
result._uses_tools = True
|
|
217
|
+
return result
|
|
218
|
+
|
|
219
|
+
def _stream(self, *args, **kwargs):
|
|
220
|
+
response = self.invoke(*args, **kwargs)
|
|
221
|
+
|
|
222
|
+
from langchain_core.messages import AIMessage, AIMessageChunk
|
|
223
|
+
from langchain_core.outputs import ChatGenerationChunk
|
|
224
|
+
|
|
225
|
+
if isinstance(response, AIMessage):
|
|
226
|
+
initial_chunk = AIMessageChunk(content="", role="assistant")
|
|
227
|
+
yield ChatGenerationChunk(message=initial_chunk)
|
|
228
|
+
|
|
229
|
+
if response.content:
|
|
230
|
+
content_chunk = AIMessageChunk(content=response.content, role="assistant")
|
|
231
|
+
yield ChatGenerationChunk(message=content_chunk)
|
|
232
|
+
|
|
233
|
+
if hasattr(response, 'tool_calls') and response.tool_calls:
|
|
234
|
+
tool_chunk = AIMessageChunk(
|
|
235
|
+
content="",
|
|
236
|
+
role="assistant",
|
|
237
|
+
tool_calls=response.tool_calls
|
|
238
|
+
)
|
|
239
|
+
yield ChatGenerationChunk(message=tool_chunk)
|
|
240
|
+
else:
|
|
241
|
+
final_chunk = AIMessageChunk(content="", role="assistant")
|
|
242
|
+
yield ChatGenerationChunk(message=final_chunk)
|
|
243
|
+
else:
|
|
244
|
+
chunk = AIMessageChunk(content=str(response), role="assistant")
|
|
245
|
+
yield ChatGenerationChunk(message=chunk)
|
|
208
246
|
|
|
209
247
|
|
|
210
248
|
class FairoChat(ChatMlflow):
|
|
@@ -3,10 +3,17 @@ import requests
|
|
|
3
3
|
from requests.auth import HTTPBasicAuth
|
|
4
4
|
|
|
5
5
|
class BaseClient:
|
|
6
|
-
def __init__(self, base_url: str, username: str, password: str):
|
|
6
|
+
def __init__(self, base_url: str, username: str = None, password: str = None, fairo_auth_token: str = None):
|
|
7
7
|
self.base_url = base_url
|
|
8
8
|
self.session = requests.Session()
|
|
9
|
-
|
|
9
|
+
if username is not None and password is not None:
|
|
10
|
+
self.session.auth = HTTPBasicAuth(username, password)
|
|
11
|
+
elif fairo_auth_token is not None:
|
|
12
|
+
self.session.headers.update({
|
|
13
|
+
"Authorization": f"Bearer {fairo_auth_token}"
|
|
14
|
+
})
|
|
15
|
+
else:
|
|
16
|
+
raise ValueError("Must provide either username/password or fairo_auth_token")
|
|
10
17
|
self.session.headers.update({
|
|
11
18
|
"Content-Type": "application/json",
|
|
12
19
|
})
|
|
@@ -1,17 +1,42 @@
|
|
|
1
|
-
from typing import Any, Dict
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
2
|
import mlflow
|
|
3
|
-
import
|
|
4
|
-
import
|
|
5
|
-
import
|
|
3
|
+
import json
|
|
4
|
+
from langchain.callbacks.base import BaseCallbackHandler
|
|
5
|
+
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from langchain_core.runnables import RunnableLambda, Runnable
|
|
8
8
|
from langchain.chains import SimpleSequentialChain
|
|
9
9
|
import logging
|
|
10
10
|
import types
|
|
11
11
|
import threading
|
|
12
|
+
import inspect
|
|
12
13
|
import pandas as pd
|
|
13
14
|
logger = logging.getLogger(__name__)
|
|
14
15
|
|
|
16
|
+
# Thread-local context for S3 client and bucket path to prevent cross-execution contamination
|
|
17
|
+
import threading
|
|
18
|
+
_agent_context = threading.local()
|
|
19
|
+
|
|
20
|
+
def set_agent_context(s3_client, bucket_path, execution_id=None):
|
|
21
|
+
"""Set thread-local context for agent execution"""
|
|
22
|
+
_agent_context.s3_client = s3_client
|
|
23
|
+
_agent_context.bucket_path = bucket_path
|
|
24
|
+
_agent_context.execution_id = execution_id
|
|
25
|
+
|
|
26
|
+
def get_agent_context():
|
|
27
|
+
"""Get thread-local context for agent execution"""
|
|
28
|
+
return {
|
|
29
|
+
's3_client': getattr(_agent_context, 's3_client', None),
|
|
30
|
+
'bucket_path': getattr(_agent_context, 'bucket_path', None),
|
|
31
|
+
'execution_id': getattr(_agent_context, 'execution_id', None)
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
def clear_agent_context():
|
|
35
|
+
"""Clear thread-local context to prevent memory leaks"""
|
|
36
|
+
for attr in ['s3_client', 'bucket_path', 'execution_id']:
|
|
37
|
+
if hasattr(_agent_context, attr):
|
|
38
|
+
delattr(_agent_context, attr)
|
|
39
|
+
|
|
15
40
|
class CustomPythonModel(mlflow.pyfunc.PythonModel):
|
|
16
41
|
def __init__(self):
|
|
17
42
|
self.agent = None
|
|
@@ -53,27 +78,53 @@ class CustomPythonModel(mlflow.pyfunc.PythonModel):
|
|
|
53
78
|
except ImportError as e:
|
|
54
79
|
raise ImportError(f"Failed to import agent_code: {e}")
|
|
55
80
|
|
|
56
|
-
def predict(self, context, model_input):
|
|
81
|
+
def predict(self, context, model_input: list[str]):
|
|
57
82
|
if isinstance(model_input, list):
|
|
58
|
-
|
|
83
|
+
parsed_data = json.loads(model_input[0])
|
|
84
|
+
return self.run(parsed_data, callback_enabled=True)
|
|
59
85
|
else:
|
|
60
|
-
return self.
|
|
61
|
-
|
|
86
|
+
return self.run(model_input)
|
|
62
87
|
class AgentChainWrapper:
|
|
63
|
-
def __init__(self, chain_class = SimpleSequentialChain, agent_functions_list = []):
|
|
88
|
+
def __init__(self, chain_class = SimpleSequentialChain, agent_functions_list = [], callback_enabled = False):
|
|
64
89
|
self.chain_class = chain_class
|
|
65
90
|
self.agents = [func() for func in agent_functions_list]
|
|
66
91
|
self.agent_functions = agent_functions_list
|
|
92
|
+
self.callback_enabled = callback_enabled
|
|
67
93
|
|
|
68
94
|
def _wrap_agent_runnable(self, agent) -> RunnableLambda:
|
|
69
95
|
"""
|
|
70
96
|
Wraps the agent's .run() method into a RunnableLambda with a custom function name.
|
|
71
97
|
Properly propagates errors instead of continuing to the next agent.
|
|
72
98
|
"""
|
|
73
|
-
def base_fn(
|
|
99
|
+
def base_fn(
|
|
100
|
+
x: Dict[str, Any],
|
|
101
|
+
*,
|
|
102
|
+
run_manager: CallbackManagerForChainRun = None,
|
|
103
|
+
):
|
|
74
104
|
# Run the agent, but don't catch exceptions - let them propagate
|
|
75
105
|
# This will stop the entire pipeline on agent failure
|
|
76
|
-
|
|
106
|
+
if run_manager:
|
|
107
|
+
run_manager.on_text(f"[{agent.__class__.__name__}] starting…")
|
|
108
|
+
|
|
109
|
+
# If your agent supports .invoke, prefer it; otherwise fall back to .run
|
|
110
|
+
try:
|
|
111
|
+
# Propagate callbacks to the inner agent call too (if it’s a Runnable)
|
|
112
|
+
if hasattr(agent, "invoke"):
|
|
113
|
+
sig = inspect.signature(agent.invoke)
|
|
114
|
+
if "config" in sig.parameters and self.callback_enabled:
|
|
115
|
+
out = agent.invoke(
|
|
116
|
+
x,
|
|
117
|
+
config={"callbacks": [OutputAgentStatus()]}
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
out = agent.invoke(x)
|
|
121
|
+
else:
|
|
122
|
+
out = agent.run(x) # legacy agents
|
|
123
|
+
finally:
|
|
124
|
+
if run_manager:
|
|
125
|
+
run_manager.on_text(f"[{agent.__class__.__name__}] finished.")
|
|
126
|
+
|
|
127
|
+
return out
|
|
77
128
|
|
|
78
129
|
# Check if result starts with "An error occurred" which indicates agent failure
|
|
79
130
|
# if isinstance(result, str) and result.startswith("An error occurred during execution:"):
|
|
@@ -94,7 +145,9 @@ class AgentChainWrapper:
|
|
|
94
145
|
|
|
95
146
|
return RunnableLambda(runnable_fn)
|
|
96
147
|
|
|
97
|
-
def run(self, query):
|
|
148
|
+
def run(self, query, callback_enabled: Optional[bool] = False):
|
|
149
|
+
if callback_enabled:
|
|
150
|
+
self.callback_enabled = callback_enabled
|
|
98
151
|
result = query
|
|
99
152
|
def is_dataframe(obj) -> bool:
|
|
100
153
|
try:
|
|
@@ -106,7 +159,14 @@ class AgentChainWrapper:
|
|
|
106
159
|
runnables = []
|
|
107
160
|
for agent in self.agents:
|
|
108
161
|
if isinstance(agent, Runnable):
|
|
109
|
-
|
|
162
|
+
# Check if agent supports with_config (Runnable style)
|
|
163
|
+
if hasattr(agent, "with_config") and self.callback_enabled:
|
|
164
|
+
# Inject default callbacks on the agent itself
|
|
165
|
+
enhanced = agent.with_config({"callbacks": [OutputAgentStatus()]})
|
|
166
|
+
runnables.append(enhanced)
|
|
167
|
+
else:
|
|
168
|
+
# Not a Runnable — wrap with your fallback wrapper
|
|
169
|
+
runnables.append(agent)
|
|
110
170
|
else:
|
|
111
171
|
runnables.append(
|
|
112
172
|
self._wrap_agent_runnable(agent)
|
|
@@ -123,8 +183,12 @@ class AgentChainWrapper:
|
|
|
123
183
|
)
|
|
124
184
|
return chain.run(result)
|
|
125
185
|
|
|
126
|
-
def predict(self, context = "", model_input = ""):
|
|
127
|
-
|
|
186
|
+
def predict(self, context = "", model_input: list[str] = [""]):
|
|
187
|
+
if isinstance(model_input, list):
|
|
188
|
+
parsed_data = json.loads(model_input[0])
|
|
189
|
+
return self.run(parsed_data, callback_enabled=True)
|
|
190
|
+
else:
|
|
191
|
+
return self.run(model_input)
|
|
128
192
|
|
|
129
193
|
class CustomChainModel(mlflow.pyfunc.PythonModel):
|
|
130
194
|
def __init__(self):
|
|
@@ -191,11 +255,12 @@ class CustomChainModel(mlflow.pyfunc.PythonModel):
|
|
|
191
255
|
agent_functions.append(agent_function)
|
|
192
256
|
|
|
193
257
|
# Create the agent chain
|
|
194
|
-
self.agent_chain = AgentChainWrapper(agent_functions_list=agent_functions)
|
|
258
|
+
self.agent_chain = AgentChainWrapper(agent_functions_list=agent_functions, callback_enabled=True)
|
|
195
259
|
|
|
196
|
-
def predict(self, context, model_input):
|
|
260
|
+
def predict(self, context, model_input: list[str]):
|
|
197
261
|
if isinstance(model_input, list):
|
|
198
|
-
|
|
262
|
+
parsed_data = json.loads(model_input[0])
|
|
263
|
+
return self.agent_chain.run(parsed_data)
|
|
199
264
|
else:
|
|
200
265
|
return self.agent_chain.run(model_input)
|
|
201
266
|
|
|
@@ -238,7 +303,7 @@ class CrewAgentWrapper:
|
|
|
238
303
|
print(f"Error running CrewAI crew: {e}")
|
|
239
304
|
return f"Error executing query '{query}': {str(e)}"
|
|
240
305
|
|
|
241
|
-
def predict(self, context, model_input):
|
|
306
|
+
def predict(self, context, model_input: list[str]):
|
|
242
307
|
return self.run(model_input)
|
|
243
308
|
|
|
244
309
|
class CustomCrewModel(mlflow.pyfunc.PythonModel):
|
|
@@ -281,8 +346,107 @@ class CustomCrewModel(mlflow.pyfunc.PythonModel):
|
|
|
281
346
|
except ImportError as e:
|
|
282
347
|
raise ImportError(f"Failed to import CrewAI agent_code: {e}")
|
|
283
348
|
|
|
284
|
-
def predict(self, context, model_input):
|
|
349
|
+
def predict(self, context, model_input: list[str]):
|
|
285
350
|
if isinstance(model_input, list):
|
|
286
351
|
return [self.agent.run(query) for query in model_input]
|
|
287
352
|
else:
|
|
288
|
-
return self.agent.run(model_input)
|
|
353
|
+
return self.agent.run(model_input)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class OutputAgentStatus(BaseCallbackHandler):
|
|
357
|
+
def __init__(self, s3_client=None, bucket_path=None):
|
|
358
|
+
super().__init__()
|
|
359
|
+
self.s3_client = s3_client
|
|
360
|
+
self.bucket_path = bucket_path
|
|
361
|
+
|
|
362
|
+
# If not provided, try to get from global context
|
|
363
|
+
if not self.s3_client or not self.bucket_path:
|
|
364
|
+
context = get_agent_context()
|
|
365
|
+
self.s3_client = self.s3_client or context.get('s3_client')
|
|
366
|
+
self.bucket_path = self.bucket_path or context.get('bucket_path')
|
|
367
|
+
|
|
368
|
+
def save_to_s3(self, status, message):
|
|
369
|
+
if not self.s3_client or not self.bucket_path:
|
|
370
|
+
return
|
|
371
|
+
|
|
372
|
+
# Validate execution_id is in bucket_path to prevent cross-execution contamination
|
|
373
|
+
import os
|
|
374
|
+
execution_id = os.environ.get('EXECUTION_ID')
|
|
375
|
+
if execution_id and execution_id not in self.bucket_path:
|
|
376
|
+
print(f"Warning: Execution ID {execution_id} not found in bucket path {self.bucket_path}. Skipping S3 write.")
|
|
377
|
+
return
|
|
378
|
+
|
|
379
|
+
try:
|
|
380
|
+
import os
|
|
381
|
+
import json
|
|
382
|
+
from datetime import datetime
|
|
383
|
+
|
|
384
|
+
bucket_name = os.environ.get('DEPLOYMENTS_BUCKET_NAME', 'local-development-deployments')
|
|
385
|
+
status_key = f"{self.bucket_path}/last_output.json"
|
|
386
|
+
|
|
387
|
+
# Try to read existing last_output.json
|
|
388
|
+
existing_data = {}
|
|
389
|
+
try:
|
|
390
|
+
response = self.s3_client.get_object(Bucket=bucket_name, Key=status_key)
|
|
391
|
+
existing_data = json.loads(response['Body'].read().decode('utf-8'))
|
|
392
|
+
except self.s3_client.exceptions.NoSuchKey:
|
|
393
|
+
# File doesn't exist yet, start with empty data
|
|
394
|
+
pass
|
|
395
|
+
except Exception as e:
|
|
396
|
+
print(f"Warning: Could not read existing last_output.json: {e}")
|
|
397
|
+
|
|
398
|
+
# Update the status and output fields
|
|
399
|
+
existing_data.update({
|
|
400
|
+
"status": status,
|
|
401
|
+
"output": message
|
|
402
|
+
})
|
|
403
|
+
|
|
404
|
+
# Save updated last_output.json
|
|
405
|
+
self.s3_client.put_object(
|
|
406
|
+
Bucket=bucket_name,
|
|
407
|
+
Key=status_key,
|
|
408
|
+
Body=json.dumps(existing_data),
|
|
409
|
+
ContentType='application/json'
|
|
410
|
+
)
|
|
411
|
+
except Exception as e:
|
|
412
|
+
print(f"Error saving status to S3: {e}")
|
|
413
|
+
|
|
414
|
+
def on_text(self, text: str, **kwargs):
|
|
415
|
+
self.save_to_s3("text_output", f"Agent generated text: {text}")
|
|
416
|
+
|
|
417
|
+
def on_llm_start(self, serialized, prompts, **kwargs):
|
|
418
|
+
model_name = serialized.get('name', 'Unknown')
|
|
419
|
+
self.save_to_s3("llm_start", f"Thinking")
|
|
420
|
+
|
|
421
|
+
def on_llm_new_token(self, token: str, **kwargs):
|
|
422
|
+
self.save_to_s3("llm_streaming", f"LLM generating response token: {token}")
|
|
423
|
+
|
|
424
|
+
def on_llm_end(self, response, **kwargs):
|
|
425
|
+
token_count = getattr(response, 'llm_output', {}).get('token_usage', {}).get('total_tokens', 'unknown')
|
|
426
|
+
self.save_to_s3("llm_complete", f"LLM completed response generation (tokens: {token_count})")
|
|
427
|
+
|
|
428
|
+
def on_tool_start(self, serialized, input_str: str, **kwargs):
|
|
429
|
+
tool_name = serialized.get('name', 'Unknown Tool')
|
|
430
|
+
self.save_to_s3("tool_start", f"Executing tool: {tool_name} with input: {input_str[:100]}")
|
|
431
|
+
|
|
432
|
+
def on_tool_end(self, output: str, **kwargs):
|
|
433
|
+
output_preview = str(output)[:100] if len(str(output)) > 100 else str(output)
|
|
434
|
+
self.save_to_s3("tool_complete", f"Tool execution completed with output: {output_preview}")
|
|
435
|
+
|
|
436
|
+
def on_chain_start(self, serialized, inputs, **kwargs):
|
|
437
|
+
chain_id = serialized.get('id', 'Unknown Chain')
|
|
438
|
+
self.save_to_s3("chain_start", f"Starting chain execution: {chain_id}")
|
|
439
|
+
|
|
440
|
+
def on_chain_end(self, outputs, **kwargs):
|
|
441
|
+
output_preview = str(outputs)[:100] if len(str(outputs)) > 100 else str(outputs)
|
|
442
|
+
self.save_to_s3("chain_complete", f"Chain execution completed with outputs: {output_preview}")
|
|
443
|
+
|
|
444
|
+
def on_agent_action(self, action, **kwargs):
|
|
445
|
+
action_tool = getattr(action, 'tool', 'Unknown')
|
|
446
|
+
action_input = getattr(action, 'tool_input', '')
|
|
447
|
+
self.save_to_s3("agent_action", f"Agent taking action with tool: {action_tool}, input: {str(action_input)[:100]}")
|
|
448
|
+
|
|
449
|
+
def on_agent_finish(self, finish, **kwargs):
|
|
450
|
+
return_values = getattr(finish, 'return_values', {})
|
|
451
|
+
output_preview = str(return_values)[:100] if len(str(return_values)) > 100 else str(return_values)
|
|
452
|
+
self.save_to_s3("agent_complete", f"Agent execution finished with result: {output_preview}")
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
|
-
from typing import List, Any, Callable, Dict, Union
|
|
3
|
+
from typing import List, Any, Callable, Dict, Optional, Type, Union
|
|
4
4
|
from langchain_core.runnables import RunnableLambda, RunnableSequence
|
|
5
5
|
from langchain.chains import SimpleSequentialChain
|
|
6
6
|
import logging
|
|
7
7
|
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
|
|
8
10
|
import mlflow
|
|
9
11
|
|
|
10
12
|
from mlflow.models.signature import ModelSignature
|
|
@@ -16,7 +18,7 @@ from fairo.core.execution.model_log_helper import ModelLogHelper
|
|
|
16
18
|
from fairo.core.runnable.runnable import Runnable
|
|
17
19
|
from fairo.core.workflow.utils import output_langchain_process_graph
|
|
18
20
|
from fairo.settings import get_fairo_api_key, get_fairo_api_secret, get_mlflow_experiment_name, get_mlflow_server, get_fairo_base_url
|
|
19
|
-
|
|
21
|
+
from fairo.core.tools import ChatSuggestions
|
|
20
22
|
|
|
21
23
|
logger = logging.getLogger(__name__)
|
|
22
24
|
|
|
@@ -32,18 +34,24 @@ class FairoExecutor:
|
|
|
32
34
|
experiment_name: str = None,
|
|
33
35
|
chain_class = SimpleSequentialChain,
|
|
34
36
|
input_fields: List[str] = [],
|
|
37
|
+
input_schema: Optional[Type[BaseModel]] = None,
|
|
38
|
+
chat_suggestions: Optional[ChatSuggestions] = None,
|
|
39
|
+
debug_mode: bool = False
|
|
35
40
|
):
|
|
36
41
|
if agents and runnable:
|
|
37
42
|
raise ValueError("FairoExecutor cannot be initialized with both 'agents' and 'runnable'. Please provide only one.")
|
|
38
|
-
if not input_fields:
|
|
39
|
-
raise ValueError("Missing input_fields")
|
|
43
|
+
if not input_fields and not input_schema:
|
|
44
|
+
raise ValueError("Missing required parameters: please provide at least one of 'input_fields' or 'input_schema'")
|
|
45
|
+
self.input_schema = input_schema
|
|
40
46
|
self.agents = agents
|
|
41
47
|
self.agent_type = agent_type
|
|
42
48
|
self.verbose = verbose
|
|
49
|
+
self.debug_mode = debug_mode
|
|
43
50
|
self.patch_run_output_json = patch_run_output_json
|
|
44
51
|
self.workflow_run_id = workflow_run_id
|
|
45
52
|
self.runnable = runnable
|
|
46
53
|
self.experiment_name = experiment_name if experiment_name else get_mlflow_experiment_name()
|
|
54
|
+
self._setup_logging()
|
|
47
55
|
self.setup_mlflow()
|
|
48
56
|
self.chain_class = chain_class
|
|
49
57
|
self.client = BaseClient(
|
|
@@ -51,6 +59,7 @@ class FairoExecutor:
|
|
|
51
59
|
password=get_fairo_api_secret(),
|
|
52
60
|
username=get_fairo_api_key()
|
|
53
61
|
)
|
|
62
|
+
self.chat_suggestions = chat_suggestions
|
|
54
63
|
self.input_fields = input_fields
|
|
55
64
|
# Inject shared attributes into agents
|
|
56
65
|
for agent in self.agents:
|
|
@@ -59,6 +68,14 @@ class FairoExecutor:
|
|
|
59
68
|
if hasattr(agent, 'verbose'):
|
|
60
69
|
agent.verbose = self.verbose
|
|
61
70
|
|
|
71
|
+
def _setup_logging(self):
|
|
72
|
+
"""Configure MLflow logging level based on debug_mode."""
|
|
73
|
+
mlflow_logger = logging.getLogger('mlflow')
|
|
74
|
+
if not self.debug_mode:
|
|
75
|
+
import warnings
|
|
76
|
+
warnings.filterwarnings('ignore', category=UserWarning)
|
|
77
|
+
mlflow_logger.setLevel(logging.ERROR)
|
|
78
|
+
|
|
62
79
|
def _build_pipeline(self) -> RunnableSequence:
|
|
63
80
|
if not self.agents and not self.runnable:
|
|
64
81
|
raise ValueError("At least one agent or runnable must be provided.")
|
|
@@ -67,8 +84,26 @@ class FairoExecutor:
|
|
|
67
84
|
pipeline = mlflow.pyfunc.load_model(self.runnable.artifact_path)
|
|
68
85
|
else:
|
|
69
86
|
pipeline = AgentChainWrapper(chain_class=self.chain_class, agent_functions_list=self.agents)
|
|
70
|
-
|
|
71
|
-
input_schema
|
|
87
|
+
# Convert Pydantic schema to MLflow Schema
|
|
88
|
+
if hasattr(self.input_schema, 'model_json_schema'):
|
|
89
|
+
# Extract field names from Pydantic schema
|
|
90
|
+
pydantic_schema = self.input_schema.model_json_schema()
|
|
91
|
+
properties = pydantic_schema.get('properties', {})
|
|
92
|
+
cols = []
|
|
93
|
+
for field_name, field_info in properties.items():
|
|
94
|
+
field_type = field_info.get('type', 'string')
|
|
95
|
+
# Map Pydantic types to MLflow types
|
|
96
|
+
mlflow_type = 'string' # Default to string
|
|
97
|
+
if field_type in ['integer', 'number']:
|
|
98
|
+
mlflow_type = 'double'
|
|
99
|
+
elif field_type == 'boolean':
|
|
100
|
+
mlflow_type = 'boolean'
|
|
101
|
+
cols.append(ColSpec(type=mlflow_type, name=field_name))
|
|
102
|
+
input_schema = Schema(cols)
|
|
103
|
+
else:
|
|
104
|
+
# Fallback to input_fields if schema is not Pydantic
|
|
105
|
+
cols = [ColSpec(type="string", name=field) for field in self.input_fields]
|
|
106
|
+
input_schema = Schema(cols)
|
|
72
107
|
|
|
73
108
|
output_schema = Schema([
|
|
74
109
|
ColSpec(type="string", name="output"),
|
|
@@ -97,7 +132,9 @@ class FairoExecutor:
|
|
|
97
132
|
fairo_settings = {
|
|
98
133
|
"type": type,
|
|
99
134
|
"process_graph": process_graph,
|
|
100
|
-
"
|
|
135
|
+
"schema": self.input_schema.model_json_schema() if self.input_schema else None,
|
|
136
|
+
"input_fields": list(self.input_schema.model_fields.keys()) if self.input_schema else self.input_fields,
|
|
137
|
+
"chat_suggestions": self.chat_suggestions.model_dump() if self.chat_suggestions else None,
|
|
101
138
|
}
|
|
102
139
|
if process_graph:
|
|
103
140
|
mlflow.log_text(json.dumps(fairo_settings, ensure_ascii=False, indent=2), artifact_file="fairo_settings.txt")
|