import collections import csv class CSVImporterBase: """Common base class for importing CSV files. Subclasses must define the following: * TEMPLATE_KEY: A string, as usual * NEEDED_FIELDS: A set of columns that must exist in the CSV file for this class to import it. * _read_row(self, row): A method that returns an entry data dict, or None if there's nothing to import from this row. Subclasses may define the following: * ENTRY_SEED: A dict with entry data that can be assumed when it's coming from this source. * COPIED_FIELDS: A dict that maps column names to data keys. These fields will be copied directly to the entry data dict before _read_row is called. Fields named here must exist in the CSV for it to be imported. * _read_header(cls, input_file): Some CSVs include "headers" with smaller rows before they get to the "real" data. This classmethod is expected to read those rows and return two values: a dict of entry data read from the headers, and a list of column names for the real data. The method is expected to leave input_data at the position where the real data starts, so callers can run ``csv.DictReader(input_file, column_names)`` after. The default implementation reads rows until it finds one long enough to include all of the columns required by NEEDED_FIELDS and COPIED_FIELDS, then returns ({}, that_row). * _read_header_row(cls, row): A classmethod that returns either a dict, or None. The default implementation of _read_header calls this method on each row. If it returns a dict, those keys and values will be included in the entry data returned by _read_header. If it returns None, _read_header expects this is the row with column names for the real data, and uses it in its return value. * Reader: A class that accepts the input source and iterates over rows of formatted data. Default csv.reader. * DictReader: A class that accepts the input source and iterates over rows of data organized into dictionaries. Default csv.DictReader. """ ENTRY_SEED = {} COPIED_FIELDS = {} Reader = csv.reader DictReader = csv.DictReader @classmethod def _read_header_row(cls, row): return {} if len(row) < cls._HEADER_MAX_LEN else None @classmethod def _read_header(cls, input_file): cls._NEEDED_KEYS = cls.NEEDED_FIELDS.union(cls.COPIED_FIELDS) cls._HEADER_MAX_LEN = len(cls._NEEDED_KEYS) header = {} row = None for row in cls.Reader(input_file): row_data = cls._read_header_row(row) if row_data is None: break else: header.update(row_data) return header, row @classmethod def can_import(cls, input_file): try: _, fields = cls._read_header(input_file) except csv.Error: return False else: return cls._NEEDED_KEYS.issubset(fields or ()) def __init__(self, input_file): self.entry_seed, fields = self._read_header(input_file) self.in_csv = self.DictReader(input_file, fields) def __iter__(self): for row in self.in_csv: row_data = self._read_row(row) if row_data is not None: copied_fields = { entry_key: row[row_key] for row_key, entry_key in self.COPIED_FIELDS.items() } yield collections.ChainMap( row_data, copied_fields, self.entry_seed, self.ENTRY_SEED)