wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/importers/base.py +20 -5
  4. wandb/apis/importers/mlflow.py +7 -1
  5. wandb/apis/internal.py +12 -0
  6. wandb/apis/public.py +247 -1387
  7. wandb/apis/reports/_panels.py +58 -35
  8. wandb/beta/workflows.py +6 -7
  9. wandb/cli/cli.py +130 -60
  10. wandb/data_types.py +3 -1
  11. wandb/filesync/dir_watcher.py +21 -27
  12. wandb/filesync/step_checksum.py +8 -8
  13. wandb/filesync/step_prepare.py +23 -10
  14. wandb/filesync/step_upload.py +13 -13
  15. wandb/filesync/upload_job.py +4 -8
  16. wandb/integration/cohere/__init__.py +3 -0
  17. wandb/integration/cohere/cohere.py +21 -0
  18. wandb/integration/cohere/resolver.py +347 -0
  19. wandb/integration/gym/__init__.py +4 -6
  20. wandb/integration/huggingface/__init__.py +3 -0
  21. wandb/integration/huggingface/huggingface.py +18 -0
  22. wandb/integration/huggingface/resolver.py +213 -0
  23. wandb/integration/langchain/wandb_tracer.py +16 -179
  24. wandb/integration/openai/__init__.py +1 -3
  25. wandb/integration/openai/openai.py +11 -143
  26. wandb/integration/openai/resolver.py +111 -38
  27. wandb/integration/sagemaker/config.py +2 -2
  28. wandb/integration/tensorboard/log.py +4 -4
  29. wandb/old/settings.py +24 -7
  30. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  31. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  32. wandb/proto/wandb_deprecated.py +3 -1
  33. wandb/sdk/__init__.py +1 -1
  34. wandb/sdk/artifacts/__init__.py +0 -0
  35. wandb/sdk/artifacts/artifact.py +2101 -0
  36. wandb/sdk/artifacts/artifact_download_logger.py +42 -0
  37. wandb/sdk/artifacts/artifact_manifest.py +67 -0
  38. wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
  39. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  40. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
  41. wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
  42. wandb/sdk/artifacts/artifact_state.py +10 -0
  43. wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
  44. wandb/sdk/artifacts/exceptions.py +55 -0
  45. wandb/sdk/artifacts/storage_handler.py +59 -0
  46. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  47. wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
  48. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  49. wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
  50. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
  51. wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
  53. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
  54. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
  55. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  56. wandb/sdk/artifacts/storage_layout.py +6 -0
  57. wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
  58. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
  60. wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
  61. wandb/sdk/data_types/_dtypes.py +7 -12
  62. wandb/sdk/data_types/base_types/json_metadata.py +3 -2
  63. wandb/sdk/data_types/base_types/media.py +8 -8
  64. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  65. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
  66. wandb/sdk/data_types/helper_types/classes.py +6 -8
  67. wandb/sdk/data_types/helper_types/image_mask.py +5 -6
  68. wandb/sdk/data_types/histogram.py +4 -3
  69. wandb/sdk/data_types/html.py +3 -4
  70. wandb/sdk/data_types/image.py +11 -9
  71. wandb/sdk/data_types/molecule.py +5 -3
  72. wandb/sdk/data_types/object_3d.py +7 -5
  73. wandb/sdk/data_types/plotly.py +3 -2
  74. wandb/sdk/data_types/saved_model.py +11 -11
  75. wandb/sdk/data_types/trace_tree.py +5 -4
  76. wandb/sdk/data_types/utils.py +3 -5
  77. wandb/sdk/data_types/video.py +5 -4
  78. wandb/sdk/integration_utils/auto_logging.py +215 -0
  79. wandb/sdk/interface/interface.py +15 -15
  80. wandb/sdk/internal/file_pusher.py +8 -16
  81. wandb/sdk/internal/file_stream.py +5 -11
  82. wandb/sdk/internal/handler.py +13 -1
  83. wandb/sdk/internal/internal_api.py +287 -13
  84. wandb/sdk/internal/job_builder.py +119 -30
  85. wandb/sdk/internal/sender.py +6 -26
  86. wandb/sdk/internal/settings_static.py +2 -0
  87. wandb/sdk/internal/system/assets/__init__.py +2 -0
  88. wandb/sdk/internal/system/assets/gpu.py +42 -0
  89. wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
  90. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  91. wandb/sdk/internal/system/system_info.py +3 -3
  92. wandb/sdk/internal/tb_watcher.py +32 -22
  93. wandb/sdk/internal/thread_local_settings.py +18 -0
  94. wandb/sdk/launch/_project_spec.py +57 -11
  95. wandb/sdk/launch/agent/agent.py +147 -65
  96. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  97. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  98. wandb/sdk/launch/builder/abstract.py +5 -1
  99. wandb/sdk/launch/builder/build.py +21 -18
  100. wandb/sdk/launch/builder/docker_builder.py +10 -4
  101. wandb/sdk/launch/builder/kaniko_builder.py +113 -23
  102. wandb/sdk/launch/builder/noop.py +6 -3
  103. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
  104. wandb/sdk/launch/environment/aws_environment.py +3 -2
  105. wandb/sdk/launch/environment/azure_environment.py +124 -0
  106. wandb/sdk/launch/environment/gcp_environment.py +2 -4
  107. wandb/sdk/launch/environment/local_environment.py +1 -1
  108. wandb/sdk/launch/errors.py +19 -0
  109. wandb/sdk/launch/github_reference.py +32 -19
  110. wandb/sdk/launch/launch.py +3 -8
  111. wandb/sdk/launch/launch_add.py +6 -2
  112. wandb/sdk/launch/loader.py +21 -2
  113. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  114. wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
  115. wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
  116. wandb/sdk/launch/registry/local_registry.py +2 -1
  117. wandb/sdk/launch/runner/abstract.py +24 -3
  118. wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
  119. wandb/sdk/launch/runner/local_container.py +103 -51
  120. wandb/sdk/launch/runner/local_process.py +1 -1
  121. wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
  122. wandb/sdk/launch/runner/vertex_runner.py +10 -5
  123. wandb/sdk/launch/sweeps/__init__.py +7 -9
  124. wandb/sdk/launch/sweeps/scheduler.py +307 -77
  125. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  126. wandb/sdk/launch/sweeps/utils.py +82 -35
  127. wandb/sdk/launch/utils.py +89 -75
  128. wandb/sdk/lib/_settings_toposort_generated.py +7 -0
  129. wandb/sdk/lib/capped_dict.py +26 -0
  130. wandb/sdk/lib/{git.py → gitlib.py} +76 -59
  131. wandb/sdk/lib/hashutil.py +12 -4
  132. wandb/sdk/lib/paths.py +96 -8
  133. wandb/sdk/lib/sock_client.py +2 -2
  134. wandb/sdk/lib/timer.py +1 -0
  135. wandb/sdk/service/server.py +22 -9
  136. wandb/sdk/service/server_sock.py +1 -1
  137. wandb/sdk/service/service.py +27 -8
  138. wandb/sdk/verify/verify.py +4 -7
  139. wandb/sdk/wandb_config.py +2 -6
  140. wandb/sdk/wandb_init.py +57 -53
  141. wandb/sdk/wandb_require.py +7 -0
  142. wandb/sdk/wandb_run.py +61 -223
  143. wandb/sdk/wandb_settings.py +28 -4
  144. wandb/testing/relay.py +15 -2
  145. wandb/util.py +74 -36
  146. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
  147. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
  148. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
  149. wandb/integration/langchain/util.py +0 -191
  150. wandb/sdk/interface/artifacts/__init__.py +0 -33
  151. wandb/sdk/interface/artifacts/artifact.py +0 -615
  152. wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
  153. wandb/sdk/wandb_artifacts.py +0 -2226
  154. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  155. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  156. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,213 @@
1
+ import logging
2
+ import os
3
+ from datetime import datetime
4
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
5
+
6
+ import pytz
7
+
8
+ import wandb
9
+ from wandb.sdk.integration_utils.auto_logging import Response
10
+ from wandb.sdk.lib.runid import generate_id
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ SUPPORTED_PIPELINE_TASKS = [
15
+ "text-classification",
16
+ "sentiment-analysis",
17
+ "question-answering",
18
+ "summarization",
19
+ "translation",
20
+ "text2text-generation",
21
+ "text-generation",
22
+ # "conversational",
23
+ ]
24
+
25
+ PIPELINES_WITH_TOP_K = [
26
+ "text-classification",
27
+ "sentiment-analysis",
28
+ "question-answering",
29
+ ]
30
+
31
+
32
+ class HuggingFacePipelineRequestResponseResolver:
33
+ """Resolver for HuggingFace's pipeline request and responses, providing necessary data transformations and formatting.
34
+
35
+ This is based off (from wandb.sdk.integration_utils.auto_logging import RequestResponseResolver)
36
+ """
37
+
38
+ autolog_id = None
39
+
40
+ def __call__(
41
+ self,
42
+ args: Sequence[Any],
43
+ kwargs: Dict[str, Any],
44
+ response: Response,
45
+ start_time: float,
46
+ time_elapsed: float,
47
+ ) -> Optional[Dict[str, Any]]:
48
+ """Main call method for this class.
49
+
50
+ :param args: list of arguments
51
+ :param kwargs: dictionary of keyword arguments
52
+ :param response: the response from the request
53
+ :param start_time: time when request started
54
+ :param time_elapsed: time elapsed for the request
55
+ :returns: packed data as a dictionary for logging to wandb, None if an exception occurred
56
+ """
57
+ try:
58
+ pipe, input_data = args[:2]
59
+ task = pipe.task
60
+
61
+ # Translation tasks are in the form of `translation_x_to_y`
62
+ if task in SUPPORTED_PIPELINE_TASKS or task.startswith("translation"):
63
+ model = self._get_model(pipe)
64
+ if model is None:
65
+ return None
66
+ model_alias = model.name_or_path
67
+ timestamp = datetime.now(pytz.utc)
68
+
69
+ input_data, response = self._transform_task_specific_data(
70
+ task, input_data, response
71
+ )
72
+ formatted_data = self._format_data(task, input_data, response, kwargs)
73
+ packed_data = self._create_table(
74
+ formatted_data, model_alias, timestamp, time_elapsed
75
+ )
76
+ table_name = os.environ.get("WANDB_AUTOLOG_TABLE_NAME", f"{task}")
77
+ # TODO: Let users decide the name in a way that does not use an environment variable
78
+
79
+ return {
80
+ table_name: wandb.Table(
81
+ columns=packed_data[0], data=packed_data[1:]
82
+ )
83
+ }
84
+
85
+ logger.warning(
86
+ f"The task: `{task}` is not yet supported.\nPlease contact `wandb` to notify us if you would like support for this task"
87
+ )
88
+ except Exception as e:
89
+ logger.warning(e)
90
+ return None
91
+
92
+ # TODO: This should have a dependency on PreTrainedModel. i.e. isinstance(PreTrainedModel)
93
+ # from transformers.modeling_utils import PreTrainedModel
94
+ # We do not want this dependency explicity in our codebase so we make a very general assumption about
95
+ # the structure of the pipeline which may have unintended consequences
96
+ def _get_model(self, pipe) -> Optional[Any]:
97
+ """Extracts model from the pipeline.
98
+
99
+ :param pipe: the HuggingFace pipeline
100
+ :returns: Model if available, None otherwise
101
+ """
102
+ model = pipe.model
103
+ try:
104
+ return model.model
105
+ except AttributeError:
106
+ logger.info(
107
+ "Model does not have a `.model` attribute. Assuming `pipe.model` is the correct model."
108
+ )
109
+ return model
110
+
111
+ @staticmethod
112
+ def _transform_task_specific_data(
113
+ task: str, input_data: Union[List[Any], Any], response: Union[List[Any], Any]
114
+ ) -> Tuple[Union[List[Any], Any], Union[List[Any], Any]]:
115
+ """Transform input and response data based on specific tasks.
116
+
117
+ :param task: the task name
118
+ :param input_data: the input data
119
+ :param response: the response data
120
+ :returns: tuple of transformed input_data and response
121
+ """
122
+ if task == "question-answering":
123
+ input_data = input_data if isinstance(input_data, list) else [input_data]
124
+ input_data = [data.__dict__ for data in input_data]
125
+ elif task == "conversational":
126
+ # We only grab the latest input/output pair from the conversation
127
+ # Logging the whole conversation renders strangely.
128
+ input_data = input_data if isinstance(input_data, list) else [input_data]
129
+ input_data = [data.__dict__["past_user_inputs"][-1] for data in input_data]
130
+
131
+ response = response if isinstance(response, list) else [response]
132
+ response = [data.__dict__["generated_responses"][-1] for data in response]
133
+ return input_data, response
134
+
135
+ def _format_data(
136
+ self,
137
+ task: str,
138
+ input_data: Union[List[Any], Any],
139
+ response: Union[List[Any], Any],
140
+ kwargs: Dict[str, Any],
141
+ ) -> List[Dict[str, Any]]:
142
+ """Formats input data, response, and kwargs into a list of dictionaries.
143
+
144
+ :param task: the task name
145
+ :param input_data: the input data
146
+ :param response: the response data
147
+ :param kwargs: dictionary of keyword arguments
148
+ :returns: list of dictionaries containing formatted data
149
+ """
150
+ input_data = input_data if isinstance(input_data, list) else [input_data]
151
+ response = response if isinstance(response, list) else [response]
152
+
153
+ formatted_data = []
154
+ for i_text, r_text in zip(input_data, response):
155
+ # Unpack single element responses for better rendering in wandb UI when it is a task without top_k
156
+ # top_k = 1 would unpack the response into a single element while top_k > 1 would be a list
157
+ # this would cause the UI to not properly concatenate the tables of the same task by omitting the elements past the first
158
+ if (
159
+ (isinstance(r_text, list))
160
+ and (len(r_text) == 1)
161
+ and task not in PIPELINES_WITH_TOP_K
162
+ ):
163
+ r_text = r_text[0]
164
+ formatted_data.append(
165
+ {"input": i_text, "response": r_text, "kwargs": kwargs}
166
+ )
167
+ return formatted_data
168
+
169
+ def _create_table(
170
+ self,
171
+ formatted_data: List[Dict[str, Any]],
172
+ model_alias: str,
173
+ timestamp: float,
174
+ time_elapsed: float,
175
+ ) -> List[List[Any]]:
176
+ """Creates a table from formatted data, model alias, timestamp, and elapsed time.
177
+
178
+ :param formatted_data: list of dictionaries containing formatted data
179
+ :param model_alias: alias of the model
180
+ :param timestamp: timestamp of the data
181
+ :param time_elapsed: time elapsed from the beginning
182
+ :returns: list of lists, representing a table of data. [0]th element = columns. [1]st element = data
183
+ """
184
+ header = [
185
+ "ID",
186
+ "Model Alias",
187
+ "Timestamp",
188
+ "Elapsed Time",
189
+ "Input",
190
+ "Response",
191
+ "Kwargs",
192
+ ]
193
+ table = [header]
194
+ autolog_id = generate_id(length=16)
195
+
196
+ for data in formatted_data:
197
+ row = [
198
+ autolog_id,
199
+ model_alias,
200
+ timestamp,
201
+ time_elapsed,
202
+ data["input"],
203
+ data["response"],
204
+ data["kwargs"],
205
+ ]
206
+ table.append(row)
207
+
208
+ self.autolog_id = autolog_id
209
+
210
+ return table
211
+
212
+ def get_latest_id(self):
213
+ return self.autolog_id
@@ -14,22 +14,10 @@ integration will not break user code. The one exception to the rule is at import
14
14
  LangChain is not installed, or the symbols are not in the same place, the appropriate error
15
15
  will be raised when importing this module.
16
16
  """
17
- import sys
18
-
19
- if sys.version_info >= (3, 8):
20
- from typing import TypedDict
21
- else:
22
- from typing_extensions import TypedDict
23
-
24
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
25
-
26
17
  from packaging import version
27
18
 
28
- import wandb
29
19
  import wandb.util
30
- from wandb.sdk.data_types import trace_tree
31
- from wandb.sdk.lib import telemetry as wb_telemetry
32
- from wandb.sdk.lib.paths import StrPath
20
+ from wandb.sdk.lib import deprecate
33
21
 
34
22
  langchain = wandb.util.get_module(
35
23
  name="langchain",
@@ -37,174 +25,23 @@ langchain = wandb.util.get_module(
37
25
  "package installed. Please install it with `pip install langchain`.",
38
26
  )
39
27
 
40
- if version.parse(langchain.__version__) < version.parse("0.0.170"):
28
+ if version.parse(langchain.__version__) < version.parse("0.0.188"):
41
29
  raise ValueError(
42
- "The Weights & Biases Langchain integration does not support versions 0.0.169 and lower. "
43
- "To ensure proper functionality, please use version 0.0.170 or higher."
30
+ "The Weights & Biases Langchain integration does not support versions 0.0.187 and lower. "
31
+ "To ensure proper functionality, please use version 0.0.188 or higher."
44
32
  )
45
33
 
46
- # We want these imports after the import_langchain() call, so that we can
47
- # catch the ImportError if langchain is not installed.
48
-
49
34
  # isort: off
50
- from langchain.callbacks.tracers.base import BaseTracer # noqa: E402, I001
51
-
52
- from .util import ( # noqa: E402
53
- print_wandb_init_message,
54
- safely_convert_lc_run_to_wb_span,
55
- # safely_convert_model_to_dict,
56
- # safely_get_span_producing_model,
57
- )
58
-
59
- if TYPE_CHECKING:
60
- from langchain.callbacks.base import BaseCallbackHandler
61
- from langchain.callbacks.tracers.schemas import Run
62
-
63
- from wandb import Settings as WBSettings
64
- from wandb.wandb_run import Run as WBRun
65
-
66
-
67
- class WandbRunArgs(TypedDict):
68
- job_type: Optional[str]
69
- dir: Optional[StrPath]
70
- config: Union[Dict, str, None]
71
- project: Optional[str]
72
- entity: Optional[str]
73
- reinit: Optional[bool]
74
- tags: Optional[Sequence]
75
- group: Optional[str]
76
- name: Optional[str]
77
- notes: Optional[str]
78
- magic: Optional[Union[dict, str, bool]]
79
- config_exclude_keys: Optional[List[str]]
80
- config_include_keys: Optional[List[str]]
81
- anonymous: Optional[str]
82
- mode: Optional[str]
83
- allow_val_change: Optional[bool]
84
- resume: Optional[Union[bool, str]]
85
- force: Optional[bool]
86
- tensorboard: Optional[bool]
87
- sync_tensorboard: Optional[bool]
88
- monitor_gym: Optional[bool]
89
- save_code: Optional[bool]
90
- id: Optional[str]
91
- settings: Union["WBSettings", Dict[str, Any], None]
92
-
93
-
94
- class WandbTracer(BaseTracer):
95
- """Callback Handler that logs to Weights and Biases.
96
-
97
- This handler will log the model architecture and run traces to Weights and Biases. This will
98
- ensure that all LangChain activity is logged to W&B.
99
- """
100
-
101
- _run: Optional["WBRun"] = None
102
- _run_args: Optional[WandbRunArgs] = None
103
-
104
- @classmethod
105
- def init(
106
- cls,
107
- run_args: Optional[WandbRunArgs] = None,
108
- include_stdout: bool = True,
109
- additional_handlers: Optional[List["BaseCallbackHandler"]] = None,
110
- ) -> None:
111
- """Method provided for backwards compatibility. Please directly construct `WandbTracer` instead."""
112
- message = """Global autologging is not currently supported for the LangChain integration.
113
- Please directly construct a `WandbTracer` and add it to the list of callbacks. For example:
114
-
115
- LLMChain(llm, callbacks=[WandbTracer()])
116
- # end of notebook / script:
117
- WandbTracer.finish()"""
118
- wandb.termlog(message)
119
-
120
- def __init__(self, run_args: Optional[WandbRunArgs] = None, **kwargs: Any) -> None:
121
- """Initializes the WandbTracer.
122
-
123
- Parameters:
124
- run_args: (dict, optional) Arguments to pass to `wandb.init()`. If not provided, `wandb.init()` will be
125
- called with no arguments. Please refer to the `wandb.init` for more details.
126
-
127
- To use W&B to monitor all LangChain activity, add this tracer like any other langchain callback
128
- ```
129
- from wandb.integration.langchain import WandbTracer
130
- LLMChain(llm, callbacks=[WandbTracer()])
131
- # end of notebook / script:
132
- WandbTracer.finish()
133
- ```.
134
- """
135
- super().__init__(**kwargs)
136
- self._run_args = run_args
137
- self._ensure_run(should_print_url=(wandb.run is None))
138
-
139
- @staticmethod
140
- def finish() -> None:
141
- """Waits for all asynchronous processes to finish and data to upload.
142
-
143
- Proxy for `wandb.finish()`.
144
- """
145
- wandb.finish()
146
-
147
- def _log_trace_from_run(self, run: "Run") -> None:
148
- """Logs a LangChain Run to W*B as a W&B Trace."""
149
- self._ensure_run()
150
-
151
- root_span = safely_convert_lc_run_to_wb_span(run)
152
- if root_span is None:
153
- return
154
-
155
- model_dict = None
156
-
157
- # TODO: Uncomment this once we have a way to get the model from a run
158
- # model = safely_get_span_producing_model(run)
159
- # if model is not None:
160
- # model_dict = safely_convert_model_to_dict(model)
161
-
162
- model_trace = trace_tree.WBTraceTree(
163
- root_span=root_span,
164
- model_dict=model_dict,
35
+ from langchain.callbacks.tracers import WandbTracer # noqa: E402, I001
36
+
37
+
38
+ class WandbTracer(WandbTracer):
39
+ def __init__(self, *args, **kwargs):
40
+ super().__init__(*args, **kwargs)
41
+ deprecate.deprecate(
42
+ field_name=deprecate.Deprecated.langchain_tracer,
43
+ warning_message="This feature is deprecated and has been moved to `langchain`. Enable tracing by setting "
44
+ "LANGCHAIN_WANDB_TRACING=true in your environment. See the documentation at "
45
+ "https://python.langchain.com/docs/ecosystem/integrations/agent_with_wandb_tracing for guidance. "
46
+ "Replace your current import with `from langchain.callbacks.tracers import WandbTracer`.",
165
47
  )
166
- wandb.run.log({"langchain_trace": model_trace})
167
-
168
- def _ensure_run(self, should_print_url=False) -> None:
169
- """Ensures an active W&B run exists.
170
-
171
- If not, will start a new run with the provided run_args.
172
- """
173
- if wandb.run is None:
174
- # Make a shallow copy of the run args, so we don't modify the original
175
- run_args = self._run_args or {} # type: ignore
176
- run_args: dict = {**run_args} # type: ignore
177
-
178
- # Prefer to run in silent mode since W&B has a lot of output
179
- # which can be undesirable when dealing with text-based models.
180
- if "settings" not in run_args: # type: ignore
181
- run_args["settings"] = {"silent": True} # type: ignore
182
-
183
- # Start the run and add the stream table
184
- wandb.init(**run_args)
185
-
186
- if should_print_url:
187
- print_wandb_init_message(wandb.run.settings.run_url)
188
-
189
- with wb_telemetry.context(wandb.run) as tel:
190
- tel.feature.langchain_tracer = True
191
-
192
- # Start of required methods (these methods are required by the BaseCallbackHandler interface)
193
- @property
194
- def always_verbose(self) -> bool:
195
- """Whether to call verbose callbacks even if verbose is False."""
196
- return True
197
-
198
- def _generate_id(self) -> Optional[Union[int, str]]:
199
- """Generate an id for a run."""
200
- return None
201
-
202
- def _persist_run(self, run: "Run") -> None:
203
- """Persist a run."""
204
- try:
205
- self._log_trace_from_run(run)
206
- except Exception:
207
- # Silently ignore errors to not break user code
208
- pass
209
-
210
- # End of required methods
@@ -1,5 +1,3 @@
1
1
  __all__ = ("autolog",)
2
2
 
3
- from .openai import AutologOpenAI
4
-
5
- autolog = AutologOpenAI()
3
+ from .openai import autolog
@@ -1,151 +1,19 @@
1
- import functools
2
1
  import logging
3
- import sys
4
- from typing import Any, Dict, List, Optional
5
2
 
6
- import wandb.sdk
7
- import wandb.util
8
- from wandb.sdk.lib import telemetry as wb_telemetry
9
- from wandb.sdk.lib.timer import Timer
3
+ from wandb.sdk.integration_utils.auto_logging import AutologAPI
10
4
 
11
5
  from .resolver import OpenAIRequestResponseResolver
12
6
 
13
- if sys.version_info >= (3, 8):
14
- from typing import Literal
15
- else:
16
- from typing_extensions import Literal
17
-
18
-
19
7
  logger = logging.getLogger(__name__)
20
8
 
21
9
 
22
- AutologOpenAIInitArgs = Optional[Dict[str, Any]]
23
-
24
-
25
- class PatchOpenAIAPI:
26
- symbols: List[Literal["Edit", "Completion", "ChatCompletion"]] = [
27
- "Edit",
28
- "Completion",
29
- "ChatCompletion",
30
- ]
31
-
32
- def __init__(self) -> None:
33
- """Patches the OpenAI API to log traces to W&B."""
34
- self.original_methods: Dict[str, Any] = {}
35
- self.resolver = OpenAIRequestResponseResolver()
36
- self._openai = None
37
-
38
- @property
39
- def openai(self) -> Any:
40
- """Returns the openai module."""
41
- if self._openai is None:
42
- self._openai = wandb.util.get_module(
43
- name="openai",
44
- required="To use the W&B OpenAI Autolog, you need to have the `openai` python "
45
- "package installed. Please install it with `pip install openai`.",
46
- lazy=False,
47
- )
48
- return self._openai
49
-
50
- def patch(self, run: "wandb.sdk.wandb_run.Run") -> None:
51
- """Patches the OpenAI API to log traces to W&B."""
52
- for symbol in self.symbols:
53
- original = getattr(self.openai, symbol).create
54
-
55
- def method_factory(original_method: Any):
56
- @functools.wraps(original_method)
57
- def create(*args, **kwargs):
58
- with Timer() as timer:
59
- result = original_method(*args, **kwargs)
60
- try:
61
- trace = self.resolver(kwargs, result, timer.elapsed)
62
- if trace is not None:
63
- run.log({"trace": trace})
64
- except Exception:
65
- # logger.warning(e)
66
- pass
67
- return result
68
-
69
- return create
70
-
71
- # save original method
72
- self.original_methods[symbol] = original
73
- # monkeypatch
74
- getattr(self.openai, symbol).create = method_factory(original)
75
-
76
- def unpatch(self) -> None:
77
- """Unpatches the OpenAI API."""
78
- for symbol, original in self.original_methods.items():
79
- getattr(self.openai, symbol).create = original
80
-
81
-
82
- class AutologOpenAI:
83
- def __init__(self) -> None:
84
- """Autolog OpenAI API calls to W&B."""
85
- self._patch_openai_api = PatchOpenAIAPI()
86
- self._run: Optional["wandb.sdk.wandb_run.Run"] = None
87
- self.__run_created_by_autolog: bool = False
88
-
89
- @property
90
- def _is_enabled(self) -> bool:
91
- """Returns whether autologging is enabled."""
92
- return self._run is not None
93
-
94
- def __call__(self, init: AutologOpenAIInitArgs = None) -> None:
95
- """Enable OpenAI autologging."""
96
- self.enable(init=init)
97
-
98
- def _run_init(self, init: AutologOpenAIInitArgs = None) -> None:
99
- """Handle wandb run initialization."""
100
- # - autolog(init: dict = {...}) calls wandb.init(**{...})
101
- # regardless of whether there is a wandb.run or not,
102
- # we only track if the run was created by autolog
103
- # - todo: autolog(init: dict | run = run) would use the user-provided run
104
- # - autolog() uses the wandb.run if there is one, otherwise it calls wandb.init()
105
- if init:
106
- _wandb_run = wandb.run
107
- # we delegate dealing with the init dict to wandb.init()
108
- self._run = wandb.init(**init)
109
- if _wandb_run != self._run:
110
- self.__run_created_by_autolog = True
111
- elif wandb.run is None:
112
- self._run = wandb.init()
113
- self.__run_created_by_autolog = True
114
- else:
115
- self._run = wandb.run
116
-
117
- def enable(self, init: AutologOpenAIInitArgs = None) -> None:
118
- """Enable OpenAI autologging.
119
-
120
- Args:
121
- init: Optional dictionary of arguments to pass to wandb.init().
122
-
123
- """
124
- if self._is_enabled:
125
- logger.info(
126
- "OpenAI autologging is already enabled, disabling and re-enabling."
127
- )
128
- self.disable()
129
-
130
- logger.info("Enabling OpenAI autologging.")
131
- self._run_init(init=init)
132
-
133
- self._patch_openai_api.patch(self._run)
134
-
135
- with wb_telemetry.context(self._run) as tel:
136
- tel.feature.openai_autolog = True
137
-
138
- def disable(self) -> None:
139
- """Disable OpenAI autologging."""
140
- if self._run is None:
141
- return
142
-
143
- logger.info("Disabling OpenAI autologging.")
144
-
145
- if self.__run_created_by_autolog:
146
- self._run.finish()
147
- self.__run_created_by_autolog = False
148
-
149
- self._run = None
150
-
151
- self._patch_openai_api.unpatch()
10
+ autolog = AutologAPI(
11
+ name="OpenAI",
12
+ symbols=(
13
+ "Edit.create",
14
+ "Completion.create",
15
+ "ChatCompletion.create",
16
+ ),
17
+ resolver=OpenAIRequestResponseResolver(),
18
+ telemetry_feature="openai_autolog",
19
+ )