Source code for neuralogic.nn.module.gnn.gcn

from typing import Optional

from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation, Combination
from neuralogic.core.constructs.factories import R, V
from neuralogic.nn.module.module import Module


[docs] class GCNConv(Module): r""" Graph Convolutional layer from `"Semi-supervised Classification with Graph Convolutional Networks" <https://arxiv.org/abs/1609.02907>`_. Parameters ---------- in_channels : int Input feature size. out_channels : int Output feature size. output_name : str Output (head) predicate name of the module. feature_name : str Feature predicate name to get features from. edge_name : str Edge predicate name to use for neighborhood relations. activation : Transformation Activation function of the output. Default: ``Transformation.IDENTITY`` aggregation : Aggregation Aggregation function of nodes' neighbors. Default: ``Aggregation.SUM`` add_self_loops : Optional[bool] Add self loops if either set to `True` or `None` (if `normalize` is `True`). Default: ``None`` normalize : bool Add normalization. Default : ``True`` """ def __init__( self, in_channels: int, out_channels: int, output_name: str, feature_name: str, edge_name: str, activation: Transformation = Transformation.IDENTITY, aggregation: Aggregation = Aggregation.SUM, add_self_loops: Optional[bool] = None, normalize: bool = True, ): self.output_name = output_name self.feature_name = feature_name self.edge_name = edge_name self.in_channels = in_channels self.out_channels = out_channels self.aggregation = aggregation self.activation = activation if add_self_loops is None: add_self_loops = normalize self.normalize = normalize self.add_self_loops = add_self_loops def __call__(self): head = R.get(self.output_name)(V.I)[self.out_channels, self.in_channels] metadata = Metadata( transformation=Transformation.IDENTITY, aggregation=self.aggregation, combination=Combination.PRODUCT ) id_metadata = Metadata(transformation=Transformation.IDENTITY) edge = R.get(self.edge_name) edge_count = R.get(f"{self.output_name}__edge_count") self_loops = [] normalization = [] body = [R.get(self.feature_name)(V.J), edge(V.J, V.I)] if self.add_self_loops: edge = R.get(f"{self.output_name}__edge") self_loops = [ edge(V.I, V.I)[1.0].fixed(), (edge(V.I, V.J) <= (R.get(self.edge_name)(V.I, V.J))) | id_metadata, edge / 2 | id_metadata, ] if self.normalize: count_metadata = Metadata(transformation=Transformation.IDENTITY, aggregation=Aggregation.COUNT) body = [R.get(self.feature_name)(V.J), edge(V.J, V.I), Transformation.SQRT(edge_count(V.J, V.I))] normalization = [ (edge_count(V.I, V.J) <= edge(V.J, V.X)) | count_metadata, (edge_count(V.I, V.J) <= edge(V.I, V.X)) | count_metadata, edge_count / 2 | Metadata(combination=Combination.PRODUCT, transformation=Transformation.INVERSE), ] return [ *self_loops, *normalization, (head <= body) | metadata, R.get(self.output_name) / 1 | Metadata(transformation=self.activation), ]