Source code for neuralogic.dataset.db

import csv
import io
from collections.abc import Callable
from typing import Any

from neuralogic.core.constructs.relation import BaseRelation, WeightedRelation
from neuralogic.core.constructs.rule import Rule
from neuralogic.dataset.base import ConvertibleDataset
from neuralogic.dataset.csv import CSVDataset, CSVFile, Mode
from neuralogic.dataset.logic import Dataset

DatasetEntries = BaseRelation | WeightedRelation | Rule


[docs] class DBSource: """ Represents a database source (table) and its configuration for conversion to logic relations. """ __slots__ = ( "relation_name", "table_name", "term_columns", "value_column", "default_value", "value_mapper", "skip_rows", "n_rows", "replace_empty_column", "sep", ) def __init__( self, relation_name: str, table_name: str, term_columns: list[str], value_column: str | None = None, default_value: float | int = 1.0, value_mapper: Callable | None = None, skip_rows: int = 0, n_rows: int | None = None, replace_empty_column: str | float | int = 0, sep: str = ",", ): """ Parameters ---------- relation_name : str The name of the relation to create from the database rows. table_name : str The name of the database table. term_columns : List[str] The columns to use as terms for the relation. value_column : str, optional The column containing the relation value (weight). Default: None. default_value : Union[float, int], optional The default value if not found in database. Default: 1.0. value_mapper : Callable, optional A function to map the database value to a different value. Default: None. skip_rows : int, optional The number of rows to skip at the beginning. Default: 0. n_rows : int, optional The maximum number of rows to read. Default: None (all). replace_empty_column : Union[str, float, int], optional The value to use for empty columns. Default: 0. sep : str, optional The separator to use for intermediate CSV representation. Default: ",". """ self.table_name = table_name self.relation_name = relation_name self.sep = sep self.value_column = value_column self.default_value = default_value self.value_mapper = value_mapper self.term_columns = term_columns self.skip_rows = skip_rows self.n_rows = n_rows self.replace_empty_column = replace_empty_column if len(term_columns) == 0: raise NotImplementedError("Cannot create DBSource with zero terms")
[docs] def to_csv(self, cursor: Any) -> CSVFile: """ Converts the database source to an intermediate CSV representation. Parameters ---------- cursor : Any The database cursor to use for execution. Returns ------- CSVFile The intermediate CSVFile object. """ source = io.StringIO() columns = [term for term in self.term_columns] term_columns = list(range(len(columns))) value_column = None if self.value_column is not None: columns.append(self.value_column) value_column = len(columns) - 1 if hasattr(cursor, "copy_to"): cursor.copy_to(source, self.table_name, sep=self.sep, null="", columns=columns) else: cursor.execute(f"SELECT {','.join(columns)} FROM {self.table_name}") results = cursor.fetchall() csv_writer = csv.writer(source, lineterminator="\n") csv_writer.writerows(results) source.seek(0) return CSVFile( self.relation_name, source, self.sep, value_column, self.default_value, self.value_mapper, term_columns, False, self.skip_rows, self.n_rows, self.replace_empty_column, )
[docs] class DBDataset(ConvertibleDataset): """ Represents a dataset composed of one or more database sources. """ def __init__( self, connection: Any, db_sources: list[DBSource] | DBSource, queries_db_source: DBSource | None = None, mode: Mode = Mode.ONE_EXAMPLE, ): """ Parameters ---------- connection : Any The database connection object. db_sources : Union[List[DBSource], DBSource] The database source(s) containing the examples. queries_db_source : DBSource, optional The database source containing the queries. Default: None. mode : Mode, optional The mode of creating samples. Default: Mode.ONE_EXAMPLE. """ self.connection = connection self.db_sources = [db_sources] if isinstance(db_sources, DBSource) else db_sources self.queries_db_source = queries_db_source self.mode = mode
[docs] def add_db_source(self, db_source: DBSource) -> None: self.db_sources.append(db_source)
[docs] def set_queries(self, db_source: DBSource) -> None: self.queries_db_source = db_source
[docs] def to_dataset(self) -> Dataset: """ Converts the database sources to a Dataset object. Returns ------- Dataset The created Dataset object. """ with self.connection.cursor() as cur: csv_files = [db_source.to_csv(cur) for db_source in self.db_sources] csv_queries = None if self.queries_db_source is None else self.queries_db_source.to_csv(cur) csv_dataset = CSVDataset(csv_files, csv_queries, self.mode) return csv_dataset.to_dataset()