Source code for qmdesc.nn_utils

import torch


[docs]def index_select_ND(source: torch.Tensor, index: torch.Tensor) -> torch.Tensor: """ Selects the message features from source corresponding to the atom or bond indices in :code:`index`. :param source: A tensor of shape :code:`(num_bonds, hidden_size)` containing message features. :param index: A tensor of shape :code:`(num_atoms/num_bonds, max_num_bonds)` containing the atom or bond indices to select from :code:`source`. :return: A tensor of shape :code:`(num_atoms/num_bonds, max_num_bonds, hidden_size)` containing the message features corresponding to the atoms/bonds specified in index. """ index_size = index.size() # (num_atoms/num_bonds, max_num_bonds) suffix_dim = source.size()[1:] # (hidden_size,) final_size = index_size + suffix_dim # (num_atoms/num_bonds, max_num_bonds, hidden_size) target = source.index_select(dim=0, index=index.view(-1)) # (num_atoms/num_bonds * max_num_bonds, hidden_size) target = target.view(final_size) # (num_atoms/num_bonds, max_num_bonds, hidden_size) return target