from __future__ import annotations
import scanpy as sc
import pandas as pd
import numpy as np
from rich import print
from scipy.sparse import csr_matrix
from anndata import AnnData
#import scarches as sca
#from scipy.stats import pearsonr
#import anndata as ad
#from scipy.stats import norm
from sklearn.neighbors import KNeighborsTransformer
import joblib
from TrajAtlas.utils._docs import d
#import lightgbm
import os
from mudata import MuData
try:
import scarches as sca
except ModuleNotFoundError:
print(
"[bold yellow]scarches is not installed. Install with [green]pip install sca [yellow]to project your datasets to our Differetiation Atlas."
)
except Exception as e:
print(f"[bold yellow]You should check whether scarches have properly installed")
location = os.path.dirname(os.path.realpath(__file__))
highVarGeneFile = os.path.join(location,'..','datasets', 'varGene_1500.csv')
trajMapFile=os.path.join(location, '..','datasets', 'trajMap_reference_1.h5ad')
refObs=os.path.join(location, '..','datasets', 'pred_obs.csv')
k_neighbor_model=os.path.join(location, '..','datasets', 'knn_transformer_model.joblib')
rfGeneFile=os.path.join(location,'..',"datasets","rf_genes.csv")
scanviModel=os.path.join(location, '..','datasets',"scanvi_model")
pseduoPredFile=os.path.join(location, '..','datasets', "pseduoPred","lightGBM_pred.pkl")
def _weighted_knn_transfer(
query_adata,
query_adata_emb,
ref_adata_obs,
label_keys,
knn_model,
threshold=1,
pred_unknown=False,
mode="package",
):
"""Annotates ``query_adata`` cells with an input trained weighted KNN classifier.
Parameters
----------
query_adata: :class:`~anndata.AnnData`
Annotated dataset to be used to queryate KNN classifier. Embedding to be used
query_adata_emb: str
Name of the obsm layer to be used for label transfer. If set to "X",
query_adata.X will be used
ref_adata_obs: :class:`pd.DataFrame`
obs of ref Anndata
label_keys: str
Names of the columns to be used as target variables (e.g. cell_type) in ``query_adata``.
knn_model: :class:`~sklearn.neighbors._graph.KNeighborsTransformer`
knn model trained on reference adata with weighted_knn_trainer function
threshold: float
Threshold of uncertainty used to annotating cells as "Unknown". cells with
uncertainties higher than this value will be annotated as "Unknown".
Set to 1 to keep all predictions. This enables one to later on play
with thresholds.
pred_unknown: bool
``False`` by default. Whether to annotate any cell as "unknown" or not.
If `False`, ``threshold`` will not be used and each cell will be annotated
with the label which is the most common in its ``n_neighbors`` nearest cells.
mode: str
Has to be one of "paper" or "package". If mode is set to "package",
uncertainties will be 1 - P(pred_label), otherwise it will be 1 - P(true_label).
"""
if not type(knn_model) == KNeighborsTransformer:
raise ValueError(
"knn_model should be of type sklearn.neighbors._graph.KNeighborsTransformer!"
)
if query_adata_emb == "X":
query_emb = query_adata.X
elif query_adata_emb in query_adata.obsm.keys():
query_emb = query_adata.obsm[query_adata_emb]
else:
raise ValueError(
"query_adata_emb should be set to either 'X' or the name of the obsm layer to be used!"
)
top_k_distances, top_k_indices = knn_model.kneighbors(X=query_emb)
stds = np.std(top_k_distances, axis=1)
stds = (2.0 / stds) ** 2
stds = stds.reshape(-1, 1)
top_k_distances_tilda = np.exp(-np.true_divide(top_k_distances, stds))
weights = top_k_distances_tilda / np.sum(
top_k_distances_tilda, axis=1, keepdims=True
)
cols = ref_adata_obs.columns[ref_adata_obs.columns.str.startswith(label_keys)]
uncertainties = pd.DataFrame(columns=cols, index=query_adata.obs_names)
pred_labels = pd.DataFrame(columns=cols, index=query_adata.obs_names)
for i in range(len(weights)):
for j in cols:
y_train_labels = ref_adata_obs[j].values
unique_labels = np.unique(y_train_labels[top_k_indices[i]])
best_label, best_prob = None, 0.0
for candidate_label in unique_labels:
candidate_prob = weights[
i, y_train_labels[top_k_indices[i]] == candidate_label
].sum()
if best_prob < candidate_prob:
best_prob = candidate_prob
best_label = candidate_label
if pred_unknown:
if best_prob >= threshold:
pred_label = best_label
else:
pred_label = "Unknown"
else:
pred_label = best_label
if mode == "package":
uncertainties.iloc[i][j] = (max(1 - best_prob, 0))
else:
raise Exception("Inquery Mode!")
pred_labels.iloc[i][j] = (pred_label)
print("finished!")
return pred_labels, uncertainties
def formOsteoAdata(adata, batchVal="sample",missing_threshold=500,variableFeature="Default"):
if isinstance(variableFeature,str):
if variableFeature=="Default":
variableFeature=pd.read_csv(highVarGeneFile,index_col=0)["0"].values
if(len(variableFeature)-adata.var_names.isin(variableFeature).sum()>missing_threshold):
raise ValueError("Too many missing gene! Please check data!")
print("Total number of genes needed for mapping:",len(variableFeature))
print(
"Number of genes found in query dataset:",
adata.var_names.isin(variableFeature).sum(),
)
missing_genes = [
gene_id
for gene_id in variableFeature
if gene_id not in adata.var_names
]
missing_gene_adata = sc.AnnData(
X=csr_matrix(np.zeros(shape=(adata.n_obs, len(missing_genes))), dtype="float32"),
obs=adata.obs.iloc[:, :1],
var=missing_genes,
)
missing_gene_adata.var_names=missing_genes
missing_gene_adata.layers["counts"] = missing_gene_adata.X
if "PCs" in adata.varm.keys():
del adata.varm["PCs"]
adata_merged = sc.concat(
[adata, missing_gene_adata],
axis=1,
join="outer",
index_unique=None,
merge="unique",
)
adata_merged = adata_merged[
:, variableFeature
].copy()
adata_merged.obs["batch"]=adata_merged.obs[batchVal].astype(str)
return(adata_merged)
[docs]@d.dedent
def ProjectData(
adata:AnnData,
modelPath:str = None,
max_epoch:int = 100):
"""Projected query datasets (osteogenesis-related) to scANVI latent space :cite:`xuProbabilisticHarmonizationAnnotation2021` which
trained with Differentiation Atlas by scArches :cite:`lotfollahiMappingSinglecellData2022`.
.. seealso::
- See :doc:`../../../tutorial/1_OPCST_projecting` for how to
projecting OPCST model to your datasets.
Parameters
----------
%(adata)s
modelPath
scANVI model. The default model loaded with scANVI is typically trained using the Differentiation Atlas dataset.
max_epoch
scANVI training epoch.
Returns
----------------------
:class:`adata <anndata.AnnData>` object. Updates :attr:`adata.obsm <anndata.AnnData.obsm>` with the following:
- ``scANVI`` scANVI latent trained by scANVI models.
"""
adata_immediate=formOsteoAdata(adata)
if modelPath==None:
modelPath=scanviModel
# if isinstance(modelPath,str):
# if modelPath=="Default":
# modelPath=scanviModel
# print("projecting....")
print("projecting....")
model = sca.models.SCANVI.load_query_data(
adata_immediate,
modelPath,
freeze_dropout = True,
)
model.train(
max_epochs=max_epoch,
plan_kwargs=dict(weight_decay=0.0),
check_val_every_n_epoch=10
)
query_latent = sc.AnnData(model.get_latent_representation())
adata.obsm["scANVI"]=query_latent.X
return(adata)
[docs]@d.dedent
def label_transfer(
adata:AnnData
):
"""Transfer seven-level annotation system and lineage path to adata.
.. seealso::
- See :doc:`../../../tutorial/1_OPCST_projecting` for how to
projecting OPCST model to your datasets.
Parameters
----------
%(adata)s
Returns:
:class:`adata <anndata.AnnData>` object. Also updates :attr:`adata.obs <anndata.AnnData.obs>` with the following:
``pred_level[1-7]_anno`` predicted seven-level annotation.
``predict_lineage_*`` predicted OPCST lineage path.
"""
k_neighbors_transformer=joblib.load(k_neighbor_model)
refTable=pd.read_csv(refObs,index_col=0)
labels, _ = _weighted_knn_transfer(
query_adata=adata,
query_adata_emb="scANVI",
label_keys="pred",
knn_model=k_neighbors_transformer,
ref_adata_obs = refTable
)
adata.obs[labels.columns]=labels
adata.obs[['pred_lineage_lepr', 'pred_lineage_msc', 'pred_lineage_chondro',"pred_lineage_fibro"]]=adata.obs[['pred_lineage_lepr', 'pred_lineage_msc', 'pred_lineage_chondro',"pred_lineage_fibro"]].astype("str")
return(adata)
[docs]@d.dedent
def pseduo_predict(adata:AnnData,
modelPath:str="Default"):
"""Predict common pseudotime.
.. seealso::
- See :doc:`../../../tutorial/1_OPCST_projecting` for how to
projecting OPCST model to your datasets.
Parameters
----------
%(adata)s
modelPath
Path of model to predict pseudotime. The default model loaded was trained with Differentiation Atlas by LightGBMRegressor.
Returns
----------------------
:class:`adata <anndata.AnnData>` object. Also updates :attr:`adata.obs <anndata.AnnData.obs>` with the following:
"""
rfGene=pd.read_csv(rfGeneFile ,index_col=0)
gene=rfGene["gene"][rfGene["importance"]>0.000008]
adata_immediate=formOsteoAdata(adata, variableFeature=gene,batchVal="sample")
if isinstance(modelPath,str):
if modelPath=="Default":
modelPath=pseduoPredFile
#==load model================
model=joblib.load(modelPath)
adata.obs["pseduoPred"]=model.predict(adata_immediate.layers["counts"])
return(adata)
def substractLineageAdata(adata, lineage: list or None = ["Fibroblast", "LepR_BMSC", "MSC", "Chondro"]):
lineageDict = {
"Fibroblast": "pred_lineage_fibro",
"LepR_BMSC": "pred_lineage_lepr",
"MSC": "pred_lineage_msc",
"Chondro": "pred_lineage_chondro"
}
if not isinstance(lineage, list):
raise TypeError("Lineage argument must contain only the valid lineages: 'Fibroblast', 'LepR_BMSC', 'MSC', 'Chondro'.")
if lineage is None:
lineage = ["Fibroblast", "LepR_BMSC", "MSC", "Chondro"]
#values = [lineageDict[key] for key in lineage if key in lineageDict]
values = []
for key in lineage:
if key not in lineageDict:
raise ValueError(f"Invalid lineage '{key}' provided. Lineage argument must contain only the valid lineages: 'Fibroblast', 'LepR_BMSC', 'MSC', 'Chondro'.")
values.append(lineageDict[key])
adata.obs[values] = adata.obs[values].astype("bool")
boolVal = adata.obs[values].apply(lambda row: row.any(), axis=1)
adata.obs["lineageSum"] = boolVal
adata=adata[boolVal,:]
return adata