wandb 0.18.0rc1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (119) hide show
  1. wandb/__init__.py +4 -4
  2. wandb/__init__.pyi +67 -12
  3. wandb/apis/internal.py +3 -0
  4. wandb/apis/public/api.py +128 -2
  5. wandb/apis/public/artifacts.py +11 -7
  6. wandb/apis/public/jobs.py +8 -0
  7. wandb/apis/public/runs.py +18 -5
  8. wandb/bin/nvidia_gpu_stats +0 -0
  9. wandb/cli/cli.py +0 -5
  10. wandb/data_types.py +9 -2019
  11. wandb/env.py +0 -5
  12. wandb/errors/__init__.py +11 -40
  13. wandb/errors/errors.py +37 -0
  14. wandb/errors/warnings.py +2 -0
  15. wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
  16. wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
  17. wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
  18. wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
  19. wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
  20. wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
  21. wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
  22. wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
  23. wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
  24. wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
  25. wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
  26. wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
  27. wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
  28. wandb/{sklearn → integration/sklearn}/utils.py +8 -8
  29. wandb/integration/tensorboard/log.py +1 -1
  30. wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
  31. wandb/old/core.py +2 -80
  32. wandb/plot/bar.py +7 -4
  33. wandb/plot/confusion_matrix.py +5 -4
  34. wandb/plot/histogram.py +7 -4
  35. wandb/plot/line.py +7 -4
  36. wandb/proto/v3/wandb_base_pb2.py +2 -1
  37. wandb/proto/v3/wandb_internal_pb2.py +2 -1
  38. wandb/proto/v3/wandb_server_pb2.py +2 -1
  39. wandb/proto/v3/wandb_settings_pb2.py +3 -2
  40. wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
  41. wandb/proto/v4/wandb_base_pb2.py +2 -1
  42. wandb/proto/v4/wandb_internal_pb2.py +2 -1
  43. wandb/proto/v4/wandb_server_pb2.py +2 -1
  44. wandb/proto/v4/wandb_settings_pb2.py +3 -2
  45. wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
  46. wandb/proto/v5/wandb_base_pb2.py +3 -2
  47. wandb/proto/v5/wandb_internal_pb2.py +3 -2
  48. wandb/proto/v5/wandb_server_pb2.py +3 -2
  49. wandb/proto/v5/wandb_settings_pb2.py +4 -3
  50. wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
  51. wandb/sdk/artifacts/_validators.py +48 -3
  52. wandb/sdk/artifacts/artifact.py +157 -183
  53. wandb/sdk/artifacts/artifact_file_cache.py +13 -11
  54. wandb/sdk/artifacts/artifact_instance_cache.py +4 -2
  55. wandb/sdk/artifacts/artifact_manifest.py +13 -11
  56. wandb/sdk/artifacts/artifact_manifest_entry.py +24 -22
  57. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +9 -7
  58. wandb/sdk/artifacts/artifact_saver.py +27 -25
  59. wandb/sdk/artifacts/exceptions.py +26 -25
  60. wandb/sdk/artifacts/storage_handler.py +11 -9
  61. wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -14
  62. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +15 -13
  63. wandb/sdk/artifacts/storage_handlers/http_handler.py +15 -14
  64. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +10 -8
  65. wandb/sdk/artifacts/storage_handlers/multi_handler.py +14 -12
  66. wandb/sdk/artifacts/storage_handlers/s3_handler.py +19 -19
  67. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +10 -8
  68. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +12 -10
  69. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +9 -7
  70. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +31 -29
  71. wandb/sdk/artifacts/storage_policy.py +20 -20
  72. wandb/sdk/backend/backend.py +8 -26
  73. wandb/sdk/data_types/audio.py +165 -0
  74. wandb/sdk/data_types/base_types/wb_value.py +1 -3
  75. wandb/sdk/data_types/bokeh.py +70 -0
  76. wandb/sdk/data_types/graph.py +405 -0
  77. wandb/sdk/data_types/image.py +156 -0
  78. wandb/sdk/data_types/table.py +1204 -0
  79. wandb/sdk/data_types/trace_tree.py +2 -2
  80. wandb/sdk/data_types/utils.py +49 -0
  81. wandb/sdk/data_types/video.py +2 -2
  82. wandb/sdk/interface/interface.py +0 -24
  83. wandb/sdk/interface/interface_shared.py +0 -12
  84. wandb/sdk/internal/handler.py +0 -10
  85. wandb/sdk/internal/internal_api.py +71 -0
  86. wandb/sdk/internal/sender.py +0 -43
  87. wandb/sdk/internal/tb_watcher.py +1 -1
  88. wandb/sdk/lib/_settings_toposort_generated.py +1 -0
  89. wandb/sdk/lib/hashutil.py +34 -12
  90. wandb/sdk/lib/service_connection.py +216 -0
  91. wandb/sdk/lib/service_token.py +94 -0
  92. wandb/sdk/lib/sock_client.py +7 -3
  93. wandb/sdk/service/server.py +2 -5
  94. wandb/sdk/service/service.py +2 -31
  95. wandb/sdk/service/streams.py +0 -7
  96. wandb/sdk/wandb_init.py +42 -25
  97. wandb/sdk/wandb_run.py +18 -159
  98. wandb/sdk/wandb_settings.py +2 -0
  99. wandb/sdk/wandb_setup.py +25 -16
  100. wandb/sdk/wandb_sync.py +9 -3
  101. wandb/sdk/wandb_watch.py +31 -15
  102. wandb/sklearn.py +35 -0
  103. wandb/util.py +14 -3
  104. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/METADATA +6 -5
  105. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/RECORD +114 -110
  106. wandb/sdk/internal/update.py +0 -113
  107. wandb/sdk/lib/console.py +0 -39
  108. wandb/sdk/service/service_base.py +0 -50
  109. wandb/sdk/service/service_sock.py +0 -70
  110. wandb/sdk/wandb_manager.py +0 -232
  111. /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
  112. /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
  113. /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
  114. /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
  115. /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
  116. /wandb/{sdk/lib → plot}/viz.py +0 -0
  117. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/WHEEL +0 -0
  118. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/entry_points.txt +0 -0
  119. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,405 @@
1
+ import codecs
2
+ import os
3
+ import pprint
4
+
5
+ from wandb import util
6
+ from wandb.sdk.data_types._private import MEDIA_TMP
7
+ from wandb.sdk.data_types.base_types.media import Media, _numpy_arrays_to_lists
8
+ from wandb.sdk.data_types.base_types.wb_value import WBValue
9
+ from wandb.sdk.lib import runid
10
+
11
+
12
+ def _nest(thing):
13
+ # Use tensorflows nest function if available, otherwise just wrap object in an array"""
14
+
15
+ tfutil = util.get_module("tensorflow.python.util")
16
+ if tfutil:
17
+ return tfutil.nest.flatten(thing)
18
+ else:
19
+ return [thing]
20
+
21
+
22
+ class Edge(WBValue):
23
+ """Edge used in `Graph`."""
24
+
25
+ def __init__(self, from_node, to_node):
26
+ self._attributes = {}
27
+ self.from_node = from_node
28
+ self.to_node = to_node
29
+
30
+ def __repr__(self):
31
+ temp_attr = dict(self._attributes)
32
+ del temp_attr["from_node"]
33
+ del temp_attr["to_node"]
34
+ temp_attr["from_id"] = self.from_node.id
35
+ temp_attr["to_id"] = self.to_node.id
36
+ return str(temp_attr)
37
+
38
+ def to_json(self, run=None):
39
+ return [self.from_node.id, self.to_node.id]
40
+
41
+ @property
42
+ def name(self):
43
+ """Optional, not necessarily unique."""
44
+ return self._attributes.get("name")
45
+
46
+ @name.setter
47
+ def name(self, val):
48
+ self._attributes["name"] = val
49
+ return val
50
+
51
+ @property
52
+ def from_node(self):
53
+ return self._attributes.get("from_node")
54
+
55
+ @from_node.setter
56
+ def from_node(self, val):
57
+ self._attributes["from_node"] = val
58
+ return val
59
+
60
+ @property
61
+ def to_node(self):
62
+ return self._attributes.get("to_node")
63
+
64
+ @to_node.setter
65
+ def to_node(self, val):
66
+ self._attributes["to_node"] = val
67
+ return val
68
+
69
+
70
+ class Node(WBValue):
71
+ """Node used in `Graph`."""
72
+
73
+ def __init__(
74
+ self,
75
+ id=None,
76
+ name=None,
77
+ class_name=None,
78
+ size=None,
79
+ parameters=None,
80
+ output_shape=None,
81
+ is_output=None,
82
+ num_parameters=None,
83
+ node=None,
84
+ ):
85
+ self._attributes = {"name": None}
86
+ self.in_edges = {} # indexed by source node id
87
+ self.out_edges = {} # indexed by dest node id
88
+ # optional object (e.g. PyTorch Parameter or Module) that this Node represents
89
+ self.obj = None
90
+
91
+ if node is not None:
92
+ self._attributes.update(node._attributes)
93
+ del self._attributes["id"]
94
+ self.obj = node.obj
95
+
96
+ if id is not None:
97
+ self.id = id
98
+ if name is not None:
99
+ self.name = name
100
+ if class_name is not None:
101
+ self.class_name = class_name
102
+ if size is not None:
103
+ self.size = size
104
+ if parameters is not None:
105
+ self.parameters = parameters
106
+ if output_shape is not None:
107
+ self.output_shape = output_shape
108
+ if is_output is not None:
109
+ self.is_output = is_output
110
+ if num_parameters is not None:
111
+ self.num_parameters = num_parameters
112
+
113
+ def to_json(self, run=None):
114
+ return self._attributes
115
+
116
+ def __repr__(self):
117
+ return repr(self._attributes)
118
+
119
+ @property
120
+ def id(self):
121
+ """Must be unique in the graph."""
122
+ return self._attributes.get("id")
123
+
124
+ @id.setter
125
+ def id(self, val):
126
+ self._attributes["id"] = val
127
+ return val
128
+
129
+ @property
130
+ def name(self):
131
+ """Usually the type of layer or sublayer."""
132
+ return self._attributes.get("name")
133
+
134
+ @name.setter
135
+ def name(self, val):
136
+ self._attributes["name"] = val
137
+ return val
138
+
139
+ @property
140
+ def class_name(self):
141
+ """Usually the type of layer or sublayer."""
142
+ return self._attributes.get("class_name")
143
+
144
+ @class_name.setter
145
+ def class_name(self, val):
146
+ self._attributes["class_name"] = val
147
+ return val
148
+
149
+ @property
150
+ def functions(self):
151
+ return self._attributes.get("functions", [])
152
+
153
+ @functions.setter
154
+ def functions(self, val):
155
+ self._attributes["functions"] = val
156
+ return val
157
+
158
+ @property
159
+ def parameters(self):
160
+ return self._attributes.get("parameters", [])
161
+
162
+ @parameters.setter
163
+ def parameters(self, val):
164
+ self._attributes["parameters"] = val
165
+ return val
166
+
167
+ @property
168
+ def size(self):
169
+ return self._attributes.get("size")
170
+
171
+ @size.setter
172
+ def size(self, val):
173
+ """Tensor size."""
174
+ self._attributes["size"] = tuple(val)
175
+ return val
176
+
177
+ @property
178
+ def output_shape(self):
179
+ return self._attributes.get("output_shape")
180
+
181
+ @output_shape.setter
182
+ def output_shape(self, val):
183
+ """Tensor output_shape."""
184
+ self._attributes["output_shape"] = val
185
+ return val
186
+
187
+ @property
188
+ def is_output(self):
189
+ return self._attributes.get("is_output")
190
+
191
+ @is_output.setter
192
+ def is_output(self, val):
193
+ """Tensor is_output."""
194
+ self._attributes["is_output"] = val
195
+ return val
196
+
197
+ @property
198
+ def num_parameters(self):
199
+ return self._attributes.get("num_parameters")
200
+
201
+ @num_parameters.setter
202
+ def num_parameters(self, val):
203
+ """Tensor num_parameters."""
204
+ self._attributes["num_parameters"] = val
205
+ return val
206
+
207
+ @property
208
+ def child_parameters(self):
209
+ return self._attributes.get("child_parameters")
210
+
211
+ @child_parameters.setter
212
+ def child_parameters(self, val):
213
+ """Tensor child_parameters."""
214
+ self._attributes["child_parameters"] = val
215
+ return val
216
+
217
+ @property
218
+ def is_constant(self):
219
+ return self._attributes.get("is_constant")
220
+
221
+ @is_constant.setter
222
+ def is_constant(self, val):
223
+ """Tensor is_constant."""
224
+ self._attributes["is_constant"] = val
225
+ return val
226
+
227
+ @classmethod
228
+ def from_keras(cls, layer):
229
+ node = cls()
230
+
231
+ try:
232
+ output_shape = layer.output_shape
233
+ except AttributeError:
234
+ output_shape = ["multiple"]
235
+
236
+ node.id = layer.name
237
+ node.name = layer.name
238
+ node.class_name = layer.__class__.__name__
239
+ node.output_shape = output_shape
240
+ node.num_parameters = layer.count_params()
241
+
242
+ return node
243
+
244
+
245
+ class Graph(Media):
246
+ """Wandb class for graphs.
247
+
248
+ This class is typically used for saving and displaying neural net models. It
249
+ represents the graph as an array of nodes and edges. The nodes can have
250
+ labels that can be visualized by wandb.
251
+
252
+ Examples:
253
+ Import a keras model:
254
+ ```
255
+ Graph.from_keras(keras_model)
256
+ ```
257
+
258
+ Attributes:
259
+ format (string): Format to help wandb display the graph nicely.
260
+ nodes ([wandb.Node]): List of wandb.Nodes
261
+ nodes_by_id (dict): dict of ids -> nodes
262
+ edges ([(wandb.Node, wandb.Node)]): List of pairs of nodes interpreted as edges
263
+ loaded (boolean): Flag to tell whether the graph is completely loaded
264
+ root (wandb.Node): root node of the graph
265
+ """
266
+
267
+ _log_type = "graph-file"
268
+
269
+ def __init__(self, format="keras"):
270
+ super().__init__()
271
+ # LB: TODO: I think we should factor criterion and criterion_passed out
272
+ self.format = format
273
+ self.nodes = []
274
+ self.nodes_by_id = {}
275
+ self.edges = []
276
+ self.loaded = False
277
+ self.criterion = None
278
+ self.criterion_passed = False
279
+ self.root = None # optional root Node if applicable
280
+
281
+ def _to_graph_json(self, run=None):
282
+ # Needs to be its own function for tests
283
+ return {
284
+ "format": self.format,
285
+ "nodes": [node.to_json() for node in self.nodes],
286
+ "edges": [edge.to_json() for edge in self.edges],
287
+ }
288
+
289
+ def bind_to_run(self, *args, **kwargs):
290
+ data = self._to_graph_json()
291
+ tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".graph.json")
292
+ data = _numpy_arrays_to_lists(data)
293
+ with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
294
+ util.json_dump_safer(data, fp)
295
+ self._set_file(tmp_path, is_tmp=True, extension=".graph.json")
296
+ if self.is_bound():
297
+ return
298
+ super().bind_to_run(*args, **kwargs)
299
+
300
+ @classmethod
301
+ def get_media_subdir(cls):
302
+ return os.path.join("media", "graph")
303
+
304
+ def to_json(self, run):
305
+ json_dict = super().to_json(run)
306
+ json_dict["_type"] = self._log_type
307
+ return json_dict
308
+
309
+ def __getitem__(self, nid):
310
+ return self.nodes_by_id[nid]
311
+
312
+ def pprint(self):
313
+ for edge in self.edges:
314
+ pprint.pprint(edge.attributes)
315
+ for node in self.nodes:
316
+ pprint.pprint(node.attributes)
317
+
318
+ def add_node(self, node=None, **node_kwargs):
319
+ if node is None:
320
+ node = Node(**node_kwargs)
321
+ elif node_kwargs:
322
+ raise ValueError(
323
+ f"Only pass one of either node ({node}) or other keyword arguments ({node_kwargs})"
324
+ )
325
+ self.nodes.append(node)
326
+ self.nodes_by_id[node.id] = node
327
+
328
+ return node
329
+
330
+ def add_edge(self, from_node, to_node):
331
+ edge = Edge(from_node, to_node)
332
+ self.edges.append(edge)
333
+
334
+ return edge
335
+
336
+ @classmethod
337
+ def from_keras(cls, model):
338
+ # TODO: his method requires a refactor to work with the keras 3.
339
+ graph = cls()
340
+ # Shamelessly copied (then modified) from keras/keras/utils/layer_utils.py
341
+ sequential_like = cls._is_sequential(model)
342
+
343
+ relevant_nodes = None
344
+ if not sequential_like:
345
+ relevant_nodes = []
346
+ for v in model._nodes_by_depth.values():
347
+ relevant_nodes += v
348
+
349
+ layers = model.layers
350
+ for i in range(len(layers)):
351
+ node = Node.from_keras(layers[i])
352
+ if hasattr(layers[i], "_inbound_nodes"):
353
+ for in_node in layers[i]._inbound_nodes:
354
+ if relevant_nodes and in_node not in relevant_nodes:
355
+ # node is not part of the current network
356
+ continue
357
+ for in_layer in _nest(in_node.inbound_layers):
358
+ inbound_keras_node = Node.from_keras(in_layer)
359
+
360
+ if inbound_keras_node.id not in graph.nodes_by_id:
361
+ graph.add_node(inbound_keras_node)
362
+ inbound_node = graph.nodes_by_id[inbound_keras_node.id]
363
+
364
+ graph.add_edge(inbound_node, node)
365
+ graph.add_node(node)
366
+ return graph
367
+
368
+ @classmethod
369
+ def _is_sequential(cls, model):
370
+ sequential_like = True
371
+
372
+ if (
373
+ model.__class__.__name__ != "Sequential"
374
+ and hasattr(model, "_is_graph_network")
375
+ and model._is_graph_network
376
+ ):
377
+ nodes_by_depth = model._nodes_by_depth.values()
378
+ nodes = []
379
+ for v in nodes_by_depth:
380
+ # TensorFlow2 doesn't insure inbound is always a list
381
+ inbound = v[0].inbound_layers
382
+ if not hasattr(inbound, "__len__"):
383
+ inbound = [inbound]
384
+ if (len(v) > 1) or (len(v) == 1 and len(inbound) > 1):
385
+ # if the model has multiple nodes
386
+ # or if the nodes have multiple inbound_layers
387
+ # the model is no longer sequential
388
+ sequential_like = False
389
+ break
390
+ nodes += v
391
+ if sequential_like:
392
+ # search for shared layers
393
+ for layer in model.layers:
394
+ flag = False
395
+ if hasattr(layer, "_inbound_nodes"):
396
+ for node in layer._inbound_nodes:
397
+ if node in nodes:
398
+ if flag:
399
+ sequential_like = False
400
+ break
401
+ else:
402
+ flag = True
403
+ if not sequential_like:
404
+ break
405
+ return sequential_like
@@ -10,6 +10,7 @@ from wandb import util
10
10
  from wandb.sdk.lib import hashutil, runid
11
11
  from wandb.sdk.lib.paths import LogicalPath
12
12
 
13
+ from . import _dtypes
13
14
  from ._private import MEDIA_TMP
14
15
  from .base_types.media import BatchableMedia, Media
15
16
  from .helper_types.bounding_boxes_2d import BoundingBoxes2D
@@ -687,3 +688,158 @@ class Image(BatchableMedia):
687
688
  self._image = pil_image.open(self._path)
688
689
  self._image.load()
689
690
  return self._image
691
+
692
+
693
+ # Custom dtypes for typing system
694
+ class _ImageFileType(_dtypes.Type):
695
+ name = "image-file"
696
+ legacy_names = ["wandb.Image"]
697
+ types = [Image]
698
+
699
+ def __init__(
700
+ self,
701
+ box_layers=None,
702
+ box_score_keys=None,
703
+ mask_layers=None,
704
+ class_map=None,
705
+ **kwargs,
706
+ ):
707
+ box_layers = box_layers or {}
708
+ box_score_keys = box_score_keys or []
709
+ mask_layers = mask_layers or {}
710
+ class_map = class_map or {}
711
+
712
+ if isinstance(box_layers, _dtypes.ConstType):
713
+ box_layers = box_layers._params["val"]
714
+ if not isinstance(box_layers, dict):
715
+ raise TypeError("box_layers must be a dict")
716
+ else:
717
+ box_layers = _dtypes.ConstType(
718
+ {layer_key: set(box_layers[layer_key]) for layer_key in box_layers}
719
+ )
720
+
721
+ if isinstance(mask_layers, _dtypes.ConstType):
722
+ mask_layers = mask_layers._params["val"]
723
+ if not isinstance(mask_layers, dict):
724
+ raise TypeError("mask_layers must be a dict")
725
+ else:
726
+ mask_layers = _dtypes.ConstType(
727
+ {layer_key: set(mask_layers[layer_key]) for layer_key in mask_layers}
728
+ )
729
+
730
+ if isinstance(box_score_keys, _dtypes.ConstType):
731
+ box_score_keys = box_score_keys._params["val"]
732
+ if not isinstance(box_score_keys, list) and not isinstance(box_score_keys, set):
733
+ raise TypeError("box_score_keys must be a list or a set")
734
+ else:
735
+ box_score_keys = _dtypes.ConstType(set(box_score_keys))
736
+
737
+ if isinstance(class_map, _dtypes.ConstType):
738
+ class_map = class_map._params["val"]
739
+ if not isinstance(class_map, dict):
740
+ raise TypeError("class_map must be a dict")
741
+ else:
742
+ class_map = _dtypes.ConstType(class_map)
743
+
744
+ self.params.update(
745
+ {
746
+ "box_layers": box_layers,
747
+ "box_score_keys": box_score_keys,
748
+ "mask_layers": mask_layers,
749
+ "class_map": class_map,
750
+ }
751
+ )
752
+
753
+ def assign_type(self, wb_type=None):
754
+ if isinstance(wb_type, _ImageFileType):
755
+ box_layers_self = self.params["box_layers"].params["val"] or {}
756
+ box_score_keys_self = self.params["box_score_keys"].params["val"] or []
757
+ mask_layers_self = self.params["mask_layers"].params["val"] or {}
758
+ class_map_self = self.params["class_map"].params["val"] or {}
759
+
760
+ box_layers_other = wb_type.params["box_layers"].params["val"] or {}
761
+ box_score_keys_other = wb_type.params["box_score_keys"].params["val"] or []
762
+ mask_layers_other = wb_type.params["mask_layers"].params["val"] or {}
763
+ class_map_other = wb_type.params["class_map"].params["val"] or {}
764
+
765
+ # Merge the class_ids from each set of box_layers
766
+ box_layers = {
767
+ str(key): set(
768
+ list(box_layers_self.get(key, []))
769
+ + list(box_layers_other.get(key, []))
770
+ )
771
+ for key in set(
772
+ list(box_layers_self.keys()) + list(box_layers_other.keys())
773
+ )
774
+ }
775
+
776
+ # Merge the class_ids from each set of mask_layers
777
+ mask_layers = {
778
+ str(key): set(
779
+ list(mask_layers_self.get(key, []))
780
+ + list(mask_layers_other.get(key, []))
781
+ )
782
+ for key in set(
783
+ list(mask_layers_self.keys()) + list(mask_layers_other.keys())
784
+ )
785
+ }
786
+
787
+ # Merge the box score keys
788
+ box_score_keys = set(list(box_score_keys_self) + list(box_score_keys_other))
789
+
790
+ # Merge the class_map
791
+ class_map = {
792
+ str(key): class_map_self.get(key, class_map_other.get(key, None))
793
+ for key in set(
794
+ list(class_map_self.keys()) + list(class_map_other.keys())
795
+ )
796
+ }
797
+
798
+ return _ImageFileType(box_layers, box_score_keys, mask_layers, class_map)
799
+
800
+ return _dtypes.InvalidType()
801
+
802
+ @classmethod
803
+ def from_obj(cls, py_obj):
804
+ if not isinstance(py_obj, Image):
805
+ raise TypeError("py_obj must be a wandb.Image")
806
+ else:
807
+ if hasattr(py_obj, "_boxes") and py_obj._boxes:
808
+ box_layers = {
809
+ str(key): set(py_obj._boxes[key]._class_labels.keys())
810
+ for key in py_obj._boxes.keys()
811
+ }
812
+ box_score_keys = {
813
+ key
814
+ for val in py_obj._boxes.values()
815
+ for box in val._val
816
+ for key in box.get("scores", {}).keys()
817
+ }
818
+
819
+ else:
820
+ box_layers = {}
821
+ box_score_keys = set()
822
+
823
+ if hasattr(py_obj, "_masks") and py_obj._masks:
824
+ mask_layers = {
825
+ str(key): set(
826
+ py_obj._masks[key]._val["class_labels"].keys()
827
+ if hasattr(py_obj._masks[key], "_val")
828
+ else []
829
+ )
830
+ for key in py_obj._masks.keys()
831
+ }
832
+ else:
833
+ mask_layers = {}
834
+
835
+ if hasattr(py_obj, "_classes") and py_obj._classes:
836
+ class_set = {
837
+ str(item["id"]): item["name"] for item in py_obj._classes._class_set
838
+ }
839
+ else:
840
+ class_set = {}
841
+
842
+ return cls(box_layers, box_score_keys, mask_layers, class_set)
843
+
844
+
845
+ _dtypes.TypeRegistry.add(_ImageFileType)