Changeset - 6956c78b0de1
[Not reviewed]
registrasion/controllers/cart.py
Show inline comments
 
import collections
 
import contextlib
 
import datetime
 
import discount
 
import functools
 
import itertools
 

	
 
from django.core.exceptions import ObjectDoesNotExist
 
from django.core.exceptions import ValidationError
 
from django.db import transaction
 
from django.db.models import Max
 
from django.db.models import Q
 
from django.utils import timezone
 

	
 
from registrasion.exceptions import CartValidationError
 
from registrasion.models import commerce
 
from registrasion.models import conditions
 
from registrasion.models import inventory
 

	
 
from category import CategoryController
 
from conditions import ConditionController
 
from product import ProductController
 
from .category import CategoryController
 
from .discount import DiscountController
 
from .flag import FlagController
 
from .product import ProductController
 

	
 

	
 
def _modifies_cart(func):
 
    ''' Decorator that makes the wrapped function raise ValidationError
 
    if we're doing something that could modify the cart. '''
 
    if we're doing something that could modify the cart.
 

	
 
    It also wraps the execution of this function in a database transaction,
 
    and marks the boundaries of a cart operations batch.
 
    '''
 

	
 
    @functools.wraps(func)
 
    def inner(self, *a, **k):
 
        self._fail_if_cart_is_not_active()
 
        return func(self, *a, **k)
 
        with transaction.atomic():
 
            with CartController.operations_batch(self.cart.user) as mark:
 
                mark.mark = True  # Marker that we've modified the cart
 
                return func(self, *a, **k)
 

	
 
    return inner
 

	
 

	
 
class CartController(object):
 

	
 
    def __init__(self, cart):
 
        self.cart = cart
 

	
 
    @classmethod
 
    def for_user(cls, user):
 
        ''' Returns the user's current cart, or creates a new cart
...
 
@@ -46,155 +55,222 @@ class CartController(object):
 
            existing = commerce.Cart.objects.get(
 
                user=user,
 
                status=commerce.Cart.STATUS_ACTIVE,
 
            )
 
        except ObjectDoesNotExist:
 
            existing = commerce.Cart.objects.create(
 
                user=user,
 
                time_last_updated=timezone.now(),
 
                reservation_duration=datetime.timedelta(),
 
            )
 
        return cls(existing)
 

	
 
    # Marks the carts that are currently in batches
 
    _FOR_USER = {}
 
    _BATCH_COUNT = collections.defaultdict(int)
 
    _MODIFIED_CARTS = set()
 

	
 
    class _ModificationMarker(object):
 
        pass
 

	
 
    @classmethod
 
    @contextlib.contextmanager
 
    def operations_batch(cls, user):
 
        ''' Marks the boundary for a batch of operations on a user's cart.
 

	
 
        These markers can be nested. Only on exiting the outermost marker will
 
        a batch be ended.
 

	
 
        When a batch is ended, discounts are recalculated, and the cart's
 
        revision is increased.
 
        '''
 

	
 
        if user not in cls._FOR_USER:
 
            _ctrl = cls.for_user(user)
 
            cls._FOR_USER[user] = (_ctrl, _ctrl.cart.id)
 

	
 
        ctrl, _id = cls._FOR_USER[user]
 

	
 
        cls._BATCH_COUNT[_id] += 1
 
        try:
 
            success = False
 

	
 
            marker = cls._ModificationMarker()
 
            yield marker
 

	
 
            if hasattr(marker, "mark"):
 
                cls._MODIFIED_CARTS.add(_id)
 

	
 
            success = True
 
        finally:
 

	
 
            cls._BATCH_COUNT[_id] -= 1
 

	
 
            # Only end on the outermost batch marker, and only if
 
            # it excited cleanly, and a modification occurred
 
            modified = _id in cls._MODIFIED_CARTS
 
            outermost = cls._BATCH_COUNT[_id] == 0
 
            if modified and outermost and success:
 
                ctrl._end_batch()
 
                cls._MODIFIED_CARTS.remove(_id)
 

	
 
            # Clear out the cache on the outermost operation
 
            if outermost:
 
                del cls._FOR_USER[user]
 

	
 
    def _fail_if_cart_is_not_active(self):
 
        self.cart.refresh_from_db()
 
        if self.cart.status != commerce.Cart.STATUS_ACTIVE:
 
            raise ValidationError("You can only amend active carts.")
 

	
 
    @_modifies_cart
 
    def extend_reservation(self):
 
    def _autoextend_reservation(self):
 
        ''' Updates the cart's time last updated value, which is used to
 
        determine whether the cart has reserved the items and discounts it
 
        holds. '''
 

	
 
        reservations = [datetime.timedelta()]
 

	
 
        # If we have vouchers, we're entitled to an hour at minimum.
 
        if len(self.cart.vouchers.all()) >= 1:
 
            reservations.append(inventory.Voucher.RESERVATION_DURATION)
 

	
 
        # Else, it's the maximum of the included products
 
        items = commerce.ProductItem.objects.filter(cart=self.cart)
 
        agg = items.aggregate(Max("product__reservation_duration"))
 
        product_max = agg["product__reservation_duration__max"]
 

	
 
        if product_max is not None:
 
            reservations.append(product_max)
 

	
 
        self.cart.time_last_updated = timezone.now()
 
        self.cart.reservation_duration = max(reservations)
 

	
 
    @_modifies_cart
 
    def end_batch(self):
 
    def _end_batch(self):
 
        ''' Performs operations that occur occur at the end of a batch of
 
        product changes/voucher applications etc.
 
        THIS SHOULD BE PRIVATE
 

	
 
        You need to call this after you've finished modifying the user's cart.
 
        This is normally done by wrapping a block of code using
 
        ``operations_batch``.
 

	
 
        '''
 

	
 
        self.recalculate_discounts()
 
        self.cart.refresh_from_db()
 

	
 
        self._recalculate_discounts()
 

	
 
        self.extend_reservation()
 
        self._autoextend_reservation()
 
        self.cart.revision += 1
 
        self.cart.save()
 

	
 
    @_modifies_cart
 
    @transaction.atomic
 
    def set_quantities(self, product_quantities):
 
        ''' Sets the quantities on each of the products on each of the
 
        products specified. Raises an exception (ValidationError) if a limit
 
        is violated. `product_quantities` is an iterable of (product, quantity)
 
        pairs. '''
 

	
 
        items_in_cart = commerce.ProductItem.objects.filter(cart=self.cart)
 
        items_in_cart = items_in_cart.select_related(
 
            "product",
 
            "product__category",
 
        )
 

	
 
        product_quantities = list(product_quantities)
 

	
 
        # n.b need to add have the existing items first so that the new
 
        # items override the old ones.
 
        all_product_quantities = dict(itertools.chain(
 
            ((i.product, i.quantity) for i in items_in_cart.all()),
 
            product_quantities,
 
        )).items()
 

	
 
        # Validate that the limits we're adding are OK
 
        self._test_limits(all_product_quantities)
 

	
 
        new_items = []
 
        products = []
 
        for product, quantity in product_quantities:
 
            try:
 
                product_item = commerce.ProductItem.objects.get(
 
                    cart=self.cart,
 
                    product=product,
 
                )
 
                product_item.quantity = quantity
 
                product_item.save()
 
            except ObjectDoesNotExist:
 
                commerce.ProductItem.objects.create(
 
                    cart=self.cart,
 
                    product=product,
 
                    quantity=quantity,
 
                )
 
            products.append(product)
 

	
 
        items_in_cart.filter(quantity=0).delete()
 
            if quantity == 0:
 
                continue
 

	
 
        self.end_batch()
 
            item = commerce.ProductItem(
 
                cart=self.cart,
 
                product=product,
 
                quantity=quantity,
 
            )
 
            new_items.append(item)
 

	
 
        to_delete = (
 
            Q(quantity=0) |
 
            Q(product__in=products)
 
        )
 

	
 
        items_in_cart.filter(to_delete).delete()
 
        commerce.ProductItem.objects.bulk_create(new_items)
 

	
 
    def _test_limits(self, product_quantities):
 
        ''' Tests that the quantity changes we intend to make do not violate
 
        the limits and flag conditions imposed on the products. '''
 

	
 
        errors = []
 

	
 
        # Pre-annotate products
 
        products = [p for (p, q) in product_quantities]
 
        r = ProductController.attach_user_remainders(self.cart.user, products)
 
        with_remainders = dict((p, p) for p in r)
 

	
 
        # Test each product limit here
 
        for product, quantity in product_quantities:
 
            if quantity < 0:
 
                errors.append((product, "Value must be zero or greater."))
 

	
 
            prod = ProductController(product)
 
            limit = prod.user_quantity_remaining(self.cart.user)
 
            limit = with_remainders[product].remainder
 

	
 
            if quantity > limit:
 
                errors.append((
 
                    product,
 
                    "You may only have %d of product: %s" % (
 
                        limit, product,
 
                    )
 
                ))
 

	
 
        # Collect by category
 
        by_cat = collections.defaultdict(list)
 
        for product, quantity in product_quantities:
 
            by_cat[product.category].append((product, quantity))
 

	
 
        # Pre-annotate categories
 
        r = CategoryController.attach_user_remainders(self.cart.user, by_cat)
 
        with_remainders = dict((cat, cat) for cat in r)
 

	
 
        # Test each category limit here
 
        for category in by_cat:
 
            ctrl = CategoryController(category)
 
            limit = ctrl.user_quantity_remaining(self.cart.user)
 
            limit = with_remainders[category].remainder
 

	
 
            # Get the amount so far in the cart
 
            to_add = sum(i[1] for i in by_cat[category])
 

	
 
            if to_add > limit:
 
                errors.append((
 
                    category,
 
                    "You may only have %d items in category: %s" % (
 
                        limit, category.name,
 
                    )
 
                ))
 

	
 
        # Test the flag conditions
 
        errs = ConditionController.test_flags(
 
        errs = FlagController.test_flags(
 
            self.cart.user,
 
            product_quantities=product_quantities,
 
        )
 

	
 
        if errs:
 
            for error in errs:
 
                errors.append(error)
 

	
 
        if errors:
 
            raise CartValidationError(errors)
 

	
 
    @_modifies_cart
...
 
@@ -203,25 +279,24 @@ class CartController(object):
 

	
 
        # Try and find the voucher
 
        voucher = inventory.Voucher.objects.get(code=voucher_code.upper())
 

	
 
        # Re-applying vouchers should be idempotent
 
        if voucher in self.cart.vouchers.all():
 
            return
 

	
 
        self._test_voucher(voucher)
 

	
 
        # If successful...
 
        self.cart.vouchers.add(voucher)
 
        self.end_batch()
 

	
 
    def _test_voucher(self, voucher):
 
        ''' Tests whether this voucher is allowed to be applied to this cart.
 
        Raises ValidationError if not. '''
 

	
 
        # Is voucher exhausted?
 
        active_carts = commerce.Cart.reserved_carts()
 

	
 
        # It's invalid for a user to enter a voucher that's exhausted
 
        carts_with_voucher = active_carts.filter(vouchers=voucher)
 
        carts_with_voucher = carts_with_voucher.exclude(pk=self.cart.id)
 
        if carts_with_voucher.count() >= voucher.limit:
...
 
@@ -285,59 +360,64 @@ class CartController(object):
 
        this is normally called before generating or paying an invoice '''
 

	
 
        cart = self.cart
 
        user = self.cart.user
 
        errors = []
 

	
 
        try:
 
            self._test_vouchers(self.cart.vouchers.all())
 
        except ValidationError as ve:
 
            errors.append(ve)
 

	
 
        items = commerce.ProductItem.objects.filter(cart=cart)
 
        items = items.select_related("product", "product__category")
 

	
 
        product_quantities = list((i.product, i.quantity) for i in items)
 
        try:
 
            self._test_limits(product_quantities)
 
        except ValidationError as ve:
 
            self._append_errors(errors, ve)
 

	
 
        try:
 
            self._test_required_categories()
 
        except ValidationError as ve:
 
            self._append_errors(errors, ve)
 

	
 
        # Validate the discounts
 
        discount_items = commerce.DiscountItem.objects.filter(cart=cart)
 
        seen_discounts = set()
 
        # TODO: refactor in terms of available_discounts
 
        # why aren't we doing that here?!
 

	
 
        #     def available_discounts(cls, user, categories, products):
 

	
 
        products = [i.product for i in items]
 
        discounts_with_quantity = DiscountController.available_discounts(
 
            user,
 
            [],
 
            products,
 
        )
 
        discounts = set(i.discount.id for i in discounts_with_quantity)
 

	
 
        discount_items = commerce.DiscountItem.objects.filter(cart=cart)
 
        for discount_item in discount_items:
 
            discount = discount_item.discount
 
            if discount in seen_discounts:
 
                continue
 
            seen_discounts.add(discount)
 
            real_discount = conditions.DiscountBase.objects.get_subclass(
 
                pk=discount.pk)
 
            cond = ConditionController.for_condition(real_discount)
 

	
 
            if not cond.is_met(user):
 
            if discount.id not in discounts:
 
                errors.append(
 
                    ValidationError("Discounts are no longer available")
 
                )
 

	
 
        if errors:
 
            raise ValidationError(errors)
 

	
 
    @_modifies_cart
 
    @transaction.atomic
 
    def fix_simple_errors(self):
 
        ''' This attempts to fix the easy errors raised by ValidationError.
 
        This includes removing items from the cart that are no longer
 
        available, recalculating all of the discounts, and removing voucher
 
        codes that are no longer available. '''
 

	
 
        # Fix vouchers first (this affects available discounts)
 
        to_remove = []
 
        for voucher in self.cart.vouchers.all():
 
            try:
 
                self._test_voucher(voucher)
 
            except ValidationError:
...
 
@@ -351,39 +431,41 @@ class CartController(object):
 
        items = items.select_related("product")
 
        products = set(i.product for i in items)
 
        available = set(ProductController.available_products(
 
            self.cart.user,
 
            products=products,
 
        ))
 

	
 
        not_available = products - available
 
        zeros = [(product, 0) for product in not_available]
 

	
 
        self.set_quantities(zeros)
 

	
 
    @_modifies_cart
 
    @transaction.atomic
 
    def recalculate_discounts(self):
 
        ''' Calculates all of the discounts available for this product.
 
        '''
 
    def _recalculate_discounts(self):
 
        ''' Calculates all of the discounts available for this product.'''
 

	
 
        # Delete the existing entries.
 
        commerce.DiscountItem.objects.filter(cart=self.cart).delete()
 

	
 
        product_items = self.cart.productitem_set.all().select_related(
 
            "product", "product__category", "product__price"
 
        )
 

	
 
        products = [i.product for i in product_items]
 
        discounts = discount.available_discounts(self.cart.user, [], products)
 
        discounts = DiscountController.available_discounts(
 
            self.cart.user,
 
            [],
 
            products,
 
        )
 

	
 
        # The highest-value discounts will apply to the highest-value
 
        # products first.
 
        product_items = reversed(product_items)
 
        for item in product_items:
 
            self._add_discount(item.product, item.quantity, discounts)
 

	
 
    def _add_discount(self, product, quantity, discounts):
 
        ''' Applies the best discounts on the given product, from the given
 
        discounts.'''
 

	
 
        def matches(discount):
registrasion/controllers/category.py
Show inline comments
 
from registrasion.models import commerce
 
from registrasion.models import inventory
 

	
 
from django.db.models import Case
 
from django.db.models import F, Q
 
from django.db.models import Sum
 
from django.db.models import When
 
from django.db.models import Value
 

	
 

	
 
class AllProducts(object):
 
    pass
 

	
 

	
 
class CategoryController(object):
 

	
 
    def __init__(self, category):
 
        self.category = category
 

	
 
    @classmethod
...
 
@@ -25,34 +29,56 @@ class CategoryController(object):
 
        if products is AllProducts:
 
            products = inventory.Product.objects.all().select_related(
 
                "category",
 
            )
 

	
 
        available = ProductController.available_products(
 
            user,
 
            products=products,
 
        )
 

	
 
        return set(i.category for i in available)
 

	
 
    def user_quantity_remaining(self, user):
 
        ''' Returns the number of items from this category that the user may
 
        add in the current cart. '''
 
    @classmethod
 
    def attach_user_remainders(cls, user, categories):
 
        '''
 

	
 
        Return:
 
            queryset(inventory.Product): A queryset containing items from
 
            ``categories``, with an extra attribute -- remainder = the amount
 
            of items from this category that is remaining.
 
        '''
 

	
 
        cat_limit = self.category.limit_per_user
 
        ids = [category.id for category in categories]
 
        categories = inventory.Category.objects.filter(id__in=ids)
 

	
 
        if cat_limit is None:
 
            # We don't need to waste the following queries
 
            return 99999999
 
        cart_filter = (
 
            Q(product__productitem__cart__user=user) &
 
            Q(product__productitem__cart__status=commerce.Cart.STATUS_PAID)
 
        )
 

	
 
        quantity = When(
 
            cart_filter,
 
            then='product__productitem__quantity'
 
        )
 

	
 
        carts = commerce.Cart.objects.filter(
 
            user=user,
 
            status=commerce.Cart.STATUS_PAID,
 
        quantity_or_zero = Case(
 
            quantity,
 
            default=Value(0),
 
        )
 

	
 
        items = commerce.ProductItem.objects.filter(
 
            cart__in=carts,
 
            product__category=self.category,
 
        remainder = Case(
 
            When(limit_per_user=None, then=Value(99999999)),
 
            default=F('limit_per_user') - Sum(quantity_or_zero),
 
        )
 

	
 
        cat_count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0
 
        return cat_limit - cat_count
 
        categories = categories.annotate(remainder=remainder)
 

	
 
        return categories
 

	
 
    def user_quantity_remaining(self, user):
 
        ''' Returns the quantity of this product that the user add in the
 
        current cart. '''
 

	
 
        with_remainders = self.attach_user_remainders(user, [self.category])
 

	
 
        return with_remainders[0].remainder
registrasion/controllers/conditions.py
Show inline comments
 
import itertools
 
import operator
 

	
 
from collections import defaultdict
 
from collections import namedtuple
 

	
 
from django.db.models import Case
 
from django.db.models import F, Q
 
from django.db.models import Sum
 
from django.db.models import Value
 
from django.db.models import When
 
from django.utils import timezone
 

	
 
from registrasion.models import commerce
 
from registrasion.models import conditions
 
from registrasion.models import inventory
 

	
 

	
 
ConditionAndRemainder = namedtuple(
 
    "ConditionAndRemainder",
 
    (
 
        "condition",
 
        "remainder",
 
    ),
 
)
 
_BIG_QUANTITY = 99999999  # A big quantity
 

	
 

	
 
class ConditionController(object):
 
    ''' Base class for testing conditions that activate Flag
 
    or Discount objects. '''
 

	
 
    def __init__(self):
 
        pass
 
    def __init__(self, condition):
 
        self.condition = condition
 

	
 
    @staticmethod
 
    def for_condition(condition):
 
        CONTROLLERS = {
 
    def _controllers():
 
        return {
 
            conditions.CategoryFlag: CategoryConditionController,
 
            conditions.IncludedProductDiscount: ProductConditionController,
 
            conditions.ProductFlag: ProductConditionController,
 
            conditions.TimeOrStockLimitDiscount:
 
                TimeOrStockLimitDiscountController,
 
            conditions.TimeOrStockLimitFlag:
 
                TimeOrStockLimitFlagController,
 
            conditions.VoucherDiscount: VoucherConditionController,
 
            conditions.VoucherFlag: VoucherConditionController,
 
        }
 

	
 
    @staticmethod
 
    def for_type(cls):
 
        return ConditionController._controllers()[cls]
 

	
 
    @staticmethod
 
    def for_condition(condition):
 
        try:
 
            return CONTROLLERS[type(condition)](condition)
 
            return ConditionController.for_type(type(condition))(condition)
 
        except KeyError:
 
            return ConditionController()
 

	
 
    SINGLE = True
 
    PLURAL = False
 
    NONE = True
 
    SOME = False
 
    MESSAGE = {
 
        NONE: {
 
            SINGLE:
 
                "%(items)s is no longer available to you",
 
            PLURAL:
 
                "%(items)s are no longer available to you",
 
        },
 
        SOME: {
 
            SINGLE:
 
                "Only %(remainder)d of the following item remains: %(items)s",
 
            PLURAL:
 
                "Only %(remainder)d of the following items remain: %(items)s"
 
        },
 
    }
 

	
 
    @classmethod
 
    def test_flags(
 
            cls, user, products=None, product_quantities=None):
 
        ''' Evaluates all of the flag conditions on the given products.
 

	
 
        If `product_quantities` is supplied, the condition is only met if it
 
        will permit the sum of the product quantities for all of the products
 
        it covers. Otherwise, it will be met if at least one item can be
 
        accepted.
 

	
 
        If all flag conditions pass, an empty list is returned, otherwise
 
        a list is returned containing all of the products that are *not
 
        enabled*. '''
 

	
 
        if products is not None and product_quantities is not None:
 
            raise ValueError("Please specify only products or "
 
                             "product_quantities")
 
        elif products is None:
 
            products = set(i[0] for i in product_quantities)
 
            quantities = dict((product, quantity)
 
                              for product, quantity in product_quantities)
 
        elif product_quantities is None:
 
            products = set(products)
 
            quantities = {}
 

	
 
        # Get the conditions covered by the products themselves
 
        prods = (
 
            product.flagbase_set.select_subclasses()
 
            for product in products
 
        )
 
        # Get the conditions covered by their categories
 
        cats = (
 
            category.flagbase_set.select_subclasses()
 
            for category in set(product.category for product in products)
 
        )
 
    def pre_filter(cls, queryset, user):
 
        ''' Returns only the flag conditions that might be available for this
 
        user. It should hopefully reduce the number of queries that need to be
 
        executed to determine if a flag is met.
 

	
 
        if products:
 
            # Simplify the query.
 
            all_conditions = reduce(operator.or_, itertools.chain(prods, cats))
 
        else:
 
            all_conditions = []
 

	
 
        # All disable-if-false conditions on a product need to be met
 
        do_not_disable = defaultdict(lambda: True)
 
        # At least one enable-if-true condition on a product must be met
 
        do_enable = defaultdict(lambda: False)
 
        # (if either sort of condition is present)
 

	
 
        messages = {}
 

	
 
        for condition in all_conditions:
 
            cond = cls.for_condition(condition)
 
            remainder = cond.user_quantity_remaining(user)
 

	
 
            # Get all products covered by this condition, and the products
 
            # from the categories covered by this condition
 
            cond_products = condition.products.all()
 
            from_category = inventory.Product.objects.filter(
 
                category__in=condition.categories.all(),
 
            ).all()
 
            all_products = cond_products | from_category
 
            all_products = all_products.select_related("category")
 
            # Remove the products that we aren't asking about
 
            all_products = [
 
                product
 
                for product in all_products
 
                if product in products
 
            ]
 

	
 
            if quantities:
 
                consumed = sum(quantities[i] for i in all_products)
 
            else:
 
                consumed = 1
 
            met = consumed <= remainder
 

	
 
            if not met:
 
                items = ", ".join(str(product) for product in all_products)
 
                base = cls.MESSAGE[remainder == 0][len(all_products) == 1]
 
                message = base % {"items": items, "remainder": remainder}
 

	
 
            for product in all_products:
 
                if condition.is_disable_if_false:
 
                    do_not_disable[product] &= met
 
                else:
 
                    do_enable[product] |= met
 

	
 
                if not met and product not in messages:
 
                    messages[product] = message
 

	
 
        valid = {}
 
        for product in itertools.chain(do_not_disable, do_enable):
 
            if product in do_enable:
 
                # If there's an enable-if-true, we need need of those met too.
 
                # (do_not_disable will default to true otherwise)
 
                valid[product] = do_not_disable[product] and do_enable[product]
 
            elif product in do_not_disable:
 
                # If there's a disable-if-false condition, all must be met
 
                valid[product] = do_not_disable[product]
 

	
 
        error_fields = [
 
            (product, messages[product])
 
            for product in valid if not valid[product]
 
        ]
 

	
 
        return error_fields
 

	
 
    def user_quantity_remaining(self, user):
 
        If this filtration implements the same query as is_met, then you should
 
        be able to implement ``is_met()`` in terms of this.
 

	
 
        Arguments:
 

	
 
            queryset (Queryset[c]): The canditate conditions.
 

	
 
            user (User): The user for whom we're testing these conditions.
 

	
 
        Returns:
 
            Queryset[c]: A subset of the conditions that pass the pre-filter
 
                test for this user.
 

	
 
        '''
 

	
 
        # Default implementation does NOTHING.
 
        return queryset
 

	
 
    def passes_filter(self, user):
 
        ''' Returns true if the condition passes the filter '''
 

	
 
        cls = type(self.condition)
 
        qs = cls.objects.filter(pk=self.condition.id)
 
        return self.condition in self.pre_filter(qs, user)
 

	
 
    def user_quantity_remaining(self, user, filtered=False):
 
        ''' Returns the number of items covered by this flag condition the
 
        user can add to the current cart. This default implementation returns
 
        a big number if is_met() is true, otherwise 0.
 

	
 
        Either this method, or is_met() must be overridden in subclasses.
 
        '''
 

	
 
        return 99999999 if self.is_met(user) else 0
 
        return _BIG_QUANTITY if self.is_met(user, filtered) else 0
 

	
 
    def is_met(self, user):
 
    def is_met(self, user, filtered=False):
 
        ''' Returns True if this flag condition is met, otherwise returns
 
        False.
 

	
 
        Either this method, or user_quantity_remaining() must be overridden
 
        in subclasses.
 

	
 
        Arguments:
 

	
 
            user (User): The user for whom this test must be met.
 

	
 
            filter (bool): If true, this condition was part of a queryset
 
                returned by pre_filter() for this user.
 

	
 
        '''
 
        return self.user_quantity_remaining(user) > 0
 
        return self.user_quantity_remaining(user, filtered) > 0
 

	
 

	
 
class CategoryConditionController(ConditionController):
 
class IsMetByFilter(object):
 

	
 
    def __init__(self, condition):
 
        self.condition = condition
 
    def is_met(self, user, filtered=False):
 
        ''' Returns True if this flag condition is met, otherwise returns
 
        False. It determines if the condition is met by calling pre_filter
 
        with a queryset containing only self.condition. '''
 

	
 
        if filtered:
 
            return True  # Why query again?
 

	
 
        return self.passes_filter(user)
 

	
 

	
 
class RemainderSetByFilter(object):
 

	
 
    def user_quantity_remaining(self, user, filtered=True):
 
        ''' returns 0 if the date range is violated, otherwise, it will return
 
        the quantity remaining under the stock limit.
 

	
 
        The filter for this condition must add an annotation called "remainder"
 
        in order for this to work.
 
        '''
 

	
 
        if filtered:
 
            if hasattr(self.condition, "remainder"):
 
                return self.condition.remainder
 

	
 
        # Mark self.condition with a remainder
 
        qs = type(self.condition).objects.filter(pk=self.condition.id)
 
        qs = self.pre_filter(qs, user)
 

	
 
    def is_met(self, user):
 
        ''' returns True if the user has a product from a category that invokes
 
        this condition in one of their carts '''
 
        if len(qs) > 0:
 
            return qs[0].remainder
 
        else:
 
            return 0
 

	
 

	
 
class CategoryConditionController(IsMetByFilter, ConditionController):
 

	
 
    @classmethod
 
    def pre_filter(self, queryset, user):
 
        ''' Returns all of the items from queryset where the user has a
 
        product from a category invoking that item's condition in one of their
 
        carts. '''
 

	
 
        carts = commerce.Cart.objects.filter(user=user)
 
        carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED)
 
        enabling_products = inventory.Product.objects.filter(
 
            category=self.condition.enabling_category,
 
        in_user_carts = Q(
 
            enabling_category__product__productitem__cart__user=user
 
        )
 
        released = commerce.Cart.STATUS_RELEASED
 
        in_released_carts = Q(
 
            enabling_category__product__productitem__cart__status=released
 
        )
 
        products_count = commerce.ProductItem.objects.filter(
 
            cart__in=carts,
 
            product__in=enabling_products,
 
        ).count()
 
        return products_count > 0
 
        queryset = queryset.filter(in_user_carts)
 
        queryset = queryset.exclude(in_released_carts)
 

	
 
        return queryset
 

	
 

	
 
class ProductConditionController(ConditionController):
 
class ProductConditionController(IsMetByFilter, ConditionController):
 
    ''' Condition tests for ProductFlag and
 
    IncludedProductDiscount. '''
 

	
 
    def __init__(self, condition):
 
        self.condition = condition
 

	
 
    def is_met(self, user):
 
        ''' returns True if the user has a product that invokes this
 
        condition in one of their carts '''
 
    @classmethod
 
    def pre_filter(self, queryset, user):
 
        ''' Returns all of the items from queryset where the user has a
 
        product invoking that item's condition in one of their carts. '''
 

	
 
        in_user_carts = Q(enabling_products__productitem__cart__user=user)
 
        released = commerce.Cart.STATUS_RELEASED
 
        in_released_carts = Q(
 
            enabling_products__productitem__cart__status=released
 
        )
 
        queryset = queryset.filter(in_user_carts)
 
        queryset = queryset.exclude(in_released_carts)
 

	
 
        carts = commerce.Cart.objects.filter(user=user)
 
        carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED)
 
        products_count = commerce.ProductItem.objects.filter(
 
            cart__in=carts,
 
            product__in=self.condition.enabling_products.all(),
 
        ).count()
 
        return products_count > 0
 
        return queryset
 

	
 

	
 
class TimeOrStockLimitConditionController(ConditionController):
 
class TimeOrStockLimitConditionController(
 
            RemainderSetByFilter,
 
            ConditionController,
 
        ):
 
    ''' Common condition tests for TimeOrStockLimit Flag and
 
    Discount.'''
 

	
 
    def __init__(self, ceiling):
 
        self.ceiling = ceiling
 

	
 
    def user_quantity_remaining(self, user):
 
        ''' returns 0 if the date range is violated, otherwise, it will return
 
        the quantity remaining under the stock limit. '''
 

	
 
        # Test date range
 
        if not self._test_date_range():
 
            return 0
 

	
 
        return self._get_remaining_stock(user)
 
    @classmethod
 
    def pre_filter(self, queryset, user):
 
        ''' Returns all of the items from queryset where the date falls into
 
        any specified range, but not yet where the stock limit is not yet
 
        reached.'''
 

	
 
    def _test_date_range(self):
 
        now = timezone.now()
 

	
 
        if self.ceiling.start_time is not None:
 
            if now < self.ceiling.start_time:
 
                return False
 
        # Keep items with no start time, or start time not yet met.
 
        queryset = queryset.filter(Q(start_time=None) | Q(start_time__lte=now))
 
        queryset = queryset.filter(Q(end_time=None) | Q(end_time__gte=now))
 

	
 
        if self.ceiling.end_time is not None:
 
            if now > self.ceiling.end_time:
 
                return False
 
        # Filter out items that have been reserved beyond the limits
 
        quantity_or_zero = self._calculate_quantities(user)
 

	
 
        return True
 
        remainder = Case(
 
            When(limit=None, then=Value(_BIG_QUANTITY)),
 
            default=F("limit") - Sum(quantity_or_zero),
 
        )
 

	
 
    def _get_remaining_stock(self, user):
 
        ''' Returns the stock that remains under this ceiling, excluding the
 
        user's current cart. '''
 
        queryset = queryset.annotate(remainder=remainder)
 
        queryset = queryset.filter(remainder__gt=0)
 

	
 
        if self.ceiling.limit is None:
 
            return 99999999
 
        return queryset
 

	
 
        # We care about all reserved carts, but not the user's current cart
 
    @classmethod
 
    def _relevant_carts(cls, user):
 
        reserved_carts = commerce.Cart.reserved_carts()
 
        reserved_carts = reserved_carts.exclude(
 
            user=user,
 
            status=commerce.Cart.STATUS_ACTIVE,
 
        )
 

	
 
        items = self._items()
 
        items = items.filter(cart__in=reserved_carts)
 
        count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0
 

	
 
        return self.ceiling.limit - count
 
        return reserved_carts
 

	
 

	
 
class TimeOrStockLimitFlagController(
 
        TimeOrStockLimitConditionController):
 

	
 
    def _items(self):
 
        category_products = inventory.Product.objects.filter(
 
            category__in=self.ceiling.categories.all(),
 
    @classmethod
 
    def _calculate_quantities(cls, user):
 
        reserved_carts = cls._relevant_carts(user)
 

	
 
        # Calculate category lines
 
        item_cats = F('categories__product__productitem__product__category')
 
        reserved_category_products = (
 
            Q(categories=item_cats) &
 
            Q(categories__product__productitem__cart__in=reserved_carts)
 
        )
 

	
 
        # Calculate product lines
 
        reserved_products = (
 
            Q(products=F('products__productitem__product')) &
 
            Q(products__productitem__cart__in=reserved_carts)
 
        )
 

	
 
        category_quantity_in_reserved_carts = When(
 
            reserved_category_products,
 
            then="categories__product__productitem__quantity",
 
        )
 

	
 
        product_quantity_in_reserved_carts = When(
 
            reserved_products,
 
            then="products__productitem__quantity",
 
        )
 
        products = self.ceiling.products.all() | category_products
 

	
 
        product_items = commerce.ProductItem.objects.filter(
 
            product__in=products.all(),
 
        quantity_or_zero = Case(
 
            category_quantity_in_reserved_carts,
 
            product_quantity_in_reserved_carts,
 
            default=Value(0),
 
        )
 
        return product_items
 

	
 
        return quantity_or_zero
 

	
 

	
 
class TimeOrStockLimitDiscountController(TimeOrStockLimitConditionController):
 

	
 
    def _items(self):
 
        discount_items = commerce.DiscountItem.objects.filter(
 
            discount=self.ceiling,
 
    @classmethod
 
    def _calculate_quantities(cls, user):
 
        reserved_carts = cls._relevant_carts(user)
 

	
 
        quantity_in_reserved_carts = When(
 
            discountitem__cart__in=reserved_carts,
 
            then="discountitem__quantity"
 
        )
 

	
 
        quantity_or_zero = Case(
 
            quantity_in_reserved_carts,
 
            default=Value(0)
 
        )
 
        return discount_items
 

	
 
        return quantity_or_zero
 

	
 
class VoucherConditionController(ConditionController):
 

	
 
class VoucherConditionController(IsMetByFilter, ConditionController):
 
    ''' Condition test for VoucherFlag and VoucherDiscount.'''
 

	
 
    def __init__(self, condition):
 
        self.condition = condition
 
    @classmethod
 
    def pre_filter(self, queryset, user):
 
        ''' Returns all of the items from queryset where the user has entered
 
        a voucher that invokes that item's condition in one of their carts. '''
 

	
 
    def is_met(self, user):
 
        ''' returns True if the user has the given voucher attached. '''
 
        carts_count = commerce.Cart.objects.filter(
 
            user=user,
 
            vouchers=self.condition.voucher,
 
        ).count()
 
        return carts_count > 0
 
        return queryset.filter(voucher__cart__user=user)
registrasion/controllers/discount.py
Show inline comments
 
import itertools
 

	
 
from conditions import ConditionController
 
from registrasion.models import commerce
 
from registrasion.models import conditions
 

	
 
from django.db.models import Case
 
from django.db.models import F, Q
 
from django.db.models import Sum
 
from django.db.models import Value
 
from django.db.models import When
 

	
 

	
 
class DiscountAndQuantity(object):
 
    ''' Represents a discount that can be applied to a product or category
 
    for a given user.
 

	
 
    Attributes:
 

	
 
        discount (conditions.DiscountBase): The discount object that the
 
            clause arises from. A given DiscountBase can apply to multiple
 
            clauses.
 

	
...
 
@@ -29,89 +33,167 @@ class DiscountAndQuantity(object):
 

	
 
    def __init__(self, discount, clause, quantity):
 
        self.discount = discount
 
        self.clause = clause
 
        self.quantity = quantity
 

	
 
    def __repr__(self):
 
        return "(discount=%s, clause=%s, quantity=%d)" % (
 
            self.discount, self.clause, self.quantity,
 
        )
 

	
 

	
 
def available_discounts(user, categories, products):
 
    ''' Returns all discounts available to this user for the given categories
 
    and products. The discounts also list the available quantity for this user,
 
    not including products that are pending purchase. '''
 

	
 
    # discounts that match provided categories
 
    category_discounts = conditions.DiscountForCategory.objects.filter(
 
        category__in=categories
 
    )
 
    # discounts that match provided products
 
    product_discounts = conditions.DiscountForProduct.objects.filter(
 
        product__in=products
 
    )
 
    # discounts that match categories for provided products
 
    product_category_discounts = conditions.DiscountForCategory.objects.filter(
 
        category__in=(product.category for product in products)
 
    )
 
    # (Not relevant: discounts that match products in provided categories)
 

	
 
    product_discounts = product_discounts.select_related(
 
        "product",
 
        "product__category",
 
    )
 

	
 
    all_category_discounts = category_discounts | product_category_discounts
 
    all_category_discounts = all_category_discounts.select_related(
 
        "category",
 
    )
 

	
 
    # The set of all potential discounts
 
    potential_discounts = set(itertools.chain(
 
        product_discounts,
 
        all_category_discounts,
 
    ))
 

	
 
    discounts = []
 

	
 
    # Markers so that we don't need to evaluate given conditions more than once
 
    accepted_discounts = set()
 
    failed_discounts = set()
 

	
 
    for discount in potential_discounts:
 
        real_discount = conditions.DiscountBase.objects.get_subclass(
 
            pk=discount.discount.pk,
 
class DiscountController(object):
 

	
 
    @classmethod
 
    def available_discounts(cls, user, categories, products):
 
        ''' Returns all discounts available to this user for the given
 
        categories and products. The discounts also list the available quantity
 
        for this user, not including products that are pending purchase. '''
 

	
 
        filtered_clauses = cls._filtered_discounts(user, categories, products)
 

	
 
        discounts = []
 

	
 
        # Markers so that we don't need to evaluate given conditions
 
        # more than once
 
        accepted_discounts = set()
 
        failed_discounts = set()
 

	
 
        for clause in filtered_clauses:
 
            discount = clause.discount
 
            cond = ConditionController.for_condition(discount)
 

	
 
            past_use_count = clause.past_use_count
 
            if past_use_count >= clause.quantity:
 
                # This clause has exceeded its use count
 
                pass
 
            elif discount not in failed_discounts:
 
                # This clause is still available
 
                is_accepted = discount in accepted_discounts
 
                if is_accepted or cond.is_met(user, filtered=True):
 
                    # This clause is valid for this user
 
                    discounts.append(DiscountAndQuantity(
 
                        discount=discount,
 
                        clause=clause,
 
                        quantity=clause.quantity - past_use_count,
 
                    ))
 
                    accepted_discounts.add(discount)
 
                else:
 
                    # This clause is not valid for this user
 
                    failed_discounts.add(discount)
 
        return discounts
 

	
 
    @classmethod
 
    def _filtered_discounts(cls, user, categories, products):
 
        '''
 

	
 
        Returns:
 
            Sequence[discountbase]: All discounts that passed the filter
 
            function.
 

	
 
        '''
 

	
 
        types = list(ConditionController._controllers())
 
        discounttypes = [
 
            i for i in types if issubclass(i, conditions.DiscountBase)
 
        ]
 

	
 
        # discounts that match provided categories
 
        category_discounts = conditions.DiscountForCategory.objects.filter(
 
            category__in=categories
 
        )
 
        cond = ConditionController.for_condition(real_discount)
 

	
 
        # Count the past uses of the given discount item.
 
        # If this user has exceeded the limit for the clause, this clause
 
        # is not available any more.
 
        past_uses = commerce.DiscountItem.objects.filter(
 
            cart__user=user,
 
            cart__status=commerce.Cart.STATUS_PAID,  # Only past carts count
 
            discount=real_discount,
 
        # discounts that match provided products
 
        product_discounts = conditions.DiscountForProduct.objects.filter(
 
            product__in=products
 
        )
 
        agg = past_uses.aggregate(Sum("quantity"))
 
        past_use_count = agg["quantity__sum"]
 
        if past_use_count is None:
 
            past_use_count = 0
 

	
 
        if past_use_count >= discount.quantity:
 
            # This clause has exceeded its use count
 
            pass
 
        elif real_discount not in failed_discounts:
 
            # This clause is still available
 
            if real_discount in accepted_discounts or cond.is_met(user):
 
                # This clause is valid for this user
 
                discounts.append(DiscountAndQuantity(
 
                    discount=real_discount,
 
                    clause=discount,
 
                    quantity=discount.quantity - past_use_count,
 
                ))
 
                accepted_discounts.add(real_discount)
 
            else:
 
                # This clause is not valid for this user
 
                failed_discounts.add(real_discount)
 
    return discounts
 
        # discounts that match categories for provided products
 
        product_category_discounts = conditions.DiscountForCategory.objects
 
        product_category_discounts = product_category_discounts.filter(
 
            category__in=(product.category for product in products)
 
        )
 
        # (Not relevant: discounts that match products in provided categories)
 

	
 
        product_discounts = product_discounts.select_related(
 
            "product",
 
            "product__category",
 
        )
 

	
 
        all_category_discounts = (
 
            category_discounts | product_category_discounts
 
        )
 
        all_category_discounts = all_category_discounts.select_related(
 
            "category",
 
        )
 

	
 
        valid_discounts = conditions.DiscountBase.objects.filter(
 
            Q(discountforproduct__in=product_discounts) |
 
            Q(discountforcategory__in=all_category_discounts)
 
        )
 

	
 
        all_subsets = []
 

	
 
        for discounttype in discounttypes:
 
            discounts = discounttype.objects.filter(id__in=valid_discounts)
 
            ctrl = ConditionController.for_type(discounttype)
 
            discounts = ctrl.pre_filter(discounts, user)
 
            all_subsets.append(discounts)
 

	
 
        filtered_discounts = list(itertools.chain(*all_subsets))
 

	
 
        # Map from discount key to itself
 
        # (contains annotations needed in the future)
 
        from_filter = dict((i.id, i) for i in filtered_discounts)
 

	
 
        clause_sets = (
 
            product_discounts.filter(discount__in=filtered_discounts),
 
            all_category_discounts.filter(discount__in=filtered_discounts),
 
        )
 

	
 
        clause_sets = (
 
            cls._annotate_with_past_uses(i, user) for i in clause_sets
 
        )
 

	
 
        # The set of all potential discount clauses
 
        discount_clauses = set(itertools.chain(*clause_sets))
 

	
 
        # Replace discounts with the filtered ones
 
        # These are the correct subclasses (saves query later on), and have
 
        # correct annotations from filters if necessary.
 
        for clause in discount_clauses:
 
            clause.discount = from_filter[clause.discount.id]
 

	
 
        return discount_clauses
 

	
 
    @classmethod
 
    def _annotate_with_past_uses(cls, queryset, user):
 
        ''' Annotates the queryset with a usage count for that discount claus
 
        by the given user. '''
 

	
 
        if queryset.model == conditions.DiscountForCategory:
 
            matches = (
 
                Q(category=F('discount__discountitem__product__category'))
 
            )
 
        elif queryset.model == conditions.DiscountForProduct:
 
            matches = (
 
                Q(product=F('discount__discountitem__product'))
 
            )
 

	
 
        in_carts = (
 
            Q(discount__discountitem__cart__user=user) &
 
            Q(discount__discountitem__cart__status=commerce.Cart.STATUS_PAID)
 
        )
 

	
 
        past_use_quantity = When(
 
            in_carts & matches,
 
            then="discount__discountitem__quantity",
 
        )
 

	
 
        past_use_quantity_or_zero = Case(
 
            past_use_quantity,
 
            default=Value(0),
 
        )
 

	
 
        queryset = queryset.annotate(
 
            past_use_count=Sum(past_use_quantity_or_zero)
 
        )
 
        return queryset
registrasion/controllers/flag.py
Show inline comments
 
new file 100644
 
import itertools
 
import operator
 

	
 
from collections import defaultdict
 
from collections import namedtuple
 
from django.db.models import Count
 
from django.db.models import Q
 

	
 
from .conditions import ConditionController
 

	
 
from registrasion.models import conditions
 
from registrasion.models import inventory
 

	
 

	
 
class FlagController(object):
 

	
 
    SINGLE = True
 
    PLURAL = False
 
    NONE = True
 
    SOME = False
 
    MESSAGE = {
 
        NONE: {
 
            SINGLE:
 
                "%(items)s is no longer available to you",
 
            PLURAL:
 
                "%(items)s are no longer available to you",
 
        },
 
        SOME: {
 
            SINGLE:
 
                "Only %(remainder)d of the following item remains: %(items)s",
 
            PLURAL:
 
                "Only %(remainder)d of the following items remain: %(items)s"
 
        },
 
    }
 

	
 
    @classmethod
 
    def test_flags(
 
            cls, user, products=None, product_quantities=None):
 
        ''' Evaluates all of the flag conditions on the given products.
 

	
 
        If `product_quantities` is supplied, the condition is only met if it
 
        will permit the sum of the product quantities for all of the products
 
        it covers. Otherwise, it will be met if at least one item can be
 
        accepted.
 

	
 
        If all flag conditions pass, an empty list is returned, otherwise
 
        a list is returned containing all of the products that are *not
 
        enabled*. '''
 

	
 
        print "GREPME: test_flags()"
 

	
 
        if products is not None and product_quantities is not None:
 
            raise ValueError("Please specify only products or "
 
                             "product_quantities")
 
        elif products is None:
 
            products = set(i[0] for i in product_quantities)
 
            quantities = dict((product, quantity)
 
                              for product, quantity in product_quantities)
 
        elif product_quantities is None:
 
            products = set(products)
 
            quantities = {}
 

	
 
        if products:
 
            # Simplify the query.
 
            all_conditions = cls._filtered_flags(user, products)
 
        else:
 
            all_conditions = []
 

	
 
        # All disable-if-false conditions on a product need to be met
 
        do_not_disable = defaultdict(lambda: True)
 
        # At least one enable-if-true condition on a product must be met
 
        do_enable = defaultdict(lambda: False)
 
        # (if either sort of condition is present)
 

	
 
        # Count the number of conditions for a product
 
        dif_count = defaultdict(int)
 
        eit_count = defaultdict(int)
 

	
 
        messages = {}
 

	
 
        for condition in all_conditions:
 
            cond = ConditionController.for_condition(condition)
 
            remainder = cond.user_quantity_remaining(user, filtered=True)
 

	
 
            # Get all products covered by this condition, and the products
 
            # from the categories covered by this condition
 

	
 
            ids = [product.id for product in products]
 
            all_products = inventory.Product.objects.filter(id__in=ids)
 
            cond = (
 
                Q(flagbase_set=condition) |
 
                Q(category__in=condition.categories.all())
 
            )
 

	
 
            all_products = all_products.filter(cond)
 
            all_products = all_products.select_related("category")
 

	
 
            if quantities:
 
                consumed = sum(quantities[i] for i in all_products)
 
            else:
 
                consumed = 1
 
            met = consumed <= remainder
 

	
 
            if not met:
 
                items = ", ".join(str(product) for product in all_products)
 
                base = cls.MESSAGE[remainder == 0][len(all_products) == 1]
 
                message = base % {"items": items, "remainder": remainder}
 

	
 
            for product in all_products:
 
                if condition.is_disable_if_false:
 
                    do_not_disable[product] &= met
 
                    dif_count[product] += 1
 
                else:
 
                    do_enable[product] |= met
 
                    eit_count[product] += 1
 

	
 
                if not met and product not in messages:
 
                    messages[product] = message
 

	
 
        total_flags = FlagCounter.count()
 

	
 
        valid = {}
 

	
 
        # the problem is that now, not every condition falls into
 
        # do_not_disable or do_enable '''
 
        # You should look into this, chris :)
 

	
 
        for product in products:
 
            if quantities:
 
                if quantities[product] == 0:
 
                    continue
 

	
 
            f = total_flags.get(product)
 
            if f.dif > 0 and f.dif != dif_count[product]:
 
                do_not_disable[product] = False
 
                if product not in messages:
 
                    messages[product] = "Some disable-if-false " \
 
                                        "conditions were not met"
 
            if f.eit > 0 and product not in do_enable:
 
                do_enable[product] = False
 
                if product not in messages:
 
                    messages[product] = "Some enable-if-true " \
 
                                        "conditions were not met"
 

	
 
        for product in itertools.chain(do_not_disable, do_enable):
 
            f = total_flags.get(product)
 
            if product in do_enable:
 
                # If there's an enable-if-true, we need need of those met too.
 
                # (do_not_disable will default to true otherwise)
 
                valid[product] = do_not_disable[product] and do_enable[product]
 
            elif product in do_not_disable:
 
                # If there's a disable-if-false condition, all must be met
 
                valid[product] = do_not_disable[product]
 

	
 
        error_fields = [
 
            (product, messages[product])
 
            for product in valid if not valid[product]
 
        ]
 

	
 
        return error_fields
 

	
 
    @classmethod
 
    def _filtered_flags(cls, user, products):
 
        '''
 

	
 
        Returns:
 
            Sequence[flagbase]: All flags that passed the filter function.
 

	
 
        '''
 

	
 
        types = list(ConditionController._controllers())
 
        flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)]
 

	
 
        # Get all flags for the products and categories.
 
        prods = (
 
            product.flagbase_set.all()
 
            for product in products
 
        )
 
        cats = (
 
            category.flagbase_set.all()
 
            for category in set(product.category for product in products)
 
        )
 
        all_flags = reduce(operator.or_, itertools.chain(prods, cats))
 

	
 
        all_subsets = []
 

	
 
        for flagtype in flagtypes:
 
            flags = flagtype.objects.filter(id__in=all_flags)
 
            ctrl = ConditionController.for_type(flagtype)
 
            flags = ctrl.pre_filter(flags, user)
 
            all_subsets.append(flags)
 

	
 
        return itertools.chain(*all_subsets)
 

	
 

	
 
ConditionAndRemainder = namedtuple(
 
    "ConditionAndRemainder",
 
    (
 
        "condition",
 
        "remainder",
 
    ),
 
)
 

	
 

	
 
_FlagCounter = namedtuple(
 
    "_FlagCounter",
 
    (
 
        "products",
 
        "categories",
 
    ),
 
)
 

	
 

	
 
_ConditionsCount = namedtuple(
 
    "ConditionsCount",
 
    (
 
        "dif",
 
        "eit",
 
    ),
 
)
 

	
 

	
 
# TODO: this should be cacheable.
 
class FlagCounter(_FlagCounter):
 

	
 
    @classmethod
 
    def count(cls):
 
        # Get the count of how many conditions should exist per product
 
        flagbases = conditions.FlagBase.objects
 

	
 
        types = (
 
            conditions.FlagBase.ENABLE_IF_TRUE,
 
            conditions.FlagBase.DISABLE_IF_FALSE,
 
        )
 
        keys = ("eit", "dif")
 
        flags = [
 
            flagbases.filter(
 
                condition=condition_type
 
            ).values(
 
                'products', 'categories'
 
            ).annotate(
 
                count=Count('id')
 
            )
 
            for condition_type in types
 
        ]
 

	
 
        cats = defaultdict(lambda: defaultdict(int))
 
        prod = defaultdict(lambda: defaultdict(int))
 

	
 
        for key, flagcounts in zip(keys, flags):
 
            for row in flagcounts:
 
                if row["products"] is not None:
 
                    prod[row["products"]][key] = row["count"]
 
                if row["categories"] is not None:
 
                    cats[row["categories"]][key] = row["count"]
 

	
 
        return cls(products=prod, categories=cats)
 

	
 
    def get(self, product):
 
        p = self.products[product.id]
 
        c = self.categories[product.category.id]
 
        eit = p["eit"] + c["eit"]
 
        dif = p["dif"] + c["dif"]
 
        return _ConditionsCount(dif=dif, eit=eit)
registrasion/controllers/invoice.py
Show inline comments
...
 
@@ -20,24 +20,25 @@ class InvoiceController(ForId, object):
 

	
 
    def __init__(self, invoice):
 
        self.invoice = invoice
 
        self.update_status()
 
        self.update_validity()  # Make sure this invoice is up-to-date
 

	
 
    @classmethod
 
    def for_cart(cls, cart):
 
        ''' Returns an invoice object for a given cart at its current revision.
 
        If such an invoice does not exist, the cart is validated, and if valid,
 
        an invoice is generated.'''
 

	
 
        cart.refresh_from_db()
 
        try:
 
            invoice = commerce.Invoice.objects.exclude(
 
                status=commerce.Invoice.STATUS_VOID,
 
            ).get(
 
                cart=cart,
 
                cart_revision=cart.revision,
 
            )
 
        except ObjectDoesNotExist:
 
            cart_controller = CartController(cart)
 
            cart_controller.validate_cart()  # Raises ValidationError on fail.
 

	
 
            cls.void_all_invoices(cart)
...
 
@@ -65,76 +66,94 @@ class InvoiceController(ForId, object):
 
            )
 
        if condition.percentage is not None:
 
            value = item.product.price * (condition.percentage / 100)
 
        else:
 
            value = condition.price
 
        return value
 

	
 
    @classmethod
 
    @transaction.atomic
 
    def _generate(cls, cart):
 
        ''' Generates an invoice for the given cart. '''
 

	
 
        cart.refresh_from_db()
 

	
 
        issued = timezone.now()
 
        reservation_limit = cart.reservation_duration + cart.time_last_updated
 
        # Never generate a due time that is before the issue time
 
        due = max(issued, reservation_limit)
 

	
 
        # Get the invoice recipient
 
        profile = people.AttendeeProfileBase.objects.get_subclass(
 
            id=cart.user.attendee.attendeeprofilebase.id,
 
        )
 
        recipient = profile.invoice_recipient()
 
        invoice = commerce.Invoice.objects.create(
 
            user=cart.user,
 
            cart=cart,
 
            cart_revision=cart.revision,
 
            status=commerce.Invoice.STATUS_UNPAID,
 
            value=Decimal(),
 
            issue_time=issued,
 
            due_time=due,
 
            recipient=recipient,
 
        )
 

	
 
        product_items = commerce.ProductItem.objects.filter(cart=cart)
 
        product_items = product_items.select_related(
 
            "product",
 
            "product__category",
 
        )
 

	
 
        if len(product_items) == 0:
 
            raise ValidationError("Your cart is empty.")
 

	
 
        product_items = product_items.order_by(
 
            "product__category__order", "product__order"
 
        )
 

	
 
        discount_items = commerce.DiscountItem.objects.filter(cart=cart)
 
        discount_items = discount_items.select_related(
 
            "discount",
 
            "product",
 
            "product__category",
 
        )
 

	
 
        line_items = []
 

	
 
        invoice_value = Decimal()
 
        for item in product_items:
 
            product = item.product
 
            line_item = commerce.LineItem.objects.create(
 
            line_item = commerce.LineItem(
 
                invoice=invoice,
 
                description="%s - %s" % (product.category.name, product.name),
 
                quantity=item.quantity,
 
                price=product.price,
 
                product=product,
 
            )
 
            line_items.append(line_item)
 
            invoice_value += line_item.quantity * line_item.price
 

	
 
        for item in discount_items:
 
            line_item = commerce.LineItem.objects.create(
 
            line_item = commerce.LineItem(
 
                invoice=invoice,
 
                description=item.discount.description,
 
                quantity=item.quantity,
 
                price=cls.resolve_discount_value(item) * -1,
 
                product=item.product,
 
            )
 
            line_items.append(line_item)
 
            invoice_value += line_item.quantity * line_item.price
 

	
 
        commerce.LineItem.objects.bulk_create(line_items)
 

	
 
        invoice.value = invoice_value
 

	
 
        invoice.save()
 

	
 
        return invoice
 

	
 
    def can_view(self, user=None, access_code=None):
 
        ''' Returns true if the accessing user is allowed to view this invoice,
 
        or if the given access code matches this invoice's user's access code.
 
        '''
 

	
 
        if user == self.invoice.user:
...
 
@@ -242,24 +261,27 @@ class InvoiceController(ForId, object):
 
        self.invoice.status = commerce.Invoice.STATUS_REFUNDED
 
        self.invoice.save()
 

	
 
    def _mark_void(self):
 
        ''' Marks the invoice as refunded, and updates the attached cart if
 
        necessary. '''
 
        self.invoice.status = commerce.Invoice.STATUS_VOID
 
        self.invoice.save()
 

	
 
    def _invoice_matches_cart(self):
 
        ''' Returns true if there is no cart, or if the revision of this
 
        invoice matches the current revision of the cart. '''
 

	
 
        self._refresh()
 

	
 
        cart = self.invoice.cart
 
        if not cart:
 
            return True
 

	
 
        return cart.revision == self.invoice.cart_revision
 

	
 
    def update_validity(self):
 
        ''' Voids this invoice if the cart it is attached to has updated. '''
 
        if not self._invoice_matches_cart():
 
            self.void()
 

	
 
    def void(self):
registrasion/controllers/product.py
Show inline comments
 
import itertools
 

	
 
from django.db.models import Case
 
from django.db.models import F, Q
 
from django.db.models import Sum
 
from django.db.models import When
 
from django.db.models import Value
 

	
 
from registrasion.models import commerce
 
from registrasion.models import inventory
 

	
 
from category import CategoryController
 
from conditions import ConditionController
 
from .category import CategoryController
 
from .flag import FlagController
 

	
 

	
 
class ProductController(object):
 

	
 
    def __init__(self, product):
 
        self.product = product
 

	
 
    @classmethod
 
    def available_products(cls, user, category=None, products=None):
 
        ''' Returns a list of all of the products that are available per
 
        flag conditions from the given categories.
 
        TODO: refactor so that all conditions are tested here and
 
        can_add_with_flags calls this method. '''
 
        flag conditions from the given categories. '''
 
        if category is None and products is None:
 
            raise ValueError("You must provide products or a category")
 

	
 
        if category is not None:
 
            all_products = inventory.Product.objects.filter(category=category)
 
            all_products = all_products.select_related("category")
 
        else:
 
            all_products = []
 

	
 
        if products is not None:
 
            all_products = set(itertools.chain(all_products, products))
 

	
 
        cat_quants = dict(
 
            (
 
                category,
 
                CategoryController(category).user_quantity_remaining(user),
 
            )
 
            for category in set(product.category for product in all_products)
 
        )
 
        categories = set(product.category for product in all_products)
 
        r = CategoryController.attach_user_remainders(user, categories)
 
        cat_quants = dict((c, c) for c in r)
 

	
 
        r = ProductController.attach_user_remainders(user, all_products)
 
        prod_quants = dict((p, p) for p in r)
 

	
 
        passed_limits = set(
 
            product
 
            for product in all_products
 
            if cat_quants[product.category] > 0
 
            if cls(product).user_quantity_remaining(user) > 0
 
            if cat_quants[product.category].remainder > 0
 
            if prod_quants[product].remainder > 0
 
        )
 

	
 
        failed_and_messages = ConditionController.test_flags(
 
        failed_and_messages = FlagController.test_flags(
 
            user, products=passed_limits
 
        )
 
        failed_conditions = set(i[0] for i in failed_and_messages)
 

	
 
        out = list(passed_limits - failed_conditions)
 
        out.sort(key=lambda product: product.order)
 

	
 
        return out
 

	
 
    def user_quantity_remaining(self, user):
 
        ''' Returns the quantity of this product that the user add in the
 
        current cart. '''
 
    @classmethod
 
    def attach_user_remainders(cls, user, products):
 
        '''
 

	
 
        prod_limit = self.product.limit_per_user
 
        Return:
 
            queryset(inventory.Product): A queryset containing items from
 
            ``product``, with an extra attribute -- remainder = the amount of
 
            this item that is remaining.
 
        '''
 

	
 
        if prod_limit is None:
 
            # Don't need to run the remaining queries
 
            return 999999  # We can do better
 
        ids = [product.id for product in products]
 
        products = inventory.Product.objects.filter(id__in=ids)
 

	
 
        cart_filter = (
 
            Q(productitem__cart__user=user) &
 
            Q(productitem__cart__status=commerce.Cart.STATUS_PAID)
 
        )
 

	
 
        carts = commerce.Cart.objects.filter(
 
            user=user,
 
            status=commerce.Cart.STATUS_PAID,
 
        quantity = When(
 
            cart_filter,
 
            then='productitem__quantity'
 
        )
 

	
 
        items = commerce.ProductItem.objects.filter(
 
            cart__in=carts,
 
            product=self.product,
 
        quantity_or_zero = Case(
 
            quantity,
 
            default=Value(0),
 
        )
 

	
 
        prod_count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0
 
        remainder = Case(
 
            When(limit_per_user=None, then=Value(99999999)),
 
            default=F('limit_per_user') - Sum(quantity_or_zero),
 
        )
 

	
 
        products = products.annotate(remainder=remainder)
 

	
 
        return products
 

	
 
    def user_quantity_remaining(self, user):
 
        ''' Returns the quantity of this product that the user add in the
 
        current cart. '''
 

	
 
        with_remainders = self.attach_user_remainders(user, [self.product])
 

	
 
        return prod_limit - prod_count
 
        return with_remainders[0].remainder
registrasion/templatetags/registrasion_tags.py
Show inline comments
 
from registrasion.models import commerce
 
from registrasion.models import inventory
 
from registrasion.controllers.category import CategoryController
 

	
 
from collections import namedtuple
 
from django import template
 
from django.db.models import Case
 
from django.db.models import Q
 
from django.db.models import Sum
 
from django.db.models import When
 
from django.db.models import Value
 

	
 
register = template.Library()
 

	
 
_ProductAndQuantity = namedtuple("ProductAndQuantity", ["product", "quantity"])
 

	
 

	
 
class ProductAndQuantity(_ProductAndQuantity):
 
    ''' Class that holds a product and a quantity.
 

	
 
    Attributes:
 
        product (models.inventory.Product)
 

	
...
 
@@ -90,38 +94,51 @@ def items_purchased(context, category=None):
 
    ''' Aggregates the items that this user has purchased.
 

	
 
    Arguments:
 
        category (Optional[models.inventory.Category]): the category of items
 
            to restrict to.
 

	
 
    Returns:
 
        [ProductAndQuantity, ...]: A list of product-quantity pairs,
 
            aggregating like products from across multiple invoices.
 

	
 
    '''
 

	
 
    all_items = commerce.ProductItem.objects.filter(
 
        cart__user=context.request.user,
 
        cart__status=commerce.Cart.STATUS_PAID,
 
    ).select_related("product", "product__category")
 
    in_cart = (
 
        Q(productitem__cart__user=context.request.user) &
 
        Q(productitem__cart__status=commerce.Cart.STATUS_PAID)
 
    )
 

	
 
    quantities_in_cart = When(
 
        in_cart,
 
        then="productitem__quantity",
 
    )
 

	
 
    quantities_or_zero = Case(
 
        quantities_in_cart,
 
        default=Value(0),
 
    )
 

	
 
    products = inventory.Product.objects
 

	
 
    if category:
 
        all_items = all_items.filter(product__category=category)
 
        products = products.filter(category=category)
 

	
 
    products = products.select_related("category")
 
    products = products.annotate(quantity=Sum(quantities_or_zero))
 
    products = products.filter(quantity__gt=0)
 

	
 
    pq = all_items.values("product").annotate(quantity=Sum("quantity")).all()
 
    products = inventory.Product.objects.all()
 
    out = []
 
    for item in pq:
 
        prod = products.get(pk=item["product"])
 
        out.append(ProductAndQuantity(prod, item["quantity"]))
 
    for prod in products:
 
        out.append(ProductAndQuantity(prod, prod.quantity))
 
    return out
 

	
 

	
 
@register.filter
 
def multiply(value, arg):
 
    ''' Multiplies value by arg.
 

	
 
    This is useful when displaying invoices, as it lets you multiply the
 
    quantity by the unit value.
 

	
 
    Arguments:
 

	
registrasion/tests/test_cart.py
Show inline comments
...
 
@@ -17,25 +17,25 @@ from registrasion.controllers.product import ProductController
 
from controller_helpers import TestingCartController
 
from patch_datetime import SetTimeMixin
 

	
 
UTC = pytz.timezone('UTC')
 

	
 

	
 
class RegistrationCartTestCase(SetTimeMixin, TestCase):
 

	
 
    def setUp(self):
 
        super(RegistrationCartTestCase, self).setUp()
 

	
 
    def tearDown(self):
 
        if False:
 
        if True:
 
            # If you're seeing segfaults in tests, enable this.
 
            call_command(
 
                'flush',
 
                verbosity=0,
 
                interactive=False,
 
                reset_sequences=False,
 
                allow_cascade=False,
 
                inhibit_post_migrate=False
 
            )
 

	
 
        super(RegistrationCartTestCase, self).tearDown()
 

	
registrasion/tests/test_ceilings.py
Show inline comments
 
import datetime
 
import pytz
 

	
 
from django.core.exceptions import ValidationError
 

	
 
from controller_helpers import TestingCartController
 
from test_cart import RegistrationCartTestCase
 

	
 
from registrasion.controllers.discount import DiscountController
 
from registrasion.controllers.product import ProductController
 
from registrasion.models import commerce
 
from registrasion.models import conditions
 

	
 
UTC = pytz.timezone('UTC')
 

	
 

	
 
class CeilingsTestCases(RegistrationCartTestCase):
 

	
 
    def test_add_to_cart_ceiling_limit(self):
 
        self.make_ceiling("Limit ceiling", limit=9)
 
        self.__add_to_cart_test()
 

	
...
 
@@ -126,24 +128,61 @@ class CeilingsTestCases(RegistrationCartTestCase):
 
        # Unpaid cart within reservation window
 
        second_cart.add_to_cart(self.PROD_1, 1)
 
        with self.assertRaises(ValidationError):
 
            first_cart.validate_cart()
 

	
 
        # Paid cart outside the reservation window
 

	
 
        second_cart.next_cart()
 
        self.add_timedelta(self.RESERVATION + datetime.timedelta(seconds=1))
 
        with self.assertRaises(ValidationError):
 
            first_cart.validate_cart()
 

	
 
    def test_discount_ceiling_aggregates_products(self):
 
        # Create two carts, add 1xprod_1 to each. Ceiling should disappear
 
        # after second.
 
        self.make_discount_ceiling(
 
            "Multi-product limit discount ceiling",
 
            limit=2,
 
        )
 
        for i in xrange(2):
 
            cart = TestingCartController.for_user(self.USER_1)
 
            cart.add_to_cart(self.PROD_1, 1)
 
            cart.next_cart()
 

	
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_1],
 
        )
 

	
 
        self.assertEqual(0, len(discounts))
 

	
 
    def test_flag_ceiling_aggregates_products(self):
 
        # Create two carts, add 1xprod_1 to each. Ceiling should disappear
 
        # after second.
 
        self.make_ceiling("Multi-product limit ceiling", limit=2)
 

	
 
        for i in xrange(2):
 
            cart = TestingCartController.for_user(self.USER_1)
 
            cart.add_to_cart(self.PROD_1, 1)
 
            cart.next_cart()
 

	
 
        products = ProductController.available_products(
 
            self.USER_1,
 
            products=[self.PROD_1],
 
        )
 

	
 
        self.assertEqual(0, len(products))
 

	
 
    def test_items_released_from_ceiling_by_refund(self):
 
        self.make_ceiling("Limit ceiling", limit=1)
 

	
 
        first_cart = TestingCartController.for_user(self.USER_1)
 
        first_cart.add_to_cart(self.PROD_1, 1)
 

	
 
        first_cart.next_cart()
 

	
 
        second_cart = TestingCartController.for_user(self.USER_2)
 
        with self.assertRaises(ValidationError):
 
            second_cart.add_to_cart(self.PROD_1, 1)
 

	
registrasion/tests/test_discount.py
Show inline comments
 
import pytz
 

	
 
from decimal import Decimal
 

	
 
from registrasion.models import commerce
 
from registrasion.models import conditions
 
from registrasion.controllers import discount
 
from registrasion.controllers.discount import DiscountController
 
from controller_helpers import TestingCartController
 

	
 
from test_cart import RegistrationCartTestCase
 

	
 
UTC = pytz.timezone('UTC')
 

	
 

	
 
class DiscountTestCase(RegistrationCartTestCase):
 

	
 
    @classmethod
 
    def add_discount_prod_1_includes_prod_2(
 
            cls,
...
 
@@ -234,173 +234,224 @@ class DiscountTestCase(RegistrationCartTestCase):
 

	
 
        # Both users should be able to apply the same discount
 
        # in the same way
 
        for user in (self.USER_1, self.USER_2):
 
            cart = TestingCartController.for_user(user)
 
            cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 
            cart.add_to_cart(self.PROD_3, 1)
 

	
 
            discount_items = list(cart.cart.discountitem_set.all())
 
            # The discount is applied.
 
            self.assertEqual(1, len(discount_items))
 

	
 
    # Tests for the discount.available_discounts enumerator
 
    # Tests for the DiscountController.available_discounts enumerator
 
    def test_enumerate_no_discounts_for_no_input(self):
 
        discounts = discount.available_discounts(self.USER_1, [], [])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [],
 
        )
 
        self.assertEqual(0, len(discounts))
 

	
 
    def test_enumerate_no_discounts_if_condition_not_met(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=1)
 

	
 
        discounts = discount.available_discounts(
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_3],
 
        )
 
        self.assertEqual(0, len(discounts))
 

	
 
        discounts = discount.available_discounts(self.USER_1, [self.CAT_2], [])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [self.CAT_2],
 
            [],
 
        )
 
        self.assertEqual(0, len(discounts))
 

	
 
    def test_category_discount_appears_once_if_met_twice(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=1)
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 

	
 
        discounts = discount.available_discounts(
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [self.CAT_2],
 
            [self.PROD_3],
 
        )
 
        self.assertEqual(1, len(discounts))
 

	
 
    def test_category_discount_appears_with_category(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=1)
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 

	
 
        discounts = discount.available_discounts(self.USER_1, [self.CAT_2], [])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [self.CAT_2],
 
            [],
 
        )
 
        self.assertEqual(1, len(discounts))
 

	
 
    def test_category_discount_appears_with_product(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=1)
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 

	
 
        discounts = discount.available_discounts(
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_3],
 
        )
 
        self.assertEqual(1, len(discounts))
 

	
 
    def test_category_discount_appears_once_with_two_valid_product(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=1)
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 

	
 
        discounts = discount.available_discounts(
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_3, self.PROD_4]
 
        )
 
        self.assertEqual(1, len(discounts))
 

	
 
    def test_product_discount_appears_with_product(self):
 
        self.add_discount_prod_1_includes_prod_2(quantity=1)
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 

	
 
        discounts = discount.available_discounts(
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_2],
 
        )
 
        self.assertEqual(1, len(discounts))
 

	
 
    def test_product_discount_does_not_appear_with_category(self):
 
        self.add_discount_prod_1_includes_prod_2(quantity=1)
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 

	
 
        discounts = discount.available_discounts(self.USER_1, [self.CAT_1], [])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [self.CAT_1],
 
            [],
 
        )
 
        self.assertEqual(0, len(discounts))
 

	
 
    def test_discount_quantity_is_correct_before_first_purchase(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=2)
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 
        cart.add_to_cart(self.PROD_3, 1)  # Exhaust the quantity
 

	
 
        discounts = discount.available_discounts(self.USER_1, [self.CAT_2], [])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [self.CAT_2],
 
            [],
 
        )
 
        self.assertEqual(2, discounts[0].quantity)
 

	
 
        cart.next_cart()
 

	
 
    def test_discount_quantity_is_correct_after_first_purchase(self):
 
        self.test_discount_quantity_is_correct_before_first_purchase()
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_3, 1)  # Exhaust the quantity
 

	
 
        discounts = discount.available_discounts(self.USER_1, [self.CAT_2], [])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [self.CAT_2],
 
            [],
 
        )
 
        self.assertEqual(1, discounts[0].quantity)
 

	
 
        cart.next_cart()
 

	
 
    def test_discount_is_gone_after_quantity_exhausted(self):
 
        self.test_discount_quantity_is_correct_after_first_purchase()
 
        discounts = discount.available_discounts(self.USER_1, [self.CAT_2], [])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [self.CAT_2],
 
            [],
 
        )
 
        self.assertEqual(0, len(discounts))
 

	
 
    def test_product_discount_enabled_twice_appears_twice(self):
 
        self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=2)
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 
        discounts = discount.available_discounts(
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_3, self.PROD_4],
 
        )
 
        self.assertEqual(2, len(discounts))
 

	
 
    def test_product_discount_applied_on_different_invoices(self):
 
        # quantity=1 means "quantity per product"
 
        self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=1)
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_3, self.PROD_4],
 
        )
 
        self.assertEqual(2, len(discounts))
 
        # adding one of PROD_3 should make it no longer an available discount.
 
        cart.add_to_cart(self.PROD_3, 1)
 
        cart.next_cart()
 

	
 
        # should still have (and only have) the discount for prod_4
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_3, self.PROD_4],
 
        )
 
        self.assertEqual(1, len(discounts))
 

	
 
    def test_discounts_are_released_by_refunds(self):
 
        self.add_discount_prod_1_includes_prod_2(quantity=2)
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 
        discounts = discount.available_discounts(
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_2],
 
        )
 
        self.assertEqual(1, len(discounts))
 

	
 
        cart.next_cart()
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_2, 2)  # The discount will be exhausted
 

	
 
        cart.next_cart()
 

	
 
        discounts = discount.available_discounts(
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_2],
 
        )
 
        self.assertEqual(0, len(discounts))
 

	
 
        cart.cart.status = commerce.Cart.STATUS_RELEASED
 
        cart.cart.save()
 

	
 
        discounts = discount.available_discounts(
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_2],
 
        )
 
        self.assertEqual(1, len(discounts))
registrasion/util.py
Show inline comments
...
 
@@ -16,12 +16,42 @@ def generate_access_code():
 
    return get_random_string(length=length, allowed_chars=chars)
 

	
 

	
 
def all_arguments_optional(ntcls):
 
    ''' Takes a namedtuple derivative and makes all of the arguments optional.
 
    '''
 

	
 
    ntcls.__new__.__defaults__ = (
 
        (None,) * len(ntcls._fields)
 
    )
 

	
 
    return ntcls
 

	
 

	
 
def lazy(function, *args, **kwargs):
 
    ''' Produces a callable so that functions can be lazily evaluated in
 
    templates.
 

	
 
    Arguments:
 

	
 
        function (callable): The function to call at evaluation time.
 

	
 
        args: Positional arguments, passed directly to ``function``.
 

	
 
        kwargs: Keyword arguments, passed directly to ``function``.
 

	
 
    Return:
 

	
 
        callable: A callable that will evaluate a call to ``function`` with
 
            the specified arguments.
 

	
 
    '''
 

	
 
    NOT_EVALUATED = object()
 
    retval = [NOT_EVALUATED]
 

	
 
    def evaluate():
 
        if retval[0] is NOT_EVALUATED:
 
            retval[0] = function(*args, **kwargs)
 
        return retval[0]
 

	
 
    return evaluate
registrasion/views.py
Show inline comments
 
import sys
 

	
 
from registrasion import forms
 
from registrasion import util
 
from registrasion.models import commerce
 
from registrasion.models import inventory
 
from registrasion.models import people
 
from registrasion.controllers import discount
 
from registrasion.controllers.discount import DiscountController
 
from registrasion.controllers.cart import CartController
 
from registrasion.controllers.credit_note import CreditNoteController
 
from registrasion.controllers.invoice import InvoiceController
 
from registrasion.controllers.product import ProductController
 
from registrasion.exceptions import CartValidationError
 

	
 
from collections import namedtuple
 

	
 
from django.conf import settings
 
from django.contrib.auth.decorators import login_required
 
from django.contrib.auth.decorators import user_passes_test
 
from django.contrib import messages
...
 
@@ -172,51 +172,53 @@ def guided_registration(request):
 

	
 
        available_products = set(ProductController.available_products(
 
            request.user,
 
            products=all_products,
 
        ))
 

	
 
        if len(available_products) == 0:
 
            # We've filled in every category
 
            attendee.completed_registration = True
 
            attendee.save()
 
            return next_step
 

	
 
        for category in cats:
 
            products = [
 
                i for i in available_products
 
                if i.category == category
 
            ]
 

	
 
            prefix = "category_" + str(category.id)
 
            p = _handle_products(request, category, products, prefix)
 
            products_form, discounts, products_handled = p
 

	
 
            section = GuidedRegistrationSection(
 
                title=category.name,
 
                description=category.description,
 
                discounts=discounts,
 
                form=products_form,
 
            )
 
        with CartController.operations_batch(request.user):
 
            for category in cats:
 
                products = [
 
                    i for i in available_products
 
                    if i.category == category
 
                ]
 

	
 
                prefix = "category_" + str(category.id)
 
                p = _handle_products(request, category, products, prefix)
 
                products_form, discounts, products_handled = p
 

	
 
                section = GuidedRegistrationSection(
 
                    title=category.name,
 
                    description=category.description,
 
                    discounts=discounts,
 
                    form=products_form,
 
                )
 

	
 
            if products:
 
                # This product category has items to show.
 
                sections.append(section)
 
                # Add this to the list of things to show if the form errors.
 
                request.session[SESSION_KEY].append(category.id)
 
                if products:
 
                    # This product category has items to show.
 
                    sections.append(section)
 
                    # Add this to the list of things to show if the form
 
                    # errors.
 
                    request.session[SESSION_KEY].append(category.id)
 

	
 
                if request.method == "POST" and not products_form.errors:
 
                    # This is only saved if we pass each form with no errors,
 
                    # and if the form actually has products.
 
                    attendee.guided_categories_complete.add(category)
 
                    if request.method == "POST" and not products_form.errors:
 
                        # This is only saved if we pass each form with no
 
                        # errors, and if the form actually has products.
 
                        attendee.guided_categories_complete.add(category)
 

	
 
    if sections and request.method == "POST":
 
        for section in sections:
 
            if section.form.errors:
 
                break
 
        else:
 
            attendee.save()
 
            if SESSION_KEY in request.session:
 
                del request.session[SESSION_KEY]
 
            # We've successfully processed everything
 
            return next_step
 

	
...
 
@@ -418,40 +420,48 @@ def _handle_products(request, category, products, prefix):
 
            carts = commerce.Cart.objects.filter(user=request.user)
 
            items = commerce.ProductItem.objects.filter(
 
                product__category=category,
 
                cart=carts,
 
            )
 
            if len(items) == 0:
 
                products_form.add_error(
 
                    None,
 
                    "You must have at least one item from this category",
 
                )
 
    handled = False if products_form.errors else True
 

	
 
    discounts = discount.available_discounts(request.user, [], products)
 
    # Making this a function to lazily evaluate when it's displayed
 
    # in templates.
 

	
 
    discounts = util.lazy(
 
        DiscountController.available_discounts,
 
        request.user,
 
        [],
 
        products,
 
    )
 

	
 
    return products_form, discounts, handled
 

	
 

	
 
def _set_quantities_from_products_form(products_form, current_cart):
 

	
 
    quantities = list(products_form.product_quantities())
 

	
 
    id_to_quantity = dict(i[:2] for i in quantities)
 
    pks = [i[0] for i in quantities]
 
    products = inventory.Product.objects.filter(
 
        id__in=pks,
 
    ).select_related("category")
 

	
 
    product_quantities = [
 
        (products.get(pk=i[0]), i[1]) for i in quantities
 
        (product, id_to_quantity[product.id]) for product in products
 
    ]
 
    field_names = dict(
 
        (i[0][0], i[1][2]) for i in zip(product_quantities, quantities)
 
    )
 

	
 
    try:
 
        current_cart.set_quantities(product_quantities)
 
    except CartValidationError as ve:
 
        for ve_field in ve.error_list:
 
            product, message = ve_field.message
 
            if product in field_names:
 
                field = field_names[product]
setup.cfg
Show inline comments
 
[flake8]
 
exclude = registrasion/migrations/*, build/*, docs/*
 
exclude = registrasion/migrations/*, build/*, docs/*, dist/*
0 comments (0 inline, 0 general)