wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/importers/base.py +20 -5
  4. wandb/apis/importers/mlflow.py +7 -1
  5. wandb/apis/internal.py +12 -0
  6. wandb/apis/public.py +247 -1387
  7. wandb/apis/reports/_panels.py +58 -35
  8. wandb/beta/workflows.py +6 -7
  9. wandb/cli/cli.py +130 -60
  10. wandb/data_types.py +3 -1
  11. wandb/filesync/dir_watcher.py +21 -27
  12. wandb/filesync/step_checksum.py +8 -8
  13. wandb/filesync/step_prepare.py +23 -10
  14. wandb/filesync/step_upload.py +13 -13
  15. wandb/filesync/upload_job.py +4 -8
  16. wandb/integration/cohere/__init__.py +3 -0
  17. wandb/integration/cohere/cohere.py +21 -0
  18. wandb/integration/cohere/resolver.py +347 -0
  19. wandb/integration/gym/__init__.py +4 -6
  20. wandb/integration/huggingface/__init__.py +3 -0
  21. wandb/integration/huggingface/huggingface.py +18 -0
  22. wandb/integration/huggingface/resolver.py +213 -0
  23. wandb/integration/langchain/wandb_tracer.py +16 -179
  24. wandb/integration/openai/__init__.py +1 -3
  25. wandb/integration/openai/openai.py +11 -143
  26. wandb/integration/openai/resolver.py +111 -38
  27. wandb/integration/sagemaker/config.py +2 -2
  28. wandb/integration/tensorboard/log.py +4 -4
  29. wandb/old/settings.py +24 -7
  30. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  31. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  32. wandb/proto/wandb_deprecated.py +3 -1
  33. wandb/sdk/__init__.py +1 -1
  34. wandb/sdk/artifacts/__init__.py +0 -0
  35. wandb/sdk/artifacts/artifact.py +2101 -0
  36. wandb/sdk/artifacts/artifact_download_logger.py +42 -0
  37. wandb/sdk/artifacts/artifact_manifest.py +67 -0
  38. wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
  39. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  40. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
  41. wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
  42. wandb/sdk/artifacts/artifact_state.py +10 -0
  43. wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
  44. wandb/sdk/artifacts/exceptions.py +55 -0
  45. wandb/sdk/artifacts/storage_handler.py +59 -0
  46. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  47. wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
  48. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  49. wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
  50. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
  51. wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
  53. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
  54. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
  55. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  56. wandb/sdk/artifacts/storage_layout.py +6 -0
  57. wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
  58. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
  60. wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
  61. wandb/sdk/data_types/_dtypes.py +7 -12
  62. wandb/sdk/data_types/base_types/json_metadata.py +3 -2
  63. wandb/sdk/data_types/base_types/media.py +8 -8
  64. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  65. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
  66. wandb/sdk/data_types/helper_types/classes.py +6 -8
  67. wandb/sdk/data_types/helper_types/image_mask.py +5 -6
  68. wandb/sdk/data_types/histogram.py +4 -3
  69. wandb/sdk/data_types/html.py +3 -4
  70. wandb/sdk/data_types/image.py +11 -9
  71. wandb/sdk/data_types/molecule.py +5 -3
  72. wandb/sdk/data_types/object_3d.py +7 -5
  73. wandb/sdk/data_types/plotly.py +3 -2
  74. wandb/sdk/data_types/saved_model.py +11 -11
  75. wandb/sdk/data_types/trace_tree.py +5 -4
  76. wandb/sdk/data_types/utils.py +3 -5
  77. wandb/sdk/data_types/video.py +5 -4
  78. wandb/sdk/integration_utils/auto_logging.py +215 -0
  79. wandb/sdk/interface/interface.py +15 -15
  80. wandb/sdk/internal/file_pusher.py +8 -16
  81. wandb/sdk/internal/file_stream.py +5 -11
  82. wandb/sdk/internal/handler.py +13 -1
  83. wandb/sdk/internal/internal_api.py +287 -13
  84. wandb/sdk/internal/job_builder.py +119 -30
  85. wandb/sdk/internal/sender.py +6 -26
  86. wandb/sdk/internal/settings_static.py +2 -0
  87. wandb/sdk/internal/system/assets/__init__.py +2 -0
  88. wandb/sdk/internal/system/assets/gpu.py +42 -0
  89. wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
  90. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  91. wandb/sdk/internal/system/system_info.py +3 -3
  92. wandb/sdk/internal/tb_watcher.py +32 -22
  93. wandb/sdk/internal/thread_local_settings.py +18 -0
  94. wandb/sdk/launch/_project_spec.py +57 -11
  95. wandb/sdk/launch/agent/agent.py +147 -65
  96. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  97. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  98. wandb/sdk/launch/builder/abstract.py +5 -1
  99. wandb/sdk/launch/builder/build.py +21 -18
  100. wandb/sdk/launch/builder/docker_builder.py +10 -4
  101. wandb/sdk/launch/builder/kaniko_builder.py +113 -23
  102. wandb/sdk/launch/builder/noop.py +6 -3
  103. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
  104. wandb/sdk/launch/environment/aws_environment.py +3 -2
  105. wandb/sdk/launch/environment/azure_environment.py +124 -0
  106. wandb/sdk/launch/environment/gcp_environment.py +2 -4
  107. wandb/sdk/launch/environment/local_environment.py +1 -1
  108. wandb/sdk/launch/errors.py +19 -0
  109. wandb/sdk/launch/github_reference.py +32 -19
  110. wandb/sdk/launch/launch.py +3 -8
  111. wandb/sdk/launch/launch_add.py +6 -2
  112. wandb/sdk/launch/loader.py +21 -2
  113. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  114. wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
  115. wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
  116. wandb/sdk/launch/registry/local_registry.py +2 -1
  117. wandb/sdk/launch/runner/abstract.py +24 -3
  118. wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
  119. wandb/sdk/launch/runner/local_container.py +103 -51
  120. wandb/sdk/launch/runner/local_process.py +1 -1
  121. wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
  122. wandb/sdk/launch/runner/vertex_runner.py +10 -5
  123. wandb/sdk/launch/sweeps/__init__.py +7 -9
  124. wandb/sdk/launch/sweeps/scheduler.py +307 -77
  125. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  126. wandb/sdk/launch/sweeps/utils.py +82 -35
  127. wandb/sdk/launch/utils.py +89 -75
  128. wandb/sdk/lib/_settings_toposort_generated.py +7 -0
  129. wandb/sdk/lib/capped_dict.py +26 -0
  130. wandb/sdk/lib/{git.py → gitlib.py} +76 -59
  131. wandb/sdk/lib/hashutil.py +12 -4
  132. wandb/sdk/lib/paths.py +96 -8
  133. wandb/sdk/lib/sock_client.py +2 -2
  134. wandb/sdk/lib/timer.py +1 -0
  135. wandb/sdk/service/server.py +22 -9
  136. wandb/sdk/service/server_sock.py +1 -1
  137. wandb/sdk/service/service.py +27 -8
  138. wandb/sdk/verify/verify.py +4 -7
  139. wandb/sdk/wandb_config.py +2 -6
  140. wandb/sdk/wandb_init.py +57 -53
  141. wandb/sdk/wandb_require.py +7 -0
  142. wandb/sdk/wandb_run.py +61 -223
  143. wandb/sdk/wandb_settings.py +28 -4
  144. wandb/testing/relay.py +15 -2
  145. wandb/util.py +74 -36
  146. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
  147. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
  148. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
  149. wandb/integration/langchain/util.py +0 -191
  150. wandb/sdk/interface/artifacts/__init__.py +0 -33
  151. wandb/sdk/interface/artifacts/artifact.py +0 -615
  152. wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
  153. wandb/sdk/wandb_artifacts.py +0 -2226
  154. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  155. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  156. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -8,25 +8,27 @@ import shutil
8
8
  import threading
9
9
  from typing import TYPE_CHECKING, NamedTuple, Optional, Union, cast
10
10
 
11
- from wandb.filesync import dir_watcher, step_upload
11
+ from wandb.filesync import step_upload
12
12
  from wandb.sdk.lib import filesystem, runid
13
+ from wandb.sdk.lib.paths import LogicalPath
13
14
 
14
15
  if TYPE_CHECKING:
15
16
  import tempfile
16
17
 
17
18
  from wandb.filesync import stats
18
- from wandb.sdk.interface import artifacts
19
- from wandb.sdk.internal import artifact_saver, internal_api
19
+ from wandb.sdk.artifacts import artifact_saver
20
+ from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
21
+ from wandb.sdk.internal import internal_api
20
22
 
21
23
 
22
24
  class RequestUpload(NamedTuple):
23
25
  path: str
24
- save_name: dir_watcher.SaveName
26
+ save_name: LogicalPath
25
27
  copy: bool
26
28
 
27
29
 
28
30
  class RequestStoreManifestFiles(NamedTuple):
29
- manifest: "artifacts.ArtifactManifest"
31
+ manifest: "ArtifactManifest"
30
32
  artifact_id: str
31
33
  save_fn: "artifact_saver.SaveFn"
32
34
  save_fn_async: "artifact_saver.SaveFnAsync"
@@ -108,9 +110,7 @@ class StepChecksum:
108
110
  self._output_queue.put(
109
111
  step_upload.RequestUpload(
110
112
  entry.local_path,
111
- dir_watcher.SaveName(
112
- entry.path
113
- ), # typecast might not be legit
113
+ entry.path,
114
114
  req.artifact_id,
115
115
  entry.digest,
116
116
  False,
@@ -8,6 +8,7 @@ import time
8
8
  from typing import (
9
9
  TYPE_CHECKING,
10
10
  Callable,
11
+ Dict,
11
12
  List,
12
13
  Mapping,
13
14
  NamedTuple,
@@ -39,9 +40,12 @@ class RequestFinish(NamedTuple):
39
40
 
40
41
 
41
42
  class ResponsePrepare(NamedTuple):
43
+ birth_artifact_id: str
42
44
  upload_url: Optional[str]
43
45
  upload_headers: Sequence[str]
44
- birth_artifact_id: str
46
+ upload_id: Optional[str]
47
+ storage_path: Optional[str]
48
+ multipart_upload_urls: Optional[Dict[int, str]]
45
49
 
46
50
 
47
51
  Request = Union[RequestPrepare, RequestFinish]
@@ -88,6 +92,21 @@ def gather_batch(
88
92
  return False, batch
89
93
 
90
94
 
95
+ def prepare_response(response: "CreateArtifactFilesResponseFile") -> ResponsePrepare:
96
+ multipart_resp = response.get("uploadMultipartUrls")
97
+ part_list = multipart_resp["uploadUrlParts"] if multipart_resp else []
98
+ multipart_parts = {u["partNumber"]: u["uploadUrl"] for u in part_list} or None
99
+
100
+ return ResponsePrepare(
101
+ birth_artifact_id=response["artifact"]["id"],
102
+ upload_url=response["uploadUrl"],
103
+ upload_headers=response["uploadHeaders"],
104
+ upload_id=multipart_resp and multipart_resp.get("uploadID"),
105
+ storage_path=response.get("storagePath"),
106
+ multipart_upload_urls=multipart_parts,
107
+ )
108
+
109
+
91
110
  class StepPrepare:
92
111
  """A thread that batches requests to our file prepare API.
93
112
 
@@ -120,18 +139,12 @@ class StepPrepare:
120
139
  max_batch_size=self._max_batch_size,
121
140
  )
122
141
  if batch:
123
- prepare_response = self._prepare_batch(batch)
142
+ batch_response = self._prepare_batch(batch)
124
143
  # send responses
125
144
  for prepare_request in batch:
126
145
  name = prepare_request.file_spec["name"]
127
- response_file = prepare_response[name]
128
- upload_url = response_file["uploadUrl"]
129
- upload_headers = response_file["uploadHeaders"]
130
- birth_artifact_id = response_file["artifact"]["id"]
131
-
132
- response = ResponsePrepare(
133
- upload_url, upload_headers, birth_artifact_id
134
- )
146
+ response_file = batch_response[name]
147
+ response = prepare_response(response_file)
135
148
  if isinstance(prepare_request.response_channel, queue.Queue):
136
149
  prepare_request.response_channel.put(response)
137
150
  else:
@@ -20,9 +20,10 @@ from typing import (
20
20
 
21
21
  from wandb.errors.term import termerror
22
22
  from wandb.filesync import upload_job
23
+ from wandb.sdk.lib.paths import LogicalPath
23
24
 
24
25
  if TYPE_CHECKING:
25
- from wandb.filesync import dir_watcher, stats
26
+ from wandb.filesync import stats
26
27
  from wandb.sdk.internal import file_stream, internal_api, progress
27
28
  from wandb.sdk.internal.settings_static import SettingsStatic
28
29
 
@@ -49,7 +50,7 @@ logger = logging.getLogger(__name__)
49
50
 
50
51
  class RequestUpload(NamedTuple):
51
52
  path: str
52
- save_name: "dir_watcher.SaveName"
53
+ save_name: LogicalPath
53
54
  artifact_id: Optional[str]
54
55
  md5: Optional[str]
55
56
  copied: bool
@@ -69,9 +70,12 @@ class RequestFinish(NamedTuple):
69
70
  callback: Optional[OnRequestFinishFn]
70
71
 
71
72
 
72
- Event = Union[
73
- RequestUpload, RequestCommitArtifact, RequestFinish, upload_job.EventJobDone
74
- ]
73
+ class EventJobDone(NamedTuple):
74
+ job: RequestUpload
75
+ exc: Optional[BaseException]
76
+
77
+
78
+ Event = Union[RequestUpload, RequestCommitArtifact, RequestFinish, EventJobDone]
75
79
 
76
80
 
77
81
  class AsyncExecutor:
@@ -148,7 +152,7 @@ class StepUpload:
148
152
  )
149
153
 
150
154
  # Indexed by files' `save_name`'s, which are their ID's in the Run.
151
- self._running_jobs: MutableMapping[dir_watcher.SaveName, RequestUpload] = {}
155
+ self._running_jobs: MutableMapping[LogicalPath, RequestUpload] = {}
152
156
  self._pending_jobs: MutableSequence[RequestUpload] = []
153
157
 
154
158
  self._artifacts: MutableMapping[str, "ArtifactStatus"] = {}
@@ -189,7 +193,7 @@ class StepUpload:
189
193
  break
190
194
 
191
195
  def _handle_event(self, event: Event) -> None:
192
- if isinstance(event, upload_job.EventJobDone):
196
+ if isinstance(event, EventJobDone):
193
197
  job = event.job
194
198
 
195
199
  if event.exc is not None:
@@ -283,9 +287,7 @@ class StepUpload:
283
287
  try:
284
288
  self._do_upload_sync(event)
285
289
  finally:
286
- self._event_queue.put(
287
- upload_job.EventJobDone(event, exc=sys.exc_info()[1])
288
- )
290
+ self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
289
291
 
290
292
  self._pool.submit(run_and_notify)
291
293
 
@@ -307,9 +309,7 @@ class StepUpload:
307
309
  try:
308
310
  await self._do_upload_async(event)
309
311
  finally:
310
- self._event_queue.put(
311
- upload_job.EventJobDone(event, exc=sys.exc_info()[1])
312
- )
312
+ self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
313
313
 
314
314
  async_executor.submit(run_and_notify())
315
315
 
@@ -1,20 +1,16 @@
1
1
  import asyncio
2
2
  import logging
3
3
  import os
4
- from typing import TYPE_CHECKING, NamedTuple, Optional
4
+ from typing import TYPE_CHECKING, Optional
5
5
 
6
6
  import wandb
7
+ from wandb.sdk.lib.paths import LogicalPath
7
8
 
8
9
  if TYPE_CHECKING:
9
10
  from wandb.filesync import dir_watcher, stats, step_upload
10
11
  from wandb.sdk.internal import file_stream, internal_api
11
12
 
12
13
 
13
- class EventJobDone(NamedTuple):
14
- job: "step_upload.RequestUpload"
15
- exc: Optional[BaseException]
16
-
17
-
18
14
  logger = logging.getLogger(__name__)
19
15
 
20
16
 
@@ -25,7 +21,7 @@ class UploadJob:
25
21
  api: "internal_api.Api",
26
22
  file_stream: "file_stream.FileStreamApi",
27
23
  silent: bool,
28
- save_name: "dir_watcher.SaveName",
24
+ save_name: LogicalPath,
29
25
  path: "dir_watcher.PathStr",
30
26
  artifact_id: Optional[str],
31
27
  md5: Optional[str],
@@ -47,7 +43,7 @@ class UploadJob:
47
43
  self._file_stream = file_stream
48
44
  self.silent = silent
49
45
  self.save_name = save_name
50
- self.save_path = self.path = path
46
+ self.save_path = path
51
47
  self.artifact_id = artifact_id
52
48
  self.md5 = md5
53
49
  self.copied = copied
@@ -0,0 +1,3 @@
1
+ __all__ = ("autolog",)
2
+
3
+ from .cohere import autolog
@@ -0,0 +1,21 @@
1
+ import logging
2
+
3
+ from wandb.sdk.integration_utils.auto_logging import AutologAPI
4
+
5
+ from .resolver import CohereRequestResponseResolver
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ autolog = AutologAPI(
11
+ name="Cohere",
12
+ symbols=(
13
+ "Client.generate",
14
+ "Client.chat",
15
+ "Client.classify",
16
+ "Client.summarize",
17
+ "Client.rerank",
18
+ ),
19
+ resolver=CohereRequestResponseResolver(),
20
+ telemetry_feature="cohere_autolog",
21
+ )
@@ -0,0 +1,347 @@
1
+ import logging
2
+ from datetime import datetime
3
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
4
+
5
+ import wandb
6
+ from wandb.sdk.integration_utils.auto_logging import Response
7
+ from wandb.sdk.lib.runid import generate_id
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def subset_dict(
13
+ original_dict: Dict[str, Any], keys_subset: Sequence[str]
14
+ ) -> Dict[str, Any]:
15
+ """Create a subset of a dictionary using a subset of keys.
16
+
17
+ :param original_dict: The original dictionary.
18
+ :param keys_subset: The subset of keys to extract.
19
+ :return: A dictionary containing only the specified keys.
20
+ """
21
+ return {key: original_dict[key] for key in keys_subset if key in original_dict}
22
+
23
+
24
+ def reorder_and_convert_dict_list_to_table(
25
+ data: List[Dict[str, Any]], order: List[str]
26
+ ) -> Tuple[List[str], List[List[Any]]]:
27
+ """Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries.
28
+
29
+ :param data: A list of dictionaries.
30
+ :param order: A list of keys specifying the desired order for specific dictionaries. The remaining dictionaries will be ordered based on their original order.
31
+ :return: A pair of column names and corresponding values.
32
+ """
33
+ final_columns = []
34
+ keys_present = set()
35
+
36
+ # First, add all ordered keys to the final columns
37
+ for key in order:
38
+ if key not in keys_present:
39
+ final_columns.append(key)
40
+ keys_present.add(key)
41
+
42
+ # Then, add any keys present in the dictionaries but not in the order
43
+ for d in data:
44
+ for key in d:
45
+ if key not in keys_present:
46
+ final_columns.append(key)
47
+ keys_present.add(key)
48
+
49
+ # Then, construct the table of values
50
+ values = []
51
+ for d in data:
52
+ row = []
53
+ for key in final_columns:
54
+ row.append(d.get(key, None))
55
+ values.append(row)
56
+
57
+ return final_columns, values
58
+
59
+
60
+ def flatten_dict(
61
+ dictionary: Dict[str, Any], parent_key: str = "", sep: str = "-"
62
+ ) -> Dict[str, Any]:
63
+ """Flatten a nested dictionary, joining keys using a specified separator.
64
+
65
+ :param dictionary: The dictionary to flatten.
66
+ :param parent_key: The base key to prepend to each key.
67
+ :param sep: The separator to use when joining keys.
68
+ :return: A flattened dictionary.
69
+ """
70
+ flattened_dict = {}
71
+ for key, value in dictionary.items():
72
+ new_key = f"{parent_key}{sep}{key}" if parent_key else key
73
+ if isinstance(value, dict):
74
+ flattened_dict.update(flatten_dict(value, new_key, sep=sep))
75
+ else:
76
+ flattened_dict[new_key] = value
77
+ return flattened_dict
78
+
79
+
80
+ def collect_common_keys(list_of_dicts: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
81
+ """Collect the common keys of a list of dictionaries. For each common key, put its values into a list in the order they appear in the original dictionaries.
82
+
83
+ :param list_of_dicts: The list of dictionaries to inspect.
84
+ :return: A dictionary with each common key and its corresponding list of values.
85
+ """
86
+ common_keys = set.intersection(*map(set, list_of_dicts))
87
+ common_dict = {key: [] for key in common_keys}
88
+ for d in list_of_dicts:
89
+ for key in common_keys:
90
+ common_dict[key].append(d[key])
91
+ return common_dict
92
+
93
+
94
+ class CohereRequestResponseResolver:
95
+ """Class to resolve the request/response from the Cohere API and convert it to a dictionary that can be logged."""
96
+
97
+ def __call__(
98
+ self,
99
+ args: Sequence[Any],
100
+ kwargs: Dict[str, Any],
101
+ response: Response,
102
+ start_time: float,
103
+ time_elapsed: float,
104
+ ) -> Optional[Dict[str, Any]]:
105
+ """Process the response from the Cohere API and convert it to a dictionary that can be logged.
106
+
107
+ :param args: The arguments of the original function.
108
+ :param kwargs: The keyword arguments of the original function.
109
+ :param response: The response from the Cohere API.
110
+ :param start_time: The start time of the request.
111
+ :param time_elapsed: The time elapsed for the request.
112
+ :return: A dictionary containing the parsed response and timing information.
113
+ """
114
+ try:
115
+ # Each of the different endpoints map to one specific response type
116
+ # We want to 'type check' the response without directly importing the packages type
117
+ # It may make more sense to pass the invoked symbol from the AutologAPI instead
118
+ response_type = str(type(response)).split("'")[1].split(".")[-1]
119
+
120
+ # Initialize parsed_response to None to handle the case where the response type is unsupported
121
+ parsed_response = None
122
+ if response_type == "Generations":
123
+ parsed_response = self._resolve_generate_response(response)
124
+ # TODO: Remove hard-coded default model name
125
+ table_column_order = [
126
+ "start_time",
127
+ "query_id",
128
+ "model",
129
+ "prompt",
130
+ "text",
131
+ "token_likelihoods",
132
+ "likelihood",
133
+ "time_elapsed_(seconds)",
134
+ "end_time",
135
+ ]
136
+ default_model = "command"
137
+ elif response_type == "Chat":
138
+ parsed_response = self._resolve_chat_response(response)
139
+ table_column_order = [
140
+ "start_time",
141
+ "query_id",
142
+ "model",
143
+ "conversation_id",
144
+ "response_id",
145
+ "query",
146
+ "text",
147
+ "prompt",
148
+ "preamble",
149
+ "chat_history",
150
+ "chatlog",
151
+ "time_elapsed_(seconds)",
152
+ "end_time",
153
+ ]
154
+ default_model = "command"
155
+ elif response_type == "Classifications":
156
+ parsed_response = self._resolve_classify_response(response)
157
+ kwargs = self._resolve_classify_kwargs(kwargs)
158
+ table_column_order = [
159
+ "start_time",
160
+ "query_id",
161
+ "model",
162
+ "id",
163
+ "input",
164
+ "prediction",
165
+ "confidence",
166
+ "time_elapsed_(seconds)",
167
+ "end_time",
168
+ ]
169
+ default_model = "embed-english-v2.0"
170
+ elif response_type == "SummarizeResponse":
171
+ parsed_response = self._resolve_summarize_response(response)
172
+ table_column_order = [
173
+ "start_time",
174
+ "query_id",
175
+ "model",
176
+ "response_id",
177
+ "text",
178
+ "additional_command",
179
+ "summary",
180
+ "time_elapsed_(seconds)",
181
+ "end_time",
182
+ "length",
183
+ "format",
184
+ ]
185
+ default_model = "summarize-xlarge"
186
+ elif response_type == "Reranking":
187
+ parsed_response = self._resolve_rerank_response(response)
188
+ table_column_order = [
189
+ "start_time",
190
+ "query_id",
191
+ "model",
192
+ "id",
193
+ "query",
194
+ "top_n",
195
+ # This is a nested dict key that got flattened
196
+ "document-text",
197
+ "relevance_score",
198
+ "index",
199
+ "time_elapsed_(seconds)",
200
+ "end_time",
201
+ ]
202
+ default_model = "rerank-english-v2.0"
203
+ else:
204
+ logger.info(f"Unsupported Cohere response object: {response}")
205
+
206
+ return self._resolve(
207
+ args,
208
+ kwargs,
209
+ parsed_response,
210
+ start_time,
211
+ time_elapsed,
212
+ response_type,
213
+ table_column_order,
214
+ default_model,
215
+ )
216
+ except Exception as e:
217
+ logger.warning(f"Failed to resolve request/response: {e}")
218
+ return None
219
+
220
+ # These helper functions process the response from different endpoints of the Cohere API.
221
+ # Since the response objects for different endpoints have different structures,
222
+ # we need different logic to process them.
223
+
224
+ def _resolve_generate_response(self, response: Response) -> List[Dict[str, Any]]:
225
+ return_list = []
226
+ for _response in response:
227
+ # Built in Cohere.*.Generations function to color token_likelihoods and return a dict of response data
228
+ _response_dict = _response._visualize_helper()
229
+ try:
230
+ _response_dict["token_likelihoods"] = wandb.Html(
231
+ _response_dict["token_likelihoods"]
232
+ )
233
+ except (KeyError, ValueError):
234
+ pass
235
+ return_list.append(_response_dict)
236
+
237
+ return return_list
238
+
239
+ def _resolve_chat_response(self, response: Response) -> List[Dict[str, Any]]:
240
+ return [
241
+ subset_dict(
242
+ response.__dict__,
243
+ [
244
+ "response_id",
245
+ "generation_id",
246
+ "query",
247
+ "text",
248
+ "conversation_id",
249
+ "prompt",
250
+ "chatlog",
251
+ "preamble",
252
+ ],
253
+ )
254
+ ]
255
+
256
+ def _resolve_classify_response(self, response: Response) -> List[Dict[str, Any]]:
257
+ # The labels key is a dict returning the scores for the classification probability for each label provided
258
+ # We flatten this nested dict for ease of consumption in the wandb UI
259
+ return [flatten_dict(_response.__dict__) for _response in response]
260
+
261
+ def _resolve_classify_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
262
+ # Example texts look strange when rendered in Wandb UI as it is a list of text and label
263
+ # We extract each value into its own column
264
+ example_texts = []
265
+ example_labels = []
266
+ for example in kwargs["examples"]:
267
+ example_texts.append(example.text)
268
+ example_labels.append(example.label)
269
+ kwargs.pop("examples")
270
+ kwargs["example_texts"] = example_texts
271
+ kwargs["example_labels"] = example_labels
272
+ return kwargs
273
+
274
+ def _resolve_summarize_response(self, response: Response) -> List[Dict[str, Any]]:
275
+ return [{"response_id": response.id, "summary": response.summary}]
276
+
277
+ def _resolve_rerank_response(self, response: Response) -> List[Dict[str, Any]]:
278
+ # The documents key contains a dict containing the content of the document which is at least "text"
279
+ # We flatten this nested dict for ease of consumption in the wandb UI
280
+ flattened_response_dicts = [
281
+ flatten_dict(_response.__dict__) for _response in response
282
+ ]
283
+ # ReRank returns each document provided a top_n value so we aggregate into one view so users can paginate a row
284
+ # As opposed to each row being one of the top_n responses
285
+ return_dict = collect_common_keys(flattened_response_dicts)
286
+ return_dict["id"] = response.id
287
+ return [return_dict]
288
+
289
+ def _resolve(
290
+ self,
291
+ args: Sequence[Any],
292
+ kwargs: Dict[str, Any],
293
+ parsed_response: List[Dict[str, Any]],
294
+ start_time: float,
295
+ time_elapsed: float,
296
+ response_type: str,
297
+ table_column_order: List[str],
298
+ default_model: str,
299
+ ) -> Dict[str, Any]:
300
+ """Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries.
301
+
302
+ :param args: The arguments passed to the API client.
303
+ :param kwargs: The keyword arguments passed to the API client.
304
+ :param parsed_response: The parsed response from the API.
305
+ :param start_time: The start time of the API request.
306
+ :param time_elapsed: The time elapsed during the API request.
307
+ :param response_type: The type of the API response.
308
+ :param table_column_order: The desired order of columns in the resulting table.
309
+ :param default_model: The default model to use if not specified in the response.
310
+ :return: A dictionary containing the formatted response.
311
+ """
312
+ # Args[0] is the client object where we can grab specific metadata about the underlying API status
313
+ query_id = generate_id(length=16)
314
+ parsed_args = subset_dict(
315
+ args[0].__dict__,
316
+ ["api_version", "batch_size", "max_retries", "num_workers", "timeout"],
317
+ )
318
+
319
+ start_time_dt = datetime.fromtimestamp(start_time)
320
+ end_time_dt = datetime.fromtimestamp(start_time + time_elapsed)
321
+
322
+ timings = {
323
+ "start_time": start_time_dt,
324
+ "end_time": end_time_dt,
325
+ "time_elapsed_(seconds)": time_elapsed,
326
+ }
327
+
328
+ packed_data = []
329
+ for _parsed_response in parsed_response:
330
+ _packed_dict = {
331
+ "query_id": query_id,
332
+ **kwargs,
333
+ **_parsed_response,
334
+ **timings,
335
+ **parsed_args,
336
+ }
337
+ if "model" not in _packed_dict:
338
+ _packed_dict["model"] = default_model
339
+ packed_data.append(_packed_dict)
340
+
341
+ columns, data = reorder_and_convert_dict_list_to_table(
342
+ packed_data, table_column_order
343
+ )
344
+
345
+ request_response_table = wandb.Table(data=data, columns=columns)
346
+
347
+ return {f"{response_type}": request_response_table}
@@ -65,12 +65,10 @@ def monitor():
65
65
  recorder.orig_close(self)
66
66
  if not self.enabled:
67
67
  return
68
- m = re.match(r".+(video\.\d+).+", getattr(self, path))
69
- if m:
70
- key = m.group(1)
71
- else:
72
- key = "videos"
73
- wandb.log({key: wandb.Video(getattr(self, path))})
68
+ if wandb.run:
69
+ m = re.match(r".+(video\.\d+).+", getattr(self, path))
70
+ key = m.group(1) if m else "videos"
71
+ wandb.log({key: wandb.Video(getattr(self, path))})
74
72
 
75
73
  def del_(self):
76
74
  self.orig_close()
@@ -0,0 +1,3 @@
1
+ __all__ = ("autolog",)
2
+
3
+ from .huggingface import autolog
@@ -0,0 +1,18 @@
1
+ import logging
2
+
3
+ from wandb.sdk.integration_utils.auto_logging import AutologAPI
4
+
5
+ from .resolver import HuggingFacePipelineRequestResponseResolver
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ resolver = HuggingFacePipelineRequestResponseResolver()
10
+
11
+ autolog = AutologAPI(
12
+ name="transformers",
13
+ symbols=("Pipeline.__call__",),
14
+ resolver=resolver,
15
+ telemetry_feature="hf_pipeline_autolog",
16
+ )
17
+
18
+ autolog.get_latest_id = resolver.get_latest_id