genesis-flow 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,14 @@
1
1
  from typing import TYPE_CHECKING, Any, Optional, Union
2
2
 
3
3
  from mlflow.data import Dataset
4
+ from mlflow.data.dataset_source import DatasetSource
4
5
  from mlflow.data.digest_utils import compute_pandas_digest
5
6
  from mlflow.data.evaluation_dataset import EvaluationDataset as LegacyEvaluationDataset
6
7
  from mlflow.data.pyfunc_dataset_mixin import PyFuncConvertibleDatasetMixin
7
- from mlflow.data.spark_dataset_source import SparkDatasetSource
8
8
  from mlflow.entities import Dataset as DatasetEntity
9
+ from mlflow.genai.datasets.databricks_evaluation_dataset_source import (
10
+ DatabricksEvaluationDatasetSource,
11
+ )
9
12
 
10
13
  if TYPE_CHECKING:
11
14
  import pandas as pd
@@ -56,11 +59,11 @@ class EvaluationDataset(Dataset, PyFuncConvertibleDatasetMixin):
56
59
  return self._dataset.profile
57
60
 
58
61
  @property
59
- def source(self) -> Optional[str]:
62
+ def source(self) -> DatasetSource:
60
63
  """Source information for the dataset."""
61
- # NB: The managed Dataset entity in Agent SDK doesn't propagate the source
62
- # information. So we use the table name as the fallback source.
63
- return self._dataset.source or SparkDatasetSource(table_name=self.name)
64
+ if isinstance(self._dataset.source, DatasetSource):
65
+ return self._dataset.source
66
+ return DatabricksEvaluationDatasetSource(table_name=self.name, dataset_id=self.dataset_id)
64
67
 
65
68
  @property
66
69
  def source_type(self) -> Optional[str]:
@@ -35,7 +35,7 @@ class SerializedScorer:
35
35
 
36
36
  # Builtin scorer fields (for scorers from mlflow.genai.scorers.builtin_scorers)
37
37
  builtin_scorer_class: Optional[str] = None
38
- builtin_scorer_pydantic_data: Optional[dict[str, Any]] = None
38
+ builtin_scorer_pydantic_data: Optional[dict] = None
39
39
 
40
40
  # Decorator scorer fields (for @scorer decorated functions)
41
41
  call_source: Optional[str] = None
@@ -43,14 +43,21 @@ class SerializedScorer:
43
43
  original_func_name: Optional[str] = None
44
44
 
45
45
 
46
- @experimental(version="3.0.0")
46
+ @experimental
47
47
  class Scorer(BaseModel):
48
48
  name: str
49
- aggregations: Optional[list[str]] = None
49
+ aggregations: Optional[list] = None
50
+
51
+ _cached_dump: Optional[dict[str, Any]] = PrivateAttr(default=None)
50
52
 
51
- def model_dump(self, **kwargs) -> dict[str, Any]:
53
+ def model_dump(self, **kwargs) -> dict:
52
54
  """Override model_dump to include source code."""
53
55
  # Create serialized scorer with core fields
56
+
57
+ # Return cached dump if available (prevents re-serialization issues with dynamic functions)
58
+ if self._cached_dump is not None:
59
+ return self._cached_dump
60
+
54
61
  serialized = SerializedScorer(
55
62
  name=self.name,
56
63
  aggregations=self.aggregations,
@@ -79,7 +86,7 @@ class Scorer(BaseModel):
79
86
 
80
87
  return asdict(serialized)
81
88
 
82
- def _extract_source_code_info(self) -> dict[str, Optional[str]]:
89
+ def _extract_source_code_info(self) -> dict:
83
90
  """Extract source code information for the original decorated function."""
84
91
  from mlflow.genai.scorers.scorer_utils import extract_function_body
85
92
 
@@ -165,7 +172,13 @@ class Scorer(BaseModel):
165
172
  # Rather than serializing and deserializing the `run` method of `Scorer`, we recreate the
166
173
  # Scorer using the original function and the `@scorer` decorator. This should be safe so
167
174
  # long as `@scorer` is a stable API.
168
- return scorer(recreated_func, name=serialized.name, aggregations=serialized.aggregations)
175
+ scorer_instance = scorer(
176
+ recreated_func, name=serialized.name, aggregations=serialized.aggregations
177
+ )
178
+ # Cache the serialized data to prevent re-serialization issues with dynamic functions
179
+ original_serialized_data = asdict(serialized)
180
+ object.__setattr__(scorer_instance, "_cached_dump", original_serialized_data)
181
+ return scorer_instance
169
182
 
170
183
  def run(self, *, inputs=None, outputs=None, expectations=None, trace=None):
171
184
  from mlflow.evaluation import Assessment as LegacyAssessment
@@ -317,18 +330,13 @@ class Scorer(BaseModel):
317
330
  raise NotImplementedError("Implementation of __call__ is required for Scorer class")
318
331
 
319
332
 
320
- @experimental(version="3.0.0")
333
+ @experimental
321
334
  def scorer(
322
335
  func=None,
323
336
  *,
324
337
  name: Optional[str] = None,
325
338
  aggregations: Optional[
326
- list[
327
- Union[
328
- Literal["min", "max", "mean", "median", "variance", "p90", "p99"],
329
- Callable[[list[Union[int, float]]], Union[int, float]],
330
- ]
331
- ]
339
+ list[Union[Literal["min", "max", "mean", "median", "variance", "p90", "p99"], Callable]]
332
340
  ] = None,
333
341
  ):
334
342
  """
@@ -464,7 +472,7 @@ def scorer(
464
472
 
465
473
  class CustomScorer(Scorer):
466
474
  # Store reference to the original function
467
- _original_func: Optional[Callable[..., Any]] = PrivateAttr(default=None)
475
+ _original_func: Optional[Callable] = PrivateAttr(default=None)
468
476
 
469
477
  def __init__(self, **data):
470
478
  super().__init__(**data)
@@ -283,6 +283,16 @@ def _get_lc_model_input_fields(lc_model) -> set[str]:
283
283
 
284
284
 
285
285
  def _should_transform_request_json_for_chat(lc_model):
286
+ # Don't convert the request to LangChain's Message format for LangGraph models.
287
+ # Inputs may have key like "messages", but they are graph state fields, not OAI chat format.
288
+ try:
289
+ from langgraph.graph.state import CompiledStateGraph
290
+
291
+ if isinstance(lc_model, CompiledStateGraph):
292
+ return False
293
+ except ImportError:
294
+ pass
295
+
286
296
  # Avoid converting the request to LangChain's Message format if the chain
287
297
  # is an AgentExecutor, as LangChainChatMessage might not be accepted by the chain
288
298
  from langchain.agents import AgentExecutor
@@ -91,8 +91,8 @@ def _install_pyfunc_deps(
91
91
  server_deps = ["gunicorn[gevent]"]
92
92
  if enable_mlserver:
93
93
  server_deps = [
94
- "'mlserver>=1.2.0,!=1.3.1,<1.4.0'",
95
- "'mlserver-mlflow>=1.2.0,!=1.3.1,<1.4.0'",
94
+ "'mlserver>=1.2.0,!=1.3.1,<2.0.0'",
95
+ "'mlserver-mlflow>=1.2.0,!=1.3.1,<2.0.0'",
96
96
  ]
97
97
 
98
98
  install_server_deps = [f"pip install {' '.join(server_deps)}"]