Source code for neuralogic.core.constructs.relation

from typing import Any, Iterable

import numpy as np

from neuralogic.core.constructs import factories, rule
from neuralogic.core.constructs.function import FContainer
from neuralogic.core.constructs.function.enum import Combination, Transformation
from neuralogic.core.constructs.function.function import CombinationFunction, TransformationFunction
from neuralogic.core.constructs.predicate import Predicate


[docs] class BaseRelation: """ Represents a relation with a predicate, terms, and an optional activation function. """ __slots__ = "predicate", "function", "terms", "negated" def __init__( self, predicate: Predicate, terms: Any = None, function: TransformationFunction | CombinationFunction | None = None, negated: bool = False, ): """ Parameters ---------- predicate : Predicate The predicate of the relation. terms : Any, optional The terms of the relation. Default: None. function : Union[TransformationFunction, CombinationFunction], optional The activation/combination function. Default: None. negated : bool Whether the relation is negated. Default: False. """ self.predicate = predicate self.function = function self.negated = negated self.terms = [] if not isinstance(terms, Iterable) or isinstance(terms, str): terms = [terms] for term in terms: if term is None: continue if isinstance(term, list): self.terms.extend(term) else: self.terms.append(term) def __neg__(self) -> "BaseRelation": return self.attach_activation_function(Transformation.REVERSE) def __invert__(self) -> "BaseRelation": if self.function is not None: raise ValueError(f"Cannot negate relation {self} with attached function.") predicate = Predicate(self.predicate.name, self.predicate.arity, True, self.predicate.special) relation = BaseRelation(predicate, self.terms, self.function, not self.negated) return relation @property def T(self) -> "BaseRelation": return self.attach_activation_function(Transformation.TRANSP)
[docs] def attach_activation_function(self, function: TransformationFunction | CombinationFunction) -> "BaseRelation": """Attaches an activation or combination function to the relation. Parameters ---------- function : Union[TransformationFunction, CombinationFunction] The function to attach. Returns ------- BaseRelation A new relation with the attached function. """ if self.negated: raise ValueError(f"Cannot attach function to negated relation {self}") relation = self.__copy__() relation.function = function return relation
def __truediv__(self, other: int) -> Predicate: if not isinstance(other, int) or self.predicate.arity != 0 or other < 0: raise TypeError(f"Invalid arity operand for {self}: {other!r}") name, hidden, special = self.predicate.name, self.predicate.hidden, self.predicate.special return factories.AtomFactory.get_predicate(name, other, hidden, special) def __call__(self, *args: Any) -> "BaseRelation": if self.terms: raise ValueError("Cannot assign terms twice to a relation") if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], str): terms = list(args[0]) else: terms = list(args) arity = len(terms) name, hidden, special = self.predicate.name, self.predicate.hidden, self.predicate.special predicate = factories.AtomFactory.get_predicate(name, arity, hidden, special) return BaseRelation(predicate, terms, self.function, self.negated) def __getitem__(self, item) -> "WeightedRelation": if self.predicate.hidden or self.predicate.special: raise ValueError(f"Special/Hidden relation {self} cannot have learnable parameters.") return WeightedRelation(item, self.predicate, False, self.terms, self.function) def __le__(self, other: Iterable["BaseRelation"] | "BaseRelation") -> rule.Rule: return rule.Rule(self, other)
[docs] def to_str(self, end=False) -> str: """Returns a string representation of the relation. Parameters ---------- end : bool Whether to append a dot at the end. Default: False. Returns ------- str The string representation. """ end = "." if end else "" if self.terms: terms = ", ".join([str(term) for term in self.terms]) if self.negated: return f"!{self.predicate.to_str()}({terms}){end}" if self.function: literal = f"{self.predicate.to_str()}({terms})" return f"{self.function.wrap(literal)}{end}" return f"{self.predicate.to_str()}({terms}){end}" if self.negated: return f"!{self.predicate.to_str()}{end}" if self.function: return f"{self.function.wrap(self.predicate.to_str())}{end}" return f"{self.predicate.to_str()}{end}"
def __str__(self) -> str: return self.to_str(True) def __repr__(self) -> str: return self.__str__() def __copy__(self): relation = BaseRelation.__new__(BaseRelation) relation.function = self.function relation.terms = self.terms relation.predicate = self.predicate relation.negated = self.negated return relation def __and__(self, other: "BaseRelation") -> rule.RuleBody: if isinstance(other, BaseRelation): return rule.RuleBody(self, other) raise TypeError(f"Cannot combine {type(self).__name__} with {type(other).__name__}") def __add__(self, other: "BaseRelation") -> FContainer: return FContainer((self, other), Combination.SUM) def __mul__(self, other: "BaseRelation") -> FContainer: return FContainer((self, other), Combination.ELPRODUCT) def __matmul__(self, other: "BaseRelation") -> FContainer: return FContainer((self, other), Combination.PRODUCT)
[docs] class WeightedRelation(BaseRelation): """ Represents a relation with an associated weight (learnable or fixed). """ __slots__ = "weight", "weight_name", "is_fixed" def __init__( self, weight: Any, predicate: Predicate, fixed: bool = False, terms: Any = None, function: TransformationFunction | CombinationFunction | None = None, ): """ Parameters ---------- weight : Any The weight of the relation. Can be a value, a tuple, or a slice (for named weights). predicate : Predicate The predicate of the relation. fixed : bool Whether the weight is fixed. Default: False. terms : Any, optional The terms of the relation. Default: None. function : Union[TransformationFunction, CombinationFunction], optional The activation/combination function. Default: None. """ super().__init__(predicate, terms, function, False) self.weight = weight self.weight_name = None self.is_fixed = fixed if isinstance(weight, slice): self.weight_name = str(weight.start) self.weight = weight.stop elif isinstance(weight, tuple) and isinstance(weight[0], slice): self.weight_name = str(weight[0].start) self.weight = (weight[0].stop, *weight[1:]) if isinstance(weight, np.ndarray): self.weight = weight.tolist()
[docs] def fixed(self) -> "WeightedRelation": """Returns a copy of the relation with the weight fixed. Returns ------- WeightedRelation The weighted relation with a fixed weight. """ if self.is_fixed: raise ValueError(f"Weighted relation {self} is already fixed") return WeightedRelation(self.weight, self.predicate, True, self.terms, self.function)
[docs] def to_str(self, end=False): if isinstance(self.weight, tuple): weight = f"{{{', '.join(str(w) for w in self.weight)}}}" else: weight = str(self.weight) if self.weight_name: weight = f"${self.weight_name}={weight}" if self.is_fixed: return f"<{weight}> {super().to_str(end)}" return f"{weight} {super().to_str(end)}"
def __str__(self) -> str: return self.to_str(True) def __repr__(self) -> str: return self.__str__() def __call__(self, *args: Any) -> BaseRelation: raise TypeError(f"Cannot assign terms to weighted relation {self.predicate}") def __getitem__(self, item) -> "WeightedRelation": raise TypeError(f"Cannot assign weight to weighted relation {self.predicate}")
[docs] def attach_activation_function(self, function: Transformation | Combination) -> "WeightedRelation": raise TypeError( f"Cannot attach a function to weighted relation {self}. Attach the function before adding weights." )
@property def T(self) -> "WeightedRelation": raise TypeError(f"Cannot transpose weighted relation {self}. Apply the transposition before adding weights.") def __invert__(self) -> "WeightedRelation": raise TypeError(f"Weighted relations ({self}) cannot be negated.") def __neg__(self) -> "WeightedRelation": raise TypeError(f"Cannot negate weighted relation {self}. Apply the reverse function before adding weights.") def __copy__(self): relation = WeightedRelation.__new__(WeightedRelation) relation.predicate = self.predicate relation.function = self.function relation.terms = self.terms relation.weight = self.weight relation.is_fixed = self.is_fixed relation.negated = self.negated return relation