Source code for full_dia.quant

from itertools import product

import numpy as np
import pandas as pd
import torch
from numba import cuda

from full_dia import cfg, fxic, tims, utils
from full_dia.log import Logger

try:
    _ = profile
except NameError:

[docs] def profile(func): return func
logger = Logger.get_logger()
[docs] def mask_tensor( xic: torch.Tensor, left: torch.Tensor, right: torch.Tensor ) -> torch.Tensor: """ Set the edge regions of the XIC to zero. Parameters ---------- xic : torch.Tensor Dimension: [n_xic, n_ion, n_cycle] left : torch.Tensor Indicate the left region of the XIC. right : torch.Tensor Indicate the right region of the XIC. Returns ------- xic : torch.Tensor The edge regions have been set to zero. """ device = xic.device n_pr, n_ion, n_cycle = xic.shape cycle_indices = torch.arange(n_cycle, device=device).view(1, 1, -1) left = left.view(-1, 1, 1) right = right.view(-1, 1, 1) mask = (cycle_indices >= left) & (cycle_indices <= right) return xic * mask
[docs] @profile def interp_xics(x: torch.Tensor, rts_input: np.ndarray, target_dim: int) -> tuple: """ Interpolate XIC along the cycle to target dimension. Also update the rts of new time points. """ rts = torch.from_numpy(rts_input).to(cfg.gpu_id) n_pr, n_ion, n_cycle = x.shape tmp = torch.linspace(0, 1, target_dim, device=x.device).unsqueeze(0) new_rts = rts[:, [0]] + (rts[:, [-1]] - rts[:, [0]]) * tmp idx = torch.searchsorted(rts, new_rts) idx_right = idx.clamp(max=n_cycle - 1) idx_left = (idx_right - 1).clamp(min=0) t_left = torch.gather(rts, 1, idx_left) t_right = torch.gather(rts, 1, idx_right) idx_left_exp = idx_left.unsqueeze(1).expand(n_pr, n_ion, target_dim) idx_right_exp = idx_right.unsqueeze(1).expand(n_pr, n_ion, target_dim) x_left = torch.gather(x, 2, idx_left_exp) x_right = torch.gather(x, 2, idx_right_exp) new_rts_exp = new_rts.unsqueeze(1) t_left_exp = t_left.unsqueeze(1) t_right_exp = t_right.unsqueeze(1) eps = 1e-8 weight = (new_rts_exp - t_left_exp) / (t_right_exp - t_left_exp + eps) x_interp = x_left + weight * (x_right - x_left) return new_rts, x_interp
[docs] def select_other_profiles(x_profile: torch.Tensor, best_profile: torch.Tensor) -> tuple: """ Select the profile from different tolerance conditions that has the highest SA with the best profile. Parameters ---------- x_profile : torch.Tensor XIC profiles using different tolerances. Dimension: [n_pep, tol, n_ion, n_cycle] best_profile : torch.Tensor The best profile. Dimension: [n_pep] Returns ------- tuple best_x : torch.Tensor Each profile has the highest SA with the best profile. sas : torch.Tensor The SA scores of profiles. """ device = x_profile.device n_pr, n_condition, n_ion, n_cycle = x_profile.shape best_profile_exp = best_profile.unsqueeze(1).unsqueeze(1) # [n_pr, 1, 1, n_cycle] # dot dot = (x_profile * best_profile_exp).sum(dim=-1) # L2 norm_x = x_profile.norm(dim=-1) # [n_pr, n_condition, n_ion] norm_best = best_profile.norm(dim=-1).unsqueeze(1).unsqueeze(1) # [n_pr, 1, 1] # cos cos_sim = dot / (norm_x * norm_best + 1e-8) # best best_condition_idx = torch.argmax(cos_sim, dim=1) # [n_pr, n_ion] # select xic pr_idx = torch.arange(n_pr, device=device).unsqueeze(1).expand(n_pr, n_ion) ion_idx = torch.arange(n_ion, device=device).unsqueeze(0).expand(n_pr, n_ion) best_x = x_profile[pr_idx, best_condition_idx, ion_idx, :] # [n_pr, n_ion, n_cycle] # select sa best_cos_sim = cos_sim[pr_idx, best_condition_idx, ion_idx] # [n_pr, n_ion] best_cos_sim_clamped = best_cos_sim.clamp(-1, 1) angles = torch.acos(best_cos_sim_clamped) sas = 1 - 2 * angles / torch.pi assert sas.max().item() <= 1.0 assert sas.min().item() >= 0.0 return best_x, sas
[docs] def select_best_profile(x_profile: torch.Tensor) -> torch.Tensor: """ Select the best profile if it has the highest SA among other profiles. """ n_pr, n_ion, n_cycle = x_profile.shape x_norm = x_profile / (x_profile.norm(dim=-1, keepdim=True) + 1e-8) cos_sim = torch.matmul(x_norm, x_norm.transpose(1, 2)) avg_sim = (cos_sim.sum(dim=-1) - 1.0) / (n_ion - 1) best_ion_idx = torch.argmax(avg_sim, dim=-1) pr_indices = torch.arange(n_pr, device=x_profile.device) best_profile = x_profile[pr_indices, best_ion_idx, :] return best_profile # [n_pr, n_cycle]
[docs] def interference_correction( xics: torch.Tensor, best_profile: torch.Tensor ) -> torch.Tensor: """ DIA-NN's method to correct the interference of profiles. """ r_m = xics / (best_profile[:, None, :] + 1e-7) r_center = r_m[:, :, int(xics.shape[-1] / 2)] bad_idx = r_m > 1.5 * r_center[:, :, None] tmp = 1.5 * r_center[:, :, None] * best_profile[:, None, :] xics[bad_idx] = tmp[bad_idx] return xics
[docs] @profile def grid_xic_best( df_batch: pd.DataFrame, ms1_centroid: dict, ms2_centroid: dict ) -> tuple: """ The profile with the highest SA among other fragment ion profiles is selected as the best profile. Different tolerance combinations are then traversed to extract XICs corresponding to the highest SA with the best profile. Parameters ---------- df_batch : pd.DataFrame Provide the precursor information. ms1_centroid : dict The MS1 data. ms2_centroid : dict The MS2 data. Returns ------- tuple areas : np.ndarray Areas by best profiles. sas : np.ndarray The corresponding SA scores. """ locus_start_v = df_batch["score_elute_span_left"].values locus_start_v = torch.from_numpy(locus_start_v).to(cfg.gpu_id) locus_end_v = df_batch["score_elute_span_right"].values locus_end_v = torch.from_numpy(locus_end_v).to(cfg.gpu_id) tol_ppm_v = [20.0, 16.0, 12.0, 8.0, 4.0] tol_im_v = [0.02, 0.01] grid_params = list(product(tol_ppm_v, tol_im_v)) xics_v = [] expand_dim = 64 for search_i, (tol_ppm, tol_im) in enumerate(grid_params): _, rts, _, _, xics = fxic.extract_xics( df_batch, ms1_centroid, ms2_centroid, im_tolerance=tol_im, ppm_tolerance=tol_ppm, cycle_num=13, by_pred=False, ) xics = utils.convert_numba_to_tensor(xics) # 14 ions xics = mask_tensor(xics, locus_start_v, locus_end_v) rts, xics = interp_xics(xics, rts, expand_dim) # rts, xics = utils.interp_xics(xics, rts, expand_dim) xics = fxic.gpu_simple_smooth(cuda.as_cuda_array(xics)) xics = utils.convert_numba_to_tensor(xics) # find best profile from top-6 if search_i == 0: xics_top6 = xics[:, 2:8, :] best_profile = select_best_profile(xics_top6) # [n_pr, n_cycle] # bad_xic by apex bad_xic = torch.abs(best_profile.argmax(dim=-1) - expand_dim / 2) > 6 # boundary by best_profile box = best_profile > best_profile.max(dim=-1, keepdims=True)[0] * 0.2 box = box.int() box_left = box.argmax(dim=-1) box_right = expand_dim - 1 - torch.flip(box, dims=[1]).argmax(dim=-1) xics_v.append(xics) # [tol, n_pep, n_ion, n_cycle] xics = torch.stack(xics_v, dim=1) # [n_pep, tol, n_ion, n_cycle] # find other profile with the help of best_profile xics, sas = select_other_profiles(xics, best_profile) # interference correction xics = interference_correction(xics, best_profile) # bad_xic re-extract _, rts2, _, _, xics2 = fxic.extract_xics( df_batch, ms1_centroid, ms2_centroid, im_tolerance=0.025, ppm_tolerance=15, cycle_num=13, by_pred=False, ) xics2 = utils.convert_numba_to_tensor(xics2) # 14 ions xics2 = mask_tensor(xics2, locus_start_v, locus_end_v) rts2, xics2 = interp_xics(xics2, rts2, expand_dim) xics2 = fxic.gpu_simple_smooth(cuda.to_device(xics2)) xics2 = utils.convert_numba_to_tensor(xics2) xics[bad_xic] = xics2[bad_xic] box_left[bad_xic] = 15 # 3-13, 15-64 box_right[bad_xic] = 50 # 9-13, 50-64 # boundary xics = mask_tensor(xics, box_left, box_right) # area not using rts: trapz(xics, rts) areas = torch.trapz(xics, dim=-1) zeros_idx = (areas == 0) | (sas == 0) areas[zeros_idx] = 0.0 sas[zeros_idx] = 0.0 return areas.cpu().numpy(), sas.cpu().numpy()
[docs] @profile def quant_center_ions(df_input: pd.DataFrame, ms: tims.Tims) -> pd.DataFrame: """ A novel xic extraction method to quantify fragment ions. Parameters ---------- df_input : pd.DataFrame Provide the identification information of precursors. ms : tims.Tims Provide the MS data. Returns ------- df : pd.DataFrame Add new columns: "score_ion_quant" and "score_ion_sa". """ df_good = [] for swath_id in df_input["swath_id"].unique(): df_swath = df_input[df_input["swath_id"] == swath_id] df_swath = df_swath.reset_index(drop=True) # ms ms1_centroid, ms2_centroid = ms.copy_map_to_gpu(swath_id, centroid=True) # in batches batch_n = cfg.batch_xic_locus for _, df_batch in df_swath.groupby(df_swath.index // batch_n): df_batch = df_batch.reset_index(drop=True) # grid search for best profiles areas, sas = grid_xic_best(df_batch, ms1_centroid, ms2_centroid) # save cols = ["score_ion_quant_" + str(i) for i in range(cfg.fg_num + 2)] df_batch[cols] = areas cols = ["score_ion_sa_" + str(i) for i in range(cfg.fg_num + 2)] df_batch[cols] = sas df_good.append(df_batch) utils.release_gpu_scans(ms1_centroid, ms2_centroid) df = pd.concat(df_good, axis=0, ignore_index=True) return df