Source code for neuralogic.nn.module.gnn.rgcn

from typing import List, Optional

from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation
from neuralogic.core.constructs.factories import R, V
from neuralogic.nn.module.module import Module


[docs] class RGCNConv(Module): r""" Relational Graph Convolutional layer from `Modeling Relational Data with Graph Convolutional Networks <https://arxiv.org/abs/1703.06103>`_. Which can be expressed as: .. math:: \mathbf{x}^{\prime}_i = act(\mathbf{W_0} \cdot \mathbf{x}_i + \sum_{r \in \mathcal{R}} {agg}_{j \in \mathcal{N}_r(i)}(\mathbf{W_r} \cdot \mathbf{x}_j)) Where *act* is an activation function, *agg* aggregation function (by default average), :math:`W_0` is a learnable root parameter and :math:`W_r` is a learnable parameter for each relation. The first part of the equation that is ":math:`\mathbf{W_0} \cdot \mathbf{x}_i`" can be expressed in the logic form as: .. code-block:: logtalk R.<output_name>(V.I) <= R.<feature_name>(V.I)[<W0>] Another part of the equation that is ":math:`{agg}_{j \in \mathcal{N}_r(i)}(\mathbf{W_r} \cdot \mathbf{x}_j)`" can be expressed as: .. code-block:: logtalk R.<output_name>(V.I) <= (R.<feature_name>(V.J)[<Wr>], R.<edge_name>(V.J, relation, V.I)) where "relation" is a constant name, or as: .. code-block:: logtalk R.<output_name>(V.I) <= (R.<feature_name>(V.J)[<Wr>], R.<relation>(V.J, V.I)) The outer summation, together with summing it with the first part, is handled by aggregation of all rules with the same head (and substitution). Examples -------- The whole computation of this module (parametrized as :code:`RGCNConv(1, 2, "h1", "h0", "_edge", ["sibling", "parent"])`) is as follows: .. code:: logtalk metadata = Metadata(activation=Transformation.IDENTITY, aggregation=Aggregation.AVG) (R.h1(V.I) <= R.h0(V.I)[2, 1]) | metadata (R.h1(V.I) <= (R.h0(V.J)[2, 1], R._edge(V.J, sibling, V.I))) | metadata (R.h1(V.I) <= (R.h0(V.J)[2, 1], R._edge(V.J, parent, V.I))) | metadata R.h1 / 1 [Transformation.IDENTITY] Module parametrized as :code:`RGCNConv(1, 2, "h1", "h0", None, ["sibling", "parent"])` translates into: .. code:: logtalk metadata = Metadata(activation=Transformation.IDENTITY, aggregation=Aggregation.AVG) (R.h1(V.I) <= R.h0(V.I)[2, 1]) | metadata (R.h1(V.I) <= (R.h0(V.J)[2, 1], R.sibling(V.J, V.I))) | metadata (R.h1(V.I) <= (R.h0(V.J)[2, 1], R.parent(V.J, V.I))) | metadata R.h1 / 1 [Transformation.IDENTITY] Parameters ---------- in_channels : int Input feature size. out_channels : int Output feature size. output_name : str Output (head) predicate name of the module. feature_name : str Feature predicate name to get features from. edge_name : Optional[str] Edge predicate name to use for neighborhood relations. When :code:`None`, elements from :code:`relations` are used instead. relations : List[str] List of relations' names activation : Transformation Activation function of the output. Default: ``Transformation.IDENTITY`` aggregation : Aggregation Aggregation function of nodes' neighbors. Default: ``Aggregation.SUM`` """ def __init__( self, in_channels: int, out_channels: int, output_name: str, feature_name: str, edge_name: Optional[str], relations: List[str], activation: Transformation = Transformation.IDENTITY, aggregation: Aggregation = Aggregation.AVG, ): self.output_name = output_name self.feature_name = feature_name self.edge_name = edge_name self.relations = relations self.in_channels = in_channels self.out_channels = out_channels self.aggregation = aggregation self.activation = activation def __call__(self): head = R.get(self.output_name)(V.I) metadata = Metadata(transformation=Transformation.IDENTITY, aggregation=self.aggregation) feature = R.get(self.feature_name)(V.J)[self.out_channels, self.in_channels] if self.edge_name is not None: relation_rules = [ ((head <= (feature, R.get(self.edge_name)(V.J, relation, V.I))) | metadata) for relation in self.relations ] else: relation_rules = [ ((head <= (feature, R.get(relation)(V.J, V.I))) | metadata) for relation in self.relations ] return [ (head <= R.get(self.feature_name)(V.I)[self.out_channels, self.in_channels]) | metadata, *relation_rules, R.get(self.output_name) / 1 | Metadata(transformation=self.activation), ]