import torch
[docs]def index_select_ND(source: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
"""
Selects the message features from source corresponding to the atom or bond indices in :code:`index`.
:param source: A tensor of shape :code:`(num_bonds, hidden_size)` containing message features.
:param index: A tensor of shape :code:`(num_atoms/num_bonds, max_num_bonds)` containing the atom or bond
indices to select from :code:`source`.
:return: A tensor of shape :code:`(num_atoms/num_bonds, max_num_bonds, hidden_size)` containing the message
features corresponding to the atoms/bonds specified in index.
"""
index_size = index.size() # (num_atoms/num_bonds, max_num_bonds)
suffix_dim = source.size()[1:] # (hidden_size,)
final_size = index_size + suffix_dim # (num_atoms/num_bonds, max_num_bonds, hidden_size)
target = source.index_select(dim=0, index=index.view(-1)) # (num_atoms/num_bonds * max_num_bonds, hidden_size)
target = target.view(final_size) # (num_atoms/num_bonds, max_num_bonds, hidden_size)
return target