Source code for neuralogic.nn.module.general.attention

import math
from typing import Optional

from neuralogic.core.constructs.function import Transformation, Combination, Aggregation
from neuralogic.core.constructs.factories import R
from neuralogic.nn.module.module import Module


[docs] class Attention(Module): r""" A single-head attention module based on `"Attention Is All You Need" <https://arxiv.org/abs/1706.03762>`_. Parameters ---------- embed_dim : int The number of expected features. output_name : str Output (head) predicate name of the module. query_name : str The name of the queries predicate. key_name : str The name of the keys predicate. value_name : str The name of the values predicate. mask_name : str, optional The name of the input mask predicate. Default: ``None`` arity : int Arity of the input and output predicates. Default: ``1`` """ def __init__( self, embed_dim: int, output_name: str, query_name: str, key_name: str, value_name: str, mask_name: Optional[str] = None, arity: int = 1, ): self.embed_dim = embed_dim self.output_name = output_name self.query_name = query_name self.key_name = key_name self.value_name = value_name self.mask_name = mask_name self.arity = arity def __call__(self): terms = [f"X{i}" for i in range(self.arity - 1)] k_terms = [*terms, "Y"] h_terms = [*terms, "X", "Y"] q_terms = [*terms, "X"] d_k = 1 / math.sqrt(self.embed_dim) dk_rel = R.get(f"{self.output_name}__dk") dot_rel = R.get(f"{self.output_name}__dot") metadata = [Combination.PRODUCT, Transformation.IDENTITY, Aggregation.SOFTMAX(agg_terms=["Y"])] out_metadata = [Combination.PRODUCT, Aggregation.SUM, Transformation.IDENTITY] attention_product_rules = [ (dot_rel(h_terms) <= (dk_rel, R.get(self.key_name)(k_terms).T, R.get(self.query_name)(q_terms))) | metadata, dot_rel / (self.arity + 1) | [Transformation.IDENTITY], ] if self.mask_name is not None: attention_product_rules[0].body.append(R.hidden.get(self.mask_name)(h_terms)) return [ dk_rel[d_k].fixed(), *attention_product_rules, (R.get(self.output_name)(q_terms) <= (dot_rel(h_terms), R.get(self.value_name)(k_terms))) | out_metadata, R.get(self.output_name) / self.arity | [Transformation.IDENTITY], ]
[docs] class MultiheadAttention(Module): r""" A multi-head attention module based on `"Attention Is All You Need" <https://arxiv.org/abs/1706.03762>`_. Parameters ---------- embed_dim : int The number of expected features. num_heads : int The number of heads. output_name : str Output (head) predicate name of the module. query_name : str The name of the queries predicate. key_name : str The name of the keys predicate. value_name : str The name of the values predicate. vdim : int Total number of features for values. kdim : int Total number of features for keys. mask_name : str, optional The name of the input mask predicate. Default: ``None`` arity : int Arity of the input and output predicates. Default: ``1`` """ def __init__( self, embed_dim: int, num_heads: int, output_name: str, query_name: str, key_name: str, value_name: str, vdim: int = None, kdim: int = None, mask_name: Optional[str] = None, arity: int = 1, ): self.embed_dim = embed_dim self.num_heads = num_heads self.output_name = output_name self.queries = query_name self.keys = key_name self.values = value_name self.vdim = vdim if vdim is not None else embed_dim self.kdim = kdim if kdim is not None else embed_dim self.mask_name = mask_name self.arity = arity def __call__(self): terms = [f"X{i}" for i in range(self.arity)] dim = self.embed_dim q_weight = f"{self.output_name}_qw" k_weight = f"{self.output_name}_kw" v_weight = f"{self.output_name}_vw" q_proj_name = f"{self.output_name}__qproj" k_proj_name = f"{self.output_name}__kproj" v_proj_name = f"{self.output_name}__vproj" q_proj = R.get(q_proj_name) k_proj = R.get(k_proj_name) v_proj = R.get(v_proj_name) output_rel = R.get(self.output_name) attention_name = f"{self.output_name}__attention" attention = Attention( dim // self.num_heads, attention_name, q_proj_name, k_proj_name, v_proj_name, self.mask_name, self.arity ) if self.num_heads != 1: size = self.embed_dim / self.num_heads attention.arity += 1 attention_concat = [] multihead_rules = [ q_proj / (self.arity + 1) | [Transformation.IDENTITY], k_proj / (self.arity + 1) | [Transformation.IDENTITY], v_proj / (self.arity + 1) | [Transformation.IDENTITY], output_rel / self.arity | [Transformation.IDENTITY], ] for i in range(self.num_heads): meta = [Transformation.SLICE(rows=(i * size, (i + 1) * size))] multihead_rules.append((q_proj(i, *terms) <= R.get(self.queries)(terms)[q_weight:dim, dim]) | meta) multihead_rules.append((v_proj(i, *terms) <= R.get(self.values)(terms)[v_weight:dim, self.vdim]) | meta) multihead_rules.append((k_proj(i, *terms) <= R.get(self.keys)(terms)[k_weight:dim, self.kdim]) | meta) attention_concat.append(R.get(attention_name)(i, *terms)) multihead_rules.append( (output_rel(terms)[dim, dim] <= attention_concat) | [Transformation.IDENTITY, Combination.CONCAT] ) else: multihead_rules = [ (q_proj(terms)[q_weight:dim, dim] <= R.get(self.queries)(terms)) | [Transformation.IDENTITY], q_proj / self.arity | [Transformation.IDENTITY], (v_proj(terms)[v_weight:dim, self.vdim] <= R.get(self.values)(terms)) | [Transformation.IDENTITY], v_proj / self.arity | [Transformation.IDENTITY], (k_proj(terms)[k_weight:dim, self.kdim] <= R.get(self.keys)(terms)) | [Transformation.IDENTITY], k_proj / self.arity | [Transformation.IDENTITY], (output_rel(terms)[dim, dim] <= R.get(attention_name)(terms)) | [Transformation.IDENTITY], output_rel / self.arity | [Transformation.IDENTITY], ] return [*attention(), *multihead_rules]