module models
class GATLightningModule_sampler
LightningModule implementation for the GAT (Graph Attention Network) model with sampling.
Args:
data_param
(object): Object containing the parameters of the input data.weight_clone
(torch.Tensor): Weight tensor for the clone loss.weight_type
(torch.Tensor): Weight tensor for the type loss.norm_sim
(torch.Tensor, optional): Tensor containing the similarity values between clones for tree loss implementation. Defaults to None.learning_rate
(float, optional): Learning rate for the optimizer. Defaults to 1e-3.heads
(int, optional): Number of attention heads. Defaults to 3.dim_h
(int, optional): Hidden dimension size. Defaults to 16.weight_decay
(float, optional): Weight decay for the optimizer. Defaults to 1e-4.map_enteties
(str, optional): Mapping entities to predict. Possible values: “both”, “clone”, “type”. Defaults to “both”.n_layers
(int, optional): Number of GAT layers. Defaults to 2.
class GAT2
Graph Attention Network (GAT) model.
Args:
num_classes_clone
(int): Number of clone classes.num_classes_type
(int): Number of type classes.heads
(int, optional): Number of attention heads. Defaults to 1.dim_h
(int, optional): Hidden dimension size. Defaults to 16.map_enteties
(str, optional): Mapping entities to predict. Possible values: “both”, “clone”, “type”. Defaults to “both”.num_node_features
(int, optional): Number of node features. Defaults to 2.
This file was automatically generated via lazydocs.