diff --git a/src/segger/models/ist_encoder.py b/src/segger/models/ist_encoder.py index 04d4ef9..2b73c3e 100644 --- a/src/segger/models/ist_encoder.py +++ b/src/segger/models/ist_encoder.py @@ -59,10 +59,16 @@ def forward( pos: torch.Tensor, batch: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if pos.numel() == 0: + return pos.new_zeros((pos.shape[0], self.dim * 2)) + if batch is None: pos = pos - pos.min(dim=0).values pos = pos / pos.max(dim=0).values else: + if batch.numel() == 0: + return pos.new_zeros((pos.shape[0], self.dim * 2)) + batch = batch.to(torch.long) # normalize per batch mins = torch.zeros((batch.max()+1, 2), device=pos.device) maxs = torch.zeros((batch.max()+1, 2), device=pos.device) diff --git a/src/segger/models/lightning_model.py b/src/segger/models/lightning_model.py index 3578a16..48c3657 100644 --- a/src/segger/models/lightning_model.py +++ b/src/segger/models/lightning_model.py @@ -283,7 +283,7 @@ def predict_step( dim_size=batch['tx'].num_nodes, ) # Filter by similarity - valid = max_idx < dst.shape[0] + valid = (max_idx >= 0) & (max_idx < dst.shape[0]) if min_similarity is not None: valid &= max_sim >= min_similarity