Files @ ab8559c75bdb
Branch filter:

Location: NPO-Accounting/import2ledger/import2ledger/importers/_csv.py

Brett Smith
csv: Support importing squared CSV spreadsheets.

See the test comment for more rationale.
import collections
import csv

class CSVImporterBase:
    """Common base class for importing CSV files.

    Subclasses must define the following:
    * TEMPLATE_KEY: A string, as usual
    * NEEDED_FIELDS: A set of columns that must exist in the CSV file for
      this class to import it.
    * _read_row(self, row): A method that returns an entry data dict, or None
      if there's nothing to import from this row.

    Subclasses may define the following:
    * ENTRY_SEED: A dict with entry data that can be assumed when it's
      coming from this source.
    * COPIED_FIELDS: A dict that maps column names to data keys.  These fields
      will be copied directly to the entry data dict before _read_row is called.
      Fields named here must exist in the CSV for it to be imported.
    * _read_header(cls, input_file): Some CSVs include "headers" with smaller
      rows before they get to the "real" data.  This classmethod is expected to
      read those rows and return two values: a dict of entry data read from
      the headers, and a list of column names for the real data.  The method
      is expected to leave input_data at the position where the real data
      starts, so callers can run ``csv.DictReader(input_file, column_names)``
      after.
      The default implementation reads rows until it finds one long enough to
      include all of the columns required by NEEDED_FIELDS and COPIED_FIELDS,
      then returns ({}, that_row).
    * _read_header_row(cls, row): A classmethod that returns either a dict,
      or None.  The default implementation of _read_header calls this method
      on each row.  If it returns a dict, those keys and values will be
      included in the entry data returned by _read_header.  If it returns
      None, _read_header expects this is the row with column names for the
      real data, and uses it in its return value.
    * Reader: A class that accepts the input source and iterates over rows of
      formatted data.  Default csv.reader.
    * DictReader: A class that accepts the input source and iterates over rows
      of data organized into dictionaries.  Default csv.DictReader.
    """
    ENTRY_SEED = {}
    COPIED_FIELDS = {}
    Reader = csv.reader
    DictReader = csv.DictReader

    @classmethod
    def _row_rindex(cls, row, default=None):
        """Return the index of the last cell in the row that has a value."""
        for offset, value in enumerate(reversed(row), 1):
            if value:
                return len(row) - offset
        return default

    @classmethod
    def _read_header_row(cls, row):
        return {} if cls._row_rindex(row, -1) + 1 < cls._HEADER_MAX_LEN else None

    @classmethod
    def _read_header(cls, input_file):
        cls._NEEDED_KEYS = cls.NEEDED_FIELDS.union(cls.COPIED_FIELDS)
        cls._HEADER_MAX_LEN = len(cls._NEEDED_KEYS)
        header = {}
        row = None
        for row in cls.Reader(input_file):
            row_data = cls._read_header_row(row)
            if row_data is None:
                break
            else:
                header.update(row_data)
        return header, row

    @classmethod
    def can_import(cls, input_file):
        try:
            _, fields = cls._read_header(input_file)
        except csv.Error:
            return False
        else:
            return cls._NEEDED_KEYS.issubset(fields or ())

    def __init__(self, input_file):
        self.entry_seed, fields = self._read_header(input_file)
        self.in_csv = self.DictReader(input_file, fields)

    def __iter__(self):
        for row in self.in_csv:
            row_data = self._read_row(row)
            if row_data is not None:
                copied_fields = {
                    entry_key: row[row_key]
                    for row_key, entry_key in self.COPIED_FIELDS.items()
                }
                yield collections.ChainMap(
                    row_data, copied_fields, self.entry_seed, self.ENTRY_SEED)