diff --git a/apps/drug_drug_synergy/RGCN/train.py b/apps/drug_drug_synergy/RGCN/train.py index 480fcc6e..8edc1ab0 100644 --- a/apps/drug_drug_synergy/RGCN/train.py +++ b/apps/drug_drug_synergy/RGCN/train.py @@ -87,7 +87,7 @@ def train(num_subgraph, graph, label_idx, epochs, sub_neighbours=[10, 10], init= fpr, tpr, _ = roc_curve(y_true=ground_truth, y_score=pred_prob) auc_v = auc(fpr, tpr) print("sub_graph index : {} | epoch: {} | training loss: {:.4f} | AUC: {:.3f}".format( - sub_g, epoch, train_loss.numpy()[0], auc_v)) + sub_g, epoch, float(train_loss), auc_v)) return model diff --git a/apps/drug_target_interaction/batchdta/pairwise/DeepDTA/utils.py b/apps/drug_target_interaction/batchdta/pairwise/DeepDTA/utils.py index aee31d42..c1b2fdd6 100644 --- a/apps/drug_target_interaction/batchdta/pairwise/DeepDTA/utils.py +++ b/apps/drug_target_interaction/batchdta/pairwise/DeepDTA/utils.py @@ -312,7 +312,7 @@ def model_eval(model,val_dataloader): for i_target_score in range(batch_smiles.shape[0]): - i_target_len = int(batch_len[i_target_score].numpy()[0]) + i_target_len = int(batch_len[i_target_score]) smiles = batch_smiles[i_target_score][0:i_target_len] target = batch_protein[i_target_score][0:i_target_len] y_label = batch_y[i_target_score][0:i_target_len].numpy() diff --git a/apps/drug_target_interaction/batchdta/pairwise/GraphDTA/run_pairwise_GraphDTA_CV.py b/apps/drug_target_interaction/batchdta/pairwise/GraphDTA/run_pairwise_GraphDTA_CV.py index e1c2f2a3..374532a1 100644 --- a/apps/drug_target_interaction/batchdta/pairwise/GraphDTA/run_pairwise_GraphDTA_CV.py +++ b/apps/drug_target_interaction/batchdta/pairwise/GraphDTA/run_pairwise_GraphDTA_CV.py @@ -195,9 +195,9 @@ def model_eval(model,val_dataloader,device): i_data = i_data.to(device) pred_scores = model.forward_single(i_data) # get the predicted labels - i_target_pred_scores.append(pred_scores.cpu().numpy()[0]) + i_target_pred_scores.append(float(pred_scores)) # get the true labels - i_target_y_label.append(i_data.y.cpu().numpy()[0]) + i_target_y_label.append(float(i_data.y.cpu())) i_target_pred_scores = np.array(i_target_pred_scores) i_target_y_label = np.array(i_target_y_label) diff --git a/apps/drug_target_interaction/batchdta/pairwise/Moltrans/helper/utils/paddle_tensor.py b/apps/drug_target_interaction/batchdta/pairwise/Moltrans/helper/utils/paddle_tensor.py index c0a9aa51..8c84c0f0 100644 --- a/apps/drug_target_interaction/batchdta/pairwise/Moltrans/helper/utils/paddle_tensor.py +++ b/apps/drug_target_interaction/batchdta/pairwise/Moltrans/helper/utils/paddle_tensor.py @@ -32,7 +32,7 @@ def item(self): """ Item function """ - return self.numpy()[0] + return float(self) @add_tensor_function diff --git a/apps/drug_target_interaction/batchdta/pairwise/Moltrans/run_pairwise_Moltrans_CV.py b/apps/drug_target_interaction/batchdta/pairwise/Moltrans/run_pairwise_Moltrans_CV.py index 127c091a..ea7fc11b 100644 --- a/apps/drug_target_interaction/batchdta/pairwise/Moltrans/run_pairwise_Moltrans_CV.py +++ b/apps/drug_target_interaction/batchdta/pairwise/Moltrans/run_pairwise_Moltrans_CV.py @@ -297,7 +297,7 @@ def model_eval(model,val_dataloader,len_SMILES,len_target): for i_target_score in range(batch_x.shape[0]): - i_target_len = int(batch_len[i_target_score].numpy()[0]) + i_target_len = int(batch_len[i_target_score]) smiles = batch_x_smiles[i_target_score][0:i_target_len] target = batch_x_protein[i_target_score][0:i_target_len] smiles_mask = batch_x_smiles_mask[i_target_score][0:i_target_len] diff --git a/apps/drug_target_interaction/batchdta/pairwise/Moltrans/run_pairwise_Moltrans_bindingDB.py b/apps/drug_target_interaction/batchdta/pairwise/Moltrans/run_pairwise_Moltrans_bindingDB.py index 7825352d..91ef20db 100644 --- a/apps/drug_target_interaction/batchdta/pairwise/Moltrans/run_pairwise_Moltrans_bindingDB.py +++ b/apps/drug_target_interaction/batchdta/pairwise/Moltrans/run_pairwise_Moltrans_bindingDB.py @@ -282,7 +282,7 @@ def model_eval(model,val_dataloader,len_SMILES,len_target): for i_target_score in range(batch_x.shape[0]): - i_target_len = int(batch_len[i_target_score].numpy()[0]) + i_target_len = int(batch_len[i_target_score]) smiles = batch_x_smiles[i_target_score][0:i_target_len] target = batch_x_protein[i_target_score][0:i_target_len] smiles_mask = batch_x_smiles_mask[i_target_score][0:i_target_len] diff --git a/apps/drug_target_interaction/batchdta/pointwise/DeepDTA/train_bindingdb.py b/apps/drug_target_interaction/batchdta/pointwise/DeepDTA/train_bindingdb.py index 8eec40fe..fbe85462 100644 --- a/apps/drug_target_interaction/batchdta/pointwise/DeepDTA/train_bindingdb.py +++ b/apps/drug_target_interaction/batchdta/pointwise/DeepDTA/train_bindingdb.py @@ -60,7 +60,7 @@ def training(model, training_loader, optim): optim.clear_grad() loss.backward() optim.step() - res_loss = loss.numpy()[0] + res_loss = float(loss) return res_loss diff --git a/apps/drug_target_interaction/batchdta/pointwise/DeepDTA/train_davis.py b/apps/drug_target_interaction/batchdta/pointwise/DeepDTA/train_davis.py index fa59c086..dd086cfe 100644 --- a/apps/drug_target_interaction/batchdta/pointwise/DeepDTA/train_davis.py +++ b/apps/drug_target_interaction/batchdta/pointwise/DeepDTA/train_davis.py @@ -60,7 +60,7 @@ def training(model, training_loader, optim): optim.clear_grad() loss.backward() optim.step() - res_loss = loss.numpy()[0] + res_loss = float(loss) return res_loss diff --git a/apps/drug_target_interaction/batchdta/pointwise/DeepDTA/train_kiba.py b/apps/drug_target_interaction/batchdta/pointwise/DeepDTA/train_kiba.py index e57a5d1c..28864824 100644 --- a/apps/drug_target_interaction/batchdta/pointwise/DeepDTA/train_kiba.py +++ b/apps/drug_target_interaction/batchdta/pointwise/DeepDTA/train_kiba.py @@ -63,7 +63,7 @@ def training(model, training_loader, optim): optim.clear_grad() loss.backward() optim.step() - res_loss = loss.numpy()[0] + res_loss = float(loss.numpy()) return res_loss diff --git a/apps/drug_target_interaction/batchdta/pointwise/Moltrans/helper/utils/paddle_tensor.py b/apps/drug_target_interaction/batchdta/pointwise/Moltrans/helper/utils/paddle_tensor.py index c0a9aa51..8c84c0f0 100644 --- a/apps/drug_target_interaction/batchdta/pointwise/Moltrans/helper/utils/paddle_tensor.py +++ b/apps/drug_target_interaction/batchdta/pointwise/Moltrans/helper/utils/paddle_tensor.py @@ -32,7 +32,7 @@ def item(self): """ Item function """ - return self.numpy()[0] + return float(self) @add_tensor_function diff --git a/apps/drug_target_interaction/moltrans_dti/helper/utils/paddle_tensor.py b/apps/drug_target_interaction/moltrans_dti/helper/utils/paddle_tensor.py index c0a9aa51..45ecbeac 100644 --- a/apps/drug_target_interaction/moltrans_dti/helper/utils/paddle_tensor.py +++ b/apps/drug_target_interaction/moltrans_dti/helper/utils/paddle_tensor.py @@ -32,7 +32,7 @@ def item(self): """ Item function """ - return self.numpy()[0] + return float(self.numpy()) @add_tensor_function diff --git a/apps/fewshot_molecular_property/chem_lib/models/trainer.py b/apps/fewshot_molecular_property/chem_lib/models/trainer.py index 05997cf4..ba3681d3 100644 --- a/apps/fewshot_molecular_property/chem_lib/models/trainer.py +++ b/apps/fewshot_molecular_property/chem_lib/models/trainer.py @@ -294,7 +294,7 @@ def train_step(self): losses_eval.backward() self.optimizer.step() - print('Train Epoch:',self.train_epoch,', train update step:', k, ', loss_eval:', losses_eval.numpy()[0]) + print('Train Epoch:',self.train_epoch,', train update step:', k, ', loss_eval:', float(losses_eval)) return self.model.layers diff --git a/apps/molecular_generation/SD_VAE/train_zinc.py b/apps/molecular_generation/SD_VAE/train_zinc.py index 8908c178..97b9fb42 100755 --- a/apps/molecular_generation/SD_VAE/train_zinc.py +++ b/apps/molecular_generation/SD_VAE/train_zinc.py @@ -122,9 +122,9 @@ def _train_epoch(model, data_loader, epoch, kl_weight, optimizer=None): optimizer.clear_grad() # Log - kl_loss_values.append(kl_loss.numpy()[0]) - perplexity_loss_values.append(perplexity.numpy()[0]) - loss_values.append(loss.numpy()[0]) + kl_loss_values.append(float(kl_loss)) + perplexity_loss_values.append(float(perplexity)) + loss_values.append(float(loss)) lr = (optimizer.get_lr() if optimizer is not None else 0) diff --git a/apps/pretrained_compound/ChemRL/GEM-2/src/paddle_utils.py b/apps/pretrained_compound/ChemRL/GEM-2/src/paddle_utils.py index b4a24f7d..c851aa06 100644 --- a/apps/pretrained_compound/ChemRL/GEM-2/src/paddle_utils.py +++ b/apps/pretrained_compound/ChemRL/GEM-2/src/paddle_utils.py @@ -37,8 +37,8 @@ def dist_mean(array, distributed=False): n = len(array) x_sum = 0 if n == 0 else np.sum(array) if distributed: - n = dist_all_reduce(paddle.to_tensor(n, dtype='int64')).numpy()[0] - x_sum = dist_all_reduce(paddle.to_tensor(x_sum, dtype='float32')).numpy()[0] + n = int(dist_all_reduce(paddle.to_tensor(n, dtype='int64'))) + x_sum = float(dist_all_reduce(paddle.to_tensor(x_sum, dtype='float32'))) x_mean = 0 if n == 0 else x_sum / n return x_mean @@ -47,14 +47,14 @@ def dist_sum(array, distributed=False): n = len(array) x_sum = 0 if n == 0 else np.sum(array) if distributed: - x_sum = dist_all_reduce(paddle.to_tensor(x_sum, dtype='float32')).numpy()[0] + x_sum = float(dist_all_reduce(paddle.to_tensor(x_sum, dtype='float32'))) return x_sum def dist_length(array, distributed=False): n = len(array) if distributed: - n = dist_all_reduce(paddle.to_tensor(n, dtype='int64')).numpy()[0] + n = int(dist_all_reduce(paddle.to_tensor(n, dtype='int64'))) return n diff --git a/apps/pretrained_compound/ChemRL/GEM-2/train_gem2.py b/apps/pretrained_compound/ChemRL/GEM-2/train_gem2.py index fb9e878f..bfaaa9fa 100644 --- a/apps/pretrained_compound/ChemRL/GEM-2/train_gem2.py +++ b/apps/pretrained_compound/ChemRL/GEM-2/train_gem2.py @@ -80,7 +80,7 @@ def get_train_steps_per_epoch(dataset_len, args): min_data_len = paddle.to_tensor(dataset_len) from paddle.distributed import ReduceOp dist.all_reduce(min_data_len, ReduceOp.MIN) - dataset_len = min_data_len.numpy()[0] + dataset_len = int(min_data_len) logging.info(f'min dataset len: {dataset_len}') return int(dataset_len / args.batch_size) - 5 diff --git a/apps/protein_folding/helixfold-single/helixfold_single_inference.py b/apps/protein_folding/helixfold-single/helixfold_single_inference.py index a789da14..8b368f04 100644 --- a/apps/protein_folding/helixfold-single/helixfold_single_inference.py +++ b/apps/protein_folding/helixfold-single/helixfold_single_inference.py @@ -113,7 +113,7 @@ def main(args): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("--init_model", type=str, help='tape + af2 stacked model') + parser.add_argument("--init_model", type=str, help='path to pretrained model') parser.add_argument("--fasta_file", type=str, help='path to fasta file to be predicted') parser.add_argument("--output_dir", type=str, help='path to prediction outputs') args = parser.parse_args() diff --git a/apps/protein_folding/helixfold-single/tape/others/utils.py b/apps/protein_folding/helixfold-single/tape/others/utils.py index f3cd392a..21272a38 100644 --- a/apps/protein_folding/helixfold-single/tape/others/utils.py +++ b/apps/protein_folding/helixfold-single/tape/others/utils.py @@ -48,8 +48,8 @@ def dist_all_reduce(x, return_num=False, distributed=False): n = len(x) x_sum = 0 if n == 0 else np.sum(x) if distributed: - n = dist.all_reduce(paddle.to_tensor(n, dtype='int64')).numpy()[0] - x_sum = dist.all_reduce(paddle.to_tensor(x_sum, dtype='float32')).numpy()[0] + n = int(dist.all_reduce(paddle.to_tensor(n, dtype='int64'))) + x_sum = float(dist.all_reduce(paddle.to_tensor(x_sum, dtype='float32'))) x_mean = 0 if n == 0 else x_sum / n if return_num: return x_mean, n @@ -62,8 +62,8 @@ def dist_mean(x, distributed=False): n = len(x) x_sum = 0 if n == 0 else np.sum(x) if distributed: - n = dist.all_reduce(paddle.to_tensor(n, dtype='int64')).numpy()[0] - x_sum = dist.all_reduce(paddle.to_tensor(x_sum, dtype='float32')).numpy()[0] + n = int(dist.all_reduce(paddle.to_tensor(n, dtype='int64'))) + x_sum = float(dist.all_reduce(paddle.to_tensor(x_sum, dtype='float32'))) x_mean = 0 if n == 0 else x_sum / n return x_mean @@ -73,7 +73,7 @@ def dist_sum(x, distributed=False): n = len(x) x_sum = 0 if n == 0 else np.sum(x) if distributed: - x_sum = dist.all_reduce(paddle.to_tensor(x_sum, dtype='float32')).numpy()[0] + x_sum = float(dist.all_reduce(paddle.to_tensor(x_sum, dtype='float32'))) return x_sum @@ -81,7 +81,7 @@ def dist_length(x, distributed=False): """tbd""" n = len(x) if distributed: - n = dist.all_reduce(paddle.to_tensor(n, dtype='int64')).numpy()[0] + n = int(dist.all_reduce(paddle.to_tensor(n, dtype='int64'))) return n diff --git a/apps/protein_folding/helixfold/README_inference.md b/apps/protein_folding/helixfold/README_inference.md index ecb1b8b2..8cf49b60 100644 --- a/apps/protein_folding/helixfold/README_inference.md +++ b/apps/protein_folding/helixfold/README_inference.md @@ -6,7 +6,7 @@ Python dependencies available through `pip` is provided in `requirements.txt`. H We provide a script `setup_env` that setup a `conda` environment and installs all dependencies. You can change the name of the environment and CUDA version in `setup_env`. Run: ```bash -wget https://paddle-wheel.bj.bcebos.com/develop/linux/linux-gpu-cuda11.2-cudnn8-mkl-gcc8.2-avx/paddlepaddle_gpu-0.0.0.post112-cp37-cp37m-linux_x86_64.whl +wget https://baidu-nlp.bj.bcebos.com/PaddleHelix/HelixFold/paddlepaddle_gpu-2.4.1-cp37-cp37m-linux_x86_64.whl sh setup_env conda activate helixfold # activate the conda environment ``` diff --git a/apps/protein_folding/helixfold/README_train.md b/apps/protein_folding/helixfold/README_train.md index 69574fc7..697024bb 100644 --- a/apps/protein_folding/helixfold/README_train.md +++ b/apps/protein_folding/helixfold/README_train.md @@ -14,7 +14,7 @@ To reproduce the results reported in our paper, specific environment settings ar ## Installation PaddlePaddle `dev` package is required to run HelixFold. Script `setup_env` is used to setup the `conda` environment, installing all dependencies. Locate to the directory of `helixfold` and run: ```bash -wget https://paddle-wheel.bj.bcebos.com/develop/linux/linux-gpu-cuda11.2-cudnn8-mkl-gcc8.2-avx/paddlepaddle_gpu-0.0.0.post112-cp37-cp37m-linux_x86_64.whl +wget https://baidu-nlp.bj.bcebos.com/PaddleHelix/HelixFold/paddlepaddle_gpu-2.4.1-cp37-cp37m-linux_x86_64.whl sh setup_env conda activate helixfold # activate the conda environment ``` diff --git a/apps/protein_folding/helixfold/alphafold_paddle/data/pipeline.py b/apps/protein_folding/helixfold/alphafold_paddle/data/pipeline.py index b78968eb..33f978b0 100644 --- a/apps/protein_folding/helixfold/alphafold_paddle/data/pipeline.py +++ b/apps/protein_folding/helixfold/alphafold_paddle/data/pipeline.py @@ -157,6 +157,8 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: hhsearch_hits = parsers.parse_hhr(hhsearch_result) mgnify_msa = mgnify_msa[:self.mgnify_max_hits] mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits] + uniref90_msa = uniref90_msa[:self.uniref_max_hits] + uniref90_deletion_matrix = uniref90_deletion_matrix[:self.uniref_max_hits] if self._use_small_bfd: jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query( diff --git a/apps/protein_folding/helixfold/alphafold_paddle/data/utils.py b/apps/protein_folding/helixfold/alphafold_paddle/data/utils.py index 4a6d8e1e..4754892e 100644 --- a/apps/protein_folding/helixfold/alphafold_paddle/data/utils.py +++ b/apps/protein_folding/helixfold/alphafold_paddle/data/utils.py @@ -114,13 +114,16 @@ def load_labels(cif_path: str, pdb_id: str, chain_id: str = 'A') -> FeatureDict: # keys that should be ignored when conducting crop & pad def is_ignored_key(k): + """tbd.""" return k in ignored_keys # keys that have batch dim, e.g. msa features which have shape [N_msa, N_res, ...] def is_batched_key(k): + """tbd.""" return k in batched_keys def align_feat(feat, size): + """Align feature.""" # get num res from aatype assert 'aatype' in feat.keys(), \ "'aatype' missing from batch, which is not expected." @@ -148,7 +151,32 @@ def pad(key, array, start_axis, align_size, num_res): return feat +def align_label(label, size): + """Align label.""" + num_res = label['all_atom_mask'].shape[1] + + if num_res % size != 0: + align_size = (num_res // size + 1) * size + + def pad(key, array, start_axis, align_size, num_res): + if is_ignored_key(key): + return array + d_seq = start_axis # choose the dim to crop / pad + if is_batched_key(key): + d_seq += 1 + pad_shape = list(array.shape) + pad_shape[d_seq] = align_size - num_res + pad_array = paddle.zeros(pad_shape, dtype=array.dtype) + array = paddle.concat([array, pad_array], axis=d_seq) + return array + + label = {k: pad(k, v, 1, align_size, num_res) for k, v in label.items()} + + return label + + def unpad_prediction(feat, pred): + """Unpad prediction.""" unpad_pred = deepcopy(pred) n = feat['aatype'].shape[0] diff --git a/apps/protein_folding/helixfold/gpu_infer.sh b/apps/protein_folding/helixfold/gpu_infer.sh index 2892d2a8..dd84b266 100644 --- a/apps/protein_folding/helixfold/gpu_infer.sh +++ b/apps/protein_folding/helixfold/gpu_infer.sh @@ -58,6 +58,7 @@ else --model_names=${MODELS} \ --output_dir=${OUTPUT_DIR} \ --disable_amber_relax \ + --seed 2022 \ --preset='reduced_dbs' \ --random_seed=0 \ ${@:2} diff --git a/apps/protein_folding/helixfold/requirements.txt b/apps/protein_folding/helixfold/requirements.txt index 803b2594..de103ebc 100644 --- a/apps/protein_folding/helixfold/requirements.txt +++ b/apps/protein_folding/helixfold/requirements.txt @@ -13,4 +13,4 @@ scipy==1.7.0 tensorflow-cpu==2.5.0 tensorboardX==2.5 etcd3 -./paddlepaddle_gpu-0.0.0.post112-cp37-cp37m-linux_x86_64.whl \ No newline at end of file +./paddlepaddle_gpu-2.4.1-cp37-cp37m-linux_x86_64.whl \ No newline at end of file diff --git a/apps/protein_folding/helixfold/train.py b/apps/protein_folding/helixfold/train.py index fd80f336..09d1afa3 100644 --- a/apps/protein_folding/helixfold/train.py +++ b/apps/protein_folding/helixfold/train.py @@ -39,7 +39,7 @@ from utils.init_env import init_seed, init_distributed_env from utils.misc import TrainLogger, set_logging_level from alphafold_paddle.model import config -from alphafold_paddle.data.utils import align_feat +from alphafold_paddle.data.utils import align_feat, align_label from ppfleetx.distributed.protein_folding import dap, bp, dp from ppfleetx.distributed.protein_folding.scg import scg @@ -164,6 +164,7 @@ def eval(args, model, eval_dataset, compute_loss, cache_dir=None): s1 = time_me() if args.dap_degree > 1: batch['feat'] = align_feat(batch['feat'], args.dap_degree) + batch['label'] = align_label(batch['label'], args.dap_degree) res = model(batch, compute_loss=compute_loss) if compute_loss: diff --git a/apps/protein_folding/helixfold/utils/metric.py b/apps/protein_folding/helixfold/utils/metric.py index 0707ecdf..bd9bf0b4 100644 --- a/apps/protein_folding/helixfold/utils/metric.py +++ b/apps/protein_folding/helixfold/utils/metric.py @@ -30,8 +30,8 @@ def dist_all_reduce(x, return_num=False, distributed=False): x_num = len(x) x_sum = 0 if x_num == 0 else np.sum(x) if distributed: - x_num = dp.all_reduce(paddle.to_tensor(x_num, dtype='int64')).numpy()[0] - x_sum = dp.all_reduce(paddle.to_tensor(x_sum, dtype='float32')).numpy()[0] + x_num = int(dp.all_reduce(paddle.to_tensor(x_num, dtype='int64'))) + x_sum = float(dp.all_reduce(paddle.to_tensor(x_sum, dtype='float32'))) x_mean = 0 if x_num == 0 else x_sum / x_num if return_num: return x_mean, x_num diff --git a/apps/protein_folding/helixfold_cpu/.github/BP_DAP_DP.png b/apps/protein_folding/helixfold_cpu/.github/BP_DAP_DP.png new file mode 100644 index 00000000..d1512b1b Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/BP_DAP_DP.png differ diff --git a/apps/protein_folding/helixfold_cpu/.github/HelixFold_accuracy.png b/apps/protein_folding/helixfold_cpu/.github/HelixFold_accuracy.png new file mode 100644 index 00000000..eebbaac0 Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/HelixFold_accuracy.png differ diff --git a/apps/protein_folding/helixfold_cpu/.github/HelixFold_computational_performance.png b/apps/protein_folding/helixfold_cpu/.github/HelixFold_computational_performance.png new file mode 100644 index 00000000..4ef748d5 Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/HelixFold_computational_performance.png differ diff --git a/apps/protein_folding/helixfold_cpu/.github/HelixFold_perf.png b/apps/protein_folding/helixfold_cpu/.github/HelixFold_perf.png new file mode 100644 index 00000000..0b418b09 Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/HelixFold_perf.png differ diff --git a/apps/protein_folding/helixfold_cpu/.github/HelixFold_perf_compare.png b/apps/protein_folding/helixfold_cpu/.github/HelixFold_perf_compare.png new file mode 100644 index 00000000..0ef0007e Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/HelixFold_perf_compare.png differ diff --git a/apps/protein_folding/helixfold_cpu/.github/LIT-PCBA_result.png b/apps/protein_folding/helixfold_cpu/.github/LIT-PCBA_result.png new file mode 100644 index 00000000..2fbd4ee4 Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/LIT-PCBA_result.png differ diff --git a/apps/protein_folding/helixfold_cpu/.github/LinearRNA.jpg b/apps/protein_folding/helixfold_cpu/.github/LinearRNA.jpg new file mode 100644 index 00000000..bab55403 Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/LinearRNA.jpg differ diff --git a/apps/protein_folding/helixfold_cpu/.github/PaddleHelix_Structure.png b/apps/protein_folding/helixfold_cpu/.github/PaddleHelix_Structure.png new file mode 100644 index 00000000..df852f9b Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/PaddleHelix_Structure.png differ diff --git a/apps/protein_folding/helixfold_cpu/.github/helixfold_pipeline.png b/apps/protein_folding/helixfold_cpu/.github/helixfold_pipeline.png new file mode 100644 index 00000000..b3df6f0e Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/helixfold_pipeline.png differ diff --git a/apps/protein_folding/helixfold_cpu/.github/memory_optimize.png b/apps/protein_folding/helixfold_cpu/.github/memory_optimize.png new file mode 100644 index 00000000..14afb67a Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/memory_optimize.png differ diff --git a/apps/protein_folding/helixfold_cpu/.github/op_fuse.png b/apps/protein_folding/helixfold_cpu/.github/op_fuse.png new file mode 100644 index 00000000..b1e8db00 Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/op_fuse.png differ diff --git a/apps/protein_folding/helixfold_cpu/.github/optimus_framework3.png b/apps/protein_folding/helixfold_cpu/.github/optimus_framework3.png new file mode 100644 index 00000000..ff232bdb Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/optimus_framework3.png differ diff --git a/apps/protein_folding/helixfold_cpu/.github/paddlehelix_features.jpg b/apps/protein_folding/helixfold_cpu/.github/paddlehelix_features.jpg new file mode 100644 index 00000000..9ff0d892 Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/paddlehelix_features.jpg differ diff --git a/apps/protein_folding/helixfold_cpu/.github/paddlehelix_logo.png b/apps/protein_folding/helixfold_cpu/.github/paddlehelix_logo.png new file mode 100644 index 00000000..6c80c636 Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/paddlehelix_logo.png differ diff --git a/apps/protein_folding/helixfold_cpu/.github/pcqm4mv2_result.png b/apps/protein_folding/helixfold_cpu/.github/pcqm4mv2_result.png new file mode 100644 index 00000000..e7f71282 Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/pcqm4mv2_result.png differ diff --git a/apps/protein_folding/helixfold_cpu/.github/tensor_fuse.png b/apps/protein_folding/helixfold_cpu/.github/tensor_fuse.png new file mode 100644 index 00000000..f9707cf4 Binary files /dev/null and b/apps/protein_folding/helixfold_cpu/.github/tensor_fuse.png differ diff --git "a/apps/protein_folding/helixfold_cpu/.github/\351\243\236\346\241\250-\350\236\272\346\227\213\346\241\250_logo.png" "b/apps/protein_folding/helixfold_cpu/.github/\351\243\236\346\241\250-\350\236\272\346\227\213\346\241\250_logo.png" new file mode 100644 index 00000000..45e72c29 Binary files /dev/null and "b/apps/protein_folding/helixfold_cpu/.github/\351\243\236\346\241\250-\350\236\272\346\227\213\346\241\250_logo.png" differ diff --git a/apps/protein_folding/helixfold_cpu/.gitignore b/apps/protein_folding/helixfold_cpu/.gitignore new file mode 100644 index 00000000..215b6e2c --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/.gitignore @@ -0,0 +1,10 @@ +**/*pyc +**/__pycache__ +*/*/__pycache__ +*/*/*/__pycache__ +*/*/*/scripts +paddlecloud* +internal* +*/internal* +.DS_Store +*/.DS_Store diff --git a/apps/protein_folding/helixfold_cpu/confidence.py b/apps/protein_folding/helixfold_cpu/confidence.py new file mode 100644 index 00000000..e152be1e --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/confidence.py @@ -0,0 +1,155 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for processing confidence metrics.""" + +from typing import Dict, Optional, Tuple +import numpy as np +import scipy.special + + +def compute_plddt(logits: np.ndarray) -> np.ndarray: + """Computes per-residue pLDDT from logits. + + Args: + logits: [num_res, num_bins] output from the PredictedLDDTHead. + + Returns: + plddt: [num_res] per-residue pLDDT. + """ + num_bins = logits.shape[-1] + bin_width = 1.0 / num_bins + bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width) + probs = scipy.special.softmax(logits, axis=-1) + predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1) + return predicted_lddt_ca * 100 + + +def _calculate_bin_centers(breaks: np.ndarray): + """Gets the bin centers from the bin edges. + + Args: + breaks: [num_bins - 1] the error bin edges. + + Returns: + bin_centers: [num_bins] the error bin centers. + """ + step = (breaks[1] - breaks[0]) + + # Add half-step to get the center + bin_centers = breaks + step / 2 + # Add a catch-all bin at the end. + bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]], + axis=0) + return bin_centers + + +def _calculate_expected_aligned_error( + alignment_confidence_breaks: np.ndarray, + aligned_distance_error_probs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Calculates expected aligned distance errors for every pair of residues. + + Args: + alignment_confidence_breaks: [num_bins - 1] the error bin edges. + aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted + probs for each error bin, for each pair of residues. + + Returns: + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + bin_centers = _calculate_bin_centers(alignment_confidence_breaks) + + # Tuple of expected aligned distance error and max possible error. + return (np.sum(aligned_distance_error_probs * bin_centers, axis=-1), + np.asarray(bin_centers[-1])) + + +def compute_predicted_aligned_error( + logits: np.ndarray, + breaks: np.ndarray) -> Dict[str, np.ndarray]: + """Computes aligned confidence metrics from logits. + + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins - 1] the error bin edges. + + Returns: + aligned_confidence_probs: [num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + aligned_confidence_probs = scipy.special.softmax( + logits, + axis=-1) + predicted_aligned_error, max_predicted_aligned_error = ( + _calculate_expected_aligned_error( + alignment_confidence_breaks=breaks, + aligned_distance_error_probs=aligned_confidence_probs)) + return { + 'aligned_confidence_probs': aligned_confidence_probs, + 'predicted_aligned_error': predicted_aligned_error, + 'max_predicted_aligned_error': max_predicted_aligned_error, + } + + +def predicted_tm_score( + logits: np.ndarray, + breaks: np.ndarray, + residue_weights: Optional[np.ndarray] = None) -> np.ndarray: + """Computes predicted TM alignment score. + + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins] the error bins. + residue_weights: [num_res] the per residue weights to use for the + expectation. + + Returns: + ptm_score: the predicted TM alignment score. + """ + + # residue_weights has to be in [0, 1], but can be floating-point, i.e. the + # exp. resolved head's probability. + if residue_weights is None: + residue_weights = np.ones(logits.shape[0]) + + bin_centers = _calculate_bin_centers(breaks) + + num_res = np.sum(residue_weights) + # Clip num_res to avoid negative/undefined d0. + clipped_num_res = max(num_res, 19) + + # Compute d_0(num_res) as defined by TM-score, eqn. (5) in + # http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf + # Yang & Skolnick "Scoring function for automated + # assessment of protein structure template quality" 2004 + d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8 + + # Convert logits to probs + probs = scipy.special.softmax(logits, axis=-1) + + # TM-Score term for every bin + tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0)) + # E_distances tm(distance) + predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1) + + normed_residue_mask = residue_weights / (1e-8 + residue_weights.sum()) + per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1) + return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()]) diff --git a/apps/protein_folding/helixfold_cpu/config.py b/apps/protein_folding/helixfold_cpu/config.py new file mode 100644 index 00000000..8bcf1dbf --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/config.py @@ -0,0 +1,463 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model config.""" + +import copy +import ml_collections + + +NUM_RES = 'num residues placeholder' +NUM_MSA_SEQ = 'msa placeholder' +NUM_EXTRA_SEQ = 'extra msa placeholder' +NUM_TEMPLATES = 'num templates placeholder' + + +def model_config(name: str) -> ml_collections.ConfigDict: + """Get the ConfigDict of a CASP14 model.""" + if name not in CONFIG_DIFFS: + raise ValueError(f'Invalid model name {name}.') + cfg = copy.deepcopy(CONFIG) + cfg.update_from_flattened_dict(CONFIG_DIFFS[name]) + return cfg + + +CONFIG_DIFFS = { + 'model_1': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.1.1 + 'data.common.max_extra_msa': 5120, + 'data.common.reduce_msa_clusters_by_max_templates': True, + 'data.common.use_templates': True, + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True + }, + 'model_2': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.1.2 + 'data.common.reduce_msa_clusters_by_max_templates': True, + 'data.common.use_templates': True, + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True + }, + 'model_3': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.2.1 + 'data.common.max_extra_msa': 5120, + }, + 'model_4': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.2.2 + 'data.common.max_extra_msa': 5120, + }, + 'model_5': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.2.3 + 'model.global_config.subbatch_size': 32, + }, + 'initial_model_5_dcu': { + 'data.eval.max_msa_clusters': 128, + 'data.common.max_extra_msa': 512, + 'model.global_config.subbatch_size': 64, + 'model.heads.structure_module.structural_violation_loss_weight': 0.0, + 'model.heads.experimentally_resolved.weight': 0.0, + }, + 'initial': { + 'data.eval.max_msa_clusters': 128, + 'data.common.max_extra_msa': 1024, + 'data.common.reduce_msa_clusters_by_max_templates': True, + 'data.common.use_templates': True, + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True, + 'model.heads.structure_module.structural_violation_loss_weight': 0.0, + 'model.heads.experimentally_resolved.weight': 0.0, + }, + 'finetune': { + 'data.eval.max_msa_clusters': 512, + 'data.common.max_extra_msa': 5120, + 'data.common.reduce_msa_clusters_by_max_templates': True, + 'data.common.use_templates': True, + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True, + }, + + 'msa128_ex1024_vio0': { + 'data.eval.max_msa_clusters': 128, + 'data.common.max_extra_msa': 1024, + 'model.heads.structure_module.structural_violation_loss_weight': 0.0, + 'model.heads.experimentally_resolved.weight': 0.0, + }, + 'msa128_ex1024_temp_vio0': { + 'data.eval.max_msa_clusters': 128, + 'data.common.max_extra_msa': 1024, + 'data.common.reduce_msa_clusters_by_max_templates': True, + 'data.common.use_templates': True, + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True, + 'model.heads.structure_module.structural_violation_loss_weight': 0.0, + 'model.heads.experimentally_resolved.weight': 0.0, + }, + 'msa128_ex1024_ft': { + 'data.eval.max_msa_clusters': 128, + 'data.common.max_extra_msa': 1024, + 'model.heads.structure_module.structural_violation_loss_weight': 1.0, + 'model.heads.experimentally_resolved.weight': 0.01, + }, + 'msa512_ex1024': { + 'data.eval.max_msa_clusters': 512, + 'data.common.max_extra_msa': 1024, + }, + 'msa512_ex1024_rec6': { + 'data.common.num_recycle': 6, + 'model.num_recycle': 6, + 'data.eval.max_msa_clusters': 512, + 'data.common.max_extra_msa': 1024, + }, + + # The following models are fine-tuned from the corresponding models above + # with an additional predicted_aligned_error head that can produce + # predicted TM-score (pTM) and predicted aligned errors. + 'model_1_ptm': { + 'data.common.max_extra_msa': 5120, + 'data.common.reduce_msa_clusters_by_max_templates': True, + 'data.common.use_templates': True, + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True, + 'model.heads.predicted_aligned_error.weight': 0.1 + }, + 'model_2_ptm': { + 'data.common.reduce_msa_clusters_by_max_templates': True, + 'data.common.use_templates': True, + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True, + 'model.heads.predicted_aligned_error.weight': 0.1 + }, + 'model_3_ptm': { + 'data.common.max_extra_msa': 5120, + 'model.heads.predicted_aligned_error.weight': 0.1 + }, + 'model_4_ptm': { + 'data.common.max_extra_msa': 5120, + 'model.heads.predicted_aligned_error.weight': 0.1 + }, + 'model_5_ptm': { + 'model.heads.predicted_aligned_error.weight': 0.1 + } +} + +CONFIG = ml_collections.ConfigDict({ + 'data': { + 'common': { + 'masked_msa': { + 'profile_prob': 0.1, + 'same_prob': 0.1, + 'uniform_prob': 0.1 + }, + 'max_extra_msa': 1024, + 'msa_cluster_features': True, + 'num_recycle': 3, + 'reduce_msa_clusters_by_max_templates': False, + 'resample_msa_in_recycling': True, + 'template_features': [ + 'template_all_atom_positions', 'template_sum_probs', + 'template_aatype', 'template_all_atom_masks', + 'template_domain_names' + ], + 'unsupervised_features': [ + 'aatype', 'residue_index', 'sequence', 'msa', 'domain_name', + 'num_alignments', 'seq_length', 'between_segment_residues', + 'deletion_matrix' + ], + 'use_templates': False, + }, + 'eval': { + 'feat': { + 'aatype': [NUM_RES], + 'all_atom_mask': [NUM_RES, None], + 'all_atom_positions': [NUM_RES, None, None], + 'alt_chi_angles': [NUM_RES, None], + 'atom14_alt_gt_exists': [NUM_RES, None], + 'atom14_alt_gt_positions': [NUM_RES, None, None], + 'atom14_atom_exists': [NUM_RES, None], + 'atom14_atom_is_ambiguous': [NUM_RES, None], + 'atom14_gt_exists': [NUM_RES, None], + 'atom14_gt_positions': [NUM_RES, None, None], + 'atom37_atom_exists': [NUM_RES, None], + 'backbone_affine_mask': [NUM_RES], + 'backbone_affine_tensor': [NUM_RES, None], + 'bert_mask': [NUM_MSA_SEQ, NUM_RES], + 'chi_angles': [NUM_RES, None], + 'chi_mask': [NUM_RES, None], + 'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa_row_mask': [NUM_EXTRA_SEQ], + 'is_distillation': [], + 'msa_feat': [NUM_MSA_SEQ, NUM_RES, None], + 'msa_mask': [NUM_MSA_SEQ, NUM_RES], + 'msa_row_mask': [NUM_MSA_SEQ], + 'pseudo_beta': [NUM_RES, None], + 'pseudo_beta_mask': [NUM_RES], + 'random_crop_to_size_seed': [None], + 'residue_index': [NUM_RES], + 'residx_atom14_to_atom37': [NUM_RES, None], + 'residx_atom37_to_atom14': [NUM_RES, None], + 'resolution': [], + 'rigidgroups_alt_gt_frames': [NUM_RES, None, None], + 'rigidgroups_group_exists': [NUM_RES, None], + 'rigidgroups_group_is_ambiguous': [NUM_RES, None], + 'rigidgroups_gt_exists': [NUM_RES, None], + 'rigidgroups_gt_frames': [NUM_RES, None, None], + 'seq_length': [], + 'seq_mask': [NUM_RES], + 'target_feat': [NUM_RES, None], + 'template_aatype': [NUM_TEMPLATES, NUM_RES], + 'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None], + 'template_all_atom_positions': [ + NUM_TEMPLATES, NUM_RES, None, None], + 'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES], + 'template_backbone_affine_tensor': [ + NUM_TEMPLATES, NUM_RES, None], + 'template_mask': [NUM_TEMPLATES], + 'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None], + 'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES], + 'template_sum_probs': [NUM_TEMPLATES, None], + 'true_msa': [NUM_MSA_SEQ, NUM_RES] + }, + 'fixed_size': True, + 'subsample_templates': False, # We want top templates. + 'masked_msa_replace_fraction': 0.15, + 'max_msa_clusters': 512, + 'max_templates': 4, + 'num_ensemble': 1, + 'num_blocks': 5, # for msa block deletion + 'randomize_num_blocks': False, + 'msa_fraction_per_block': 0.3, + }, + }, + 'model': { + 'embeddings_and_evoformer': { + 'evoformer_num_block': 48, + 'evoformer': { + 'msa_row_attention_with_pair_bias': { + 'dropout_rate': 0.15, + 'gating': True, + 'num_head': 8, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'msa_column_attention': { + 'dropout_rate': 0.0, + 'gating': True, + 'num_head': 8, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'msa_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'outer_product_mean': { + 'chunk_size': 128, + 'dropout_rate': 0.0, + 'num_outer_channel': 32, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_attention_starting_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_attention_ending_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'triangle_multiplication_outgoing': { + 'dropout_rate': 0.25, + 'equation': 'ikc,jkc->ijc', + 'num_intermediate_channel': 128, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_multiplication_incoming': { + 'dropout_rate': 0.25, + 'equation': 'kjc,kic->ijc', + 'num_intermediate_channel': 128, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'pair_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 4, + 'orientation': 'per_row', + 'shared_dropout': True + } + }, + 'extra_msa_channel': 64, + 'extra_msa_stack_num_block': 4, + 'max_relative_feature': 32, + 'msa_channel': 256, + 'pair_channel': 128, + 'prev_pos': { + 'min_bin': 3.25, + 'max_bin': 20.75, + 'num_bins': 15 + }, + 'recycle_features': True, + 'recycle_pos': True, + 'seq_channel': 384, + 'template': { + 'attention': { + 'gating': False, + 'key_dim': 64, + 'num_head': 4, + 'value_dim': 64 + }, + 'dgram_features': { + 'min_bin': 3.25, + 'max_bin': 50.75, + 'num_bins': 39 + }, + 'embed_torsion_angles': False, + 'enabled': False, + 'template_pair_stack': { + 'num_block': 2, + 'triangle_attention_starting_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'key_dim': 64, + 'num_head': 4, + 'orientation': 'per_row', + 'shared_dropout': True, + 'value_dim': 64 + }, + 'triangle_attention_ending_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'key_dim': 64, + 'num_head': 4, + 'orientation': 'per_column', + 'shared_dropout': True, + 'value_dim': 64 + }, + 'triangle_multiplication_outgoing': { + 'dropout_rate': 0.25, + 'equation': 'ikc,jkc->ijc', + 'num_intermediate_channel': 64, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_multiplication_incoming': { + 'dropout_rate': 0.25, + 'equation': 'kjc,kic->ijc', + 'num_intermediate_channel': 64, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'pair_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 2, + 'orientation': 'per_row', + 'shared_dropout': True + } + }, + 'max_templates': 4, + 'subbatch_size': 48, + 'use_template_unit_vector': False, + } + }, + 'global_config': { + 'deterministic': False, + 'subbatch_size': 48, + 'use_remat': False, + 'zero_init': True + }, + 'heads': { + 'distogram': { + 'first_break': 2.3125, + 'last_break': 21.6875, + 'num_bins': 64, + 'weight': 0.3 + }, + 'predicted_aligned_error': { + # `num_bins - 1` bins uniformly space the + # [0, max_error_bin A] range. + # The final bin covers [max_error_bin A, +infty] + # 31A gives bins with 0.5A width. + 'max_error_bin': 31., + 'num_bins': 64, + 'num_channels': 128, + 'filter_by_resolution': True, + 'min_resolution': 0.1, + 'max_resolution': 3.0, + 'weight': 0.0, + }, + 'experimentally_resolved': { + 'filter_by_resolution': True, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'weight': 0.01 + }, + 'structure_module': { + 'num_layer': 8, + 'fape': { + 'clamp_distance': 10.0, + 'clamp_type': 'relu', + 'loss_unit_distance': 10.0 + }, + 'angle_norm_weight': 0.01, + 'chi_weight': 0.5, + 'clash_overlap_tolerance': 1.5, + 'compute_in_graph_metrics': True, + 'dropout': 0.1, + 'num_channel': 384, + 'num_head': 12, + 'num_layer_in_transition': 3, + 'num_point_qk': 4, + 'num_point_v': 8, + 'num_scalar_qk': 16, + 'num_scalar_v': 16, + 'position_scale': 10.0, + 'sidechain': { + 'atom_clamp_distance': 10.0, + 'num_channel': 128, + 'num_residual_block': 2, + 'weight_frac': 0.5, + 'length_scale': 10., + }, + 'structural_violation_loss_weight': 1.0, + 'violation_tolerance_factor': 12.0, + 'weight': 1.0 + }, + 'predicted_lddt': { + 'filter_by_resolution': True, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'num_bins': 50, + 'num_channels': 128, + 'weight': 0.01 + }, + 'masked_msa': { + 'num_output': 23, + 'weight': 2.0 + }, + }, + 'num_recycle': 3, + 'resample_msa_in_recycling': True + }, +}) diff --git a/apps/protein_folding/helixfold_cpu/dap.py b/apps/protein_folding/helixfold_cpu/dap.py new file mode 100644 index 00000000..9b6a2601 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/dap.py @@ -0,0 +1,507 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Dynamic Axial Parallelism and Duality Async Operation helper functions +paper ref: FastFold: Reducing AlphaFold Training Time from 11 Days to 67 Hours, https://arxiv.org/abs/2203.00854 +code ref: https://github.com/hpcaitech/FastFold.git +""" + +import warnings +import paddle +from paddle import distributed as dist +from paddle.autograd import PyLayer + +__all__ = [ + 'init_dap', + 'dap_is_initialized', + 'get_tensor_model_parallel_group', + 'get_data_parallel_group', + 'get_tensor_model_parallel_world_size', + 'get_tensor_model_parallel_rank', + 'get_data_parallel_world_size', + 'get_data_parallel_rank', + 'get_tensor_model_parallel_src_rank', + 'scatter', + 'gather', + 'all_gather', + 'all_gather_opp', + 'all_to_all', + 'all_to_all_opp', + 'row_to_col', + 'col_to_row' + ] + +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None +# Intra-layer model parallel group that the current rank belongs to. +_TENSOR_MODEL_PARALLEL_GROUP = None + +# These values enable us to change the mpu sizes on the fly. +_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None +_TENSOR_MODEL_PARALLEL_RANK = None + +# communication whether use_calc_stream (sync) or not (async). Default True +_COMM_SYNC = None + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator) + + +def divide(numerator, denominator): + ensure_divisibility(numerator, denominator) + return numerator // denominator + +def init_dap(tensor_model_parallel_size_=1, sync=True): + + global _COMM_SYNC + assert _COMM_SYNC is None, \ + 'communication manner `sync` is already initialized' + _COMM_SYNC = sync + + world_size = dist.get_world_size() + rank = dist.get_rank() + + # check dist config + ensure_divisibility(world_size, tensor_model_parallel_size_) + data_parallel_size_ = world_size // tensor_model_parallel_size_ + + # Build the data-parallel groups. + global _DATA_PARALLEL_GROUP + assert _DATA_PARALLEL_GROUP is None, \ + 'data parallel group is already initialized' + for i in range(tensor_model_parallel_size_): + ranks = list(range(i, world_size, tensor_model_parallel_size_)) + group = dist.new_group(ranks) + print('> dp ranks:', ranks, 'dp group:', group) + if rank in ranks: + _DATA_PARALLEL_GROUP = group + + global _TENSOR_MODEL_PARALLEL_GROUP + assert _TENSOR_MODEL_PARALLEL_GROUP is None, \ + 'tensor model parallel group is already initialized' + # Build the model-parallel groups. + for i in range(data_parallel_size_): + ranks = list(range(i * tensor_model_parallel_size_, (i + 1) * tensor_model_parallel_size_)) + group = dist.new_group(ranks) + print('> mp ranks:', ranks, 'mp group', group) + if rank in ranks: + _TENSOR_MODEL_PARALLEL_GROUP = group + + if dist.get_rank() == 0: + print('> initialize tensor model parallel with size {}'.format(tensor_model_parallel_size_)) + print('> initialize data parallel with size {}'.format(data_parallel_size_)) + +def dap_is_initialized(): + """Check if model and data parallel groups are initialized.""" + global _DATA_PARALLEL_GROUP + global _TENSOR_MODEL_PARALLEL_GROUP + if _TENSOR_MODEL_PARALLEL_GROUP is None or \ + _DATA_PARALLEL_GROUP is None: + return False + return True + +def is_sync(): + return _COMM_SYNC + +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \ + 'intra_layer_model parallel group is not initialized' + return _TENSOR_MODEL_PARALLEL_GROUP + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, \ + 'data parallel group is not initialized' + return _DATA_PARALLEL_GROUP + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + if not dap_is_initialized(): + warnings.warn("DAP comminication group is not initialized.") + return 1 + global _TENSOR_MODEL_PARALLEL_WORLD_SIZE + if _TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: + return _TENSOR_MODEL_PARALLEL_WORLD_SIZE + return get_tensor_model_parallel_group().nranks + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + if not dap_is_initialized(): + warnings.warn("DAP comminication group is not initialized.") + return 0 + global _TENSOR_MODEL_PARALLEL_RANK + if _TENSOR_MODEL_PARALLEL_RANK is not None: + return _TENSOR_MODEL_PARALLEL_RANK + return get_tensor_model_parallel_group().rank + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + if not dap_is_initialized(): + warnings.warn("DAP comminication group is not initialized.") + return 1 + return get_data_parallel_group().nranks + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + if not dap_is_initialized(): + warnings.warn("DAP comminication group is not initialized.") + return 0 + return get_data_parallel_group().rank + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = dist.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +@paddle.no_grad() +def _gather(tensor, axis=-1): + tensor_list = [] + dist.all_gather(tensor_list, + tensor, + group=get_tensor_model_parallel_group(), + use_calc_stream=True) + output = paddle.concat(tensor_list, axis=axis) + return output + + +@paddle.no_grad() +def _split(tensor, axis=-1): + ensure_divisibility(tensor.shape[axis], get_tensor_model_parallel_world_size()) + tensor_list = paddle.split(tensor, get_tensor_model_parallel_world_size(), axis=axis) + + output = tensor_list[get_tensor_model_parallel_rank()] + + return output + + +class Scatter(PyLayer): + """ Scatter PyLayer Op""" + @staticmethod + def forward(ctx, input, axis:-1): + ctx.axis = axis + return _split(input, axis=axis) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, axis=ctx.axis) + + +def scatter(input, axis=-1): + """ split a tensor according axis by dap size """ + if get_tensor_model_parallel_world_size() == 1: + return input + + if not input.stop_gradient: + output = Scatter.apply(input, axis=axis) + else: + output = _split(input, axis=axis) + return output + + +class Gather(PyLayer): + """ Gather PyLayer Op """ + @staticmethod + def forward(ctx, input, axis=-1): + ctx.axis = axis + return _gather(input, axis=axis) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, axis=ctx.axis) + +def gather(input, axis=-1): + """ gather tensor form all rank in dap group in axis """ + if get_tensor_model_parallel_world_size() == 1: + return input + + if not input.stop_gradient: + output = Gather.apply(input, axis=axis) + else: + output = _gather(input, axis=axis) + return output + +@paddle.no_grad() +def _all_gather(tensor, axis=-1, sync=True): + if not sync: + dist.wait(tensor, group=get_tensor_model_parallel_group(), use_calc_stream=True) + + group = get_tensor_model_parallel_group() + ring_id = group.id + nranks = group.nranks + output = paddle._C_ops.c_allgather(tensor, 'use_calc_stream', sync, 'ring_id', ring_id, 'nranks', nranks) + + return output + +@paddle.no_grad() +def _reduce_scatter(tensor, sync=True): + if not sync: + dist.wait(tensor, group=get_tensor_model_parallel_group(), use_calc_stream=True) + + group = get_tensor_model_parallel_group() + ring_id = group.id + nranks = group.nranks + output = paddle._C_ops.c_reducescatter(tensor, 'use_calc_stream', sync, 'ring_id', ring_id, 'nranks', nranks) + paddle.device.cuda.synchronize() + return output + +class AllGather(PyLayer): + """ AllGather PyLayer Op """ + @staticmethod + def forward(ctx, input, axis=-1, sync=True): + ctx.axis = axis + ctx.sync = sync + output = _all_gather(input, axis=axis, sync=sync) + return output + + @staticmethod + def backward(ctx, grad_output): + if not ctx.sync: + dist.wait(grad_output, group=get_tensor_model_parallel_group(), use_calc_stream=ctx.sync) + return grad_output + +class AllGather_Opp(PyLayer): + """ Duality Async Operation for AllGather """ + @staticmethod + def forward(ctx, input, axis=-1, sync=True): + ctx.axis = axis + ctx.sync = sync + return input + + @staticmethod + def backward(ctx, grad_output): + output = _reduce_scatter(grad_output, sync=ctx.sync) + return output + + +def all_gather(input, axis=-1, sync=None): + """ gather tensors from all rank in dap group and all get the result. + if sync=None, sync will be assign according init_dap setting. + + when using async communication, sync=False, do not use the output as same as input. + E.g. do not use `a = all_gather(a, ...)`, recommend to use `b = all_gather(a, ...)` + """ + if get_tensor_model_parallel_world_size() == 1: + return input + + if sync is None: + sync = is_sync() + + if not input.stop_gradient: + output = AllGather.apply(input, axis, sync=sync) + else: + output = _all_gather(input, axis, sync=sync) + return output + + +def all_gather_opp(output, axis=-1, sync=None): + """ Duality Async Operation for all_gather. + if sync=None, sync will be assign according init_dap setting. + """ + nranks = get_tensor_model_parallel_world_size() + if nranks == 1: + return output + + if sync is None: + sync = is_sync() + + if not sync: + dist.wait(output, group=get_tensor_model_parallel_group(), use_calc_stream=sync) + + if not output.stop_gradient: + output = AllGather_Opp.apply(output, axis, sync=sync) + + if axis != 0: + output = paddle.concat(paddle.split(output, nranks, 0), axis=axis) + + return output + + +@paddle.no_grad() +def _all_to_all(tensor, in_axis=-1, out_axis=-1, sync=True): + if not sync: + dist.wait(tensor, group=get_tensor_model_parallel_group(), use_calc_stream=True) + + group = get_tensor_model_parallel_group() + ring_id = group.id + + output = paddle._C_ops.alltoall(tensor, 'use_calc_stream', sync, 'ring_id', ring_id) + + return output + + +class All_to_All(PyLayer): + """ All_to_All PyLayer Op""" + @staticmethod + def forward(ctx, + input, + in_axis=-1, + out_axis=-1, + sync=True): + ctx.in_axis = in_axis + ctx.out_axis = out_axis + ctx.sync = sync + return _all_to_all(input, in_axis=in_axis, out_axis=out_axis, sync=sync) + + @staticmethod + def backward(ctx, grad_output): + if not ctx.sync: + dist.wait(grad_output, group=get_tensor_model_parallel_group(), use_calc_stream=ctx.sync) + return grad_output + + +class All_to_All_Opp(PyLayer): + """ Duality Async Operation for All_to_All """ + @staticmethod + def forward(ctx, output, in_axis=-1, out_axis=-1, sync=True): + ctx.in_axis = in_axis + ctx.out_axis = out_axis + ctx.sync = sync + return output + + @staticmethod + def backward(ctx, grad_output): + return _all_to_all(grad_output, in_axis=ctx.out_axis, out_axis=ctx.in_axis, sync=ctx.sync) + + +class All2All(PyLayer): + @staticmethod + def forward(ctx, + input, + in_axis=-1, + out_axis=-1): + ctx.in_axis = in_axis + ctx.out_axis = out_axis + return _all_to_all(input, in_axis=in_axis, out_axis=out_axis, sync=True) + + @staticmethod + def backward(ctx, grad_output): + return _all_to_all(grad_output, in_axis=ctx.out_axis, out_axis=ctx.in_axis, sync=True) + + +def all_to_all(input, in_axis, out_axis, sync=True): + """ all to all according in_axis and out_axis. + if sync=None, sync will be assign according init_dap setting. + """ + if get_tensor_model_parallel_world_size() == 1: + return input + + if sync is None: + sync = is_sync() + + if in_axis != 0: + ensure_divisibility(input.shape[in_axis], get_tensor_model_parallel_world_size()) + input = paddle.concat(paddle.split(input, get_tensor_model_parallel_world_size(), axis=in_axis), axis=0) + + if not input.stop_gradient: + output = All_to_All.apply(input, in_axis=in_axis, out_axis=out_axis, sync=sync) + else: + output = _all_to_all(input, in_axis=in_axis, out_axis=out_axis, sync=sync) + + return output + + +def all_to_all_opp(output, in_axis, out_axis, sync=True): + """ Duality Async Operation for all_to_all. + if sync=None, sync will be assign according init_dap setting. + """ + if get_tensor_model_parallel_world_size() == 1: + return output + + if sync is None: + sync = is_sync() + + if not sync: + dist.wait(output, group=get_tensor_model_parallel_group(), use_calc_stream=sync) + + if not output.stop_gradient: + output = All_to_All_Opp.apply(output, in_axis=in_axis, out_axis=out_axis, sync=sync) + + if out_axis != 0: + ensure_divisibility(output.shape[0], get_tensor_model_parallel_world_size()) + output = paddle.concat(paddle.split(output, get_tensor_model_parallel_world_size(), axis=0), axis=out_axis) + + return output + + +def row_to_col(input): + """ N, S, R, C => N, R, S, C using sync all_to_all """ + if get_tensor_model_parallel_world_size() == 1: + return input + + ensure_divisibility(input.shape[2], get_tensor_model_parallel_world_size()) + input = paddle.concat(paddle.split(input, get_tensor_model_parallel_world_size(), axis=2), axis=0) + + if not input.stop_gradient: + output = All2All.apply(input, in_axis=2, out_axis=1) + else: + output = _all_to_all(input, in_axis=2, out_axis=1) + + output = paddle.concat(paddle.split(output, get_tensor_model_parallel_world_size(), axis=0), axis=1) + return output + + +def col_to_row(input): + """ N, R, S, C => N, S, R, C using sync all_to_all """ + if get_tensor_model_parallel_world_size() == 1: + return input + + ensure_divisibility(input.shape[1], get_tensor_model_parallel_world_size()) + input = paddle.concat(paddle.split(input, get_tensor_model_parallel_world_size(), axis=1), axis=0) + + if not input.stop_gradient: + output = All2All.apply(input, in_axis=1, out_axis=2) + else: + output = _all_to_all(input, in_axis=1, out_axis=2) + + output = paddle.concat(paddle.split(output, get_tensor_model_parallel_world_size(), axis=0), axis=2) + return output + + +@paddle.no_grad() +def grad_sync(param_groups, comm_group): + """ + sync the gradients of params + """ + + nranks = comm_group.nranks + + if nranks < 2: + return + + for group in param_groups: + if group.get("dap", False): + for p in group['params']: + if p.is_distributed: + continue + + grad = p.grad + if grad is None: + continue + + paddle.distributed.all_reduce( + grad, use_calc_stream=True, group=comm_group) + + return None diff --git a/apps/protein_folding/helixfold_cpu/inspect_input.py b/apps/protein_folding/helixfold_cpu/inspect_input.py new file mode 100644 index 00000000..c07648e3 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/inspect_input.py @@ -0,0 +1,10 @@ +import numpy as np +import pickle as pkl +import pdb + + +f = '/home/yangw/sources/helix_fold/output/T1026/features.pkl' +with open(f,'rb') as h: + df = pkl.load(h) + print(df.keys()) + pdb.set_trace() \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/layers/__init__.py b/apps/protein_folding/helixfold_cpu/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/protein_folding/helixfold_cpu/layers/backbones.py b/apps/protein_folding/helixfold_cpu/layers/backbones.py new file mode 100644 index 00000000..4bb0bab6 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/layers/backbones.py @@ -0,0 +1,571 @@ +import paddle +import paddle.nn as nn +import dap +from paddle.distributed.fleet.utils import recompute +from tools import all_atom, residue_constants +from layers.basics import ( + MSAColumnAttention, + MSARowAttentionWithPairBias, + MSAColumnGlobalAttention, + Transition, + OuterProductMean, + TriangleAttention, + TriangleMultiplication, + dgram_from_positions +) +from layers.embeddings import TemplateEmbedding + + +def recompute_wrapper(func, *args, is_recompute=True): + """Function wrapper for recompute""" + if is_recompute: + return recompute(func, *args) + else: + return func(*args) + + +class EvoformerIteration(nn.Layer): + """Single iteration (block) of Evoformer stack. + + Jumper et al. (2021) Suppl. Alg. 6 "EvoformerStack" lines 2-10 + """ + def __init__(self, channel_num, config, global_config, is_extra_msa=False): + super(EvoformerIteration, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.is_extra_msa = is_extra_msa + + # Row-wise Gated Self-attention with Pair Bias + self.msa_row_attention_with_pair_bias = MSARowAttentionWithPairBias( + channel_num, self.config.msa_row_attention_with_pair_bias, + self.global_config, is_extra_msa) + dropout_rate, dropout_axis = self._parse_dropout_params( + self.msa_row_attention_with_pair_bias) + self.msa_row_attn_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + if self.is_extra_msa: + self.msa_column_global_attention = MSAColumnGlobalAttention( + channel_num, config.msa_column_attention, global_config) + dropout_rate, dropout_axis = self._parse_dropout_params( + self.msa_column_global_attention) + self.msa_col_attn_dropout = nn.Dropout( + dropout_rate, axis=dropout_axis) + else: + self.msa_column_attention = MSAColumnAttention( + channel_num, config.msa_column_attention, global_config) + dropout_rate, dropout_axis = self._parse_dropout_params( + self.msa_column_attention) + self.msa_col_attn_dropout = nn.Dropout( + dropout_rate, axis=dropout_axis) + + self.msa_transition = Transition( + channel_num, self.config.msa_transition, self.global_config, + is_extra_msa, 'msa_transition') + dropout_rate, dropout_axis = self._parse_dropout_params( + self.msa_transition) + self.msa_transition_dropout = nn.Dropout( + dropout_rate, axis=dropout_axis) + + # OuterProductMean + self.outer_product_mean = OuterProductMean(channel_num, + self.config.outer_product_mean, self.global_config, + self.is_extra_msa, name='outer_product_mean') + + # Dropout + dropout_rate, dropout_axis = self._parse_dropout_params( + self.outer_product_mean) + self.outer_product_mean_dropout = nn.Dropout( + dropout_rate, axis=dropout_axis) + + # Triangle Multiplication. + self.triangle_multiplication_outgoing = TriangleMultiplication(channel_num, + self.config.triangle_multiplication_outgoing, self.global_config, + name='triangle_multiplication_outgoing') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_multiplication_outgoing) + self.triangle_outgoing_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_multiplication_incoming = TriangleMultiplication(channel_num, + self.config.triangle_multiplication_incoming, self.global_config, + name='triangle_multiplication_incoming') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_multiplication_incoming) + self.triangle_incoming_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + # TriangleAttention. + self.triangle_attention_starting_node = TriangleAttention(channel_num, + self.config.triangle_attention_starting_node, self.global_config, + name='triangle_attention_starting_node') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_attention_starting_node) + self.triangle_starting_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_attention_ending_node = TriangleAttention(channel_num, + self.config.triangle_attention_ending_node, self.global_config, + name='triangle_attention_ending_node') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_attention_ending_node) + self.triangle_ending_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + # Pair transition. + self.pair_transition = Transition( + channel_num, self.config.pair_transition, self.global_config, + is_extra_msa, 'pair_transition') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.pair_transition) + self.pair_transition_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + def _parse_dropout_params(self, module): + dropout_rate = 0.0 if self.global_config.deterministic else \ + module.config.dropout_rate + dropout_axis = None + if module.config.shared_dropout: + dropout_axis = { + 'per_row': [0, 2, 3], + 'per_column': [0, 1, 3], + }[module.config.orientation] + + return dropout_rate, dropout_axis + + def forward(self, + msa_act, # [1, 512, len_dim, 256], dtype='float32' + pair_act, # [1, len_dim, len_dim, 128], dtype='float32' + msa_mask, # [1, 512, len_dim], dtype='float32' + pair_mask # [1, len_dim, len_dim], dtype='float32' + ): + # [B, N_seq//dap_size, N_res, c_m] + residual = self.msa_row_attention_with_pair_bias( + msa_act, msa_mask, pair_act) + residual = self.msa_row_attn_dropout(residual) + msa_act = msa_act + residual + + # [B, N_seq//dap_size, N_res, c_m] => [B, N_seq, N_res//dap_size, c_m] + msa_act = dap.row_to_col(msa_act) + + if self.is_extra_msa: + # [B, N_seq, N_res//dap_size, c_m] + residual = self.msa_column_global_attention(msa_act, msa_mask) + residual = self.msa_col_attn_dropout(residual) + msa_act = msa_act + residual + + # [B, N_seq, N_res//dap_size, c_m] + residual = self.msa_transition(msa_act) + residual = self.msa_transition_dropout(residual) + msa_act = msa_act + residual + + else: + # [B, N_seq, N_res//dap_size, c_m] + residual = self.msa_column_attention(msa_act, msa_mask) + residual = self.msa_col_attn_dropout(residual) + msa_act = msa_act + residual + + # [B, N_seq, N_res//dap_size, c_m] + residual = self.msa_transition(msa_act) + residual = self.msa_transition_dropout(residual) + msa_act = msa_act + residual + + # return msa_act, pair_act, pair_mask # 128GB + + # [B, N_res//dap_size, N_res, c_z] + residual = self.outer_product_mean(msa_act, msa_mask) + residual = self.outer_product_mean_dropout(residual) + pair_act = pair_act + residual + + # return msa_act, pair_act, pair_mask # single-thread computation 129 GB + # [B, N_seq, N_res//dap_size, c_m] => [B, N_seq//dap_size, N_res, c_m] + msa_act = dap.all_to_all(msa_act, in_axis=1, out_axis=2) + + # scatter if using dap, otherwise do nothing + pair_mask_row = dap.scatter(pair_mask, axis=1) + pair_mask_col = dap.scatter(pair_mask, axis=2) + + # [B, N_res//dap_size, N_res, c_z] + residual = self.triangle_multiplication_outgoing(pair_act, pair_mask_row) + residual = self.triangle_outgoing_dropout(residual) + pair_act = pair_act + residual + + # return msa_act, pair_act, pair_mask # 141 GB + + # [B, N_res//dap_size, N_res, c_z] => [B, N_res, N_res//dap_size, c_z] + pair_act = dap.row_to_col(pair_act) + # [B, N_res, N_res//dap_size, c_z] + residual = self.triangle_multiplication_incoming(pair_act, pair_mask_col) + residual = self.triangle_incoming_dropout(residual) + pair_act = pair_act + residual + + # return msa_act, pair_act, pair_mask # 141 GB + + # [B, N_res, N_res//dap_size, c_z] => [B, N_res//dap_size, N_res, c_z] + pair_act = dap.col_to_row(pair_act) + # [B, N_res//dap_size, N_res, c_z] + residual = self.triangle_attention_starting_node(pair_act, pair_mask_row) + residual = self.triangle_starting_dropout(residual) + pair_act = pair_act + residual + + # return msa_act, pair_act, pair_mask # 149 GB + + # [B, N_res//dap_size, N_res, c_z] => [B, N_res, N_res//dap_size, c_z] + pair_act = dap.row_to_col(pair_act) + # [B, N_res, N_res//dap_size, c_z] + residual = self.triangle_attention_ending_node(pair_act, pair_mask_col) + residual = self.triangle_ending_dropout(residual) + pair_act = pair_act + residual + + residual = self.pair_transition(pair_act) + residual = self.pair_transition_dropout(residual) + pair_act = pair_act + residual + + # return msa_act, pair_act, pair_mask # 303 GB + + # [B, N_res, N_res//dap_size, c_z] => [B, N_res//dap_size, N_res, c_z] + pair_act = dap.col_to_row(pair_act) + + # wait if using async communication and dap, otherwise do nothing + # [B, N_seq//dap_size, N_res, c_m] + msa_act = dap.all_to_all_opp(msa_act, in_axis=1, out_axis=2) + + return msa_act, pair_act + + +class ExtraEvoformerIterations(nn.Layer): + def __init__(self, channel_num, config, global_config): + super(ExtraEvoformerIterations, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.is_extra_msa = True + self.n_layers = self.config['extra_msa_stack_num_block'] + self.extra_msa_stack = nn.LayerList([EvoformerIteration( + channel_num, + self.config['evoformer'], + self.global_config, + self.is_extra_msa + ) for _ in range(self.n_layers)]) + + def forward(self, extra_msa_act, extra_pair_act, extra_msa_mask, mask_2d): + for extra_msa_stack_iteration in self.extra_msa_stack: + extra_msa_act, extra_pair_act = extra_msa_stack_iteration( + extra_msa_act, extra_pair_act, extra_msa_mask, mask_2d) + return extra_msa_act, extra_pair_act + + +class Embeddings(nn.Layer): + """Embeds the input data and runs Evoformer. + + Produces the MSA, single and pair representations. + Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5-18 + """ + + def __init__(self, channel_num, config, global_config): + super(Embeddings, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + # InputEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 + # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" + self.preprocess_1d = nn.Linear(channel_num['target_feat'], + self.config.msa_channel, name='preprocess_1d') + self.preprocess_msa = nn.Linear(channel_num['msa_feat'], + self.config.msa_channel, name='preprocess_msa') + self.left_single = nn.Linear(channel_num['target_feat'], self.config.pair_channel, + name='left_single') + self.right_single = nn.Linear(channel_num['target_feat'], self.config.pair_channel, + name='right_single') + + # RecyclingEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 + # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" + if self.config.recycle_pos: + self.prev_pos_linear = nn.Linear(self.config.prev_pos.num_bins, + self.config.pair_channel) + + # RelPosEmbedder + # Jumper et al. (2021) Suppl. Alg. 4 "relpos" + # Jumper et al. (2021) Suppl. Alg. 5 "one_hot" + if self.config.max_relative_feature: + self.pair_activiations = nn.Linear( + 2 * self.config.max_relative_feature + 1, + self.config.pair_channel) + + if self.config.recycle_features: + self.prev_msa_first_row_norm = nn.LayerNorm( + self.config.msa_channel) + self.prev_pair_norm = nn.LayerNorm(self.config.pair_channel) + + # Embed templates into the pair activations. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13 + if self.config.template.enabled: + self.channel_num['template_angle'] = 57 + self.channel_num['template_pair'] = 88 + self.template_embedding = TemplateEmbedding( + self.channel_num, self.config.template, self.global_config) + + # ExtraMSAEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16 + self.extra_msa_activations = nn.Linear( + 25, # 23 (20aa+unknown+gap+mask) + 1 (has_del) + 1 (del_val) + self.config.extra_msa_channel) + + def _pseudo_beta_fn(self, aatype, all_atom_positions): + gly_id = paddle.ones_like(aatype) * residue_constants.restype_order['G'] # gly_id = (1, len_dim) + is_gly = paddle.equal(aatype, gly_id) # is_gly = (1, len_dim) + is_gly_dim = len(is_gly.shape) + new_is_gly = paddle.unsqueeze(is_gly, axis=-1) + new_is_gly.stop_gradient = True + + ca_idx = residue_constants.atom_order['CA'] # 1 + cb_idx = residue_constants.atom_order['CB'] # 3 + n = len(all_atom_positions.shape) + pseudo_beta = paddle.where( + paddle.tile(new_is_gly, [1] * is_gly_dim + [3]), # 1, len_dim, 3 + paddle.squeeze(all_atom_positions.slice([n-2], [ca_idx], [ca_idx+1]),axis=-2), # 1, len_dim + paddle.squeeze(all_atom_positions.slice([n-2], [cb_idx], [cb_idx+1]),axis=-2) # 1, len_dim + ) + return pseudo_beta # = (1, len_dim, 3) + + def _create_extra_msa_feature(self, + extra_msa, + extra_has_deletion, + extra_deletion_value): + # 23: 20aa + unknown + gap + bert mask + extra_msa = extra_msa.astype(paddle.int32) + msa_1hot = nn.functional.one_hot(extra_msa, 23) + msa_feat = [msa_1hot, + paddle.unsqueeze(extra_has_deletion, axis=-1), + paddle.unsqueeze(extra_deletion_value, axis=-1)] + return paddle.concat(msa_feat, axis=-1) + + def forward(self, + target_feat, + msa_feat, + seq_mask, + aatype, + residue_index, + template_mask, + template_aatype, + template_pseudo_beta_mask, + template_pseudo_beta, + template_all_atom_positions, + template_all_atom_masks, + extra_msa, + extra_has_deletion, + extra_deletion_value, + prev_pos=None, + prev_msa_first_row=None, + prev_pair=None): + # InputEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 + # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" + preprocess_1d = self.preprocess_1d(target_feat) + # preprocess_msa = self.preprocess_msa(batch['msa_feat']) + msa_activations = paddle.unsqueeze(preprocess_1d, axis=1) + \ + self.preprocess_msa(msa_feat) + + right_single = self.right_single(target_feat) # 1, n_res, 22 -> 1, n_res, 128 + right_single = paddle.unsqueeze(right_single, axis=1) # 1, n_res, 128 -> 1, 1, n_res, 128 + left_single = self.left_single(target_feat) # 1, n_res, 22 -> 1, n_res, 128 + left_single = paddle.unsqueeze(left_single, axis=2) # 1, n_res, 128 -> 1, n_res, 1, 128 + pair_activations = left_single + right_single + + mask_2d = paddle.unsqueeze(seq_mask, axis=1) * paddle.unsqueeze(seq_mask, axis=2) + # Inject previous outputs for recycling. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 + # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" + + if self.config.recycle_pos: # and prev_pos is not None: + prev_pseudo_beta = self._pseudo_beta_fn(aatype, prev_pos) + dgram = dgram_from_positions( + prev_pseudo_beta, **self.config.prev_pos) + pair_activations += self.prev_pos_linear(dgram) + + if self.config.recycle_features: + if prev_msa_first_row is not None: + prev_msa_first_row = self.prev_msa_first_row_norm( + prev_msa_first_row) + + # A workaround for `jax.ops.index_add` + msa_first_row = paddle.squeeze(msa_activations[:, 0, :], axis=1) + msa_first_row += prev_msa_first_row + msa_first_row = paddle.unsqueeze(msa_first_row, axis=1) + msa_activations_raw = paddle.concat([msa_first_row, msa_activations[:, 1:, :]], axis=1) + + if 'prev_pair' is not None: + pair_activations += self.prev_pair_norm(prev_pair) + + # RelPosEmbedder + # Jumper et al. (2021) Suppl. Alg. 4 "relpos" + # Jumper et al. (2021) Suppl. Alg. 5 "one_hot" + if self.config.max_relative_feature: + pos = residue_index # [bs, N_res] + offset = paddle.unsqueeze(pos, axis=[-1]) - \ + paddle.unsqueeze(pos, axis=[-2]) + offset = offset.astype(dtype=paddle.int32) + rel_pos = nn.functional.one_hot( + paddle.clip( + offset + self.config.max_relative_feature, + min=0, + max=2 * self.config.max_relative_feature), + 2 * self.config.max_relative_feature + 1) + rel_pos_bias = self.pair_activiations(rel_pos) + pair_activations += rel_pos_bias + + # TemplateEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13 + if self.config.template.enabled: # [TODO] check if valid + #template_batch = {k: batch[k] for k in batch if k.startswith('template_')} + # pdb.set_trace() + template_pair_repr = self.template_embedding( + pair_activations, # 1xlxlx128 + template_mask, # 1x4 + template_aatype, # 1xl + template_pseudo_beta_mask, # 1xl + template_pseudo_beta, # 1xlx3 + template_all_atom_positions, # 1xlx37x3 + template_all_atom_masks, # 1xlx37 + mask_2d # 1xlxl + ) + pair_activations += template_pair_repr + + # ExtraMSAEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16 + extra_msa_feat = self._create_extra_msa_feature( # [INFO] done + extra_msa, extra_has_deletion, extra_deletion_value + ) + extra_msa_activations = self.extra_msa_activations(extra_msa_feat) + # ================================================== + # Extra MSA Stack + # Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack" + # ================================================== + # extra_msa_stack_input = { + # 'msa': extra_msa_activations, + # 'pair': pair_activations, + # } + + # scatter if using dap, otherwise do nothing + # [B, N_seq, N_res, c_m] => [B, N_seq//dap_size, N_res, c_m] + extra_msa_act = dap.scatter(extra_msa_activations, axis=1) + # [B, N_res, N_res, c_z] => [B, N_res//dap_size, N_res, c_z] + extra_pair_act = dap.scatter(pair_activations, axis=1) + + return ( + msa_activations_raw, # (1, 508, len_dim, 256) + extra_msa_act, # (1, 5120, len_dim, 64) + extra_pair_act, # (1, len_dim, len_dim, 128) + mask_2d # (1, len_dim, len_dim) + ) + + +class ExtraMsa(nn.Layer): + def __init__(self, channel_num, config, global_config): + super(ExtraMsa, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + # Extra MSA Stack. + # Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack" + self.extra_msa_stack = nn.LayerList() + for _ in range(self.config.extra_msa_stack_num_block): + self.extra_msa_stack.append(EvoformerIteration( + self.channel_num, self.config.evoformer, self.global_config, + is_extra_msa=True)) + + def _create_extra_msa_feature(self, + extra_msa, + extra_has_deletion, + extra_deletion_value): + # 23: 20aa + unknown + gap + bert mask + extra_msa = extra_msa.astype(paddle.int32) + msa_1hot = nn.functional.one_hot(extra_msa, 23) + msa_feat = [msa_1hot, + paddle.unsqueeze(extra_has_deletion, axis=-1), + paddle.unsqueeze(extra_deletion_value, axis=-1)] + return paddle.concat(msa_feat, axis=-1) + + def forward(self, + extra_msa_act, + extra_pair_act, + extra_msa_mask, + mask_2d + ): + for extra_msa_stack_iteration in self.extra_msa_stack: + extra_msa_act_new, extra_pair_act_new = recompute_wrapper( # [TODO] check if valid + extra_msa_stack_iteration, + extra_msa_act, + extra_pair_act, + extra_msa_mask, + mask_2d, + is_recompute=self.training) + extra_msa_act = extra_msa_act_new + extra_pair_act = extra_pair_act_new + + # gather if using dap, otherwise do nothing + # [B, N_res//dap_size, N_res, c_z] => [B, N_res, N_res, c_z] + extra_pair_act= dap.gather(extra_pair_act, axis=1) # 1xlxlx128 + # msa_activations_raw = 1x508xlx256 + + return extra_msa_act, extra_pair_act + + +class SingleTemplateEmbedding(nn.Layer): + def __init__(self, + channel_num, + config, # model_config['model']['embeddings_and_evoformer'] + global_config + ): + super(SingleTemplateEmbedding, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + # Embed templates torsion angles + if self.config.template.enabled and self.config.template.embed_torsion_angles: + c = self.config.msa_channel + self.template_single_embedding = nn.Linear( + self.channel_num['template_angle'], c) + self.template_projection = nn.Linear(c, c) + + def forward(self, + msa_mask, + torsion_angles_mask, + msa_activations_raw, + template_features + ): + template_activations = self.template_single_embedding( + template_features) + template_activations = nn.functional.relu(template_activations) + template_activations = self.template_projection(template_activations) + + # Concatenate the templates to the msa. + msa_activations = paddle.concat( + [msa_activations_raw, template_activations], axis=1) + + # Concatenate templates masks to the msa masks. + # Use mask from the psi angle, as it only depends on the backbone atoms + # from a single residue. + torsion_angle_mask = torsion_angles_mask[..., 2] + torsion_angle_mask = torsion_angle_mask.astype(msa_mask.dtype) + msa_mask = paddle.concat([msa_mask, torsion_angle_mask], axis=1) + return msa_activations, msa_mask + + +class SingleActivations(nn.Layer): + def __init__(self, channel_num, config, global_config): + super(SingleActivations, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + self.single_activations = nn.Linear( + self.config['msa_channel'], self.config['seq_channel']) + + def forward(self, msa_activation): + return self.single_activations(msa_activation) diff --git a/apps/protein_folding/helixfold_cpu/layers/basics.py b/apps/protein_folding/helixfold_cpu/layers/basics.py new file mode 100644 index 00000000..edac118c --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/layers/basics.py @@ -0,0 +1,955 @@ +from paddle import nn +import paddle +from tools import dap +import numpy as np +import functools +import numbers +import collections + + +def set_tensor_constant(tensor, constant): + tensor.set_value(paddle.full_like(tensor, constant)) + + +def init_gate_linear(linear): + set_tensor_constant(linear.weight, 0) + set_tensor_constant(linear.bias, 1) + + +def init_final_linear(linear): + set_tensor_constant(linear.weight, 0) + +# alternative way to reduce memory cost during Evoformer +def subbatch(f, arg_idx, dim, bs, out_idx): + """ Converts a function to one that applies to subbatch of an input + dimension. + + Args: + f(Callable): original function. + arg_idx([int]): indices of the inputs to be subbatched. + dim([int]): index of the dimension to be subbatched. + bs(int): subbatch size. + out_idx(int): index of the output dimension that needs stacking + + Returns: + converted function. + """ + @functools.wraps(f) + def wrapper(*args, **kwargs): + assert len(arg_idx) == len(dim), f'Number of batching args and number of batching dims should match.' + + inps = [args[i] for i in arg_idx] + dim_width = [inp.shape[d] for inp, d in zip(inps, dim)] + assert len(set(dim_width)) == 1, f'Batch sizes should be kept equal.' + + inp_dim = {inp: d for inp, d in zip(inps, dim)} + + dim_width = dim_width[0] + if dim_width < bs: + return f(*args, **kwargs) + + outs = [] + for slice_at in np.arange(0, dim_width, bs): + _args = [] + for i, inp in enumerate(args): + if i in arg_idx: + inp = inp.slice([inp_dim[inp]], [slice_at], [slice_at + bs]) + _args.append(inp) + outs.append(f(*_args, **kwargs)) + + return paddle.concat(outs, out_idx) + + return wrapper + + +def mask_mean(mask:paddle.Tensor, value:paddle.Tensor, axis=None, drop_mask_channel=False, eps=1e-10): + if drop_mask_channel: + mask = mask[:, 0] + + mask_shape = mask.shape + value_shape = value.shape + assert len(mask_shape) == len(value_shape) + + if isinstance(axis, numbers.Integral): + axis = [axis] + elif axis is None: + axis = list(range(len(mask_shape))) + + assert isinstance(axis, collections.Iterable), \ + 'axis needs to be either an iterable, integer or "None"' + + broadcast_factor = 1. + for axis_ in axis: + value_size = value_shape[axis_] + mask_size = mask_shape[axis_] + if mask_size == 1: + broadcast_factor *= value_size + else: + assert mask_size == value_size + + return (paddle.sum(mask * value, axis=axis) / + (paddle.sum(mask, axis=axis) * broadcast_factor + eps)) + + +def set_tensor_constant(tensor:paddle.Tensor, constant): + tensor.set_value(paddle.full_like(tensor, constant)) + + +def init_gate_linear(linear:nn.Linear): + set_tensor_constant(linear.weight, 0) + set_tensor_constant(linear.bias, 1) + + +class Attention(nn.Layer): + """Multihead attention.""" + + def __init__(self, config, global_config, q_dim, kv_dim, output_dim): + super(Attention, self).__init__() + self.config = config + self.global_config = global_config + + num_head = self.config.num_head + key_dim = self.config.get('key_dim', q_dim) + value_dim = self.config.get('value_dim', kv_dim) + + assert key_dim % num_head == 0 + assert value_dim % num_head == 0 + key_dim = key_dim // num_head + value_dim = value_dim // num_head + + self.key_dim = key_dim + self.value_dim = value_dim + + self.query_w = paddle.create_parameter( + [q_dim, num_head, key_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + self.key_w = paddle.create_parameter( + [kv_dim, num_head, key_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + self.value_w = paddle.create_parameter( + [kv_dim, num_head, value_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + + if self.config.gating: + self.gating_w = paddle.create_parameter( + [q_dim, num_head, value_dim], 'float32', + default_initializer=nn.initializer.Constant(0.0)) + self.gating_b = paddle.create_parameter( + [num_head, value_dim], 'float32', + default_initializer=nn.initializer.Constant(1.0)) + + if self.global_config.zero_init: + init = nn.initializer.Constant(0.0) + else: + init = nn.initializer.XavierUniform() + + self.output_w = paddle.create_parameter( + [num_head, value_dim, output_dim], 'float32', + default_initializer=init) + self.output_b = paddle.create_parameter( + [output_dim], 'float32', + default_initializer=nn.initializer.Constant(0.0)) + + def forward(self, q_data, m_data, bias, nonbatched_bias=None): + """Builds Attention module. + Arguments: + q_data: A tensor of queries, shape [batch, row_size, N_queries, q_channels]. + m_data: A tensor of memories from which the keys and values are + projected, shape [batch, row_size, N_keys, m_channels]. + bias: A bias for the attention, shape [batch, row_size, num_head, N_queries, N_keys]. + nonbatched_bias: Shared bias, shape [N_queries, N_keys]. + + Returns: + A float32 tensor of shape [batch_size, row_size, N_queries, output_dim]. + """ + c = self.key_dim ** (-0.5) + q = paddle.einsum('nbqa,ahc->nbqhc', q_data, self.query_w) * c + # q_data [1,48,5120,64] + # self.query_w [64, 8, 8] + k = paddle.einsum('nbka,ahc->nbkhc', m_data, self.key_w) + v = paddle.einsum('nbka,ahc->nbkhc', m_data, self.value_w) + logits = paddle.einsum('nbqhc,nbkhc->nbhqk', q, k) + bias # segment fault when input following test samples + # q [1, 48, 5120, 8, 8] + # k [1, 48, 5120, 8, 8] + # bias [1, 48, 1, 1, 5120] + + if nonbatched_bias is not None: + nonbatched_bias_after = dap.all_gather_opp(nonbatched_bias, axis=2) + logits += paddle.unsqueeze(nonbatched_bias_after, axis=1) + + weights = nn.functional.softmax(logits) + + # by paddlepaddle team + if weights.shape[-1] != v.shape[2]: + v = paddle.tile(v, [1,1,weights.shape[-1], 1, 1]) + weighted_avg = paddle.einsum('nbhqk,nbkhc->nbqhc', weights, v) + + if self.config.gating: + gate_values = paddle.einsum('nbqc,chv->nbqhv', q_data, + self.gating_w) + self.gating_b + gate_values = nn.functional.sigmoid(gate_values) + weighted_avg *= gate_values + + output = paddle.einsum('nbqhc,hco->nbqo', weighted_avg, + self.output_w) + self.output_b + return output + + +class GlobalAttention(nn.Layer): + """Global attention. + + Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" lines 2-7 + """ + + def __init__(self, config, global_config, q_dim, kv_dim, output_dim): + super(GlobalAttention, self).__init__() + self.config = config + self.global_config = global_config + + num_head = self.config.num_head + key_dim = self.config.get('key_dim', q_dim) + value_dim = self.config.get('value_dim', kv_dim) + + assert key_dim % num_head == 0 + assert value_dim % num_head == 0 + key_dim = key_dim // num_head + value_dim = value_dim // num_head + + self.key_dim = key_dim + self.value_dim = value_dim + + self.query_w = paddle.create_parameter( + [q_dim, num_head, key_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + self.key_w = paddle.create_parameter( + [kv_dim, key_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + self.value_w = paddle.create_parameter( + [kv_dim, value_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + + if self.config.gating: + self.gating_w = paddle.create_parameter( + [q_dim, num_head, value_dim], 'float32', + default_initializer=nn.initializer.Constant(0.0)) + self.gating_b = paddle.create_parameter( + [num_head, value_dim], 'float32', + default_initializer=nn.initializer.Constant(1.0)) + + if self.global_config.zero_init: + init = nn.initializer.Constant(0.0) + else: + init = nn.initializer.XavierUniform() + + self.output_w = paddle.create_parameter( + [num_head, value_dim, output_dim], 'float32', + default_initializer=init) + self.output_b = paddle.create_parameter( + [output_dim], 'float32', + default_initializer=nn.initializer.Constant(0.0)) + + def forward(self, q_data, m_data, q_mask): + k = paddle.einsum('nbka,ac->nbkc', m_data, self.key_w) + v = paddle.einsum('nbka,ac->nbkc', m_data, self.value_w) + + # NOTE: differ from non-global version using q_avg for attn + q_avg = mask_mean(q_mask, q_data, axis=2) + c = self.key_dim ** (-0.5) + q = paddle.einsum('nba,ahc->nbhc', q_avg, self.query_w) * c + + q_mask_ = paddle.unsqueeze(q_mask, axis=2)[..., 0] + bias = 1e9 * (q_mask_ - 1.) + + logits = paddle.einsum('nbhc,nbkc->nbhk', q, k) + bias + weights = nn.functional.softmax(logits) + weighted_avg = paddle.einsum('nbhk,nbkc->nbhc', weights, v) + + if self.config.gating: + gate_values = paddle.einsum('nbqc,chv->nbqhv', q_data, + self.gating_w) + self.gating_b + gate_values = nn.functional.sigmoid(gate_values) + weighted_avg = paddle.unsqueeze(weighted_avg, axis=2) + weighted_avg *= gate_values + + output = paddle.einsum('nbqhc,hco->nbqo', weighted_avg, + self.output_w) + self.output_b + else: + output = paddle.einsum('nbhc,hco->nbo', weighted_avg, + self.output_w) + self.output_b + output = paddle.unsqueeze(output, axis=-1) + + return output + + +class MSAColumnAttention(nn.Layer): + """MSA per-column attention. + + Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention" + """ + + def __init__(self, channel_num, config, global_config): + super(MSAColumnAttention, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.bs = self.global_config.subbatch_size # 48 + assert config.orientation == 'per_column' + + msa_channel = channel_num['msa_channel'] + self.query_norm = nn.LayerNorm(msa_channel) + self.attention = Attention( + self.config, self.global_config, + msa_channel, msa_channel, msa_channel) + + + def subbatch_attention(self, q_mat:paddle.Tensor, m_mat:paddle.Tensor, bias:paddle.Tensor): + arg_idx = [0,1,2] + dim = [1,1,1] + out_idx = 1 + #inps = [args[i] for i in arg_idx] + inps = [q_mat, m_mat, bias] + dim_width = [inp.shape[d] for inp, d in zip(inps, dim)] + inp_dim = {inp: d for inp, d in zip(inps, dim)} + dim_width = dim_width[0] + if dim_width < self.bs: + return self.attention(q_mat, m_mat, bias) + + outs = [] + for slice_at in np.arange(0, dim_width, self.bs): # use np.arange to escape the warning: for-range when cvt to static graph + _args = [] + for i, inp in enumerate(inps): + if i in arg_idx: + inp = inp.slice([inp_dim[inp]], [slice_at], [slice_at + self.bs]) + _args.append(inp) + outs.append(self.attention(_args[0], _args[1], _args[2])) + + return paddle.concat(outs, out_idx) + + + def forward(self, msa_act, msa_mask): + # scatter if using dap, otherwise do nothing + # [B, N_seq, N_res] => [B, N_seq, N_res//dap_size] + msa_mask = dap.scatter(msa_mask, axis=2) + msa_act = paddle.transpose(msa_act, [0, 2, 1, 3]) + msa_mask = paddle.transpose(msa_mask, [0, 2, 1]) + bias = 1e9 * (msa_mask - 1.) + bias = paddle.unsqueeze(bias, axis=[2, 3]) + msa_act = self.query_norm(msa_act) + + msa_act = self.subbatch_attention(msa_act, msa_act, bias) + # unit = self.bs + # n_inps = msa_act.shape[0] + # if msa_act.shape[1] < unit: + # msa_act = self.attention(msa_act, msa_act, bias) + # else: + # for i_inp in range(n_inps): + # for i in range(msa_act.shape[1] // unit): + # q_sub_data = paddle.unsqueeze(msa_act[i_inp, unit*i:unit*(i+1)], axis=0) + # bias_sub = paddle.unsqueeze(bias[i_inp, unit*i:unit*(i+1)], axis=0) + # msa_act[i_inp, unit*i:unit*(i+1)] = self.attention(q_sub_data, q_sub_data, bias_sub) + + msa_act = paddle.transpose(msa_act, [0, 2, 1, 3]) + return msa_act + + +class Transition(nn.Layer): + """Transition layer. + + Jumper et al. (2021) Suppl. Alg. 9 "MSATransition" + Jumper et al. (2021) Suppl. Alg. 15 "PairTransition" + """ + + def __init__(self, channel_num, config, global_config, is_extra_msa, + transition_type): + super(Transition, self).__init__() + assert transition_type in ['msa_transition', 'pair_transition'] + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.is_extra_msa = is_extra_msa + self.transition_type = transition_type + self.bs = self.global_config.subbatch_size # 48 + + if transition_type == 'msa_transition' and is_extra_msa: + in_dim = channel_num['extra_msa_channel'] + elif transition_type == 'msa_transition' and not is_extra_msa: + in_dim = channel_num['msa_channel'] + elif transition_type == 'pair_transition': + in_dim = channel_num['pair_channel'] + + self.input_layer_norm = nn.LayerNorm(in_dim) + self.transition1 = nn.Linear( + in_dim, int(in_dim * self.config.num_intermediate_factor), + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.KaimingNormal())) + + if self.global_config.zero_init: + last_init = nn.initializer.Constant(0.0) + else: + last_init = nn.initializer.TruncatedNormal() + + self.transition2 = nn.Linear( + int(in_dim * self.config.num_intermediate_factor), in_dim, + weight_attr=paddle.ParamAttr(initializer=last_init)) + + def subbatch_transition(self, act:paddle.Tensor): + arg_idx = [0] + dim = [1] + out_idx = 1 + #inps = [args[i] for i in arg_idx] + inps = [act] + dim_width = [inp.shape[d] for inp, d in zip(inps, dim)] + inp_dim = {inp: d for inp, d in zip(inps, dim)} + dim_width = dim_width[0] + if dim_width < self.bs: + return self.transition_module(act) + outs = [] + for slice_at in np.arange(0, dim_width, self.bs): # use np.arange to escape the warning: for-range when cvt to static graph + _args = [] + for i, inp in enumerate(inps): + if i in arg_idx: + inp = inp.slice([inp_dim[inp]], [slice_at], [slice_at + self.bs]) + _args.append(inp) + outs.append(self.transition_module(_args[0])) + + return paddle.concat(outs, out_idx) + + def transition_module(self, x): + x = self.transition1(x) + x = nn.functional.relu(x) + x = self.transition2(x) + return x + + def forward(self, act): # edit by zjh@intel SMG 20220825 + act = self.input_layer_norm(act) + + # act = self.subbatch_transition(act) # [TODO] change slice appendage to slice on-site + dim_width = act.shape[1] + if dim_width < self.bs: + act = self.transition_module(act) + else: + for i in np.arange(0, dim_width, self.bs): # use np.arange to escape the warning: for-range when cvt to static graph + act[:, i:(i + self.bs)] = self.transition_module(act[:, i:(i + self.bs)]) + return act + + +class MSARowAttentionWithPairBias(nn.Layer): + """MSA per-row attention biased by the pair representation. + + Jumper et al. (2021) Suppl. Alg. 7 "MSARowAttentionWithPairBias" + """ + + def __init__(self, channel_num, config, global_config, is_extra_msa): + super(MSARowAttentionWithPairBias, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.is_extra_msa = is_extra_msa + self.bs = self.global_config.subbatch_size + assert config.orientation == 'per_row' + + if is_extra_msa: + self.query_norm = nn.LayerNorm(channel_num['extra_msa_channel']) + else: + self.query_norm = nn.LayerNorm(channel_num['msa_channel']) + + self.feat_2d_norm = nn.LayerNorm(channel_num['pair_channel']) + self.feat_2d_weights = paddle.create_parameter( + [channel_num['pair_channel'], self.config.num_head], 'float32', + default_initializer=nn.initializer.Normal( + std=1. / np.sqrt(channel_num['pair_channel']))) + + if is_extra_msa: + extra_msa_channel = channel_num['extra_msa_channel'] + self.attention = Attention( + self.config, self.global_config, + extra_msa_channel, extra_msa_channel, extra_msa_channel) + else: + msa_channel = channel_num['msa_channel'] + self.attention = Attention( + self.config, self.global_config, + msa_channel, msa_channel, msa_channel) + + def subbatch_attention(self, + q_mat:paddle.Tensor, + m_mat:paddle.Tensor, + bias:paddle.Tensor, + nonbatched_bias:paddle.Tensor + ): + arg_idx = [0,1,2] + dim = [1,1,1] + out_idx = 1 + #inps = [args[i] for i in arg_idx] + inps = [q_mat, m_mat, bias] + dim_width = [inp.shape[d] for inp, d in zip(inps, dim)] + inp_dim = {inp: d for inp, d in zip(inps, dim)} + dim_width = dim_width[0] + if dim_width < self.bs: + return self.attention(q_mat, m_mat, bias) + + outs = [] + for slice_at in np.arange(0, dim_width, self.bs): # use np.arange to escape the warning: for-range when cvt to static graph + _args = [] + for i, inp in enumerate(inps): + if i in arg_idx: + inp = inp.slice([inp_dim[inp]], [slice_at], [slice_at + self.bs]) + _args.append(inp) + outs.append(self.attention(_args[0], _args[1], _args[2], nonbatched_bias)) + + return paddle.concat(outs, out_idx) + + def forward(self, msa_act, msa_mask, pair_act): + + pair_act = self.feat_2d_norm(pair_act) + + # [B, N_res//dap_size, N_res, cz], [cz, head] => [B, head, N_res//dap_size, N_res] + nonbatched_bias_before = paddle.einsum( + 'nqkc,ch->nhqk', pair_act, self.feat_2d_weights) + + # [B, head, N_res//dap_size, N_res] => [B, head, N_res, N_res] + nonbatched_bias = dap.all_gather(nonbatched_bias_before, axis=2) + + # [B, N_seq, N_res] => [B, N_seq//dap_size, N_res] + msa_mask = dap.scatter(msa_mask, axis=1) + + + bias = 1e9 * (msa_mask - 1.) + # [B, N_seq//dap_size, N_res] => [B, N_seq//dap_size, 1, 1, N_res] + bias = paddle.unsqueeze(bias, axis=[2, 3]) + msa_act = self.query_norm(msa_act) + + # if not self.training: + # low memory mode using subbatch + # msa_act = self.subbatch_attention(msa_act, msa_act, bias, nonbatched_bias) + + unit = self.bs + n_inps = msa_act.shape[0] + if msa_act.shape[1] < unit: + msa_act = self.attention(msa_act, msa_act, bias, nonbatched_bias) + else: + for i_inp in range(n_inps): + for i in range(msa_act.shape[1] // unit): + q_sub_data = paddle.unsqueeze(msa_act[i_inp, unit*i:unit*(i+1)], axis=0) + bias_sub = paddle.unsqueeze(bias[i_inp, unit*i:unit*(i+1)], axis=0) + msa_act[i_inp, unit*i:unit*(i+1)] = self.attention( + q_sub_data, q_sub_data, bias_sub, nonbatched_bias) + + # msa_act = self.sliced_attention(msa_act, msa_act, bias, nonbatched_bias) + # else: + # msa_act = self.attention(msa_act, msa_act, bias, nonbatched_bias) + + return msa_act + + +class MSAColumnGlobalAttention(nn.Layer): + """MSA per-column global attention. + + Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" + """ + + def __init__(self, channel_num, config, global_config): + super(MSAColumnGlobalAttention, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.bs = self.global_config.subbatch_size + assert config.orientation == 'per_column' + + extra_msa_channel = channel_num['extra_msa_channel'] + self.query_norm = nn.LayerNorm(extra_msa_channel) + self.attention = GlobalAttention( + self.config, self.global_config, + extra_msa_channel, extra_msa_channel, extra_msa_channel) + + def subbatch_attention(self, msa_act1:paddle.Tensor, msa_act2:paddle.Tensor, msa_mask:paddle.Tensor): + arg_idx = [0,1,2] + dim = [1,1,1] + out_idx = 1 + #inps = [args[i] for i in arg_idx] + inps = [msa_act1, msa_act2, msa_mask] + dim_width = [inp.shape[d] for inp, d in zip(inps, dim)] + inp_dim = {inp: d for inp, d in zip(inps, dim)} + dim_width = dim_width[0] + if dim_width < self.bs: + return self.attention(msa_act1, msa_act2, msa_mask) + + outs = [] + for slice_at in np.arange(0, dim_width, self.bs): # use np.arange to escape the warning: for-range when cvt to static graph + _args = [] + for i, inp in enumerate(inps): + if i in arg_idx: + inp = inp.slice([inp_dim[inp]], [slice_at], [slice_at + self.bs]) + _args.append(inp) + outs.append(self.attention(_args[0], _args[1], _args[2])) + + return paddle.concat(outs, out_idx) + + def forward(self, msa_act, msa_mask): + # scatter if using dap, otherwise do nothing + # [B, N_seq, N_res] => [B, N_seq, N_res//dap_size] + msa_mask = dap.scatter(msa_mask, axis=2) + + msa_act = paddle.transpose(msa_act, [0, 2, 1, 3]) + msa_mask = paddle.transpose(msa_mask, [0, 2, 1]) + + bias = 1e9 * (msa_mask - 1.) + bias = paddle.unsqueeze(bias, axis=[2, 3]) + + msa_mask = paddle.unsqueeze(msa_mask, axis=-1) + msa_act = self.query_norm(msa_act) + + if not self.training: + # low memory mode using subbatch + # sb_attn = subbatch(self.attention, [0, 1, 2], [1, 1, 1], + # self.global_config.subbatch_size, 1) + # msa_act = sb_attn(msa_act, msa_act, msa_mask) + msa_act = self.subbatch_attention(msa_act, msa_act, msa_mask) + else: + msa_act = self.attention(msa_act, msa_act, msa_mask) + + msa_act = paddle.transpose(msa_act, [0, 2, 1, 3]) + return msa_act + + +class OuterProductMean(nn.Layer): + """Computes mean outer product. + + Jumper et al. (2021) Suppl. Alg. 10 "OuterProductMean" + """ + + def __init__(self, channel_num, config, global_config, is_extra_msa, name='outer_product_mean'): + super(OuterProductMean, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + if is_extra_msa: + c_m = channel_num['extra_msa_channel'] + else: + c_m = channel_num['msa_channel'] + + self.layer_norm_input = nn.LayerNorm(c_m, name='layer_norm_input') + self.left_projection = nn.Linear( + c_m, self.config.num_outer_channel, name='left_projection') + self.right_projection = nn.Linear( + c_m, self.config.num_outer_channel, name='right_projection') + + if self.global_config.zero_init: + init_w = nn.initializer.Constant(value=0.0) + else: + init_w = nn.initializer.KaimingNormal() + + self.output_w = paddle.create_parameter( + [self.config.num_outer_channel, self.config.num_outer_channel, channel_num['pair_channel']], + 'float32', default_initializer=init_w) + self.output_b = paddle.create_parameter( + [channel_num['pair_channel']], 'float32', + default_initializer=nn.initializer.Constant(value=0.0)) + + def compute_chunk(self, left_act, right_act): + # This is equivalent to + # + # act = jnp.einsum('abc,ade->dceb', left_act, right_act) + # act = jnp.einsum('dceb,cef->bdf', act, output_w) + output_b + # + # but faster. maybe for subbatch inference? + + # [B, N_seq, N_res//dap_size, num_outer_channel] => [B, N_seq, num_outer_channel, N_res//dap_size] + left_act = left_act.transpose([0, 1, 3, 2]) + # wait if using async communication and dap, otherwise do nothing + right_act_after = dap.all_gather_opp(right_act, axis=2) + # [B, N_seq, num_outer_channel, N_res//dap_size], [B, N_seq, N_res, num_outer_channel] + # => [B, N_res, num_outer_channel, num_outer_channel, N_res//dap_size] + act = paddle.einsum('nacb,nade->ndceb', left_act, right_act_after) + # [B, N_res, num_outer_channel, num_outer_channel, N_res//dap_size], [num_outer_channel, num_outer_channel, c_z] + # => [B, N_res, N_res//dap_size, c_z] + act = paddle.einsum('ndceb,cef->ndbf', act, self.output_w) + self.output_b + # [B, N_res, N_res//dap_size, c_z] => [B, N_res//dap_size, N_res, c_z] + return act.transpose([0, 2, 1, 3]) + + def forward(self, act, mask): + """Builds OuterProductMean module. + + Arguments: + act: MSA representation, shape [batch, N_seq, N_res, c_m]. + mask: MSA mask, shape [batch, N_seq, N_res]. + + Returns: + Update to pair representation, shape [batch, N_res, N_res, c_z]. + """ + # [B, N_seq, N_res//dap_size, c_m] + act = self.layer_norm_input(act) + # [B, N_seq, N_res//dap_size, c_m] => [B, N_seq, N_res//dap_size, num_outer_channel] + right_act_before = self.right_projection(act) + # [B, N_seq, N_res//dap_size, num_outer_channel] => [B, N_seq, N_res, num_outer_channel] + right_act = dap.all_gather(right_act_before, axis=2) + + # [B, N_seq, N_res//dap_size, c_m] => [B, N_seq, N_res//dap_size, num_outer_channel] + left_act = self.left_projection(act) + # [B, N_seq, N_res] => [B, N_seq, N_res, 1] + mask = paddle.unsqueeze(mask, axis=-1) + # [B, N_seq, N_res, 1] => [B, N_seq, N_res//dap_size, 1] + mask_col = dap.scatter(mask, axis=2) + left_act = mask_col * left_act + + # [B, N_seq, N_res//dap_size, 1], [B, N_seq, N_res, 1] => [B, N_res//dap_size, N_res, 1] + epsilon = 1e-3 + norm = paddle.einsum('nabc,nadc->nbdc', mask_col, mask) + epsilon + + + + # if not self.training: + # # low memory mode using subbatch + # sb_chunk = subbatch(self.compute_chunk, [0], [2], + # self.config.chunk_size, 1) + # act = sb_chunk(left_act, right_act) + # else: + act = self.compute_chunk(left_act, right_act) + + act = act / norm + + return act + + +class TriangleAttention(nn.Layer): + """Triangle Attention. + + Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode" + Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode" + """ + + def __init__(self, channel_num, config, global_config, name='triangle_attention'): + super(TriangleAttention, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.bs = self.global_config.subbatch_size + + assert config.orientation in ['per_row', 'per_column'] + + self.query_norm = nn.LayerNorm(channel_num['pair_channel'], + name='query_norm') + self.feat_2d_weights = paddle.create_parameter( + [channel_num['pair_channel'], self.config.num_head], 'float32', + default_initializer=nn.initializer.Normal( + std=1. / np.sqrt(channel_num['pair_channel']))) + + self.attention = Attention(self.config, self.global_config, + channel_num['pair_channel'], channel_num['pair_channel'], + channel_num['pair_channel']) + + def subbatch_attention(self, q_mat:paddle.Tensor, m_mat:paddle.Tensor, bias:paddle.Tensor, nonbatched_bias:paddle.Tensor): + arg_idx = [0,1,2] + dim = [1,1,1] + out_idx = 1 + #inps = [args[i] for i in arg_idx] + inps = [q_mat, m_mat, bias] + dim_width = [inp.shape[d] for inp, d in zip(inps, dim)] + inp_dim = {inp: d for inp, d in zip(inps, dim)} + dim_width = dim_width[0] + if dim_width < self.bs: + return self.attention(q_mat, m_mat, bias) + + outs = [] + for slice_at in np.arange(0, dim_width, self.bs): # use np.arange to escape the warning: for-range when cvt to static graph + _args = [] + for i, inp in enumerate(inps): + if i in arg_idx: + inp = inp.slice([inp_dim[inp]], [slice_at], [slice_at + self.bs]) + _args.append(inp) + outs.append(self.attention(_args[0], _args[1], _args[2], nonbatched_bias)) + + return paddle.concat(outs, out_idx) + + def forward(self, pair_act, pair_mask): + """Builds TriangleAttention module. + + Arguments: + pair_act: [batch, N_res, N_res, c_z] pair activations tensor + pair_mask: [batch, N_res, N_res] mask of non-padded regions in the tensor. + + Returns: + Update to pair_act, shape [batch, N_res, N_res, c_z]. + """ + if self.config.orientation == 'per_column': + pair_act = pair_act.transpose([0, 2, 1, 3]) + pair_mask = pair_mask.transpose([0, 2, 1]) + + # [B, N_res//dap_size, N_res] + bias = 1e9 * (pair_mask - 1.) + # [B, N_res//dap_size, 1, 1, N_res] + bias = paddle.unsqueeze(bias, axis=[2, 3]) + + pair_act = self.query_norm(pair_act) + + # [B, N_res//dap_size, N_res, cz], [cz, head] => [B, head, N_res//dap_size, N_res] + nonbatched_bias_before = paddle.einsum('bqkc,ch->bhqk', pair_act, self.feat_2d_weights) + + # # [B, head, N_res//dap_size, N_res] => [B, head, N_res, N_res] + nonbatched_bias = dap.all_gather(nonbatched_bias_before, axis=2) + + # pair_act = self.subbatch_attention(pair_act, pair_act, bias, nonbatched_bias) + + unit = self.bs + n_inps = pair_act.shape[0] + for i_inp in range(n_inps): + for i in range(pair_act.shape[1] // unit): + q_sub_data = paddle.unsqueeze(pair_act[i_inp, unit*i:unit*(i+1)], axis=0) + bias_sub = paddle.unsqueeze(bias[i_inp, unit*i:unit*(i+1)], axis=0) + pair_act[i_inp, unit*i:unit*(i+1)] = self.attention( + q_sub_data, q_sub_data, bias_sub, nonbatched_bias) + + if self.config.orientation == 'per_column': + pair_act = pair_act.transpose([0, 2, 1, 3]) + + return pair_act + + +class TriangleMultiplication(nn.Layer): + """Triangle multiplication layer ("outgoing" or "incoming"). + + Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing" + Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming" + """ + + def __init__(self, channel_num, config, global_config, name='triangle_multiplication'): + super(TriangleMultiplication, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + self.layer_norm_input = nn.LayerNorm(self.channel_num['pair_channel'], name='layer_norm_input') + self.left_projection = nn.Linear(self.channel_num['pair_channel'], + self.config.num_intermediate_channel, name='left_projection') + self.right_projection = nn.Linear(self.channel_num['pair_channel'], + self.config.num_intermediate_channel, name='right_projection') + self.left_gate = nn.Linear(self.channel_num['pair_channel'], + self.config.num_intermediate_channel, name='left_gate') + init_gate_linear(self.left_gate) + self.right_gate = nn.Linear(self.channel_num['pair_channel'], + self.config.num_intermediate_channel, name='right_gate') + init_gate_linear(self.right_gate) + + # line 4 + self.center_layer_norm = nn.LayerNorm(self.config.num_intermediate_channel, name='center_layer_norm') + self.output_projection = nn.Linear(self.config.num_intermediate_channel, + self.channel_num['pair_channel'], name='output_projection') + init_final_linear(self.output_projection) + # line 3 + self.gating_linear = nn.Linear(self.channel_num['pair_channel'], + self.channel_num['pair_channel'], name='output_projection') + init_gate_linear(self.gating_linear) + + def forward(self, act, mask): + """Builds TriangleMultiplication module. + + Arguments: + act: Pair activations, shape [batch, N_res, N_res, c_z] + mask: Pair mask, shape [batch, N_res, N_res]. + + Returns: + Outputs, same shape/type as act. + """ + # Outgoing [batch, N_res//dap_size, N_res] => [batch, N_res//dap_size, N_res, 1] + # Incoming [batch, N_res, N_res//dap_size] => [batch, N_res, N_res//dap_size, 1] + mask = paddle.unsqueeze(mask, axis=-1) # [batch, N_res, N_res, 1] + + # Outgoing [B, N_res//dap_size, N_res, c_z] + # Incoming [B, N_res, N_res//dap_size, c_z] + act = self.layer_norm_input(act) # line 1 + + # Outgoing [B, N_res//dap_size, N_res, c_z] => [B, N_res//dap_size, N_res, num_intermediate_channel] + # Incoming [B, N_res, N_res//dap_size, c_z] => [B, N_res, N_res//dap_size, num_intermediate_channel] + left_proj_act = mask * self.left_projection(act) + right_proj_act = mask * self.right_projection(act) + + # Outgoing [B, N_res//dap_size, N_res, c_z] => [B, N_res//dap_size, N_res, num_intermediate_channel] + # Incoming [B, N_res, N_res//dap_size, c_z] => [B, N_res, N_res//dap_size, num_intermediate_channel] + left_gate_values = nn.functional.sigmoid(self.left_gate(act)) + right_gate_values = nn.functional.sigmoid(self.right_gate(act)) + + # Outgoing [B, N_res//dap_size, N_res, num_intermediate_channel] + # Incoming [B, N_res, N_res//dap_size, num_intermediate_channel] + left_proj_act = left_proj_act * left_gate_values + right_proj_act_before = right_proj_act * right_gate_values + + + # "Outgoing" edges equation: 'ikc,jkc->ijc' + # "Incoming" edges equation: 'kjc,kic->ijc' + # Note on the Suppl. Alg. 11 & 12 notation: + # For the "outgoing" edges, a = left_proj_act and b = right_proj_act + # For the "incoming" edges, it's swapped: + # b = left_proj_act and a = right_proj_act + + if self.config.equation == 'ikc,jkc->ijc': + # Outgoing + # [B, N_res//dap_size, N_res, num_intermediate_channel] => [B, N_res, N_res, num_intermediate_channel] + right_proj_act = dap.all_gather(right_proj_act_before, axis=1) + elif self.config.equation == 'kjc,kic->ijc': + # Incoming + # [B, N_res, N_res//dap_size, num_intermediate_channel] => [B, N_res, N_res, num_intermediate_channel] + right_proj_act = dap.all_gather(right_proj_act_before, axis=2) + else: + raise ValueError('unknown equation.') + + + # Outgoing [B, N_res//dap_size, N_res, c_z] + # Incoming [B, N_res, N_res//dap_size, c_z] + gate_values = nn.functional.sigmoid(self.gating_linear(act)) # line 3 + + if self.config.equation == 'ikc,jkc->ijc': + # Outgoing + dim, out_idx = 1, 1 + equation = 'bikc,bjkc->bijc' + + # [B, N_res, N_res, num_intermediate_channel] + right_proj_act_after = dap.all_gather_opp(right_proj_act, axis=1) + elif self.config.equation == 'kjc,kic->ijc': + # Incoming + dim, out_idx = 2, 2 + equation = 'bkjc,bkic->bijc' + + # [B, N_res, N_res, num_intermediate_channel] + right_proj_act_after = dap.all_gather_opp(right_proj_act, axis=2) + else: + raise ValueError('unknown equation.') + + # if not self.training: + # einsum_fn = subbatch(paddle.einsum, [1], [dim], self.global_config.subbatch_size, out_idx) + # act = einsum_fn(equation, left_proj_act, right_proj_act_after) + # else: + # Outgoing equation = 'bikc,bjkc->bijc' + # [B, N_res//dap_size, N_res, num_intermediate_channel], [B, N_res, N_res, num_intermediate_channel] + # => [B, N_res//dap_size, N_res, num_intermediate_channel] + + # Incoming equation = 'bkjc,bkic->bijc' + # [B, N_res, N_res//dap_size, num_intermediate_channel], [B, N_res, N_res, num_intermediate_channel] + # => [B, N_res, N_res//dap_size, num_intermediate_channel] + act = paddle.einsum(equation, left_proj_act, right_proj_act_after) + + act = self.center_layer_norm(act) + act = self.output_projection(act) + + act = act * gate_values + + return act + + +def dgram_from_positions(positions, num_bins, min_bin, max_bin): + lower_breaks = paddle.linspace(min_bin, max_bin, num_bins) + lower_breaks = paddle.square(lower_breaks) + upper_breaks = paddle.concat([lower_breaks[1:], + paddle.to_tensor([1e8], dtype='float32')]) + + def _squared_difference(x, y): + return paddle.square(x - y) + + dist2 = paddle.sum( + _squared_difference( + paddle.unsqueeze(positions, axis=-2), + paddle.unsqueeze(positions, axis=-3)), + axis=-1, keepdim=True) + + dgram = ((dist2 > lower_breaks).astype('float32') * + (dist2 < upper_breaks).astype('float32')) + return dgram diff --git a/apps/protein_folding/helixfold_cpu/layers/embeddings.py b/apps/protein_folding/helixfold_cpu/layers/embeddings.py new file mode 100644 index 00000000..dd5517d4 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/layers/embeddings.py @@ -0,0 +1,352 @@ +import pdb +from paddle import nn +import paddle +from layers.basics import ( + Attention, + TriangleAttention, + TriangleMultiplication, + Transition, + dgram_from_positions +) +from tools import residue_constants, quat_affine +import numpy as np + +class TemplatePair(nn.Layer): + """Pair processing for the templates. + + Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack" lines 2-6 + """ + def __init__(self, channel_num, config, global_config): + super(TemplatePair, self).__init__() + self.config = config + self.global_config = global_config + + channel_num = {} + channel_num['pair_channel'] = self.config.triangle_attention_ending_node.value_dim + + self.triangle_attention_starting_node = TriangleAttention(channel_num, + self.config.triangle_attention_starting_node, self.global_config, + name='triangle_attention_starting_node') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_attention_starting_node) + self.triangle_starting_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_attention_ending_node = TriangleAttention(channel_num, + self.config.triangle_attention_ending_node, self.global_config, + name='triangle_attention_ending_node') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_attention_ending_node) + self.triangle_ending_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_multiplication_outgoing = TriangleMultiplication(channel_num, + self.config.triangle_multiplication_outgoing, self.global_config, + name='triangle_multiplication_outgoing') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_multiplication_outgoing) + self.triangle_outgoing_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_multiplication_incoming = TriangleMultiplication(channel_num, + self.config.triangle_multiplication_incoming, self.global_config, + name='triangle_multiplication_incoming') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_multiplication_incoming) + self.triangle_incoming_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.pair_transition = Transition(channel_num, self.config.pair_transition, + self.global_config, is_extra_msa=False, + transition_type='pair_transition') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.pair_transition) + self.pair_transition_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + + def _parse_dropout_params(self, module): + dropout_rate = 0.0 if self.global_config.deterministic else \ + module.config.dropout_rate + dropout_axis = None + if module.config.shared_dropout: + dropout_axis = { + 'per_row': [0, 2, 3], + 'per_column': [0, 1, 3], + }[module.config.orientation] + + return dropout_rate, dropout_axis + + def forward(self, pair_act, pair_mask): + """Builds one block of TemplatePair module. + + Arguments: + pair_act: Pair activations for single template, shape [batch, N_res, N_res, c_t]. + pair_mask: Pair mask, shape [batch, N_res, N_res]. + + Returns: + Updated pair_act, shape [batch, N_res, N_res, c_t]. + """ + + residual = self.triangle_attention_starting_node(pair_act, pair_mask) + residual = self.triangle_starting_dropout(residual) + pair_act = pair_act + residual + + residual = self.triangle_attention_ending_node(pair_act, pair_mask) + residual = self.triangle_ending_dropout(residual) + pair_act = pair_act + residual + + residual = self.triangle_multiplication_outgoing(pair_act, pair_mask) + residual = self.triangle_outgoing_dropout(residual) + pair_act = pair_act + residual + + residual = self.triangle_multiplication_incoming(pair_act, pair_mask) + residual = self.triangle_incoming_dropout(residual) + pair_act = pair_act + residual + + residual = self.pair_transition(pair_act) + residual = self.pair_transition_dropout(residual) + pair_act = pair_act + residual + + return pair_act + + +class SingleTemplateEmbedding(nn.Layer): + """Embeds a single template. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9+11 + """ + def __init__(self, channel_num, config, global_config): + super(SingleTemplateEmbedding, self).__init__() + self.config = config + self.channel_num = channel_num + # {'target_feat': 22, + # 'msa_feat': 49, + # 'extra_msa_channel': 64, + # 'msa_channel': 256, + # 'pair_channel': 128, + # 'seq_channel': 384, + # 'template_pair': 85} + self.global_config = global_config + # self.dtype = query_embedding_dtype + self.embedding2d = nn.Linear(channel_num['template_pair'], + self.config.template_pair_stack.triangle_attention_ending_node.value_dim) + + self.template_pair_stack = nn.LayerList() + for _ in range(self.config.template_pair_stack.num_block): + self.template_pair_stack.append(TemplatePair( + self.channel_num, self.config.template_pair_stack, self.global_config)) + + self.output_layer_norm = nn.LayerNorm(self.config.attention.key_dim) + + def forward(self, + template_aatype, + template_pseudo_beta_mask, + template_pseudo_beta, + template_all_atom_positions, + template_all_atom_masks, + mask_2d): + """Build the single template embedding. + + Arguments: + query_embedding: Query pair representation, shape [batch, N_res, N_res, c_z]. + batch: A batch of template features (note the template dimension has been + stripped out as this module only runs over a single template). + mask_2d: Padding mask (Note: this doesn't care if a template exists, + unlike the template_pseudo_beta_mask). + + Returns: + A template embedding [N_res, N_res, c_z]. + """ + dtype = mask_2d.dtype + num_res = template_aatype.shape[1] + template_mask = template_pseudo_beta_mask + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + template_mask_2d = template_mask_2d.astype(dtype) + + template_dgram = dgram_from_positions( + template_pseudo_beta, + **self.config.dgram_features) + template_dgram = template_dgram.astype(dtype) + + aatype = nn.functional.one_hot(template_aatype, 22) + aatype = aatype.astype(dtype) + + to_concat = [template_dgram, template_mask_2d[..., None]] + to_concat.append(paddle.tile(aatype[..., None, :, :], + [1, num_res, 1, 1])) + to_concat.append(paddle.tile(aatype[..., None, :], + [1, 1, num_res, 1])) + + #if self.config.use_template_unit_vector: + n, ca, c = [residue_constants.atom_order[a] + for a in ('N', 'CA', 'C')] + rot, trans = quat_affine.make_transform_from_reference( + n_xyz=template_all_atom_positions[..., n, :], # reference shape [1, len, 37, 3] + ca_xyz=template_all_atom_positions[..., ca, :], + c_xyz=template_all_atom_positions[..., c, :]) + affines = quat_affine.QuatAffine( + quaternion=quat_affine.rot_to_quat(rot), + translation=trans, + rotation=rot) + + points = [paddle.unsqueeze(x, axis=-2) for x in + paddle.unstack(affines.translation, axis=-1)] + affine_vec = affines.invert_point(points, extra_dims=1) + inv_distance_scalar = paddle.rsqrt( + 1e-6 + sum([paddle.square(x) for x in affine_vec])) + + # Backbone affine mask: whether the residue has C, CA, N + # (the template mask defined above only considers pseudo CB). + template_mask = ( + template_all_atom_masks[..., n] * + template_all_atom_masks[..., ca] * + template_all_atom_masks[..., c]) + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + inv_distance_scalar *= template_mask_2d.astype(inv_distance_scalar.dtype) + + unit_vector = [(x * inv_distance_scalar)[..., None] for x in affine_vec] + unit_vector = [x.astype(dtype) for x in unit_vector] + + ### [UnboundLocalError] local variable 'x' .... + if not self.config.use_template_unit_vector: + unit_vector = [paddle.zeros_like(x) for x in unit_vector] + to_concat.extend(unit_vector) + + template_mask_2d = template_mask_2d.astype(dtype) + to_concat.append(template_mask_2d[..., None]) + + act = paddle.concat(to_concat, axis=-1) + # Mask out non-template regions so we don't get arbitrary values in the + # distogram for these regions. + act *= template_mask_2d[..., None] + + act = self.embedding2d(act) + for pair_encoder in self.template_pair_stack: # InvalidArgumentError + act = pair_encoder(act, mask_2d) + + act = self.output_layer_norm(act) + return act + + +class TemplateEmbedding(nn.Layer): + """Embeds a set of templates. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12 + Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention" + """ + + def __init__(self, channel_num, config, global_config): + super(TemplateEmbedding, self).__init__() + self.config = config + self.global_config = global_config + + self.single_template_embedding = SingleTemplateEmbedding( + channel_num, config, global_config) + self.attention = Attention( + config.attention, global_config, + channel_num['pair_channel'], + config.attention.key_dim, + channel_num['pair_channel']) + + def subbatch_attention(self, + msa_act1:paddle.Tensor, + msa_act2:paddle.Tensor, + msa_mask:paddle.Tensor): + arg_idx = [0,1] + dim = [1,1] + out_idx = 1 + self.bs = self.config.subbatch_size + #inps = [args[i] for i in arg_idx] + inps = [msa_act1, msa_act2, msa_mask] + dim_width = [inp.shape[d] for inp, d in zip(inps, dim)] + inp_dim = {inp: d for inp, d in zip(inps, dim)} + dim_width = dim_width[0] + if dim_width < self.bs: + return self.attention(msa_act1, msa_act2, msa_mask) + + outs = [] + for slice_at in np.arange(0, dim_width, self.bs): # use np.arange to escape the warning: for-range when cvt to static graph + _args = [] + for i, inp in enumerate(inps): + if i in arg_idx: + inp = inp.slice([inp_dim[inp]], [slice_at], [slice_at + self.bs]) + _args.append(inp) + outs.append(self.attention(_args[0], _args[1], _args[2])) + + return paddle.concat(outs, out_idx) + + def forward(self, + query_embedding, + template_mask, + template_aatype, + template_pseudo_beta_mask, + template_pseudo_beta, + template_all_atom_positions, + template_all_atom_masks, + mask_2d): + """Build TemplateEmbedding module. + + Arguments: + query_embedding: Query pair representation, shape [n_batch, N_res, N_res, c_z]. + template_batch: A batch of template features. + mask_2d: Padding mask (Note: this doesn't care if a template exists, + unlike the template_pseudo_beta_mask). + + Returns: + A template embedding [n_batch, N_res, N_res, c_z]. + """ + num_templates = template_mask.shape[0] + num_channels = (self.config.template_pair_stack + .triangle_attention_ending_node.value_dim) + num_res = query_embedding.shape[1] + dtype = query_embedding.dtype + template_mask = template_mask.astype(dtype) + + query_channels = query_embedding.shape[-1] + template_batch = {'template_mask': template_mask} + + outs = [] + for i in range(num_templates): + # By default, num_templates = 4 + template_aatype = paddle.squeeze(template_aatype.slice([1], [i], [i+1]), axis=1) + template_pseudo_beta_mask = paddle.squeeze(template_pseudo_beta_mask.slice([1], [i], [i+1]), axis=1) + template_pseudo_beta = paddle.squeeze(template_pseudo_beta.slice([1], [i], [i+1]), axis=1) + template_all_atom_positions = paddle.squeeze(template_all_atom_positions.slice([1], [i], [i+1]), axis=1) + template_all_atom_masks = paddle.squeeze(template_all_atom_masks.slice([1], [i], [i+1]), axis=1) + outs.append(self.single_template_embedding( + template_aatype, # [1,len_dim] + template_pseudo_beta_mask, # [1,len_dim] + template_pseudo_beta, # [1,len_dim, 3] + template_all_atom_positions, # [1,len_dim, 37, 3] + template_all_atom_masks, # [1,len_dim, 37] + mask_2d)) # [1,len_dim, len_dim] + + template_pair_repr = paddle.stack(outs, axis=1) + + flat_query = paddle.reshape( + query_embedding, [-1, num_res * num_res, 1, query_channels]) + flat_templates = paddle.reshape( + paddle.transpose(template_pair_repr, [0, 2, 3, 1, 4]), + [-1, num_res * num_res, num_templates, num_channels]) + + bias = 1e9 * (template_mask[:, None, None, None, :] - 1.) + # OK until here + + # if not self.training: + # sb_attn = subbatch(self.attention, [0, 1], [1, 1], + # self.config.subbatch_size, 1) + #emb = self.subbatch_attention(flat_query, flat_templates, bias) # will comes out with a huge graph + # emb = self.attention(flat_query, flat_templates, bias) + # else: + # emb = self.attention(flat_query, flat_templates, bias) + emb = self.attention(flat_query, flat_templates, bias) + + + emb = paddle.reshape( + emb, [-1, num_res, num_res, query_channels]) + + # No gradients if no templates. + emb *= (paddle.sum(template_mask) > 0.).astype(emb.dtype) + return emb + diff --git a/apps/protein_folding/helixfold_cpu/layers/head.py b/apps/protein_folding/helixfold_cpu/layers/head.py new file mode 100644 index 00000000..8f512489 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/layers/head.py @@ -0,0 +1,886 @@ +import paddle +import paddle.nn as nn +import tools.lddt as lddt +from tools import quat_affine, residue_constants +from tools.model_utils import init_final_linear +import numpy as np +from tools import r3, all_atom + + +def generate_new_affine(sequence_mask): + t_shape = sequence_mask.shape[:-1] # (batch, N_res, 1) + assert len(t_shape) == 2 + t_shape.append(3) # (batch, N_res, 3) + q_shape = sequence_mask.shape[:-1] + [1] # (batch, N_res, 1) + quaternion = paddle.tile( + paddle.reshape( + paddle.to_tensor([1.0, 0.0, 0.0, 0.0]), [1, 1, 4]), + repeat_times=q_shape) + translation = paddle.zeros(t_shape) + return quat_affine.QuatAffine(quaternion, translation) + + +def sigmoid_cross_entropy(logits, labels): + """Computes sigmoid cross entropy given logits and multiple class labels.""" + log_p = nn.functional.log_sigmoid(logits) + # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter is more numerically stable + log_not_p = nn.functional.log_sigmoid(-logits) + loss = -labels * log_p - (1. - labels) * log_not_p + return loss + + +def softmax_cross_entropy(logits, labels): + """Computes softmax cross entropy given logits and one-hot class labels.""" + loss = -paddle.sum(labels * nn.functional.log_softmax(logits), axis=-1) + return loss + + +def _distogram_log_loss(logits, bin_edges, batch, num_bins): + """Log loss of a distogram.""" + positions = batch['pseudo_beta'] + mask = batch['pseudo_beta_mask'] + + assert positions.shape[-1] == 3 + + sq_breaks = paddle.square(bin_edges).unsqueeze([1, 2]) + + dist2 = paddle.sum( + paddle.square( + paddle.unsqueeze(positions, axis=-2) - + paddle.unsqueeze(positions, axis=-3)), + axis=-1, + keepdim=True) + + true_bins = paddle.sum(dist2 > sq_breaks, axis=-1) + + errors = softmax_cross_entropy( + labels=nn.functional.one_hot(true_bins, num_classes=num_bins), logits=logits) + + square_mask = paddle.unsqueeze(mask, axis=-2) * paddle.unsqueeze(mask, axis=-1) + + avg_error = ( + paddle.sum(errors * square_mask, axis=[-2, -1]) / + (1e-6 + paddle.sum(square_mask, axis=[-2, -1]))) + dist2 = dist2[..., 0] + return { + 'loss': avg_error, + 'true_dist': paddle.sqrt(1e-6 + dist2)} + + +def l2_normalize(x, axis=-1, epsilon=1e-12): + return x / paddle.sqrt( + paddle.maximum( + paddle.sum(paddle.square(x), axis=axis, keepdim=True), + paddle.to_tensor([epsilon], dtype='float32'))) + + +def squared_difference(x, y): + return paddle.square(x - y) + + +class MaskedMsaHead(nn.Layer): + """Head to predict MSA at the masked locations. + + The MaskedMsaHead employs a BERT-style objective to reconstruct a masked + version of the full MSA, based on a linear projection of + the MSA representation. + Jumper et al. (2021) Suppl. Sec. 1.9.9 "Masked MSA prediction" + """ + def __init__(self, channel_num, config, global_config, name='masked_msa_head'): + super(MaskedMsaHead, self).__init__() + self.config = config + self.global_config = global_config + self.num_output = config.num_output + self.logits = nn.Linear(channel_num['msa_channel'], self.num_output, name='logits') + + def forward(self, msa_representation): + """Builds MaskedMsaHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'msa': MSA representation, shape [batch, N_seq, N_res, c_m]. + batch: Batch, unused. + + Returns: + Dictionary containing: + * 'logits': logits of shape [batch, N_seq, N_res, N_aatype] with + (unnormalized) log probabilies of predicted aatype at position. + """ + logits = self.logits(msa_representation['msa']) + return {logits:logits} + + def loss(self, value, batch): + errors = softmax_cross_entropy( + labels=nn.functional.one_hot(batch['true_msa'], num_classes=self.num_output), + logits=value['logits']) + loss = (paddle.sum(errors * batch['bert_mask'], axis=[-2, -1]) / + (1e-8 + paddle.sum(batch['bert_mask'], axis=[-2, -1]))) + return {'loss': loss} + + +class PredictedLDDTHead(nn.Layer): + """Head to predict the per-residue LDDT to be used as a confidence measure. + + Jumper et al. (2021) Suppl. Sec. 1.9.6 "Model confidence prediction (pLDDT)" + Jumper et al. (2021) Suppl. Alg. 29 "predictPerResidueLDDT_Ca" + """ + + def __init__(self, channel_num, config, global_config, name='predicted_lddt_head'): + super(PredictedLDDTHead, self).__init__() + self.config = config + self.global_config = global_config + + self.input_layer_norm = nn.LayerNorm(channel_num['seq_channel'], + name='input_layer_norm') + self.act_0 = nn.Linear(channel_num['seq_channel'], + self.config.num_channels, name='act_0') + self.act_1 = nn.Linear(self.config.num_channels, + self.config.num_channels, name='act_1') + self.logits = nn.Linear(self.config.num_channels, + self.config.num_bins, name='logits') + + def forward(self, representations): + """Builds PredictedLDDTHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'structure_module': Single representation from the structure module, + shape [n_batch, N_res, c_s]. + + Returns: + Dictionary containing : + * 'logits': logits of shape [n_batch, N_res, N_bins] with + (unnormalized) log probabilies of binned predicted lDDT. + """ + act = representations['structure_module'] + act = self.input_layer_norm(act) + act = nn.functional.relu(self.act_0(act)) + act = nn.functional.relu(self.act_1(act)) + logits = self.logits(act) + + return dict(logits=logits) + + def loss(self, value, batch): + # Shape (n_batch, num_res, 37, 3) + pred_all_atom_pos = value['structure_module']['final_atom_positions'] + # Shape (n_batch, num_res, 37, 3) + true_all_atom_pos = paddle.cast(batch['all_atom_positions'], 'float32') + # Shape (n_batch, num_res, 37) + all_atom_mask = paddle.cast(batch['all_atom_mask'], 'float32') + + # Shape (batch_size, num_res) + lddt_ca = lddt.lddt( + # Shape (batch_size, num_res, 3) + predicted_points=pred_all_atom_pos[:, :, 1, :], + # Shape (batch_size, num_res, 3) + true_points=true_all_atom_pos[:, :, 1, :], + # Shape (batch_size, num_res, 1) + true_points_mask=all_atom_mask[:, :, 1:2], + cutoff=15., + per_residue=True) + lddt_ca = lddt_ca.detach() + + # Shape (batch_size, num_res) + num_bins = self.config.num_bins + bin_index = paddle.floor(lddt_ca * num_bins) + + # protect against out of range for lddt_ca == 1 + bin_index = paddle.minimum(bin_index, paddle.to_tensor(num_bins - 1, dtype='float32')) + lddt_ca_one_hot = nn.functional.one_hot(paddle.cast(bin_index, 'int64'), num_classes=num_bins) + + # Shape (n_batch, num_res, num_channel) + logits = value['predicted_lddt']['logits'] + errors = softmax_cross_entropy(labels=lddt_ca_one_hot, logits=logits) + + # Shape (num_res,) + mask_ca = all_atom_mask[:, :, residue_constants.atom_order['CA']] + mask_ca = paddle.to_tensor(mask_ca, dtype='float32') + loss = paddle.sum(errors * mask_ca, axis=-1) / (paddle.sum(mask_ca, axis=-1) + 1e-8) + + if self.config.filter_by_resolution: + # NMR & distillation have resolution = 0 + resolution = paddle.squeeze(batch['resolution'], axis=-1) + loss *= paddle.cast((resolution >= self.config.min_resolution) + & (resolution <= self.config.max_resolution), 'float32') + output = {'loss': loss} + return output + + +class PredictedAlignedErrorHead(nn.Layer): + """Head to predict the distance errors in the backbone alignment frames. + + Can be used to compute predicted TM-Score. + Jumper et al. (2021) Suppl. Sec. 1.9.7 "TM-score prediction" + """ + def __init__(self, channel_num, config, global_config, + name='predicted_aligned_error_head'): + super(PredictedAlignedErrorHead, self).__init__() + self.config = config + self.global_config = global_config + + self.logits = nn.Linear(channel_num['pair_channel'], + self.config.num_bins, name='logits') + + def forward(self, representations): + """Builds PredictedAlignedErrorHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'pair': pair representation, shape [B, N_res, N_res, c_z]. + batch: Batch, unused. + + Returns: + Dictionary containing: + * logits: logits for aligned error, shape [B, N_res, N_res, N_bins]. + * bin_breaks: array containing bin breaks, shape [N_bins - 1]. + """ + logits = self.logits(representations['pair']) + breaks = paddle.linspace(0., self.config.max_error_bin, + self.config.num_bins-1) + + return dict(logits=logits, breaks=breaks) + + def loss(self, value, batch): + # Shape (B, num_res, 7) + predicted_affine = quat_affine.QuatAffine.from_tensor( + value['structure_module']['final_affines']) + # Shape (B, num_res, 7) + true_rot = paddle.to_tensor(batch['backbone_affine_tensor_rot'], dtype='float32') + true_trans = paddle.to_tensor(batch['backbone_affine_tensor_trans'], dtype='float32') + true_affine = quat_affine.QuatAffine( + quaternion=None, + translation=true_trans, + rotation=true_rot) + # Shape (B, num_res) + mask = batch['backbone_affine_mask'] + # Shape (B, num_res, num_res) + square_mask = mask[..., None] * mask[:, None, :] + num_bins = self.config.num_bins + # (num_bins - 1) + breaks = value['predicted_aligned_error']['breaks'] + # (B, num_res, num_res, num_bins) + logits = value['predicted_aligned_error']['logits'] + + # Compute the squared error for each alignment. + def _local_frame_points(affine): + points = [paddle.unsqueeze(x, axis=-2) for x in + paddle.unstack(affine.translation, axis=-1)] + return affine.invert_point(points, extra_dims=1) + error_dist2_xyz = [ + paddle.square(a - b) + for a, b in zip(_local_frame_points(predicted_affine), + _local_frame_points(true_affine))] + error_dist2 = sum(error_dist2_xyz) + # Shape (B, num_res, num_res) + # First num_res are alignment frames, second num_res are the residues. + error_dist2 = error_dist2.detach() + + sq_breaks = paddle.square(breaks) + true_bins = paddle.sum(paddle.cast((error_dist2[..., None] > sq_breaks), 'int32'), axis=-1) + + errors = softmax_cross_entropy( + labels=paddle.nn.functional.one_hot(true_bins, num_classes=num_bins), logits=logits) + + loss = (paddle.sum(errors * square_mask, axis=[-2, -1]) / + (1e-8 + paddle.sum(square_mask, axis=[-2, -1]))) + + if self.config.filter_by_resolution: + # NMR & distillation have resolution = 0 + resolution = paddle.squeeze(batch['resolution'], axis=-1) + loss *= paddle.cast((resolution >= self.config.min_resolution) + & (resolution <= self.config.max_resolution), 'float32') + + output = {'loss': loss} + return output + + +class ExperimentallyResolvedHead(nn.Layer): + """Predicts if an atom is experimentally resolved in a high-res structure. + + Only trained on high-resolution X-ray crystals & cryo-EM. + Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction' + """ + + def __init__(self, channel_num, config, global_config, name='experimentally_resolved_head'): + super(ExperimentallyResolvedHead, self).__init__() + self.config = config + self.global_config = global_config + self.logits = nn.Linear(channel_num['seq_channel'], 37, name='logits') + + def forward(self, representations): + """Builds ExperimentallyResolvedHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'single': Single representation, shape [B, N_res, c_s]. + batch: Batch, unused. + + Returns: + Dictionary containing: + * 'logits': logits of shape [B, N_res, 37], + log probability that an atom is resolved in atom37 representation, + can be converted to probability by applying sigmoid. + """ + logits = self.logits(representations['single']) + return dict(logits=logits) + + def loss(self, value, batch): + logits = value['logits'] + assert len(logits.shape) == 3 + + # Does the atom appear in the amino acid? + atom_exists = batch['atom37_atom_exists'] + # Is the atom resolved in the experiment? Subset of atom_exists, + # *except for OXT* + all_atom_mask = paddle.cast(batch['all_atom_mask'], 'float32') + + xent = sigmoid_cross_entropy(labels=all_atom_mask, logits=logits) + loss = paddle.sum(xent * atom_exists, axis=[-2, -1]) / (1e-8 + paddle.sum(atom_exists, axis=[-2, -1])) + + if self.config.filter_by_resolution: + # NMR & distillation have resolution = 0 + resolution = paddle.squeeze(batch['resolution'], axis=-1) + loss *= paddle.cast((resolution >= self.config.min_resolution) + & (resolution <= self.config.max_resolution), 'float32') + + output = {'loss': loss} + return output + + +class DistogramHead(nn.Layer): + """Head to predict a distogram. + + Jumper et al. (2021) Suppl. Sec. 1.9.8 "Distogram prediction" + """ + + def __init__(self, channel_num, config, name='distogram_head'): + super(DistogramHead, self).__init__() + self.config = config + # self.global_config = global_config + + self.half_logits = nn.Linear(channel_num['pair_channel'], + self.config.num_bins, name='half_logits') + init_final_linear(self.half_logits) + + def forward(self, representations): + """Builds DistogramHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'pair': pair representation, shape [batch, N_res, N_res, c_z]. + + Returns: + Dictionary containing: + * logits: logits for distogram, shape [batch, N_res, N_res, N_bins]. + * bin_breaks: array containing bin breaks, shape [batch, N_bins - 1]. + """ + half_logits = self.half_logits(representations['pair']) + + logits = half_logits + paddle.transpose(half_logits, perm=[0, 2, 1, 3]) + breaks = paddle.linspace(self.config.first_break, self.config.last_break, + self.config.num_bins - 1) + breaks = paddle.tile(breaks[None, :], + repeat_times=[logits.shape[0], 1]) + + return { + 'logits': logits, + 'bin_edges': breaks} + + def loss(self, value, batch): + return _distogram_log_loss(value['logits'], value['bin_edges'], + batch, self.config.num_bins) + + +class InvariantPointAttention(nn.Layer): + """Invariant Point attention module. + + The high-level idea is that this attention module works over a set of points + and associated orientations in 3D space (e.g. protein residues). + + Each residue outputs a set of queries and keys as points in their local + reference frame. The attention is then defined as the euclidean distance + between the queries and keys in the global frame. + + Jumper et al. (2021) Suppl. Alg. 22 "InvariantPointAttention" + """ + def __init__(self, channel_num, config, global_config, + dist_epsilon=1e-8): + super(InvariantPointAttention, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.dist_epsilon = dist_epsilon + + num_head = self.config.num_head + num_scalar_qk = self.config.num_scalar_qk + num_point_qk = self.config.num_point_qk + num_scalar_v = self.config.num_scalar_v + num_point_v = self.config.num_point_v + num_output = self.config.num_channel + + assert num_scalar_qk > 0 + assert num_point_qk > 0 + assert num_point_v > 0 + + self.q_scalar = nn.Linear( + channel_num['seq_channel'], num_head * num_scalar_qk) + self.kv_scalar = nn.Linear( + channel_num['seq_channel'], + num_head * (num_scalar_v + num_scalar_qk)) + + self.q_point_local = nn.Linear( + channel_num['seq_channel'], num_head * 3 * num_point_qk) + self.kv_point_local = nn.Linear( + channel_num['seq_channel'], + num_head * 3 * (num_point_qk + num_point_v)) + + tpw = np.log(np.exp(1.) - 1.) + self.trainable_point_weights = paddle.create_parameter( + [num_head], 'float32', + default_initializer=nn.initializer.Constant(tpw)) + + self.attention_2d = nn.Linear(channel_num['pair_channel'], num_head) + + if self.global_config.zero_init: + init_w = nn.initializer.Constant(value=0.0) + else: + init_w = nn.initializer.XavierUniform() + + c = num_scalar_v + num_point_v * 4 + channel_num['pair_channel'] + self.output_projection = nn.Linear( + num_head * c, num_output, + weight_attr=paddle.ParamAttr(initializer=init_w)) + + def forward(self, single_act: paddle.Tensor, pair_act: paddle.Tensor, + mask: paddle.Tensor, affine: quat_affine.QuatAffine): + # single_act: [B, N, C] + # pair_act: [B, N, M, C'] + # mask: [B, N, 1] + num_residues = single_act.shape[1] + num_head = self.config.num_head + num_scalar_qk = self.config.num_scalar_qk + num_point_qk = self.config.num_point_qk + num_scalar_v = self.config.num_scalar_v + num_point_v = self.config.num_point_v + num_output = self.config.num_channel + + # Construct scalar queries of shape: + # [batch_size, num_query_residues, num_head, num_points] + q_scalar = self.q_scalar(single_act) + q_scalar = paddle.reshape( + q_scalar, [-1, num_residues, num_head, num_scalar_qk]) + + # Construct scalar keys/values of shape: + # [batch_size, num_target_residues, num_head, num_points] + kv_scalar = self.kv_scalar(single_act) + kv_scalar = paddle.reshape( + kv_scalar, + [-1, num_residues, num_head, num_scalar_v + num_scalar_qk]) + k_scalar, v_scalar = paddle.split( + kv_scalar, [num_scalar_qk, -1], axis=-1) + + # Construct query points of shape: + # [batch_size, num_residues, num_head, num_point_qk] + q_point_local = self.q_point_local(single_act) + q_point_local = paddle.split(q_point_local, 3, axis=-1) + + q_point_global = affine.apply_to_point(q_point_local, extra_dims=1) + q_point = [ + paddle.reshape(x, [-1, num_residues, num_head, num_point_qk]) + for x in q_point_global] + + # Construct key and value points. + # Key points shape [batch_size, num_residues, num_head, num_point_qk] + # Value points shape [batch_size, num_residues, num_head, num_point_v] + kv_point_local = self.kv_point_local(single_act) + kv_point_local = paddle.split(kv_point_local, 3, axis=-1) + + kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1) + kv_point_global = [ + paddle.reshape(x, [-1, num_residues, num_head, num_point_qk + num_point_v]) + for x in kv_point_global] + + k_point, v_point = list( + zip(*[ + paddle.split(x, [num_point_qk, -1], axis=-1) + for x in kv_point_global + ])) + + # We assume that all queries and keys come iid from N(0, 1) distribution + # and compute the variances of the attention logits. + # Each scalar pair (q, k) contributes Var q*k = 1 + scalar_variance = max(num_scalar_qk, 1) * 1. + # Each point pair (q, k) contributes Var [0.5 ||q||^2 - ] = 9 / 2 + point_variance = max(num_point_qk, 1) * 9. / 2 + + # Allocate equal variance to scalar, point and attention 2d parts so that + # the sum is 1. + + num_logit_terms = 3 + scalar_weights = np.sqrt(1.0 / (num_logit_terms * scalar_variance)) + point_weights = np.sqrt(1.0 / (num_logit_terms * point_variance)) + attention_2d_weights = np.sqrt(1.0 / (num_logit_terms)) + + trainable_point_weights = nn.functional.softplus( + self.trainable_point_weights) + point_weights *= paddle.unsqueeze( + trainable_point_weights, axis=1) + + # [B, R, H, C] => [B, H, R, C], put head dim first + q_point = [paddle.transpose(x, [0, 2, 1, 3]) for x in q_point] + k_point = [paddle.transpose(x, [0, 2, 1, 3]) for x in k_point] + v_point = [paddle.transpose(x, [0, 2, 1, 3]) for x in v_point] + + dist2 = [ + paddle.square(paddle.unsqueeze(qx, axis=-2) - \ + paddle.unsqueeze(kx, axis=-3)) + for qx, kx in zip(q_point, k_point)] + dist2 = sum(dist2) + + attn_qk_point = -0.5 * paddle.sum( + paddle.unsqueeze(point_weights, axis=[1, 2]) * dist2, axis=-1) + + q = paddle.transpose(scalar_weights * q_scalar, [0, 2, 1, 3]) + k = paddle.transpose(k_scalar, [0, 2, 1, 3]) + v = paddle.transpose(v_scalar, [0, 2, 1, 3]) + attn_qk_scalar = paddle.matmul(q, paddle.transpose(k, [0, 1, 3, 2])) + attn_logits = attn_qk_scalar + attn_qk_point + + attention_2d = self.attention_2d(pair_act) + attention_2d = paddle.transpose(attention_2d, [0, 3, 1, 2]) + attention_2d = attention_2d_weights * attention_2d + attn_logits += attention_2d + + mask_2d = mask * paddle.transpose(mask, [0, 2, 1]) + attn_logits -= 1e5 * (1. - mask_2d.unsqueeze(1)) + + # [batch_size, num_head, num_query_residues, num_target_residues] + attn = nn.functional.softmax(attn_logits) + + # o_i^h + # [batch_size, num_query_residues, num_head, num_head * num_scalar_v] + result_scalar = paddle.matmul(attn, v) + result_scalar = paddle.transpose(result_scalar, [0, 2, 1, 3]) + + # o_i^{hp} + # [batch_size, num_query_residues, num_head, num_head * num_point_v] + result_point_global = [ + paddle.sum(paddle.unsqueeze(attn, -1) * paddle.unsqueeze(vx, -3), + axis=-2) for vx in v_point] + result_point_global = [ + paddle.transpose(x, [0, 2, 1, 3]) for x in result_point_global] + + # \tilde{o}_i^h + # [batch_size, num_residues, num_head, pair_channel] + result_attention_over_2d = paddle.einsum( + 'nhij,nijc->nihc', attn, pair_act) + + # Reshape, global-to-local and save + result_scalar = paddle.reshape( + result_scalar, [-1, num_residues, num_head * num_scalar_v]) + result_point_global = [ + paddle.reshape(x, [-1, num_residues, num_head * num_point_v]) + for x in result_point_global] + result_point_local = affine.invert_point( + result_point_global, extra_dims=1) + result_attention_over_2d = paddle.reshape( + result_attention_over_2d, + [-1, num_residues, num_head * self.channel_num['pair_channel']]) + + result_point_local_norm = paddle.sqrt( + self.dist_epsilon + paddle.square(result_point_local[0]) + \ + paddle.square(result_point_local[1]) + \ + paddle.square(result_point_local[2])) + + output_features = [result_scalar] + output_features.extend(result_point_local) + output_features.extend( + [result_point_local_norm, result_attention_over_2d]) + + final_act = paddle.concat(output_features, axis=-1) + return self.output_projection(final_act) + + +class MultiRigidSidechain(nn.Layer): + """Class to make side chain atoms.""" + def __init__(self, channel_num, config, global_config): + super(MultiRigidSidechain, self).__init__() + + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + c = self.config.num_channel + self.input_projection = nn.Linear(channel_num['seq_channel'], c) + self.input_projection_1 = nn.Linear(channel_num['seq_channel'], c) + + for i in range(self.config.num_residual_block): + l1, l2 = 'resblock1', 'resblock2' + if i > 0: + l1, l2 = f'resblock1_{i}', f'resblock2_{i}' + + init_w_1 = nn.initializer.KaimingNormal() + if self.global_config.zero_init: + init_w_2 = nn.initializer.Constant(value=0.) + else: + init_w_2 = nn.initializer.XavierUniform() + + setattr(self, l1, nn.Linear( + c, c, weight_attr=paddle.ParamAttr(initializer=init_w_1))) + setattr(self, l2, nn.Linear( + c, c, weight_attr=paddle.ParamAttr(initializer=init_w_2))) + + self.unnormalized_angles = nn.Linear(c, 14) + + def forward(self, affine, single_act, init_single_act, aatype): + single_act = self.input_projection(nn.functional.relu(single_act)) + init_single_act = self.input_projection_1( + nn.functional.relu(init_single_act)) + act = single_act + init_single_act + + for i in range(self.config.num_residual_block): + l1, l2 = 'resblock1', 'resblock2' + if i > 0: + l1, l2 = f'resblock1_{i}', f'resblock2_{i}' + + old_act = act + act = getattr(self, l1)(nn.functional.relu(act)) + act = getattr(self, l2)(nn.functional.relu(act)) + act += old_act + + # Map activations to torsion angles. Shape: (num_res, 14). + num_res = act.shape[1] + unnormalized_angles = self.unnormalized_angles( + nn.functional.relu(act)) + unnormalized_angles = paddle.reshape( + unnormalized_angles, [-1, num_res, 7, 2]) + angles = l2_normalize(unnormalized_angles, axis=-1) + + outputs = { + 'angles_sin_cos': angles, # (B, N, 7, 2) + 'unnormalized_angles_sin_cos': + unnormalized_angles, # (B, N, 7, 2) + } + + # Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" + backbone_to_global = r3.rigids_from_quataffine(affine) + all_frames_to_global = all_atom.torsion_angles_to_frames( + aatype, backbone_to_global, angles) + pred_positions = all_atom.frames_and_literature_positions_to_atom14_pos( + aatype, all_frames_to_global) + + # Outputs1 (Rot + Trans) + outputs.update({ + 'atom_pos': pred_positions.translation, # (B, N, 14, 3) + 'frames_rot': all_frames_to_global.rot.rotation, # (B, N, 8, 3, 3) + 'frames_trans': all_frames_to_global.trans.translation, # (B, N, 8, 3) + }) + + # ## Outputs2 (Rigids) + # outputs.update({ + # 'atom_pos': pred_positions.translation, # (B, N, 14, 3) + # 'frames': all_frames_to_global, # (B, N, 8, 3, 3) + # }) + + return outputs + + +class FoldIteration(nn.Layer): + """A single iteration of the main structure module loop. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" lines 6-21 + + First, each residue attends to all residues using InvariantPointAttention. + Then, we apply transition layers to update the hidden representations. + Finally, we use the hidden representations to produce an update to the + affine of each residue. + """ + def __init__(self, channel_num, config, global_config): + super(FoldIteration, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + self.invariant_point_attention = InvariantPointAttention( + channel_num, config, global_config) + self.attention_layer_norm = nn.LayerNorm(channel_num['seq_channel']) + + for i in range(self.config.num_layer_in_transition): + if i < self.config.num_layer_in_transition - 1: + init_w = nn.initializer.KaimingNormal() + elif self.global_config.zero_init: + init_w = nn.initializer.Constant(value=0.0) + else: + init_w = nn.initializer.XavierUniform() + + layer_name, c_in = 'transition', channel_num['seq_channel'] + if i > 0: + layer_name, c_in = f'transition_{i}', self.config.num_channel + + setattr(self, layer_name, nn.Linear( + c_in, self.config.num_channel, + weight_attr=paddle.ParamAttr(initializer=init_w))) + + self.ipa_dropout = nn.Dropout(p=self.config.dropout) + self.transition_dropout = nn.Dropout(p=self.config.dropout) + self.transition_layer_norm = nn.LayerNorm(self.config.num_channel) + + if self.global_config.zero_init: + last_init_w = nn.initializer.Constant(value=0.0) + else: + last_init_w = nn.initializer.XavierUniform() + + # Jumper et al. (2021) Alg. 23 "Backbone update" + self.affine_update = nn.Linear( + self.config.num_channel, 6, + weight_attr=paddle.ParamAttr(initializer=last_init_w)) + + self.rigid_sidechain = MultiRigidSidechain( + channel_num, self.config.sidechain, self.global_config) + + def forward(self, activations, init_single_act, static_pair_act, + seq_mask, aatype): + affine = quat_affine.QuatAffine.from_tensor(activations['affine']) + act = activations['act'] + + attn = self.invariant_point_attention( + act, static_pair_act, seq_mask, affine) + act += attn + act = self.ipa_dropout(act) + act = self.attention_layer_norm(act) + + input_act = act + for i in range(self.config.num_layer_in_transition): + layer_name = 'transition' + if i > 0: + layer_name = f'transition_{i}' + + act = getattr(self, layer_name)(act) + + if i < self.config.num_layer_in_transition - 1: + act = nn.functional.relu(act) + + act += input_act + act = self.transition_dropout(act) + act = self.transition_layer_norm(act) + + affine_update = self.affine_update(act) + affine = affine.pre_compose(affine_update) + + sc = self.rigid_sidechain( + affine.scale_translation(self.config.position_scale), + act, init_single_act, aatype) + outputs = {'affine': affine.to_tensor(), 'sc': sc} + + affine = affine.stop_rot_gradient() + new_activations = { + 'act': act, + 'affine': affine.to_tensor() + } + return new_activations, outputs + + +class StructureModule(nn.Layer): + """StructureModule as a network head. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" + """ + def __init__(self, channel_num, config, global_config): + super(StructureModule, self).__init__() + assert config.num_layer > 0 + + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + self.single_layer_norm = nn.LayerNorm(channel_num['seq_channel']) + self.initial_projection = nn.Linear( + channel_num['seq_channel'], config.num_channel) + self.pair_layer_norm = nn.LayerNorm(channel_num['pair_channel']) + + self.fold_iteration = FoldIteration( + channel_num, config, global_config) + + def forward(self, representations, batch): + """tbd.""" + + output = self._generate_affines(representations, batch) + + ret = dict() + ret['representations'] = {'structure_module': output['act']} + + # NOTE: pred unit is nanometer, *position_scale to scale back to + # angstroms to match unit of PDB files. + # (L, B, N, 7), L = FoldIteration layers + scale = paddle.to_tensor( + [1.] * 4 + [self.config.position_scale] * 3, 'float32') + ret['traj'] = output['affine'] * paddle.unsqueeze( + scale, axis=[0, 1, 2]) + + ret['sidechains'] = output['sc'] + + # (B, N, 14, 3) + atom14_pred_positions = output['sc']['atom_pos'][-1] + ret['final_atom14_positions'] = atom14_pred_positions + + # (B, N, 14) + ret['final_atom14_mask'] = batch['atom14_atom_exists'] + + # (B, N, 37, 3) + atom37_pred_positions = all_atom.atom14_to_atom37( + atom14_pred_positions, batch) + atom37_pred_positions *= paddle.unsqueeze( + batch['atom37_atom_exists'], axis=-1) + ret['final_atom_positions'] = atom37_pred_positions + + # (B, N, 37) + ret['final_atom_mask'] = batch['atom37_atom_exists'] + + # (B, N, 7) + ret['final_affines'] = ret['traj'][-1] + + return ret + + def _generate_affines(self, representations, batch): + """Generate predicted affines for a single chain. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" + + This is the main part of the structure module - it iteratively applies + folding to produce a set of predicted residue positions. + + Args: + representations: Representations dictionary. + batch: Batch dictionary. + + Returns: + A dictionary containing residue affines and sidechain positions. + """ + seq_mask = paddle.unsqueeze(batch['seq_mask'], axis=-1) + + single_act = self.single_layer_norm(representations['single']) + + init_single_act = single_act + single_act = self.initial_projection(single_act) + pair_act = self.pair_layer_norm(representations['pair']) + affine = generate_new_affine(seq_mask) + + outputs = [] + activations = {'act': single_act, 'affine': affine.to_tensor()} + for _ in range(self.config.num_layer): + activations, output = self.fold_iteration( + activations, init_single_act, pair_act, + seq_mask, batch['aatype']) + outputs.append(output) + + output = dict() + for k in outputs[0].keys(): + if k == 'sc': + output[k] = dict() + for l in outputs[0][k].keys(): + output[k][l] = paddle.stack([o[k][l] for o in outputs]) + else: + output[k] = paddle.stack([o[k] for o in outputs]) + + output['act'] = activations['act'] + return output diff --git a/apps/protein_folding/helixfold_cpu/layers/net.py b/apps/protein_folding/helixfold_cpu/layers/net.py new file mode 100644 index 00000000..e1f030ae --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/layers/net.py @@ -0,0 +1,319 @@ +import pdb +import numpy as np +import paddle +import paddle.nn as nn +from paddle.fluid.framework import _dygraph_tracer +from paddle.distributed.fleet.utils import recompute +from tools import residue_constants +from tools import folding +from layers.subnets import EmbeddingsAndEvoformer +from layers.head import ( + MaskedMsaHead, + DistogramHead, + PredictedLDDTHead, + PredictedAlignedErrorHead, + ExperimentallyResolvedHead) + +# Map head name in config to head name in model params +Head_names = { + 'masked_msa': 'masked_msa_head', + 'distogram': 'distogram_head', + 'predicted_lddt': 'predicted_lddt_head', + 'predicted_aligned_error': 'predicted_aligned_error_head', + 'experimentally_resolved': 'experimentally_resolved_head', # finetune loss +} + + +def recompute_wrapper(func, *args, is_recompute=True): + """Function wrapper for recompute""" + if is_recompute: + return recompute(func, *args) + else: + return func(*args) + + +class AlphaFold(nn.Layer): + """AlphaFold model with recycling. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" + """ + def __init__(self, config): + super(AlphaFold, self).__init__() + self.channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + } + self.config = config + self.global_config = config.global_config + + self.alphafold_iteration = AlphaFoldIteration( + self.channel_num, + self.config, + self.global_config) + + def forward(self, + batch, + ensemble_representations=False, + ): + """Run the AlphaFold model. + + Arguments: + batch: Dictionary with inputs to the AlphaFold model. + ensemble_representations: Whether to use ensembling of representations. + + Returns: + The output of AlphaFoldIteration is a nested dictionary containing + predictions from the various heads. + + """ + inner_batch, num_residues = batch['aatype'].shape[1:] + + def _get_prev(ret): + new_prev = { + 'prev_pos': ret['structure_module']['final_atom_positions'], + 'prev_msa_first_row': ret['representations']['msa_first_row'], + 'prev_pair': ret['representations']['pair'], + } + + for k in new_prev.keys(): + new_prev[k].stop_gradient = True + + return new_prev + + def _run_single_recycling(prev, recycle_idx): + print(f'########## recycle id: {recycle_idx} ##########') + + if self.config.resample_msa_in_recycling: + # (B, (R+1)*E, N, ...) + # B: batch size, R: recycling number, + # E: ensemble number, N: residue number + num_ensemble = inner_batch // (self.config.num_recycle + 1) + ensembled_batch = dict() + for k in batch.keys(): + start = recycle_idx * num_ensemble + end = start + num_ensemble + ensembled_batch[k] = batch[k][:, start:end] + else: + # (B, E, N, ...) + num_ensemble = inner_batch + ensembled_batch = batch + + non_ensembled_batch = prev + return self.alphafold_iteration( + ensembled_batch, + non_ensembled_batch, + ensemble_representations=ensemble_representations) + + if self.config.num_recycle: + # aatype: (B, E, N), zeros_bn: (B, N) + zeros_bn = paddle.zeros_like(paddle.Tensor(batch['aatype'][:, 0]), dtype='float32') + + emb_config = self.config.embeddings_and_evoformer + prev = { + 'prev_pos': paddle.tile( + zeros_bn[..., None, None], + [1, 1, residue_constants.atom_type_num, 3]), + 'prev_msa_first_row': paddle.tile( + zeros_bn[..., None], + [1, 1, emb_config.msa_channel]), + 'prev_pair': paddle.tile( + zeros_bn[..., None, None], + [1, 1, num_residues, emb_config.pair_channel]), + } + + if 'num_iter_recycling' in batch: + # Training trick: dynamic recycling number + num_iter = batch['num_iter_recycling'].numpy()[0, 0] + num_iter = min(int(num_iter), self.config.num_recycle) + else: + num_iter = self.config.num_recycle + + for recycle_idx in range(num_iter): + ret = _run_single_recycling(prev, recycle_idx) + prev = _get_prev(ret) + + else: + prev = {} + num_iter = 0 + + return _run_single_recycling(prev, num_iter) + + +class AlphaFoldIteration(nn.Layer): + """A single recycling iteration of AlphaFold architecture. + + Computes ensembled (averaged) representations from the provided features. + These representations are then passed to the various heads + that have been requested by the configuration file. Each head also returns a + loss which is combined as a weighted sum to produce the total loss. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 3-22 + """ + + def __init__(self, + channel_num, + config, + global_config, + ): + super(AlphaFoldIteration, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + # copy these config for later usage + self.channel_num['extra_msa_channel'] = config.embeddings_and_evoformer.extra_msa_channel + self.channel_num['msa_channel'] = config.embeddings_and_evoformer.msa_channel + self.channel_num['pair_channel'] = config.embeddings_and_evoformer.pair_channel + self.channel_num['seq_channel'] = config.embeddings_and_evoformer.seq_channel + + self.evoformer = EmbeddingsAndEvoformer( + channel_num=self.channel_num, + config=self.config['embeddings_and_evoformer'], + global_config=self.global_config) + + Head_modules = { + 'masked_msa': MaskedMsaHead, + 'distogram': DistogramHead, + 'structure_module': folding.StructureModule, + 'predicted_lddt': PredictedLDDTHead, + 'predicted_aligned_error': PredictedAlignedErrorHead, + 'experimentally_resolved': ExperimentallyResolvedHead, # finetune loss + } + + self.used_heads = [] + self.heads = {} + for head_name, head_config in sorted(self.config.heads.items()): + if head_name not in Head_modules.keys(): + continue + + self.used_heads.append(head_name) + module = Head_modules[head_name]( + self.channel_num, head_config, self.global_config) + # setattr(self, head_name_, module) + self.heads[head_name] = module + + def filtered_inputs(self, d:dict, ks:list): + ret = [] + for k, v in d.items(): + if k in ks: + ret.append(v) + return ret + + def __call__(self, + ensembled_batch, + non_ensembled_batch, + ensemble_representations=False): + num_ensemble = ensembled_batch['seq_length'].shape[1] + print(ensembled_batch['seq_length'].shape) + if not ensemble_representations: + assert num_ensemble == 1 + + def _slice_batch(i): + b = {k: v[:, i] for k, v in ensembled_batch.items()} + b.update(non_ensembled_batch) + return b + + batch0 = _slice_batch(0) + res_evoformer = self.evoformer(*self.filtered_inputs(batch0, ks=[ + 'target_feat', + 'msa_feat', + 'seq_mask', + 'aatype', + 'residue_index', + 'template_mask', + 'template_aatype', + 'template_pseudo_beta_mask', + 'template_pseudo_beta', + 'template_all_atom_positions', + 'template_all_atom_masks', + 'extra_msa', + 'extra_has_deletion', + 'extra_deletion_value', + 'extra_msa_mask', + 'msa_mask', + 'prev_pos', + 'prev_msa_first_row', + 'prev_pair' + ])) + + representations = { + 'single': res_evoformer[0], + 'pair': res_evoformer[1], + 'msa': res_evoformer[2], + 'msa_first_row': res_evoformer[3] + } + + # MSA representations are not ensembled + msa_representation = representations['msa'] + del representations['msa'] + + if ensemble_representations: + for i in range(1, num_ensemble): + batch = _slice_batch(i) + representations_update = self.evoformer(batch) + for k in representations.keys(): + representations[k] += representations_update[k] + + for k in representations.keys(): + representations[k] /= num_ensemble + 0.0 + + representations['msa'] = msa_representation + ret = {'representations': representations} + + def _forward_heads(representations, ret, batch0): + for head_name, head_config in self._get_heads(): + # Skip PredictedLDDTHead and PredictedAlignedErrorHead until + # StructureModule is executed. + if head_name in ('predicted_lddt', 'predicted_aligned_error'): + continue + else: + # ret[head_name] = getattr(self, head_name_)(representations, batch0) + if head_name == 'structure_module': + ret[head_name] = self.heads[head_name](representations, batch0) + else: + ret[head_name] = self.heads[head_name](representations) + if 'representations' in ret[head_name]: + # Extra representations from the head. Used by the + # structure module to provide activations for the PredictedLDDTHead. + representations.update(ret[head_name].pop('representations')) + + if self.config.heads.get('predicted_lddt.weight', 0.0): + # Add PredictedLDDTHead after StructureModule executes. + head_name = 'predicted_lddt' + # Feed all previous results to give access to structure_module result. + head_config = self.config.heads[head_name] + # ret[head_name] = getattr(self, head_name_)(representations, batch0) + if head_name == 'structure_module': + ret[head_name] = self.heads[head_name](representations, batch0) + else: + ret[head_name] = self.heads[head_name](representations) + + if ('predicted_aligned_error' in self.config.heads + and self.config.heads.get('predicted_aligned_error.weight', 0.0)): + # Add PredictedAlignedErrorHead after StructureModule executes. + head_name = 'predicted_aligned_error' + # Feed all previous results to give access to structure_module result. + head_config = self.config.heads[head_name] + # ret[head_name] = getattr(self, head_name_)(representations, batch0) + if head_name == 'structure_module': + ret[head_name] = self.heads[head_name](representations, batch0) + else: + ret[head_name] = self.heads[head_name](representations) + + return ret + + tracer = _dygraph_tracer() + if tracer._amp_dtype == "bfloat16": + raise NotImplementedError("Currently CPU optimized inference is unsupported on bfloat16.") + else: + with paddle.no_grad(): + ret = _forward_heads(representations, ret, batch0) + + return ret + + def _get_heads(self): + assert 'structure_module' in self.used_heads + head_names = [h for h in self.used_heads] + + for k in head_names: + yield k, self.config.heads[k] diff --git a/apps/protein_folding/helixfold_cpu/layers/static_backbones.py b/apps/protein_folding/helixfold_cpu/layers/static_backbones.py new file mode 100644 index 00000000..d7fae9e3 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/layers/static_backbones.py @@ -0,0 +1,343 @@ +import pdb +from layers.backbones import ( + EvoformerIteration, + Embeddings, + SingleTemplateEmbedding, + SingleActivations +) +from layers.static_basics import StaticModule, JitModule +import paddle +from paddle.distributed.fleet.utils import recompute +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +from paddle import nn +from tools import all_atom +import time +import os +from joblib import delayed, Parallel +import numpy as np +from argparse import ArgumentParser as Parser + + +class StaticEvoformerIteration(JitModule): + def __init__(self, + config:dict, # cfg['model']['embeddings_and_evoformer']['evoformer'] + global_config:dict, + feed_dict:dict, + channel_num:dict, + n_cpus:int, + is_extra_msa:bool, + module_prefix:str = 'evoformeriteration', + root_weights:str = 'static_modules', + is_pdinfer_init:bool = True, + ) -> None: + self.c = config + self.gc = global_config + super(StaticEvoformerIteration, self).__init__( + config=self.c, + global_config=self.gc, + pdmodule=EvoformerIteration(channel_num, self.c, self.gc, is_extra_msa), + feed_dict=feed_dict, + module_prefix=module_prefix, + n_cpus=n_cpus, + channel_num=channel_num, + root_weights=root_weights, + is_pdinfer_init=is_pdinfer_init + ) + # basic hyper-params + self.is_extra_msa = is_extra_msa + + +class StaticEmbeddings(JitModule): + ''' + Embedding layer in EmbeddingsAndEvoformer + input: + target_feat, # [1, len_dim, 22], dtype='float32' + msa_feat, # [1, 508, len_dim, 49], dtype='float32' + seq_mask, # [1, len_dim], dtype='float32' + aatype, # [1, len_dim], dtype='int32' + residue_index, # [1, len_dim], dtype='float32' + template_mask, # [1, 4], dtype='float32' + template_aatype, # [1, 4, len_dim], dtype="int32" + template_pseudo_beta_mask, # [1, 4, len_dim], dtype='float32' + template_pseudo_beta, # [1, 4, len_dim, 3], dtype='float32' + template_all_atom_positions, # [1, 4, len_dim, 37, 3], dtype='float32' + template_all_atom_masks, # [1, 4, len_dim, 37], dtype='float32' + extra_msa, # [1, 5120, len_dim], dtype='float32' + extra_has_deletion, # [1, 5120, len_dim], dtype='float32' + extra_deletion_value, # [1, 5120, len_dim], dtype='float32' + prev_pos=None, # [1, len_dim, 37, 3], dtype='float32' + prev_msa_first_row=None, # [1, len_dim, 256], dtype='float32' + prev_pair=None # [1, len_dim, len_dim, 128], dtype='float32' + output: + msa_activations_raw, # (1, 508, len_dim, 256) + extra_msa_act, # (1, 5120, len_dim, 64) + extra_pair_act, # (1, len_dim, len_dim, 128) + mask_2d # (1, len_dim, len_dim) + ''' + def __init__(self, + config:dict, # cfg['model']['embeddings_and_evoformer'] + global_config:dict, + feed_dict:dict, + channel_num:dict, + n_cpus:int, + module_prefix:str='embeddings', + root_weights:str='static_modules', + is_pdinfer_init:bool=False) -> None: + self.c = config + self.gc = global_config + super(StaticEmbeddings, self).__init__( + config=self.c, + global_config=self.gc, + pdmodule=Embeddings(channel_num, self.c, self.gc), + feed_dict=feed_dict, + module_prefix=module_prefix, + n_cpus=n_cpus, + channel_num=channel_num, + root_weights=root_weights, + is_pdinfer_init=is_pdinfer_init + ) + + +class StaticSingleTemplateEmbedding(JitModule): + ''' + Embedding layer in EmbeddingsAndEvoformer + input: + msa_mask, # [1, 508, len_dim], dtype='float32' + torsion_angles_mask, # ret from folding, [1, 4, len_dim, 7], dtype='float32' + msa_activations_raw, # [1, 508, len_dim, 256], dtype='float32' + template_features # [1, 4, len_dim, 57], dtype='float32' + output: + msa_activations, # + msa_mask # + ''' + def __init__(self, + config:dict, # cfg['model']['embeddings_and_evoformer'] + global_config:dict, + feed_dict:dict, + channel_num:dict, + n_cpus:int, + module_prefix:str='singletemplateembedding', + root_weights:str='static_modules', + is_pdinfer_init:bool=False) -> None: + self.c = config + self.gc = global_config + if self.c.template.enabled: + channel_num['template_angle'] = 57 + channel_num['template_pair'] = 88 + super(StaticSingleTemplateEmbedding, self).__init__( + config=self.c, + global_config=self.gc, + pdmodule=SingleTemplateEmbedding(channel_num, self.c, self.gc), + feed_dict=feed_dict, + module_prefix=module_prefix, + n_cpus=n_cpus, + channel_num=channel_num, + root_weights=root_weights, + is_pdinfer_init=is_pdinfer_init + ) + + +class StaticSingleActivations(JitModule): + def __init__(self, + config:dict, # cfg['model']['embeddings_and_evoformer'] + global_config:dict, + feed_dict:dict, + channel_num:dict, + n_cpus:int, + module_prefix:str='single_activations', + root_weights:str='static_modules', + is_pdinfer_init:bool=False) -> None: + self.c = config + self.gc = global_config + super(StaticSingleActivations, self).__init__( + config=self.c, + global_config=self.gc, + pdmodule=SingleActivations(channel_num, self.c, self.gc), + feed_dict=feed_dict, + module_prefix=module_prefix, + n_cpus=n_cpus, + channel_num=channel_num, + root_weights=root_weights, + is_pdinfer_init=is_pdinfer_init + ) + + +class StaticExtraMsa(object): + def __init__(self, + config:dict, # cfg['model']['embeddings_and_evoformer'] + global_config:dict, + feed_dict:dict, + channel_num:dict, + n_cpus:int, + module_prefix:str='extramsa', + root_weights:str='static_modules', + is_pdinfer_init:bool=False) -> None: + + self.c = config + self.gc = global_config + n_layers = self.c['extra_msa_stack_num_block'] + self.extra_msa_stack = [] + self.is_extra_msa = True + for i in range(n_layers): + self.extra_msa_stack.append(StaticEvoformerIteration( + self.c['evoformer'], + self.gc, + feed_dict, + channel_num, + n_cpus, + is_extra_msa=self.is_extra_msa, + module_prefix='%s.evoformeriteration_%d' % (module_prefix, i), + root_weights=root_weights, + is_pdinfer_init=is_pdinfer_init + )) + + def pd2np(self, d:dict) -> dict: + res = {} + for k,v in d.items(): + if isinstance(v, pd.Tensor): + res[k] = v.detach().numpy() + else: + res[k] = v + return res + + def __call__(self, + feeddict:dict + ) -> dict: + extra_msa_act = feeddict['extra_msa_act'] # (1, 5120, len, 64) + extra_pair_act = feeddict['extra_pair_act'] # (1, len, len, 128) + extra_msa_mask = feeddict['extra_msa_mask'] # (1, 5120, len) + mask_2d = feeddict['mask_2d'] # (1, len, len) + for i, extra_msa_stack_iteration in enumerate(self.extra_msa_stack): + print('# [INFO] extra_msa_stack_iteration_%d' % i) + res = extra_msa_stack_iteration(self.pd2np({ + 'extra_msa_act':extra_msa_act, + 'extra_pair_act':extra_pair_act, + 'extra_msa_mask':extra_msa_mask, + 'mask_2d':mask_2d + })) + ks = list(res.keys()) + extra_msa_act = res[ks[0]] # ['extra_msa_act_new'] + extra_pair_act = res[ks[1]] # ['extra_pair_act_new'] + return { + 'extra_msa_act':extra_msa_act, # (1, 5120, len, 64) + 'extra_pair_act':extra_pair_act # (1, len, len, 128) + } + + +class StaticEvoformer(object): + def __init__(self, + config:dict, # cfg['model']['embeddings_and_evoformer'] + global_config:dict, + feed_dict:dict, + channel_num:dict, + n_cpus:int, + module_prefix:str='evoformer', + root_weights:str='static_modules', + is_pdinfer_init:bool=False) -> None: + + self.c = config + self.gc = global_config + n_layers = self.c['evoformer_num_block'] + self.is_extra_msa = False + + feed2evoformer = { + 'msa_act':feed_dict['msa_activations'], + 'pair_act':feed_dict['extra_pair_act'], + 'msa_mask':feed_dict['msa_mask'], + 'pair_mask':feed_dict['mask_2d'] + } + if not os.path.isfile('{}/{}.evoformeriteration_0.pdiparams'.format( + root_weights, module_prefix)): + Parallel(n_jobs=-1)( + delayed(self._create_layer)( + feed2evoformer, + channel_num, + n_cpus, + is_extra_msa=self.is_extra_msa, + module_prefix='%s.evoformeriteration_%d' % (module_prefix, i), + root_weights=root_weights, + is_pdinfer_init=is_pdinfer_init + ) for i in range(n_layers) + ) + + print('# [INFO] parallel compilation of iteration layers were done') + # sequential compilation of evoformer iterations + self.evoformer_iteration = [] + for i in range(n_layers): + self.evoformer_iteration.append(StaticEvoformerIteration( + self.c['evoformer'], + self.gc, + feed2evoformer, + channel_num, + n_cpus, + is_extra_msa=self.is_extra_msa, + module_prefix='%s.evoformeriteration_%d' % (module_prefix, i), + root_weights=root_weights, + is_pdinfer_init=is_pdinfer_init + )) + + len_dims = list(feed_dict['msa_activations'][:,0].shape) + len_dims[1] = 1 + self.single_activations = StaticSingleActivations( + self.c, + self.gc, + {'msa_activation': np.ones(len_dims, dtype='float32')}, + channel_num, + n_cpus, + module_prefix='%s.single_activations' % module_prefix, + root_weights=root_weights, + is_pdinfer_init=is_pdinfer_init + ) + + def _create_layer(self, feed_dict, channel_num, n_cpus, is_extra_msa, module_prefix, root_weights, is_pdinfer_init): + StaticEvoformerIteration( + self.c['evoformer'], + self.gc, + feed_dict, + channel_num, + n_cpus, + is_extra_msa, + module_prefix, + root_weights, + is_pdinfer_init + ) + + def pd2np(self, d:dict) -> dict: + res = {} + for k,v in d.items(): + if isinstance(v, pd.Tensor): + res[k] = v.detach().numpy() + else: + res[k] = v + return res + + def __call__(self, feed_dict:dict) -> dict: + feed2block = { + 'msa_act':feed_dict['msa_activations'], # (1, 508, len, 256) + 'pair_act':feed_dict['extra_pair_act'], # (1, len, len, 128) + 'msa_mask':feed_dict['msa_mask'], # (1, 508, len) + 'pair_mask':feed_dict['mask_2d'] # (1, len, len) + } + for i, evoformer_block in enumerate(self.evoformer_iteration): + print('# [INFO] evoformer_iteration_%d' % i) + res = evoformer_block(self.pd2np(feed2block)) + ks = list(res.keys()) + msa_activations = res[ks[0]] # ['msa_act'] + extra_pair_act = res[ks[1]] # ['pair_act'] + feed2block['msa_act'] = msa_activations + feed2block['pair_act'] = extra_pair_act + + single_acts = self.single_activations(self.pd2np({ + 'msa_activation':msa_activations[:, 0] + })) + k = list(single_acts.keys())[0] + single_activations = single_acts[k] + return { + 'single_activations':single_activations, + 'pair_activations':extra_pair_act, + 'msa_activations':msa_activations + } diff --git a/apps/protein_folding/helixfold_cpu/layers/static_basics.py b/apps/protein_folding/helixfold_cpu/layers/static_basics.py new file mode 100644 index 00000000..c845aad5 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/layers/static_basics.py @@ -0,0 +1,122 @@ +import paddle as pd +from paddle import nn +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import os +import pdb +import warnings + + +class StaticModule(object): + def __init__(self, + config:dict, + global_config:dict, + pdmodule:nn.Layer, + feed_dict:dict, + module_prefix:str, + n_cpus:int, + channel_num:dict, + root_weights:str = 'static_params', + is_pdinfer_init:bool = True + ) -> None: + # basic hyper-params + self.c = config + self.gc = global_config + self.module_prefix = module_prefix + self.channel_num = channel_num + + f_topo = os.path.join(root_weights, '{}.pdmodel'.format(self.module_prefix)) + f_bin = os.path.join(root_weights, '{}.pdiparams'.format(self.module_prefix)) + if not (os.path.isfile(f_topo) and os.path.isfile(f_bin)): + print('# [statc_basics] build static model') + pdmodule.eval() + warnings.filterwarnings('ignore', 'DAP comm') + specs = [InputSpec(shape=list(v.shape), dtype=v.dtype, name=k) for k, v in feed_dict.items()] + net = to_static(pdmodule, input_spec=specs) + save(net, os.path.join(root_weights,self.module_prefix)) + warnings.resetwarnings() + + print('# [statc_basics] build executor of static model') + pd_cfg = pdinfer.Config(f_topo, f_bin) + if not is_pdinfer_init: + pd_cfg.set_cpu_math_library_num_threads(n_cpus) + pd_cfg.enable_mkldnn() + self.predictor = pdinfer.create_predictor(pd_cfg) + + print('# [statc_basics] build input ports') + self.input_names = self.predictor.get_input_names() + self.input_ports = {} + for k in self.input_names: + assert k in feed_dict.keys() + self.input_ports[k] = self.predictor.get_input_handle(k) + + print('# [statc_basics] build output ports') + self.output_names = self.predictor.get_output_names() + self.output_ports = {} + for k in self.output_names: + self.output_ports[k] = self.predictor.get_output_handle(k) + + def __call__(self, feed_dict:dict) -> dict: + for k, input_port in self.input_ports.items(): + input_port.copy_from_cpu(feed_dict[k]) + self.predictor.run() + return {k:output_port.copy_to_cpu() for k, output_port in self.output_ports.items()} + + +class JitModule(object): + def __init__(self, + config:dict, + global_config:dict, + pdmodule:nn.Layer, + feed_dict:dict, + module_prefix:str, + n_cpus:int, + channel_num:dict, + root_weights:str = 'static_params', + is_pdinfer_init:bool = True + ) -> None: + # basic hyper-params + self.c = config + self.gc = global_config + self.module_prefix = module_prefix + self.channel_num = channel_num + self.n_cpus = n_cpus + self.is_pdinfer_init = is_pdinfer_init + + self.f_topo = os.path.join(root_weights, '{}.pdmodel'.format(self.module_prefix)) + self.f_bin = os.path.join(root_weights, '{}.pdiparams'.format(self.module_prefix)) + if not (os.path.isfile(self.f_topo) and os.path.isfile(self.f_bin)): + print('# [statc_basics] build static model') + pdmodule.eval() + warnings.filterwarnings('ignore', 'DAP comm') + specs = [InputSpec(shape=list(v.shape), dtype=v.dtype, name=k) for k, v in feed_dict.items()] + net = to_static(pdmodule, input_spec=specs) + save(net, os.path.join(root_weights,self.module_prefix)) + warnings.resetwarnings() + + + def __call__(self, feed_dict:dict) -> dict: + # print('# [{}.basics] build JIT graph'.format(self.module_prefix)) + pd_cfg = pdinfer.Config(self.f_topo, self.f_bin) + if not self.is_pdinfer_init: + pd_cfg.set_cpu_math_library_num_threads(self.n_cpus) + pd_cfg.enable_mkldnn() + predictor = pdinfer.create_predictor(pd_cfg) + + # print('# [{}.basics] build input ports'.format(self.module_prefix)) + input_names = predictor.get_input_names() + input_ports = {} + for k in input_names: + assert k in feed_dict.keys() + input_ports[k] = predictor.get_input_handle(k) + + # print('# [{}.basics] build output ports'.format(self.module_prefix)) + output_names = predictor.get_output_names() + output_ports = {} + for k in output_names: + output_ports[k] = predictor.get_output_handle(k) + for k, input_port in input_ports.items(): + input_port.copy_from_cpu(feed_dict[k]) + predictor.run() + return {k:pd.Tensor(output_port.copy_to_cpu()) for k, output_port in output_ports.items()} diff --git a/apps/protein_folding/helixfold_cpu/layers/static_net.py b/apps/protein_folding/helixfold_cpu/layers/static_net.py new file mode 100644 index 00000000..81f614ad --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/layers/static_net.py @@ -0,0 +1,320 @@ +import pdb +import numpy as np +import paddle +import paddle.nn as nn +from paddle.fluid.framework import _dygraph_tracer +from paddle.distributed.fleet.utils import recompute +from tools import residue_constants +from tools import folding +from layers.static_subnets import StaticEmbeddingsAndEvoformer +from layers.head import ( + MaskedMsaHead, + DistogramHead, + PredictedLDDTHead, + PredictedAlignedErrorHead, + ExperimentallyResolvedHead) + +# Map head name in config to head name in model params +Head_names = { + 'masked_msa': 'masked_msa_head', + 'distogram': 'distogram_head', + 'predicted_lddt': 'predicted_lddt_head', + 'predicted_aligned_error': 'predicted_aligned_error_head', + 'experimentally_resolved': 'experimentally_resolved_head', # finetune loss +} + + +def recompute_wrapper(func, *args, is_recompute=True): + """Function wrapper for recompute""" + if is_recompute: + return recompute(func, *args) + else: + return func(*args) + + +class AlphaFold(object): + """AlphaFold model with recycling. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" + """ + def __init__(self, + config, + seq_len, + n_cpus, + module_prefix='alphafold', + root_weights='static_modules', + is_pdinfer_init=False + ): + super(AlphaFold, self).__init__() + self.channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + } + self.config = config + self.global_config = config.global_config + + self.alphafold_iteration = StaticAlphaFoldIteration( + self.config, + self.global_config, + self.channel_num, + seq_len, + n_cpus, + module_prefix='%s.alphafold_iteration' % module_prefix, + root_weights=root_weights, + is_pdinfer_init=is_pdinfer_init) + + def __call__(self, + batch, + ensemble_representations=False, + ): + """Run the AlphaFold model. + + Arguments: + batch: Dictionary with inputs to the AlphaFold model. + ensemble_representations: Whether to use ensembling of representations. + + Returns: + The output of AlphaFoldIteration is a nested dictionary containing + predictions from the various heads. + + """ + inner_batch, num_residues = batch['aatype'].shape[1:] + + def _get_prev(ret): + new_prev = { + 'prev_pos': ret['structure_module']['final_atom_positions'], + 'prev_msa_first_row': ret['representations']['msa_first_row'], + 'prev_pair': ret['representations']['pair'], + } + + for k in new_prev.keys(): + new_prev[k].stop_gradient = True + + return new_prev + + def _run_single_recycling(prev, recycle_idx): + print(f'########## recycle id: {recycle_idx} ##########') + + if self.config.resample_msa_in_recycling: + # (B, (R+1)*E, N, ...) + # B: batch size, R: recycling number, + # E: ensemble number, N: residue number + num_ensemble = inner_batch // (self.config.num_recycle + 1) + ensembled_batch = dict() + for k in batch.keys(): + start = recycle_idx * num_ensemble + end = start + num_ensemble + ensembled_batch[k] = batch[k][:, start:end] + else: + # (B, E, N, ...) + num_ensemble = inner_batch + ensembled_batch = batch + + non_ensembled_batch = prev + return self.alphafold_iteration( + ensembled_batch, + non_ensembled_batch, + ensemble_representations=ensemble_representations) + + if self.config.num_recycle: + # aatype: (B, E, N), zeros_bn: (B, N) + zeros_bn = paddle.zeros_like(paddle.Tensor(batch['aatype'][:, 0]), dtype='float32') + + emb_config = self.config.embeddings_and_evoformer + prev = { + 'prev_pos': paddle.tile( + zeros_bn[..., None, None], + [1, 1, residue_constants.atom_type_num, 3]), + 'prev_msa_first_row': paddle.tile( + zeros_bn[..., None], + [1, 1, emb_config.msa_channel]), + 'prev_pair': paddle.tile( + zeros_bn[..., None, None], + [1, 1, num_residues, emb_config.pair_channel]), + } + + if 'num_iter_recycling' in batch: + # Training trick: dynamic recycling number + num_iter = batch['num_iter_recycling'].numpy()[0, 0] + num_iter = min(int(num_iter), self.config.num_recycle) + else: + num_iter = self.config.num_recycle + + for recycle_idx in range(num_iter): + ret = _run_single_recycling(prev, recycle_idx) + prev = _get_prev(ret) + + else: + prev = {} + num_iter = 0 + + return _run_single_recycling(prev, num_iter) + + +class StaticAlphaFoldIteration(object): + """A single recycling iteration of AlphaFold architecture. + + Computes ensembled (averaged) representations from the provided features. + These representations are then passed to the various heads + that have been requested by the configuration file. Each head also returns a + loss which is combined as a weighted sum to produce the total loss. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 3-22 + """ + + def __init__(self, + config, + global_config, + channel_num, + seq_len, + n_cpus, + module_prefix = 'alphafold_iteration', + root_weights = 'static_modules', + is_pdinfer_init=False + ): + super(StaticAlphaFoldIteration, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + # copy these config for later usage + self.channel_num['extra_msa_channel'] = config.embeddings_and_evoformer.extra_msa_channel + self.channel_num['msa_channel'] = config.embeddings_and_evoformer.msa_channel + self.channel_num['pair_channel'] = config.embeddings_and_evoformer.pair_channel + self.channel_num['seq_channel'] = config.embeddings_and_evoformer.seq_channel + + self.evoformer = StaticEmbeddingsAndEvoformer( + config=self.config['embeddings_and_evoformer'], + global_config=self.global_config, + seq_len=seq_len, + channel_num=self.channel_num, + n_cpus=n_cpus, + module_prefix='%s.embeddings_and_evoformer' % module_prefix, + root_weights=root_weights, + is_pdinfer_init=is_pdinfer_init) + + Head_modules = { + 'masked_msa': MaskedMsaHead, + 'distogram': DistogramHead, + 'structure_module': folding.StructureModule, + 'predicted_lddt': PredictedLDDTHead, + 'predicted_aligned_error': PredictedAlignedErrorHead, + 'experimentally_resolved': ExperimentallyResolvedHead, # finetune loss + } + + self.used_heads = [] + self.heads = {} + for head_name, head_config in sorted(self.config.heads.items()): + if head_name not in Head_modules.keys(): + continue + + self.used_heads.append(head_name) + module = Head_modules[head_name]( + self.channel_num, head_config, self.global_config) + module.eval() + # setattr(self, head_name_, module) + self.heads[head_name] = module + + def np2pd(self, d:dict) -> dict: + res = {} + for k,v in d.items(): + if isinstance(v, np.ndarray): + if k == 'aatype': + res[k] = paddle.Tensor(v, dtype='int32') + else: + res[k] = paddle.Tensor(v, dtype='float32') + else: + res[k] = v + return res + + def __call__(self, + ensembled_batch, + non_ensembled_batch, + ensemble_representations=False): + num_ensemble = ensembled_batch['seq_length'].shape[1] + print(ensembled_batch['seq_length'].shape) + if not ensemble_representations: + assert num_ensemble == 1 + + def _slice_batch(i): + b = {k: v[:, i] for k, v in ensembled_batch.items()} + b.update(non_ensembled_batch) + return b + + batch0 = _slice_batch(0) + representations = self.evoformer(batch0) + + # MSA representations are not ensembled + msa_representation = representations['msa'] + del representations['msa'] + + if ensemble_representations: + for i in range(1, num_ensemble): + batch = _slice_batch(i) + representations_update = self.evoformer(batch) + for k in representations.keys(): + representations[k] += representations_update[k] + + for k in representations.keys(): + representations[k] /= num_ensemble + 0.0 + + representations['msa'] = msa_representation + ret = {'representations': representations} + + def _forward_heads(representations, ret, batch0): + for head_name, head_config in self._get_heads(): + # Skip PredictedLDDTHead and PredictedAlignedErrorHead until + # StructureModule is executed. + if head_name in ('predicted_lddt', 'predicted_aligned_error'): + continue + else: + # ret[head_name] = getattr(self, head_name_)(representations, batch0) + if head_name == 'structure_module': + ret[head_name] = self.heads[head_name](representations, self.np2pd(batch0)) + else: + ret[head_name] = self.heads[head_name](representations) + if 'representations' in ret[head_name]: + # Extra representations from the head. Used by the + # structure module to provide activations for the PredictedLDDTHead. + representations.update(ret[head_name].pop('representations')) + + if self.config.heads.get('predicted_lddt.weight', 0.0): + # Add PredictedLDDTHead after StructureModule executes. + head_name = 'predicted_lddt' + # Feed all previous results to give access to structure_module result. + head_config = self.config.heads[head_name] + # ret[head_name] = getattr(self, head_name_)(representations, batch0) + if head_name == 'structure_module': + ret[head_name] = self.heads[head_name](representations, self.np2pd(batch0)) + else: + ret[head_name] = self.heads[head_name](representations) + + if ('predicted_aligned_error' in self.config.heads + and self.config.heads.get('predicted_aligned_error.weight', 0.0)): + # Add PredictedAlignedErrorHead after StructureModule executes. + head_name = 'predicted_aligned_error' + # Feed all previous results to give access to structure_module result. + head_config = self.config.heads[head_name] + # ret[head_name] = getattr(self, head_name_)(representations, batch0) + if head_name == 'structure_module': + ret[head_name] = self.heads[head_name](representations, self.np2pd(batch0)) + else: + ret[head_name] = self.heads[head_name](representations) + + return ret + + tracer = _dygraph_tracer() + if tracer._amp_dtype == "bfloat16": + raise NotImplementedError("Currently CPU optimized inference is unsupported on bfloat16.") + else: + with paddle.no_grad(): + ret = _forward_heads(representations, ret, batch0) + + return ret + + def _get_heads(self): + assert 'structure_module' in self.used_heads + head_names = [h for h in self.used_heads] + + for k in head_names: + yield k, self.config.heads[k] diff --git a/apps/protein_folding/helixfold_cpu/layers/static_subnets.py b/apps/protein_folding/helixfold_cpu/layers/static_subnets.py new file mode 100644 index 00000000..bcb265c9 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/layers/static_subnets.py @@ -0,0 +1,212 @@ +import pdb +from layers.static_backbones import ( + StaticEmbeddings, + StaticExtraMsa, + StaticSingleTemplateEmbedding, + StaticEvoformer +) +import paddle as pd +from paddle import nn +from tools import all_atom +import numpy as np + + +class StaticEmbeddingsAndEvoformer(object): + def __init__(self, + config:dict, # cfg['model']['embeddings_and_evoformer'] + global_config:dict, + seq_len:int, + channel_num:dict, + n_cpus:int, + module_prefix:str='embeddingsandevoformer', + root_weights:str='static_modules', + is_pdinfer_init:bool=False) -> None: + + # [INFO] build sample & configuration + self.c = config + self.gc = global_config + feed_dict = { + 'target_feat': np.ones([1, seq_len, 22], dtype='float32'), + 'msa_feat': np.ones([1, 508, seq_len, 49], dtype='float32'), + 'seq_mask': np.ones([1, seq_len], dtype='float32'), + 'aatype': np.ones([1, seq_len], dtype='float32'), + 'residue_index': np.ones([1, seq_len], dtype='float32'), + 'template_mask': np.ones([1, 4], dtype='float32'), + 'template_aatype': np.ones([1, 4, seq_len], dtype="int32"), # define + 'template_pseudo_beta_mask': np.ones([1, 4, seq_len], dtype='float32'), + 'template_pseudo_beta': np.ones([1, 4, seq_len, 3], dtype='float32'), + 'template_all_atom_positions': np.ones([1, 4, seq_len, 37, 3], dtype='float32'), + 'template_all_atom_masks': np.ones([1, 4, seq_len, 37], dtype='float32'), + 'extra_msa': np.ones([1, 5120, seq_len], dtype='float32'), + 'extra_has_deletion': np.ones([1, 5120, seq_len], dtype='float32'), + 'extra_deletion_value': np.ones([1, 5120, seq_len], dtype='float32'), + 'extra_msa_mask': np.ones([1, 5120, seq_len], dtype='float32'), + 'msa_mask': np.ones([1, 508, seq_len], dtype='float32'), + 'prev_pos': np.ones([1, seq_len, 37, 3], dtype='float32'), + 'prev_msa_first_row': np.ones([1, seq_len, 256], dtype='float32'), + 'prev_pair': np.ones([1, seq_len, seq_len, 128], dtype='float32') + } + + # [INFO] build embedding alyer + feed2embeddings = {k:feed_dict[k] for k in [ + 'target_feat', + 'msa_feat', + 'seq_mask', + 'aatype', + 'residue_index', + 'template_mask', + 'template_aatype', + 'template_pseudo_beta_mask', + 'template_pseudo_beta', + 'template_all_atom_positions', + 'template_all_atom_masks', + 'extra_msa', + 'extra_has_deletion', + 'extra_deletion_value', + 'prev_pos', + 'prev_msa_first_row', + 'prev_pair'] + } + self.embeddings = StaticEmbeddings( + config, + global_config, + feed2embeddings, + channel_num, + n_cpus, + module_prefix='%s.embeddings' % module_prefix, + root_weights=root_weights, + is_pdinfer_init=is_pdinfer_init + ) + + # [INFO] build ExtraMSA layer + feed2extramsa = { + 'extra_msa_act':np.ones([1, 5120, seq_len, 64], dtype='float32'), + 'extra_pair_act':np.ones([1, seq_len, seq_len, 128], dtype='float32'), + 'extra_msa_mask':feed_dict['extra_msa_mask'], + 'mask_2d': np.ones([1, seq_len, seq_len], dtype='float32') + } + self.extra_msa = StaticExtraMsa( + config, + global_config, + feed2extramsa, + channel_num, + n_cpus, + module_prefix='%s.extramsa' % module_prefix, + root_weights=root_weights, + is_pdinfer_init=is_pdinfer_init + ) + + # [INFO] build single template embedding layer + feed2singletemplate = { + 'msa_mask': np.ones([1, 508, seq_len], dtype='float32'), + 'torsion_angles_mask': np.ones([1, 4, seq_len, 7], dtype='float32'), + 'msa_activations_raw': np.ones([1, 508, seq_len, 256], dtype='float32'), + 'template_features': np.ones([1, 4, seq_len, 57], dtype='float32') + } + self.single_template_embedding = StaticSingleTemplateEmbedding( + config, + global_config, + feed2singletemplate, + channel_num, + n_cpus, + module_prefix='%s.singletemplateembedding' % module_prefix, + root_weights=root_weights, + is_pdinfer_init=is_pdinfer_init + ) + + # [INFO] build evoformer stack + feed2evoformer = { + 'msa_activations': np.ones([1, 508, seq_len, 256], dtype='float32'), + 'extra_pair_act': np.ones([1, seq_len, seq_len, 128], dtype='float32'), + 'msa_mask': np.ones([1, 508, seq_len], dtype='float32'), + 'mask_2d': np.ones([1, seq_len, seq_len], dtype='float32') + } + self.evoformer = StaticEvoformer( + config, + global_config, + feed2evoformer, + channel_num, + n_cpus, + module_prefix='%s.evoformer' % module_prefix, + root_weights=root_weights, + is_pdinfer_init=is_pdinfer_init + ) + + def pd2np(self, d:dict) -> dict: + res = {} + for k,v in d.items(): + if isinstance(v, pd.Tensor): + res[k] = v.detach().numpy() + else: + res[k] = v + return res + + def __call__(self, feed_dict:dict) -> dict: + feed_dict = self.pd2np(feed_dict) + + # [INFO] embeddings + res_embeddings = self.embeddings(feed_dict) + msa_activations_raw, extra_msa_act, extra_pair_act, mask_2d = res_embeddings.values() + + # [INFO] extra_msa + feed_to_extra_msa = { + 'extra_msa_act':extra_msa_act, # (1, 5120, len_dim, 64) + 'extra_pair_act':extra_pair_act, # (1, len_dim, len_dim, 128) + 'extra_msa_mask':feed_dict['extra_msa_mask'], # [1, 5120, len_dim] + 'mask_2d':mask_2d # (1, len_dim, len_dim) + } + feed_to_extra_msa = self.pd2np(feed_to_extra_msa) + res_extra_msa = self.extra_msa(feed_to_extra_msa) # [OK] I/O valid + extra_msa_act, extra_pair_act = res_extra_msa.values() + + # [INFO] template angle features + template_aatype = pd.Tensor(feed_dict['template_aatype']) + template_all_atom_positions = pd.Tensor(feed_dict['template_all_atom_positions']) + template_all_atom_masks = pd.Tensor(feed_dict['template_all_atom_masks']) + + if self.c.template.enabled and self.c.template.embed_torsion_angles: + num_templ, num_res = template_aatype.shape[1:] + + aatype_one_hot = nn.functional.one_hot(template_aatype, 22) + # Embed the templates aatype, torsion angles and masks. + # Shape (templates, residues, msa_channels) + ret = all_atom.atom37_to_torsion_angles( + aatype=template_aatype, + all_atom_pos=template_all_atom_positions, + all_atom_mask=template_all_atom_masks, + # Ensure consistent behaviour during testing: + placeholder_for_undefined=not self.gc.zero_init) + + template_features = pd.concat([ + aatype_one_hot, + pd.reshape(ret['torsion_angles_sin_cos'], + [-1, num_templ, num_res, 14]), + pd.reshape(ret['alt_torsion_angles_sin_cos'], + [-1, num_templ, num_res, 14]), + ret['torsion_angles_mask']], axis=-1) + + res_single_template = self.single_template_embedding(self.pd2np({ + 'msa_mask': feed_dict['msa_mask'], + 'torsion_angles_mask': ret['torsion_angles_mask'].detach().numpy(), + 'msa_activations_raw': msa_activations_raw, + 'template_features': template_features.detach().numpy() + })) + msa_activations, msa_mask = res_single_template.values() + + # [INFO] evoformer + feed_to_evoformer = self.pd2np({ + 'msa_activations':msa_activations, + 'extra_pair_act':extra_pair_act, + 'msa_mask':msa_mask, + 'mask_2d':mask_2d + }) + res_evoformer = self.evoformer(feed_to_evoformer) + single_activations, pair_activations, msa_activations = res_evoformer.values() + + num_seq = feed_dict['msa_feat'].shape[1] + return { + 'single':single_activations, + 'pair':pair_activations, + 'msa':msa_activations[:, :num_seq], + 'msa_first_row':msa_activations[:, 0] + } diff --git a/apps/protein_folding/helixfold_cpu/layers/subnets.py b/apps/protein_folding/helixfold_cpu/layers/subnets.py new file mode 100644 index 00000000..0f5ea77b --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/layers/subnets.py @@ -0,0 +1,354 @@ + +import pdb +from paddle.distributed.fleet.utils import recompute +import paddle +from paddle import nn +from tools import dap, all_atom, residue_constants +from layers.basics import dgram_from_positions +from layers.backbones import EvoformerIteration +from layers.embeddings import TemplateEmbedding +import numpy as np + +def recompute_wrapper(func, *args, is_recompute=True): + """Function wrapper for recompute""" + if is_recompute: + return recompute(func, *args) + else: + return func(*args) + + +class EmbeddingsAndEvoformer(nn.Layer): + """Embeds the input data and runs Evoformer. + + Produces the MSA, single and pair representations. + Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5-18 + """ + + def __init__(self, channel_num, config, global_config): + super(EmbeddingsAndEvoformer, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + # InputEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 + # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" + self.preprocess_1d = nn.Linear(channel_num['target_feat'], + self.config.msa_channel, name='preprocess_1d') + self.preprocess_msa = nn.Linear(channel_num['msa_feat'], + self.config.msa_channel, name='preprocess_msa') + self.left_single = nn.Linear(channel_num['target_feat'], self.config.pair_channel, + name='left_single') + self.right_single = nn.Linear(channel_num['target_feat'], self.config.pair_channel, + name='right_single') + + # RecyclingEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 + # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" + if self.config.recycle_pos: + self.prev_pos_linear = nn.Linear(self.config.prev_pos.num_bins, + self.config.pair_channel) + + # RelPosEmbedder + # Jumper et al. (2021) Suppl. Alg. 4 "relpos" + # Jumper et al. (2021) Suppl. Alg. 5 "one_hot" + if self.config.max_relative_feature: + self.pair_activiations = nn.Linear( + 2 * self.config.max_relative_feature + 1, + self.config.pair_channel) + + if self.config.recycle_features: + self.prev_msa_first_row_norm = nn.LayerNorm( + self.config.msa_channel) + self.prev_pair_norm = nn.LayerNorm(self.config.pair_channel) + + # Embed templates into the pair activations. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13 + if self.config.template.enabled: + self.channel_num['template_angle'] = 57 + self.channel_num['template_pair'] = 88 + self.template_embedding = TemplateEmbedding( + self.channel_num, self.config.template, self.global_config) + + # ExtraMSAEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16 + self.extra_msa_activations = nn.Linear( + 25, # 23 (20aa+unknown+gap+mask) + 1 (has_del) + 1 (del_val) + self.config.extra_msa_channel) + + # Extra MSA Stack. + # Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack" + self.extra_msa_stack = nn.LayerList() + for _ in range(self.config.extra_msa_stack_num_block): + self.extra_msa_stack.append(EvoformerIteration( + self.channel_num, self.config.evoformer, self.global_config, + is_extra_msa=True)) + + # Embed templates torsion angles + if self.config.template.enabled and self.config.template.embed_torsion_angles: + c = self.config.msa_channel + self.template_single_embedding = nn.Linear( + self.channel_num['template_angle'], c) + self.template_projection = nn.Linear(c, c) + + # Main trunk of the network + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18 + self.evoformer_iteration = nn.LayerList() + for _ in range(self.config.evoformer_num_block): + self.evoformer_iteration.append(EvoformerIteration( + self.channel_num, self.config.evoformer, self.global_config, + is_extra_msa=False)) + + self.single_activations = nn.Linear( + self.config.msa_channel, self.config.seq_channel) + + def _pseudo_beta_fn(self, aatype, all_atom_positions): + gly_id = paddle.ones_like(aatype) * residue_constants.restype_order['G'] # gly_id = (1, len_dim) + is_gly = paddle.equal(aatype, gly_id) # is_gly = (1, len_dim) + is_gly_dim = len(is_gly.shape) + new_is_gly = paddle.unsqueeze(is_gly, axis=-1) + new_is_gly.stop_gradient = True + + ca_idx = residue_constants.atom_order['CA'] # 1 + cb_idx = residue_constants.atom_order['CB'] # 3 + n = len(all_atom_positions.shape) + pseudo_beta = paddle.where( + paddle.tile(new_is_gly, [1] * is_gly_dim + [3]), # 1, len_dim, 3 + paddle.squeeze(all_atom_positions.slice([n-2], [ca_idx], [ca_idx+1]),axis=-2), # 1, len_dim + paddle.squeeze(all_atom_positions.slice([n-2], [cb_idx], [cb_idx+1]),axis=-2) # 1, len_dim + ) + return pseudo_beta # = (1, len_dim, 3) + + def _create_extra_msa_feature(self, + extra_msa, + extra_has_deletion, + extra_deletion_value): + # 23: 20aa + unknown + gap + bert mask + extra_msa = extra_msa.astype(paddle.int32) + msa_1hot = nn.functional.one_hot(extra_msa, 23) + msa_feat = [msa_1hot, + paddle.unsqueeze(extra_has_deletion, axis=-1), + paddle.unsqueeze(extra_deletion_value, axis=-1)] + return paddle.concat(msa_feat, axis=-1) + + #def forward(self, batch): + def forward(self, + target_feat, + msa_feat, + seq_mask, + aatype, + residue_index, + template_mask, + template_aatype, + template_pseudo_beta_mask, + template_pseudo_beta, + template_all_atom_positions, + template_all_atom_masks, + extra_msa, + extra_has_deletion, + extra_deletion_value, + extra_msa_mask, + msa_mask, + prev_pos=None, + prev_msa_first_row=None, + prev_pair=None): + # InputEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 + # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" + preprocess_1d = self.preprocess_1d(target_feat) + # preprocess_msa = self.preprocess_msa(batch['msa_feat']) + msa_activations = paddle.unsqueeze(preprocess_1d, axis=1) + \ + self.preprocess_msa(msa_feat) + + right_single = self.right_single(target_feat) # 1, n_res, 22 -> 1, n_res, 128 + right_single = paddle.unsqueeze(right_single, axis=1) # 1, n_res, 128 -> 1, 1, n_res, 128 + left_single = self.left_single(target_feat) # 1, n_res, 22 -> 1, n_res, 128 + left_single = paddle.unsqueeze(left_single, axis=2) # 1, n_res, 128 -> 1, n_res, 1, 128 + pair_activations = left_single + right_single + + mask_2d = paddle.unsqueeze(seq_mask, axis=1) * paddle.unsqueeze(seq_mask, axis=2) + # Inject previous outputs for recycling. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 + # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" + + if self.config.recycle_pos: # and prev_pos is not None: + prev_pseudo_beta = self._pseudo_beta_fn(aatype, prev_pos) + dgram = dgram_from_positions( + prev_pseudo_beta, **self.config.prev_pos) + pair_activations += self.prev_pos_linear(dgram) + + + if self.config.recycle_features: + if prev_msa_first_row is not None: + prev_msa_first_row = self.prev_msa_first_row_norm( + prev_msa_first_row) + + # A workaround for `jax.ops.index_add` + msa_first_row = paddle.squeeze(msa_activations[:, 0, :], axis=1) + msa_first_row += prev_msa_first_row + msa_first_row = paddle.unsqueeze(msa_first_row, axis=1) + msa_activations_raw = paddle.concat([msa_first_row, msa_activations[:, 1:, :]], axis=1) + + if 'prev_pair' is not None: + pair_activations += self.prev_pair_norm(prev_pair) + + + # RelPosEmbedder + # Jumper et al. (2021) Suppl. Alg. 4 "relpos" + # Jumper et al. (2021) Suppl. Alg. 5 "one_hot" + if self.config.max_relative_feature: + pos = residue_index # [bs, N_res] + offset = paddle.unsqueeze(pos, axis=[-1]) - \ + paddle.unsqueeze(pos, axis=[-2]) + offset = offset.astype(dtype=paddle.int32) + rel_pos = nn.functional.one_hot( + paddle.clip( + offset + self.config.max_relative_feature, + min=0, + max=2 * self.config.max_relative_feature), + 2 * self.config.max_relative_feature + 1) + rel_pos_bias = self.pair_activiations(rel_pos) + pair_activations += rel_pos_bias + + # TemplateEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13 + if self.config.template.enabled: # [TODO] check if valid + #template_batch = {k: batch[k] for k in batch if k.startswith('template_')} + # pdb.set_trace() + template_pair_repr = self.template_embedding( + pair_activations, # 1xlxlx128 + template_mask, # 1x4 + template_aatype, # 1xl + template_pseudo_beta_mask, # 1xl + template_pseudo_beta, # 1xlx3 + template_all_atom_positions, # 1xlx37x3 + template_all_atom_masks, # 1xlx37 + mask_2d # 1xlxl + ) + pair_activations += template_pair_repr + + # ExtraMSAEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16 + extra_msa_feat = self._create_extra_msa_feature( # [INFO] done + extra_msa, extra_has_deletion, extra_deletion_value + ) + extra_msa_activations = self.extra_msa_activations(extra_msa_feat) + + # ================================================== + # Extra MSA Stack + # Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack" + # ================================================== + # extra_msa_stack_input = { + # 'msa': extra_msa_activations, + # 'pair': pair_activations, + # } + + # scatter if using dap, otherwise do nothing + # [B, N_seq, N_res, c_m] => [B, N_seq//dap_size, N_res, c_m] + extra_msa_act = dap.scatter(extra_msa_activations, axis=1) + # [B, N_res, N_res, c_z] => [B, N_res//dap_size, N_res, c_z] + extra_pair_act = dap.scatter(pair_activations, axis=1) + + # [INFO] --- extra_msa start --- + for extra_msa_stack_iteration in self.extra_msa_stack: + print('# [INFO] inference one MSA stack iteration') + extra_msa_act, extra_pair_act = extra_msa_stack_iteration( + extra_msa_act, + extra_pair_act, + extra_msa_mask, + mask_2d) + + # gather if using dap, otherwise do nothing + # [B, N_res//dap_size, N_res, c_z] => [B, N_res, N_res, c_z] + extra_pair_act= dap.gather(extra_pair_act, axis=1) + + # [INFO] --- extra_msa end --- + # ================================================== + # Template angle feat + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 7-8 + # ================================================== + if self.config.template.enabled and self.config.template.embed_torsion_angles: + num_templ, num_res = template_aatype.shape[1:] + + aatype_one_hot = nn.functional.one_hot(template_aatype, 22) + # Embed the templates aatype, torsion angles and masks. + # Shape (templates, residues, msa_channels) + ret = all_atom.atom37_to_torsion_angles( + aatype=template_aatype, + all_atom_pos=template_all_atom_positions, + all_atom_mask=template_all_atom_masks, + # Ensure consistent behaviour during testing: + placeholder_for_undefined=not self.global_config.zero_init) + + template_features = paddle.concat([ + aatype_one_hot, + paddle.reshape(ret['torsion_angles_sin_cos'], + [-1, num_templ, num_res, 14]), + paddle.reshape(ret['alt_torsion_angles_sin_cos'], + [-1, num_templ, num_res, 14]), + ret['torsion_angles_mask']], axis=-1) + + template_activations = self.template_single_embedding( + template_features) + template_activations = nn.functional.relu(template_activations) + template_activations = self.template_projection(template_activations) + + # Concatenate the templates to the msa. + msa_activations = paddle.concat( + [msa_activations_raw, template_activations], axis=1) + + # Concatenate templates masks to the msa masks. + # Use mask from the psi angle, as it only depends on the backbone atoms + # from a single residue. + torsion_angle_mask = ret['torsion_angles_mask'][..., 2] + torsion_angle_mask = torsion_angle_mask.astype(msa_mask.dtype) + msa_mask = paddle.concat([msa_mask, torsion_angle_mask], axis=1) + + # scatter if using dap, otherwise do nothing + # [B, N_seq, N_res, c_m] => [B, N_seq//dap_size, N_res, c_m] + msa_activations = dap.scatter(msa_activations, axis=1) # [TODO] check if valid + # [B, N_res, N_res, c_z] => [B, N_res//dap_size, N_res, c_z] + extra_pair_act = dap.scatter(extra_pair_act, axis=1) + + # ================================================== + # Main MSA Stack + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18 + # ================================================== + + # [INFO] --- evoformer start --- + i_block=0 + for evoformer_block in self.evoformer_iteration: + print('# [INFO] evoformer iteration %d' % i_block) + i_block += 1 + msa_act, pair_act = recompute_wrapper( + evoformer_block, + msa_activations, + extra_pair_act, + msa_mask, + mask_2d, + is_recompute=self.training) + msa_activations = msa_act + extra_pair_act = pair_act + + # gather if using dap, otherwise do nothing + # [B, N_seq//dap_size, N_res, c_m] => [B, N_seq, N_res, c_m] + msa_act = dap.gather(msa_act, axis=1) + # [B, N_res//dap_size, N_res, c_z] => [B, N_res, N_res, c_z] + pair_act = dap.gather(pair_act, axis=1) + + msa_activations = msa_act + pair_activations = pair_act + single_activations = self.single_activations(msa_activations[:, 0]) + + # [INFO] --- evoformer end --- + num_seq = msa_feat.shape[1] + # output = { + # 'single': single_activations, + # 'pair': pair_activations, + # # Crop away template rows such that they are not used + # # in MaskedMsaHead. + # 'msa': msa_activations[:, :num_seq], + # 'msa_first_row': msa_activations[:, 0], + # } + + return single_activations, pair_activations, msa_activations[:, :num_seq], msa_activations[:, 0] diff --git a/apps/protein_folding/helixfold_cpu/model.py b/apps/protein_folding/helixfold_cpu/model.py new file mode 100644 index 00000000..75c4d718 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/model.py @@ -0,0 +1,240 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model.""" + +import os +import io +import time +import pickle +import logging +import pathlib +import numpy as np +import ml_collections +from copy import deepcopy +from typing import Dict, Optional + +import paddle +from tools import utils +from layers.backbones import * +from layers.subnets import * +import protein +from tools import residue_constants + +try: + import tensorflow.compat.v1 as tf + from tools import input_pipeline + from tools import proteins_dataset + + USE_TF = True +except Exception: + from tools import input_pipeline + + USE_TF = False + +logger = logging.getLogger(__name__) + +TARGET_FEAT_DIM = 22 +MSA_FEAT_DIM = 49 + + +def print_shape(d, level=0): + tabs = '\t' * level + for k, v in d.items(): + if type(v) is dict: + print(tabs + k) + print_shape(v, level=level+1) + else: + print(tabs + f'{k}: {v.shape} {v.dtype}') + + +def tensor_to_numpy(pred_dict): + for k in pred_dict.keys(): + if isinstance(pred_dict[k], paddle.Tensor): + pred_dict[k] = pred_dict[k].numpy() + + elif type(pred_dict[k]) is dict: + tensor_to_numpy(pred_dict[k]) + + +def slice_pred_dict(pred_dict, slice_idx, ignores=['breaks', 'traj', 'sidechains']): + for k in pred_dict.keys(): + if k in ignores: + continue + + if type(pred_dict[k]) is dict: + pred_dict[k] = slice_pred_dict(pred_dict[k], slice_idx, + ignores=ignores) + + else: + pred_dict[k] = pred_dict[k][slice_idx] + + return pred_dict + + +class RunModel(object): + """Wrapper for paddle model.""" + + def __init__(self, + name: str, + config: ml_collections.ConfigDict, + params_path: str, + dynamic_subbatch_size: bool = True): + self.name = name + self.config = config + self.dynamic_subbatch_size = dynamic_subbatch_size + + channel_num = { + 'target_feat': TARGET_FEAT_DIM, + 'msa_feat': MSA_FEAT_DIM, + } + self.alphafold = modules.AlphaFold(channel_num, config.model) + self.init_params(str(params_path)) + self.alphafold.eval() + + def init_params(self, params_path: str): + if params_path.endswith('.npz'): + logger.info('Load as AlphaFold pre-trained model') + with open(params_path, 'rb') as f: + params = np.load(io.BytesIO(f.read()), allow_pickle=False) + params = dict(params) + + pd_params = utils.jax_params_to_paddle(params) + pd_params = {k[len('alphafold.'):]: v for k, v in pd_params.items()} + + elif params_path.endswith('.pd'): + logger.info('Load as Paddle model') + pd_params = paddle.load(params_path) + + else: + raise ValueError('Unsupported params file type') + + self.alphafold.set_state_dict(pd_params) + + def preprocess(self, + raw_features: Dict[str, np.ndarray], + random_seed: int, + pkl: pathlib.Path = None) -> Dict[str, paddle.Tensor]: + """Convert raw input features to model input features""" + if pkl is not None and pkl.exists(): + logger.info(f'Use cached {pkl}') + with open(pkl, 'rb') as f: + features = pickle.load(f) + + print('########## feature shape ##########') + print_shape(features) + return utils.map_to_tensor(features, add_batch=True) + + print('Processing input features') + data_config = deepcopy(self.config.data) + feature_names = data_config.common.unsupervised_features + if data_config.common.use_templates: + feature_names += data_config.common.template_features + + + num_residues = int(raw_features['seq_length'][0]) + data_config.eval.crop_size = num_residues + + if 'deletion_matrix_int' in raw_features: + raw_features['deletion_matrix'] = (raw_features.pop( + 'deletion_matrix_int').astype(np.float32)) + + if raw_features['msa'].shape[0] > 10000: + raw_features['msa'] = raw_features['msa'][:10000] + raw_features['num_alignments'] = np.ones_like(raw_features['num_alignments']) * 10000 + + if 'deletion_matrix' in raw_features: + raw_features['deletion_matrix'] = raw_features['deletion_matrix'][:10000] + + if USE_TF: + data_config.eval.delete_msa_block = False + + tf_graph = tf.Graph() + with tf_graph.as_default(), tf.device('/device:CPU:0'): + tf.compat.v1.set_random_seed(random_seed) + tensor_dict = proteins_dataset.np_to_tensor_dict( + np_example=raw_features, features=feature_names) + + processed_batch = input_pipeline.process_tensors_from_config( + tensor_dict, data_config) + + tf_graph.finalize() + + with tf.Session(graph=tf_graph) as sess: + features = sess.run(processed_batch) + + else: + + array_dict = input_pipeline.np_to_array_dict( + np_example=raw_features, features=feature_names, + use_templates=data_config.common.use_templates) + features = input_pipeline.process_arrays_from_config( + array_dict, data_config) + features = {k: v for k, v in features.items() if v.dtype != 'O'} + + extra_msa_length = data_config.common.max_extra_msa + for k in ['extra_msa', 'extra_has_deletion', 'extra_deletion_value', + 'extra_msa_mask']: + features[k] = features[k][:, :extra_msa_length] + + for k in features.keys(): + if features[k].dtype == np.int64: + features[k] = features[k].astype(np.int32) + + elif features[k].dtype == np.float64: + features[k] = features[k].astype(np.float32) + + if pkl is not None: + with open(pkl, 'wb') as f: + pickle.dump(features, f, protocol=4) + + print('Preprocessesing finished') + print('########## feature shape ##########') + print_shape(features) + return utils.map_to_tensor(features, add_batch=True) + + def predict(self, + feat: Dict[str, paddle.Tensor], + ensemble_representations: bool = True, + return_representations: bool = True): + """Predict protein structure and encoding representation""" + if self.dynamic_subbatch_size: + seq_len = feat['aatype'].shape[-1] + extra_msa_num = feat['extra_msa'].shape[-2] + self.update_subbatch_size(seq_len, extra_msa_num) + + with paddle.no_grad(): + ret = self.alphafold( + feat, {}, + ensemble_representations=ensemble_representations, + return_representations=return_representations, + compute_loss=False) + + print('Prediction finished') + tensor_to_numpy(ret) + return ret + + def update_subbatch_size(self, seq_len, extra_msa_num): + if extra_msa_num == 5120: + if seq_len < 200: + # disable subbatch + self.alphafold.global_config.subbatch_size = 5120 + + elif extra_msa_num == 1024: + if seq_len < 600: + # disable subbatch + self.alphafold.global_config.subbatch_size = 1024 + + else: + raise ValueError('Unknown subbatch strategy') diff --git a/apps/protein_folding/helixfold_cpu/modules.py b/apps/protein_folding/helixfold_cpu/modules.py new file mode 100644 index 00000000..93985d69 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/modules.py @@ -0,0 +1,2189 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Modules.""" + +import pdb +import numpy as np + +import paddle +import paddle.nn as nn +from paddle.fluid.framework import _dygraph_tracer +from paddle.distributed.fleet.utils import recompute + +from alphafold_paddle.common import residue_constants +from alphafold_paddle.model.utils import mask_mean, subbatch +from alphafold_paddle.model import folding, lddt, quat_affine, all_atom +from alphafold_paddle.model.utils import init_gate_linear, init_final_linear +from alphafold_paddle.distributed import dap + +# Map head name in config to head name in model params +Head_names = { + 'masked_msa': 'masked_msa_head', + 'distogram': 'distogram_head', + 'predicted_lddt': 'predicted_lddt_head', + 'predicted_aligned_error': 'predicted_aligned_error_head', + 'experimentally_resolved': 'experimentally_resolved_head', # finetune loss +} + + +def recompute_wrapper(func, *args, is_recompute=True): + """Function wrapper for recompute""" + if is_recompute: + return recompute(func, *args) + else: + return func(*args) + + +def softmax_cross_entropy(logits, labels): + """Computes softmax cross entropy given logits and one-hot class labels.""" + loss = -paddle.sum(labels * paddle.nn.functional.log_softmax(logits), axis=-1) + return loss + + +def sigmoid_cross_entropy(logits, labels): + """Computes sigmoid cross entropy given logits and multiple class labels.""" + log_p = paddle.nn.functional.log_sigmoid(logits) + # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter is more numerically stable + log_not_p = paddle.nn.functional.log_sigmoid(-logits) + loss = -labels * log_p - (1. - labels) * log_not_p + return loss + + +class AlphaFold(nn.Layer): + """AlphaFold model with recycling. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" + """ + def __init__(self, channel_num, config): + super(AlphaFold, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = config.global_config + + self.alphafold_iteration = AlphaFoldIteration( + self.channel_num, self.config, self.global_config) + + def forward(self, + batch, + label, + ensemble_representations=False, + return_representations=False, + compute_loss=True): + """Run the AlphaFold model. + + Arguments: + batch: Dictionary with inputs to the AlphaFold model. + ensemble_representations: Whether to use ensembling of representations. + return_representations: Whether to also return the intermediate + representations. + + Returns: + The output of AlphaFoldIteration is a nested dictionary containing + predictions from the various heads. + + """ + inner_batch, num_residues = batch['aatype'].shape[1:] + + def _get_prev(ret): + new_prev = { + 'prev_pos': ret['structure_module']['final_atom_positions'], + 'prev_msa_first_row': ret['representations']['msa_first_row'], + 'prev_pair': ret['representations']['pair'], + } + + for k in new_prev.keys(): + new_prev[k].stop_gradient = True + + return new_prev + + def _run_single_recycling(prev, recycle_idx, compute_loss): + if not self.training: + print(f'########## recycle id: {recycle_idx} ##########') + + if self.config.resample_msa_in_recycling: + # (B, (R+1)*E, N, ...) + # B: batch size, R: recycling number, + # E: ensemble number, N: residue number + num_ensemble = inner_batch // (self.config.num_recycle + 1) + ensembled_batch = dict() + for k in batch.keys(): + start = recycle_idx * num_ensemble + end = start + num_ensemble + ensembled_batch[k] = batch[k][:, start:end] + else: + # (B, E, N, ...) + num_ensemble = inner_batch + ensembled_batch = batch + + non_ensembled_batch = prev + return self.alphafold_iteration( + ensembled_batch, label, non_ensembled_batch, + compute_loss=compute_loss, + ensemble_representations=ensemble_representations) + + if self.config.num_recycle: + # aatype: (B, E, N), zeros_bn: (B, N) + zeros_bn = paddle.zeros_like(batch['aatype'][:, 0], dtype='float32') + + emb_config = self.config.embeddings_and_evoformer + prev = { + 'prev_pos': paddle.tile( + zeros_bn[..., None, None], + [1, 1, residue_constants.atom_type_num, 3]), + 'prev_msa_first_row': paddle.tile( + zeros_bn[..., None], + [1, 1, emb_config.msa_channel]), + 'prev_pair': paddle.tile( + zeros_bn[..., None, None], + [1, 1, num_residues, emb_config.pair_channel]), + } + + if 'num_iter_recycling' in batch: + # Training trick: dynamic recycling number + num_iter = batch['num_iter_recycling'].numpy()[0, 0] + num_iter = min(int(num_iter), self.config.num_recycle) + else: + num_iter = self.config.num_recycle + + for recycle_idx in range(num_iter): + ret = _run_single_recycling(prev, recycle_idx, compute_loss=False) + prev = _get_prev(ret) + + else: + prev = {} + num_iter = 0 + + return _run_single_recycling(prev, num_iter, compute_loss=compute_loss) + + +class AlphaFoldIteration(nn.Layer): + """A single recycling iteration of AlphaFold architecture. + + Computes ensembled (averaged) representations from the provided features. + These representations are then passed to the various heads + that have been requested by the configuration file. Each head also returns a + loss which is combined as a weighted sum to produce the total loss. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 3-22 + """ + + def __init__(self, channel_num, config, global_config): + super(AlphaFoldIteration, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + # copy these config for later usage + self.channel_num['extra_msa_channel'] = config.embeddings_and_evoformer.extra_msa_channel + self.channel_num['msa_channel'] = config.embeddings_and_evoformer.msa_channel + self.channel_num['pair_channel'] = config.embeddings_and_evoformer.pair_channel + self.channel_num['seq_channel'] = config.embeddings_and_evoformer.seq_channel + + self.evoformer = EmbeddingsAndEvoformer( + self.channel_num, self.config.embeddings_and_evoformer, + self.global_config) + + Head_modules = { + 'masked_msa': MaskedMsaHead, + 'distogram': DistogramHead, + 'structure_module': folding.StructureModule, + 'predicted_lddt': PredictedLDDTHead, + 'predicted_aligned_error': PredictedAlignedErrorHead, + 'experimentally_resolved': ExperimentallyResolvedHead, # finetune loss + } + + self.used_heads = [] + for head_name, head_config in sorted(self.config.heads.items()): + if head_name not in Head_modules: + continue + + self.used_heads.append(head_name) + module = Head_modules[head_name]( + self.channel_num, head_config, self.global_config) + + head_name_ = Head_names.get(head_name, head_name) + setattr(self, head_name_, module) + + def forward(self, + ensembled_batch, + label, + non_ensembled_batch, + compute_loss=False, + ensemble_representations=False): + num_ensemble = ensembled_batch['seq_length'].shape[1] + if not ensemble_representations: + assert num_ensemble == 1 + + def _slice_batch(i): + b = {k: v[:, i] for k, v in ensembled_batch.items()} + b.update(non_ensembled_batch) + return b + + batch0 = _slice_batch(0) + representations = self.evoformer(batch0) + + # MSA representations are not ensembled + msa_representation = representations['msa'] + del representations['msa'] + # MaskedMSAHead is apply on batch0 + label['bert_mask'] = batch0['bert_mask'] + label['true_msa'] = batch0['true_msa'] + label['residue_index'] = batch0['residue_index'] + + if ensemble_representations: + for i in range(1, num_ensemble): + batch = _slice_batch(i) + representations_update = self.evoformer(batch) + for k in representations.keys(): + representations[k] += representations_update[k] + + for k in representations.keys(): + representations[k] /= num_ensemble + 0.0 + + representations['msa'] = msa_representation + ret = {'representations': representations} + + def loss(head_name_, head_config, ret, head_name, filter_ret=True): + if filter_ret: + value = ret[head_name] + else: + value = ret + loss_output = getattr(self, head_name_).loss(value, label) + ret[head_name].update(loss_output) + loss = head_config.weight * ret[head_name]['loss'] + return loss + + def _forward_heads(representations, ret, batch0): + total_loss = 0. + for head_name, head_config in self._get_heads(): + head_name_ = Head_names.get(head_name, head_name) + # Skip PredictedLDDTHead and PredictedAlignedErrorHead until + # StructureModule is executed. + if head_name in ('predicted_lddt', 'predicted_aligned_error'): + continue + else: + ret[head_name] = getattr(self, head_name_)(representations, batch0) + if 'representations' in ret[head_name]: + # Extra representations from the head. Used by the + # structure module to provide activations for the PredictedLDDTHead. + representations.update(ret[head_name].pop('representations')) + if compute_loss: + total_loss += loss(head_name_, head_config, ret, head_name) + + if self.config.heads.get('predicted_lddt.weight', 0.0): + # Add PredictedLDDTHead after StructureModule executes. + head_name = 'predicted_lddt' + # Feed all previous results to give access to structure_module result. + head_name_ = Head_names.get(head_name, head_name) + head_config = self.config.heads[head_name] + ret[head_name] = getattr(self, head_name_)(representations, batch0) + if compute_loss: + total_loss += loss(head_name_, head_config, ret, head_name, filter_ret=False) + + if ('predicted_aligned_error' in self.config.heads + and self.config.heads.get('predicted_aligned_error.weight', 0.0)): + # Add PredictedAlignedErrorHead after StructureModule executes. + head_name = 'predicted_aligned_error' + # Feed all previous results to give access to structure_module result. + head_config = self.config.heads[head_name] + head_name_ = Head_names.get(head_name, head_name) + ret[head_name] = getattr(self, head_name_)(representations, batch0) + if compute_loss: + total_loss += loss(head_name_, head_config, ret, head_name, filter_ret=False) + + return ret, total_loss + + tracer = _dygraph_tracer() + if tracer._amp_dtype == "bfloat16": + with paddle.amp.auto_cast(enable=False): + for key, value in representations.items(): + if value.dtype in [paddle.fluid.core.VarDesc.VarType.BF16]: + temp_value = value.cast('float32') + temp_value.stop_gradient = value.stop_gradient + representations[key] = temp_value + for key, value in batch0.items(): + if value.dtype in [paddle.fluid.core.VarDesc.VarType.BF16]: + temp_value = value.cast('float32') + temp_value.stop_gradient = value.stop_gradient + batch0[key] = temp_value + ret, total_loss = _forward_heads(representations, ret, batch0) + + else: + ret, total_loss = _forward_heads(representations, ret, batch0) + + if compute_loss: + return ret, total_loss + else: + return ret + + def _get_heads(self): + assert 'structure_module' in self.used_heads + head_names = [h for h in self.used_heads] + + for k in head_names: + yield k, self.config.heads[k] + + +class Attention(nn.Layer): + """Multihead attention.""" + + def __init__(self, config, global_config, q_dim, kv_dim, output_dim): + super(Attention, self).__init__() + self.config = config + self.global_config = global_config + + num_head = self.config.num_head + key_dim = self.config.get('key_dim', q_dim) + value_dim = self.config.get('value_dim', kv_dim) + + assert key_dim % num_head == 0 + assert value_dim % num_head == 0 + key_dim = key_dim // num_head + value_dim = value_dim // num_head + + self.key_dim = key_dim + self.value_dim = value_dim + + self.query_w = paddle.create_parameter( + [q_dim, num_head, key_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + self.key_w = paddle.create_parameter( + [kv_dim, num_head, key_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + self.value_w = paddle.create_parameter( + [kv_dim, num_head, value_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + + if self.config.gating: + self.gating_w = paddle.create_parameter( + [q_dim, num_head, value_dim], 'float32', + default_initializer=nn.initializer.Constant(0.0)) + self.gating_b = paddle.create_parameter( + [num_head, value_dim], 'float32', + default_initializer=nn.initializer.Constant(1.0)) + + if self.global_config.zero_init: + init = nn.initializer.Constant(0.0) + else: + init = nn.initializer.XavierUniform() + + self.output_w = paddle.create_parameter( + [num_head, value_dim, output_dim], 'float32', + default_initializer=init) + self.output_b = paddle.create_parameter( + [output_dim], 'float32', + default_initializer=nn.initializer.Constant(0.0)) + + def forward(self, q_data, m_data, bias, nonbatched_bias=None): + """Builds Attention module. + Arguments: + q_data: A tensor of queries, shape [batch, row_size, N_queries, q_channels]. + m_data: A tensor of memories from which the keys and values are + projected, shape [batch, row_size, N_keys, m_channels]. + bias: A bias for the attention, shape [batch, row_size, num_head, N_queries, N_keys]. + nonbatched_bias: Shared bias, shape [N_queries, N_keys]. + + Returns: + A float32 tensor of shape [batch_size, row_size, N_queries, output_dim]. + """ + c = self.key_dim ** (-0.5) + q = paddle.einsum('nbqa,ahc->nbqhc', q_data, self.query_w) * c + k = paddle.einsum('nbka,ahc->nbkhc', m_data, self.key_w) + v = paddle.einsum('nbka,ahc->nbkhc', m_data, self.value_w) + logits = paddle.einsum('nbqhc,nbkhc->nbhqk', q, k) + bias + + if nonbatched_bias is not None: + nonbatched_bias_after = dap.all_gather_opp(nonbatched_bias, axis=2) + logits += paddle.unsqueeze(nonbatched_bias_after, axis=1) + + weights = nn.functional.softmax(logits) + weighted_avg = paddle.einsum('nbhqk,nbkhc->nbqhc', weights, v) + + if self.config.gating: + gate_values = paddle.einsum('nbqc,chv->nbqhv', q_data, + self.gating_w) + self.gating_b + gate_values = nn.functional.sigmoid(gate_values) + weighted_avg *= gate_values + + output = paddle.einsum('nbqhc,hco->nbqo', weighted_avg, + self.output_w) + self.output_b + return output + + +class GlobalAttention(nn.Layer): + """Global attention. + + Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" lines 2-7 + """ + + def __init__(self, config, global_config, q_dim, kv_dim, output_dim): + super(GlobalAttention, self).__init__() + self.config = config + self.global_config = global_config + + num_head = self.config.num_head + key_dim = self.config.get('key_dim', q_dim) + value_dim = self.config.get('value_dim', kv_dim) + + assert key_dim % num_head == 0 + assert value_dim % num_head == 0 + key_dim = key_dim // num_head + value_dim = value_dim // num_head + + self.key_dim = key_dim + self.value_dim = value_dim + + self.query_w = paddle.create_parameter( + [q_dim, num_head, key_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + self.key_w = paddle.create_parameter( + [kv_dim, key_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + self.value_w = paddle.create_parameter( + [kv_dim, value_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + + if self.config.gating: + self.gating_w = paddle.create_parameter( + [q_dim, num_head, value_dim], 'float32', + default_initializer=nn.initializer.Constant(0.0)) + self.gating_b = paddle.create_parameter( + [num_head, value_dim], 'float32', + default_initializer=nn.initializer.Constant(1.0)) + + if self.global_config.zero_init: + init = nn.initializer.Constant(0.0) + else: + init = nn.initializer.XavierUniform() + + self.output_w = paddle.create_parameter( + [num_head, value_dim, output_dim], 'float32', + default_initializer=init) + self.output_b = paddle.create_parameter( + [output_dim], 'float32', + default_initializer=nn.initializer.Constant(0.0)) + + def forward(self, q_data, m_data, q_mask): + k = paddle.einsum('nbka,ac->nbkc', m_data, self.key_w) + v = paddle.einsum('nbka,ac->nbkc', m_data, self.value_w) + + # NOTE: differ from non-global version using q_avg for attn + q_avg = mask_mean(q_mask, q_data, axis=2) + c = self.key_dim ** (-0.5) + q = paddle.einsum('nba,ahc->nbhc', q_avg, self.query_w) * c + + q_mask_ = paddle.unsqueeze(q_mask, axis=2)[..., 0] + bias = 1e9 * (q_mask_ - 1.) + + logits = paddle.einsum('nbhc,nbkc->nbhk', q, k) + bias + weights = nn.functional.softmax(logits) + weighted_avg = paddle.einsum('nbhk,nbkc->nbhc', weights, v) + + if self.config.gating: + gate_values = paddle.einsum('nbqc,chv->nbqhv', q_data, + self.gating_w) + self.gating_b + gate_values = nn.functional.sigmoid(gate_values) + weighted_avg = paddle.unsqueeze(weighted_avg, axis=2) + weighted_avg *= gate_values + + output = paddle.einsum('nbqhc,hco->nbqo', weighted_avg, + self.output_w) + self.output_b + else: + output = paddle.einsum('nbhc,hco->nbo', weighted_avg, + self.output_w) + self.output_b + output = paddle.unsqueeze(output, axis=-1) + + return output + + +class MSARowAttentionWithPairBias(nn.Layer): + """MSA per-row attention biased by the pair representation. + + Jumper et al. (2021) Suppl. Alg. 7 "MSARowAttentionWithPairBias" + """ + + def __init__(self, channel_num, config, global_config, is_extra_msa): + super(MSARowAttentionWithPairBias, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.is_extra_msa = is_extra_msa + assert config.orientation == 'per_row' + + if is_extra_msa: + self.query_norm = nn.LayerNorm(channel_num['extra_msa_channel']) + else: + self.query_norm = nn.LayerNorm(channel_num['msa_channel']) + + self.feat_2d_norm = nn.LayerNorm(channel_num['pair_channel']) + self.feat_2d_weights = paddle.create_parameter( + [channel_num['pair_channel'], self.config.num_head], 'float32', + default_initializer=nn.initializer.Normal( + std=1. / np.sqrt(channel_num['pair_channel']))) + + if is_extra_msa: + extra_msa_channel = channel_num['extra_msa_channel'] + self.attention = Attention( + self.config, self.global_config, + extra_msa_channel, extra_msa_channel, extra_msa_channel) + else: + msa_channel = channel_num['msa_channel'] + self.attention = Attention( + self.config, self.global_config, + msa_channel, msa_channel, msa_channel) + + def forward(self, msa_act, msa_mask, pair_act): + + pair_act = self.feat_2d_norm(pair_act) + + # [B, N_res//dap_size, N_res, cz], [cz, head] => [B, head, N_res//dap_size, N_res] + nonbatched_bias_before = paddle.einsum( + 'nqkc,ch->nhqk', pair_act, self.feat_2d_weights) + + # [B, head, N_res//dap_size, N_res] => [B, head, N_res, N_res] + nonbatched_bias = dap.all_gather(nonbatched_bias_before, axis=2) + + # [B, N_seq, N_res] => [B, N_seq//dap_size, N_res] + msa_mask = dap.scatter(msa_mask, axis=1) + + bias = 1e9 * (msa_mask - 1.) + # [B, N_seq//dap_size, N_res] => [B, N_seq//dap_size, 1, 1, N_res] + bias = paddle.unsqueeze(bias, axis=[2, 3]) + msa_act = self.query_norm(msa_act) + + if not self.training: + # low memory mode using subbatch + sb_attn = subbatch(self.attention, [0, 1, 2], [1, 1, 1], + self.global_config.subbatch_size, 1) + msa_act = sb_attn(msa_act, msa_act, bias, nonbatched_bias) + else: + msa_act = self.attention(msa_act, msa_act, bias, nonbatched_bias) + + return msa_act + + +class MSAColumnGlobalAttention(nn.Layer): + """MSA per-column global attention. + + Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" + """ + + def __init__(self, channel_num, config, global_config): + super(MSAColumnGlobalAttention, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + assert config.orientation == 'per_column' + + extra_msa_channel = channel_num['extra_msa_channel'] + self.query_norm = nn.LayerNorm(extra_msa_channel) + self.attention = GlobalAttention( + self.config, self.global_config, + extra_msa_channel, extra_msa_channel, extra_msa_channel) + + def forward(self, msa_act, msa_mask): + # scatter if using dap, otherwise do nothing + # [B, N_seq, N_res] => [B, N_seq, N_res//dap_size] + msa_mask = dap.scatter(msa_mask, axis=2) + + msa_act = paddle.transpose(msa_act, [0, 2, 1, 3]) + msa_mask = paddle.transpose(msa_mask, [0, 2, 1]) + + bias = 1e9 * (msa_mask - 1.) + bias = paddle.unsqueeze(bias, axis=[2, 3]) + + msa_mask = paddle.unsqueeze(msa_mask, axis=-1) + msa_act = self.query_norm(msa_act) + + if not self.training: + # low memory mode using subbatch + sb_attn = subbatch(self.attention, [0, 1, 2], [1, 1, 1], + self.global_config.subbatch_size, 1) + msa_act = sb_attn(msa_act, msa_act, msa_mask) + else: + msa_act = self.attention(msa_act, msa_act, msa_mask) + + msa_act = paddle.transpose(msa_act, [0, 2, 1, 3]) + return msa_act + + +class MSAColumnAttention(nn.Layer): + """MSA per-column attention. + + Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention" + """ + + def __init__(self, channel_num, config, global_config): + super(MSAColumnAttention, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + assert config.orientation == 'per_column' + + msa_channel = channel_num['msa_channel'] + self.query_norm = nn.LayerNorm(msa_channel) + self.attention = Attention( + self.config, self.global_config, + msa_channel, msa_channel, msa_channel) + + def forward(self, msa_act, msa_mask): + # scatter if using dap, otherwise do nothing + # [B, N_seq, N_res] => [B, N_seq, N_res//dap_size] + msa_mask = dap.scatter(msa_mask, axis=2) + + msa_act = paddle.transpose(msa_act, [0, 2, 1, 3]) + msa_mask = paddle.transpose(msa_mask, [0, 2, 1]) + + bias = 1e9 * (msa_mask - 1.) + bias = paddle.unsqueeze(bias, axis=[2, 3]) + + msa_act = self.query_norm(msa_act) + if not self.training: + # low memory mode using subbatch + sb_attn = subbatch(self.attention, [0, 1, 2], [1, 1, 1], + self.global_config.subbatch_size, 1) + msa_act = sb_attn(msa_act, msa_act, bias) + else: + msa_act = self.attention(msa_act, msa_act, bias) + + msa_act = paddle.transpose(msa_act, [0, 2, 1, 3]) + return msa_act + + +class Transition(nn.Layer): + """Transition layer. + + Jumper et al. (2021) Suppl. Alg. 9 "MSATransition" + Jumper et al. (2021) Suppl. Alg. 15 "PairTransition" + """ + + def __init__(self, channel_num, config, global_config, is_extra_msa, + transition_type): + super(Transition, self).__init__() + assert transition_type in ['msa_transition', 'pair_transition'] + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.is_extra_msa = is_extra_msa + self.transition_type = transition_type + + if transition_type == 'msa_transition' and is_extra_msa: + in_dim = channel_num['extra_msa_channel'] + elif transition_type == 'msa_transition' and not is_extra_msa: + in_dim = channel_num['msa_channel'] + elif transition_type == 'pair_transition': + in_dim = channel_num['pair_channel'] + + self.input_layer_norm = nn.LayerNorm(in_dim) + self.transition1 = nn.Linear( + in_dim, int(in_dim * self.config.num_intermediate_factor), + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.KaimingNormal())) + + if self.global_config.zero_init: + last_init = nn.initializer.Constant(0.0) + else: + last_init = nn.initializer.TruncatedNormal() + + self.transition2 = nn.Linear( + int(in_dim * self.config.num_intermediate_factor), in_dim, + weight_attr=paddle.ParamAttr(initializer=last_init)) + + def forward(self, act, mask): + act = self.input_layer_norm(act) + + def transition_module(x): + x = self.transition1(x) + x = nn.functional.relu(x) + x = self.transition2(x) + return x + + if not self.training: + # low memory mode using subbatch + sb_transition = subbatch(transition_module, [0], [1], + self.global_config.subbatch_size, 1) + act = sb_transition(act) + else: + act = transition_module(act) + + return act + + +class MaskedMsaHead(nn.Layer): + """Head to predict MSA at the masked locations. + + The MaskedMsaHead employs a BERT-style objective to reconstruct a masked + version of the full MSA, based on a linear projection of + the MSA representation. + Jumper et al. (2021) Suppl. Sec. 1.9.9 "Masked MSA prediction" + """ + def __init__(self, channel_num, config, global_config, name='masked_msa_head'): + super(MaskedMsaHead, self).__init__() + self.config = config + self.global_config = global_config + self.num_output = config.num_output + self.logits = nn.Linear(channel_num['msa_channel'], self.num_output, name='logits') + + def forward(self, representations, batch): + """Builds MaskedMsaHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'msa': MSA representation, shape [batch, N_seq, N_res, c_m]. + batch: Batch, unused. + + Returns: + Dictionary containing: + * 'logits': logits of shape [batch, N_seq, N_res, N_aatype] with + (unnormalized) log probabilies of predicted aatype at position. + """ + del batch + logits = self.logits(representations['msa']) + return {'logits': logits} + + def loss(self, value, batch): + errors = softmax_cross_entropy( + labels=paddle.nn.functional.one_hot(batch['true_msa'], num_classes=self.num_output), + logits=value['logits']) + loss = (paddle.sum(errors * batch['bert_mask'], axis=[-2, -1]) / + (1e-8 + paddle.sum(batch['bert_mask'], axis=[-2, -1]))) + return {'loss': loss} + + +class PredictedLDDTHead(nn.Layer): + """Head to predict the per-residue LDDT to be used as a confidence measure. + + Jumper et al. (2021) Suppl. Sec. 1.9.6 "Model confidence prediction (pLDDT)" + Jumper et al. (2021) Suppl. Alg. 29 "predictPerResidueLDDT_Ca" + """ + + def __init__(self, channel_num, config, global_config, name='predicted_lddt_head'): + super(PredictedLDDTHead, self).__init__() + self.config = config + self.global_config = global_config + + self.input_layer_norm = nn.LayerNorm(channel_num['seq_channel'], + name='input_layer_norm') + self.act_0 = nn.Linear(channel_num['seq_channel'], + self.config.num_channels, name='act_0') + self.act_1 = nn.Linear(self.config.num_channels, + self.config.num_channels, name='act_1') + self.logits = nn.Linear(self.config.num_channels, + self.config.num_bins, name='logits') + + def forward(self, representations, batch): + """Builds PredictedLDDTHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'structure_module': Single representation from the structure module, + shape [n_batch, N_res, c_s]. + + Returns: + Dictionary containing : + * 'logits': logits of shape [n_batch, N_res, N_bins] with + (unnormalized) log probabilies of binned predicted lDDT. + """ + act = representations['structure_module'] + act = self.input_layer_norm(act) + act = nn.functional.relu(self.act_0(act)) + act = nn.functional.relu(self.act_1(act)) + logits = self.logits(act) + + return dict(logits=logits) + + def loss(self, value, batch): + # Shape (n_batch, num_res, 37, 3) + pred_all_atom_pos = value['structure_module']['final_atom_positions'] + # Shape (n_batch, num_res, 37, 3) + true_all_atom_pos = paddle.cast(batch['all_atom_positions'], 'float32') + # Shape (n_batch, num_res, 37) + all_atom_mask = paddle.cast(batch['all_atom_mask'], 'float32') + + # Shape (batch_size, num_res) + lddt_ca = lddt.lddt( + # Shape (batch_size, num_res, 3) + predicted_points=pred_all_atom_pos[:, :, 1, :], + # Shape (batch_size, num_res, 3) + true_points=true_all_atom_pos[:, :, 1, :], + # Shape (batch_size, num_res, 1) + true_points_mask=all_atom_mask[:, :, 1:2], + cutoff=15., + per_residue=True) + lddt_ca = lddt_ca.detach() + + # Shape (batch_size, num_res) + num_bins = self.config.num_bins + bin_index = paddle.floor(lddt_ca * num_bins) + + # protect against out of range for lddt_ca == 1 + bin_index = paddle.minimum(bin_index, paddle.to_tensor(num_bins - 1, dtype='float32')) + lddt_ca_one_hot = paddle.nn.functional.one_hot(paddle.cast(bin_index, 'int64'), num_classes=num_bins) + + # Shape (n_batch, num_res, num_channel) + logits = value['predicted_lddt']['logits'] + errors = softmax_cross_entropy(labels=lddt_ca_one_hot, logits=logits) + + # Shape (num_res,) + mask_ca = all_atom_mask[:, :, residue_constants.atom_order['CA']] + mask_ca = paddle.to_tensor(mask_ca, dtype='float32') + loss = paddle.sum(errors * mask_ca, axis=-1) / (paddle.sum(mask_ca, axis=-1) + 1e-8) + + if self.config.filter_by_resolution: + # NMR & distillation have resolution = 0 + resolution = paddle.squeeze(batch['resolution'], axis=-1) + loss *= paddle.cast((resolution >= self.config.min_resolution) + & (resolution <= self.config.max_resolution), 'float32') + output = {'loss': loss} + return output + + +class PredictedAlignedErrorHead(nn.Layer): + """Head to predict the distance errors in the backbone alignment frames. + + Can be used to compute predicted TM-Score. + Jumper et al. (2021) Suppl. Sec. 1.9.7 "TM-score prediction" + """ + def __init__(self, channel_num, config, global_config, + name='predicted_aligned_error_head'): + super(PredictedAlignedErrorHead, self).__init__() + self.config = config + self.global_config = global_config + + self.logits = nn.Linear(channel_num['pair_channel'], + self.config.num_bins, name='logits') + + def forward(self, representations, batch): + """Builds PredictedAlignedErrorHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'pair': pair representation, shape [B, N_res, N_res, c_z]. + batch: Batch, unused. + + Returns: + Dictionary containing: + * logits: logits for aligned error, shape [B, N_res, N_res, N_bins]. + * bin_breaks: array containing bin breaks, shape [N_bins - 1]. + """ + logits = self.logits(representations['pair']) + breaks = paddle.linspace(0., self.config.max_error_bin, + self.config.num_bins-1) + + return dict(logits=logits, breaks=breaks) + + def loss(self, value, batch): + # Shape (B, num_res, 7) + predicted_affine = quat_affine.QuatAffine.from_tensor( + value['structure_module']['final_affines']) + # Shape (B, num_res, 7) + true_rot = paddle.to_tensor(batch['backbone_affine_tensor_rot'], dtype='float32') + true_trans = paddle.to_tensor(batch['backbone_affine_tensor_trans'], dtype='float32') + true_affine = quat_affine.QuatAffine( + quaternion=None, + translation=true_trans, + rotation=true_rot) + # Shape (B, num_res) + mask = batch['backbone_affine_mask'] + # Shape (B, num_res, num_res) + square_mask = mask[..., None] * mask[:, None, :] + num_bins = self.config.num_bins + # (num_bins - 1) + breaks = value['predicted_aligned_error']['breaks'] + # (B, num_res, num_res, num_bins) + logits = value['predicted_aligned_error']['logits'] + + # Compute the squared error for each alignment. + def _local_frame_points(affine): + points = [paddle.unsqueeze(x, axis=-2) for x in + paddle.unstack(affine.translation, axis=-1)] + return affine.invert_point(points, extra_dims=1) + error_dist2_xyz = [ + paddle.square(a - b) + for a, b in zip(_local_frame_points(predicted_affine), + _local_frame_points(true_affine))] + error_dist2 = sum(error_dist2_xyz) + # Shape (B, num_res, num_res) + # First num_res are alignment frames, second num_res are the residues. + error_dist2 = error_dist2.detach() + + sq_breaks = paddle.square(breaks) + true_bins = paddle.sum(paddle.cast((error_dist2[..., None] > sq_breaks), 'int32'), axis=-1) + + errors = softmax_cross_entropy( + labels=paddle.nn.functional.one_hot(true_bins, num_classes=num_bins), logits=logits) + + loss = (paddle.sum(errors * square_mask, axis=[-2, -1]) / + (1e-8 + paddle.sum(square_mask, axis=[-2, -1]))) + + if self.config.filter_by_resolution: + # NMR & distillation have resolution = 0 + resolution = paddle.squeeze(batch['resolution'], axis=-1) + loss *= paddle.cast((resolution >= self.config.min_resolution) + & (resolution <= self.config.max_resolution), 'float32') + + output = {'loss': loss} + return output + + +class ExperimentallyResolvedHead(nn.Layer): + """Predicts if an atom is experimentally resolved in a high-res structure. + + Only trained on high-resolution X-ray crystals & cryo-EM. + Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction' + """ + + def __init__(self, channel_num, config, global_config, name='experimentally_resolved_head'): + super(ExperimentallyResolvedHead, self).__init__() + self.config = config + self.global_config = global_config + self.logits = nn.Linear(channel_num['seq_channel'], 37, name='logits') + + def forward(self, representations, batch): + """Builds ExperimentallyResolvedHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'single': Single representation, shape [B, N_res, c_s]. + batch: Batch, unused. + + Returns: + Dictionary containing: + * 'logits': logits of shape [B, N_res, 37], + log probability that an atom is resolved in atom37 representation, + can be converted to probability by applying sigmoid. + """ + logits = self.logits(representations['single']) + return dict(logits=logits) + + def loss(self, value, batch): + logits = value['logits'] + assert len(logits.shape) == 3 + + # Does the atom appear in the amino acid? + atom_exists = batch['atom37_atom_exists'] + # Is the atom resolved in the experiment? Subset of atom_exists, + # *except for OXT* + all_atom_mask = paddle.cast(batch['all_atom_mask'], 'float32') + + xent = sigmoid_cross_entropy(labels=all_atom_mask, logits=logits) + loss = paddle.sum(xent * atom_exists, axis=[-2, -1]) / (1e-8 + paddle.sum(atom_exists, axis=[-2, -1])) + + if self.config.filter_by_resolution: + # NMR & distillation have resolution = 0 + resolution = paddle.squeeze(batch['resolution'], axis=-1) + loss *= paddle.cast((resolution >= self.config.min_resolution) + & (resolution <= self.config.max_resolution), 'float32') + + output = {'loss': loss} + return output + + +class DistogramHead(nn.Layer): + """Head to predict a distogram. + + Jumper et al. (2021) Suppl. Sec. 1.9.8 "Distogram prediction" + """ + + def __init__(self, channel_num, config, name='distogram_head'): + super(DistogramHead, self).__init__() + self.config = config + # self.global_config = global_config + + self.half_logits = nn.Linear(channel_num['pair_channel'], + self.config.num_bins, name='half_logits') + init_final_linear(self.half_logits) + + def forward(self, representations, batch): + """Builds DistogramHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'pair': pair representation, shape [batch, N_res, N_res, c_z]. + + Returns: + Dictionary containing: + * logits: logits for distogram, shape [batch, N_res, N_res, N_bins]. + * bin_breaks: array containing bin breaks, shape [batch, N_bins - 1]. + """ + half_logits = self.half_logits(representations['pair']) + + logits = half_logits + paddle.transpose(half_logits, perm=[0, 2, 1, 3]) + breaks = paddle.linspace(self.config.first_break, self.config.last_break, + self.config.num_bins - 1) + breaks = paddle.tile(breaks[None, :], + repeat_times=[logits.shape[0], 1]) + + return { + 'logits': logits, + 'bin_edges': breaks} + + def loss(self, value, batch): + return _distogram_log_loss(value['logits'], value['bin_edges'], + batch, self.config.num_bins) + + +def _distogram_log_loss(logits, bin_edges, batch, num_bins): + """Log loss of a distogram.""" + positions = batch['pseudo_beta'] + mask = batch['pseudo_beta_mask'] + + assert positions.shape[-1] == 3 + + sq_breaks = paddle.square(bin_edges).unsqueeze([1, 2]) + + dist2 = paddle.sum( + paddle.square( + paddle.unsqueeze(positions, axis=-2) - + paddle.unsqueeze(positions, axis=-3)), + axis=-1, + keepdim=True) + + true_bins = paddle.sum(dist2 > sq_breaks, axis=-1) + + errors = softmax_cross_entropy( + labels=paddle.nn.functional.one_hot(true_bins, num_classes=num_bins), logits=logits) + + square_mask = paddle.unsqueeze(mask, axis=-2) * paddle.unsqueeze(mask, axis=-1) + + avg_error = ( + paddle.sum(errors * square_mask, axis=[-2, -1]) / + (1e-6 + paddle.sum(square_mask, axis=[-2, -1]))) + dist2 = dist2[..., 0] + return { + 'loss': avg_error, + 'true_dist': paddle.sqrt(1e-6 + dist2)} + + +def dgram_from_positions(positions, num_bins, min_bin, max_bin): + lower_breaks = paddle.linspace(min_bin, max_bin, num_bins) + lower_breaks = paddle.square(lower_breaks) + upper_breaks = paddle.concat([lower_breaks[1:], + paddle.to_tensor([1e8], dtype='float32')]) + + def _squared_difference(x, y): + return paddle.square(x - y) + + dist2 = paddle.sum( + _squared_difference( + paddle.unsqueeze(positions, axis=-2), + paddle.unsqueeze(positions, axis=-3)), + axis=-1, keepdim=True) + + dgram = ((dist2 > lower_breaks).astype('float32') * + (dist2 < upper_breaks).astype('float32')) + return dgram + + +class EvoformerIteration(nn.Layer): + """Single iteration (block) of Evoformer stack. + + Jumper et al. (2021) Suppl. Alg. 6 "EvoformerStack" lines 2-10 + """ + def __init__(self, channel_num, config, global_config, is_extra_msa=False): + super(EvoformerIteration, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.is_extra_msa = is_extra_msa + + # Row-wise Gated Self-attention with Pair Bias + self.msa_row_attention_with_pair_bias = MSARowAttentionWithPairBias( + channel_num, self.config.msa_row_attention_with_pair_bias, + self.global_config, is_extra_msa) + dropout_rate, dropout_axis = self._parse_dropout_params( + self.msa_row_attention_with_pair_bias) + self.msa_row_attn_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + if self.is_extra_msa: + self.msa_column_global_attention = MSAColumnGlobalAttention( + channel_num, config.msa_column_attention, global_config) + dropout_rate, dropout_axis = self._parse_dropout_params( + self.msa_column_global_attention) + self.msa_col_attn_dropout = nn.Dropout( + dropout_rate, axis=dropout_axis) + else: + self.msa_column_attention = MSAColumnAttention( + channel_num, config.msa_column_attention, global_config) + dropout_rate, dropout_axis = self._parse_dropout_params( + self.msa_column_attention) + self.msa_col_attn_dropout = nn.Dropout( + dropout_rate, axis=dropout_axis) + + self.msa_transition = Transition( + channel_num, self.config.msa_transition, self.global_config, + is_extra_msa, 'msa_transition') + dropout_rate, dropout_axis = self._parse_dropout_params( + self.msa_transition) + self.msa_transition_dropout = nn.Dropout( + dropout_rate, axis=dropout_axis) + + # OuterProductMean + self.outer_product_mean = OuterProductMean(channel_num, + self.config.outer_product_mean, self.global_config, + self.is_extra_msa, name='outer_product_mean') + + # Dropout + dropout_rate, dropout_axis = self._parse_dropout_params( + self.outer_product_mean) + self.outer_product_mean_dropout = nn.Dropout( + dropout_rate, axis=dropout_axis) + + # Triangle Multiplication. + self.triangle_multiplication_outgoing = TriangleMultiplication(channel_num, + self.config.triangle_multiplication_outgoing, self.global_config, + name='triangle_multiplication_outgoing') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_multiplication_outgoing) + self.triangle_outgoing_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_multiplication_incoming = TriangleMultiplication(channel_num, + self.config.triangle_multiplication_incoming, self.global_config, + name='triangle_multiplication_incoming') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_multiplication_incoming) + self.triangle_incoming_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + # TriangleAttention. + self.triangle_attention_starting_node = TriangleAttention(channel_num, + self.config.triangle_attention_starting_node, self.global_config, + name='triangle_attention_starting_node') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_attention_starting_node) + self.triangle_starting_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_attention_ending_node = TriangleAttention(channel_num, + self.config.triangle_attention_ending_node, self.global_config, + name='triangle_attention_ending_node') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_attention_ending_node) + self.triangle_ending_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + # Pair transition. + self.pair_transition = Transition( + channel_num, self.config.pair_transition, self.global_config, + is_extra_msa, 'pair_transition') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.pair_transition) + self.pair_transition_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + def _parse_dropout_params(self, module): + dropout_rate = 0.0 if self.global_config.deterministic else \ + module.config.dropout_rate + dropout_axis = None + if module.config.shared_dropout: + dropout_axis = { + 'per_row': [0, 2, 3], + 'per_column': [0, 1, 3], + }[module.config.orientation] + + return dropout_rate, dropout_axis + + def forward(self, msa_act, pair_act, masks): + msa_mask, pair_mask = masks['msa'], masks['pair'] + + # [B, N_seq//dap_size, N_res, c_m] + residual = self.msa_row_attention_with_pair_bias( + msa_act, msa_mask, pair_act) + residual = self.msa_row_attn_dropout(residual) + msa_act = msa_act + residual + + # [B, N_seq//dap_size, N_res, c_m] => [B, N_seq, N_res//dap_size, c_m] + msa_act = dap.row_to_col(msa_act) + + if self.is_extra_msa: + # [B, N_seq, N_res//dap_size, c_m] + residual = self.msa_column_global_attention(msa_act, msa_mask) + residual = self.msa_col_attn_dropout(residual) + msa_act = msa_act + residual + + # [B, N_seq, N_res//dap_size, c_m] + residual = self.msa_transition(msa_act, msa_mask) + residual = self.msa_transition_dropout(residual) + msa_act = msa_act + residual + + else: + # [B, N_seq, N_res//dap_size, c_m] + residual = self.msa_column_attention(msa_act, msa_mask) + residual = self.msa_col_attn_dropout(residual) + msa_act = msa_act + residual + + # [B, N_seq, N_res//dap_size, c_m] + residual = self.msa_transition(msa_act, msa_mask) + residual = self.msa_transition_dropout(residual) + msa_act = msa_act + residual + + # [B, N_res//dap_size, N_res, c_z] + residual = self.outer_product_mean(msa_act, msa_mask) + residual = self.outer_product_mean_dropout(residual) + pair_act = pair_act + residual + + # [B, N_seq, N_res//dap_size, c_m] => [B, N_seq//dap_size, N_res, c_m] + msa_act = dap.all_to_all(msa_act, in_axis=1, out_axis=2) + + # scatter if using dap, otherwise do nothing + pair_mask_row = dap.scatter(pair_mask, axis=1) + pair_mask_col = dap.scatter(pair_mask, axis=2) + + # [B, N_res//dap_size, N_res, c_z] + residual = self.triangle_multiplication_outgoing(pair_act, pair_mask_row) + residual = self.triangle_outgoing_dropout(residual) + pair_act = pair_act + residual + + # [B, N_res//dap_size, N_res, c_z] => [B, N_res, N_res//dap_size, c_z] + pair_act = dap.row_to_col(pair_act) + # [B, N_res, N_res//dap_size, c_z] + residual = self.triangle_multiplication_incoming(pair_act, pair_mask_col) + residual = self.triangle_incoming_dropout(residual) + pair_act = pair_act + residual + + # [B, N_res, N_res//dap_size, c_z] => [B, N_res//dap_size, N_res, c_z] + pair_act = dap.col_to_row(pair_act) + # [B, N_res//dap_size, N_res, c_z] + residual = self.triangle_attention_starting_node(pair_act, pair_mask_row) + residual = self.triangle_starting_dropout(residual) + pair_act = pair_act + residual + + # [B, N_res//dap_size, N_res, c_z] => [B, N_res, N_res//dap_size, c_z] + pair_act = dap.row_to_col(pair_act) + # [B, N_res, N_res//dap_size, c_z] + residual = self.triangle_attention_ending_node(pair_act, pair_mask_col) + residual = self.triangle_ending_dropout(residual) + pair_act = pair_act + residual + + residual = self.pair_transition(pair_act, pair_mask) + residual = self.pair_transition_dropout(residual) + pair_act = pair_act + residual + + # [B, N_res, N_res//dap_size, c_z] => [B, N_res//dap_size, N_res, c_z] + pair_act = dap.col_to_row(pair_act) + + # wait if using async communication and dap, otherwise do nothing + # [B, N_seq//dap_size, N_res, c_m] + msa_act = dap.all_to_all_opp(msa_act, in_axis=1, out_axis=2) + + return msa_act, pair_act + + +class EmbeddingsAndEvoformer(nn.Layer): + """Embeds the input data and runs Evoformer. + + Produces the MSA, single and pair representations. + Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5-18 + """ + + def __init__(self, channel_num, config, global_config): + super(EmbeddingsAndEvoformer, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + # InputEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 + # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" + self.preprocess_1d = nn.Linear(channel_num['target_feat'], + self.config.msa_channel, name='preprocess_1d') + self.preprocess_msa = nn.Linear(channel_num['msa_feat'], + self.config.msa_channel, name='preprocess_msa') + self.left_single = nn.Linear(channel_num['target_feat'], self.config.pair_channel, + name='left_single') + self.right_single = nn.Linear(channel_num['target_feat'], self.config.pair_channel, + name='right_single') + + # RecyclingEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 + # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" + if self.config.recycle_pos: + self.prev_pos_linear = nn.Linear(self.config.prev_pos.num_bins, + self.config.pair_channel) + + # RelPosEmbedder + # Jumper et al. (2021) Suppl. Alg. 4 "relpos" + # Jumper et al. (2021) Suppl. Alg. 5 "one_hot" + if self.config.max_relative_feature: + self.pair_activiations = nn.Linear( + 2 * self.config.max_relative_feature + 1, + self.config.pair_channel) + + if self.config.recycle_features: + self.prev_msa_first_row_norm = nn.LayerNorm( + self.config.msa_channel) + self.prev_pair_norm = nn.LayerNorm(self.config.pair_channel) + + # Embed templates into the pair activations. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13 + if self.config.template.enabled: + self.channel_num['template_angle'] = 57 + self.channel_num['template_pair'] = 88 + self.template_embedding = TemplateEmbedding( + self.channel_num, self.config.template, self.global_config) + + # ExtraMSAEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16 + self.extra_msa_activations = nn.Linear( + 25, # 23 (20aa+unknown+gap+mask) + 1 (has_del) + 1 (del_val) + self.config.extra_msa_channel) + + # Extra MSA Stack. + # Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack" + self.extra_msa_stack = nn.LayerList() + for _ in range(self.config.extra_msa_stack_num_block): + self.extra_msa_stack.append(EvoformerIteration( + self.channel_num, self.config.evoformer, self.global_config, + is_extra_msa=True)) + + # Embed templates torsion angles + if self.config.template.enabled and self.config.template.embed_torsion_angles: + c = self.config.msa_channel + self.template_single_embedding = nn.Linear( + self.channel_num['template_angle'], c) + self.template_projection = nn.Linear(c, c) + + # Main trunk of the network + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18 + self.evoformer_iteration = nn.LayerList() + for _ in range(self.config.evoformer_num_block): + self.evoformer_iteration.append(EvoformerIteration( + self.channel_num, self.config.evoformer, self.global_config, + is_extra_msa=False)) + + self.single_activations = nn.Linear( + self.config.msa_channel, self.config.seq_channel) + + def _pseudo_beta_fn(self, aatype, all_atom_positions, all_atom_masks): + gly_id = paddle.ones_like(aatype) * residue_constants.restype_order['G'] + is_gly = paddle.equal(aatype, gly_id) + + ca_idx = residue_constants.atom_order['CA'] + cb_idx = residue_constants.atom_order['CB'] + + n = len(all_atom_positions.shape) + pseudo_beta = paddle.where( + paddle.tile(paddle.unsqueeze(is_gly, axis=-1), [1] * len(is_gly.shape) + [3]), + paddle.squeeze(all_atom_positions.slice([n-2], [ca_idx], [ca_idx+1]), axis=-2), + paddle.squeeze(all_atom_positions.slice([n-2], [cb_idx], [cb_idx+1]), axis=-2) + ) + + if all_atom_masks is not None: + m = len(all_atom_masks) + pseudo_beta_mask = paddle.where( + is_gly, + paddle.squeeze( + all_atom_masks.slice([m-1], [ca_idx], [ca_idx+1]), + axis=-1), + paddle.squeeze( + all_atom_masks.slice([m-1], [cb_idx], [cb_idx+1]), + axis=-1)) + pseudo_beta_mask = paddle.squeeze(pseudo_beta_mask, axis=-1) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + def _create_extra_msa_feature(self, batch): + # 23: 20aa + unknown + gap + bert mask + msa_1hot = nn.functional.one_hot(batch['extra_msa'], 23) + msa_feat = [msa_1hot, + paddle.unsqueeze(batch['extra_has_deletion'], axis=-1), + paddle.unsqueeze(batch['extra_deletion_value'], axis=-1)] + return paddle.concat(msa_feat, axis=-1) + + def forward(self, batch): + # InputEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 + # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" + preprocess_1d = self.preprocess_1d(batch['target_feat']) + # preprocess_msa = self.preprocess_msa(batch['msa_feat']) + msa_activations = paddle.unsqueeze(preprocess_1d, axis=1) + \ + self.preprocess_msa(batch['msa_feat']) + + right_single = self.right_single(batch['target_feat']) # 1, n_res, 22 -> 1, n_res, 128 + right_single = paddle.unsqueeze(right_single, axis=1) # 1, n_res, 128 -> 1, 1, n_res, 128 + left_single = self.left_single(batch['target_feat']) # 1, n_res, 22 -> 1, n_res, 128 + left_single = paddle.unsqueeze(left_single, axis=2) # 1, n_res, 128 -> 1, n_res, 1, 128 + pair_activations = left_single + right_single + + mask_2d = paddle.unsqueeze(batch['seq_mask'], axis=1) * paddle.unsqueeze(batch['seq_mask'], axis=2) + + # Inject previous outputs for recycling. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 + # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" + if self.config.recycle_pos and 'prev_pos' in batch: + prev_pseudo_beta = self._pseudo_beta_fn( + batch['aatype'], batch['prev_pos'], None) + dgram = dgram_from_positions( + prev_pseudo_beta, **self.config.prev_pos) + pair_activations += self.prev_pos_linear(dgram) + + if self.config.recycle_features: + if 'prev_msa_first_row' in batch: + prev_msa_first_row = self.prev_msa_first_row_norm( + batch['prev_msa_first_row']) + + # A workaround for `jax.ops.index_add` + msa_first_row = paddle.squeeze(msa_activations[:, 0, :], axis=1) + msa_first_row += prev_msa_first_row + msa_first_row = paddle.unsqueeze(msa_first_row, axis=1) + msa_activations = paddle.concat([msa_first_row, msa_activations[:, 1:, :]], axis=1) + + if 'prev_pair' in batch: + pair_activations += self.prev_pair_norm(batch['prev_pair']) + + # RelPosEmbedder + # Jumper et al. (2021) Suppl. Alg. 4 "relpos" + # Jumper et al. (2021) Suppl. Alg. 5 "one_hot" + if self.config.max_relative_feature: + pos = batch['residue_index'] # [bs, N_res] + offset = paddle.unsqueeze(pos, axis=[-1]) - \ + paddle.unsqueeze(pos, axis=[-2]) + rel_pos = nn.functional.one_hot( + paddle.clip( + offset + self.config.max_relative_feature, + min=0, + max=2 * self.config.max_relative_feature), + 2 * self.config.max_relative_feature + 1) + rel_pos_bias = self.pair_activiations(rel_pos) + pair_activations += rel_pos_bias + + # TemplateEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13 + if self.config.template.enabled: + template_batch = {k: batch[k] for k in batch + if k.startswith('template_')} + template_pair_repr = self.template_embedding( + pair_activations, template_batch, mask_2d) + pair_activations += template_pair_repr + + # ExtraMSAEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16 + extra_msa_feat = self._create_extra_msa_feature(batch) + extra_msa_activations = self.extra_msa_activations(extra_msa_feat) + + # ================================================== + # Extra MSA Stack + # Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack" + # ================================================== + extra_msa_stack_input = { + 'msa': extra_msa_activations, + 'pair': pair_activations, + } + + # scatter if using dap, otherwise do nothing + # [B, N_seq, N_res, c_m] => [B, N_seq//dap_size, N_res, c_m] + extra_msa_stack_input['msa'] = dap.scatter(extra_msa_stack_input['msa'], axis=1) + # [B, N_res, N_res, c_z] => [B, N_res//dap_size, N_res, c_z] + extra_msa_stack_input['pair'] = dap.scatter(extra_msa_stack_input['pair'], axis=1) + + for extra_msa_stack_iteration in self.extra_msa_stack: + extra_msa_act, extra_pair_act = recompute_wrapper(extra_msa_stack_iteration, + extra_msa_stack_input['msa'], + extra_msa_stack_input['pair'], + {'msa': batch['extra_msa_mask'], + 'pair': mask_2d}, + is_recompute=self.training) + extra_msa_stack_output = { + 'msa': extra_msa_act, + 'pair': extra_pair_act} + extra_msa_stack_input = { + 'msa': extra_msa_stack_output['msa'], + 'pair': extra_msa_stack_output['pair']} + + # gather if using dap, otherwise do nothing + # [B, N_res//dap_size, N_res, c_z] => [B, N_res, N_res, c_z] + extra_msa_stack_output['pair'] = dap.gather(extra_msa_stack_output['pair'], axis=1) + + evoformer_input = { + 'msa': msa_activations, + 'pair': extra_msa_stack_output['pair'], + } + + evoformer_masks = { + 'msa': batch['msa_mask'], + 'pair': mask_2d, + } + + # ================================================== + # Template angle feat + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 7-8 + # ================================================== + if self.config.template.enabled and self.config.template.embed_torsion_angles: + num_templ, num_res = batch['template_aatype'].shape[1:] + + aatype_one_hot = nn.functional.one_hot(batch['template_aatype'], 22) + # Embed the templates aatype, torsion angles and masks. + # Shape (templates, residues, msa_channels) + ret = all_atom.atom37_to_torsion_angles( + aatype=batch['template_aatype'], + all_atom_pos=batch['template_all_atom_positions'], + all_atom_mask=batch['template_all_atom_masks'], + # Ensure consistent behaviour during testing: + placeholder_for_undefined=not self.global_config.zero_init) + + template_features = paddle.concat([ + aatype_one_hot, + paddle.reshape(ret['torsion_angles_sin_cos'], + [-1, num_templ, num_res, 14]), + paddle.reshape(ret['alt_torsion_angles_sin_cos'], + [-1, num_templ, num_res, 14]), + ret['torsion_angles_mask']], axis=-1) + pdb.set_trace() + + template_activations = self.template_single_embedding( + template_features) + template_activations = nn.functional.relu(template_activations) + template_activations = self.template_projection(template_activations) + + # Concatenate the templates to the msa. + evoformer_input['msa'] = paddle.concat( + [evoformer_input['msa'], template_activations], axis=1) + + # Concatenate templates masks to the msa masks. + # Use mask from the psi angle, as it only depends on the backbone atoms + # from a single residue. + torsion_angle_mask = ret['torsion_angles_mask'][..., 2] + torsion_angle_mask = torsion_angle_mask.astype( + evoformer_masks['msa'].dtype) + evoformer_masks['msa'] = paddle.concat( + [evoformer_masks['msa'], torsion_angle_mask], axis=1) + + # scatter if using dap, otherwise do nothing + # [B, N_seq, N_res, c_m] => [B, N_seq//dap_size, N_res, c_m] + evoformer_input['msa'] = dap.scatter(evoformer_input['msa'], axis=1) + # [B, N_res, N_res, c_z] => [B, N_res//dap_size, N_res, c_z] + evoformer_input['pair'] = dap.scatter(evoformer_input['pair'], axis=1) + + # ================================================== + # Main MSA Stack + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18 + # ================================================== + for evoformer_block in self.evoformer_iteration: + msa_act, pair_act = recompute_wrapper(evoformer_block, + evoformer_input['msa'], + evoformer_input['pair'], + evoformer_masks, + is_recompute=self.training) + evoformer_output = { + 'msa': msa_act, + 'pair': pair_act} + evoformer_input = { + 'msa': evoformer_output['msa'], + 'pair': evoformer_output['pair'], + } + + # gather if using dap, otherwise do nothing + # [B, N_seq//dap_size, N_res, c_m] => [B, N_seq, N_res, c_m] + evoformer_output['msa'] = dap.gather(evoformer_output['msa'], axis=1) + # [B, N_res//dap_size, N_res, c_z] => [B, N_res, N_res, c_z] + evoformer_output['pair'] = dap.gather(evoformer_output['pair'], axis=1) + + msa_activations = evoformer_output['msa'] + pair_activations = evoformer_output['pair'] + single_activations = self.single_activations(msa_activations[:, 0]) + + num_seq = batch['msa_feat'].shape[1] + output = { + 'single': single_activations, + 'pair': pair_activations, + # Crop away template rows such that they are not used + # in MaskedMsaHead. + 'msa': msa_activations[:, :num_seq], + 'msa_first_row': msa_activations[:, 0], + } + + return output + + +class OuterProductMean(nn.Layer): + """Computes mean outer product. + + Jumper et al. (2021) Suppl. Alg. 10 "OuterProductMean" + """ + + def __init__(self, channel_num, config, global_config, is_extra_msa, name='outer_product_mean'): + super(OuterProductMean, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + if is_extra_msa: + c_m = channel_num['extra_msa_channel'] + else: + c_m = channel_num['msa_channel'] + + self.layer_norm_input = nn.LayerNorm(c_m, name='layer_norm_input') + self.left_projection = nn.Linear( + c_m, self.config.num_outer_channel, name='left_projection') + self.right_projection = nn.Linear( + c_m, self.config.num_outer_channel, name='right_projection') + + if self.global_config.zero_init: + init_w = nn.initializer.Constant(value=0.0) + else: + init_w = nn.initializer.KaimingNormal() + + self.output_w = paddle.create_parameter( + [self.config.num_outer_channel, self.config.num_outer_channel, channel_num['pair_channel']], + 'float32', default_initializer=init_w) + self.output_b = paddle.create_parameter( + [channel_num['pair_channel']], 'float32', + default_initializer=nn.initializer.Constant(value=0.0)) + + def forward(self, act, mask): + """Builds OuterProductMean module. + + Arguments: + act: MSA representation, shape [batch, N_seq, N_res, c_m]. + mask: MSA mask, shape [batch, N_seq, N_res]. + + Returns: + Update to pair representation, shape [batch, N_res, N_res, c_z]. + """ + # [B, N_seq, N_res//dap_size, c_m] + act = self.layer_norm_input(act) + # [B, N_seq, N_res//dap_size, c_m] => [B, N_seq, N_res//dap_size, num_outer_channel] + right_act_before = self.right_projection(act) + # [B, N_seq, N_res//dap_size, num_outer_channel] => [B, N_seq, N_res, num_outer_channel] + right_act = dap.all_gather(right_act_before, axis=2) + + # [B, N_seq, N_res//dap_size, c_m] => [B, N_seq, N_res//dap_size, num_outer_channel] + left_act = self.left_projection(act) + # [B, N_seq, N_res] => [B, N_seq, N_res, 1] + mask = paddle.unsqueeze(mask, axis=-1) + # [B, N_seq, N_res, 1] => [B, N_seq, N_res//dap_size, 1] + mask_col = dap.scatter(mask, axis=2) + left_act = mask_col * left_act + + # [B, N_seq, N_res//dap_size, 1], [B, N_seq, N_res, 1] => [B, N_res//dap_size, N_res, 1] + epsilon = 1e-3 + norm = paddle.einsum('nabc,nadc->nbdc', mask_col, mask) + epsilon + + def compute_chunk(left_act, right_act): + # This is equivalent to + # + # act = jnp.einsum('abc,ade->dceb', left_act, right_act) + # act = jnp.einsum('dceb,cef->bdf', act, output_w) + output_b + # + # but faster. maybe for subbatch inference? + + # [B, N_seq, N_res//dap_size, num_outer_channel] => [B, N_seq, num_outer_channel, N_res//dap_size] + left_act = left_act.transpose([0, 1, 3, 2]) + # wait if using async communication and dap, otherwise do nothing + right_act_after = dap.all_gather_opp(right_act, axis=2) + # [B, N_seq, num_outer_channel, N_res//dap_size], [B, N_seq, N_res, num_outer_channel] + # => [B, N_res, num_outer_channel, num_outer_channel, N_res//dap_size] + act = paddle.einsum('nacb,nade->ndceb', left_act, right_act_after) + # [B, N_res, num_outer_channel, num_outer_channel, N_res//dap_size], [num_outer_channel, num_outer_channel, c_z] + # => [B, N_res, N_res//dap_size, c_z] + act = paddle.einsum('ndceb,cef->ndbf', act, self.output_w) + self.output_b + # [B, N_res, N_res//dap_size, c_z] => [B, N_res//dap_size, N_res, c_z] + return act.transpose([0, 2, 1, 3]) + + if not self.training: + # low memory mode using subbatch + sb_chunk = subbatch(compute_chunk, [0], [2], + self.config.chunk_size, 1) + act = sb_chunk(left_act, right_act) + else: + act = compute_chunk(left_act, right_act) + + act = act / norm + + return act + + +class TriangleAttention(nn.Layer): + """Triangle Attention. + + Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode" + Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode" + """ + + def __init__(self, channel_num, config, global_config, name='triangle_attention'): + super(TriangleAttention, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + assert config.orientation in ['per_row', 'per_column'] + + self.query_norm = nn.LayerNorm(channel_num['pair_channel'], + name='query_norm') + self.feat_2d_weights = paddle.create_parameter( + [channel_num['pair_channel'], self.config.num_head], 'float32', + default_initializer=nn.initializer.Normal( + std=1. / np.sqrt(channel_num['pair_channel']))) + + self.attention = Attention(self.config, self.global_config, + channel_num['pair_channel'], channel_num['pair_channel'], + channel_num['pair_channel']) + + + def forward(self, pair_act, pair_mask): + """Builds TriangleAttention module. + + Arguments: + pair_act: [batch, N_res, N_res, c_z] pair activations tensor + pair_mask: [batch, N_res, N_res] mask of non-padded regions in the tensor. + + Returns: + Update to pair_act, shape [batch, N_res, N_res, c_z]. + """ + if self.config.orientation == 'per_column': + pair_act = pair_act.transpose([0, 2, 1, 3]) + pair_mask = pair_mask.transpose([0, 2, 1]) + + # [B, N_res//dap_size, N_res] + bias = 1e9 * (pair_mask - 1.) + # [B, N_res//dap_size, 1, 1, N_res] + bias = paddle.unsqueeze(bias, axis=[2, 3]) + + pair_act = self.query_norm(pair_act) + + # [B, N_res//dap_size, N_res, cz], [cz, head] => [B, head, N_res//dap_size, N_res] + nonbatched_bias_before = paddle.einsum('bqkc,ch->bhqk', pair_act, self.feat_2d_weights) + + # # [B, head, N_res//dap_size, N_res] => [B, head, N_res, N_res] + nonbatched_bias = dap.all_gather(nonbatched_bias_before, axis=2) + + if not self.training: + # low memory mode using subbatch + sb_attn = subbatch(self.attention, [0, 1, 2], [1, 1, 1], + self.global_config.subbatch_size, 1) + pair_act = sb_attn(pair_act, pair_act, bias, nonbatched_bias) + else: + pair_act = self.attention(pair_act, pair_act, bias, nonbatched_bias) + + if self.config.orientation == 'per_column': + pair_act = pair_act.transpose([0, 2, 1, 3]) + + return pair_act + + +class TriangleMultiplication(nn.Layer): + """Triangle multiplication layer ("outgoing" or "incoming"). + + Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing" + Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming" + """ + + def __init__(self, channel_num, config, global_config, name='triangle_multiplication'): + super(TriangleMultiplication, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + self.layer_norm_input = nn.LayerNorm(self.channel_num['pair_channel'], name='layer_norm_input') + self.left_projection = nn.Linear(self.channel_num['pair_channel'], + self.config.num_intermediate_channel, name='left_projection') + self.right_projection = nn.Linear(self.channel_num['pair_channel'], + self.config.num_intermediate_channel, name='right_projection') + self.left_gate = nn.Linear(self.channel_num['pair_channel'], + self.config.num_intermediate_channel, name='left_gate') + init_gate_linear(self.left_gate) + self.right_gate = nn.Linear(self.channel_num['pair_channel'], + self.config.num_intermediate_channel, name='right_gate') + init_gate_linear(self.right_gate) + + # line 4 + self.center_layer_norm = nn.LayerNorm(self.config.num_intermediate_channel, name='center_layer_norm') + self.output_projection = nn.Linear(self.config.num_intermediate_channel, + self.channel_num['pair_channel'], name='output_projection') + init_final_linear(self.output_projection) + # line 3 + self.gating_linear = nn.Linear(self.channel_num['pair_channel'], + self.channel_num['pair_channel'], name='output_projection') + init_gate_linear(self.gating_linear) + + def forward(self, act, mask): + """Builds TriangleMultiplication module. + + Arguments: + act: Pair activations, shape [batch, N_res, N_res, c_z] + mask: Pair mask, shape [batch, N_res, N_res]. + + Returns: + Outputs, same shape/type as act. + """ + # Outgoing [batch, N_res//dap_size, N_res] => [batch, N_res//dap_size, N_res, 1] + # Incoming [batch, N_res, N_res//dap_size] => [batch, N_res, N_res//dap_size, 1] + mask = paddle.unsqueeze(mask, axis=-1) # [batch, N_res, N_res, 1] + + # Outgoing [B, N_res//dap_size, N_res, c_z] + # Incoming [B, N_res, N_res//dap_size, c_z] + act = self.layer_norm_input(act) # line 1 + + # Outgoing [B, N_res//dap_size, N_res, c_z] => [B, N_res//dap_size, N_res, num_intermediate_channel] + # Incoming [B, N_res, N_res//dap_size, c_z] => [B, N_res, N_res//dap_size, num_intermediate_channel] + left_proj_act = mask * self.left_projection(act) + right_proj_act = mask * self.right_projection(act) + + # Outgoing [B, N_res//dap_size, N_res, c_z] => [B, N_res//dap_size, N_res, num_intermediate_channel] + # Incoming [B, N_res, N_res//dap_size, c_z] => [B, N_res, N_res//dap_size, num_intermediate_channel] + left_gate_values = nn.functional.sigmoid(self.left_gate(act)) + right_gate_values = nn.functional.sigmoid(self.right_gate(act)) + + # Outgoing [B, N_res//dap_size, N_res, num_intermediate_channel] + # Incoming [B, N_res, N_res//dap_size, num_intermediate_channel] + left_proj_act = left_proj_act * left_gate_values + right_proj_act_before = right_proj_act * right_gate_values + + + # "Outgoing" edges equation: 'ikc,jkc->ijc' + # "Incoming" edges equation: 'kjc,kic->ijc' + # Note on the Suppl. Alg. 11 & 12 notation: + # For the "outgoing" edges, a = left_proj_act and b = right_proj_act + # For the "incoming" edges, it's swapped: + # b = left_proj_act and a = right_proj_act + + if self.config.equation == 'ikc,jkc->ijc': + # Outgoing + # [B, N_res//dap_size, N_res, num_intermediate_channel] => [B, N_res, N_res, num_intermediate_channel] + right_proj_act = dap.all_gather(right_proj_act_before, axis=1) + elif self.config.equation == 'kjc,kic->ijc': + # Incoming + # [B, N_res, N_res//dap_size, num_intermediate_channel] => [B, N_res, N_res, num_intermediate_channel] + right_proj_act = dap.all_gather(right_proj_act_before, axis=2) + else: + raise ValueError('unknown equation.') + + + # Outgoing [B, N_res//dap_size, N_res, c_z] + # Incoming [B, N_res, N_res//dap_size, c_z] + gate_values = nn.functional.sigmoid(self.gating_linear(act)) # line 3 + + if self.config.equation == 'ikc,jkc->ijc': + # Outgoing + dim, out_idx = 1, 1 + equation = 'bikc,bjkc->bijc' + + # [B, N_res, N_res, num_intermediate_channel] + right_proj_act_after = dap.all_gather_opp(right_proj_act, axis=1) + elif self.config.equation == 'kjc,kic->ijc': + # Incoming + dim, out_idx = 2, 2 + equation = 'bkjc,bkic->bijc' + + # [B, N_res, N_res, num_intermediate_channel] + right_proj_act_after = dap.all_gather_opp(right_proj_act, axis=2) + else: + raise ValueError('unknown equation.') + + if not self.training: + einsum_fn = subbatch(paddle.einsum, [1], [dim], + self.global_config.subbatch_size, out_idx) + act = einsum_fn(equation, left_proj_act, right_proj_act_after) + else: + # Outgoing equation = 'bikc,bjkc->bijc' + # [B, N_res//dap_size, N_res, num_intermediate_channel], [B, N_res, N_res, num_intermediate_channel] + # => [B, N_res//dap_size, N_res, num_intermediate_channel] + + # Incoming equation = 'bkjc,bkic->bijc' + # [B, N_res, N_res//dap_size, num_intermediate_channel], [B, N_res, N_res, num_intermediate_channel] + # => [B, N_res, N_res//dap_size, num_intermediate_channel] + act = paddle.einsum(equation, left_proj_act, right_proj_act_after) + + act = self.center_layer_norm(act) + act = self.output_projection(act) + + act = act * gate_values + + return act + + +class TemplatePair(nn.Layer): + """Pair processing for the templates. + + Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack" lines 2-6 + """ + def __init__(self, channel_num, config, global_config): + super(TemplatePair, self).__init__() + self.config = config + self.global_config = global_config + + channel_num = {} + channel_num['pair_channel'] = self.config.triangle_attention_ending_node.value_dim + + self.triangle_attention_starting_node = TriangleAttention(channel_num, + self.config.triangle_attention_starting_node, self.global_config, + name='triangle_attention_starting_node') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_attention_starting_node) + self.triangle_starting_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_attention_ending_node = TriangleAttention(channel_num, + self.config.triangle_attention_ending_node, self.global_config, + name='triangle_attention_ending_node') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_attention_ending_node) + self.triangle_ending_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_multiplication_outgoing = TriangleMultiplication(channel_num, + self.config.triangle_multiplication_outgoing, self.global_config, + name='triangle_multiplication_outgoing') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_multiplication_outgoing) + self.triangle_outgoing_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_multiplication_incoming = TriangleMultiplication(channel_num, + self.config.triangle_multiplication_incoming, self.global_config, + name='triangle_multiplication_incoming') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_multiplication_incoming) + self.triangle_incoming_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.pair_transition = Transition(channel_num, self.config.pair_transition, + self.global_config, is_extra_msa=False, + transition_type='pair_transition') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.pair_transition) + self.pair_transition_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + + def _parse_dropout_params(self, module): + dropout_rate = 0.0 if self.global_config.deterministic else \ + module.config.dropout_rate + dropout_axis = None + if module.config.shared_dropout: + dropout_axis = { + 'per_row': [0, 2, 3], + 'per_column': [0, 1, 3], + }[module.config.orientation] + + return dropout_rate, dropout_axis + + def forward(self, pair_act, pair_mask): + """Builds one block of TemplatePair module. + + Arguments: + pair_act: Pair activations for single template, shape [batch, N_res, N_res, c_t]. + pair_mask: Pair mask, shape [batch, N_res, N_res]. + + Returns: + Updated pair_act, shape [batch, N_res, N_res, c_t]. + """ + + residual = self.triangle_attention_starting_node(pair_act, pair_mask) + residual = self.triangle_starting_dropout(residual) + pair_act = pair_act + residual + + residual = self.triangle_attention_ending_node(pair_act, pair_mask) + residual = self.triangle_ending_dropout(residual) + pair_act = pair_act + residual + + residual = self.triangle_multiplication_outgoing(pair_act, pair_mask) + residual = self.triangle_outgoing_dropout(residual) + pair_act = pair_act + residual + + residual = self.triangle_multiplication_incoming(pair_act, pair_mask) + residual = self.triangle_incoming_dropout(residual) + pair_act = pair_act + residual + + residual = self.pair_transition(pair_act, pair_mask) + residual = self.pair_transition_dropout(residual) + pair_act = pair_act + residual + + return pair_act + + +class SingleTemplateEmbedding(nn.Layer): + """Embeds a single template. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9+11 + """ + def __init__(self, channel_num, config, global_config): + super(SingleTemplateEmbedding, self).__init__() + self.config = config + self.channel_num = channel_num + self.global_config = global_config + + self.embedding2d = nn.Linear(channel_num['template_pair'], + self.config.template_pair_stack.triangle_attention_ending_node.value_dim) + + self.template_pair_stack = nn.LayerList() + for _ in range(self.config.template_pair_stack.num_block): + self.template_pair_stack.append(TemplatePair( + self.channel_num, self.config.template_pair_stack, self.global_config)) + + self.output_layer_norm = nn.LayerNorm(self.config.attention.key_dim) + + def forward(self, query_embedding, batch, mask_2d): + """Build the single template embedding. + + Arguments: + query_embedding: Query pair representation, shape [batch, N_res, N_res, c_z]. + batch: A batch of template features (note the template dimension has been + stripped out as this module only runs over a single template). + mask_2d: Padding mask (Note: this doesn't care if a template exists, + unlike the template_pseudo_beta_mask). + + Returns: + A template embedding [N_res, N_res, c_z]. + """ + assert mask_2d.dtype == query_embedding.dtype + dtype = query_embedding.dtype + num_res = batch['template_aatype'].shape[1] + template_mask = batch['template_pseudo_beta_mask'] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + template_mask_2d = template_mask_2d.astype(dtype) + + template_dgram = dgram_from_positions( + batch['template_pseudo_beta'], + **self.config.dgram_features) + template_dgram = template_dgram.astype(dtype) + + aatype = nn.functional.one_hot(batch['template_aatype'], 22) + aatype = aatype.astype(dtype) + + to_concat = [template_dgram, template_mask_2d[..., None]] + to_concat.append(paddle.tile(aatype[..., None, :, :], + [1, num_res, 1, 1])) + to_concat.append(paddle.tile(aatype[..., None, :], + [1, 1, num_res, 1])) + + n, ca, c = [residue_constants.atom_order[a] + for a in ('N', 'CA', 'C')] + rot, trans = quat_affine.make_transform_from_reference( + n_xyz=batch['template_all_atom_positions'][..., n, :], + ca_xyz=batch['template_all_atom_positions'][..., ca, :], + c_xyz=batch['template_all_atom_positions'][..., c, :]) + affines = quat_affine.QuatAffine( + quaternion=quat_affine.rot_to_quat(rot), + translation=trans, + rotation=rot) + + points = [paddle.unsqueeze(x, axis=-2) for x in + paddle.unstack(affines.translation, axis=-1)] + affine_vec = affines.invert_point(points, extra_dims=1) + inv_distance_scalar = paddle.rsqrt( + 1e-6 + sum([paddle.square(x) for x in affine_vec])) + + # Backbone affine mask: whether the residue has C, CA, N + # (the template mask defined above only considers pseudo CB). + template_mask = ( + batch['template_all_atom_masks'][..., n] * + batch['template_all_atom_masks'][..., ca] * + batch['template_all_atom_masks'][..., c]) + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + inv_distance_scalar *= template_mask_2d.astype(inv_distance_scalar.dtype) + + unit_vector = [(x * inv_distance_scalar)[..., None] for x in affine_vec] + unit_vector = [x.astype(dtype) for x in unit_vector] + if not self.config.use_template_unit_vector: + unit_vector = [paddle.zeros_like(x) for x in unit_vector] + to_concat.extend(unit_vector) + + template_mask_2d = template_mask_2d.astype(dtype) + to_concat.append(template_mask_2d[..., None]) + + act = paddle.concat(to_concat, axis=-1) + # Mask out non-template regions so we don't get arbitrary values in the + # distogram for these regions. + act *= template_mask_2d[..., None] + + act = self.embedding2d(act) + for pair_encoder in self.template_pair_stack: + act = pair_encoder(act, mask_2d) + + act = self.output_layer_norm(act) + return act + + +class TemplateEmbedding(nn.Layer): + """Embeds a set of templates. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12 + Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention" + """ + + def __init__(self, channel_num, config, global_config): + super(TemplateEmbedding, self).__init__() + self.config = config + self.global_config = global_config + + self.single_template_embedding = SingleTemplateEmbedding( + channel_num, config, global_config) + self.attention = Attention( + config.attention, global_config, + channel_num['pair_channel'], + config.attention.key_dim, + channel_num['pair_channel']) + + def forward(self, query_embedding, template_batch, mask_2d): + """Build TemplateEmbedding module. + + Arguments: + query_embedding: Query pair representation, shape [n_batch, N_res, N_res, c_z]. + template_batch: A batch of template features. + mask_2d: Padding mask (Note: this doesn't care if a template exists, + unlike the template_pseudo_beta_mask). + + Returns: + A template embedding [n_batch, N_res, N_res, c_z]. + """ + + num_templates = template_batch['template_mask'].shape[1] + + num_channels = (self.config.template_pair_stack + .triangle_attention_ending_node.value_dim) + + num_res = query_embedding.shape[1] + + dtype = query_embedding.dtype + template_mask = template_batch['template_mask'] + template_mask = template_mask.astype(dtype) + + query_channels = query_embedding.shape[-1] + + outs = [] + for i in range(num_templates): + # By default, num_templates = 4 + batch0 = {k: paddle.squeeze(v.slice([1], [i], [i+1]), axis=1) + for k, v in template_batch.items()} + outs.append(self.single_template_embedding( + query_embedding, batch0, mask_2d)) + + template_pair_repr = paddle.stack(outs, axis=1) + + flat_query = paddle.reshape( + query_embedding, [-1, num_res * num_res, 1, query_channels]) + flat_templates = paddle.reshape( + paddle.transpose(template_pair_repr, [0, 2, 3, 1, 4]), + [-1, num_res * num_res, num_templates, num_channels]) + + bias = 1e9 * (template_mask[:, None, None, None, :] - 1.) + + if not self.training: + sb_attn = subbatch(self.attention, [0, 1], [1, 1], + self.config.subbatch_size, 1) + emb = sb_attn(flat_query, flat_templates, bias) + + else: + emb = self.attention(flat_query, flat_templates, bias) + + emb = paddle.reshape( + emb, [-1, num_res, num_res, query_channels]) + + # No gradients if no templates. + emb *= (paddle.sum(template_mask) > 0.).astype(emb.dtype) + return emb diff --git a/apps/protein_folding/helixfold_cpu/modules_bk.py b/apps/protein_folding/helixfold_cpu/modules_bk.py new file mode 100644 index 00000000..c6664452 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/modules_bk.py @@ -0,0 +1,2191 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Modules.""" + +import numpy as np + +import paddle +import paddle.nn as nn +from paddle.fluid.framework import _dygraph_tracer +from paddle.distributed.fleet.utils import recompute + +from alphafold_paddle.common import residue_constants +from alphafold_paddle.model.utils import mask_mean, subbatch +from alphafold_paddle.model import folding, lddt, quat_affine, all_atom +from alphafold_paddle.model.utils import init_gate_linear, init_final_linear +from alphafold_paddle.distributed import dap + +# Map head name in config to head name in model params +Head_names = { + 'masked_msa': 'masked_msa_head', + 'distogram': 'distogram_head', + 'predicted_lddt': 'predicted_lddt_head', + 'predicted_aligned_error': 'predicted_aligned_error_head', + 'experimentally_resolved': 'experimentally_resolved_head', # finetune loss +} + + +def recompute_wrapper(func, *args, is_recompute=True): + """Function wrapper for recompute""" + if is_recompute: + return recompute(func, *args) + else: + return func(*args) + + +def softmax_cross_entropy(logits, labels): + """Computes softmax cross entropy given logits and one-hot class labels.""" + loss = -paddle.sum(labels * paddle.nn.functional.log_softmax(logits), axis=-1) + return loss + + +def sigmoid_cross_entropy(logits, labels): + """Computes sigmoid cross entropy given logits and multiple class labels.""" + log_p = paddle.nn.functional.log_sigmoid(logits) + # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter is more numerically stable + log_not_p = paddle.nn.functional.log_sigmoid(-logits) + loss = -labels * log_p - (1. - labels) * log_not_p + return loss + + +class AlphaFold(nn.Layer): + """AlphaFold model with recycling. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" + """ + def __init__(self, channel_num, config): + super(AlphaFold, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = config.global_config + + self.alphafold_iteration = AlphaFoldIteration( + self.channel_num, self.config, self.global_config) + + def forward(self, + batch, + label, + ensemble_representations=False, + return_representations=False, + compute_loss=True): + """Run the AlphaFold model. + + Arguments: + batch: Dictionary with inputs to the AlphaFold model. + ensemble_representations: Whether to use ensembling of representations. + return_representations: Whether to also return the intermediate + representations. + + Returns: + The output of AlphaFoldIteration is a nested dictionary containing + predictions from the various heads. + + """ + inner_batch, num_residues = batch['aatype'].shape[1:] + + def _get_prev(ret): + new_prev = { + 'prev_pos': ret['structure_module']['final_atom_positions'], + 'prev_msa_first_row': ret['representations']['msa_first_row'], + 'prev_pair': ret['representations']['pair'], + } + + for k in new_prev.keys(): + new_prev[k].stop_gradient = True + + return new_prev + + def _run_single_recycling(prev, recycle_idx, compute_loss): + if not self.training: + print(f'########## recycle id: {recycle_idx} ##########') + + if self.config.resample_msa_in_recycling: + # (B, (R+1)*E, N, ...) + # B: batch size, R: recycling number, + # E: ensemble number, N: residue number + num_ensemble = inner_batch // (self.config.num_recycle + 1) + ensembled_batch = dict() + for k in batch.keys(): + start = recycle_idx * num_ensemble + end = start + num_ensemble + ensembled_batch[k] = batch[k][:, start:end] + else: + # (B, E, N, ...) + num_ensemble = inner_batch + ensembled_batch = batch + + non_ensembled_batch = prev + return self.alphafold_iteration( + ensembled_batch, label, non_ensembled_batch, + compute_loss=compute_loss, + ensemble_representations=ensemble_representations) + + if self.config.num_recycle: + # aatype: (B, E, N), zeros_bn: (B, N) + zeros_bn = paddle.zeros_like(batch['aatype'][:, 0], dtype='float32') + + emb_config = self.config.embeddings_and_evoformer + prev = { + 'prev_pos': paddle.tile( + zeros_bn[..., None, None], + [1, 1, residue_constants.atom_type_num, 3]), + 'prev_msa_first_row': paddle.tile( + zeros_bn[..., None], + [1, 1, emb_config.msa_channel]), + 'prev_pair': paddle.tile( + zeros_bn[..., None, None], + [1, 1, num_residues, emb_config.pair_channel]), + } + + if 'num_iter_recycling' in batch: + # Training trick: dynamic recycling number + num_iter = batch['num_iter_recycling'].numpy()[0, 0] + num_iter = min(int(num_iter), self.config.num_recycle) + else: + num_iter = self.config.num_recycle + + for recycle_idx in range(num_iter): + ret = _run_single_recycling(prev, recycle_idx, compute_loss=False) + prev = _get_prev(ret) + + else: + prev = {} + num_iter = 0 + + return _run_single_recycling(prev, num_iter, compute_loss=compute_loss) + + +class AlphaFoldIteration(nn.Layer): + """A single recycling iteration of AlphaFold architecture. + + Computes ensembled (averaged) representations from the provided features. + These representations are then passed to the various heads + that have been requested by the configuration file. Each head also returns a + loss which is combined as a weighted sum to produce the total loss. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 3-22 + """ + + def __init__(self, channel_num, config, global_config): + super(AlphaFoldIteration, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + # copy these config for later usage + self.channel_num['extra_msa_channel'] = config.embeddings_and_evoformer.extra_msa_channel + self.channel_num['msa_channel'] = config.embeddings_and_evoformer.msa_channel + self.channel_num['pair_channel'] = config.embeddings_and_evoformer.pair_channel + self.channel_num['seq_channel'] = config.embeddings_and_evoformer.seq_channel + + self.evoformer = EmbeddingsAndEvoformer( + self.channel_num, self.config.embeddings_and_evoformer, + self.global_config) + + Head_modules = { + 'masked_msa': MaskedMsaHead, + 'distogram': DistogramHead, + 'structure_module': folding.StructureModule, + 'predicted_lddt': PredictedLDDTHead, + 'predicted_aligned_error': PredictedAlignedErrorHead, + 'experimentally_resolved': ExperimentallyResolvedHead, # finetune loss + } + + self.used_heads = [] + for head_name, head_config in sorted(self.config.heads.items()): + if head_name not in Head_modules: + continue + + self.used_heads.append(head_name) + module = Head_modules[head_name]( + self.channel_num, head_config, self.global_config) + + head_name_ = Head_names.get(head_name, head_name) + setattr(self, head_name_, module) + + def forward(self, + ensembled_batch, + label, + non_ensembled_batch, + compute_loss=False, + ensemble_representations=False): + num_ensemble = ensembled_batch['seq_length'].shape[1] + if not ensemble_representations: + assert num_ensemble == 1 + + def _slice_batch(i): + b = {k: v[:, i] for k, v in ensembled_batch.items()} + b.update(non_ensembled_batch) + return b + + batch0 = _slice_batch(0) + representations = self.evoformer(batch0) + + # MSA representations are not ensembled + msa_representation = representations['msa'] + del representations['msa'] + # MaskedMSAHead is apply on batch0 + label['bert_mask'] = batch0['bert_mask'] + label['true_msa'] = batch0['true_msa'] + label['residue_index'] = batch0['residue_index'] + + if ensemble_representations: + for i in range(1, num_ensemble): + batch = _slice_batch(i) + representations_update = self.evoformer(batch) + for k in representations.keys(): + representations[k] += representations_update[k] + + for k in representations.keys(): + representations[k] /= num_ensemble + 0.0 + + representations['msa'] = msa_representation + ret = {'representations': representations} + + def loss(head_name_, head_config, ret, head_name, filter_ret=True): + if filter_ret: + value = ret[head_name] + else: + value = ret + loss_output = getattr(self, head_name_).loss(value, label) + ret[head_name].update(loss_output) + loss = head_config.weight * ret[head_name]['loss'] + return loss + + def _forward_heads(representations, ret, batch0): + total_loss = 0. + for head_name, head_config in self._get_heads(): + head_name_ = Head_names.get(head_name, head_name) + # Skip PredictedLDDTHead and PredictedAlignedErrorHead until + # StructureModule is executed. + if head_name in ('predicted_lddt', 'predicted_aligned_error'): + continue + else: + ret[head_name] = getattr(self, head_name_)(representations, batch0) + if 'representations' in ret[head_name]: + # Extra representations from the head. Used by the + # structure module to provide activations for the PredictedLDDTHead. + representations.update(ret[head_name].pop('representations')) + if compute_loss: + total_loss += loss(head_name_, head_config, ret, head_name) + + if self.config.heads.get('predicted_lddt.weight', 0.0): + # Add PredictedLDDTHead after StructureModule executes. + head_name = 'predicted_lddt' + # Feed all previous results to give access to structure_module result. + head_name_ = Head_names.get(head_name, head_name) + head_config = self.config.heads[head_name] + ret[head_name] = getattr(self, head_name_)(representations, batch0) + if compute_loss: + total_loss += loss(head_name_, head_config, ret, head_name, filter_ret=False) + + if ('predicted_aligned_error' in self.config.heads + and self.config.heads.get('predicted_aligned_error.weight', 0.0)): + # Add PredictedAlignedErrorHead after StructureModule executes. + head_name = 'predicted_aligned_error' + # Feed all previous results to give access to structure_module result. + head_config = self.config.heads[head_name] + head_name_ = Head_names.get(head_name, head_name) + ret[head_name] = getattr(self, head_name_)(representations, batch0) + if compute_loss: + total_loss += loss(head_name_, head_config, ret, head_name, filter_ret=False) + + return ret, total_loss + + tracer = _dygraph_tracer() + if tracer._amp_dtype == "bfloat16": + with paddle.amp.auto_cast(enable=False): + for key, value in representations.items(): + if value.dtype in [paddle.fluid.core.VarDesc.VarType.BF16]: + temp_value = value.cast('float32') + temp_value.stop_gradient = value.stop_gradient + representations[key] = temp_value + for key, value in batch0.items(): + if value.dtype in [paddle.fluid.core.VarDesc.VarType.BF16]: + temp_value = value.cast('float32') + temp_value.stop_gradient = value.stop_gradient + batch0[key] = temp_value + ret, total_loss = _forward_heads(representations, ret, batch0) + + else: + ret, total_loss = _forward_heads(representations, ret, batch0) + + if compute_loss: + return ret, total_loss + else: + return ret + + def _get_heads(self): + assert 'structure_module' in self.used_heads + head_names = [h for h in self.used_heads] + + for k in head_names: + yield k, self.config.heads[k] + + +class Attention(nn.Layer): + """Multihead attention.""" + + def __init__(self, config, global_config, q_dim, kv_dim, output_dim): + super(Attention, self).__init__() + self.config = config + self.global_config = global_config + + num_head = self.config.num_head + key_dim = self.config.get('key_dim', q_dim) + value_dim = self.config.get('value_dim', kv_dim) + + assert key_dim % num_head == 0 + assert value_dim % num_head == 0 + key_dim = key_dim // num_head + value_dim = value_dim // num_head + + self.key_dim = key_dim + self.value_dim = value_dim + + self.query_w = paddle.create_parameter( + [q_dim, num_head, key_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + self.key_w = paddle.create_parameter( + [kv_dim, num_head, key_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + self.value_w = paddle.create_parameter( + [kv_dim, num_head, value_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + + if self.config.gating: + self.gating_w = paddle.create_parameter( + [q_dim, num_head, value_dim], 'float32', + default_initializer=nn.initializer.Constant(0.0)) + self.gating_b = paddle.create_parameter( + [num_head, value_dim], 'float32', + default_initializer=nn.initializer.Constant(1.0)) + + if self.global_config.zero_init: + init = nn.initializer.Constant(0.0) + else: + init = nn.initializer.XavierUniform() + + self.output_w = paddle.create_parameter( + [num_head, value_dim, output_dim], 'float32', + default_initializer=init) + self.output_b = paddle.create_parameter( + [output_dim], 'float32', + default_initializer=nn.initializer.Constant(0.0)) + + def forward(self, q_data, m_data, bias, nonbatched_bias=None): + """Builds Attention module. + Arguments: + q_data: A tensor of queries, shape [batch, row_size, N_queries, q_channels]. + m_data: A tensor of memories from which the keys and values are + projected, shape [batch, row_size, N_keys, m_channels]. + bias: A bias for the attention, shape [batch, row_size, num_head, N_queries, N_keys]. + nonbatched_bias: Shared bias, shape [N_queries, N_keys]. + + Returns: + A float32 tensor of shape [batch_size, row_size, N_queries, output_dim]. + """ + c = self.key_dim ** (-0.5) + q = paddle.einsum('nbqa,ahc->nbqhc', q_data, self.query_w) * c + k = paddle.einsum('nbka,ahc->nbkhc', m_data, self.key_w) + v = paddle.einsum('nbka,ahc->nbkhc', m_data, self.value_w) + logits = paddle.einsum('nbqhc,nbkhc->nbhqk', q, k) + bias + + if nonbatched_bias is not None: + nonbatched_bias_after = dap.all_gather_opp(nonbatched_bias, axis=2) + logits += paddle.unsqueeze(nonbatched_bias_after, axis=1) + + weights = nn.functional.softmax(logits) + weighted_avg = paddle.einsum('nbhqk,nbkhc->nbqhc', weights, v) + + if self.config.gating: + gate_values = paddle.einsum('nbqc,chv->nbqhv', q_data, + self.gating_w) + self.gating_b + gate_values = nn.functional.sigmoid(gate_values) + weighted_avg *= gate_values + + output = paddle.einsum('nbqhc,hco->nbqo', weighted_avg, + self.output_w) + self.output_b + return output + + +class GlobalAttention(nn.Layer): + """Global attention. + + Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" lines 2-7 + """ + + def __init__(self, config, global_config, q_dim, kv_dim, output_dim): + super(GlobalAttention, self).__init__() + self.config = config + self.global_config = global_config + + num_head = self.config.num_head + key_dim = self.config.get('key_dim', q_dim) + value_dim = self.config.get('value_dim', kv_dim) + + assert key_dim % num_head == 0 + assert value_dim % num_head == 0 + key_dim = key_dim // num_head + value_dim = value_dim // num_head + + self.key_dim = key_dim + self.value_dim = value_dim + + self.query_w = paddle.create_parameter( + [q_dim, num_head, key_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + self.key_w = paddle.create_parameter( + [kv_dim, key_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + self.value_w = paddle.create_parameter( + [kv_dim, value_dim], 'float32', + default_initializer=nn.initializer.XavierUniform()) + + if self.config.gating: + self.gating_w = paddle.create_parameter( + [q_dim, num_head, value_dim], 'float32', + default_initializer=nn.initializer.Constant(0.0)) + self.gating_b = paddle.create_parameter( + [num_head, value_dim], 'float32', + default_initializer=nn.initializer.Constant(1.0)) + + if self.global_config.zero_init: + init = nn.initializer.Constant(0.0) + else: + init = nn.initializer.XavierUniform() + + self.output_w = paddle.create_parameter( + [num_head, value_dim, output_dim], 'float32', + default_initializer=init) + self.output_b = paddle.create_parameter( + [output_dim], 'float32', + default_initializer=nn.initializer.Constant(0.0)) + + def forward(self, q_data, m_data, q_mask): + k = paddle.einsum('nbka,ac->nbkc', m_data, self.key_w) + v = paddle.einsum('nbka,ac->nbkc', m_data, self.value_w) + + # NOTE: differ from non-global version using q_avg for attn + q_avg = mask_mean(q_mask, q_data, axis=2) + c = self.key_dim ** (-0.5) + q = paddle.einsum('nba,ahc->nbhc', q_avg, self.query_w) * c + + q_mask_ = paddle.unsqueeze(q_mask, axis=2)[..., 0] + bias = 1e9 * (q_mask_ - 1.) + + logits = paddle.einsum('nbhc,nbkc->nbhk', q, k) + bias + weights = nn.functional.softmax(logits) + weighted_avg = paddle.einsum('nbhk,nbkc->nbhc', weights, v) + + if self.config.gating: + gate_values = paddle.einsum('nbqc,chv->nbqhv', q_data, + self.gating_w) + self.gating_b + gate_values = nn.functional.sigmoid(gate_values) + weighted_avg = paddle.unsqueeze(weighted_avg, axis=2) + weighted_avg *= gate_values + + output = paddle.einsum('nbqhc,hco->nbqo', weighted_avg, + self.output_w) + self.output_b + else: + output = paddle.einsum('nbhc,hco->nbo', weighted_avg, + self.output_w) + self.output_b + output = paddle.unsqueeze(output, axis=-1) + + return output + + +class MSARowAttentionWithPairBias(nn.Layer): + """MSA per-row attention biased by the pair representation. + + Jumper et al. (2021) Suppl. Alg. 7 "MSARowAttentionWithPairBias" + """ + + def __init__(self, channel_num, config, global_config, is_extra_msa): + super(MSARowAttentionWithPairBias, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.is_extra_msa = is_extra_msa + assert config.orientation == 'per_row' + + if is_extra_msa: + self.query_norm = nn.LayerNorm(channel_num['extra_msa_channel']) + else: + self.query_norm = nn.LayerNorm(channel_num['msa_channel']) + + self.feat_2d_norm = nn.LayerNorm(channel_num['pair_channel']) + self.feat_2d_weights = paddle.create_parameter( + [channel_num['pair_channel'], self.config.num_head], 'float32', + default_initializer=nn.initializer.Normal( + std=1. / np.sqrt(channel_num['pair_channel']))) + + if is_extra_msa: + extra_msa_channel = channel_num['extra_msa_channel'] + self.attention = Attention( + self.config, self.global_config, + extra_msa_channel, extra_msa_channel, extra_msa_channel) + else: + msa_channel = channel_num['msa_channel'] + self.attention = Attention( + self.config, self.global_config, + msa_channel, msa_channel, msa_channel) + + def forward(self, msa_act, msa_mask, pair_act): + + pair_act = self.feat_2d_norm(pair_act) + + # [B, N_res//dap_size, N_res, cz], [cz, head] => [B, head, N_res//dap_size, N_res] + nonbatched_bias_before = paddle.einsum( + 'nqkc,ch->nhqk', pair_act, self.feat_2d_weights) + + # [B, head, N_res//dap_size, N_res] => [B, head, N_res, N_res] + nonbatched_bias = dap.all_gather(nonbatched_bias_before, axis=2) + + # [B, N_seq, N_res] => [B, N_seq//dap_size, N_res] + msa_mask = dap.scatter(msa_mask, axis=1) + + bias = 1e9 * (msa_mask - 1.) + # [B, N_seq//dap_size, N_res] => [B, N_seq//dap_size, 1, 1, N_res] + bias = paddle.unsqueeze(bias, axis=[2, 3]) + msa_act = self.query_norm(msa_act) + + if not self.training: + # low memory mode using subbatch + sb_attn = subbatch(self.attention, [0, 1, 2], [1, 1, 1], + self.global_config.subbatch_size, 1) + msa_act = sb_attn(msa_act, msa_act, bias, nonbatched_bias) + else: + msa_act = self.attention(msa_act, msa_act, bias, nonbatched_bias) + + return msa_act + + +class MSAColumnGlobalAttention(nn.Layer): + """MSA per-column global attention. + + Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" + """ + + def __init__(self, channel_num, config, global_config): + super(MSAColumnGlobalAttention, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + assert config.orientation == 'per_column' + + extra_msa_channel = channel_num['extra_msa_channel'] + self.query_norm = nn.LayerNorm(extra_msa_channel) + self.attention = GlobalAttention( + self.config, self.global_config, + extra_msa_channel, extra_msa_channel, extra_msa_channel) + + def forward(self, msa_act, msa_mask): + # scatter if using dap, otherwise do nothing + # [B, N_seq, N_res] => [B, N_seq, N_res//dap_size] + msa_mask = dap.scatter(msa_mask, axis=2) + + msa_act = paddle.transpose(msa_act, [0, 2, 1, 3]) + msa_mask = paddle.transpose(msa_mask, [0, 2, 1]) + + bias = 1e9 * (msa_mask - 1.) + bias = paddle.unsqueeze(bias, axis=[2, 3]) + + msa_mask = paddle.unsqueeze(msa_mask, axis=-1) + msa_act = self.query_norm(msa_act) + + if not self.training: + # low memory mode using subbatch + sb_attn = subbatch(self.attention, [0, 1, 2], [1, 1, 1], + self.global_config.subbatch_size, 1) + msa_act = sb_attn(msa_act, msa_act, msa_mask) + else: + msa_act = self.attention(msa_act, msa_act, msa_mask) + + msa_act = paddle.transpose(msa_act, [0, 2, 1, 3]) + return msa_act + + +class MSAColumnAttention(nn.Layer): + """MSA per-column attention. + + Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention" + """ + + def __init__(self, channel_num, config, global_config): + super(MSAColumnAttention, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + assert config.orientation == 'per_column' + + msa_channel = channel_num['msa_channel'] + self.query_norm = nn.LayerNorm(msa_channel) + self.attention = Attention( + self.config, self.global_config, + msa_channel, msa_channel, msa_channel) + + def forward(self, msa_act, msa_mask): + # scatter if using dap, otherwise do nothing + # [B, N_seq, N_res] => [B, N_seq, N_res//dap_size] + msa_mask = dap.scatter(msa_mask, axis=2) + + msa_act = paddle.transpose(msa_act, [0, 2, 1, 3]) + msa_mask = paddle.transpose(msa_mask, [0, 2, 1]) + + bias = 1e9 * (msa_mask - 1.) + bias = paddle.unsqueeze(bias, axis=[2, 3]) + + msa_act = self.query_norm(msa_act) + if not self.training: + # low memory mode using subbatch + sb_attn = subbatch(self.attention, [0, 1, 2], [1, 1, 1], + self.global_config.subbatch_size, 1) + msa_act = sb_attn(msa_act, msa_act, bias) + else: + msa_act = self.attention(msa_act, msa_act, bias) + + msa_act = paddle.transpose(msa_act, [0, 2, 1, 3]) + return msa_act + + +class Transition(nn.Layer): + """Transition layer. + + Jumper et al. (2021) Suppl. Alg. 9 "MSATransition" + Jumper et al. (2021) Suppl. Alg. 15 "PairTransition" + """ + + def __init__(self, channel_num, config, global_config, is_extra_msa, + transition_type): + super(Transition, self).__init__() + assert transition_type in ['msa_transition', 'pair_transition'] + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.is_extra_msa = is_extra_msa + self.transition_type = transition_type + + if transition_type == 'msa_transition' and is_extra_msa: + in_dim = channel_num['extra_msa_channel'] + elif transition_type == 'msa_transition' and not is_extra_msa: + in_dim = channel_num['msa_channel'] + elif transition_type == 'pair_transition': + in_dim = channel_num['pair_channel'] + + self.input_layer_norm = nn.LayerNorm(in_dim) + self.transition1 = nn.Linear( + in_dim, int(in_dim * self.config.num_intermediate_factor), + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.KaimingNormal())) + + if self.global_config.zero_init: + last_init = nn.initializer.Constant(0.0) + else: + last_init = nn.initializer.TruncatedNormal() + + self.transition2 = nn.Linear( + int(in_dim * self.config.num_intermediate_factor), in_dim, + weight_attr=paddle.ParamAttr(initializer=last_init)) + + def forward(self, act, mask): + act = self.input_layer_norm(act) + + def transition_module(x): + x = self.transition1(x) + x = nn.functional.relu(x) + x = self.transition2(x) + return x + + if not self.training: + # low memory mode using subbatch + sb_transition = subbatch(transition_module, [0], [1], + self.global_config.subbatch_size, 1) + act = sb_transition(act) + else: + act = transition_module(act) + + return act + + +class MaskedMsaHead(nn.Layer): + """Head to predict MSA at the masked locations. + + The MaskedMsaHead employs a BERT-style objective to reconstruct a masked + version of the full MSA, based on a linear projection of + the MSA representation. + Jumper et al. (2021) Suppl. Sec. 1.9.9 "Masked MSA prediction" + """ + def __init__(self, channel_num, config, global_config, name='masked_msa_head'): + super(MaskedMsaHead, self).__init__() + self.config = config + self.global_config = global_config + self.num_output = config.num_output + self.logits = nn.Linear(channel_num['msa_channel'], self.num_output, name='logits') + + def forward(self, representations, batch): + """Builds MaskedMsaHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'msa': MSA representation, shape [batch, N_seq, N_res, c_m]. + batch: Batch, unused. + + Returns: + Dictionary containing: + * 'logits': logits of shape [batch, N_seq, N_res, N_aatype] with + (unnormalized) log probabilies of predicted aatype at position. + """ + del batch + logits = self.logits(representations['msa']) + return {'logits': logits} + + def loss(self, value, batch): + errors = softmax_cross_entropy( + labels=paddle.nn.functional.one_hot(batch['true_msa'], num_classes=self.num_output), + logits=value['logits']) + loss = (paddle.sum(errors * batch['bert_mask'], axis=[-2, -1]) / + (1e-8 + paddle.sum(batch['bert_mask'], axis=[-2, -1]))) + return {'loss': loss} + + +class PredictedLDDTHead(nn.Layer): + """Head to predict the per-residue LDDT to be used as a confidence measure. + + Jumper et al. (2021) Suppl. Sec. 1.9.6 "Model confidence prediction (pLDDT)" + Jumper et al. (2021) Suppl. Alg. 29 "predictPerResidueLDDT_Ca" + """ + + def __init__(self, channel_num, config, global_config, name='predicted_lddt_head'): + super(PredictedLDDTHead, self).__init__() + self.config = config + self.global_config = global_config + + self.input_layer_norm = nn.LayerNorm(channel_num['seq_channel'], + name='input_layer_norm') + self.act_0 = nn.Linear(channel_num['seq_channel'], + self.config.num_channels, name='act_0') + self.act_1 = nn.Linear(self.config.num_channels, + self.config.num_channels, name='act_1') + self.logits = nn.Linear(self.config.num_channels, + self.config.num_bins, name='logits') + + def forward(self, representations, batch): + """Builds PredictedLDDTHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'structure_module': Single representation from the structure module, + shape [n_batch, N_res, c_s]. + + Returns: + Dictionary containing : + * 'logits': logits of shape [n_batch, N_res, N_bins] with + (unnormalized) log probabilies of binned predicted lDDT. + """ + act = representations['structure_module'] + act = self.input_layer_norm(act) + act = nn.functional.relu(self.act_0(act)) + act = nn.functional.relu(self.act_1(act)) + logits = self.logits(act) + + return dict(logits=logits) + + def loss(self, value, batch): + # Shape (n_batch, num_res, 37, 3) + pred_all_atom_pos = value['structure_module']['final_atom_positions'] + # Shape (n_batch, num_res, 37, 3) + true_all_atom_pos = paddle.cast(batch['all_atom_positions'], 'float32') + # Shape (n_batch, num_res, 37) + all_atom_mask = paddle.cast(batch['all_atom_mask'], 'float32') + + # Shape (batch_size, num_res) + lddt_ca = lddt.lddt( + # Shape (batch_size, num_res, 3) + predicted_points=pred_all_atom_pos[:, :, 1, :], + # Shape (batch_size, num_res, 3) + true_points=true_all_atom_pos[:, :, 1, :], + # Shape (batch_size, num_res, 1) + true_points_mask=all_atom_mask[:, :, 1:2], + cutoff=15., + per_residue=True) + lddt_ca = lddt_ca.detach() + + # Shape (batch_size, num_res) + num_bins = self.config.num_bins + bin_index = paddle.floor(lddt_ca * num_bins) + + # protect against out of range for lddt_ca == 1 + bin_index = paddle.minimum(bin_index, paddle.to_tensor(num_bins - 1, dtype='float32')) + lddt_ca_one_hot = paddle.nn.functional.one_hot(paddle.cast(bin_index, 'int64'), num_classes=num_bins) + + # Shape (n_batch, num_res, num_channel) + logits = value['predicted_lddt']['logits'] + errors = softmax_cross_entropy(labels=lddt_ca_one_hot, logits=logits) + + # Shape (num_res,) + mask_ca = all_atom_mask[:, :, residue_constants.atom_order['CA']] + mask_ca = paddle.to_tensor(mask_ca, dtype='float32') + loss = paddle.sum(errors * mask_ca, axis=-1) / (paddle.sum(mask_ca, axis=-1) + 1e-8) + + if self.config.filter_by_resolution: + # NMR & distillation have resolution = 0 + resolution = paddle.squeeze(batch['resolution'], axis=-1) + loss *= paddle.cast((resolution >= self.config.min_resolution) + & (resolution <= self.config.max_resolution), 'float32') + output = {'loss': loss} + return output + + +class PredictedAlignedErrorHead(nn.Layer): + """Head to predict the distance errors in the backbone alignment frames. + + Can be used to compute predicted TM-Score. + Jumper et al. (2021) Suppl. Sec. 1.9.7 "TM-score prediction" + """ + def __init__(self, channel_num, config, global_config, + name='predicted_aligned_error_head'): + super(PredictedAlignedErrorHead, self).__init__() + self.config = config + self.global_config = global_config + + self.logits = nn.Linear(channel_num['pair_channel'], + self.config.num_bins, name='logits') + + def forward(self, representations, batch): + """Builds PredictedAlignedErrorHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'pair': pair representation, shape [B, N_res, N_res, c_z]. + batch: Batch, unused. + + Returns: + Dictionary containing: + * logits: logits for aligned error, shape [B, N_res, N_res, N_bins]. + * bin_breaks: array containing bin breaks, shape [N_bins - 1]. + """ + logits = self.logits(representations['pair']) + breaks = paddle.linspace(0., self.config.max_error_bin, + self.config.num_bins-1) + + return dict(logits=logits, breaks=breaks) + + def loss(self, value, batch): + # Shape (B, num_res, 7) + predicted_affine = quat_affine.QuatAffine.from_tensor( + value['structure_module']['final_affines']) + # Shape (B, num_res, 7) + true_rot = paddle.to_tensor(batch['backbone_affine_tensor_rot'], dtype='float32') + true_trans = paddle.to_tensor(batch['backbone_affine_tensor_trans'], dtype='float32') + true_affine = quat_affine.QuatAffine( + quaternion=None, + translation=true_trans, + rotation=true_rot) + # Shape (B, num_res) + mask = batch['backbone_affine_mask'] + # Shape (B, num_res, num_res) + square_mask = mask[..., None] * mask[:, None, :] + num_bins = self.config.num_bins + # (num_bins - 1) + breaks = value['predicted_aligned_error']['breaks'] + # (B, num_res, num_res, num_bins) + logits = value['predicted_aligned_error']['logits'] + + # Compute the squared error for each alignment. + def _local_frame_points(affine): + points = [paddle.unsqueeze(x, axis=-2) for x in + paddle.unstack(affine.translation, axis=-1)] + return affine.invert_point(points, extra_dims=1) + error_dist2_xyz = [ + paddle.square(a - b) + for a, b in zip(_local_frame_points(predicted_affine), + _local_frame_points(true_affine))] + error_dist2 = sum(error_dist2_xyz) + # Shape (B, num_res, num_res) + # First num_res are alignment frames, second num_res are the residues. + error_dist2 = error_dist2.detach() + + sq_breaks = paddle.square(breaks) + true_bins = paddle.sum(paddle.cast((error_dist2[..., None] > sq_breaks), 'int32'), axis=-1) + + errors = softmax_cross_entropy( + labels=paddle.nn.functional.one_hot(true_bins, num_classes=num_bins), logits=logits) + + loss = (paddle.sum(errors * square_mask, axis=[-2, -1]) / + (1e-8 + paddle.sum(square_mask, axis=[-2, -1]))) + + if self.config.filter_by_resolution: + # NMR & distillation have resolution = 0 + resolution = paddle.squeeze(batch['resolution'], axis=-1) + loss *= paddle.cast((resolution >= self.config.min_resolution) + & (resolution <= self.config.max_resolution), 'float32') + + output = {'loss': loss} + return output + + +class ExperimentallyResolvedHead(nn.Layer): + """Predicts if an atom is experimentally resolved in a high-res structure. + + Only trained on high-resolution X-ray crystals & cryo-EM. + Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction' + """ + + def __init__(self, channel_num, config, global_config, name='experimentally_resolved_head'): + super(ExperimentallyResolvedHead, self).__init__() + self.config = config + self.global_config = global_config + self.logits = nn.Linear(channel_num['seq_channel'], 37, name='logits') + + def forward(self, representations, batch): + """Builds ExperimentallyResolvedHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'single': Single representation, shape [B, N_res, c_s]. + batch: Batch, unused. + + Returns: + Dictionary containing: + * 'logits': logits of shape [B, N_res, 37], + log probability that an atom is resolved in atom37 representation, + can be converted to probability by applying sigmoid. + """ + logits = self.logits(representations['single']) + return dict(logits=logits) + + def loss(self, value, batch): + logits = value['logits'] + assert len(logits.shape) == 3 + + # Does the atom appear in the amino acid? + atom_exists = batch['atom37_atom_exists'] + # Is the atom resolved in the experiment? Subset of atom_exists, + # *except for OXT* + all_atom_mask = paddle.cast(batch['all_atom_mask'], 'float32') + + xent = sigmoid_cross_entropy(labels=all_atom_mask, logits=logits) + loss = paddle.sum(xent * atom_exists, axis=[-2, -1]) / (1e-8 + paddle.sum(atom_exists, axis=[-2, -1])) + + if self.config.filter_by_resolution: + # NMR & distillation have resolution = 0 + resolution = paddle.squeeze(batch['resolution'], axis=-1) + loss *= paddle.cast((resolution >= self.config.min_resolution) + & (resolution <= self.config.max_resolution), 'float32') + + output = {'loss': loss} + return output + + +class DistogramHead(nn.Layer): + """Head to predict a distogram. + + Jumper et al. (2021) Suppl. Sec. 1.9.8 "Distogram prediction" + """ + + def __init__(self, channel_num, config, name='distogram_head'): + super(DistogramHead, self).__init__() + self.config = config + # self.global_config = global_config + + self.half_logits = nn.Linear(channel_num['pair_channel'], + self.config.num_bins, name='half_logits') + init_final_linear(self.half_logits) + + def forward(self, representations, batch): + """Builds DistogramHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'pair': pair representation, shape [batch, N_res, N_res, c_z]. + + Returns: + Dictionary containing: + * logits: logits for distogram, shape [batch, N_res, N_res, N_bins]. + * bin_breaks: array containing bin breaks, shape [batch, N_bins - 1]. + """ + half_logits = self.half_logits(representations['pair']) + + logits = half_logits + paddle.transpose(half_logits, perm=[0, 2, 1, 3]) + breaks = paddle.linspace(self.config.first_break, self.config.last_break, + self.config.num_bins - 1) + breaks = paddle.tile(breaks[None, :], + repeat_times=[logits.shape[0], 1]) + + return { + 'logits': logits, + 'bin_edges': breaks} + + def loss(self, value, batch): + return _distogram_log_loss(value['logits'], value['bin_edges'], + batch, self.config.num_bins) + + +def _distogram_log_loss(logits, bin_edges, batch, num_bins): + """Log loss of a distogram.""" + positions = batch['pseudo_beta'] + mask = batch['pseudo_beta_mask'] + + assert positions.shape[-1] == 3 + + sq_breaks = paddle.square(bin_edges).unsqueeze([1, 2]) + + dist2 = paddle.sum( + paddle.square( + paddle.unsqueeze(positions, axis=-2) - + paddle.unsqueeze(positions, axis=-3)), + axis=-1, + keepdim=True) + + true_bins = paddle.sum(dist2 > sq_breaks, axis=-1) + + errors = softmax_cross_entropy( + labels=paddle.nn.functional.one_hot(true_bins, num_classes=num_bins), logits=logits) + + square_mask = paddle.unsqueeze(mask, axis=-2) * paddle.unsqueeze(mask, axis=-1) + + avg_error = ( + paddle.sum(errors * square_mask, axis=[-2, -1]) / + (1e-6 + paddle.sum(square_mask, axis=[-2, -1]))) + dist2 = dist2[..., 0] + return { + 'loss': avg_error, + 'true_dist': paddle.sqrt(1e-6 + dist2)} + + +def dgram_from_positions(positions, num_bins, min_bin, max_bin): + lower_breaks = paddle.linspace(min_bin, max_bin, num_bins) + lower_breaks = paddle.square(lower_breaks) + upper_breaks = paddle.concat([lower_breaks[1:], + paddle.to_tensor([1e8], dtype='float32')]) + + def _squared_difference(x, y): + return paddle.square(x - y) + + dist2 = paddle.sum( + _squared_difference( + paddle.unsqueeze(positions, axis=-2), + paddle.unsqueeze(positions, axis=-3)), + axis=-1, keepdim=True) + + dgram = ((dist2 > lower_breaks).astype('float32') * + (dist2 < upper_breaks).astype('float32')) + return dgram + + +class EvoformerIteration(nn.Layer): + """Single iteration (block) of Evoformer stack. + + Jumper et al. (2021) Suppl. Alg. 6 "EvoformerStack" lines 2-10 + """ + def __init__(self, channel_num, config, global_config, is_extra_msa=False): + super(EvoformerIteration, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.is_extra_msa = is_extra_msa + + # Row-wise Gated Self-attention with Pair Bias + self.msa_row_attention_with_pair_bias = MSARowAttentionWithPairBias( + channel_num, self.config.msa_row_attention_with_pair_bias, + self.global_config, is_extra_msa) + dropout_rate, dropout_axis = self._parse_dropout_params( + self.msa_row_attention_with_pair_bias) + self.msa_row_attn_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + if self.is_extra_msa: + self.msa_column_global_attention = MSAColumnGlobalAttention( + channel_num, config.msa_column_attention, global_config) + dropout_rate, dropout_axis = self._parse_dropout_params( + self.msa_column_global_attention) + self.msa_col_attn_dropout = nn.Dropout( + dropout_rate, axis=dropout_axis) + else: + self.msa_column_attention = MSAColumnAttention( + channel_num, config.msa_column_attention, global_config) + dropout_rate, dropout_axis = self._parse_dropout_params( + self.msa_column_attention) + self.msa_col_attn_dropout = nn.Dropout( + dropout_rate, axis=dropout_axis) + + self.msa_transition = Transition( + channel_num, self.config.msa_transition, self.global_config, + is_extra_msa, 'msa_transition') + dropout_rate, dropout_axis = self._parse_dropout_params( + self.msa_transition) + self.msa_transition_dropout = nn.Dropout( + dropout_rate, axis=dropout_axis) + + # OuterProductMean + self.outer_product_mean = OuterProductMean(channel_num, + self.config.outer_product_mean, self.global_config, + self.is_extra_msa, name='outer_product_mean') + + # Dropout + dropout_rate, dropout_axis = self._parse_dropout_params( + self.outer_product_mean) + self.outer_product_mean_dropout = nn.Dropout( + dropout_rate, axis=dropout_axis) + + # Triangle Multiplication. + self.triangle_multiplication_outgoing = TriangleMultiplication(channel_num, + self.config.triangle_multiplication_outgoing, self.global_config, + name='triangle_multiplication_outgoing') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_multiplication_outgoing) + self.triangle_outgoing_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_multiplication_incoming = TriangleMultiplication(channel_num, + self.config.triangle_multiplication_incoming, self.global_config, + name='triangle_multiplication_incoming') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_multiplication_incoming) + self.triangle_incoming_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + # TriangleAttention. + self.triangle_attention_starting_node = TriangleAttention(channel_num, + self.config.triangle_attention_starting_node, self.global_config, + name='triangle_attention_starting_node') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_attention_starting_node) + self.triangle_starting_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_attention_ending_node = TriangleAttention(channel_num, + self.config.triangle_attention_ending_node, self.global_config, + name='triangle_attention_ending_node') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_attention_ending_node) + self.triangle_ending_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + # Pair transition. + self.pair_transition = Transition( + channel_num, self.config.pair_transition, self.global_config, + is_extra_msa, 'pair_transition') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.pair_transition) + self.pair_transition_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + def _parse_dropout_params(self, module): + dropout_rate = 0.0 if self.global_config.deterministic else \ + module.config.dropout_rate + dropout_axis = None + if module.config.shared_dropout: + dropout_axis = { + 'per_row': [0, 2, 3], + 'per_column': [0, 1, 3], + }[module.config.orientation] + + return dropout_rate, dropout_axis + + def forward(self, msa_act, pair_act, masks): + msa_mask, pair_mask = masks['msa'], masks['pair'] + + # [B, N_seq//dap_size, N_res, c_m] + residual = self.msa_row_attention_with_pair_bias( + msa_act, msa_mask, pair_act) + residual = self.msa_row_attn_dropout(residual) + msa_act = msa_act + residual + + # [B, N_seq//dap_size, N_res, c_m] => [B, N_seq, N_res//dap_size, c_m] + msa_act = dap.row_to_col(msa_act) + + if self.is_extra_msa: + # [B, N_seq, N_res//dap_size, c_m] + residual = self.msa_column_global_attention(msa_act, msa_mask) + residual = self.msa_col_attn_dropout(residual) + msa_act = msa_act + residual + + # [B, N_seq, N_res//dap_size, c_m] + residual = self.msa_transition(msa_act, msa_mask) + residual = self.msa_transition_dropout(residual) + msa_act = msa_act + residual + + else: + # [B, N_seq, N_res//dap_size, c_m] + residual = self.msa_column_attention(msa_act, msa_mask) + residual = self.msa_col_attn_dropout(residual) + msa_act = msa_act + residual + + # [B, N_seq, N_res//dap_size, c_m] + residual = self.msa_transition(msa_act, msa_mask) + residual = self.msa_transition_dropout(residual) + msa_act = msa_act + residual + + # [B, N_res//dap_size, N_res, c_z] + residual = self.outer_product_mean(msa_act, msa_mask) + residual = self.outer_product_mean_dropout(residual) + pair_act = pair_act + residual + + # [B, N_seq, N_res//dap_size, c_m] => [B, N_seq//dap_size, N_res, c_m] + msa_act = dap.all_to_all(msa_act, in_axis=1, out_axis=2) + + # scatter if using dap, otherwise do nothing + pair_mask_row = dap.scatter(pair_mask, axis=1) + pair_mask_col = dap.scatter(pair_mask, axis=2) + + # [B, N_res//dap_size, N_res, c_z] + residual = self.triangle_multiplication_outgoing(pair_act, pair_mask_row) + residual = self.triangle_outgoing_dropout(residual) + pair_act = pair_act + residual + + # [B, N_res//dap_size, N_res, c_z] => [B, N_res, N_res//dap_size, c_z] + pair_act = dap.row_to_col(pair_act) + # [B, N_res, N_res//dap_size, c_z] + residual = self.triangle_multiplication_incoming(pair_act, pair_mask_col) + residual = self.triangle_incoming_dropout(residual) + pair_act = pair_act + residual + + # [B, N_res, N_res//dap_size, c_z] => [B, N_res//dap_size, N_res, c_z] + pair_act = dap.col_to_row(pair_act) + # [B, N_res//dap_size, N_res, c_z] + residual = self.triangle_attention_starting_node(pair_act, pair_mask_row) + residual = self.triangle_starting_dropout(residual) + pair_act = pair_act + residual + + # [B, N_res//dap_size, N_res, c_z] => [B, N_res, N_res//dap_size, c_z] + pair_act = dap.row_to_col(pair_act) + # [B, N_res, N_res//dap_size, c_z] + residual = self.triangle_attention_ending_node(pair_act, pair_mask_col) + residual = self.triangle_ending_dropout(residual) + pair_act = pair_act + residual + + residual = self.pair_transition(pair_act, pair_mask) + residual = self.pair_transition_dropout(residual) + pair_act = pair_act + residual + + # [B, N_res, N_res//dap_size, c_z] => [B, N_res//dap_size, N_res, c_z] + pair_act = dap.col_to_row(pair_act) + + # wait if using async communication and dap, otherwise do nothing + # [B, N_seq//dap_size, N_res, c_m] + msa_act = dap.all_to_all_opp(msa_act, in_axis=1, out_axis=2) + + return msa_act, pair_act + + +class EmbeddingsAndEvoformer(nn.Layer): + """Embeds the input data and runs Evoformer. + + Produces the MSA, single and pair representations. + Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5-18 + """ + + def __init__(self, channel_num, config, global_config): + super(EmbeddingsAndEvoformer, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + # InputEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 + # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" + self.preprocess_1d = nn.Linear(channel_num['target_feat'], + self.config.msa_channel, name='preprocess_1d') + self.preprocess_msa = nn.Linear(channel_num['msa_feat'], + self.config.msa_channel, name='preprocess_msa') + self.left_single = nn.Linear(channel_num['target_feat'], self.config.pair_channel, + name='left_single') + self.right_single = nn.Linear(channel_num['target_feat'], self.config.pair_channel, + name='right_single') + + # RecyclingEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 + # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" + if self.config.recycle_pos: + self.prev_pos_linear = nn.Linear(self.config.prev_pos.num_bins, + self.config.pair_channel) + + # RelPosEmbedder + # Jumper et al. (2021) Suppl. Alg. 4 "relpos" + # Jumper et al. (2021) Suppl. Alg. 5 "one_hot" + if self.config.max_relative_feature: + self.pair_activiations = nn.Linear( + 2 * self.config.max_relative_feature + 1, + self.config.pair_channel) + + if self.config.recycle_features: + self.prev_msa_first_row_norm = nn.LayerNorm( + self.config.msa_channel) + self.prev_pair_norm = nn.LayerNorm(self.config.pair_channel) + + # Embed templates into the pair activations. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13 + if self.config.template.enabled: + self.channel_num['template_angle'] = 57 + self.channel_num['template_pair'] = 88 + self.template_embedding = TemplateEmbedding( + self.channel_num, self.config.template, self.global_config) + + # ExtraMSAEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16 + self.extra_msa_activations = nn.Linear( + 25, # 23 (20aa+unknown+gap+mask) + 1 (has_del) + 1 (del_val) + self.config.extra_msa_channel) + + # Extra MSA Stack. + # Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack" + self.extra_msa_stack = nn.LayerList() + for _ in range(self.config.extra_msa_stack_num_block): + self.extra_msa_stack.append(EvoformerIteration( + self.channel_num, self.config.evoformer, self.global_config, + is_extra_msa=True)) + + # Embed templates torsion angles + if self.config.template.enabled and self.config.template.embed_torsion_angles: + c = self.config.msa_channel + self.template_single_embedding = nn.Linear( + self.channel_num['template_angle'], c) + self.template_projection = nn.Linear(c, c) + + # Main trunk of the network + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18 + self.evoformer_iteration = nn.LayerList() + for _ in range(self.config.evoformer_num_block): + self.evoformer_iteration.append(EvoformerIteration( + self.channel_num, self.config.evoformer, self.global_config, + is_extra_msa=False)) + + self.single_activations = nn.Linear( + self.config.msa_channel, self.config.seq_channel) + + def _pseudo_beta_fn(self, aatype, all_atom_positions, all_atom_masks): + gly_id = paddle.ones_like(aatype) * residue_constants.restype_order['G'] + is_gly = paddle.equal(aatype, gly_id) + + ca_idx = residue_constants.atom_order['CA'] + cb_idx = residue_constants.atom_order['CB'] + + n = len(all_atom_positions.shape) + pseudo_beta = paddle.where( + paddle.tile(paddle.unsqueeze(is_gly, axis=-1), + [1] * len(is_gly.shape) + [3]), + paddle.squeeze( + all_atom_positions.slice([n-2], [ca_idx], [ca_idx+1]), + axis=-2), + paddle.squeeze( + all_atom_positions.slice([n-2], [cb_idx], [cb_idx+1]), + axis=-2)) + + if all_atom_masks is not None: + m = len(all_atom_masks) + pseudo_beta_mask = paddle.where( + is_gly, + paddle.squeeze( + all_atom_masks.slice([m-1], [ca_idx], [ca_idx+1]), + axis=-1), + paddle.squeeze( + all_atom_masks.slice([m-1], [cb_idx], [cb_idx+1]), + axis=-1)) + pseudo_beta_mask = paddle.squeeze(pseudo_beta_mask, axis=-1) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + def _create_extra_msa_feature(self, batch): + # 23: 20aa + unknown + gap + bert mask + msa_1hot = nn.functional.one_hot(batch['extra_msa'], 23) + msa_feat = [msa_1hot, + paddle.unsqueeze(batch['extra_has_deletion'], axis=-1), + paddle.unsqueeze(batch['extra_deletion_value'], axis=-1)] + return paddle.concat(msa_feat, axis=-1) + + def forward(self, batch): + # InputEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 + # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" + preprocess_1d = self.preprocess_1d(batch['target_feat']) + # preprocess_msa = self.preprocess_msa(batch['msa_feat']) + msa_activations = paddle.unsqueeze(preprocess_1d, axis=1) + \ + self.preprocess_msa(batch['msa_feat']) + + right_single = self.right_single(batch['target_feat']) # 1, n_res, 22 -> 1, n_res, 128 + right_single = paddle.unsqueeze(right_single, axis=1) # 1, n_res, 128 -> 1, 1, n_res, 128 + left_single = self.left_single(batch['target_feat']) # 1, n_res, 22 -> 1, n_res, 128 + left_single = paddle.unsqueeze(left_single, axis=2) # 1, n_res, 128 -> 1, n_res, 1, 128 + pair_activations = left_single + right_single + + mask_2d = paddle.unsqueeze(batch['seq_mask'], axis=1) * paddle.unsqueeze(batch['seq_mask'], axis=2) + + # Inject previous outputs for recycling. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 + # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" + if self.config.recycle_pos and 'prev_pos' in batch: + prev_pseudo_beta = self._pseudo_beta_fn( + batch['aatype'], batch['prev_pos'], None) + dgram = dgram_from_positions( + prev_pseudo_beta, **self.config.prev_pos) + pair_activations += self.prev_pos_linear(dgram) + + if self.config.recycle_features: + if 'prev_msa_first_row' in batch: + prev_msa_first_row = self.prev_msa_first_row_norm( + batch['prev_msa_first_row']) + + # A workaround for `jax.ops.index_add` + msa_first_row = paddle.squeeze(msa_activations[:, 0, :], axis=1) + msa_first_row += prev_msa_first_row + msa_first_row = paddle.unsqueeze(msa_first_row, axis=1) + msa_activations = paddle.concat([msa_first_row, msa_activations[:, 1:, :]], axis=1) + + if 'prev_pair' in batch: + pair_activations += self.prev_pair_norm(batch['prev_pair']) + + # RelPosEmbedder + # Jumper et al. (2021) Suppl. Alg. 4 "relpos" + # Jumper et al. (2021) Suppl. Alg. 5 "one_hot" + if self.config.max_relative_feature: + pos = batch['residue_index'] # [bs, N_res] + offset = paddle.unsqueeze(pos, axis=[-1]) - \ + paddle.unsqueeze(pos, axis=[-2]) + rel_pos = nn.functional.one_hot( + paddle.clip( + offset + self.config.max_relative_feature, + min=0, + max=2 * self.config.max_relative_feature), + 2 * self.config.max_relative_feature + 1) + rel_pos_bias = self.pair_activiations(rel_pos) + pair_activations += rel_pos_bias + + # TemplateEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13 + if self.config.template.enabled: + template_batch = {k: batch[k] for k in batch + if k.startswith('template_')} + template_pair_repr = self.template_embedding( + pair_activations, template_batch, mask_2d) + pair_activations += template_pair_repr + + # ExtraMSAEmbedder + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16 + extra_msa_feat = self._create_extra_msa_feature(batch) + extra_msa_activations = self.extra_msa_activations(extra_msa_feat) + + # ================================================== + # Extra MSA Stack + # Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack" + # ================================================== + extra_msa_stack_input = { + 'msa': extra_msa_activations, + 'pair': pair_activations, + } + + # scatter if using dap, otherwise do nothing + # [B, N_seq, N_res, c_m] => [B, N_seq//dap_size, N_res, c_m] + extra_msa_stack_input['msa'] = dap.scatter(extra_msa_stack_input['msa'], axis=1) + # [B, N_res, N_res, c_z] => [B, N_res//dap_size, N_res, c_z] + extra_msa_stack_input['pair'] = dap.scatter(extra_msa_stack_input['pair'], axis=1) + + for extra_msa_stack_iteration in self.extra_msa_stack: + extra_msa_act, extra_pair_act = recompute_wrapper(extra_msa_stack_iteration, + extra_msa_stack_input['msa'], + extra_msa_stack_input['pair'], + {'msa': batch['extra_msa_mask'], + 'pair': mask_2d}, + is_recompute=self.training) + extra_msa_stack_output = { + 'msa': extra_msa_act, + 'pair': extra_pair_act} + extra_msa_stack_input = { + 'msa': extra_msa_stack_output['msa'], + 'pair': extra_msa_stack_output['pair']} + + # gather if using dap, otherwise do nothing + # [B, N_res//dap_size, N_res, c_z] => [B, N_res, N_res, c_z] + extra_msa_stack_output['pair'] = dap.gather(extra_msa_stack_output['pair'], axis=1) + + evoformer_input = { + 'msa': msa_activations, + 'pair': extra_msa_stack_output['pair'], + } + + evoformer_masks = { + 'msa': batch['msa_mask'], + 'pair': mask_2d, + } + + # ================================================== + # Template angle feat + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 7-8 + # ================================================== + if self.config.template.enabled and self.config.template.embed_torsion_angles: + num_templ, num_res = batch['template_aatype'].shape[1:] + + aatype_one_hot = nn.functional.one_hot(batch['template_aatype'], 22) + # Embed the templates aatype, torsion angles and masks. + # Shape (templates, residues, msa_channels) + ret = all_atom.atom37_to_torsion_angles( + aatype=batch['template_aatype'], + all_atom_pos=batch['template_all_atom_positions'], + all_atom_mask=batch['template_all_atom_masks'], + # Ensure consistent behaviour during testing: + placeholder_for_undefined=not self.global_config.zero_init) + + template_features = paddle.concat([ + aatype_one_hot, + paddle.reshape(ret['torsion_angles_sin_cos'], + [-1, num_templ, num_res, 14]), + paddle.reshape(ret['alt_torsion_angles_sin_cos'], + [-1, num_templ, num_res, 14]), + ret['torsion_angles_mask']], axis=-1) + + template_activations = self.template_single_embedding( + template_features) + template_activations = nn.functional.relu(template_activations) + template_activations = self.template_projection(template_activations) + + # Concatenate the templates to the msa. + evoformer_input['msa'] = paddle.concat( + [evoformer_input['msa'], template_activations], axis=1) + + # Concatenate templates masks to the msa masks. + # Use mask from the psi angle, as it only depends on the backbone atoms + # from a single residue. + torsion_angle_mask = ret['torsion_angles_mask'][..., 2] + torsion_angle_mask = torsion_angle_mask.astype( + evoformer_masks['msa'].dtype) + evoformer_masks['msa'] = paddle.concat( + [evoformer_masks['msa'], torsion_angle_mask], axis=1) + + # scatter if using dap, otherwise do nothing + # [B, N_seq, N_res, c_m] => [B, N_seq//dap_size, N_res, c_m] + evoformer_input['msa'] = dap.scatter(evoformer_input['msa'], axis=1) + # [B, N_res, N_res, c_z] => [B, N_res//dap_size, N_res, c_z] + evoformer_input['pair'] = dap.scatter(evoformer_input['pair'], axis=1) + + # ================================================== + # Main MSA Stack + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18 + # ================================================== + for evoformer_block in self.evoformer_iteration: + msa_act, pair_act = recompute_wrapper(evoformer_block, + evoformer_input['msa'], + evoformer_input['pair'], + evoformer_masks, + is_recompute=self.training) + evoformer_output = { + 'msa': msa_act, + 'pair': pair_act} + evoformer_input = { + 'msa': evoformer_output['msa'], + 'pair': evoformer_output['pair'], + } + + # gather if using dap, otherwise do nothing + # [B, N_seq//dap_size, N_res, c_m] => [B, N_seq, N_res, c_m] + evoformer_output['msa'] = dap.gather(evoformer_output['msa'], axis=1) + # [B, N_res//dap_size, N_res, c_z] => [B, N_res, N_res, c_z] + evoformer_output['pair'] = dap.gather(evoformer_output['pair'], axis=1) + + msa_activations = evoformer_output['msa'] + pair_activations = evoformer_output['pair'] + single_activations = self.single_activations(msa_activations[:, 0]) + + num_seq = batch['msa_feat'].shape[1] + output = { + 'single': single_activations, + 'pair': pair_activations, + # Crop away template rows such that they are not used + # in MaskedMsaHead. + 'msa': msa_activations[:, :num_seq], + 'msa_first_row': msa_activations[:, 0], + } + + return output + + +class OuterProductMean(nn.Layer): + """Computes mean outer product. + + Jumper et al. (2021) Suppl. Alg. 10 "OuterProductMean" + """ + + def __init__(self, channel_num, config, global_config, is_extra_msa, name='outer_product_mean'): + super(OuterProductMean, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + if is_extra_msa: + c_m = channel_num['extra_msa_channel'] + else: + c_m = channel_num['msa_channel'] + + self.layer_norm_input = nn.LayerNorm(c_m, name='layer_norm_input') + self.left_projection = nn.Linear( + c_m, self.config.num_outer_channel, name='left_projection') + self.right_projection = nn.Linear( + c_m, self.config.num_outer_channel, name='right_projection') + + if self.global_config.zero_init: + init_w = nn.initializer.Constant(value=0.0) + else: + init_w = nn.initializer.KaimingNormal() + + self.output_w = paddle.create_parameter( + [self.config.num_outer_channel, self.config.num_outer_channel, channel_num['pair_channel']], + 'float32', default_initializer=init_w) + self.output_b = paddle.create_parameter( + [channel_num['pair_channel']], 'float32', + default_initializer=nn.initializer.Constant(value=0.0)) + + def forward(self, act, mask): + """Builds OuterProductMean module. + + Arguments: + act: MSA representation, shape [batch, N_seq, N_res, c_m]. + mask: MSA mask, shape [batch, N_seq, N_res]. + + Returns: + Update to pair representation, shape [batch, N_res, N_res, c_z]. + """ + # [B, N_seq, N_res//dap_size, c_m] + act = self.layer_norm_input(act) + # [B, N_seq, N_res//dap_size, c_m] => [B, N_seq, N_res//dap_size, num_outer_channel] + right_act_before = self.right_projection(act) + # [B, N_seq, N_res//dap_size, num_outer_channel] => [B, N_seq, N_res, num_outer_channel] + right_act = dap.all_gather(right_act_before, axis=2) + + # [B, N_seq, N_res//dap_size, c_m] => [B, N_seq, N_res//dap_size, num_outer_channel] + left_act = self.left_projection(act) + # [B, N_seq, N_res] => [B, N_seq, N_res, 1] + mask = paddle.unsqueeze(mask, axis=-1) + # [B, N_seq, N_res, 1] => [B, N_seq, N_res//dap_size, 1] + mask_col = dap.scatter(mask, axis=2) + left_act = mask_col * left_act + + # [B, N_seq, N_res//dap_size, 1], [B, N_seq, N_res, 1] => [B, N_res//dap_size, N_res, 1] + epsilon = 1e-3 + norm = paddle.einsum('nabc,nadc->nbdc', mask_col, mask) + epsilon + + def compute_chunk(left_act, right_act): + # This is equivalent to + # + # act = jnp.einsum('abc,ade->dceb', left_act, right_act) + # act = jnp.einsum('dceb,cef->bdf', act, output_w) + output_b + # + # but faster. maybe for subbatch inference? + + # [B, N_seq, N_res//dap_size, num_outer_channel] => [B, N_seq, num_outer_channel, N_res//dap_size] + left_act = left_act.transpose([0, 1, 3, 2]) + # wait if using async communication and dap, otherwise do nothing + right_act_after = dap.all_gather_opp(right_act, axis=2) + # [B, N_seq, num_outer_channel, N_res//dap_size], [B, N_seq, N_res, num_outer_channel] + # => [B, N_res, num_outer_channel, num_outer_channel, N_res//dap_size] + act = paddle.einsum('nacb,nade->ndceb', left_act, right_act_after) + # [B, N_res, num_outer_channel, num_outer_channel, N_res//dap_size], [num_outer_channel, num_outer_channel, c_z] + # => [B, N_res, N_res//dap_size, c_z] + act = paddle.einsum('ndceb,cef->ndbf', act, self.output_w) + self.output_b + # [B, N_res, N_res//dap_size, c_z] => [B, N_res//dap_size, N_res, c_z] + return act.transpose([0, 2, 1, 3]) + + if not self.training: + # low memory mode using subbatch + sb_chunk = subbatch(compute_chunk, [0], [2], + self.config.chunk_size, 1) + act = sb_chunk(left_act, right_act) + else: + act = compute_chunk(left_act, right_act) + + act = act / norm + + return act + + +class TriangleAttention(nn.Layer): + """Triangle Attention. + + Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode" + Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode" + """ + + def __init__(self, channel_num, config, global_config, name='triangle_attention'): + super(TriangleAttention, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + assert config.orientation in ['per_row', 'per_column'] + + self.query_norm = nn.LayerNorm(channel_num['pair_channel'], + name='query_norm') + self.feat_2d_weights = paddle.create_parameter( + [channel_num['pair_channel'], self.config.num_head], 'float32', + default_initializer=nn.initializer.Normal( + std=1. / np.sqrt(channel_num['pair_channel']))) + + self.attention = Attention(self.config, self.global_config, + channel_num['pair_channel'], channel_num['pair_channel'], + channel_num['pair_channel']) + + + def forward(self, pair_act, pair_mask): + """Builds TriangleAttention module. + + Arguments: + pair_act: [batch, N_res, N_res, c_z] pair activations tensor + pair_mask: [batch, N_res, N_res] mask of non-padded regions in the tensor. + + Returns: + Update to pair_act, shape [batch, N_res, N_res, c_z]. + """ + if self.config.orientation == 'per_column': + pair_act = pair_act.transpose([0, 2, 1, 3]) + pair_mask = pair_mask.transpose([0, 2, 1]) + + # [B, N_res//dap_size, N_res] + bias = 1e9 * (pair_mask - 1.) + # [B, N_res//dap_size, 1, 1, N_res] + bias = paddle.unsqueeze(bias, axis=[2, 3]) + + pair_act = self.query_norm(pair_act) + + # [B, N_res//dap_size, N_res, cz], [cz, head] => [B, head, N_res//dap_size, N_res] + nonbatched_bias_before = paddle.einsum('bqkc,ch->bhqk', pair_act, self.feat_2d_weights) + + # # [B, head, N_res//dap_size, N_res] => [B, head, N_res, N_res] + nonbatched_bias = dap.all_gather(nonbatched_bias_before, axis=2) + + if not self.training: + # low memory mode using subbatch + sb_attn = subbatch(self.attention, [0, 1, 2], [1, 1, 1], + self.global_config.subbatch_size, 1) + pair_act = sb_attn(pair_act, pair_act, bias, nonbatched_bias) + else: + pair_act = self.attention(pair_act, pair_act, bias, nonbatched_bias) + + if self.config.orientation == 'per_column': + pair_act = pair_act.transpose([0, 2, 1, 3]) + + return pair_act + + +class TriangleMultiplication(nn.Layer): + """Triangle multiplication layer ("outgoing" or "incoming"). + + Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing" + Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming" + """ + + def __init__(self, channel_num, config, global_config, name='triangle_multiplication'): + super(TriangleMultiplication, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + self.layer_norm_input = nn.LayerNorm(self.channel_num['pair_channel'], name='layer_norm_input') + self.left_projection = nn.Linear(self.channel_num['pair_channel'], + self.config.num_intermediate_channel, name='left_projection') + self.right_projection = nn.Linear(self.channel_num['pair_channel'], + self.config.num_intermediate_channel, name='right_projection') + self.left_gate = nn.Linear(self.channel_num['pair_channel'], + self.config.num_intermediate_channel, name='left_gate') + init_gate_linear(self.left_gate) + self.right_gate = nn.Linear(self.channel_num['pair_channel'], + self.config.num_intermediate_channel, name='right_gate') + init_gate_linear(self.right_gate) + + # line 4 + self.center_layer_norm = nn.LayerNorm(self.config.num_intermediate_channel, name='center_layer_norm') + self.output_projection = nn.Linear(self.config.num_intermediate_channel, + self.channel_num['pair_channel'], name='output_projection') + init_final_linear(self.output_projection) + # line 3 + self.gating_linear = nn.Linear(self.channel_num['pair_channel'], + self.channel_num['pair_channel'], name='output_projection') + init_gate_linear(self.gating_linear) + + def forward(self, act, mask): + """Builds TriangleMultiplication module. + + Arguments: + act: Pair activations, shape [batch, N_res, N_res, c_z] + mask: Pair mask, shape [batch, N_res, N_res]. + + Returns: + Outputs, same shape/type as act. + """ + # Outgoing [batch, N_res//dap_size, N_res] => [batch, N_res//dap_size, N_res, 1] + # Incoming [batch, N_res, N_res//dap_size] => [batch, N_res, N_res//dap_size, 1] + mask = paddle.unsqueeze(mask, axis=-1) # [batch, N_res, N_res, 1] + + # Outgoing [B, N_res//dap_size, N_res, c_z] + # Incoming [B, N_res, N_res//dap_size, c_z] + act = self.layer_norm_input(act) # line 1 + + # Outgoing [B, N_res//dap_size, N_res, c_z] => [B, N_res//dap_size, N_res, num_intermediate_channel] + # Incoming [B, N_res, N_res//dap_size, c_z] => [B, N_res, N_res//dap_size, num_intermediate_channel] + left_proj_act = mask * self.left_projection(act) + right_proj_act = mask * self.right_projection(act) + + # Outgoing [B, N_res//dap_size, N_res, c_z] => [B, N_res//dap_size, N_res, num_intermediate_channel] + # Incoming [B, N_res, N_res//dap_size, c_z] => [B, N_res, N_res//dap_size, num_intermediate_channel] + left_gate_values = nn.functional.sigmoid(self.left_gate(act)) + right_gate_values = nn.functional.sigmoid(self.right_gate(act)) + + # Outgoing [B, N_res//dap_size, N_res, num_intermediate_channel] + # Incoming [B, N_res, N_res//dap_size, num_intermediate_channel] + left_proj_act = left_proj_act * left_gate_values + right_proj_act_before = right_proj_act * right_gate_values + + + # "Outgoing" edges equation: 'ikc,jkc->ijc' + # "Incoming" edges equation: 'kjc,kic->ijc' + # Note on the Suppl. Alg. 11 & 12 notation: + # For the "outgoing" edges, a = left_proj_act and b = right_proj_act + # For the "incoming" edges, it's swapped: + # b = left_proj_act and a = right_proj_act + + if self.config.equation == 'ikc,jkc->ijc': + # Outgoing + # [B, N_res//dap_size, N_res, num_intermediate_channel] => [B, N_res, N_res, num_intermediate_channel] + right_proj_act = dap.all_gather(right_proj_act_before, axis=1) + elif self.config.equation == 'kjc,kic->ijc': + # Incoming + # [B, N_res, N_res//dap_size, num_intermediate_channel] => [B, N_res, N_res, num_intermediate_channel] + right_proj_act = dap.all_gather(right_proj_act_before, axis=2) + else: + raise ValueError('unknown equation.') + + + # Outgoing [B, N_res//dap_size, N_res, c_z] + # Incoming [B, N_res, N_res//dap_size, c_z] + gate_values = nn.functional.sigmoid(self.gating_linear(act)) # line 3 + + if self.config.equation == 'ikc,jkc->ijc': + # Outgoing + dim, out_idx = 1, 1 + equation = 'bikc,bjkc->bijc' + + # [B, N_res, N_res, num_intermediate_channel] + right_proj_act_after = dap.all_gather_opp(right_proj_act, axis=1) + elif self.config.equation == 'kjc,kic->ijc': + # Incoming + dim, out_idx = 2, 2 + equation = 'bkjc,bkic->bijc' + + # [B, N_res, N_res, num_intermediate_channel] + right_proj_act_after = dap.all_gather_opp(right_proj_act, axis=2) + else: + raise ValueError('unknown equation.') + + if not self.training: + einsum_fn = subbatch(paddle.einsum, [1], [dim], + self.global_config.subbatch_size, out_idx) + act = einsum_fn(equation, left_proj_act, right_proj_act_after) + else: + # Outgoing equation = 'bikc,bjkc->bijc' + # [B, N_res//dap_size, N_res, num_intermediate_channel], [B, N_res, N_res, num_intermediate_channel] + # => [B, N_res//dap_size, N_res, num_intermediate_channel] + + # Incoming equation = 'bkjc,bkic->bijc' + # [B, N_res, N_res//dap_size, num_intermediate_channel], [B, N_res, N_res, num_intermediate_channel] + # => [B, N_res, N_res//dap_size, num_intermediate_channel] + act = paddle.einsum(equation, left_proj_act, right_proj_act_after) + + act = self.center_layer_norm(act) + act = self.output_projection(act) + + act = act * gate_values + + return act + + +class TemplatePair(nn.Layer): + """Pair processing for the templates. + + Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack" lines 2-6 + """ + def __init__(self, channel_num, config, global_config): + super(TemplatePair, self).__init__() + self.config = config + self.global_config = global_config + + channel_num = {} + channel_num['pair_channel'] = self.config.triangle_attention_ending_node.value_dim + + self.triangle_attention_starting_node = TriangleAttention(channel_num, + self.config.triangle_attention_starting_node, self.global_config, + name='triangle_attention_starting_node') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_attention_starting_node) + self.triangle_starting_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_attention_ending_node = TriangleAttention(channel_num, + self.config.triangle_attention_ending_node, self.global_config, + name='triangle_attention_ending_node') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_attention_ending_node) + self.triangle_ending_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_multiplication_outgoing = TriangleMultiplication(channel_num, + self.config.triangle_multiplication_outgoing, self.global_config, + name='triangle_multiplication_outgoing') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_multiplication_outgoing) + self.triangle_outgoing_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.triangle_multiplication_incoming = TriangleMultiplication(channel_num, + self.config.triangle_multiplication_incoming, self.global_config, + name='triangle_multiplication_incoming') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.triangle_multiplication_incoming) + self.triangle_incoming_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + self.pair_transition = Transition(channel_num, self.config.pair_transition, + self.global_config, is_extra_msa=False, + transition_type='pair_transition') + + dropout_rate, dropout_axis = self._parse_dropout_params( + self.pair_transition) + self.pair_transition_dropout = nn.Dropout(dropout_rate, axis=dropout_axis) + + + def _parse_dropout_params(self, module): + dropout_rate = 0.0 if self.global_config.deterministic else \ + module.config.dropout_rate + dropout_axis = None + if module.config.shared_dropout: + dropout_axis = { + 'per_row': [0, 2, 3], + 'per_column': [0, 1, 3], + }[module.config.orientation] + + return dropout_rate, dropout_axis + + def forward(self, pair_act, pair_mask): + """Builds one block of TemplatePair module. + + Arguments: + pair_act: Pair activations for single template, shape [batch, N_res, N_res, c_t]. + pair_mask: Pair mask, shape [batch, N_res, N_res]. + + Returns: + Updated pair_act, shape [batch, N_res, N_res, c_t]. + """ + + residual = self.triangle_attention_starting_node(pair_act, pair_mask) + residual = self.triangle_starting_dropout(residual) + pair_act = pair_act + residual + + residual = self.triangle_attention_ending_node(pair_act, pair_mask) + residual = self.triangle_ending_dropout(residual) + pair_act = pair_act + residual + + residual = self.triangle_multiplication_outgoing(pair_act, pair_mask) + residual = self.triangle_outgoing_dropout(residual) + pair_act = pair_act + residual + + residual = self.triangle_multiplication_incoming(pair_act, pair_mask) + residual = self.triangle_incoming_dropout(residual) + pair_act = pair_act + residual + + residual = self.pair_transition(pair_act, pair_mask) + residual = self.pair_transition_dropout(residual) + pair_act = pair_act + residual + + return pair_act + + +class SingleTemplateEmbedding(nn.Layer): + """Embeds a single template. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9+11 + """ + def __init__(self, channel_num, config, global_config): + super(SingleTemplateEmbedding, self).__init__() + self.config = config + self.channel_num = channel_num + self.global_config = global_config + + self.embedding2d = nn.Linear(channel_num['template_pair'], + self.config.template_pair_stack.triangle_attention_ending_node.value_dim) + + self.template_pair_stack = nn.LayerList() + for _ in range(self.config.template_pair_stack.num_block): + self.template_pair_stack.append(TemplatePair( + self.channel_num, self.config.template_pair_stack, self.global_config)) + + self.output_layer_norm = nn.LayerNorm(self.config.attention.key_dim) + + def forward(self, query_embedding, batch, mask_2d): + """Build the single template embedding. + + Arguments: + query_embedding: Query pair representation, shape [batch, N_res, N_res, c_z]. + batch: A batch of template features (note the template dimension has been + stripped out as this module only runs over a single template). + mask_2d: Padding mask (Note: this doesn't care if a template exists, + unlike the template_pseudo_beta_mask). + + Returns: + A template embedding [N_res, N_res, c_z]. + """ + assert mask_2d.dtype == query_embedding.dtype + dtype = query_embedding.dtype + num_res = batch['template_aatype'].shape[1] + template_mask = batch['template_pseudo_beta_mask'] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + template_mask_2d = template_mask_2d.astype(dtype) + + template_dgram = dgram_from_positions( + batch['template_pseudo_beta'], + **self.config.dgram_features) + template_dgram = template_dgram.astype(dtype) + + aatype = nn.functional.one_hot(batch['template_aatype'], 22) + aatype = aatype.astype(dtype) + + to_concat = [template_dgram, template_mask_2d[..., None]] + to_concat.append(paddle.tile(aatype[..., None, :, :], + [1, num_res, 1, 1])) + to_concat.append(paddle.tile(aatype[..., None, :], + [1, 1, num_res, 1])) + + n, ca, c = [residue_constants.atom_order[a] + for a in ('N', 'CA', 'C')] + rot, trans = quat_affine.make_transform_from_reference( + n_xyz=batch['template_all_atom_positions'][..., n, :], + ca_xyz=batch['template_all_atom_positions'][..., ca, :], + c_xyz=batch['template_all_atom_positions'][..., c, :]) + affines = quat_affine.QuatAffine( + quaternion=quat_affine.rot_to_quat(rot), + translation=trans, + rotation=rot) + + points = [paddle.unsqueeze(x, axis=-2) for x in + paddle.unstack(affines.translation, axis=-1)] + affine_vec = affines.invert_point(points, extra_dims=1) + inv_distance_scalar = paddle.rsqrt( + 1e-6 + sum([paddle.square(x) for x in affine_vec])) + + # Backbone affine mask: whether the residue has C, CA, N + # (the template mask defined above only considers pseudo CB). + template_mask = ( + batch['template_all_atom_masks'][..., n] * + batch['template_all_atom_masks'][..., ca] * + batch['template_all_atom_masks'][..., c]) + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + inv_distance_scalar *= template_mask_2d.astype(inv_distance_scalar.dtype) + + unit_vector = [(x * inv_distance_scalar)[..., None] for x in affine_vec] + unit_vector = [x.astype(dtype) for x in unit_vector] + if not self.config.use_template_unit_vector: + unit_vector = [paddle.zeros_like(x) for x in unit_vector] + to_concat.extend(unit_vector) + + template_mask_2d = template_mask_2d.astype(dtype) + to_concat.append(template_mask_2d[..., None]) + + act = paddle.concat(to_concat, axis=-1) + # Mask out non-template regions so we don't get arbitrary values in the + # distogram for these regions. + act *= template_mask_2d[..., None] + + act = self.embedding2d(act) + for pair_encoder in self.template_pair_stack: + act = pair_encoder(act, mask_2d) + + act = self.output_layer_norm(act) + return act + + +class TemplateEmbedding(nn.Layer): + """Embeds a set of templates. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12 + Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention" + """ + + def __init__(self, channel_num, config, global_config): + super(TemplateEmbedding, self).__init__() + self.config = config + self.global_config = global_config + + self.single_template_embedding = SingleTemplateEmbedding( + channel_num, config, global_config) + self.attention = Attention( + config.attention, global_config, + channel_num['pair_channel'], + config.attention.key_dim, + channel_num['pair_channel']) + + def forward(self, query_embedding, template_batch, mask_2d): + """Build TemplateEmbedding module. + + Arguments: + query_embedding: Query pair representation, shape [n_batch, N_res, N_res, c_z]. + template_batch: A batch of template features. + mask_2d: Padding mask (Note: this doesn't care if a template exists, + unlike the template_pseudo_beta_mask). + + Returns: + A template embedding [n_batch, N_res, N_res, c_z]. + """ + + num_templates = template_batch['template_mask'].shape[1] + + num_channels = (self.config.template_pair_stack + .triangle_attention_ending_node.value_dim) + + num_res = query_embedding.shape[1] + + dtype = query_embedding.dtype + template_mask = template_batch['template_mask'] + template_mask = template_mask.astype(dtype) + + query_channels = query_embedding.shape[-1] + + outs = [] + for i in range(num_templates): + # By default, num_templates = 4 + batch0 = {k: paddle.squeeze(v.slice([1], [i], [i+1]), axis=1) + for k, v in template_batch.items()} + outs.append(self.single_template_embedding( + query_embedding, batch0, mask_2d)) + + template_pair_repr = paddle.stack(outs, axis=1) + + flat_query = paddle.reshape( + query_embedding, [-1, num_res * num_res, 1, query_channels]) + flat_templates = paddle.reshape( + paddle.transpose(template_pair_repr, [0, 2, 3, 1, 4]), + [-1, num_res * num_res, num_templates, num_channels]) + + bias = 1e9 * (template_mask[:, None, None, None, :] - 1.) + + if not self.training: + sb_attn = subbatch(self.attention, [0, 1], [1, 1], + self.config.subbatch_size, 1) + emb = sb_attn(flat_query, flat_templates, bias) + + else: + emb = self.attention(flat_query, flat_templates, bias) + + emb = paddle.reshape( + emb, [-1, num_res, num_res, query_channels]) + + # No gradients if no templates. + emb *= (paddle.sum(template_mask) > 0.).astype(emb.dtype) + return emb diff --git a/apps/protein_folding/helixfold_cpu/playground.py b/apps/protein_folding/helixfold_cpu/playground.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/protein_folding/helixfold_cpu/protein.py b/apps/protein_folding/helixfold_cpu/protein.py new file mode 100644 index 00000000..4e58b901 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/protein.py @@ -0,0 +1,279 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Protein data type.""" + +import dataclasses +import io +from typing import Any, Mapping, Optional +from tools import residue_constants +from Bio.PDB import PDBParser +import numpy as np + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. + +# Complete sequence of chain IDs supported by the PDB format. +PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' +PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. + + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # 0-indexed number corresponding to the chain in the protein that this residue + # belongs to. + chain_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + def __post_init__(self): + if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: + raise ValueError( + f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains ' + 'because these cannot be written to PDB format.') + + +def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: + """Takes a PDB string and constructs a Protein object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Args: + pdb_str: The contents of the pdb file + chain_id: If chain_id is specified (e.g. A), then only that chain + is parsed. Otherwise all chains are parsed. + + Returns: + A new `Protein` parsed from the pdb contents. + """ + pdb_fh = io.StringIO(pdb_str) + parser = PDBParser(QUIET=True) + structure = parser.get_structure('none', pdb_fh) + models = list(structure.get_models()) + if len(models) != 1: + raise ValueError( + f'Only single model PDBs are supported. Found {len(models)} models.') + model = models[0] + + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + chain_ids = [] + b_factors = [] + + for chain in model: + if chain_id is not None and chain.id != chain_id: + continue + for res in chain: + if res.id[2] != ' ': + raise ValueError( + f'PDB contains an insertion code at chain {chain.id} and residue ' + f'index {res.id[1]}. These are not supported.') + res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num,)) + res_b_factors = np.zeros((residue_constants.atom_type_num,)) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1. + res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor + if np.sum(mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + chain_ids.append(chain.id) + b_factors.append(res_b_factors) + + # Chain IDs are usually characters so map these to ints. + unique_chain_ids = np.unique(chain_ids) + chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} + chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + chain_index=chain_index, + b_factors=np.array(b_factors)) + + +def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: + chain_end = 'TER' + return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' + f'{chain_name:>1}{residue_index:>4}') + + +def to_pdb(prot: Protein) -> str: + """Converts a `Protein` instance to a PDB string. + + Args: + prot: The protein to convert to PDB. + + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ['X'] + res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + chain_index = prot.chain_index.astype(np.int32) + b_factors = prot.b_factors + + if np.any(aatype > residue_constants.restype_num): + raise ValueError('Invalid aatypes.') + + # Construct a mapping from chain integer indices to chain ID strings. + chain_ids = {} + for i in np.unique(chain_index): # np.unique gives sorted output. + if i >= PDB_MAX_CHAINS: + raise ValueError( + f'The PDB format supports at most {PDB_MAX_CHAINS} chains.') + chain_ids[i] = PDB_CHAIN_IDS[i] + + pdb_lines.append('MODEL 1') + atom_index = 1 + last_chain_index = chain_index[0] + # Add all atom sites. + for i in range(aatype.shape[0]): + # Close the previous chain if in a multichain PDB. + if last_chain_index != chain_index[i]: + pdb_lines.append(_chain_end( + atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]], + residue_index[i - 1])) + last_chain_index = chain_index[i] + atom_index += 1 # Atom index increases at the TER symbol. + + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip( + atom_types, atom_positions[i], atom_mask[i], b_factors[i]): + if mask < 0.5: + continue + + record_type = 'ATOM' + name = atom_name if len(atom_name) == 4 else f' {atom_name}' + alt_loc = '' + insertion_code = '' + occupancy = 1.00 + element = atom_name[0] # Protein supports only C, N, O, S, this works. + charge = '' + # PDB is a columnar format, every space matters here! + atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' + f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}' + f'{residue_index[i]:>4}{insertion_code:>1} ' + f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' + f'{occupancy:>6.2f}{b_factor:>6.2f} ' + f'{element:>2}{charge:>2}') + pdb_lines.append(atom_line) + atom_index += 1 + + # Close the final chain. + pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]), + chain_ids[chain_index[-1]], residue_index[-1])) + pdb_lines.append('ENDMDL') + pdb_lines.append('END') + + # Pad all lines to 80 characters. + pdb_lines = [line.ljust(80) for line in pdb_lines] + return '\n'.join(pdb_lines) + '\n' # Add terminating newline. + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + + `Protein.atom_mask` typically is defined according to the atoms that are + reported in the PDB. This function computes a mask according to heavy atoms + that should be present in the given sequence of amino acids. + + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction( + features: FeatureDict, + result: ModelOutput, + b_factors: Optional[np.ndarray] = None, + remove_leading_feature_dimension: bool = True) -> Protein: + """Assembles a protein from a prediction. + + Args: + features: Dictionary holding model inputs. + result: Dictionary holding model outputs. + b_factors: (Optional) B-factors to use for the protein. + remove_leading_feature_dimension: Whether to remove the leading dimension + of the `features` values. + + Returns: + A protein instance. + """ + fold_output = result['structure_module'] + + def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray: + return arr[0] if remove_leading_feature_dimension else arr + + if 'asym_id' in features: + chain_index = _maybe_remove_leading_dim(features['asym_id']) + else: + chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype'])) + + if b_factors is None: + b_factors = np.zeros_like(fold_output['final_atom_mask']) + + return Protein( + aatype=_maybe_remove_leading_dim(features['aatype']), + atom_positions=fold_output['final_atom_positions'], + atom_mask=fold_output['final_atom_mask'], + residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1, + chain_index=chain_index, + b_factors=b_factors) diff --git a/apps/protein_folding/helixfold_cpu/setup_env.sh b/apps/protein_folding/helixfold_cpu/setup_env.sh new file mode 100644 index 00000000..8cc65690 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/setup_env.sh @@ -0,0 +1,3 @@ +python -m pip install numpy +python -m pip install joblib +python -m pip install paddlepaddle==2.4.0 -i https://pypi.tuna.tsinghua.edu.cn/simple \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/stut_alphafold.py b/apps/protein_folding/helixfold_cpu/stut_alphafold.py new file mode 100644 index 00000000..745ee2b3 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/stut_alphafold.py @@ -0,0 +1,69 @@ +from layers.static_net import AlphaFold +from config import model_config +import time +import numpy as np +from tqdm import tqdm +import paddle +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] static UT of pdinfer.HelixFold') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model'] + +n_warm = 0 +n_iter = 1 +ignore_eval = False +root_weights = 'static_modules' +module_prefix = 'alphafold' + +### create sample input +len_dim = 765 +feed_dict = { + 'target_feat': np.ones([1, 4, len_dim, 22], dtype='float32'), + 'msa_feat': np.ones([1, 4, 508, len_dim, 49], dtype='float32'), + 'seq_mask': np.ones([1, 4, len_dim], dtype='float32'), + 'seq_length': np.ones([1, 4, len_dim], dtype='int32'), + 'aatype': np.ones([1, 4, len_dim], dtype='float32'), + 'residue_index': np.ones([1, 4, len_dim], dtype='float32'), + 'template_mask': np.ones([1, 4, 4], dtype='float32'), + 'template_aatype': np.ones([1, 4, 4, len_dim], dtype="int32"), # define + 'template_pseudo_beta_mask': np.ones([1, 4, 4, len_dim], dtype='float32'), + 'template_pseudo_beta': np.ones([1, 4, 4, len_dim, 3], dtype='float32'), + 'template_all_atom_positions': np.ones([1, 4, 4, len_dim, 37, 3], dtype='float32'), + 'template_all_atom_masks': np.ones([1, 4, 4, len_dim, 37], dtype='float32'), + 'extra_msa': np.ones([1, 4, 5120, len_dim], dtype='float32'), + 'extra_has_deletion': np.ones([1, 4, 5120, len_dim], dtype='float32'), + 'extra_deletion_value': np.ones([1, 4, 5120, len_dim], dtype='float32'), + 'extra_msa_mask': np.ones([1, 4, 5120, len_dim], dtype='float32'), + 'msa_mask': np.ones([1, 4, 508, len_dim], dtype='float32'), + 'prev_pos': np.ones([1, 4, len_dim, 37, 3], dtype='float32'), + 'prev_msa_first_row': np.ones([1, 4, len_dim, 256], dtype='float32'), + 'prev_pair': np.ones([1, 4, len_dim, len_dim, 128], dtype='float32'), + 'atom14_atom_exists': np.ones([1, 4, len_dim, 14], dtype='float32'), + 'atom37_atom_exists': np.ones([1, 4, len_dim, 37], dtype='float32'), + 'residx_atom37_to_atom14': np.ones([1, 4, len_dim, 37], dtype='float32') +} + +print('# [INFO] build and save static graph of HelixFold') +model = AlphaFold( + config=c, + seq_len=len_dim, + n_cpus=n_cpus, + module_prefix=module_prefix, + root_weights=root_weights, + is_pdinfer_init=False) + +print('# [INFO] inference on static graph') +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + outputs = model(feed_dict, False) + dt = time.time() - t0 + if i >= n_warm: + dts += dt +print('# [INFO] avg inference time = {}'.format(dts/(n_iter-n_warm))) diff --git a/apps/protein_folding/helixfold_cpu/stut_embeddings.py b/apps/protein_folding/helixfold_cpu/stut_embeddings.py new file mode 100644 index 00000000..e1903db2 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/stut_embeddings.py @@ -0,0 +1,73 @@ +from layers.static_backbones import StaticEmbeddings +from config import model_config +import paddle as pd +import numpy as np +from tqdm import tqdm +import time +import os +from argparse import ArgumentParser as Parser + + +parser = Parser('# [INFO] UT of static helixfold.evoformeriteration') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +n_warm = 3 +n_iter = 13 +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer'] +gc = cfg['model']['global_config'] +len_dim = 206 +feed_dict = { + 'target_feat': np.ones([1, len_dim, 22], dtype='float32'), + 'msa_feat': np.ones([1, 508, len_dim, 49], dtype='float32'), # 508 -> 512 + 'seq_mask': np.ones([1, len_dim], dtype='float32'), + 'aatype': np.ones([1, len_dim], dtype='int32'), + 'residue_index': np.ones([1, len_dim], dtype='float32'), + 'template_mask': np.ones([1, 4], dtype='float32'), + 'template_aatype': np.ones([1, 4, len_dim], dtype="int32"), # define + 'template_pseudo_beta_mask': np.ones([1, 4, len_dim], dtype='float32'), + 'template_pseudo_beta': np.ones([1, 4, len_dim, 3], dtype='float32'), + 'template_all_atom_positions': np.ones([1, 4, len_dim, 37, 3], dtype='float32'), + 'template_all_atom_masks': np.ones([1, 4, len_dim, 37], dtype='float32'), + 'extra_msa': np.ones([1, 5120, len_dim], dtype='float32'), + 'extra_has_deletion': np.ones([1, 5120, len_dim], dtype='float32'), + 'extra_deletion_value': np.ones([1, 5120, len_dim], dtype='float32'), + 'prev_pos': np.ones([1, len_dim, 37, 3], dtype='float32'), + 'prev_msa_first_row': np.ones([1, len_dim, 256], dtype='float32'), + 'prev_pair': np.ones([1, len_dim, len_dim, 128], dtype='float32'), +} +channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + 'extra_msa_channel': c.extra_msa_channel, + 'msa_channel': c.msa_channel, + 'pair_channel': c.pair_channel, + 'seq_channel': c.seq_channel + } +model = StaticEmbeddings( + config=c, + global_config=gc, + feed_dict=feed_dict, + channel_num=channel_num, + n_cpus=n_cpus, + module_prefix='evoformeriteration', + root_weights='static_modules', + is_pdinfer_init=False +) + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + outputs = model(feed_dict) + dt = time.time() - t0 + if i >= n_warm: + dts += dt +print('# [INFO] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +for name, output in outputs.items(): + print('# [INFO] {} -> {}'.format( + name, output.shape + )) diff --git a/apps/protein_folding/helixfold_cpu/stut_embeddingsandevoformer.py b/apps/protein_folding/helixfold_cpu/stut_embeddingsandevoformer.py new file mode 100644 index 00000000..56c269fa --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/stut_embeddingsandevoformer.py @@ -0,0 +1,80 @@ +from layers.static_subnets import StaticEmbeddingsAndEvoformer +from config import model_config +import time +import numpy as np +from tqdm import tqdm +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] static UT of pdinfer.StaticEmbeddingsAndEvoformer') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer'] +gc = cfg['model']['global_config'] + +n_warm = 1 +n_iter = 2 +ignore_eval = False +root_weights = 'static_modules' +module_prefix = 'embeddingsandevoformer' +channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + 'extra_msa_channel': c.extra_msa_channel, + 'msa_channel': c.msa_channel, + 'pair_channel': c.pair_channel, + 'seq_channel': c.seq_channel + } +### create sample input +len_dim = 512 +feed_dict = { + 'target_feat': np.ones([1, len_dim, 22], dtype='float32'), + 'msa_feat': np.ones([1, 508, len_dim, 49], dtype='float32'), + 'seq_mask': np.ones([1, len_dim], dtype='float32'), + 'aatype': np.ones([1, len_dim], dtype='float32'), + 'residue_index': np.ones([1, len_dim], dtype='float32'), + 'template_mask': np.ones([1, 4], dtype='float32'), + 'template_aatype': np.ones([1, 4, len_dim], dtype="int32"), # define + 'template_pseudo_beta_mask': np.ones([1, 4, len_dim], dtype='float32'), + 'template_pseudo_beta': np.ones([1, 4, len_dim, 3], dtype='float32'), + 'template_all_atom_positions': np.ones([1, 4, len_dim, 37, 3], dtype='float32'), + 'template_all_atom_masks': np.ones([1, 4, len_dim, 37], dtype='float32'), + 'extra_msa': np.ones([1, 5120, len_dim], dtype='float32'), + 'extra_has_deletion': np.ones([1, 5120, len_dim], dtype='float32'), + 'extra_deletion_value': np.ones([1, 5120, len_dim], dtype='float32'), + 'extra_msa_mask': np.ones([1, 5120, len_dim], dtype='float32'), + 'msa_mask': np.ones([1, 508, len_dim], dtype='float32'), + 'prev_pos': np.ones([1, len_dim, 37, 3], dtype='float32'), + 'prev_msa_first_row': np.ones([1, len_dim, 256], dtype='float32'), + 'prev_pair': np.ones([1, len_dim, len_dim, 128], dtype='float32'), + # 'torsion_angles_sin_cos': pd.ones([1, 4, len_dim, 7, 2]), + # 'alt_torsion_angles_sin_cos': pd.ones([1, 4, len_dim, 7, 2]), + # 'torsion_angles_mask': pd.ones([1, 4, len_dim, 7]) +} + +print('# [INFO] build and save static graph of Attention') +model = StaticEmbeddingsAndEvoformer( + config=c, + global_config=gc, + seq_len=len_dim, + channel_num=channel_num, + n_cpus=n_cpus, + module_prefix=module_prefix, + root_weights=root_weights, + is_pdinfer_init=False) + +print('# [INFO] inference on static graph') +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + outputs = model(feed_dict) + print('# [INFO] output of {}:'.format(module_prefix)) + for k, v in outputs.items(): + print('{} -> {}'.format(k, v.shape)) + dt = time.time() - t0 + if i >= n_warm: + dts += dt +print('# [INFO] avg inference time = {}'.format(dts/(n_iter-n_warm))) diff --git a/apps/protein_folding/helixfold_cpu/stut_evoformer.py b/apps/protein_folding/helixfold_cpu/stut_evoformer.py new file mode 100644 index 00000000..e9a03120 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/stut_evoformer.py @@ -0,0 +1,69 @@ +import pdb +from layers.static_backbones import StaticEvoformer +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from tqdm import tqdm +import pickle +from argparse import ArgumentParser as Parser +import warnings + +model_prefix = 'evoformer' +parser = Parser('[pd.infer] UT of static helixfold.{}'.format(model_prefix)) +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus +# os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 13 +channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + 'extra_msa_channel': c.extra_msa_channel, + 'msa_channel': c.msa_channel, + 'pair_channel': c.pair_channel, + 'seq_channel': c.seq_channel + } +### create sample input +len_dim = 1024 # 1024=350GB +feed_dict = { + 'msa_activations': np.ones([1, 508, len_dim, 256], dtype='float32'), + 'extra_pair_act': np.ones([1, len_dim, len_dim, 128], dtype='float32'), + 'msa_mask': np.ones([1, 508, len_dim], dtype='float32'), + 'mask_2d': np.ones([1, len_dim, len_dim], dtype='float32') +} + +print('# [INFO] build and save static graph of {}'.format(model_prefix)) +model = StaticEvoformer( + config=c, + global_config=gc, + feed_dict=feed_dict, + channel_num=channel_num, + n_cpus=n_cpus, + module_prefix='evoformer', + root_weights='static_modules', + is_pdinfer_init=False) + +print('# [INFO] inference on static graph') +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + outputs = model(feed_dict) + print('# [INFO] output of {}:'.format(model_prefix)) + for k, v in outputs.items(): + print('{} -> {}'.format(k, v.shape)) + dt = time.time() - t0 + if i >= n_warm: + dts += dt +print('# [INFO] avg inference time = {}'.format(dts/(n_iter-n_warm))) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/stut_evoformeriteration.py b/apps/protein_folding/helixfold_cpu/stut_evoformeriteration.py new file mode 100644 index 00000000..9c0a740d --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/stut_evoformeriteration.py @@ -0,0 +1,50 @@ +from layers.static_backbones import StaticEvoformerIteration +from config import model_config +import numpy as np +from tqdm import tqdm +import time +from argparse import ArgumentParser as Parser + + +parser = Parser('# [INFO] UT of static helixfold.evoformeriteration') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +n_warm = 3 +n_iter = 13 +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['evoformer'] +gc = cfg['model']['global_config'] +len_dim = 1024 +feed_dict = { + 'msa_act':np.ones([1, 512, len_dim, 256], dtype='float32'), + 'pair_act':np.ones([1, len_dim, len_dim, 128], dtype='float32'), + 'msa_mask':np.ones([1, 512, len_dim], dtype='float32'), + 'pair_mask':np.ones([1, len_dim, len_dim], dtype='float32') +} +channel_num={'msa_channel':256, 'pair_channel':128} +is_extra_msa = False +model = StaticEvoformerIteration( + config=c, + global_config=gc, + feed_dict=feed_dict, + channel_num=channel_num, + n_cpus=n_cpus, + is_extra_msa=is_extra_msa, + module_prefix='evoformeriteration', + root_weights='static_modules', + is_pdinfer_init=False +) + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + outputs = model(feed_dict) + dt = time.time() - t0 + if i >= n_warm: + dts += dt +print('# [INFO] avg inference time = {}'.format(dts/(n_iter-n_warm))) +for k, v in outputs.items(): + print('{} -> {}'.format(k, v.shape)) diff --git a/apps/protein_folding/helixfold_cpu/stut_extraevoformeriterations.py b/apps/protein_folding/helixfold_cpu/stut_extraevoformeriterations.py new file mode 100644 index 00000000..efdf0987 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/stut_extraevoformeriterations.py @@ -0,0 +1,69 @@ +import pdb +from layers.static_backbones import StaticExtraEvoformerIterations +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from tqdm import tqdm +import pickle +from argparse import ArgumentParser as Parser +import warnings + +model_prefix = 'extramsa' +parser = Parser('[pd.infer] UT of static helixfold.{}'.format(model_prefix)) +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus +# os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 13 +channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + 'extra_msa_channel': c.extra_msa_channel, + 'msa_channel': c.msa_channel, + 'pair_channel': c.pair_channel, + 'seq_channel': c.seq_channel + } +### create sample input +len_dim = 206 +feed_dict = { + 'extra_msa_act': np.ones([1, 5120, len_dim, 64], dtype='float32'), + 'extra_pair_act': np.ones([1, len_dim, len_dim, 128], dtype='float32'), + 'extra_msa_mask': np.ones([1, 5120, len_dim], dtype='float32'), + 'mask_2d': np.ones([1, len_dim, len_dim], dtype='float32') +} + +print('# [INFO] build and save static graph of {}'.format(model_prefix)) +model = StaticExtraEvoformerIterations( + config=c, + global_config=gc, + feed_dict=feed_dict, + channel_num=channel_num, + n_cpus=n_cpus, + module_prefix='extramsa', + root_weights='static_modules', + is_pdinfer_init=False) + +print('# [INFO] inference on dynamic graph') +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + outputs = model(feed_dict) + print('# [INFO] output of extramsa:') + for k, v in outputs.items(): + print('{} -> {}'.format(k, v.shape)) + dt = time.time() - t0 + if i >= n_warm: + dts += dt +print('# [INFO] avg inference time = {}'.format(dts/(n_iter-n_warm))) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/stut_extramsa.py b/apps/protein_folding/helixfold_cpu/stut_extramsa.py new file mode 100644 index 00000000..5781f146 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/stut_extramsa.py @@ -0,0 +1,69 @@ +import pdb +from layers.static_backbones import StaticExtraMsa +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from tqdm import tqdm +import pickle +from argparse import ArgumentParser as Parser +import warnings + +model_prefix = 'extramsa' +parser = Parser('[pd.infer] UT of static helixfold.{}'.format(model_prefix)) +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus +# os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 13 +channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + 'extra_msa_channel': c.extra_msa_channel, + 'msa_channel': c.msa_channel, + 'pair_channel': c.pair_channel, + 'seq_channel': c.seq_channel + } +### create sample input +len_dim = 206 +feed_dict = { + 'extra_msa_act': np.ones([1, 5120, len_dim, 64], dtype='float32'), + 'extra_pair_act': np.ones([1, len_dim, len_dim, 128], dtype='float32'), + 'extra_msa_mask': np.ones([1, 5120, len_dim], dtype='float32'), + 'mask_2d': np.ones([1, len_dim, len_dim], dtype='float32') +} + +print('# [INFO] build and save static graph of {}'.format(model_prefix)) +model = StaticExtraMsa( + config=c, + global_config=gc, + feed_dict=feed_dict, + channel_num=channel_num, + n_cpus=n_cpus, + module_prefix='extramsa', + root_weights='static_modules', + is_pdinfer_init=False) + +print('# [INFO] inference on dynamic graph') +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + outputs = model(feed_dict) + print('# [INFO] output of extramsa:') + for k, v in outputs.items(): + print('{} -> {}'.format(k, v.shape)) + dt = time.time() - t0 + if i >= n_warm: + dts += dt +print('# [INFO] avg inference time = {}'.format(dts/(n_iter-n_warm))) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/stut_singletemplateembedding.py b/apps/protein_folding/helixfold_cpu/stut_singletemplateembedding.py new file mode 100644 index 00000000..5f77ef83 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/stut_singletemplateembedding.py @@ -0,0 +1,60 @@ +from layers.static_backbones import StaticSingleTemplateEmbedding +from config import model_config +import paddle as pd +import numpy as np +from tqdm import tqdm +import time +import os +from argparse import ArgumentParser as Parser + + +parser = Parser('# [INFO] UT of static helixfold.singletemplateembedding') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +n_warm = 3 +n_iter = 13 +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer'] +gc = cfg['model']['global_config'] +len_dim = 206 +feed_dict = { + 'msa_mask': np.ones([1, 508, len_dim], dtype='float32'), + 'torsion_angles_mask': np.ones([1, 4, len_dim, 7], dtype='float32'), + 'msa_activations_raw': np.ones([1, 508, len_dim, 256], dtype='float32'), + 'template_features': np.ones([1, 4, len_dim, 57], dtype='float32') +} +channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + 'extra_msa_channel': c.extra_msa_channel, + 'msa_channel': c.msa_channel, + 'pair_channel': c.pair_channel, + 'seq_channel': c.seq_channel +} +model = StaticSingleTemplateEmbedding( + config=c, + global_config=gc, + feed_dict=feed_dict, + channel_num=channel_num, + n_cpus=n_cpus, + module_prefix='singletemplateembedding', + root_weights='static_modules', + is_pdinfer_init=False +) + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + outputs = model(feed_dict) + dt = time.time() - t0 + if i >= n_warm: + dts += dt +print('# [INFO] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +for name, output in outputs.items(): + print('# [INFO] {} -> {}'.format( + name, output.shape + )) diff --git a/apps/protein_folding/helixfold_cpu/tools/__init__.py b/apps/protein_folding/helixfold_cpu/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/protein_folding/helixfold_cpu/tools/all_atom.py b/apps/protein_folding/helixfold_cpu/tools/all_atom.py new file mode 100644 index 00000000..68345a6b --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/all_atom.py @@ -0,0 +1,1166 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ops for all atom representations. + +Generally we employ two different representations for all atom coordinates, +one is atom37 where each heavy atom corresponds to a given position in a 37 +dimensional array, This mapping is non amino acid specific, but each slot +corresponds to an atom of a given name, for example slot 12 always corresponds +to 'C delta 1', positions that are not present for a given amino acid are +zeroed out and denoted by a mask. +The other representation we employ is called atom14, this is a more dense way +of representing atoms with 14 slots. Here a given slot will correspond to a +different kind of atom depending on amino acid type, for example slot 5 +corresponds to 'N delta 2' for Aspargine, but to 'C delta 1' for Isoleucine. +14 is chosen because it is the maximum number of heavy atoms for any standard +amino acid. +The order of slots can be found in 'residue_constants.residue_atoms'. +Internally the model uses the atom14 representation because it is +computationally more efficient. +The internal atom14 representation is turned into the atom37 at the output of +the network to facilitate easier conversion to existing protein datastructures. +""" + +import numpy as np +from typing import Dict, Optional +import paddle +import paddle.nn as nn +from tools import residue_constants, r3 +import tools.model_utils as utils + + +def squared_difference(x, y): + return paddle.square(x - y) + + +def get_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in residue_constants.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in residue_constants.restypes: + residue_name = residue_constants.restype_1to3[residue_name] + residue_chi_angles = residue_constants.chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append( + [residue_constants.atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return paddle.to_tensor(chi_atom_indices) + + +def atom14_to_atom37(atom14_data: paddle.Tensor, # (B, N, 14, ...) + batch: Dict[str, paddle.Tensor] + ) -> paddle.Tensor: # (B, N, 37, ...) + """Convert atom14 to atom37 representation.""" + assert len(atom14_data.shape) in [3, 4] + assert 'residx_atom37_to_atom14' in batch + assert 'atom37_atom_exists' in batch + + atom37_data = utils.batched_gather(atom14_data, + batch['residx_atom37_to_atom14'], + batch_dims=2) + if len(atom14_data.shape) == 3: + atom37_data *= batch['atom37_atom_exists'] + elif len(atom14_data.shape) == 4: + atom37_data *= batch['atom37_atom_exists'][:, :, :, + None].astype(atom37_data.dtype) + return atom37_data + + +def atom37_to_atom14( + atom37_data: paddle.Tensor, # (B, N, 37, ...) + batch: Dict[str, paddle.Tensor]) -> paddle.Tensor: # (B, N, 14, ...) + """Convert atom14 to atom37 representation.""" + assert len(atom37_data.shape) in [3, 4] + assert 'residx_atom14_to_atom37' in batch + assert 'atom14_atom_exists' in batch + + atom14_data = utils.batched_gather(atom37_data, + batch['residx_atom14_to_atom37'], + batch_dims=2) + if len(atom37_data.shape) == 3: + atom14_data *= batch['atom14_atom_exists'].astype(atom14_data.dtype) + elif len(atom37_data.shape) == 4: + atom14_data *= batch['atom14_atom_exists'][:, :, :, + None].astype(atom14_data.dtype) + return atom14_data + + +def atom37_to_frames( + aatype: paddle.Tensor, # (B, N) + all_atom_positions: paddle.Tensor, # (B, N, 37, 3) + all_atom_mask: paddle.Tensor, # (B, N, 37) +) -> Dict[str, paddle.Tensor]: + """Computes the frames for the up to 8 rigid groups for each residue. + + The rigid groups are defined by the possible torsions in a given amino acid. + We group the atoms according to their dependence on the torsion angles into + "rigid groups". E.g., the position of atoms in the chi2-group depend on + chi1 and chi2, but do not depend on chi3 or chi4. + Jumper et al. (2021) Suppl. Table 2 and corresponding text. + + Args: + aatype: Amino acid type, given as Tensor with integers. + all_atom_positions: atom37 representation of all atom coordinates. + all_atom_mask: atom37 representation of mask on all atom coordinates. + Returns: + Dictionary containing: + * 'rigidgroups_gt_frames': 8 Frames corresponding to 'all_atom_positions' + represented as flat 12 dimensional array. + * 'rigidgroups_gt_exists': Mask denoting whether the atom positions for + the given frame are available in the ground truth, e.g. if they were + resolved in the experiment. + * 'rigidgroups_group_exists': Mask denoting whether given group is in + principle present for given amino acid type. + * 'rigidgroups_group_is_ambiguous': Mask denoting whether frame is + affected by naming ambiguity. + * 'rigidgroups_alt_gt_frames': 8 Frames with alternative atom renaming + corresponding to 'all_atom_positions' represented as flat + 12 dimensional array. + """ + # 0: 'backbone group', + # 1: 'pre-omega-group', (empty) + # 2: 'phi-group', (currently empty, because it defines only hydrogens) + # 3: 'psi-group', + # 4,5,6,7: 'chi1,2,3,4-group' + aatype_in_shape = aatype.shape + + # If there is a batch axis, just flatten it away, and reshape everything + # back at the end of the function. + aatype = paddle.reshape(aatype, [-1]) + all_atom_positions = paddle.reshape(all_atom_positions, [-1, 37, 3]) + all_atom_mask = paddle.reshape(all_atom_mask, [-1, 37]) + + # Create an array with the atom names. + # shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3) + restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object) + + # 0: backbone frame + restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N'] + + # 3: 'psi-group' + restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O'] + + # 4,5,6,7: 'chi1,2,3,4-group' + for restype, restype_letter in enumerate(residue_constants.restypes): + resname = residue_constants.restype_1to3[restype_letter] + for chi_idx in range(4): + if residue_constants.chi_angles_mask[restype][chi_idx]: + atom_names = residue_constants.chi_angles_atoms[resname][chi_idx] + restype_rigidgroup_base_atom_names[ + restype, chi_idx + 4, :] = atom_names[1:] + + # Create mask for existing rigid groups. + restype_rigidgroup_mask = np.zeros([21, 8], dtype=np.float32) + restype_rigidgroup_mask[:, 0] = 1 + restype_rigidgroup_mask[:, 3] = 1 + restype_rigidgroup_mask[:20, 4:] = residue_constants.chi_angles_mask + + # Translate atom names into atom37 indices. + lookuptable = residue_constants.atom_order.copy() + lookuptable[''] = 0 + restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])( + restype_rigidgroup_base_atom_names) + restype_rigidgroup_base_atom37_idx = paddle.to_tensor(restype_rigidgroup_base_atom37_idx) + + # Compute the gather indices for all residues in the chain. + # shape (B, N, 8, 3) + residx_rigidgroup_base_atom37_idx = utils.batched_gather( + restype_rigidgroup_base_atom37_idx, aatype) + + # Gather the base atom positions for each rigid group. + base_atom_pos = utils.batched_gather( + all_atom_positions, + residx_rigidgroup_base_atom37_idx, + batch_dims=1) + + # Compute the Rigids. + gt_frames = r3.rigids_from_3_points_vecs( + point_on_neg_x_axis=r3.vecs_from_tensor(base_atom_pos[:, :, 0, :]), + origin=r3.vecs_from_tensor(base_atom_pos[:, :, 1, :]), + point_on_xy_plane=r3.vecs_from_tensor(base_atom_pos[:, :, 2, :]) + ) + + # Compute a mask whether the group exists. + # (B, N, 8) + restype_rigidgroup_mask = paddle.to_tensor(restype_rigidgroup_mask) + group_exists = utils.batched_gather(restype_rigidgroup_mask, aatype) + + # Compute a mask whether ground truth exists for the group + gt_atoms_exist = utils.batched_gather( # shape (B, N, 8, 3) + all_atom_mask.astype('float32'), + residx_rigidgroup_base_atom37_idx, + batch_dims=1) + gt_exists = paddle.min(gt_atoms_exist, axis=-1) * group_exists # (B, N, 8) + + # Adapt backbone frame to old convention (mirror x-axis and z-axis). + rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) + rots[0, 0, 0] = -1 + rots[0, 2, 2] = -1 + rots = paddle.to_tensor(rots, dtype='float32') + gt_frames = r3.rigids_mul_rots(gt_frames, r3.rots_from_tensor3x3(rots)) + + # The frames for ambiguous rigid groups are just rotated by 180 degree around + # the x-axis. The ambiguous group is always the last chi-group. + restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32) + restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1]) + + for resname, _ in residue_constants.residue_atom_renaming_swaps.items(): + restype = residue_constants.restype_order[ + residue_constants.restype_3to1[resname]] + chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1) + restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1 + restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1 + restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1 + + # Gather the ambiguity information for each residue. + restype_rigidgroup_is_ambiguous = paddle.to_tensor(restype_rigidgroup_is_ambiguous, dtype='float32') + restype_rigidgroup_rots = paddle.to_tensor(restype_rigidgroup_rots, dtype='float32') + residx_rigidgroup_is_ambiguous = utils.batched_gather( + restype_rigidgroup_is_ambiguous, aatype) + residx_rigidgroup_ambiguity_rot = utils.batched_gather( + restype_rigidgroup_rots, aatype) + + # Create the alternative ground truth frames. + alt_gt_frames = r3.rigids_mul_rots( + gt_frames, r3.rots_from_tensor3x3(residx_rigidgroup_ambiguity_rot)) + + gt_frames_flat12 = r3.rigids_to_tensor_flat12(gt_frames) + alt_gt_frames_flat12 = r3.rigids_to_tensor_flat12(alt_gt_frames) + + # reshape back to original residue layout + gt_frames_flat12 = paddle.reshape(gt_frames_flat12, aatype_in_shape + [8, 12]) + gt_exists = paddle.reshape(gt_exists, aatype_in_shape + [8]) + group_exists = paddle.reshape(group_exists, aatype_in_shape + [8]) + residx_rigidgroup_is_ambiguous = paddle.reshape( + residx_rigidgroup_is_ambiguous, aatype_in_shape + [8]) + alt_gt_frames_flat12 = paddle.reshape( + alt_gt_frames_flat12, aatype_in_shape + [8, 12]) + + return { + 'rigidgroups_gt_frames': gt_frames_flat12, # (B, N, 8, 12) + 'rigidgroups_gt_exists': gt_exists, # (B, N, 8) + 'rigidgroups_group_exists': group_exists, # (B, N, 8) + 'rigidgroups_group_is_ambiguous': residx_rigidgroup_is_ambiguous, # (B, N, 8) + 'rigidgroups_alt_gt_frames': alt_gt_frames_flat12, # (B, N, 8, 12) + } + + +def atom37_to_torsion_angles( + aatype: paddle.Tensor, # (B, T, N) + all_atom_pos: paddle.Tensor, # (B, T, N, 37, 3) + all_atom_mask: paddle.Tensor, # (B, T, N, 37) + placeholder_for_undefined=False, +) -> Dict[str, paddle.Tensor]: + """Computes the 7 torsion angles (in sin, cos encoding) for each residue. + + The 7 torsion angles are in the order + '[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]', + here pre_omega denotes the omega torsion angle between the given amino acid + and the previous amino acid. + + Args: + aatype: Amino acid type, given as array with integers. + all_atom_pos: atom37 representation of all atom coordinates. + all_atom_mask: atom37 representation of mask on all atom coordinates. + placeholder_for_undefined: flag denoting whether to set masked torsion + angles to zero. + Returns: + Dict containing: + * 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final + 2 dimensions denote sin and cos respectively + * 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but + with the angle shifted by pi for all chi angles affected by the naming + ambiguities. + * 'torsion_angles_mask': Mask for which chi angles are present. + """ + + # Map aatype > 20 to 'Unknown' (20). + aatype = paddle.minimum(aatype.astype('int'), paddle.to_tensor([20]).astype('int')) + + num_batch, num_temp, num_res = aatype.shape + + # Compute the backbone angles. + pad = paddle.zeros([num_batch, num_temp, 1, 37, 3]) + prev_all_atom_pos = paddle.concat([pad, all_atom_pos[..., :-1, :, :]], axis=-3) + + pad = paddle.zeros([num_batch, num_temp, 1, 37]) + prev_all_atom_mask = paddle.concat([pad, all_atom_mask[..., :-1, :]], axis=-2) + + # For each torsion angle collect the 4 atom positions that define this angle. + # shape (B, T, N, atoms=4, xyz=3) + pre_omega_atom_pos = paddle.concat( + [prev_all_atom_pos[..., 1:3, :], # prev CA, C + all_atom_pos[..., 0:2, :] # this N, CA + ], axis=-2) + + phi_atom_pos = paddle.concat( + [prev_all_atom_pos[..., 2:3, :], # prev C + all_atom_pos[..., 0:3, :] # this N, CA, C + ], axis=-2) + + psi_atom_pos = paddle.concat( + [all_atom_pos[..., 0:3, :], # this N, CA, C + all_atom_pos[..., 4:5, :] # this O + ], axis=-2) + + # Collect the masks from these atoms. + # Shape [batch, n_temp, num_res] + pre_omega_mask = ( + paddle.prod(prev_all_atom_mask[..., 1:3], axis=-1) # prev CA, C + * paddle.prod(all_atom_mask[..., 0:2], axis=-1)) # this N, CA + phi_mask = ( + prev_all_atom_mask[..., 2] # prev C + * paddle.prod(all_atom_mask[..., 0:3], axis=-1)) # this N, CA, C + psi_mask = ( + paddle.prod(all_atom_mask[..., 0:3], axis=-1) * # this N, CA, C + all_atom_mask[..., 4]) # this O + + # Collect the atoms for the chi-angles. + # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. + chi_atom_indices = get_chi_atom_indices() + + # Select atoms to compute chis. Shape: [batch, num_temp, num_res, chis=4, atoms=4]. + atom_indices = utils.batched_gather( + params=chi_atom_indices, indices=aatype, axis=0, batch_dims=0) + + # Gather atom positions. Shape: [batch, num_temp, num_res, chis=4, atoms=4, xyz=3]. + chis_atom_pos = utils.batched_gather( + params=all_atom_pos, indices=atom_indices, axis=0, + batch_dims=3) + + # Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4]. + chi_angles_mask = list(residue_constants.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = paddle.to_tensor(chi_angles_mask) + + # Compute the chi angle mask. I.e. which chis angles exist according to the + # aatype. Shape [batch, num_temp, num_res, chis=4]. + chis_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype, + axis=0, batch_dims=0) + # Constrain the chis_mask to those chis, where the ground truth coordinates of + # all defining four atoms are available. + # Gather the chi angle atoms mask. Shape: [batch, num_temp, num_res, chis=4, atoms=4]. + chi_angle_atoms_mask = utils.batched_gather( + params=all_atom_mask, indices=atom_indices, axis=0, + batch_dims=3) + # Check if all 4 chi angle atoms were set. Shape: [batch, num_temp, num_res, chis=4]. + chi_angle_atoms_mask = paddle.prod(chi_angle_atoms_mask, axis=[-1]) + chis_mask = chis_mask * chi_angle_atoms_mask + + # Stack all torsion angle atom positions. + # Shape (B, T, N, torsions=7, atoms=4, xyz=3) + torsions_atom_pos = paddle.concat( + [pre_omega_atom_pos[:, :, :, None, :, :], + phi_atom_pos[:, :, :, None, :, :], + psi_atom_pos[:, :, :, None, :, :], + chis_atom_pos + ], axis=3) + + # Stack up masks for all torsion angles. + # shape (B, T, N, torsions=7) + torsion_angles_mask = paddle.concat( + [pre_omega_mask[..., None], + phi_mask[..., None], + psi_mask[..., None], + chis_mask + ], axis=-1) + + # Create a frame from the first three atoms: + # First atom: point on x-y-plane + # Second atom: point on negative x-axis + # Third atom: origin + # r3.Rigids (B, T, N, torsions=7) + torsion_frames = r3.rigids_from_3_points_vecs( + point_on_neg_x_axis=r3.Vecs(torsions_atom_pos[..., 1, :]), + origin=r3.Vecs(torsions_atom_pos[..., 2, :]), + point_on_xy_plane=r3.Vecs(torsions_atom_pos[..., 0, :])) + + # Compute the position of the forth atom in this frame (y and z coordinate + # define the chi angle) + # r3.Vecs (B, T, N, torsions=7) + forth_atom_rel_pos = r3.rigids_mul_vecs( + r3.invert_rigids(torsion_frames), + r3.vecs_from_tensor(torsions_atom_pos[..., 3, :])) + + # Normalize to have the sin and cos of the torsion angle. + # paddle.Tensor (B, T, N, torsions=7, sincos=2) + torsion_angles_sin_cos = paddle.stack( + [forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1) + torsion_angles_sin_cos /= paddle.sqrt( + paddle.sum(paddle.square(torsion_angles_sin_cos), axis=-1, keepdim=True) + + 1e-8) + + # Mirror psi, because we computed it from the Oxygen-atom. + torsion_angles_sin_cos *= paddle.to_tensor( + [1., 1., -1., 1., 1., 1., 1.])[None, None, None, :, None] + + # Create alternative angles for ambiguous atom names. + chi_is_ambiguous = utils.batched_gather( + paddle.to_tensor(residue_constants.chi_pi_periodic), aatype) + # chi_is_ambiguous (B, T, N, torsions=4) + mirror_torsion_angles = paddle.concat( + [paddle.ones([num_batch, num_temp, num_res, 3]), + 1.0 - 2.0 * chi_is_ambiguous], axis=-1) + # mirror_torsion_angles (B, T, N, torsions=7) + alt_torsion_angles_sin_cos = ( + torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, :, None]) + + if placeholder_for_undefined: + # Add placeholder torsions in place of undefined torsion angles + # (e.g. N-terminus pre-omega) + placeholder_torsions = paddle.stack([ + paddle.ones(torsion_angles_sin_cos.shape[:-1]), + paddle.zeros(torsion_angles_sin_cos.shape[:-1]) + ], axis=-1) + torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[ + ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[ + ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + + return { + 'torsion_angles_sin_cos': torsion_angles_sin_cos, # (B, T, N, 7, 2) + 'alt_torsion_angles_sin_cos': alt_torsion_angles_sin_cos, # (B, T, N, 7, 2) + 'torsion_angles_mask': torsion_angles_mask # (B, T, N, 7) + } + + +def torsion_angles_to_frames( + aatype: paddle.Tensor, # (B, N) + backb_to_global: r3.Rigids, # (B, N) + torsion_angles_sin_cos: paddle.Tensor # (B, N, 7, 2) +) -> r3.Rigids: # (B, N, 8) + """Compute rigid group frames from torsion angles. + + Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" lines 2-10 + Jumper et al. (2021) Suppl. Alg. 25 "makeRotX" + + Args: + aatype: aatype for each residue + backb_to_global: Rigid transformations describing transformation from + backbone frame to global frame. + torsion_angles_sin_cos: sin and cosine of the 7 torsion angles + Returns: + Frames corresponding to all the Sidechain Rigid Transforms + """ + assert len(aatype.shape) == 2 + assert len(backb_to_global.rot.xx.shape) == 2 + assert len(torsion_angles_sin_cos.shape) == 4 + assert torsion_angles_sin_cos.shape[2] == 7 + assert torsion_angles_sin_cos.shape[3] == 2 + + # Gather the default frames for all rigid groups.(New) + # # r3.Rigids with shape (B, N, 8) + restype_rigid_group_default_frame = paddle.to_tensor( + residue_constants.restype_rigid_group_default_frame) # (21, 8, 4, 4) + # (B, N, 8, 4, 4) + m = utils.batched_gather(restype_rigid_group_default_frame, aatype) + + default_frames = r3.rigids_from_tensor4x4(m) + + # Create the rotation matrices according to the given angles (each frame is + # defined such that its rotation is around the x-axis). + sin_angles = torsion_angles_sin_cos[..., 0] + cos_angles = torsion_angles_sin_cos[..., 1] + + # insert zero rotation for backbone group. + num_batch, num_residues = aatype.shape + sin_angles = paddle.concat([paddle.zeros([num_batch, num_residues, 1]), sin_angles], + axis=-1) + cos_angles = paddle.concat([paddle.ones([num_batch, num_residues, 1]), cos_angles], + axis=-1) + zeros = paddle.zeros_like(sin_angles) + ones = paddle.ones_like(sin_angles) + + # all_rots are r3.Rots with shape (B, N, 8) + all_rots = r3.Rots(ones, zeros, zeros, + zeros, cos_angles, -sin_angles, + zeros, sin_angles, cos_angles) + + # Apply rotations to the frames. + all_frames = r3.rigids_mul_rots(default_frames, all_rots) + + # slice, concat and unsqueeze Rigids + def slice_rigids(rigid, start, end): + """slice along the last axis of rot.xx and trans.x""" + assert len(rigid.rot.xx.shape) == 3 + rotation = rigid.rot.rotation[..., start:end, :, :] + translation = rigid.trans.translation[..., start:end, :] + return r3.Rigids(rot=r3.Rots(rotation), trans=r3.Vecs(translation)) + + def concat_rigids(*arg): + """concat along the last axis of rot.xx and trans.x""" + assert len(arg) > 1 + assert len(arg[0].rot.xx.shape) == len(arg[1].rot.xx.shape) + rotation = paddle.concat([r.rot.rotation for r in arg], axis=-3) + translation = paddle.concat([r.trans.translation for r in arg], axis=-2) + return r3.Rigids(rot=r3.Rots(rotation), trans=r3.Vecs(translation)) + + def unsqueeze_rigids(rigid, axis=-1): + """add an axis in the axis of rot.xx and trans.x""" + if axis < 0: + axis_t = axis - 1 + axis_r = axis - 2 + else: + axis_t = axis + axis_r = axis + + rotation = paddle.unsqueeze(rigid.rot.rotation, axis=axis_r) + translation = paddle.unsqueeze(rigid.trans.translation, axis=axis_t) + return r3.Rigids(rot=r3.Rots(rotation), trans=r3.Vecs(translation)) + + # chi2, chi3, and chi4 frames do not transform to the backbone frame but to + # the previous frame. So chain them up accordingly. + + chi2_frame_to_frame = slice_rigids(all_frames, 5, 6) + chi3_frame_to_frame = slice_rigids(all_frames, 6, 7) + chi4_frame_to_frame = slice_rigids(all_frames, 7, 8) + + chi1_frame_to_backb = slice_rigids(all_frames, 4, 5) + chi2_frame_to_backb = r3.rigids_mul_rigids(chi1_frame_to_backb, + chi2_frame_to_frame) + chi3_frame_to_backb = r3.rigids_mul_rigids(chi2_frame_to_backb, + chi3_frame_to_frame) + chi4_frame_to_backb = r3.rigids_mul_rigids(chi3_frame_to_backb, + chi4_frame_to_frame) + + all_frames_to_backb = concat_rigids( + slice_rigids(all_frames, 0, 5), + chi2_frame_to_backb, + chi3_frame_to_backb, + chi4_frame_to_backb) + + # Create the global frames. + # shape (B, N, 8) + all_frames_to_global = r3.rigids_mul_rigids( + unsqueeze_rigids(backb_to_global), + all_frames_to_backb) + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + aatype: paddle.Tensor, # (B, N) + all_frames_to_global: r3.Rigids # (B, N, 8) +) -> r3.Vecs: # (B, N, 14) + """Put atom literature positions (atom14 encoding) in each rigid group. + + Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11 + + Args: + aatype: aatype for each residue. + all_frames_to_global: All per residue coordinate frames. + Returns: + Positions of all atom coordinates in global frame. + """ + # Pick the appropriate transform for every atom. + restype_atom14_to_rigid_group = paddle.to_tensor( + residue_constants.restype_atom14_to_rigid_group)[None, ...] + + # [1, 21, 14] -> # [n_batch, 21, 14] + n_batch = aatype.shape[0] + if n_batch > 1: + restype_atom14_to_rigid_group = paddle.tile( + restype_atom14_to_rigid_group, repeat_times=[n_batch, 1, 1]) + + residx_to_group_idx = utils.batched_gather( + restype_atom14_to_rigid_group, + aatype, batch_dims=1) + + # 8 rigid groups: + # 0: 'backbone group', + # 1: 'pre-omega-group', (empty) + # 2: 'phi-group', (currently empty, because it defines only hydrogens) + # 3: 'psi-group', + # 4,5,6,7: 'chi1,2,3,4-group' + # (B, N, 14, 8) + group_mask = nn.functional.one_hot( + residx_to_group_idx, num_classes=8) + + def _convert(x, y): + return paddle.sum(paddle.unsqueeze(x, -2) * y, axis=-1) + + # r3.Rigids with shape (B, N, 14) + map_atoms_to_global = r3.Rigids( + rot=all_frames_to_global.rot.map(_convert, group_mask), + trans=all_frames_to_global.trans.map(_convert, group_mask)) + + # Gather the literature atom positions for each residue. + # r3.Vecs with shape (B, N, 14) + restype_atom14_rigid_group_positions = paddle.to_tensor( + residue_constants.restype_atom14_rigid_group_positions)[None, ...] + # [1, 21, 14, 3] -> [B, 21, 14, 3] + if n_batch > 1: + restype_atom14_rigid_group_positions = paddle.tile( + restype_atom14_rigid_group_positions, repeat_times=[n_batch, 1, 1, 1]) + + lit_positions = r3.vecs_from_tensor( + utils.batched_gather( + restype_atom14_rigid_group_positions, + aatype, batch_dims=1)) + + # Transform each atom from its local frame to the global frame. + # r3.Vecs with shape (B, N, 14) + pred_positions = r3.rigids_mul_vecs(map_atoms_to_global, lit_positions) + + # Mask out non-existing atoms. + restype_atom14_mask = paddle.to_tensor( + residue_constants.restype_atom14_mask)[None, ...] + # [1, 21, 14] -> [B, 21, 14] + if n_batch > 1: + restype_atom14_mask = paddle.tile( + restype_atom14_mask, repeat_times=[n_batch, 1, 1]) + + mask = utils.batched_gather( + restype_atom14_mask, aatype, batch_dims=1) + pred_positions = pred_positions.map(lambda x, m: x * m, mask) + + return pred_positions + + +def extreme_ca_ca_distance_violations( + pred_atom_positions: paddle.Tensor, # (B, N, 37(14), 3) + pred_atom_mask: paddle.Tensor, # (B, N, 37(14)) + residue_index: paddle.Tensor, # (B, N) + max_angstrom_tolerance=1.5 + ) -> paddle.Tensor: + """Counts residues whose Ca is a large distance from its neighbour. + + Measures the fraction of CA-CA pairs between consecutive amino acids that are + more than 'max_angstrom_tolerance' apart. + + Args: + pred_atom_positions: Atom positions in atom37/14 representation + pred_atom_mask: Atom mask in atom37/14 representation + residue_index: Residue index for given amino acid, this is assumed to be + monotonically increasing. + max_angstrom_tolerance: Maximum distance allowed to not count as violation. + Returns: + Fraction of consecutive CA-CA pairs with violation. + """ + batch_size = pred_atom_positions.shape[0] + this_ca_pos = pred_atom_positions[:, :-1, 1, :] # (B, N - 1, 3) + this_ca_mask = pred_atom_mask[:, :-1, 1] # (B, N - 1) + next_ca_pos = pred_atom_positions[:, 1:, 1, :] # (B, N - 1, 3) + next_ca_mask = pred_atom_mask[:, 1:, 1] # (B, N - 1) + has_no_gap_mask = ((residue_index[:, 1:] - residue_index[:, :-1]) == 1.0) + + ca_ca_distance = paddle.sqrt(1e-6 + paddle.sum(squared_difference(this_ca_pos, next_ca_pos), axis=-1)) + violations = (ca_ca_distance - residue_constants.ca_ca) > max_angstrom_tolerance + mask = this_ca_mask * next_ca_mask * has_no_gap_mask + ca_ca_violation_tmp = [] + for i in range(batch_size): + ca_ca_violation_i = utils.mask_mean(mask=mask[i], value=violations[i]) + ca_ca_violation_tmp.append(ca_ca_violation_i) + ca_ca_violation = paddle.to_tensor(ca_ca_violation_tmp, stop_gradient=False) + ca_ca_violation = paddle.squeeze(ca_ca_violation, axis=-1) + return ca_ca_violation + + +def between_residue_bond_loss( + pred_atom_positions: paddle.Tensor, # (B, N, 37(14), 3) + pred_atom_mask: paddle.Tensor, # (B, N, 37(14)) + residue_index: paddle.Tensor, # (B, N) + aatype: paddle.Tensor, # (B, N) + tolerance_factor_soft=12.0, + tolerance_factor_hard=12.0 +) -> Dict[str, paddle.Tensor]: + """Flat-bottom loss to penalize structural violations between residues. + + This is a loss penalizing any violation of the geometry around the peptide + bond between consecutive amino acids. This loss corresponds to + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45. + + Args: + pred_atom_positions: Atom positions in atom37/14 representation + pred_atom_mask: Atom mask in atom37/14 representation + residue_index: Residue index for given amino acid, this is assumed to be + monotonically increasing. + aatype: Amino acid type of given residue + tolerance_factor_soft: soft tolerance factor measured in standard deviations + of pdb distributions + tolerance_factor_hard: hard tolerance factor measured in standard deviations + of pdb distributions + + Returns: + Dict containing: + * 'c_n_loss_mean': Loss for peptide bond length violations + * 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned + by CA, C, N + * 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned + by C, N, CA + * 'per_residue_loss_sum': sum of all losses for each residue + * 'per_residue_violation_mask': mask denoting all residues with violation + present. + """ + batch_size = aatype.shape[0] + + assert len(pred_atom_positions.shape) == 4 + assert len(pred_atom_mask.shape) == 3 + assert len(residue_index.shape) == 2 + assert len(aatype.shape) == 2 + + # Get the positions of the relevant backbone atoms. + this_ca_pos = pred_atom_positions[:, :-1, 1, :] # (B, N - 1, 3) + this_ca_mask = pred_atom_mask[:, :-1, 1] # (B, N - 1) + this_c_pos = pred_atom_positions[:, :-1, 2, :] # (B, N - 1, 3) + this_c_mask = pred_atom_mask[:, :-1, 2] # (B, N - 1) + next_n_pos = pred_atom_positions[:, 1:, 0, :] # (B, N - 1, 3) + next_n_mask = pred_atom_mask[:, 1:, 0] # (B, N - 1) + next_ca_pos = pred_atom_positions[:, 1:, 1, :] # (B, N - 1, 3) + next_ca_mask = pred_atom_mask[:, 1:, 1] # (B, N - 1) + has_no_gap_mask = ((residue_index[:, 1:] - residue_index[:, :-1]) == 1.0) + + + # Compute loss for the C--N bond. + c_n_bond_length = paddle.sqrt(1e-6 + paddle.sum(squared_difference(this_c_pos, next_n_pos), axis=-1)) + + # The C-N bond to proline has slightly different length because of the ring. + next_is_proline = paddle.cast((aatype[:, 1:] == residue_constants.resname_to_idx['PRO']), 'float32') + gt_length = ( + (1. - next_is_proline) * residue_constants.between_res_bond_length_c_n[0] + + next_is_proline * residue_constants.between_res_bond_length_c_n[1]) + gt_stddev = ( + (1. - next_is_proline) * + residue_constants.between_res_bond_length_stddev_c_n[0] + + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1]) + c_n_bond_length_error = paddle.sqrt(1e-6 + paddle.square(c_n_bond_length - gt_length)) + c_n_loss_per_residue = nn.functional.relu(c_n_bond_length_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * has_no_gap_mask + c_n_loss = paddle.sum(mask * c_n_loss_per_residue, axis=-1) / (paddle.sum(mask, axis=-1) + 1e-6) + c_n_violation_mask = mask * (c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)) + + # Compute loss for the angles. + ca_c_bond_length = paddle.sqrt(1e-6 + paddle.sum( + squared_difference(this_ca_pos, this_c_pos), axis=-1)) + n_ca_bond_length = paddle.sqrt(1e-6 + paddle.sum( + squared_difference(next_n_pos, next_ca_pos), axis=-1)) + + ca_c_bond_length = paddle.unsqueeze(ca_c_bond_length, axis=-1) + c_n_bond_length = paddle.unsqueeze(c_n_bond_length, axis=-1) + n_ca_bond_length = paddle.unsqueeze(n_ca_bond_length, axis=-1) + c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length + c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length + n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length + + ca_c_n_cos_angle = paddle.sum(c_ca_unit_vec * c_n_unit_vec, axis=-1) + gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] + gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0] + ca_c_n_cos_angle_error = paddle.sqrt(1e-6 + paddle.square(ca_c_n_cos_angle - gt_angle)) + ca_c_n_loss_per_residue = nn.functional.relu(ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask + ca_c_n_loss = paddle.sum(mask * ca_c_n_loss_per_residue, axis=-1) / (paddle.sum(mask, axis=-1) + 1e-6) + ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error > + (tolerance_factor_hard * gt_stddev)) + + c_n_ca_cos_angle = paddle.sum((-c_n_unit_vec) * n_ca_unit_vec, axis=-1) + gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0] + gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1] + c_n_ca_cos_angle_error = paddle.sqrt(1e-6 + paddle.square(c_n_ca_cos_angle - gt_angle)) + c_n_ca_loss_per_residue = nn.functional.relu(c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask + c_n_ca_loss = paddle.sum(mask * c_n_ca_loss_per_residue, axis=-1) / (paddle.sum(mask, axis=-1) + 1e-6) + c_n_ca_violation_mask = mask * (c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)) + + # Compute a per residue loss (equally distribute the loss to both + # neighbouring residues). + tmpsum = paddle.zeros(shape=[batch_size, 1]) + per_residue_loss_sum = (c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue) + tmp_per_residue_loss1 = paddle.concat(x=[per_residue_loss_sum, tmpsum], axis=-1) + tmp_per_residue_loss2 = paddle.concat(x=[tmpsum, per_residue_loss_sum], axis=-1) + per_residue_loss_sum = 0.5 * (tmp_per_residue_loss1 + tmp_per_residue_loss2) + + # Compute hard violations. + violation_mask = paddle.max( + paddle.stack([c_n_violation_mask, + ca_c_n_violation_mask, + c_n_ca_violation_mask]), axis=0) + tmp_violation_mask1 = paddle.concat(x=[violation_mask, tmpsum], axis=-1) + tmp_violation_mask2 = paddle.concat(x=[tmpsum, violation_mask], axis=-1) + violation_mask = paddle.maximum(tmp_violation_mask1, tmp_violation_mask2) + + return {'c_n_loss_mean': c_n_loss, # shape (B) + 'ca_c_n_loss_mean': ca_c_n_loss, # shape (B) + 'c_n_ca_loss_mean': c_n_ca_loss, # shape (B) + 'per_residue_loss_sum': per_residue_loss_sum, # shape (B, N) + 'per_residue_violation_mask': violation_mask # shape (B, N) + } + + +def between_residue_clash_loss( + atom14_pred_positions: paddle.Tensor, # (B, N, 14, 3) + atom14_atom_exists: paddle.Tensor, # (B, N, 14) + atom14_atom_radius: paddle.Tensor, # (B, N, 14) + residue_index: paddle.Tensor, # (B, N) + overlap_tolerance_soft=1.5, + overlap_tolerance_hard=1.5 +) -> Dict[str, paddle.Tensor]: + """Loss to penalize steric clashes between residues. + + This is a loss penalizing any steric clashes due to non bonded atoms in + different peptides coming too close. This loss corresponds to the part with + different residues of + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. + + Args: + atom14_pred_positions: Predicted positions of atoms in + global prediction frame + atom14_atom_exists: Mask denoting whether atom at positions exists for given + amino acid type + atom14_atom_radius: Van der Waals radius for each atom. + residue_index: Residue index for given amino acid. + overlap_tolerance_soft: Soft tolerance factor. + overlap_tolerance_hard: Hard tolerance factor. + + Returns: + Dict containing: + * 'mean_loss': average clash loss + * 'per_atom_loss_sum': sum of all clash losses per atom, shape (B, N, 14) + * 'per_atom_clash_mask': mask whether atom clashes with any other atom + shape (B, N, 14) + """ + assert len(atom14_pred_positions.shape) == 4 + assert len(atom14_atom_exists.shape) == 3 + assert len(atom14_atom_radius.shape) == 3 + assert len(residue_index.shape) == 2 + + # Create the distance matrix. + # (B, N, N, 14, 14) + atom14_pred_positions1 = paddle.unsqueeze(atom14_pred_positions, axis=[2,4]) + atom14_pred_positions2 = paddle.unsqueeze(atom14_pred_positions, axis=[1,3]) + dists = paddle.sqrt(1e-10 + paddle.sum(squared_difference(atom14_pred_positions1, atom14_pred_positions2), axis=-1)) + + # Create the mask for valid distances. + # shape (B, N, N, 14, 14) + atom14_atom_exists1 = paddle.unsqueeze(atom14_atom_exists, axis=[2,4]) + atom14_atom_exists2 = paddle.unsqueeze(atom14_atom_exists, axis=[1,3]) + dists_mask = (atom14_atom_exists1 * atom14_atom_exists2) + + # Mask out all the duplicate entries in the lower triangular matrix. + # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms + # are handled separately. + residue_index1 = paddle.unsqueeze(residue_index, axis=[2,3,4]) + residue_index2 = paddle.unsqueeze(residue_index, axis=[1,3,4]) + dists_mask *= (residue_index1 < residue_index2) + + # Backbone C--N bond between subsequent residues is no clash. + c_one_hot = nn.functional.one_hot(paddle.to_tensor([2]), num_classes=14) + n_one_hot = nn.functional.one_hot(paddle.to_tensor([0]), num_classes=14) + neighbour_mask = ((residue_index1 + 1) == residue_index2) + tmp_c_one_hot = paddle.unsqueeze(c_one_hot, axis=[1,2,4]) + tmp_n_one_hot = paddle.unsqueeze(n_one_hot, axis=[1,2,3]) + c_n_bonds = neighbour_mask * tmp_c_one_hot * tmp_n_one_hot + + dists_mask *= (1. - c_n_bonds) + + # Disulfide bridge between two cysteines is no clash. + cys_sg_idx = residue_constants.restype_name_to_atom14_names['CYS'].index('SG') + cys_sg_one_hot = nn.functional.one_hot(paddle.to_tensor(cys_sg_idx), num_classes=14) + cys_sg_one_hot1 = paddle.unsqueeze(cys_sg_one_hot, axis=[1,2,4]) + cys_sg_one_hot2 = paddle.unsqueeze(cys_sg_one_hot, axis=[1,2,3]) + disulfide_bonds = (cys_sg_one_hot1 * cys_sg_one_hot2) + dists_mask *= (1. - disulfide_bonds) + + # Compute the lower bound for the allowed distances. + # shape (B, N, N, 14, 14) + atom14_atom_radius1 = paddle.unsqueeze(atom14_atom_radius, axis=[2,4]) + atom14_atom_radius2 = paddle.unsqueeze(atom14_atom_radius, axis=[1,3]) + dists_lower_bound = dists_mask * (atom14_atom_radius1 + atom14_atom_radius2) + + # Compute the error. + # shape (B, N, N, 14, 14) + dists_to_low_error = dists_mask * nn.functional.relu(dists_lower_bound - overlap_tolerance_soft - dists) + + # Compute the mean loss. + # shape (B) + mean_loss = (paddle.sum(dists_to_low_error, axis=[1,2,3,4]) / (1e-6 + paddle.sum(dists_mask, axis=[1,2,3,4]))) + + # Compute the per atom loss sum. + # shape (B, N, 14) + per_atom_loss_sum = (paddle.sum(dists_to_low_error, axis=[1, 3]) + + paddle.sum(dists_to_low_error, axis=[2, 4])) + + # Compute the hard clash mask. + # shape (B, N, N, 14, 14) + clash_mask = dists_mask * (dists < (dists_lower_bound - overlap_tolerance_hard)) + + # Compute the per atom clash. + # shape (B, N, 14) + per_atom_clash_mask = paddle.maximum( + paddle.max(clash_mask, axis=[1, 3]), + paddle.max(clash_mask, axis=[2, 4])) + + return {'mean_loss': mean_loss, # shape (B) + 'per_atom_loss_sum': per_atom_loss_sum, # shape (B, N, 14) + 'per_atom_clash_mask': per_atom_clash_mask # shape (B, N, 14) + } + + +def within_residue_violations( + atom14_pred_positions: paddle.Tensor, # (B, N, 14, 3) + atom14_atom_exists: paddle.Tensor, # (B, N, 14) + atom14_dists_lower_bound: paddle.Tensor, # (B, N, 14, 14) + atom14_dists_upper_bound: paddle.Tensor, # (B, N, 14, 14) + tighten_bounds_for_loss=0.0, +) -> Dict[str, paddle.Tensor]: + """Loss to penalize steric clashes within residues. + + This is a loss penalizing any steric violations or clashes of non-bonded atoms + in a given peptide. This loss corresponds to the part with + the same residues of + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. + + Args: + atom14_pred_positions: Predicted positions of atoms in + global prediction frame + atom14_atom_exists: Mask denoting whether atom at positions exists for given + amino acid type + atom14_dists_lower_bound: Lower bound on allowed distances. + atom14_dists_upper_bound: Upper bound on allowed distances + tighten_bounds_for_loss: Extra factor to tighten loss + + Returns: + Dict containing: + * 'per_atom_loss_sum': sum of all clash losses per atom, shape (B, N, 14) + * 'per_atom_clash_mask': mask whether atom clashes with any other atom + shape (B, N, 14) + """ + assert len(atom14_pred_positions.shape) == 4 + assert len(atom14_atom_exists.shape) == 3 + assert len(atom14_dists_lower_bound.shape) == 4 + assert len(atom14_dists_upper_bound.shape) == 4 + + # Compute the mask for each residue. + # shape (B, N, 14, 14) + dists_masks = (1. - paddle.unsqueeze(paddle.eye(14, 14), axis=[0, 1])) + atom14_atom_exists1 = paddle.unsqueeze(atom14_atom_exists, axis=-1) + atom14_atom_exists2 = paddle.unsqueeze(atom14_atom_exists, axis=-2) + dists_masks *= (atom14_atom_exists1 * atom14_atom_exists2) + + # Distance matrix + # shape (B, N, 14, 14) + atom14_pred_positions1 = paddle.unsqueeze(atom14_pred_positions, axis=-2) + atom14_pred_positions2 = paddle.unsqueeze(atom14_pred_positions, axis=-3) + dists = paddle.sqrt(1e-10 + paddle.sum( + squared_difference(atom14_pred_positions1, atom14_pred_positions2), + axis=-1)) + + # Compute the loss. + # shape (B, N, 14, 14) + dists_to_low_error = nn.functional.relu( + atom14_dists_lower_bound + tighten_bounds_for_loss - dists) + dists_to_high_error = nn.functional.relu( + dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)) + loss = dists_masks * (dists_to_low_error + dists_to_high_error) + + # Compute the per atom loss sum. + # shape (B, N, 14) + per_atom_loss_sum = (paddle.sum(loss, axis=2) + paddle.sum(loss, axis=3)) + + # Compute the violations mask. + # shape (B, N, 14, 14) + violations = dists_masks * ((dists < atom14_dists_lower_bound) | + (dists > atom14_dists_upper_bound)) + + # Compute the per atom violations. + # shape (B, N, 14) + per_atom_violations = paddle.maximum(paddle.max(violations, axis=2), paddle.max(violations, axis=3)) + + return {'per_atom_loss_sum': per_atom_loss_sum, # shape (B, N, 14) + 'per_atom_violations': per_atom_violations # shape (B, N, 14) + } + + +def find_optimal_renaming( + atom14_gt_positions: paddle.Tensor, # (B, N, 14, 3) + atom14_alt_gt_positions: paddle.Tensor, # (B, N, 14, 3) + atom14_atom_is_ambiguous: paddle.Tensor, # (B, N, 14) + atom14_gt_exists: paddle.Tensor, # (B, N, 14) + atom14_pred_positions: paddle.Tensor, # (B, N, 14, 3) + atom14_atom_exists: paddle.Tensor, # (B, N, 14) +) -> paddle.Tensor: # (B, N): + """Find optimal renaming for ground truth that maximizes LDDT. + + Jumper et al. (2021) Suppl. Alg. 26 + "renameSymmetricGroundTruthAtoms" lines 1-5 + + Args: + atom14_gt_positions: Ground truth positions in global frame of ground truth. + atom14_alt_gt_positions: Alternate ground truth positions in global frame of + ground truth with coordinates of ambiguous atoms swapped relative to + 'atom14_gt_positions'. + atom14_atom_is_ambiguous: Mask denoting whether atom is among ambiguous + atoms, see Jumper et al. (2021) Suppl. Table 3 + atom14_gt_exists: Mask denoting whether atom at positions exists in ground + truth. + atom14_pred_positions: Predicted positions of atoms in + global prediction frame + atom14_atom_exists: Mask denoting whether atom at positions exists for given + amino acid type + + Returns: + Float array of shape [N] with 1. where atom14_alt_gt_positions is closer to + prediction and 0. otherwise + """ + assert len(atom14_gt_positions.shape) == 4 + assert len(atom14_alt_gt_positions.shape) == 4 + assert len(atom14_atom_is_ambiguous.shape) == 3 + assert len(atom14_gt_exists.shape) == 3 + assert len(atom14_pred_positions.shape) == 4 + assert len(atom14_atom_exists.shape) == 3 + + # Create the pred distance matrix. + # shape (B, N, N, 14, 14) + atom14_pred_positions1 = paddle.unsqueeze(atom14_pred_positions, axis=[2,4]) + atom14_pred_positions2 = paddle.unsqueeze(atom14_pred_positions, axis=[1,3]) + pred_dists = paddle.sqrt(1e-10 + paddle.sum( + squared_difference(atom14_pred_positions1, atom14_pred_positions2), + axis=-1)) + + # Compute distances for ground truth with original and alternative names. + # shape (B, N, N, 14, 14) + atom14_gt_positions1 = paddle.unsqueeze(atom14_gt_positions, axis=[2,4]) + atom14_gt_positions2 = paddle.unsqueeze(atom14_gt_positions, axis=[1,3]) + gt_dists = paddle.sqrt(1e-10 + paddle.sum( + squared_difference(atom14_gt_positions1, atom14_gt_positions2), + axis=-1)) + + atom14_alt_gt_positions1 = paddle.unsqueeze(atom14_alt_gt_positions, axis=[2,4]) + atom14_alt_gt_positions2 = paddle.unsqueeze(atom14_alt_gt_positions, axis=[1,3]) + alt_gt_dists = paddle.sqrt(1e-10 + paddle.sum( + squared_difference(atom14_alt_gt_positions1, atom14_alt_gt_positions2), + axis=-1)) + + # Compute LDDT's. + # shape (B, N, N, 14, 14) + lddt = paddle.sqrt(1e-10 + squared_difference(pred_dists, gt_dists)) + alt_lddt = paddle.sqrt(1e-10 + squared_difference(pred_dists, alt_gt_dists)) + + # Create a mask for ambiguous atoms in rows vs. non-ambiguous atoms + # in cols. + # shape (B, N, N, 14, 14) + atom14_gt_exists1 = paddle.unsqueeze(atom14_gt_exists, axis=[2,4]) + atom14_gt_exists2 = paddle.unsqueeze(atom14_gt_exists, axis=[1,3]) + atom14_atom_is_ambiguous1 = paddle.unsqueeze(atom14_atom_is_ambiguous, axis=[2,4]) + atom14_atom_is_ambiguous2 = paddle.unsqueeze(atom14_atom_is_ambiguous, axis=[1,3]) + mask = (atom14_gt_exists1 * # rows + atom14_atom_is_ambiguous1 * # rows + atom14_gt_exists2 * # cols + (1. - atom14_atom_is_ambiguous2)) # cols + + # Aggregate distances for each residue to the non-amibuguous atoms. + # shape (B, N) + per_res_lddt = paddle.sum(mask * lddt, axis=[2, 3, 4]) + alt_per_res_lddt = paddle.sum(mask * alt_lddt, axis=[2, 3, 4]) + + # Decide for each residue, whether alternative naming is better. + # shape (B, N) + alt_naming_is_better = paddle.cast((alt_per_res_lddt < per_res_lddt), 'float32') + + return alt_naming_is_better # shape (B, N) + + +def frame_aligned_point_error( + pred_frames: r3.Rigids, + target_frames: r3.Rigids, + frames_mask: paddle.Tensor, + pred_positions: r3.Vecs, + target_positions: r3.Vecs, + positions_mask: paddle.Tensor, + length_scale: float, + l1_clamp_distance: Optional[float] = None, + epsilon=1e-4) -> paddle.Tensor: + """Measure point error under different alignments. + + Jumper et al. (2021) Suppl. Alg. 28 "computeFAPE" + + Computes error between two structures with B points under A alignments derived + from the given pairs of frames. + Args: + pred_frames: num_frames reference frames for 'pred_positions'. + target_frames: num_frames reference frames for 'target_positions'. + frames_mask: Mask for frame pairs to use. + pred_positions: num_positions predicted positions of the structure. + target_positions: num_positions target positions of the structure. + positions_mask: Mask on which positions to score. + length_scale: length scale to divide loss by. + l1_clamp_distance: Distance cutoff on error beyond which gradients will + be zero. + epsilon: small value used to regularize denominator for masked average. + Returns: + Masked Frame Aligned Point Error. + """ + def unsqueeze_rigids(rigid, axis=-1): + """add an axis in the axis of rot.xx and trans.x""" + if axis < 0: + axis_t = axis - 1 + axis_r = axis - 2 + else: + axis_t = axis + axis_r = axis + + rotation = paddle.unsqueeze(rigid.rot.rotation, axis=axis_r) + translation = paddle.unsqueeze(rigid.trans.translation, axis=axis_t) + return r3.Rigids(rot=r3.Rots(rotation), trans=r3.Vecs(translation)) + + def unsqueeze_vecs(vecs, axis=-1): + """add an axis in the axis of rot.xx and trans.x""" + if axis < 0: + axis_t = axis - 1 + else: + axis_t = axis + + translation = paddle.unsqueeze(vecs.translation, axis=axis_t) + return r3.Vecs(translation) + + # Compute array of predicted positions in the predicted frames. + # r3.Vecs (num_frames, num_positions) + local_pred_pos = r3.rigids_mul_vecs( + unsqueeze_rigids(r3.invert_rigids(pred_frames)), + unsqueeze_vecs(pred_positions, axis=1)) + + # Compute array of target positions in the target frames. + # r3.Vecs (num_frames, num_positions) + local_target_pos = r3.rigids_mul_vecs( + unsqueeze_rigids(r3.invert_rigids(target_frames)), + unsqueeze_vecs(target_positions, axis=1)) + + # Compute errors between the structures. + # paddle.Tensor (num_frames, num_positions) + error_dist = paddle.sqrt(r3.vecs_squared_distance(local_pred_pos, local_target_pos) + epsilon) + + if l1_clamp_distance: + error_dist = paddle.clip(error_dist, min=0, max=l1_clamp_distance) + + normed_error = error_dist / length_scale + normed_error *= paddle.unsqueeze(frames_mask, axis=-1) + normed_error *= paddle.unsqueeze(positions_mask, axis=-2) + + normalization_factor = ( + paddle.sum(frames_mask, axis=-1) * + paddle.sum(positions_mask, axis=-1)) + return (paddle.sum(normed_error, axis=[-2, -1]) / + (epsilon + normalization_factor)) diff --git a/apps/protein_folding/helixfold_cpu/tools/dap.py b/apps/protein_folding/helixfold_cpu/tools/dap.py new file mode 100644 index 00000000..9b6a2601 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/dap.py @@ -0,0 +1,507 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Dynamic Axial Parallelism and Duality Async Operation helper functions +paper ref: FastFold: Reducing AlphaFold Training Time from 11 Days to 67 Hours, https://arxiv.org/abs/2203.00854 +code ref: https://github.com/hpcaitech/FastFold.git +""" + +import warnings +import paddle +from paddle import distributed as dist +from paddle.autograd import PyLayer + +__all__ = [ + 'init_dap', + 'dap_is_initialized', + 'get_tensor_model_parallel_group', + 'get_data_parallel_group', + 'get_tensor_model_parallel_world_size', + 'get_tensor_model_parallel_rank', + 'get_data_parallel_world_size', + 'get_data_parallel_rank', + 'get_tensor_model_parallel_src_rank', + 'scatter', + 'gather', + 'all_gather', + 'all_gather_opp', + 'all_to_all', + 'all_to_all_opp', + 'row_to_col', + 'col_to_row' + ] + +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None +# Intra-layer model parallel group that the current rank belongs to. +_TENSOR_MODEL_PARALLEL_GROUP = None + +# These values enable us to change the mpu sizes on the fly. +_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None +_TENSOR_MODEL_PARALLEL_RANK = None + +# communication whether use_calc_stream (sync) or not (async). Default True +_COMM_SYNC = None + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator) + + +def divide(numerator, denominator): + ensure_divisibility(numerator, denominator) + return numerator // denominator + +def init_dap(tensor_model_parallel_size_=1, sync=True): + + global _COMM_SYNC + assert _COMM_SYNC is None, \ + 'communication manner `sync` is already initialized' + _COMM_SYNC = sync + + world_size = dist.get_world_size() + rank = dist.get_rank() + + # check dist config + ensure_divisibility(world_size, tensor_model_parallel_size_) + data_parallel_size_ = world_size // tensor_model_parallel_size_ + + # Build the data-parallel groups. + global _DATA_PARALLEL_GROUP + assert _DATA_PARALLEL_GROUP is None, \ + 'data parallel group is already initialized' + for i in range(tensor_model_parallel_size_): + ranks = list(range(i, world_size, tensor_model_parallel_size_)) + group = dist.new_group(ranks) + print('> dp ranks:', ranks, 'dp group:', group) + if rank in ranks: + _DATA_PARALLEL_GROUP = group + + global _TENSOR_MODEL_PARALLEL_GROUP + assert _TENSOR_MODEL_PARALLEL_GROUP is None, \ + 'tensor model parallel group is already initialized' + # Build the model-parallel groups. + for i in range(data_parallel_size_): + ranks = list(range(i * tensor_model_parallel_size_, (i + 1) * tensor_model_parallel_size_)) + group = dist.new_group(ranks) + print('> mp ranks:', ranks, 'mp group', group) + if rank in ranks: + _TENSOR_MODEL_PARALLEL_GROUP = group + + if dist.get_rank() == 0: + print('> initialize tensor model parallel with size {}'.format(tensor_model_parallel_size_)) + print('> initialize data parallel with size {}'.format(data_parallel_size_)) + +def dap_is_initialized(): + """Check if model and data parallel groups are initialized.""" + global _DATA_PARALLEL_GROUP + global _TENSOR_MODEL_PARALLEL_GROUP + if _TENSOR_MODEL_PARALLEL_GROUP is None or \ + _DATA_PARALLEL_GROUP is None: + return False + return True + +def is_sync(): + return _COMM_SYNC + +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \ + 'intra_layer_model parallel group is not initialized' + return _TENSOR_MODEL_PARALLEL_GROUP + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, \ + 'data parallel group is not initialized' + return _DATA_PARALLEL_GROUP + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + if not dap_is_initialized(): + warnings.warn("DAP comminication group is not initialized.") + return 1 + global _TENSOR_MODEL_PARALLEL_WORLD_SIZE + if _TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: + return _TENSOR_MODEL_PARALLEL_WORLD_SIZE + return get_tensor_model_parallel_group().nranks + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + if not dap_is_initialized(): + warnings.warn("DAP comminication group is not initialized.") + return 0 + global _TENSOR_MODEL_PARALLEL_RANK + if _TENSOR_MODEL_PARALLEL_RANK is not None: + return _TENSOR_MODEL_PARALLEL_RANK + return get_tensor_model_parallel_group().rank + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + if not dap_is_initialized(): + warnings.warn("DAP comminication group is not initialized.") + return 1 + return get_data_parallel_group().nranks + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + if not dap_is_initialized(): + warnings.warn("DAP comminication group is not initialized.") + return 0 + return get_data_parallel_group().rank + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = dist.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +@paddle.no_grad() +def _gather(tensor, axis=-1): + tensor_list = [] + dist.all_gather(tensor_list, + tensor, + group=get_tensor_model_parallel_group(), + use_calc_stream=True) + output = paddle.concat(tensor_list, axis=axis) + return output + + +@paddle.no_grad() +def _split(tensor, axis=-1): + ensure_divisibility(tensor.shape[axis], get_tensor_model_parallel_world_size()) + tensor_list = paddle.split(tensor, get_tensor_model_parallel_world_size(), axis=axis) + + output = tensor_list[get_tensor_model_parallel_rank()] + + return output + + +class Scatter(PyLayer): + """ Scatter PyLayer Op""" + @staticmethod + def forward(ctx, input, axis:-1): + ctx.axis = axis + return _split(input, axis=axis) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, axis=ctx.axis) + + +def scatter(input, axis=-1): + """ split a tensor according axis by dap size """ + if get_tensor_model_parallel_world_size() == 1: + return input + + if not input.stop_gradient: + output = Scatter.apply(input, axis=axis) + else: + output = _split(input, axis=axis) + return output + + +class Gather(PyLayer): + """ Gather PyLayer Op """ + @staticmethod + def forward(ctx, input, axis=-1): + ctx.axis = axis + return _gather(input, axis=axis) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, axis=ctx.axis) + +def gather(input, axis=-1): + """ gather tensor form all rank in dap group in axis """ + if get_tensor_model_parallel_world_size() == 1: + return input + + if not input.stop_gradient: + output = Gather.apply(input, axis=axis) + else: + output = _gather(input, axis=axis) + return output + +@paddle.no_grad() +def _all_gather(tensor, axis=-1, sync=True): + if not sync: + dist.wait(tensor, group=get_tensor_model_parallel_group(), use_calc_stream=True) + + group = get_tensor_model_parallel_group() + ring_id = group.id + nranks = group.nranks + output = paddle._C_ops.c_allgather(tensor, 'use_calc_stream', sync, 'ring_id', ring_id, 'nranks', nranks) + + return output + +@paddle.no_grad() +def _reduce_scatter(tensor, sync=True): + if not sync: + dist.wait(tensor, group=get_tensor_model_parallel_group(), use_calc_stream=True) + + group = get_tensor_model_parallel_group() + ring_id = group.id + nranks = group.nranks + output = paddle._C_ops.c_reducescatter(tensor, 'use_calc_stream', sync, 'ring_id', ring_id, 'nranks', nranks) + paddle.device.cuda.synchronize() + return output + +class AllGather(PyLayer): + """ AllGather PyLayer Op """ + @staticmethod + def forward(ctx, input, axis=-1, sync=True): + ctx.axis = axis + ctx.sync = sync + output = _all_gather(input, axis=axis, sync=sync) + return output + + @staticmethod + def backward(ctx, grad_output): + if not ctx.sync: + dist.wait(grad_output, group=get_tensor_model_parallel_group(), use_calc_stream=ctx.sync) + return grad_output + +class AllGather_Opp(PyLayer): + """ Duality Async Operation for AllGather """ + @staticmethod + def forward(ctx, input, axis=-1, sync=True): + ctx.axis = axis + ctx.sync = sync + return input + + @staticmethod + def backward(ctx, grad_output): + output = _reduce_scatter(grad_output, sync=ctx.sync) + return output + + +def all_gather(input, axis=-1, sync=None): + """ gather tensors from all rank in dap group and all get the result. + if sync=None, sync will be assign according init_dap setting. + + when using async communication, sync=False, do not use the output as same as input. + E.g. do not use `a = all_gather(a, ...)`, recommend to use `b = all_gather(a, ...)` + """ + if get_tensor_model_parallel_world_size() == 1: + return input + + if sync is None: + sync = is_sync() + + if not input.stop_gradient: + output = AllGather.apply(input, axis, sync=sync) + else: + output = _all_gather(input, axis, sync=sync) + return output + + +def all_gather_opp(output, axis=-1, sync=None): + """ Duality Async Operation for all_gather. + if sync=None, sync will be assign according init_dap setting. + """ + nranks = get_tensor_model_parallel_world_size() + if nranks == 1: + return output + + if sync is None: + sync = is_sync() + + if not sync: + dist.wait(output, group=get_tensor_model_parallel_group(), use_calc_stream=sync) + + if not output.stop_gradient: + output = AllGather_Opp.apply(output, axis, sync=sync) + + if axis != 0: + output = paddle.concat(paddle.split(output, nranks, 0), axis=axis) + + return output + + +@paddle.no_grad() +def _all_to_all(tensor, in_axis=-1, out_axis=-1, sync=True): + if not sync: + dist.wait(tensor, group=get_tensor_model_parallel_group(), use_calc_stream=True) + + group = get_tensor_model_parallel_group() + ring_id = group.id + + output = paddle._C_ops.alltoall(tensor, 'use_calc_stream', sync, 'ring_id', ring_id) + + return output + + +class All_to_All(PyLayer): + """ All_to_All PyLayer Op""" + @staticmethod + def forward(ctx, + input, + in_axis=-1, + out_axis=-1, + sync=True): + ctx.in_axis = in_axis + ctx.out_axis = out_axis + ctx.sync = sync + return _all_to_all(input, in_axis=in_axis, out_axis=out_axis, sync=sync) + + @staticmethod + def backward(ctx, grad_output): + if not ctx.sync: + dist.wait(grad_output, group=get_tensor_model_parallel_group(), use_calc_stream=ctx.sync) + return grad_output + + +class All_to_All_Opp(PyLayer): + """ Duality Async Operation for All_to_All """ + @staticmethod + def forward(ctx, output, in_axis=-1, out_axis=-1, sync=True): + ctx.in_axis = in_axis + ctx.out_axis = out_axis + ctx.sync = sync + return output + + @staticmethod + def backward(ctx, grad_output): + return _all_to_all(grad_output, in_axis=ctx.out_axis, out_axis=ctx.in_axis, sync=ctx.sync) + + +class All2All(PyLayer): + @staticmethod + def forward(ctx, + input, + in_axis=-1, + out_axis=-1): + ctx.in_axis = in_axis + ctx.out_axis = out_axis + return _all_to_all(input, in_axis=in_axis, out_axis=out_axis, sync=True) + + @staticmethod + def backward(ctx, grad_output): + return _all_to_all(grad_output, in_axis=ctx.out_axis, out_axis=ctx.in_axis, sync=True) + + +def all_to_all(input, in_axis, out_axis, sync=True): + """ all to all according in_axis and out_axis. + if sync=None, sync will be assign according init_dap setting. + """ + if get_tensor_model_parallel_world_size() == 1: + return input + + if sync is None: + sync = is_sync() + + if in_axis != 0: + ensure_divisibility(input.shape[in_axis], get_tensor_model_parallel_world_size()) + input = paddle.concat(paddle.split(input, get_tensor_model_parallel_world_size(), axis=in_axis), axis=0) + + if not input.stop_gradient: + output = All_to_All.apply(input, in_axis=in_axis, out_axis=out_axis, sync=sync) + else: + output = _all_to_all(input, in_axis=in_axis, out_axis=out_axis, sync=sync) + + return output + + +def all_to_all_opp(output, in_axis, out_axis, sync=True): + """ Duality Async Operation for all_to_all. + if sync=None, sync will be assign according init_dap setting. + """ + if get_tensor_model_parallel_world_size() == 1: + return output + + if sync is None: + sync = is_sync() + + if not sync: + dist.wait(output, group=get_tensor_model_parallel_group(), use_calc_stream=sync) + + if not output.stop_gradient: + output = All_to_All_Opp.apply(output, in_axis=in_axis, out_axis=out_axis, sync=sync) + + if out_axis != 0: + ensure_divisibility(output.shape[0], get_tensor_model_parallel_world_size()) + output = paddle.concat(paddle.split(output, get_tensor_model_parallel_world_size(), axis=0), axis=out_axis) + + return output + + +def row_to_col(input): + """ N, S, R, C => N, R, S, C using sync all_to_all """ + if get_tensor_model_parallel_world_size() == 1: + return input + + ensure_divisibility(input.shape[2], get_tensor_model_parallel_world_size()) + input = paddle.concat(paddle.split(input, get_tensor_model_parallel_world_size(), axis=2), axis=0) + + if not input.stop_gradient: + output = All2All.apply(input, in_axis=2, out_axis=1) + else: + output = _all_to_all(input, in_axis=2, out_axis=1) + + output = paddle.concat(paddle.split(output, get_tensor_model_parallel_world_size(), axis=0), axis=1) + return output + + +def col_to_row(input): + """ N, R, S, C => N, S, R, C using sync all_to_all """ + if get_tensor_model_parallel_world_size() == 1: + return input + + ensure_divisibility(input.shape[1], get_tensor_model_parallel_world_size()) + input = paddle.concat(paddle.split(input, get_tensor_model_parallel_world_size(), axis=1), axis=0) + + if not input.stop_gradient: + output = All2All.apply(input, in_axis=1, out_axis=2) + else: + output = _all_to_all(input, in_axis=1, out_axis=2) + + output = paddle.concat(paddle.split(output, get_tensor_model_parallel_world_size(), axis=0), axis=2) + return output + + +@paddle.no_grad() +def grad_sync(param_groups, comm_group): + """ + sync the gradients of params + """ + + nranks = comm_group.nranks + + if nranks < 2: + return + + for group in param_groups: + if group.get("dap", False): + for p in group['params']: + if p.is_distributed: + continue + + grad = p.grad + if grad is None: + continue + + paddle.distributed.all_reduce( + grad, use_calc_stream=True, group=comm_group) + + return None diff --git a/apps/protein_folding/helixfold_cpu/tools/data_transforms.py b/apps/protein_folding/helixfold_cpu/tools/data_transforms.py new file mode 100644 index 00000000..4dec980d --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/data_transforms.py @@ -0,0 +1,624 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data for AlphaFold.""" + +from tools import residue_constants +from tools import shape_helpers +from tools import shape_placeholders +from tools import input_utils +import numpy as np +import tensorflow.compat.v1 as tf + +# Pylint gets confused by the curry1 decorator because it changes the number +# of arguments to the function. +# pylint:disable=no-value-for-parameter + + +NUM_RES = shape_placeholders.NUM_RES +NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ +NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ +NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES + + +def cast_64bit_ints(protein): + + for k, v in protein.items(): + if v.dtype == tf.int64: + protein[k] = tf.cast(v, tf.int32) + return protein + + +_MSA_FEATURE_NAMES = [ + 'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask', + 'true_msa' +] + + +def make_seq_mask(protein): + protein['seq_mask'] = tf.ones( + shape_helpers.shape_list(protein['aatype']), dtype=tf.float32) + return protein + + +def make_template_mask(protein): + protein['template_mask'] = tf.ones( + shape_helpers.shape_list(protein['template_domain_names']), + dtype=tf.float32) + return protein + + +def curry1(f): + """Supply all arguments but the first.""" + + def fc(*args, **kwargs): + return lambda x: f(x, *args, **kwargs) + + return fc + + +@curry1 +def add_distillation_flag(protein, distillation): + protein['is_distillation'] = tf.constant(float(distillation), + shape=[], + dtype=tf.float32) + return protein + + +def make_all_atom_aatype(protein): + protein['all_atom_aatype'] = protein['aatype'] + return protein + + +def fix_templates_aatype(protein): + """Fixes aatype encoding of templates.""" + # Map one-hot to indices. + protein['template_aatype'] = tf.argmax( + protein['template_aatype'], output_type=tf.int32, axis=-1) + # Map hhsearch-aatype to our aatype. + new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = tf.constant(new_order_list, dtype=tf.int32) + protein['template_aatype'] = tf.gather(params=new_order, + indices=protein['template_aatype']) + return protein + + +def correct_msa_restypes(protein): + """Correct MSA restype to have the same order as residue_constants.""" + new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = tf.constant(new_order_list, dtype=protein['msa'].dtype) + protein['msa'] = tf.gather(new_order, protein['msa'], axis=0) + + perm_matrix = np.zeros((22, 22), dtype=np.float32) + perm_matrix[range(len(new_order_list)), new_order_list] = 1. + + for k in protein: + if 'profile' in k: # Include both hhblits and psiblast profiles + num_dim = protein[k].shape.as_list()[-1] + assert num_dim in [20, 21, 22], ( + 'num_dim for %s out of expected range: %s' % (k, num_dim)) + protein[k] = tf.tensordot(protein[k], perm_matrix[:num_dim, :num_dim], 1) + return protein + + +def squeeze_features(protein): + """Remove singleton and repeated dimensions in protein features.""" + protein['aatype'] = tf.argmax( + protein['aatype'], axis=-1, output_type=tf.int32) + for k in [ + 'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence', + 'superfamily', 'deletion_matrix', 'resolution', + 'between_segment_residues', 'residue_index', 'template_all_atom_masks']: + if k in protein: + final_dim = shape_helpers.shape_list(protein[k])[-1] + if isinstance(final_dim, int) and final_dim == 1: + protein[k] = tf.squeeze(protein[k], axis=-1) + + for k in ['seq_length', 'num_alignments']: + if k in protein: + protein[k] = protein[k][0] # Remove fake sequence dimension + return protein + + +def make_random_crop_to_size_seed(protein): + """Random seed for cropping residues and templates.""" + protein['random_crop_to_size_seed'] = utils.make_random_seed() + return protein + + +@curry1 +def randomly_replace_msa_with_unknown(protein, replace_proportion): + """Replace a proportion of the MSA with 'X'.""" + msa_mask = (tf.random.uniform(shape_helpers.shape_list(protein['msa'])) < + replace_proportion) + x_idx = 20 + gap_idx = 21 + msa_mask = tf.logical_and(msa_mask, protein['msa'] != gap_idx) + protein['msa'] = tf.where(msa_mask, + tf.ones_like(protein['msa']) * x_idx, + protein['msa']) + aatype_mask = ( + tf.random.uniform(shape_helpers.shape_list(protein['aatype'])) < + replace_proportion) + + protein['aatype'] = tf.where(aatype_mask, + tf.ones_like(protein['aatype']) * x_idx, + protein['aatype']) + return protein + + +@curry1 +def sample_msa(protein, max_seq, keep_extra): + """Sample MSA randomly, remaining sequences are stored as `extra_*`. + + Args: + protein: batch to sample msa from. + max_seq: number of sequences to sample. + keep_extra: When True sequences not sampled are put into fields starting + with 'extra_*'. + + Returns: + Protein with sampled msa. + """ + num_seq = tf.shape(protein['msa'])[0] + shuffled = tf.random_shuffle(tf.range(1, num_seq)) + index_order = tf.concat([[0], shuffled], axis=0) + num_sel = tf.minimum(max_seq, num_seq) + + sel_seq, not_sel_seq = tf.split(index_order, [num_sel, num_seq - num_sel]) + + for k in _MSA_FEATURE_NAMES: + if k in protein: + if keep_extra: + protein['extra_' + k] = tf.gather(protein[k], not_sel_seq) + protein[k] = tf.gather(protein[k], sel_seq) + + return protein + + +@curry1 +def crop_extra_msa(protein, max_extra_msa): + """MSA features are cropped so only `max_extra_msa` sequences are kept.""" + num_seq = tf.shape(protein['extra_msa'])[0] + num_sel = tf.minimum(max_extra_msa, num_seq) + select_indices = tf.random_shuffle(tf.range(0, num_seq))[:num_sel] + for k in _MSA_FEATURE_NAMES: + if 'extra_' + k in protein: + protein['extra_' + k] = tf.gather(protein['extra_' + k], select_indices) + + return protein + + +def delete_extra_msa(protein): + for k in _MSA_FEATURE_NAMES: + if 'extra_' + k in protein: + del protein['extra_' + k] + return protein + + +@curry1 +def block_delete_msa(protein, config): + """Sample MSA by deleting contiguous blocks. + + Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion" + + Arguments: + protein: batch dict containing the msa + config: ConfigDict with parameters + + Returns: + updated protein + """ + num_seq = shape_helpers.shape_list(protein['msa'])[0] + block_num_seq = tf.cast( + tf.floor(tf.cast(num_seq, tf.float32) * config.msa_fraction_per_block), + tf.int32) + + if config.randomize_num_blocks: + nb = tf.random.uniform([], 0, config.num_blocks + 1, dtype=tf.int32) + else: + nb = config.num_blocks + + del_block_starts = tf.random.uniform([nb], 0, num_seq, dtype=tf.int32) + del_blocks = del_block_starts[:, None] + tf.range(block_num_seq) + del_blocks = tf.clip_by_value(del_blocks, 0, num_seq - 1) + del_indices = tf.unique(tf.sort(tf.reshape(del_blocks, [-1])))[0] + + # Make sure we keep the original sequence + sparse_diff = tf.sets.difference(tf.range(1, num_seq)[None], + del_indices[None]) + keep_indices = tf.squeeze(tf.sparse.to_dense(sparse_diff), 0) + keep_indices = tf.concat([[0], keep_indices], axis=0) + + for k in _MSA_FEATURE_NAMES: + if k in protein: + protein[k] = tf.gather(protein[k], keep_indices) + + return protein + + +@curry1 +def nearest_neighbor_clusters(protein, gap_agreement_weight=0.): + """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" + + # Determine how much weight we assign to each agreement. In theory, we could + # use a full blosum matrix here, but right now let's just down-weight gap + # agreement because it could be spurious. + # Never put weight on agreeing on BERT mask + weights = tf.concat([ + tf.ones(21), + gap_agreement_weight * tf.ones(1), + np.zeros(1)], 0) + + # Make agreement score as weighted Hamming distance + sample_one_hot = (protein['msa_mask'][:, :, None] * + tf.one_hot(protein['msa'], 23)) + extra_one_hot = (protein['extra_msa_mask'][:, :, None] * + tf.one_hot(protein['extra_msa'], 23)) + + num_seq, num_res, _ = shape_helpers.shape_list(sample_one_hot) + extra_num_seq, _, _ = shape_helpers.shape_list(extra_one_hot) + + # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights) + # in an optimized fashion to avoid possible memory or computation blowup. + agreement = tf.matmul( + tf.reshape(extra_one_hot, [extra_num_seq, num_res * 23]), + tf.reshape(sample_one_hot * weights, [num_seq, num_res * 23]), + transpose_b=True) + + # Assign each sequence in the extra sequences to the closest MSA sample + protein['extra_cluster_assignment'] = tf.argmax( + agreement, axis=1, output_type=tf.int32) + + return protein + + +@curry1 +def summarize_clusters(protein): + """Produce profile and deletion_matrix_mean within each cluster.""" + num_seq = shape_helpers.shape_list(protein['msa'])[0] + def csum(x): + return tf.math.unsorted_segment_sum( + x, protein['extra_cluster_assignment'], num_seq) + + mask = protein['extra_msa_mask'] + mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center + + msa_sum = csum(mask[:, :, None] * tf.one_hot(protein['extra_msa'], 23)) + msa_sum += tf.one_hot(protein['msa'], 23) # Original sequence + protein['cluster_profile'] = msa_sum / mask_counts[:, :, None] + + del msa_sum + + del_sum = csum(mask * protein['extra_deletion_matrix']) + del_sum += protein['deletion_matrix'] # Original sequence + protein['cluster_deletion_mean'] = del_sum / mask_counts + del del_sum + + return protein + + +def make_msa_mask(protein): + """Mask features are all ones, but will later be zero-padded.""" + protein['msa_mask'] = tf.ones( + shape_helpers.shape_list(protein['msa']), dtype=tf.float32) + protein['msa_row_mask'] = tf.ones( + shape_helpers.shape_list(protein['msa'])[0], dtype=tf.float32) + return protein + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): + """Create pseudo beta features.""" + is_gly = tf.equal(aatype, residue_constants.restype_order['G']) + ca_idx = residue_constants.atom_order['CA'] + cb_idx = residue_constants.atom_order['CB'] + pseudo_beta = tf.where( + tf.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :]) + + if all_atom_masks is not None: + pseudo_beta_mask = tf.where( + is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx]) + pseudo_beta_mask = tf.cast(pseudo_beta_mask, tf.float32) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +@curry1 +def make_pseudo_beta(protein, prefix=''): + """Create pseudo-beta (alpha for glycine) position and mask.""" + assert prefix in ['', 'template_'] + protein[prefix + 'pseudo_beta'], protein[prefix + 'pseudo_beta_mask'] = ( + pseudo_beta_fn( + protein['template_aatype' if prefix else 'all_atom_aatype'], + protein[prefix + 'all_atom_positions'], + protein['template_all_atom_masks' if prefix else 'all_atom_mask'])) + return protein + + +@curry1 +def add_constant_field(protein, key, value): + protein[key] = tf.convert_to_tensor(value) + return protein + + +def shaped_categorical(probs, epsilon=1e-10): + ds = shape_helpers.shape_list(probs) + num_classes = ds[-1] + counts = tf.random.categorical( + tf.reshape(tf.log(probs + epsilon), [-1, num_classes]), + 1, + dtype=tf.int32) + return tf.reshape(counts, ds[:-1]) + + +def make_hhblits_profile(protein): + """Compute the HHblits MSA profile if not already present.""" + if 'hhblits_profile' in protein: + return protein + + # Compute the profile for every residue (over all MSA sequences). + protein['hhblits_profile'] = tf.reduce_mean( + tf.one_hot(protein['msa'], 22), axis=0) + return protein + + +@curry1 +def make_masked_msa(protein, config, replace_fraction): + """Create data for BERT on raw MSA.""" + # Add a random amino acid uniformly + random_aa = tf.constant([0.05] * 20 + [0., 0.], dtype=tf.float32) + + categorical_probs = ( + config.uniform_prob * random_aa + + config.profile_prob * protein['hhblits_profile'] + + config.same_prob * tf.one_hot(protein['msa'], 22)) + + # Put all remaining probability on [MASK] which is a new column + pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))] + pad_shapes[-1][1] = 1 + mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob + assert mask_prob >= 0. + categorical_probs = tf.pad( + categorical_probs, pad_shapes, constant_values=mask_prob) + + sh = shape_helpers.shape_list(protein['msa']) + mask_position = tf.random.uniform(sh) < replace_fraction + + bert_msa = shaped_categorical(categorical_probs) + bert_msa = tf.where(mask_position, bert_msa, protein['msa']) + + # Mix real and masked MSA + protein['bert_mask'] = tf.cast(mask_position, tf.float32) + protein['true_msa'] = protein['msa'] + protein['msa'] = bert_msa + + return protein + + +@curry1 +def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, + num_res, num_templates=0): + """Guess at the MSA and sequence dimensions to make fixed size.""" + + pad_size_map = { + NUM_RES: num_res, + NUM_MSA_SEQ: msa_cluster_size, + NUM_EXTRA_SEQ: extra_msa_size, + NUM_TEMPLATES: num_templates, + } + + for k, v in protein.items(): + # Don't transfer this to the accelerator. + if k == 'extra_cluster_assignment': + continue + shape = v.shape.as_list() + schema = shape_schema[k] + assert len(shape) == len(schema), ( + f'Rank mismatch between shape and shape schema for {k}: ' + f'{shape} vs {schema}') + pad_size = [ + pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema) + ] + padding = [(0, p - tf.shape(v)[i]) for i, p in enumerate(pad_size)] + if padding: + protein[k] = tf.pad( + v, padding, name=f'pad_to_fixed_{k}') + protein[k].set_shape(pad_size) + + return protein + + +@curry1 +def make_msa_feat(protein): + """Create and concatenate MSA features.""" + # Whether there is a domain break. Always zero for chains, but keeping + # for compatibility with domain datasets. + has_break = tf.clip_by_value( + tf.cast(protein['between_segment_residues'], tf.float32), + 0, 1) + aatype_1hot = tf.one_hot(protein['aatype'], 21, axis=-1) + + target_feat = [ + tf.expand_dims(has_break, axis=-1), + aatype_1hot, # Everyone gets the original sequence. + ] + + msa_1hot = tf.one_hot(protein['msa'], 23, axis=-1) + has_deletion = tf.clip_by_value(protein['deletion_matrix'], 0., 1.) + deletion_value = tf.atan(protein['deletion_matrix'] / 3.) * (2. / np.pi) + + msa_feat = [ + msa_1hot, + tf.expand_dims(has_deletion, axis=-1), + tf.expand_dims(deletion_value, axis=-1), + ] + + if 'cluster_profile' in protein: + deletion_mean_value = ( + tf.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi)) + msa_feat.extend([ + protein['cluster_profile'], + tf.expand_dims(deletion_mean_value, axis=-1), + ]) + + if 'extra_deletion_matrix' in protein: + protein['extra_has_deletion'] = tf.clip_by_value( + protein['extra_deletion_matrix'], 0., 1.) + protein['extra_deletion_value'] = tf.atan( + protein['extra_deletion_matrix'] / 3.) * (2. / np.pi) + + protein['msa_feat'] = tf.concat(msa_feat, axis=-1) + protein['target_feat'] = tf.concat(target_feat, axis=-1) + return protein + + +@curry1 +def select_feat(protein, feature_list): + return {k: v for k, v in protein.items() if k in feature_list} + + +@curry1 +def crop_templates(protein, max_templates): + for k, v in protein.items(): + if k.startswith('template_'): + protein[k] = v[:max_templates] + return protein + + +@curry1 +def random_crop_to_size(protein, crop_size, max_templates, shape_schema, + subsample_templates=False): + """Crop randomly to `crop_size`, or keep as is if shorter than that.""" + seq_length = protein['seq_length'] + if 'template_mask' in protein: + num_templates = tf.cast( + shape_helpers.shape_list(protein['template_mask'])[0], tf.int32) + else: + num_templates = tf.constant(0, dtype=tf.int32) + num_res_crop_size = tf.math.minimum(seq_length, crop_size) + + # Ensures that the cropping of residues and templates happens in the same way + # across ensembling iterations. + # Do not use for randomness that should vary in ensembling. + seed_maker = utils.SeedMaker(initial_seed=protein['random_crop_to_size_seed']) + + if subsample_templates: + templates_crop_start = tf.random.stateless_uniform( + shape=(), minval=0, maxval=num_templates + 1, dtype=tf.int32, + seed=seed_maker()) + else: + templates_crop_start = 0 + + num_templates_crop_size = tf.math.minimum( + num_templates - templates_crop_start, max_templates) + + num_res_crop_start = tf.random.stateless_uniform( + shape=(), minval=0, maxval=seq_length - num_res_crop_size + 1, + dtype=tf.int32, seed=seed_maker()) + + templates_select_indices = tf.argsort(tf.random.stateless_uniform( + [num_templates], seed=seed_maker())) + + for k, v in protein.items(): + if k not in shape_schema or ( + 'template' not in k and NUM_RES not in shape_schema[k]): + continue + + # randomly permute the templates before cropping them. + if k.startswith('template') and subsample_templates: + v = tf.gather(v, templates_select_indices) + + crop_sizes = [] + crop_starts = [] + for i, (dim_size, dim) in enumerate(zip(shape_schema[k], + shape_helpers.shape_list(v))): + is_num_res = (dim_size == NUM_RES) + if i == 0 and k.startswith('template'): + crop_size = num_templates_crop_size + crop_start = templates_crop_start + else: + crop_start = num_res_crop_start if is_num_res else 0 + crop_size = (num_res_crop_size if is_num_res else + (-1 if dim is None else dim)) + crop_sizes.append(crop_size) + crop_starts.append(crop_start) + protein[k] = tf.slice(v, crop_starts, crop_sizes) + + protein['seq_length'] = num_res_crop_size + return protein + + +def make_atom14_masks(protein): + """Construct denser atom positions (14 dimensions instead of 37).""" + restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 + restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 + restype_atom14_mask = [] + + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt]] + + restype_atom14_to_atom37.append([ + (residue_constants.atom_order[name] if name else 0) + for name in atom_names + ]) + + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in residue_constants.atom_types + ]) + + restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) + + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37.append([0] * 14) + restype_atom37_to_atom14.append([0] * 37) + restype_atom14_mask.append([0.] * 14) + + restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) + restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) + restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) + + # create the mapping for (residx, atom14) --> atom37, i.e. an array + # with shape (num_res, 14) containing the atom37 indices for this protein + residx_atom14_to_atom37 = tf.gather(restype_atom14_to_atom37, + protein['aatype']) + residx_atom14_mask = tf.gather(restype_atom14_mask, + protein['aatype']) + + protein['atom14_atom_exists'] = residx_atom14_mask + protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37 + + # create the gather indices for mapping back + residx_atom37_to_atom14 = tf.gather(restype_atom37_to_atom14, + protein['aatype']) + protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14 + + # create the corresponding mask + restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) + for restype, restype_letter in enumerate(residue_constants.restypes): + restype_name = residue_constants.restype_1to3[restype_letter] + atom_names = residue_constants.residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = residue_constants.atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + + residx_atom37_mask = tf.gather(restype_atom37_mask, + protein['aatype']) + protein['atom37_atom_exists'] = residx_atom37_mask + + return protein diff --git a/apps/protein_folding/helixfold_cpu/tools/folding.py b/apps/protein_folding/helixfold_cpu/tools/folding.py new file mode 100644 index 00000000..86d0d530 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/folding.py @@ -0,0 +1,991 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Modules and utilities for the structure module.""" +import pdb +import ml_collections +import numpy as np +import paddle +import paddle.nn as nn +import functools +from typing import Dict +from tools import residue_constants +from tools import all_atom +from tools import quat_affine +from tools import r3 +from tools import model_utils as utils + +def squared_difference(x, y): + return paddle.square(x - y) + +class InvariantPointAttention(nn.Layer): + """Invariant Point attention module. + + The high-level idea is that this attention module works over a set of points + and associated orientations in 3D space (e.g. protein residues). + + Each residue outputs a set of queries and keys as points in their local + reference frame. The attention is then defined as the euclidean distance + between the queries and keys in the global frame. + + Jumper et al. (2021) Suppl. Alg. 22 "InvariantPointAttention" + """ + def __init__(self, channel_num, config, global_config, + dist_epsilon=1e-8): + super(InvariantPointAttention, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + self.dist_epsilon = dist_epsilon + + num_head = self.config.num_head + num_scalar_qk = self.config.num_scalar_qk + num_point_qk = self.config.num_point_qk + num_scalar_v = self.config.num_scalar_v + num_point_v = self.config.num_point_v + num_output = self.config.num_channel + + assert num_scalar_qk > 0 + assert num_point_qk > 0 + assert num_point_v > 0 + + self.q_scalar = nn.Linear( + channel_num['seq_channel'], num_head * num_scalar_qk) + self.kv_scalar = nn.Linear( + channel_num['seq_channel'], + num_head * (num_scalar_v + num_scalar_qk)) + + self.q_point_local = nn.Linear( + channel_num['seq_channel'], num_head * 3 * num_point_qk) + self.kv_point_local = nn.Linear( + channel_num['seq_channel'], + num_head * 3 * (num_point_qk + num_point_v)) + + tpw = np.log(np.exp(1.) - 1.) + self.trainable_point_weights = paddle.create_parameter( + [num_head], 'float32', + default_initializer=nn.initializer.Constant(tpw)) + + self.attention_2d = nn.Linear(channel_num['pair_channel'], num_head) + + if self.global_config.zero_init: + init_w = nn.initializer.Constant(value=0.0) + else: + init_w = nn.initializer.XavierUniform() + + c = num_scalar_v + num_point_v * 4 + channel_num['pair_channel'] + self.output_projection = nn.Linear( + num_head * c, num_output, + weight_attr=paddle.ParamAttr(initializer=init_w)) + + def forward(self, single_act: paddle.Tensor, pair_act: paddle.Tensor, + mask: paddle.Tensor, affine: quat_affine.QuatAffine): + # single_act: [B, N, C] + # pair_act: [B, N, M, C'] + # mask: [B, N, 1] + num_residues = single_act.shape[1] + num_head = self.config.num_head + num_scalar_qk = self.config.num_scalar_qk + num_point_qk = self.config.num_point_qk + num_scalar_v = self.config.num_scalar_v + num_point_v = self.config.num_point_v + num_output = self.config.num_channel + + # Construct scalar queries of shape: + # [batch_size, num_query_residues, num_head, num_points] + q_scalar = self.q_scalar(single_act) + q_scalar = paddle.reshape( + q_scalar, [-1, num_residues, num_head, num_scalar_qk]) + + # Construct scalar keys/values of shape: + # [batch_size, num_target_residues, num_head, num_points] + kv_scalar = self.kv_scalar(single_act) + kv_scalar = paddle.reshape( + kv_scalar, + [-1, num_residues, num_head, num_scalar_v + num_scalar_qk]) + k_scalar, v_scalar = paddle.split( + kv_scalar, [num_scalar_qk, -1], axis=-1) + + # Construct query points of shape: + # [batch_size, num_residues, num_head, num_point_qk] + q_point_local = self.q_point_local(single_act) + q_point_local = paddle.split(q_point_local, 3, axis=-1) + + q_point_global = affine.apply_to_point(q_point_local, extra_dims=1) + q_point = [ + paddle.reshape(x, [-1, num_residues, num_head, num_point_qk]) + for x in q_point_global] + + # Construct key and value points. + # Key points shape [batch_size, num_residues, num_head, num_point_qk] + # Value points shape [batch_size, num_residues, num_head, num_point_v] + kv_point_local = self.kv_point_local(single_act) + kv_point_local = paddle.split(kv_point_local, 3, axis=-1) + + kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1) + kv_point_global = [ + paddle.reshape(x, [-1, num_residues, num_head, num_point_qk + num_point_v]) + for x in kv_point_global] + + k_point, v_point = list( + zip(*[ + paddle.split(x, [num_point_qk, -1], axis=-1) + for x in kv_point_global + ])) + + # We assume that all queries and keys come iid from N(0, 1) distribution + # and compute the variances of the attention logits. + # Each scalar pair (q, k) contributes Var q*k = 1 + scalar_variance = max(num_scalar_qk, 1) * 1. + # Each point pair (q, k) contributes Var [0.5 ||q||^2 - ] = 9 / 2 + point_variance = max(num_point_qk, 1) * 9. / 2 + + # Allocate equal variance to scalar, point and attention 2d parts so that + # the sum is 1. + + num_logit_terms = 3 + scalar_weights = np.sqrt(1.0 / (num_logit_terms * scalar_variance)) + point_weights = np.sqrt(1.0 / (num_logit_terms * point_variance)) + attention_2d_weights = np.sqrt(1.0 / (num_logit_terms)) + + trainable_point_weights = nn.functional.softplus( + self.trainable_point_weights) + point_weights *= paddle.unsqueeze( + trainable_point_weights, axis=1) + + # [B, R, H, C] => [B, H, R, C], put head dim first + q_point = [paddle.transpose(x, [0, 2, 1, 3]) for x in q_point] + k_point = [paddle.transpose(x, [0, 2, 1, 3]) for x in k_point] + v_point = [paddle.transpose(x, [0, 2, 1, 3]) for x in v_point] + + dist2 = [ + paddle.square(paddle.unsqueeze(qx, axis=-2) - \ + paddle.unsqueeze(kx, axis=-3)) + for qx, kx in zip(q_point, k_point)] + dist2 = sum(dist2) + + attn_qk_point = -0.5 * paddle.sum( + paddle.unsqueeze(point_weights, axis=[1, 2]) * dist2, axis=-1) + + q = paddle.transpose(scalar_weights * q_scalar, [0, 2, 1, 3]) + k = paddle.transpose(k_scalar, [0, 2, 1, 3]) + v = paddle.transpose(v_scalar, [0, 2, 1, 3]) + attn_qk_scalar = paddle.matmul(q, paddle.transpose(k, [0, 1, 3, 2])) + attn_logits = attn_qk_scalar + attn_qk_point + + attention_2d = self.attention_2d(pair_act) + attention_2d = paddle.transpose(attention_2d, [0, 3, 1, 2]) + attention_2d = attention_2d_weights * attention_2d + attn_logits += attention_2d + + mask_2d = mask * paddle.transpose(mask, [0, 2, 1]) + attn_logits -= 1e5 * (1. - mask_2d.unsqueeze(1)) + + # [batch_size, num_head, num_query_residues, num_target_residues] + attn = nn.functional.softmax(attn_logits) + + # o_i^h + # [batch_size, num_query_residues, num_head, num_head * num_scalar_v] + result_scalar = paddle.matmul(attn, v) + result_scalar = paddle.transpose(result_scalar, [0, 2, 1, 3]) + + # o_i^{hp} + # [batch_size, num_query_residues, num_head, num_head * num_point_v] + result_point_global = [ + paddle.sum(paddle.unsqueeze(attn, -1) * paddle.unsqueeze(vx, -3), + axis=-2) for vx in v_point] + result_point_global = [ + paddle.transpose(x, [0, 2, 1, 3]) for x in result_point_global] + + # \tilde{o}_i^h + # [batch_size, num_residues, num_head, pair_channel] + result_attention_over_2d = paddle.einsum( + 'nhij,nijc->nihc', attn, pair_act) + + # Reshape, global-to-local and save + result_scalar = paddle.reshape( + result_scalar, [-1, num_residues, num_head * num_scalar_v]) + result_point_global = [ + paddle.reshape(x, [-1, num_residues, num_head * num_point_v]) + for x in result_point_global] + result_point_local = affine.invert_point( + result_point_global, extra_dims=1) + result_attention_over_2d = paddle.reshape( + result_attention_over_2d, + [-1, num_residues, num_head * self.channel_num['pair_channel']]) + + result_point_local_norm = paddle.sqrt( + self.dist_epsilon + paddle.square(result_point_local[0]) + \ + paddle.square(result_point_local[1]) + \ + paddle.square(result_point_local[2])) + + output_features = [result_scalar] + output_features.extend(result_point_local) + output_features.extend( + [result_point_local_norm, result_attention_over_2d]) + + final_act = paddle.concat(output_features, axis=-1) + return self.output_projection(final_act) + + +class FoldIteration(nn.Layer): + """A single iteration of the main structure module loop. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" lines 6-21 + + First, each residue attends to all residues using InvariantPointAttention. + Then, we apply transition layers to update the hidden representations. + Finally, we use the hidden representations to produce an update to the + affine of each residue. + """ + def __init__(self, channel_num, config, global_config): + super(FoldIteration, self).__init__() + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + self.invariant_point_attention = InvariantPointAttention( + channel_num, config, global_config) + self.attention_layer_norm = nn.LayerNorm(channel_num['seq_channel']) + + for i in range(self.config.num_layer_in_transition): + if i < self.config.num_layer_in_transition - 1: + init_w = nn.initializer.KaimingNormal() + elif self.global_config.zero_init: + init_w = nn.initializer.Constant(value=0.0) + else: + init_w = nn.initializer.XavierUniform() + + layer_name, c_in = 'transition', channel_num['seq_channel'] + if i > 0: + layer_name, c_in = f'transition_{i}', self.config.num_channel + + setattr(self, layer_name, nn.Linear( + c_in, self.config.num_channel, + weight_attr=paddle.ParamAttr(initializer=init_w))) + + self.ipa_dropout = nn.Dropout(p=self.config.dropout) + self.transition_dropout = nn.Dropout(p=self.config.dropout) + self.transition_layer_norm = nn.LayerNorm(self.config.num_channel) + + if self.global_config.zero_init: + last_init_w = nn.initializer.Constant(value=0.0) + else: + last_init_w = nn.initializer.XavierUniform() + + # Jumper et al. (2021) Alg. 23 "Backbone update" + self.affine_update = nn.Linear( + self.config.num_channel, 6, + weight_attr=paddle.ParamAttr(initializer=last_init_w)) + + self.rigid_sidechain = MultiRigidSidechain( + channel_num, self.config.sidechain, self.global_config) + + def forward(self, activations, init_single_act, static_pair_act, + seq_mask, aatype): + affine = quat_affine.QuatAffine.from_tensor(activations['affine']) + act = activations['act'] + + attn = self.invariant_point_attention( + act, static_pair_act, seq_mask, affine) + act += attn + act = self.ipa_dropout(act) + act = self.attention_layer_norm(act) + + input_act = act + for i in range(self.config.num_layer_in_transition): + layer_name = 'transition' + if i > 0: + layer_name = f'transition_{i}' + + act = getattr(self, layer_name)(act) + + if i < self.config.num_layer_in_transition - 1: + act = nn.functional.relu(act) + + act += input_act + act = self.transition_dropout(act) + act = self.transition_layer_norm(act) + + affine_update = self.affine_update(act) + affine = affine.pre_compose(affine_update) + + sc = self.rigid_sidechain( + affine.scale_translation(self.config.position_scale), + act, init_single_act, aatype) + outputs = {'affine': affine.to_tensor(), 'sc': sc} + + affine = affine.stop_rot_gradient() + new_activations = { + 'act': act, + 'affine': affine.to_tensor() + } + return new_activations, outputs + + +class StructureModule(nn.Layer): + """StructureModule as a network head. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" + """ + def __init__(self, channel_num, config, global_config): + super(StructureModule, self).__init__() + assert config.num_layer > 0 + + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + self.single_layer_norm = nn.LayerNorm(channel_num['seq_channel']) + self.initial_projection = nn.Linear( + channel_num['seq_channel'], config.num_channel) + self.pair_layer_norm = nn.LayerNorm(channel_num['pair_channel']) + + self.fold_iteration = FoldIteration( + channel_num, config, global_config) + + def forward(self, representations, batch): + """tbd.""" + + output = self._generate_affines(representations, batch) + + ret = dict() + ret['representations'] = {'structure_module': output['act']} + + # NOTE: pred unit is nanometer, *position_scale to scale back to + # angstroms to match unit of PDB files. + # (L, B, N, 7), L = FoldIteration layers + scale = paddle.to_tensor( + [1.] * 4 + [self.config.position_scale] * 3, 'float32') + ret['traj'] = output['affine'] * paddle.unsqueeze( + scale, axis=[0, 1, 2]) + + ret['sidechains'] = output['sc'] + + # (B, N, 14, 3) + atom14_pred_positions = output['sc']['atom_pos'][-1] + ret['final_atom14_positions'] = atom14_pred_positions + + # (B, N, 14) + ret['final_atom14_mask'] = batch['atom14_atom_exists'] + + # (B, N, 37, 3) + atom37_pred_positions = all_atom.atom14_to_atom37( + atom14_pred_positions, batch) + atom37_pred_positions *= paddle.unsqueeze( + batch['atom37_atom_exists'], axis=-1) + ret['final_atom_positions'] = atom37_pred_positions + + # (B, N, 37) + ret['final_atom_mask'] = batch['atom37_atom_exists'] + + # (B, N, 7) + ret['final_affines'] = ret['traj'][-1] + + return ret + + def loss(self, value, batch): + ret = {'loss': 0.} + + ret['metrics'] = {} + # If requested, compute in-graph metrics. + if self.config.compute_in_graph_metrics: + atom14_pred_positions = value['final_atom14_positions'] + # Compute renaming and violations. + value.update(compute_renamed_ground_truth(batch, paddle.to_tensor(atom14_pred_positions))) + value['violations'] = find_structural_violations( + batch, atom14_pred_positions, self.config) + + # Several violation metrics: + violation_metrics = compute_violation_metrics( + batch=batch, + atom14_pred_positions=atom14_pred_positions, + violations=value['violations']) + ret['metrics'].update(violation_metrics) + + backbone_loss(ret, batch, value, self.config) + + if 'renamed_atom14_gt_positions' not in value: + tmp_atom14_positions = value['final_atom14_positions'] + value.update(compute_renamed_ground_truth(batch, paddle.to_tensor(tmp_atom14_positions))) + + sc_loss = sidechain_loss(batch, value, self.config) + + ret['loss'] = ((1 - self.config.sidechain.weight_frac) * ret['loss'] + self.config.sidechain.weight_frac * sc_loss['loss']) + ret['sidechain_fape'] = sc_loss['fape'] + + supervised_chi_loss(ret, batch, value, self.config) + + # Finetune loss + if self.config.structural_violation_loss_weight: + if 'violations' not in value: + value['violations'] = find_structural_violations(batch, value['final_atom14_positions'], self.config) + structural_violation_loss(ret, batch, value, self.config) + return ret + + def _generate_affines(self, representations, batch): + """Generate predicted affines for a single chain. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" + + This is the main part of the structure module - it iteratively applies + folding to produce a set of predicted residue positions. + + Args: + representations: Representations dictionary. + batch: Batch dictionary. + + Returns: + A dictionary containing residue affines and sidechain positions. + """ + seq_mask = paddle.unsqueeze(batch['seq_mask'], axis=-1) + + single_act = self.single_layer_norm(representations['single']) + + init_single_act = single_act + single_act = self.initial_projection(single_act) + pair_act = self.pair_layer_norm(representations['pair']) + affine = generate_new_affine(seq_mask) + + outputs = [] + activations = {'act': single_act, 'affine': affine.to_tensor()} + for _ in range(self.config.num_layer): + activations, output = self.fold_iteration( + activations, init_single_act, pair_act, + seq_mask, batch['aatype'].astype(dtype='int32')) + outputs.append(output) + + output = dict() + for k in outputs[0].keys(): + if k == 'sc': + output[k] = dict() + for l in outputs[0][k].keys(): + output[k][l] = paddle.stack([o[k][l] for o in outputs]) + else: + output[k] = paddle.stack([o[k] for o in outputs]) + + output['act'] = activations['act'] + return output + + +def compute_renamed_ground_truth( + batch: Dict[str, paddle.Tensor], + atom14_pred_positions: paddle.Tensor) -> Dict[str, paddle.Tensor]: + """Find optimal renaming of ground truth based on the predicted positions. + + Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" + + This renamed ground truth is then used for all losses, + such that each loss moves the atoms in the same direction. + Shape (B, N). + + Args: + batch: Dictionary containing: + * atom14_gt_positions: Ground truth positions. + * atom14_alt_gt_positions: Ground truth positions with renaming swaps. + * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by + renaming swaps. + * atom14_gt_exists: Mask for which atoms exist in ground truth. + * atom14_alt_gt_exists: Mask for which atoms exist in ground truth + after renaming. + * atom14_atom_exists: Mask for whether each atom is part of the given + amino acid type. + atom14_pred_positions: Array of atom positions in global frame with shape + (B, N, 14, 3). + + Returns: + Dictionary containing: + alt_naming_is_better: Array with 1.0 where alternative swap is better. + renamed_atom14_gt_positions: Array of optimal ground truth positions + after renaming swaps are performed. + renamed_atom14_gt_exists: Mask after renaming swap is performed. + """ + atom14_gt_positions_pd = paddle.to_tensor(batch['atom14_gt_positions']) + atom14_alt_gt_positions_pd = paddle.to_tensor(batch['atom14_alt_gt_positions']) + atom14_atom_is_ambiguous_pd = paddle.to_tensor(batch['atom14_atom_is_ambiguous']) + atom14_gt_exists_pd = paddle.to_tensor(batch['atom14_gt_exists']) + atom14_atom_exists_pd = paddle.to_tensor(batch['atom14_atom_exists']) + # (B, N) + alt_naming_is_better = all_atom.find_optimal_renaming( + atom14_gt_positions=atom14_gt_positions_pd, + atom14_alt_gt_positions=atom14_alt_gt_positions_pd, + atom14_atom_is_ambiguous=atom14_atom_is_ambiguous_pd, + atom14_gt_exists=atom14_gt_exists_pd, + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=atom14_atom_exists_pd) + + renamed_atom14_gt_positions = ( + (1. - alt_naming_is_better[:, :, None, None]) + * atom14_gt_positions_pd + + alt_naming_is_better[:, :, None, None] + * atom14_alt_gt_positions_pd + ) + + tmp_atom14_alt_gt_exists = paddle.to_tensor(batch['atom14_alt_gt_exists']) + + renamed_atom14_gt_mask = ( + (1. - alt_naming_is_better[:, :, None]) * atom14_gt_exists_pd + + alt_naming_is_better[:, :, None] * tmp_atom14_alt_gt_exists) + + return { + 'alt_naming_is_better': alt_naming_is_better, # (B, N) + 'renamed_atom14_gt_positions': renamed_atom14_gt_positions, # (B, N, 14, 3) + 'renamed_atom14_gt_exists': renamed_atom14_gt_mask, # (B, N, 14) + } + + +def backbone_loss(ret, batch, value, config): + """Backbone FAPE Loss. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" line 17 + + Args: + ret: Dictionary to write outputs into, needs to contain 'loss'. + batch: Batch, needs to contain 'backbone_affine_tensor', + 'backbone_affine_mask'. + value: Dictionary containing structure module output, needs to contain + 'traj', a trajectory of rigids. + config: Configuration of loss, should contain 'fape.clamp_distance' and + 'fape.loss_unit_distance'. + """ + affine_trajectory = quat_affine.QuatAffine.from_tensor(value['traj']) + rigid_trajectory = r3.rigids_from_quataffine(affine_trajectory) + + gt_rot = paddle.to_tensor(batch['backbone_affine_tensor_rot'], dtype='float32') + gt_trans = paddle.to_tensor(batch['backbone_affine_tensor_trans'], dtype='float32') + gt_affine = quat_affine.QuatAffine( + quaternion=None, + translation=gt_trans, + rotation=gt_rot) + gt_rigid = r3.rigids_from_quataffine(gt_affine) + backbone_mask = batch['backbone_affine_mask'] + backbone_mask = paddle.to_tensor(backbone_mask) + + fape_loss_fn = functools.partial( + all_atom.frame_aligned_point_error, + l1_clamp_distance=config.fape.clamp_distance, + length_scale=config.fape.loss_unit_distance) + + fape_loss = [] + index = 0 + for rigid_trajectory_rot_item,rigid_trajectory_trans_item in zip(rigid_trajectory.rot,rigid_trajectory.trans): + rigid_trajectory_item = r3.Rigids(rigid_trajectory_rot_item, rigid_trajectory_trans_item) + index+=1 + middle_fape_loss = fape_loss_fn(rigid_trajectory_item, gt_rigid, backbone_mask, + rigid_trajectory_trans_item, gt_rigid.trans, + backbone_mask) + fape_loss.append(middle_fape_loss) + fape_loss = paddle.stack(fape_loss) + + if 'use_clamped_fape' in batch: + # Jumper et al. (2021) Suppl. Sec. 1.11.5 "Loss clamping details" + use_clamped_fape = batch['use_clamped_fape'][0, 0] + + unclamped_fape_loss_fn = functools.partial( + all_atom.frame_aligned_point_error, + l1_clamp_distance=None, + length_scale=config.fape.loss_unit_distance) + + fape_loss_unclamped = [] + index_t = 0 + for rigid_trajectory_rot_item_t, rigid_trajectory_trans_item_t in zip(rigid_trajectory.rot, rigid_trajectory.trans): + rigid_trajectory_item_t = r3.Rigids(rigid_trajectory_rot_item_t, rigid_trajectory_trans_item_t) + index_t+=1 + middle_fape_loss_t = unclamped_fape_loss_fn(rigid_trajectory_item_t, gt_rigid, backbone_mask, + rigid_trajectory_trans_item_t, gt_rigid.trans, + backbone_mask) + fape_loss_unclamped.append(middle_fape_loss_t) + fape_loss_unclamped = paddle.stack(fape_loss_unclamped) + + fape_loss = (fape_loss * use_clamped_fape + fape_loss_unclamped * (1 - use_clamped_fape)) + + ret['fape'] = fape_loss[-1] + ret['backbone_fape'] = paddle.mean(fape_loss) + ret['loss'] += paddle.mean(fape_loss) + + +def sidechain_loss(batch, value, config): + """All Atom FAPE Loss using renamed rigids.""" + # Rename Frames + # Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" line 7 + alt_naming_is_better = value['alt_naming_is_better'] + + renamed_gt_frames = ( + (1. - alt_naming_is_better[:, :, None, None]) + * batch['rigidgroups_gt_frames'] + + alt_naming_is_better[:, :, None, None] + * batch['rigidgroups_alt_gt_frames']) + + batch_size = renamed_gt_frames.shape[0] + flat_gt_frames = r3.rigids_from_tensor_flat12( + paddle.reshape(renamed_gt_frames, [batch_size, -1, 12])) + flat_frames_mask = paddle.reshape(batch['rigidgroups_gt_exists'], [batch_size, -1]) + + flat_gt_positions = r3.vecs_from_tensor( + paddle.reshape(value['renamed_atom14_gt_positions'], [batch_size, -1, 3])) + flat_positions_mask = paddle.reshape(value['renamed_atom14_gt_exists'], [batch_size, -1]) + + # Compute frame_aligned_point_error score for the final layer. + pred_frames_rot = value['sidechains']['frames_rot'] + pred_frames_trans = value['sidechains']['frames_trans'] + tmp_rots = paddle.reshape(pred_frames_rot[-1], [batch_size, -1, 3, 3]) + tmp_vecs = paddle.reshape(pred_frames_trans[-1], [batch_size, -1, 3]) + tmp_rots = r3.rots_from_tensor3x3(tmp_rots) + tmp_vecs = r3.vecs_from_tensor(tmp_vecs) + flat_pred_frames = r3.Rigids(rot=tmp_rots, trans=tmp_vecs) + + pred_positions = value['sidechains']['atom_pos'] + pred_positions = paddle.reshape(pred_positions[-1], [batch_size, -1, 3]) + flat_pred_positions = r3.vecs_from_tensor(pred_positions) + + # FAPE Loss on sidechains + fape = all_atom.frame_aligned_point_error( + pred_frames=flat_pred_frames, + target_frames=flat_gt_frames, + frames_mask=flat_frames_mask, + pred_positions=flat_pred_positions, + target_positions=flat_gt_positions, + positions_mask=flat_positions_mask, + l1_clamp_distance=config.sidechain.atom_clamp_distance, + length_scale=config.sidechain.length_scale) + + return { + 'fape': fape, + 'loss': fape} + + +def structural_violation_loss(ret, batch, value, config): + """Computes loss for structural violations.""" + assert config.sidechain.weight_frac + + # Put all violation losses together to one large loss. + violations = value['violations'] + num_atoms = paddle.sum(batch['atom14_atom_exists'], dtype='float32') + ret['loss'] += (config.structural_violation_loss_weight * ( + violations['between_residues']['bonds_c_n_loss_mean'] + + violations['between_residues']['angles_ca_c_n_loss_mean'] + + violations['between_residues']['angles_c_n_ca_loss_mean'] + + paddle.sum( + violations['between_residues']['clashes_per_atom_loss_sum'] + + violations['within_residues']['per_atom_loss_sum']) / + (1e-6 + num_atoms))) + + +def find_structural_violations( + batch: Dict[str, paddle.Tensor], + atom14_pred_positions: paddle.Tensor, # (B, N, 14, 3) + config: ml_collections.ConfigDict): + """Computes several checks for structural violations.""" + + # Compute between residue backbone violations of bonds and angles. + connection_violations = all_atom.between_residue_bond_loss( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=batch['atom14_atom_exists'], + residue_index=paddle.cast(batch['residue_index'], 'float32'), + aatype=batch['aatype_index'], + tolerance_factor_soft=config.violation_tolerance_factor, + tolerance_factor_hard=config.violation_tolerance_factor) + + # Compute the Van der Waals radius for every atom + # (the first letter of the atom name is the element type). + # Shape: (B, N, 14). + temp_atomtype_radius = np.array([ + residue_constants.van_der_waals_radius[name[0]] + for name in residue_constants.atom_types + ]) + atomtype_radius = paddle.to_tensor(temp_atomtype_radius) + atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather( + atomtype_radius, batch['residx_atom14_to_atom37']) + + # Compute the between residue clash loss. + between_residue_clashes = all_atom.between_residue_clash_loss( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch['atom14_atom_exists'], + atom14_atom_radius=atom14_atom_radius, + residue_index=paddle.cast(batch['residue_index'], 'float32'), + overlap_tolerance_soft=config.clash_overlap_tolerance, + overlap_tolerance_hard=config.clash_overlap_tolerance) + + # Compute all within-residue violations (clashes, + # bond length and angle violations). + restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( + overlap_tolerance=config.clash_overlap_tolerance, + bond_length_tolerance_factor=config.violation_tolerance_factor) + atom14_dists_lower_bound = utils.batched_gather( + paddle.to_tensor(restype_atom14_bounds['lower_bound']), batch['aatype_index']) + atom14_dists_upper_bound = utils.batched_gather( + paddle.to_tensor(restype_atom14_bounds['upper_bound']), batch['aatype_index']) + within_residue_violations = all_atom.within_residue_violations( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch['atom14_atom_exists'], + atom14_dists_lower_bound=atom14_dists_lower_bound, + atom14_dists_upper_bound=atom14_dists_upper_bound, + tighten_bounds_for_loss=0.0) + + # Combine them to a single per-residue violation mask (used later for LDDT). + per_residue_violations_mask = paddle.max(paddle.stack([ + connection_violations['per_residue_violation_mask'], + paddle.max(between_residue_clashes['per_atom_clash_mask'], axis=-1), + paddle.max(within_residue_violations['per_atom_violations'], axis=-1)]), axis=0) + + return { + 'between_residues': { + 'bonds_c_n_loss_mean': + connection_violations['c_n_loss_mean'], # (B) + 'angles_ca_c_n_loss_mean': + connection_violations['ca_c_n_loss_mean'], # (B) + 'angles_c_n_ca_loss_mean': + connection_violations['c_n_ca_loss_mean'], # (B) + 'connections_per_residue_loss_sum': + connection_violations['per_residue_loss_sum'], # (B, N) + 'connections_per_residue_violation_mask': + connection_violations['per_residue_violation_mask'], # (B, N) + 'clashes_mean_loss': + between_residue_clashes['mean_loss'], # (B) + 'clashes_per_atom_loss_sum': + between_residue_clashes['per_atom_loss_sum'], # (B, N, 14) + 'clashes_per_atom_clash_mask': + between_residue_clashes['per_atom_clash_mask'], # (B, N, 14) + }, + 'within_residues': { + 'per_atom_loss_sum': + within_residue_violations['per_atom_loss_sum'], # (B, N, 14) + 'per_atom_violations': + within_residue_violations['per_atom_violations'], # (B, N, 14), + }, + 'total_per_residue_violations_mask': + per_residue_violations_mask, # (B, N) + } + + +def compute_violation_metrics( + batch: Dict[str, paddle.Tensor], + atom14_pred_positions: paddle.Tensor, # (B, N, 14, 3) + violations: Dict[str, paddle.Tensor]) -> Dict[str, paddle.Tensor]: + """Compute several metrics to assess the structural violations.""" + batch_size = atom14_pred_positions.shape[0] + ret = {} + extreme_ca_ca_violations = all_atom.extreme_ca_ca_distance_violations( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=paddle.cast(batch['atom14_atom_exists'], 'float32'), + residue_index=paddle.cast(batch['residue_index'], 'float32')) + ret['violations_extreme_ca_ca_distance'] = extreme_ca_ca_violations + + violations_between_residue_bond_tmp = [] + for i in range(batch_size): + violations_between_residue_bond_i = utils.mask_mean(mask=batch['seq_mask'][i], + value=violations['between_residues']['connections_per_residue_violation_mask'][i]) + violations_between_residue_bond_tmp.append(violations_between_residue_bond_i) + violations_between_residue_bond = paddle.to_tensor(violations_between_residue_bond_tmp, + stop_gradient=False) + violations_between_residue_bond = paddle.squeeze(violations_between_residue_bond, axis=-1) + ret['violations_between_residue_bond'] = violations_between_residue_bond + + violations_between_residue_clash_tmp = [] + for i in range(batch_size): + violations_between_residue_clash_i = utils.mask_mean(mask=batch['seq_mask'][i], + value=paddle.max(violations['between_residues']['clashes_per_atom_clash_mask'], + axis=-1)[i]) + violations_between_residue_clash_tmp.append(violations_between_residue_clash_i) + violations_between_residue_clash = paddle.to_tensor(violations_between_residue_clash_tmp, + stop_gradient=False) + violations_between_residue_clash = paddle.squeeze(violations_between_residue_clash, axis=-1) + ret['violations_between_residue_clash'] = violations_between_residue_clash + + violations_within_residue_tmp = [] + for i in range(batch_size): + violations_within_residue_i = utils.mask_mean(mask=batch['seq_mask'][i], + value=paddle.max(violations['within_residues']['per_atom_violations'], axis=-1)[i]) + violations_within_residue_tmp.append(violations_within_residue_i) + violations_within_residue = paddle.to_tensor(violations_within_residue_tmp, + dtype='float32', stop_gradient=False) + violations_within_residue = paddle.squeeze(violations_within_residue, axis=-1) + ret['violations_within_residue'] = violations_within_residue + + violations_per_residue_tmp = [] + for i in range(batch_size): + violations_per_residue_i = utils.mask_mean(mask=batch['seq_mask'][i], + value=violations['total_per_residue_violations_mask'][i]) + violations_per_residue_tmp.append(violations_per_residue_i) + violations_per_residue = paddle.to_tensor(violations_per_residue_tmp, + dtype='float32', stop_gradient=False) + violations_per_residue = paddle.squeeze(violations_per_residue, axis=-1) + ret['violations_per_residue'] = violations_per_residue + return ret + + +def supervised_chi_loss(ret, batch, value, config): + """Computes loss for direct chi angle supervision. + + Jumper et al. (2021) Suppl. Alg. 27 "torsionAngleLoss" + + Args: + ret: Dictionary to write outputs into, needs to contain 'loss'. + batch: Batch, needs to contain 'seq_mask', 'chi_mask', 'chi_angles'. + value: Dictionary containing structure module output, needs to contain + value['sidechains']['angles_sin_cos'] for angles and + value['sidechains']['unnormalized_angles_sin_cos'] for unnormalized + angles. + config: Configuration of loss, should contain 'chi_weight' and + 'angle_norm_weight', 'angle_norm_weight' scales angle norm term, + 'chi_weight' scales torsion term. + """ + eps = 1e-6 + + sequence_mask = batch['seq_mask'] + num_res = sequence_mask.shape[1] + batch_size = sequence_mask.shape[0] + chi_mask = batch['chi_mask'] + pred_angles = paddle.reshape(value['sidechains']['angles_sin_cos'], [batch_size, -1, num_res, 7, 2]) + pred_angles = pred_angles[:, :, :, 3:] + + residue_type_one_hot = paddle.nn.functional.one_hot(batch['aatype_index'], + num_classes=residue_constants.restype_num + 1) + chi_pi_periodic = paddle.einsum('nijk,nkl->nijl', residue_type_one_hot[:, None, ...], + paddle.to_tensor(residue_constants.chi_pi_periodic)[None]) + + sin_cos_true_chi = batch['chi_angles_sin_cos'][:, None, ...] + + # This is -1 if chi is pi-periodic and +1 if it's 2pi-periodic + shifted_mask = (1 - 2 * chi_pi_periodic)[..., None] + sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi + + sq_chi_error = paddle.sum(squared_difference(sin_cos_true_chi, pred_angles), axis=-1) + sq_chi_error_shifted = paddle.sum(squared_difference(sin_cos_true_chi_shifted, pred_angles), axis=-1) + sq_chi_error = paddle.minimum(sq_chi_error, sq_chi_error_shifted) + + sq_chi_loss_tmp = [] + for i in range(batch_size): + sq_chi_loss_i = utils.mask_mean(mask=paddle.unsqueeze(chi_mask[i], axis=0), value=sq_chi_error[i]) + sq_chi_loss_tmp.append(sq_chi_loss_i) + sq_chi_loss = paddle.to_tensor(sq_chi_loss_tmp, stop_gradient=False) + sq_chi_loss = paddle.squeeze(sq_chi_loss, axis=-1) + ret['chi_loss'] = sq_chi_loss + ret['loss'] += config.chi_weight * sq_chi_loss + + unnormed_angles = paddle.reshape(value['sidechains']['unnormalized_angles_sin_cos'], [batch_size, -1, num_res, 7, 2]) + angle_norm = paddle.sqrt(paddle.sum(paddle.square(unnormed_angles), axis=-1) + eps) + norm_error = paddle.abs(angle_norm - 1.) + angle_norm_loss_tmp = [] + for i in range(batch_size): + angle_norm_loss_i = utils.mask_mean(mask=paddle.unsqueeze(sequence_mask[i], axis=[0,2]), value=norm_error[i]) + angle_norm_loss_tmp.append(angle_norm_loss_i) + angle_norm_loss = paddle.to_tensor(angle_norm_loss_tmp, stop_gradient=False) + angle_norm_loss = paddle.squeeze(angle_norm_loss, axis=-1) + ret['angle_norm_loss'] = angle_norm_loss + ret['loss'] += config.angle_norm_weight * angle_norm_loss + + +def generate_new_affine(sequence_mask): + t_shape = sequence_mask.shape[:-1] # (batch, N_res, 1) + assert len(t_shape) == 2 + t_shape.append(3) # (batch, N_res, 3) + q_shape = sequence_mask.shape[:-1] + [1] # (batch, N_res, 1) + quaternion = paddle.tile( + paddle.reshape( + paddle.to_tensor([1.0, 0.0, 0.0, 0.0]), [1, 1, 4]), + repeat_times=q_shape) + translation = paddle.zeros(t_shape) + return quat_affine.QuatAffine(quaternion, translation) + + +def l2_normalize(x, axis=-1, epsilon=1e-12): + return x / paddle.sqrt( + paddle.maximum( + paddle.sum(paddle.square(x), axis=axis, keepdim=True), + paddle.to_tensor([epsilon], dtype='float32'))) + + +class MultiRigidSidechain(nn.Layer): + """Class to make side chain atoms.""" + def __init__(self, channel_num, config, global_config): + super(MultiRigidSidechain, self).__init__() + + self.channel_num = channel_num + self.config = config + self.global_config = global_config + + c = self.config.num_channel + self.input_projection = nn.Linear(channel_num['seq_channel'], c) + self.input_projection_1 = nn.Linear(channel_num['seq_channel'], c) + + for i in range(self.config.num_residual_block): + l1, l2 = 'resblock1', 'resblock2' + if i > 0: + l1, l2 = f'resblock1_{i}', f'resblock2_{i}' + + init_w_1 = nn.initializer.KaimingNormal() + if self.global_config.zero_init: + init_w_2 = nn.initializer.Constant(value=0.) + else: + init_w_2 = nn.initializer.XavierUniform() + + setattr(self, l1, nn.Linear( + c, c, weight_attr=paddle.ParamAttr(initializer=init_w_1))) + setattr(self, l2, nn.Linear( + c, c, weight_attr=paddle.ParamAttr(initializer=init_w_2))) + + self.unnormalized_angles = nn.Linear(c, 14) + + def forward(self, affine, single_act, init_single_act, aatype): + single_act = self.input_projection(nn.functional.relu(single_act)) + init_single_act = self.input_projection_1( + nn.functional.relu(init_single_act)) + act = single_act + init_single_act + + for i in range(self.config.num_residual_block): + l1, l2 = 'resblock1', 'resblock2' + if i > 0: + l1, l2 = f'resblock1_{i}', f'resblock2_{i}' + + old_act = act + act = getattr(self, l1)(nn.functional.relu(act)) + act = getattr(self, l2)(nn.functional.relu(act)) + act += old_act + + # Map activations to torsion angles. Shape: (num_res, 14). + num_res = act.shape[1] + unnormalized_angles = self.unnormalized_angles( + nn.functional.relu(act)) + unnormalized_angles = paddle.reshape( + unnormalized_angles, [-1, num_res, 7, 2]) + angles = l2_normalize(unnormalized_angles, axis=-1) + + outputs = { + 'angles_sin_cos': angles, # (B, N, 7, 2) + 'unnormalized_angles_sin_cos': + unnormalized_angles, # (B, N, 7, 2) + } + + # Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" + backbone_to_global = r3.rigids_from_quataffine(affine) + all_frames_to_global = all_atom.torsion_angles_to_frames( + aatype, backbone_to_global, angles) + pred_positions = all_atom.frames_and_literature_positions_to_atom14_pos( + aatype, all_frames_to_global) + + # Outputs1 (Rot + Trans) + outputs.update({ + 'atom_pos': pred_positions.translation, # (B, N, 14, 3) + 'frames_rot': all_frames_to_global.rot.rotation, # (B, N, 8, 3, 3) + 'frames_trans': all_frames_to_global.trans.translation, # (B, N, 8, 3) + }) + + # ## Outputs2 (Rigids) + # outputs.update({ + # 'atom_pos': pred_positions.translation, # (B, N, 14, 3) + # 'frames': all_frames_to_global, # (B, N, 8, 3, 3) + # }) + + return outputs diff --git a/apps/protein_folding/helixfold_cpu/tools/input_pipeline.py b/apps/protein_folding/helixfold_cpu/tools/input_pipeline.py new file mode 100644 index 00000000..25fca582 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/input_pipeline.py @@ -0,0 +1,168 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature pre-processing input pipeline for AlphaFold.""" + +from tools import data_transforms +from tools import shape_placeholders +import tensorflow.compat.v1 as tf +import tree + +# Pylint gets confused by the curry1 decorator because it changes the number +# of arguments to the function. +# pylint:disable=no-value-for-parameter + + +NUM_RES = shape_placeholders.NUM_RES +NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ +NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ +NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES + + +def nonensembled_map_fns(data_config): + """Input pipeline functions which are not ensembled.""" + common_cfg = data_config.common + + map_fns = [ + data_transforms.correct_msa_restypes, + data_transforms.add_distillation_flag(False), + data_transforms.cast_64bit_ints, + data_transforms.squeeze_features, + # Keep to not disrupt RNG. + data_transforms.randomly_replace_msa_with_unknown(0.0), + data_transforms.make_seq_mask, + data_transforms.make_msa_mask, + # Compute the HHblits profile if it's not set. This has to be run before + # sampling the MSA. + data_transforms.make_hhblits_profile, + data_transforms.make_random_crop_to_size_seed, + ] + if common_cfg.use_templates: + map_fns.extend([ + data_transforms.fix_templates_aatype, + data_transforms.make_template_mask, + data_transforms.make_pseudo_beta('template_') + ]) + map_fns.extend([ + data_transforms.make_atom14_masks, + ]) + + return map_fns + + +def ensembled_map_fns(data_config): + """Input pipeline functions that can be ensembled and averaged.""" + common_cfg = data_config.common + eval_cfg = data_config.eval + + map_fns = [] + + if common_cfg.reduce_msa_clusters_by_max_templates: + pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates + else: + pad_msa_clusters = eval_cfg.max_msa_clusters + + max_msa_clusters = pad_msa_clusters + max_extra_msa = common_cfg.max_extra_msa + + if eval_cfg.delete_msa_block: + map_fns.append(data_transforms.block_delete_msa(eval_cfg)) + + map_fns.append( + data_transforms.sample_msa( + max_msa_clusters, + keep_extra=True)) + + if 'masked_msa' in common_cfg: + # Masked MSA should come *before* MSA clustering so that + # the clustering and full MSA profile do not leak information about + # the masked locations and secret corrupted locations. + map_fns.append( + data_transforms.make_masked_msa(common_cfg.masked_msa, + eval_cfg.masked_msa_replace_fraction)) + + if common_cfg.msa_cluster_features: + map_fns.append(data_transforms.nearest_neighbor_clusters()) + map_fns.append(data_transforms.summarize_clusters()) + + # Crop after creating the cluster profiles. + if max_extra_msa: + map_fns.append(data_transforms.crop_extra_msa(max_extra_msa)) + else: + map_fns.append(data_transforms.delete_extra_msa) + + map_fns.append(data_transforms.make_msa_feat()) + + crop_feats = dict(eval_cfg.feat) + + if eval_cfg.fixed_size: + map_fns.append(data_transforms.select_feat(list(crop_feats))) + map_fns.append(data_transforms.random_crop_to_size( + eval_cfg.crop_size, + eval_cfg.max_templates, + crop_feats, + eval_cfg.subsample_templates)) + map_fns.append(data_transforms.make_fixed_size( + crop_feats, + pad_msa_clusters, + common_cfg.max_extra_msa, + eval_cfg.crop_size, + eval_cfg.max_templates)) + else: + map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates)) + + return map_fns + + +def process_tensors_from_config(tensors, data_config): + """Apply filters and maps to an existing dataset, based on the config.""" + + def wrap_ensemble_fn(data, i): + """Function to be mapped over the ensemble dimension.""" + d = data.copy() + fns = ensembled_map_fns(data_config) + fn = compose(fns) + d['ensemble_index'] = i + return fn(d) + + eval_cfg = data_config.eval + tensors = compose( + nonensembled_map_fns( + data_config))( + tensors) + + tensors_0 = wrap_ensemble_fn(tensors, tf.constant(0)) + num_ensemble = eval_cfg.num_ensemble + if data_config.common.resample_msa_in_recycling: + # Separate batch per ensembling & recycling step. + num_ensemble *= data_config.common.num_recycle + 1 + + if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1: + fn_output_signature = tree.map_structure( + tf.TensorSpec.from_tensor, tensors_0) + tensors = tf.map_fn( + lambda x: wrap_ensemble_fn(tensors, x), + tf.range(num_ensemble), + parallel_iterations=1, + fn_output_signature=fn_output_signature) + else: + tensors = tree.map_structure(lambda x: x[None], + tensors_0) + return tensors + + +@data_transforms.curry1 +def compose(x, fs): + for f in fs: + x = f(x) + return x diff --git a/apps/protein_folding/helixfold_cpu/tools/input_utils b/apps/protein_folding/helixfold_cpu/tools/input_utils new file mode 100644 index 00000000..264dc1cf --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/input_utils @@ -0,0 +1,47 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities for various components.""" +import tensorflow.compat.v1 as tf + + +def tf_combine_mask(*masks): + """Take the intersection of float-valued masks.""" + ret = 1 + for m in masks: + ret *= m + return ret + + +class SeedMaker(object): + """Return unique seeds.""" + + def __init__(self, initial_seed=0): + self.next_seed = initial_seed + + def __call__(self): + i = self.next_seed + self.next_seed += 1 + return i + +seed_maker = SeedMaker() + + +def make_random_seed(): + return tf.random.uniform([2], + tf.int32.min, + tf.int32.max, + tf.int32, + seed=seed_maker()) + diff --git a/apps/protein_folding/helixfold_cpu/tools/lddt.py b/apps/protein_folding/helixfold_cpu/tools/lddt.py new file mode 100644 index 00000000..5962bd75 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/lddt.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""lDDT protein distance score.""" + +import paddle + + +def lddt(predicted_points, + true_points, + true_points_mask, + cutoff=15., + per_residue=False): + """Measure (approximate) lDDT for a batch of coordinates. + + lDDT reference: + Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local + superposition-free score for comparing protein structures and models using + distance difference tests. Bioinformatics 29, 2722–2728 (2013). + + lDDT is a measure of the difference between the true distance matrix and the + distance matrix of the predicted points. The difference is computed only on + points closer than cutoff *in the true structure*. + + This function does not compute the exact lDDT value that the original paper + describes because it does not include terms for physical feasibility + (e.g. bond length violations). Therefore this is only an approximate + lDDT score. + + Args: + predicted_points: (batch, length, 3) array of predicted 3D points + true_points: (batch, length, 3) array of true 3D points + true_points_mask: (batch, length, 1) binary-valued float array. This mask + should be 1 for points that exist in the true points. + cutoff: Maximum distance for a pair of points to be included + per_residue: If true, return score for each residue. Note that the overall + lDDT is not exactly the mean of the per_residue lDDT's because some + residues have more contacts than others. + + Returns: + An (approximate, see above) lDDT score in the range 0-1. + """ + + assert len(predicted_points.shape) == 3 + assert predicted_points.shape[-1] == 3 + assert true_points_mask.shape[-1] == 1 + assert len(true_points_mask.shape) == 3 + + # Compute true and predicted distance matrices. + dmat_true = paddle.sqrt(1e-10 + paddle.sum( + (true_points[:, :, None] - true_points[:, None, :])**2, axis=-1)) + + dmat_predicted = paddle.sqrt(1e-10 + paddle.sum( + (predicted_points[:, :, None] - + predicted_points[:, None, :])**2, axis=-1)) + + cutoff = paddle.to_tensor(cutoff) + + dists_to_score = ( + paddle.cast((dmat_true < cutoff), 'float32') * true_points_mask * + paddle.transpose(true_points_mask, [0, 2, 1]) * + (1. - paddle.eye(dmat_true.shape[1])) # Exclude self-interaction. + ) + + # Shift unscored distances to be far away. + dist_l1 = paddle.abs(dmat_true - dmat_predicted) + + # True lDDT uses a number of fixed bins. + # We ignore the physical plausibility correction to lDDT, though. + score = 0.25 * (paddle.cast((dist_l1 < 0.5), 'float32') + + paddle.cast((dist_l1 < 1.0), 'float32') + + paddle.cast((dist_l1 < 2.0), 'float32') + + paddle.cast((dist_l1 < 4.0), 'float32')) + + # Normalize over the appropriate axes. + reduce_axes = (-1,) if per_residue else (-2, -1) + norm = 1. / (1e-10 + paddle.sum(dists_to_score, axis=reduce_axes)) + score = norm * (1e-10 + paddle.sum(dists_to_score * score, axis=reduce_axes)) + + return score diff --git a/apps/protein_folding/helixfold_cpu/tools/model_utils.py b/apps/protein_folding/helixfold_cpu/tools/model_utils.py new file mode 100644 index 00000000..03f8003e --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/model_utils.py @@ -0,0 +1,303 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils.""" + +import numbers +import functools +import collections +import paddle +import numpy as np +from typing import Any, Mapping + +import protein +import confidence + + +def jax_params_to_paddle(params): + """ + Rule 1: alphafold/alphafold_iteration/evoformer/template_embedding/single_template_embedding/template_pair_stack/* ==> + '...template_pair_stack.0.*' + '...template_pair_stack.1.*' + ... + + Rule 2: alphafold/alphafold_iteration/evoformer/extra_msa_stack/* ==> + 'alphafold_iteration.evoformer.extra_msa_stack.0.*', + 'alphafold_iteration.evoformer.extra_msa_stack.1.*', + ... + + Rule 3: alphafold/alphafold_iteration/evoformer/evoformer_iteration/* ==> + 'alphafold.alphafold_iteration.evoformer.evoformer_iteration.0.*', + 'alphafold.alphafold_iteration.evoformer.evoformer_iteration.1.*', + ... + + Rule 4: */__layer_stack_no_state/* ==> '*.*' + + Rule 5: *//weights ==> '*.weight' + + Rule 6: *//bias ==> '*.bias' + + Rule 7: *//scale ==> '*.weight' + + Rule 8: *//offset ==> '*.bias' + """ + rule_1_prefix = 'alphafold/alphafold_iteration/evoformer/template_embedding/single_template_embedding/template_pair_stack/' + rule_2_prefix = 'alphafold/alphafold_iteration/evoformer/extra_msa_stack/' + rule_3_prefix = 'alphafold/alphafold_iteration/evoformer/evoformer_iteration/' + rule_4_prefix = '__layer_stack_no_state/' + + pd_params = dict() + + def _parse_stack_or_iteration(rule_prefix, k): + n = params[k].shape[0] + suffix = k[len(rule_prefix):] + + # rule 4 + if suffix.startswith(rule_4_prefix): + suffix = suffix[len(rule_4_prefix):] + + # rule 5 + suffix = suffix.replace('//weights', '.weight') + # rule 6 + suffix = suffix.replace('//bias', '.bias') + # rule 7 + suffix = suffix.replace('//scale', '.weight') + # rule 8 + suffix = suffix.replace('//offset', '.bias') + + suffix = suffix.replace('//', '.') + suffix = suffix.replace('/', '.') + + prefix = rule_prefix.replace('/', '.') + for i in range(n): + k_ = f'{prefix}{i}.{suffix}' + pd_params[k_] = np.copy(params[k][i]) + + for k in params.keys(): + if k.startswith(rule_1_prefix): + _parse_stack_or_iteration(rule_1_prefix, k) + + elif k.startswith(rule_2_prefix): + _parse_stack_or_iteration(rule_2_prefix, k) + + elif k.startswith(rule_3_prefix): + _parse_stack_or_iteration(rule_3_prefix, k) + + else: + k_ = k.replace('//weights', '.weight') + k_ = k_.replace('//scale', '.weight') + k_ = k_.replace('//offset', '.bias') + k_ = k_.replace('//', '.') + k_ = k_.replace('/', '.') + pd_params[k_] = np.copy(params[k]) + + return pd_params + + +def slice_batch(batch, i): + b = {k: v[i] for k, v in batch.items()} + return b + +def add_batch_dim(batch): + b = {k: v[None,] for k, v in batch.items()} + return b + +def map_to_tensor(batch, add_batch=False): + if add_batch: + batch = add_batch_dim(batch) + + b = {k: paddle.to_tensor(v) for k, v in batch.items()} + return b + + +def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10): + if drop_mask_channel: + mask = mask[:, 0] + + mask_shape = mask.shape + value_shape = value.shape + assert len(mask_shape) == len(value_shape) + + if isinstance(axis, numbers.Integral): + axis = [axis] + elif axis is None: + axis = list(range(len(mask_shape))) + + assert isinstance(axis, collections.Iterable), \ + 'axis needs to be either an iterable, integer or "None"' + + broadcast_factor = 1. + for axis_ in axis: + value_size = value_shape[axis_] + mask_size = mask_shape[axis_] + if mask_size == 1: + broadcast_factor *= value_size + else: + assert mask_size == value_size + + return (paddle.sum(mask * value, axis=axis) / + (paddle.sum(mask, axis=axis) * broadcast_factor + eps)) + + +def batched_gather(params, indices, axis=0, batch_dims=0): + # Implement gather with batching, like tensorflow: + # https://www.tensorflow.org/api_docs/python/tf/gather#batching + # print(params.shape, indices.shape, axis) + p, i = params, indices + rank = len(p.shape) + axis = (rank + axis) % rank + # The stride of axis + stride = p.shape[batch_dims + axis] + + if batch_dims == 0 and len(i.shape) == 1: + return paddle.gather(p, i, axis=axis) + + elif batch_dims == 0: + flat_i = i.reshape([-1]) + gathered = paddle.gather(p, flat_i, axis=axis) + shape = p.shape[:axis] + i.shape + if axis < rank - 1: + shape += params.shape[axis + 1:] + return gathered.reshape(shape) + + b = batch_dims + a = axis + assert p.shape[:b] == i.shape[:b] + bn = np.prod(p.shape[:b]) + + # Shift batch dimensions right to bundle with axis + if a > 0: + perm = list(range(rank)) + perm = perm[b:(b + a)] + perm[:b] + perm[(b + a):] + p = p.transpose(perm) + + # Merge params' batch+axis + p = p.reshape(p.shape[:a] + [-1] + p.shape[(b + a + 1):]) + + # indices = [Batch..., Index...] + # Expand the index values across batch elements + strides = paddle.arange(bn).unsqueeze(-1) * stride + i = i.reshape([bn, -1]) + flat_i = paddle.flatten(i + strides) + + # Do gather + gathered = paddle.gather(p, flat_i, axis=axis) + + # Unbundle batch and index dimensions + unbundled_shape = p.shape[:a] + indices.shape + p.shape[a + 1:] + gathered = gathered.reshape(unbundled_shape) + + # Shift batch dimensions back to the left + if a > 0: + perm = list(range(len(unbundled_shape))) + perm = perm[a:(a + b)] + perm[:a] + perm[(a + b):] + gathered = gathered.transpose(perm) + + return gathered + + +def subbatch(f, arg_idx, dim, bs, out_idx): + """ Converts a function to one that applies to subbatch of an input + dimension. + + Args: + f(Callable): original function. + arg_idx([int]): indices of the inputs to be subbatched. + dim([int]): index of the dimension to be subbatched. + bs(int): subbatch size. + out_idx(int): index of the output dimension that needs stacking + + Returns: + converted function. + """ + @functools.wraps(f) + def wrapper(*args, **kwargs): + assert len(arg_idx) == len(dim), f'Number of batching args and number of batching dims should match.' + + inps = [args[i] for i in arg_idx] + dim_width = [inp.shape[d] for inp, d in zip(inps, dim)] + assert len(set(dim_width)) == 1, f'Batch sizes should be kept equal.' + + inp_dim = {inp: d for inp, d in zip(inps, dim)} + + dim_width = dim_width[0] + if dim_width < bs: + return f(*args, **kwargs) + + outs = [] + for slice_at in np.arange(0, dim_width, bs): + _args = [] + for i, inp in enumerate(args): + if i in arg_idx: + inp = inp.slice([inp_dim[inp]], [slice_at], [slice_at + bs]) + _args.append(inp) + outs.append(f(*_args, **kwargs)) + + return paddle.concat(outs, out_idx) + + return wrapper + + +def get_confidence_metrics( + prediction_result: Mapping[str, Any]) -> Mapping[str, Any]: + """Post processes prediction_result to get confidence metrics.""" + + confidence_metrics = {} + confidence_metrics['plddt'] = confidence.compute_plddt( + prediction_result['predicted_lddt']['logits']) + + if 'predicted_aligned_error' in prediction_result: + confidence_metrics.update(confidence.compute_predicted_aligned_error( + prediction_result['predicted_aligned_error']['logits'], + prediction_result['predicted_aligned_error']['breaks'])) + + confidence_metrics['ptm'] = confidence.predicted_tm_score( + prediction_result['predicted_aligned_error']['logits'], + prediction_result['predicted_aligned_error']['breaks']) + + return confidence_metrics + + +def generate_unrelaxed_pdb(aatype, residue_index, model_output, pdb_path, + b_factors=None): + fold_output = model_output['structure_module'] + if b_factors is None: + b_factors = np.zeros_like(fold_output['final_atom_mask']) + + # NOTE: for single protein, chain_index is always 'A' (idx:0) + prot = protein.Protein( + aatype=aatype, + atom_positions=fold_output['final_atom_positions'], + atom_mask=fold_output['final_atom_mask'], + residue_index=residue_index + 1, + chain_index=np.zeros(aatype.shape), + b_factors=b_factors) + + with open(pdb_path, 'w') as f: + f.write(protein.to_pdb(prot)) + + return prot + + +def set_tensor_constant(tensor, constant): + tensor.set_value(paddle.full_like(tensor, constant)) + + +def init_gate_linear(linear): + set_tensor_constant(linear.weight, 0) + set_tensor_constant(linear.bias, 1) + + +def init_final_linear(linear): + set_tensor_constant(linear.weight, 0) diff --git a/apps/protein_folding/helixfold_cpu/tools/protein_features.py b/apps/protein_folding/helixfold_cpu/tools/protein_features.py new file mode 100644 index 00000000..a68da518 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/protein_features.py @@ -0,0 +1,128 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains descriptions of various protein features.""" +import enum +from typing import Dict, Optional, Sequence, Tuple, Union +from alphafold_paddle.common import residue_constants +import tensorflow.compat.v1 as tf + +# Type aliases. +FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]] + + +class FeatureType(enum.Enum): + ZERO_DIM = 0 # Shape [x] + ONE_DIM = 1 # Shape [num_res, x] + TWO_DIM = 2 # Shape [num_res, num_res, x] + MSA = 3 # Shape [msa_length, num_res, x] + + +# Placeholder values that will be replaced with their true value at runtime. +NUM_RES = "num residues placeholder" +NUM_SEQ = "length msa placeholder" +NUM_TEMPLATES = "num templates placeholder" +# Sizes of the protein features, NUM_RES and NUM_SEQ are allowed as placeholders +# to be replaced with the number of residues and the number of sequences in the +# multiple sequence alignment, respectively. + + +FEATURES = { + #### Static features of a protein sequence #### + "aatype": (tf.float32, [NUM_RES, 21]), + "between_segment_residues": (tf.int64, [NUM_RES, 1]), + "deletion_matrix": (tf.float32, [NUM_SEQ, NUM_RES, 1]), + "domain_name": (tf.string, [1]), + "msa": (tf.int64, [NUM_SEQ, NUM_RES, 1]), + "num_alignments": (tf.int64, [NUM_RES, 1]), + "residue_index": (tf.int64, [NUM_RES, 1]), + "seq_length": (tf.int64, [NUM_RES, 1]), + "sequence": (tf.string, [1]), + "all_atom_positions": (tf.float32, + [NUM_RES, residue_constants.atom_type_num, 3]), + "all_atom_mask": (tf.int64, [NUM_RES, residue_constants.atom_type_num]), + "resolution": (tf.float32, [1]), + "template_domain_names": (tf.string, [NUM_TEMPLATES]), + "template_sum_probs": (tf.float32, [NUM_TEMPLATES, 1]), + "template_aatype": (tf.float32, [NUM_TEMPLATES, NUM_RES, 22]), + "template_all_atom_positions": (tf.float32, [ + NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3 + ]), + "template_all_atom_masks": (tf.float32, [ + NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1 + ]), +} + +FEATURE_TYPES = {k: v[0] for k, v in FEATURES.items()} +FEATURE_SIZES = {k: v[1] for k, v in FEATURES.items()} + + +def register_feature(name: str, + type_: tf.dtypes.DType, + shape_: Tuple[Union[str, int]]): + """Register extra features used in custom datasets.""" + FEATURES[name] = (type_, shape_) + FEATURE_TYPES[name] = type_ + FEATURE_SIZES[name] = shape_ + + +def shape(feature_name: str, + num_residues: int, + msa_length: int, + num_templates: Optional[int] = None, + features: Optional[FeaturesMetadata] = None): + """Get the shape for the given feature name. + + This is near identical to _get_tf_shape_no_placeholders() but with 2 + differences: + * This method does not calculate a single placeholder from the total number of + elements (eg given and size := 12, this won't deduce NUM_RES + must be 4) + * This method will work with tensors + + Args: + feature_name: String identifier for the feature. If the feature name ends + with "_unnormalized", this suffix is stripped off. + num_residues: The number of residues in the current domain - some elements + of the shape can be dynamic and will be replaced by this value. + msa_length: The number of sequences in the multiple sequence alignment, some + elements of the shape can be dynamic and will be replaced by this value. + If the number of alignments is unknown / not read, please pass None for + msa_length. + num_templates (optional): The number of templates in this tfexample. + features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES. + + Returns: + List of ints representation the tensor size. + + Raises: + ValueError: If a feature is requested but no concrete placeholder value is + given. + """ + features = features or FEATURES + if feature_name.endswith("_unnormalized"): + feature_name = feature_name[:-13] + + unused_dtype, raw_sizes = features[feature_name] + replacements = {NUM_RES: num_residues, + NUM_SEQ: msa_length} + + if num_templates is not None: + replacements[NUM_TEMPLATES] = num_templates + + sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes] + for dimension in sizes: + if isinstance(dimension, str): + raise ValueError("Could not parse %s (shape: %s) with values: %s" % ( + feature_name, raw_sizes, replacements)) + return sizes diff --git a/apps/protein_folding/helixfold_cpu/tools/proteins_dataset.py b/apps/protein_folding/helixfold_cpu/tools/proteins_dataset.py new file mode 100644 index 00000000..990950d1 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/proteins_dataset.py @@ -0,0 +1,166 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Datasets consisting of proteins.""" +from typing import Dict, Mapping, Optional, Sequence +from tools import protein_features +import numpy as np +import tensorflow.compat.v1 as tf + +TensorDict = Dict[str, tf.Tensor] + + +def parse_tfexample( + raw_data: bytes, + features: protein_features.FeaturesMetadata, + key: Optional[str] = None) -> Dict[str, tf.train.Feature]: + """Read a single TF Example proto and return a subset of its features. + + Args: + raw_data: A serialized tf.Example proto. + features: A dictionary of features, mapping string feature names to a tuple + (dtype, shape). This dictionary should be a subset of + protein_features.FEATURES (or the dictionary itself for all features). + key: Optional string with the SSTable key of that tf.Example. This will be + added into features as a 'key' but only if requested in features. + + Returns: + A dictionary of features mapping feature names to features. Only the given + features are returned, all other ones are filtered out. + """ + feature_map = { + k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True) + for k, v in features.items() + } + parsed_features = tf.io.parse_single_example(raw_data, feature_map) + reshaped_features = parse_reshape_logic(parsed_features, features, key=key) + + return reshaped_features + + +def _first(tensor: tf.Tensor) -> tf.Tensor: + """Returns the 1st element - the input can be a tensor or a scalar.""" + return tf.reshape(tensor, shape=(-1,))[0] + + +def parse_reshape_logic( + parsed_features: TensorDict, + features: protein_features.FeaturesMetadata, + key: Optional[str] = None) -> TensorDict: + """Transforms parsed serial features to the correct shape.""" + # Find out what is the number of sequences and the number of alignments. + num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32) + + if "num_alignments" in parsed_features: + num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32) + else: + num_msa = 0 + + if "template_domain_names" in parsed_features: + num_templates = tf.cast( + tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32) + else: + num_templates = 0 + + if key is not None and "key" in features: + parsed_features["key"] = [key] # Expand dims from () to (1,). + + # Reshape the tensors according to the sequence length and num alignments. + for k, v in parsed_features.items(): + new_shape = protein_features.shape( + feature_name=k, + num_residues=num_residues, + msa_length=num_msa, + num_templates=num_templates, + features=features) + new_shape_size = tf.constant(1, dtype=tf.int32) + for dim in new_shape: + new_shape_size *= tf.cast(dim, tf.int32) + + assert_equal = tf.assert_equal( + tf.size(v), new_shape_size, + name="assert_%s_shape_correct" % k, + message="The size of feature %s (%s) could not be reshaped " + "into %s" % (k, tf.size(v), new_shape)) + if "template" not in k: + # Make sure the feature we are reshaping is not empty. + assert_non_empty = tf.assert_greater( + tf.size(v), 0, name="assert_%s_non_empty" % k, + message="The feature %s is not set in the tf.Example. Either do not " + "request the feature or use a tf.Example that has the " + "feature set." % k) + with tf.control_dependencies([assert_non_empty, assert_equal]): + parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) + else: + with tf.control_dependencies([assert_equal]): + parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) + + return parsed_features + + +def _make_features_metadata( + feature_names: Sequence[str]) -> protein_features.FeaturesMetadata: + """Makes a feature name to type and shape mapping from a list of names.""" + # Make sure these features are always read. + required_features = ["aatype", "sequence", "seq_length"] + feature_names = list(set(feature_names) | set(required_features)) + + features_metadata = {name: protein_features.FEATURES[name] + for name in feature_names} + return features_metadata + + +def create_tensor_dict( + raw_data: bytes, + features: Sequence[str], + key: Optional[str] = None, + ) -> TensorDict: + """Creates a dictionary of tensor features. + + Args: + raw_data: A serialized tf.Example proto. + features: A list of strings of feature names to be returned in the dataset. + key: Optional string with the SSTable key of that tf.Example. This will be + added into features as a 'key' but only if requested in features. + + Returns: + A dictionary of features mapping feature names to features. Only the given + features are returned, all other ones are filtered out. + """ + features_metadata = _make_features_metadata(features) + return parse_tfexample(raw_data, features_metadata, key) + + +def np_to_tensor_dict( + np_example: Mapping[str, np.ndarray], + features: Sequence[str], + ) -> TensorDict: + """Creates dict of tensors from a dict of NumPy arrays. + + Args: + np_example: A dict of NumPy feature arrays. + features: A list of strings of feature names to be returned in the dataset. + + Returns: + A dictionary of features mapping feature names to features. Only the given + features are returned, all other ones are filtered out. + """ + features_metadata = _make_features_metadata(features) + tensor_dict = {k: tf.constant(v) for k, v in np_example.items() + if k in features_metadata} + + # Ensures shapes are as expected. Needed for setting size of empty features + # e.g. when no template hits were found. + tensor_dict = parse_reshape_logic(tensor_dict, features_metadata) + return tensor_dict diff --git a/apps/protein_folding/helixfold_cpu/tools/quat_affine.py b/apps/protein_folding/helixfold_cpu/tools/quat_affine.py new file mode 100644 index 00000000..4b487f41 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/quat_affine.py @@ -0,0 +1,558 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quaternion geometry modules. + +This introduces a representation of coordinate frames that is based around a +‘QuatAffine’ object. This object describes an array of coordinate frames. +It consists of vectors corresponding to the +origin of the frames as well as orientations which are stored in two +ways, as unit quaternions as well as a rotation matrices. +The rotation matrices are derived from the unit quaternions and the two are kept +in sync. +For an explanation of the relation between unit quaternions and rotations see +https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation + +This representation is used in the model for the backbone frames. + +One important thing to note here, is that while we update both representations +the jit compiler is going to ensure that only the parts that are +actually used are executed. +""" + +import paddle +import numpy as np +from typing import Tuple + + +QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32) + +QUAT_TO_ROT[0, 0] = [[ 1, 0, 0], [ 0, 1, 0], [ 0, 0, 1]] # rr +QUAT_TO_ROT[1, 1] = [[ 1, 0, 0], [ 0,-1, 0], [ 0, 0,-1]] # ii +QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [ 0, 1, 0], [ 0, 0,-1]] # jj +QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [ 0,-1, 0], [ 0, 0, 1]] # kk + +QUAT_TO_ROT[1, 2] = [[ 0, 2, 0], [ 2, 0, 0], [ 0, 0, 0]] # ij +QUAT_TO_ROT[1, 3] = [[ 0, 0, 2], [ 0, 0, 0], [ 2, 0, 0]] # ik +QUAT_TO_ROT[2, 3] = [[ 0, 0, 0], [ 0, 0, 2], [ 0, 2, 0]] # jk + +QUAT_TO_ROT[0, 1] = [[ 0, 0, 0], [ 0, 0,-2], [ 0, 2, 0]] # ir +QUAT_TO_ROT[0, 2] = [[ 0, 0, 2], [ 0, 0, 0], [-2, 0, 0]] # jr +QUAT_TO_ROT[0, 3] = [[ 0,-2, 0], [ 2, 0, 0], [ 0, 0, 0]] # kr + +QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32) +QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0], + [ 0,-1, 0, 0], + [ 0, 0,-1, 0], + [ 0, 0, 0,-1]] + +QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0], + [ 1, 0, 0, 0], + [ 0, 0, 0, 1], + [ 0, 0,-1, 0]] + +QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0], + [ 0, 0, 0,-1], + [ 1, 0, 0, 0], + [ 0, 1, 0, 0]] + +QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1], + [ 0, 0, 1, 0], + [ 0,-1, 0, 0], + [ 1, 0, 0, 0]] + +QUAT_MULTIPLY_BY_VEC = QUAT_MULTIPLY[:, 1:, :] + + +def rot_to_quat(rot): + """Convert rotation matrix to quaternion. + + Note that this function calls self_adjoint_eig which is extremely expensive on + the GPU. If at all possible, this function should run on the CPU. + + Args: + rot: rotation matrix (see below for format). rotation matrix should be shape (..., 3, 3) + + Returns: + Quaternion as (..., 4) tensor. + """ + rot = [ [rot[..., i, j] for j in range(3)] for i in range(3)] + [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot + + # pylint: disable=bad-whitespace + k = [[ xx + yy + zz, zy - yz, xz - zx, yx - xy,], + [ zy - yz, xx - yy - zz, xy + yx, xz + zx,], + [ xz - zx, xy + yx, yy - xx - zz, yz + zy,], + [ yx - xy, xz + zx, yz + zy, zz - xx - yy,]] + + k = (1./3.) * paddle.stack([paddle.stack(x, axis=-1) for x in k], + axis=-2) + + # Get eigenvalues in non-decreasing order and associated. + _, qs = paddle.linalg.eigh(k) + return qs[..., -1] + +def quat_to_rot(normalized_quat): + """Convert a normalized quaternion to a rotation matrix. Quat (..., 4)""" + + mat = paddle.unsqueeze(normalized_quat, [-1, -3]) # normalized_quat[..., None, :, None] + rot_tensor = paddle.sum( + paddle.to_tensor(np.reshape(QUAT_TO_ROT, (4, 4, 9))) * + normalized_quat[..., :, None, None] * + mat, + axis=(-3, -2)) # (..., 4, 4, 9) -> (..., 9) + t_shape = rot_tensor.shape[:-1] + t_shape.extend([3, 3]) + rot = paddle.reshape(rot_tensor, t_shape) # Unstack. (..., 3, 3) + return rot + +def quat_multiply_by_vec(quat, vec): + """Multiply a quaternion by a pure-vector quaternion.""" + mat = paddle.unsqueeze(vec, [-1, -3]) # vec[..., None, :, None] + return paddle.sum( + paddle.to_tensor(QUAT_MULTIPLY_BY_VEC) * + quat[..., :, None, None] * + mat, + axis=(-3, -2)) + +def quat_multiply(quat1, quat2): + """Multiply a quaternion by another quaternion.""" + mat = paddle.unsqueeze(quat2, [-1, -3]) # quat2[..., None, :, None] + return paddle.sum( + paddle.to_tensor(QUAT_MULTIPLY) * + quat1[..., :, None, None] * + mat, + axis=(-3, -2)) + +def apply_rot_to_vec(rot, vec, unstack=False): + """Multiply rotation matrix by a vector. vec is a list. + Returns: a list of 3 tensors of the points + """ + if unstack: + x, y, z = [vec[..., i] for i in range(3)] + else: + x, y, z = vec + return [rot[..., 0, 0] * x + rot[..., 0, 1] * y + rot[..., 0, 2] * z, + rot[..., 1, 0] * x + rot[..., 1, 1] * y + rot[..., 1, 2] * z, + rot[..., 2, 0] * x + rot[..., 2, 1] * y + rot[..., 2, 2] * z] + +def apply_rot_to_vec_np(rot, vec, unstack=False): + """Multiply rotation matrix by a vector. vec is a list. + Returns: a list of 3 tensors of the points + """ + if unstack: + x, y, z = [vec[..., i] for i in range(3)] + else: + x, y, z = vec + return [rot[0][0] * x + rot[0][1] * y + rot[0][2] * z, + rot[1][0] * x + rot[1][1] * y + rot[1][2] * z, + rot[2][0] * x + rot[2][1] * y + rot[2][2] * z] + +def apply_inverse_rot_to_vec(rot, vec): + """Multiply the inverse of a rotation matrix by a vector. vec is a list. + Returns: a list of 3 tensors of the points + """ + # Inverse rotation is just transpose + x, y, z = vec + return [rot[..., 0, 0] * x + rot[..., 1, 0] * y + rot[..., 2, 0] * z, + rot[..., 0, 1] * x + rot[..., 1, 1] * y + rot[..., 2, 1] * z, + rot[..., 0, 2] * x + rot[..., 1, 2] * y + rot[..., 2, 2] * z] + + +class QuatAffine(object): + """Affine transformation represented by quaternion and vector.""" + + def __init__(self, + quaternion: paddle.Tensor, + translation: paddle.Tensor, + rotation=None, normalize=True): + """Initialize from quaternion and translation. + + Args: + quaternion: Rotation represented by a quaternion, to be applied + before translation. Must be a unit quaternion unless normalize==True. + shape (batch, N_res, 4) + translation: Translation represented as a vector. (batch, N_res, 3) + rotation: Same rotation as the quaternion, represented as a (batch, N_res, 3, 3) + tensor. If None, rotation will be calculated from the quaternion. + normalize: If True, l2 normalize the quaternion on input. + """ + + if quaternion is not None: + assert quaternion.shape[-1] == 4 + + if normalize and quaternion is not None: + q_length = paddle.norm(quaternion, axis=-1) + quaternion = quaternion / q_length[..., None] + + if rotation is None: + rotation = quat_to_rot(quaternion) + + self.quaternion = quaternion + self.rotation = rotation + self.translation = translation + + assert rotation.shape[-1] == 3 and rotation.shape[-2] == 3 + assert translation.shape[-1] == 3 + + def to_tensor(self): + return paddle.concat([self.quaternion, self.translation], axis=-1) + + def stop_rot_gradient(self): + """ + stop the gradient of rotations + """ + quat = self.quaternion + if not quat is None: + quat = quat.detach() + return QuatAffine( + quaternion=quat, + translation=self.translation, + rotation=self.rotation.detach(), + normalize=False) + + def scale_translation(self, position_scale): + """Return a new quat affine with a different scale for translation.""" + + return QuatAffine(self.quaternion, + position_scale * self.translation, + rotation=self.rotation, normalize=False) + + @classmethod + def from_tensor(cls, tensor, normalize=False): + assert tensor.shape[-1] == 7 + quaternion = tensor[..., 0:4] + translation = tensor[..., 4:7] + return cls(quaternion, translation, normalize=normalize) + + def pre_compose(self, update): + """Return a new QuatAffine which applies the transformation update first. + + Args: + update: Length-6 vector. 3-vector of x, y, and z such that the quaternion + update is (1, x, y, z) and zero for the 3-vector is the identity + quaternion. 3-vector for translation concatenated. + + Returns: + New QuatAffine object. + """ + vector_quaternion_update = update[..., 0:3] + trans_update = [update[..., 3], update[..., 4], update[..., 5]] + + new_quaternion = (self.quaternion + + quat_multiply_by_vec(self.quaternion, + vector_quaternion_update)) + + trans_update = apply_rot_to_vec(self.rotation, trans_update) + trans_update = paddle.stack(trans_update, axis=-1) + new_translation = self.translation + trans_update + + return QuatAffine(new_quaternion, new_translation) + + def apply_to_point(self, point, extra_dims=0): + """Apply affine to a point. + + Args: + point: List of 3 tensors to apply affine. + each with shape [batch_size, num_residues, num_head*num_point_qk] + extra_dims: Number of dimensions at the end of the transformed_point + shape that are not present in the rotation and translation. The most + common use is rotation N points at once with extra_dims=1 for use in a + network. + + Returns: + Transformed point after applying affine. + """ + rotation = self.rotation # [batch_size, num_residues, 3, 3] + translation = self.translation # [batch_size, num_residues, 3] + for _ in range(extra_dims): + translation = paddle.unsqueeze(translation, axis=-2) + rotation = paddle.unsqueeze(rotation, axis=-3) + + rot_point = apply_rot_to_vec(rotation, point) + return [rot_point[0] + translation[..., 0], + rot_point[1] + translation[..., 1], + rot_point[2] + translation[..., 2]] + + def invert_point(self, transformed_point, extra_dims=0): + """Apply inverse of transformation to a point. + + Args: + transformed_point: List of 3 tensors to apply affine + extra_dims: Number of dimensions at the end of the transformed_point + shape that are not present in the rotation and translation. The most + common use is rotation N points at once with extra_dims=1 for use in a + network. + + Returns: + Transformed point after applying affine. + """ + rotation = self.rotation + translation = self.translation + for _ in range(extra_dims): + translation = paddle.unsqueeze(translation, axis=-2) + rotation = paddle.unsqueeze(rotation, axis=-3) + + rot_point = [ + transformed_point[0] - translation[..., 0], + transformed_point[1] - translation[..., 1], + transformed_point[2] - translation[..., 2]] + + return apply_inverse_rot_to_vec(rotation, rot_point) + + def invert(self): + """Return a new quat affine of the invert transformation.""" + pass # TODO + + +######Paddle Implementation +def _multiply(a, b): + return paddle.stack([ + paddle.stack([ + a[..., 0, 0]*b[..., 0, 0] + a[..., 0, 1]*b[..., 1, 0] + a[..., 0, 2]*b[..., 2, 0], + a[..., 0, 0]*b[..., 0, 1] + a[..., 0, 1]*b[..., 1, 1] + a[..., 0, 2]*b[..., 2, 1], + a[..., 0, 0]*b[..., 0, 2] + a[..., 0, 1]*b[..., 1, 2] + a[..., 0, 2]*b[..., 2, 2]], axis=-1), + + paddle.stack([ + a[..., 1, 0]*b[..., 0, 0] + a[..., 1, 1]*b[..., 1, 0] + a[..., 1, 2]*b[..., 2, 0], + a[..., 1, 0]*b[..., 0, 1] + a[..., 1, 1]*b[..., 1, 1] + a[..., 1, 2]*b[..., 2, 1], + a[..., 1, 0]*b[..., 0, 2] + a[..., 1, 1]*b[..., 1, 2] + a[..., 1, 2]*b[..., 2, 2]], axis=-1), + + paddle.stack([ + a[..., 2, 0]*b[..., 0, 0] + a[..., 2, 1]*b[..., 1, 0] + a[..., 2, 2]*b[..., 2, 0], + a[..., 2, 0]*b[..., 0, 1] + a[..., 2, 1]*b[..., 1, 1] + a[..., 2, 2]*b[..., 2, 1], + a[..., 2, 0]*b[..., 0, 2] + a[..., 2, 1]*b[..., 1, 2] + a[..., 2, 2]*b[..., 2, 2]], axis=-1)], + axis=-2) + + +def make_canonical_transform( + n_xyz: paddle.Tensor, + ca_xyz: paddle.Tensor, + c_xyz: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Returns translation and rotation matrices to canonicalize residue atoms. + + Note that this method does not take care of symmetries. If you provide the + atom positions in the non-standard way, the N atom will end up not at + [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You + need to take care of such cases in your code. + + Args: + n_xyz: An array of shape [batch, n_res, 3] of nitrogen xyz coordinates. + ca_xyz: An array of shape [batch, n_res, 3] of carbon alpha xyz coordinates. + c_xyz: An array of shape [batch, n_res, 3] of carbon xyz coordinates. + + Returns: + A tuple (translation, rotation) where: + translation is an array of shape [batch, n_res, 3] defining the translation. + rotation is an array of shape [batch, n_res, 3, 3] defining the rotation. + After applying the translation and rotation to all atoms in a residue: + * All atoms will be shifted so that CA is at the origin, + * All atoms will be rotated so that C is at the x-axis, + * All atoms will be shifted so that N is in the xy plane. + """ + assert len(n_xyz.shape) == 3, n_xyz.shape + assert n_xyz.shape[-1] == 3, n_xyz.shape + assert n_xyz.shape == ca_xyz.shape == c_xyz.shape, ( + n_xyz.shape, ca_xyz.shape, c_xyz.shape) + + # Place CA at the origin. + translation = -ca_xyz + n_xyz = n_xyz + translation + c_xyz = c_xyz + translation + + # Place C on the x-axis. + c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)] + # Rotate by angle c1 in the x-y plane (around the z-axis). + norm = paddle.sqrt(1e-20 + c_x ** 2 + c_y ** 2) + sin_c1 = -c_y / norm + cos_c1 = c_x / norm + zeros = paddle.zeros_like(sin_c1) + ones = paddle.ones_like(sin_c1) + + c1_rot_matrix = paddle.stack([cos_c1, -sin_c1, zeros, + sin_c1, cos_c1, zeros, + zeros, zeros, ones], axis=-1) + c1_rot_matrix = c1_rot_matrix.reshape(sin_c1.shape + [3,3]) # edit by zjh@intel SMG 20220825 change [3, 3] -> (3, 3) + + # Rotate by angle c2 in the x-z plane (around the y-axis). + norm = paddle.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2) + sin_c2 = c_z / norm + cos_c2 = paddle.sqrt(c_x ** 2 + c_y ** 2) / norm + c2_rot_matrix = paddle.stack([cos_c2, zeros, sin_c2, + zeros, ones, zeros, + -sin_c2, zeros, cos_c2], axis=-1) + c2_rot_matrix = c2_rot_matrix.reshape(sin_c2.shape + [3,3]) # edit by zjh@intel SMG 20220825 change [3, 3] -> (3, 3) + + c_rot_matrix = _multiply(c2_rot_matrix, c1_rot_matrix) + n_xyz = paddle.stack(apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True), axis=-1) + + # Place N in the x-y plane. + _, n_y, n_z = [n_xyz[..., i] for i in range(3)] + # Rotate by angle alpha in the y-z plane (around the x-axis). + norm = paddle.sqrt(1e-20 + n_y**2 + n_z**2) + sin_n = -n_z / norm + cos_n = n_y / norm + n_rot_matrix = paddle.stack([ones, zeros, zeros, + zeros, cos_n, -sin_n, + zeros, sin_n, cos_n], axis=-1) + n_rot_matrix = n_rot_matrix.reshape(sin_n.shape + [3,3]) # edit by zjh@intel SMG 20220825 change [3, 3] -> (3, 3) + # pylint: enable=bad-whitespace + + return (translation, _multiply(n_rot_matrix, c_rot_matrix)) + + +def make_transform_from_reference( + n_xyz: paddle.Tensor, + ca_xyz: paddle.Tensor, + c_xyz: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Returns rotation and translation matrices to convert from reference. + + Note that this method does not take care of symmetries. If you provide the + atom positions in the non-standard way, the N atom will end up not at + [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You + need to take care of such cases in your code. + + Args: + n_xyz: An array of shape [batch, n_res, 3] of nitrogen xyz coordinates. + ca_xyz: An array of shape [batch, n_res, 3] of carbon alpha xyz coordinates. + c_xyz: An array of shape [batch, n_res, 3] of carbon xyz coordinates. + + Returns: + A tuple (rotation, translation) where: + rotation is an array of shape [batch, n_res, 3, 3] defining the rotation. + translation is an array of shape [batch, n_res, 3] defining the translation. + After applying the translation and rotation to the reference backbone, + the coordinates will approximately equal to the input coordinates. + + The order of translation and rotation differs from make_canonical_transform + because the rotation from this function should be applied before the + translation, unlike make_canonical_transform. + """ + translation, rotation = make_canonical_transform(n_xyz, ca_xyz, c_xyz) + return paddle.transpose(rotation, (0, 1, 3, 2)), -translation + +#######Numpy Implementation +def _multiply_np(a, b): + return np.stack([ + np.array([a[0][0]*b[0][0] + a[0][1]*b[1][0] + a[0][2]*b[2][0], + a[0][0]*b[0][1] + a[0][1]*b[1][1] + a[0][2]*b[2][1], + a[0][0]*b[0][2] + a[0][1]*b[1][2] + a[0][2]*b[2][2]]), + + np.array([a[1][0]*b[0][0] + a[1][1]*b[1][0] + a[1][2]*b[2][0], + a[1][0]*b[0][1] + a[1][1]*b[1][1] + a[1][2]*b[2][1], + a[1][0]*b[0][2] + a[1][1]*b[1][2] + a[1][2]*b[2][2]]), + + np.array([a[2][0]*b[0][0] + a[2][1]*b[1][0] + a[2][2]*b[2][0], + a[2][0]*b[0][1] + a[2][1]*b[1][1] + a[2][2]*b[2][1], + a[2][0]*b[0][2] + a[2][1]*b[1][2] + a[2][2]*b[2][2]])]) + + +def make_canonical_transform_np( + n_xyz: np.ndarray, + ca_xyz: np.ndarray, + c_xyz: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Returns translation and rotation matrices to canonicalize residue atoms. + + Note that this method does not take care of symmetries. If you provide the + atom positions in the non-standard way, the N atom will end up not at + [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You + need to take care of such cases in your code. + + Args: + n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates. + ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates. + c_xyz: An array of shape [batch, 3] of carbon xyz coordinates. + + Returns: + A tuple (translation, rotation) where: + translation is an array of shape [batch, 3] defining the translation. + rotation is an array of shape [batch, 3, 3] defining the rotation. + After applying the translation and rotation to all atoms in a residue: + * All atoms will be shifted so that CA is at the origin, + * All atoms will be rotated so that C is at the x-axis, + * All atoms will be shifted so that N is in the xy plane. + """ + assert len(n_xyz.shape) == 2, n_xyz.shape + assert n_xyz.shape[-1] == 3, n_xyz.shape + assert n_xyz.shape == ca_xyz.shape == c_xyz.shape, (n_xyz.shape, ca_xyz.shape, c_xyz.shape) + + # Place CA at the origin. + translation = -ca_xyz + n_xyz = n_xyz + translation + c_xyz = c_xyz + translation + + # Place C on the x-axis. + c_x, c_y, c_z = [c_xyz[:, i] for i in range(3)] + # Rotate by angle c1 in the x-y plane (around the z-axis). + sin_c1 = -c_y / np.sqrt(1e-20 + c_x**2 + c_y**2) + cos_c1 = c_x / np.sqrt(1e-20 + c_x**2 + c_y**2) + zeros = np.zeros_like(sin_c1) + ones = np.ones_like(sin_c1) + # pylint: disable=bad-whitespace + c1_rot_matrix = np.stack([np.array([cos_c1, -sin_c1, zeros]), + np.array([sin_c1, cos_c1, zeros]), + np.array([zeros, zeros, ones])]) + + # Rotate by angle c2 in the x-z plane (around the y-axis). + sin_c2 = c_z / np.sqrt(1e-20 + c_x**2 + c_y**2 + c_z**2) + cos_c2 = np.sqrt(c_x**2 + c_y**2) / np.sqrt( + 1e-20 + c_x**2 + c_y**2 + c_z**2) + c2_rot_matrix = np.stack([np.array([cos_c2, zeros, sin_c2]), + np.array([zeros, ones, zeros]), + np.array([-sin_c2, zeros, cos_c2])]) + + c_rot_matrix = _multiply_np(c2_rot_matrix, c1_rot_matrix) + n_xyz = np.stack(apply_rot_to_vec_np(c_rot_matrix, n_xyz, unstack=True)).T + + # Place N in the x-y plane. + _, n_y, n_z = [n_xyz[:, i] for i in range(3)] + # Rotate by angle alpha in the y-z plane (around the x-axis). + sin_n = -n_z / np.sqrt(1e-20 + n_y**2 + n_z**2) + cos_n = n_y / np.sqrt(1e-20 + n_y**2 + n_z**2) + n_rot_matrix = np.stack([np.array([ones, zeros, zeros]), + np.array([zeros, cos_n, -sin_n]), + np.array([zeros, sin_n, cos_n])]) + + return (translation, np.transpose(_multiply_np(n_rot_matrix, c_rot_matrix), [2, 0, 1])) + + +def make_transform_from_reference_np( + n_xyz: np.ndarray, + ca_xyz: np.ndarray, + c_xyz: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Returns rotation and translation matrices to convert from reference. + + Note that this method does not take care of symmetries. If you provide the + atom positions in the non-standard way, the N atom will end up not at + [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You + need to take care of such cases in your code. + + Args: + n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates. + ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates. + c_xyz: An array of shape [batch, 3] of carbon xyz coordinates. + + Returns: + A tuple (rotation, translation) where: + rotation is an array of shape [batch, 3, 3] defining the rotation. + translation is an array of shape [batch, 3] defining the translation. + After applying the translation and rotation to the reference backbone, + the coordinates will approximately equal to the input coordinates. + + The order of translation and rotation differs from make_canonical_transform + because the rotation from this function should be applied before the + translation, unlike make_canonical_transform. + """ + translation, rotation = make_canonical_transform_np(n_xyz, ca_xyz, c_xyz) + return np.transpose(rotation, (0, 2, 1)), -translation \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/tools/r3.py b/apps/protein_folding/helixfold_cpu/tools/r3.py new file mode 100644 index 00000000..167bbc43 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/r3.py @@ -0,0 +1,492 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformations for 3D coordinates. + +This Module contains objects for representing Vectors (Vecs), Rotation Matrices +(Rots) and proper Rigid transformation (Rigids). These are represented as +named tuples with arrays for each entry, for example a set of +[N, M] points would be represented as a Vecs object with arrays of shape [N, M] +for x, y and z. + +This is being done to improve readability by making it very clear what objects +are geometric objects rather than relying on comments and array shapes. +Another reason for this is to avoid using matrix +multiplication primitives like matmul or einsum, on modern accelerator hardware +these can end up on specialized cores such as tensor cores on GPU or the MXU on +cloud TPUs, this often involves lower computational precision which can be +problematic for coordinate geometry. Also these cores are typically optimized +for larger matrices than 3 dimensional, this code is written to avoid any +unintended use of these cores on both GPUs and TPUs. +""" +import paddle +import collections +from typing import List +from tools import quat_affine + +# Array of rigid 3D transformations, stored as array of rotations and +# array of translations. +Rigids = collections.namedtuple('Rigids', ['rot', 'trans']) + +class Vecs: + def __init__(self, *args): + + if len(args) == 1: + if type(args[0]) in [list, tuple] and len(args[0]) == 3: + self.translation = paddle.stack(args[0], axis=-1) + elif len(args[0]) == 1: + self.translation = args[0] + elif args[0].shape[-1]==3: + self.translation = args[0] + else: + raise ValueError('Invalid number of inputs') + elif len(args) == 3: + self.translation = paddle.stack(args, axis=-1) + else: + raise ValueError('Invalid number of inputs') + + def map(self, map_fn, *args): + result = [] + for i in range(3): + r = map_fn(self.translation[..., i], *args) + result.append(r) + + if result[0].shape[-1] == 1: + return Vecs(paddle.concat(result, axis=-1)) + else: + return Vecs(paddle.stack(result, axis=-1)) + + @property + def shape(self): + return self.translation.shape + + @property + def x(self): + return self.translation[..., 0] + + @property + def y(self): + return self.translation[..., 1] + + @property + def z(self): + return self.translation[..., 2] + + def __getitem__(self,index): + return Vecs(self.translation[index]) + def __str__(self): + return str(self.translation.shape) + def __repr__(self): + return str(self.translation.shape) + + def reshape(self,*argv): + return self.translation.reshape(*argv) + + +class Rots: + def __init__(self, *args): + if len(args) == 1: + args = args[0] + if len(args) == 9: + rots = paddle.stack(args, axis=-1) + self.rotation = rots.reshape(rots.shape[:-1] + [3, 3]) + else: + if args.shape[-1] == 3 and args.shape[-2] == 3: + self.rotation = args + elif args.shape[-1] == 9: + self.rotation = args.reshape(args.shape[:-1] + [3, 3]) + else: + raise ValueError('Invalid shape of input') + elif len(args) == 9: + rots = paddle.stack(args, axis=-1) + self.rotation = rots.reshape(rots.shape[:-1] + [3, 3]) + else: + raise ValueError('Invalid number of inputs') + + def map(self, map_fn, *args): + result_i = [] + for i in range(3): + result_j = [] + for j in range(3): + r = map_fn(self.rotation[..., i, j], *args) + result_j.append(r) + + if result_j[0].shape[-1] == 1: + result_i.append(paddle.concat(result_j, axis=-1)) + else: + result_i.append(paddle.stack(result_j, axis=-1)) + + return Rots(paddle.stack(result_i, axis=-2)) + + @property + def shape(self): + return self.rotation.shape + + @property + def xx(self): + return self.rotation[..., 0, 0] + + @property + def xy(self): + return self.rotation[..., 0, 1] + + @property + def xz(self): + return self.rotation[..., 0, 2] + + @property + def yx(self): + return self.rotation[..., 1, 0] + + @property + def yy(self): + return self.rotation[..., 1, 1] + + @property + def yz(self): + return self.rotation[..., 1, 2] + + @property + def zx(self): + return self.rotation[..., 2, 0] + + @property + def zy(self): + return self.rotation[..., 2, 1] + + @property + def zz(self): + return self.rotation[..., 2, 2] + + def __getitem__(self,index): + return Rots(self.rotation[index]) + def __str__(self): + return str(self.rotation.shape) + def __repr__(self): + return str(self.rotation.shape) + def reshape(self,*argv): + return self.rotation.reshape(*argv) + + +def squared_difference(x, y): + return paddle.square(x - y) + + +def invert_rigids(r: Rigids) -> Rigids: + """Computes group inverse of rigid transformations 'r'.""" + inv_rots = invert_rots(r.rot) + t = rots_mul_vecs(inv_rots, r.trans) + inv_trans = Vecs(-t.x, -t.y, -t.z) + return Rigids(inv_rots, inv_trans) + + +def invert_rots(m: Rots) -> Rots: + """Computes inverse of rotations 'm'.""" + return Rots(m.xx, m.yx, m.zx, + m.xy, m.yy, m.zy, + m.xz, m.yz, m.zz) + + +def rigids_from_3_points_vecs( + point_on_neg_x_axis: Vecs, + origin: Vecs, + point_on_xy_plane: Vecs, +) -> Rigids: + """Create Rigids from 3 points. + + Jumper et al. (2021) Suppl. Alg. 21 "rigidFrom3Points" + This creates a set of rigid transformations from 3 points by Gram Schmidt + orthogonalization. + + Args: + point_on_neg_x_axis: Vecs corresponding to points on the negative x axis + origin: Origin of resulting rigid transformations + point_on_xy_plane: Vecs corresponding to points in the xy plane + Returns: + Rigid transformations from global frame to local frames derived from + the input points. + """ + m = rots_from_two_vecs( + e0_unnormalized=vecs_sub(origin, point_on_neg_x_axis), + e1_unnormalized=vecs_sub(point_on_xy_plane, origin)) + + return Rigids(rot=m, trans=origin) + + +def rigids_from_3_points( + point_on_neg_x_axis: paddle.Tensor, + origin: paddle.Tensor, + point_on_xy_plane: paddle.Tensor, + eps: float = 1e-8) -> Rigids: + """Create Rigids from 3 points. + + Jumper et al. (2021) Suppl. Alg. 21 "rigidFrom3Points" + This creates a set of rigid transformations from 3 points by Gram Schmidt + orthogonalization. + + Argss: + point_on_neg_x_axis: [*, 3] coordinates + origin: [*, 3] coordinates + point_on_xy_plane: [*, 3] coordinates + eps: small regularizer added to squared norm before taking square root. + Returns: + Rigids corresponding to transformations from global frame + to local frames derived from the input points. + """ + point_on_neg_x_axis = paddle.unbind(point_on_neg_x_axis, axis=-1) + origin = paddle.unbind(origin, axis=-1) + point_on_xy_plane = paddle.unbind(point_on_xy_plane, axis=-1) + + e0 = [c1 - c2 for c1, c2 in zip(origin, point_on_neg_x_axis)] + e1 = [c1 - c2 for c1, c2 in zip(point_on_xy_plane, origin)] + + norms = paddle.sqrt(paddle.square(e0[0]) + + paddle.square(e0[1]) + + paddle.square(e0[2]) + eps) + e0 = [c / norms for c in e0] + dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) + e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] + norms = paddle.sqrt(paddle.square(e1[0]) + + paddle.square(e1[1]) + + paddle.square(e1[2]) + eps) + e1 = [c / norms for c in e1] + e2 = [ + e0[1] * e1[2] - e0[2] * e1[1], + e0[2] * e1[0] - e0[0] * e1[2], + e0[0] * e1[1] - e0[1] * e1[0], + ] + + rots = paddle.stack([c for tup in zip(e0, e1, e2) for c in tup], axis=-1) + + return Rigids(Rots(rots), Vecs(origin)) + + +def rigids_from_list(l: List[paddle.Tensor]) -> Rigids: + """Converts flat list of arrays to rigid transformations.""" + assert len(l) == 12 + return Rigids(Rots(*(l[:9])), Vecs(*(l[9:]))) + + +def rigids_from_quataffine(a: quat_affine.QuatAffine) -> Rigids: + """Converts QuatAffine object to the corresponding Rigids object.""" + return Rigids(Rots(a.rotation), + Vecs(a.translation)) + + +def rigids_from_tensor4x4(m: paddle.Tensor) -> Rigids: + """Construct Rigids from an 4x4 array. + + Here the 4x4 is representing the transformation in homogeneous coordinates. + + Argss: + m: [*, 4, 4] homogenous transformation tensor + Returns: + Rigids corresponding to transformations m + """ + assert m.shape[-1] == 4 + assert m.shape[-2] == 4 + return Rigids( + Rots(m[..., 0, 0], m[..., 0, 1], m[..., 0, 2], + m[..., 1, 0], m[..., 1, 1], m[..., 1, 2], + m[..., 2, 0], m[..., 2, 1], m[..., 2, 2]), + Vecs(m[..., 0, 3], m[..., 1, 3], m[..., 2, 3])) + + +def rigids_from_tensor_flat9(m: paddle.Tensor) -> Rigids: + """Flat9 encoding: first two columns of rotation matrix + translation.""" + assert m.shape[-1] == 9 + e0 = Vecs(m[..., 0], m[..., 1], m[..., 2]) + e1 = Vecs(m[..., 3], m[..., 4], m[..., 5]) + trans = Vecs(m[..., 6], m[..., 7], m[..., 8]) + return Rigids(rot=rots_from_two_vecs(e0, e1), + trans=trans) + + +def rigids_from_tensor_flat12( + m: paddle.Tensor # shape (..., 12) + ) -> Rigids: # shape (...) + """Flat12 encoding: rotation matrix (9 floats) + translation (3 floats).""" + assert m.shape[-1] == 12 + return Rigids(Rots(m[..., :9]), Vecs(m[..., 9:])) + + +def rigids_mul_rigids(a: Rigids, b: Rigids) -> Rigids: + """Group composition of Rigids 'a' and 'b'.""" + return Rigids( + rots_mul_rots(a.rot, b.rot), + vecs_add(a.trans, rots_mul_vecs(a.rot, b.trans))) + + +def rigids_mul_rots(r: Rigids, m: Rots) -> Rigids: + """Compose rigid transformations 'r' with rotations 'm'.""" + return Rigids(rots_mul_rots(r.rot, m), r.trans) + + +def rigids_mul_vecs(r: Rigids, v: Vecs) -> Vecs: + """Apply rigid transforms 'r' to points 'v'.""" + return vecs_add(rots_mul_vecs(r.rot, v), r.trans) + + +def rigids_to_list(r: Rigids) -> List[paddle.Tensor]: + """Turn Rigids into flat list, inverse of 'rigids_from_list'.""" + return list(r.rot) + list(r.trans) + + +def rigids_to_quataffine(r: Rigids) -> quat_affine.QuatAffine: + """Convert Rigids r into QuatAffine, inverse of 'rigids_from_quataffine'.""" + return quat_affine.QuatAffine( + quaternion=None, + rotation=r.rot.rotation, + translation=r.trans.translation) + + +def rigids_to_tensor_flat9( + r: Rigids) -> paddle.Tensor: # shape (..., 9) + """Flat9 encoding: first two columns of rotation matrix + translation.""" + return paddle.stack( + [r.rot.xx, r.rot.yx, r.rot.zx, r.rot.xy, r.rot.yy, r.rot.zy] + + list(r.trans), axis=-1) + + +def rigids_to_tensor_flat12( + r: Rigids # shape (...) + ) -> paddle.Tensor: # shape (..., 12) + """Flat12 encoding: rotation matrix (9 floats) + translation (3 floats).""" + + return paddle.stack([r.rot.xx, r.rot.yx, r.rot.zx, r.rot.xy, r.rot.yy, r.rot.zy, r.rot.xz, r.rot.yz, r.rot.zz] + + [r.trans.x, r.trans.y, r.trans.z], axis=-1) + + +def rots_from_tensor3x3( + m: paddle.Tensor, # shape (..., 3, 3) + ) -> Rots: # shape (...) + """Convert rotations represented as (3, 3) array to Rots.""" + assert m.shape[-1] == 3 + assert m.shape[-2] == 3 + return Rots(m[..., 0, 0], m[..., 0, 1], m[..., 0, 2], + m[..., 1, 0], m[..., 1, 1], m[..., 1, 2], + m[..., 2, 0], m[..., 2, 1], m[..., 2, 2]) + + +def rots_from_two_vecs(e0_unnormalized: Vecs, e1_unnormalized: Vecs) -> Rots: + """Create rotation matrices from unnormalized vectors for the x and y-axes. + + This creates a rotation matrix from two vectors using Gram-Schmidt + orthogonalization. + + Args: + e0_unnormalized: vectors lying along x-axis of resulting rotation + e1_unnormalized: vectors lying in xy-plane of resulting rotation + Returns: + Rotations resulting from Gram-Schmidt procedure. + """ + # Normalize the unit vector for the x-axis, e0. + e0 = vecs_robust_normalize(e0_unnormalized) + + # make e1 perpendicular to e0. + c = vecs_dot_vecs(e1_unnormalized, e0) + e1 = Vecs(e1_unnormalized.x - c * e0.x, + e1_unnormalized.y - c * e0.y, + e1_unnormalized.z - c * e0.z) + e1 = vecs_robust_normalize(e1) + + # Compute e2 as cross product of e0 and e1. + e2 = vecs_cross_vecs(e0, e1) + + return Rots(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) + + +def rots_mul_rots(a: Rots, b: Rots) -> Rots: + """Composition of rotations 'a' and 'b'.""" + c0 = rots_mul_vecs(a, Vecs(b.xx, b.yx, b.zx)) + c1 = rots_mul_vecs(a, Vecs(b.xy, b.yy, b.zy)) + c2 = rots_mul_vecs(a, Vecs(b.xz, b.yz, b.zz)) + return Rots(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) + + +def rots_mul_vecs(m: Rots, v: Vecs) -> Vecs: + """Apply rotations 'm' to vectors 'v'.""" + return Vecs(m.xx * v.x + m.xy * v.y + m.xz * v.z, + m.yx * v.x + m.yy * v.y + m.yz * v.z, + m.zx * v.x + m.zy * v.y + m.zz * v.z) + + +def vecs_add(v1: Vecs, v2: Vecs) -> Vecs: + """Add two vectors 'v1' and 'v2'.""" + return Vecs(v1.x + v2.x, v1.y + v2.y, v1.z + v2.z) + + +def vecs_dot_vecs(v1: Vecs, v2: Vecs) -> paddle.Tensor: + """Dot product of vectors 'v1' and 'v2'.""" + return v1.x * v2.x + v1.y * v2.y + v1.z * v2.z + + +def vecs_cross_vecs(v1: Vecs, v2: Vecs) -> Vecs: + """Cross product of vectors 'v1' and 'v2'.""" + return Vecs(v1.y * v2.z - v1.z * v2.y, + v1.z * v2.x - v1.x * v2.z, + v1.x * v2.y - v1.y * v2.x) + + +def vecs_from_tensor(x: paddle.Tensor # shape (..., 3) + ) -> Vecs: # shape (...) + """Converts from tensor of shape (3,) to Vecs.""" + num_components = x.shape[-1] + assert num_components == 3 + return Vecs(x[..., 0], x[..., 1], x[..., 2]) + + +def vecs_robust_normalize(v: Vecs, epsilon: float = 1e-8) -> Vecs: + """Normalizes vectors 'v'. + + Argss: + v: vectors to be normalized. + epsilon: small regularizer added to squared norm before taking square root. + Returns: + normalized vectors + """ + norms = vecs_robust_norm(v, epsilon) + return Vecs(v.x / norms, v.y / norms, v.z / norms) + + +def vecs_robust_norm(v: Vecs, epsilon: float = 1e-8) -> paddle.Tensor: + """Computes norm of vectors 'v'. + + Args: + v: vectors to be normalized. + epsilon: small regularizer added to squared norm before taking square root. + Returns: + norm of 'v' + """ + return paddle.sqrt(paddle.square(v.x) + + paddle.square(v.y) + + paddle.square(v.z) + epsilon) + + +def vecs_sub(v1: Vecs, v2: Vecs) -> Vecs: + """Computes v1 - v2.""" + return Vecs(v1.x - v2.x, v1.y - v2.y, v1.z - v2.z) + + +def vecs_squared_distance(v1: Vecs, v2: Vecs) -> paddle.Tensor: + """Computes squared euclidean difference between 'v1' and 'v2'.""" + return (squared_difference(v1.x, v2.x) + + squared_difference(v1.y, v2.y) + + squared_difference(v1.z, v2.z)) + + +def vecs_to_tensor(v: Vecs # shape (...) + ) -> paddle.Tensor: # shape(..., 3) + """Converts 'v' to tensor with shape 3, inverse of 'vecs_from_tensor'.""" + return paddle.stack([v.x, v.y, v.z], axis=-1) diff --git a/apps/protein_folding/helixfold_cpu/tools/residue_constants.py b/apps/protein_folding/helixfold_cpu/tools/residue_constants.py new file mode 100644 index 00000000..8fabb710 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/residue_constants.py @@ -0,0 +1,897 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used in AlphaFold.""" + +import collections +import functools +import os +from typing import List, Mapping, Tuple + +import numpy as np +import tree + +# Internal import (35fd). + + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + 'ALA': [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + 'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']], + 'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'CYS': [['N', 'CA', 'CB', 'SG']], + 'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1']], + 'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1']], + 'GLY': [], + 'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']], + 'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']], + 'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']], + 'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'], + ['CB', 'CG', 'SD', 'CE']], + 'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']], + 'SER': [['N', 'CA', 'CB', 'OG']], + 'THR': [['N', 'CA', 'CB', 'OG1']], + 'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'VAL': [['N', 'CA', 'CB', 'CG1']], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions = { + 'ALA': [ + ['N', 0, (-0.525, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.529, -0.774, -1.205)], + ['O', 3, (0.627, 1.062, 0.000)], + ], + 'ARG': [ + ['N', 0, (-0.524, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.524, -0.778, -1.209)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.616, 1.390, -0.000)], + ['CD', 5, (0.564, 1.414, 0.000)], + ['NE', 6, (0.539, 1.357, -0.000)], + ['NH1', 7, (0.206, 2.301, 0.000)], + ['NH2', 7, (2.078, 0.978, -0.000)], + ['CZ', 7, (0.758, 1.093, -0.000)], + ], + 'ASN': [ + ['N', 0, (-0.536, 1.357, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.531, -0.787, -1.200)], + ['O', 3, (0.625, 1.062, 0.000)], + ['CG', 4, (0.584, 1.399, 0.000)], + ['ND2', 5, (0.593, -1.188, 0.001)], + ['OD1', 5, (0.633, 1.059, 0.000)], + ], + 'ASP': [ + ['N', 0, (-0.525, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, 0.000, -0.000)], + ['CB', 0, (-0.526, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.593, 1.398, -0.000)], + ['OD1', 5, (0.610, 1.091, 0.000)], + ['OD2', 5, (0.592, -1.101, -0.003)], + ], + 'CYS': [ + ['N', 0, (-0.522, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, 0.000)], + ['CB', 0, (-0.519, -0.773, -1.212)], + ['O', 3, (0.625, 1.062, -0.000)], + ['SG', 4, (0.728, 1.653, 0.000)], + ], + 'GLN': [ + ['N', 0, (-0.526, 1.361, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.779, -1.207)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.615, 1.393, 0.000)], + ['CD', 5, (0.587, 1.399, -0.000)], + ['NE2', 6, (0.593, -1.189, -0.001)], + ['OE1', 6, (0.634, 1.060, 0.000)], + ], + 'GLU': [ + ['N', 0, (-0.528, 1.361, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.526, -0.781, -1.207)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.615, 1.392, 0.000)], + ['CD', 5, (0.600, 1.397, 0.000)], + ['OE1', 6, (0.607, 1.095, -0.000)], + ['OE2', 6, (0.589, -1.104, -0.001)], + ], + 'GLY': [ + ['N', 0, (-0.572, 1.337, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.517, -0.000, -0.000)], + ['O', 3, (0.626, 1.062, -0.000)], + ], + 'HIS': [ + ['N', 0, (-0.527, 1.360, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.778, -1.208)], + ['O', 3, (0.625, 1.063, 0.000)], + ['CG', 4, (0.600, 1.370, -0.000)], + ['CD2', 5, (0.889, -1.021, 0.003)], + ['ND1', 5, (0.744, 1.160, -0.000)], + ['CE1', 5, (2.030, 0.851, 0.002)], + ['NE2', 5, (2.145, -0.466, 0.004)], + ], + 'ILE': [ + ['N', 0, (-0.493, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.536, -0.793, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.534, 1.437, -0.000)], + ['CG2', 4, (0.540, -0.785, -1.199)], + ['CD1', 5, (0.619, 1.391, 0.000)], + ], + 'LEU': [ + ['N', 0, (-0.520, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.773, -1.214)], + ['O', 3, (0.625, 1.063, -0.000)], + ['CG', 4, (0.678, 1.371, 0.000)], + ['CD1', 5, (0.530, 1.430, -0.000)], + ['CD2', 5, (0.535, -0.774, 1.200)], + ], + 'LYS': [ + ['N', 0, (-0.526, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.524, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.619, 1.390, 0.000)], + ['CD', 5, (0.559, 1.417, 0.000)], + ['CE', 6, (0.560, 1.416, 0.000)], + ['NZ', 7, (0.554, 1.387, 0.000)], + ], + 'MET': [ + ['N', 0, (-0.521, 1.364, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.210)], + ['O', 3, (0.625, 1.062, -0.000)], + ['CG', 4, (0.613, 1.391, -0.000)], + ['SD', 5, (0.703, 1.695, 0.000)], + ['CE', 6, (0.320, 1.786, -0.000)], + ], + 'PHE': [ + ['N', 0, (-0.518, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, -0.000)], + ['CB', 0, (-0.525, -0.776, -1.212)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.607, 1.377, 0.000)], + ['CD1', 5, (0.709, 1.195, -0.000)], + ['CD2', 5, (0.706, -1.196, 0.000)], + ['CE1', 5, (2.102, 1.198, -0.000)], + ['CE2', 5, (2.098, -1.201, -0.000)], + ['CZ', 5, (2.794, -0.003, -0.001)], + ], + 'PRO': [ + ['N', 0, (-0.566, 1.351, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, 0.000)], + ['CB', 0, (-0.546, -0.611, -1.293)], + ['O', 3, (0.621, 1.066, 0.000)], + ['CG', 4, (0.382, 1.445, 0.0)], + # ['CD', 5, (0.427, 1.440, 0.0)], + ['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + 'SER': [ + ['N', 0, (-0.529, 1.360, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.518, -0.777, -1.211)], + ['O', 3, (0.626, 1.062, -0.000)], + ['OG', 4, (0.503, 1.325, 0.000)], + ], + 'THR': [ + ['N', 0, (-0.517, 1.364, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, -0.000)], + ['CB', 0, (-0.516, -0.793, -1.215)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG2', 4, (0.550, -0.718, -1.228)], + ['OG1', 4, (0.472, 1.353, 0.000)], + ], + 'TRP': [ + ['N', 0, (-0.521, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.212)], + ['O', 3, (0.627, 1.062, 0.000)], + ['CG', 4, (0.609, 1.370, -0.000)], + ['CD1', 5, (0.824, 1.091, 0.000)], + ['CD2', 5, (0.854, -1.148, -0.005)], + ['CE2', 5, (2.186, -0.678, -0.007)], + ['CE3', 5, (0.622, -2.530, -0.007)], + ['NE1', 5, (2.140, 0.690, -0.004)], + ['CH2', 5, (3.028, -2.890, -0.013)], + ['CZ2', 5, (3.283, -1.543, -0.011)], + ['CZ3', 5, (1.715, -3.389, -0.011)], + ], + 'TYR': [ + ['N', 0, (-0.522, 1.362, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.776, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG', 4, (0.607, 1.382, -0.000)], + ['CD1', 5, (0.716, 1.195, -0.000)], + ['CD2', 5, (0.713, -1.194, -0.001)], + ['CE1', 5, (2.107, 1.200, -0.002)], + ['CE2', 5, (2.104, -1.201, -0.003)], + ['OH', 5, (4.168, -0.002, -0.005)], + ['CZ', 5, (2.791, -0.001, -0.003)], + ], + 'VAL': [ + ['N', 0, (-0.494, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.533, -0.795, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.540, 1.429, -0.000)], + ['CG2', 4, (0.533, -0.776, 1.203)], + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + 'ALA': ['C', 'CA', 'CB', 'N', 'O'], + 'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'], + 'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'], + 'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'], + 'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'], + 'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'], + 'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'], + 'GLY': ['C', 'CA', 'N', 'O'], + 'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'], + 'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'], + 'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'], + 'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'], + 'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'], + 'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'], + 'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'], + 'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'], + 'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'], + 'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3', + 'CH2', 'N', 'NE1', 'O'], + 'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O', + 'OH'], + 'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O'] +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +residue_atom_renaming_swaps = { + 'ASP': {'OD1': 'OD2'}, + 'GLU': {'OE1': 'OE2'}, + 'PHE': {'CD1': 'CD2', 'CE1': 'CE2'}, + 'TYR': {'CD1': 'CD2', 'CE1': 'CE2'}, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius = { + 'C': 1.7, + 'N': 1.55, + 'O': 1.52, + 'S': 1.8, +} + +Bond = collections.namedtuple( + 'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev']) +BondAngle = collections.namedtuple( + 'BondAngle', + ['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev']) + + +@functools.lru_cache(maxsize=None) +def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]], + Mapping[str, List[Bond]], + Mapping[str, List[BondAngle]]]: + """Load stereo_chemical_props.txt into a nice structure. + + Load literature values for bond lengths and bond angles and translate + bond angles into the length of the opposite edge of the triangle + ("residue_virtual_bonds"). + + Returns: + residue_bonds: Dict that maps resname -> list of Bond tuples. + residue_virtual_bonds: Dict that maps resname -> list of Bond tuples. + residue_bond_angles: Dict that maps resname -> list of BondAngle tuples. + """ + stereo_chemical_props_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), 'stereo_chemical_props.txt' + ) + with open(stereo_chemical_props_path, 'rt') as f: + stereo_chemical_props = f.read() + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, length, stddev = line.split() + atom1, atom2 = bond.split('-') + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append( + Bond(atom1, atom2, float(length), float(stddev))) + residue_bonds['UNK'] = [] + + # Load bond angles. + residue_bond_angles = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split('-') + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle(atom1, atom2, atom3, + float(angle_degree) / 180. * np.pi, + float(stddev_degree) / 180. * np.pi)) + residue_bond_angles['UNK'] = [] + + def make_bond_key(atom1_name, atom2_name): + """Unique key to lookup bonds.""" + return '-'.join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt(bond1.length**2 + bond2.length**2 + - 2 * bond1.length * bond2.length * np.cos(gamma)) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer + dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer + dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer + stddev = np.sqrt((dl_dgamma * ba.stddev)**2 + + (dl_db1 * bond1.stddev)**2 + + (dl_db2 * bond2.stddev)**2) + residue_virtual_bonds[resname].append( + Bond(ba.atom1_name, ba.atom3name, length, stddev)) + + return (residue_bonds, + residue_virtual_bonds, + residue_bond_angles) + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n = [1.329, 1.341] +between_res_bond_length_stddev_c_n = [0.014, 0.016] + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + 'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD', + 'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3', + 'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2', + 'CZ3', 'NZ', 'OXT' +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''], + 'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''], + 'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''], + 'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''], + 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''], + 'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''], + 'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''], + 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''], + 'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''], + 'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''], + 'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''], + 'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''], + 'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''], + 'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''], + 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''], + 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''], + 'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''], + 'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'], + 'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''], + 'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''], + 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''], + +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', + 'S', 'T', 'W', 'Y', 'V' +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x = restypes + ['X'] +restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)} + + +def sequence_to_onehot( + sequence: str, + mapping: Mapping[str, int], + map_unknown_to_x: bool = False) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain + amino acid 'X', an error will be thrown. If False, any amino acid not in + the mapping will throw an error. + + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of + the sequence. + + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError('The mapping must have values from 0 to num_unique_aas-1 ' + 'without any gaps. Got: %s' % sorted(mapping.values())) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping['X']) + else: + raise ValueError(f'Invalid character in the sequence: {aa_type}') + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3 = { + 'A': 'ALA', + 'R': 'ARG', + 'N': 'ASN', + 'D': 'ASP', + 'C': 'CYS', + 'Q': 'GLN', + 'E': 'GLU', + 'G': 'GLY', + 'H': 'HIS', + 'I': 'ILE', + 'L': 'LEU', + 'K': 'LYS', + 'M': 'MET', + 'F': 'PHE', + 'P': 'PRO', + 'S': 'SER', + 'T': 'THR', + 'W': 'TRP', + 'Y': 'TYR', + 'V': 'VAL', +} + + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = 'UNK' + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + + +# The mapping here uses hhblits convention, so that B is mapped to D, J and O +# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the +# remaining 20 amino acids are kept in alphabetical order. +# There are 2 non-amino acid codes, X (representing any amino acid) and +# "-" representing a missing amino acid in an alignment. The id for these +# codes is put at the end (20 and 21) so that they can easily be ignored if +# desired. +HHBLITS_AA_TO_ID = { + 'A': 0, + 'B': 2, + 'C': 1, + 'D': 2, + 'E': 3, + 'F': 4, + 'G': 5, + 'H': 6, + 'I': 7, + 'J': 20, + 'K': 8, + 'L': 9, + 'M': 10, + 'N': 11, + 'O': 20, + 'P': 12, + 'Q': 13, + 'R': 14, + 'S': 15, + 'T': 16, + 'U': 1, + 'V': 17, + 'W': 18, + 'X': 20, + 'Y': 19, + 'Z': 3, + '-': 21, +} + +# Partial inversion of HHBLITS_AA_TO_ID. +ID_TO_HHBLITS_AA = { + 0: 'A', + 1: 'C', # Also U. + 2: 'D', # Also B. + 3: 'E', # Also Z. + 4: 'F', + 5: 'G', + 6: 'H', + 7: 'I', + 8: 'K', + 9: 'L', + 10: 'M', + 11: 'N', + 12: 'P', + 13: 'Q', + 14: 'R', + 15: 'S', + 16: 'T', + 17: 'V', + 18: 'W', + 19: 'Y', + 20: 'X', # Includes J and O. + 21: '-', +} + +restypes_with_x_and_gap = restypes + ['X', '-'] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple( + restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) + for i in range(len(restypes_with_x_and_gap))) + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1]*(4-len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3[r] + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + +chi_atom_1_one_hot = chi_angle_atom(1) +chi_atom_2_one_hot = chi_angle_atom(2) + +# An array like chi_angles_atoms but using indices rather than names. +chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices = tree.map_structure( + lambda atom_name: atom_order[atom_name], chi_angles_atom_indices) +chi_angles_atom_indices = np.array([ + chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) + for chi_atoms in chi_angles_atom_indices]) + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex, ey, translation): + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose() + m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + + +def _make_rigid_group_constants(): + """Fill the arrays above.""" + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_position in rigid_group_atom_positions[ + resname]: + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[restype, + atom14idx, :] = atom_position + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = {name: np.array(pos) for name, _, pos + in rigid_group_atom_positions[resname]} + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['N'] - atom_positions['CA'], + ey=np.array([1., 0., 0.]), + translation=atom_positions['N']) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['C'] - atom_positions['CA'], + ey=atom_positions['CA'] - atom_positions['N'], + translation=atom_positions['C']) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [atom_positions[name] for name in base_atom_names] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2]) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1., 0., 0.]), + translation=axis_end_atom_position) + restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds(overlap_tolerance=1.5, + bond_length_tolerance_factor=15): + """compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_list = restype_name_to_atom14_names[resname] + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev + return {'lower_bound': restype_atom14_bond_lower_bound, # shape (21,14,14) + 'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14) + 'stddev': restype_atom14_bond_stddev, # shape (21,14,14) + } diff --git a/apps/protein_folding/helixfold_cpu/tools/shape_helpers.py b/apps/protein_folding/helixfold_cpu/tools/shape_helpers.py new file mode 100644 index 00000000..75598183 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/shape_helpers.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for dealing with shapes of TensorFlow tensors.""" +import tensorflow.compat.v1 as tf + + +def shape_list(x): + """Return list of dimensions of a tensor, statically where possible. + + Like `x.shape.as_list()` but with tensors instead of `None`s. + + Args: + x: A tensor. + Returns: + A list with length equal to the rank of the tensor. The n-th element of the + list is an integer when that dimension is statically known otherwise it is + the n-th element of `tf.shape(x)`. + """ + x = tf.convert_to_tensor(x) + + # If unknown rank, return dynamic shape + if x.get_shape().dims is None: + return tf.shape(x) + + static = x.get_shape().as_list() + shape = tf.shape(x) + + ret = [] + for i in range(len(static)): + dim = static[i] + if dim is None: + dim = shape[i] + ret.append(dim) + return ret + diff --git a/apps/protein_folding/helixfold_cpu/tools/shape_placeholders.py b/apps/protein_folding/helixfold_cpu/tools/shape_placeholders.py new file mode 100644 index 00000000..12347e00 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/shape_placeholders.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Placeholder values for run-time varying dimension sizes.""" + +NUM_RES = 'num residues placeholder' +NUM_MSA_SEQ = 'msa placeholder' +NUM_EXTRA_SEQ = 'extra msa placeholder' +NUM_TEMPLATES = 'num templates placeholder' diff --git a/apps/protein_folding/helixfold_cpu/tools/utils.py b/apps/protein_folding/helixfold_cpu/tools/utils.py new file mode 100644 index 00000000..1c63547c --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/tools/utils.py @@ -0,0 +1,304 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils.""" + +import os +import numbers +import functools +import collections +import paddle +import numpy as np +from typing import Any, Mapping + +import protein +import confidence + + +def jax_params_to_paddle(params): + """ + Rule 1: alphafold/alphafold_iteration/evoformer/template_embedding/single_template_embedding/template_pair_stack/* ==> + '...template_pair_stack.0.*' + '...template_pair_stack.1.*' + ... + + Rule 2: alphafold/alphafold_iteration/evoformer/extra_msa_stack/* ==> + 'alphafold_iteration.evoformer.extra_msa_stack.0.*', + 'alphafold_iteration.evoformer.extra_msa_stack.1.*', + ... + + Rule 3: alphafold/alphafold_iteration/evoformer/evoformer_iteration/* ==> + 'alphafold.alphafold_iteration.evoformer.evoformer_iteration.0.*', + 'alphafold.alphafold_iteration.evoformer.evoformer_iteration.1.*', + ... + + Rule 4: */__layer_stack_no_state/* ==> '*.*' + + Rule 5: *//weights ==> '*.weight' + + Rule 6: *//bias ==> '*.bias' + + Rule 7: *//scale ==> '*.weight' + + Rule 8: *//offset ==> '*.bias' + """ + rule_1_prefix = 'alphafold/alphafold_iteration/evoformer/template_embedding/single_template_embedding/template_pair_stack/' + rule_2_prefix = 'alphafold/alphafold_iteration/evoformer/extra_msa_stack/' + rule_3_prefix = 'alphafold/alphafold_iteration/evoformer/evoformer_iteration/' + rule_4_prefix = '__layer_stack_no_state/' + + pd_params = dict() + + def _parse_stack_or_iteration(rule_prefix, k): + n = params[k].shape[0] + suffix = k[len(rule_prefix):] + + # rule 4 + if suffix.startswith(rule_4_prefix): + suffix = suffix[len(rule_4_prefix):] + + # rule 5 + suffix = suffix.replace('//weights', '.weight') + # rule 6 + suffix = suffix.replace('//bias', '.bias') + # rule 7 + suffix = suffix.replace('//scale', '.weight') + # rule 8 + suffix = suffix.replace('//offset', '.bias') + + suffix = suffix.replace('//', '.') + suffix = suffix.replace('/', '.') + + prefix = rule_prefix.replace('/', '.') + for i in range(n): + k_ = f'{prefix}{i}.{suffix}' + pd_params[k_] = np.copy(params[k][i]) + + for k in params.keys(): + if k.startswith(rule_1_prefix): + _parse_stack_or_iteration(rule_1_prefix, k) + + elif k.startswith(rule_2_prefix): + _parse_stack_or_iteration(rule_2_prefix, k) + + elif k.startswith(rule_3_prefix): + _parse_stack_or_iteration(rule_3_prefix, k) + + else: + k_ = k.replace('//weights', '.weight') + k_ = k_.replace('//scale', '.weight') + k_ = k_.replace('//offset', '.bias') + k_ = k_.replace('//', '.') + k_ = k_.replace('/', '.') + pd_params[k_] = np.copy(params[k]) + + return pd_params + + +def slice_batch(batch, i): + b = {k: v[i] for k, v in batch.items()} + return b + +def add_batch_dim(batch): + b = {k: v[None,] for k, v in batch.items()} + return b + +def map_to_tensor(batch, add_batch=False): + if add_batch: + batch = add_batch_dim(batch) + + b = {k: paddle.to_tensor(v) for k, v in batch.items()} + return b + + +def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10): + if drop_mask_channel: + mask = mask[:, 0] + + mask_shape = mask.shape + value_shape = value.shape + assert len(mask_shape) == len(value_shape) + + if isinstance(axis, numbers.Integral): + axis = [axis] + elif axis is None: + axis = list(range(len(mask_shape))) + + assert isinstance(axis, collections.Iterable), \ + 'axis needs to be either an iterable, integer or "None"' + + broadcast_factor = 1. + for axis_ in axis: + value_size = value_shape[axis_] + mask_size = mask_shape[axis_] + if mask_size == 1: + broadcast_factor *= value_size + else: + assert mask_size == value_size + + return (paddle.sum(mask * value, axis=axis) / + (paddle.sum(mask, axis=axis) * broadcast_factor + eps)) + + +def batched_gather(params, indices, axis=0, batch_dims=0): + # Implement gather with batching, like tensorflow: + # https://www.tensorflow.org/api_docs/python/tf/gather#batching + # print(params.shape, indices.shape, axis) + p, i = params, indices + rank = len(p.shape) + axis = (rank + axis) % rank + # The stride of axis + stride = p.shape[batch_dims + axis] + + if batch_dims == 0 and len(i.shape) == 1: + return paddle.gather(p, i, axis=axis) + + elif batch_dims == 0: + flat_i = i.reshape([-1]) + gathered = paddle.gather(p, flat_i, axis=axis) + shape = p.shape[:axis] + i.shape + if axis < rank - 1: + shape += params.shape[axis + 1:] + return gathered.reshape(shape) + + b = batch_dims + a = axis + assert p.shape[:b] == i.shape[:b] + bn = np.prod(p.shape[:b]) + + # Shift batch dimensions right to bundle with axis + if a > 0: + perm = list(range(rank)) + perm = perm[b:(b + a)] + perm[:b] + perm[(b + a):] + p = p.transpose(perm) + + # Merge params' batch+axis + p = p.reshape(p.shape[:a] + [-1] + p.shape[(b + a + 1):]) + + # indices = [Batch..., Index...] + # Expand the index values across batch elements + strides = paddle.arange(bn).unsqueeze(-1) * stride + i = i.reshape([bn, -1]) + flat_i = paddle.flatten(i + strides) + + # Do gather + gathered = paddle.gather(p, flat_i, axis=axis) + + # Unbundle batch and index dimensions + unbundled_shape = p.shape[:a] + indices.shape + p.shape[a + 1:] + gathered = gathered.reshape(unbundled_shape) + + # Shift batch dimensions back to the left + if a > 0: + perm = list(range(len(unbundled_shape))) + perm = perm[a:(a + b)] + perm[:a] + perm[(a + b):] + gathered = gathered.transpose(perm) + + return gathered + + +def subbatch(f, arg_idx, dim, bs, out_idx): + """ Converts a function to one that applies to subbatch of an input + dimension. + + Args: + f(Callable): original function. + arg_idx([int]): indices of the inputs to be subbatched. + dim([int]): index of the dimension to be subbatched. + bs(int): subbatch size. + out_idx(int): index of the output dimension that needs stacking + + Returns: + converted function. + """ + @functools.wraps(f) + def wrapper(*args, **kwargs): + assert len(arg_idx) == len(dim), f'Number of batching args and number of batching dims should match.' + + inps = [args[i] for i in arg_idx] + dim_width = [inp.shape[d] for inp, d in zip(inps, dim)] + assert len(set(dim_width)) == 1, f'Batch sizes should be kept equal.' + + inp_dim = {inp: d for inp, d in zip(inps, dim)} + + dim_width = dim_width[0] + if dim_width < bs: + return f(*args, **kwargs) + + outs = [] + for slice_at in np.arange(0, dim_width, bs): + _args = [] + for i, inp in enumerate(args): + if i in arg_idx: + inp = inp.slice([inp_dim[inp]], [slice_at], [slice_at + bs]) + _args.append(inp) + outs.append(f(*_args, **kwargs)) + + return paddle.concat(outs, out_idx) + + return wrapper + + +def get_confidence_metrics( + prediction_result: Mapping[str, Any]) -> Mapping[str, Any]: + """Post processes prediction_result to get confidence metrics.""" + + confidence_metrics = {} + confidence_metrics['plddt'] = confidence.compute_plddt( + prediction_result['predicted_lddt']['logits']) + + if 'predicted_aligned_error' in prediction_result: + confidence_metrics.update(confidence.compute_predicted_aligned_error( + prediction_result['predicted_aligned_error']['logits'], + prediction_result['predicted_aligned_error']['breaks'])) + + confidence_metrics['ptm'] = confidence.predicted_tm_score( + prediction_result['predicted_aligned_error']['logits'], + prediction_result['predicted_aligned_error']['breaks']) + + return confidence_metrics + + +def generate_unrelaxed_pdb(aatype, residue_index, model_output, pdb_path, + b_factors=None): + fold_output = model_output['structure_module'] + if b_factors is None: + b_factors = np.zeros_like(fold_output['final_atom_mask']) + + # NOTE: for single protein, chain_index is always 'A' (idx:0) + prot = protein.Protein( + aatype=aatype, + atom_positions=fold_output['final_atom_positions'], + atom_mask=fold_output['final_atom_mask'], + residue_index=residue_index + 1, + chain_index=np.zeros(aatype.shape), + b_factors=b_factors) + + with open(pdb_path, 'w') as f: + f.write(protein.to_pdb(prot)) + + return prot + + +def set_tensor_constant(tensor, constant): + tensor.set_value(paddle.full_like(tensor, constant)) + + +def init_gate_linear(linear): + set_tensor_constant(linear.weight, 0) + set_tensor_constant(linear.bias, 1) + + +def init_final_linear(linear): + set_tensor_constant(linear.weight, 0) diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_attention.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_attention.py new file mode 100644 index 00000000..627f2cf5 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_attention.py @@ -0,0 +1,114 @@ +import pdb +from layers.basics import Attention +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from argparse import ArgumentParser as Parser + + +parser = Parser('[pd.infer] UT of pdpd.Attention') +parser.add_argument('--n_cpus', type=int, required=True, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['template']['attention'] +gc = cfg['model']['global_config'] +q_mat = pd.ones([1, 583696, 1, 128]) +m_mat = pd.ones([1, 583696, 4, 64]) +bias = pd.ones([1, 1, 1, 1, 4]) +n_warm = 3 +n_iter = 13 +ignore_eval = False +is_dynamic_input = False +prefix_weights = 'dynamic_params/attention' if is_dynamic_input else 'static_params/attention' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + + +print('# [INFO] build and save static graph of Attention') +model = Attention(c, gc,128,64,128) +model.eval() +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + else: + len_dim = 583696 # None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, len_dim, 1, 128]), + InputSpec(shape=[1, len_dim, 4, 64]), + InputSpec(shape=[1, 1, 1, 1, 4]) + ]) + save(net, prefix_weights) + + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in range(n_iter): + t0 = time.time() + _ = model(q_mat, m_mat, bias) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +q_mat1 = np.ones([1, 583696, 1, 128], dtype='float32') +m_mat1 = np.ones([1, 583696, 4, 64], dtype='float32') +bias1 = np.ones([1, 1, 1, 1, 4], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() # ['q_data', 'm_data', 'bias'] +inputl_q = predictor.get_input_handle('q_data') +inputl_m = predictor.get_input_handle('m_data') +inputl_b = predictor.get_input_handle('bias') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_q.reshape(q_mat1.shape) + inputl_m.reshape(m_mat1.shape) + inputl_b.reshape(bias1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_2'] +outputl = predictor.get_output_handle('tmp_2') + +# run +dts = 0. +for i in range(n_iter): + t0 = time.time() + inputl_q.copy_from_cpu(q_mat1) + inputl_m.copy_from_cpu(m_mat1) + inputl_b.copy_from_cpu(bias1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_embeddings.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_embeddings.py new file mode 100644 index 00000000..94c432e3 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_embeddings.py @@ -0,0 +1,221 @@ +import pdb +from layers.backbones import Embeddings +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from tqdm import tqdm +import pickle +from argparse import ArgumentParser as Parser +import warnings + +parser = Parser('[pd.infer] UT of pdpd.embeddings') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 13 +ignore_eval = True +is_dynamic_input = False +prefix_weights = 'dynamic_params/embeddings' if is_dynamic_input else 'static_params/embeddings' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +TARGET_FEAT_DIM = 22 +MSA_FEAT_DIM = 49 +channel_num = { + 'target_feat': TARGET_FEAT_DIM, + 'msa_feat': MSA_FEAT_DIM, + 'extra_msa_channel': c.extra_msa_channel, + 'msa_channel': c.msa_channel, + 'pair_channel': c.pair_channel, + 'seq_channel': c.seq_channel + } +### create sample input +len_dim = 206 +batch = { + 'target_feat': pd.ones([1, len_dim, 22]), + 'msa_feat': pd.ones([1, 508, len_dim, 49]), # 508 -> 512 + 'seq_mask': pd.ones([1, len_dim]), + 'aatype': pd.ones([1, len_dim], dtype='int32'), + 'prev_pos': pd.ones([1, len_dim, 37, 3]), + 'prev_msa_first_row': pd.ones([1, len_dim, 256]), + 'prev_pair': pd.ones([1, len_dim, len_dim, 128]), + 'residue_index': pd.ones([1, len_dim]), + 'template_mask': pd.ones([1, 4]), + 'template_aatype': pd.ones([1, 4, len_dim], dtype="int32"), # define + 'template_pseudo_beta_mask': pd.ones([1, 4, len_dim]), + 'template_pseudo_beta': pd.ones([1, 4, len_dim, 3]), + 'template_all_atom_positions': pd.ones([1, 4, len_dim, 37, 3]), + 'template_all_atom_masks': pd.ones([1, 4, len_dim, 37]), + 'extra_msa': pd.ones([1, 5120, len_dim]), + 'extra_has_deletion': pd.ones([1, 5120, len_dim]), + 'extra_deletion_value': pd.ones([1, 5120, len_dim]), + # 'extra_msa_mask': pd.ones([1, 5120, len_dim]), + # 'torsion_angles_sin_cos': pd.ones([1, 4, len_dim, 7, 2]), + # 'alt_torsion_angles_sin_cos': pd.ones([1, 4, len_dim, 7, 2]), + # 'torsion_angles_mask': pd.ones([1, 4, len_dim, 7]) +} + + +print('# [INFO] build and save static graph of Attention') +model = Embeddings(channel_num, c, gc) +model.eval() +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, len_dim, 22],name='target_feat'), + InputSpec(shape=[1, 508, len_dim, 49],name='msa_feat'), + InputSpec(shape=[1, len_dim],name='seq_mask'), + InputSpec(shape=[1, len_dim],name='aatype', dtype='int32'), + InputSpec(shape=[1, len_dim],name='residue_index'), + InputSpec(shape=[1, 4],name='template_mask'), + InputSpec(shape=[1, 4, len_dim],name='template_aatype', dtype="int32"), + InputSpec(shape=[1, 4, len_dim],name='template_pseudo_beta_mask'), + InputSpec(shape=[1, 4, len_dim, 3],name='template_pseudo_beta'), + InputSpec(shape=[1, 4, len_dim, 37, 3],name='template_all_atom_positions'), + InputSpec(shape=[1, 4, len_dim, 37],name='template_all_atom_masks'), + InputSpec(shape=[1, 5120, len_dim],name='extra_msa'), + InputSpec(shape=[1, 5120, len_dim],name='extra_has_deletion'), + InputSpec(shape=[1, 5120, len_dim],name='extra_deletion_value'), + InputSpec(shape=[1, len_dim, 37, 3],name='prev_pos'), + InputSpec(shape=[1, len_dim, 256],name='prev_msa_first_row'), + InputSpec(shape=[1, len_dim, len_dim, 128],name='prev_pair'), + # InputSpec(shape=[1, 4, len_dim, 7, 2]), + # InputSpec(shape=[1, 4, len_dim, 7, 2]), + # InputSpec(shape=[1, 4, len_dim, 7]) + ]) + save(net, prefix_weights) + + +print('# [INFO] inference on dynamic graph') +warnings.filterwarnings('ignore', 'DAP communication') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(**batch) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt +warnings.resetwarnings() + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +target_feat_1 = np.ones([1, len_dim, 22], dtype='float32') +msa_feat_1 = np.ones([1, 508, len_dim, 49], dtype='float32') +seq_mask_1 = np.ones([1, len_dim], dtype='float32') +aatype_1 = np.ones([1, len_dim], dtype='int32') +residue_index_1 = np.ones([1, len_dim], dtype='float32') +template_mask_1 = np.ones([1, 4], dtype='float32') +template_aatype_1 = np.ones([1, 4, len_dim], dtype="int32") +template_pseudo_beta_mask_1 = np.ones([1, 4, len_dim], dtype='float32') +template_pseudo_beta_1 = np.ones([1, 4, len_dim, 3], dtype='float32') +template_all_atom_positions_1 = np.ones([1, 4, len_dim, 37, 3], dtype='float32') +template_all_atom_masks_1 = np.ones([1, 4, len_dim, 37], dtype='float32') +prev_pos_1 = np.ones([1, len_dim, 37, 3], dtype='float32') +prev_msa_first_row_1 = np.ones([1, len_dim, 256], dtype='float32') +prev_pair_1 = np.ones([1, len_dim, len_dim, 128], dtype='float32') +extra_msa_1 = np.ones([1, 5120, len_dim], dtype='float32') +extra_has_deletion_1 = np.ones([1, 5120, len_dim], dtype='float32') +extra_deletion_value_1 = np.ones([1, 5120, len_dim], dtype='float32') +# extra_msa_mask_1 = np.ones([1, 5120, len_dim], dtype='float32') +# msa_mask_1 = np.ones([1, 508, len_dim], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() +inputl_ta = predictor.get_input_handle('target_feat') +inputl_msf = predictor.get_input_handle('msa_feat') +inputl_se = predictor.get_input_handle('seq_mask') +inputl_aa = predictor.get_input_handle('aatype') +inputl_prpo = predictor.get_input_handle('prev_pos') +inputl_prms = predictor.get_input_handle('prev_msa_first_row') +inputl_prpa = predictor.get_input_handle('prev_pair') +inputl_re = predictor.get_input_handle('residue_index') +inputl_tema = predictor.get_input_handle('template_mask') +inputl_teaa = predictor.get_input_handle('template_aatype') +inputl_tepm = predictor.get_input_handle('template_pseudo_beta_mask') +inputl_tepb = predictor.get_input_handle('template_pseudo_beta') +inputl_teap = predictor.get_input_handle('template_all_atom_positions') +inputl_team = predictor.get_input_handle('template_all_atom_masks') +inputl_exms = predictor.get_input_handle('extra_msa') +inputl_exha = predictor.get_input_handle('extra_has_deletion') +inputl_exde = predictor.get_input_handle('extra_deletion_value') +# inputl_exmm = predictor.get_input_handle('extra_msa_mask') +# inputl_msm = predictor.get_input_handle('msa_mask') + +# # 变形输入轴 +# if is_dynamic_input: +# print('# [INFO] re-organize dynamic axes') +# inputl_q.reshape(q_mat1.shape) +# inputl_m.reshape(m_mat1.shape) +# inputl_b.reshape(bias1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() +print(output_names) +outputls = {k:predictor.get_output_handle(k) for k in output_names} + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_ta.copy_from_cpu(target_feat_1) + inputl_msf.copy_from_cpu(msa_feat_1) + inputl_se.copy_from_cpu(seq_mask_1) + inputl_aa.copy_from_cpu(aatype_1) + inputl_re.copy_from_cpu(residue_index_1) + inputl_tema.copy_from_cpu(template_mask_1) + inputl_teaa.copy_from_cpu(template_aatype_1) + inputl_tepm.copy_from_cpu(template_pseudo_beta_mask_1) + inputl_tepb.copy_from_cpu(template_pseudo_beta_1) + inputl_teap.copy_from_cpu(template_all_atom_positions_1) + inputl_team.copy_from_cpu(template_all_atom_masks_1) + inputl_prpo.copy_from_cpu(prev_pos_1) + inputl_prms.copy_from_cpu(prev_msa_first_row_1) + inputl_prpa.copy_from_cpu(prev_pair_1) + inputl_exms.copy_from_cpu(extra_msa_1) + inputl_exha.copy_from_cpu(extra_has_deletion_1) + inputl_exde.copy_from_cpu(extra_deletion_value_1) + # inputl_exmm.copy_from_cpu(extra_msa_mask_1) + # inputl_msm.copy_from_cpu(msa_mask_1) + + predictor.run() + # msa_activations_raw_1 = outputls[output_names[0]].copy_to_cpu() # (1, 508, len, 256) + # pair_activations = outputls[output_names[1]].copy_to_cpu() # (1, len, len, 128) + # mask_2d = outputls[output_names[2]].copy_to_cpu() # (1, len, len) + output_shapes = {k:outputls[k].copy_to_cpu().shape for k in output_names} + print(output_shapes) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +# print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_embeddingsandevoformer.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_embeddingsandevoformer.py new file mode 100644 index 00000000..70634aa7 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_embeddingsandevoformer.py @@ -0,0 +1,220 @@ +import pdb +from layers.subnets import EmbeddingsAndEvoformer +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from tqdm import tqdm +import pickle +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] UT of pdpd.Attention') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 13 +ignore_eval = False +is_dynamic_input = False +prefix_weights = 'dynamic_params/embeddingsandevoformer' if is_dynamic_input else 'static_params/embeddingsandevoformer' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +TARGET_FEAT_DIM = 22 +MSA_FEAT_DIM = 49 +channel_num = { + 'target_feat': TARGET_FEAT_DIM, + 'msa_feat': MSA_FEAT_DIM, + 'extra_msa_channel': c.extra_msa_channel, + 'msa_channel': c.msa_channel, + 'pair_channel': c.pair_channel, + 'seq_channel': c.seq_channel + } +### create sample input +with open('/home/yangw/experiments/helix_fold/T1042/model_1_input.pkl', 'rb') as h: + sample = pickle.load(h) +len_dim = 40 +batch = { + 'target_feat': pd.ones([1, len_dim, 22]), + 'msa_feat': pd.ones([1, 508, len_dim, 49]), # 508 -> 512 + 'seq_mask': pd.ones([1, len_dim]), + 'aatype': pd.ones([1, len_dim]), + 'prev_pos': pd.ones([1, len_dim, 37, 3]), + 'prev_msa_first_row': pd.ones([1, len_dim, 256]), + 'prev_pair': pd.ones([1, len_dim, len_dim, 128]), + 'residue_index': pd.ones([1, len_dim]), + 'template_mask': pd.ones([1, 4]), + 'template_aatype': pd.ones([1, 4, len_dim], dtype="int32"), # define + 'template_pseudo_beta_mask': pd.ones([1, 4, len_dim]), + 'template_pseudo_beta': pd.ones([1, 4, len_dim, 3]), + 'template_all_atom_positions': pd.ones([1, 4, len_dim, 37, 3]), + 'template_all_atom_masks': pd.ones([1, 4, len_dim, 37]), + 'extra_msa': pd.ones([1, 5120, len_dim]), + 'extra_has_deletion': pd.ones([1, 5120, len_dim]), + 'extra_deletion_value': pd.ones([1, 5120, len_dim]), + 'extra_msa_mask': pd.ones([1, 5120, len_dim]), + 'msa_mask': pd.ones([1, 508, len_dim]), + # 'torsion_angles_sin_cos': pd.ones([1, 4, len_dim, 7, 2]), + # 'alt_torsion_angles_sin_cos': pd.ones([1, 4, len_dim, 7, 2]), + # 'torsion_angles_mask': pd.ones([1, 4, len_dim, 7]) +} + + +print('# [INFO] build and save static graph of Attention') +model = EmbeddingsAndEvoformer(channel_num, c, gc) +model.eval() +with pd.no_grad(): + _ = model(**batch) +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, len_dim, 22],name='target_feat'), + InputSpec(shape=[1, 508, len_dim, 49],name='msa_feat'), + InputSpec(shape=[1, len_dim],name='seq_mask'), + InputSpec(shape=[1, len_dim],name='aatype'), + InputSpec(shape=[1, len_dim],name='residue_index'), + InputSpec(shape=[1, 4],name='template_mask'), + InputSpec(shape=[1, 4, len_dim],name='template_aatype', dtype="int32"), + InputSpec(shape=[1, 4, len_dim],name='template_pseudo_beta_mask'), + InputSpec(shape=[1, 4, len_dim, 3],name='template_pseudo_beta'), + InputSpec(shape=[1, 4, len_dim, 37, 3],name='template_all_atom_positions'), + InputSpec(shape=[1, 4, len_dim, 37],name='template_all_atom_masks'), + InputSpec(shape=[1, 5120, len_dim],name='extra_msa'), + InputSpec(shape=[1, 5120, len_dim],name='extra_has_deletion'), + InputSpec(shape=[1, 5120, len_dim],name='extra_deletion_value'), + InputSpec(shape=[1, 5120, len_dim],name='extra_msa_mask'), + InputSpec(shape=[1, 508, len_dim],name='msa_mask'), + InputSpec(shape=[1, len_dim, 37, 3],name='prev_pos'), + InputSpec(shape=[1, len_dim, 256],name='prev_msa_first_row'), + InputSpec(shape=[1, len_dim, len_dim, 128],name='prev_pair'), + # InputSpec(shape=[1, 4, len_dim, 7, 2]), + # InputSpec(shape=[1, 4, len_dim, 7, 2]), + # InputSpec(shape=[1, 4, len_dim, 7]) + ]) + save(net, prefix_weights) + + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in range(n_iter): + t0 = time.time() + _ = model(**batch) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +target_feat_1 = np.ones([1, len_dim, 22], dtype='float32') +msa_feat_1 = np.ones([1, 508, len_dim, 49], dtype='float32') +seq_mask_1 = np.ones([1, len_dim], dtype='float32') +aatype_1 = np.ones([1, len_dim], dtype='int32') +prev_pos_1 = np.ones([1, len_dim, 37, 3], dtype='float32') +prev_msa_first_row_1 = np.ones([1, len_dim, 256], dtype='float32') +prev_pair_1 = np.ones([1, len_dim, len_dim, 128], dtype='float32') +residue_index_1 = np.ones([1, len_dim], dtype='float32') +template_mask_1 = np.ones([1, 4], dtype='float32') +template_aatype_1 = np.ones([1, 4, len_dim], dtype="int32") +template_pseudo_beta_mask_1 = np.ones([1, 4, len_dim], dtype='float32') +template_pseudo_beta_1 = np.ones([1, 4, len_dim, 3], dtype='float32') +template_all_atom_positions_1 = np.ones([1, 4, len_dim, 37, 3], dtype='float32') +template_all_atom_masks_1 = np.ones([1, 4, len_dim, 37], dtype='float32') +extra_msa_1 = np.ones([1, 5120, len_dim], dtype='float32') +extra_has_deletion_1 = np.ones([1, 5120, len_dim], dtype='float32') +extra_deletion_value_1 = np.ones([1, 5120, len_dim], dtype='float32') +extra_msa_mask_1 = np.ones([1, 5120, len_dim], dtype='float32') +msa_mask_1 = np.ones([1, 508, len_dim], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() +inputl_ta = predictor.get_input_handle('target_feat') +inputl_msf = predictor.get_input_handle('msa_feat') +inputl_se = predictor.get_input_handle('seq_mask') +inputl_aa = predictor.get_input_handle('aatype') +inputl_prpo = predictor.get_input_handle('prev_pos') +inputl_prms = predictor.get_input_handle('prev_msa_first_row') +inputl_prpa = predictor.get_input_handle('prev_pair') +inputl_re = predictor.get_input_handle('residue_index') +inputl_tema = predictor.get_input_handle('template_mask') +inputl_teaa = predictor.get_input_handle('template_aatype') +inputl_tepm = predictor.get_input_handle('template_pseudo_beta_mask') +inputl_tepb = predictor.get_input_handle('template_pseudo_beta') +inputl_teap = predictor.get_input_handle('template_all_atom_positions') +inputl_team = predictor.get_input_handle('template_all_atom_masks') +inputl_exms = predictor.get_input_handle('extra_msa') +inputl_exha = predictor.get_input_handle('extra_has_deletion') +inputl_exde = predictor.get_input_handle('extra_deletion_value') +inputl_exmm = predictor.get_input_handle('extra_msa_mask') +inputl_msm = predictor.get_input_handle('msa_mask') + +# # 变形输入轴 +# if is_dynamic_input: +# print('# [INFO] re-organize dynamic axes') +# inputl_q.reshape(q_mat1.shape) +# inputl_m.reshape(m_mat1.shape) +# inputl_b.reshape(bias1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() +outputl = predictor.get_output_handle('tmp_2') + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_ta.copy_from_cpu(target_feat_1) + inputl_msf.copy_from_cpu(msa_feat_1) + inputl_se.copy_from_cpu(seq_mask_1) + inputl_aa.copy_from_cpu(aatype_1) + inputl_prpo.copy_from_cpu(prev_pos_1) + inputl_prms.copy_from_cpu(prev_msa_first_row_1) + inputl_prpa.copy_from_cpu(prev_pair_1) + inputl_re.copy_from_cpu(residue_index_1) + inputl_tema.copy_from_cpu(template_mask_1) + inputl_teaa.copy_from_cpu(template_aatype_1) + inputl_tepm.copy_from_cpu(template_pseudo_beta_mask_1) + inputl_tepb.copy_from_cpu(template_pseudo_beta_1) + inputl_teap.copy_from_cpu(template_all_atom_positions_1) + inputl_team.copy_from_cpu(template_all_atom_masks_1) + inputl_exms.copy_from_cpu(extra_msa_1) + inputl_exha.copy_from_cpu(extra_has_deletion_1) + inputl_exde.copy_from_cpu(extra_deletion_value_1) + inputl_exmm.copy_from_cpu(extra_msa_mask_1) + inputl_msm.copy_from_cpu(msa_mask_1) + + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +# if not ignore_eval: +# print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +# print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +# print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_evoformeriteration.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_evoformeriteration.py new file mode 100644 index 00000000..545e49e3 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_evoformeriteration.py @@ -0,0 +1,121 @@ +import pdb +from layers.backbones import EvoformerIteration +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from argparse import ArgumentParser as Parser +from tqdm import tqdm + + +parser = Parser('[pd.infer] UT of helixfold.evoformeriteration') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['evoformer'] +gc = cfg['model']['global_config'] +len_dim = 1024 +msa_act = pd.ones([1, 512, len_dim, 256]) +msa_mask = pd.ones([1, 512, len_dim]) +pair_act = pd.ones([1, len_dim, len_dim, 128]) +pair_mask = pd.ones([1, len_dim, len_dim]) +n_warm = 3 +n_iter = 13 +ignore_eval = True +is_dynamic_input = False +module_prefix = 'evoformeriteration' +prefix_weights = 'dynamic_params/evoformeriteration' if is_dynamic_input else 'static_params/evoformeriteration' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' +channel_num = {'msa_channel':256, 'pair_channel':128} + +print('# [INFO] build and save static graph of Evoformeriteration') +model = EvoformerIteration(channel_num, c, gc, is_extra_msa=False) +model.eval() +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, 512, len_dim, 256]), + InputSpec(shape=[1, len_dim, len_dim, 128]), + InputSpec(shape=[1, 512, len_dim]), + InputSpec(shape=[1, len_dim, len_dim]) + ]) + save(net, prefix_weights) + + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + msa_act, pair_act = model(msa_act, pair_act, msa_mask, pair_mask) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +msa_act1 = np.ones([1, 512, len_dim, 256], dtype='float32') +msa_mask1 = np.ones([1, 512, len_dim], dtype='float32') +pair_act1 = np.ones([1, len_dim, len_dim, 128], dtype='float32') +pair_mask1 = np.ones([1, len_dim, len_dim], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() +inputl_sa = predictor.get_input_handle('msa_act') +inputl_sm = predictor.get_input_handle('msa_mask') +inputl_pa = predictor.get_input_handle('pair_act') +inputl_pm = predictor.get_input_handle('pair_mask') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_sa.reshape(msa_act1.shape) + inputl_sm.reshape(msa_mask1.shape) + inputl_pa.reshape(pair_act1.shape) + inputl_pm.reshape(pair_mask1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_2'] +outputl = predictor.get_output_handle('tmp_2') + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_sa.copy_from_cpu(msa_act1) + inputl_sm.copy_from_cpu(msa_mask1) + inputl_pa.copy_from_cpu(pair_act1) + inputl_pm.copy_from_cpu(pair_mask1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_extramsa.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_extramsa.py new file mode 100644 index 00000000..4be81d24 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_extramsa.py @@ -0,0 +1,152 @@ +import pdb +from layers.backbones import ExtraMsa +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from tqdm import tqdm +import pickle +from argparse import ArgumentParser as Parser +import warnings + +model_prefix = 'extramsa' +parser = Parser('[pd.infer] UT of pdpd.{}'.format(model_prefix)) +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 13 +ignore_eval = True +is_dynamic_input = False +prefix_weights = 'dynamic_params/{}'.format(model_prefix) if is_dynamic_input else 'static_params/{}'.format(model_prefix) +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +TARGET_FEAT_DIM = 22 +MSA_FEAT_DIM = 49 +channel_num = { + 'target_feat': TARGET_FEAT_DIM, + 'msa_feat': MSA_FEAT_DIM, + 'extra_msa_channel': c.extra_msa_channel, + 'msa_channel': c.msa_channel, + 'pair_channel': c.pair_channel, + 'seq_channel': c.seq_channel + } +### create sample input +len_dim = 10 +batch = { + 'extra_msa': pd.ones([1, 5120, len_dim]), + 'extra_has_deletion': pd.ones([1, 5120, len_dim]), + 'extra_deletion_value': pd.ones([1, 5120, len_dim]), + 'extra_msa_mask': pd.ones([1, 5120, len_dim]), + 'pair_activations': pd.ones([1, len_dim, len_dim, 128]), + 'mask_2d': pd.ones([1, len_dim, len_dim]) +} + + +print('# [INFO] build and save static graph of {}'.format(model_prefix)) +model = ExtraMsa(channel_num, c, gc) +model.eval() +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, 5120, len_dim],name='extra_msa'), + InputSpec(shape=[1, 5120, len_dim],name='extra_has_deletion'), + InputSpec(shape=[1, 5120, len_dim],name='extra_deletion_value'), + InputSpec(shape=[1, 5120, len_dim],name='extra_msa_mask'), + InputSpec(shape=[1, len_dim, len_dim, 128],name='pair_activations'), + InputSpec(shape=[1, len_dim, len_dim],name='mask_2d') + ]) + save(net, prefix_weights) + + +print('# [INFO] inference on dynamic graph') +warnings.filterwarnings('ignore', 'DAP communication') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(**batch) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt +warnings.resetwarnings() + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +extra_msa_1 = np.ones([1, 5120, len_dim], dtype='float32') +extra_has_deletion_1 = np.ones([1, 5120, len_dim], dtype='float32') +extra_deletion_value_1 = np.ones([1, 5120, len_dim], dtype='float32') +extra_msa_mask_1 = np.ones([1, 5120, len_dim], dtype='float32') +pair_activations_1 = np.ones([1, len_dim, len_dim, 128], dtype='float32') +mask_2d_1 = np.ones([1, len_dim, len_dim], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() +inputl_exms = predictor.get_input_handle('extra_msa') +inputl_exha = predictor.get_input_handle('extra_has_deletion') +inputl_exde = predictor.get_input_handle('extra_deletion_value') +inputl_exmm = predictor.get_input_handle('extra_msa_mask') +inputl_pact = predictor.get_input_handle('pair_activations') +inputl_ms2d = predictor.get_input_handle('mask_2d') + +# # 变形输入轴 +# if is_dynamic_input: +# print('# [INFO] re-organize dynamic axes') +# inputl_q.reshape(q_mat1.shape) +# inputl_m.reshape(m_mat1.shape) +# inputl_b.reshape(bias1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() +print(output_names) +outputls = {k:predictor.get_output_handle(k) for k in output_names} + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_exms.copy_from_cpu(extra_msa_1) + inputl_exha.copy_from_cpu(extra_has_deletion_1) + inputl_exde.copy_from_cpu(extra_deletion_value_1) + inputl_exmm.copy_from_cpu(extra_msa_mask_1) + inputl_pact.copy_from_cpu(pair_activations_1) + inputl_ms2d.copy_from_cpu(mask_2d_1) + + predictor.run() + output_shapes = {k:outputls[k].copy_to_cpu().shape for k in output_names} + print(output_shapes) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +# print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_globalattention.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_globalattention.py new file mode 100644 index 00000000..6cfc8357 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_globalattention.py @@ -0,0 +1,115 @@ +import pdb +from layers.basics import GlobalAttention +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from tqdm import tqdm +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] UT of pdpd.GolbalAttention') +parser.add_argument('--n_cpus', type=int, required=True, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['evoformer']['msa_column_attention'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 13 +ignore_eval = False +is_dynamic_input = False +prefix_weights = 'dynamic_params/globalattention' if is_dynamic_input else 'static_params/globalattention' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' +### setup model +model = GlobalAttention(c,gc,64,64,64) +model.eval() +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + else: + len_dim = 764 # None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, len_dim, 5120, 64]), + InputSpec(shape=[1, len_dim, 5120, 64]), + InputSpec(shape=[1, len_dim, 5120, 1]) + ]) + save(net, prefix_weights) +else: + len_dim = 764 +### create sample input +msa_act = pd.ones([1, 764, 5120, 64]) +msa_mask = pd.ones([1, 764, 5120, 1]) +# bias = pd.ones([1, 764, 1, 1, 5120]) # not used + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(msa_act, msa_act, msa_mask) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +msa_act1 = np.ones([1, len_dim, 5120, 64], dtype='float32') +msa_mask1 = np.ones([1, len_dim, 5120, 1], dtype='float32') +# bias1 = np.ones([1, 1, 1, 1, 4], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() # ['q_data', 'm_data', 'q_mask'] +inputl_q = predictor.get_input_handle('q_data') +inputl_m = predictor.get_input_handle('m_data') +inputl_b = predictor.get_input_handle('q_mask') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_q.reshape(msa_act1.shape) + inputl_m.reshape(msa_act1.shape) + inputl_b.reshape(msa_mask1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_10'] +outputl = predictor.get_output_handle('tmp_10') + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_q.copy_from_cpu(msa_act1) + inputl_m.copy_from_cpu(msa_act1) + inputl_b.copy_from_cpu(msa_mask1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_maskedmsahead.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_maskedmsahead.py new file mode 100644 index 00000000..28240c90 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_maskedmsahead.py @@ -0,0 +1,115 @@ +import pdb +from layers.head import MaskedMsaHead +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +from tqdm import tqdm +import numpy as np +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] UT of pdpd.maskedmsahead') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['heads']['masked_msa'] +gc = cfg['model']['global_config'] +### create sample input +# masked_msa +len_dim = 206 +msa_representation = pd.ones([1, 508, len_dim, 256]) + +n_warm = 3 +n_iter = 13 +ignore_eval = True +is_dynamic_input = False +prefix_weights = 'dynamic_params/maskedmashead' if is_dynamic_input else 'static_params/maskedmashead' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +TARGET_FEAT_DIM = 22 +MSA_FEAT_DIM = 49 +channel_num = { + 'target_feat': TARGET_FEAT_DIM, + 'msa_feat': MSA_FEAT_DIM, + 'msa_channel': 256, + } + +print('# [INFO] build and save static graph of Attention') +model = MaskedMsaHead(channel_num, c, gc) +model.eval() +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, 508, len_dim, 256]), + ]) + save(net, prefix_weights) + + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(msa_representation) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +msa_representation1 = np.ones([1, 508, len_dim, 256], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() # ['q_data', 'm_data', 'bias'] +inputl_msa = predictor.get_input_handle('msa_representation') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_q.reshape(q_mat1.shape) + inputl_m.reshape(m_mat1.shape) + inputl_b.reshape(bias1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_2'] +outputl = predictor.get_output_handle('tmp_2') + +# run +dts = 0. +for i in range(n_iter): + t0 = time.time() + inputl_q.copy_from_cpu(q_mat1) + inputl_m.copy_from_cpu(m_mat1) + inputl_b.copy_from_cpu(bias1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_msacolumnattention.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_msacolumnattention.py new file mode 100644 index 00000000..ed5d2b49 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_msacolumnattention.py @@ -0,0 +1,112 @@ +import pdb +from layers.basics import MSAColumnAttention +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from argparse import ArgumentParser as Parser +from tqdm import tqdm + + +parser = Parser('[pd.infer] UT of pdpd.Attention') +parser.add_argument('--n_cpus', type=int, required=True, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['evoformer']['msa_column_attention'] +gc = cfg['model']['global_config'] +n_warm = 3 +n_iter = 10+n_warm +ignore_eval = False +is_dynamic_input = False +prefix_weights = 'dynamic_params/msacolumnattention' if is_dynamic_input else 'static_params/msacolumnattention' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +channel_num = {'msa_channel': 64} +print('# [INFO] build and save static graph of Attention') +model = MSAColumnAttention(channel_num, c, gc) +model.eval() +emb_size = 512 +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + else: + len_dim = emb_size # None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, len_dim, 764, 64]), + InputSpec(shape=[1, len_dim, 764]) + ]) + save(net, prefix_weights) +### create sample input +msa_act = pd.ones([1, emb_size, 764, 64]) +msa_mask = pd.ones([1, emb_size, 764]) + + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(msa_act, msa_mask) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +msa_act1 = np.ones([1, emb_size, 764, 64], dtype='float32') +msa_mask1 = np.ones([1, emb_size, 764], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() # ['msa_act', 'msa_mask'] +inputl_a = predictor.get_input_handle('msa_act') +inputl_m = predictor.get_input_handle('msa_mask') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_a.reshape(msa_act1.shape) + inputl_m.reshape(msa_mask1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_0'] +outputl = predictor.get_output_handle(output_names[0]) + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_a.copy_from_cpu(msa_act1) + inputl_m.copy_from_cpu(msa_mask1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) + # 23.197 sec/iter +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_msacolumnglobalattention.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_msacolumnglobalattention.py new file mode 100644 index 00000000..555e022f --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_msacolumnglobalattention.py @@ -0,0 +1,119 @@ +import pdb +from layers.basics import MSAColumnGlobalAttention +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from tqdm import tqdm +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] UT of pdpd.Attention') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['evoformer']['msa_column_attention'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 10 + n_warm +ignore_eval = True +is_dynamic_input = False +prefix_weights = 'dynamic_params/msacolumnglobalattention' if is_dynamic_input else 'static_params/msacolumnglobalattention' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +TARGET_FEAT_DIM = 22 +MSA_FEAT_DIM = 49 +channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + 'extra_msa_channel': 64, + 'msa_channel': 256, + 'pair_channel': 128, + 'seq_channel': 384, + } +### create sample input +len_dim = 764 +msa_act = pd.ones([1, len_dim, 5120, 64]) +msa_mask = pd.ones([1, len_dim,5120]) + +print('# [INFO] build and save static graph of Attention') +model = MSAColumnGlobalAttention(channel_num, c, gc) +model.eval() +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, len_dim, 5120, 64]), + InputSpec(shape=[1, len_dim, 5120]), + ]) + save(net, prefix_weights) + + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(msa_act, msa_mask) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +msa_act1 = np.ones([1, len_dim, 5120, 64], dtype='float32') +msa_mask1 = np.ones([1, len_dim,5120], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() # ['msa_act', 'msa_mask] +inputl_a = predictor.get_input_handle('msa_act') +inputl_m = predictor.get_input_handle('msa_mask') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_a.reshape(msa_act1.shape) + inputl_m.reshape(msa_mask1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_0'] +outputl = predictor.get_output_handle(output_names[0]) + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_a.copy_from_cpu(msa_act1) + inputl_m.copy_from_cpu(msa_mask1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_msarowattentionwithpairbias.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_msarowattentionwithpairbias.py new file mode 100644 index 00000000..e8545c6e --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_msarowattentionwithpairbias.py @@ -0,0 +1,123 @@ +import pdb +from layers.basics import MSARowAttentionWithPairBias +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from tqdm import tqdm +from argparse import ArgumentParser as Parser + + +parser = Parser('[pd.infer] UT of pdpd.Attention') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['evoformer']['msa_row_attention_with_pair_bias'] +gc = cfg['model']['global_config'] +len_dim = 765 +# [TODO] input dims are wrong @zjh intel SMG 20220810 +msa_act = pd.ones([1, 5120, len_dim, 64]) +# msa_act = pd.ones([1, len_dim, 256]) +msa_mask = pd.ones([1, 5120, len_dim]) +pair_act = pd.ones([1, len_dim, len_dim, 128]) +channel_num = { + 'extra_msa_channel': 64, + 'msa_channel': 64, + 'pair_channel': 128, + } +is_extra_msa = False + +n_warm = 3 +n_iter = 13 +ignore_eval = True +is_dynamic_input = False +prefix_weights = 'dynamic_params/msarowattentionwithpairbias' if is_dynamic_input else 'static_params/msarowattentionwithpairbias' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + + +print('# [INFO] build and save static graph of Attention') +model = MSARowAttentionWithPairBias(channel_num, c, gc, is_extra_msa) +model.eval() +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, 5120, len_dim, 64]), + InputSpec(shape=[1, 5120, len_dim]), + InputSpec(shape=[1, len_dim, len_dim, 128]) + ]) + save(net, prefix_weights) + + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(msa_act, msa_mask, pair_act) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +msa_act1 = np.ones([1, 5120, len_dim, 64], dtype='float32') +msa_mask1 = np.ones([1, 5120, len_dim], dtype='float32') +pair_act1 = np.ones([1, len_dim, len_dim, 128], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() # ['msa_act', 'msa_mask', 'pair_act'] +inputl_q = predictor.get_input_handle('msa_act') +inputl_m = predictor.get_input_handle('msa_mask') +inputl_b = predictor.get_input_handle('pair_act') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_q.reshape(msa_act1.shape) + inputl_m.reshape(msa_mask1.shape) + inputl_b.reshape(pair_act1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_2'] +outputl = predictor.get_output_handle('tmp_2') + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_q.copy_from_cpu(msa_act1) + inputl_m.copy_from_cpu(msa_mask1) + inputl_b.copy_from_cpu(pair_act1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_outerproductmean.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_outerproductmean.py new file mode 100644 index 00000000..11a3b228 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_outerproductmean.py @@ -0,0 +1,117 @@ +import pdb +from layers.basics import OuterProductMean +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from tqdm import tqdm +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] UT of pdpd.OuterProductMean') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['evoformer']['outer_product_mean'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 10 + n_warm +ignore_eval = False +is_dynamic_input = False +prefix_weights = 'dynamic_params/outerproductmean' if is_dynamic_input else 'static_params/outerproductmean' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +TARGET_FEAT_DIM = 22 +MSA_FEAT_DIM = 49 +channel_num = { + 'pair_channel': 128, #256 + 'msa_channel': 64, # 128 + } +### create sample input +len_dim = 206 +msa_act = pd.ones([1, 5120, len_dim, 64]) +msa_mask = pd.ones([1, 5120, len_dim]) +pair_act = pd.ones([1, len_dim, len_dim, 128]) + +print('# [INFO] build and save static graph of Attention') +model = OuterProductMean(channel_num, c, gc, False) +model.eval() +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, 5120, len_dim, 64]), + InputSpec(shape=[1, 5120, len_dim]) + ]) + save(net, prefix_weights) + + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + pair_act = model(msa_act, msa_mask) + pair_act + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +msa_act1 = np.ones([1, 5120, len_dim, 64], dtype='float32') +msa_mask1 = np.ones([1, 5120, len_dim], dtype='float32') +pair_act1 = np.ones([1, len_dim, len_dim, 128], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() # ['act'] +inputl_a = predictor.get_input_handle('act') +inputl_m = predictor.get_input_handle('mask') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_a.reshape(msa_act1.shape) + inputl_m.reshape(msa_mask1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_0'] +outputl = predictor.get_output_handle(output_names[0]) + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_a.copy_from_cpu(msa_act1) + inputl_m.copy_from_cpu(msa_mask1) + predictor.run() + output = outputl.copy_to_cpu() + pair_act1 + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_plddt.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_plddt.py new file mode 100644 index 00000000..74ad3940 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_plddt.py @@ -0,0 +1,111 @@ +from layers.head import PredictedLDDTHead +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +from tqdm import tqdm +import os +import numpy as np +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] UT of pdpd.Attention') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + 'extra_msa_channel': 64, + 'msa_channel': 256, + 'pair_channel': 128, + 'seq_channel': 384, + 'template_pair': 85, + } +cfg = model_config('model_1') +c = cfg['model']['heads']['predicted_lddt'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 1000 + n_warm +force_static_cvt = True +ignore_eval = False +is_dynamic_input = False +model_name = 'plddt' # 'plddt' +prefix_weights = 'dynamic_params/' + model_name if is_dynamic_input else 'static_params/' + model_name +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +### create sample input +len_dim = 206 +structure_module = pd.ones([1, 384]) + +print('# [INFO] build and save static graph of Attention') +model = PredictedLDDTHead(channel_num, c, gc) +model.eval() +if not os.path.isfile(f_topo) or force_static_cvt: + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[InputSpec(shape=[1, 384])]) + save(net, prefix_weights) + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(structure_module) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +structure_module1 = np.ones([1, 384], dtype="float32") + +# 获取输入轴 +input_names = predictor.get_input_names() +inputl_structure = predictor.get_input_handle('structure_module') + + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_structure.reshape(structure_module1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_0'] +outputl = predictor.get_output_handle(output_names[0]) + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_structure.copy_from_cpu(structure_module1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_singletemplateembedding.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_singletemplateembedding.py new file mode 100644 index 00000000..91890b16 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_singletemplateembedding.py @@ -0,0 +1,149 @@ +from layers.basics import SingleTemplateEmbedding +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +from tqdm import tqdm +import time +import os +import numpy as np +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] UT of pdpd.Attention') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + 'extra_msa_channel': 64, + 'msa_channel': 256, + 'pair_channel': 128, + 'seq_channel': 384, + 'template_pair': 85, + } +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['template'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 10 + n_warm +force_static_cvt = False +ignore_eval = False +is_dynamic_input = False +model_name = 'singletemplateembedding' +prefix_weights = 'dynamic_params/' + model_name if is_dynamic_input else 'static_params/' + model_name +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +### create sample input +len_dim = 206 +query_embedding = pd.ones([1, len_dim, len_dim, 128]) +template_aatype = pd.ones([1, len_dim], dtype='int32') +template_pseudo_beta_mask = pd.ones([1, len_dim]) +template_pseudo_beta = pd.ones([1, len_dim, 3]) +template_all_atom_positions = pd.ones([1, len_dim, 37, 3]) +template_all_atom_masks = pd.ones([1, len_dim, 37]) +mask_2d = pd.ones([1, len_dim, len_dim]) + +print('# [INFO] build and save static graph of Attention') +model = SingleTemplateEmbedding(channel_num, c, gc) +model.eval() +if not os.path.isfile(f_topo) or force_static_cvt: + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, len_dim], dtype='int32'), + InputSpec(shape=[1, len_dim]), + InputSpec(shape=[1, len_dim, 3]), + InputSpec(shape=[1, len_dim, 37, 3]), + InputSpec(shape=[1, len_dim, 37]), + InputSpec(shape=[1, len_dim, len_dim]) + ]) + save(net, prefix_weights) + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(template_aatype, + template_pseudo_beta_mask, + template_pseudo_beta, + template_all_atom_positions, + template_all_atom_masks, + mask_2d) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 + +template_aatype1 = np.ones([1, len_dim], dtype="int32") # 一定注意 +template_pseudo_beta_mask1 = np.ones([1, len_dim], dtype="float32") +template_pseudo_beta1 = np.ones([1, len_dim, 3], dtype="float32") +template_all_atom_positions1 = np.ones([1, len_dim, 37, 3], dtype="float32") +template_all_atom_masks1 = np.ones([1, len_dim, 37], dtype="float32") +mask_2d1 = np.ones([1, len_dim, len_dim], dtype="float32") + +# 获取输入轴 +input_names = predictor.get_input_names() +inputl_ta = predictor.get_input_handle('template_aatype') +inputl_tpbm = predictor.get_input_handle('template_pseudo_beta_mask') +inputl_tpb = predictor.get_input_handle('template_pseudo_beta') +inputl_taap = predictor.get_input_handle('template_all_atom_positions') +inputl_taam = predictor.get_input_handle('template_all_atom_masks') +inputl_m2d = predictor.get_input_handle('mask_2d') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_ta.reshape(template_aatype1.shape) + inputl_tpbm.reshape(template_pseudo_beta_mask1.shape) + inputl_tpb.reshape(template_pseudo_beta1.shape) + inputl_taap.reshape(template_all_atom_positions1.shape) + inputl_taam.reshape(template_all_atom_masks1.shape) + inputl_m2d.reshape(mask_2d1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_0'] +outputl = predictor.get_output_handle(output_names[0]) + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_ta.copy_from_cpu(template_aatype1) + inputl_tpbm.copy_from_cpu(template_pseudo_beta_mask1) + inputl_tpb.copy_from_cpu(template_pseudo_beta1) + inputl_taap.copy_from_cpu(template_all_atom_positions1) + inputl_taam.copy_from_cpu(template_all_atom_masks1) + inputl_m2d.copy_from_cpu(mask_2d1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_structuremodule.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_structuremodule.py new file mode 100644 index 00000000..e6990d17 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_structuremodule.py @@ -0,0 +1,112 @@ +import pdb +from layers.head import StructureModule +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +from tqdm import tqdm +import numpy as np +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] UT of pdpd.maskedmsahead') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['heads']['structure_module'] +gc = cfg['model']['global_config'] +### create sample input +# masked_msa +len_dim = 206 +msa_representation = pd.ones([1, 508, len_dim, 256]) + +n_warm = 3 +n_iter = 13 +ignore_eval = True +is_dynamic_input = False +prefix_weights = 'dynamic_params/maskedmashead' if is_dynamic_input else 'static_params/maskedmashead' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +channel_num = { + 'seq_channel': len_dim, + 'pair_channel': 512, + } + +print('# [INFO] build and save static graph of StructureModule') +model = StructureModule(channel_num, c, gc) +model.eval() +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, 508, len_dim, 256]), + ]) + save(net, prefix_weights) + + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(msa_representation) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +msa_representation1 = np.ones([1, 508, len_dim, 256], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() # ['q_data', 'm_data', 'bias'] +inputl_msa = predictor.get_input_handle('msa_representation') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_q.reshape(q_mat1.shape) + inputl_m.reshape(m_mat1.shape) + inputl_b.reshape(bias1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_2'] +outputl = predictor.get_output_handle('tmp_2') + +# run +dts = 0. +for i in range(n_iter): + t0 = time.time() + inputl_q.copy_from_cpu(q_mat1) + inputl_m.copy_from_cpu(m_mat1) + inputl_b.copy_from_cpu(bias1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_templateembedding.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_templateembedding.py new file mode 100644 index 00000000..80e30839 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_templateembedding.py @@ -0,0 +1,167 @@ +import pdb +from layers.embeddings import TemplateEmbedding +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +from tqdm import tqdm +import os +import numpy as np +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] UT of pdpd.Attention') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + 'extra_msa_channel': 64, + 'msa_channel': 256, + 'pair_channel': 128, + 'seq_channel': 384, + 'template_pair': 88, + } +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['template'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 10 + n_warm +force_static_cvt = True +ignore_eval = False +is_dynamic_input = False +model_name = 'tmp' #'templateembedding' +prefix_weights = 'static_params/' + model_name if is_dynamic_input else 'static_params/' + model_name +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +### create sample input +len_dim = 206 +query_embedding = pd.ones([1, len_dim, len_dim, 128]) +template_mask = pd.ones([1, 4]) +template_aatype = pd.ones([1, 4, len_dim], dtype='int32') +template_pseudo_beta_mask = pd.ones([1, 4, len_dim]) +template_pseudo_beta = pd.ones([1, 4, len_dim, 3]) +template_all_atom_positions = pd.ones([1, 4, len_dim, 37, 3]) +template_all_atom_masks = pd.ones([1, 4, len_dim, 37]) +mask_2d = pd.ones([1, 508, len_dim]) + +print('# [INFO] build and save static graph of Attention') +model = TemplateEmbedding(channel_num, c, gc) +model.eval() +if not os.path.isfile(f_topo) or force_static_cvt: + if is_dynamic_input: + len_dim = None + print('# [INFO] create static graph') + net = to_static(model, input_spec=[ + InputSpec(shape=[1, len_dim, len_dim, 128], dtype='float32'), + InputSpec(shape=[1, 4], dtype='float32'), + InputSpec(shape=[1, 4, len_dim], dtype='int32'), + InputSpec(shape=[1, 4, len_dim], dtype='float32'), + InputSpec(shape=[1, 4, len_dim, 3], dtype='float32'), + InputSpec(shape=[1, 4, len_dim, 37, 3], dtype='float32'), + InputSpec(shape=[1, 4, len_dim, 37], dtype='float32'), + InputSpec(shape=[1, 508, len_dim], dtype='float32'), + ]) + save(net, prefix_weights) + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(query_embedding, + template_mask, + template_aatype, + template_pseudo_beta_mask, + template_pseudo_beta, + template_all_atom_positions, + template_all_atom_masks, + mask_2d) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +print('# [INFO] create numpy samples') +query_embedding1 = np.ones([1, len_dim, len_dim, 128], dtype="float32") +template_mask1 = np.ones([1, 4], dtype="float32") +template_aatype1 = np.ones([1, len_dim], dtype="int32") +template_pseudo_beta_mask1 = np.ones([1, len_dim], dtype="float32") +template_pseudo_beta1 = np.ones([1, len_dim, 3], dtype="float32") +template_all_atom_positions1 = np.ones([1, len_dim, 37, 3], dtype="float32") +template_all_atom_masks1 = np.ones([1, len_dim, 37], dtype="float32") +mask_2d1 = np.ones([1, len_dim, len_dim], dtype="float32") + +# 获取输入轴 +print('# [INFO] get input handles') +input_names = predictor.get_input_names() +inputl_qe = predictor.get_input_handle('query_embedding') +inputl_tm = predictor.get_input_handle('template_mask') +inputl_ta = predictor.get_input_handle('template_aatype') +inputl_tpbm = predictor.get_input_handle('template_pseudo_beta_mask') +inputl_tpb = predictor.get_input_handle('template_pseudo_beta') +inputl_taap = predictor.get_input_handle('template_all_atom_positions') +inputl_taam = predictor.get_input_handle('template_all_atom_masks') +inputl_m2d = predictor.get_input_handle('mask_2d') + +# 变形输入轴 +print('# [INFO] reshaping axes') +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_qe.reshape(query_embedding1.shape) + inputl_tm.reshape(template_mask1.shape) + inputl_ta.reshape(template_aatype1.shape) + inputl_tpbm.reshape(template_pseudo_beta_mask1.shape) + inputl_tpb.reshape(template_pseudo_beta1.shape) + inputl_taap.reshape(template_all_atom_positions1.shape) + inputl_taam.reshape(template_all_atom_masks1.shape) + inputl_m2d.reshape(mask_2d1.shape) + +# 获取输出轴 +print('# [INFO] get output handles') +output_names = predictor.get_output_names() # ['tmp_0'] +outputl = predictor.get_output_handle(output_names[0]) + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_qe.copy_from_cpu(query_embedding1) + inputl_tm.copy_from_cpu(template_mask1) + inputl_ta.copy_from_cpu(template_aatype1) + inputl_tpbm.copy_from_cpu(template_pseudo_beta_mask1) + inputl_tpb.copy_from_cpu(template_pseudo_beta1) + inputl_taap.copy_from_cpu(template_all_atom_positions1) + inputl_taam.copy_from_cpu(template_all_atom_masks1) + inputl_m2d.copy_from_cpu(mask_2d1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_templateembedding_copy.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_templateembedding_copy.py new file mode 100644 index 00000000..4df23b13 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_templateembedding_copy.py @@ -0,0 +1,147 @@ +from layers.embeddings import TemplateEmbedding +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +from tqdm import tqdm +import os +import numpy as np +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] UT of pdpd.Attention') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + 'extra_msa_channel': 64, + 'msa_channel': 256, + 'pair_channel': 128, + 'seq_channel': 384, + 'template_pair': 85, + } +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['template'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 10 + n_warm +force_static_cvt = True +ignore_eval = False +is_dynamic_input = False +model_name = 'templateembedding' +prefix_weights = 'static_params/' + model_name if is_dynamic_input else 'static_params/' + model_name +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +### create sample input +len_dim = 206 +query_embedding = pd.ones([1, len_dim, len_dim, 128]) +template_mask = pd.ones([1, 4]) +template_aatype = pd.ones([1, len_dim], dtype='int32') +template_pseudo_beta_mask = pd.ones([1, len_dim]) +template_pseudo_beta = pd.ones([1, len_dim, 3]) +template_all_atom_positions = pd.ones([1, len_dim, 37, 3]) +template_all_atom_masks = pd.ones([1, len_dim, 37]) +mask_2d = pd.ones([1, len_dim, len_dim]) + +print('# [INFO] build and save static graph of Attention') +model = TemplateEmbedding(channel_num, c, gc) +model.eval() + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(query_embedding, + template_mask, + template_aatype, + template_pseudo_beta_mask, + template_pseudo_beta, + template_all_atom_positions, + template_all_atom_masks, + mask_2d) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +query_embedding1 = np.ones([len_dim, len_dim, 128], dtype="float32") +template_mask1 = np.ones([4]) +template_aatype1 = np.ones([4, len_dim], dtype="int32") +template_pseudo_beta_mask1 = np.ones([4, len_dim], dtype="float32") +template_pseudo_beta1 = np.ones([4, len_dim, 3], dtype="float32") +template_all_atom_positions1 = np.ones([4, len_dim, 37, 3], dtype="float32") +template_all_atom_masks1 = np.ones([4, len_dim, 37], dtype="float32") +mask_2d1 = np.ones([len_dim, len_dim], dtype="float32") + +# 获取输入轴 +input_names = predictor.get_input_names() +inputl_qe = predictor.get_input_handle('query_embedding') +inputl_tm = predictor.get_input_handle('template_mask') +inputl_ta = predictor.get_input_handle('template_aatype') +inputl_tpbm = predictor.get_input_handle('template_pseudo_beta_mask') +inputl_tpb = predictor.get_input_handle('template_pseudo_beta') +inputl_taap = predictor.get_input_handle('template_all_atom_positions') +inputl_taam = predictor.get_input_handle('template_all_atom_masks') +inputl_m2d = predictor.get_input_handle('mask_2d') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_qe.reshape(query_embedding1.shape) + inputl_tm.reshape(template_mask1.shape) + inputl_ta.reshape(template_aatype1.shape) + inputl_tpbm.reshape(template_pseudo_beta_mask1.shape) + inputl_tpb.reshape(template_pseudo_beta1.shape) + inputl_taap.reshape(template_all_atom_positions1.shape) + inputl_taam.reshape(template_all_atom_masks1.shape) + inputl_m2d.reshape(mask_2d1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_0'] +outputl = predictor.get_output_handle(output_names[0]) + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_qe.copy_from_cpu(query_embedding1) + inputl_tm.copy_from_cpu(template_mask1) + inputl_ta.copy_from_cpu(template_aatype1) + inputl_tpbm.copy_from_cpu(template_pseudo_beta_mask1) + inputl_tpb.copy_from_cpu(template_pseudo_beta1) + inputl_taap.copy_from_cpu(template_all_atom_positions1) + inputl_taam.copy_from_cpu(template_all_atom_masks1) + inputl_m2d.copy_from_cpu(mask_2d1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_templatepair.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_templatepair.py new file mode 100644 index 00000000..68df0330 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_templatepair.py @@ -0,0 +1,117 @@ +from layers.embeddings import TemplatePair +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from argparse import ArgumentParser as Parser +from tqdm import tqdm + +parser = Parser('[pd.infer] UT of pdpd.Attention') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + 'extra_msa_channel': 64, + 'msa_channel': 256, + 'pair_channel': 128, + 'seq_channel': 384, + 'template_pair': 88, + } +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['template']['template_pair_stack'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 10 + n_warm +ignore_eval = False +is_dynamic_input = False +prefix_weights = 'dynamic_params/templatepair' if is_dynamic_input else 'static_params/templatepair' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +### create sample input +len_dim = 206 +pair_act = pd.ones([1, len_dim, len_dim, 64]) +pair_mask = pd.ones([1, len_dim, len_dim]) + +print('# [INFO] build and save static graph of Attention') +model = TemplatePair(channel_num, c, gc) +model.eval() +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, len_dim, len_dim, 64]), + InputSpec(shape=[1, len_dim, len_dim]) + ]) + save(net, prefix_weights) + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(pair_act, pair_mask) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +pair_act1 = np.ones([1, len_dim, len_dim, 64], dtype='float32') +pair_mask1 = np.ones([1, len_dim, len_dim], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() # ['pair_act1', 'pair_mask1'] +inputl_a = predictor.get_input_handle('pair_act') +inputl_b = predictor.get_input_handle('pair_mask') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_a.reshape(pair_act1.shape) + inputl_b.reshape(pair_mask1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_0'] +outputl = predictor.get_output_handle(output_names[0]) + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_a.copy_from_cpu(pair_act1) + inputl_b.copy_from_cpu(pair_mask1) + + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_transition.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_transition.py new file mode 100644 index 00000000..6996e091 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_transition.py @@ -0,0 +1,112 @@ +import pdb +from layers.basics import Transition +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] UT of pdpd.Attention') +parser.add_argument('--n_cpus', type=int, required=True, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['evoformer']['msa_transition'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 100 + n_warm +ignore_eval = False +is_dynamic_input = False +prefix_weights = 'dynamic_params/transition' if is_dynamic_input else 'static_params/transition' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +TARGET_FEAT_DIM = 22 +MSA_FEAT_DIM = 49 +channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + 'extra_msa_channel': 64, + 'msa_channel': 256, + 'pair_channel': 128, + 'seq_channel': 384, + } +### create sample input +len_dim = 764 +msa_act = pd.ones([1, len_dim, 256]) + +print('# [INFO] build and save static graph of Attention') +model = Transition(channel_num, c, gc, False, 'msa_transition') +model.eval() +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, len_dim, 256]) + ]) + save(net, prefix_weights) + + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in range(n_iter): + t0 = time.time() + _ = model(msa_act) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +msa_act1 = np.ones([1, len_dim, 256], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() # ['act'] +inputl_a = predictor.get_input_handle('act') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_a.reshape(msa_act1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_0'] +outputl = predictor.get_output_handle(output_names[0]) + +# run +dts = 0. +for i in range(n_iter): + t0 = time.time() + inputl_a.copy_from_cpu(msa_act1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_triangleattention.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_triangleattention.py new file mode 100644 index 00000000..0df00f9a --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_triangleattention.py @@ -0,0 +1,119 @@ +import pdb +from layers.basics import TriangleAttention +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from tqdm import tqdm +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] UT of pdpd.Attention') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['evoformer']['triangle_attention_starting_node'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 10 + n_warm +ignore_eval = False +is_dynamic_input = False +prefix_weights = 'dynamic_params/triangleattention' if is_dynamic_input else 'static_params/triangleattention' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +TARGET_FEAT_DIM = 22 +MSA_FEAT_DIM = 49 +channel_num = { + 'target_feat': 22, + 'msa_feat': 49, + 'extra_msa_channel': 64, + 'msa_channel': 256, + 'pair_channel': 128, + 'seq_channel': 384, + } +### create sample input +len_dim = 206 +pair_act = pd.ones([1, len_dim, len_dim, 128]) +pair_mask = pd.ones([1, len_dim, len_dim]) + +print('# [INFO] build and save static graph of Attention') +model = TriangleAttention(channel_num, c, gc, 'triangle_attention') +model.eval() +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, len_dim, len_dim, 128]), + InputSpec(shape=[1, len_dim, len_dim]), + ]) + save(net, prefix_weights) + + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(pair_act, pair_mask) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +pair_act1 = np.ones([1, len_dim, len_dim, 128], dtype='float32') +pair_mask1 = np.ones([1, len_dim, len_dim], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() # ['pair_act', 'pair_mask] +inputl_a = predictor.get_input_handle('pair_act') +inputl_m = predictor.get_input_handle('pair_mask') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_a.reshape(pair_act1.shape) + inputl_m.reshape(pair_mask1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_0'] +outputl = predictor.get_output_handle(output_names[0]) + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_a.copy_from_cpu(pair_act1) + inputl_m.copy_from_cpu(pair_mask1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/unit_tests/ut_trianglemultiplication.py b/apps/protein_folding/helixfold_cpu/unit_tests/ut_trianglemultiplication.py new file mode 100644 index 00000000..23fa546e --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/unit_tests/ut_trianglemultiplication.py @@ -0,0 +1,114 @@ +import pdb +from layers.basics import TriangleMultiplication +from config import model_config +import paddle as pd +from paddle.jit import to_static, save +from paddle.static import InputSpec +from paddle import inference as pdinfer +import time +import os +import numpy as np +from tqdm import tqdm +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] UT of pdpd.Attention') +parser.add_argument('--n_cpus', type=int, default=64, help='physical cores used during pd.infer') +args = parser.parse_args() +n_cpus = args.n_cpus + +cfg = model_config('model_1') +c = cfg['model']['embeddings_and_evoformer']['evoformer']['triangle_multiplication_incoming'] +gc = cfg['model']['global_config'] + +n_warm = 3 +n_iter = 10 + n_warm +ignore_eval = False +is_dynamic_input = False +prefix_weights = 'dynamic_params/trianglemultiplication' if is_dynamic_input else 'static_params/trianglemultiplication' +f_topo = prefix_weights + '.pdmodel' +f_params = prefix_weights + '.pdiparams' + +TARGET_FEAT_DIM = 22 +MSA_FEAT_DIM = 49 +channel_num = { + 'pair_channel': 128, + } +### create sample input +len_dim = 764 +pair_act = pd.ones([1, len_dim, len_dim, 128]) +pair_mask = pd.ones([1, len_dim, len_dim]) + +print('# [INFO] build and save static graph of Attention') +model = TriangleMultiplication(channel_num, c, gc) +model.eval() +if not os.path.isfile(f_topo): + if is_dynamic_input: + len_dim = None + net = to_static(model, input_spec=[ + InputSpec(shape=[1, len_dim, len_dim, 128]), + InputSpec(shape=[1, len_dim, len_dim]), + ]) + save(net, prefix_weights) + + +print('# [INFO] inference on dynamic graph') +if not ignore_eval: + dy_dts = 0. + with pd.no_grad(): + for i in tqdm(range(n_iter)): + t0 = time.time() + _ = model(pair_act, pair_mask) + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dy_dts += dt + +dts = 0. +pd_cfg = pdinfer.Config(f_topo, f_params) + +print('# [INFO] inference on static graph') +### optimization based on intel architecture +pd_cfg.set_cpu_math_library_num_threads(n_cpus) +pd_cfg.enable_mkldnn() +#pd_cfg.enable_memory_optim() # no change in perf. or memory +if is_dynamic_input: + pd_cfg.set_mkldnn_cache_capacity(1) +predictor = pdinfer.create_predictor(pd_cfg) + +# 创建输入样例 +pair_act1 = np.ones([1, len_dim, len_dim, 128], dtype='float32') +pair_mask1 = np.ones([1, len_dim, len_dim], dtype='float32') + +# 获取输入轴 +input_names = predictor.get_input_names() # ['pair_act', 'pair_mask] +inputl_a = predictor.get_input_handle('act') +inputl_m = predictor.get_input_handle('mask') + +# 变形输入轴 +if is_dynamic_input: + print('# [INFO] re-organize dynamic axes') + inputl_a.reshape(pair_act1.shape) + inputl_m.reshape(pair_mask1.shape) + +# 获取输出轴 +output_names = predictor.get_output_names() # ['tmp_0'] +outputl = predictor.get_output_handle(output_names[0]) + +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + inputl_a.copy_from_cpu(pair_act1) + inputl_m.copy_from_cpu(pair_mask1) + predictor.run() + output = outputl.copy_to_cpu() + t1 = time.time() + dt = t1 - t0 + if i >= n_warm: + dts += dt + +if not ignore_eval: + print('# [dynamic-graph] avg inference time = {}'.format(dy_dts/(n_iter-n_warm))) +print('# [static-graph] avg inference time = {}'.format(dts/(n_iter-n_warm))) + +print(output.shape) \ No newline at end of file diff --git a/apps/protein_folding/helixfold_cpu/ut_alphafold.py b/apps/protein_folding/helixfold_cpu/ut_alphafold.py new file mode 100644 index 00000000..db7fcc40 --- /dev/null +++ b/apps/protein_folding/helixfold_cpu/ut_alphafold.py @@ -0,0 +1,59 @@ +from layers.net import AlphaFold +from config import model_config +import time +from tqdm import tqdm +import paddle +from argparse import ArgumentParser as Parser + +parser = Parser('[pd.infer] dynamic UT of pdinfer.HelixFold') + +cfg = model_config('model_1') +c = cfg['model'] + +n_warm = 0 +n_iter = 1 +ignore_eval = False + +### create sample input +len_dim = 40 +feed_dict = { + 'target_feat': paddle.ones([1, 4, len_dim, 22], dtype='float32'), + 'msa_feat': paddle.ones([1, 4, 508, len_dim, 49], dtype='float32'), + 'seq_mask': paddle.ones([1, 4, len_dim], dtype='float32'), + 'seq_length': paddle.ones([1, 4, len_dim], dtype='int32'), + 'aatype': paddle.ones([1, 4, len_dim], dtype='float32'), + 'residue_index': paddle.ones([1, 4, len_dim], dtype='float32'), + 'template_mask': paddle.ones([1, 4, 4], dtype='float32'), + 'template_aatype': paddle.ones([1, 4, 4, len_dim], dtype="int32"), # define + 'template_pseudo_beta_mask': paddle.ones([1, 4, 4, len_dim], dtype='float32'), + 'template_pseudo_beta': paddle.ones([1, 4, 4, len_dim, 3], dtype='float32'), + 'template_all_atom_positions': paddle.ones([1, 4, 4, len_dim, 37, 3], dtype='float32'), + 'template_all_atom_masks': paddle.ones([1, 4, 4, len_dim, 37], dtype='float32'), + 'extra_msa': paddle.ones([1, 4, 5120, len_dim], dtype='float32'), + 'extra_has_deletion': paddle.ones([1, 4, 5120, len_dim], dtype='float32'), + 'extra_deletion_value': paddle.ones([1, 4, 5120, len_dim], dtype='float32'), + 'extra_msa_mask': paddle.ones([1, 4, 5120, len_dim], dtype='float32'), + 'msa_mask': paddle.ones([1, 4, 508, len_dim], dtype='float32'), + 'prev_pos': paddle.ones([1, 4, len_dim, 37, 3], dtype='float32'), + 'prev_msa_first_row': paddle.ones([1, 4, len_dim, 256], dtype='float32'), + 'prev_pair': paddle.ones([1, 4, len_dim, len_dim, 128], dtype='float32'), + 'atom14_atom_exists': paddle.ones([1, 4, len_dim, 14], dtype='float32'), + 'atom37_atom_exists': paddle.ones([1, 4, len_dim, 37], dtype='float32'), + 'residx_atom37_to_atom14': paddle.ones([1, 4, len_dim, 37], dtype='float32') +} + +print('# [INFO] build dynamic graph of HelixFold') +model = AlphaFold(config=c) +model.eval() + +print('# [INFO] inference on dynamic graph') +# run +dts = 0. +for i in tqdm(range(n_iter)): + t0 = time.time() + with paddle.no_grad(): + outputs = model(feed_dict, False) + dt = time.time() - t0 + if i >= n_warm: + dts += dt +print('# [INFO] avg inference time = {}'.format(dts/(n_iter-n_warm)))