Changeset - 162db2481781
[Not reviewed]
0 9 0
Christopher Neugebauer - 8 years ago 2016-04-29 01:08:45
chrisjrn@gmail.com
Flake8 fixes
9 files changed with 86 insertions and 35 deletions:
0 comments (0 inline, 0 general)
registrasion/controllers/cart.py
Show inline comments
 
import collections
 
import datetime
 
import discount
 
import functools
 
import itertools
 

	
 
from django.core.exceptions import ObjectDoesNotExist
 
from django.core.exceptions import ValidationError
 
from django.db import transaction
 
from django.db.models import Max
 
from django.utils import timezone
 

	
 
from registrasion.exceptions import CartValidationError
 
from registrasion.models import commerce
 
from registrasion.models import conditions
 
from registrasion.models import inventory
 

	
 
from .category import CategoryController
 
from .conditions import ConditionController
 
from .discount import DiscountController
 
from .flag import FlagController
 
from .product import ProductController
 

	
 

	
 
def _modifies_cart(func):
 
    ''' Decorator that makes the wrapped function raise ValidationError
 
    if we're doing something that could modify the cart. '''
 

	
 
    @functools.wraps(func)
 
    def inner(self, *a, **k):
 
        self._fail_if_cart_is_not_active()
 
        return func(self, *a, **k)
 

	
 
    return inner
 

	
 

	
 
class CartController(object):
 

	
 
    def __init__(self, cart):
 
        self.cart = cart
 

	
 
    @classmethod
 
    def for_user(cls, user):
 
        ''' Returns the user's current cart, or creates a new cart
 
        if there isn't one ready yet. '''
 

	
 
        try:
 
            existing = commerce.Cart.objects.get(
 
                user=user,
 
                status=commerce.Cart.STATUS_ACTIVE,
 
            )
 
        except ObjectDoesNotExist:
 
            existing = commerce.Cart.objects.create(
 
                user=user,
 
                time_last_updated=timezone.now(),
 
                reservation_duration=datetime.timedelta(),
 
            )
 
        return cls(existing)
 

	
 
    def _fail_if_cart_is_not_active(self):
 
        self.cart.refresh_from_db()
 
        if self.cart.status != commerce.Cart.STATUS_ACTIVE:
 
            raise ValidationError("You can only amend active carts.")
 

	
 
    @_modifies_cart
 
    def extend_reservation(self):
 
        ''' Updates the cart's time last updated value, which is used to
 
        determine whether the cart has reserved the items and discounts it
 
        holds. '''
 

	
 
        reservations = [datetime.timedelta()]
 

	
 
        # If we have vouchers, we're entitled to an hour at minimum.
 
        if len(self.cart.vouchers.all()) >= 1:
 
            reservations.append(inventory.Voucher.RESERVATION_DURATION)
 

	
 
        # Else, it's the maximum of the included products
 
        items = commerce.ProductItem.objects.filter(cart=self.cart)
 
        agg = items.aggregate(Max("product__reservation_duration"))
 
        product_max = agg["product__reservation_duration__max"]
 

	
 
        if product_max is not None:
 
            reservations.append(product_max)
 

	
 
        self.cart.time_last_updated = timezone.now()
 
        self.cart.reservation_duration = max(reservations)
 

	
 
    @_modifies_cart
 
    def end_batch(self):
 
        ''' Performs operations that occur occur at the end of a batch of
 
        product changes/voucher applications etc.
 
        THIS SHOULD BE PRIVATE
 
        '''
 

	
 
        self.recalculate_discounts()
 

	
 
        self.extend_reservation()
 
        self.cart.revision += 1
 
        self.cart.save()
...
 
@@ -285,155 +284,159 @@ class CartController(object):
 
    def validate_cart(self):
 
        ''' Determines whether the status of the current cart is valid;
 
        this is normally called before generating or paying an invoice '''
 

	
 
        cart = self.cart
 
        user = self.cart.user
 
        errors = []
 

	
 
        try:
 
            self._test_vouchers(self.cart.vouchers.all())
 
        except ValidationError as ve:
 
            errors.append(ve)
 

	
 
        items = commerce.ProductItem.objects.filter(cart=cart)
 

	
 
        product_quantities = list((i.product, i.quantity) for i in items)
 
        try:
 
            self._test_limits(product_quantities)
 
        except ValidationError as ve:
 
            self._append_errors(errors, ve)
 

	
 
        try:
 
            self._test_required_categories()
 
        except ValidationError as ve:
 
            self._append_errors(errors, ve)
 

	
 
        # Validate the discounts
 
        # TODO: refactor in terms of available_discounts
 
        # why aren't we doing that here?!
 
        discount_items = commerce.DiscountItem.objects.filter(cart=cart)
 
        seen_discounts = set()
 

	
 
        for discount_item in discount_items:
 
            discount = discount_item.discount
 
            if discount in seen_discounts:
 
                continue
 
            seen_discounts.add(discount)
 
            real_discount = conditions.DiscountBase.objects.get_subclass(
 
                pk=discount.pk)
 
            cond = ConditionController.for_condition(real_discount)
 

	
 
            if not cond.is_met(user):
 
                errors.append(
 
                    ValidationError("Discounts are no longer available")
 
                )
 

	
 
        if errors:
 
            raise ValidationError(errors)
 

	
 
    @_modifies_cart
 
    @transaction.atomic
 
    def fix_simple_errors(self):
 
        ''' This attempts to fix the easy errors raised by ValidationError.
 
        This includes removing items from the cart that are no longer
 
        available, recalculating all of the discounts, and removing voucher
 
        codes that are no longer available. '''
 

	
 
        # Fix vouchers first (this affects available discounts)
 
        to_remove = []
 
        for voucher in self.cart.vouchers.all():
 
            try:
 
                self._test_voucher(voucher)
 
            except ValidationError:
 
                to_remove.append(voucher)
 

	
 
        for voucher in to_remove:
 
            self.cart.vouchers.remove(voucher)
 

	
 
        # Fix products and discounts
 
        items = commerce.ProductItem.objects.filter(cart=self.cart)
 
        items = items.select_related("product")
 
        products = set(i.product for i in items)
 
        available = set(ProductController.available_products(
 
            self.cart.user,
 
            products=products,
 
        ))
 

	
 
        not_available = products - available
 
        zeros = [(product, 0) for product in not_available]
 

	
 
        self.set_quantities(zeros)
 

	
 
    @_modifies_cart
 
    @transaction.atomic
 
    def recalculate_discounts(self):
 
        ''' Calculates all of the discounts available for this product.
 
        '''
 

	
 
        # Delete the existing entries.
 
        commerce.DiscountItem.objects.filter(cart=self.cart).delete()
 

	
 
        product_items = self.cart.productitem_set.all().select_related(
 
            "product", "product__category", "product__price"
 
        )
 

	
 
        products = [i.product for i in product_items]
 
        discounts = DiscountController.available_discounts(self.cart.user, [], products)
 
        discounts = DiscountController.available_discounts(
 
            self.cart.user,
 
            [],
 
            products,
 
        )
 

	
 
        # The highest-value discounts will apply to the highest-value
 
        # products first.
 
        product_items = reversed(product_items)
 
        for item in product_items:
 
            self._add_discount(item.product, item.quantity, discounts)
 

	
 
    def _add_discount(self, product, quantity, discounts):
 
        ''' Applies the best discounts on the given product, from the given
 
        discounts.'''
 

	
 
        def matches(discount):
 
            ''' Returns True if and only if the given discount apples to
 
            our product. '''
 
            if isinstance(discount.clause, conditions.DiscountForCategory):
 
                return discount.clause.category == product.category
 
            else:
 
                return discount.clause.product == product
 

	
 
        def value(discount):
 
            ''' Returns the value of this discount clause
 
            as applied to this product '''
 
            if discount.clause.percentage is not None:
 
                return discount.clause.percentage * product.price
 
            else:
 
                return discount.clause.price
 

	
 
        discounts = [i for i in discounts if matches(i)]
 
        discounts.sort(key=value)
 

	
 
        for candidate in reversed(discounts):
 
            if quantity == 0:
 
                break
 
            elif candidate.quantity == 0:
 
                # This discount clause has been exhausted by this cart
 
                continue
 

	
 
            # Get a provisional instance for this DiscountItem
 
            # with the quantity set to as much as we have in the cart
 
            discount_item = commerce.DiscountItem.objects.create(
 
                product=product,
 
                cart=self.cart,
 
                discount=candidate.discount,
 
                quantity=quantity,
 
            )
 

	
 
            # Truncate the quantity for this DiscountItem if we exceed quantity
 
            ours = discount_item.quantity
 
            allowed = candidate.quantity
 
            if ours > allowed:
 
                discount_item.quantity = allowed
 
                discount_item.save()
 
                # Update the remaining quantity.
 
                quantity = ours - allowed
 
            else:
 
                quantity = 0
 

	
 
            candidate.quantity -= discount_item.quantity
registrasion/controllers/conditions.py
Show inline comments
 
import itertools
 
import operator
 

	
 
from collections import defaultdict
 
from collections import namedtuple
 

	
 
from django.db.models import Case
 
from django.db.models import Count
 
from django.db.models import F, Q
 
from django.db.models import Sum
 
from django.db.models import Value
 
from django.db.models import When
 
from django.utils import timezone
 

	
 
from registrasion.models import commerce
 
from registrasion.models import conditions
 
from registrasion.models import inventory
 

	
 

	
 

	
 
_BIG_QUANTITY = 99999999  # A big quantity
 

	
 

	
 
class ConditionController(object):
 
    ''' Base class for testing conditions that activate Flag
 
    or Discount objects. '''
 

	
 
    def __init__(self, condition):
 
        self.condition = condition
 

	
 
    @staticmethod
 
    def _controllers():
 
        return {
 
            conditions.CategoryFlag: CategoryConditionController,
 
            conditions.IncludedProductDiscount: ProductConditionController,
 
            conditions.ProductFlag: ProductConditionController,
 
            conditions.TimeOrStockLimitDiscount:
 
                TimeOrStockLimitDiscountController,
 
            conditions.TimeOrStockLimitFlag:
 
                TimeOrStockLimitFlagController,
 
            conditions.VoucherDiscount: VoucherConditionController,
 
            conditions.VoucherFlag: VoucherConditionController,
 
        }
 

	
 
    @staticmethod
 
    def for_type(cls):
 
        return ConditionController._controllers()[cls]
 

	
 
    @staticmethod
 
    def for_condition(condition):
 
        try:
 
            return ConditionController.for_type(type(condition))(condition)
 
        except KeyError:
 
            return ConditionController()
 

	
 
    @classmethod
 
    def pre_filter(cls, queryset, user):
 
        ''' Returns only the flag conditions that might be available for this
 
        user. It should hopefully reduce the number of queries that need to be
 
        executed to determine if a flag is met.
 

	
 
        If this filtration implements the same query as is_met, then you should
 
        be able to implement ``is_met()`` in terms of this.
 

	
 
        Arguments:
 

	
 
            queryset (Queryset[c]): The canditate conditions.
 

	
 
            user (User): The user for whom we're testing these conditions.
 

	
 
        Returns:
 
            Queryset[c]: A subset of the conditions that pass the pre-filter
 
                test for this user.
 

	
 
        '''
 

	
 
        # Default implementation does NOTHING.
 
        return queryset
 

	
 
    def passes_filter(self, user):
 
        ''' Returns true if the condition passes the filter '''
 

	
 
        cls = type(self.condition)
 
        qs = cls.objects.filter(pk=self.condition.id)
 
        return self.condition in self.pre_filter(qs, user)
 

	
 
    def user_quantity_remaining(self, user, filtered=False):
 
        ''' Returns the number of items covered by this flag condition the
 
        user can add to the current cart. This default implementation returns
 
        a big number if is_met() is true, otherwise 0.
 

	
 
        Either this method, or is_met() must be overridden in subclasses.
 
        '''
 

	
 
        return _BIG_QUANTITY if self.is_met(user, filtered) else 0
 

	
 
    def is_met(self, user, filtered=False):
 
        ''' Returns True if this flag condition is met, otherwise returns
 
        False.
 

	
 
        Either this method, or user_quantity_remaining() must be overridden
 
        in subclasses.
 

	
 
        Arguments:
...
 
@@ -141,160 +136,160 @@ class RemainderSetByFilter(object):
 

	
 

	
 

	
 
        # Mark self.condition with a remainder
 
        qs = type(self.condition).objects.filter(pk=self.condition.id)
 
        qs = self.pre_filter(qs, user)
 

	
 
        if len(qs) > 0:
 
            return qs[0].remainder
 
        else:
 
            return 0
 

	
 

	
 
class CategoryConditionController(IsMetByFilter, ConditionController):
 

	
 
    @classmethod
 
    def pre_filter(self, queryset, user):
 
        ''' Returns all of the items from queryset where the user has a
 
        product from a category invoking that item's condition in one of their
 
        carts. '''
 

	
 
        items = commerce.ProductItem.objects.filter(cart__user=user)
 
        items = items.exclude(cart__status=commerce.Cart.STATUS_RELEASED)
 
        items = items.select_related("product", "product__category")
 
        categories = [item.product.category for item in items]
 

	
 
        return queryset.filter(enabling_category__in=categories)
 

	
 

	
 
class ProductConditionController(IsMetByFilter, ConditionController):
 
    ''' Condition tests for ProductFlag and
 
    IncludedProductDiscount. '''
 

	
 
    @classmethod
 
    def pre_filter(self, queryset, user):
 
        ''' Returns all of the items from queryset where the user has a
 
        product invoking that item's condition in one of their carts. '''
 

	
 
        items = commerce.ProductItem.objects.filter(cart__user=user)
 
        items = items.exclude(cart__status=commerce.Cart.STATUS_RELEASED)
 
        items = items.select_related("product", "product__category")
 
        products = [item.product for item in items]
 

	
 
        return queryset.filter(enabling_products__in=products)
 

	
 

	
 
class TimeOrStockLimitConditionController(
 
        RemainderSetByFilter,
 
        ConditionController,
 
    ):
 
    ''' Common condition tests for TimeOrStockLimit Flag and
 
    Discount.'''
 

	
 
    @classmethod
 
    def pre_filter(self, queryset, user):
 
        ''' Returns all of the items from queryset where the date falls into
 
        any specified range, but not yet where the stock limit is not yet
 
        reached.'''
 

	
 
        now = timezone.now()
 

	
 
        # Keep items with no start time, or start time not yet met.
 
        queryset = queryset.filter(Q(start_time=None) | Q(start_time__lte=now))
 
        queryset = queryset.filter(Q(end_time=None) | Q(end_time__gte=now))
 

	
 
        # Filter out items that have been reserved beyond the limits
 
        quantity_or_zero = self._calculate_quantities(user)
 

	
 
        remainder = Case(
 
            When(limit=None, then=Value(_BIG_QUANTITY)),
 
            default=F("limit") - Sum(quantity_or_zero),
 
        )
 

	
 
        queryset = queryset.annotate(remainder=remainder)
 
        queryset = queryset.filter(remainder__gt=0)
 

	
 
        return queryset
 

	
 
    @classmethod
 
    def _relevant_carts(cls, user):
 
        reserved_carts = commerce.Cart.reserved_carts()
 
        reserved_carts = reserved_carts.exclude(
 
            user=user,
 
            status=commerce.Cart.STATUS_ACTIVE,
 
        )
 
        return reserved_carts
 

	
 

	
 
class TimeOrStockLimitFlagController(
 
        TimeOrStockLimitConditionController):
 

	
 
    @classmethod
 
    def _calculate_quantities(cls, user):
 
        reserved_carts = cls._relevant_carts(user)
 

	
 
        # Calculate category lines
 
        cat_items = F('categories__product__productitem__product__category')
 
        item_cats = F('categories__product__productitem__product__category')
 
        reserved_category_products = (
 
            Q(categories=F('categories__product__productitem__product__category')) &
 
            Q(categories=item_cats) &
 
            Q(categories__product__productitem__cart__in=reserved_carts)
 
        )
 

	
 
        # Calculate product lines
 
        reserved_products = (
 
            Q(products=F('products__productitem__product')) &
 
            Q(products__productitem__cart__in=reserved_carts)
 
        )
 

	
 
        category_quantity_in_reserved_carts = When(
 
            reserved_category_products,
 
            then="categories__product__productitem__quantity",
 
        )
 

	
 
        product_quantity_in_reserved_carts = When(
 
            reserved_products,
 
            then="products__productitem__quantity",
 
        )
 

	
 
        quantity_or_zero = Case(
 
            category_quantity_in_reserved_carts,
 
            product_quantity_in_reserved_carts,
 
            default=Value(0),
 
        )
 

	
 
        return quantity_or_zero
 

	
 

	
 
class TimeOrStockLimitDiscountController(TimeOrStockLimitConditionController):
 

	
 
    @classmethod
 
    def _calculate_quantities(cls, user):
 
        reserved_carts = cls._relevant_carts(user)
 

	
 
        quantity_in_reserved_carts = When(
 
            discountitem__cart__in=reserved_carts,
 
            then="discountitem__quantity"
 
        )
 

	
 
        quantity_or_zero = Case(
 
            quantity_in_reserved_carts,
 
            default=Value(0)
 
        )
 

	
 
        return quantity_or_zero
 

	
 

	
 
class VoucherConditionController(IsMetByFilter, ConditionController):
 
    ''' Condition test for VoucherFlag and VoucherDiscount.'''
 

	
 
    @classmethod
 
    def pre_filter(self, queryset, user):
 
        ''' Returns all of the items from queryset where the user has entered
 
        a voucher that invokes that item's condition in one of their carts. '''
 

	
 
        carts = commerce.Cart.objects.filter(
 
            user=user,
 
        )
 
        vouchers = [cart.vouchers.all() for cart in carts]
 

	
 
        return queryset.filter(voucher__in=itertools.chain(*vouchers))
registrasion/controllers/discount.py
Show inline comments
 
import itertools
 

	
 
from conditions import ConditionController
 
from registrasion.models import commerce
 
from registrasion.models import conditions
 

	
 
from django.db.models import Case
 
from django.db.models import Q
 
from django.db.models import Sum
 
from django.db.models import Value
 
from django.db.models import When
 

	
 

	
 
class DiscountAndQuantity(object):
 
    ''' Represents a discount that can be applied to a product or category
 
    for a given user.
 

	
 
    Attributes:
 

	
 
        discount (conditions.DiscountBase): The discount object that the
 
            clause arises from. A given DiscountBase can apply to multiple
 
            clauses.
 

	
 
        clause (conditions.DiscountForProduct|conditions.DiscountForCategory):
 
            A clause describing which product or category this discount item
 
            applies to. This casts to ``str()`` to produce a human-readable
 
            version of the clause.
 

	
 
        quantity (int): The number of times this discount item can be applied
 
            for the given user.
 

	
 
    '''
 

	
 
    def __init__(self, discount, clause, quantity):
 
        self.discount = discount
 
        self.clause = clause
 
        self.quantity = quantity
 

	
 
    def __repr__(self):
 
        return "(discount=%s, clause=%s, quantity=%d)" % (
 
            self.discount, self.clause, self.quantity,
 
        )
 

	
 

	
 
class DiscountController(object):
 

	
 
    @classmethod
 
    def available_discounts(cls, user, categories, products):
 
        ''' Returns all discounts available to this user for the given categories
 
        and products. The discounts also list the available quantity for this user,
 
        not including products that are pending purchase. '''
 
        ''' Returns all discounts available to this user for the given
 
        categories and products. The discounts also list the available quantity
 
        for this user, not including products that are pending purchase. '''
 

	
 

	
 
        filtered_clauses = cls._filtered_discounts(user, categories, products)
 

	
 
        discounts = []
 

	
 
        # Markers so that we don't need to evaluate given conditions more than once
 
        # Markers so that we don't need to evaluate given conditions
 
        # more than once
 
        accepted_discounts = set()
 
        failed_discounts = set()
 

	
 
        for clause in filtered_clauses:
 
            discount = clause.discount
 
            cond = ConditionController.for_condition(discount)
 

	
 
            past_use_count = discount.past_use_count
 

	
 

	
 
            if past_use_count >= clause.quantity:
 
                # This clause has exceeded its use count
 
                pass
 
            elif discount not in failed_discounts:
 
                # This clause is still available
 
                if discount in accepted_discounts or cond.is_met(user, filtered=True):
 
                is_accepted = discount in accepted_discounts
 
                if is_accepted or cond.is_met(user, filtered=True):
 
                    # This clause is valid for this user
 
                    discounts.append(DiscountAndQuantity(
 
                        discount=discount,
 
                        clause=clause,
 
                        quantity=clause.quantity - past_use_count,
 
                    ))
 
                    accepted_discounts.add(discount)
 
                else:
 
                    # This clause is not valid for this user
 
                    failed_discounts.add(discount)
 
        return discounts
 

	
 
    @classmethod
 
    def _filtered_discounts(cls, user, categories, products):
 
        '''
 

	
 
        Returns:
 
            Sequence[discountbase]: All discounts that passed the filter function.
 
            Sequence[discountbase]: All discounts that passed the filter
 
            function.
 

	
 
        '''
 

	
 
        types = list(ConditionController._controllers())
 
        discounttypes = [i for i in types if issubclass(i, conditions.DiscountBase)]
 
        discounttypes = [
 
            i for i in types if issubclass(i, conditions.DiscountBase)
 
        ]
 

	
 
        # discounts that match provided categories
 
        category_discounts = conditions.DiscountForCategory.objects.filter(
 
            category__in=categories
 
        )
 
        # discounts that match provided products
 
        product_discounts = conditions.DiscountForProduct.objects.filter(
 
            product__in=products
 
        )
 
        # discounts that match categories for provided products
 
        product_category_discounts = conditions.DiscountForCategory.objects.filter(
 
        product_category_discounts = conditions.DiscountForCategory.objects
 
        product_category_discounts = product_category_discounts.filter(
 
            category__in=(product.category for product in products)
 
        )
 
        # (Not relevant: discounts that match products in provided categories)
 

	
 
        product_discounts = product_discounts.select_related(
 
            "product",
 
            "product__category",
 
        )
 

	
 
        all_category_discounts = category_discounts | product_category_discounts
 
        all_category_discounts = (
 
            category_discounts | product_category_discounts
 
        )
 
        all_category_discounts = all_category_discounts.select_related(
 
            "category",
 
        )
 

	
 
        valid_discounts = conditions.DiscountBase.objects.filter(
 
            Q(discountforproduct__in=product_discounts) |
 
            Q(discountforcategory__in=all_category_discounts)
 
        )
 

	
 
        all_subsets = []
 

	
 
        for discounttype in discounttypes:
 
            discounts = discounttype.objects.filter(id__in=valid_discounts)
 
            ctrl = ConditionController.for_type(discounttype)
 
            discounts = ctrl.pre_filter(discounts, user)
 
            discounts = cls._annotate_with_past_uses(discounts, user)
 
            all_subsets.append(discounts)
 

	
 
        filtered_discounts = list(itertools.chain(*all_subsets))
 

	
 
        # Map from discount key to itself (contains annotations added by filter)
 
        # Map from discount key to itself
 
        # (contains annotations needed in the future)
 
        from_filter = dict((i.id, i) for i in filtered_discounts)
 

	
 
        # The set of all potential discounts
 
        discount_clauses = set(itertools.chain(
 
            product_discounts.filter(discount__in=filtered_discounts),
 
            all_category_discounts.filter(discount__in=filtered_discounts),
 
        ))
 

	
 
        # Replace discounts with the filtered ones
 
        # These are the correct subclasses (saves query later on), and have
 
        # correct annotations from filters if necessary.
 
        for clause in discount_clauses:
 
            clause.discount = from_filter[clause.discount.id]
 

	
 
        return discount_clauses
 

	
 
    @classmethod
 
    def _annotate_with_past_uses(cls, queryset, user):
 
        ''' Annotates the queryset with a usage count for that discount by the
 
        given user. '''
 

	
 
        past_use_quantity = When(
 
            (
 
                Q(discountitem__cart__user=user) &
 
                Q(discountitem__cart__status=commerce.Cart.STATUS_PAID)
 
            ),
 
            then="discountitem__quantity",
 
        )
 

	
 
        past_use_quantity_or_zero = Case(
 
            past_use_quantity,
 
            default=Value(0),
 
        )
 

	
 
        queryset = queryset.annotate(past_use_count=Sum(past_use_quantity_or_zero))
 
        queryset = queryset.annotate(
 
            past_use_count=Sum(past_use_quantity_or_zero)
 
        )
 
        return queryset
registrasion/controllers/flag.py
Show inline comments
...
 
@@ -135,123 +135,130 @@ class FlagController(object):
 
            if f.dif > 0 and f.dif != dif_count[product]:
 
                do_not_disable[product] = False
 
                if product not in messages:
 
                    messages[product] = "Some disable-if-false " \
 
                                        "conditions were not met"
 
            if f.eit > 0 and product not in do_enable:
 
                do_enable[product] = False
 
                if product not in messages:
 
                    messages[product] = "Some enable-if-true " \
 
                                        "conditions were not met"
 

	
 
        for product in itertools.chain(do_not_disable, do_enable):
 
            f = total_flags.get(product)
 
            if product in do_enable:
 
                # If there's an enable-if-true, we need need of those met too.
 
                # (do_not_disable will default to true otherwise)
 
                valid[product] = do_not_disable[product] and do_enable[product]
 
            elif product in do_not_disable:
 
                # If there's a disable-if-false condition, all must be met
 
                valid[product] = do_not_disable[product]
 

	
 
        error_fields = [
 
            (product, messages[product])
 
            for product in valid if not valid[product]
 
        ]
 

	
 
        return error_fields
 

	
 
    @classmethod
 
    def _filtered_flags(cls, user, products):
 
        '''
 

	
 
        Returns:
 
            Sequence[flagbase]: All flags that passed the filter function.
 

	
 
        '''
 

	
 
        types = list(ConditionController._controllers())
 
        flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)]
 

	
 
        # Get all flags for the products and categories.
 
        prods = (
 
            product.flagbase_set.all()
 
            for product in products
 
        )
 
        cats = (
 
            category.flagbase_set.all()
 
            for category in set(product.category for product in products)
 
        )
 
        all_flags = reduce(operator.or_, itertools.chain(prods, cats))
 

	
 
        all_subsets = []
 

	
 
        for flagtype in flagtypes:
 
            flags = flagtype.objects.filter(id__in=all_flags)
 
            ctrl = ConditionController.for_type(flagtype)
 
            flags = ctrl.pre_filter(flags, user)
 
            all_subsets.append(flags)
 

	
 
        return itertools.chain(*all_subsets)
 

	
 

	
 
ConditionAndRemainder = namedtuple(
 
    "ConditionAndRemainder",
 
    (
 
        "condition",
 
        "remainder",
 
    ),
 
)
 

	
 

	
 
_FlagCounter = namedtuple(
 
    "_FlagCounter",
 
    (
 
        "products",
 
        "categories",
 
    ),
 
)
 

	
 

	
 
_ConditionsCount = namedtuple(
 
    "ConditionsCount",
 
    (
 
        "dif",
 
        "eit",
 
    ),
 
)
 

	
 

	
 
class FlagCounter(_FlagCounter):
 

	
 
    @classmethod
 
    def count(cls):
 
        # Get the count of how many conditions should exist per product
 
        flagbases = conditions.FlagBase.objects
 

	
 
        types = (conditions.FlagBase.ENABLE_IF_TRUE, conditions.FlagBase.DISABLE_IF_FALSE)
 
        types = (
 
            conditions.FlagBase.ENABLE_IF_TRUE,
 
            conditions.FlagBase.DISABLE_IF_FALSE,
 
        )
 
        keys = ("eit", "dif")
 
        flags = [
 
            flagbases.filter(condition=condition_type
 
            ).values('products', 'categories'
 
            ).annotate(count=Count('id'))
 
            flagbases.filter(
 
                condition=condition_type
 
            ).values(
 
                'products', 'categories'
 
            ).annotate(
 
                count=Count('id')
 
            )
 
            for condition_type in types
 
        ]
 

	
 
        cats = defaultdict(lambda: defaultdict(int))
 
        prod = defaultdict(lambda: defaultdict(int))
 

	
 
        for key, flagcounts in zip(keys, flags):
 
            for row in flagcounts:
 
                if row["products"] is not None:
 
                    prod[row["products"]][key] = row["count"]
 
                if row["categories"] is not None:
 
                    cats[row["categories"]][key] = row["count"]
 

	
 
        return cls(products=prod, categories=cats)
 

	
 
    def get(self, product):
 
        p = self.products[product.id]
 
        c = self.categories[product.category.id]
 
        eit = p["eit"] + c["eit"]
 
        dif = p["dif"] + c["dif"]
 
        return _ConditionsCount(dif=dif, eit=eit)
registrasion/controllers/product.py
Show inline comments
 
import itertools
 

	
 
from django.db.models import Sum
 
from registrasion.models import commerce
 
from registrasion.models import inventory
 

	
 
from .category import CategoryController
 
from .conditions import ConditionController
 
from .flag import FlagController
 

	
 

	
 
class ProductController(object):
 

	
 
    def __init__(self, product):
 
        self.product = product
 

	
 
    @classmethod
 
    def available_products(cls, user, category=None, products=None):
 
        ''' Returns a list of all of the products that are available per
 
        flag conditions from the given categories.
 
        TODO: refactor so that all conditions are tested here and
 
        can_add_with_flags calls this method. '''
 
        if category is None and products is None:
 
            raise ValueError("You must provide products or a category")
 

	
 
        if category is not None:
 
            all_products = inventory.Product.objects.filter(category=category)
 
            all_products = all_products.select_related("category")
 
        else:
 
            all_products = []
 

	
 
        if products is not None:
 
            all_products = set(itertools.chain(all_products, products))
 

	
 
        cat_quants = dict(
 
            (
 
                category,
 
                CategoryController(category).user_quantity_remaining(user),
 
            )
 
            for category in set(product.category for product in all_products)
 
        )
 

	
 
        passed_limits = set(
 
            product
 
            for product in all_products
 
            if cat_quants[product.category] > 0
 
            if cls(product).user_quantity_remaining(user) > 0
 
        )
 

	
 
        failed_and_messages = FlagController.test_flags(
 
            user, products=passed_limits
 
        )
 
        failed_conditions = set(i[0] for i in failed_and_messages)
 

	
 
        out = list(passed_limits - failed_conditions)
 
        out.sort(key=lambda product: product.order)
 

	
 
        return out
 

	
 
    def user_quantity_remaining(self, user):
 
        ''' Returns the quantity of this product that the user add in the
 
        current cart. '''
 

	
 
        prod_limit = self.product.limit_per_user
 

	
 
        if prod_limit is None:
 
            # Don't need to run the remaining queries
 
            return 999999  # We can do better
 

	
 
        carts = commerce.Cart.objects.filter(
 
            user=user,
 
            status=commerce.Cart.STATUS_PAID,
 
        )
 

	
 
        items = commerce.ProductItem.objects.filter(
 
            cart__in=carts,
 
            product=self.product,
 
        )
 

	
 
        prod_count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0
 

	
 
        return prod_limit - prod_count
registrasion/tests/test_cart.py
Show inline comments
 
import datetime
 
import pytz
 

	
 
from decimal import Decimal
 
from django.contrib.auth.models import User
 
from django.core.exceptions import ObjectDoesNotExist
 
from django.core.exceptions import ValidationError
 
from django.core.management import call_command
 
from django.test import TestCase
 

	
 
from registrasion.models import commerce
 
from registrasion.models import conditions
 
from registrasion.models import inventory
 
from registrasion.models import people
 
from registrasion.controllers.product import ProductController
 

	
 
from controller_helpers import TestingCartController
 
from patch_datetime import SetTimeMixin
 

	
 
UTC = pytz.timezone('UTC')
 

	
 

	
 
class RegistrationCartTestCase(SetTimeMixin, TestCase):
 

	
 
    def setUp(self):
 
        super(RegistrationCartTestCase, self).setUp()
 

	
 
    def tearDown(self):
 
        if False:
 
        if True:
 
            # If you're seeing segfaults in tests, enable this.
 
            call_command(
 
                'flush',
 
                verbosity=0,
 
                interactive=False,
 
                reset_sequences=False,
 
                allow_cascade=False,
 
                inhibit_post_migrate=False
 
            )
 

	
 
        super(RegistrationCartTestCase, self).tearDown()
 

	
 
    @classmethod
 
    def setUpTestData(cls):
 

	
 
        super(RegistrationCartTestCase, cls).setUpTestData()
 

	
 
        cls.USER_1 = User.objects.create_user(
 
            username='testuser',
 
            email='test@example.com',
 
            password='top_secret')
 

	
 
        cls.USER_2 = User.objects.create_user(
 
            username='testuser2',
 
            email='test2@example.com',
 
            password='top_secret')
 

	
 
        attendee1 = people.Attendee.get_instance(cls.USER_1)
 
        people.AttendeeProfileBase.objects.create(
 
            attendee=attendee1,
 
        )
 
        attendee2 = people.Attendee.get_instance(cls.USER_2)
 
        people.AttendeeProfileBase.objects.create(
 
            attendee=attendee2,
 
        )
 

	
 
        cls.RESERVATION = datetime.timedelta(hours=1)
 

	
 
        cls.categories = []
 
        for i in xrange(2):
 
            cat = inventory.Category.objects.create(
 
                name="Category " + str(i + 1),
 
                description="This is a test category",
 
                order=i,
 
                render_type=inventory.Category.RENDER_TYPE_RADIO,
 
                required=False,
 
            )
 
            cls.categories.append(cat)
 

	
 
        cls.CAT_1 = cls.categories[0]
 
        cls.CAT_2 = cls.categories[1]
 

	
 
        cls.products = []
 
        for i in xrange(4):
 
            prod = inventory.Product.objects.create(
 
                name="Product " + str(i + 1),
 
                description="This is a test product.",
 
                category=cls.categories[i / 2],  # 2 products per category
 
                price=Decimal("10.00"),
 
                reservation_duration=cls.RESERVATION,
 
                limit_per_user=10,
 
                order=1,
 
            )
 
            cls.products.append(prod)
 

	
 
        cls.PROD_1 = cls.products[0]
 
        cls.PROD_2 = cls.products[1]
 
        cls.PROD_3 = cls.products[2]
 
        cls.PROD_4 = cls.products[3]
 

	
 
        cls.PROD_4.price = Decimal("5.00")
 
        cls.PROD_4.save()
 

	
 
        # Burn through some carts -- this made some past flag tests fail
 
        current_cart = TestingCartController.for_user(cls.USER_1)
 

	
 
        current_cart.next_cart()
 

	
 
        current_cart = TestingCartController.for_user(cls.USER_2)
 

	
 
        current_cart.next_cart()
 

	
 
    @classmethod
 
    def make_ceiling(cls, name, limit=None, start_time=None, end_time=None):
 
        limit_ceiling = conditions.TimeOrStockLimitFlag.objects.create(
 
            description=name,
 
            condition=conditions.FlagBase.DISABLE_IF_FALSE,
 
            limit=limit,
 
            start_time=start_time,
 
            end_time=end_time
 
        )
 
        limit_ceiling.products.add(cls.PROD_1, cls.PROD_2)
 

	
 
    @classmethod
 
    def make_category_ceiling(
 
            cls, name, limit=None, start_time=None, end_time=None):
registrasion/tests/test_ceilings.py
Show inline comments
...
 
@@ -56,164 +56,168 @@ class CeilingsTestCases(RegistrationCartTestCase):
 
        self.set_time(datetime.datetime(2014, 01, 01, tzinfo=UTC))
 
        with self.assertRaises(ValidationError):
 
            current_cart.add_to_cart(self.PROD_1, 1)
 

	
 
        # User should be able to add whilst we're during date range
 
        # On edge of start
 
        self.set_time(datetime.datetime(2015, 01, 01, tzinfo=UTC))
 
        current_cart.add_to_cart(self.PROD_1, 1)
 
        # In middle
 
        self.set_time(datetime.datetime(2015, 01, 15, tzinfo=UTC))
 
        current_cart.add_to_cart(self.PROD_1, 1)
 
        # On edge of end
 
        self.set_time(datetime.datetime(2015, 02, 01, tzinfo=UTC))
 
        current_cart.add_to_cart(self.PROD_1, 1)
 

	
 
        # User should not be able to add whilst we're after date range
 
        self.set_time(datetime.datetime(2014, 01, 01, minute=01, tzinfo=UTC))
 
        with self.assertRaises(ValidationError):
 
            current_cart.add_to_cart(self.PROD_1, 1)
 

	
 
    def test_add_to_cart_ceiling_limit_reserved_carts(self):
 
        self.make_ceiling("Limit ceiling", limit=1)
 

	
 
        self.set_time(datetime.datetime(2015, 01, 01, tzinfo=UTC))
 

	
 
        first_cart = TestingCartController.for_user(self.USER_1)
 
        second_cart = TestingCartController.for_user(self.USER_2)
 

	
 
        first_cart.add_to_cart(self.PROD_1, 1)
 

	
 
        # User 2 should not be able to add item to their cart
 
        # because user 1 has item reserved, exhausting the ceiling
 
        with self.assertRaises(ValidationError):
 
            second_cart.add_to_cart(self.PROD_1, 1)
 

	
 
        # User 2 should be able to add item to their cart once the
 
        # reservation duration is elapsed
 
        self.add_timedelta(self.RESERVATION + datetime.timedelta(seconds=1))
 
        second_cart.add_to_cart(self.PROD_1, 1)
 

	
 
        # User 2 pays for their cart
 

	
 
        second_cart.next_cart()
 

	
 
        # User 1 should not be able to add item to their cart
 
        # because user 2 has paid for their reserved item, exhausting
 
        # the ceiling, regardless of the reservation time.
 
        self.add_timedelta(self.RESERVATION * 20)
 
        with self.assertRaises(ValidationError):
 
            first_cart.add_to_cart(self.PROD_1, 1)
 

	
 
    def test_validate_cart_fails_product_ceilings(self):
 
        self.make_ceiling("Limit ceiling", limit=1)
 
        self.__validation_test()
 

	
 
    def test_validate_cart_fails_product_discount_ceilings(self):
 
        self.make_discount_ceiling("Limit ceiling", limit=1)
 
        self.__validation_test()
 

	
 
    def __validation_test(self):
 
        self.set_time(datetime.datetime(2015, 01, 01, tzinfo=UTC))
 

	
 
        first_cart = TestingCartController.for_user(self.USER_1)
 
        second_cart = TestingCartController.for_user(self.USER_2)
 

	
 
        # Adding a valid product should validate.
 
        first_cart.add_to_cart(self.PROD_1, 1)
 
        first_cart.validate_cart()
 

	
 
        # Cart should become invalid if lapsed carts are claimed.
 
        self.add_timedelta(self.RESERVATION + datetime.timedelta(seconds=1))
 

	
 
        # Unpaid cart within reservation window
 
        second_cart.add_to_cart(self.PROD_1, 1)
 
        with self.assertRaises(ValidationError):
 
            first_cart.validate_cart()
 

	
 
        # Paid cart outside the reservation window
 

	
 
        second_cart.next_cart()
 
        self.add_timedelta(self.RESERVATION + datetime.timedelta(seconds=1))
 
        with self.assertRaises(ValidationError):
 
            first_cart.validate_cart()
 

	
 
    def test_discount_ceiling_aggregates_products(self):
 
        # Create two carts, add 1xprod_1 to each. Ceiling should disappear
 
        # after second.
 
        self.make_discount_ceiling(
 
            "Multi-product limit discount ceiling",
 
            limit=2,
 
        )
 
        for i in xrange(2):
 
            cart = TestingCartController.for_user(self.USER_1)
 
            cart.add_to_cart(self.PROD_1, 1)
 
            cart.next_cart()
 

	
 
        discounts = DiscountController.available_discounts(self.USER_1, [], [self.PROD_1])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_1],
 
        )
 

	
 
        self.assertEqual(0, len(discounts))
 

	
 
    def test_flag_ceiling_aggregates_products(self):
 
        # Create two carts, add 1xprod_1 to each. Ceiling should disappear
 
        # after second.
 
        self.make_ceiling("Multi-product limit ceiling", limit=2)
 

	
 
        for i in xrange(2):
 
            cart = TestingCartController.for_user(self.USER_1)
 
            cart.add_to_cart(self.PROD_1, 1)
 
            cart.next_cart()
 

	
 
        products = ProductController.available_products(
 
            self.USER_1,
 
            products=[self.PROD_1],
 
        )
 

	
 
        self.assertEqual(0, len(products))
 

	
 
    def test_items_released_from_ceiling_by_refund(self):
 
        self.make_ceiling("Limit ceiling", limit=1)
 

	
 
        first_cart = TestingCartController.for_user(self.USER_1)
 
        first_cart.add_to_cart(self.PROD_1, 1)
 

	
 
        first_cart.next_cart()
 

	
 
        second_cart = TestingCartController.for_user(self.USER_2)
 
        with self.assertRaises(ValidationError):
 
            second_cart.add_to_cart(self.PROD_1, 1)
 

	
 
        first_cart.cart.status = commerce.Cart.STATUS_RELEASED
 
        first_cart.cart.save()
 

	
 
        second_cart.add_to_cart(self.PROD_1, 1)
 

	
 
    def test_discount_ceiling_only_counts_items_covered_by_ceiling(self):
 
        self.make_discount_ceiling("Limit ceiling", limit=1, percentage=50)
 
        voucher = self.new_voucher(code="VOUCHER")
 

	
 
        discount = conditions.VoucherDiscount.objects.create(
 
            description="VOUCHER RECIPIENT",
 
            voucher=voucher,
 
        )
 
        conditions.DiscountForProduct.objects.create(
 
            discount=discount,
 
            product=self.PROD_1,
 
            percentage=100,
 
            quantity=1
 
        )
 

	
 
        # Buy two of PROD_1, in separate carts:
 
        cart = TestingCartController.for_user(self.USER_1)
 
        # the 100% discount from the voucher should apply to the first item
 
        # and not the ceiling discount.
 
        cart.apply_voucher("VOUCHER")
 
        cart.add_to_cart(self.PROD_1, 1)
 
        self.assertEqual(1, cart.cart.discountitem_set.count())
 

	
 
        cart.next_cart()
 

	
 
        # The second cart has no voucher attached, so should apply the
 
        # ceiling discount
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)
 
        self.assertEqual(1, cart.cart.discountitem_set.count())
registrasion/tests/test_discount.py
Show inline comments
...
 
@@ -152,255 +152,283 @@ class DiscountTestCase(RegistrationCartTestCase):
 
        # The half discount should be applied only once
 
        self.assertEqual(1, discount_items[0].quantity)
 
        self.assertEqual(discount_half.pk, discount_items[0].discount.pk)
 
        # The full discount should be applied twice
 
        self.assertEqual(2, discount_items[1].quantity)
 
        self.assertEqual(discount_full.pk, discount_items[1].discount.pk)
 

	
 
    def test_discount_applies_across_carts(self):
 
        self.add_discount_prod_1_includes_prod_2()
 

	
 
        # Enable the discount during the first cart.
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)
 

	
 
        cart.next_cart()
 

	
 
        # Use the discount in the second cart
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_2, 1)
 

	
 
        # The discount should be applied.
 
        self.assertEqual(1, len(cart.cart.discountitem_set.all()))
 

	
 
        cart.next_cart()
 

	
 
        # The discount should respect the total quantity across all
 
        # of the user's carts.
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_2, 2)
 

	
 
        # Having one item in the second cart leaves one more item where
 
        # the discount is applicable. The discount should apply, but only for
 
        # quantity=1
 
        discount_items = list(cart.cart.discountitem_set.all())
 
        self.assertEqual(1, discount_items[0].quantity)
 

	
 
    def test_discount_applies_only_once_enabled(self):
 
        # Enable the discount during the first cart.
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)
 
        # This would exhaust discount if present
 
        cart.add_to_cart(self.PROD_2, 2)
 

	
 
        cart.next_cart()
 

	
 
        self.add_discount_prod_1_includes_prod_2()
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_2, 2)
 

	
 
        discount_items = list(cart.cart.discountitem_set.all())
 
        self.assertEqual(2, discount_items[0].quantity)
 

	
 
    def test_category_discount_applies_once_per_category(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=1)
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)
 

	
 
        # Add two items from category 2
 
        cart.add_to_cart(self.PROD_3, 1)
 
        cart.add_to_cart(self.PROD_4, 1)
 

	
 
        discount_items = list(cart.cart.discountitem_set.all())
 
        # There is one discount, and it should apply to one item.
 
        self.assertEqual(1, len(discount_items))
 
        self.assertEqual(1, discount_items[0].quantity)
 

	
 
    def test_category_discount_applies_to_highest_value(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=1)
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)
 

	
 
        # Add two items from category 2, add the less expensive one first
 
        cart.add_to_cart(self.PROD_4, 1)
 
        cart.add_to_cart(self.PROD_3, 1)
 

	
 
        discount_items = list(cart.cart.discountitem_set.all())
 
        # There is one discount, and it should apply to the more expensive.
 
        self.assertEqual(1, len(discount_items))
 
        self.assertEqual(self.PROD_3, discount_items[0].product)
 

	
 
    def test_discount_quantity_is_per_user(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=1)
 

	
 
        # Both users should be able to apply the same discount
 
        # in the same way
 
        for user in (self.USER_1, self.USER_2):
 
            cart = TestingCartController.for_user(user)
 
            cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 
            cart.add_to_cart(self.PROD_3, 1)
 

	
 
            discount_items = list(cart.cart.discountitem_set.all())
 
            # The discount is applied.
 
            self.assertEqual(1, len(discount_items))
 

	
 
    # Tests for the DiscountController.available_discounts enumerator
 
    def test_enumerate_no_discounts_for_no_input(self):
 
        discounts = DiscountController.available_discounts(self.USER_1, [], [])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [],
 
        )
 
        self.assertEqual(0, len(discounts))
 

	
 
    def test_enumerate_no_discounts_if_condition_not_met(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=1)
 

	
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_3],
 
        )
 
        self.assertEqual(0, len(discounts))
 

	
 
        discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], [])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [self.CAT_2],
 
            [],
 
        )
 
        self.assertEqual(0, len(discounts))
 

	
 
    def test_category_discount_appears_once_if_met_twice(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=1)
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 

	
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [self.CAT_2],
 
            [self.PROD_3],
 
        )
 
        self.assertEqual(1, len(discounts))
 

	
 
    def test_category_discount_appears_with_category(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=1)
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 

	
 
        discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], [])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [self.CAT_2],
 
            [],
 
        )
 
        self.assertEqual(1, len(discounts))
 

	
 
    def test_category_discount_appears_with_product(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=1)
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 

	
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_3],
 
        )
 
        self.assertEqual(1, len(discounts))
 

	
 
    def test_category_discount_appears_once_with_two_valid_product(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=1)
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 

	
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_3, self.PROD_4]
 
        )
 
        self.assertEqual(1, len(discounts))
 

	
 
    def test_product_discount_appears_with_product(self):
 
        self.add_discount_prod_1_includes_prod_2(quantity=1)
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 

	
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_2],
 
        )
 
        self.assertEqual(1, len(discounts))
 

	
 
    def test_product_discount_does_not_appear_with_category(self):
 
        self.add_discount_prod_1_includes_prod_2(quantity=1)
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 

	
 
        discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_1], [])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [self.CAT_1],
 
            [],
 
        )
 
        self.assertEqual(0, len(discounts))
 

	
 
    def test_discount_quantity_is_correct_before_first_purchase(self):
 
        self.add_discount_prod_1_includes_cat_2(quantity=2)
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 
        cart.add_to_cart(self.PROD_3, 1)  # Exhaust the quantity
 

	
 
        discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], [])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [self.CAT_2],
 
            [],
 
        )
 
        self.assertEqual(2, discounts[0].quantity)
 

	
 
        cart.next_cart()
 

	
 
    def test_discount_quantity_is_correct_after_first_purchase(self):
 
        self.test_discount_quantity_is_correct_before_first_purchase()
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_3, 1)  # Exhaust the quantity
 

	
 
        discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], [])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [self.CAT_2],
 
            [],
 
        )
 
        self.assertEqual(1, discounts[0].quantity)
 

	
 
        cart.next_cart()
 

	
 
    def test_discount_is_gone_after_quantity_exhausted(self):
 
        self.test_discount_quantity_is_correct_after_first_purchase()
 
        discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], [])
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [self.CAT_2],
 
            [],
 
        )
 
        self.assertEqual(0, len(discounts))
 

	
 
    def test_product_discount_enabled_twice_appears_twice(self):
 
        self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=2)
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_3, self.PROD_4],
 
        )
 
        self.assertEqual(2, len(discounts))
 

	
 
    def test_discounts_are_released_by_refunds(self):
 
        self.add_discount_prod_1_includes_prod_2(quantity=2)
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_1, 1)  # Enable the discount
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_2],
 
        )
 
        self.assertEqual(1, len(discounts))
 

	
 
        cart.next_cart()
 

	
 
        cart = TestingCartController.for_user(self.USER_1)
 
        cart.add_to_cart(self.PROD_2, 2)  # The discount will be exhausted
 

	
 
        cart.next_cart()
 

	
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_2],
 
        )
 
        self.assertEqual(0, len(discounts))
 

	
 
        cart.cart.status = commerce.Cart.STATUS_RELEASED
 
        cart.cart.save()
 

	
 
        discounts = DiscountController.available_discounts(
 
            self.USER_1,
 
            [],
 
            [self.PROD_2],
 
        )
 
        self.assertEqual(1, len(discounts))
registrasion/views.py
Show inline comments
...
 
@@ -334,193 +334,197 @@ def product_category(request, category_id):
 

	
 
    PRODUCTS_FORM_PREFIX = "products"
 
    VOUCHERS_FORM_PREFIX = "vouchers"
 

	
 
    # Handle the voucher form *before* listing products.
 
    # Products can change as vouchers are entered.
 
    v = _handle_voucher(request, VOUCHERS_FORM_PREFIX)
 
    voucher_form, voucher_handled = v
 

	
 
    category_id = int(category_id)  # Routing is [0-9]+
 
    category = inventory.Category.objects.get(pk=category_id)
 

	
 
    products = ProductController.available_products(
 
        request.user,
 
        category=category,
 
    )
 

	
 
    if not products:
 
        messages.warning(
 
            request,
 
            "There are no products available from category: " + category.name,
 
        )
 
        return redirect("dashboard")
 

	
 
    p = _handle_products(request, category, products, PRODUCTS_FORM_PREFIX)
 
    products_form, discounts, products_handled = p
 

	
 
    if request.POST and not voucher_handled and not products_form.errors:
 
        # Only return to the dashboard if we didn't add a voucher code
 
        # and if there's no errors in the products form
 
        messages.success(
 
            request,
 
            "Your reservations have been updated.",
 
        )
 
        return redirect("dashboard")
 

	
 
    data = {
 
        "category": category,
 
        "discounts": discounts,
 
        "form": products_form,
 
        "voucher_form": voucher_form,
 
    }
 

	
 
    return render(request, "registrasion/product_category.html", data)
 

	
 

	
 
def _handle_products(request, category, products, prefix):
 
    ''' Handles a products list form in the given request. Returns the
 
    form instance, the discounts applicable to this form, and whether the
 
    contents were handled. '''
 

	
 
    current_cart = CartController.for_user(request.user)
 

	
 
    ProductsForm = forms.ProductsForm(category, products)
 

	
 
    # Create initial data for each of products in category
 
    items = commerce.ProductItem.objects.filter(
 
        product__in=products,
 
        cart=current_cart.cart,
 
    ).select_related("product")
 
    quantities = []
 
    seen = set()
 

	
 
    for item in items:
 
        quantities.append((item.product, item.quantity))
 
        seen.add(item.product)
 

	
 
    zeros = set(products) - seen
 
    for product in zeros:
 
        quantities.append((product, 0))
 

	
 
    products_form = ProductsForm(
 
        request.POST or None,
 
        product_quantities=quantities,
 
        prefix=prefix,
 
    )
 

	
 
    if request.method == "POST" and products_form.is_valid():
 
        if products_form.has_changed():
 
            _set_quantities_from_products_form(products_form, current_cart)
 

	
 
        # If category is required, the user must have at least one
 
        # in an active+valid cart
 
        if category.required:
 
            carts = commerce.Cart.objects.filter(user=request.user)
 
            items = commerce.ProductItem.objects.filter(
 
                product__category=category,
 
                cart=carts,
 
            )
 
            if len(items) == 0:
 
                products_form.add_error(
 
                    None,
 
                    "You must have at least one item from this category",
 
                )
 
    handled = False if products_form.errors else True
 

	
 
    discounts = DiscountController.available_discounts(request.user, [], products)
 
    discounts = DiscountController.available_discounts(
 
        request.user,
 
        [],
 
        products,
 
    )
 

	
 
    return products_form, discounts, handled
 

	
 

	
 
def _set_quantities_from_products_form(products_form, current_cart):
 

	
 
    quantities = list(products_form.product_quantities())
 

	
 
    pks = [i[0] for i in quantities]
 
    products = inventory.Product.objects.filter(
 
        id__in=pks,
 
    ).select_related("category")
 

	
 
    product_quantities = [
 
        (products.get(pk=i[0]), i[1]) for i in quantities
 
    ]
 
    field_names = dict(
 
        (i[0][0], i[1][2]) for i in zip(product_quantities, quantities)
 
    )
 

	
 
    try:
 
        current_cart.set_quantities(product_quantities)
 
    except CartValidationError as ve:
 
        for ve_field in ve.error_list:
 
            product, message = ve_field.message
 
            if product in field_names:
 
                field = field_names[product]
 
            elif isinstance(product, inventory.Product):
 
                continue
 
            else:
 
                field = None
 
            products_form.add_error(field, message)
 

	
 

	
 
def _handle_voucher(request, prefix):
 
    ''' Handles a voucher form in the given request. Returns the voucher
 
    form instance, and whether the voucher code was handled. '''
 

	
 
    voucher_form = forms.VoucherForm(request.POST or None, prefix=prefix)
 
    current_cart = CartController.for_user(request.user)
 

	
 
    if (voucher_form.is_valid() and
 
            voucher_form.cleaned_data["voucher"].strip()):
 

	
 
        voucher = voucher_form.cleaned_data["voucher"]
 
        voucher = inventory.Voucher.normalise_code(voucher)
 

	
 
        if len(current_cart.cart.vouchers.filter(code=voucher)) > 0:
 
            # This voucher has already been applied to this cart.
 
            # Do not apply code
 
            handled = False
 
        else:
 
            try:
 
                current_cart.apply_voucher(voucher)
 
            except Exception as e:
 
                voucher_form.add_error("voucher", e)
 
            handled = True
 
    else:
 
        handled = False
 

	
 
    return (voucher_form, handled)
 

	
 

	
 
@login_required
 
def checkout(request):
 
    ''' Runs the checkout process for the current cart.
 

	
 
    If the query string contains ``fix_errors=true``, Registrasion will attempt
 
    to fix errors preventing the system from checking out, including by
 
    cancelling expired discounts and vouchers, and removing any unavailable
 
    products.
 

	
 
    Returns:
 
        render or redirect:
 
            If the invoice is generated successfully, or there's already a
 
            valid invoice for the current cart, redirect to ``invoice``.
 
            If there are errors when generating the invoice, render
 
            ``registrasion/checkout_errors.html`` with the following data::
 

	
 
                {
 
                    "error_list", [str, ...]  # The errors to display.
 
                }
 

	
 
    '''
 

	
 
    current_cart = CartController.for_user(request.user)
 

	
 
    if "fix_errors" in request.GET and request.GET["fix_errors"] == "true":
 
        current_cart.fix_simple_errors()
 

	
 
    try:
 
        current_invoice = InvoiceController.for_cart(current_cart.cart)
 
    except ValidationError as ve:
 
        return _checkout_errors(request, ve)
 

	
 
    return redirect("invoice", current_invoice.invoice.id)
0 comments (0 inline, 0 general)