Source code for neuralogic.nn.module.general.positional_encoding

import math
from typing import Any

import numpy as np

from neuralogic.core.constructs.factories import R
from neuralogic.core.constructs.function import Combination
from neuralogic.nn.module.module import Module


[docs] class PositionalEncoding(Module): """ Implements positional encoding for logic sequences. It generates a set of rules that add fixed or learnable positional embeddings to the input features. """ def __init__( self, embed_dim: int, max_len: int, output_name: str, input_name: str, arity: int = 1, learnable: bool = False, ): """ Parameters ---------- embed_dim : int The dimension of the positional embedding. max_len : int The maximum length of the sequence. output_name : str The name of the output relation (with positional encoding added). input_name : str The name of the input relation to add encoding to. arity : int, optional The arity of the input/output relations. Default: 1. learnable : bool, optional Whether the positional encodings should be learnable parameters. Default: False. """ self.embed_dim = embed_dim self.max_len = max_len self.output_name = output_name self.input_name = input_name self.arity = arity self.learnable = learnable def __call__(self) -> list[Any]: """ Generates the rules for positional encoding. Returns ------- list A list of rules defining the positional encoding nodes and the final combined output. """ terms = [f"X{i}" for i in range(self.arity - 1)] all_terms = [f"X{i}" for i in range(self.arity)] position = np.arange(0, self.max_len).reshape((self.max_len, 1)) div_term = np.exp(np.arange(0, self.embed_dim, 2) * (-math.log(1000.0) / self.embed_dim)) pe = np.zeros((self.max_len, self.embed_dim)) mul = position * div_term pe[:, 0::2] = np.sin(mul) pe[:, 1::2] = np.cos(mul) pe_rel = R.get(f"{self.output_name}__pe") out_rel = R.get(self.output_name) in_rel = R.get(self.input_name) if self.learnable: rules = [pe_rel(*terms, i)[row] for i, row in enumerate(pe)] else: rules = [pe_rel(*terms, i)[row].fixed() for i, row in enumerate(pe)] rules.append((out_rel(all_terms) <= (pe_rel(all_terms), in_rel(all_terms))) | [Combination.SUM]) return rules