wandb 0.18.0__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/__init__.pyi +1 -1
  3. wandb/apis/public/runs.py +2 -0
  4. wandb/bin/nvidia_gpu_stats +0 -0
  5. wandb/cli/cli.py +0 -2
  6. wandb/data_types.py +9 -2019
  7. wandb/env.py +0 -5
  8. wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
  9. wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
  10. wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
  11. wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
  12. wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
  13. wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
  14. wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
  15. wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
  16. wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
  17. wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
  18. wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
  19. wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
  20. wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
  21. wandb/{sklearn → integration/sklearn}/utils.py +8 -8
  22. wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
  23. wandb/proto/v3/wandb_base_pb2.py +2 -1
  24. wandb/proto/v3/wandb_internal_pb2.py +2 -1
  25. wandb/proto/v3/wandb_server_pb2.py +2 -1
  26. wandb/proto/v3/wandb_settings_pb2.py +2 -1
  27. wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
  28. wandb/proto/v4/wandb_base_pb2.py +2 -1
  29. wandb/proto/v4/wandb_internal_pb2.py +2 -1
  30. wandb/proto/v4/wandb_server_pb2.py +2 -1
  31. wandb/proto/v4/wandb_settings_pb2.py +2 -1
  32. wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
  33. wandb/proto/v5/wandb_base_pb2.py +3 -2
  34. wandb/proto/v5/wandb_internal_pb2.py +3 -2
  35. wandb/proto/v5/wandb_server_pb2.py +3 -2
  36. wandb/proto/v5/wandb_settings_pb2.py +3 -2
  37. wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
  38. wandb/sdk/data_types/audio.py +165 -0
  39. wandb/sdk/data_types/bokeh.py +70 -0
  40. wandb/sdk/data_types/graph.py +405 -0
  41. wandb/sdk/data_types/image.py +156 -0
  42. wandb/sdk/data_types/table.py +1204 -0
  43. wandb/sdk/data_types/trace_tree.py +2 -2
  44. wandb/sdk/data_types/utils.py +49 -0
  45. wandb/sdk/service/service.py +2 -9
  46. wandb/sdk/service/streams.py +0 -7
  47. wandb/sdk/wandb_init.py +10 -3
  48. wandb/sdk/wandb_run.py +6 -152
  49. wandb/sdk/wandb_setup.py +1 -1
  50. wandb/sklearn.py +35 -0
  51. wandb/util.py +6 -2
  52. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/METADATA +1 -1
  53. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/RECORD +61 -57
  54. wandb/sdk/lib/console.py +0 -39
  55. /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
  56. /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
  57. /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
  58. /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
  59. /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
  60. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/WHEEL +0 -0
  61. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/entry_points.txt +0 -0
  62. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/licenses/LICENSE +0 -0
wandb/env.py CHANGED
@@ -62,7 +62,6 @@ SAVE_CODE = "WANDB_SAVE_CODE"
62
62
  TAGS = "WANDB_TAGS"
63
63
  IGNORE = "WANDB_IGNORE_GLOBS"
64
64
  ERROR_REPORTING = "WANDB_ERROR_REPORTING"
65
- CORE_ERROR_REPORTING = "WANDB_CORE_ERROR_REPORTING"
66
65
  CORE_DEBUG = "WANDB_CORE_DEBUG"
67
66
  DOCKER = "WANDB_DOCKER"
68
67
  AGENT_REPORT_INTERVAL = "WANDB_AGENT_REPORT_INTERVAL"
@@ -172,10 +171,6 @@ def error_reporting_enabled() -> bool:
172
171
  return _env_as_bool(ERROR_REPORTING, default="True")
173
172
 
174
173
 
175
- def core_error_reporting_enabled(default: Optional[str] = None) -> bool:
176
- return _env_as_bool(CORE_ERROR_REPORTING, default=default)
177
-
178
-
179
174
  def core_debug(default: Optional[str] = None) -> bool:
180
175
  return _env_as_bool(CORE_DEBUG, default=default)
181
176
 
@@ -7,18 +7,18 @@ from sklearn.calibration import CalibratedClassifierCV
7
7
  from sklearn.linear_model import LogisticRegression
8
8
 
9
9
  import wandb
10
- from wandb.sklearn import utils
10
+ from wandb.integration.sklearn import utils
11
11
 
12
12
  # ignore all future warnings
13
13
  simplefilter(action="ignore", category=FutureWarning)
14
14
 
15
15
 
16
- def calibration_curves(clf, X, y, clf_name):
16
+ def calibration_curves(clf, X, y, clf_name): # noqa: N803
17
17
  # ComplementNB (introduced in 0.20.0) requires non-negative features
18
18
  if int(sklearn.__version__.split(".")[1]) >= 20 and isinstance(
19
19
  clf, naive_bayes.ComplementNB
20
20
  ):
21
- X = X - X.min()
21
+ X = X - X.min() # noqa:N806
22
22
 
23
23
  # Calibrated with isotonic calibration
24
24
  isotonic = CalibratedClassifierCV(clf, cv=2, method="isotonic")
@@ -48,7 +48,7 @@ def calibration_curves(clf, X, y, clf_name):
48
48
  frac_positives_column.append(1)
49
49
  mean_pred_value_column.append(1)
50
50
 
51
- X_train, X_test, y_train, y_test = model_selection.train_test_split(
51
+ x_train, x_test, y_train, y_test = model_selection.train_test_split(
52
52
  X, y, test_size=0.9, random_state=42
53
53
  )
54
54
 
@@ -58,11 +58,11 @@ def calibration_curves(clf, X, y, clf_name):
58
58
  names = ["Logistic", f"{clf_name} Isotonic", f"{clf_name} Sigmoid"]
59
59
 
60
60
  for model, name in zip(models, names):
61
- model.fit(X_train, y_train)
61
+ model.fit(x_train, y_train)
62
62
  if hasattr(model, "predict_proba"):
63
- prob_pos = model.predict_proba(X_test)[:, 1]
63
+ prob_pos = model.predict_proba(x_test)[:, 1]
64
64
  else: # use decision function
65
- prob_pos = model.decision_function(X_test)
65
+ prob_pos = model.decision_function(x_test)
66
66
  prob_pos = (prob_pos - prob_pos.min()) / (prob_pos.max() - prob_pos.min())
67
67
 
68
68
  hist, edges = np.histogram(prob_pos, bins=10, density=False)
@@ -4,7 +4,7 @@ import numpy as np
4
4
  from sklearn.utils.multiclass import unique_labels
5
5
 
6
6
  import wandb
7
- from wandb.sklearn import utils
7
+ from wandb.integration.sklearn import utils
8
8
 
9
9
  # ignore all future warnings
10
10
  simplefilter(action="ignore", category=FutureWarning)
@@ -6,14 +6,15 @@ from sklearn import metrics
6
6
  from sklearn.utils.multiclass import unique_labels
7
7
 
8
8
  import wandb
9
- from wandb.sklearn import utils
9
+
10
+ from .. import utils
10
11
 
11
12
  # ignore all future warnings
12
13
  simplefilter(action="ignore", category=FutureWarning)
13
14
 
14
15
 
15
16
  def validate_labels(*args, **kwargs): # FIXME
16
- assert False
17
+ raise AssertionError()
17
18
 
18
19
 
19
20
  def confusion_matrix(
@@ -11,7 +11,7 @@ import wandb
11
11
  simplefilter(action="ignore", category=FutureWarning)
12
12
 
13
13
 
14
- def elbow_curve(clusterer, X, cluster_ranges, n_jobs, show_cluster_time):
14
+ def elbow_curve(clusterer, X, cluster_ranges, n_jobs, show_cluster_time): # noqa: N803
15
15
  if cluster_ranges is None:
16
16
  cluster_ranges = range(1, 10, 2)
17
17
  else:
@@ -37,19 +37,19 @@ def make_table(cluster_ranges, clfs, times):
37
37
  return table
38
38
 
39
39
 
40
- def _compute_results_parallel(n_jobs, clusterer, X, cluster_ranges):
40
+ def _compute_results_parallel(n_jobs, clusterer, x, cluster_ranges):
41
41
  parallel_runner = Parallel(n_jobs=n_jobs)
42
42
  _cluster_scorer = delayed(_clone_and_score_clusterer)
43
- results = parallel_runner(_cluster_scorer(clusterer, X, i) for i in cluster_ranges)
43
+ results = parallel_runner(_cluster_scorer(clusterer, x, i) for i in cluster_ranges)
44
44
 
45
45
  clfs, times = zip(*results)
46
46
 
47
47
  return clfs, times
48
48
 
49
49
 
50
- def _clone_and_score_clusterer(clusterer, X, n_clusters):
50
+ def _clone_and_score_clusterer(clusterer, x, n_clusters):
51
51
  start = time.time()
52
52
  clusterer = clone(clusterer)
53
- setattr(clusterer, "n_clusters", n_clusters)
53
+ clusterer.n_clusters = n_clusters
54
54
 
55
- return clusterer.fit(X).score(X), time.time() - start
55
+ return clusterer.fit(x).score(x), time.time() - start
@@ -4,7 +4,7 @@ import numpy as np
4
4
  from sklearn import model_selection
5
5
 
6
6
  import wandb
7
- from wandb.sklearn import utils
7
+ from wandb.integration.sklearn import utils
8
8
 
9
9
  # ignore all future warnings
10
10
  simplefilter(action="ignore", category=FutureWarning)
@@ -12,7 +12,7 @@ simplefilter(action="ignore", category=FutureWarning)
12
12
 
13
13
  def learning_curve(
14
14
  model,
15
- X,
15
+ X, # noqa: N803
16
16
  y,
17
17
  cv=None,
18
18
  shuffle=False,
@@ -3,13 +3,13 @@ from warnings import simplefilter
3
3
  import numpy as np
4
4
 
5
5
  import wandb
6
- from wandb.sklearn import utils
6
+ from wandb.integration.sklearn import utils
7
7
 
8
8
  # ignore all future warnings
9
9
  simplefilter(action="ignore", category=FutureWarning)
10
10
 
11
11
 
12
- def outlier_candidates(regressor, X, y):
12
+ def outlier_candidates(regressor, X, y): # noqa: N803
13
13
  # Fit a linear model to X and y to compute MSE
14
14
  regressor.fit(X, y)
15
15
 
@@ -3,27 +3,27 @@ from warnings import simplefilter
3
3
  from sklearn import model_selection
4
4
 
5
5
  import wandb
6
- from wandb.sklearn import utils
6
+ from wandb.integration.sklearn import utils
7
7
 
8
8
  # ignore all future warnings
9
9
  simplefilter(action="ignore", category=FutureWarning)
10
10
 
11
11
 
12
- def residuals(regressor, X, y):
12
+ def residuals(regressor, X, y): # noqa: N803
13
13
  # Create the train and test splits
14
- X_train, X_test, y_train, y_test = model_selection.train_test_split(
14
+ x_train, x_test, y_train, y_test = model_selection.train_test_split(
15
15
  X, y, test_size=0.2
16
16
  )
17
17
 
18
18
  # Store labels and colors for the legend ordered by call
19
- regressor.fit(X_train, y_train)
20
- train_score_ = regressor.score(X_train, y_train)
21
- test_score_ = regressor.score(X_test, y_test)
19
+ regressor.fit(x_train, y_train)
20
+ train_score_ = regressor.score(x_train, y_train)
21
+ test_score_ = regressor.score(x_test, y_test)
22
22
 
23
- y_pred_train = regressor.predict(X_train)
23
+ y_pred_train = regressor.predict(x_train)
24
24
  residuals_train = y_pred_train - y_train
25
25
 
26
- y_pred_test = regressor.predict(X_test)
26
+ y_pred_test = regressor.predict(x_test)
27
27
  residuals_test = y_pred_test - y_test
28
28
 
29
29
  table = make_table(
@@ -5,13 +5,13 @@ from sklearn.metrics import silhouette_samples, silhouette_score
5
5
  from sklearn.preprocessing import LabelEncoder
6
6
 
7
7
  import wandb
8
- from wandb.sklearn import utils
8
+ from wandb.integration.sklearn import utils
9
9
 
10
10
  # ignore all future warnings
11
11
  simplefilter(action="ignore", category=FutureWarning)
12
12
 
13
13
 
14
- def silhouette(clusterer, X, cluster_labels, labels, metric, kmeans):
14
+ def silhouette(clusterer, X, cluster_labels, labels, metric, kmeans): # noqa: N803
15
15
  # Run clusterer for n_clusters in range(len(cluster_ranges), get cluster labels
16
16
  # TODO - keep/delete once we decide if we should train clusterers
17
17
  # or ask for trained models
@@ -4,13 +4,13 @@ import numpy as np
4
4
  import sklearn
5
5
 
6
6
  import wandb
7
- from wandb.sklearn import utils
7
+ from wandb.integration.sklearn import utils
8
8
 
9
9
  # ignore all future warnings
10
10
  simplefilter(action="ignore", category=FutureWarning)
11
11
 
12
12
 
13
- def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None):
13
+ def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None): # noqa: N803
14
14
  """Calculate summary metrics for both regressors and classifiers.
15
15
 
16
16
  Called by plot_summary_metrics to visualize metrics. Please use the function
@@ -7,7 +7,7 @@ from sklearn import naive_bayes
7
7
 
8
8
  import wandb
9
9
  import wandb.plot
10
- from wandb.sklearn import calculate, utils
10
+ from wandb.integration.sklearn import calculate, utils
11
11
 
12
12
  from . import shared
13
13
 
@@ -17,8 +17,8 @@ simplefilter(action="ignore", category=FutureWarning)
17
17
 
18
18
  def classifier(
19
19
  model,
20
- X_train,
21
- X_test,
20
+ X_train, # noqa: N803
21
+ X_test, # noqa: N803
22
22
  y_train,
23
23
  y_test,
24
24
  y_pred,
@@ -77,7 +77,7 @@ def classifier(
77
77
  )
78
78
  ```
79
79
  """
80
- wandb.termlog("\nPlotting %s." % model_name)
80
+ wandb.termlog(f"\nPlotting {model_name}.")
81
81
 
82
82
  if not isinstance(model, naive_bayes.MultinomialNB):
83
83
  feature_importances(model, feature_names)
@@ -280,7 +280,7 @@ def class_proportions(y_train=None, y_test=None, labels=None):
280
280
  wandb.log({"class_proportions": class_proportions_chart})
281
281
 
282
282
 
283
- def calibration_curve(clf=None, X=None, y=None, clf_name="Classifier"):
283
+ def calibration_curve(clf=None, X=None, y=None, clf_name="Classifier"): # noqa: N803
284
284
  """Log a plot depicting how well-calibrated the predicted probabilities of a classifier are.
285
285
 
286
286
  Also suggests how to calibrate an uncalibrated classifier. Compares estimated predicted
@@ -6,13 +6,13 @@ import pandas as pd
6
6
  import sklearn
7
7
 
8
8
  import wandb
9
- from wandb.sklearn import calculate, utils
9
+ from wandb.integration.sklearn import calculate, utils
10
10
 
11
11
  # ignore all future warnings
12
12
  simplefilter(action="ignore", category=FutureWarning)
13
13
 
14
14
 
15
- def clusterer(model, X_train, cluster_labels, labels=None, model_name="Clusterer"):
15
+ def clusterer(model, X_train, cluster_labels, labels=None, model_name="Clusterer"): # noqa: N803
16
16
  """Generates all sklearn clusterer plots supported by W&B.
17
17
 
18
18
  The following plots are generated:
@@ -40,7 +40,7 @@ def clusterer(model, X_train, cluster_labels, labels=None, model_name="Clusterer
40
40
  wandb.sklearn.plot_clusterer(kmeans, X, cluster_labels, labels, "KMeans")
41
41
  ```
42
42
  """
43
- wandb.termlog("\nPlotting %s." % model_name)
43
+ wandb.termlog(f"\nPlotting {model_name}.")
44
44
  if isinstance(model, sklearn.cluster.KMeans):
45
45
  elbow_curve(model, X_train)
46
46
  wandb.termlog("Logged elbow curve.")
@@ -54,7 +54,11 @@ def clusterer(model, X_train, cluster_labels, labels=None, model_name="Clusterer
54
54
 
55
55
 
56
56
  def elbow_curve(
57
- clusterer=None, X=None, cluster_ranges=None, n_jobs=1, show_cluster_time=True
57
+ clusterer=None,
58
+ X=None, # noqa: N803
59
+ cluster_ranges=None,
60
+ n_jobs=1,
61
+ show_cluster_time=True,
58
62
  ):
59
63
  """Measures and plots variance explained as a function of the number of clusters.
60
64
 
@@ -97,7 +101,7 @@ def elbow_curve(
97
101
 
98
102
  def silhouette(
99
103
  clusterer=None,
100
- X=None,
104
+ X=None, # noqa: N803
101
105
  cluster_labels=None,
102
106
  labels=None,
103
107
  metric="euclidean",
@@ -135,7 +139,7 @@ def silhouette(
135
139
 
136
140
  if not_missing and correct_types and is_fitted:
137
141
  if isinstance(X, (pd.DataFrame)):
138
- X = X.values
142
+ X = X.values # noqa: N806
139
143
  silhouette_chart = calculate.silhouette(
140
144
  clusterer, X, cluster_labels, labels, metric, kmeans
141
145
  )
@@ -5,7 +5,7 @@ from warnings import simplefilter
5
5
  import numpy as np
6
6
 
7
7
  import wandb
8
- from wandb.sklearn import calculate, utils
8
+ from wandb.integration.sklearn import calculate, utils
9
9
 
10
10
  from . import shared
11
11
 
@@ -13,7 +13,7 @@ from . import shared
13
13
  simplefilter(action="ignore", category=FutureWarning)
14
14
 
15
15
 
16
- def regressor(model, X_train, X_test, y_train, y_test, model_name="Regressor"):
16
+ def regressor(model, X_train, X_test, y_train, y_test, model_name="Regressor"): # noqa: N803
17
17
  """Generates all sklearn regressor plots supported by W&B.
18
18
 
19
19
  The following plots are generated:
@@ -38,7 +38,7 @@ def regressor(model, X_train, X_test, y_train, y_test, model_name="Regressor"):
38
38
  wandb.sklearn.plot_regressor(reg, X_train, X_test, y_train, y_test, "Ridge")
39
39
  ```
40
40
  """
41
- wandb.termlog("\nPlotting %s." % model_name)
41
+ wandb.termlog(f"\nPlotting {model_name}.")
42
42
 
43
43
  shared.summary_metrics(model, X_train, y_train, X_test, y_test)
44
44
  wandb.termlog("Logged summary metrics.")
@@ -53,7 +53,7 @@ def regressor(model, X_train, X_test, y_train, y_test, model_name="Regressor"):
53
53
  wandb.termlog("Logged residuals.")
54
54
 
55
55
 
56
- def outlier_candidates(regressor=None, X=None, y=None):
56
+ def outlier_candidates(regressor=None, X=None, y=None): # noqa: N803
57
57
  """Measures a datapoint's influence on regression model via cook's distance.
58
58
 
59
59
  Instances with high influences could potentially be outliers.
@@ -87,7 +87,7 @@ def outlier_candidates(regressor=None, X=None, y=None):
87
87
  wandb.log({"outlier_candidates": outliers_chart})
88
88
 
89
89
 
90
- def residuals(regressor=None, X=None, y=None):
90
+ def residuals(regressor=None, X=None, y=None): # noqa: N803
91
91
  """Measures and plots the regressor's predicted value against the residual.
92
92
 
93
93
  The marginal distribution of residuals is also calculated and plotted.
@@ -5,13 +5,13 @@ from warnings import simplefilter
5
5
  import numpy as np
6
6
 
7
7
  import wandb
8
- from wandb.sklearn import calculate, utils
8
+ from wandb.integration.sklearn import calculate, utils
9
9
 
10
10
  # ignore all future warnings
11
11
  simplefilter(action="ignore", category=FutureWarning)
12
12
 
13
13
 
14
- def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None):
14
+ def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None): # noqa: N803
15
15
  """Logs a chart depicting summary metrics for a model.
16
16
 
17
17
  Should only be called with a fitted model (otherwise an error is thrown).
@@ -47,7 +47,7 @@ def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None):
47
47
 
48
48
  def learning_curve(
49
49
  model=None,
50
- X=None,
50
+ X=None, # noqa: N803
51
51
  y=None,
52
52
  cv=None,
53
53
  shuffle=False,
@@ -61,7 +61,7 @@ def test_types(**kwargs):
61
61
  list,
62
62
  ),
63
63
  ):
64
- wandb.termerror("%s is not an array. Please try again." % (k))
64
+ wandb.termerror(f"{k} is not an array. Please try again.")
65
65
  test_passed = False
66
66
  # check for classifier types
67
67
  if k == "model":
@@ -69,20 +69,20 @@ def test_types(**kwargs):
69
69
  not sklearn.base.is_regressor(v)
70
70
  ):
71
71
  wandb.termerror(
72
- "%s is not a classifier or regressor. Please try again." % (k)
72
+ f"{k} is not a classifier or regressor. Please try again."
73
73
  )
74
74
  test_passed = False
75
75
  elif k == "clf" or k == "binary_clf":
76
76
  if not (sklearn.base.is_classifier(v)):
77
- wandb.termerror("%s is not a classifier. Please try again." % (k))
77
+ wandb.termerror(f"{k} is not a classifier. Please try again.")
78
78
  test_passed = False
79
79
  elif k == "regressor":
80
80
  if not sklearn.base.is_regressor(v):
81
- wandb.termerror("%s is not a regressor. Please try again." % (k))
81
+ wandb.termerror(f"{k} is not a regressor. Please try again.")
82
82
  test_passed = False
83
83
  elif k == "clusterer":
84
84
  if not (getattr(v, "_estimator_type", None) == "clusterer"):
85
- wandb.termerror("%s is not a clusterer. Please try again." % (k))
85
+ wandb.termerror(f"{k} is not a clusterer. Please try again.")
86
86
  test_passed = False
87
87
  return test_passed
88
88
 
@@ -129,7 +129,7 @@ def test_missing(**kwargs):
129
129
  for k, v in kwargs.items():
130
130
  # Missing/empty params/datapoint arrays
131
131
  if v is None:
132
- wandb.termerror("%s is None. Please try again." % (k))
132
+ wandb.termerror(f"{k} is None. Please try again.")
133
133
  test_passed = False
134
134
  if (k == "X") or (k == "X_test"):
135
135
  if isinstance(v, scipy.sparse.csr.csr_matrix):
@@ -168,8 +168,8 @@ def test_missing(**kwargs):
168
168
  )
169
169
  if non_nums > 0:
170
170
  wandb.termerror(
171
- "%s contains values that are not numbers. Please vectorize, label encode or one hot encode %s and call the plotting function again."
172
- % (k, k)
171
+ f"{k} contains values that are not numbers. Please vectorize, label encode or one hot encode {k} "
172
+ "and call the plotting function again."
173
173
  )
174
174
  test_passed = False
175
175
  return test_passed
@@ -1,11 +1,9 @@
1
- #!/usr/bin/env python
2
-
3
- """PyTorch-specific functionality"""
1
+ """PyTorch-specific functionality."""
4
2
 
5
3
  import itertools
6
4
  from functools import reduce
7
5
  from operator import mul
8
- from typing import List
6
+ from typing import TYPE_CHECKING, List
9
7
 
10
8
  import wandb
11
9
  from wandb import util
@@ -13,13 +11,18 @@ from wandb.data_types import Node
13
11
 
14
12
  torch = None
15
13
 
14
+ if TYPE_CHECKING:
15
+ from torch import Tensor
16
+ from torch.nn import Module
17
+
16
18
 
17
19
  def nested_shape(array_or_tuple, seen=None):
18
- """Figure out the shape of tensors possibly embedded in tuples
19
- i.e
20
- [0,0] returns (2)
21
- ([0,0], [0,0]) returns (2,2)
22
- (([0,0], [0,0]),[0,0]) returns ((2,2),2)
20
+ """Figure out the shape of tensors possibly embedded in tuples.
21
+
22
+ for example:
23
+ - [0,0] returns (2)
24
+ - ([0,0], [0,0]) returns (2,2)
25
+ - (([0,0], [0,0]),[0,0]) returns ((2,2),2).
23
26
  """
24
27
  if seen is None:
25
28
  seen = set()
@@ -49,14 +52,14 @@ LOG_TRACK_COUNT, LOG_TRACK_THRESHOLD = range(2)
49
52
 
50
53
 
51
54
  def log_track_init(log_freq: int) -> List[int]:
52
- """create tracking structure used by log_track_update"""
53
- l = [0] * 2
54
- l[LOG_TRACK_THRESHOLD] = log_freq
55
- return l
55
+ """Create tracking structure used by log_track_update."""
56
+ log_track = [0, 0]
57
+ log_track[LOG_TRACK_THRESHOLD] = log_freq
58
+ return log_track
56
59
 
57
60
 
58
61
  def log_track_update(log_track: int) -> bool:
59
- """count (log_track[0]) up to threshold (log_track[1]), reset count (log_track[0]) and return true when reached"""
62
+ """Count (log_track[0]) up to threshold (log_track[1]), reset count (log_track[0]) and return true when reached."""
60
63
  log_track[LOG_TRACK_COUNT] += 1
61
64
  if log_track[LOG_TRACK_COUNT] < log_track[LOG_TRACK_THRESHOLD]:
62
65
  return False
@@ -65,7 +68,7 @@ def log_track_update(log_track: int) -> bool:
65
68
 
66
69
 
67
70
  class TorchHistory:
68
- """History methods specific to PyTorch"""
71
+ """History methods specific to PyTorch."""
69
72
 
70
73
  def __init__(self):
71
74
  global torch
@@ -77,14 +80,15 @@ class TorchHistory:
77
80
 
78
81
  def add_log_parameters_hook(
79
82
  self,
80
- module: "torch.nn.Module",
83
+ module: "Module",
81
84
  name: str = "",
82
85
  prefix: str = "",
83
86
  log_freq: int = 0,
84
87
  ) -> None:
85
- """This instruments hooks into the pytorch module
88
+ """This instruments hooks into the pytorch module.
89
+
86
90
  log parameters after a forward pass
87
- log_freq - log gradients/parameters every N batches
91
+ log_freq - log gradients/parameters every N batches.
88
92
  """
89
93
  # if name is not None:
90
94
  prefix = prefix + name
@@ -119,16 +123,19 @@ class TorchHistory:
119
123
 
120
124
  def add_log_gradients_hook(
121
125
  self,
122
- module: "torch.nn.Module",
126
+ module: "Module",
123
127
  name: str = "",
124
128
  prefix: str = "",
125
129
  log_freq: int = 0,
126
130
  ) -> None:
127
- """This instruments hooks into the pytorch module
128
- log gradients after a backward pass
129
- log_freq - log gradients/parameters every N batches
130
- """
131
+ """This instruments hooks into the PyTorch module slog gradients after a backward pass.
131
132
 
133
+ Args:
134
+ module: torch.nn.Module - the module to instrument
135
+ name: str - the name of the module
136
+ prefix: str - the prefix to add to the name
137
+ log_freq: log gradients/parameters every N batches
138
+ """
132
139
  # if name is not None:
133
140
  prefix = prefix + name
134
141
 
@@ -143,8 +150,8 @@ class TorchHistory:
143
150
  parameter, "gradients/" + prefix + name, log_track_grad
144
151
  )
145
152
 
146
- def log_tensor_stats(self, tensor, name):
147
- """Add distribution statistics on a tensor's elements to the current History entry"""
153
+ def log_tensor_stats(self, tensor, name): # noqa: C901
154
+ """Add distribution statistics on a tensor's elements to the current History entry."""
148
155
  # TODO Handle the case of duplicate names.
149
156
  if isinstance(tensor, (tuple, list)):
150
157
  while isinstance(tensor, (tuple, list)) and isinstance(
@@ -250,9 +257,7 @@ class TorchHistory:
250
257
  )
251
258
 
252
259
  def _hook_variable_gradient_stats(self, var, name, log_track):
253
- """Logs a Variable's gradient's distribution statistics next time backward()
254
- is called on it.
255
- """
260
+ """Logs a Variable's gradient's distribution statistics next time backward() is called on it."""
256
261
  if not isinstance(var, torch.autograd.Variable):
257
262
  cls = type(var)
258
263
  raise TypeError(
@@ -288,10 +293,10 @@ class TorchHistory:
288
293
  else:
289
294
  return handle.id in d
290
295
 
291
- def _no_finite_values(self, tensor: "torch.Tensor") -> bool:
296
+ def _no_finite_values(self, tensor: "Tensor") -> bool:
292
297
  return tensor.shape == torch.Size([0]) or (~torch.isfinite(tensor)).all().item()
293
298
 
294
- def _remove_infs_nans(self, tensor: "torch.Tensor") -> "torch.Tensor":
299
+ def _remove_infs_nans(self, tensor: "Tensor") -> "Tensor":
295
300
  if not torch.isfinite(tensor).all():
296
301
  tensor = tensor[torch.isfinite(tensor)]
297
302
 
@@ -420,8 +425,7 @@ class TorchGraph(wandb.data_types.Graph):
420
425
 
421
426
  @classmethod
422
427
  def from_torch_layers(cls, module_graph, variable):
423
- """Recover something like neural net layers from PyTorch Module's and the
424
- compute graph from a Variable.
428
+ """Recover something like neural net layers from PyTorch Module's and the compute graph from a Variable.
425
429
 
426
430
  Example output for a multi-layer RNN. We confusingly assign shared embedding values
427
431
  to the encoder, but ordered next to the decoder.
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
14
14
 
15
15
 
16
16
 
17
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1cwandb/proto/wandb_base.proto\x12\x0ewandb_internal\"6\n\x0b_RecordInfo\x12\x11\n\tstream_id\x18\x01 \x01(\t\x12\x14\n\x0c_tracelog_id\x18\x64 \x01(\t\"!\n\x0c_RequestInfo\x12\x11\n\tstream_id\x18\x01 \x01(\t\"#\n\x0b_ResultInfo\x12\x14\n\x0c_tracelog_id\x18\x64 \x01(\tb\x06proto3')
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1cwandb/proto/wandb_base.proto\x12\x0ewandb_internal\"6\n\x0b_RecordInfo\x12\x11\n\tstream_id\x18\x01 \x01(\t\x12\x14\n\x0c_tracelog_id\x18\x64 \x01(\t\"!\n\x0c_RequestInfo\x12\x11\n\tstream_id\x18\x01 \x01(\t\"#\n\x0b_ResultInfo\x12\x14\n\x0c_tracelog_id\x18\x64 \x01(\tB\x1bZ\x19\x63ore/pkg/service_go_protob\x06proto3')
18
18
 
19
19
 
20
20
 
@@ -45,6 +45,7 @@ _sym_db.RegisterMessage(_ResultInfo)
45
45
  if _descriptor._USE_C_DESCRIPTORS == False:
46
46
 
47
47
  DESCRIPTOR._options = None
48
+ DESCRIPTOR._serialized_options = b'Z\031core/pkg/service_go_proto'
48
49
  __RECORDINFO._serialized_start=48
49
50
  __RECORDINFO._serialized_end=102
50
51
  __REQUESTINFO._serialized_start=104