Source code for TrajAtlas.TrajDiff.trajdiff_utils

from __future__ import annotations
from scipy.stats import binom
import numpy as np
from anndata import AnnData
from mudata import MuData
from rich import print
import pandas as pd
import statsmodels.api as sm
from sklearn.cluster import KMeans
from TrajAtlas.utils._docs import d
from typing import (
    Literal,
)

def _test_binom(length_df,
                times:int = 20):
    # Use binom distribution to test
    sumVal = length_df["true"] + length_df["false"]
    trueVal=length_df["true"]
    null=length_df["null"]
    p_val_list=[]
    for i in range(len(length_df)):
        if null[i]==0:
            null[i]=1/(sumVal[i]*times) # minimal
        p_val= 1- binom.cdf(trueVal[i], sumVal[i], null[i])
        if trueVal[i] == 0:
            p_val=1
        p_val_list.append(p_val)
    length_df["binom_p"]=p_val_list
    return(length_df)
    
def _test_gene_binom(mdata: MuData):
    try:
        sample_adata = mdata["tdiff"]
        pseudobulk=mdata["pseudobulk"]
    except KeyError:
        print(
            "tdiff_mdata should be a MuData object with three slots: feature_key ,'tdiff', and 'pseudobulk' - please run tdiff.count_nhoods(adata) first"
        )
        raise
    times=pseudobulk.uns["shuffle_times"]
    varTable=pseudobulk.varm
    nullTable=varTable["nullPoint"]

    #rateVal=
    sumTable=pseudobulk.uns["sum"]
    trueTable=varTable["truePoint"]
    trueTable = trueTable.fillna(0)
    nullTable = nullTable.fillna(0)
    nullVal = np.where(nullTable == 0, 1 / (sumTable.T * times), nullTable/np.array(sumTable.T * times))
    p_val_array = 1 - binom.cdf(trueTable, sumTable.T, nullVal)
    p_val_array[trueTable == 0] = 1
    p_val_df = pd.DataFrame(p_val_array, index=trueTable.index, columns=trueTable.columns)
    result_df = p_val_df.apply(lambda column: sm.stats.multipletests(column, method='fdr_bh'))
    p_adj_df=result_df.apply(lambda x: x[1])
    p_adj_df.index=p_val_df.index
    p_adj_df.columns=p_val_df.columns.astype("str")
    pseudobulk.varm["gene_p_adj"]=p_adj_df




def _row_scale(row):
    # apply z-score to scale row
    return (row - row.mean()) / row.std()


def _graph_spatial_fdr(
    sample_adata: AnnData,
    neighbors_key: str | None = None,
):
    """FDR correction weighted on inverse of connectivity of neighbourhoods. The distance to the k-th nearest neighbor is used as a measure of connectivity.

    Args:
        sample_adata: Sample-level AnnData.
        neighbors_key: The key in `adata.obsp` to use as KNN graph. Defaults to None.
    """
    # use 1/connectivity as the weighting for the weighted BH adjustment from Cydar
    w = 1 / sample_adata.var["kth_distance"]
    w[np.isinf(w)] = 0

    # Computing a density-weighted q-value.
    pvalues = sample_adata.var["PValue"]
    keep_nhoods = ~pvalues.isna()  # Filtering in case of test on subset of nhoods
    o = pvalues[keep_nhoods].argsort()
    pvalues = pvalues[keep_nhoods][o]
    w = w[keep_nhoods][o]

    adjp = np.zeros(shape=len(o))
    adjp[o] = (sum(w) * pvalues / np.cumsum(w))[::-1].cummin()[::-1]
    adjp = np.array([x if x < 1 else 1 for x in adjp])

    sample_adata.var["SpatialFDR"] = np.nan
    sample_adata.var.loc[keep_nhoods, "SpatialFDR"] = adjp

def _mergeVar(varTable,
            table):
    return(pd.merge(varTable,table,left_index=True,right_index=True,how="left"))


def _test_whole_gene(mdata: MuData):
    try:
        tdiff = mdata["tdiff"]
    except KeyError:
        print(
            "tdiff_mdata should be a MuData object with two slots: feature_key and 'tdiff' - please run tdiff.count_nhoods(adata) first"
        )
        raise
    pseudobulk=mdata["pseudobulk"]
    times=pseudobulk.uns["shuffle_times"]
    varTable=tdiff.varm
    trueVal=np.sum(varTable["Accept"],axis=0)
    sumVal=varTable["Accept"].shape[0]
    nullVal=np.sum(varTable["null_mean"],axis=0)
    
    pval_list=[]
    for i in range(len(nullVal)):
        if(trueVal[i]==0):
            pval_list.append(1)
        else:
            if (nullVal[i]==0):
                nullVal[i]=1/(sumVal*times)
            else:
                    nullVal[i]=nullVal[i]/(sumVal*times)
            pval_list.append(1- binom.cdf(trueVal[i], sumVal, nullVal[i]))
    rejected, adjusted_p_values, _, _ = sm.stats.multipletests(pval_list, method='fdr_bh')
    pseudobulk.var["overall_gene_p"]=np.array(adjusted_p_values)


[docs]@d.dedent def split_gene( mdata : MuData, mode : Literal["Kmean","Stage"]="Kmean", FDR : int = 0.05, kmean_cluster : int = 10, stage_threshold: int = 10, feature_key: str = "rna", select_genes:list = None ): """ Group genes into gene clusters based on their pseudotemporal expression patterns. Currently, we offer two clustering strategies: Kmeans and Stage. In `Kmeans` mode, genes are grouped using Kmeans clustering, which identifies differential expression patterns between two groups. In `Stage` mode, genes are grouped based on the stage (early or late) at which they exhibit differential expression (up or down). .. seealso:: - See :doc:`../../../tutorial/step3_DE` for how to detect pseudotemporal differential genes. Parameters ---------- mdata MuData object processed by `TrajDiff` process. mode Gene group strategy, either `Kmean` or `Stage`. If `Kmean` were selected, genes are grouped using Kmeans clustering. If `Stage` were selected genes are grouped based on the stage. (default: "Kmean") FDR False discovery rate to discover significant genes. (default: 0.05) kmean_cluster If 'Kmeans' is chosen in the mode parameters, you can specify the number of clusters for this parameter. (Default: 10) stage_threshold If 'Stage' is chosen in the mode parameters, you can specify the threshold of every stage (early/late) to select stage-specific genes. (Default: 10) feature_key Key to store the cell-level AnnData object in the MuData object. (Default: "rna") select_genes Custom gene sets are allowed. If not provided, we are using significant genes. (Default: None) Returns ----------------- Gene category. Also update MuData in `MuData[feature_key].var`. """ if mode not in ["Kmean", "Stage"]: raise ValueError("mode must be one of 'Kmean' 或 'Stage'") try: tdiff = mdata["tdiff"] pseudobulk = mdata["pseudobulk"] except KeyError: print("tdiff_mdata should be a MuData object with three slots: feature_key, 'tdiff' and 'pseudobulk' - please run tdiff.count_nhoods(adata) first") raise try: if select_genes == None: sigGene = pseudobulk.var_names[pseudobulk.var["overall_gene_p"]<FDR] select_genes = sigGene.copy() exprMatrix=pseudobulk.varm["exprPoint"].loc[select_genes] fdr_matrix=pseudobulk.varm["gene_p_adj"].loc[select_genes] except KeyError: print("tdiff_mdata should have been run de pipeline. Please run tdiff.de first") raise if mode=="Kmean": exprMatrix=exprMatrix.fillna(0) kmeans = KMeans(n_clusters=kmean_cluster) kmeans.fit(exprMatrix) labels = kmeans.labels_ labels=pd.DataFrame(labels) labels.index=exprMatrix.index labels.columns=["geneGroup"] mdata[feature_key].var["Kmeans"]=np.nan mdata[feature_key].var["Kmeans"].loc[labels.index]=np.array(labels["geneGroup"]) labels=labels["geneGroup"] elif mode=="Stage": fdr_matrix=-np.log(fdr_matrix+0.000000001) # Assuming exprMatrix and fdrMatrix are numpy arrays fdrBinary = fdr_matrix > 1 exprBinary = exprMatrix * fdrBinary start = exprBinary.iloc[:, 0:int(exprBinary.shape[1]/3)] end = exprBinary.iloc[:, int(2 * exprBinary.shape[1]/3):exprBinary.shape[1]] startSum = np.sum(start, axis=1) endSum = np.sum(end, axis=1) # Define different categories of genes geneWholeUp = np.array([name for idx, name in enumerate(startSum.index) if startSum[idx] > stage_threshold and endSum[idx] > stage_threshold]) geneStartUp = np.array([name for idx, name in enumerate(startSum.index) if startSum[idx] > stage_threshold and abs(endSum[idx]) < stage_threshold]) geneEndUp = np.array([name for idx, name in enumerate(startSum.index) if abs(startSum[idx]) < stage_threshold and endSum[idx] > stage_threshold]) geneWholeDown = np.array([name for idx, name in enumerate(startSum.index) if startSum[idx] < -stage_threshold and endSum[idx] < -stage_threshold]) geneStartdown = np.array([name for idx, name in enumerate(startSum.index) if startSum[idx] < -stage_threshold and abs(endSum[idx]) < stage_threshold]) geneEndDown = np.array([name for idx, name in enumerate(startSum.index) if abs(startSum[idx]) < stage_threshold and endSum[idx] < -stage_threshold]) geneUpDown = np.array([name for idx, name in enumerate(startSum.index) if startSum[idx] > stage_threshold and endSum[idx] < -stage_threshold]) geneDownUp = np.array([name for idx, name in enumerate(startSum.index) if startSum[idx] < -stage_threshold and endSum[idx] > stage_threshold]) geneGroup = np.concatenate((geneWholeUp, geneStartUp, geneEndUp, geneWholeDown, geneStartdown, geneEndDown, geneUpDown, geneDownUp)) geneCat = np.concatenate((np.repeat("Up_Up", len(geneWholeUp)), np.repeat("Up_0", len(geneStartUp)), np.repeat("0_Up", len(geneEndUp)), np.repeat("Down_Down", len(geneWholeDown)), np.repeat("Down_0", len(geneStartdown)), np.repeat("0_Down", len(geneEndDown)), np.repeat("Up_Down", len(geneUpDown)), np.repeat("Down_Up", len(geneDownUp)))) geneSplit = pd.DataFrame({"gene": geneGroup, "geneGroup": geneCat}) geneSplit['geneGroup'] = pd.Categorical(geneSplit['geneGroup']) geneSplit.index=geneSplit["gene"] labels=geneSplit["geneGroup"] mdata[feature_key].var["Stage"]=np.nan mdata[feature_key].var["Stage"].loc[labels.index]=np.array(labels) return(labels)