from typing import Any
import jpype
from neuralogic.core.constructs.function.function import TransformationFunction
[docs]
class Reshape(TransformationFunction):
"""
Represents a reshape transformation function that changes the shape of a tensor.
"""
__slots__ = ("shape",)
def __init__(
self,
name: str,
*,
shape: tuple[int, int] | int | None = None,
):
"""
Parameters
----------
name : str
The name of the function.
shape : Union[Tuple[int, int], int], optional
The target shape. Default: None.
"""
super().__init__(name)
self.shape = (shape,) if isinstance(shape, int) else shape
def __call__(
self,
relation: Any | None = None,
*,
shape: tuple[int, int] | int | None = None,
) -> Any:
"""
Creates a new Reshape instance with the provided shape and applies it to the relation.
Parameters
----------
relation : Any, optional
The relation to apply the reshape to. Default: None.
shape : tuple[int, int] | int | None, optional
The target shape. Default: None.
Returns
-------
TransformationFunction
The new Reshape instance (attached to the relation if provided).
"""
reshape = Reshape(self.name, shape=shape)
return TransformationFunction.__call__(reshape, relation)
[docs]
def is_parametrized(self) -> bool:
return True
[docs]
def get(self) -> Any:
shape = None if self.shape is None else list(self.shape)
return jpype.JClass("cz.cvut.fel.ida.algebra.functions.transformation.joint.Reshape")(shape)
[docs]
def wrap(self, content: str) -> str:
return f"reshape({content}, shape={self.shape})"
def __str__(self) -> str:
return f"reshape(shape={self.shape})"