from typing import Any
from neuralogic.core.constructs import relation
from neuralogic.core.constructs.predicate import Predicate
from neuralogic.core.constructs.term import Constant, Variable
[docs]
class SpecialPredicateFactory:
"""
Factory for creating special predicates, such as 'alldiff', 'neq', 'eq', etc.
Special predicates are handled differently by the backend engine.
"""
def __init__(self, hidden: bool = False):
"""
Parameters
----------
hidden : bool
Whether the created predicates should be hidden. Default: False.
"""
self.is_hidden = hidden
@property
def hidden(self) -> "SpecialPredicateFactory":
return SpecialPredicateFactory(True)
[docs]
def alldiff(self, *args: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("alldiff", len(args), self.is_hidden, True), args)
[docs]
def neq(self, a: Any, b: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("neq", 2, self.is_hidden, True), [a, b])
[docs]
def eq(self, a: Any, b: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("eq", 2, self.is_hidden, True), [a, b])
[docs]
def leq(self, a: Any, b: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("leq", 2, self.is_hidden, True), [a, b])
[docs]
def lt(self, a: Any, b: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("lt", 2, self.is_hidden, True), [a, b])
[docs]
def geq(self, a: Any, b: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("geq", 2, self.is_hidden, True), [a, b])
[docs]
def gt(self, a: Any, b: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("gt", 2, self.is_hidden, True), [a, b])
[docs]
def next(self, a: Any, b: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("next", 2, self.is_hidden, True), [a, b])
[docs]
def maxcard(self, *args: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("maxcard", len(args), self.is_hidden, True), args)
def _in(self, *args: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("in", len(args), self.is_hidden, True), args)
[docs]
def anypred(self) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("anypred", 0, self.is_hidden, True), None)
[docs]
def truepred(self) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("truepred", 0, self.is_hidden, True), None)
[docs]
def add(self, a: Any, b: Any, c: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("add", 3, self.is_hidden, True), [a, b, c])
[docs]
def sub(self, a: Any, b: Any, c: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("sub", 3, self.is_hidden, True), [a, b, c])
[docs]
def mod(self, a: Any, b: Any, c: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("mod", 3, self.is_hidden, True), [a, b, c])
[docs]
def add_eval(self, a: Any, b: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("add_eval", 2, self.is_hidden, True), [a, b])
[docs]
def sub_eval(self, a: Any, b: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("sub_eval", 2, self.is_hidden, True), [a, b])
[docs]
def mod_eval(self, a: Any, b: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("mod_eval", 2, self.is_hidden, True), [a, b])
[docs]
def mul_eval(self, a: Any, b: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("mul_eval", 2, self.is_hidden, True), [a, b])
[docs]
def div_eval(self, a: Any, b: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("div_eval", 2, self.is_hidden, True), [a, b])
[docs]
def max_eval(self, a: Any, b: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("max_eval", 2, self.is_hidden, True), [a, b])
[docs]
def min_eval(self, a: Any, b: Any) -> relation.BaseRelation:
return relation.BaseRelation(Predicate("min_eval", 2, self.is_hidden, True), [a, b])
def __call__(self, *args, **kwargs):
raise TypeError(
"Cannot add terms to not fully initialized relation - 'special' and 'hidden' are keywords, "
"that cannot be used as a predicate name with dot notation (use `get` method instead)"
)
def __getitem__(self, item):
raise TypeError(
"Cannot add terms to not fully initialized relation - 'special' and 'hidden' are keywords, "
"that cannot be used as a predicate name with dot notation (use `get` method instead)"
)
def __getattr__(self, item):
return relation.BaseRelation(Predicate(item, 0, self.is_hidden, True))
[docs]
def get(self, name: str) -> relation.BaseRelation:
"""
Creates a special relation with the given name and 0 arity.
Parameters
----------
name : str
The name of the special predicate.
Returns
-------
relation.BaseRelation
The created special relation.
"""
return relation.BaseRelation(Predicate(name, 0, self.is_hidden, True))
[docs]
class HiddenPredicateFactory:
"""
Factory for creating hidden predicates.
Hidden predicates are not part of the output unless explicitly requested.
"""
@property
def special(self) -> SpecialPredicateFactory:
return SpecialPredicateFactory(hidden=True)
def __getattr__(self, item):
return relation.BaseRelation(Predicate(item, 0, True, False))
[docs]
def get(self, name: str) -> relation.BaseRelation:
return relation.BaseRelation(Predicate(name, 0, True, False))
def __call__(self, *args, **kwargs):
raise TypeError(
"Cannot add terms to not fully initialized relation - 'special' and 'hidden' are keywords, "
"that cannot be used as a predicate name with dot notation (use `get` method instead)"
)
def __getitem__(self, item):
raise TypeError(
"Cannot add terms to not fully initialized relation - 'special' and 'hidden' are keywords, "
"that cannot be used as a predicate name with dot notation (use `get` method instead)"
)
[docs]
class AtomFactory:
"""
Factory for creating atoms (relations) in the logic program.
It supports dot notation for predicate names and provides access to special and hidden factories.
"""
def __init__(self):
"""Initializes the AtomFactory."""
self.instances: dict[str, dict[int, relation.BaseRelation]] = {}
self.special = SpecialPredicateFactory()
self.hidden = HiddenPredicateFactory()
[docs]
def get(self, name: str) -> relation.BaseRelation:
"""
Creates a relation with the given name and 0 arity.
Parameters
----------
name : str
The name of the predicate.
Returns
-------
relation.BaseRelation
The created relation.
"""
return relation.BaseRelation(Predicate(name, 0, False, False))
def __getattr__(self, item) -> relation.BaseRelation:
return relation.BaseRelation(Predicate(item, 0, False, False))
[docs]
@staticmethod
def get_predicate(name: str, arity: int, hidden: bool, special: bool) -> Predicate:
return Predicate(name, arity, hidden, special)
[docs]
class VariableFactory:
"""
Factory for creating variables.
Variables are automatically capitalized unless specified otherwise.
"""
def __getattr__(self, item: str) -> Variable:
return self.get(item)
[docs]
def get(self, item: str, var_type: str | None = None) -> Variable:
"""
Creates a variable with the given name and optional type.
Parameters
----------
item : str
The name of the variable.
var_type : str, optional
The type of the variable. Default: None.
Returns
-------
Variable
The created variable.
"""
return Variable(item.capitalize(), var_type)
[docs]
class ConstantFactory:
"""
Factory for creating constants.
Constants are automatically converted to lowercase unless specified otherwise.
"""
def __getattr__(self, item: str) -> Constant:
return self.get(item)
[docs]
def get(self, item: str, const_type: str | None = None) -> Constant:
"""
Creates a constant with the given name and optional type.
Parameters
----------
item : str
The name of the constant.
const_type : str, optional
The type of the constant. Default: None.
Returns
-------
Constant
The created constant.
"""
return Constant(item.lower(), const_type)
Var = VariableFactory()
Relation = AtomFactory()
Const = ConstantFactory()
V = Var
C = Const
R = Relation