diff --git a/analysis/gnina.py b/analysis/gnina.py new file mode 100644 index 000000000..65cbacbfd --- /dev/null +++ b/analysis/gnina.py @@ -0,0 +1,49 @@ +import os +from rdkit import Chem + +def parse_gnina_log(filename, multiple=False): + if not multiple: + d={'Affinity':[],'RMSD':[],'CNNscore':[],'CNNaffinity':[],'CNNvariance':[]} + with open(filename,'r') as f: + for line in f: + for key in d: + if line[:len(key)]==key: + d[key].append(float(line.strip().split(' (kcal/mol)')[0].split(' ')[-1])) + return d + else: + d={'affinity':[],'intramol':[],'CNNscore':[],'CNNaffinity':[]} + with open(filename,'r') as f: + for line in f: + if line[:5]==' 1': + for i, key in enumerate(d): + d[key].append(float(line[5+i*13:18+i*13])) + return d + +class Gnina: + def __init__(self, pdb_file): + self.gnina='/home/domain/data/prog/micromamba/envs/drugflow/bin/gnina' + self.tmp_ligand='/tmp/tmp_gnina.sdf' + self.pdb_file=pdb_file + + def calculate_metrics(self,rdmol): + + writer = Chem.SDWriter(self.tmp_ligand) + writer.write(mol=rdmol) + writer.close() + + cmd=f'{self.gnina} -r {self.pdb_file} -l {self.tmp_ligand} --minimize' + output = os.popen(cmd, 'r') + d={'Affinity':[],'RMSD':[],'CNNscore':[],'CNNaffinity':[],'CNNvariance':[]} + for line in output: + for key in d: + if line.startswith(key): + d[key].append(float(line.strip().split(' ')[1])) + return d + + def affinity(self,rdmol): + d=self.calculate_metrics(rdmol) + return -d['Affinity'][0] + + def CNNaffinity(self,rdmol): + d=self.calculate_metrics(rdmol) + return d['CNNaffinity'][0] diff --git a/inpaint.py b/inpaint.py index 8f234bc54..995ff1260 100644 --- a/inpaint.py +++ b/inpaint.py @@ -172,6 +172,7 @@ def inpaint_ligand(model, pdb_file, n_samples, ligand, fix_atoms, # Build mol objects x = xh_lig[:, :model.x_dims].detach().cpu() atom_type = xh_lig[:, model.x_dims:].argmax(1).detach().cpu() + lig_mask=lig_mask.detach().cpu() molecules = [] for mol_pc in zip(utils.batch_to_list(x, lig_mask), diff --git a/optimize.py b/optimize.py index aaa9ec717..0652f118f 100644 --- a/optimize.py +++ b/optimize.py @@ -130,6 +130,7 @@ def diversify_ligands(model, pocket, mols, timesteps, # Build mol objects x = out_lig[:, :model.x_dims].detach().cpu() atom_type = out_lig[:, model.x_dims:].argmax(1).detach().cpu() + lig_mask=lig_mask.detach().cpu() molecules = [] for mol_pc in zip(utils.batch_to_list(x, lig_mask), @@ -153,7 +154,7 @@ def diversify_ligands(model, pocket, mols, timesteps, parser.add_argument('--checkpoint', type=Path, default='checkpoints/crossdocked_fullatom_cond.ckpt') parser.add_argument('--pdbfile', type=str, default='example/5ndu.pdb') parser.add_argument('--ref_ligand', type=str, default='example/5ndu_linked_mols.sdf') - parser.add_argument('--objective', type=str, default='sa', choices={'qed', 'sa'}) + parser.add_argument('--objective', type=str, default='sa', choices={'qed', 'sa','gnina'}) parser.add_argument('--timesteps', type=int, default=100) parser.add_argument('--population_size', type=int, default=100) parser.add_argument('--evolution_steps', type=int, default=10) @@ -188,18 +189,22 @@ def diversify_ligands(model, pocket, mols, timesteps, objective_function = MoleculeProperties().calculate_qed elif args.objective == 'sa': objective_function = MoleculeProperties().calculate_sa + elif args.objective == 'gnina': + from analysis.gnina import Gnina + objective_function = Gnina(args.pdbfile).affinity else: ### IMPLEMENT YOUR OWN OBJECTIVE ### FUNCTIONS HERE raise ValueError(f"Objective function {args.objective} not recognized.") - ref_mol = Chem.SDMolSupplier(args.ref_ligand)[0] + ref_mols = Chem.SDMolSupplier(args.ref_ligand) # Store molecules in history dataframe buffer = pd.DataFrame(columns=['generation', 'score', 'fate' 'mol', 'smiles']) # Population initialization - buffer = buffer.append({'generation': 0, + for ref_mol in ref_mols: + buffer=buffer._append({'generation': 0, 'score': objective_function(ref_mol), 'fate': 'initial', 'mol': ref_mol, 'smiles': Chem.MolToSmiles(ref_mol)}, ignore_index=True) @@ -207,12 +212,15 @@ def diversify_ligands(model, pocket, mols, timesteps, for generation_idx in range(evolution_steps): if generation_idx == 0: - molecules = buffer['mol'].tolist() * population_size + top_k_molecules=buffer.sort_values(by='score', ascending=False)['mol'].tolist()[:population_size] + molecules = top_k_molecules * ((population_size-1)//len(top_k_molecules)+1) + molecules=molecules[:population_size] else: # Select top k molecules from previous generation previous_gen = buffer[buffer['generation'] == generation_idx] top_k_molecules = previous_gen.nlargest(top_k, 'score')['mol'].tolist() - molecules = top_k_molecules * (population_size // top_k) + molecules = top_k_molecules * ((population_size-1)//top_k+1) + molecules=molecules[:population_size] # Update the fate of selected top k molecules in the buffer buffer.loc[buffer['generation'] == generation_idx, 'fate'] = 'survived' @@ -235,7 +243,7 @@ def diversify_ligands(model, pocket, mols, timesteps, # Evaluate and save molecules for mol in molecules: - buffer = buffer.append({'generation': generation_idx + 1, + buffer = buffer._append({'generation': generation_idx + 1, 'score': objective_function(mol), 'fate': 'purged', 'mol': mol,