module models


class GATLightningModule_sampler

LightningModule implementation for the GAT (Graph Attention Network) model with sampling.

Args:

  • data_param (object): Object containing the parameters of the input data.
  • weight_clone (torch.Tensor): Weight tensor for the clone loss.
  • weight_type (torch.Tensor): Weight tensor for the type loss.
  • norm_sim (torch.Tensor, optional): Tensor containing the similarity values between clones for tree loss implementation. Defaults to None.
  • learning_rate (float, optional): Learning rate for the optimizer. Defaults to 1e-3.
  • heads (int, optional): Number of attention heads. Defaults to 3.
  • dim_h (int, optional): Hidden dimension size. Defaults to 16.
  • weight_decay (float, optional): Weight decay for the optimizer. Defaults to 1e-4.
  • map_enteties (str, optional): Mapping entities to predict. Possible values: “both”, “clone”, “type”. Defaults to “both”.
  • n_layers (int, optional): Number of GAT layers. Defaults to 2.

class GAT2

Graph Attention Network (GAT) model.

Args:

  • num_classes_clone (int): Number of clone classes.
  • num_classes_type (int): Number of type classes.
  • heads (int, optional): Number of attention heads. Defaults to 1.
  • dim_h (int, optional): Hidden dimension size. Defaults to 16.
  • map_enteties (str, optional): Mapping entities to predict. Possible values: “both”, “clone”, “type”. Defaults to “both”.
  • num_node_features (int, optional): Number of node features. Defaults to 2.


This file was automatically generated via lazydocs.