entity-ent 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,29 @@
1
+ """Categorical attention: type-aware attention biases.
2
+
3
+ Encodes entity-type relationships as additive query+key biases per type.
4
+ O(seq²) memory (one (batch, 1, seq, seq) tensor) but computed via fast
5
+ broadcast add — no gather or outer-product overhead.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ class CategoricalAttentionBias(nn.Module):
13
+ def __init__(self, num_heads: int, num_types: int = 32):
14
+ super().__init__()
15
+ self.q_bias = nn.Embedding(num_types, 1)
16
+ self.k_bias = nn.Embedding(num_types, 1)
17
+
18
+ def forward(
19
+ self,
20
+ token_types: torch.LongTensor,
21
+ attention_mask: torch.Tensor | None = None,
22
+ ) -> torch.Tensor:
23
+ q = self.q_bias(token_types).squeeze(-1) # (batch, seq)
24
+ k = self.k_bias(token_types).squeeze(-1) # (batch, seq)
25
+ bias = q.unsqueeze(-1) + k.unsqueeze(-2) # (batch, seq, seq) broadcast-add
26
+ bias = bias.unsqueeze(1) # (batch, 1, seq, seq)
27
+ if attention_mask is not None:
28
+ bias = bias.masked_fill(~attention_mask.bool().unsqueeze(1), float("-inf"))
29
+ return bias
@@ -0,0 +1,64 @@
1
+ """baseline — hash-based entity decoder.
2
+
3
+ Maps compact entity hashes through a learned decoder that jointly predicts
4
+ type, class, and scope — the model discovers categorical structure itself
5
+ rather than relying on hand-crafted semantic token strings.
6
+
7
+ Contrast with semantic CEIDs (robot.recon.w11.a = 4+ tokens):
8
+ - Hash-based: single compact ID → learned structured representation
9
+ - Supports categorical attention by providing type/class/scope predictions
10
+ - Trainable end-to-end: the decoder IS the functor
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+
18
+ class EntityDecoder(nn.Module):
19
+ def __init__(
20
+ self,
21
+ num_entities: int = 131072,
22
+ hidden_dim: int = 896,
23
+ num_types: int = 32,
24
+ num_classes: int = 64,
25
+ num_scopes: int = 128,
26
+ ):
27
+ super().__init__()
28
+ self.entity_embed = nn.Embedding(num_entities, hidden_dim)
29
+ self.type_head = nn.Linear(hidden_dim, num_types)
30
+ self.class_head = nn.Linear(hidden_dim, num_classes)
31
+ self.scope_head = nn.Linear(hidden_dim, num_scopes)
32
+ self.num_types = num_types
33
+ self.num_classes = num_classes
34
+ self.num_scopes = num_scopes
35
+
36
+ def forward(
37
+ self, entity_ids: torch.LongTensor
38
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
39
+ e = self.entity_embed(entity_ids)
40
+ type_logits = self.type_head(e)
41
+ class_logits = self.class_head(e)
42
+ scope_logits = self.scope_head(e)
43
+ return e, type_logits, class_logits, scope_logits
44
+
45
+ def compute_structural_loss(
46
+ self,
47
+ type_logits: torch.Tensor,
48
+ class_logits: torch.Tensor,
49
+ scope_logits: torch.Tensor,
50
+ type_labels: torch.LongTensor,
51
+ class_labels: torch.LongTensor,
52
+ scope_labels: torch.LongTensor,
53
+ entity_mask: torch.Tensor,
54
+ ) -> torch.Tensor:
55
+ type_loss = F.cross_entropy(
56
+ type_logits[entity_mask], type_labels[entity_mask]
57
+ )
58
+ class_loss = F.cross_entropy(
59
+ class_logits[entity_mask], class_labels[entity_mask]
60
+ )
61
+ scope_loss = F.cross_entropy(
62
+ scope_logits[entity_mask], scope_labels[entity_mask]
63
+ )
64
+ return type_loss + class_loss + scope_loss
@@ -0,0 +1,73 @@
1
+ """Schema-based entity hash encoding.
2
+
3
+ 17-bit layout (131072 entities):
4
+ [16:12] type (5 bits, 0-31)
5
+ [11:6] class (6 bits, 0-63)
6
+ [5:3] scope (3 bits, 0-7)
7
+ [2] arity (1 bit, low=0 high=1)
8
+ [1] role (1 bit, caller=0 callee=1)
9
+ [0] morphism (1 bit, static=0 dynamic=1)
10
+ """
11
+
12
+ import torch
13
+
14
+ SHIFT_TYPE = 12
15
+ SHIFT_CLASS = 6
16
+ SHIFT_SCOPE = 3
17
+ SHIFT_ARITY = 2
18
+ SHIFT_ROLE = 1
19
+ SHIFT_MORPHISM = 0
20
+
21
+ MASK_TYPE = 0x1F
22
+ MASK_CLASS = 0x3F
23
+ MASK_SCOPE = 0x7
24
+ MASK_ARITY = 0x1
25
+ MASK_ROLE = 0x1
26
+ MASK_MORPHISM = 0x1
27
+
28
+ NUM_ENTITIES = 1 << 17
29
+ NUM_TYPES = 1 << 5
30
+ NUM_CLASSES = 1 << 6
31
+ NUM_SCOPES = 1 << 3
32
+
33
+
34
+ def encode(
35
+ type_id: int,
36
+ class_id: int,
37
+ scope_id: int,
38
+ arity: int,
39
+ role: int,
40
+ morphism: int,
41
+ ) -> int:
42
+ return (
43
+ (type_id << SHIFT_TYPE)
44
+ | (class_id << SHIFT_CLASS)
45
+ | (scope_id << SHIFT_SCOPE)
46
+ | (arity << SHIFT_ARITY)
47
+ | (role << SHIFT_ROLE)
48
+ | morphism
49
+ )
50
+
51
+
52
+ def get_type(entity_hash: torch.Tensor) -> torch.Tensor:
53
+ return (entity_hash >> SHIFT_TYPE) & MASK_TYPE
54
+
55
+
56
+ def get_class(entity_hash: torch.Tensor) -> torch.Tensor:
57
+ return (entity_hash >> SHIFT_CLASS) & MASK_CLASS
58
+
59
+
60
+ def get_scope(entity_hash: torch.Tensor) -> torch.Tensor:
61
+ return (entity_hash >> SHIFT_SCOPE) & MASK_SCOPE
62
+
63
+
64
+ def get_arity(entity_hash: torch.Tensor) -> torch.Tensor:
65
+ return (entity_hash >> SHIFT_ARITY) & MASK_ARITY
66
+
67
+
68
+ def get_role(entity_hash: torch.Tensor) -> torch.Tensor:
69
+ return (entity_hash >> SHIFT_ROLE) & MASK_ROLE
70
+
71
+
72
+ def get_morphism(entity_hash: torch.Tensor) -> torch.Tensor:
73
+ return (entity_hash >> SHIFT_MORPHISM) & MASK_MORPHISM
@@ -0,0 +1,60 @@
1
+ """Murmuration router: lightweight functor routing.
2
+
3
+ Routes tokens to specialized functor modules based on content and
4
+ CEID scope (local interaction rules rather than global learned gating).
5
+
6
+ Enabled in Phase 3 — disabled during Phase 1 pretraining.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class MurmurationRouter(nn.Module):
15
+ def __init__(
16
+ self,
17
+ hidden_size: int,
18
+ routing_dim: int = 512,
19
+ num_functors: int = 8,
20
+ top_k: int = 2,
21
+ num_scopes: int = 8,
22
+ ):
23
+ super().__init__()
24
+ self.hidden_size = hidden_size
25
+ self.num_functors = num_functors
26
+ self.top_k = top_k
27
+ self.gate = nn.Sequential(
28
+ nn.Linear(hidden_size, routing_dim),
29
+ nn.SiLU(),
30
+ nn.Linear(routing_dim, num_functors),
31
+ )
32
+ self.scope_bias = nn.Embedding(num_scopes, num_functors)
33
+
34
+ def forward(
35
+ self,
36
+ hidden_states: torch.Tensor,
37
+ scope_ids: torch.LongTensor | None = None,
38
+ ) -> tuple[torch.Tensor, torch.Tensor]:
39
+ logits = self.gate(hidden_states)
40
+ if scope_ids is not None:
41
+ logits = logits + self.scope_bias(scope_ids)
42
+ routing_weights = F.softmax(logits, dim=-1)
43
+ top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
44
+ top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
45
+ return top_k_weights, top_k_indices
46
+
47
+
48
+ class FunctorModule(nn.Module):
49
+ """A single functor — a small FFN specialized for one morphism class."""
50
+
51
+ def __init__(self, hidden_size: int, intermediate_mult: int = 1):
52
+ super().__init__()
53
+ self.net = nn.Sequential(
54
+ nn.Linear(hidden_size, hidden_size * intermediate_mult),
55
+ nn.GELU(),
56
+ nn.Linear(hidden_size * intermediate_mult, hidden_size),
57
+ )
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ return self.net(x)
ent/__init__.py ADDED
@@ -0,0 +1,65 @@
1
+ """ent — entity-conditioned language models with category-theoretic reasoning.
2
+
3
+ Key components:
4
+ category — EntityCategory, EntityMorphism, EntityFunctor, Subcategory
5
+ store — EntityStore (entity decoder weights, fast lookups)
6
+ reasoner — EntityReasoner (symbolic queries over entity graphs)
7
+ decoder — EntityConditionedDecoder (small LM for text generation)
8
+ inference — EntInferenceEngine (abstraction, graph reasoning, working memory, program execution)
9
+ training — EntitySmolWrapper (entity-conditioned generation with LoRA)
10
+ serving — FastAPI server with OpenAI-compatible /v1/chat/completions
11
+ """
12
+
13
+ from ent.core.category import (
14
+ EntityObject,
15
+ EntityMorphism,
16
+ EntityCategory,
17
+ StructuralMorphism,
18
+ EmbeddingFunctor,
19
+ TypeClassifierFunctor,
20
+ Subcategory,
21
+ )
22
+ from ent.core.store import EntityStore
23
+ from ent.core.reasoner import EntityGraph, EntityReasoner
24
+ from ent.models.decoder import DecoderConfig, EntityConditionedDecoder
25
+ from ent.core.inference import (
26
+ AbstractionNode,
27
+ AbstractionEdge,
28
+ AbstractionGraph,
29
+ MemoryItem,
30
+ DurableMemoryRecord,
31
+ DurableMemoryStore,
32
+ CandidateAnswer,
33
+ InferenceState,
34
+ EntityAbstractionLayer,
35
+ ProgramExecutor,
36
+ EntInferenceEngine,
37
+ )
38
+
39
+ __all__ = [
40
+ "EntityObject",
41
+ "EntityMorphism",
42
+ "EntityCategory",
43
+ "StructuralMorphism",
44
+ "EmbeddingFunctor",
45
+ "TypeClassifierFunctor",
46
+ "Subcategory",
47
+ "EntityStore",
48
+ "EntityGraph",
49
+ "EntityReasoner",
50
+ "DecoderConfig",
51
+ "EntityConditionedDecoder",
52
+ "AbstractionNode",
53
+ "AbstractionEdge",
54
+ "AbstractionGraph",
55
+ "MemoryItem",
56
+ "DurableMemoryRecord",
57
+ "DurableMemoryStore",
58
+ "CandidateAnswer",
59
+ "InferenceState",
60
+ "EntityAbstractionLayer",
61
+ "ProgramExecutor",
62
+ "EntInferenceEngine",
63
+ ]
64
+
65
+ __version__ = "0.1.0"
ent/core/__init__.py ADDED
@@ -0,0 +1 @@
1
+ """Core entity reasoning primitives."""
ent/core/category.py ADDED
@@ -0,0 +1,405 @@
1
+ """Category theory foundations for entity reasoning.
2
+
3
+ The entity hash encodes categorical structure without any neural
4
+ computation. This module formalizes the structure:
5
+
6
+ Objects — EntityObject: a single entity hash (17-bit)
7
+ Morphisms — EntityMorphism: a structural relationship between entities,
8
+ derived from type/class/scope compatibility
9
+ Category — EntityCategory: the discrete category of all entities
10
+ Functors — Map EntityCategory → other categories (Vect, TypeDist, etc.)
11
+ Subcategory — Filtered view of the category by structural constraints
12
+
13
+ All operations are O(1) or O(n) on the CPU — no matrix multiplies.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Callable, Optional, Iterator
20
+
21
+ import torch
22
+
23
+ from architecture.entity_hash import (
24
+ get_type,
25
+ get_class,
26
+ get_scope,
27
+ get_arity,
28
+ get_role,
29
+ get_morphism,
30
+ encode,
31
+ NUM_ENTITIES,
32
+ NUM_TYPES,
33
+ NUM_CLASSES,
34
+ NUM_SCOPES,
35
+ SHIFT_TYPE,
36
+ SHIFT_CLASS,
37
+ SHIFT_SCOPE,
38
+ MASK_TYPE,
39
+ MASK_CLASS,
40
+ MASK_SCOPE,
41
+ )
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class EntityObject:
46
+ """An object in the discrete category of entities.
47
+
48
+ Wraps a 17-bit entity hash and exposes its structural fields.
49
+ Two EntityObjects are equal iff their hashes are equal.
50
+ """
51
+
52
+ hash: int
53
+
54
+ @property
55
+ def type_id(self) -> int:
56
+ return int(get_type(torch.tensor(self.hash)).item())
57
+
58
+ @property
59
+ def class_id(self) -> int:
60
+ return int(get_class(torch.tensor(self.hash)).item())
61
+
62
+ @property
63
+ def scope_id(self) -> int:
64
+ return int(get_scope(torch.tensor(self.hash)).item())
65
+
66
+ @property
67
+ def arity(self) -> int:
68
+ return int(get_arity(torch.tensor(self.hash)).item())
69
+
70
+ @property
71
+ def role(self) -> int:
72
+ return int(get_role(torch.tensor(self.hash)).item())
73
+
74
+ @property
75
+ def morphism_bit(self) -> int:
76
+ return int(get_morphism(torch.tensor(self.hash)).item())
77
+
78
+ @property
79
+ def type_name(self) -> str:
80
+ names = {
81
+ 0: "empty",
82
+ 1: "number",
83
+ 2: "proper-noun",
84
+ 3: "lowercase-alpha",
85
+ 4: "punctuation",
86
+ 5: "alphanumeric",
87
+ 6: "other",
88
+ }
89
+ return names.get(self.type_id, f"type-{self.type_id}")
90
+
91
+ def __repr__(self) -> str:
92
+ return (
93
+ f"Entity(hash={self.hash}, type={self.type_name}, "
94
+ f"class={self.class_id}, scope={self.scope_id})"
95
+ )
96
+
97
+
98
+ class EntityMorphism:
99
+ """A morphism f: A → B between entity objects.
100
+
101
+ Morphisms arise from structural relationships:
102
+ - identity: A → A for every object
103
+ - scope-expanding: A → B when class(A)=class(B) and scope(A)<scope(B)
104
+ - type-specializing: A → B when scope(A)=scope(B) and types are compatible
105
+ - embedding: A → B with an associated vector transformation
106
+
107
+ The category is thin: at most one morphism between any two objects.
108
+ """
109
+
110
+ __slots__ = ("_source", "_target", "_kind", "_hash")
111
+
112
+ def __init__(
113
+ self,
114
+ source: EntityObject,
115
+ target: EntityObject,
116
+ kind: str = "structural",
117
+ ):
118
+ self._source = source
119
+ self._target = target
120
+ self._kind = kind
121
+ self._hash = hash((source.hash, target.hash, kind))
122
+
123
+ @property
124
+ def source(self) -> EntityObject:
125
+ return self._source
126
+
127
+ @property
128
+ def target(self) -> EntityObject:
129
+ return self._target
130
+
131
+ @property
132
+ def kind(self) -> str:
133
+ return self._kind
134
+
135
+ @property
136
+ def is_identity(self) -> bool:
137
+ return self._source == self._target and self._kind == "identity"
138
+
139
+ @property
140
+ def is_structural(self) -> bool:
141
+ return self._kind in ("structural", "scope", "type", "identity")
142
+
143
+ def compose(self, other: EntityMorphism) -> Optional[EntityMorphism]:
144
+ """Compose self: A→B with other: B→C to get other∘self: A→C.
145
+
146
+ Composition is defined when self.target == other.source.
147
+ Returns None if composition is not valid in this category.
148
+ """
149
+ if self.target != other.source:
150
+ return None
151
+ if self.is_identity:
152
+ return other
153
+ if other.is_identity:
154
+ return self
155
+ if not self._can_compose_with(other):
156
+ return None
157
+ composed_kind = "composed"
158
+ return EntityMorphism(self.source, other.target, kind=composed_kind)
159
+
160
+ def _can_compose_with(self, other: EntityMorphism) -> bool:
161
+ if self.kind == "embedding" or other.kind == "embedding":
162
+ return True
163
+ return (
164
+ self.source.class_id == other.source.class_id
165
+ or self.source.scope_id == other.source.scope_id
166
+ or self.source.type_id == other.source.type_id
167
+ )
168
+
169
+ def __call__(self) -> str:
170
+ return f"{self.source} → {self.target}"
171
+
172
+ def __repr__(self) -> str:
173
+ return f"Morphism({self.source.hash}→{self.target.hash}, {self.kind})"
174
+
175
+ def __hash__(self) -> int:
176
+ return self._hash
177
+
178
+ def __eq__(self, other: object) -> bool:
179
+ if not isinstance(other, EntityMorphism):
180
+ return False
181
+ return (
182
+ self._source == other._source
183
+ and self._target == other._target
184
+ and self._kind == other._kind
185
+ )
186
+
187
+
188
+ class EntityCategory:
189
+ """The discrete category C where:
190
+ - ob(C) = {0, 1, ..., 2^17 - 1} (all entity hashes)
191
+ - hom(A, B) = {f: A→B | structural relationship holds}
192
+ - id_A = EntityMorphism(A, A, kind="identity")
193
+ - Composition of structural morphisms respects transitivity
194
+ """
195
+
196
+ def __init__(self) -> None:
197
+ self._identity_cache: dict[int, EntityMorphism] = {}
198
+
199
+ def obj(self, hash_val: int) -> EntityObject:
200
+ return EntityObject(hash_val)
201
+
202
+ def all_objects(self) -> range:
203
+ return range(NUM_ENTITIES)
204
+
205
+ def objects_of_type(self, type_id: int) -> Iterator[EntityObject]:
206
+ base = type_id << SHIFT_TYPE
207
+ for class_id in range(NUM_CLASSES):
208
+ for scope_id in range(NUM_SCOPES):
209
+ yield EntityObject(encode(type_id, class_id, scope_id, 0, 0, 0))
210
+
211
+ def objects_of_class(self, class_id: int) -> Iterator[EntityObject]:
212
+ base = class_id << SHIFT_CLASS
213
+ for type_id in range(NUM_TYPES):
214
+ for scope_id in range(NUM_SCOPES):
215
+ yield EntityObject(encode(type_id, class_id, scope_id, 0, 0, 0))
216
+
217
+ def objects_of_scope(self, scope_id: int) -> Iterator[EntityObject]:
218
+ for type_id in range(NUM_TYPES):
219
+ for class_id in range(NUM_CLASSES):
220
+ yield EntityObject(encode(type_id, class_id, scope_id, 0, 0, 0))
221
+
222
+ def identity(self, obj: EntityObject) -> EntityMorphism:
223
+ key = obj.hash
224
+ if key not in self._identity_cache:
225
+ self._identity_cache[key] = EntityMorphism(obj, obj, kind="identity")
226
+ return self._identity_cache[key]
227
+
228
+ def has_morphism(self, source: EntityObject, target: EntityObject) -> bool:
229
+ """Check whether a structural morphism exists between two entities."""
230
+ if source == target:
231
+ return True
232
+ return (
233
+ source.class_id == target.class_id
234
+ or source.scope_id == target.scope_id
235
+ or source.type_id == target.type_id
236
+ )
237
+
238
+ def hom(
239
+ self, source: EntityObject, target: EntityObject
240
+ ) -> Optional[EntityMorphism]:
241
+ """Return the unique morphism from source to target, if it exists."""
242
+ if source == target:
243
+ return self.identity(source)
244
+ if self.has_morphism(source, target):
245
+ kind = "structural"
246
+ if source.class_id == target.class_id:
247
+ kind = "class-equivalent"
248
+ elif source.scope_id == target.scope_id:
249
+ kind = "scope-equivalent"
250
+ elif source.type_id == target.type_id:
251
+ kind = "type-equivalent"
252
+ return EntityMorphism(source, target, kind=kind)
253
+ return None
254
+
255
+ def compose(
256
+ self, f: EntityMorphism, g: EntityMorphism
257
+ ) -> Optional[EntityMorphism]:
258
+ """Compose f: A→B with g: B→C to get g∘f: A→C."""
259
+ return f.compose(g)
260
+
261
+ def __repr__(self) -> str:
262
+ return f"EntityCategory(ob={NUM_ENTITIES}, types={NUM_TYPES}, classes={NUM_CLASSES}, scopes={NUM_SCOPES})"
263
+
264
+
265
+ @dataclass
266
+ class StructuralMorphism:
267
+ """A morphism encoded in the entity hash bits.
268
+
269
+ In v0, arity/role/morphism bits are always 0 (placeholder).
270
+ Future versions will populate these from syntactic context.
271
+ When populated, they define how entities transform into each other.
272
+ """
273
+
274
+ arity: int = 0
275
+ role: int = 0
276
+ morphism: int = 0
277
+
278
+ @property
279
+ def is_identity_like(self) -> bool:
280
+ return self.arity == 0 and self.role == 0 and self.morphism == 0
281
+
282
+ @property
283
+ def is_unary(self) -> bool:
284
+ return self.arity == 0
285
+
286
+ @property
287
+ def is_binary(self) -> bool:
288
+ return self.arity == 1
289
+
290
+ @property
291
+ def is_static(self) -> bool:
292
+ return self.morphism == 0
293
+
294
+ @property
295
+ def is_dynamic(self) -> bool:
296
+ return self.morphism == 1
297
+
298
+ def __repr__(self) -> str:
299
+ parts = []
300
+ parts.append("unary" if self.arity == 0 else "binary")
301
+ parts.append("caller" if self.role == 0 else "callee")
302
+ parts.append("static" if self.morphism == 0 else "dynamic")
303
+ return f"MorphismEncoding({', '.join(parts)})"
304
+
305
+
306
+ class EmbeddingFunctor:
307
+ """The embedding functor F_embed: EntityCategory → Vect.
308
+
309
+ Maps each entity hash to its learned embedding vector
310
+ and each morphism to the difference of embeddings.
311
+
312
+ The entity_decoder's entity_embed IS this functor at the object level.
313
+ """
314
+
315
+ def __init__(self, embedding_weight: torch.Tensor):
316
+ self._weight = embedding_weight
317
+
318
+ def map_object(self, obj: EntityObject) -> torch.Tensor:
319
+ return self._weight[obj.hash]
320
+
321
+ def map_morphism(self, f: EntityMorphism) -> torch.Tensor:
322
+ return self.map_object(f.target) - self.map_object(f.source)
323
+
324
+ def map_objects_batch(self, hashes: torch.Tensor) -> torch.Tensor:
325
+ return self._weight[hashes]
326
+
327
+ def __call__(self, obj: EntityObject) -> torch.Tensor:
328
+ return self.map_object(obj)
329
+
330
+
331
+ class TypeClassifierFunctor:
332
+ """The type-classifier functor F_type: EntityCategory → simplex(32).
333
+
334
+ Maps each entity to its predicted type distribution via the
335
+ entity_decoder's type_head.
336
+ """
337
+
338
+ def __init__(self, type_head_weight: torch.Tensor, type_head_bias: torch.Tensor):
339
+ self._w = type_head_weight
340
+ self._b = type_head_bias
341
+
342
+ def map_object(self, embedding: torch.Tensor) -> torch.Tensor:
343
+ logits = embedding @ self._w.T + self._b
344
+ return torch.softmax(logits, dim=-1)
345
+
346
+ def predicted_type(self, embedding: torch.Tensor) -> int:
347
+ logits = embedding @ self._w.T + self._b
348
+ return int(logits.argmax(dim=-1).item())
349
+
350
+
351
+ class Subcategory:
352
+ """A full subcategory of EntityCategory filtered by structural constraints.
353
+
354
+ objects ⊆ ob(EntityCategory) restricted to matching type, class, and/or scope.
355
+ Morphisms are inherited from the parent category.
356
+ """
357
+
358
+ def __init__(
359
+ self,
360
+ parent: EntityCategory,
361
+ type_ids: Optional[list[int]] = None,
362
+ class_ids: Optional[list[int]] = None,
363
+ scope_ids: Optional[list[int]] = None,
364
+ ):
365
+ self.parent = parent
366
+ self.type_ids = set(type_ids) if type_ids else None
367
+ self.class_ids = set(class_ids) if class_ids else None
368
+ self.scope_ids = set(scope_ids) if scope_ids else None
369
+
370
+ def contains(self, obj: EntityObject) -> bool:
371
+ if self.type_ids is not None and obj.type_id not in self.type_ids:
372
+ return False
373
+ if self.class_ids is not None and obj.class_id not in self.class_ids:
374
+ return False
375
+ if self.scope_ids is not None and obj.scope_id not in self.scope_ids:
376
+ return False
377
+ return True
378
+
379
+ def objects(self) -> Iterator[EntityObject]:
380
+ for h in self.parent.all_objects():
381
+ obj = EntityObject(h)
382
+ if self.contains(obj):
383
+ yield obj
384
+
385
+ def identity(self, obj: EntityObject) -> Optional[EntityMorphism]:
386
+ if not self.contains(obj):
387
+ return None
388
+ return self.parent.identity(obj)
389
+
390
+ def hom(
391
+ self, source: EntityObject, target: EntityObject
392
+ ) -> Optional[EntityMorphism]:
393
+ if not self.contains(source) or not self.contains(target):
394
+ return None
395
+ return self.parent.hom(source, target)
396
+
397
+ def __repr__(self) -> str:
398
+ constraints = []
399
+ if self.type_ids:
400
+ constraints.append(f"types={sorted(self.type_ids)}")
401
+ if self.class_ids:
402
+ constraints.append(f"classes={sorted(self.class_ids)}")
403
+ if self.scope_ids:
404
+ constraints.append(f"scopes={sorted(self.scope_ids)}")
405
+ return f"Subcategory({', '.join(constraints)})"