import csv
import io
from collections.abc import Callable
from typing import Any
from neuralogic.core.constructs.relation import BaseRelation, WeightedRelation
from neuralogic.core.constructs.rule import Rule
from neuralogic.dataset.base import ConvertibleDataset
from neuralogic.dataset.csv import CSVDataset, CSVFile, Mode
from neuralogic.dataset.logic import Dataset
DatasetEntries = BaseRelation | WeightedRelation | Rule
[docs]
class DBSource:
"""
Represents a database source (table) and its configuration for conversion to logic relations.
"""
__slots__ = (
"relation_name",
"table_name",
"term_columns",
"value_column",
"default_value",
"value_mapper",
"skip_rows",
"n_rows",
"replace_empty_column",
"sep",
)
def __init__(
self,
relation_name: str,
table_name: str,
term_columns: list[str],
value_column: str | None = None,
default_value: float | int = 1.0,
value_mapper: Callable | None = None,
skip_rows: int = 0,
n_rows: int | None = None,
replace_empty_column: str | float | int = 0,
sep: str = ",",
):
"""
Parameters
----------
relation_name : str
The name of the relation to create from the database rows.
table_name : str
The name of the database table.
term_columns : List[str]
The columns to use as terms for the relation.
value_column : str, optional
The column containing the relation value (weight). Default: None.
default_value : Union[float, int], optional
The default value if not found in database. Default: 1.0.
value_mapper : Callable, optional
A function to map the database value to a different value. Default: None.
skip_rows : int, optional
The number of rows to skip at the beginning. Default: 0.
n_rows : int, optional
The maximum number of rows to read. Default: None (all).
replace_empty_column : Union[str, float, int], optional
The value to use for empty columns. Default: 0.
sep : str, optional
The separator to use for intermediate CSV representation. Default: ",".
"""
self.table_name = table_name
self.relation_name = relation_name
self.sep = sep
self.value_column = value_column
self.default_value = default_value
self.value_mapper = value_mapper
self.term_columns = term_columns
self.skip_rows = skip_rows
self.n_rows = n_rows
self.replace_empty_column = replace_empty_column
if len(term_columns) == 0:
raise NotImplementedError("Cannot create DBSource with zero terms")
[docs]
def to_csv(self, cursor: Any) -> CSVFile:
"""
Converts the database source to an intermediate CSV representation.
Parameters
----------
cursor : Any
The database cursor to use for execution.
Returns
-------
CSVFile
The intermediate CSVFile object.
"""
source = io.StringIO()
columns = [term for term in self.term_columns]
term_columns = list(range(len(columns)))
value_column = None
if self.value_column is not None:
columns.append(self.value_column)
value_column = len(columns) - 1
if hasattr(cursor, "copy_to"):
cursor.copy_to(source, self.table_name, sep=self.sep, null="", columns=columns)
else:
cursor.execute(f"SELECT {','.join(columns)} FROM {self.table_name}")
results = cursor.fetchall()
csv_writer = csv.writer(source, lineterminator="\n")
csv_writer.writerows(results)
source.seek(0)
return CSVFile(
self.relation_name,
source,
self.sep,
value_column,
self.default_value,
self.value_mapper,
term_columns,
False,
self.skip_rows,
self.n_rows,
self.replace_empty_column,
)
[docs]
class DBDataset(ConvertibleDataset):
"""
Represents a dataset composed of one or more database sources.
"""
def __init__(
self,
connection: Any,
db_sources: list[DBSource] | DBSource,
queries_db_source: DBSource | None = None,
mode: Mode = Mode.ONE_EXAMPLE,
):
"""
Parameters
----------
connection : Any
The database connection object.
db_sources : Union[List[DBSource], DBSource]
The database source(s) containing the examples.
queries_db_source : DBSource, optional
The database source containing the queries. Default: None.
mode : Mode, optional
The mode of creating samples. Default: Mode.ONE_EXAMPLE.
"""
self.connection = connection
self.db_sources = [db_sources] if isinstance(db_sources, DBSource) else db_sources
self.queries_db_source = queries_db_source
self.mode = mode
[docs]
def add_db_source(self, db_source: DBSource) -> None:
self.db_sources.append(db_source)
[docs]
def set_queries(self, db_source: DBSource) -> None:
self.queries_db_source = db_source
[docs]
def to_dataset(self) -> Dataset:
"""
Converts the database sources to a Dataset object.
Returns
-------
Dataset
The created Dataset object.
"""
with self.connection.cursor() as cur:
csv_files = [db_source.to_csv(cur) for db_source in self.db_sources]
csv_queries = None if self.queries_db_source is None else self.queries_db_source.to_csv(cur)
csv_dataset = CSVDataset(csv_files, csv_queries, self.mode)
return csv_dataset.to_dataset()