Source code for neuralogic.nn.module.gnn.gine

from typing import Optional

from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation, Combination
from neuralogic.core.constructs.factories import R, V
from neuralogic.nn.module.module import Module


[docs] class GINEConv(Module): r""" GINEConv layer from `"Strategies for Pre-training Graph Neural Networks" <https://arxiv.org/abs/1905.12265>`_. Parameters ---------- in_channels : int Input feature size. out_channels : int Output feature size. feature_name : str Feature predicate name to get features from. edge_name : str Edge predicate name to use for neighborhood relations. nn_name : str Neural network predicate name. eps : float :math:`\epsilon`-value. Default: ``0.0`` train_eps : bool Is ``eps`` trainable parameter. Default: ``false`` edge_dim : Optional[int] Dimension of edge features (``None`` is projection to ``in_channels`` is not needed). Default: ``None`` """ def __init__( self, in_channels: int, feature_name: str, edge_name: str, nn_name: str, eps: float = 0.0, train_eps: bool = False, edge_dim: Optional[int] = None, ): self.feature_name = feature_name self.edge_name = edge_name self.nn_name = nn_name self.eps = eps self.train_eps = train_eps self.in_channels = in_channels self.edge_dim: Optional[int] = edge_dim def __call__(self): feat_sum = R.get(f"{self.nn_name}__gine_feat_sum") feat_agg = R.get(f"{self.nn_name}__gine_feat_agg") out = R.get(f"{self.nn_name}__gine_out") x = R.get(self.feature_name) e = R.get(self.edge_name) x_eps = x(V.I)[self.eps + 1] if not self.train_eps: x_eps = x_eps.fixed() e_proj = [] if self.edge_dim is not None: e = R.get(f"{self.nn_name}__gine_edge_proj") e_proj = [ (e(V.I, V.J)[self.in_channels, self.edge_dim] <= R.get(self.edge_name)(V.I, V.J)) | Metadata(transformation=Transformation.IDENTITY), e / 2 | Metadata(transformation=Transformation.IDENTITY), ] return [ *e_proj, (feat_sum(V.I, V.J) <= (x(V.J), e(V.J, V.I))) | Metadata(transformation=Transformation.RELU, combination=Combination.SUM), feat_sum / 2 | Metadata(transformation=Transformation.IDENTITY), (feat_agg(V.I) <= feat_sum(V.I, V.J)) | Metadata(transformation=Transformation.IDENTITY, aggregation=Aggregation.SUM), feat_agg / 1 | Metadata(transformation=Transformation.IDENTITY), (out(V.I) <= (x_eps, feat_agg(V.I))) | Metadata(transformation=Transformation.IDENTITY, combination=Combination.SUM), out / 1 | Metadata(transformation=Transformation.IDENTITY), (R.get(self.nn_name)(V.I) <= out(V.I)) | Metadata(transformation=Transformation.IDENTITY), R.get(self.nn_name) / 1 | Metadata(transformation=Transformation.IDENTITY), ]