Performance Analysis - Pytorch
Number of effective sequences implemented in Pytorch
In the previous post I have compared various languages and libraries in terms of their speed. This notebook contains the code used in the comparison as well as some details about the choices made to improve the performance of Pytorch implementation.
# ! pip install pandas
# ! pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio===0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
import pandas as pd
import numpy as np
import torch
def get_data(path):
fasta_df = pd.read_csv(path, lineterminator=">", header=None)
fasta_df[['id', 'seq']] = fasta_df[0].str.split('\n', expand=True)[[0,1]]
return fasta_df.seq.to_numpy(dtype=str)
seqs = get_data('picked_msa.fasta')
seqs = get_data('../data/picked_msa.fasta')
Just to remind the pseudo code looks like this:
for seq1 in seqs:
for seq2 in seqs:
if count_mathes(seq1, seq2) > threshold:
weight +=1
meff += 1/weight
meff = meff/(len(seq1)^0.5)
@torch.jit.script
def get_nf_pytorch_core(input_data):
n_seqs, seq_len = input_data.shape
pairwise_ids = torch.zeros((n_seqs, n_seqs), dtype=torch.float16, device=input_data.device)
for i in torch.arange(0, n_seqs, 1, dtype=torch.int32):
batch =input_data[i:i+1]
match = torch.eq(input_data, batch)
pairwise_id = match.to(torch.float16).mean(-1)
pairwise_ids[i, :] = pairwise_id
is_more = torch.greater(pairwise_ids, 0.8).sum(-1)
cluster_size =(1.0/is_more).sum()
return cluster_size/(seq_len**0.5)
def get_nf_pytorch(input_data, gpu=True):
n_seqs = input_data.shape[0]
with torch.no_grad():
with torch.cuda.amp.autocast():
input_data = torch.from_numpy(input_data.view(np.int32).reshape(n_seqs, -1))
if gpu:
input_data=input_data.cuda()
result = get_nf_pytorch_core(input_data)
return result
seqs_ = seqs[:100]
get_nf_pytorch(seqs_, False)
%%timeit -n 3 -r 3
seqs_ = seqs[:100]
get_nf_pytorch(seqs_, False)
seqs_ = seqs[:2500]
get_nf_pytorch(seqs_, True)
%%timeit -n 3 -r 3
get_nf_pytorch(seqs_, True)