Source code for neuralogic.core.neural_module

from __future__ import annotations

from typing import TYPE_CHECKING, Collection

import jpype

from neuralogic.core.builder.dataset import BuiltDataset, GroundedDataset

if TYPE_CHECKING:
    from neuralogic.core.builder import DatasetBuilder
    from neuralogic.core.settings.settings_proxy import SettingsProxy
from neuralogic.core.constructs.java_objects import ValueFactory
from neuralogic.dataset import Dataset
from neuralogic.dataset.base import BaseDataset
from neuralogic.setup import initialize, is_initialized
from neuralogic.utils.visualize import draw_model

Value = list | float


[docs] class NeuralModule: """ NeuralModule is the base class for all neural models. It provides methods for grounding, building, training, and testing. """ def __init__(self): """Initializes the neural module.""" if not is_initialized(): initialize() self._need_sync = False self._value_factory = ValueFactory() self._parsed_model = None self._dataset_builder: DatasetBuilder | None = None self._settings: SettingsProxy | None = None self._neural_model = None self._strategy = None self._trainer = None self._invalidation = None self._evaluation = None self._backpropagation = None self._weight_updater = None self._tensor_parameters = None self._torch_module = None
[docs] def ground( self, dataset: BaseDataset, *, batch_size: int = 1, learnable_facts: bool = False, progress: bool = False, ) -> GroundedDataset: """Grounds the provided dataset using the model's settings. Parameters ---------- dataset : BaseDataset The dataset to ground. batch_size : int The batch size for grounding. Default: 1. learnable_facts : bool Whether facts are learnable. Default: False. progress : bool Whether to show progress. Default: False. Returns ------- GroundedDataset The grounded dataset. """ if self._dataset_builder is None or self._settings is None: raise ValueError("model is not built") return self._dataset_builder.ground_dataset( dataset, self._settings, batch_size=batch_size, learnable_facts=learnable_facts, progress=progress, )
[docs] def build_dataset( self, dataset: BaseDataset | GroundedDataset, *, batch_size: int = 1, learnable_facts: bool = False, progress: bool = False, ) -> BuiltDataset: """Builds (ground and neuralize) the provided dataset. Parameters ---------- dataset : Union[BaseDataset, GroundedDataset] The dataset to build. batch_size : int The batch size. Default: 1. learnable_facts : bool Whether facts are learnable. Default: False. progress : bool Whether to show progress. Default: False. Returns ------- BuiltDataset The built dataset. """ if self._dataset_builder is None or self._settings is None: raise ValueError("model is not built") return self._dataset_builder.build_dataset( dataset, self._settings, batch_size=batch_size, learnable_facts=learnable_facts, progress=progress, )
def __call__(self, dataset=None): samples, _ = self._dataset_to_samples(dataset) sample_collection = samples if isinstance(samples, Collection) else [samples] for sample in sample_collection: self._trainer.invalidateSample(self._invalidation, sample._java_sample) results = [ self._value_factory.from_java( self._trainer.evaluateSample(self._evaluation, sample._java_sample).getOutput(), ) for sample in sample_collection ] if self._torch_module is None: return results return self._torch_module.forward(self, samples, results)
[docs] def forward(self, dataset): return self(dataset)
[docs] def train(self, dataset, epochs: int = 1) -> Value: """Trains the model on the provided dataset. Parameters ---------- dataset : Any The dataset to train on. Can be a Dataset, GroundedDataset, BuiltDataset, or a list of samples. epochs : int The number of epochs to train. Default: 1. Returns ------- Union[Tuple[Value, Value, Value], List[Tuple[Value, Value, Value]]] The training results (target, output, error). """ samples, batch_size = self._dataset_to_samples(dataset) if not isinstance(samples, Collection): result = self._strategy.learnSample(samples._java_sample) res = ( ValueFactory.from_java(result.getTarget()), ValueFactory.from_java(result.getOutput()), ValueFactory.from_java(result.errorValue()), ) else: sample_array = jpype.java.util.ArrayList([sample._java_sample for sample in samples]) results = self._strategy.learnSamples(sample_array, epochs, batch_size) res = [ ( ValueFactory.from_java(result.getTarget()), ValueFactory.from_java(result.getOutput()), ValueFactory.from_java(result.errorValue()), ) for result in results ] self._update_tensor_parameters() return res
[docs] def test(self, dataset) -> Value: """Tests the model on the provided dataset. Parameters ---------- dataset : Any The dataset to test on. Returns ------- Union[Value, List[Value]] The test results (outputs). """ samples, batch_size = self._dataset_to_samples(dataset) if not isinstance(samples, Collection): return ValueFactory.from_java(self._strategy.evaluateSample(samples._java_sample)) sample_array = jpype.java.util.ArrayList([sample._java_sample for sample in samples]) results = self._strategy.evaluateSamples(sample_array, batch_size) return [ValueFactory.from_java(result) for result in results]
[docs] def reset_parameters(self): self._strategy.resetParameters()
[docs] def parameters(self) -> dict: """Returns the model parameters. Returns ------- dict The model parameters. """ return self.state_dict()
[docs] def state_dict(self) -> dict: """Returns the state dictionary of the model. Returns ------- dict The state dictionary (weights and weight names). """ weights = self._neural_model.getAllWeights() weights_dict = {} weight_names = {} for weight in weights: if weight.isLearnable: weights_dict[weight.index] = ValueFactory.from_java(weight.value) weight_names[weight.index] = str(weight.name) return { "weights": weights_dict, "weight_names": weight_names, }
[docs] def tensor_parameters(self): if self._torch_module is None: raise NotImplementedError( "tensor_parameters() requires the PyTorch backend. Call model.build(settings, torch=True) to enable it." ) self._tensor_parameters = self._torch_module.tensor_parameters( self._tensor_parameters, self._weight_updater, self._value_factory, self._neural_model, ) return list(self._tensor_parameters)
def _update_tensor_parameters(self): if self._torch_module is not None: self._torch_module.update_tensor_parameters(self._tensor_parameters)
[docs] def load_state_dict(self, state_dict: dict): self._sync_model(state_dict, self._neural_model.getAllWeights()) if self._torch_module is not None: self._torch_module.update_tensor_parameters(self._tensor_parameters)
[docs] def draw( self, filename: str | None = None, show=True, img_type="png", value_detail: int = 0, graphviz_path: str | None = None, *args, **kwargs, ): if self._dataset_builder is None or self._settings is None: raise ValueError("model is not built") return draw_model(self, filename, show, img_type, value_detail, graphviz_path, *args, **kwargs)
def _initialize_neural_module(self, dataset_builder: DatasetBuilder, settings: SettingsProxy, model, torch: bool): self._dataset_builder = dataset_builder self._settings = settings self._neural_model = model if torch: try: import torch except ImportError: raise ImportError("torch is not installed in the environment") from neuralogic.core.torch.neural_module import TorchNeuralModule self._torch_module = TorchNeuralModule() optimizer = self._settings.optimizer.initialize() lr_decay = self._settings.optimizer.get_lr_decay() python_strategy = jpype.JClass( "cz.cvut.fel.ida.neural.networks.computation.training.strategies.PythonTrainingStrategy" ) self._strategy = python_strategy(settings.settings, model, optimizer, lr_decay) self._trainer = self._strategy.getTrainer() self._invalidation = self._trainer.getInvalidation() self._evaluation = self._trainer.getEvaluation() self._backpropagation = self._trainer.getBackpropagation() self._weight_updater = self._backpropagation.weightUpdater self.reset_parameters() def _dataset_to_samples(self, dataset): if isinstance(dataset, Dataset): dataset = self.build_dataset(dataset) return dataset._samples, dataset._batch_size if isinstance(dataset, GroundedDataset): dataset = dataset.neuralize() return dataset._samples, dataset._batch_size if isinstance(dataset, BuiltDataset): return dataset._samples, dataset._batch_size return dataset, 1 def _sync_model(self, state_dict: dict | None = None, weights=None): state_dict = self.state_dict() if state_dict is None else state_dict weights = self._parsed_model.getAllWeights() if weights is None else weights weight_dict = state_dict["weights"] for weight in weights: if not weight.isLearnable: continue weight_value = weight.value index = weight.index value = weight_dict[index] if isinstance(value, (float, int)): weight_value.set(0, float(value)) continue if isinstance(value[0], (float, int)): for i, val in enumerate(value): weight_value.set(i, float(val)) continue cols = len(value[0]) for i, values in enumerate(value): for j, val in enumerate(values): weight_value.set(i * cols + j, float(val)) def _backprop(self, sample, gradient): _, gradient_value = self._value_factory.get_value(gradient) weight_updater = self._backpropagation.backpropagate(sample._java_sample, gradient_value) state_index = self._backpropagation.backproper return state_index, weight_updater