wandb 0.19.0__py3-none-win32.whl → 0.19.1__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. wandb/__init__.py +1 -7
  2. wandb/__init__.pyi +211 -209
  3. wandb/apis/attrs.py +15 -4
  4. wandb/apis/public/api.py +8 -4
  5. wandb/apis/public/files.py +65 -12
  6. wandb/apis/public/runs.py +52 -7
  7. wandb/apis/public/sweeps.py +1 -1
  8. wandb/bin/gpu_stats.exe +0 -0
  9. wandb/bin/wandb-core +0 -0
  10. wandb/cli/cli.py +2 -1
  11. wandb/env.py +1 -1
  12. wandb/errors/term.py +60 -1
  13. wandb/integration/keras/callbacks/tables_builder.py +3 -1
  14. wandb/integration/kfp/kfp_patch.py +25 -15
  15. wandb/integration/lightning/fabric/logger.py +3 -1
  16. wandb/integration/tensorboard/monkeypatch.py +3 -2
  17. wandb/jupyter.py +4 -5
  18. wandb/plot/bar.py +5 -6
  19. wandb/plot/histogram.py +1 -1
  20. wandb/plot/line_series.py +3 -3
  21. wandb/plot/pr_curve.py +7 -3
  22. wandb/plot/scatter.py +2 -1
  23. wandb/proto/v3/wandb_settings_pb2.py +25 -15
  24. wandb/proto/v4/wandb_settings_pb2.py +17 -15
  25. wandb/proto/v5/wandb_settings_pb2.py +17 -15
  26. wandb/sdk/artifacts/_validators.py +1 -3
  27. wandb/sdk/artifacts/artifact_manifest_entry.py +1 -1
  28. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +12 -2
  29. wandb/sdk/data_types/helper_types/image_mask.py +8 -2
  30. wandb/sdk/data_types/histogram.py +3 -3
  31. wandb/sdk/data_types/image.py +3 -1
  32. wandb/sdk/interface/interface.py +34 -5
  33. wandb/sdk/interface/interface_sock.py +2 -2
  34. wandb/sdk/internal/file_stream.py +4 -1
  35. wandb/sdk/internal/sender.py +4 -1
  36. wandb/sdk/internal/settings_static.py +17 -4
  37. wandb/sdk/launch/utils.py +1 -0
  38. wandb/sdk/lib/ipython.py +5 -27
  39. wandb/sdk/lib/printer.py +33 -20
  40. wandb/sdk/lib/progress.py +7 -1
  41. wandb/sdk/lib/sparkline.py +1 -2
  42. wandb/sdk/wandb_config.py +2 -2
  43. wandb/sdk/wandb_init.py +236 -243
  44. wandb/sdk/wandb_run.py +172 -231
  45. wandb/sdk/wandb_settings.py +104 -15
  46. {wandb-0.19.0.dist-info → wandb-0.19.1.dist-info}/METADATA +1 -1
  47. {wandb-0.19.0.dist-info → wandb-0.19.1.dist-info}/RECORD +50 -50
  48. {wandb-0.19.0.dist-info → wandb-0.19.1.dist-info}/WHEEL +0 -0
  49. {wandb-0.19.0.dist-info → wandb-0.19.1.dist-info}/entry_points.txt +0 -0
  50. {wandb-0.19.0.dist-info → wandb-0.19.1.dist-info}/licenses/LICENSE +0 -0
@@ -49,6 +49,7 @@ class Files(Paginator):
49
49
  query RunFiles($project: String!, $entity: String!, $name: String!, $fileCursor: String,
50
50
  $fileLimit: Int = 50, $fileNames: [String] = [], $upload: Boolean = false) {{
51
51
  project(name: $project, entityName: $entity) {{
52
+ internalId
52
53
  run(name: $name) {{
53
54
  fileCount
54
55
  ...RunFilesFragment
@@ -98,7 +99,7 @@ class Files(Paginator):
98
99
 
99
100
  def convert_objects(self):
100
101
  return [
101
- File(self.client, r["node"])
102
+ File(self.client, r["node"], self.run)
102
103
  for r in self.last_response["project"]["run"]["files"]["edges"]
103
104
  ]
104
105
 
@@ -120,9 +121,11 @@ class File(Attrs):
120
121
  path_uri (str): path to file in the bucket, currently only available for files stored in S3
121
122
  """
122
123
 
123
- def __init__(self, client, attrs):
124
+ def __init__(self, client, attrs, run=None):
124
125
  self.client = client
125
126
  self._attrs = attrs
127
+ self.run = run
128
+ self.server_supports_delete_file_with_project_id: Optional[bool] = None
126
129
  super().__init__(dict(attrs))
127
130
 
128
131
  @property
@@ -189,18 +192,35 @@ class File(Attrs):
189
192
 
190
193
  @normalize_exceptions
191
194
  def delete(self):
192
- mutation = gql(
193
- """
194
- mutation deleteFiles($files: [ID!]!) {
195
- deleteFiles(input: {
196
- files: $files
197
- }) {
198
- success
199
- }
195
+ project_id_mutation_fragment = ""
196
+ project_id_variable_fragment = ""
197
+ variable_values = {
198
+ "files": [self.id],
200
199
  }
201
- """
200
+
201
+ # Add projectId to mutation and variables if the server supports it.
202
+ # Otherwise, do not include projectId in mutation for older server versions which do not support it.
203
+ if self._server_accepts_project_id_for_delete_file():
204
+ variable_values["projectId"] = self.run._project_internal_id
205
+ project_id_variable_fragment = ", $projectId: Int"
206
+ project_id_mutation_fragment = "projectId: $projectId"
207
+
208
+ mutation_string = """
209
+ mutation deleteFiles($files: [ID!]!{}) {{
210
+ deleteFiles(input: {{
211
+ files: $files
212
+ {}
213
+ }}) {{
214
+ success
215
+ }}
216
+ }}
217
+ """.format(project_id_variable_fragment, project_id_mutation_fragment)
218
+ mutation = gql(mutation_string)
219
+
220
+ self.client.execute(
221
+ mutation,
222
+ variable_values=variable_values,
202
223
  )
203
- self.client.execute(mutation, variable_values={"files": [self.id]})
204
224
 
205
225
  def __repr__(self):
206
226
  return "<File {} ({}) {}>".format(
@@ -208,3 +228,36 @@ class File(Attrs):
208
228
  self.mimetype,
209
229
  util.to_human_size(self.size, units=util.POW_2_BYTES),
210
230
  )
231
+
232
+ @normalize_exceptions
233
+ def _server_accepts_project_id_for_delete_file(self) -> bool:
234
+ """Returns True if the server supports deleting files with a projectId.
235
+
236
+ This check is done by utilizing GraphQL introspection in the avaiable fields on the DeleteFiles API.
237
+ """
238
+ query_string = """
239
+ query ProbeDeleteFilesProjectIdInput {
240
+ DeleteFilesProjectIdInputType: __type(name:"DeleteFilesInput") {
241
+ inputFields{
242
+ name
243
+ }
244
+ }
245
+ }
246
+ """
247
+
248
+ # Only perform the query once to avoid extra network calls
249
+ if self.server_supports_delete_file_with_project_id is None:
250
+ query = gql(query_string)
251
+ res = self.client.execute(query)
252
+
253
+ # If projectId is in the inputFields, the server supports deleting files with a projectId
254
+ self.server_supports_delete_file_with_project_id = "projectId" in [
255
+ x["name"]
256
+ for x in (
257
+ res.get("DeleteFilesProjectIdInputType", {}).get(
258
+ "inputFields", [{}]
259
+ )
260
+ )
261
+ ]
262
+
263
+ return self.server_supports_delete_file_with_project_id
wandb/apis/public/runs.py CHANGED
@@ -71,6 +71,7 @@ class Runs(Paginator):
71
71
  """
72
72
  query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
73
73
  project(name: $project, entityName: $entity) {{
74
+ internalId
74
75
  runCount(filters: $filters)
75
76
  readOnly
76
77
  runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
@@ -103,6 +104,7 @@ class Runs(Paginator):
103
104
  ):
104
105
  self.entity = entity
105
106
  self.project = project
107
+ self._project_internal_id = None
106
108
  self.filters = filters or {}
107
109
  self.order = order
108
110
  self._sweeps = {}
@@ -287,6 +289,7 @@ class Run(Attrs):
287
289
  Calling update will persist any changes.
288
290
  project (str): the project associated with the run
289
291
  entity (str): the name of the entity associated with the run
292
+ project_internal_id (int): the internal id of the project
290
293
  user (str): the name of the user who created the run
291
294
  path (str): Unique identifier [entity]/[project]/[run_id]
292
295
  notes (str): Notes about the run
@@ -328,6 +331,7 @@ class Run(Attrs):
328
331
  self._summary = None
329
332
  self._metadata: Optional[Dict[str, Any]] = None
330
333
  self._state = _attrs.get("state", "not found")
334
+ self.server_provides_internal_id_field: Optional[bool] = None
331
335
 
332
336
  self.load(force=not _attrs)
333
337
 
@@ -417,13 +421,18 @@ class Run(Attrs):
417
421
  """
418
422
  query Run($project: String!, $entity: String!, $name: String!) {{
419
423
  project(name: $project, entityName: $entity) {{
424
+ {}
420
425
  run(name: $name) {{
421
426
  ...RunFragment
422
427
  }}
423
428
  }}
424
429
  }}
425
430
  {}
426
- """.format(RUN_FRAGMENT)
431
+ """.format(
432
+ # Only query internalId if the server supports it
433
+ "internalId" if self._server_provides_internal_id_for_project() else "",
434
+ RUN_FRAGMENT,
435
+ )
427
436
  )
428
437
  if force or not self._attrs:
429
438
  response = self._exec(query)
@@ -435,7 +444,7 @@ class Run(Attrs):
435
444
  raise ValueError("Could not find run {}".format(self))
436
445
  self._attrs = response["project"]["run"]
437
446
  self._state = self._attrs["state"]
438
-
447
+ self._project_internal_id = response["project"].get("internalId", None)
439
448
  if self._include_sweeps and self.sweep_name and not self.sweep:
440
449
  # There may be a lot of runs. Don't bother pulling them all
441
450
  # just for the sake of this one.
@@ -495,7 +504,6 @@ class Run(Attrs):
495
504
  res = self._exec(query)
496
505
  state = res["project"]["run"]["state"]
497
506
  if state in ["finished", "crashed", "failed"]:
498
- print(f"Run finished with status: {state}")
499
507
  self._attrs["state"] = state
500
508
  self._state = state
501
509
  return
@@ -506,8 +514,8 @@ class Run(Attrs):
506
514
  """Persist changes to the run object to the wandb backend."""
507
515
  mutation = gql(
508
516
  """
509
- mutation UpsertBucket($id: String!, $description: String, $display_name: String, $notes: String, $tags: [String!], $config: JSONString!, $groupName: String) {{
510
- upsertBucket(input: {{id: $id, description: $description, displayName: $display_name, notes: $notes, tags: $tags, config: $config, groupName: $groupName}}) {{
517
+ mutation UpsertBucket($id: String!, $description: String, $display_name: String, $notes: String, $tags: [String!], $config: JSONString!, $groupName: String, $jobType: String) {{
518
+ upsertBucket(input: {{id: $id, description: $description, displayName: $display_name, notes: $notes, tags: $tags, config: $config, groupName: $groupName, jobType: $jobType}}) {{
511
519
  bucket {{
512
520
  ...RunFragment
513
521
  }}
@@ -525,6 +533,7 @@ class Run(Attrs):
525
533
  display_name=self.display_name,
526
534
  config=self.json_config,
527
535
  groupName=self.group,
536
+ jobType=self.job_type,
528
537
  )
529
538
  self.summary.update()
530
539
 
@@ -770,7 +779,9 @@ class Run(Attrs):
770
779
  Example:
771
780
  >>> import wandb
772
781
  >>> import tempfile
773
- >>> with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as tmp:
782
+ >>> with tempfile.NamedTemporaryFile(
783
+ ... mode="w", delete=False, suffix=".txt"
784
+ ... ) as tmp:
774
785
  ... tmp.write("This is a test artifact")
775
786
  ... tmp_path = tmp.name
776
787
  >>> run = wandb.init(project="artifact-example")
@@ -893,6 +904,37 @@ class Run(Attrs):
893
904
  )
894
905
  return artifact
895
906
 
907
+ @normalize_exceptions
908
+ def _server_provides_internal_id_for_project(self) -> bool:
909
+ """Returns True if the server allows us to query the internalId field for a project.
910
+
911
+ This check is done by utilizing GraphQL introspection in the avaiable fields on the Project type.
912
+ """
913
+ query_string = """
914
+ query ProbeProjectInput {
915
+ ProjectType: __type(name:"Project") {
916
+ fields {
917
+ name
918
+ }
919
+ }
920
+ }
921
+ """
922
+
923
+ # Only perform the query once to avoid extra network calls
924
+ if self.server_provides_internal_id_field is None:
925
+ query = gql(query_string)
926
+ res = self.client.execute(query)
927
+ print(
928
+ "internalId"
929
+ in [x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))]
930
+ )
931
+
932
+ self.server_provides_internal_id_field = "internalId" in [
933
+ x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))
934
+ ]
935
+
936
+ return self.server_provides_internal_id_field
937
+
896
938
  @property
897
939
  def summary(self):
898
940
  if self._summary is None:
@@ -921,7 +963,10 @@ class Run(Attrs):
921
963
  if self._metadata is None:
922
964
  try:
923
965
  f = self.file("wandb-metadata.json")
924
- contents = util.download_file_into_memory(f.url, wandb.Api().api_key)
966
+ session = self.client._client.transport.session
967
+ response = session.get(f.url, timeout=5)
968
+ response.raise_for_status()
969
+ contents = response.content
925
970
  self._metadata = json_util.loads(contents)
926
971
  except: # noqa: E722
927
972
  # file doesn't exist, or can't be downloaded, or can't be parsed
@@ -34,7 +34,7 @@ class Sweep(Attrs):
34
34
  Instantiate with:
35
35
  ```
36
36
  api = wandb.Api()
37
- sweep = api.sweep(path/to/sweep)
37
+ sweep = api.sweep(path / to / sweep)
38
38
  ```
39
39
 
40
40
  Attributes:
wandb/bin/gpu_stats.exe CHANGED
Binary file
wandb/bin/wandb-core CHANGED
Binary file
wandb/cli/cli.py CHANGED
@@ -33,6 +33,7 @@ from wandb.apis.public import RunQueue
33
33
  from wandb.errors.links import url_registry
34
34
  from wandb.sdk.artifacts._validators import is_artifact_registry_project
35
35
  from wandb.sdk.artifacts.artifact_file_cache import get_artifact_file_cache
36
+ from wandb.sdk.internal.internal_api import Api as SDKInternalApi
36
37
  from wandb.sdk.launch import utils as launch_utils
37
38
  from wandb.sdk.launch._launch_add import _launch_add
38
39
  from wandb.sdk.launch.errors import ExecutionError, LaunchError
@@ -2399,7 +2400,7 @@ def get(path, root, type):
2399
2400
  settings_entity = public_api.settings["entity"] or public_api.default_entity
2400
2401
  # Registry artifacts are under the org entity. Because we offer a shorthand and alias for this path,
2401
2402
  # we need to fetch the org entity to for the user behind the scenes.
2402
- entity = InternalApi()._resolve_org_entity_name(
2403
+ entity = SDKInternalApi()._resolve_org_entity_name(
2403
2404
  entity=settings_entity, organization=organization
2404
2405
  )
2405
2406
  full_path = f"{entity}/{project}/{artifact_name}:{version}"
wandb/env.py CHANGED
@@ -169,7 +169,7 @@ def error_reporting_enabled() -> bool:
169
169
 
170
170
 
171
171
  def core_debug(default: str | None = None) -> bool:
172
- return _env_as_bool(CORE_DEBUG, default=default)
172
+ return _env_as_bool(CORE_DEBUG, default=default) or is_debug()
173
173
 
174
174
 
175
175
  def ssl_disabled() -> bool:
wandb/errors/term.py CHANGED
@@ -5,6 +5,8 @@ from __future__ import annotations
5
5
  import contextlib
6
6
  import logging
7
7
  import os
8
+ import re
9
+ import shutil
8
10
  import sys
9
11
  import threading
10
12
  from typing import TYPE_CHECKING, Iterator, Protocol
@@ -271,10 +273,67 @@ class DynamicBlock:
271
273
  The lock must be held.
272
274
  """
273
275
  if self._lines_to_print:
274
- click.echo("\n".join(self._lines_to_print), file=sys.stderr)
276
+ # Trim lines before printing. This is crucial because the \x1b[Am
277
+ # (cursor up) sequence used when clearing the text moves up by one
278
+ # visual line, and the terminal may be wrapping long lines onto
279
+ # multiple visual lines.
280
+ #
281
+ # There is no ANSI escape sequence that moves the cursor up by one
282
+ # "physical" line instead. Note that the user may resize their
283
+ # terminal.
284
+ term_width = _shutil_get_terminal_width()
285
+ click.echo(
286
+ "\n".join(
287
+ _ansi_shorten(line, term_width) #
288
+ for line in self._lines_to_print
289
+ ),
290
+ file=sys.stderr,
291
+ )
292
+
275
293
  self._num_lines_printed += len(self._lines_to_print)
276
294
 
277
295
 
296
+ def _shutil_get_terminal_width() -> int:
297
+ """Returns the width of the terminal.
298
+
299
+ Defined here for patching in tests.
300
+ """
301
+ columns, _ = shutil.get_terminal_size()
302
+ return columns
303
+
304
+
305
+ _ANSI_RE = re.compile("\x1b\\[(K|.*?m)")
306
+
307
+
308
+ def _ansi_shorten(text: str, width: int) -> str:
309
+ """Shorten text potentially containing ANSI sequences to fit a width."""
310
+ first_ansi = _ANSI_RE.search(text)
311
+
312
+ if not first_ansi:
313
+ return _raw_shorten(text, width)
314
+
315
+ if first_ansi.start() > width - 3:
316
+ return _raw_shorten(text[: first_ansi.start()], width)
317
+
318
+ return text[: first_ansi.end()] + _ansi_shorten(
319
+ text[first_ansi.end() :],
320
+ # Key part: the ANSI sequence doesn't reduce the remaining width.
321
+ width - first_ansi.start(),
322
+ )
323
+
324
+
325
+ def _raw_shorten(text: str, width: int) -> str:
326
+ """Shorten text to fit a width, replacing the end with "...".
327
+
328
+ Unlike textwrap.shorten(), this does not drop whitespace or do anything
329
+ smart.
330
+ """
331
+ if len(text) <= width:
332
+ return text
333
+
334
+ return text[: width - 3] + "..."
335
+
336
+
278
337
  def _log(
279
338
  string="",
280
339
  newline=True,
@@ -145,7 +145,9 @@ class WandbEvalCallback(Callback, abc.ABC):
145
145
  for idx, data in enumerate(dataloader):
146
146
  preds = model.predict(data)
147
147
  self.pred_table.add_data(
148
- self.data_table_ref.data[idx][0], self.data_table_ref.data[idx][1], preds
148
+ self.data_table_ref.data[idx][0],
149
+ self.data_table_ref.data[idx][1],
150
+ preds,
149
151
  )
150
152
  ```
151
153
  This method is called `on_epoch_end` or equivalent hook.
@@ -208,23 +208,24 @@ def create_component_from_func(
208
208
  """Return sum of two arguments"""
209
209
  return a + b
210
210
 
211
+
211
212
  # add_op is a task factory function that creates a task object when given arguments
212
213
  add_op = create_component_from_func(
213
214
  func=add,
214
- base_image='python:3.7', # Optional
215
- output_component_file='add.component.yaml', # Optional
216
- packages_to_install=['pandas==0.24'], # Optional
215
+ base_image="python:3.7", # Optional
216
+ output_component_file="add.component.yaml", # Optional
217
+ packages_to_install=["pandas==0.24"], # Optional
217
218
  )
218
219
 
219
220
  # The component spec can be accessed through the .component_spec attribute:
220
- add_op.component_spec.save('add.component.yaml')
221
+ add_op.component_spec.save("add.component.yaml")
221
222
 
222
223
  # The component function can be called with arguments to create a task:
223
224
  add_task = add_op(1, 3)
224
225
 
225
226
  # The resulting task has output references, corresponding to the component outputs.
226
227
  # When the function only has a single anonymous return value, the output name is "Output":
227
- sum_output_ref = add_task.outputs['Output']
228
+ sum_output_ref = add_task.outputs["Output"]
228
229
 
229
230
  # These task output references can be passed to other component functions, constructing a computation graph:
230
231
  task2 = add_op(sum_output_ref, 5)
@@ -241,17 +242,21 @@ def create_component_from_func(
241
242
 
242
243
  from typing import NamedTuple
243
244
 
244
- def add_multiply_two_numbers(a: float, b: float) -> NamedTuple('Outputs', [('sum', float), ('product', float)]):
245
+
246
+ def add_multiply_two_numbers(a: float, b: float) -> NamedTuple(
247
+ "Outputs", [("sum", float), ("product", float)]
248
+ ):
245
249
  """Return sum and product of two arguments"""
246
250
  return (a + b, a * b)
247
251
 
252
+
248
253
  add_multiply_op = create_component_from_func(add_multiply_two_numbers)
249
254
 
250
255
  # The component function can be called with arguments to create a task:
251
256
  add_multiply_task = add_multiply_op(1, 3)
252
257
 
253
258
  # The resulting task has output references, corresponding to the component outputs:
254
- sum_output_ref = add_multiply_task.outputs['sum']
259
+ sum_output_ref = add_multiply_task.outputs["sum"]
255
260
 
256
261
  # These task output references can be passed to other component functions, constructing a computation graph:
257
262
  task2 = add_multiply_op(sum_output_ref, 5)
@@ -266,14 +271,19 @@ def create_component_from_func(
266
271
  Example of a component function declaring file input and output::
267
272
 
268
273
  def catboost_train_classifier(
269
- training_data_path: InputPath('CSV'), # Path to input data file of type "CSV"
270
- trained_model_path: OutputPath('CatBoostModel'), # Path to output data file of type "CatBoostModel"
271
- number_of_trees: int = 100, # Small output of type "Integer"
272
- ) -> NamedTuple('Outputs', [
273
- ('Accuracy', float), # Small output of type "Float"
274
- ('Precision', float), # Small output of type "Float"
275
- ('JobUri', 'URI'), # Small output of type "URI"
276
- ]):
274
+ training_data_path: InputPath("CSV"), # Path to input data file of type "CSV"
275
+ trained_model_path: OutputPath(
276
+ "CatBoostModel"
277
+ ), # Path to output data file of type "CatBoostModel"
278
+ number_of_trees: int = 100, # Small output of type "Integer"
279
+ ) -> NamedTuple(
280
+ "Outputs",
281
+ [
282
+ ("Accuracy", float), # Small output of type "Float"
283
+ ("Precision", float), # Small output of type "Float"
284
+ ("JobUri", "URI"), # Small output of type "URI"
285
+ ],
286
+ ):
277
287
  """Train CatBoost classification model"""
278
288
  ...
279
289
 
@@ -192,7 +192,9 @@ class WandbLogger(Logger):
192
192
  wandb_logger.log_image(key="samples", images=[img1, img2])
193
193
 
194
194
  # adding captions
195
- wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])
195
+ wandb_logger.log_image(
196
+ key="samples", images=[img1, img2], caption=["tree", "person"]
197
+ )
196
198
 
197
199
  # using file path
198
200
  wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])
@@ -30,8 +30,9 @@ def patch(
30
30
  ) -> None:
31
31
  if len(wandb.patched["tensorboard"]) > 0:
32
32
  raise ValueError(
33
- "Tensorboard already patched, remove `sync_tensorboard=True` "
34
- "from `wandb.init` or only call `wandb.tensorboard.patch` once."
33
+ "Tensorboard already patched. Call `wandb.tensorboard.unpatch()` first; "
34
+ "remove `sync_tensorboard=True` from `wandb.init`; "
35
+ "or only call `wandb.tensorboard.patch` once."
35
36
  )
36
37
 
37
38
  # TODO: Some older versions of tensorflow don't require tensorboard to be present.
wandb/jupyter.py CHANGED
@@ -15,10 +15,9 @@ import wandb.util
15
15
  from wandb.sdk.lib import filesystem
16
16
 
17
17
  try:
18
- from IPython.core.getipython import get_ipython
18
+ import IPython
19
19
  from IPython.core.magic import Magics, line_cell_magic, magics_class
20
20
  from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
21
- from IPython.display import display
22
21
  except ImportError:
23
22
  wandb.termwarn("ipython is not supported in python 2.7, upgrade to 3.x")
24
23
 
@@ -66,7 +65,7 @@ class IFrame:
66
65
 
67
66
  def maybe_display(self) -> bool:
68
67
  if not self.displayed and (self.path or wandb.run):
69
- display(self)
68
+ IPython.display.display(self)
70
69
  return self.displayed
71
70
 
72
71
  def _repr_html_(self):
@@ -155,7 +154,7 @@ class WandBMagics(Magics):
155
154
  + cell
156
155
  + "\nwandb.jupyter.__IFrame = None"
157
156
  )
158
- get_ipython().run_cell(cell)
157
+ IPython.get_ipython().run_cell(cell)
159
158
 
160
159
 
161
160
  def notebook_metadata_from_jupyter_servers_and_kernel_id():
@@ -343,7 +342,7 @@ class Notebook:
343
342
  def __init__(self, settings):
344
343
  self.outputs = {}
345
344
  self.settings = settings
346
- self.shell = get_ipython()
345
+ self.shell = IPython.get_ipython()
347
346
 
348
347
  def save_display(self, exc_count, data_with_metadata):
349
348
  self.outputs[exc_count] = self.outputs.get(exc_count, [])
wandb/plot/bar.py CHANGED
@@ -38,10 +38,10 @@ def bar(
38
38
 
39
39
  # Generate random data for the table
40
40
  data = [
41
- ['car', random.uniform(0, 1)],
42
- ['bus', random.uniform(0, 1)],
43
- ['road', random.uniform(0, 1)],
44
- ['person', random.uniform(0, 1)],
41
+ ["car", random.uniform(0, 1)],
42
+ ["bus", random.uniform(0, 1)],
43
+ ["road", random.uniform(0, 1)],
44
+ ["person", random.uniform(0, 1)],
45
45
  ]
46
46
 
47
47
  # Create a table with the data
@@ -49,7 +49,6 @@ def bar(
49
49
 
50
50
  # Initialize a W&B run and log the bar plot
51
51
  with wandb.init(project="bar_chart") as run:
52
-
53
52
  # Create a bar plot from the table
54
53
  bar_plot = wandb.plot.bar(
55
54
  table=table,
@@ -59,7 +58,7 @@ def bar(
59
58
  )
60
59
 
61
60
  # Log the bar chart to W&B
62
- run.log({'bar_plot': bar_plot})
61
+ run.log({"bar_plot": bar_plot})
63
62
  ```
64
63
  """
65
64
  return plot_table(
wandb/plot/histogram.py CHANGED
@@ -53,7 +53,7 @@ def histogram(
53
53
 
54
54
  # Log the histogram plot to W&B
55
55
  with wandb.init(...) as run:
56
- run.log({'histogram-plot1': histogram})
56
+ run.log({"histogram-plot1": histogram})
57
57
  ```
58
58
  """
59
59
  return plot_table(
wandb/plot/line_series.py CHANGED
@@ -53,9 +53,9 @@ def line_series(
53
53
 
54
54
  # Multiple y series to plot
55
55
  ys = [
56
- [i for i in range(10)], # y = x
57
- [i**2 for i in range(10)], # y = x^2
58
- [i**3 for i in range(10)], # y = x^3
56
+ [i for i in range(10)], # y = x
57
+ [i**2 for i in range(10)], # y = x^2
58
+ [i**3 for i in range(10)], # y = x^3
59
59
  ]
60
60
 
61
61
  # Generate and log the line series chart
wandb/plot/pr_curve.py CHANGED
@@ -74,10 +74,10 @@ def pr_curve(
74
74
  [0.2, 0.8], # Second sample (spam), and so on
75
75
  [0.1, 0.9],
76
76
  [0.8, 0.2],
77
- [0.3, 0.7]
77
+ [0.3, 0.7],
78
78
  ]
79
79
 
80
- labels = ['not spam', 'spam'] # Optional class names for readability
80
+ labels = ["not spam", "spam"] # Optional class names for readability
81
81
 
82
82
  with wandb.init(project="spam-detection") as run:
83
83
  pr_curve = wandb.plot.pr_curve(
@@ -176,6 +176,10 @@ def pr_curve(
176
176
  "y": "precision",
177
177
  "class": "class",
178
178
  },
179
- string_fields={"title": title},
179
+ string_fields={
180
+ "title": title,
181
+ "x-axis-title": "Recall",
182
+ "y-axis-title": "Precision",
183
+ },
180
184
  split_table=split_table,
181
185
  )
wandb/plot/scatter.py CHANGED
@@ -38,7 +38,8 @@ def scatter(
38
38
 
39
39
  # Simulate temperature variations at different altitudes over time
40
40
  data = [
41
- [i, random.uniform(-10, 20) - 0.005 * i + 5 * math.sin(i / 50)] for i in range(300)
41
+ [i, random.uniform(-10, 20) - 0.005 * i + 5 * math.sin(i / 50)]
42
+ for i in range(300)
42
43
  ]
43
44
 
44
45
  # Create W&B table with altitude (m) and temperature (°C) columns