genesis-flow 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {genesis_flow-1.0.1.dist-info → genesis_flow-1.0.4.dist-info}/METADATA +32 -28
- {genesis_flow-1.0.1.dist-info → genesis_flow-1.0.4.dist-info}/RECORD +29 -26
- mlflow/azure/config.py +13 -7
- mlflow/azure/connection_factory.py +15 -2
- mlflow/data/dataset_source_registry.py +8 -0
- mlflow/gateway/providers/bedrock.py +298 -0
- mlflow/genai/datasets/databricks_evaluation_dataset_source.py +77 -0
- mlflow/genai/datasets/evaluation_dataset.py +8 -5
- mlflow/genai/scorers/base.py +22 -14
- mlflow/langchain/utils/chat.py +10 -0
- mlflow/models/container/__init__.py +2 -2
- mlflow/spark/__init__.py +1286 -0
- mlflow/store/artifact/azure_blob_artifact_repo.py +1 -1
- mlflow/store/artifact/azure_data_lake_artifact_repo.py +1 -1
- mlflow/store/artifact/gcs_artifact_repo.py +1 -1
- mlflow/store/artifact/local_artifact_repo.py +2 -1
- mlflow/store/artifact/s3_artifact_repo.py +173 -3
- mlflow/tracing/client.py +139 -49
- mlflow/tracing/export/mlflow_v3.py +8 -11
- mlflow/tracing/provider.py +5 -1
- mlflow/tracking/_model_registry/client.py +5 -1
- mlflow/tracking/_tracking_service/utils.py +17 -5
- mlflow/utils/file_utils.py +2 -1
- mlflow/utils/rest_utils.py +4 -0
- mlflow/version.py +2 -2
- {genesis_flow-1.0.1.dist-info → genesis_flow-1.0.4.dist-info}/WHEEL +0 -0
- {genesis_flow-1.0.1.dist-info → genesis_flow-1.0.4.dist-info}/entry_points.txt +0 -0
- {genesis_flow-1.0.1.dist-info → genesis_flow-1.0.4.dist-info}/licenses/LICENSE.txt +0 -0
- {genesis_flow-1.0.1.dist-info → genesis_flow-1.0.4.dist-info}/top_level.txt +0 -0
@@ -41,7 +41,7 @@ class AzureBlobArtifactRepository(ArtifactRepository, MultipartUploadMixin):
|
|
41
41
|
- DefaultAzureCredential is configured
|
42
42
|
"""
|
43
43
|
|
44
|
-
def __init__(self, artifact_uri: str, tracking_uri: Optional[str] = None
|
44
|
+
def __init__(self, artifact_uri: str, client=None, tracking_uri: Optional[str] = None) -> None:
|
45
45
|
super().__init__(artifact_uri, tracking_uri)
|
46
46
|
|
47
47
|
_DEFAULT_TIMEOUT = 600 # 10 minutes
|
@@ -82,9 +82,9 @@ class AzureDataLakeArtifactRepository(CloudArtifactRepository):
|
|
82
82
|
def __init__(
|
83
83
|
self,
|
84
84
|
artifact_uri: str,
|
85
|
-
tracking_uri: Optional[str] = None,
|
86
85
|
credential=None,
|
87
86
|
credential_refresh_def=None,
|
87
|
+
tracking_uri: Optional[str] = None,
|
88
88
|
) -> None:
|
89
89
|
super().__init__(artifact_uri, tracking_uri)
|
90
90
|
_DEFAULT_TIMEOUT = 600 # 10 minutes
|
@@ -43,9 +43,9 @@ class GCSArtifactRepository(ArtifactRepository, MultipartUploadMixin):
|
|
43
43
|
def __init__(
|
44
44
|
self,
|
45
45
|
artifact_uri: str,
|
46
|
-
tracking_uri: Optional[str] = None,
|
47
46
|
client=None,
|
48
47
|
credential_refresh_def=None,
|
48
|
+
tracking_uri: Optional[str] = None,
|
49
49
|
) -> None:
|
50
50
|
super().__init__(artifact_uri, tracking_uri)
|
51
51
|
from google.auth.exceptions import DefaultCredentialsError
|
@@ -14,6 +14,7 @@ from mlflow.utils.file_utils import (
|
|
14
14
|
local_file_uri_to_path,
|
15
15
|
mkdir,
|
16
16
|
relative_path_to_artifact_path,
|
17
|
+
shutil_copytree_without_file_permissions,
|
17
18
|
)
|
18
19
|
from mlflow.utils.uri import validate_path_is_safe
|
19
20
|
|
@@ -64,7 +65,7 @@ class LocalArtifactRepository(ArtifactRepository):
|
|
64
65
|
)
|
65
66
|
if not os.path.exists(artifact_dir):
|
66
67
|
mkdir(artifact_dir)
|
67
|
-
|
68
|
+
shutil_copytree_without_file_permissions(local_dir, artifact_dir)
|
68
69
|
|
69
70
|
def download_artifacts(self, artifact_path, dst_path=None):
|
70
71
|
"""
|
@@ -122,16 +122,54 @@ def _get_s3_client(
|
|
122
122
|
|
123
123
|
|
124
124
|
class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
|
125
|
-
"""
|
125
|
+
"""
|
126
|
+
Stores artifacts on Amazon S3.
|
127
|
+
|
128
|
+
This repository provides MLflow artifact storage using Amazon S3 as the backend.
|
129
|
+
It supports both single-file uploads and multipart uploads for large files,
|
130
|
+
with automatic content type detection and configurable upload parameters.
|
131
|
+
|
132
|
+
The repository uses boto3 for S3 operations and supports various authentication
|
133
|
+
methods including AWS credentials, IAM roles, and environment variables.
|
134
|
+
|
135
|
+
Environment Variables:
|
136
|
+
AWS_ACCESS_KEY_ID: AWS access key ID for authentication
|
137
|
+
AWS_SECRET_ACCESS_KEY: AWS secret access key for authentication
|
138
|
+
AWS_SESSION_TOKEN: AWS session token for temporary credentials
|
139
|
+
AWS_DEFAULT_REGION: Default AWS region for S3 operations
|
140
|
+
MLFLOW_S3_ENDPOINT_URL: Custom S3 endpoint URL (for S3-compatible storage)
|
141
|
+
MLFLOW_S3_IGNORE_TLS: Set to 'true' to disable TLS verification
|
142
|
+
MLFLOW_S3_UPLOAD_EXTRA_ARGS: JSON string of extra arguments for S3 uploads
|
143
|
+
MLFLOW_BOTO_CLIENT_ADDRESSING_STYLE: S3 addressing style ('path' or 'virtual')
|
144
|
+
|
145
|
+
Note:
|
146
|
+
This class inherits from both ArtifactRepository and MultipartUploadMixin,
|
147
|
+
providing full artifact management capabilities including efficient large file uploads.
|
148
|
+
"""
|
126
149
|
|
127
150
|
def __init__(
|
128
151
|
self,
|
129
152
|
artifact_uri: str,
|
130
|
-
tracking_uri: Optional[str] = None,
|
131
153
|
access_key_id=None,
|
132
154
|
secret_access_key=None,
|
133
155
|
session_token=None,
|
156
|
+
tracking_uri: Optional[str] = None,
|
134
157
|
) -> None:
|
158
|
+
"""
|
159
|
+
Initialize an S3 artifact repository.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
artifact_uri: S3 URI in the format 's3://bucket-name/path/to/artifacts'.
|
163
|
+
The URI must be a valid S3 URI with a bucket that exists and is accessible.
|
164
|
+
access_key_id: Optional AWS access key ID. If None, uses default AWS credential
|
165
|
+
resolution (environment variables, IAM roles, etc.).
|
166
|
+
secret_access_key: Optional AWS secret access key. Must be provided if
|
167
|
+
access_key_id is provided.
|
168
|
+
session_token: Optional AWS session token for temporary credentials.
|
169
|
+
Used with STS tokens or IAM roles.
|
170
|
+
tracking_uri: Optional URI for the MLflow tracking server.
|
171
|
+
If None, uses the current tracking URI context.
|
172
|
+
"""
|
135
173
|
super().__init__(artifact_uri, tracking_uri)
|
136
174
|
self._access_key_id = access_key_id
|
137
175
|
self._secret_access_key = secret_access_key
|
@@ -145,7 +183,17 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
|
|
145
183
|
)
|
146
184
|
|
147
185
|
def parse_s3_compliant_uri(self, uri):
|
148
|
-
"""
|
186
|
+
"""
|
187
|
+
Parse an S3 URI into bucket and path components.
|
188
|
+
|
189
|
+
Args:
|
190
|
+
uri: S3 URI in the format 's3://bucket-name/path/to/object'
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
A tuple containing (bucket_name, object_path) where:
|
194
|
+
- bucket_name: The S3 bucket name
|
195
|
+
- object_path: The path within the bucket (without leading slash)
|
196
|
+
"""
|
149
197
|
parsed = urllib.parse.urlparse(uri)
|
150
198
|
if parsed.scheme != "s3":
|
151
199
|
raise Exception(f"Not an S3 URI: {uri}")
|
@@ -156,6 +204,17 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
|
|
156
204
|
|
157
205
|
@staticmethod
|
158
206
|
def get_s3_file_upload_extra_args():
|
207
|
+
"""
|
208
|
+
Get additional S3 upload arguments from environment variables.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
Dictionary of extra arguments for S3 uploads, or None if not configured.
|
212
|
+
These arguments are passed to boto3's upload_file method.
|
213
|
+
|
214
|
+
Environment Variables:
|
215
|
+
MLFLOW_S3_UPLOAD_EXTRA_ARGS: JSON string containing extra arguments
|
216
|
+
for S3 uploads (e.g., '{"ServerSideEncryption": "AES256"}')
|
217
|
+
"""
|
159
218
|
s3_file_upload_extra_args = MLFLOW_S3_UPLOAD_EXTRA_ARGS.get()
|
160
219
|
if s3_file_upload_extra_args:
|
161
220
|
return json.loads(s3_file_upload_extra_args)
|
@@ -175,6 +234,19 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
|
|
175
234
|
s3_client.upload_file(Filename=local_file, Bucket=bucket, Key=key, ExtraArgs=extra_args)
|
176
235
|
|
177
236
|
def log_artifact(self, local_file, artifact_path=None):
|
237
|
+
"""
|
238
|
+
Log a local file as an artifact to S3.
|
239
|
+
|
240
|
+
This method uploads a single file to S3 with automatic content type detection
|
241
|
+
and optional extra upload arguments from environment variables.
|
242
|
+
|
243
|
+
Args:
|
244
|
+
local_file: Absolute path to the local file to upload. The file must
|
245
|
+
exist and be readable.
|
246
|
+
artifact_path: Optional relative path within the S3 bucket where the
|
247
|
+
artifact should be stored. If None, the file is stored in the root
|
248
|
+
of the configured S3 path. Use forward slashes (/) for path separators.
|
249
|
+
"""
|
178
250
|
(bucket, dest_path) = self.parse_s3_compliant_uri(self.artifact_uri)
|
179
251
|
if artifact_path:
|
180
252
|
dest_path = posixpath.join(dest_path, artifact_path)
|
@@ -184,6 +256,20 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
|
|
184
256
|
)
|
185
257
|
|
186
258
|
def log_artifacts(self, local_dir, artifact_path=None):
|
259
|
+
"""
|
260
|
+
Log all files in a local directory as artifacts to S3.
|
261
|
+
|
262
|
+
This method recursively uploads all files in the specified directory,
|
263
|
+
preserving the directory structure in S3. Each file is uploaded with
|
264
|
+
automatic content type detection.
|
265
|
+
|
266
|
+
Args:
|
267
|
+
local_dir: Absolute path to the local directory containing files to upload.
|
268
|
+
The directory must exist and be readable.
|
269
|
+
artifact_path: Optional relative path within the S3 bucket where the
|
270
|
+
artifacts should be stored. If None, files are stored in the root
|
271
|
+
of the configured S3 path. Use forward slashes (/) for path separators.
|
272
|
+
"""
|
187
273
|
(bucket, dest_path) = self.parse_s3_compliant_uri(self.artifact_uri)
|
188
274
|
if artifact_path:
|
189
275
|
dest_path = posixpath.join(dest_path, artifact_path)
|
@@ -205,6 +291,25 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
|
|
205
291
|
)
|
206
292
|
|
207
293
|
def list_artifacts(self, path=None):
|
294
|
+
"""
|
295
|
+
List all artifacts directly under the specified S3 path.
|
296
|
+
|
297
|
+
This method uses S3's list_objects_v2 API with pagination to efficiently
|
298
|
+
list artifacts. It treats S3 prefixes as directories and returns both
|
299
|
+
files and directories as FileInfo objects.
|
300
|
+
|
301
|
+
Args:
|
302
|
+
path: Optional relative path within the S3 bucket to list. If None,
|
303
|
+
lists artifacts in the root of the configured S3 path. If the path
|
304
|
+
refers to a single file, returns an empty list per MLflow convention.
|
305
|
+
|
306
|
+
Returns:
|
307
|
+
A list of FileInfo objects representing artifacts directly under the
|
308
|
+
specified path. Each FileInfo contains:
|
309
|
+
- path: Relative path of the artifact from the repository root
|
310
|
+
- is_dir: True if the artifact represents a directory (S3 prefix)
|
311
|
+
- file_size: Size in bytes for files, None for directories
|
312
|
+
"""
|
208
313
|
(bucket, artifact_path) = self.parse_s3_compliant_uri(self.artifact_uri)
|
209
314
|
dest_path = artifact_path
|
210
315
|
if path:
|
@@ -247,6 +352,18 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
|
|
247
352
|
)
|
248
353
|
|
249
354
|
def _download_file(self, remote_file_path, local_path):
|
355
|
+
"""
|
356
|
+
Download a file from S3 to the local filesystem.
|
357
|
+
|
358
|
+
This method downloads a single file from S3 to the specified local path.
|
359
|
+
It's used internally by the download_artifacts method.
|
360
|
+
|
361
|
+
Args:
|
362
|
+
remote_file_path: Relative path of the file within the S3 bucket,
|
363
|
+
relative to the repository's root path.
|
364
|
+
local_path: Absolute path where the file should be saved locally.
|
365
|
+
The parent directory must exist.
|
366
|
+
"""
|
250
367
|
(bucket, s3_root_path) = self.parse_s3_compliant_uri(self.artifact_uri)
|
251
368
|
s3_full_path = posixpath.join(s3_root_path, remote_file_path)
|
252
369
|
s3_client = self._get_s3_client()
|
@@ -273,6 +390,28 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
|
|
273
390
|
s3_client.delete_objects(Bucket=bucket, Delete={"Objects": keys})
|
274
391
|
|
275
392
|
def create_multipart_upload(self, local_file, num_parts=1, artifact_path=None):
|
393
|
+
"""
|
394
|
+
Initiate a multipart upload for efficient large file uploads to S3.
|
395
|
+
|
396
|
+
This method creates a multipart upload session in S3 and generates
|
397
|
+
presigned URLs for uploading each part. This is more efficient than
|
398
|
+
single-part uploads for large files and provides better error recovery.
|
399
|
+
|
400
|
+
Args:
|
401
|
+
local_file: Absolute path to the local file to upload. The file must
|
402
|
+
exist and be readable.
|
403
|
+
num_parts: Number of parts to split the upload into. Must be between
|
404
|
+
1 and 10,000 (S3 limit). More parts allow greater parallelism
|
405
|
+
but increase overhead.
|
406
|
+
artifact_path: Optional relative path within the S3 bucket where the
|
407
|
+
artifact should be stored. If None, the file is stored in the root
|
408
|
+
of the configured S3 path.
|
409
|
+
|
410
|
+
Returns:
|
411
|
+
CreateMultipartUploadResponse containing:
|
412
|
+
- credentials: List of MultipartUploadCredential objects with presigned URLs
|
413
|
+
- upload_id: S3 upload ID for tracking this multipart upload
|
414
|
+
"""
|
276
415
|
(bucket, dest_path) = self.parse_s3_compliant_uri(self.artifact_uri)
|
277
416
|
if artifact_path:
|
278
417
|
dest_path = posixpath.join(dest_path, artifact_path)
|
@@ -307,6 +446,23 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
|
|
307
446
|
)
|
308
447
|
|
309
448
|
def complete_multipart_upload(self, local_file, upload_id, parts=None, artifact_path=None):
|
449
|
+
"""
|
450
|
+
Complete a multipart upload by combining all parts into a single S3 object.
|
451
|
+
|
452
|
+
This method should be called after all parts have been successfully uploaded
|
453
|
+
using the presigned URLs from create_multipart_upload. It tells S3 to combine
|
454
|
+
all the parts into the final object.
|
455
|
+
|
456
|
+
Args:
|
457
|
+
local_file: Absolute path to the local file that was uploaded. Must match
|
458
|
+
the local_file used in create_multipart_upload.
|
459
|
+
upload_id: The S3 upload ID returned by create_multipart_upload.
|
460
|
+
parts: List of MultipartUploadPart objects containing metadata for each
|
461
|
+
successfully uploaded part. Must include part_number and etag for each part.
|
462
|
+
Parts must be provided in order (part 1, part 2, etc.).
|
463
|
+
artifact_path: Optional relative path where the artifact should be stored.
|
464
|
+
Must match the artifact_path used in create_multipart_upload.
|
465
|
+
"""
|
310
466
|
(bucket, dest_path) = self.parse_s3_compliant_uri(self.artifact_uri)
|
311
467
|
if artifact_path:
|
312
468
|
dest_path = posixpath.join(dest_path, artifact_path)
|
@@ -318,6 +474,20 @@ class S3ArtifactRepository(ArtifactRepository, MultipartUploadMixin):
|
|
318
474
|
)
|
319
475
|
|
320
476
|
def abort_multipart_upload(self, local_file, upload_id, artifact_path=None):
|
477
|
+
"""
|
478
|
+
Abort a multipart upload and clean up any uploaded parts.
|
479
|
+
|
480
|
+
This method should be called if a multipart upload fails or is cancelled.
|
481
|
+
It cleans up any parts that were successfully uploaded and cancels the
|
482
|
+
multipart upload session in S3.
|
483
|
+
|
484
|
+
Args:
|
485
|
+
local_file: Absolute path to the local file that was being uploaded.
|
486
|
+
Must match the local_file used in create_multipart_upload.
|
487
|
+
upload_id: The S3 upload ID returned by create_multipart_upload.
|
488
|
+
artifact_path: Optional relative path where the artifact would have been stored.
|
489
|
+
Must match the artifact_path used in create_multipart_upload.
|
490
|
+
"""
|
321
491
|
(bucket, dest_path) = self.parse_s3_compliant_uri(self.artifact_uri)
|
322
492
|
if artifact_path:
|
323
493
|
dest_path = posixpath.join(dest_path, artifact_path)
|
mlflow/tracing/client.py
CHANGED
@@ -2,15 +2,19 @@ import json
|
|
2
2
|
import logging
|
3
3
|
from concurrent.futures import ThreadPoolExecutor
|
4
4
|
from contextlib import nullcontext
|
5
|
-
from typing import Optional, Sequence
|
5
|
+
from typing import Optional, Sequence, Union
|
6
6
|
|
7
7
|
import mlflow
|
8
|
-
from mlflow.entities.assessment import
|
8
|
+
from mlflow.entities.assessment import (
|
9
|
+
Assessment,
|
10
|
+
)
|
9
11
|
from mlflow.entities.model_registry import PromptVersion
|
10
12
|
from mlflow.entities.span import NO_OP_SPAN_TRACE_ID
|
11
13
|
from mlflow.entities.trace import Trace
|
12
14
|
from mlflow.entities.trace_data import TraceData
|
13
15
|
from mlflow.entities.trace_info import TraceInfo
|
16
|
+
from mlflow.entities.trace_info_v2 import TraceInfoV2
|
17
|
+
from mlflow.entities.trace_status import TraceStatus
|
14
18
|
from mlflow.environment_variables import MLFLOW_SEARCH_TRACES_MAX_THREADS
|
15
19
|
from mlflow.exceptions import (
|
16
20
|
MlflowException,
|
@@ -59,65 +63,128 @@ class TracingClient:
|
|
59
63
|
def store(self):
|
60
64
|
return _get_store(self.tracking_uri)
|
61
65
|
|
62
|
-
def start_trace(
|
66
|
+
def start_trace(
|
67
|
+
self,
|
68
|
+
experiment_id: str,
|
69
|
+
timestamp_ms: int,
|
70
|
+
request_metadata: dict[str, str],
|
71
|
+
tags: dict[str, str],
|
72
|
+
):
|
63
73
|
"""
|
64
|
-
|
74
|
+
Start an initial TraceInfo object in the backend store.
|
65
75
|
|
66
76
|
Args:
|
67
|
-
|
77
|
+
experiment_id: String id of the experiment for this run.
|
78
|
+
timestamp_ms: Start time of the trace, in milliseconds since the UNIX epoch.
|
79
|
+
request_metadata: Metadata of the trace.
|
80
|
+
tags: Tags of the trace.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
The created TraceInfo object.
|
84
|
+
"""
|
85
|
+
tags = exclude_immutable_tags(tags or {})
|
86
|
+
return self.store.start_trace(
|
87
|
+
experiment_id=experiment_id,
|
88
|
+
timestamp_ms=timestamp_ms,
|
89
|
+
request_metadata=request_metadata,
|
90
|
+
tags=tags,
|
91
|
+
)
|
92
|
+
|
93
|
+
def start_trace_v3(self, trace: Trace) -> TraceInfo:
|
94
|
+
"""
|
95
|
+
Start a trace using the V3 API format.
|
96
|
+
NB: This method is named "Start" for internal reason in the backend, but actually
|
97
|
+
should be called at the end of the trace. We will migrate this to "CreateTrace"
|
98
|
+
API in the future to avoid confusion.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
trace: The Trace object to create.
|
68
102
|
|
69
103
|
Returns:
|
70
104
|
The returned TraceInfoV3 object from the backend.
|
71
105
|
"""
|
72
|
-
return self.store.
|
106
|
+
return self.store.start_trace_v3(trace=trace)
|
107
|
+
|
108
|
+
def end_trace(
|
109
|
+
self,
|
110
|
+
request_id: str,
|
111
|
+
timestamp_ms: int,
|
112
|
+
status: TraceStatus,
|
113
|
+
request_metadata: dict[str, str],
|
114
|
+
tags: dict[str, str],
|
115
|
+
) -> TraceInfoV2:
|
116
|
+
"""
|
117
|
+
Update the TraceInfo object in the backend store with the completed trace info.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
request_id: Unique string identifier of the trace.
|
121
|
+
timestamp_ms: End time of the trace, in milliseconds. The execution time field
|
122
|
+
in the TraceInfo will be calculated by subtracting the start time from this.
|
123
|
+
status: Status of the trace.
|
124
|
+
request_metadata: Metadata of the trace. This will be merged with the existing
|
125
|
+
metadata logged during the start_trace call.
|
126
|
+
tags: Tags of the trace. This will be merged with the existing tags logged
|
127
|
+
during the start_trace or set_trace_tag calls.
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
The updated TraceInfo object.
|
131
|
+
"""
|
132
|
+
tags = exclude_immutable_tags(tags or {})
|
133
|
+
return self.store.end_trace(
|
134
|
+
request_id=request_id,
|
135
|
+
timestamp_ms=timestamp_ms,
|
136
|
+
status=status,
|
137
|
+
request_metadata=request_metadata,
|
138
|
+
tags=tags,
|
139
|
+
)
|
73
140
|
|
74
141
|
def delete_traces(
|
75
142
|
self,
|
76
143
|
experiment_id: str,
|
77
144
|
max_timestamp_millis: Optional[int] = None,
|
78
145
|
max_traces: Optional[int] = None,
|
79
|
-
|
146
|
+
request_ids: Optional[list[str]] = None,
|
80
147
|
) -> int:
|
81
148
|
return self.store.delete_traces(
|
82
149
|
experiment_id=experiment_id,
|
83
150
|
max_timestamp_millis=max_timestamp_millis,
|
84
151
|
max_traces=max_traces,
|
85
|
-
|
152
|
+
request_ids=request_ids,
|
86
153
|
)
|
87
154
|
|
88
|
-
def get_trace_info(self,
|
155
|
+
def get_trace_info(self, request_id, should_query_v3: bool = False) -> TraceInfoV2:
|
89
156
|
"""
|
90
|
-
Get the trace info matching the ``
|
157
|
+
Get the trace info matching the ``request_id``.
|
91
158
|
|
92
159
|
Args:
|
93
|
-
|
160
|
+
request_id: String id of the trace to fetch.
|
161
|
+
should_query_v3: If True, the backend store will query the V3 API for the trace info.
|
162
|
+
TODO: Remove this flag once the V3 API is the default in OSS.
|
94
163
|
|
95
164
|
Returns:
|
96
165
|
TraceInfo object, of type ``mlflow.entities.trace_info.TraceInfo``.
|
97
166
|
"""
|
98
|
-
|
99
|
-
if trace is not None:
|
100
|
-
return trace.info
|
167
|
+
return self.store.get_trace_info(request_id, should_query_v3=should_query_v3)
|
101
168
|
|
102
|
-
|
103
|
-
|
104
|
-
def get_trace(self, trace_id: str) -> Trace:
|
169
|
+
def get_trace(self, request_id) -> Trace:
|
105
170
|
"""
|
106
|
-
Get the trace matching the ``
|
171
|
+
Get the trace matching the ``request_id``.
|
107
172
|
|
108
173
|
Args:
|
109
|
-
|
174
|
+
request_id: String id of the trace to fetch.
|
110
175
|
|
111
176
|
Returns:
|
112
177
|
The fetched Trace object, of type ``mlflow.entities.Trace``.
|
113
178
|
"""
|
114
|
-
trace_info = self.get_trace_info(
|
179
|
+
trace_info = self.get_trace_info(
|
180
|
+
request_id=request_id, should_query_v3=is_databricks_uri(self.tracking_uri)
|
181
|
+
)
|
115
182
|
try:
|
116
183
|
trace_data = self._download_trace_data(trace_info)
|
117
184
|
except MlflowTraceDataNotFound:
|
118
185
|
raise MlflowException(
|
119
186
|
message=(
|
120
|
-
f"Trace with ID {
|
187
|
+
f"Trace with ID {request_id} cannot be loaded because it is missing span data."
|
121
188
|
" Please try creating or loading another trace."
|
122
189
|
),
|
123
190
|
error_code=BAD_REQUEST,
|
@@ -125,7 +192,7 @@ class TracingClient:
|
|
125
192
|
except MlflowTraceDataCorrupted:
|
126
193
|
raise MlflowException(
|
127
194
|
message=(
|
128
|
-
f"Trace with ID {
|
195
|
+
f"Trace with ID {request_id} cannot be loaded because its span data"
|
129
196
|
" is corrupted. Please try creating or loading another trace."
|
130
197
|
),
|
131
198
|
error_code=BAD_REQUEST,
|
@@ -253,23 +320,29 @@ class TracingClient:
|
|
253
320
|
else:
|
254
321
|
filter_string = additional_filter
|
255
322
|
|
256
|
-
|
257
|
-
|
258
|
-
|
323
|
+
def download_trace_extra_fields(
|
324
|
+
trace_info: Union[TraceInfoV2, TraceInfo],
|
325
|
+
) -> Optional[Trace]:
|
259
326
|
"""
|
260
327
|
Download trace data and assessments for the given trace_info and returns a Trace object.
|
261
328
|
If the download fails (e.g., the trace data is missing or corrupted), returns None.
|
262
329
|
|
263
330
|
The trace_info parameter can be either TraceInfo or TraceInfoV3 object.
|
264
331
|
"""
|
265
|
-
|
332
|
+
from mlflow.entities.trace_info import TraceInfo
|
333
|
+
|
334
|
+
# Determine if this is TraceInfo or TraceInfoV3
|
335
|
+
# Helps while transitioning to V3 traces for offline & online
|
336
|
+
is_v3 = isinstance(trace_info, TraceInfo)
|
337
|
+
trace_id = trace_info.trace_id if is_v3 else trace_info.request_id
|
338
|
+
is_online_trace = is_uuid(trace_id)
|
266
339
|
|
267
340
|
# For online traces in Databricks, we need to get trace data from a different endpoint
|
268
341
|
try:
|
269
342
|
if is_databricks and is_online_trace:
|
270
343
|
# For online traces, get data from the online API
|
271
344
|
trace_data = self.get_online_trace_details(
|
272
|
-
trace_id=
|
345
|
+
trace_id=trace_id,
|
273
346
|
sql_warehouse_id=sql_warehouse_id,
|
274
347
|
source_inference_table=trace_info.request_metadata.get(
|
275
348
|
"mlflow.sourceTable"
|
@@ -285,7 +358,7 @@ class TracingClient:
|
|
285
358
|
except MlflowTraceDataException as e:
|
286
359
|
_logger.warning(
|
287
360
|
(
|
288
|
-
f"Failed to download trace data for trace {
|
361
|
+
f"Failed to download trace data for trace {trace_id!r} "
|
289
362
|
f"with {e.ctx}. For full traceback, set logging level to DEBUG."
|
290
363
|
),
|
291
364
|
exc_info=_logger.isEnabledFor(logging.DEBUG),
|
@@ -299,11 +372,7 @@ class TracingClient:
|
|
299
372
|
next_token = page_token
|
300
373
|
|
301
374
|
max_workers = MLFLOW_SEARCH_TRACES_MAX_THREADS.get()
|
302
|
-
executor = (
|
303
|
-
ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="MlflowTracingSearch")
|
304
|
-
if include_spans
|
305
|
-
else nullcontext()
|
306
|
-
)
|
375
|
+
executor = ThreadPoolExecutor(max_workers=max_workers) if include_spans else nullcontext()
|
307
376
|
with executor:
|
308
377
|
while len(traces) < max_results:
|
309
378
|
trace_infos, next_token = self._search_traces(
|
@@ -332,24 +401,24 @@ class TracingClient:
|
|
332
401
|
|
333
402
|
return PagedList(traces, next_token)
|
334
403
|
|
335
|
-
def set_trace_tags(self,
|
404
|
+
def set_trace_tags(self, request_id, tags):
|
336
405
|
"""
|
337
|
-
Set tags on the trace with the given
|
406
|
+
Set tags on the trace with the given request_id.
|
338
407
|
|
339
408
|
Args:
|
340
|
-
|
409
|
+
request_id: The ID of the trace.
|
341
410
|
tags: A dictionary of key-value pairs.
|
342
411
|
"""
|
343
412
|
tags = exclude_immutable_tags(tags)
|
344
413
|
for k, v in tags.items():
|
345
|
-
self.set_trace_tag(
|
414
|
+
self.set_trace_tag(request_id, k, v)
|
346
415
|
|
347
|
-
def set_trace_tag(self,
|
416
|
+
def set_trace_tag(self, request_id, key, value):
|
348
417
|
"""
|
349
418
|
Set a tag on the trace with the given trace ID.
|
350
419
|
|
351
420
|
Args:
|
352
|
-
|
421
|
+
request_id: The ID of the trace to set the tag on.
|
353
422
|
key: The string key of the tag. Must be at most 250 characters long, otherwise
|
354
423
|
it will be truncated when stored.
|
355
424
|
value: The string value of the tag. Must be at most 250 characters long, otherwise
|
@@ -362,7 +431,7 @@ class TracingClient:
|
|
362
431
|
)
|
363
432
|
|
364
433
|
# Trying to set the tag on the active trace first
|
365
|
-
with InMemoryTraceManager.get_instance().get_trace(
|
434
|
+
with InMemoryTraceManager.get_instance().get_trace(request_id) as trace:
|
366
435
|
if trace:
|
367
436
|
trace.info.tags[key] = str(value)
|
368
437
|
return
|
@@ -370,33 +439,33 @@ class TracingClient:
|
|
370
439
|
if key in IMMUTABLE_TAGS:
|
371
440
|
_logger.warning(f"Tag '{key}' is immutable and cannot be set on a trace.")
|
372
441
|
else:
|
373
|
-
self.store.set_trace_tag(
|
442
|
+
self.store.set_trace_tag(request_id, key, str(value))
|
374
443
|
|
375
|
-
def delete_trace_tag(self,
|
444
|
+
def delete_trace_tag(self, request_id, key):
|
376
445
|
"""
|
377
446
|
Delete a tag on the trace with the given trace ID.
|
378
447
|
|
379
448
|
Args:
|
380
|
-
|
449
|
+
request_id: The ID of the trace to delete the tag from.
|
381
450
|
key: The string key of the tag. Must be at most 250 characters long, otherwise
|
382
451
|
it will be truncated when stored.
|
383
452
|
"""
|
384
453
|
# Trying to delete the tag on the active trace first
|
385
|
-
with InMemoryTraceManager.get_instance().get_trace(
|
454
|
+
with InMemoryTraceManager.get_instance().get_trace(request_id) as trace:
|
386
455
|
if trace:
|
387
456
|
if key in trace.info.tags:
|
388
457
|
trace.info.tags.pop(key)
|
389
458
|
return
|
390
459
|
else:
|
391
460
|
raise MlflowException(
|
392
|
-
f"Tag with key {key} not found in trace with ID {
|
461
|
+
f"Tag with key {key} not found in trace with ID {request_id}.",
|
393
462
|
error_code=RESOURCE_DOES_NOT_EXIST,
|
394
463
|
)
|
395
464
|
|
396
465
|
if key in IMMUTABLE_TAGS:
|
397
466
|
_logger.warning(f"Tag '{key}' is immutable and cannot be deleted on a trace.")
|
398
467
|
else:
|
399
|
-
self.store.delete_trace_tag(
|
468
|
+
self.store.delete_trace_tag(request_id, key)
|
400
469
|
|
401
470
|
def get_assessment(self, trace_id: str, assessment_id: str) -> Assessment:
|
402
471
|
"""
|
@@ -492,12 +561,12 @@ class TracingClient:
|
|
492
561
|
|
493
562
|
self.store.delete_assessment(trace_id=trace_id, assessment_id=assessment_id)
|
494
563
|
|
495
|
-
def _get_artifact_repo_for_trace(self, trace_info:
|
564
|
+
def _get_artifact_repo_for_trace(self, trace_info: TraceInfoV2):
|
496
565
|
artifact_uri = get_artifact_uri_for_trace(trace_info)
|
497
566
|
artifact_uri = add_databricks_profile_info_to_artifact_uri(artifact_uri, self.tracking_uri)
|
498
567
|
return get_artifact_repository(artifact_uri)
|
499
568
|
|
500
|
-
def _download_trace_data(self, trace_info: TraceInfo) -> TraceData:
|
569
|
+
def _download_trace_data(self, trace_info: Union[TraceInfoV2, TraceInfo]) -> TraceData:
|
501
570
|
"""
|
502
571
|
Download trace data from artifact repository.
|
503
572
|
|
@@ -510,11 +579,32 @@ class TracingClient:
|
|
510
579
|
artifact_repo = self._get_artifact_repo_for_trace(trace_info)
|
511
580
|
return TraceData.from_dict(artifact_repo.download_trace_data())
|
512
581
|
|
513
|
-
def _upload_trace_data(self, trace_info:
|
582
|
+
def _upload_trace_data(self, trace_info: TraceInfoV2, trace_data: TraceData) -> None:
|
514
583
|
artifact_repo = self._get_artifact_repo_for_trace(trace_info)
|
515
584
|
trace_data_json = json.dumps(trace_data.to_dict(), cls=TraceJSONEncoder, ensure_ascii=False)
|
516
585
|
return artifact_repo.upload_trace_data(trace_data_json)
|
517
586
|
|
587
|
+
def _upload_ended_trace_info(
|
588
|
+
self,
|
589
|
+
trace_info: TraceInfoV2,
|
590
|
+
) -> TraceInfoV2:
|
591
|
+
"""
|
592
|
+
Update the TraceInfo object in the backend store with the completed trace info.
|
593
|
+
|
594
|
+
Args:
|
595
|
+
trace_info: Updated TraceInfo object to be stored in the backend store.
|
596
|
+
|
597
|
+
Returns:
|
598
|
+
The updated TraceInfo object.
|
599
|
+
"""
|
600
|
+
return self.end_trace(
|
601
|
+
request_id=trace_info.request_id,
|
602
|
+
timestamp_ms=trace_info.timestamp_ms + trace_info.execution_time_ms,
|
603
|
+
status=trace_info.status,
|
604
|
+
request_metadata=trace_info.request_metadata,
|
605
|
+
tags=trace_info.tags or {},
|
606
|
+
)
|
607
|
+
|
518
608
|
def link_prompt_versions_to_trace(
|
519
609
|
self, trace_id: str, prompts: Sequence[PromptVersion]
|
520
610
|
) -> None:
|