Source code for qmdesc.handler

This module defines the PathwayRankingHandler for use in Torchserve.
from typing import Dict

import pkg_resources
import torch
from qmdesc.featurization import mol2graph, get_atom_fdim, get_bond_fdim
from rdkit import Chem

[docs]class ReactivityDescriptorHandler(): '''Wrap the trained atom-bond qm descriptors predicting model Predict QM descriptors for a given SMILES string of organic compound containing C, H, O, N, P, S, F, Cl, Br, I, B Example: >>> from qmdesc import ReactivityDescriptorHandler >>> handler = ReactivityDescriptorHandler() >>> results = handler.predict('CCCC') ''' def __init__(self): """ ReactivityDescriptorHandler constructor. """ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_pt_path = "" from qmdesc.model import MoleculeModel # Load model and args stream = pkg_resources.resource_stream(__name__, model_pt_path) state = torch.load(stream, lambda storage, loc: storage) args, loaded_state_dict = state['args'], state['state_dict'] atom_fdim = get_atom_fdim() bond_fdim = get_bond_fdim() + atom_fdim self.model = MoleculeModel(args, atom_fdim, bond_fdim) self.model.load_state_dict(loaded_state_dict) self.model.eval() self.initalized = True def _preprocess(self, smiles: str): """ Preprocess SMILES :param smiles: SMILES string :return: molecular graph """ mol_graph = mol2graph(smiles) f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, b2br, bond_types = mol_graph.get_components() f_atoms, f_bonds, a2b, b2a, b2revb, b2br, bond_types = \,,,, \,, return f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, b2br, bond_types def _inference(self, data): """ model prediction :param data: molecular graph :return: The output of the model """ descs = self.model(data) return descs def _postprocess(self, inference_output) -> Dict: """ Postprocess results :param inference_output: The output of the model :return: Results """ smiles = inference_output['smiles'] descs = inference_output['descs'] descs = [ for x in descs] partial_charge, partial_neu, partial_elec, NMR, bond_order, bond_distance = descs results = {'smiles': smiles, 'partial_charge': partial_charge.flatten(), 'fukui_neu': partial_neu.flatten(), 'fukui_elec': partial_elec.flatten(), 'NMR': NMR.flatten(), 'bond_order': bond_order.flatten(), 'bond_length': bond_distance.flatten()} return results
[docs] def predict(self, smiles: str, sdf: str = None) -> Dict: """ Wrap the preprocess, inference, and postprocess :param smiles: Input SMILES string :param sdf: Output .sdf file :return: A dictionary containing the prediction result """ outputs = self._inference(self._preprocess([smiles])) postprocess_inputs = {'smiles': smiles, 'descs': outputs} results = self._postprocess(postprocess_inputs) if sdf is not None: if not sdf.endswith('.sdf'): print('must provide a sdf name end up with \'.sdf\'') return results writer = Chem.SDWriter(sdf) m = Chem.MolFromSmiles(smiles) m = Chem.AddHs(m) for p in results: p_upper = p.upper() if p == 'smiles': m.SetProp(p_upper, results[p]) else: m.SetProp(p_upper, ','.join(str(x) for x in results[p])) name = sdf.strip('.sdf') m.SetProp('_Name', name) writer.write(m) return results
[docs]def qmdesc() -> None: """ This is the entry point for the command line command :code:'qmdesc' Example: $ qmdesc CCCC --sdf CCCC.sdf """ import argparse parser = argparse.ArgumentParser() parser.add_argument('smiles', type=str, help='Input smiles string') parser.add_argument('--sdf', default='qmdesc.sdf', type=str, help='output sdf saving the qm descriptors') args = parser.parse_args() predictor = ReactivityDescriptorHandler() results = predictor.predict(args.smiles, sdf=args.sdf)