wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/importers/base.py +20 -5
- wandb/apis/importers/mlflow.py +7 -1
- wandb/apis/internal.py +12 -0
- wandb/apis/public.py +247 -1387
- wandb/apis/reports/_panels.py +58 -35
- wandb/beta/workflows.py +6 -7
- wandb/cli/cli.py +130 -60
- wandb/data_types.py +3 -1
- wandb/filesync/dir_watcher.py +21 -27
- wandb/filesync/step_checksum.py +8 -8
- wandb/filesync/step_prepare.py +23 -10
- wandb/filesync/step_upload.py +13 -13
- wandb/filesync/upload_job.py +4 -8
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/gym/__init__.py +4 -6
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/openai/__init__.py +1 -3
- wandb/integration/openai/openai.py +11 -143
- wandb/integration/openai/resolver.py +111 -38
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/old/settings.py +24 -7
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/artifact.py +2101 -0
- wandb/sdk/artifacts/artifact_download_logger.py +42 -0
- wandb/sdk/artifacts/artifact_manifest.py +67 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
- wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
- wandb/sdk/artifacts/exceptions.py +55 -0
- wandb/sdk/artifacts/storage_handler.py +59 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
- wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +3 -2
- wandb/sdk/data_types/base_types/media.py +8 -8
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
- wandb/sdk/data_types/helper_types/classes.py +6 -8
- wandb/sdk/data_types/helper_types/image_mask.py +5 -6
- wandb/sdk/data_types/histogram.py +4 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +11 -9
- wandb/sdk/data_types/molecule.py +5 -3
- wandb/sdk/data_types/object_3d.py +7 -5
- wandb/sdk/data_types/plotly.py +3 -2
- wandb/sdk/data_types/saved_model.py +11 -11
- wandb/sdk/data_types/trace_tree.py +5 -4
- wandb/sdk/data_types/utils.py +3 -5
- wandb/sdk/data_types/video.py +5 -4
- wandb/sdk/integration_utils/auto_logging.py +215 -0
- wandb/sdk/interface/interface.py +15 -15
- wandb/sdk/internal/file_pusher.py +8 -16
- wandb/sdk/internal/file_stream.py +5 -11
- wandb/sdk/internal/handler.py +13 -1
- wandb/sdk/internal/internal_api.py +287 -13
- wandb/sdk/internal/job_builder.py +119 -30
- wandb/sdk/internal/sender.py +6 -26
- wandb/sdk/internal/settings_static.py +2 -0
- wandb/sdk/internal/system/assets/__init__.py +2 -0
- wandb/sdk/internal/system/assets/gpu.py +42 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +3 -3
- wandb/sdk/internal/tb_watcher.py +32 -22
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/launch/_project_spec.py +57 -11
- wandb/sdk/launch/agent/agent.py +147 -65
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +21 -18
- wandb/sdk/launch/builder/docker_builder.py +10 -4
- wandb/sdk/launch/builder/kaniko_builder.py +113 -23
- wandb/sdk/launch/builder/noop.py +6 -3
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
- wandb/sdk/launch/environment/aws_environment.py +3 -2
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/environment/gcp_environment.py +2 -4
- wandb/sdk/launch/environment/local_environment.py +1 -1
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/github_reference.py +32 -19
- wandb/sdk/launch/launch.py +3 -8
- wandb/sdk/launch/launch_add.py +6 -2
- wandb/sdk/launch/loader.py +21 -2
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
- wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
- wandb/sdk/launch/registry/local_registry.py +2 -1
- wandb/sdk/launch/runner/abstract.py +24 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
- wandb/sdk/launch/runner/local_container.py +103 -51
- wandb/sdk/launch/runner/local_process.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
- wandb/sdk/launch/runner/vertex_runner.py +10 -5
- wandb/sdk/launch/sweeps/__init__.py +7 -9
- wandb/sdk/launch/sweeps/scheduler.py +307 -77
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +82 -35
- wandb/sdk/launch/utils.py +89 -75
- wandb/sdk/lib/_settings_toposort_generated.py +7 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/{git.py → gitlib.py} +76 -59
- wandb/sdk/lib/hashutil.py +12 -4
- wandb/sdk/lib/paths.py +96 -8
- wandb/sdk/lib/sock_client.py +2 -2
- wandb/sdk/lib/timer.py +1 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +4 -7
- wandb/sdk/wandb_config.py +2 -6
- wandb/sdk/wandb_init.py +57 -53
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +61 -223
- wandb/sdk/wandb_settings.py +28 -4
- wandb/testing/relay.py +15 -2
- wandb/util.py +74 -36
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/interface/artifacts/__init__.py +0 -33
- wandb/sdk/interface/artifacts/artifact.py +0 -615
- wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
- wandb/sdk/wandb_artifacts.py +0 -2226
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,213 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from datetime import datetime
|
4
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
5
|
+
|
6
|
+
import pytz
|
7
|
+
|
8
|
+
import wandb
|
9
|
+
from wandb.sdk.integration_utils.auto_logging import Response
|
10
|
+
from wandb.sdk.lib.runid import generate_id
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
SUPPORTED_PIPELINE_TASKS = [
|
15
|
+
"text-classification",
|
16
|
+
"sentiment-analysis",
|
17
|
+
"question-answering",
|
18
|
+
"summarization",
|
19
|
+
"translation",
|
20
|
+
"text2text-generation",
|
21
|
+
"text-generation",
|
22
|
+
# "conversational",
|
23
|
+
]
|
24
|
+
|
25
|
+
PIPELINES_WITH_TOP_K = [
|
26
|
+
"text-classification",
|
27
|
+
"sentiment-analysis",
|
28
|
+
"question-answering",
|
29
|
+
]
|
30
|
+
|
31
|
+
|
32
|
+
class HuggingFacePipelineRequestResponseResolver:
|
33
|
+
"""Resolver for HuggingFace's pipeline request and responses, providing necessary data transformations and formatting.
|
34
|
+
|
35
|
+
This is based off (from wandb.sdk.integration_utils.auto_logging import RequestResponseResolver)
|
36
|
+
"""
|
37
|
+
|
38
|
+
autolog_id = None
|
39
|
+
|
40
|
+
def __call__(
|
41
|
+
self,
|
42
|
+
args: Sequence[Any],
|
43
|
+
kwargs: Dict[str, Any],
|
44
|
+
response: Response,
|
45
|
+
start_time: float,
|
46
|
+
time_elapsed: float,
|
47
|
+
) -> Optional[Dict[str, Any]]:
|
48
|
+
"""Main call method for this class.
|
49
|
+
|
50
|
+
:param args: list of arguments
|
51
|
+
:param kwargs: dictionary of keyword arguments
|
52
|
+
:param response: the response from the request
|
53
|
+
:param start_time: time when request started
|
54
|
+
:param time_elapsed: time elapsed for the request
|
55
|
+
:returns: packed data as a dictionary for logging to wandb, None if an exception occurred
|
56
|
+
"""
|
57
|
+
try:
|
58
|
+
pipe, input_data = args[:2]
|
59
|
+
task = pipe.task
|
60
|
+
|
61
|
+
# Translation tasks are in the form of `translation_x_to_y`
|
62
|
+
if task in SUPPORTED_PIPELINE_TASKS or task.startswith("translation"):
|
63
|
+
model = self._get_model(pipe)
|
64
|
+
if model is None:
|
65
|
+
return None
|
66
|
+
model_alias = model.name_or_path
|
67
|
+
timestamp = datetime.now(pytz.utc)
|
68
|
+
|
69
|
+
input_data, response = self._transform_task_specific_data(
|
70
|
+
task, input_data, response
|
71
|
+
)
|
72
|
+
formatted_data = self._format_data(task, input_data, response, kwargs)
|
73
|
+
packed_data = self._create_table(
|
74
|
+
formatted_data, model_alias, timestamp, time_elapsed
|
75
|
+
)
|
76
|
+
table_name = os.environ.get("WANDB_AUTOLOG_TABLE_NAME", f"{task}")
|
77
|
+
# TODO: Let users decide the name in a way that does not use an environment variable
|
78
|
+
|
79
|
+
return {
|
80
|
+
table_name: wandb.Table(
|
81
|
+
columns=packed_data[0], data=packed_data[1:]
|
82
|
+
)
|
83
|
+
}
|
84
|
+
|
85
|
+
logger.warning(
|
86
|
+
f"The task: `{task}` is not yet supported.\nPlease contact `wandb` to notify us if you would like support for this task"
|
87
|
+
)
|
88
|
+
except Exception as e:
|
89
|
+
logger.warning(e)
|
90
|
+
return None
|
91
|
+
|
92
|
+
# TODO: This should have a dependency on PreTrainedModel. i.e. isinstance(PreTrainedModel)
|
93
|
+
# from transformers.modeling_utils import PreTrainedModel
|
94
|
+
# We do not want this dependency explicity in our codebase so we make a very general assumption about
|
95
|
+
# the structure of the pipeline which may have unintended consequences
|
96
|
+
def _get_model(self, pipe) -> Optional[Any]:
|
97
|
+
"""Extracts model from the pipeline.
|
98
|
+
|
99
|
+
:param pipe: the HuggingFace pipeline
|
100
|
+
:returns: Model if available, None otherwise
|
101
|
+
"""
|
102
|
+
model = pipe.model
|
103
|
+
try:
|
104
|
+
return model.model
|
105
|
+
except AttributeError:
|
106
|
+
logger.info(
|
107
|
+
"Model does not have a `.model` attribute. Assuming `pipe.model` is the correct model."
|
108
|
+
)
|
109
|
+
return model
|
110
|
+
|
111
|
+
@staticmethod
|
112
|
+
def _transform_task_specific_data(
|
113
|
+
task: str, input_data: Union[List[Any], Any], response: Union[List[Any], Any]
|
114
|
+
) -> Tuple[Union[List[Any], Any], Union[List[Any], Any]]:
|
115
|
+
"""Transform input and response data based on specific tasks.
|
116
|
+
|
117
|
+
:param task: the task name
|
118
|
+
:param input_data: the input data
|
119
|
+
:param response: the response data
|
120
|
+
:returns: tuple of transformed input_data and response
|
121
|
+
"""
|
122
|
+
if task == "question-answering":
|
123
|
+
input_data = input_data if isinstance(input_data, list) else [input_data]
|
124
|
+
input_data = [data.__dict__ for data in input_data]
|
125
|
+
elif task == "conversational":
|
126
|
+
# We only grab the latest input/output pair from the conversation
|
127
|
+
# Logging the whole conversation renders strangely.
|
128
|
+
input_data = input_data if isinstance(input_data, list) else [input_data]
|
129
|
+
input_data = [data.__dict__["past_user_inputs"][-1] for data in input_data]
|
130
|
+
|
131
|
+
response = response if isinstance(response, list) else [response]
|
132
|
+
response = [data.__dict__["generated_responses"][-1] for data in response]
|
133
|
+
return input_data, response
|
134
|
+
|
135
|
+
def _format_data(
|
136
|
+
self,
|
137
|
+
task: str,
|
138
|
+
input_data: Union[List[Any], Any],
|
139
|
+
response: Union[List[Any], Any],
|
140
|
+
kwargs: Dict[str, Any],
|
141
|
+
) -> List[Dict[str, Any]]:
|
142
|
+
"""Formats input data, response, and kwargs into a list of dictionaries.
|
143
|
+
|
144
|
+
:param task: the task name
|
145
|
+
:param input_data: the input data
|
146
|
+
:param response: the response data
|
147
|
+
:param kwargs: dictionary of keyword arguments
|
148
|
+
:returns: list of dictionaries containing formatted data
|
149
|
+
"""
|
150
|
+
input_data = input_data if isinstance(input_data, list) else [input_data]
|
151
|
+
response = response if isinstance(response, list) else [response]
|
152
|
+
|
153
|
+
formatted_data = []
|
154
|
+
for i_text, r_text in zip(input_data, response):
|
155
|
+
# Unpack single element responses for better rendering in wandb UI when it is a task without top_k
|
156
|
+
# top_k = 1 would unpack the response into a single element while top_k > 1 would be a list
|
157
|
+
# this would cause the UI to not properly concatenate the tables of the same task by omitting the elements past the first
|
158
|
+
if (
|
159
|
+
(isinstance(r_text, list))
|
160
|
+
and (len(r_text) == 1)
|
161
|
+
and task not in PIPELINES_WITH_TOP_K
|
162
|
+
):
|
163
|
+
r_text = r_text[0]
|
164
|
+
formatted_data.append(
|
165
|
+
{"input": i_text, "response": r_text, "kwargs": kwargs}
|
166
|
+
)
|
167
|
+
return formatted_data
|
168
|
+
|
169
|
+
def _create_table(
|
170
|
+
self,
|
171
|
+
formatted_data: List[Dict[str, Any]],
|
172
|
+
model_alias: str,
|
173
|
+
timestamp: float,
|
174
|
+
time_elapsed: float,
|
175
|
+
) -> List[List[Any]]:
|
176
|
+
"""Creates a table from formatted data, model alias, timestamp, and elapsed time.
|
177
|
+
|
178
|
+
:param formatted_data: list of dictionaries containing formatted data
|
179
|
+
:param model_alias: alias of the model
|
180
|
+
:param timestamp: timestamp of the data
|
181
|
+
:param time_elapsed: time elapsed from the beginning
|
182
|
+
:returns: list of lists, representing a table of data. [0]th element = columns. [1]st element = data
|
183
|
+
"""
|
184
|
+
header = [
|
185
|
+
"ID",
|
186
|
+
"Model Alias",
|
187
|
+
"Timestamp",
|
188
|
+
"Elapsed Time",
|
189
|
+
"Input",
|
190
|
+
"Response",
|
191
|
+
"Kwargs",
|
192
|
+
]
|
193
|
+
table = [header]
|
194
|
+
autolog_id = generate_id(length=16)
|
195
|
+
|
196
|
+
for data in formatted_data:
|
197
|
+
row = [
|
198
|
+
autolog_id,
|
199
|
+
model_alias,
|
200
|
+
timestamp,
|
201
|
+
time_elapsed,
|
202
|
+
data["input"],
|
203
|
+
data["response"],
|
204
|
+
data["kwargs"],
|
205
|
+
]
|
206
|
+
table.append(row)
|
207
|
+
|
208
|
+
self.autolog_id = autolog_id
|
209
|
+
|
210
|
+
return table
|
211
|
+
|
212
|
+
def get_latest_id(self):
|
213
|
+
return self.autolog_id
|
@@ -14,22 +14,10 @@ integration will not break user code. The one exception to the rule is at import
|
|
14
14
|
LangChain is not installed, or the symbols are not in the same place, the appropriate error
|
15
15
|
will be raised when importing this module.
|
16
16
|
"""
|
17
|
-
import sys
|
18
|
-
|
19
|
-
if sys.version_info >= (3, 8):
|
20
|
-
from typing import TypedDict
|
21
|
-
else:
|
22
|
-
from typing_extensions import TypedDict
|
23
|
-
|
24
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
25
|
-
|
26
17
|
from packaging import version
|
27
18
|
|
28
|
-
import wandb
|
29
19
|
import wandb.util
|
30
|
-
from wandb.sdk.
|
31
|
-
from wandb.sdk.lib import telemetry as wb_telemetry
|
32
|
-
from wandb.sdk.lib.paths import StrPath
|
20
|
+
from wandb.sdk.lib import deprecate
|
33
21
|
|
34
22
|
langchain = wandb.util.get_module(
|
35
23
|
name="langchain",
|
@@ -37,174 +25,23 @@ langchain = wandb.util.get_module(
|
|
37
25
|
"package installed. Please install it with `pip install langchain`.",
|
38
26
|
)
|
39
27
|
|
40
|
-
if version.parse(langchain.__version__) < version.parse("0.0.
|
28
|
+
if version.parse(langchain.__version__) < version.parse("0.0.188"):
|
41
29
|
raise ValueError(
|
42
|
-
"The Weights & Biases Langchain integration does not support versions 0.0.
|
43
|
-
"To ensure proper functionality, please use version 0.0.
|
30
|
+
"The Weights & Biases Langchain integration does not support versions 0.0.187 and lower. "
|
31
|
+
"To ensure proper functionality, please use version 0.0.188 or higher."
|
44
32
|
)
|
45
33
|
|
46
|
-
# We want these imports after the import_langchain() call, so that we can
|
47
|
-
# catch the ImportError if langchain is not installed.
|
48
|
-
|
49
34
|
# isort: off
|
50
|
-
from langchain.callbacks.tracers
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
from wandb import Settings as WBSettings
|
64
|
-
from wandb.wandb_run import Run as WBRun
|
65
|
-
|
66
|
-
|
67
|
-
class WandbRunArgs(TypedDict):
|
68
|
-
job_type: Optional[str]
|
69
|
-
dir: Optional[StrPath]
|
70
|
-
config: Union[Dict, str, None]
|
71
|
-
project: Optional[str]
|
72
|
-
entity: Optional[str]
|
73
|
-
reinit: Optional[bool]
|
74
|
-
tags: Optional[Sequence]
|
75
|
-
group: Optional[str]
|
76
|
-
name: Optional[str]
|
77
|
-
notes: Optional[str]
|
78
|
-
magic: Optional[Union[dict, str, bool]]
|
79
|
-
config_exclude_keys: Optional[List[str]]
|
80
|
-
config_include_keys: Optional[List[str]]
|
81
|
-
anonymous: Optional[str]
|
82
|
-
mode: Optional[str]
|
83
|
-
allow_val_change: Optional[bool]
|
84
|
-
resume: Optional[Union[bool, str]]
|
85
|
-
force: Optional[bool]
|
86
|
-
tensorboard: Optional[bool]
|
87
|
-
sync_tensorboard: Optional[bool]
|
88
|
-
monitor_gym: Optional[bool]
|
89
|
-
save_code: Optional[bool]
|
90
|
-
id: Optional[str]
|
91
|
-
settings: Union["WBSettings", Dict[str, Any], None]
|
92
|
-
|
93
|
-
|
94
|
-
class WandbTracer(BaseTracer):
|
95
|
-
"""Callback Handler that logs to Weights and Biases.
|
96
|
-
|
97
|
-
This handler will log the model architecture and run traces to Weights and Biases. This will
|
98
|
-
ensure that all LangChain activity is logged to W&B.
|
99
|
-
"""
|
100
|
-
|
101
|
-
_run: Optional["WBRun"] = None
|
102
|
-
_run_args: Optional[WandbRunArgs] = None
|
103
|
-
|
104
|
-
@classmethod
|
105
|
-
def init(
|
106
|
-
cls,
|
107
|
-
run_args: Optional[WandbRunArgs] = None,
|
108
|
-
include_stdout: bool = True,
|
109
|
-
additional_handlers: Optional[List["BaseCallbackHandler"]] = None,
|
110
|
-
) -> None:
|
111
|
-
"""Method provided for backwards compatibility. Please directly construct `WandbTracer` instead."""
|
112
|
-
message = """Global autologging is not currently supported for the LangChain integration.
|
113
|
-
Please directly construct a `WandbTracer` and add it to the list of callbacks. For example:
|
114
|
-
|
115
|
-
LLMChain(llm, callbacks=[WandbTracer()])
|
116
|
-
# end of notebook / script:
|
117
|
-
WandbTracer.finish()"""
|
118
|
-
wandb.termlog(message)
|
119
|
-
|
120
|
-
def __init__(self, run_args: Optional[WandbRunArgs] = None, **kwargs: Any) -> None:
|
121
|
-
"""Initializes the WandbTracer.
|
122
|
-
|
123
|
-
Parameters:
|
124
|
-
run_args: (dict, optional) Arguments to pass to `wandb.init()`. If not provided, `wandb.init()` will be
|
125
|
-
called with no arguments. Please refer to the `wandb.init` for more details.
|
126
|
-
|
127
|
-
To use W&B to monitor all LangChain activity, add this tracer like any other langchain callback
|
128
|
-
```
|
129
|
-
from wandb.integration.langchain import WandbTracer
|
130
|
-
LLMChain(llm, callbacks=[WandbTracer()])
|
131
|
-
# end of notebook / script:
|
132
|
-
WandbTracer.finish()
|
133
|
-
```.
|
134
|
-
"""
|
135
|
-
super().__init__(**kwargs)
|
136
|
-
self._run_args = run_args
|
137
|
-
self._ensure_run(should_print_url=(wandb.run is None))
|
138
|
-
|
139
|
-
@staticmethod
|
140
|
-
def finish() -> None:
|
141
|
-
"""Waits for all asynchronous processes to finish and data to upload.
|
142
|
-
|
143
|
-
Proxy for `wandb.finish()`.
|
144
|
-
"""
|
145
|
-
wandb.finish()
|
146
|
-
|
147
|
-
def _log_trace_from_run(self, run: "Run") -> None:
|
148
|
-
"""Logs a LangChain Run to W*B as a W&B Trace."""
|
149
|
-
self._ensure_run()
|
150
|
-
|
151
|
-
root_span = safely_convert_lc_run_to_wb_span(run)
|
152
|
-
if root_span is None:
|
153
|
-
return
|
154
|
-
|
155
|
-
model_dict = None
|
156
|
-
|
157
|
-
# TODO: Uncomment this once we have a way to get the model from a run
|
158
|
-
# model = safely_get_span_producing_model(run)
|
159
|
-
# if model is not None:
|
160
|
-
# model_dict = safely_convert_model_to_dict(model)
|
161
|
-
|
162
|
-
model_trace = trace_tree.WBTraceTree(
|
163
|
-
root_span=root_span,
|
164
|
-
model_dict=model_dict,
|
35
|
+
from langchain.callbacks.tracers import WandbTracer # noqa: E402, I001
|
36
|
+
|
37
|
+
|
38
|
+
class WandbTracer(WandbTracer):
|
39
|
+
def __init__(self, *args, **kwargs):
|
40
|
+
super().__init__(*args, **kwargs)
|
41
|
+
deprecate.deprecate(
|
42
|
+
field_name=deprecate.Deprecated.langchain_tracer,
|
43
|
+
warning_message="This feature is deprecated and has been moved to `langchain`. Enable tracing by setting "
|
44
|
+
"LANGCHAIN_WANDB_TRACING=true in your environment. See the documentation at "
|
45
|
+
"https://python.langchain.com/docs/ecosystem/integrations/agent_with_wandb_tracing for guidance. "
|
46
|
+
"Replace your current import with `from langchain.callbacks.tracers import WandbTracer`.",
|
165
47
|
)
|
166
|
-
wandb.run.log({"langchain_trace": model_trace})
|
167
|
-
|
168
|
-
def _ensure_run(self, should_print_url=False) -> None:
|
169
|
-
"""Ensures an active W&B run exists.
|
170
|
-
|
171
|
-
If not, will start a new run with the provided run_args.
|
172
|
-
"""
|
173
|
-
if wandb.run is None:
|
174
|
-
# Make a shallow copy of the run args, so we don't modify the original
|
175
|
-
run_args = self._run_args or {} # type: ignore
|
176
|
-
run_args: dict = {**run_args} # type: ignore
|
177
|
-
|
178
|
-
# Prefer to run in silent mode since W&B has a lot of output
|
179
|
-
# which can be undesirable when dealing with text-based models.
|
180
|
-
if "settings" not in run_args: # type: ignore
|
181
|
-
run_args["settings"] = {"silent": True} # type: ignore
|
182
|
-
|
183
|
-
# Start the run and add the stream table
|
184
|
-
wandb.init(**run_args)
|
185
|
-
|
186
|
-
if should_print_url:
|
187
|
-
print_wandb_init_message(wandb.run.settings.run_url)
|
188
|
-
|
189
|
-
with wb_telemetry.context(wandb.run) as tel:
|
190
|
-
tel.feature.langchain_tracer = True
|
191
|
-
|
192
|
-
# Start of required methods (these methods are required by the BaseCallbackHandler interface)
|
193
|
-
@property
|
194
|
-
def always_verbose(self) -> bool:
|
195
|
-
"""Whether to call verbose callbacks even if verbose is False."""
|
196
|
-
return True
|
197
|
-
|
198
|
-
def _generate_id(self) -> Optional[Union[int, str]]:
|
199
|
-
"""Generate an id for a run."""
|
200
|
-
return None
|
201
|
-
|
202
|
-
def _persist_run(self, run: "Run") -> None:
|
203
|
-
"""Persist a run."""
|
204
|
-
try:
|
205
|
-
self._log_trace_from_run(run)
|
206
|
-
except Exception:
|
207
|
-
# Silently ignore errors to not break user code
|
208
|
-
pass
|
209
|
-
|
210
|
-
# End of required methods
|
@@ -1,151 +1,19 @@
|
|
1
|
-
import functools
|
2
1
|
import logging
|
3
|
-
import sys
|
4
|
-
from typing import Any, Dict, List, Optional
|
5
2
|
|
6
|
-
|
7
|
-
import wandb.util
|
8
|
-
from wandb.sdk.lib import telemetry as wb_telemetry
|
9
|
-
from wandb.sdk.lib.timer import Timer
|
3
|
+
from wandb.sdk.integration_utils.auto_logging import AutologAPI
|
10
4
|
|
11
5
|
from .resolver import OpenAIRequestResponseResolver
|
12
6
|
|
13
|
-
if sys.version_info >= (3, 8):
|
14
|
-
from typing import Literal
|
15
|
-
else:
|
16
|
-
from typing_extensions import Literal
|
17
|
-
|
18
|
-
|
19
7
|
logger = logging.getLogger(__name__)
|
20
8
|
|
21
9
|
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
"
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
def __init__(self) -> None:
|
33
|
-
"""Patches the OpenAI API to log traces to W&B."""
|
34
|
-
self.original_methods: Dict[str, Any] = {}
|
35
|
-
self.resolver = OpenAIRequestResponseResolver()
|
36
|
-
self._openai = None
|
37
|
-
|
38
|
-
@property
|
39
|
-
def openai(self) -> Any:
|
40
|
-
"""Returns the openai module."""
|
41
|
-
if self._openai is None:
|
42
|
-
self._openai = wandb.util.get_module(
|
43
|
-
name="openai",
|
44
|
-
required="To use the W&B OpenAI Autolog, you need to have the `openai` python "
|
45
|
-
"package installed. Please install it with `pip install openai`.",
|
46
|
-
lazy=False,
|
47
|
-
)
|
48
|
-
return self._openai
|
49
|
-
|
50
|
-
def patch(self, run: "wandb.sdk.wandb_run.Run") -> None:
|
51
|
-
"""Patches the OpenAI API to log traces to W&B."""
|
52
|
-
for symbol in self.symbols:
|
53
|
-
original = getattr(self.openai, symbol).create
|
54
|
-
|
55
|
-
def method_factory(original_method: Any):
|
56
|
-
@functools.wraps(original_method)
|
57
|
-
def create(*args, **kwargs):
|
58
|
-
with Timer() as timer:
|
59
|
-
result = original_method(*args, **kwargs)
|
60
|
-
try:
|
61
|
-
trace = self.resolver(kwargs, result, timer.elapsed)
|
62
|
-
if trace is not None:
|
63
|
-
run.log({"trace": trace})
|
64
|
-
except Exception:
|
65
|
-
# logger.warning(e)
|
66
|
-
pass
|
67
|
-
return result
|
68
|
-
|
69
|
-
return create
|
70
|
-
|
71
|
-
# save original method
|
72
|
-
self.original_methods[symbol] = original
|
73
|
-
# monkeypatch
|
74
|
-
getattr(self.openai, symbol).create = method_factory(original)
|
75
|
-
|
76
|
-
def unpatch(self) -> None:
|
77
|
-
"""Unpatches the OpenAI API."""
|
78
|
-
for symbol, original in self.original_methods.items():
|
79
|
-
getattr(self.openai, symbol).create = original
|
80
|
-
|
81
|
-
|
82
|
-
class AutologOpenAI:
|
83
|
-
def __init__(self) -> None:
|
84
|
-
"""Autolog OpenAI API calls to W&B."""
|
85
|
-
self._patch_openai_api = PatchOpenAIAPI()
|
86
|
-
self._run: Optional["wandb.sdk.wandb_run.Run"] = None
|
87
|
-
self.__run_created_by_autolog: bool = False
|
88
|
-
|
89
|
-
@property
|
90
|
-
def _is_enabled(self) -> bool:
|
91
|
-
"""Returns whether autologging is enabled."""
|
92
|
-
return self._run is not None
|
93
|
-
|
94
|
-
def __call__(self, init: AutologOpenAIInitArgs = None) -> None:
|
95
|
-
"""Enable OpenAI autologging."""
|
96
|
-
self.enable(init=init)
|
97
|
-
|
98
|
-
def _run_init(self, init: AutologOpenAIInitArgs = None) -> None:
|
99
|
-
"""Handle wandb run initialization."""
|
100
|
-
# - autolog(init: dict = {...}) calls wandb.init(**{...})
|
101
|
-
# regardless of whether there is a wandb.run or not,
|
102
|
-
# we only track if the run was created by autolog
|
103
|
-
# - todo: autolog(init: dict | run = run) would use the user-provided run
|
104
|
-
# - autolog() uses the wandb.run if there is one, otherwise it calls wandb.init()
|
105
|
-
if init:
|
106
|
-
_wandb_run = wandb.run
|
107
|
-
# we delegate dealing with the init dict to wandb.init()
|
108
|
-
self._run = wandb.init(**init)
|
109
|
-
if _wandb_run != self._run:
|
110
|
-
self.__run_created_by_autolog = True
|
111
|
-
elif wandb.run is None:
|
112
|
-
self._run = wandb.init()
|
113
|
-
self.__run_created_by_autolog = True
|
114
|
-
else:
|
115
|
-
self._run = wandb.run
|
116
|
-
|
117
|
-
def enable(self, init: AutologOpenAIInitArgs = None) -> None:
|
118
|
-
"""Enable OpenAI autologging.
|
119
|
-
|
120
|
-
Args:
|
121
|
-
init: Optional dictionary of arguments to pass to wandb.init().
|
122
|
-
|
123
|
-
"""
|
124
|
-
if self._is_enabled:
|
125
|
-
logger.info(
|
126
|
-
"OpenAI autologging is already enabled, disabling and re-enabling."
|
127
|
-
)
|
128
|
-
self.disable()
|
129
|
-
|
130
|
-
logger.info("Enabling OpenAI autologging.")
|
131
|
-
self._run_init(init=init)
|
132
|
-
|
133
|
-
self._patch_openai_api.patch(self._run)
|
134
|
-
|
135
|
-
with wb_telemetry.context(self._run) as tel:
|
136
|
-
tel.feature.openai_autolog = True
|
137
|
-
|
138
|
-
def disable(self) -> None:
|
139
|
-
"""Disable OpenAI autologging."""
|
140
|
-
if self._run is None:
|
141
|
-
return
|
142
|
-
|
143
|
-
logger.info("Disabling OpenAI autologging.")
|
144
|
-
|
145
|
-
if self.__run_created_by_autolog:
|
146
|
-
self._run.finish()
|
147
|
-
self.__run_created_by_autolog = False
|
148
|
-
|
149
|
-
self._run = None
|
150
|
-
|
151
|
-
self._patch_openai_api.unpatch()
|
10
|
+
autolog = AutologAPI(
|
11
|
+
name="OpenAI",
|
12
|
+
symbols=(
|
13
|
+
"Edit.create",
|
14
|
+
"Completion.create",
|
15
|
+
"ChatCompletion.create",
|
16
|
+
),
|
17
|
+
resolver=OpenAIRequestResponseResolver(),
|
18
|
+
telemetry_feature="openai_autolog",
|
19
|
+
)
|