wandb 0.18.0rc1__py3-none-win32.whl → 0.18.1__py3-none-win32.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (62) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/__init__.pyi +1 -1
  3. wandb/apis/public/runs.py +2 -0
  4. wandb/bin/wandb-core +0 -0
  5. wandb/cli/cli.py +0 -2
  6. wandb/data_types.py +9 -2019
  7. wandb/env.py +0 -5
  8. wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
  9. wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
  10. wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
  11. wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
  12. wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
  13. wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
  14. wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
  15. wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
  16. wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
  17. wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
  18. wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
  19. wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
  20. wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
  21. wandb/{sklearn → integration/sklearn}/utils.py +8 -8
  22. wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
  23. wandb/proto/v3/wandb_base_pb2.py +2 -1
  24. wandb/proto/v3/wandb_internal_pb2.py +2 -1
  25. wandb/proto/v3/wandb_server_pb2.py +2 -1
  26. wandb/proto/v3/wandb_settings_pb2.py +2 -1
  27. wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
  28. wandb/proto/v4/wandb_base_pb2.py +2 -1
  29. wandb/proto/v4/wandb_internal_pb2.py +2 -1
  30. wandb/proto/v4/wandb_server_pb2.py +2 -1
  31. wandb/proto/v4/wandb_settings_pb2.py +2 -1
  32. wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
  33. wandb/proto/v5/wandb_base_pb2.py +3 -2
  34. wandb/proto/v5/wandb_internal_pb2.py +3 -2
  35. wandb/proto/v5/wandb_server_pb2.py +3 -2
  36. wandb/proto/v5/wandb_settings_pb2.py +3 -2
  37. wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
  38. wandb/sdk/data_types/audio.py +165 -0
  39. wandb/sdk/data_types/bokeh.py +70 -0
  40. wandb/sdk/data_types/graph.py +405 -0
  41. wandb/sdk/data_types/image.py +156 -0
  42. wandb/sdk/data_types/table.py +1204 -0
  43. wandb/sdk/data_types/trace_tree.py +2 -2
  44. wandb/sdk/data_types/utils.py +49 -0
  45. wandb/sdk/service/service.py +2 -9
  46. wandb/sdk/service/streams.py +0 -7
  47. wandb/sdk/wandb_init.py +10 -3
  48. wandb/sdk/wandb_run.py +6 -152
  49. wandb/sdk/wandb_setup.py +1 -1
  50. wandb/sklearn.py +35 -0
  51. wandb/util.py +6 -2
  52. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/METADATA +5 -5
  53. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/RECORD +61 -57
  54. wandb/sdk/lib/console.py +0 -39
  55. /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
  56. /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
  57. /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
  58. /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
  59. /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
  60. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/WHEEL +0 -0
  61. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/entry_points.txt +0 -0
  62. {wandb-0.18.0rc1.dist-info → wandb-0.18.1.dist-info}/licenses/LICENSE +0 -0
wandb/data_types.py CHANGED
@@ -13,30 +13,10 @@ serialize to JSON, since that is what wandb uses to save the objects locally
13
13
  and upload them to the W&B server.
14
14
  """
15
15
 
16
- import base64
17
- import binascii
18
- import codecs
19
- import datetime
20
- import hashlib
21
- import json
22
- import logging
23
- import os
24
- import pprint
25
- from decimal import Decimal
26
- from typing import Optional
27
-
28
- import wandb
29
- from wandb import util
30
- from wandb.sdk.lib import filesystem
31
-
32
- from .sdk.data_types import _dtypes
33
- from .sdk.data_types._private import MEDIA_TMP
34
- from .sdk.data_types.base_types.media import (
35
- BatchableMedia,
36
- Media,
37
- _numpy_arrays_to_lists,
38
- )
16
+ from .sdk.data_types.audio import Audio
39
17
  from .sdk.data_types.base_types.wb_value import WBValue
18
+ from .sdk.data_types.bokeh import Bokeh
19
+ from .sdk.data_types.graph import Graph, Node
40
20
  from .sdk.data_types.helper_types.bounding_boxes_2d import BoundingBoxes2D
41
21
  from .sdk.data_types.helper_types.classes import Classes
42
22
  from .sdk.data_types.helper_types.image_mask import ImageMask
@@ -47,9 +27,9 @@ from .sdk.data_types.molecule import Molecule
47
27
  from .sdk.data_types.object_3d import Object3D, box3d
48
28
  from .sdk.data_types.plotly import Plotly
49
29
  from .sdk.data_types.saved_model import _SavedModel
30
+ from .sdk.data_types.table import JoinedTable, PartitionedTable, Table
50
31
  from .sdk.data_types.trace_tree import WBTraceTree
51
32
  from .sdk.data_types.video import Video
52
- from .sdk.lib import runid
53
33
 
54
34
  # Note: we are importing everything from the sdk/data_types to maintain a namespace for now.
55
35
  # Once we fully type this file and move it all into sdk, then we will need to clean up the
@@ -59,7 +39,11 @@ __all__ = [
59
39
  # Untyped Exports
60
40
  "Audio",
61
41
  "Table",
42
+ "JoinedTable",
43
+ "PartitionedTable",
62
44
  "Bokeh",
45
+ "Node",
46
+ "Graph",
63
47
  # Typed Exports
64
48
  "Histogram",
65
49
  "Html",
@@ -71,2003 +55,9 @@ __all__ = [
71
55
  "Video",
72
56
  "WBTraceTree",
73
57
  "_SavedModel",
58
+ "WBValue",
74
59
  # Typed Legacy Exports (I'd like to remove these)
75
60
  "ImageMask",
76
61
  "BoundingBoxes2D",
77
62
  "Classes",
78
63
  ]
79
-
80
-
81
- class _TableLinkMixin:
82
- def set_table(self, table):
83
- self._table = table
84
-
85
-
86
- class _TableKey(str, _TableLinkMixin):
87
- def set_table(self, table, col_name):
88
- assert col_name in table.columns
89
- self._table = table
90
- self._col_name = col_name
91
-
92
-
93
- class _TableIndex(int, _TableLinkMixin):
94
- def get_row(self):
95
- row = {}
96
- if self._table:
97
- row = {
98
- c: self._table.data[self][i] for i, c in enumerate(self._table.columns)
99
- }
100
-
101
- return row
102
-
103
-
104
- def _json_helper(val, artifact):
105
- if isinstance(val, WBValue):
106
- return val.to_json(artifact)
107
- elif val.__class__ is dict:
108
- res = {}
109
- for key in val:
110
- res[key] = _json_helper(val[key], artifact)
111
- return res
112
-
113
- if hasattr(val, "tolist"):
114
- py_val = val.tolist()
115
- if val.__class__.__name__ == "datetime64" and isinstance(py_val, int):
116
- # when numpy datetime64 .tolist() returns an int, it is nanoseconds.
117
- # need to convert to milliseconds
118
- return _json_helper(py_val / int(1e6), artifact)
119
- return _json_helper(py_val, artifact)
120
- elif hasattr(val, "item"):
121
- return _json_helper(val.item(), artifact)
122
-
123
- if isinstance(val, datetime.datetime):
124
- if val.tzinfo is None:
125
- val = datetime.datetime(
126
- val.year,
127
- val.month,
128
- val.day,
129
- val.hour,
130
- val.minute,
131
- val.second,
132
- val.microsecond,
133
- tzinfo=datetime.timezone.utc,
134
- )
135
- return int(val.timestamp() * 1000)
136
- elif isinstance(val, datetime.date):
137
- return int(
138
- datetime.datetime(
139
- val.year, val.month, val.day, tzinfo=datetime.timezone.utc
140
- ).timestamp()
141
- * 1000
142
- )
143
- elif isinstance(val, (list, tuple)):
144
- return [_json_helper(i, artifact) for i in val]
145
- elif isinstance(val, Decimal):
146
- return float(val)
147
- else:
148
- return util.json_friendly(val)[0]
149
-
150
-
151
- class Table(Media):
152
- """The Table class used to display and analyze tabular data.
153
-
154
- Unlike traditional spreadsheets, Tables support numerous types of data:
155
- scalar values, strings, numpy arrays, and most subclasses of `wandb.data_types.Media`.
156
- This means you can embed `Images`, `Video`, `Audio`, and other sorts of rich, annotated media
157
- directly in Tables, alongside other traditional scalar values.
158
-
159
- This class is the primary class used to generate the Table Visualizer
160
- in the UI: https://docs.wandb.ai/guides/data-vis/tables.
161
-
162
- Arguments:
163
- columns: (List[str]) Names of the columns in the table.
164
- Defaults to ["Input", "Output", "Expected"].
165
- data: (List[List[any]]) 2D row-oriented array of values.
166
- dataframe: (pandas.DataFrame) DataFrame object used to create the table.
167
- When set, `data` and `columns` arguments are ignored.
168
- optional: (Union[bool,List[bool]]) Determines if `None` values are allowed. Default to True
169
- - If a singular bool value, then the optionality is enforced for all
170
- columns specified at construction time
171
- - If a list of bool values, then the optionality is applied to each
172
- column - should be the same length as `columns`
173
- applies to all columns. A list of bool values applies to each respective column.
174
- allow_mixed_types: (bool) Determines if columns are allowed to have mixed types
175
- (disables type validation). Defaults to False
176
- """
177
-
178
- MAX_ROWS = 10000
179
- MAX_ARTIFACT_ROWS = 200000
180
- _MAX_EMBEDDING_DIMENSIONS = 150
181
- _log_type = "table"
182
-
183
- def __init__(
184
- self,
185
- columns=None,
186
- data=None,
187
- rows=None,
188
- dataframe=None,
189
- dtype=None,
190
- optional=True,
191
- allow_mixed_types=False,
192
- ):
193
- """Initializes a Table object.
194
-
195
- The rows is available for legacy reasons and should not be used.
196
- The Table class uses data to mimic the Pandas API.
197
- """
198
- super().__init__()
199
- self._pk_col = None
200
- self._fk_cols = set()
201
- if allow_mixed_types:
202
- dtype = _dtypes.AnyType
203
-
204
- # This is kept for legacy reasons (tss: personally, I think we should remove this)
205
- if columns is None:
206
- columns = ["Input", "Output", "Expected"]
207
-
208
- # Explicit dataframe option
209
- if dataframe is not None:
210
- self._init_from_dataframe(dataframe, columns, optional, dtype)
211
- else:
212
- # Expected pattern
213
- if data is not None:
214
- if util.is_numpy_array(data):
215
- self._init_from_ndarray(data, columns, optional, dtype)
216
- elif util.is_pandas_data_frame(data):
217
- self._init_from_dataframe(data, columns, optional, dtype)
218
- else:
219
- self._init_from_list(data, columns, optional, dtype)
220
-
221
- # legacy
222
- elif rows is not None:
223
- self._init_from_list(rows, columns, optional, dtype)
224
-
225
- # Default empty case
226
- else:
227
- self._init_from_list([], columns, optional, dtype)
228
-
229
- @staticmethod
230
- def _assert_valid_columns(columns):
231
- valid_col_types = [str, int]
232
- assert isinstance(columns, list), "columns argument expects a `list` object"
233
- assert len(columns) == 0 or all(
234
- [type(col) in valid_col_types for col in columns]
235
- ), "columns argument expects list of strings or ints"
236
-
237
- def _init_from_list(self, data, columns, optional=True, dtype=None):
238
- assert isinstance(data, list), "data argument expects a `list` object"
239
- self.data = []
240
- self._assert_valid_columns(columns)
241
- self.columns = columns
242
- self._make_column_types(dtype, optional)
243
- for row in data:
244
- self.add_data(*row)
245
-
246
- def _init_from_ndarray(self, ndarray, columns, optional=True, dtype=None):
247
- assert util.is_numpy_array(
248
- ndarray
249
- ), "ndarray argument expects a `numpy.ndarray` object"
250
- self.data = []
251
- self._assert_valid_columns(columns)
252
- self.columns = columns
253
- self._make_column_types(dtype, optional)
254
- for row in ndarray:
255
- self.add_data(*row)
256
-
257
- def _init_from_dataframe(self, dataframe, columns, optional=True, dtype=None):
258
- assert util.is_pandas_data_frame(
259
- dataframe
260
- ), "dataframe argument expects a `pandas.core.frame.DataFrame` object"
261
- self.data = []
262
- columns = list(dataframe.columns)
263
- self._assert_valid_columns(columns)
264
- self.columns = columns
265
- self._make_column_types(dtype, optional)
266
- for row in range(len(dataframe)):
267
- self.add_data(*tuple(dataframe[col].values[row] for col in self.columns))
268
-
269
- def _make_column_types(self, dtype=None, optional=True):
270
- if dtype is None:
271
- dtype = _dtypes.UnknownType()
272
-
273
- if optional.__class__ is not list:
274
- optional = [optional for _ in range(len(self.columns))]
275
-
276
- if dtype.__class__ is not list:
277
- dtype = [dtype for _ in range(len(self.columns))]
278
-
279
- self._column_types = _dtypes.TypedDictType({})
280
- for col_name, opt, dt in zip(self.columns, optional, dtype):
281
- self.cast(col_name, dt, opt)
282
-
283
- def cast(self, col_name, dtype, optional=False):
284
- """Casts a column to a specific data type.
285
-
286
- This can be one of the normal python classes, an internal W&B type, or an
287
- example object, like an instance of wandb.Image or wandb.Classes.
288
-
289
- Arguments:
290
- col_name: (str) - The name of the column to cast.
291
- dtype: (class, wandb.wandb_sdk.interface._dtypes.Type, any) - The target dtype.
292
- optional: (bool) - If the column should allow Nones.
293
- """
294
- assert col_name in self.columns
295
-
296
- wbtype = _dtypes.TypeRegistry.type_from_dtype(dtype)
297
-
298
- if optional:
299
- wbtype = _dtypes.OptionalType(wbtype)
300
-
301
- # Cast each value in the row, raising an error if there are invalid entries.
302
- col_ndx = self.columns.index(col_name)
303
- for row in self.data:
304
- result_type = wbtype.assign(row[col_ndx])
305
- if isinstance(result_type, _dtypes.InvalidType):
306
- raise TypeError(
307
- "Existing data {}, of type {} cannot be cast to {}".format(
308
- row[col_ndx],
309
- _dtypes.TypeRegistry.type_of(row[col_ndx]),
310
- wbtype,
311
- )
312
- )
313
- wbtype = result_type
314
-
315
- # Assert valid options
316
- is_pk = isinstance(wbtype, _PrimaryKeyType)
317
- is_fk = isinstance(wbtype, _ForeignKeyType)
318
- is_fi = isinstance(wbtype, _ForeignIndexType)
319
- if is_pk or is_fk or is_fi:
320
- assert (
321
- not optional
322
- ), "Primary keys, foreign keys, and foreign indexes cannot be optional."
323
-
324
- if (is_fk or is_fk) and id(wbtype.params["table"]) == id(self):
325
- raise AssertionError("Cannot set a foreign table reference to same table.")
326
-
327
- if is_pk:
328
- assert (
329
- self._pk_col is None
330
- ), "Cannot have multiple primary keys - {} is already set as the primary key.".format(
331
- self._pk_col
332
- )
333
-
334
- # Update the column type
335
- self._column_types.params["type_map"][col_name] = wbtype
336
-
337
- # Wrap the data if needed
338
- self._update_keys()
339
- return wbtype
340
-
341
- def __ne__(self, other):
342
- return not self.__eq__(other)
343
-
344
- def _eq_debug(self, other, should_assert=False):
345
- eq = isinstance(other, Table)
346
- assert not should_assert or eq, "Found type {}, expected {}".format(
347
- other.__class__, Table
348
- )
349
- eq = eq and len(self.data) == len(other.data)
350
- assert not should_assert or eq, "Found {} rows, expected {}".format(
351
- len(other.data), len(self.data)
352
- )
353
- eq = eq and self.columns == other.columns
354
- assert not should_assert or eq, "Found columns {}, expected {}".format(
355
- other.columns, self.columns
356
- )
357
- eq = eq and self._column_types == other._column_types
358
- assert (
359
- not should_assert or eq
360
- ), "Found column type {}, expected column type {}".format(
361
- other._column_types, self._column_types
362
- )
363
- if eq:
364
- for row_ndx in range(len(self.data)):
365
- for col_ndx in range(len(self.data[row_ndx])):
366
- _eq = self.data[row_ndx][col_ndx] == other.data[row_ndx][col_ndx]
367
- # equal if all are equal
368
- if util.is_numpy_array(_eq):
369
- _eq = ((_eq * -1) + 1).sum() == 0
370
- eq = eq and _eq
371
- assert (
372
- not should_assert or eq
373
- ), "Unequal data at row_ndx {} col_ndx {}: found {}, expected {}".format(
374
- row_ndx,
375
- col_ndx,
376
- other.data[row_ndx][col_ndx],
377
- self.data[row_ndx][col_ndx],
378
- )
379
- if not eq:
380
- return eq
381
- return eq
382
-
383
- def __eq__(self, other):
384
- return self._eq_debug(other)
385
-
386
- def add_row(self, *row):
387
- """Deprecated; use add_data instead."""
388
- logging.warning("add_row is deprecated, use add_data")
389
- self.add_data(*row)
390
-
391
- def add_data(self, *data):
392
- """Adds a new row of data to the table. The maximum amount of rows in a table is determined by `wandb.Table.MAX_ARTIFACT_ROWS`.
393
-
394
- The length of the data should match the length of the table column.
395
- """
396
- if len(data) != len(self.columns):
397
- raise ValueError(
398
- "This table expects {} columns: {}, found {}".format(
399
- len(self.columns), self.columns, len(data)
400
- )
401
- )
402
-
403
- # Special case to pre-emptively cast a column as a key.
404
- # Needed as String.assign(Key) is invalid
405
- for ndx, item in enumerate(data):
406
- if isinstance(item, _TableLinkMixin):
407
- self.cast(
408
- self.columns[ndx],
409
- _dtypes.TypeRegistry.type_of(item),
410
- optional=False,
411
- )
412
-
413
- # Update the table's column types
414
- result_type = self._get_updated_result_type(data)
415
- self._column_types = result_type
416
-
417
- # rows need to be mutable
418
- if isinstance(data, tuple):
419
- data = list(data)
420
- # Add the new data
421
- self.data.append(data)
422
-
423
- # Update the wrapper values if needed
424
- self._update_keys(force_last=True)
425
-
426
- def _get_updated_result_type(self, row):
427
- """Returns the updated result type based on the inputted row.
428
-
429
- Raises:
430
- TypeError: if the assignment is invalid.
431
- """
432
- incoming_row_dict = {
433
- col_key: row[ndx] for ndx, col_key in enumerate(self.columns)
434
- }
435
- current_type = self._column_types
436
- result_type = current_type.assign(incoming_row_dict)
437
- if isinstance(result_type, _dtypes.InvalidType):
438
- raise TypeError(
439
- "Data row contained incompatible types:\n{}".format(
440
- current_type.explain(incoming_row_dict)
441
- )
442
- )
443
- return result_type
444
-
445
- def _to_table_json(self, max_rows=None, warn=True):
446
- # separate this method for easier testing
447
- if max_rows is None:
448
- max_rows = Table.MAX_ROWS
449
- n_rows = len(self.data)
450
- if n_rows > max_rows and warn:
451
- if wandb.run and (
452
- wandb.run.settings.table_raise_on_max_row_limit_exceeded
453
- or wandb.run.settings.strict
454
- ):
455
- raise ValueError(
456
- f"Table row limit exceeded: table has {n_rows} rows, limit is {max_rows}. "
457
- f"To increase the maximum number of allowed rows in a wandb.Table, override "
458
- f"the limit with `wandb.Table.MAX_ARTIFACT_ROWS = X` and try again. Note: "
459
- f"this may cause slower queries in the W&B UI."
460
- )
461
- logging.warning("Truncating wandb.Table object to %i rows." % max_rows)
462
- return {"columns": self.columns, "data": self.data[:max_rows]}
463
-
464
- def bind_to_run(self, *args, **kwargs):
465
- # We set `warn=False` since Tables will now always be logged to both
466
- # files and artifacts. The file limit will never practically matter and
467
- # this code path will be ultimately removed. The 10k limit warning confuses
468
- # users given that we publicly say 200k is the limit.
469
- data = self._to_table_json(warn=False)
470
- tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".table.json")
471
- data = _numpy_arrays_to_lists(data)
472
- with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
473
- util.json_dump_safer(data, fp)
474
- self._set_file(tmp_path, is_tmp=True, extension=".table.json")
475
- super().bind_to_run(*args, **kwargs)
476
-
477
- @classmethod
478
- def get_media_subdir(cls):
479
- return os.path.join("media", "table")
480
-
481
- @classmethod
482
- def from_json(cls, json_obj, source_artifact):
483
- data = []
484
- column_types = None
485
- np_deserialized_columns = {}
486
- timestamp_column_indices = set()
487
- if json_obj.get("column_types") is not None:
488
- column_types = _dtypes.TypeRegistry.type_from_dict(
489
- json_obj["column_types"], source_artifact
490
- )
491
- for col_name in column_types.params["type_map"]:
492
- col_type = column_types.params["type_map"][col_name]
493
- ndarray_type = None
494
- if isinstance(col_type, _dtypes.NDArrayType):
495
- ndarray_type = col_type
496
- elif isinstance(col_type, _dtypes.UnionType):
497
- for t in col_type.params["allowed_types"]:
498
- if isinstance(t, _dtypes.NDArrayType):
499
- ndarray_type = t
500
- elif isinstance(t, _dtypes.TimestampType):
501
- timestamp_column_indices.add(
502
- json_obj["columns"].index(col_name)
503
- )
504
-
505
- elif isinstance(col_type, _dtypes.TimestampType):
506
- timestamp_column_indices.add(json_obj["columns"].index(col_name))
507
-
508
- if (
509
- ndarray_type is not None
510
- and ndarray_type._get_serialization_path() is not None
511
- ):
512
- serialization_path = ndarray_type._get_serialization_path()
513
- np = util.get_module(
514
- "numpy",
515
- required="Deserializing NumPy columns requires NumPy to be installed.",
516
- )
517
- deserialized = np.load(
518
- source_artifact.get_entry(serialization_path["path"]).download()
519
- )
520
- np_deserialized_columns[json_obj["columns"].index(col_name)] = (
521
- deserialized[serialization_path["key"]]
522
- )
523
- ndarray_type._clear_serialization_path()
524
-
525
- for r_ndx, row in enumerate(json_obj["data"]):
526
- row_data = []
527
- for c_ndx, item in enumerate(row):
528
- cell = item
529
- if c_ndx in timestamp_column_indices and isinstance(item, (int, float)):
530
- cell = datetime.datetime.fromtimestamp(
531
- item / 1000, tz=datetime.timezone.utc
532
- )
533
- elif c_ndx in np_deserialized_columns:
534
- cell = np_deserialized_columns[c_ndx][r_ndx]
535
- elif isinstance(item, dict) and "_type" in item:
536
- obj = WBValue.init_from_json(item, source_artifact)
537
- if obj is not None:
538
- cell = obj
539
- row_data.append(cell)
540
- data.append(row_data)
541
-
542
- # construct Table with dtypes for each column if type information exists
543
- dtypes = None
544
- if column_types is not None:
545
- dtypes = [
546
- column_types.params["type_map"][str(col)] for col in json_obj["columns"]
547
- ]
548
-
549
- new_obj = cls(columns=json_obj["columns"], data=data, dtype=dtypes)
550
-
551
- if column_types is not None:
552
- new_obj._column_types = column_types
553
-
554
- new_obj._update_keys()
555
- return new_obj
556
-
557
- def to_json(self, run_or_artifact):
558
- json_dict = super().to_json(run_or_artifact)
559
-
560
- if isinstance(run_or_artifact, wandb.wandb_sdk.wandb_run.Run):
561
- json_dict.update(
562
- {
563
- "_type": "table-file",
564
- "ncols": len(self.columns),
565
- "nrows": len(self.data),
566
- }
567
- )
568
-
569
- elif isinstance(run_or_artifact, wandb.Artifact):
570
- artifact = run_or_artifact
571
- mapped_data = []
572
- data = self._to_table_json(Table.MAX_ARTIFACT_ROWS)["data"]
573
-
574
- ndarray_col_ndxs = set()
575
- for col_ndx, col_name in enumerate(self.columns):
576
- col_type = self._column_types.params["type_map"][col_name]
577
- ndarray_type = None
578
- if isinstance(col_type, _dtypes.NDArrayType):
579
- ndarray_type = col_type
580
- elif isinstance(col_type, _dtypes.UnionType):
581
- for t in col_type.params["allowed_types"]:
582
- if isinstance(t, _dtypes.NDArrayType):
583
- ndarray_type = t
584
-
585
- # Do not serialize 1d arrays - these are likely embeddings and
586
- # will not have the same cost as higher dimensional arrays
587
- is_1d_array = (
588
- ndarray_type is not None
589
- and "shape" in ndarray_type._params
590
- and isinstance(ndarray_type._params["shape"], list)
591
- and len(ndarray_type._params["shape"]) == 1
592
- and ndarray_type._params["shape"][0]
593
- <= self._MAX_EMBEDDING_DIMENSIONS
594
- )
595
- if is_1d_array:
596
- self._column_types.params["type_map"][col_name] = _dtypes.ListType(
597
- _dtypes.NumberType, ndarray_type._params["shape"][0]
598
- )
599
- elif ndarray_type is not None:
600
- np = util.get_module(
601
- "numpy",
602
- required="Serializing NumPy requires NumPy to be installed.",
603
- )
604
- file_name = f"{str(col_name)}_{runid.generate_id()}.npz"
605
- npz_file_name = os.path.join(MEDIA_TMP.name, file_name)
606
- np.savez_compressed(
607
- npz_file_name,
608
- **{
609
- str(col_name): self.get_column(col_name, convert_to="numpy")
610
- },
611
- )
612
- entry = artifact.add_file(
613
- npz_file_name, "media/serialized_data/" + file_name, is_tmp=True
614
- )
615
- ndarray_type._set_serialization_path(entry.path, str(col_name))
616
- ndarray_col_ndxs.add(col_ndx)
617
-
618
- for row in data:
619
- mapped_row = []
620
- for ndx, v in enumerate(row):
621
- if ndx in ndarray_col_ndxs:
622
- mapped_row.append(None)
623
- else:
624
- mapped_row.append(_json_helper(v, artifact))
625
- mapped_data.append(mapped_row)
626
-
627
- json_dict.update(
628
- {
629
- "_type": Table._log_type,
630
- "columns": self.columns,
631
- "data": mapped_data,
632
- "ncols": len(self.columns),
633
- "nrows": len(mapped_data),
634
- "column_types": self._column_types.to_json(artifact),
635
- }
636
- )
637
- else:
638
- raise ValueError("to_json accepts wandb_run.Run or wandb_artifact.Artifact")
639
-
640
- return json_dict
641
-
642
- def iterrows(self):
643
- """Returns the table data by row, showing the index of the row and the relevant data.
644
-
645
- Yields:
646
- ------
647
- index : int
648
- The index of the row. Using this value in other W&B tables
649
- will automatically build a relationship between the tables
650
- row : List[any]
651
- The data of the row.
652
- """
653
- for ndx in range(len(self.data)):
654
- index = _TableIndex(ndx)
655
- index.set_table(self)
656
- yield index, self.data[ndx]
657
-
658
- def set_pk(self, col_name):
659
- # TODO: Docs
660
- assert col_name in self.columns
661
- self.cast(col_name, _PrimaryKeyType())
662
-
663
- def set_fk(self, col_name, table, table_col):
664
- # TODO: Docs
665
- assert col_name in self.columns
666
- assert col_name != self._pk_col
667
- self.cast(col_name, _ForeignKeyType(table, table_col))
668
-
669
- def _update_keys(self, force_last=False):
670
- """Updates the known key-like columns based on current column types.
671
-
672
- If the state has been updated since the last update, wraps the data
673
- appropriately in the Key classes.
674
-
675
- Arguments:
676
- force_last: (bool) Wraps the last column of data even if there
677
- are no key updates.
678
- """
679
- _pk_col = None
680
- _fk_cols = set()
681
-
682
- # Buildup the known keys from column types
683
- c_types = self._column_types.params["type_map"]
684
- for t in c_types:
685
- if isinstance(c_types[t], _PrimaryKeyType):
686
- _pk_col = t
687
- elif isinstance(c_types[t], _ForeignKeyType) or isinstance(
688
- c_types[t], _ForeignIndexType
689
- ):
690
- _fk_cols.add(t)
691
-
692
- # If there are updates to perform, safely update them
693
- has_update = _pk_col != self._pk_col or _fk_cols != self._fk_cols
694
- if has_update:
695
- # If we removed the PK
696
- if _pk_col is None and self._pk_col is not None:
697
- raise AssertionError(
698
- f"Cannot unset primary key (column {self._pk_col})"
699
- )
700
- # If there is a removed FK
701
- if len(self._fk_cols - _fk_cols) > 0:
702
- raise AssertionError(
703
- "Cannot unset foreign key. Attempted to unset ({})".format(
704
- self._fk_cols - _fk_cols
705
- )
706
- )
707
-
708
- self._pk_col = _pk_col
709
- self._fk_cols = _fk_cols
710
-
711
- # Apply updates to data only if there are update or the caller
712
- # requested the final row to be updated
713
- if has_update or force_last:
714
- self._apply_key_updates(not has_update)
715
-
716
- def _apply_key_updates(self, only_last=False):
717
- """Appropriately wraps the underlying data in special Key classes.
718
-
719
- Arguments:
720
- only_last: only apply the updates to the last row (used for performance when
721
- the caller knows that the only new data is the last row and no updates were
722
- applied to the column types)
723
- """
724
- c_types = self._column_types.params["type_map"]
725
-
726
- # Define a helper function which will wrap the data of a single row
727
- # in the appropriate class wrapper.
728
- def update_row(row_ndx):
729
- for fk_col in self._fk_cols:
730
- col_ndx = self.columns.index(fk_col)
731
-
732
- # Wrap the Foreign Keys
733
- if isinstance(c_types[fk_col], _ForeignKeyType) and not isinstance(
734
- self.data[row_ndx][col_ndx], _TableKey
735
- ):
736
- self.data[row_ndx][col_ndx] = _TableKey(self.data[row_ndx][col_ndx])
737
- self.data[row_ndx][col_ndx].set_table(
738
- c_types[fk_col].params["table"],
739
- c_types[fk_col].params["col_name"],
740
- )
741
-
742
- # Wrap the Foreign Indexes
743
- elif isinstance(c_types[fk_col], _ForeignIndexType) and not isinstance(
744
- self.data[row_ndx][col_ndx], _TableIndex
745
- ):
746
- self.data[row_ndx][col_ndx] = _TableIndex(
747
- self.data[row_ndx][col_ndx]
748
- )
749
- self.data[row_ndx][col_ndx].set_table(
750
- c_types[fk_col].params["table"]
751
- )
752
-
753
- # Wrap the Primary Key
754
- if self._pk_col is not None:
755
- col_ndx = self.columns.index(self._pk_col)
756
- self.data[row_ndx][col_ndx] = _TableKey(self.data[row_ndx][col_ndx])
757
- self.data[row_ndx][col_ndx].set_table(self, self._pk_col)
758
-
759
- if only_last:
760
- update_row(len(self.data) - 1)
761
- else:
762
- for row_ndx in range(len(self.data)):
763
- update_row(row_ndx)
764
-
765
- def add_column(self, name, data, optional=False):
766
- """Adds a column of data to the table.
767
-
768
- Arguments:
769
- name: (str) - the unique name of the column
770
- data: (list | np.array) - a column of homogeneous data
771
- optional: (bool) - if null-like values are permitted
772
- """
773
- assert isinstance(name, str) and name not in self.columns
774
- is_np = util.is_numpy_array(data)
775
- assert isinstance(data, list) or is_np
776
- assert isinstance(optional, bool)
777
- is_first_col = len(self.columns) == 0
778
- assert is_first_col or len(data) == len(
779
- self.data
780
- ), f"Expected length {len(self.data)}, found {len(data)}"
781
-
782
- # Add the new data
783
- for ndx in range(max(len(data), len(self.data))):
784
- if is_first_col:
785
- self.data.append([])
786
- if is_np:
787
- self.data[ndx].append(data[ndx])
788
- else:
789
- self.data[ndx].append(data[ndx])
790
- # add the column
791
- self.columns.append(name)
792
-
793
- try:
794
- self.cast(name, _dtypes.UnknownType(), optional=optional)
795
- except TypeError as err:
796
- # Undo the changes
797
- if is_first_col:
798
- self.data = []
799
- self.columns = []
800
- else:
801
- for ndx in range(len(self.data)):
802
- self.data[ndx] = self.data[ndx][:-1]
803
- self.columns = self.columns[:-1]
804
- raise err
805
-
806
- def get_column(self, name, convert_to=None):
807
- """Retrieves a column from the table and optionally converts it to a NumPy object.
808
-
809
- Arguments:
810
- name: (str) - the name of the column
811
- convert_to: (str, optional)
812
- - "numpy": will convert the underlying data to numpy object
813
- """
814
- assert name in self.columns
815
- assert convert_to is None or convert_to == "numpy"
816
- if convert_to == "numpy":
817
- np = util.get_module(
818
- "numpy", required="Converting to NumPy requires installing NumPy"
819
- )
820
- col = []
821
- col_ndx = self.columns.index(name)
822
- for row in self.data:
823
- item = row[col_ndx]
824
- if convert_to is not None and isinstance(item, WBValue):
825
- item = item.to_data_array()
826
- col.append(item)
827
- if convert_to == "numpy":
828
- col = np.array(col)
829
- return col
830
-
831
- def get_index(self):
832
- """Returns an array of row indexes for use in other tables to create links."""
833
- ndxs = []
834
- for ndx in range(len(self.data)):
835
- index = _TableIndex(ndx)
836
- index.set_table(self)
837
- ndxs.append(index)
838
- return ndxs
839
-
840
- def get_dataframe(self):
841
- """Returns a `pandas.DataFrame` of the table."""
842
- pd = util.get_module(
843
- "pandas",
844
- required="Converting to pandas.DataFrame requires installing pandas",
845
- )
846
- return pd.DataFrame.from_records(self.data, columns=self.columns)
847
-
848
- def index_ref(self, index):
849
- """Gets a reference of the index of a row in the table."""
850
- assert index < len(self.data)
851
- _index = _TableIndex(index)
852
- _index.set_table(self)
853
- return _index
854
-
855
- def add_computed_columns(self, fn):
856
- """Adds one or more computed columns based on existing data.
857
-
858
- Args:
859
- fn: A function which accepts one or two parameters, ndx (int) and row (dict),
860
- which is expected to return a dict representing new columns for that row, keyed
861
- by the new column names.
862
-
863
- `ndx` is an integer representing the index of the row. Only included if `include_ndx`
864
- is set to `True`.
865
-
866
- `row` is a dictionary keyed by existing columns
867
- """
868
- new_columns = {}
869
- for ndx, row in self.iterrows():
870
- row_dict = {self.columns[i]: row[i] for i in range(len(self.columns))}
871
- new_row_dict = fn(ndx, row_dict)
872
- assert isinstance(new_row_dict, dict)
873
- for key in new_row_dict:
874
- new_columns[key] = new_columns.get(key, [])
875
- new_columns[key].append(new_row_dict[key])
876
- for new_col_name in new_columns:
877
- self.add_column(new_col_name, new_columns[new_col_name])
878
-
879
-
880
- class _PartitionTablePartEntry:
881
- """Helper class for PartitionTable to track its parts."""
882
-
883
- def __init__(self, entry, source_artifact):
884
- self.entry = entry
885
- self.source_artifact = source_artifact
886
- self._part = None
887
-
888
- def get_part(self):
889
- if self._part is None:
890
- self._part = self.source_artifact.get(self.entry.path)
891
- return self._part
892
-
893
- def free(self):
894
- self._part = None
895
-
896
-
897
- class PartitionedTable(Media):
898
- """A table which is composed of multiple sub-tables.
899
-
900
- Currently, PartitionedTable is designed to point to a directory within an artifact.
901
- """
902
-
903
- _log_type = "partitioned-table"
904
-
905
- def __init__(self, parts_path):
906
- """Initialize a PartitionedTable.
907
-
908
- Args:
909
- parts_path (str): path to a directory of tables in the artifact.
910
- """
911
- super().__init__()
912
- self.parts_path = parts_path
913
- self._loaded_part_entries = {}
914
-
915
- def to_json(self, artifact_or_run):
916
- json_obj = {
917
- "_type": PartitionedTable._log_type,
918
- }
919
- if isinstance(artifact_or_run, wandb.wandb_sdk.wandb_run.Run):
920
- artifact_entry_url = self._get_artifact_entry_ref_url()
921
- if artifact_entry_url is None:
922
- raise ValueError(
923
- "PartitionedTables must first be added to an Artifact before logging to a Run"
924
- )
925
- json_obj["artifact_path"] = artifact_entry_url
926
- else:
927
- json_obj["parts_path"] = self.parts_path
928
- return json_obj
929
-
930
- @classmethod
931
- def from_json(cls, json_obj, source_artifact):
932
- instance = cls(json_obj["parts_path"])
933
- entries = source_artifact.manifest.get_entries_in_directory(
934
- json_obj["parts_path"]
935
- )
936
- for entry in entries:
937
- instance._add_part_entry(entry, source_artifact)
938
- return instance
939
-
940
- def iterrows(self):
941
- """Iterate over rows as (ndx, row).
942
-
943
- Yields:
944
- ------
945
- index : int
946
- The index of the row.
947
- row : List[any]
948
- The data of the row.
949
- """
950
- columns = None
951
- ndx = 0
952
- for entry_path in self._loaded_part_entries:
953
- part = self._loaded_part_entries[entry_path].get_part()
954
- if columns is None:
955
- columns = part.columns
956
- elif columns != part.columns:
957
- raise ValueError(
958
- "Table parts have non-matching columns. {} != {}".format(
959
- columns, part.columns
960
- )
961
- )
962
- for _, row in part.iterrows():
963
- yield ndx, row
964
- ndx += 1
965
-
966
- self._loaded_part_entries[entry_path].free()
967
-
968
- def _add_part_entry(self, entry, source_artifact):
969
- self._loaded_part_entries[entry.path] = _PartitionTablePartEntry(
970
- entry, source_artifact
971
- )
972
-
973
- def __ne__(self, other):
974
- return not self.__eq__(other)
975
-
976
- def __eq__(self, other):
977
- return isinstance(other, self.__class__) and self.parts_path == other.parts_path
978
-
979
- def bind_to_run(self, *args, **kwargs):
980
- raise ValueError("PartitionedTables cannot be bound to runs")
981
-
982
-
983
- class Audio(BatchableMedia):
984
- """Wandb class for audio clips.
985
-
986
- Arguments:
987
- data_or_path: (string or numpy array) A path to an audio file
988
- or a numpy array of audio data.
989
- sample_rate: (int) Sample rate, required when passing in raw
990
- numpy array of audio data.
991
- caption: (string) Caption to display with audio.
992
- """
993
-
994
- _log_type = "audio-file"
995
-
996
- def __init__(self, data_or_path, sample_rate=None, caption=None):
997
- """Accept a path to an audio file or a numpy array of audio data."""
998
- super().__init__()
999
- self._duration = None
1000
- self._sample_rate = sample_rate
1001
- self._caption = caption
1002
-
1003
- if isinstance(data_or_path, str):
1004
- if self.path_is_reference(data_or_path):
1005
- self._path = data_or_path
1006
- self._sha256 = hashlib.sha256(data_or_path.encode("utf-8")).hexdigest()
1007
- self._is_tmp = False
1008
- else:
1009
- self._set_file(data_or_path, is_tmp=False)
1010
- else:
1011
- if sample_rate is None:
1012
- raise ValueError(
1013
- 'Argument "sample_rate" is required when instantiating wandb.Audio with raw data.'
1014
- )
1015
-
1016
- soundfile = util.get_module(
1017
- "soundfile",
1018
- required='Raw audio requires the soundfile package. To get it, run "pip install soundfile"',
1019
- )
1020
-
1021
- tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".wav")
1022
- soundfile.write(tmp_path, data_or_path, sample_rate)
1023
- self._duration = len(data_or_path) / float(sample_rate)
1024
-
1025
- self._set_file(tmp_path, is_tmp=True)
1026
-
1027
- @classmethod
1028
- def get_media_subdir(cls):
1029
- return os.path.join("media", "audio")
1030
-
1031
- @classmethod
1032
- def from_json(cls, json_obj, source_artifact):
1033
- return cls(
1034
- source_artifact.get_entry(json_obj["path"]).download(),
1035
- caption=json_obj["caption"],
1036
- )
1037
-
1038
- def bind_to_run(
1039
- self, run, key, step, id_=None, ignore_copy_err: Optional[bool] = None
1040
- ):
1041
- if self.path_is_reference(self._path):
1042
- raise ValueError(
1043
- "Audio media created by a reference to external storage cannot currently be added to a run"
1044
- )
1045
-
1046
- return super().bind_to_run(run, key, step, id_, ignore_copy_err)
1047
-
1048
- def to_json(self, run):
1049
- json_dict = super().to_json(run)
1050
- json_dict.update(
1051
- {
1052
- "_type": self._log_type,
1053
- "caption": self._caption,
1054
- }
1055
- )
1056
- return json_dict
1057
-
1058
- @classmethod
1059
- def seq_to_json(cls, seq, run, key, step):
1060
- audio_list = list(seq)
1061
-
1062
- util.get_module(
1063
- "soundfile",
1064
- required="wandb.Audio requires the soundfile package. To get it, run: pip install soundfile",
1065
- )
1066
- base_path = os.path.join(run.dir, "media", "audio")
1067
- filesystem.mkdir_exists_ok(base_path)
1068
- meta = {
1069
- "_type": "audio",
1070
- "count": len(audio_list),
1071
- "audio": [a.to_json(run) for a in audio_list],
1072
- }
1073
- sample_rates = cls.sample_rates(audio_list)
1074
- if sample_rates:
1075
- meta["sampleRates"] = sample_rates
1076
- durations = cls.durations(audio_list)
1077
- if durations:
1078
- meta["durations"] = durations
1079
- captions = cls.captions(audio_list)
1080
- if captions:
1081
- meta["captions"] = captions
1082
-
1083
- return meta
1084
-
1085
- @classmethod
1086
- def durations(cls, audio_list):
1087
- return [a._duration for a in audio_list]
1088
-
1089
- @classmethod
1090
- def sample_rates(cls, audio_list):
1091
- return [a._sample_rate for a in audio_list]
1092
-
1093
- @classmethod
1094
- def captions(cls, audio_list):
1095
- captions = [a._caption for a in audio_list]
1096
- if all(c is None for c in captions):
1097
- return False
1098
- else:
1099
- return ["" if c is None else c for c in captions]
1100
-
1101
- def resolve_ref(self):
1102
- if self.path_is_reference(self._path):
1103
- # this object was already created using a ref:
1104
- return self._path
1105
- source_artifact = self._artifact_source.artifact
1106
-
1107
- resolved_name = source_artifact._local_path_to_name(self._path)
1108
- if resolved_name is not None:
1109
- target_entry = source_artifact.manifest.get_entry_by_path(resolved_name)
1110
- if target_entry is not None:
1111
- return target_entry.ref
1112
-
1113
- return None
1114
-
1115
- def __eq__(self, other):
1116
- if self.path_is_reference(self._path) or self.path_is_reference(other._path):
1117
- # one or more of these objects is an unresolved reference -- we'll compare
1118
- # their reference paths instead of their SHAs:
1119
- return (
1120
- self.resolve_ref() == other.resolve_ref()
1121
- and self._caption == other._caption
1122
- )
1123
-
1124
- return super().__eq__(other) and self._caption == other._caption
1125
-
1126
- def __ne__(self, other):
1127
- return not self.__eq__(other)
1128
-
1129
-
1130
- class JoinedTable(Media):
1131
- """Join two tables for visualization in the Artifact UI.
1132
-
1133
- Arguments:
1134
- table1 (str, wandb.Table, ArtifactManifestEntry):
1135
- the path to a wandb.Table in an artifact, the table object, or ArtifactManifestEntry
1136
- table2 (str, wandb.Table):
1137
- the path to a wandb.Table in an artifact, the table object, or ArtifactManifestEntry
1138
- join_key (str, [str, str]):
1139
- key or keys to perform the join
1140
- """
1141
-
1142
- _log_type = "joined-table"
1143
-
1144
- def __init__(self, table1, table2, join_key):
1145
- super().__init__()
1146
-
1147
- if not isinstance(join_key, str) and (
1148
- not isinstance(join_key, list) or len(join_key) != 2
1149
- ):
1150
- raise ValueError(
1151
- "JoinedTable join_key should be a string or a list of two strings"
1152
- )
1153
-
1154
- if not self._validate_table_input(table1):
1155
- raise ValueError(
1156
- "JoinedTable table1 should be an artifact path to a table or wandb.Table object"
1157
- )
1158
-
1159
- if not self._validate_table_input(table2):
1160
- raise ValueError(
1161
- "JoinedTable table2 should be an artifact path to a table or wandb.Table object"
1162
- )
1163
-
1164
- self._table1 = table1
1165
- self._table2 = table2
1166
- self._join_key = join_key
1167
-
1168
- @classmethod
1169
- def from_json(cls, json_obj, source_artifact):
1170
- t1 = source_artifact.get(json_obj["table1"])
1171
- if t1 is None:
1172
- t1 = json_obj["table1"]
1173
-
1174
- t2 = source_artifact.get(json_obj["table2"])
1175
- if t2 is None:
1176
- t2 = json_obj["table2"]
1177
-
1178
- return cls(
1179
- t1,
1180
- t2,
1181
- json_obj["join_key"],
1182
- )
1183
-
1184
- @staticmethod
1185
- def _validate_table_input(table):
1186
- """Helper method to validate that the table input is one of the 3 supported types."""
1187
- return (
1188
- (isinstance(table, str) and table.endswith(".table.json"))
1189
- or isinstance(table, Table)
1190
- or isinstance(table, PartitionedTable)
1191
- or (hasattr(table, "ref_url") and table.ref_url().endswith(".table.json"))
1192
- )
1193
-
1194
- def _ensure_table_in_artifact(self, table, artifact, table_ndx):
1195
- """Helper method to add the table to the incoming artifact. Returns the path."""
1196
- if isinstance(table, Table) or isinstance(table, PartitionedTable):
1197
- table_name = f"t{table_ndx}_{str(id(self))}"
1198
- if (
1199
- table._artifact_source is not None
1200
- and table._artifact_source.name is not None
1201
- ):
1202
- table_name = os.path.basename(table._artifact_source.name)
1203
- entry = artifact.add(table, table_name)
1204
- table = entry.path
1205
- # Check if this is an ArtifactManifestEntry
1206
- elif hasattr(table, "ref_url"):
1207
- # Give the new object a unique, yet deterministic name
1208
- name = binascii.hexlify(base64.standard_b64decode(table.digest)).decode(
1209
- "ascii"
1210
- )[:20]
1211
- entry = artifact.add_reference(
1212
- table.ref_url(), "{}.{}.json".format(name, table.name.split(".")[-2])
1213
- )[0]
1214
- table = entry.path
1215
-
1216
- err_str = "JoinedTable table:{} not found in artifact. Add a table to the artifact using Artifact#add(<table>, {}) before adding this JoinedTable"
1217
- if table not in artifact._manifest.entries:
1218
- raise ValueError(err_str.format(table, table))
1219
-
1220
- return table
1221
-
1222
- def to_json(self, artifact_or_run):
1223
- json_obj = {
1224
- "_type": JoinedTable._log_type,
1225
- }
1226
- if isinstance(artifact_or_run, wandb.wandb_sdk.wandb_run.Run):
1227
- artifact_entry_url = self._get_artifact_entry_ref_url()
1228
- if artifact_entry_url is None:
1229
- raise ValueError(
1230
- "JoinedTables must first be added to an Artifact before logging to a Run"
1231
- )
1232
- json_obj["artifact_path"] = artifact_entry_url
1233
- else:
1234
- table1 = self._ensure_table_in_artifact(self._table1, artifact_or_run, 1)
1235
- table2 = self._ensure_table_in_artifact(self._table2, artifact_or_run, 2)
1236
- json_obj.update(
1237
- {
1238
- "table1": table1,
1239
- "table2": table2,
1240
- "join_key": self._join_key,
1241
- }
1242
- )
1243
- return json_obj
1244
-
1245
- def __ne__(self, other):
1246
- return not self.__eq__(other)
1247
-
1248
- def _eq_debug(self, other, should_assert=False):
1249
- eq = isinstance(other, JoinedTable)
1250
- assert not should_assert or eq, "Found type {}, expected {}".format(
1251
- other.__class__, JoinedTable
1252
- )
1253
- eq = eq and self._join_key == other._join_key
1254
- assert not should_assert or eq, "Found {} join key, expected {}".format(
1255
- other._join_key, self._join_key
1256
- )
1257
- eq = eq and self._table1._eq_debug(other._table1, should_assert)
1258
- eq = eq and self._table2._eq_debug(other._table2, should_assert)
1259
- return eq
1260
-
1261
- def __eq__(self, other):
1262
- return self._eq_debug(other, False)
1263
-
1264
- def bind_to_run(self, *args, **kwargs):
1265
- raise ValueError("JoinedTables cannot be bound to runs")
1266
-
1267
-
1268
- class Bokeh(Media):
1269
- """Wandb class for Bokeh plots.
1270
-
1271
- Arguments:
1272
- val: Bokeh plot
1273
- """
1274
-
1275
- _log_type = "bokeh-file"
1276
-
1277
- def __init__(self, data_or_path):
1278
- super().__init__()
1279
- bokeh = util.get_module("bokeh", required=True)
1280
- if isinstance(data_or_path, str) and os.path.exists(data_or_path):
1281
- with open(data_or_path) as file:
1282
- b_json = json.load(file)
1283
- self.b_obj = bokeh.document.Document.from_json(b_json)
1284
- self._set_file(data_or_path, is_tmp=False, extension=".bokeh.json")
1285
- elif isinstance(data_or_path, bokeh.model.Model):
1286
- _data = bokeh.document.Document()
1287
- _data.add_root(data_or_path)
1288
- # serialize/deserialize pairing followed by sorting attributes ensures
1289
- # that the file's sha's are equivalent in subsequent calls
1290
- self.b_obj = bokeh.document.Document.from_json(_data.to_json())
1291
- b_json = self.b_obj.to_json()
1292
- if "references" in b_json["roots"]:
1293
- b_json["roots"]["references"].sort(key=lambda x: x["id"])
1294
-
1295
- tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".bokeh.json")
1296
- with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
1297
- util.json_dump_safer(b_json, fp)
1298
- self._set_file(tmp_path, is_tmp=True, extension=".bokeh.json")
1299
- elif not isinstance(data_or_path, bokeh.document.Document):
1300
- raise TypeError(
1301
- "Bokeh constructor accepts Bokeh document/model or path to Bokeh json file"
1302
- )
1303
-
1304
- def get_media_subdir(self):
1305
- return os.path.join("media", "bokeh")
1306
-
1307
- def to_json(self, run):
1308
- # TODO: (tss) this is getting redundant for all the media objects. We can probably
1309
- # pull this into Media#to_json and remove this type override for all the media types.
1310
- # There are only a few cases where the type is different between artifacts and runs.
1311
- json_dict = super().to_json(run)
1312
- json_dict["_type"] = self._log_type
1313
- return json_dict
1314
-
1315
- @classmethod
1316
- def from_json(cls, json_obj, source_artifact):
1317
- return cls(source_artifact.get_entry(json_obj["path"]).download())
1318
-
1319
-
1320
- def _nest(thing):
1321
- # Use tensorflows nest function if available, otherwise just wrap object in an array"""
1322
-
1323
- tfutil = util.get_module("tensorflow.python.util")
1324
- if tfutil:
1325
- return tfutil.nest.flatten(thing)
1326
- else:
1327
- return [thing]
1328
-
1329
-
1330
- class Graph(Media):
1331
- """Wandb class for graphs.
1332
-
1333
- This class is typically used for saving and displaying neural net models. It
1334
- represents the graph as an array of nodes and edges. The nodes can have
1335
- labels that can be visualized by wandb.
1336
-
1337
- Examples:
1338
- Import a keras model:
1339
- ```
1340
- Graph.from_keras(keras_model)
1341
- ```
1342
-
1343
- Attributes:
1344
- format (string): Format to help wandb display the graph nicely.
1345
- nodes ([wandb.Node]): List of wandb.Nodes
1346
- nodes_by_id (dict): dict of ids -> nodes
1347
- edges ([(wandb.Node, wandb.Node)]): List of pairs of nodes interpreted as edges
1348
- loaded (boolean): Flag to tell whether the graph is completely loaded
1349
- root (wandb.Node): root node of the graph
1350
- """
1351
-
1352
- _log_type = "graph-file"
1353
-
1354
- def __init__(self, format="keras"):
1355
- super().__init__()
1356
- # LB: TODO: I think we should factor criterion and criterion_passed out
1357
- self.format = format
1358
- self.nodes = []
1359
- self.nodes_by_id = {}
1360
- self.edges = []
1361
- self.loaded = False
1362
- self.criterion = None
1363
- self.criterion_passed = False
1364
- self.root = None # optional root Node if applicable
1365
-
1366
- def _to_graph_json(self, run=None):
1367
- # Needs to be its own function for tests
1368
- return {
1369
- "format": self.format,
1370
- "nodes": [node.to_json() for node in self.nodes],
1371
- "edges": [edge.to_json() for edge in self.edges],
1372
- }
1373
-
1374
- def bind_to_run(self, *args, **kwargs):
1375
- data = self._to_graph_json()
1376
- tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".graph.json")
1377
- data = _numpy_arrays_to_lists(data)
1378
- with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
1379
- util.json_dump_safer(data, fp)
1380
- self._set_file(tmp_path, is_tmp=True, extension=".graph.json")
1381
- if self.is_bound():
1382
- return
1383
- super().bind_to_run(*args, **kwargs)
1384
-
1385
- @classmethod
1386
- def get_media_subdir(cls):
1387
- return os.path.join("media", "graph")
1388
-
1389
- def to_json(self, run):
1390
- json_dict = super().to_json(run)
1391
- json_dict["_type"] = self._log_type
1392
- return json_dict
1393
-
1394
- def __getitem__(self, nid):
1395
- return self.nodes_by_id[nid]
1396
-
1397
- def pprint(self):
1398
- for edge in self.edges:
1399
- pprint.pprint(edge.attributes)
1400
- for node in self.nodes:
1401
- pprint.pprint(node.attributes)
1402
-
1403
- def add_node(self, node=None, **node_kwargs):
1404
- if node is None:
1405
- node = Node(**node_kwargs)
1406
- elif node_kwargs:
1407
- raise ValueError(
1408
- f"Only pass one of either node ({node}) or other keyword arguments ({node_kwargs})"
1409
- )
1410
- self.nodes.append(node)
1411
- self.nodes_by_id[node.id] = node
1412
-
1413
- return node
1414
-
1415
- def add_edge(self, from_node, to_node):
1416
- edge = Edge(from_node, to_node)
1417
- self.edges.append(edge)
1418
-
1419
- return edge
1420
-
1421
- @classmethod
1422
- def from_keras(cls, model):
1423
- # TODO: his method requires a refactor to work with the keras 3.
1424
- graph = cls()
1425
- # Shamelessly copied (then modified) from keras/keras/utils/layer_utils.py
1426
- sequential_like = cls._is_sequential(model)
1427
-
1428
- relevant_nodes = None
1429
- if not sequential_like:
1430
- relevant_nodes = []
1431
- for v in model._nodes_by_depth.values():
1432
- relevant_nodes += v
1433
-
1434
- layers = model.layers
1435
- for i in range(len(layers)):
1436
- node = Node.from_keras(layers[i])
1437
- if hasattr(layers[i], "_inbound_nodes"):
1438
- for in_node in layers[i]._inbound_nodes:
1439
- if relevant_nodes and in_node not in relevant_nodes:
1440
- # node is not part of the current network
1441
- continue
1442
- for in_layer in _nest(in_node.inbound_layers):
1443
- inbound_keras_node = Node.from_keras(in_layer)
1444
-
1445
- if inbound_keras_node.id not in graph.nodes_by_id:
1446
- graph.add_node(inbound_keras_node)
1447
- inbound_node = graph.nodes_by_id[inbound_keras_node.id]
1448
-
1449
- graph.add_edge(inbound_node, node)
1450
- graph.add_node(node)
1451
- return graph
1452
-
1453
- @classmethod
1454
- def _is_sequential(cls, model):
1455
- sequential_like = True
1456
-
1457
- if (
1458
- model.__class__.__name__ != "Sequential"
1459
- and hasattr(model, "_is_graph_network")
1460
- and model._is_graph_network
1461
- ):
1462
- nodes_by_depth = model._nodes_by_depth.values()
1463
- nodes = []
1464
- for v in nodes_by_depth:
1465
- # TensorFlow2 doesn't insure inbound is always a list
1466
- inbound = v[0].inbound_layers
1467
- if not hasattr(inbound, "__len__"):
1468
- inbound = [inbound]
1469
- if (len(v) > 1) or (len(v) == 1 and len(inbound) > 1):
1470
- # if the model has multiple nodes
1471
- # or if the nodes have multiple inbound_layers
1472
- # the model is no longer sequential
1473
- sequential_like = False
1474
- break
1475
- nodes += v
1476
- if sequential_like:
1477
- # search for shared layers
1478
- for layer in model.layers:
1479
- flag = False
1480
- if hasattr(layer, "_inbound_nodes"):
1481
- for node in layer._inbound_nodes:
1482
- if node in nodes:
1483
- if flag:
1484
- sequential_like = False
1485
- break
1486
- else:
1487
- flag = True
1488
- if not sequential_like:
1489
- break
1490
- return sequential_like
1491
-
1492
-
1493
- class Node(WBValue):
1494
- """Node used in `Graph`."""
1495
-
1496
- def __init__(
1497
- self,
1498
- id=None,
1499
- name=None,
1500
- class_name=None,
1501
- size=None,
1502
- parameters=None,
1503
- output_shape=None,
1504
- is_output=None,
1505
- num_parameters=None,
1506
- node=None,
1507
- ):
1508
- self._attributes = {"name": None}
1509
- self.in_edges = {} # indexed by source node id
1510
- self.out_edges = {} # indexed by dest node id
1511
- # optional object (e.g. PyTorch Parameter or Module) that this Node represents
1512
- self.obj = None
1513
-
1514
- if node is not None:
1515
- self._attributes.update(node._attributes)
1516
- del self._attributes["id"]
1517
- self.obj = node.obj
1518
-
1519
- if id is not None:
1520
- self.id = id
1521
- if name is not None:
1522
- self.name = name
1523
- if class_name is not None:
1524
- self.class_name = class_name
1525
- if size is not None:
1526
- self.size = size
1527
- if parameters is not None:
1528
- self.parameters = parameters
1529
- if output_shape is not None:
1530
- self.output_shape = output_shape
1531
- if is_output is not None:
1532
- self.is_output = is_output
1533
- if num_parameters is not None:
1534
- self.num_parameters = num_parameters
1535
-
1536
- def to_json(self, run=None):
1537
- return self._attributes
1538
-
1539
- def __repr__(self):
1540
- return repr(self._attributes)
1541
-
1542
- @property
1543
- def id(self):
1544
- """Must be unique in the graph."""
1545
- return self._attributes.get("id")
1546
-
1547
- @id.setter
1548
- def id(self, val):
1549
- self._attributes["id"] = val
1550
- return val
1551
-
1552
- @property
1553
- def name(self):
1554
- """Usually the type of layer or sublayer."""
1555
- return self._attributes.get("name")
1556
-
1557
- @name.setter
1558
- def name(self, val):
1559
- self._attributes["name"] = val
1560
- return val
1561
-
1562
- @property
1563
- def class_name(self):
1564
- """Usually the type of layer or sublayer."""
1565
- return self._attributes.get("class_name")
1566
-
1567
- @class_name.setter
1568
- def class_name(self, val):
1569
- self._attributes["class_name"] = val
1570
- return val
1571
-
1572
- @property
1573
- def functions(self):
1574
- return self._attributes.get("functions", [])
1575
-
1576
- @functions.setter
1577
- def functions(self, val):
1578
- self._attributes["functions"] = val
1579
- return val
1580
-
1581
- @property
1582
- def parameters(self):
1583
- return self._attributes.get("parameters", [])
1584
-
1585
- @parameters.setter
1586
- def parameters(self, val):
1587
- self._attributes["parameters"] = val
1588
- return val
1589
-
1590
- @property
1591
- def size(self):
1592
- return self._attributes.get("size")
1593
-
1594
- @size.setter
1595
- def size(self, val):
1596
- """Tensor size."""
1597
- self._attributes["size"] = tuple(val)
1598
- return val
1599
-
1600
- @property
1601
- def output_shape(self):
1602
- return self._attributes.get("output_shape")
1603
-
1604
- @output_shape.setter
1605
- def output_shape(self, val):
1606
- """Tensor output_shape."""
1607
- self._attributes["output_shape"] = val
1608
- return val
1609
-
1610
- @property
1611
- def is_output(self):
1612
- return self._attributes.get("is_output")
1613
-
1614
- @is_output.setter
1615
- def is_output(self, val):
1616
- """Tensor is_output."""
1617
- self._attributes["is_output"] = val
1618
- return val
1619
-
1620
- @property
1621
- def num_parameters(self):
1622
- return self._attributes.get("num_parameters")
1623
-
1624
- @num_parameters.setter
1625
- def num_parameters(self, val):
1626
- """Tensor num_parameters."""
1627
- self._attributes["num_parameters"] = val
1628
- return val
1629
-
1630
- @property
1631
- def child_parameters(self):
1632
- return self._attributes.get("child_parameters")
1633
-
1634
- @child_parameters.setter
1635
- def child_parameters(self, val):
1636
- """Tensor child_parameters."""
1637
- self._attributes["child_parameters"] = val
1638
- return val
1639
-
1640
- @property
1641
- def is_constant(self):
1642
- return self._attributes.get("is_constant")
1643
-
1644
- @is_constant.setter
1645
- def is_constant(self, val):
1646
- """Tensor is_constant."""
1647
- self._attributes["is_constant"] = val
1648
- return val
1649
-
1650
- @classmethod
1651
- def from_keras(cls, layer):
1652
- node = cls()
1653
-
1654
- try:
1655
- output_shape = layer.output_shape
1656
- except AttributeError:
1657
- output_shape = ["multiple"]
1658
-
1659
- node.id = layer.name
1660
- node.name = layer.name
1661
- node.class_name = layer.__class__.__name__
1662
- node.output_shape = output_shape
1663
- node.num_parameters = layer.count_params()
1664
-
1665
- return node
1666
-
1667
-
1668
- class Edge(WBValue):
1669
- """Edge used in `Graph`."""
1670
-
1671
- def __init__(self, from_node, to_node):
1672
- self._attributes = {}
1673
- self.from_node = from_node
1674
- self.to_node = to_node
1675
-
1676
- def __repr__(self):
1677
- temp_attr = dict(self._attributes)
1678
- del temp_attr["from_node"]
1679
- del temp_attr["to_node"]
1680
- temp_attr["from_id"] = self.from_node.id
1681
- temp_attr["to_id"] = self.to_node.id
1682
- return str(temp_attr)
1683
-
1684
- def to_json(self, run=None):
1685
- return [self.from_node.id, self.to_node.id]
1686
-
1687
- @property
1688
- def name(self):
1689
- """Optional, not necessarily unique."""
1690
- return self._attributes.get("name")
1691
-
1692
- @name.setter
1693
- def name(self, val):
1694
- self._attributes["name"] = val
1695
- return val
1696
-
1697
- @property
1698
- def from_node(self):
1699
- return self._attributes.get("from_node")
1700
-
1701
- @from_node.setter
1702
- def from_node(self, val):
1703
- self._attributes["from_node"] = val
1704
- return val
1705
-
1706
- @property
1707
- def to_node(self):
1708
- return self._attributes.get("to_node")
1709
-
1710
- @to_node.setter
1711
- def to_node(self, val):
1712
- self._attributes["to_node"] = val
1713
- return val
1714
-
1715
-
1716
- # Custom dtypes for typing system
1717
- class _ImageFileType(_dtypes.Type):
1718
- name = "image-file"
1719
- legacy_names = ["wandb.Image"]
1720
- types = [Image]
1721
-
1722
- def __init__(
1723
- self,
1724
- box_layers=None,
1725
- box_score_keys=None,
1726
- mask_layers=None,
1727
- class_map=None,
1728
- **kwargs,
1729
- ):
1730
- box_layers = box_layers or {}
1731
- box_score_keys = box_score_keys or []
1732
- mask_layers = mask_layers or {}
1733
- class_map = class_map or {}
1734
-
1735
- if isinstance(box_layers, _dtypes.ConstType):
1736
- box_layers = box_layers._params["val"]
1737
- if not isinstance(box_layers, dict):
1738
- raise TypeError("box_layers must be a dict")
1739
- else:
1740
- box_layers = _dtypes.ConstType(
1741
- {layer_key: set(box_layers[layer_key]) for layer_key in box_layers}
1742
- )
1743
-
1744
- if isinstance(mask_layers, _dtypes.ConstType):
1745
- mask_layers = mask_layers._params["val"]
1746
- if not isinstance(mask_layers, dict):
1747
- raise TypeError("mask_layers must be a dict")
1748
- else:
1749
- mask_layers = _dtypes.ConstType(
1750
- {layer_key: set(mask_layers[layer_key]) for layer_key in mask_layers}
1751
- )
1752
-
1753
- if isinstance(box_score_keys, _dtypes.ConstType):
1754
- box_score_keys = box_score_keys._params["val"]
1755
- if not isinstance(box_score_keys, list) and not isinstance(box_score_keys, set):
1756
- raise TypeError("box_score_keys must be a list or a set")
1757
- else:
1758
- box_score_keys = _dtypes.ConstType(set(box_score_keys))
1759
-
1760
- if isinstance(class_map, _dtypes.ConstType):
1761
- class_map = class_map._params["val"]
1762
- if not isinstance(class_map, dict):
1763
- raise TypeError("class_map must be a dict")
1764
- else:
1765
- class_map = _dtypes.ConstType(class_map)
1766
-
1767
- self.params.update(
1768
- {
1769
- "box_layers": box_layers,
1770
- "box_score_keys": box_score_keys,
1771
- "mask_layers": mask_layers,
1772
- "class_map": class_map,
1773
- }
1774
- )
1775
-
1776
- def assign_type(self, wb_type=None):
1777
- if isinstance(wb_type, _ImageFileType):
1778
- box_layers_self = self.params["box_layers"].params["val"] or {}
1779
- box_score_keys_self = self.params["box_score_keys"].params["val"] or []
1780
- mask_layers_self = self.params["mask_layers"].params["val"] or {}
1781
- class_map_self = self.params["class_map"].params["val"] or {}
1782
-
1783
- box_layers_other = wb_type.params["box_layers"].params["val"] or {}
1784
- box_score_keys_other = wb_type.params["box_score_keys"].params["val"] or []
1785
- mask_layers_other = wb_type.params["mask_layers"].params["val"] or {}
1786
- class_map_other = wb_type.params["class_map"].params["val"] or {}
1787
-
1788
- # Merge the class_ids from each set of box_layers
1789
- box_layers = {
1790
- str(key): set(
1791
- list(box_layers_self.get(key, []))
1792
- + list(box_layers_other.get(key, []))
1793
- )
1794
- for key in set(
1795
- list(box_layers_self.keys()) + list(box_layers_other.keys())
1796
- )
1797
- }
1798
-
1799
- # Merge the class_ids from each set of mask_layers
1800
- mask_layers = {
1801
- str(key): set(
1802
- list(mask_layers_self.get(key, []))
1803
- + list(mask_layers_other.get(key, []))
1804
- )
1805
- for key in set(
1806
- list(mask_layers_self.keys()) + list(mask_layers_other.keys())
1807
- )
1808
- }
1809
-
1810
- # Merge the box score keys
1811
- box_score_keys = set(list(box_score_keys_self) + list(box_score_keys_other))
1812
-
1813
- # Merge the class_map
1814
- class_map = {
1815
- str(key): class_map_self.get(key, class_map_other.get(key, None))
1816
- for key in set(
1817
- list(class_map_self.keys()) + list(class_map_other.keys())
1818
- )
1819
- }
1820
-
1821
- return _ImageFileType(box_layers, box_score_keys, mask_layers, class_map)
1822
-
1823
- return _dtypes.InvalidType()
1824
-
1825
- @classmethod
1826
- def from_obj(cls, py_obj):
1827
- if not isinstance(py_obj, Image):
1828
- raise TypeError("py_obj must be a wandb.Image")
1829
- else:
1830
- if hasattr(py_obj, "_boxes") and py_obj._boxes:
1831
- box_layers = {
1832
- str(key): set(py_obj._boxes[key]._class_labels.keys())
1833
- for key in py_obj._boxes.keys()
1834
- }
1835
- box_score_keys = {
1836
- key
1837
- for val in py_obj._boxes.values()
1838
- for box in val._val
1839
- for key in box.get("scores", {}).keys()
1840
- }
1841
-
1842
- else:
1843
- box_layers = {}
1844
- box_score_keys = set()
1845
-
1846
- if hasattr(py_obj, "_masks") and py_obj._masks:
1847
- mask_layers = {
1848
- str(key): set(
1849
- py_obj._masks[key]._val["class_labels"].keys()
1850
- if hasattr(py_obj._masks[key], "_val")
1851
- else []
1852
- )
1853
- for key in py_obj._masks.keys()
1854
- }
1855
- else:
1856
- mask_layers = {}
1857
-
1858
- if hasattr(py_obj, "_classes") and py_obj._classes:
1859
- class_set = {
1860
- str(item["id"]): item["name"] for item in py_obj._classes._class_set
1861
- }
1862
- else:
1863
- class_set = {}
1864
-
1865
- return cls(box_layers, box_score_keys, mask_layers, class_set)
1866
-
1867
-
1868
- class _TableType(_dtypes.Type):
1869
- name = "table"
1870
- legacy_names = ["wandb.Table"]
1871
- types = [Table]
1872
-
1873
- def __init__(self, column_types=None):
1874
- if column_types is None:
1875
- column_types = _dtypes.UnknownType()
1876
- if isinstance(column_types, dict):
1877
- column_types = _dtypes.TypedDictType(column_types)
1878
- elif not (
1879
- isinstance(column_types, _dtypes.TypedDictType)
1880
- or isinstance(column_types, _dtypes.UnknownType)
1881
- ):
1882
- raise TypeError("column_types must be a dict or TypedDictType")
1883
-
1884
- self.params.update({"column_types": column_types})
1885
-
1886
- def assign_type(self, wb_type=None):
1887
- if isinstance(wb_type, _TableType):
1888
- column_types = self.params["column_types"].assign_type(
1889
- wb_type.params["column_types"]
1890
- )
1891
- if not isinstance(column_types, _dtypes.InvalidType):
1892
- return _TableType(column_types)
1893
-
1894
- return _dtypes.InvalidType()
1895
-
1896
- @classmethod
1897
- def from_obj(cls, py_obj):
1898
- if not isinstance(py_obj, Table):
1899
- raise TypeError("py_obj must be a wandb.Table")
1900
- else:
1901
- return cls(py_obj._column_types)
1902
-
1903
-
1904
- class _ForeignKeyType(_dtypes.Type):
1905
- name = "foreignKey"
1906
- legacy_names = ["wandb.TableForeignKey"]
1907
- types = [_TableKey]
1908
-
1909
- def __init__(self, table, col_name):
1910
- assert isinstance(table, Table)
1911
- assert isinstance(col_name, str)
1912
- assert col_name in table.columns
1913
- self.params.update({"table": table, "col_name": col_name})
1914
-
1915
- def assign_type(self, wb_type=None):
1916
- if isinstance(wb_type, _dtypes.StringType):
1917
- return self
1918
- elif (
1919
- isinstance(wb_type, _ForeignKeyType)
1920
- and id(self.params["table"]) == id(wb_type.params["table"])
1921
- and self.params["col_name"] == wb_type.params["col_name"]
1922
- ):
1923
- return self
1924
-
1925
- return _dtypes.InvalidType()
1926
-
1927
- @classmethod
1928
- def from_obj(cls, py_obj):
1929
- if not isinstance(py_obj, _TableKey):
1930
- raise TypeError("py_obj must be a _TableKey")
1931
- else:
1932
- return cls(py_obj._table, py_obj._col_name)
1933
-
1934
- def to_json(self, artifact=None):
1935
- res = super().to_json(artifact)
1936
- if artifact is not None:
1937
- table_name = f"media/tables/t_{runid.generate_id()}"
1938
- entry = artifact.add(self.params["table"], table_name)
1939
- res["params"]["table"] = entry.path
1940
- else:
1941
- raise AssertionError(
1942
- "_ForeignKeyType does not support serialization without an artifact"
1943
- )
1944
- return res
1945
-
1946
- @classmethod
1947
- def from_json(
1948
- cls,
1949
- json_dict,
1950
- artifact,
1951
- ):
1952
- table = None
1953
- col_name = None
1954
- if artifact is None:
1955
- raise AssertionError(
1956
- "_ForeignKeyType does not support deserialization without an artifact"
1957
- )
1958
- else:
1959
- table = artifact.get(json_dict["params"]["table"])
1960
- col_name = json_dict["params"]["col_name"]
1961
-
1962
- if table is None:
1963
- raise AssertionError("Unable to deserialize referenced table")
1964
-
1965
- return cls(table, col_name)
1966
-
1967
-
1968
- class _ForeignIndexType(_dtypes.Type):
1969
- name = "foreignIndex"
1970
- legacy_names = ["wandb.TableForeignIndex"]
1971
- types = [_TableIndex]
1972
-
1973
- def __init__(self, table):
1974
- assert isinstance(table, Table)
1975
- self.params.update({"table": table})
1976
-
1977
- def assign_type(self, wb_type=None):
1978
- if isinstance(wb_type, _dtypes.NumberType):
1979
- return self
1980
- elif isinstance(wb_type, _ForeignIndexType) and id(self.params["table"]) == id(
1981
- wb_type.params["table"]
1982
- ):
1983
- return self
1984
-
1985
- return _dtypes.InvalidType()
1986
-
1987
- @classmethod
1988
- def from_obj(cls, py_obj):
1989
- if not isinstance(py_obj, _TableIndex):
1990
- raise TypeError("py_obj must be a _TableIndex")
1991
- else:
1992
- return cls(py_obj._table)
1993
-
1994
- def to_json(self, artifact=None):
1995
- res = super().to_json(artifact)
1996
- if artifact is not None:
1997
- table_name = f"media/tables/t_{runid.generate_id()}"
1998
- entry = artifact.add(self.params["table"], table_name)
1999
- res["params"]["table"] = entry.path
2000
- else:
2001
- raise AssertionError(
2002
- "_ForeignIndexType does not support serialization without an artifact"
2003
- )
2004
- return res
2005
-
2006
- @classmethod
2007
- def from_json(
2008
- cls,
2009
- json_dict,
2010
- artifact,
2011
- ):
2012
- table = None
2013
- if artifact is None:
2014
- raise AssertionError(
2015
- "_ForeignIndexType does not support deserialization without an artifact"
2016
- )
2017
- else:
2018
- table = artifact.get(json_dict["params"]["table"])
2019
-
2020
- if table is None:
2021
- raise AssertionError("Unable to deserialize referenced table")
2022
-
2023
- return cls(table)
2024
-
2025
-
2026
- class _PrimaryKeyType(_dtypes.Type):
2027
- name = "primaryKey"
2028
- legacy_names = ["wandb.TablePrimaryKey"]
2029
-
2030
- def assign_type(self, wb_type=None):
2031
- if isinstance(wb_type, _dtypes.StringType) or isinstance(
2032
- wb_type, _PrimaryKeyType
2033
- ):
2034
- return self
2035
- return _dtypes.InvalidType()
2036
-
2037
- @classmethod
2038
- def from_obj(cls, py_obj):
2039
- if not isinstance(py_obj, _TableKey):
2040
- raise TypeError("py_obj must be a wandb.Table")
2041
- else:
2042
- return cls()
2043
-
2044
-
2045
- class _AudioFileType(_dtypes.Type):
2046
- name = "audio-file"
2047
- types = [Audio]
2048
-
2049
-
2050
- class _BokehFileType(_dtypes.Type):
2051
- name = "bokeh-file"
2052
- types = [Bokeh]
2053
-
2054
-
2055
- class _JoinedTableType(_dtypes.Type):
2056
- name = "joined-table"
2057
- types = [JoinedTable]
2058
-
2059
-
2060
- class _PartitionedTableType(_dtypes.Type):
2061
- name = "partitioned-table"
2062
- types = [PartitionedTable]
2063
-
2064
-
2065
- _dtypes.TypeRegistry.add(_AudioFileType)
2066
- _dtypes.TypeRegistry.add(_BokehFileType)
2067
- _dtypes.TypeRegistry.add(_ImageFileType)
2068
- _dtypes.TypeRegistry.add(_TableType)
2069
- _dtypes.TypeRegistry.add(_JoinedTableType)
2070
- _dtypes.TypeRegistry.add(_PartitionedTableType)
2071
- _dtypes.TypeRegistry.add(_ForeignKeyType)
2072
- _dtypes.TypeRegistry.add(_PrimaryKeyType)
2073
- _dtypes.TypeRegistry.add(_ForeignIndexType)