Source code for neuralogic.nn.module.gnn.gin

from typing import Any

from neuralogic.core.constructs.factories import R, V
from neuralogic.core.constructs.function import Aggregation, Transformation
from neuralogic.core.constructs.function.function import AggregationFunction, TransformationFunction
from neuralogic.core.constructs.metadata import Metadata
from neuralogic.nn.module.module import Module


[docs] class GINConv(Module): """ Implements the Graph Isomorphism Network (GIN) convolution layer. GIN is a powerful GNN layer that can distinguish between different graph structures. """ def __init__( self, in_channels: int, out_channels: int, output_name: str, feature_name: str, edge_name: str, activation: TransformationFunction = Transformation.IDENTITY, aggregation: AggregationFunction = Aggregation.SUM, ): """ Parameters ---------- in_channels : int Number of input features. out_channels : int Number of output features. output_name : str Name of the output relation. feature_name : str Name of the input feature relation. edge_name : str Name of the edge relation. activation : TransformationFunction, optional Activation function to use. Default: Transformation.IDENTITY. aggregation : AggregationFunction, optional Aggregation function to use. Default: Aggregation.SUM. """ self.output_name = output_name self.feature_name = feature_name self.edge_name = edge_name self.in_channels = in_channels self.out_channels = out_channels self.aggregation = aggregation self.activation = activation def __call__(self) -> list[Any]: """ Generates the rules for GIN convolution. Returns ------- list A list of rules defining the GIN convolution. """ head = R.get(self.output_name)(V.I)[self.out_channels, self.in_channels] embed = R.get(f"embed__{self.output_name}") metadata = Metadata(aggregation=self.aggregation) return [ (head <= (R.get(self.feature_name)(V.J), R.get(self.edge_name)(V.J, V.I))) | metadata, (embed(V.I) <= R.get(self.feature_name)(V.I)) | metadata, (head <= embed(V.I)[self.in_channels, self.in_channels]) | Metadata(transformation=self.activation), R.get(self.output_name) / 1 | Metadata(transformation=self.activation), ]