Source code for neuralogic.core.constructs.rule

from typing import Iterable, Optional

from neuralogic.core.constructs.metadata import Metadata


[docs] class RuleBody: __slots__ = "literals", "metadata" def __init__(self, lit1, lit2): self.literals = [lit1, lit2] self.metadata = None def __and__(self, other): from neuralogic.core.constructs.relation import BaseRelation if isinstance(other, BaseRelation): self.literals.append(other) return self raise NotImplementedError def __str__(self) -> str: return ", ".join(atom.to_str() for atom in self.literals) def __repr__(self) -> str: return self.__str__() def __or__(self, other) -> "RuleBody": if isinstance(other, Iterable): other = Metadata.from_iterable(other) elif not isinstance(other, Metadata): raise NotImplementedError self.metadata = other return self
[docs] class Rule: __slots__ = "head", "body", "metadata" def __init__(self, head, body): from neuralogic.core import Relation self.head = head self.metadata: Optional[Metadata] = None if head.function is not None: raise NotImplementedError(f"Rule head {head} cannot have a function attached") if isinstance(body, RuleBody): if body.metadata is not None: self._set_metadata(body.metadata) body = body.literals if not isinstance(body, Iterable): body = [body] self.body = list(body) if self.is_ellipsis_templated(): variable_set = {term for term in head.terms if term is not Ellipsis and str(term)[0].isupper()} for body_atom in self.body: if body_atom.predicate.special and body_atom.predicate.name == "alldiff": continue for term in body_atom.terms: if term is not Ellipsis and str(term)[0].isupper(): variable_set.add(term) for atom_index, body_atom in enumerate(self.body): if not body_atom.predicate.special or body_atom.predicate.name != "alldiff": continue new_terms = [] found_replacement = False for index, term in enumerate(body_atom.terms): if term is Ellipsis: if found_replacement: raise NotImplementedError found_replacement = True new_terms.extend(variable_set) else: new_terms.append(term) if found_replacement: self.body[atom_index] = Relation.special.alldiff(*new_terms)
[docs] def is_ellipsis_templated(self) -> bool: for body_atom in self.body: if not body_atom.predicate.special or body_atom.predicate.name != "alldiff": continue for term in body_atom.terms: if term is Ellipsis: return True return False
[docs] def to_str(self, _: bool = False) -> str: return str(self)
def __str__(self) -> str: metadata = "" if self.metadata is None is None else f" {self.metadata}" return f"{self.head.to_str()} :- {', '.join(atom.to_str() for atom in self.body)}.{metadata}" def __repr__(self) -> str: return self.to_str() def __and__(self, other) -> "Rule": if isinstance(other, Iterable): self.body.extend(list(other)) else: self.body.append(other) return self def __or__(self, other) -> "Rule": if isinstance(other, Iterable): other = Metadata.from_iterable(other) elif not isinstance(other, Metadata): raise NotImplementedError self._set_metadata(other) return self def _set_metadata(self, metadata: Metadata): if metadata.aggregation is not None and metadata.aggregation.rule_head_dependant(): metadata = metadata.copy() metadata.aggregation = metadata.aggregation.process_head(self.head) self.metadata = metadata