wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/importers/base.py +20 -5
- wandb/apis/importers/mlflow.py +7 -1
- wandb/apis/internal.py +12 -0
- wandb/apis/public.py +247 -1387
- wandb/apis/reports/_panels.py +58 -35
- wandb/beta/workflows.py +6 -7
- wandb/cli/cli.py +130 -60
- wandb/data_types.py +3 -1
- wandb/filesync/dir_watcher.py +21 -27
- wandb/filesync/step_checksum.py +8 -8
- wandb/filesync/step_prepare.py +23 -10
- wandb/filesync/step_upload.py +13 -13
- wandb/filesync/upload_job.py +4 -8
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/gym/__init__.py +4 -6
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/openai/__init__.py +1 -3
- wandb/integration/openai/openai.py +11 -143
- wandb/integration/openai/resolver.py +111 -38
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/old/settings.py +24 -7
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/artifact.py +2101 -0
- wandb/sdk/artifacts/artifact_download_logger.py +42 -0
- wandb/sdk/artifacts/artifact_manifest.py +67 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
- wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
- wandb/sdk/artifacts/exceptions.py +55 -0
- wandb/sdk/artifacts/storage_handler.py +59 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
- wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +3 -2
- wandb/sdk/data_types/base_types/media.py +8 -8
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
- wandb/sdk/data_types/helper_types/classes.py +6 -8
- wandb/sdk/data_types/helper_types/image_mask.py +5 -6
- wandb/sdk/data_types/histogram.py +4 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +11 -9
- wandb/sdk/data_types/molecule.py +5 -3
- wandb/sdk/data_types/object_3d.py +7 -5
- wandb/sdk/data_types/plotly.py +3 -2
- wandb/sdk/data_types/saved_model.py +11 -11
- wandb/sdk/data_types/trace_tree.py +5 -4
- wandb/sdk/data_types/utils.py +3 -5
- wandb/sdk/data_types/video.py +5 -4
- wandb/sdk/integration_utils/auto_logging.py +215 -0
- wandb/sdk/interface/interface.py +15 -15
- wandb/sdk/internal/file_pusher.py +8 -16
- wandb/sdk/internal/file_stream.py +5 -11
- wandb/sdk/internal/handler.py +13 -1
- wandb/sdk/internal/internal_api.py +287 -13
- wandb/sdk/internal/job_builder.py +119 -30
- wandb/sdk/internal/sender.py +6 -26
- wandb/sdk/internal/settings_static.py +2 -0
- wandb/sdk/internal/system/assets/__init__.py +2 -0
- wandb/sdk/internal/system/assets/gpu.py +42 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +3 -3
- wandb/sdk/internal/tb_watcher.py +32 -22
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/launch/_project_spec.py +57 -11
- wandb/sdk/launch/agent/agent.py +147 -65
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +21 -18
- wandb/sdk/launch/builder/docker_builder.py +10 -4
- wandb/sdk/launch/builder/kaniko_builder.py +113 -23
- wandb/sdk/launch/builder/noop.py +6 -3
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
- wandb/sdk/launch/environment/aws_environment.py +3 -2
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/environment/gcp_environment.py +2 -4
- wandb/sdk/launch/environment/local_environment.py +1 -1
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/github_reference.py +32 -19
- wandb/sdk/launch/launch.py +3 -8
- wandb/sdk/launch/launch_add.py +6 -2
- wandb/sdk/launch/loader.py +21 -2
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
- wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
- wandb/sdk/launch/registry/local_registry.py +2 -1
- wandb/sdk/launch/runner/abstract.py +24 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
- wandb/sdk/launch/runner/local_container.py +103 -51
- wandb/sdk/launch/runner/local_process.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
- wandb/sdk/launch/runner/vertex_runner.py +10 -5
- wandb/sdk/launch/sweeps/__init__.py +7 -9
- wandb/sdk/launch/sweeps/scheduler.py +307 -77
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +82 -35
- wandb/sdk/launch/utils.py +89 -75
- wandb/sdk/lib/_settings_toposort_generated.py +7 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/{git.py → gitlib.py} +76 -59
- wandb/sdk/lib/hashutil.py +12 -4
- wandb/sdk/lib/paths.py +96 -8
- wandb/sdk/lib/sock_client.py +2 -2
- wandb/sdk/lib/timer.py +1 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +4 -7
- wandb/sdk/wandb_config.py +2 -6
- wandb/sdk/wandb_init.py +57 -53
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +61 -223
- wandb/sdk/wandb_settings.py +28 -4
- wandb/testing/relay.py +15 -2
- wandb/util.py +74 -36
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/interface/artifacts/__init__.py +0 -33
- wandb/sdk/interface/artifacts/artifact.py +0 -615
- wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
- wandb/sdk/wandb_artifacts.py +0 -2226
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
wandb/filesync/step_checksum.py
CHANGED
@@ -8,25 +8,27 @@ import shutil
|
|
8
8
|
import threading
|
9
9
|
from typing import TYPE_CHECKING, NamedTuple, Optional, Union, cast
|
10
10
|
|
11
|
-
from wandb.filesync import
|
11
|
+
from wandb.filesync import step_upload
|
12
12
|
from wandb.sdk.lib import filesystem, runid
|
13
|
+
from wandb.sdk.lib.paths import LogicalPath
|
13
14
|
|
14
15
|
if TYPE_CHECKING:
|
15
16
|
import tempfile
|
16
17
|
|
17
18
|
from wandb.filesync import stats
|
18
|
-
from wandb.sdk.
|
19
|
-
from wandb.sdk.
|
19
|
+
from wandb.sdk.artifacts import artifact_saver
|
20
|
+
from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
|
21
|
+
from wandb.sdk.internal import internal_api
|
20
22
|
|
21
23
|
|
22
24
|
class RequestUpload(NamedTuple):
|
23
25
|
path: str
|
24
|
-
save_name:
|
26
|
+
save_name: LogicalPath
|
25
27
|
copy: bool
|
26
28
|
|
27
29
|
|
28
30
|
class RequestStoreManifestFiles(NamedTuple):
|
29
|
-
manifest: "
|
31
|
+
manifest: "ArtifactManifest"
|
30
32
|
artifact_id: str
|
31
33
|
save_fn: "artifact_saver.SaveFn"
|
32
34
|
save_fn_async: "artifact_saver.SaveFnAsync"
|
@@ -108,9 +110,7 @@ class StepChecksum:
|
|
108
110
|
self._output_queue.put(
|
109
111
|
step_upload.RequestUpload(
|
110
112
|
entry.local_path,
|
111
|
-
|
112
|
-
entry.path
|
113
|
-
), # typecast might not be legit
|
113
|
+
entry.path,
|
114
114
|
req.artifact_id,
|
115
115
|
entry.digest,
|
116
116
|
False,
|
wandb/filesync/step_prepare.py
CHANGED
@@ -8,6 +8,7 @@ import time
|
|
8
8
|
from typing import (
|
9
9
|
TYPE_CHECKING,
|
10
10
|
Callable,
|
11
|
+
Dict,
|
11
12
|
List,
|
12
13
|
Mapping,
|
13
14
|
NamedTuple,
|
@@ -39,9 +40,12 @@ class RequestFinish(NamedTuple):
|
|
39
40
|
|
40
41
|
|
41
42
|
class ResponsePrepare(NamedTuple):
|
43
|
+
birth_artifact_id: str
|
42
44
|
upload_url: Optional[str]
|
43
45
|
upload_headers: Sequence[str]
|
44
|
-
|
46
|
+
upload_id: Optional[str]
|
47
|
+
storage_path: Optional[str]
|
48
|
+
multipart_upload_urls: Optional[Dict[int, str]]
|
45
49
|
|
46
50
|
|
47
51
|
Request = Union[RequestPrepare, RequestFinish]
|
@@ -88,6 +92,21 @@ def gather_batch(
|
|
88
92
|
return False, batch
|
89
93
|
|
90
94
|
|
95
|
+
def prepare_response(response: "CreateArtifactFilesResponseFile") -> ResponsePrepare:
|
96
|
+
multipart_resp = response.get("uploadMultipartUrls")
|
97
|
+
part_list = multipart_resp["uploadUrlParts"] if multipart_resp else []
|
98
|
+
multipart_parts = {u["partNumber"]: u["uploadUrl"] for u in part_list} or None
|
99
|
+
|
100
|
+
return ResponsePrepare(
|
101
|
+
birth_artifact_id=response["artifact"]["id"],
|
102
|
+
upload_url=response["uploadUrl"],
|
103
|
+
upload_headers=response["uploadHeaders"],
|
104
|
+
upload_id=multipart_resp and multipart_resp.get("uploadID"),
|
105
|
+
storage_path=response.get("storagePath"),
|
106
|
+
multipart_upload_urls=multipart_parts,
|
107
|
+
)
|
108
|
+
|
109
|
+
|
91
110
|
class StepPrepare:
|
92
111
|
"""A thread that batches requests to our file prepare API.
|
93
112
|
|
@@ -120,18 +139,12 @@ class StepPrepare:
|
|
120
139
|
max_batch_size=self._max_batch_size,
|
121
140
|
)
|
122
141
|
if batch:
|
123
|
-
|
142
|
+
batch_response = self._prepare_batch(batch)
|
124
143
|
# send responses
|
125
144
|
for prepare_request in batch:
|
126
145
|
name = prepare_request.file_spec["name"]
|
127
|
-
response_file =
|
128
|
-
|
129
|
-
upload_headers = response_file["uploadHeaders"]
|
130
|
-
birth_artifact_id = response_file["artifact"]["id"]
|
131
|
-
|
132
|
-
response = ResponsePrepare(
|
133
|
-
upload_url, upload_headers, birth_artifact_id
|
134
|
-
)
|
146
|
+
response_file = batch_response[name]
|
147
|
+
response = prepare_response(response_file)
|
135
148
|
if isinstance(prepare_request.response_channel, queue.Queue):
|
136
149
|
prepare_request.response_channel.put(response)
|
137
150
|
else:
|
wandb/filesync/step_upload.py
CHANGED
@@ -20,9 +20,10 @@ from typing import (
|
|
20
20
|
|
21
21
|
from wandb.errors.term import termerror
|
22
22
|
from wandb.filesync import upload_job
|
23
|
+
from wandb.sdk.lib.paths import LogicalPath
|
23
24
|
|
24
25
|
if TYPE_CHECKING:
|
25
|
-
from wandb.filesync import
|
26
|
+
from wandb.filesync import stats
|
26
27
|
from wandb.sdk.internal import file_stream, internal_api, progress
|
27
28
|
from wandb.sdk.internal.settings_static import SettingsStatic
|
28
29
|
|
@@ -49,7 +50,7 @@ logger = logging.getLogger(__name__)
|
|
49
50
|
|
50
51
|
class RequestUpload(NamedTuple):
|
51
52
|
path: str
|
52
|
-
save_name:
|
53
|
+
save_name: LogicalPath
|
53
54
|
artifact_id: Optional[str]
|
54
55
|
md5: Optional[str]
|
55
56
|
copied: bool
|
@@ -69,9 +70,12 @@ class RequestFinish(NamedTuple):
|
|
69
70
|
callback: Optional[OnRequestFinishFn]
|
70
71
|
|
71
72
|
|
72
|
-
|
73
|
-
RequestUpload
|
74
|
-
]
|
73
|
+
class EventJobDone(NamedTuple):
|
74
|
+
job: RequestUpload
|
75
|
+
exc: Optional[BaseException]
|
76
|
+
|
77
|
+
|
78
|
+
Event = Union[RequestUpload, RequestCommitArtifact, RequestFinish, EventJobDone]
|
75
79
|
|
76
80
|
|
77
81
|
class AsyncExecutor:
|
@@ -148,7 +152,7 @@ class StepUpload:
|
|
148
152
|
)
|
149
153
|
|
150
154
|
# Indexed by files' `save_name`'s, which are their ID's in the Run.
|
151
|
-
self._running_jobs: MutableMapping[
|
155
|
+
self._running_jobs: MutableMapping[LogicalPath, RequestUpload] = {}
|
152
156
|
self._pending_jobs: MutableSequence[RequestUpload] = []
|
153
157
|
|
154
158
|
self._artifacts: MutableMapping[str, "ArtifactStatus"] = {}
|
@@ -189,7 +193,7 @@ class StepUpload:
|
|
189
193
|
break
|
190
194
|
|
191
195
|
def _handle_event(self, event: Event) -> None:
|
192
|
-
if isinstance(event,
|
196
|
+
if isinstance(event, EventJobDone):
|
193
197
|
job = event.job
|
194
198
|
|
195
199
|
if event.exc is not None:
|
@@ -283,9 +287,7 @@ class StepUpload:
|
|
283
287
|
try:
|
284
288
|
self._do_upload_sync(event)
|
285
289
|
finally:
|
286
|
-
self._event_queue.put(
|
287
|
-
upload_job.EventJobDone(event, exc=sys.exc_info()[1])
|
288
|
-
)
|
290
|
+
self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
|
289
291
|
|
290
292
|
self._pool.submit(run_and_notify)
|
291
293
|
|
@@ -307,9 +309,7 @@ class StepUpload:
|
|
307
309
|
try:
|
308
310
|
await self._do_upload_async(event)
|
309
311
|
finally:
|
310
|
-
self._event_queue.put(
|
311
|
-
upload_job.EventJobDone(event, exc=sys.exc_info()[1])
|
312
|
-
)
|
312
|
+
self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
|
313
313
|
|
314
314
|
async_executor.submit(run_and_notify())
|
315
315
|
|
wandb/filesync/upload_job.py
CHANGED
@@ -1,20 +1,16 @@
|
|
1
1
|
import asyncio
|
2
2
|
import logging
|
3
3
|
import os
|
4
|
-
from typing import TYPE_CHECKING,
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
5
5
|
|
6
6
|
import wandb
|
7
|
+
from wandb.sdk.lib.paths import LogicalPath
|
7
8
|
|
8
9
|
if TYPE_CHECKING:
|
9
10
|
from wandb.filesync import dir_watcher, stats, step_upload
|
10
11
|
from wandb.sdk.internal import file_stream, internal_api
|
11
12
|
|
12
13
|
|
13
|
-
class EventJobDone(NamedTuple):
|
14
|
-
job: "step_upload.RequestUpload"
|
15
|
-
exc: Optional[BaseException]
|
16
|
-
|
17
|
-
|
18
14
|
logger = logging.getLogger(__name__)
|
19
15
|
|
20
16
|
|
@@ -25,7 +21,7 @@ class UploadJob:
|
|
25
21
|
api: "internal_api.Api",
|
26
22
|
file_stream: "file_stream.FileStreamApi",
|
27
23
|
silent: bool,
|
28
|
-
save_name:
|
24
|
+
save_name: LogicalPath,
|
29
25
|
path: "dir_watcher.PathStr",
|
30
26
|
artifact_id: Optional[str],
|
31
27
|
md5: Optional[str],
|
@@ -47,7 +43,7 @@ class UploadJob:
|
|
47
43
|
self._file_stream = file_stream
|
48
44
|
self.silent = silent
|
49
45
|
self.save_name = save_name
|
50
|
-
self.save_path =
|
46
|
+
self.save_path = path
|
51
47
|
self.artifact_id = artifact_id
|
52
48
|
self.md5 = md5
|
53
49
|
self.copied = copied
|
@@ -0,0 +1,21 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
from wandb.sdk.integration_utils.auto_logging import AutologAPI
|
4
|
+
|
5
|
+
from .resolver import CohereRequestResponseResolver
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
autolog = AutologAPI(
|
11
|
+
name="Cohere",
|
12
|
+
symbols=(
|
13
|
+
"Client.generate",
|
14
|
+
"Client.chat",
|
15
|
+
"Client.classify",
|
16
|
+
"Client.summarize",
|
17
|
+
"Client.rerank",
|
18
|
+
),
|
19
|
+
resolver=CohereRequestResponseResolver(),
|
20
|
+
telemetry_feature="cohere_autolog",
|
21
|
+
)
|
@@ -0,0 +1,347 @@
|
|
1
|
+
import logging
|
2
|
+
from datetime import datetime
|
3
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
4
|
+
|
5
|
+
import wandb
|
6
|
+
from wandb.sdk.integration_utils.auto_logging import Response
|
7
|
+
from wandb.sdk.lib.runid import generate_id
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
def subset_dict(
|
13
|
+
original_dict: Dict[str, Any], keys_subset: Sequence[str]
|
14
|
+
) -> Dict[str, Any]:
|
15
|
+
"""Create a subset of a dictionary using a subset of keys.
|
16
|
+
|
17
|
+
:param original_dict: The original dictionary.
|
18
|
+
:param keys_subset: The subset of keys to extract.
|
19
|
+
:return: A dictionary containing only the specified keys.
|
20
|
+
"""
|
21
|
+
return {key: original_dict[key] for key in keys_subset if key in original_dict}
|
22
|
+
|
23
|
+
|
24
|
+
def reorder_and_convert_dict_list_to_table(
|
25
|
+
data: List[Dict[str, Any]], order: List[str]
|
26
|
+
) -> Tuple[List[str], List[List[Any]]]:
|
27
|
+
"""Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries.
|
28
|
+
|
29
|
+
:param data: A list of dictionaries.
|
30
|
+
:param order: A list of keys specifying the desired order for specific dictionaries. The remaining dictionaries will be ordered based on their original order.
|
31
|
+
:return: A pair of column names and corresponding values.
|
32
|
+
"""
|
33
|
+
final_columns = []
|
34
|
+
keys_present = set()
|
35
|
+
|
36
|
+
# First, add all ordered keys to the final columns
|
37
|
+
for key in order:
|
38
|
+
if key not in keys_present:
|
39
|
+
final_columns.append(key)
|
40
|
+
keys_present.add(key)
|
41
|
+
|
42
|
+
# Then, add any keys present in the dictionaries but not in the order
|
43
|
+
for d in data:
|
44
|
+
for key in d:
|
45
|
+
if key not in keys_present:
|
46
|
+
final_columns.append(key)
|
47
|
+
keys_present.add(key)
|
48
|
+
|
49
|
+
# Then, construct the table of values
|
50
|
+
values = []
|
51
|
+
for d in data:
|
52
|
+
row = []
|
53
|
+
for key in final_columns:
|
54
|
+
row.append(d.get(key, None))
|
55
|
+
values.append(row)
|
56
|
+
|
57
|
+
return final_columns, values
|
58
|
+
|
59
|
+
|
60
|
+
def flatten_dict(
|
61
|
+
dictionary: Dict[str, Any], parent_key: str = "", sep: str = "-"
|
62
|
+
) -> Dict[str, Any]:
|
63
|
+
"""Flatten a nested dictionary, joining keys using a specified separator.
|
64
|
+
|
65
|
+
:param dictionary: The dictionary to flatten.
|
66
|
+
:param parent_key: The base key to prepend to each key.
|
67
|
+
:param sep: The separator to use when joining keys.
|
68
|
+
:return: A flattened dictionary.
|
69
|
+
"""
|
70
|
+
flattened_dict = {}
|
71
|
+
for key, value in dictionary.items():
|
72
|
+
new_key = f"{parent_key}{sep}{key}" if parent_key else key
|
73
|
+
if isinstance(value, dict):
|
74
|
+
flattened_dict.update(flatten_dict(value, new_key, sep=sep))
|
75
|
+
else:
|
76
|
+
flattened_dict[new_key] = value
|
77
|
+
return flattened_dict
|
78
|
+
|
79
|
+
|
80
|
+
def collect_common_keys(list_of_dicts: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
|
81
|
+
"""Collect the common keys of a list of dictionaries. For each common key, put its values into a list in the order they appear in the original dictionaries.
|
82
|
+
|
83
|
+
:param list_of_dicts: The list of dictionaries to inspect.
|
84
|
+
:return: A dictionary with each common key and its corresponding list of values.
|
85
|
+
"""
|
86
|
+
common_keys = set.intersection(*map(set, list_of_dicts))
|
87
|
+
common_dict = {key: [] for key in common_keys}
|
88
|
+
for d in list_of_dicts:
|
89
|
+
for key in common_keys:
|
90
|
+
common_dict[key].append(d[key])
|
91
|
+
return common_dict
|
92
|
+
|
93
|
+
|
94
|
+
class CohereRequestResponseResolver:
|
95
|
+
"""Class to resolve the request/response from the Cohere API and convert it to a dictionary that can be logged."""
|
96
|
+
|
97
|
+
def __call__(
|
98
|
+
self,
|
99
|
+
args: Sequence[Any],
|
100
|
+
kwargs: Dict[str, Any],
|
101
|
+
response: Response,
|
102
|
+
start_time: float,
|
103
|
+
time_elapsed: float,
|
104
|
+
) -> Optional[Dict[str, Any]]:
|
105
|
+
"""Process the response from the Cohere API and convert it to a dictionary that can be logged.
|
106
|
+
|
107
|
+
:param args: The arguments of the original function.
|
108
|
+
:param kwargs: The keyword arguments of the original function.
|
109
|
+
:param response: The response from the Cohere API.
|
110
|
+
:param start_time: The start time of the request.
|
111
|
+
:param time_elapsed: The time elapsed for the request.
|
112
|
+
:return: A dictionary containing the parsed response and timing information.
|
113
|
+
"""
|
114
|
+
try:
|
115
|
+
# Each of the different endpoints map to one specific response type
|
116
|
+
# We want to 'type check' the response without directly importing the packages type
|
117
|
+
# It may make more sense to pass the invoked symbol from the AutologAPI instead
|
118
|
+
response_type = str(type(response)).split("'")[1].split(".")[-1]
|
119
|
+
|
120
|
+
# Initialize parsed_response to None to handle the case where the response type is unsupported
|
121
|
+
parsed_response = None
|
122
|
+
if response_type == "Generations":
|
123
|
+
parsed_response = self._resolve_generate_response(response)
|
124
|
+
# TODO: Remove hard-coded default model name
|
125
|
+
table_column_order = [
|
126
|
+
"start_time",
|
127
|
+
"query_id",
|
128
|
+
"model",
|
129
|
+
"prompt",
|
130
|
+
"text",
|
131
|
+
"token_likelihoods",
|
132
|
+
"likelihood",
|
133
|
+
"time_elapsed_(seconds)",
|
134
|
+
"end_time",
|
135
|
+
]
|
136
|
+
default_model = "command"
|
137
|
+
elif response_type == "Chat":
|
138
|
+
parsed_response = self._resolve_chat_response(response)
|
139
|
+
table_column_order = [
|
140
|
+
"start_time",
|
141
|
+
"query_id",
|
142
|
+
"model",
|
143
|
+
"conversation_id",
|
144
|
+
"response_id",
|
145
|
+
"query",
|
146
|
+
"text",
|
147
|
+
"prompt",
|
148
|
+
"preamble",
|
149
|
+
"chat_history",
|
150
|
+
"chatlog",
|
151
|
+
"time_elapsed_(seconds)",
|
152
|
+
"end_time",
|
153
|
+
]
|
154
|
+
default_model = "command"
|
155
|
+
elif response_type == "Classifications":
|
156
|
+
parsed_response = self._resolve_classify_response(response)
|
157
|
+
kwargs = self._resolve_classify_kwargs(kwargs)
|
158
|
+
table_column_order = [
|
159
|
+
"start_time",
|
160
|
+
"query_id",
|
161
|
+
"model",
|
162
|
+
"id",
|
163
|
+
"input",
|
164
|
+
"prediction",
|
165
|
+
"confidence",
|
166
|
+
"time_elapsed_(seconds)",
|
167
|
+
"end_time",
|
168
|
+
]
|
169
|
+
default_model = "embed-english-v2.0"
|
170
|
+
elif response_type == "SummarizeResponse":
|
171
|
+
parsed_response = self._resolve_summarize_response(response)
|
172
|
+
table_column_order = [
|
173
|
+
"start_time",
|
174
|
+
"query_id",
|
175
|
+
"model",
|
176
|
+
"response_id",
|
177
|
+
"text",
|
178
|
+
"additional_command",
|
179
|
+
"summary",
|
180
|
+
"time_elapsed_(seconds)",
|
181
|
+
"end_time",
|
182
|
+
"length",
|
183
|
+
"format",
|
184
|
+
]
|
185
|
+
default_model = "summarize-xlarge"
|
186
|
+
elif response_type == "Reranking":
|
187
|
+
parsed_response = self._resolve_rerank_response(response)
|
188
|
+
table_column_order = [
|
189
|
+
"start_time",
|
190
|
+
"query_id",
|
191
|
+
"model",
|
192
|
+
"id",
|
193
|
+
"query",
|
194
|
+
"top_n",
|
195
|
+
# This is a nested dict key that got flattened
|
196
|
+
"document-text",
|
197
|
+
"relevance_score",
|
198
|
+
"index",
|
199
|
+
"time_elapsed_(seconds)",
|
200
|
+
"end_time",
|
201
|
+
]
|
202
|
+
default_model = "rerank-english-v2.0"
|
203
|
+
else:
|
204
|
+
logger.info(f"Unsupported Cohere response object: {response}")
|
205
|
+
|
206
|
+
return self._resolve(
|
207
|
+
args,
|
208
|
+
kwargs,
|
209
|
+
parsed_response,
|
210
|
+
start_time,
|
211
|
+
time_elapsed,
|
212
|
+
response_type,
|
213
|
+
table_column_order,
|
214
|
+
default_model,
|
215
|
+
)
|
216
|
+
except Exception as e:
|
217
|
+
logger.warning(f"Failed to resolve request/response: {e}")
|
218
|
+
return None
|
219
|
+
|
220
|
+
# These helper functions process the response from different endpoints of the Cohere API.
|
221
|
+
# Since the response objects for different endpoints have different structures,
|
222
|
+
# we need different logic to process them.
|
223
|
+
|
224
|
+
def _resolve_generate_response(self, response: Response) -> List[Dict[str, Any]]:
|
225
|
+
return_list = []
|
226
|
+
for _response in response:
|
227
|
+
# Built in Cohere.*.Generations function to color token_likelihoods and return a dict of response data
|
228
|
+
_response_dict = _response._visualize_helper()
|
229
|
+
try:
|
230
|
+
_response_dict["token_likelihoods"] = wandb.Html(
|
231
|
+
_response_dict["token_likelihoods"]
|
232
|
+
)
|
233
|
+
except (KeyError, ValueError):
|
234
|
+
pass
|
235
|
+
return_list.append(_response_dict)
|
236
|
+
|
237
|
+
return return_list
|
238
|
+
|
239
|
+
def _resolve_chat_response(self, response: Response) -> List[Dict[str, Any]]:
|
240
|
+
return [
|
241
|
+
subset_dict(
|
242
|
+
response.__dict__,
|
243
|
+
[
|
244
|
+
"response_id",
|
245
|
+
"generation_id",
|
246
|
+
"query",
|
247
|
+
"text",
|
248
|
+
"conversation_id",
|
249
|
+
"prompt",
|
250
|
+
"chatlog",
|
251
|
+
"preamble",
|
252
|
+
],
|
253
|
+
)
|
254
|
+
]
|
255
|
+
|
256
|
+
def _resolve_classify_response(self, response: Response) -> List[Dict[str, Any]]:
|
257
|
+
# The labels key is a dict returning the scores for the classification probability for each label provided
|
258
|
+
# We flatten this nested dict for ease of consumption in the wandb UI
|
259
|
+
return [flatten_dict(_response.__dict__) for _response in response]
|
260
|
+
|
261
|
+
def _resolve_classify_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
262
|
+
# Example texts look strange when rendered in Wandb UI as it is a list of text and label
|
263
|
+
# We extract each value into its own column
|
264
|
+
example_texts = []
|
265
|
+
example_labels = []
|
266
|
+
for example in kwargs["examples"]:
|
267
|
+
example_texts.append(example.text)
|
268
|
+
example_labels.append(example.label)
|
269
|
+
kwargs.pop("examples")
|
270
|
+
kwargs["example_texts"] = example_texts
|
271
|
+
kwargs["example_labels"] = example_labels
|
272
|
+
return kwargs
|
273
|
+
|
274
|
+
def _resolve_summarize_response(self, response: Response) -> List[Dict[str, Any]]:
|
275
|
+
return [{"response_id": response.id, "summary": response.summary}]
|
276
|
+
|
277
|
+
def _resolve_rerank_response(self, response: Response) -> List[Dict[str, Any]]:
|
278
|
+
# The documents key contains a dict containing the content of the document which is at least "text"
|
279
|
+
# We flatten this nested dict for ease of consumption in the wandb UI
|
280
|
+
flattened_response_dicts = [
|
281
|
+
flatten_dict(_response.__dict__) for _response in response
|
282
|
+
]
|
283
|
+
# ReRank returns each document provided a top_n value so we aggregate into one view so users can paginate a row
|
284
|
+
# As opposed to each row being one of the top_n responses
|
285
|
+
return_dict = collect_common_keys(flattened_response_dicts)
|
286
|
+
return_dict["id"] = response.id
|
287
|
+
return [return_dict]
|
288
|
+
|
289
|
+
def _resolve(
|
290
|
+
self,
|
291
|
+
args: Sequence[Any],
|
292
|
+
kwargs: Dict[str, Any],
|
293
|
+
parsed_response: List[Dict[str, Any]],
|
294
|
+
start_time: float,
|
295
|
+
time_elapsed: float,
|
296
|
+
response_type: str,
|
297
|
+
table_column_order: List[str],
|
298
|
+
default_model: str,
|
299
|
+
) -> Dict[str, Any]:
|
300
|
+
"""Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries.
|
301
|
+
|
302
|
+
:param args: The arguments passed to the API client.
|
303
|
+
:param kwargs: The keyword arguments passed to the API client.
|
304
|
+
:param parsed_response: The parsed response from the API.
|
305
|
+
:param start_time: The start time of the API request.
|
306
|
+
:param time_elapsed: The time elapsed during the API request.
|
307
|
+
:param response_type: The type of the API response.
|
308
|
+
:param table_column_order: The desired order of columns in the resulting table.
|
309
|
+
:param default_model: The default model to use if not specified in the response.
|
310
|
+
:return: A dictionary containing the formatted response.
|
311
|
+
"""
|
312
|
+
# Args[0] is the client object where we can grab specific metadata about the underlying API status
|
313
|
+
query_id = generate_id(length=16)
|
314
|
+
parsed_args = subset_dict(
|
315
|
+
args[0].__dict__,
|
316
|
+
["api_version", "batch_size", "max_retries", "num_workers", "timeout"],
|
317
|
+
)
|
318
|
+
|
319
|
+
start_time_dt = datetime.fromtimestamp(start_time)
|
320
|
+
end_time_dt = datetime.fromtimestamp(start_time + time_elapsed)
|
321
|
+
|
322
|
+
timings = {
|
323
|
+
"start_time": start_time_dt,
|
324
|
+
"end_time": end_time_dt,
|
325
|
+
"time_elapsed_(seconds)": time_elapsed,
|
326
|
+
}
|
327
|
+
|
328
|
+
packed_data = []
|
329
|
+
for _parsed_response in parsed_response:
|
330
|
+
_packed_dict = {
|
331
|
+
"query_id": query_id,
|
332
|
+
**kwargs,
|
333
|
+
**_parsed_response,
|
334
|
+
**timings,
|
335
|
+
**parsed_args,
|
336
|
+
}
|
337
|
+
if "model" not in _packed_dict:
|
338
|
+
_packed_dict["model"] = default_model
|
339
|
+
packed_data.append(_packed_dict)
|
340
|
+
|
341
|
+
columns, data = reorder_and_convert_dict_list_to_table(
|
342
|
+
packed_data, table_column_order
|
343
|
+
)
|
344
|
+
|
345
|
+
request_response_table = wandb.Table(data=data, columns=columns)
|
346
|
+
|
347
|
+
return {f"{response_type}": request_response_table}
|
@@ -65,12 +65,10 @@ def monitor():
|
|
65
65
|
recorder.orig_close(self)
|
66
66
|
if not self.enabled:
|
67
67
|
return
|
68
|
-
|
69
|
-
|
70
|
-
key = m.group(1)
|
71
|
-
|
72
|
-
key = "videos"
|
73
|
-
wandb.log({key: wandb.Video(getattr(self, path))})
|
68
|
+
if wandb.run:
|
69
|
+
m = re.match(r".+(video\.\d+).+", getattr(self, path))
|
70
|
+
key = m.group(1) if m else "videos"
|
71
|
+
wandb.log({key: wandb.Video(getattr(self, path))})
|
74
72
|
|
75
73
|
def del_(self):
|
76
74
|
self.orig_close()
|
@@ -0,0 +1,18 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
from wandb.sdk.integration_utils.auto_logging import AutologAPI
|
4
|
+
|
5
|
+
from .resolver import HuggingFacePipelineRequestResponseResolver
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
resolver = HuggingFacePipelineRequestResponseResolver()
|
10
|
+
|
11
|
+
autolog = AutologAPI(
|
12
|
+
name="transformers",
|
13
|
+
symbols=("Pipeline.__call__",),
|
14
|
+
resolver=resolver,
|
15
|
+
telemetry_feature="hf_pipeline_autolog",
|
16
|
+
)
|
17
|
+
|
18
|
+
autolog.get_latest_id = resolver.get_latest_id
|