Source code for neuralogic.core.settings

import weakref
from typing import Any

from neuralogic.core.enums import Grounder
from neuralogic.core.settings.settings_proxy import SettingsProxy
from neuralogic.nn.init import Initializer, Uniform
from neuralogic.nn.loss import MSE, ErrorFunction
from neuralogic.nn.optim import Adam, Optimizer


[docs] class Settings: def __init__( self, *, optimizer: Optimizer = Adam(), error_function: ErrorFunction = MSE(), initializer: Initializer = Uniform(), iso_value_compression: bool = True, chain_pruning: bool = True, prune_only_identities: bool = False, grounder: Grounder = Grounder.BUP, ): self.params = locals().copy() self.params.pop("self") self._proxies: weakref.WeakSet[SettingsProxy] = weakref.WeakSet() self.kw_params: dict[str, Any] = {} @property def iso_value_compression(self) -> bool: return self.params["iso_value_compression"] @iso_value_compression.setter def iso_value_compression(self, iso_value_compression: bool): self._update("iso_value_compression", iso_value_compression) @property def chain_pruning(self) -> bool: return self.params["chain_pruning"] @chain_pruning.setter def chain_pruning(self, chain_pruning: bool): self._update("chain_pruning", chain_pruning) @property def prune_only_identities(self) -> bool: return self.params["prune_only_identities"] @prune_only_identities.setter def prune_only_identities(self, prune_only_identities: bool): self._update("prune_only_identities", prune_only_identities) @property def grounder(self) -> Grounder: return self.params["grounder"] @grounder.setter def grounder(self, grounder: Grounder): self._update("grounder", grounder) @property def optimizer(self) -> Optimizer: return self.params["optimizer"] @optimizer.setter def optimizer(self, optimizer: Optimizer): self._update("optimizer", optimizer) @property def error_function(self) -> ErrorFunction: return self.params["error_function"] @error_function.setter def error_function(self, error_function: ErrorFunction): self._update("error_function", error_function) @property def initializer(self) -> Initializer: return self.params["initializer"] @initializer.setter def initializer(self, initializer: Initializer): self._update("initializer", initializer)
[docs] def create_proxy(self) -> SettingsProxy: proxy = SettingsProxy(**self.params) self._proxies.add(proxy) for k, v in self.kw_params.items(): proxy[k] = v return proxy
[docs] def create_disconnected_proxy(self) -> SettingsProxy: proxy = SettingsProxy(**self.params) for k, v in self.kw_params.items(): proxy[k] = v return proxy
def __setitem__(self, key, value): for proxy in self._proxies.copy(): proxy[key] = value self.kw_params[key] = value def __getitem__(self, item): return self.kw_params[item] def _update(self, parameter: str, value: Any) -> None: if parameter not in self.params: raise NotImplementedError self.params[parameter] = value for proxy in self._proxies.copy(): proxy.__setattr__(parameter, value)