from typing import Sequence, Union
import jpype
from neuralogic.core.constructs.factories import R
from neuralogic.core.constructs.java_objects import JavaFactory
from neuralogic.core.constructs.relation import BaseRelation
from neuralogic.core.constructs.rule import Rule
from neuralogic.dataset.base import BaseDataset
DatasetEntries = Union[BaseRelation, Rule]
[docs]
class Sample:
__slots__ = (
"query",
"example",
)
def __init__(
self, query: BaseRelation | list[BaseRelation] | None, example: Sequence[DatasetEntries] | DatasetEntries | None
):
self.query = query
if example is None:
example = []
if not isinstance(example, Sequence):
self.example = [example]
else:
self.example = example
[docs]
def draw(*args, **kwargs):
raise NotImplementedError("sample cannot be drawn unless it is grounded or neuralized")
def __str__(self) -> str:
if isinstance(self.query, list):
return ", ".join(str(q) for q in self.query)
return str(self.query)
def __len__(self) -> int:
if self.example is None:
return 0
return len(self.example)
[docs]
class Dataset(BaseDataset):
r"""
Dataset encapsulating (learning) samples in the form of logic format, allowing users to fully take advantage of the
PyNeuraLogic library.
"""
__slots__ = ("samples", "_examples", "_queries")
def __init__(self, samples: list[Sample] | Sample | None = None):
self.samples = []
if isinstance(samples, list):
self.samples = samples
elif not isinstance(samples, list) and samples is not None:
self.samples = [samples]
self._examples: list[list[DatasetEntries]] = []
self._queries: list[BaseRelation] = []
[docs]
def set_samples(self, samples: list[Sample]):
self.samples = samples
[docs]
def add_samples(self, samples: list[Sample]) -> "Dataset":
self.samples.extend(samples)
return self
[docs]
def add_sample(self, sample: Sample) -> "Dataset":
self.samples.append(sample)
return self
[docs]
def add(self, query: BaseRelation | list[BaseRelation] | None, example: list[DatasetEntries] | None) -> "Dataset":
self.samples.append(Sample(query, example))
return self
def __getitem__(self, item: int) -> Sample:
return self.samples[item]
def __setitem__(self, key: int, value: Sample):
self.samples[key] = value
def __delitem__(self, key: int):
del self.samples[key]
def __str__(self):
return ". ".join(str(s) for s in self.samples)
def __len__(self):
return len(self.samples)
# Deprecated
[docs]
def add_example(self, example):
self.add_examples([example])
[docs]
def add_examples(self, examples: list):
self._examples.extend(examples)
[docs]
def add_query(self, query):
self.add_queries([query])
[docs]
def add_queries(self, queries: list):
self._queries.extend(queries)
[docs]
def set_examples(self, examples: list):
self._examples = examples
[docs]
def set_queries(self, queries: list):
self._queries = queries
[docs]
def generate_features(self, feature_depth: int = 1, count_groundings: bool = True):
java_factory = JavaFactory()
clauses = []
vertex_lit = R.get("__vert")
vertex_lit.predicate.special = False
vertex_lit.predicate.hidden = False
for sample in self.samples:
vertex = set()
for e in sample.example:
if isinstance(e, Rule):
vertex.update(self._get_constants(e.head))
for rel in e.body:
vertex.update(self._get_constants(rel))
if isinstance(e, BaseRelation):
vertex.update(self._get_constants(e))
example = [vertex_lit(vert) for vert in vertex]
example.extend(sample.example)
clauses.append(java_factory.to_clause(example))
clause = jpype.java.util.ArrayList(clauses)
namespace = "cz.cvut.fel.ida.logic.features.generation"
jpype.JClass(f"{namespace}.FeatureGenerationSettings").COUNT_GROUNDINGS = count_groundings
features = jpype.JClass(f"{namespace}.FeatureGenerator").generateFeatures(clause, feature_depth)
table = [[int(i) for i in feats] for feats in features.table]
clauses = [str(clause) for clause in features.features]
return table, clauses
@staticmethod
def _get_constants(relation: BaseRelation):
return [term for term in relation.terms if not str(relation)[0].isupper()]