from __future__ import annotations
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union, TypeVar
from pulp import lpSum
from pydantic import BaseModel

import itertools
import logging

from models.item import Item
from models.constraints.generic_constraint import GenericConstraint
from models.constraints.metadata_constraint import MetadataConstraint
from models.constraints.bundle_constraint import BundleConstraint
from models.constraints.form_uniqueness_constraint import FormUniquenessConstraint
from models.constraints.total_form_items_constraint import TotalFormItemsConstraint
from models.constraints.enemy_pair_constraint import EnemyPairConstraint
from models.irt_model import IRTModel
from models.bundle import Bundle
from models.objective_function import ObjectiveFunction
from models.advanced_options import AdvancedOptions

if TYPE_CHECKING:
    from models.solution import Solution
    from models.problem import Problem

ConstraintType = TypeVar('ConstraintType', bound=GenericConstraint)

class SolverRun(BaseModel):
    items: List[Item] = []
    bundles: List[Bundle] = []
    bundle_first_ordering: bool = True
    constraints: List[ConstraintType]
    irt_model: IRTModel
    objective_function: ObjectiveFunction
    total_form_items: int
    total_forms: int = 1
    theta_cut_score: float = 0.00
    drift_style: Literal['constant', 'variable'] = 'constant'
    allow_enemies: bool = False
    advanced_options: Optional[AdvancedOptions]
    engine: str

    def __init__(self, **data) -> None:
        super().__init__(**data)

        # this is all a compensator for dynamically creating objects
        # ideally we'd change the payload to determine what type it is
        constraints: [ConstraintType] = []

        # repackage to create appropriate constraint types
        for constraint in self.constraints:
            if constraint.reference_attribute.type == 'metadata':
                constraints.append(MetadataConstraint(reference_attribute=constraint.reference_attribute, minimum=constraint.minimum, maximum=constraint.maximum))
            elif constraint.reference_attribute.type == 'bundle':
                constraints.append(BundleConstraint(reference_attribute=constraint.reference_attribute, minimum=constraint.minimum, maximum=constraint.maximum))

        self.constraints = constraints

    def get_item(self, item_id: int) -> Item or None:
        for item in self.items:
            if item.id == item_id:
                return item

    def get_bundle(self, bundle_id: int) -> Bundle or None:
        for bundle in self.bundles:
            if bundle.id == bundle_id:
                return bundle

    def get_constraint_by_type(self, type: str) -> ConstraintType or None:
        for constraint in self.constraints:
            if type == constraint.reference_attribute.type:
                return constraint

    def remove_items(self, items: List[Item]) -> bool:
        self.items = [item for item in self.items if item not in items]
        return True

    def generate_constraints(self) -> None:
        # total form items
        self.constraints.append(TotalFormItemsConstraint.create(self.total_form_items))

        # ensure form uniqueness
        if self.advanced_options.ensure_form_uniqueness:
            self.constraints.append(FormUniquenessConstraint.create(self.total_form_items - 1))

        # enemies constraints
        for pair in self.enemy_pairs():
            self.constraints.append(EnemyPairConstraint.create(pair))

    def generate_bundles(self):
        logging.info('Generating Bundles...')
        # confirms bundle constraints exists
        bundle_constraints = (
            constraint.reference_attribute for constraint in self.constraints
            if constraint.reference_attribute.type == 'bundle')

        for bundle_constraint in bundle_constraints:
            type_attribute = bundle_constraint.id

            for item in self.items:
                attribute_id = getattr(item, type_attribute, None)

                # make sure the item has said attribute
                if attribute_id != None:
                    # if there are pre-existing bundles, add new or increment existing
                    # else create array with new bundle
                    if self.bundles != None:
                        # get index of the bundle in the bundles list if exists or None if it doesn't
                        bundle_index = next(
                            (index
                             for (index, bundle) in enumerate(self.bundles)
                             if bundle.id == attribute_id
                             and bundle.type == type_attribute), None)

                        # if the index doesn't exist add the new bundle of whatever type
                        # else increment the count of the current bundle
                        if bundle_index == None:
                            self.bundles.append(
                                Bundle(id=attribute_id,
                                       count=1,
                                       items=[item],
                                       type=type_attribute))
                        else:
                            self.bundles[bundle_index].count += 1
                            self.bundles[bundle_index].items.append(item)
                    else:
                        self.bundles = [
                            Bundle(id=attribute_id,
                                   count=1,
                                   items=[item],
                                   type=type_attribute)
                        ]
        # temporary compensator for bundle item limits, since we shouldn't be using cases with less than 3 items
        # ideally this should be in the bundles model as a new attribute to handle "constraints of constraints"
        logging.info('Removing bundles with items < 3')
        for k, v in enumerate(self.bundles):
            bundle = self.bundles[k]
            if bundle.count < 3: del self.bundles[k]

        logging.info('Bundles Generated...')

    def get_constraint(self, name: str) -> ConstraintType:
        return next((constraint for constraint in self.constraints
                     if constraint.reference_attribute.id == name), None)

    def unbundled_items(self) -> List[Item]:
        # since the only bundles are based on passage id currently
        # in the future when we have more than just passage based bundles
        # we'll need to develop a more sophisticated way of handling this concern
        bundle_constraints = (
            constraint.reference_attribute for constraint in self.constraints
            if constraint.reference_attribute.type == 'bundle')

        if len(list(bundle_constraints)) > 0:
            return [item for item in self.items if item.passage_id == None]
        else:
            return self.items

    def enemy_pairs(self) -> List[List[int]]:
        pairs = []

        for item in self.items:
            # add enemy pairs for item to pairs
            pairs += item.enemy_pairs()

        # remove duplicates
        pairs.sort()
        return list(k for k,_ in itertools.groupby(pairs))