module utils


function reverse_log_softmax

reverse_log_softmax(log_probs)

Reverse the log softmax operation to obtain the logits.

Args:

  • log_probs (torch.Tensor): The log probabilities.

Returns:

  • torch.Tensor: The logits.

function get_results

get_results(
    pred,
    data,
    node_encoder_rev,
    node_encoder_ct,
    node_encoder_cl,
    activation=None
)

Get the results of the prediction for clone and cell type classifications.

Args:

  • pred (torch.Tensor): The prediction tensor.
  • data (torch.Tensor): The data tensor.
  • node_encoder_rev (dict): The reverse node encoder dictionary.
  • node_encoder_ct (dict): The cell type node encoder dictionary.
  • node_encoder_cl (dict): The clone node encoder dictionary.
  • activation (str, optional): The activation function to apply. Defaults to None.

Returns:

  • tuple: A tuple containing the clone results and cell type results as pandas DataFrames.

function get_results_all

get_results_all(
    pred,
    data,
    node_encoder_rev,
    node_encoder_ct,
    node_encoder_cl,
    activation=None
)

Get the results for clone and cell type predictions.

Args:

  • pred (torch.Tensor): Predictions tensor.
  • data (torch_geometric.data.Data): Input data.
  • node_encoder_rev (dict): Reverse node encoder dictionary.
  • node_encoder_ct (dict): Cell type node encoder dictionary.
  • node_encoder_cl (dict): Clone node encoder dictionary.
  • activation (str, optional): Activation function to apply. Can be “softmax”, “raw”, or None. Defaults to None.

Returns:

  • tuple: A tuple containing two pandas DataFrames: - clone_res: DataFrame containing clone predictions. - ct_res: DataFrame containing cell type predictions.

function get_calibrated_results

get_calibrated_results(
    pred,
    data,
    node_encoder_rev,
    node_encoder_ct,
    node_encoder_cl,
    t
)

Calibrates the predicted results using temperature scaling.

Args:

  • pred (numpy.ndarray): The predicted results.
  • data (pandas.DataFrame): The input data.
  • node_encoder_rev (dict): The reverse node encoder dictionary.
  • node_encoder_ct (dict): The cell type node encoder dictionary.
  • node_encoder_cl (dict): The clone node encoder dictionary.
  • t (tuple): The temperature parameter for scaling clone (t[0]) and cell type (t[1]) predictions.

Returns:

  • tuple: A tuple containing the calibrated results for clones and cell types.

function get_results_clone

get_results_clone(
    pred,
    data,
    node_encoder_rev,
    node_encoder_cl,
    activation=None
)

Get the clone results based on the predictions.

Args:

  • pred (torch.Tensor): The predictions.
  • data (Data): The data object containing the hold_out indices.
  • node_encoder_rev (dict): The reverse node encoder dictionary.
  • node_encoder_cl (dict): The node encoder dictionary.
  • activation (str, optional): The activation function to apply. Defaults to None.

Returns:

  • pd.DataFrame: The clone results.

function get_results_type

get_results_type(pred, data, node_encoder_rev, node_encoder_ct, activation=None)

Get the results type for the predicted cell types.

Args:

  • pred (torch.Tensor): The predicted cell types.
  • data (torch.Tensor): The input data.
  • node_encoder_rev (dict): A dictionary mapping node indices to cell names.
  • node_encoder_ct (dict): A dictionary mapping node indices to cell types.
  • activation (str, optional): The activation function to apply. Defaults to None.

Returns:

  • pd.DataFrame: A DataFrame containing the predicted cell types for the hold-out cells.

function rotate_90_degrees_clockwise

rotate_90_degrees_clockwise(matrix)

Rotates a matrix 90 degrees clockwise.

Parameters: matrix (numpy.ndarray): The input matrix to be rotated.

Returns: numpy.ndarray: The rotated matrix.


function get_attention_visium

get_attention_visium(w, node_encoder_rev, data, coordinates)

Calculate attention visualization for Visium data.

Args:

  • w (tuple): Tuple containing the edges and weights of the attention graph.
  • node_encoder_rev (dict): Reverse node encoder dictionary.
  • data (torch.Tensor): Hold out data.
  • coordinates (pd.DataFrame): DataFrame containing the coordinates of the nodes.

Returns:

  • pd.DataFrame: DataFrame containing the attention weights for each target node and distance category.

function get_attention

get_attention(w, node_encoder_rev, data, coordinates)

Calculate attention weights for spatial graph nodes based on the given inputs.

Args:

  • w (tuple): A tuple containing two elements - edges and weight. - edges (torch.Tensor): Tensor representing the edges of the graph. - weight (torch.Tensor): Tensor representing the weights of the edges.
  • node_encoder_rev (dict): A dictionary mapping node indices to their corresponding IDs.
  • data (torch.Tensor): Tensor representing the hold-out data.
  • coordinates (pd.DataFrame): DataFrame containing the coordinates of the nodes.

Returns:

  • pd.DataFrame: DataFrame containing the attention weights for each target node, categorized by distance.

function plot_metrics

plot_metrics(stored_metrics)

Plots the validation accuracy for clone and cell type metrics.

Args:

  • stored_metrics (dict): A dictionary containing the stored metrics.

Returns: None


function check_class_distributions

check_class_distributions(
    data,
    weight_clone,
    weight_type,
    norm_sim,
    no_diploid=False
)

Check the class distributions in the data and validate the inputs.

Args:

  • data (torch_geometric.data.Data): The input data.
  • weight_clone (list): The weights for each clone class.
  • weight_type (list): The weights for each type class.
  • norm_sim (torch.Tensor): The similarity scores.
  • no_diploid (bool, optional): Whether to exclude the diploid class. Defaults to False.

Raises:

  • AssertionError: If the number of clone classes in the training set is not equal to the total number of classes.
  • AssertionError: If the number of clone classes is not equal to the number of weights.
  • AssertionError: If the number of clone classes is not equal to the number of similarity scores.
  • AssertionError: If the number of type classes in the training set is not equal to the total number of classes.

function compute_class_weights

compute_class_weights(y_train)

Calculate class weights based on the class sample count.


function balanced_split

balanced_split(data, hold_in, size=0.5)

Splits the data into balanced train and test sets based on the given hold_in indices.

Parameters:

  • data (object): The data object containing the features and labels.
  • hold_in (list): The indices of the data to be split.
  • size (float): The proportion of data to be included in the test set. Default is 0.5.

Returns:

  • train_indices_final (list): The indices of the training set.
  • test_indices_final (list): The indices of the test set.

This file was automatically generated via lazydocs.