from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.patches import Rectangle
from numba import cuda, jit, prange
from full_dia import alphatims
from full_dia.alphatims import bruker
from full_dia.log import Logger
try:
_ = profile
except NameError:
[docs]
def profile(func):
return func
logger = Logger.get_logger()
[docs]
@jit(nopython=True, nogil=True)
def numba_index_by_bool(idx, ims, mzs, heights):
"""
Value extraction using boolean indexing in Numba
"""
result_len = np.sum(idx)
result_ims = np.empty(result_len, dtype=ims.dtype)
result_mzs = np.empty(result_len, dtype=mzs.dtype)
result_heights = np.empty(result_len, dtype=heights.dtype)
# idx_cumsum = np.cumsum(idx)
write_idx = 0
for i in range(len(idx)):
if idx[i]:
# write_idx = idx_cumsum[i] - 1
result_ims[write_idx] = ims[i]
result_mzs[write_idx] = mzs[i]
result_heights[write_idx] = heights[i]
write_idx += 1
return result_ims, result_mzs, result_heights
[docs]
@jit(nopython=True, nogil=True, parallel=True)
def numba_paral_repeat(x, y):
"""
Repeat elements of vectors in Numba
"""
result = np.empty(y[-1], dtype=x.dtype)
for i in prange(len(x)):
value = x[i]
start = y[i]
end = y[i + 1]
for ii in range(start, end):
result[ii] = value
return result
[docs]
@jit(nopython=True, nogil=True, parallel=True)
def numba_paral_sum(select_id, cumlen):
"""
Calculate sum values in Numba
"""
result = np.empty(len(cumlen), dtype=np.int64)
for i in prange(len(cumlen)):
if i == 0:
start = 0
end = cumlen[i]
else:
start = cumlen[i - 1]
end = cumlen[i]
value = np.sum(select_id[start:end])
result[i] = value
return result
[docs]
@jit(nopython=True, nogil=True, parallel=True)
def numba_paral_sort(all_tof, all_push, all_height, cumlen):
"""
Sort vectors based on the m/z ascending order in Numba
"""
result_tof = np.empty(len(all_tof), dtype=all_tof.dtype)
result_push = np.empty(len(all_push), dtype=all_push.dtype)
result_height = np.empty(len(all_height), dtype=all_height.dtype)
for i in prange(len(cumlen)):
if i == 0:
start = 0
end = cumlen[i]
else:
start = cumlen[i - 1]
end = cumlen[i]
tof = all_tof[start:end]
push = all_push[start:end]
height = all_height[start:end]
idx = np.argsort(tof)
result_tof[start:end] = tof[idx]
result_push[start:end] = push[idx]
result_height[start:end] = height[idx]
return result_tof, result_push, result_height
[docs]
@jit(nopython=True, nogil=True, parallel=True)
def numba_paral_centroid(
all_tof, all_push, all_height, tol_tof_sum, tol_tof_suppression, tol_push, cumlen
):
"""
Centroid the profile MS data using DIA-NN's method:
1. Summarize intensity values within a window range (m/z + 1/K0).
2. Remove an aggregated point if a higher-intensity aggregated point exists in its neighborhood.
"""
all_height_summed = np.zeros_like(all_height, dtype=np.uint32)
all_height_suppressed = np.ones_like(all_height)
# first summed surrounding
for ms_i in prange(len(cumlen)):
if ms_i == 0:
start = 0
end = cumlen[ms_i]
else:
start = cumlen[ms_i - 1]
end = cumlen[ms_i]
tof = all_tof[start:end]
push = all_push[start:end]
height = all_height[start:end]
if np.sum(height) == 0:
continue
# sum
for i in range(len(tof)):
tof_i = tof[i]
push_i = push[i]
height_i = height[i]
all_height_summed[start + i] += height_i
for ii in range(i + 1, len(tof)):
tof_ii = tof[ii]
push_ii = push[ii]
height_ii = height[ii]
delta_tof = tof_ii - tof_i
delta_push = abs(push_ii - push_i)
if delta_tof > tol_tof_sum:
break
if delta_push > tol_push:
continue
all_height_summed[start + ii] += height_i
all_height_summed[start + i] += height_ii
# suppression
for i in range(len(tof)):
tof_i = tof[i]
push_i = push[i]
height_i = all_height_summed[start + i]
for ii in range(i + 1, len(tof)):
tof_ii = tof[ii]
push_ii = push[ii]
height_ii = all_height_summed[start + ii]
delta_tof = tof_ii - tof_i
delta_push = abs(push_ii - push_i)
if delta_tof > tol_tof_suppression:
break
if delta_push > tol_push:
continue
if height_ii > height_i:
all_height_suppressed[start + i] = 0.0
elif height_ii < height_i:
all_height_suppressed[start + ii] = 0.0
else:
height_raw_i = all_height[start + i]
height_raw_ii = all_height[start + ii]
if height_raw_ii > height_raw_i:
all_height_suppressed[start + i] = 0.0
elif height_raw_ii < height_raw_i:
all_height_suppressed[start + ii] = 0.0
all_height_suppressed = all_height_summed * all_height_suppressed
return all_height_summed, all_height_suppressed
[docs]
class Tims:
"""
Reader and centroiding the profile data for diaPASEF.
"""
@profile
def __init__(self, dir_d: Path) -> None:
# logger.info('Loading .d data...')
self.dir_d = dir_d
self.bruker = bruker.TimsTOF(str(dir_d))
self.df_settings, self.frames_num_per_cycle = self.get_dia_windows()
self.im_gap = self.get_im_gap()
d_ms1_maps, d_ms2_maps = {}, {}
for swath_id in range(len(self.get_dia_quadrupole())):
map = self.construct_data_by_quadrupole(swath_id)
if swath_id == 0: # ms1
d_ms1_maps = self.split_ms1_to_chunks(map)
else: # ms2
d_ms2_maps[swath_id] = map
self.d_ms1_maps = d_ms1_maps
self.d_ms2_maps = d_ms2_maps
@property
def frame_nums(self):
return len(self.bruker.frames)
[docs]
def get_dia_windows(self) -> tuple:
"""
Exact boundaries of the window (m/z + 1/K0) partitioning.
return
------
tuple
df : pd.DataFrame
Each row represents one window range: (im_low, im_high, q_low, q_high)
frames_num_per_cycle : int
The frame number per cycle.
"""
# MsMsType: 0 -- MS, 9 -- MS/MS
ms1_idx = np.where(self.bruker.frames.MsMsType == 0)[0]
ms1_frame_diff = np.diff(ms1_idx)
assert ms1_frame_diff[0] == 1, "AlphaTims not add zeroth frame!"
# condition = (ms1_frame_diff[1:] == ms1_frame_diff[1]).all()
# assert condition, 'alphatims data exists missing cycles!'
frames_num_per_cycle = ms1_frame_diff[1] # has a frame for MS1
df_v = []
for i in range(
2 + 100 * frames_num_per_cycle,
100 * frames_num_per_cycle + frames_num_per_cycle + 1,
):
df = self.bruker[i]
df = df[df["precursor_indices"] > 0] # remove overlap ms1 ions
df = df[["mobility_values", "quad_low_mz_values", "quad_high_mz_values"]]
df_max = (
df.groupby(["quad_low_mz_values", "quad_high_mz_values"], sort=False)
.apply(np.maximum.reduce)
.reset_index(drop=True)
)
df_min = (
df.groupby(["quad_low_mz_values", "quad_high_mz_values"], sort=False)
.apply(np.minimum.reduce)
.reset_index(drop=True)
)
df = df_max.merge(
df_min,
on=["quad_low_mz_values", "quad_high_mz_values"],
suffixes=("_max", "_min"),
)
df_v.append(df)
df = pd.concat(df_v).reset_index(drop=True)
return df, frames_num_per_cycle
[docs]
def plot_dia_windows(self):
"""
For developing.
"""
fig, ax = plt.subplots()
x = np.linspace(300, 1300, 50)
y = np.linspace(0.5, 1.7, 50)
ax.plot(x, y, "w")
df = self.get_dia_windows()[0]
for i in range(len(df)):
x_min = df["quad_low_mz_values"][i]
x_max = df["quad_high_mz_values"][i]
y_min = df["mobility_values_min"][i]
y_max = df["mobility_values_max"][i]
width = x_max - x_min
height = y_max - y_min
ax.add_patch(
Rectangle((x_min, y_min), width, height, fc="none", color="blue")
)
plt.show()
[docs]
def get_dia_quadrupole(self) -> np.ndarray:
"""
Exact boundaries of the quadrupole partitioning.
Return likes: [200, 250, 300, 350 ... 1150, 1200]
"""
x = self.df_settings[["quad_low_mz_values", "quad_high_mz_values"]]
x = x.drop_duplicates()
x = x.sort_values(by="quad_low_mz_values")
low = x["quad_low_mz_values"].values
high = x["quad_high_mz_values"].values
# assert (low[1:] == high[0:-1]).all(), 'dia swath exists ' \
# 'overlap between ' \
# 'windows!'
low[1:] = (low[1:] + high[:-1]) / 2
swath = np.concatenate([low, [high[-1]]])
return swath
[docs]
@profile
def construct_data_by_quadrupole(self, window_id: int) -> tuple:
"""
Construct profile and centroid data with specified window_id.
Parameters
----------
window_id : int
0 refers to MS1, others refer to different quadrupole windows.
Returns
-------
tuple
all_rt : np.ndarray
The rt values of all cycles.
cycle_valid_lens : np.ndarray
The number of profile ions per cycle.
all_push : np.ndarray
The 1/k0 values of profile ions.
all_tof : np.ndarray
The m/z values of profile ions.
all_height : np.ndarray
The intensities of profile ions.
cycle_valid_lens2 : np.ndarray
The number of centroided ions per cycle.
all_push2 : np.ndarray
The 1/k0 values of centroided ions.
all_tof2 : np.ndarray
The m/z values of centroided ions.
all_height2 : np.ndarray
The intensities of centroided ions.
"""
all_rt = self.bruker.rt_values
ms1_idx_v = np.where(self.bruker.frames.MsMsType == 0)[0]
frame_start = ms1_idx_v[1]
frame_end = ms1_idx_v[-1]
all_rt = all_rt[frame_start:frame_end] # remove start and end
all_rt = all_rt.astype(np.float32)
msms_type = self.bruker.frames.MsMsType[frame_start:frame_end]
ms1_idx_v = np.where(msms_type == 0)[0]
# frame_len, height, tof
frame_lens = self.bruker.frames.NumPeaks.values
all_height = self.bruker.intensity_values # uint16
all_tof = self.bruker.tof_indices # uint32
# push to each ion
push_lens = np.diff(self.bruker.push_indptr)
assert len(self.bruker.frames) * self.bruker.scan_max_index == len(
push_lens
), "push exists missing values!"
push_lens = push_lens.astype(np.uint16)
push_idx = np.arange(len(push_lens)) % self.bruker.scan_max_index
push_idx = push_idx.astype(np.int16) # existing subtraction
all_push = numba_paral_repeat(push_idx, self.bruker.push_indptr)
# ion -- window
swath = self.get_dia_quadrupole()
quad_center_values = self.bruker.quad_mz_values.mean(axis=1)
quad_window_ids = np.digitize(quad_center_values, swath)
quad_window_ids = quad_window_ids.astype(np.uint8)
# by swath_id
select_id = quad_window_ids == window_id
select_id = numba_paral_repeat(select_id, self.bruker.quad_indptr)
frame_len_cumsum = np.cumsum(frame_lens)
frame_valid_lens = numba_paral_sum(select_id, frame_len_cumsum)
before_num = frame_lens[:frame_start].sum()
end_num = frame_lens[frame_end:].sum()
select_id[0:before_num] = False
if end_num > 0:
select_id[-end_num:] = False
frame_valid_lens = frame_valid_lens[frame_start:frame_end]
all_push, all_tof, all_height = numba_index_by_bool(
select_id, all_push, all_tof, all_height
)
assert len(all_rt) == len(frame_valid_lens)
assert (
len(all_push) == len(all_tof) == len(all_height) == frame_valid_lens.sum()
)
# cycle rt == first frame rt
all_rt = all_rt[ms1_idx_v]
cycle_valid_lens = np.add.reduceat(frame_valid_lens, ms1_idx_v)
# in cycle: mz in ascending order, im not consideration
cycle_len_cumsum = np.cumsum(cycle_valid_lens)
result = numba_paral_sort(all_tof, all_push, all_height, cycle_len_cumsum)
all_tof, all_push, all_height = result
# centroid
tol_push = self.get_centroid_tol_push()
tol_tof_summed, tol_tof_suppression = 2, 1
summed2, all_height2 = numba_paral_centroid(
all_tof,
all_push,
all_height,
tol_tof_summed,
tol_tof_suppression,
tol_push,
cycle_len_cumsum,
)
tmp = np.split(all_height2, cycle_len_cumsum[:-1])
cycle_valid_lens2 = np.array(list(map(np.count_nonzero, tmp)))
select_id = all_height2 > 0
all_push2, all_tof2, all_height2 = numba_index_by_bool(
select_id, all_push, all_tof, all_height2
)
assert len(all_tof2) == sum(cycle_valid_lens2)
# push -- im,tof -- m/z
push_to_im = self.bruker.mobility_values.astype(np.float32)
all_push = push_to_im[all_push]
all_push2 = push_to_im[all_push2]
tof_to_mz = self.bruker.mz_values.astype(np.float32)
all_tof = tof_to_mz[all_tof]
all_tof2 = tof_to_mz[all_tof2]
return (
all_rt,
cycle_valid_lens,
all_push,
all_tof,
all_height,
cycle_valid_lens2,
all_push2,
all_tof2,
all_height2,
)
[docs]
def get_rt_range(self) -> tuple[float, float]:
"""
Return the minimum and maximum of RTs.
"""
all_rt = self.d_ms1_maps[1][0]
return (all_rt.min(), all_rt.max())
[docs]
def get_cycle_time(self) -> float:
"""
Return the cycle time.
"""
all_rt = self.d_ms1_maps[1][0]
cycle_time = np.mean(np.diff(all_rt))
return cycle_time
[docs]
@profile
def copy_map_to_gpu(self, swath_id: int, centroid: bool) -> list:
"""
Copy profile or centroided MS data to GPU.
Parameters
----------
swath_id : int
Specify the SWATH or quadrupole ID.
centroid : bool
Specify the centroid profile or centroided MS data.
Returns the MS1 chunk and MS2 data.
"""
result = []
for map_type in ["ms1", "ms2"]:
if map_type == "ms1":
(
all_rt,
cycle_valid_lens,
all_push,
all_tof,
all_height,
cycle_valid_lens2,
all_push2,
all_tof2,
all_height2,
) = self.d_ms1_maps[swath_id]
else:
(
all_rt,
cycle_valid_lens,
all_push,
all_tof,
all_height,
cycle_valid_lens2,
all_push2,
all_tof2,
all_height2,
) = self.d_ms2_maps[swath_id]
if centroid:
scan_seek_idx = np.concatenate(
[[0], np.cumsum(cycle_valid_lens2)], dtype=np.int64
)
scan_seek_idx = cuda.to_device(scan_seek_idx)
scan_im = cuda.to_device(all_push2)
scan_mz = cuda.to_device(all_tof2)
scan_height = cuda.to_device(all_height2)
else:
scan_seek_idx = np.concatenate(
[[0], np.cumsum(cycle_valid_lens)], dtype=np.int64
)
scan_seek_idx = cuda.to_device(scan_seek_idx)
scan_im = cuda.to_device(all_push)
scan_mz = cuda.to_device(all_tof)
scan_height = cuda.to_device(all_height)
dia_map = {
"scan_rts": all_rt,
"scan_seek_idx": scan_seek_idx,
"scan_im": scan_im,
"scan_mz": scan_mz,
"scan_height": scan_height,
}
result.append(dia_map)
return result
[docs]
@profile
def split_ms1_to_chunks(self, ms1_map: tuple) -> dict:
"""
MS1 can split by swath_id to save memory.
Also, the start and end add 3Da to cover isos of prs.
Parameters
----------
ms1_map : tuple, the unsplit ms1 map.
Returns
-------
d_ms1_maps : dict
The key is the swath_id, and the value is the MS1 chunk data.
"""
mass_neutron = 1.0033548378
(
all_rt,
cycle_valid_lens,
all_push,
all_tof,
all_height,
cycle_valid_lens2,
all_push2,
all_tof2,
all_height2,
) = ms1_map
# profile and centroid
scans_seek_idx = np.concatenate([[0], np.cumsum(cycle_valid_lens)])
scans_seek_idx2 = np.concatenate([[0], np.cumsum(cycle_valid_lens2)])
swath = self.get_dia_quadrupole()
d_ms1_maps = {}
for i in range(len(swath) - 1):
map_id = i + 1
pr_mz_low = swath[i] - 3 * mass_neutron
pr_mz_high = swath[i + 1] + 3 * mass_neutron
locals_mz, locals_im, locals_height, locals_len = [], [], [], []
locals_mz2, locals_im2, locals_height2, locals_len2 = [], [], [], []
for j in range(len(all_rt)):
# profile
scan_seek_start = scans_seek_idx[j]
scan_seek_end = scans_seek_idx[j + 1]
scan_mz = all_tof[scan_seek_start:scan_seek_end]
scan_height = all_height[scan_seek_start:scan_seek_end]
scan_im = all_push[scan_seek_start:scan_seek_end]
good_idx = (scan_mz >= pr_mz_low) & (scan_mz <= pr_mz_high)
good_num = good_idx.sum()
if good_num:
local_im, local_mz, local_height = numba_index_by_bool(
good_idx, scan_im, scan_mz, scan_height
)
else:
local_mz = np.array([10.0], dtype=np.float32)
local_height = np.array([1], dtype=np.uint16)
local_im = np.array([1.0], dtype=np.float32)
locals_mz.append(local_mz)
locals_im.append(local_im)
locals_height.append(local_height)
locals_len.append(len(local_mz))
# centroid
scan_seek_start2 = scans_seek_idx2[j]
scan_seek_end2 = scans_seek_idx2[j + 1]
scan_mz2 = all_tof2[scan_seek_start2:scan_seek_end2]
scan_height2 = all_height2[scan_seek_start2:scan_seek_end2]
scan_im2 = all_push2[scan_seek_start2:scan_seek_end2]
good_idx = (scan_mz2 >= pr_mz_low) & (scan_mz2 <= pr_mz_high)
good_num = good_idx.sum()
if good_num:
local_im2, local_mz2, local_height2 = numba_index_by_bool(
good_idx, scan_im2, scan_mz2, scan_height2
)
else:
local_mz2 = np.array([10.0], dtype=np.float32)
local_height2 = np.array([1], dtype=np.uint32)
local_im2 = np.array([1.0], dtype=np.float32)
locals_mz2.append(local_mz2)
locals_im2.append(local_im2)
locals_height2.append(local_height2)
locals_len2.append(len(local_mz2))
locals_mz = np.concatenate(locals_mz)
locals_im = np.concatenate(locals_im)
locals_height = np.concatenate(locals_height)
locals_len = np.array(locals_len)
locals_mz2 = np.concatenate(locals_mz2)
locals_im2 = np.concatenate(locals_im2)
locals_height2 = np.concatenate(locals_height2)
locals_len2 = np.array(locals_len2)
d_ms1_maps[map_id] = (
all_rt,
locals_len,
locals_im,
locals_mz,
locals_height,
locals_len2,
locals_im2,
locals_mz2,
locals_height2,
)
return d_ms1_maps
[docs]
def get_scan_rts(self) -> np.ndarray:
"""
Get the RT for each cycle or frame.
"""
scan_rts = self.d_ms2_maps[1][0]
return scan_rts
[docs]
def get_im_gap(self) -> float:
"""
Calculate the 1/k0 value of a single push.
"""
im_min = self.bruker.mobility_min_value
im_max = self.bruker.mobility_max_value
im_count = self.bruker.frames.NumScans.max() + 1
im_gap = (im_max - im_min) / im_count
return im_gap
[docs]
def get_centroid_tol_push(self) -> int:
"""
Calculate how many pushes should be considered as neighbors when centroiding.
"""
im_range = self.bruker.mobility_max_value - self.bruker.mobility_min_value
tol_push = 10 * self.bruker.scan_max_index / 900 / im_range
return int(tol_push)
[docs]
def get_device_name(self) -> str:
"""
Get the device name like timsTOF Ultra.
"""
_, df1, _, df_NCE, _ = alphatims.bruker.read_bruker_sql(str(self.dir_d))
d = df1.set_index("Key")["Value"].to_dict()
return d["InstrumentName"]
[docs]
def load_ms(ws: Path) -> Tims:
"""
Wrapper function for loading diaPASEF data.
"""
ms = Tims(ws)
device = ms.get_device_name()
gradient = ms.get_scan_rts()[-1] / 60.0
logger.info("device_name: {}, gradient: {:.2f}min".format(device, gradient))
return ms