wandb 0.19.0rc1__py3-none-win_amd64.whl → 0.19.1__py3-none-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (50) hide show
  1. wandb/__init__.py +1 -7
  2. wandb/__init__.pyi +211 -209
  3. wandb/apis/attrs.py +15 -4
  4. wandb/apis/public/api.py +8 -4
  5. wandb/apis/public/files.py +65 -12
  6. wandb/apis/public/runs.py +52 -7
  7. wandb/apis/public/sweeps.py +1 -1
  8. wandb/bin/gpu_stats.exe +0 -0
  9. wandb/bin/wandb-core +0 -0
  10. wandb/cli/cli.py +2 -1
  11. wandb/env.py +1 -1
  12. wandb/errors/term.py +60 -1
  13. wandb/integration/keras/callbacks/tables_builder.py +3 -1
  14. wandb/integration/kfp/kfp_patch.py +25 -15
  15. wandb/integration/lightning/fabric/logger.py +3 -1
  16. wandb/integration/tensorboard/monkeypatch.py +3 -2
  17. wandb/jupyter.py +4 -5
  18. wandb/plot/bar.py +5 -6
  19. wandb/plot/histogram.py +1 -1
  20. wandb/plot/line_series.py +3 -3
  21. wandb/plot/pr_curve.py +7 -3
  22. wandb/plot/scatter.py +2 -1
  23. wandb/proto/v3/wandb_settings_pb2.py +25 -15
  24. wandb/proto/v4/wandb_settings_pb2.py +17 -15
  25. wandb/proto/v5/wandb_settings_pb2.py +17 -15
  26. wandb/sdk/artifacts/_validators.py +1 -3
  27. wandb/sdk/artifacts/artifact_manifest_entry.py +1 -1
  28. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +12 -2
  29. wandb/sdk/data_types/helper_types/image_mask.py +8 -2
  30. wandb/sdk/data_types/histogram.py +3 -3
  31. wandb/sdk/data_types/image.py +3 -1
  32. wandb/sdk/interface/interface.py +34 -5
  33. wandb/sdk/interface/interface_sock.py +2 -2
  34. wandb/sdk/internal/file_stream.py +4 -1
  35. wandb/sdk/internal/sender.py +4 -1
  36. wandb/sdk/internal/settings_static.py +17 -4
  37. wandb/sdk/launch/utils.py +1 -0
  38. wandb/sdk/lib/ipython.py +5 -27
  39. wandb/sdk/lib/printer.py +33 -20
  40. wandb/sdk/lib/progress.py +7 -1
  41. wandb/sdk/lib/sparkline.py +1 -2
  42. wandb/sdk/wandb_config.py +2 -2
  43. wandb/sdk/wandb_init.py +236 -243
  44. wandb/sdk/wandb_run.py +172 -231
  45. wandb/sdk/wandb_settings.py +104 -15
  46. {wandb-0.19.0rc1.dist-info → wandb-0.19.1.dist-info}/METADATA +1 -1
  47. {wandb-0.19.0rc1.dist-info → wandb-0.19.1.dist-info}/RECORD +50 -50
  48. {wandb-0.19.0rc1.dist-info → wandb-0.19.1.dist-info}/WHEEL +0 -0
  49. {wandb-0.19.0rc1.dist-info → wandb-0.19.1.dist-info}/entry_points.txt +0 -0
  50. {wandb-0.19.0rc1.dist-info → wandb-0.19.1.dist-info}/licenses/LICENSE +0 -0
@@ -49,6 +49,7 @@ class Files(Paginator):
49
49
  query RunFiles($project: String!, $entity: String!, $name: String!, $fileCursor: String,
50
50
  $fileLimit: Int = 50, $fileNames: [String] = [], $upload: Boolean = false) {{
51
51
  project(name: $project, entityName: $entity) {{
52
+ internalId
52
53
  run(name: $name) {{
53
54
  fileCount
54
55
  ...RunFilesFragment
@@ -98,7 +99,7 @@ class Files(Paginator):
98
99
 
99
100
  def convert_objects(self):
100
101
  return [
101
- File(self.client, r["node"])
102
+ File(self.client, r["node"], self.run)
102
103
  for r in self.last_response["project"]["run"]["files"]["edges"]
103
104
  ]
104
105
 
@@ -120,9 +121,11 @@ class File(Attrs):
120
121
  path_uri (str): path to file in the bucket, currently only available for files stored in S3
121
122
  """
122
123
 
123
- def __init__(self, client, attrs):
124
+ def __init__(self, client, attrs, run=None):
124
125
  self.client = client
125
126
  self._attrs = attrs
127
+ self.run = run
128
+ self.server_supports_delete_file_with_project_id: Optional[bool] = None
126
129
  super().__init__(dict(attrs))
127
130
 
128
131
  @property
@@ -189,18 +192,35 @@ class File(Attrs):
189
192
 
190
193
  @normalize_exceptions
191
194
  def delete(self):
192
- mutation = gql(
193
- """
194
- mutation deleteFiles($files: [ID!]!) {
195
- deleteFiles(input: {
196
- files: $files
197
- }) {
198
- success
199
- }
195
+ project_id_mutation_fragment = ""
196
+ project_id_variable_fragment = ""
197
+ variable_values = {
198
+ "files": [self.id],
200
199
  }
201
- """
200
+
201
+ # Add projectId to mutation and variables if the server supports it.
202
+ # Otherwise, do not include projectId in mutation for older server versions which do not support it.
203
+ if self._server_accepts_project_id_for_delete_file():
204
+ variable_values["projectId"] = self.run._project_internal_id
205
+ project_id_variable_fragment = ", $projectId: Int"
206
+ project_id_mutation_fragment = "projectId: $projectId"
207
+
208
+ mutation_string = """
209
+ mutation deleteFiles($files: [ID!]!{}) {{
210
+ deleteFiles(input: {{
211
+ files: $files
212
+ {}
213
+ }}) {{
214
+ success
215
+ }}
216
+ }}
217
+ """.format(project_id_variable_fragment, project_id_mutation_fragment)
218
+ mutation = gql(mutation_string)
219
+
220
+ self.client.execute(
221
+ mutation,
222
+ variable_values=variable_values,
202
223
  )
203
- self.client.execute(mutation, variable_values={"files": [self.id]})
204
224
 
205
225
  def __repr__(self):
206
226
  return "<File {} ({}) {}>".format(
@@ -208,3 +228,36 @@ class File(Attrs):
208
228
  self.mimetype,
209
229
  util.to_human_size(self.size, units=util.POW_2_BYTES),
210
230
  )
231
+
232
+ @normalize_exceptions
233
+ def _server_accepts_project_id_for_delete_file(self) -> bool:
234
+ """Returns True if the server supports deleting files with a projectId.
235
+
236
+ This check is done by utilizing GraphQL introspection in the avaiable fields on the DeleteFiles API.
237
+ """
238
+ query_string = """
239
+ query ProbeDeleteFilesProjectIdInput {
240
+ DeleteFilesProjectIdInputType: __type(name:"DeleteFilesInput") {
241
+ inputFields{
242
+ name
243
+ }
244
+ }
245
+ }
246
+ """
247
+
248
+ # Only perform the query once to avoid extra network calls
249
+ if self.server_supports_delete_file_with_project_id is None:
250
+ query = gql(query_string)
251
+ res = self.client.execute(query)
252
+
253
+ # If projectId is in the inputFields, the server supports deleting files with a projectId
254
+ self.server_supports_delete_file_with_project_id = "projectId" in [
255
+ x["name"]
256
+ for x in (
257
+ res.get("DeleteFilesProjectIdInputType", {}).get(
258
+ "inputFields", [{}]
259
+ )
260
+ )
261
+ ]
262
+
263
+ return self.server_supports_delete_file_with_project_id
wandb/apis/public/runs.py CHANGED
@@ -71,6 +71,7 @@ class Runs(Paginator):
71
71
  """
72
72
  query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
73
73
  project(name: $project, entityName: $entity) {{
74
+ internalId
74
75
  runCount(filters: $filters)
75
76
  readOnly
76
77
  runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
@@ -103,6 +104,7 @@ class Runs(Paginator):
103
104
  ):
104
105
  self.entity = entity
105
106
  self.project = project
107
+ self._project_internal_id = None
106
108
  self.filters = filters or {}
107
109
  self.order = order
108
110
  self._sweeps = {}
@@ -287,6 +289,7 @@ class Run(Attrs):
287
289
  Calling update will persist any changes.
288
290
  project (str): the project associated with the run
289
291
  entity (str): the name of the entity associated with the run
292
+ project_internal_id (int): the internal id of the project
290
293
  user (str): the name of the user who created the run
291
294
  path (str): Unique identifier [entity]/[project]/[run_id]
292
295
  notes (str): Notes about the run
@@ -328,6 +331,7 @@ class Run(Attrs):
328
331
  self._summary = None
329
332
  self._metadata: Optional[Dict[str, Any]] = None
330
333
  self._state = _attrs.get("state", "not found")
334
+ self.server_provides_internal_id_field: Optional[bool] = None
331
335
 
332
336
  self.load(force=not _attrs)
333
337
 
@@ -417,13 +421,18 @@ class Run(Attrs):
417
421
  """
418
422
  query Run($project: String!, $entity: String!, $name: String!) {{
419
423
  project(name: $project, entityName: $entity) {{
424
+ {}
420
425
  run(name: $name) {{
421
426
  ...RunFragment
422
427
  }}
423
428
  }}
424
429
  }}
425
430
  {}
426
- """.format(RUN_FRAGMENT)
431
+ """.format(
432
+ # Only query internalId if the server supports it
433
+ "internalId" if self._server_provides_internal_id_for_project() else "",
434
+ RUN_FRAGMENT,
435
+ )
427
436
  )
428
437
  if force or not self._attrs:
429
438
  response = self._exec(query)
@@ -435,7 +444,7 @@ class Run(Attrs):
435
444
  raise ValueError("Could not find run {}".format(self))
436
445
  self._attrs = response["project"]["run"]
437
446
  self._state = self._attrs["state"]
438
-
447
+ self._project_internal_id = response["project"].get("internalId", None)
439
448
  if self._include_sweeps and self.sweep_name and not self.sweep:
440
449
  # There may be a lot of runs. Don't bother pulling them all
441
450
  # just for the sake of this one.
@@ -495,7 +504,6 @@ class Run(Attrs):
495
504
  res = self._exec(query)
496
505
  state = res["project"]["run"]["state"]
497
506
  if state in ["finished", "crashed", "failed"]:
498
- print(f"Run finished with status: {state}")
499
507
  self._attrs["state"] = state
500
508
  self._state = state
501
509
  return
@@ -506,8 +514,8 @@ class Run(Attrs):
506
514
  """Persist changes to the run object to the wandb backend."""
507
515
  mutation = gql(
508
516
  """
509
- mutation UpsertBucket($id: String!, $description: String, $display_name: String, $notes: String, $tags: [String!], $config: JSONString!, $groupName: String) {{
510
- upsertBucket(input: {{id: $id, description: $description, displayName: $display_name, notes: $notes, tags: $tags, config: $config, groupName: $groupName}}) {{
517
+ mutation UpsertBucket($id: String!, $description: String, $display_name: String, $notes: String, $tags: [String!], $config: JSONString!, $groupName: String, $jobType: String) {{
518
+ upsertBucket(input: {{id: $id, description: $description, displayName: $display_name, notes: $notes, tags: $tags, config: $config, groupName: $groupName, jobType: $jobType}}) {{
511
519
  bucket {{
512
520
  ...RunFragment
513
521
  }}
@@ -525,6 +533,7 @@ class Run(Attrs):
525
533
  display_name=self.display_name,
526
534
  config=self.json_config,
527
535
  groupName=self.group,
536
+ jobType=self.job_type,
528
537
  )
529
538
  self.summary.update()
530
539
 
@@ -770,7 +779,9 @@ class Run(Attrs):
770
779
  Example:
771
780
  >>> import wandb
772
781
  >>> import tempfile
773
- >>> with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as tmp:
782
+ >>> with tempfile.NamedTemporaryFile(
783
+ ... mode="w", delete=False, suffix=".txt"
784
+ ... ) as tmp:
774
785
  ... tmp.write("This is a test artifact")
775
786
  ... tmp_path = tmp.name
776
787
  >>> run = wandb.init(project="artifact-example")
@@ -893,6 +904,37 @@ class Run(Attrs):
893
904
  )
894
905
  return artifact
895
906
 
907
+ @normalize_exceptions
908
+ def _server_provides_internal_id_for_project(self) -> bool:
909
+ """Returns True if the server allows us to query the internalId field for a project.
910
+
911
+ This check is done by utilizing GraphQL introspection in the avaiable fields on the Project type.
912
+ """
913
+ query_string = """
914
+ query ProbeProjectInput {
915
+ ProjectType: __type(name:"Project") {
916
+ fields {
917
+ name
918
+ }
919
+ }
920
+ }
921
+ """
922
+
923
+ # Only perform the query once to avoid extra network calls
924
+ if self.server_provides_internal_id_field is None:
925
+ query = gql(query_string)
926
+ res = self.client.execute(query)
927
+ print(
928
+ "internalId"
929
+ in [x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))]
930
+ )
931
+
932
+ self.server_provides_internal_id_field = "internalId" in [
933
+ x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))
934
+ ]
935
+
936
+ return self.server_provides_internal_id_field
937
+
896
938
  @property
897
939
  def summary(self):
898
940
  if self._summary is None:
@@ -921,7 +963,10 @@ class Run(Attrs):
921
963
  if self._metadata is None:
922
964
  try:
923
965
  f = self.file("wandb-metadata.json")
924
- contents = util.download_file_into_memory(f.url, wandb.Api().api_key)
966
+ session = self.client._client.transport.session
967
+ response = session.get(f.url, timeout=5)
968
+ response.raise_for_status()
969
+ contents = response.content
925
970
  self._metadata = json_util.loads(contents)
926
971
  except: # noqa: E722
927
972
  # file doesn't exist, or can't be downloaded, or can't be parsed
@@ -34,7 +34,7 @@ class Sweep(Attrs):
34
34
  Instantiate with:
35
35
  ```
36
36
  api = wandb.Api()
37
- sweep = api.sweep(path/to/sweep)
37
+ sweep = api.sweep(path / to / sweep)
38
38
  ```
39
39
 
40
40
  Attributes:
wandb/bin/gpu_stats.exe CHANGED
Binary file
wandb/bin/wandb-core CHANGED
Binary file
wandb/cli/cli.py CHANGED
@@ -33,6 +33,7 @@ from wandb.apis.public import RunQueue
33
33
  from wandb.errors.links import url_registry
34
34
  from wandb.sdk.artifacts._validators import is_artifact_registry_project
35
35
  from wandb.sdk.artifacts.artifact_file_cache import get_artifact_file_cache
36
+ from wandb.sdk.internal.internal_api import Api as SDKInternalApi
36
37
  from wandb.sdk.launch import utils as launch_utils
37
38
  from wandb.sdk.launch._launch_add import _launch_add
38
39
  from wandb.sdk.launch.errors import ExecutionError, LaunchError
@@ -2399,7 +2400,7 @@ def get(path, root, type):
2399
2400
  settings_entity = public_api.settings["entity"] or public_api.default_entity
2400
2401
  # Registry artifacts are under the org entity. Because we offer a shorthand and alias for this path,
2401
2402
  # we need to fetch the org entity to for the user behind the scenes.
2402
- entity = InternalApi()._resolve_org_entity_name(
2403
+ entity = SDKInternalApi()._resolve_org_entity_name(
2403
2404
  entity=settings_entity, organization=organization
2404
2405
  )
2405
2406
  full_path = f"{entity}/{project}/{artifact_name}:{version}"
wandb/env.py CHANGED
@@ -169,7 +169,7 @@ def error_reporting_enabled() -> bool:
169
169
 
170
170
 
171
171
  def core_debug(default: str | None = None) -> bool:
172
- return _env_as_bool(CORE_DEBUG, default=default)
172
+ return _env_as_bool(CORE_DEBUG, default=default) or is_debug()
173
173
 
174
174
 
175
175
  def ssl_disabled() -> bool:
wandb/errors/term.py CHANGED
@@ -5,6 +5,8 @@ from __future__ import annotations
5
5
  import contextlib
6
6
  import logging
7
7
  import os
8
+ import re
9
+ import shutil
8
10
  import sys
9
11
  import threading
10
12
  from typing import TYPE_CHECKING, Iterator, Protocol
@@ -271,10 +273,67 @@ class DynamicBlock:
271
273
  The lock must be held.
272
274
  """
273
275
  if self._lines_to_print:
274
- click.echo("\n".join(self._lines_to_print), file=sys.stderr)
276
+ # Trim lines before printing. This is crucial because the \x1b[Am
277
+ # (cursor up) sequence used when clearing the text moves up by one
278
+ # visual line, and the terminal may be wrapping long lines onto
279
+ # multiple visual lines.
280
+ #
281
+ # There is no ANSI escape sequence that moves the cursor up by one
282
+ # "physical" line instead. Note that the user may resize their
283
+ # terminal.
284
+ term_width = _shutil_get_terminal_width()
285
+ click.echo(
286
+ "\n".join(
287
+ _ansi_shorten(line, term_width) #
288
+ for line in self._lines_to_print
289
+ ),
290
+ file=sys.stderr,
291
+ )
292
+
275
293
  self._num_lines_printed += len(self._lines_to_print)
276
294
 
277
295
 
296
+ def _shutil_get_terminal_width() -> int:
297
+ """Returns the width of the terminal.
298
+
299
+ Defined here for patching in tests.
300
+ """
301
+ columns, _ = shutil.get_terminal_size()
302
+ return columns
303
+
304
+
305
+ _ANSI_RE = re.compile("\x1b\\[(K|.*?m)")
306
+
307
+
308
+ def _ansi_shorten(text: str, width: int) -> str:
309
+ """Shorten text potentially containing ANSI sequences to fit a width."""
310
+ first_ansi = _ANSI_RE.search(text)
311
+
312
+ if not first_ansi:
313
+ return _raw_shorten(text, width)
314
+
315
+ if first_ansi.start() > width - 3:
316
+ return _raw_shorten(text[: first_ansi.start()], width)
317
+
318
+ return text[: first_ansi.end()] + _ansi_shorten(
319
+ text[first_ansi.end() :],
320
+ # Key part: the ANSI sequence doesn't reduce the remaining width.
321
+ width - first_ansi.start(),
322
+ )
323
+
324
+
325
+ def _raw_shorten(text: str, width: int) -> str:
326
+ """Shorten text to fit a width, replacing the end with "...".
327
+
328
+ Unlike textwrap.shorten(), this does not drop whitespace or do anything
329
+ smart.
330
+ """
331
+ if len(text) <= width:
332
+ return text
333
+
334
+ return text[: width - 3] + "..."
335
+
336
+
278
337
  def _log(
279
338
  string="",
280
339
  newline=True,
@@ -145,7 +145,9 @@ class WandbEvalCallback(Callback, abc.ABC):
145
145
  for idx, data in enumerate(dataloader):
146
146
  preds = model.predict(data)
147
147
  self.pred_table.add_data(
148
- self.data_table_ref.data[idx][0], self.data_table_ref.data[idx][1], preds
148
+ self.data_table_ref.data[idx][0],
149
+ self.data_table_ref.data[idx][1],
150
+ preds,
149
151
  )
150
152
  ```
151
153
  This method is called `on_epoch_end` or equivalent hook.
@@ -208,23 +208,24 @@ def create_component_from_func(
208
208
  """Return sum of two arguments"""
209
209
  return a + b
210
210
 
211
+
211
212
  # add_op is a task factory function that creates a task object when given arguments
212
213
  add_op = create_component_from_func(
213
214
  func=add,
214
- base_image='python:3.7', # Optional
215
- output_component_file='add.component.yaml', # Optional
216
- packages_to_install=['pandas==0.24'], # Optional
215
+ base_image="python:3.7", # Optional
216
+ output_component_file="add.component.yaml", # Optional
217
+ packages_to_install=["pandas==0.24"], # Optional
217
218
  )
218
219
 
219
220
  # The component spec can be accessed through the .component_spec attribute:
220
- add_op.component_spec.save('add.component.yaml')
221
+ add_op.component_spec.save("add.component.yaml")
221
222
 
222
223
  # The component function can be called with arguments to create a task:
223
224
  add_task = add_op(1, 3)
224
225
 
225
226
  # The resulting task has output references, corresponding to the component outputs.
226
227
  # When the function only has a single anonymous return value, the output name is "Output":
227
- sum_output_ref = add_task.outputs['Output']
228
+ sum_output_ref = add_task.outputs["Output"]
228
229
 
229
230
  # These task output references can be passed to other component functions, constructing a computation graph:
230
231
  task2 = add_op(sum_output_ref, 5)
@@ -241,17 +242,21 @@ def create_component_from_func(
241
242
 
242
243
  from typing import NamedTuple
243
244
 
244
- def add_multiply_two_numbers(a: float, b: float) -> NamedTuple('Outputs', [('sum', float), ('product', float)]):
245
+
246
+ def add_multiply_two_numbers(a: float, b: float) -> NamedTuple(
247
+ "Outputs", [("sum", float), ("product", float)]
248
+ ):
245
249
  """Return sum and product of two arguments"""
246
250
  return (a + b, a * b)
247
251
 
252
+
248
253
  add_multiply_op = create_component_from_func(add_multiply_two_numbers)
249
254
 
250
255
  # The component function can be called with arguments to create a task:
251
256
  add_multiply_task = add_multiply_op(1, 3)
252
257
 
253
258
  # The resulting task has output references, corresponding to the component outputs:
254
- sum_output_ref = add_multiply_task.outputs['sum']
259
+ sum_output_ref = add_multiply_task.outputs["sum"]
255
260
 
256
261
  # These task output references can be passed to other component functions, constructing a computation graph:
257
262
  task2 = add_multiply_op(sum_output_ref, 5)
@@ -266,14 +271,19 @@ def create_component_from_func(
266
271
  Example of a component function declaring file input and output::
267
272
 
268
273
  def catboost_train_classifier(
269
- training_data_path: InputPath('CSV'), # Path to input data file of type "CSV"
270
- trained_model_path: OutputPath('CatBoostModel'), # Path to output data file of type "CatBoostModel"
271
- number_of_trees: int = 100, # Small output of type "Integer"
272
- ) -> NamedTuple('Outputs', [
273
- ('Accuracy', float), # Small output of type "Float"
274
- ('Precision', float), # Small output of type "Float"
275
- ('JobUri', 'URI'), # Small output of type "URI"
276
- ]):
274
+ training_data_path: InputPath("CSV"), # Path to input data file of type "CSV"
275
+ trained_model_path: OutputPath(
276
+ "CatBoostModel"
277
+ ), # Path to output data file of type "CatBoostModel"
278
+ number_of_trees: int = 100, # Small output of type "Integer"
279
+ ) -> NamedTuple(
280
+ "Outputs",
281
+ [
282
+ ("Accuracy", float), # Small output of type "Float"
283
+ ("Precision", float), # Small output of type "Float"
284
+ ("JobUri", "URI"), # Small output of type "URI"
285
+ ],
286
+ ):
277
287
  """Train CatBoost classification model"""
278
288
  ...
279
289
 
@@ -192,7 +192,9 @@ class WandbLogger(Logger):
192
192
  wandb_logger.log_image(key="samples", images=[img1, img2])
193
193
 
194
194
  # adding captions
195
- wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])
195
+ wandb_logger.log_image(
196
+ key="samples", images=[img1, img2], caption=["tree", "person"]
197
+ )
196
198
 
197
199
  # using file path
198
200
  wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])
@@ -30,8 +30,9 @@ def patch(
30
30
  ) -> None:
31
31
  if len(wandb.patched["tensorboard"]) > 0:
32
32
  raise ValueError(
33
- "Tensorboard already patched, remove `sync_tensorboard=True` "
34
- "from `wandb.init` or only call `wandb.tensorboard.patch` once."
33
+ "Tensorboard already patched. Call `wandb.tensorboard.unpatch()` first; "
34
+ "remove `sync_tensorboard=True` from `wandb.init`; "
35
+ "or only call `wandb.tensorboard.patch` once."
35
36
  )
36
37
 
37
38
  # TODO: Some older versions of tensorflow don't require tensorboard to be present.
wandb/jupyter.py CHANGED
@@ -15,10 +15,9 @@ import wandb.util
15
15
  from wandb.sdk.lib import filesystem
16
16
 
17
17
  try:
18
- from IPython.core.getipython import get_ipython
18
+ import IPython
19
19
  from IPython.core.magic import Magics, line_cell_magic, magics_class
20
20
  from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
21
- from IPython.display import display
22
21
  except ImportError:
23
22
  wandb.termwarn("ipython is not supported in python 2.7, upgrade to 3.x")
24
23
 
@@ -66,7 +65,7 @@ class IFrame:
66
65
 
67
66
  def maybe_display(self) -> bool:
68
67
  if not self.displayed and (self.path or wandb.run):
69
- display(self)
68
+ IPython.display.display(self)
70
69
  return self.displayed
71
70
 
72
71
  def _repr_html_(self):
@@ -155,7 +154,7 @@ class WandBMagics(Magics):
155
154
  + cell
156
155
  + "\nwandb.jupyter.__IFrame = None"
157
156
  )
158
- get_ipython().run_cell(cell)
157
+ IPython.get_ipython().run_cell(cell)
159
158
 
160
159
 
161
160
  def notebook_metadata_from_jupyter_servers_and_kernel_id():
@@ -343,7 +342,7 @@ class Notebook:
343
342
  def __init__(self, settings):
344
343
  self.outputs = {}
345
344
  self.settings = settings
346
- self.shell = get_ipython()
345
+ self.shell = IPython.get_ipython()
347
346
 
348
347
  def save_display(self, exc_count, data_with_metadata):
349
348
  self.outputs[exc_count] = self.outputs.get(exc_count, [])
wandb/plot/bar.py CHANGED
@@ -38,10 +38,10 @@ def bar(
38
38
 
39
39
  # Generate random data for the table
40
40
  data = [
41
- ['car', random.uniform(0, 1)],
42
- ['bus', random.uniform(0, 1)],
43
- ['road', random.uniform(0, 1)],
44
- ['person', random.uniform(0, 1)],
41
+ ["car", random.uniform(0, 1)],
42
+ ["bus", random.uniform(0, 1)],
43
+ ["road", random.uniform(0, 1)],
44
+ ["person", random.uniform(0, 1)],
45
45
  ]
46
46
 
47
47
  # Create a table with the data
@@ -49,7 +49,6 @@ def bar(
49
49
 
50
50
  # Initialize a W&B run and log the bar plot
51
51
  with wandb.init(project="bar_chart") as run:
52
-
53
52
  # Create a bar plot from the table
54
53
  bar_plot = wandb.plot.bar(
55
54
  table=table,
@@ -59,7 +58,7 @@ def bar(
59
58
  )
60
59
 
61
60
  # Log the bar chart to W&B
62
- run.log({'bar_plot': bar_plot})
61
+ run.log({"bar_plot": bar_plot})
63
62
  ```
64
63
  """
65
64
  return plot_table(
wandb/plot/histogram.py CHANGED
@@ -53,7 +53,7 @@ def histogram(
53
53
 
54
54
  # Log the histogram plot to W&B
55
55
  with wandb.init(...) as run:
56
- run.log({'histogram-plot1': histogram})
56
+ run.log({"histogram-plot1": histogram})
57
57
  ```
58
58
  """
59
59
  return plot_table(
wandb/plot/line_series.py CHANGED
@@ -53,9 +53,9 @@ def line_series(
53
53
 
54
54
  # Multiple y series to plot
55
55
  ys = [
56
- [i for i in range(10)], # y = x
57
- [i**2 for i in range(10)], # y = x^2
58
- [i**3 for i in range(10)], # y = x^3
56
+ [i for i in range(10)], # y = x
57
+ [i**2 for i in range(10)], # y = x^2
58
+ [i**3 for i in range(10)], # y = x^3
59
59
  ]
60
60
 
61
61
  # Generate and log the line series chart
wandb/plot/pr_curve.py CHANGED
@@ -74,10 +74,10 @@ def pr_curve(
74
74
  [0.2, 0.8], # Second sample (spam), and so on
75
75
  [0.1, 0.9],
76
76
  [0.8, 0.2],
77
- [0.3, 0.7]
77
+ [0.3, 0.7],
78
78
  ]
79
79
 
80
- labels = ['not spam', 'spam'] # Optional class names for readability
80
+ labels = ["not spam", "spam"] # Optional class names for readability
81
81
 
82
82
  with wandb.init(project="spam-detection") as run:
83
83
  pr_curve = wandb.plot.pr_curve(
@@ -176,6 +176,10 @@ def pr_curve(
176
176
  "y": "precision",
177
177
  "class": "class",
178
178
  },
179
- string_fields={"title": title},
179
+ string_fields={
180
+ "title": title,
181
+ "x-axis-title": "Recall",
182
+ "y-axis-title": "Precision",
183
+ },
180
184
  split_table=split_table,
181
185
  )
wandb/plot/scatter.py CHANGED
@@ -38,7 +38,8 @@ def scatter(
38
38
 
39
39
  # Simulate temperature variations at different altitudes over time
40
40
  data = [
41
- [i, random.uniform(-10, 20) - 0.005 * i + 5 * math.sin(i / 50)] for i in range(300)
41
+ [i, random.uniform(-10, 20) - 0.005 * i + 5 * math.sin(i / 50)]
42
+ for i in range(300)
42
43
  ]
43
44
 
44
45
  # Create W&B table with altitude (m) and temperature (°C) columns