"""Schemas for classical MD package."""

from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from typing_extensions import Annotated
import zlib

from pydantic import (
    BaseModel,
    Field,
    PlainValidator,
    PlainSerializer,
    WithJsonSchema,
)
from monty.json import MSONable

from pymatgen.core import Structure
from emmet.core.vasp.task_valid import TaskState  # type: ignore[import-untyped]


def compressed_str_validator(s: str) -> str:
    try:
        compressed_bytes = bytes.fromhex(s)
        decompressed_bytes = zlib.decompress(compressed_bytes)
        return decompressed_bytes.decode("utf-8")
    except:  # noqa
        return s


def compressed_str_serializer(s: str) -> str:
    decompressed_bytes = s.encode("utf-8")
    return zlib.compress(decompressed_bytes).hex()


# this type will take a string and automatically compress and
# decompress it when it is serialized and deserialized
CompressedStr = Annotated[
    str,
    PlainValidator(compressed_str_validator),
    PlainSerializer(compressed_str_serializer),
    WithJsonSchema({"type": "string"}),
]


@dataclass
class MoleculeSpec(MSONable):
    """A molecule schema to be output by OpenMMGenerators."""

    name: str
    count: int
    charge_scaling: float
    charge_method: str
    openff_mol: str  # a tk.Molecule object serialized with to_json


class MDTaskDocument(BaseModel, extra="allow"):  # type: ignore[call-arg]
    """Definition of the OpenMM task document."""

    tags: Optional[list[str]] = Field(
        [], title="tag", description="Metadata tagged to a given task."
    )

    dir_name: Optional[str] = Field(None, description="The directory for this MD task")

    state: Optional[TaskState] = Field(None, description="State of this calculation")

    job_uuids: Optional[list] = Field(
        None,
        description="The job_uuids for all contributing jobs, this will only"
        "have a value if the taskdoc is generated by a Flow.",
    )

    calcs_reversed: Optional[list] = Field(
        None,
        title="Calcs reversed data",
        description="Detailed data for each MD calculation contributing to "
        "the task document.",
    )

    interchange: Optional[CompressedStr] = Field(
        None, description="An interchange object serialized to json."
    )

    mol_specs: Optional[list[MoleculeSpec]] = Field(
        None,
        description="Molecules within the system. Only makes sense "
        "for molecular systems.",
    )

    structure: Optional[Structure] = Field(
        None,
        title="Structure",
        description="The final structure for the simulation. Saved only "
        "if specified by job.",
    )

    force_field: Optional[str] = Field(None, description="The classical MD forcefield.")

    task_type: Optional[str] = Field(None, description="The type of calculation.")

    # task_label: Optional[str] = Field(None, description="A description of the task")
    # TODO: where does task_label get added

    last_updated: Optional[datetime] = Field(
        None,
        description="Timestamp for the most recent calculation for this task document",
    )


class ClassicalMDTaskDocument(MDTaskDocument):
    """Definition of the OpenMM task document."""

    mol_specs: Optional[list[MoleculeSpec]] = Field(
        None, description="Molecules within the system."
    )
