from typing import Dict, Optional, Union, Callable, List
from neuralogic.core.settings import Settings
from neuralogic.core.builder import DatasetBuilder
from neuralogic.core import Template, BuiltDataset, SettingsProxy, GroundedDataset
from neuralogic.dataset.base import BaseDataset
from neuralogic.utils.visualize import draw_model
[docs]
class AbstractNeuraLogic:
def __init__(self, dataset_builder: DatasetBuilder, template: Template, settings: SettingsProxy):
self.need_sync = True
self.source_template = [rule for rule in template.template]
self.template = dataset_builder.parsed_template
self.dataset_builder = dataset_builder
self.settings = settings
self.hooks_set = False
self.hooks: Dict[str, List[Callable]] = {}
def __call__(self, sample):
raise NotImplementedError
[docs]
def ground(
self,
dataset: BaseDataset,
*,
batch_size: int = 1,
learnable_facts: bool = False,
) -> GroundedDataset:
return self.dataset_builder.ground_dataset(
dataset,
self.settings,
batch_size=batch_size,
learnable_facts=learnable_facts,
)
[docs]
def build_dataset(
self,
dataset: Union[BaseDataset, GroundedDataset],
*,
batch_size: int = 1,
learnable_facts: bool = False,
progress: bool = False,
) -> BuiltDataset:
return self.dataset_builder.build_dataset(
dataset,
self.settings,
batch_size=batch_size,
learnable_facts=learnable_facts,
progress=progress,
)
[docs]
def set_hooks(self, hooks):
self.hooks_set = len(hooks) != 0
self.hooks = hooks
[docs]
def run_hook(self, hook: str, value):
for callback in self.hooks[hook]:
callback(value)
[docs]
def sync_template(self, state_dict: Optional[Dict] = None, weights=None):
state_dict = self.state_dict() if state_dict is None else state_dict
weights = self.template.getAllWeights() if weights is None else weights
weight_dict = state_dict["weights"]
for weight in weights:
if not weight.isLearnable:
continue
weight_value = weight.value
index = weight.index
value = weight_dict[index]
if isinstance(value, (float, int)):
weight_value.set(0, float(value))
continue
if isinstance(value[0], (float, int)):
for i, val in enumerate(value):
weight_value.set(i, float(val))
continue
cols = len(value[0])
for i, values in enumerate(value):
for j, val in enumerate(values):
weight_value.set(i * cols + j, float(val))
[docs]
def parameters(self) -> Dict:
return self.state_dict()
[docs]
def state_dict(self) -> Dict:
raise NotImplementedError
[docs]
def load_state_dict(self, state_dict: Dict):
raise NotImplementedError
[docs]
def draw(
self,
filename: Optional[str] = None,
show=True,
img_type="png",
value_detail: int = 0,
graphviz_path: Optional[str] = None,
*args,
**kwargs,
):
return draw_model(self, filename, show, img_type, value_detail, graphviz_path, *args, **kwargs)
[docs]
class AbstractEvaluator:
def __init__(self, template: Template, settings: Settings):
self.settings = settings.create_proxy()
self.neuralogic_model = template.build(settings)
self.neuralogic_model.set_hooks(template.hooks)
[docs]
def build_dataset(
self,
dataset: Union[BaseDataset, BuiltDataset],
*,
batch_size: int = 1,
learnable_facts: bool = False,
progress: bool = False,
):
if isinstance(dataset, BaseDataset):
return self.neuralogic_model.build_dataset(
dataset,
batch_size=batch_size,
learnable_facts=learnable_facts,
progress=progress,
)
return dataset
@property
def model(self) -> AbstractNeuraLogic:
return self.neuralogic_model
[docs]
def train(self, dataset: Optional[Union[BaseDataset, BuiltDataset]] = None, *, generator: bool = True):
pass
[docs]
def test(self, dataset: Optional[Union[BaseDataset, BuiltDataset]] = None, *, generator: bool = True):
pass
[docs]
def parameters(self) -> Dict:
return self.state_dict()
[docs]
def state_dict(self) -> Dict:
pass
[docs]
def load_state_dict(self, state_dict: Dict):
pass
[docs]
def reset_parameters(self):
self.neuralogic_model.reset_parameters()
[docs]
def draw(
self,
filename: Optional[str] = None,
show=True,
img_type="png",
value_detail: int = 0,
graphviz_path: Optional[str] = None,
*args,
**kwargs,
):
return self.neuralogic_model.draw(filename, show, img_type, value_detail, graphviz_path, *args, **kwargs)