from typing import Iterable, Union
import numpy as np
from neuralogic.core.constructs.predicate import Predicate
from neuralogic.core.constructs import rule, factories
from neuralogic.core.constructs.function import Transformation, Combination
[docs]
class BaseRelation:
__slots__ = "predicate", "function", "terms", "negated"
def __init__(
self,
predicate: Predicate,
terms=None,
function: Union[Transformation, Combination] = None,
negated: bool = False,
):
self.predicate = predicate
self.function = function
self.negated = negated
self.terms = []
if not isinstance(terms, Iterable) or isinstance(terms, str):
terms = [terms]
for term in terms:
if term is None:
continue
if isinstance(term, list):
self.terms.extend(term)
else:
self.terms.append(term)
def __neg__(self) -> "BaseRelation":
return self.attach_activation_function(Transformation.REVERSE)
def __invert__(self) -> "BaseRelation":
if self.function is not None:
raise ValueError(f"Cannot negate relation {self} with attached function.")
predicate = Predicate(self.predicate.name, self.predicate.arity, True, self.predicate.special)
relation = BaseRelation(predicate, self.terms, self.function, not self.negated)
return relation
@property
def T(self) -> "BaseRelation":
return self.attach_activation_function(Transformation.TRANSP)
[docs]
def attach_activation_function(self, function: Union[Transformation, Combination]):
if self.negated:
raise ValueError(f"Cannot attach function to negated relation {self}")
relation = self.__copy__()
relation.function = function
return relation
def __truediv__(self, other):
if not isinstance(other, int) or self.predicate.arity != 0 or other < 0:
raise NotImplementedError
name, hidden, special = self.predicate.name, self.predicate.hidden, self.predicate.special
return factories.AtomFactory.Predicate.get_predicate(name, other, hidden, special)
def __call__(self, *args) -> "BaseRelation":
if self.terms:
raise Exception
if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], str):
terms = list(args[0])
else:
terms = list(args)
arity = len(terms)
name, hidden, special = self.predicate.name, self.predicate.hidden, self.predicate.special
predicate = factories.AtomFactory.Predicate.get_predicate(name, arity, hidden, special)
return BaseRelation(predicate, terms, self.function, self.negated)
def __getitem__(self, item) -> "WeightedRelation":
if self.predicate.hidden or self.predicate.special:
raise ValueError(f"Special/Hidden relation {self} cannot have learnable parameters.")
return WeightedRelation(item, self.predicate, False, self.terms, self.function)
def __le__(self, other: Union[Iterable["BaseRelation"], "BaseRelation"]) -> rule.Rule:
return rule.Rule(self, other)
[docs]
def to_str(self, end=False) -> str:
end = "." if end else ""
if self.terms:
terms = ", ".join([str(term) for term in self.terms])
if self.negated:
return f"!{self.predicate.to_str()}({terms}){end}"
if self.function:
literal = f"{self.predicate.to_str()}({terms})"
return f"{self.function.wrap(literal)}{end}"
return f"{self.predicate.to_str()}({terms}){end}"
if self.negated:
return f"!{self.predicate.to_str()}{end}"
if self.function:
return f"{self.function.wrap(self.predicate.to_str())}{end}"
return f"{self.predicate.to_str()}{end}"
def __str__(self) -> str:
return self.to_str(True)
def __repr__(self) -> str:
return self.__str__()
def __copy__(self):
relation = BaseRelation.__new__(BaseRelation)
relation.function = self.function
relation.terms = self.terms
relation.predicate = self.predicate
relation.negated = self.negated
return relation
def __and__(self, other) -> rule.RuleBody:
if isinstance(other, BaseRelation):
return rule.RuleBody(self, other)
raise NotImplementedError
[docs]
class WeightedRelation(BaseRelation):
__slots__ = "weight", "weight_name", "is_fixed"
def __init__(
self, weight, predicate: Predicate, fixed=False, terms=None, function: Union[Transformation, Combination] = None
):
super().__init__(predicate, terms, function, False)
self.weight = weight
self.weight_name = None
self.is_fixed = fixed
if isinstance(weight, slice):
self.weight_name = str(weight.start)
self.weight = weight.stop
elif isinstance(weight, tuple) and isinstance(weight[0], slice):
self.weight_name = str(weight[0].start)
self.weight = (weight[0].stop, *weight[1:])
if isinstance(weight, np.ndarray):
self.weight = weight.tolist()
[docs]
def fixed(self) -> "WeightedRelation":
if self.is_fixed:
raise Exception(f"Weighted relation {self} is already fixed")
return WeightedRelation(self.weight, self.predicate, True, self.terms, self.function)
[docs]
def to_str(self, end=False):
if isinstance(self.weight, tuple):
weight = f"{{{', '.join(str(w) for w in self.weight)}}}"
else:
weight = str(self.weight)
if self.weight_name:
weight = f"${self.weight_name}={weight}"
if self.is_fixed:
return f"<{weight}> {super().to_str(end)}"
return f"{weight} {super().to_str(end)}"
def __str__(self) -> str:
return self.to_str(True)
def __repr__(self) -> str:
return self.__str__()
def __call__(self, *args) -> BaseRelation:
raise NotImplementedError(f"Cannot assign terms to weighted relation {self.predicate}")
def __getitem__(self, item) -> "WeightedRelation":
raise NotImplementedError(f"Cannot assign weight to weighted relation {self.predicate}")
[docs]
def attach_activation_function(self, function: Union[Transformation, Combination]):
raise NotImplementedError(
f"Cannot attach a function to weighted relation {self}. Attach the function before adding weights."
)
@property
def T(self) -> "WeightedRelation":
raise NotImplementedError(
f"Cannot transpose weighted relation {self} Apply the transposition before adding weights."
)
def __invert__(self) -> "WeightedRelation":
raise NotImplementedError(f"Weighted relations ({self}) cannot be negated.")
def __neg__(self) -> "WeightedRelation":
raise NotImplementedError(
f"Cannot negate weighted relation {self} Apply the reverse function before adding weights."
)
def __copy__(self):
relation = WeightedRelation.__new__(WeightedRelation)
relation.predicate = self.predicate
relation.function = self.function
relation.terms = self.terms
relation.weight = self.weight
relation.is_fixed = self.is_fixed
relation.negated = self.negated
return relation