Source code for neuralogic.core.constructs.function.softmax

from collections.abc import Sequence
from typing import Any

import jpype

from neuralogic.core.constructs.function.function import AggregationFunction


[docs] class SoftmaxAggregation(AggregationFunction): """ Represents a Softmax aggregation function. It can be parametrized by specific terms (variables) to aggregate over. """ __slots__ = ("agg_terms", "var_terms") def __init__( self, name: str, *, agg_terms: Sequence[str] | None = None, ): """ Parameters ---------- name : str The name of the aggregation function. agg_terms : Sequence[str], optional The terms (variables) to aggregate over. Default: None. """ super().__init__(name) self.term_indices = agg_terms self.agg_terms = agg_terms def __call__(self, *, agg_terms: Sequence[int] | None = None) -> Any: """ Creates a new SoftmaxAggregation instance with the provided aggregation terms. Parameters ---------- agg_terms : Sequence[int], optional The indices or names of terms to aggregate over. Default: None. Returns ------- AggregationFunction The new SoftmaxAggregation instance. """ softmax = SoftmaxAggregation(self.name, agg_terms=agg_terms) return AggregationFunction.__call__(softmax)
[docs] def is_parametrized(self) -> bool: return self.agg_terms is not None
[docs] def get(self) -> Any: return jpype.JClass("cz.cvut.fel.ida.algebra.functions.combination.Softmax")(self.term_indices)
def __str__(self) -> str: if self.agg_terms is None: return "softmax" return f"softmax(agg_terms=[{', '.join(self.agg_terms)}])"
[docs] def rule_head_dependant(self) -> bool: return self.agg_terms is not None
[docs] def process_head(self, head) -> "SoftmaxAggregation": """ Processes the rule head to determine the indices of the aggregation terms. Parameters ---------- head : Any The rule head. Returns ------- SoftmaxAggregation A new instance with the determined term indices. """ term_indices = [] for agg_term in set(self.agg_terms): if not agg_term[0].isupper(): raise TypeError(f"Softmax aggregable terms can be only variables. Provided: {agg_term}") for i, term in enumerate(head.terms): if agg_term == term: term_indices.append(i) break aggregation = SoftmaxAggregation(self.name, agg_terms=self.agg_terms) aggregation.term_indices = term_indices return aggregation