import collections
import csv
class CSVImporterBase:
"""Common base class for importing CSV files.
Subclasses must define the following:
* TEMPLATE_KEY: A string, as usual
* NEEDED_FIELDS: A set of columns that must exist in the CSV file for
this class to import it.
* _read_row(self, row): A method that returns an entry data dict, or None
if there's nothing to import from this row.
Subclasses may define the following:
* ENTRY_SEED: A dict with entry data that can be assumed when it's
coming from this source.
* COPIED_FIELDS: A dict that maps column names to data keys. These fields
will be copied directly to the entry data dict before _read_row is called.
Fields named here must exist in the CSV for it to be imported.
* _read_header(cls, input_file): Some CSVs include "headers" with smaller
rows before they get to the "real" data. This classmethod is expected to
read those rows and return two values: a dict of entry data read from
the headers, and a list of column names for the real data. The method
is expected to leave input_data at the position where the real data
starts, so callers can run ``csv.DictReader(input_file, column_names)``
after.
The default implementation reads rows until it finds one long enough to
include all of the columns required by NEEDED_FIELDS and COPIED_FIELDS,
then returns ({}, that_row).
* _read_header_row(cls, row): A classmethod that returns either a dict,
or None. The default implementation of _read_header calls this method
on each row. If it returns a dict, those keys and values will be
included in the entry data returned by _read_header. If it returns
None, _read_header expects this is the row with column names for the
real data, and uses it in its return value.
* Reader: A class that accepts the input source and iterates over rows of
formatted data. Default csv.reader.
* DictReader: A class that accepts the input source and iterates over rows
of data organized into dictionaries. Default csv.DictReader.
"""
ENTRY_SEED = {}
COPIED_FIELDS = {}
Reader = csv.reader
DictReader = csv.DictReader
@classmethod
def _read_header_row(cls, row):
return {} if len(row) < cls._HEADER_MAX_LEN else None
@classmethod
def _read_header(cls, input_file):
cls._NEEDED_KEYS = cls.NEEDED_FIELDS.union(cls.COPIED_FIELDS)
cls._HEADER_MAX_LEN = len(cls._NEEDED_KEYS)
header = {}
row = None
for row in cls.Reader(input_file):
row_data = cls._read_header_row(row)
if row_data is None:
break
else:
header.update(row_data)
return header, row
@classmethod
def can_import(cls, input_file):
try:
_, fields = cls._read_header(input_file)
except csv.Error:
return False
else:
return cls._NEEDED_KEYS.issubset(fields or ())
def __init__(self, input_file):
self.entry_seed, fields = self._read_header(input_file)
self.in_csv = self.DictReader(input_file, fields)
def __iter__(self):
for row in self.in_csv:
row_data = self._read_row(row)
if row_data is not None:
copied_fields = {
entry_key: row[row_key]
for row_key, entry_key in self.COPIED_FIELDS.items()
}
yield collections.ChainMap(
row_data, copied_fields, self.entry_seed, self.ENTRY_SEED)