wandb 0.18.0rc1__py3-none-win_amd64.whl → 0.18.1__py3-none-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (62) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/__init__.pyi +1 -1
  3. wandb/apis/public/runs.py +2 -0
  4. wandb/bin/wandb-core +0 -0
  5. wandb/cli/cli.py +0 -2
  6. wandb/data_types.py +9 -2019
  7. wandb/env.py +0 -5
  8. wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
  9. wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
  10. wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
  11. wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
  12. wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
  13. wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
  14. wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
  15. wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
  16. wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
  17. wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
  18. wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
  19. wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
  20. wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
  21. wandb/{sklearn → integration/sklearn}/utils.py +8 -8
  22. wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
  23. wandb/proto/v3/wandb_base_pb2.py +2 -1
  24. wandb/proto/v3/wandb_internal_pb2.py +2 -1
  25. wandb/proto/v3/wandb_server_pb2.py +2 -1
  26. wandb/proto/v3/wandb_settings_pb2.py +2 -1
  27. wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
  28. wandb/proto/v4/wandb_base_pb2.py +2 -1
  29. wandb/proto/v4/wandb_internal_pb2.py +2 -1
  30. wandb/proto/v4/wandb_server_pb2.py +2 -1
  31. wandb/proto/v4/wandb_settings_pb2.py +2 -1
  32. wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
  33. wandb/proto/v5/wandb_base_pb2.py +3 -2
  34. wandb/proto/v5/wandb_internal_pb2.py +3 -2
  35. wandb/proto/v5/wandb_server_pb2.py +3 -2
  36. wandb/proto/v5/wandb_settings_pb2.py +3 -2
  37. wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
  38. wandb/sdk/data_types/audio.py +165 -0
  39. wandb/sdk/data_types/bokeh.py +70 -0
  40. wandb/sdk/data_types/graph.py +405 -0
  41. wandb/sdk/data_types/image.py +156 -0
  42. wandb/sdk/data_types/table.py +1204 -0
  43. wandb/sdk/data_types/trace_tree.py +2 -2
  44. wandb/sdk/data_types/utils.py +49 -0
  45. wandb/sdk/service/service.py +2 -9
  46. wandb/sdk/service/streams.py +0 -7
  47. wandb/sdk/wandb_init.py +10 -3
  48. wandb/sdk/wandb_run.py +6 -152
  49. wandb/sdk/wandb_setup.py +1 -1
  50. wandb/sklearn.py +35 -0
  51. wandb/util.py +6 -2
  52. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/METADATA +5 -5
  53. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/RECORD +61 -57
  54. wandb/sdk/lib/console.py +0 -39
  55. /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
  56. /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
  57. /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
  58. /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
  59. /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
  60. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/WHEEL +0 -0
  61. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/entry_points.txt +0 -0
  62. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/licenses/LICENSE +0 -0
wandb/env.py CHANGED
@@ -62,7 +62,6 @@ SAVE_CODE = "WANDB_SAVE_CODE"
62
62
  TAGS = "WANDB_TAGS"
63
63
  IGNORE = "WANDB_IGNORE_GLOBS"
64
64
  ERROR_REPORTING = "WANDB_ERROR_REPORTING"
65
- CORE_ERROR_REPORTING = "WANDB_CORE_ERROR_REPORTING"
66
65
  CORE_DEBUG = "WANDB_CORE_DEBUG"
67
66
  DOCKER = "WANDB_DOCKER"
68
67
  AGENT_REPORT_INTERVAL = "WANDB_AGENT_REPORT_INTERVAL"
@@ -172,10 +171,6 @@ def error_reporting_enabled() -> bool:
172
171
  return _env_as_bool(ERROR_REPORTING, default="True")
173
172
 
174
173
 
175
- def core_error_reporting_enabled(default: Optional[str] = None) -> bool:
176
- return _env_as_bool(CORE_ERROR_REPORTING, default=default)
177
-
178
-
179
174
  def core_debug(default: Optional[str] = None) -> bool:
180
175
  return _env_as_bool(CORE_DEBUG, default=default)
181
176
 
@@ -7,18 +7,18 @@ from sklearn.calibration import CalibratedClassifierCV
7
7
  from sklearn.linear_model import LogisticRegression
8
8
 
9
9
  import wandb
10
- from wandb.sklearn import utils
10
+ from wandb.integration.sklearn import utils
11
11
 
12
12
  # ignore all future warnings
13
13
  simplefilter(action="ignore", category=FutureWarning)
14
14
 
15
15
 
16
- def calibration_curves(clf, X, y, clf_name):
16
+ def calibration_curves(clf, X, y, clf_name): # noqa: N803
17
17
  # ComplementNB (introduced in 0.20.0) requires non-negative features
18
18
  if int(sklearn.__version__.split(".")[1]) >= 20 and isinstance(
19
19
  clf, naive_bayes.ComplementNB
20
20
  ):
21
- X = X - X.min()
21
+ X = X - X.min() # noqa:N806
22
22
 
23
23
  # Calibrated with isotonic calibration
24
24
  isotonic = CalibratedClassifierCV(clf, cv=2, method="isotonic")
@@ -48,7 +48,7 @@ def calibration_curves(clf, X, y, clf_name):
48
48
  frac_positives_column.append(1)
49
49
  mean_pred_value_column.append(1)
50
50
 
51
- X_train, X_test, y_train, y_test = model_selection.train_test_split(
51
+ x_train, x_test, y_train, y_test = model_selection.train_test_split(
52
52
  X, y, test_size=0.9, random_state=42
53
53
  )
54
54
 
@@ -58,11 +58,11 @@ def calibration_curves(clf, X, y, clf_name):
58
58
  names = ["Logistic", f"{clf_name} Isotonic", f"{clf_name} Sigmoid"]
59
59
 
60
60
  for model, name in zip(models, names):
61
- model.fit(X_train, y_train)
61
+ model.fit(x_train, y_train)
62
62
  if hasattr(model, "predict_proba"):
63
- prob_pos = model.predict_proba(X_test)[:, 1]
63
+ prob_pos = model.predict_proba(x_test)[:, 1]
64
64
  else: # use decision function
65
- prob_pos = model.decision_function(X_test)
65
+ prob_pos = model.decision_function(x_test)
66
66
  prob_pos = (prob_pos - prob_pos.min()) / (prob_pos.max() - prob_pos.min())
67
67
 
68
68
  hist, edges = np.histogram(prob_pos, bins=10, density=False)
@@ -4,7 +4,7 @@ import numpy as np
4
4
  from sklearn.utils.multiclass import unique_labels
5
5
 
6
6
  import wandb
7
- from wandb.sklearn import utils
7
+ from wandb.integration.sklearn import utils
8
8
 
9
9
  # ignore all future warnings
10
10
  simplefilter(action="ignore", category=FutureWarning)
@@ -6,14 +6,15 @@ from sklearn import metrics
6
6
  from sklearn.utils.multiclass import unique_labels
7
7
 
8
8
  import wandb
9
- from wandb.sklearn import utils
9
+
10
+ from .. import utils
10
11
 
11
12
  # ignore all future warnings
12
13
  simplefilter(action="ignore", category=FutureWarning)
13
14
 
14
15
 
15
16
  def validate_labels(*args, **kwargs): # FIXME
16
- assert False
17
+ raise AssertionError()
17
18
 
18
19
 
19
20
  def confusion_matrix(
@@ -11,7 +11,7 @@ import wandb
11
11
  simplefilter(action="ignore", category=FutureWarning)
12
12
 
13
13
 
14
- def elbow_curve(clusterer, X, cluster_ranges, n_jobs, show_cluster_time):
14
+ def elbow_curve(clusterer, X, cluster_ranges, n_jobs, show_cluster_time): # noqa: N803
15
15
  if cluster_ranges is None:
16
16
  cluster_ranges = range(1, 10, 2)
17
17
  else:
@@ -37,19 +37,19 @@ def make_table(cluster_ranges, clfs, times):
37
37
  return table
38
38
 
39
39
 
40
- def _compute_results_parallel(n_jobs, clusterer, X, cluster_ranges):
40
+ def _compute_results_parallel(n_jobs, clusterer, x, cluster_ranges):
41
41
  parallel_runner = Parallel(n_jobs=n_jobs)
42
42
  _cluster_scorer = delayed(_clone_and_score_clusterer)
43
- results = parallel_runner(_cluster_scorer(clusterer, X, i) for i in cluster_ranges)
43
+ results = parallel_runner(_cluster_scorer(clusterer, x, i) for i in cluster_ranges)
44
44
 
45
45
  clfs, times = zip(*results)
46
46
 
47
47
  return clfs, times
48
48
 
49
49
 
50
- def _clone_and_score_clusterer(clusterer, X, n_clusters):
50
+ def _clone_and_score_clusterer(clusterer, x, n_clusters):
51
51
  start = time.time()
52
52
  clusterer = clone(clusterer)
53
- setattr(clusterer, "n_clusters", n_clusters)
53
+ clusterer.n_clusters = n_clusters
54
54
 
55
- return clusterer.fit(X).score(X), time.time() - start
55
+ return clusterer.fit(x).score(x), time.time() - start
@@ -4,7 +4,7 @@ import numpy as np
4
4
  from sklearn import model_selection
5
5
 
6
6
  import wandb
7
- from wandb.sklearn import utils
7
+ from wandb.integration.sklearn import utils
8
8
 
9
9
  # ignore all future warnings
10
10
  simplefilter(action="ignore", category=FutureWarning)
@@ -12,7 +12,7 @@ simplefilter(action="ignore", category=FutureWarning)
12
12
 
13
13
  def learning_curve(
14
14
  model,
15
- X,
15
+ X, # noqa: N803
16
16
  y,
17
17
  cv=None,
18
18
  shuffle=False,
@@ -3,13 +3,13 @@ from warnings import simplefilter
3
3
  import numpy as np
4
4
 
5
5
  import wandb
6
- from wandb.sklearn import utils
6
+ from wandb.integration.sklearn import utils
7
7
 
8
8
  # ignore all future warnings
9
9
  simplefilter(action="ignore", category=FutureWarning)
10
10
 
11
11
 
12
- def outlier_candidates(regressor, X, y):
12
+ def outlier_candidates(regressor, X, y): # noqa: N803
13
13
  # Fit a linear model to X and y to compute MSE
14
14
  regressor.fit(X, y)
15
15
 
@@ -3,27 +3,27 @@ from warnings import simplefilter
3
3
  from sklearn import model_selection
4
4
 
5
5
  import wandb
6
- from wandb.sklearn import utils
6
+ from wandb.integration.sklearn import utils
7
7
 
8
8
  # ignore all future warnings
9
9
  simplefilter(action="ignore", category=FutureWarning)
10
10
 
11
11
 
12
- def residuals(regressor, X, y):
12
+ def residuals(regressor, X, y): # noqa: N803
13
13
  # Create the train and test splits
14
- X_train, X_test, y_train, y_test = model_selection.train_test_split(
14
+ x_train, x_test, y_train, y_test = model_selection.train_test_split(
15
15
  X, y, test_size=0.2
16
16
  )
17
17
 
18
18
  # Store labels and colors for the legend ordered by call
19
- regressor.fit(X_train, y_train)
20
- train_score_ = regressor.score(X_train, y_train)
21
- test_score_ = regressor.score(X_test, y_test)
19
+ regressor.fit(x_train, y_train)
20
+ train_score_ = regressor.score(x_train, y_train)
21
+ test_score_ = regressor.score(x_test, y_test)
22
22
 
23
- y_pred_train = regressor.predict(X_train)
23
+ y_pred_train = regressor.predict(x_train)
24
24
  residuals_train = y_pred_train - y_train
25
25
 
26
- y_pred_test = regressor.predict(X_test)
26
+ y_pred_test = regressor.predict(x_test)
27
27
  residuals_test = y_pred_test - y_test
28
28
 
29
29
  table = make_table(
@@ -5,13 +5,13 @@ from sklearn.metrics import silhouette_samples, silhouette_score
5
5
  from sklearn.preprocessing import LabelEncoder
6
6
 
7
7
  import wandb
8
- from wandb.sklearn import utils
8
+ from wandb.integration.sklearn import utils
9
9
 
10
10
  # ignore all future warnings
11
11
  simplefilter(action="ignore", category=FutureWarning)
12
12
 
13
13
 
14
- def silhouette(clusterer, X, cluster_labels, labels, metric, kmeans):
14
+ def silhouette(clusterer, X, cluster_labels, labels, metric, kmeans): # noqa: N803
15
15
  # Run clusterer for n_clusters in range(len(cluster_ranges), get cluster labels
16
16
  # TODO - keep/delete once we decide if we should train clusterers
17
17
  # or ask for trained models
@@ -4,13 +4,13 @@ import numpy as np
4
4
  import sklearn
5
5
 
6
6
  import wandb
7
- from wandb.sklearn import utils
7
+ from wandb.integration.sklearn import utils
8
8
 
9
9
  # ignore all future warnings
10
10
  simplefilter(action="ignore", category=FutureWarning)
11
11
 
12
12
 
13
- def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None):
13
+ def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None): # noqa: N803
14
14
  """Calculate summary metrics for both regressors and classifiers.
15
15
 
16
16
  Called by plot_summary_metrics to visualize metrics. Please use the function
@@ -7,7 +7,7 @@ from sklearn import naive_bayes
7
7
 
8
8
  import wandb
9
9
  import wandb.plot
10
- from wandb.sklearn import calculate, utils
10
+ from wandb.integration.sklearn import calculate, utils
11
11
 
12
12
  from . import shared
13
13
 
@@ -17,8 +17,8 @@ simplefilter(action="ignore", category=FutureWarning)
17
17
 
18
18
  def classifier(
19
19
  model,
20
- X_train,
21
- X_test,
20
+ X_train, # noqa: N803
21
+ X_test, # noqa: N803
22
22
  y_train,
23
23
  y_test,
24
24
  y_pred,
@@ -77,7 +77,7 @@ def classifier(
77
77
  )
78
78
  ```
79
79
  """
80
- wandb.termlog("\nPlotting %s." % model_name)
80
+ wandb.termlog(f"\nPlotting {model_name}.")
81
81
 
82
82
  if not isinstance(model, naive_bayes.MultinomialNB):
83
83
  feature_importances(model, feature_names)
@@ -280,7 +280,7 @@ def class_proportions(y_train=None, y_test=None, labels=None):
280
280
  wandb.log({"class_proportions": class_proportions_chart})
281
281
 
282
282
 
283
- def calibration_curve(clf=None, X=None, y=None, clf_name="Classifier"):
283
+ def calibration_curve(clf=None, X=None, y=None, clf_name="Classifier"): # noqa: N803
284
284
  """Log a plot depicting how well-calibrated the predicted probabilities of a classifier are.
285
285
 
286
286
  Also suggests how to calibrate an uncalibrated classifier. Compares estimated predicted
@@ -6,13 +6,13 @@ import pandas as pd
6
6
  import sklearn
7
7
 
8
8
  import wandb
9
- from wandb.sklearn import calculate, utils
9
+ from wandb.integration.sklearn import calculate, utils
10
10
 
11
11
  # ignore all future warnings
12
12
  simplefilter(action="ignore", category=FutureWarning)
13
13
 
14
14
 
15
- def clusterer(model, X_train, cluster_labels, labels=None, model_name="Clusterer"):
15
+ def clusterer(model, X_train, cluster_labels, labels=None, model_name="Clusterer"): # noqa: N803
16
16
  """Generates all sklearn clusterer plots supported by W&B.
17
17
 
18
18
  The following plots are generated:
@@ -40,7 +40,7 @@ def clusterer(model, X_train, cluster_labels, labels=None, model_name="Clusterer
40
40
  wandb.sklearn.plot_clusterer(kmeans, X, cluster_labels, labels, "KMeans")
41
41
  ```
42
42
  """
43
- wandb.termlog("\nPlotting %s." % model_name)
43
+ wandb.termlog(f"\nPlotting {model_name}.")
44
44
  if isinstance(model, sklearn.cluster.KMeans):
45
45
  elbow_curve(model, X_train)
46
46
  wandb.termlog("Logged elbow curve.")
@@ -54,7 +54,11 @@ def clusterer(model, X_train, cluster_labels, labels=None, model_name="Clusterer
54
54
 
55
55
 
56
56
  def elbow_curve(
57
- clusterer=None, X=None, cluster_ranges=None, n_jobs=1, show_cluster_time=True
57
+ clusterer=None,
58
+ X=None, # noqa: N803
59
+ cluster_ranges=None,
60
+ n_jobs=1,
61
+ show_cluster_time=True,
58
62
  ):
59
63
  """Measures and plots variance explained as a function of the number of clusters.
60
64
 
@@ -97,7 +101,7 @@ def elbow_curve(
97
101
 
98
102
  def silhouette(
99
103
  clusterer=None,
100
- X=None,
104
+ X=None, # noqa: N803
101
105
  cluster_labels=None,
102
106
  labels=None,
103
107
  metric="euclidean",
@@ -135,7 +139,7 @@ def silhouette(
135
139
 
136
140
  if not_missing and correct_types and is_fitted:
137
141
  if isinstance(X, (pd.DataFrame)):
138
- X = X.values
142
+ X = X.values # noqa: N806
139
143
  silhouette_chart = calculate.silhouette(
140
144
  clusterer, X, cluster_labels, labels, metric, kmeans
141
145
  )
@@ -5,7 +5,7 @@ from warnings import simplefilter
5
5
  import numpy as np
6
6
 
7
7
  import wandb
8
- from wandb.sklearn import calculate, utils
8
+ from wandb.integration.sklearn import calculate, utils
9
9
 
10
10
  from . import shared
11
11
 
@@ -13,7 +13,7 @@ from . import shared
13
13
  simplefilter(action="ignore", category=FutureWarning)
14
14
 
15
15
 
16
- def regressor(model, X_train, X_test, y_train, y_test, model_name="Regressor"):
16
+ def regressor(model, X_train, X_test, y_train, y_test, model_name="Regressor"): # noqa: N803
17
17
  """Generates all sklearn regressor plots supported by W&B.
18
18
 
19
19
  The following plots are generated:
@@ -38,7 +38,7 @@ def regressor(model, X_train, X_test, y_train, y_test, model_name="Regressor"):
38
38
  wandb.sklearn.plot_regressor(reg, X_train, X_test, y_train, y_test, "Ridge")
39
39
  ```
40
40
  """
41
- wandb.termlog("\nPlotting %s." % model_name)
41
+ wandb.termlog(f"\nPlotting {model_name}.")
42
42
 
43
43
  shared.summary_metrics(model, X_train, y_train, X_test, y_test)
44
44
  wandb.termlog("Logged summary metrics.")
@@ -53,7 +53,7 @@ def regressor(model, X_train, X_test, y_train, y_test, model_name="Regressor"):
53
53
  wandb.termlog("Logged residuals.")
54
54
 
55
55
 
56
- def outlier_candidates(regressor=None, X=None, y=None):
56
+ def outlier_candidates(regressor=None, X=None, y=None): # noqa: N803
57
57
  """Measures a datapoint's influence on regression model via cook's distance.
58
58
 
59
59
  Instances with high influences could potentially be outliers.
@@ -87,7 +87,7 @@ def outlier_candidates(regressor=None, X=None, y=None):
87
87
  wandb.log({"outlier_candidates": outliers_chart})
88
88
 
89
89
 
90
- def residuals(regressor=None, X=None, y=None):
90
+ def residuals(regressor=None, X=None, y=None): # noqa: N803
91
91
  """Measures and plots the regressor's predicted value against the residual.
92
92
 
93
93
  The marginal distribution of residuals is also calculated and plotted.
@@ -5,13 +5,13 @@ from warnings import simplefilter
5
5
  import numpy as np
6
6
 
7
7
  import wandb
8
- from wandb.sklearn import calculate, utils
8
+ from wandb.integration.sklearn import calculate, utils
9
9
 
10
10
  # ignore all future warnings
11
11
  simplefilter(action="ignore", category=FutureWarning)
12
12
 
13
13
 
14
- def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None):
14
+ def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None): # noqa: N803
15
15
  """Logs a chart depicting summary metrics for a model.
16
16
 
17
17
  Should only be called with a fitted model (otherwise an error is thrown).
@@ -47,7 +47,7 @@ def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None):
47
47
 
48
48
  def learning_curve(
49
49
  model=None,
50
- X=None,
50
+ X=None, # noqa: N803
51
51
  y=None,
52
52
  cv=None,
53
53
  shuffle=False,
@@ -61,7 +61,7 @@ def test_types(**kwargs):
61
61
  list,
62
62
  ),
63
63
  ):
64
- wandb.termerror("%s is not an array. Please try again." % (k))
64
+ wandb.termerror(f"{k} is not an array. Please try again.")
65
65
  test_passed = False
66
66
  # check for classifier types
67
67
  if k == "model":
@@ -69,20 +69,20 @@ def test_types(**kwargs):
69
69
  not sklearn.base.is_regressor(v)
70
70
  ):
71
71
  wandb.termerror(
72
- "%s is not a classifier or regressor. Please try again." % (k)
72
+ f"{k} is not a classifier or regressor. Please try again."
73
73
  )
74
74
  test_passed = False
75
75
  elif k == "clf" or k == "binary_clf":
76
76
  if not (sklearn.base.is_classifier(v)):
77
- wandb.termerror("%s is not a classifier. Please try again." % (k))
77
+ wandb.termerror(f"{k} is not a classifier. Please try again.")
78
78
  test_passed = False
79
79
  elif k == "regressor":
80
80
  if not sklearn.base.is_regressor(v):
81
- wandb.termerror("%s is not a regressor. Please try again." % (k))
81
+ wandb.termerror(f"{k} is not a regressor. Please try again.")
82
82
  test_passed = False
83
83
  elif k == "clusterer":
84
84
  if not (getattr(v, "_estimator_type", None) == "clusterer"):
85
- wandb.termerror("%s is not a clusterer. Please try again." % (k))
85
+ wandb.termerror(f"{k} is not a clusterer. Please try again.")
86
86
  test_passed = False
87
87
  return test_passed
88
88
 
@@ -129,7 +129,7 @@ def test_missing(**kwargs):
129
129
  for k, v in kwargs.items():
130
130
  # Missing/empty params/datapoint arrays
131
131
  if v is None:
132
- wandb.termerror("%s is None. Please try again." % (k))
132
+ wandb.termerror(f"{k} is None. Please try again.")
133
133
  test_passed = False
134
134
  if (k == "X") or (k == "X_test"):
135
135
  if isinstance(v, scipy.sparse.csr.csr_matrix):
@@ -168,8 +168,8 @@ def test_missing(**kwargs):
168
168
  )
169
169
  if non_nums > 0:
170
170
  wandb.termerror(
171
- "%s contains values that are not numbers. Please vectorize, label encode or one hot encode %s and call the plotting function again."
172
- % (k, k)
171
+ f"{k} contains values that are not numbers. Please vectorize, label encode or one hot encode {k} "
172
+ "and call the plotting function again."
173
173
  )
174
174
  test_passed = False
175
175
  return test_passed
@@ -1,11 +1,9 @@
1
- #!/usr/bin/env python
2
-
3
- """PyTorch-specific functionality"""
1
+ """PyTorch-specific functionality."""
4
2
 
5
3
  import itertools
6
4
  from functools import reduce
7
5
  from operator import mul
8
- from typing import List
6
+ from typing import TYPE_CHECKING, List
9
7
 
10
8
  import wandb
11
9
  from wandb import util
@@ -13,13 +11,18 @@ from wandb.data_types import Node
13
11
 
14
12
  torch = None
15
13
 
14
+ if TYPE_CHECKING:
15
+ from torch import Tensor
16
+ from torch.nn import Module
17
+
16
18
 
17
19
  def nested_shape(array_or_tuple, seen=None):
18
- """Figure out the shape of tensors possibly embedded in tuples
19
- i.e
20
- [0,0] returns (2)
21
- ([0,0], [0,0]) returns (2,2)
22
- (([0,0], [0,0]),[0,0]) returns ((2,2),2)
20
+ """Figure out the shape of tensors possibly embedded in tuples.
21
+
22
+ for example:
23
+ - [0,0] returns (2)
24
+ - ([0,0], [0,0]) returns (2,2)
25
+ - (([0,0], [0,0]),[0,0]) returns ((2,2),2).
23
26
  """
24
27
  if seen is None:
25
28
  seen = set()
@@ -49,14 +52,14 @@ LOG_TRACK_COUNT, LOG_TRACK_THRESHOLD = range(2)
49
52
 
50
53
 
51
54
  def log_track_init(log_freq: int) -> List[int]:
52
- """create tracking structure used by log_track_update"""
53
- l = [0] * 2
54
- l[LOG_TRACK_THRESHOLD] = log_freq
55
- return l
55
+ """Create tracking structure used by log_track_update."""
56
+ log_track = [0, 0]
57
+ log_track[LOG_TRACK_THRESHOLD] = log_freq
58
+ return log_track
56
59
 
57
60
 
58
61
  def log_track_update(log_track: int) -> bool:
59
- """count (log_track[0]) up to threshold (log_track[1]), reset count (log_track[0]) and return true when reached"""
62
+ """Count (log_track[0]) up to threshold (log_track[1]), reset count (log_track[0]) and return true when reached."""
60
63
  log_track[LOG_TRACK_COUNT] += 1
61
64
  if log_track[LOG_TRACK_COUNT] < log_track[LOG_TRACK_THRESHOLD]:
62
65
  return False
@@ -65,7 +68,7 @@ def log_track_update(log_track: int) -> bool:
65
68
 
66
69
 
67
70
  class TorchHistory:
68
- """History methods specific to PyTorch"""
71
+ """History methods specific to PyTorch."""
69
72
 
70
73
  def __init__(self):
71
74
  global torch
@@ -77,14 +80,15 @@ class TorchHistory:
77
80
 
78
81
  def add_log_parameters_hook(
79
82
  self,
80
- module: "torch.nn.Module",
83
+ module: "Module",
81
84
  name: str = "",
82
85
  prefix: str = "",
83
86
  log_freq: int = 0,
84
87
  ) -> None:
85
- """This instruments hooks into the pytorch module
88
+ """This instruments hooks into the pytorch module.
89
+
86
90
  log parameters after a forward pass
87
- log_freq - log gradients/parameters every N batches
91
+ log_freq - log gradients/parameters every N batches.
88
92
  """
89
93
  # if name is not None:
90
94
  prefix = prefix + name
@@ -119,16 +123,19 @@ class TorchHistory:
119
123
 
120
124
  def add_log_gradients_hook(
121
125
  self,
122
- module: "torch.nn.Module",
126
+ module: "Module",
123
127
  name: str = "",
124
128
  prefix: str = "",
125
129
  log_freq: int = 0,
126
130
  ) -> None:
127
- """This instruments hooks into the pytorch module
128
- log gradients after a backward pass
129
- log_freq - log gradients/parameters every N batches
130
- """
131
+ """This instruments hooks into the PyTorch module slog gradients after a backward pass.
131
132
 
133
+ Args:
134
+ module: torch.nn.Module - the module to instrument
135
+ name: str - the name of the module
136
+ prefix: str - the prefix to add to the name
137
+ log_freq: log gradients/parameters every N batches
138
+ """
132
139
  # if name is not None:
133
140
  prefix = prefix + name
134
141
 
@@ -143,8 +150,8 @@ class TorchHistory:
143
150
  parameter, "gradients/" + prefix + name, log_track_grad
144
151
  )
145
152
 
146
- def log_tensor_stats(self, tensor, name):
147
- """Add distribution statistics on a tensor's elements to the current History entry"""
153
+ def log_tensor_stats(self, tensor, name): # noqa: C901
154
+ """Add distribution statistics on a tensor's elements to the current History entry."""
148
155
  # TODO Handle the case of duplicate names.
149
156
  if isinstance(tensor, (tuple, list)):
150
157
  while isinstance(tensor, (tuple, list)) and isinstance(
@@ -250,9 +257,7 @@ class TorchHistory:
250
257
  )
251
258
 
252
259
  def _hook_variable_gradient_stats(self, var, name, log_track):
253
- """Logs a Variable's gradient's distribution statistics next time backward()
254
- is called on it.
255
- """
260
+ """Logs a Variable's gradient's distribution statistics next time backward() is called on it."""
256
261
  if not isinstance(var, torch.autograd.Variable):
257
262
  cls = type(var)
258
263
  raise TypeError(
@@ -288,10 +293,10 @@ class TorchHistory:
288
293
  else:
289
294
  return handle.id in d
290
295
 
291
- def _no_finite_values(self, tensor: "torch.Tensor") -> bool:
296
+ def _no_finite_values(self, tensor: "Tensor") -> bool:
292
297
  return tensor.shape == torch.Size([0]) or (~torch.isfinite(tensor)).all().item()
293
298
 
294
- def _remove_infs_nans(self, tensor: "torch.Tensor") -> "torch.Tensor":
299
+ def _remove_infs_nans(self, tensor: "Tensor") -> "Tensor":
295
300
  if not torch.isfinite(tensor).all():
296
301
  tensor = tensor[torch.isfinite(tensor)]
297
302
 
@@ -420,8 +425,7 @@ class TorchGraph(wandb.data_types.Graph):
420
425
 
421
426
  @classmethod
422
427
  def from_torch_layers(cls, module_graph, variable):
423
- """Recover something like neural net layers from PyTorch Module's and the
424
- compute graph from a Variable.
428
+ """Recover something like neural net layers from PyTorch Module's and the compute graph from a Variable.
425
429
 
426
430
  Example output for a multi-layer RNN. We confusingly assign shared embedding values
427
431
  to the encoder, but ordered next to the decoder.
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
14
14
 
15
15
 
16
16
 
17
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1cwandb/proto/wandb_base.proto\x12\x0ewandb_internal\"6\n\x0b_RecordInfo\x12\x11\n\tstream_id\x18\x01 \x01(\t\x12\x14\n\x0c_tracelog_id\x18\x64 \x01(\t\"!\n\x0c_RequestInfo\x12\x11\n\tstream_id\x18\x01 \x01(\t\"#\n\x0b_ResultInfo\x12\x14\n\x0c_tracelog_id\x18\x64 \x01(\tb\x06proto3')
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1cwandb/proto/wandb_base.proto\x12\x0ewandb_internal\"6\n\x0b_RecordInfo\x12\x11\n\tstream_id\x18\x01 \x01(\t\x12\x14\n\x0c_tracelog_id\x18\x64 \x01(\t\"!\n\x0c_RequestInfo\x12\x11\n\tstream_id\x18\x01 \x01(\t\"#\n\x0b_ResultInfo\x12\x14\n\x0c_tracelog_id\x18\x64 \x01(\tB\x1bZ\x19\x63ore/pkg/service_go_protob\x06proto3')
18
18
 
19
19
 
20
20
 
@@ -45,6 +45,7 @@ _sym_db.RegisterMessage(_ResultInfo)
45
45
  if _descriptor._USE_C_DESCRIPTORS == False:
46
46
 
47
47
  DESCRIPTOR._options = None
48
+ DESCRIPTOR._serialized_options = b'Z\031core/pkg/service_go_proto'
48
49
  __RECORDINFO._serialized_start=48
49
50
  __RECORDINFO._serialized_end=102
50
51
  __REQUESTINFO._serialized_start=104