from collections.abc import Iterable
from typing import Any
import jpype
from neuralogic.core.constructs.function.enum import Combination
from neuralogic.core.constructs.function.function import Function
from neuralogic.core.constructs.function.function_graph import FunctionGraph
[docs]
class FContainer:
"""
A container for multiple logic nodes (relations or other containers) and a function to apply to them.
It allows for nesting and building complex function graphs.
"""
__slots__ = "nodes", "function"
def __init__(self, nodes: Iterable[Any], function: Function):
"""
Parameters
----------
nodes : Iterable[Any]
The nodes to be contained in the container.
function : Function
The function to apply to the nodes.
"""
self.function = function
self.nodes = nodes if not self.function.can_flatten else self.get_flattened_nodes(nodes, function)
[docs]
@staticmethod
def get_flattened_nodes(nodes: Iterable[Any], function: Function) -> tuple:
"""
Flattens the nodes if they are FContainers with the same function.
Parameters
----------
nodes : Iterable[Any]
The nodes to flatten.
function : Function
The function used for flattening criteria.
Returns
-------
tuple
The flattened nodes.
"""
new_nodes = []
for node in nodes:
if not isinstance(node, FContainer):
new_nodes.append(node)
continue
if node.function.name == function.name:
new_nodes.extend(node.nodes)
else:
new_nodes.append(node)
return tuple(new_nodes)
def __add__(self, other: Any) -> "FContainer":
return FContainer((self, other), Combination.SUM)
def __mul__(self, other: Any) -> "FContainer":
return FContainer((self, other), Combination.ELPRODUCT)
def __matmul__(self, other: Any) -> "FContainer":
return FContainer((self, other), Combination.PRODUCT)
def __str__(self) -> str:
if self.function.operator is not None:
return f" {self.function.operator} ".join(
node.to_str(True) if isinstance(node, FContainer) else node.to_str() for node in self.nodes
)
args = ", ".join(node.to_str() for node in self.nodes)
if args:
return f"{self.function}({args})"
return f"{self.function}"
@property
def name(self) -> str:
args = ", ".join(str(node.function) for node in self.nodes if isinstance(node, FContainer))
if args:
return f"{self.function}({args})"
return f"{self.function}"
def __iter__(self) -> Iterable[Any]:
for node in self.nodes:
if isinstance(node, FContainer):
for a in node:
yield a
else:
yield node
[docs]
def to_function(self) -> Function:
"""
Converts the container and its nested structure into a FunctionGraph.
Returns
-------
Function
The generated FunctionGraph.
"""
graph = self._get_function_node({}, 0)
return FunctionGraph(name=self.name, function_graph=graph)
def _get_function_node(self, input_counter: dict[int, int], start_index: int = 0) -> Any:
from neuralogic.core.constructs.relation import BaseRelation
next_indices = [-1] * len(self.nodes)
next_nodes = [None] * len(self.nodes)
for i, node in enumerate(self.nodes):
if isinstance(node, FContainer):
next_node = node._get_function_node(input_counter)
if next_node is None:
continue
next_nodes[i] = next_node
elif isinstance(node, BaseRelation):
idx = id(node)
if node.predicate.hidden or node.predicate.special or node.predicate.name.startswith("_"):
continue
if idx not in input_counter:
input_counter[idx] = len(input_counter) + start_index
next_indices[i] = input_counter[idx]
else:
raise ValueError(f"{node} of type {type(node)} inside of body function is not supported")
filtered_next_node = []
filtered_next_indices = []
for i, (node, index) in enumerate(zip(next_nodes, next_indices)):
if node is not None or index != -1:
filtered_next_node.append(node)
filtered_next_indices.append(index)
if not filtered_next_node or not filtered_next_indices:
return None
class_name = "cz.cvut.fel.ida.algebra.functions.combination.FunctionGraph.FunctionGraphNode"
return jpype.JClass(class_name)(self.function.get(), filtered_next_node, filtered_next_indices)
[docs]
def to_str(self, parentheses_wrap: bool = False) -> str:
"""
Returns a string representation of the container.
Parameters
----------
parentheses_wrap : bool
Whether to wrap the output in parentheses. Default: False.
Returns
-------
str
The string representation.
"""
if parentheses_wrap and self.function.operator is not None:
return f"({self})"
return self.__str__()