"""Input/Output for models.
A model's `ParametricFunction` is not directly saved, only an identifier string is written to file. In order to
reconstruct the model from the file, the same mapping from function to identifier needs to be available. This module
implements a function returning a bidirectional mapping for the `ParametricFunction` implemented in the `pkoffeee`
package. This mapping can be extended with additional functions to save other models.
"""
import json
from collections.abc import Iterable, Mapping
from enum import StrEnum
from pathlib import Path
import tomlkit
from bidict import bidict
from tomlkit import aot, document, item
from pkoffee.fit_model import Model
from pkoffee.parametric_function import (
Logistic,
MichaelisMentenSaturation,
Peak2Model,
PeakModel,
Quadratic,
)
[docs]
def pkoffee_function_id_mapping() -> bidict:
"""Bidirectional mapping from string Identifiers to the ParametricFunctions implemented in `pkoffee`."""
return bidict(
{
"Quadratic": Quadratic,
"MichaelisMentenSaturation": MichaelisMentenSaturation,
"Logistic": Logistic,
"PeakModel": PeakModel,
"Peak2Model": Peak2Model,
}
)
[docs]
def save_models_json(model_dicts: Iterable[dict], output_path: Path) -> None:
"""Save the model dictionary representation to a json file.
Parameters
----------
model_dicts : Iterable[dict]
Models dictionary representation
output_path : Path
Path to save the models
"""
with output_path.open("w") as of:
of.write(json.dumps({"Models": model_dicts}))
[docs]
def save_models_toml(model_dicts: Iterable[dict], output_path: Path) -> None:
"""Save the model dictionaries representation to a toml file.
Parameters
----------
model_dicts : Iterable[dict]
Models dictionary representation
output_path : Path
Path to save the models
"""
toml_doc = document()
models_array = aot()
for md in model_dicts:
models_array.append(item(md))
toml_doc.append("Models", models_array)
with output_path.open("w") as of:
of.write(toml_doc.as_string())
[docs]
def save_models(
models: Iterable[Model], function_to_str: Mapping, output_path: Path, file_format: ModelFileFormat | None = None
) -> None:
"""Save the models to disk.
Parameters
----------
models : Iterable[Model]
Collection of models to save
function_to_str : Mapping
Mapping of function to string identifier used as function representation in the model's file.
output_path : Path
Path to the model's file.
file_format : ModelFileFormat
The format of the model's file
"""
if file_format is None:
file_format = file_format_from_path(output_path)
model_dicts = [m.to_dict(function_to_str) for m in models]
match file_format:
case ModelFileFormat.JSON:
save_models_json(model_dicts, output_path)
case ModelFileFormat.TOML:
save_models_toml(model_dicts, output_path)
case _:
raise UnsupportedModelFormatError(str(file_format))
[docs]
def load_models_json(model_file: Path, str_to_function: Mapping) -> list[Model]:
"""Load models from json file.
Parameters
----------
model_file : Path
Path to the models' file
str_to_function : Mapping
Mapping of function string identifier to function classes
Returns
-------
list[Model]
Loaded models
"""
with model_file.open("r") as mdlf:
models_dict = json.loads(mdlf.read())["Models"]
return [Model.from_dict(m_d, str_to_function) for m_d in models_dict]
[docs]
def load_models_toml(model_file: Path, str_to_function: Mapping) -> list[Model]:
"""Load models from toml file.
Parameters
----------
model_file : Path
Path to the models' file
str_to_function : Mapping
Mapping of function string identifier to function classes
Returns
-------
list[Model]
Loaded models
"""
with model_file.open("r") as mdlf:
models_dict = tomlkit.parse(mdlf.read())["Models"]
return [Model.from_dict(m_d, str_to_function) for m_d in models_dict] # type: ignore[misc]
[docs]
def load_models(
model_file: Path, str_to_function: Mapping, file_format: ModelFileFormat | None = None
) -> list[Model]:
"""Load models from file.
Parameters
----------
model_file : Path
Path to the model's file
str_to_function : Mapping
Mapping of function string identifier to function classes
file_format : ModelFileFormat
Format of the model file
Returns
-------
list[Model]
Loaded models
"""
if file_format is None:
file_format = file_format_from_path(model_file)
match file_format:
case ModelFileFormat.JSON:
return load_models_json(model_file, str_to_function)
case ModelFileFormat.TOML:
return load_models_toml(model_file, str_to_function)
case _:
raise UnsupportedModelFormatError(str(file_format))