from pydantic import BaseModel, validator
from typing import List, Optional

from models.attribute import Attribute

from lib.irt.item_response_function import ItemResponseFunction
from lib.irt.item_information_function import ItemInformationFunction

class Item(BaseModel):
    id: int
    position: Optional[int] = None
    passage_id: Optional[int] = None
    enemies: List[int] = []
    workflow_state: Optional[str] = None
    attributes: List[Attribute] = None
    b_param: float = 0.00
    response: Optional[int] = None

    @validator("enemies", pre=True)
    def set_enemies(cls, v) -> List[id]:
        if v == "":
            return []
        enemies = list(filter(None, [int(enemy) for enemy in v.split(",")]))
        return enemies

    def iif(self, irt_model, theta):
        return ItemInformationFunction(irt_model).calculate(b_param=self.b_param, theta=theta)

    def irf(self, irt_model, theta):
        return ItemResponseFunction(irt_model).calculate(b_param=self.b_param, theta=theta)

    def get_attribute(self, ref_attribute: Attribute) -> Attribute or None:
        for attribute in self.attributes:
            if self.attribute_exists(ref_attribute):
                return attribute

        return None

    def attribute_exists(self, ref_attribute: Attribute) -> bool:
        for attribute in self.attributes:
            if attribute.id == ref_attribute.id and attribute.value.lower(
            ) == ref_attribute.value.lower():
                return True
        return False

    def iif_irf_sum(self, solver_run):
        return self.__iif_sum(solver_run) + self.__irf_sum(solver_run)

    def __iif_sum(self, solver_run):
        total = 0

        for target in solver_run.objective_function.tif_targets:
            total += self.iif(solver_run.irt_model, target.theta)

        return total

    def __irf_sum(self, solver_run):
        total = 0

        for target in solver_run.objective_function.tif_targets:
            total += self.irf(solver_run.irt_model, target.theta)

        return total

    def enemy_pairs(self, sort: bool = True) -> List[List[int]]:
        pairs = []

        for enemy_id in self.enemies:
            pair = [self.id, enemy_id]

            if sort: pair.sort()

            pairs.append(pair)

        return pairs