fairo 25.6.5__tar.gz → 25.7.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {fairo-25.6.5 → fairo-25.7.2}/PKG-INFO +4 -2
  2. fairo-25.7.2/fairo/__init__.py +1 -0
  3. fairo-25.7.2/fairo/core/chat/__init__.py +1 -0
  4. fairo-25.7.2/fairo/core/chat/chat.py +227 -0
  5. fairo-25.7.2/fairo/core/execution/agent_serializer.py +288 -0
  6. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/execution/executor.py +50 -97
  7. fairo-25.7.2/fairo/core/execution/model_log_helper.py +409 -0
  8. fairo-25.7.2/fairo/core/workflow/utils.py +460 -0
  9. {fairo-25.6.5 → fairo-25.7.2}/fairo/settings.py +0 -20
  10. {fairo-25.6.5 → fairo-25.7.2}/fairo.egg-info/PKG-INFO +4 -2
  11. {fairo-25.6.5 → fairo-25.7.2}/fairo.egg-info/SOURCES.txt +2 -0
  12. {fairo-25.6.5 → fairo-25.7.2}/fairo.egg-info/requires.txt +3 -1
  13. {fairo-25.6.5 → fairo-25.7.2}/pyproject.toml +4 -2
  14. fairo-25.6.5/fairo/__init__.py +0 -1
  15. fairo-25.6.5/fairo/core/chat/chat.py +0 -23
  16. fairo-25.6.5/fairo/core/workflow/utils.py +0 -191
  17. fairo-25.6.5/fairo/tests/__init__.py +0 -0
  18. {fairo-25.6.5 → fairo-25.7.2}/README.md +0 -0
  19. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/__init__.py +0 -0
  20. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/agent/__init__.py +0 -0
  21. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/agent/base_agent.py +0 -0
  22. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/agent/code_analysis_agent.py +0 -0
  23. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/agent/output/__init__.py +0 -0
  24. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/agent/output/base_output.py +0 -0
  25. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/agent/output/google_drive.py +0 -0
  26. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/agent/tools/__init__.py +0 -0
  27. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/agent/tools/base_tools.py +0 -0
  28. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/agent/tools/code_analysis.py +0 -0
  29. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/agent/tools/utils.py +0 -0
  30. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/agent/utils.py +0 -0
  31. {fairo-25.6.5/fairo/core/chat → fairo-25.7.2/fairo/core/client}/__init__.py +0 -0
  32. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/client/client.py +0 -0
  33. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/exceptions.py +0 -0
  34. {fairo-25.6.5/fairo/core/client → fairo-25.7.2/fairo/core/execution}/__init__.py +0 -0
  35. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/execution/env_finder.py +0 -0
  36. {fairo-25.6.5/fairo/core/execution → fairo-25.7.2/fairo/core/models}/__init__.py +0 -0
  37. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/models/custom_field_value.py +0 -0
  38. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/models/resources.py +0 -0
  39. {fairo-25.6.5/fairo/core/models → fairo-25.7.2/fairo/core/runnable}/__init__.py +0 -0
  40. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/runnable/runnable.py +0 -0
  41. {fairo-25.6.5/fairo/core/runnable → fairo-25.7.2/fairo/core/workflow}/__init__.py +0 -0
  42. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/workflow/base_workflow.py +0 -0
  43. {fairo-25.6.5 → fairo-25.7.2}/fairo/core/workflow/dependency.py +0 -0
  44. {fairo-25.6.5/fairo/core/workflow → fairo-25.7.2/fairo/metrics}/__init__.py +0 -0
  45. {fairo-25.6.5 → fairo-25.7.2}/fairo/metrics/fairness_object.py +0 -0
  46. {fairo-25.6.5 → fairo-25.7.2}/fairo/metrics/metrics.py +0 -0
  47. {fairo-25.6.5/fairo/metrics → fairo-25.7.2/fairo/tests}/__init__.py +0 -0
  48. {fairo-25.6.5 → fairo-25.7.2}/fairo/tests/test_metrics.py +0 -0
  49. {fairo-25.6.5 → fairo-25.7.2}/fairo.egg-info/dependency_links.txt +0 -0
  50. {fairo-25.6.5 → fairo-25.7.2}/fairo.egg-info/top_level.txt +0 -0
  51. {fairo-25.6.5 → fairo-25.7.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fairo
3
- Version: 25.6.5
3
+ Version: 25.7.2
4
4
  Summary: SDK for interfacing with Fairo SaaS platform.
5
5
  Author-email: "Fairo Systems, Inc." <support@fairo.ai>
6
6
  License: Apache-2.0
@@ -11,7 +11,7 @@ Classifier: Programming Language :: Python :: 3.10
11
11
  Classifier: License :: OSI Approved :: Apache Software License
12
12
  Classifier: Operating System :: OS Independent
13
13
  Description-Content-Type: text/markdown
14
- Requires-Dist: mlflow<3.0.0,>=2.21.0
14
+ Requires-Dist: mlflow<=3.1.1,>=3.1.0
15
15
  Requires-Dist: langchain<0.4.0,>=0.3.20
16
16
  Requires-Dist: langchain-aws<0.3.0,>=0.2.18
17
17
  Requires-Dist: langchain-community<0.4.0,>=0.3.20
@@ -19,6 +19,8 @@ Requires-Dist: langchain-core<0.4.0,>=0.3.49
19
19
  Requires-Dist: langchain-text-splitters<0.4.0,>=0.3.7
20
20
  Requires-Dist: psycopg2-binary<3.0.0,>=2.9.0
21
21
  Requires-Dist: langchain-postgres<0.1.0,>=0.0.14
22
+ Requires-Dist: setuptools>=79.0.0
23
+ Requires-Dist: pandas<3.0.0,>=2.0.0
22
24
 
23
25
  # Fairo SDK
24
26
 
@@ -0,0 +1 @@
1
+ __version__ = "25.7.2"
@@ -0,0 +1 @@
1
+ from .chat import ChatFairo
@@ -0,0 +1,227 @@
1
+
2
+ from langchain_community.chat_models.mlflow import ChatMlflow
3
+ from mlflow.deployments import get_deploy_client
4
+ from mlflow.deployments.base import BaseDeploymentClient
5
+ from fairo.settings import get_mlflow_gateway_chat_route, get_mlflow_gateway_uri, get_mlflow_user, get_mlflow_password
6
+ import requests
7
+ from requests.auth import HTTPBasicAuth
8
+ import json
9
+ import os
10
+
11
+ class FairoDeploymentClient(BaseDeploymentClient):
12
+ """Custom deployment client that implements predict_stream for Fairo endpoints."""
13
+
14
+ def __init__(self, target_uri: str, endpoint: str):
15
+ self.target_uri = target_uri
16
+ self.endpoint = endpoint
17
+
18
+ def predict_stream(self, deployment_name=None, inputs=None, endpoint=None):
19
+ """
20
+ Implement streaming predictions by making HTTP requests to the Fairo gateway.
21
+ """
22
+ endpoint = endpoint or self.endpoint
23
+
24
+ # Use the gateway URL to make streaming requests
25
+ gateway_url = f"{self.target_uri.rstrip('/')}/gateway/{endpoint}/invocations"
26
+
27
+ headers = {
28
+ 'Content-Type': 'application/json',
29
+ 'Accept': 'text/event-stream'
30
+ }
31
+
32
+ # Add authentication if needed
33
+ auth = None
34
+ if os.environ.get('MLFLOW_TRACKING_USERNAME') and os.environ.get('MLFLOW_TRACKING_PASSWORD'):
35
+ auth = HTTPBasicAuth(
36
+ os.environ.get('MLFLOW_TRACKING_USERNAME'),
37
+ os.environ.get('MLFLOW_TRACKING_PASSWORD')
38
+ )
39
+
40
+ # Make streaming request
41
+ try:
42
+ response = requests.post(
43
+ gateway_url,
44
+ json={**inputs, "stream": True},
45
+ headers=headers,
46
+ auth=auth,
47
+ )
48
+
49
+ if response.status_code != 200:
50
+ error_text = response.text
51
+ raise Exception(f"HTTP {response.status_code}: {error_text}")
52
+
53
+ # Check if response is actually streaming
54
+ content_type = response.headers.get('content-type', '')
55
+
56
+ chunk_count = 0
57
+
58
+ # Parse streaming response
59
+ for line in response.iter_lines():
60
+ if line:
61
+ line = line.decode('utf-8')
62
+
63
+ # Handle different streaming formats
64
+ if line.startswith('data: '):
65
+ try:
66
+ data_str = line[6:] # Remove 'data: ' prefix
67
+ if data_str.strip() == '[DONE]':
68
+ break
69
+ data = json.loads(data_str)
70
+ chunk_count += 1
71
+ yield data
72
+ except json.JSONDecodeError as e:
73
+ continue
74
+ else:
75
+ # Try parsing as direct JSON
76
+ try:
77
+ data = json.loads(line)
78
+ chunk_count += 1
79
+ yield data
80
+ except json.JSONDecodeError:
81
+ continue
82
+
83
+
84
+ # If no chunks were yielded, fall back to non-streaming
85
+ if chunk_count == 0:
86
+ # Try to get the full response as JSON
87
+ try:
88
+ if hasattr(response, 'json'):
89
+ result = response.json()
90
+ yield result
91
+ except:
92
+ # Create a minimal response to avoid the error
93
+ yield {
94
+ "choices": [{
95
+ "delta": {"content": "", "role": "assistant"},
96
+ "finish_reason": "stop"
97
+ }]
98
+ }
99
+
100
+ except requests.exceptions.RequestException as e:
101
+ raise Exception(f"Request failed: {e}")
102
+
103
+ def predict(self, deployment_name=None, inputs=None, endpoint=None):
104
+ """
105
+ Implement synchronous predictions by making HTTP requests to the Fairo gateway.
106
+ """
107
+ endpoint = endpoint or self.endpoint
108
+
109
+ # Use the gateway URL to make requests
110
+ gateway_url = f"{self.target_uri.rstrip('/')}/gateway/{endpoint}/invocations"
111
+
112
+ headers = {
113
+ 'Content-Type': 'application/json',
114
+ 'Accept': 'application/json'
115
+ }
116
+
117
+ # Add authentication if needed
118
+ auth = None
119
+ if os.environ.get('MLFLOW_TRACKING_USERNAME') and os.environ.get('MLFLOW_TRACKING_PASSWORD'):
120
+ auth = HTTPBasicAuth(
121
+ os.environ.get('MLFLOW_TRACKING_USERNAME'),
122
+ os.environ.get('MLFLOW_TRACKING_PASSWORD')
123
+ )
124
+
125
+ # Make request
126
+ response = requests.post(
127
+ gateway_url,
128
+ json=inputs,
129
+ headers=headers,
130
+ auth=auth
131
+ )
132
+
133
+ if response.status_code != 200:
134
+ raise Exception(f"HTTP {response.status_code}: {response.text}")
135
+
136
+ return response.json()
137
+
138
+ def get_deployment(self, name, endpoint=None):
139
+ """Get deployment information."""
140
+ raise NotImplementedError("get_deployment not implemented")
141
+
142
+ def list_deployments(self, endpoint=None):
143
+ """List available deployments."""
144
+ raise NotImplementedError("list_deployments not implemented")
145
+
146
+ def get_endpoint(self, endpoint):
147
+ """Get endpoint information."""
148
+ raise NotImplementedError("get_endpoint not implemented")
149
+
150
+ def list_endpoints(self):
151
+ """List available endpoints."""
152
+ raise NotImplementedError("list_endpoints not implemented")
153
+
154
+ def create_deployment(self, name, config, endpoint=None):
155
+ """Create a new deployment."""
156
+ raise NotImplementedError("create_deployment not implemented")
157
+
158
+ def update_deployment(self, name, config, endpoint=None):
159
+ """Update an existing deployment."""
160
+ raise NotImplementedError("update_deployment not implemented")
161
+
162
+ def delete_deployment(self, name, endpoint=None):
163
+ """Delete a deployment."""
164
+ raise NotImplementedError("delete_deployment not implemented")
165
+
166
+ def create_endpoint(self, name, config):
167
+ """Create a new endpoint."""
168
+ raise NotImplementedError("create_endpoint not implemented")
169
+
170
+ def update_endpoint(self, name, config):
171
+ """Update an existing endpoint."""
172
+ raise NotImplementedError("update_endpoint not implemented")
173
+
174
+ def delete_endpoint(self, name):
175
+ """Delete an endpoint."""
176
+ raise NotImplementedError("delete_endpoint not implemented")
177
+
178
+
179
+ class ChatFairo(ChatMlflow):
180
+
181
+ def __init__(self, **kwargs):
182
+
183
+ # # TODO <- see if this can be improved
184
+ # os.environ["MLFLOW_TRACKING_USERNAME"] = get_mlflow_user()
185
+ # os.environ["MLFLOW_TRACKING_PASSWORD"] = get_mlflow_password()
186
+
187
+ super().__init__(
188
+ target_uri=os.environ.get('MLFLOW_GATEWAY_URI', get_mlflow_gateway_uri()),
189
+ endpoint=os.environ.get('MLFLOW_GATEWAY_ROUTE', get_mlflow_gateway_chat_route()),
190
+ **kwargs
191
+ )
192
+
193
+ self._client = FairoDeploymentClient(self.target_uri, self.endpoint)
194
+
195
+ @property
196
+ def _target_uri(self):
197
+ return os.environ.get("MLFLOW_GATEWAY_URI", None)
198
+
199
+ @property
200
+ def _endpoint(self):
201
+ return os.environ.get("MLFLOW_GATEWAY_ROUTE", None)
202
+
203
+ def invoke(self, *args, **kwargs):
204
+ # Override invoke to use dynamic target_uri
205
+ self.target_uri = self._target_uri
206
+ self._client = FairoDeploymentClient(self.target_uri, self.endpoint)
207
+ return super().invoke(*args, **kwargs)
208
+
209
+
210
+ class FairoChat(ChatMlflow):
211
+ def __init__(self, endpoint, **kwargs):
212
+ super().__init__(
213
+ target_uri=os.environ.get('MLFLOW_GATEWAY_URI', None),
214
+ endpoint=endpoint,
215
+ # extra_params={"workflow_run_id": workflow_run_id},
216
+ **kwargs
217
+ )
218
+
219
+ @property
220
+ def _target_uri(self):
221
+ return os.environ.get("MLFLOW_GATEWAY_URI", None)
222
+
223
+ def invoke(self, *args, **kwargs):
224
+ # Override invoke to use dynamic target_uri
225
+ self.target_uri = self._target_uri
226
+ self._client = get_deploy_client(self.target_uri)
227
+ return super().invoke(*args, **kwargs)
@@ -0,0 +1,288 @@
1
+ from typing import Any, Dict
2
+ import mlflow
3
+ import cloudpickle
4
+ import os
5
+ import sys
6
+ from pathlib import Path
7
+ from langchain_core.runnables import RunnableLambda, Runnable
8
+ from langchain.chains import SimpleSequentialChain
9
+ import logging
10
+ import types
11
+ import threading
12
+ import pandas as pd
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class CustomPythonModel(mlflow.pyfunc.PythonModel):
16
+ def __init__(self):
17
+ self.agent = None
18
+
19
+ def __getstate__(self):
20
+ state = self.__dict__.copy()
21
+ state.pop("lock", None)
22
+
23
+ def __setstate__(self, state):
24
+ self.__dict__.update(state)
25
+ self.lock = threading.Lock()
26
+
27
+ def load_context(self, context):
28
+ import sys
29
+ import os
30
+ import shutil
31
+
32
+ agent_code_path = context.model_config["agent_code"]
33
+ agent_code_dir = os.path.dirname(agent_code_path)
34
+
35
+ if agent_code_dir not in sys.path:
36
+ sys.path.insert(0, agent_code_dir)
37
+
38
+ for artifact_name, artifact_path in context.model_config.items():
39
+ if artifact_name.startswith("local_module_"):
40
+ module_name = artifact_name.replace("local_module_", "")
41
+ module_filename = f"{module_name}.py"
42
+ dest_path = os.path.join(agent_code_dir, module_filename)
43
+
44
+ if not os.path.exists(dest_path):
45
+ shutil.copy2(artifact_path, dest_path)
46
+ print(f"Restored local module: {module_name}")
47
+
48
+ try:
49
+ import agent_code
50
+ from agent_code import create_simple_agent
51
+ self.agent_func = create_simple_agent
52
+ self.agent = self.agent_func()
53
+ except ImportError as e:
54
+ raise ImportError(f"Failed to import agent_code: {e}")
55
+
56
+ def predict(self, context, model_input):
57
+ if isinstance(model_input, list):
58
+ return [self.agent.run(query) for query in model_input]
59
+ else:
60
+ return self.agent.run(model_input)
61
+
62
+ class AgentChainWrapper:
63
+ def __init__(self, chain_class = SimpleSequentialChain, agent_functions_list = []):
64
+ self.chain_class = chain_class
65
+ self.agents = [func() for func in agent_functions_list]
66
+ self.agent_functions = agent_functions_list
67
+
68
+ def _wrap_agent_runnable(self, agent) -> RunnableLambda:
69
+ """
70
+ Wraps the agent's .run() method into a RunnableLambda with a custom function name.
71
+ Properly propagates errors instead of continuing to the next agent.
72
+ """
73
+ def base_fn(inputs: Dict[str, Any]) -> Dict[str, Any]:
74
+ # Run the agent, but don't catch exceptions - let them propagate
75
+ # This will stop the entire pipeline on agent failure
76
+ return agent.invoke(inputs)
77
+
78
+ # Check if result starts with "An error occurred" which indicates agent failure
79
+ # if isinstance(result, str) and result.startswith("An error occurred during execution:"):
80
+ # # Propagate the error by raising an exception to stop the execution
81
+ # raise RuntimeError(f"Agent {agent.__class__.__name__} failed: {result}")
82
+
83
+ # return result
84
+
85
+ # Clone function and set custom name
86
+ fn_name = f"runnable_{agent.__class__.__name__.lower().replace(' ', '_')}"
87
+ runnable_fn = types.FunctionType(
88
+ base_fn.__code__,
89
+ base_fn.__globals__,
90
+ name=fn_name,
91
+ argdefs=base_fn.__defaults__,
92
+ closure=base_fn.__closure__,
93
+ )
94
+
95
+ return RunnableLambda(runnable_fn)
96
+
97
+ def run(self, query):
98
+ result = query
99
+ def is_dataframe(obj) -> bool:
100
+ try:
101
+ return isinstance(obj, pd.DataFrame)
102
+ except Exception as e:
103
+ return False
104
+ if is_dataframe(result):
105
+ result = result.to_dict(orient='records')[0]
106
+ runnables = []
107
+ for agent in self.agents:
108
+ if isinstance(agent, Runnable):
109
+ runnables.append(agent)
110
+ else:
111
+ runnables.append(
112
+ self._wrap_agent_runnable(agent)
113
+ )
114
+ if self.chain_class is SimpleSequentialChain:
115
+ pipeline = runnables[0]
116
+ for r in runnables[1:]:
117
+ pipeline = pipeline | r
118
+ if is_dataframe(query):
119
+ query = query.to_dict(orient='records')[0]
120
+ return pipeline.invoke(query)
121
+ chain = self.chain_class(
122
+ chains=runnables,
123
+ )
124
+ return chain.run(result)
125
+
126
+ def predict(self, context = "", model_input = ""):
127
+ return self.run(model_input)
128
+
129
+ class CustomChainModel(mlflow.pyfunc.PythonModel):
130
+ def __init__(self):
131
+ self.agent_chain = None
132
+ self.agents = []
133
+
134
+ def __getstate__(self):
135
+ state = self.__dict__.copy()
136
+ state.pop("lock", None)
137
+
138
+ def __setstate__(self, state):
139
+ self.__dict__.update(state)
140
+ self.lock = threading.Lock()
141
+
142
+ def load_context(self, context):
143
+ import sys
144
+ import os
145
+ import shutil
146
+ import importlib.util
147
+
148
+ # Get the directory where artifacts are stored
149
+ base_dir = os.path.dirname(list(context.artifacts.values())[0])
150
+
151
+ if base_dir not in sys.path:
152
+ sys.path.insert(0, base_dir)
153
+
154
+ # Restore local modules
155
+ for artifact_name, artifact_path in context.artifacts.items():
156
+ if artifact_name.startswith("local_module_"):
157
+ module_name = artifact_name.replace("local_module_", "")
158
+ module_filename = f"{module_name}.py"
159
+ dest_path = os.path.join(base_dir, module_filename)
160
+
161
+ if not os.path.exists(dest_path):
162
+ shutil.copy2(artifact_path, dest_path)
163
+ print(f"Restored local module: {module_name}")
164
+
165
+ # Load chain configuration
166
+ chain_config_path = context.artifacts["chain_config"]
167
+ spec = importlib.util.spec_from_file_location("chain_config", chain_config_path)
168
+ chain_config_module = importlib.util.module_from_spec(spec)
169
+ spec.loader.exec_module(chain_config_module)
170
+
171
+ chain_config = chain_config_module.CHAIN_CONFIG
172
+
173
+ # Load each agent
174
+ agent_functions = []
175
+ for agent_info in chain_config["agents"]:
176
+ agent_code_file = agent_info["agent_code_file"]
177
+ function_name = agent_info["function_name"]
178
+
179
+ # Load the agent module - handle the artifact key mapping
180
+ artifact_key = agent_code_file.replace(".py", "")
181
+ if artifact_key not in context.artifacts:
182
+ # Try with agent_code_ prefix for consistency
183
+ artifact_key = f"agent_code_{agent_info['name'].split('_')[-1]}"
184
+ agent_code_path = context.artifacts[artifact_key]
185
+ spec = importlib.util.spec_from_file_location("agent_module", agent_code_path)
186
+ agent_module = importlib.util.module_from_spec(spec)
187
+ spec.loader.exec_module(agent_module)
188
+
189
+ # Get the agent function
190
+ agent_function = getattr(agent_module, function_name)
191
+ agent_functions.append(agent_function)
192
+
193
+ # Create the agent chain
194
+ self.agent_chain = AgentChainWrapper(agent_functions_list=agent_functions)
195
+
196
+ def predict(self, context, model_input):
197
+ if isinstance(model_input, list):
198
+ return [self.agent_chain.run(query) for query in model_input]
199
+ else:
200
+ return self.agent_chain.run(model_input)
201
+
202
+ class CrewAgentWrapper:
203
+ def __init__(self, agent_func=None):
204
+ if agent_func is not None:
205
+ # During logging phase
206
+ try:
207
+ from crew_agent import create_crew_agent
208
+ self.base_agent = create_crew_agent()
209
+ except ImportError:
210
+ raise ImportError("Could not import CrewAI agent functions")
211
+ else:
212
+ # During model loading phase
213
+ try:
214
+ from agent_code import create_crew_agent
215
+ self.base_agent = create_crew_agent()
216
+ except ImportError:
217
+ try:
218
+ from crew_agent import create_crew_agent
219
+ self.base_agent = create_crew_agent()
220
+ except ImportError:
221
+ raise ImportError("Could not import CrewAI agent")
222
+
223
+ def run(self, query):
224
+ try:
225
+ if hasattr(self, 'base_agent'):
226
+ # Import create_crew_with_task function
227
+ try:
228
+ from agent_code import create_crew_with_task
229
+ except ImportError:
230
+ from crew_agent import create_crew_with_task
231
+
232
+ crew = create_crew_with_task(query)
233
+ result = crew.kickoff()
234
+ return str(result)
235
+ else:
236
+ return "Error: Agent not properly initialized"
237
+ except Exception as e:
238
+ print(f"Error running CrewAI crew: {e}")
239
+ return f"Error executing query '{query}': {str(e)}"
240
+
241
+ def predict(self, context, model_input):
242
+ return self.run(model_input)
243
+
244
+ class CustomCrewModel(mlflow.pyfunc.PythonModel):
245
+ def __init__(self):
246
+ self.agent = None
247
+
248
+ def __getstate__(self):
249
+ state = self.__dict__.copy()
250
+ state.pop("lock", None)
251
+
252
+ def __setstate__(self, state):
253
+ self.__dict__.update(state)
254
+ self.lock = threading.Lock()
255
+
256
+ def load_context(self, context):
257
+ import sys
258
+ import os
259
+ import shutil
260
+
261
+ agent_code_path = context.model_config["agent_code"]
262
+ agent_code_dir = os.path.dirname(agent_code_path)
263
+
264
+ if agent_code_dir not in sys.path:
265
+ sys.path.insert(0, agent_code_dir)
266
+
267
+ for artifact_name, artifact_path in context.model_config.items():
268
+ if artifact_name.startswith("local_module_"):
269
+ module_name = artifact_name.replace("local_module_", "")
270
+ module_filename = f"{module_name}.py"
271
+ dest_path = os.path.join(agent_code_dir, module_filename)
272
+
273
+ if not os.path.exists(dest_path):
274
+ shutil.copy2(artifact_path, dest_path)
275
+ print(f"Restored local module: {module_name}")
276
+
277
+ try:
278
+ import agent_code
279
+ from agent_code import CrewAgentWrapper
280
+ self.agent = CrewAgentWrapper()
281
+ except ImportError as e:
282
+ raise ImportError(f"Failed to import CrewAI agent_code: {e}")
283
+
284
+ def predict(self, context, model_input):
285
+ if isinstance(model_input, list):
286
+ return [self.agent.run(query) for query in model_input]
287
+ else:
288
+ return self.agent.run(model_input)