import enum
from pathlib import Path
from typing import Optional, List, Union, TextIO, Callable, Sequence
from neuralogic.core.constructs.factories import R
from neuralogic.core.constructs.relation import BaseRelation
from neuralogic.core.constructs.rule import Rule
from neuralogic.dataset import Dataset, Sample
from neuralogic.dataset.base import ConvertibleDataset
DatasetEntries = Union[BaseRelation, Rule]
[docs]
class Mode(enum.Enum):
ONE_EXAMPLE = "one"
EXAMPLE_PER_SOURCE = "example_per_source"
ZIP = "zip"
[docs]
class CSVFile:
__slots__ = (
"relation_name",
"csv_source",
"sep",
"value_column",
"default_value",
"value_mapper",
"term_columns",
"header",
"skip_rows",
"n_rows",
"replace_empty_column",
)
def __init__(
self,
relation_name: str,
csv_source: Union[TextIO, Path],
sep=",",
value_column: Optional[Union[str, int]] = None,
default_value: Optional[Union[float, int]] = None,
value_mapper: Optional[Callable] = None,
term_columns: Optional[Sequence[Union[str, int]]] = None,
header: bool = False,
skip_rows: int = 0,
n_rows: Optional[int] = None,
replace_empty_column: Union[str, float, int] = 0,
):
self.relation_name = relation_name
self.csv_source = csv_source
self.sep = sep
self.value_column = value_column
self.default_value = default_value
self.value_mapper = value_mapper
self.term_columns = term_columns
self.header = header
self.skip_rows = skip_rows
self.n_rows = n_rows
self.replace_empty_column = replace_empty_column
@staticmethod
def _find_index_in_header(header, value) -> int:
for index, header_value in enumerate(header):
if value == header_value:
return index
raise ValueError(f"Value {value} not found in the header {header}")
def _get_column_indices(self, header) -> Optional[List[int]]:
if self.term_columns is None:
return None
new_columns = []
for col_value in self.term_columns:
new_columns.append(CSVFile._find_index_in_header(header, col_value))
return new_columns
def _to_logic(self, fp: TextIO) -> Sequence[DatasetEntries]:
example = []
use_columns = self.term_columns
value_column = self.value_column
default_value = self.default_value
value_mapper = self.value_mapper
relation = R.get(self.relation_name)
replace_empty = self.replace_empty_column
read_lines = 0
if self.header:
header = fp.readline()
if not header:
return example
headers = header.strip().split(self.sep)
value_column = None if value_column is None else CSVFile._find_index_in_header(headers, value_column)
use_columns = self._get_column_indices(headers)
for _ in range(self.skip_rows):
fp.readline()
while True:
line = fp.readline()
if not line or not line.strip():
break
terms = line.strip().split(self.sep)
if use_columns is None:
line_relation = relation(
[(term.strip().lower() if len(term.strip()) else replace_empty) for term in terms]
)
else:
line_relation = relation(
[(terms[i].strip().lower() if len(terms[i].strip()) else replace_empty) for i in use_columns]
)
if value_column is None:
if default_value is not None:
line_relation = line_relation[float(default_value)]
else:
value = terms[value_column].strip()
if not len(value):
value = default_value if default_value is not None else replace_empty
if value_mapper is None:
line_relation = line_relation[float(value)]
else:
line_relation = line_relation[value_mapper(value)]
example.append(line_relation)
read_lines += 1
if read_lines == self.n_rows:
break
return example
def to_logic_form(self) -> Sequence[DatasetEntries]:
if isinstance(self.csv_source, (str, Path)):
with open(self.csv_source, "r") as fp:
return self._to_logic(fp)
return self._to_logic(self.csv_source)
[docs]
class CSVDataset(ConvertibleDataset):
def __init__(
self,
csv_files: Union[List[CSVFile], CSVFile],
csv_queries: Optional[CSVFile] = None,
mode: Mode = Mode.ONE_EXAMPLE,
):
self.csv_queries = csv_queries
self.csv_files = [csv_files] if isinstance(csv_files, CSVFile) else csv_files
self.mode = mode
def add_csv_file(self, file: CSVFile):
self.csv_files.append(file)
def set_query_csv_file(self, file: CSVFile):
self.csv_queries = file
def to_dataset(self) -> Dataset:
queries = self.csv_queries.to_logic_form() if self.csv_queries else []
if self.mode == Mode.ONE_EXAMPLE:
example: List[DatasetEntries] = []
for source in self.csv_files:
example.extend(source.to_logic_form())
if not queries:
return Dataset([Sample(None, example)])
return Dataset([Sample(q, example) for q in queries])
elif self.mode == Mode.ZIP:
logic_examples = [source.to_logic_form() for source in self.csv_files]
if not queries:
return Dataset([Sample(None, zipped_example) for zipped_example in zip(*logic_examples)])
return Dataset([Sample(q, zipped_example) for q, zipped_example in zip(queries, zip(*logic_examples))])
elif self.mode == Mode.EXAMPLE_PER_SOURCE:
if not queries:
return Dataset([Sample(None, source.to_logic_form()) for source in self.csv_files])
return Dataset([Sample(q, source.to_logic_form()) for source, q in zip(self.csv_files, queries)])
raise NotImplementedError