Source code for neuralogic.nn.module.general.rvnn

from neuralogic.core.constructs.metadata import Metadata
from neuralogic.core.constructs.function import Transformation, Aggregation
from neuralogic.core.constructs.factories import R, V
from neuralogic.nn.module.module import Module


[docs] class RvNN(Module): r""" Recursive Neural Network (RvNN) module which is computed as: .. math:: \mathbf{h}_i = act(agg_{j \in \mathcal{Ch(i)}}(\mathbf{W_{id(j)}} \mathbf{h}_j)) Where :math:`act` is an activation function, :math:`agg` aggregation function and :math:`\mathbf{W}`'s are learnable parameters. :math:`\mathcal{Ch(i)}` represents the ordered list of children of node :math:`i`. The :math:`id(j)` function maps node :math:`j` to its index (position) in its parent's children list. Parameters ---------- input_size : int Input feature size. output_name : str Output (head) predicate name of the module. input_name : str Input feature predicate name to get leaf features from. parent_map_name : str Name of the predicate to get mapping from parent to children max_children : int Maximum number of children (specify which <max_children>-ary tree will be considered). Default: ``2`` activation : Transformation Activation function of all layers. Default: ``Transformation.TANH`` aggregation : Aggregation Aggregation function of a layer. Default: ``Aggregation.SUM`` arity : int Arity of the input and output predicate (doesn't include the node id term). Default: ``1`` """ def __init__( self, input_size: int, output_name: str, input_name: str, parent_map_name: str, max_children: int = 2, activation: Transformation = Transformation.TANH, aggregation: Aggregation = Aggregation.SUM, arity: int = 1, ): self.input_size = input_size self.output_name = output_name self.input_name = input_name self.parent_map_name = parent_map_name self.max_children = max_children self.activation = activation self.aggregation = aggregation self.arity = arity def __call__(self): terms = [f"X{i}" for i in range(self.arity)] head_terms = [*terms, V.P] input_rel = R.get(self.input_name) output_rel = R.get(self.output_name) parent_map_rel = R.get(self.parent_map_name) metadata = Metadata(transformation=self.activation, aggregation=self.aggregation) rules = [ (output_rel(head_terms) <= (input_rel(head_terms), parent_map_rel(V.P))) | metadata, output_rel / len(head_terms) | [Transformation.IDENTITY], ] body = [] parent_terms = [V.P] for i in range(1, self.max_children + 1): term = f"C{i}" body.append(output_rel([*terms, term])[f"{self.output_name}__rvnn_{i}" : self.input_size, self.input_size]) parent_terms.append(term) rules.append((output_rel(head_terms) <= (*body, parent_map_rel(*parent_terms))) | metadata) return rules