diff --git a/doc/conf.py b/doc/conf.py index 975cad03..41c41496 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -9,6 +9,8 @@ import sys import warnings +from pydantic import BaseModel + # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, @@ -126,6 +128,9 @@ def skip_some_objects(app, what, name, obj, skip, options): """Exclude some objects from the documentation""" if getattr(obj, "__module__", None) == "collections": return True + # Napoleon + Pydantic v2 bug: BaseModel itself triggers __getattr__ error + if obj is BaseModel: + return True def setup(app): diff --git a/petab/v2/C.py b/petab/v2/C.py index e640ae5c..d738804f 100644 --- a/petab/v2/C.py +++ b/petab/v2/C.py @@ -258,6 +258,8 @@ MAPPING_FILES = "mapping_files" #: Extensions key in the YAML file EXTENSIONS = "extensions" +#: PEtab SciML extension +EXT_ID_SCIML = "sciml" # MAPPING diff --git a/petab/v2/core.py b/petab/v2/core.py index fb206502..fb7c9863 100644 --- a/petab/v2/core.py +++ b/petab/v2/core.py @@ -19,6 +19,7 @@ Annotated, Any, Generic, + Literal, Self, TypeVar, get_args, @@ -308,6 +309,22 @@ def __iadd__(self, other: T) -> BaseTable[T]: return self +# SciML extension classes — imported after BaseTable is defined to avoid +# circular imports (sciml.py does not import from core.py). +from .extensions.sciml import ( # noqa: E402 + HybridizationTable, + SciMLConfig, + SciMLExt, +) + + +class ProblemExtensions: + """Runtime extension state attached to a :class:`Problem`.""" + + def __init__(self, sciml: SciMLExt = None): + self.sciml: SciMLExt = sciml or SciMLExt() + + class Observable(BaseModel): """Observable definition.""" @@ -318,9 +335,9 @@ class Observable(BaseModel): #: Observable name. name: str | None = Field(alias=C.OBSERVABLE_NAME, default=None) #: Observable formula. - formula: sp.Basic | None = Field(alias=C.OBSERVABLE_FORMULA, default=None) + formula: sp.Basic = Field(alias=C.OBSERVABLE_FORMULA) #: Noise formula. - noise_formula: sp.Basic | None = Field(alias=C.NOISE_FORMULA, default=None) + noise_formula: sp.Basic = Field(alias=C.NOISE_FORMULA) #: Noise distribution. noise_distribution: NoiseDistribution = Field( alias=C.NOISE_DISTRIBUTION, default=NoiseDistribution.NORMAL @@ -926,7 +943,8 @@ class Parameter(BaseModel): ) #: Nominal value. nominal_value: Annotated[ - float | None, BeforeValidator(_convert_nan_to_none) + # PEtab SciML supports arrays via "array" nominal values + float | Literal["array"] | None, BeforeValidator(_convert_nan_to_none) ] = Field(alias=C.NOMINAL_VALUE, default=None) #: Is the parameter to be estimated? estimate: bool = Field(alias=C.ESTIMATE, default=True) @@ -1133,15 +1151,25 @@ def __init__( measurement_tables: list[MeasurementTable] = None, parameter_tables: list[ParameterTable] = None, mapping_tables: list[MappingTable] = None, + extensions: ProblemExtensions = None, config: ProblemConfig = None, ): - from ..v2.lint import default_validation_tasks + from ..v2.lint import default_validation_tasks, sciml_validation_tasks self.config = config self.models: list[Model] = models or [] - self.validation_tasks: list[ValidationTask] = ( - default_validation_tasks.copy() - ) + if ( + config + and config.extensions + and config.extensions.get(C.EXT_ID_SCIML) + ): + self.validation_tasks: list[ValidationTask] = ( + sciml_validation_tasks.copy() + ) + else: + self.validation_tasks: list[ValidationTask] = ( + default_validation_tasks.copy() + ) self.observable_tables = observable_tables or [ObservableTable()] self.condition_tables = condition_tables or [ConditionTable()] @@ -1149,6 +1177,7 @@ def __init__( self.measurement_tables = measurement_tables or [MeasurementTable()] self.mapping_tables = mapping_tables or [MappingTable()] self.parameter_tables = parameter_tables or [ParameterTable()] + self.extensions = extensions or ProblemExtensions() def __repr__(self): return f"<{self.__class__.__name__} id={self.id!r}>" @@ -1321,6 +1350,45 @@ def from_yaml( else None ) + extensions = ProblemExtensions() + if config.extensions and config.extensions.get(C.EXT_ID_SCIML): + from petab_sciml import ArrayDataStandard, NNModel, NNModelStandard + + # Neural network classes are constructed via pytorch for now to get + # the proper inputs + neural_networks = [ + NNModel.from_pytorch_module( + NNModelStandard.load_data( + _generate_path( + file_path=nn_config.location, + base_path=base_path, + ) + ).to_pytorch_module(), + nn_model_id=nn_id, + ) + for nn_id, nn_config in ( + config.extensions[C.EXT_ID_SCIML].neural_networks or {} + ).items() + ] + + hybridization_tables = [ + HybridizationTable.from_tsv(f, base_path) + for f in config.extensions[C.EXT_ID_SCIML].hybridization_files + ] + + array_data_files = [ + ArrayDataStandard.load_data(_generate_path(f, base_path)) + for f in config.extensions[C.EXT_ID_SCIML].array_files + ] + + extensions = ProblemExtensions( + sciml=SciMLExt( + neural_networks=neural_networks, + hybridization_tables=hybridization_tables, + array_data_files=array_data_files, + ) + ) + return Problem( config=config, models=models, @@ -1330,6 +1398,7 @@ def from_yaml( measurement_tables=measurement_tables, parameter_tables=parameter_tables, mapping_tables=mapping_tables, + extensions=extensions, ) @staticmethod @@ -1940,14 +2009,21 @@ def validate( validation_results = ValidationResultList() - if self.config and self.config.extensions: - extensions = ",".join(self.config.extensions.keys()) + supported_extensions = {C.EXT_ID_SCIML} + if ( + self.config + and self.config.extensions + and (self.config.extensions.keys() - supported_extensions) + ): + extensions_without_support = ",".join( + self.config.extensions.keys() - supported_extensions + ) validation_results.append( ValidationIssue( ValidationIssueSeverity.WARNING, - "Validation of PEtab extensions is not yet implemented, " - "but the given problem uses the following extensions: " - f"{extensions}", + "The given problem uses the following extensions for " + "which validation is not yet implemented: " + f"{extensions_without_support}", ) ) @@ -2505,6 +2581,23 @@ class ProblemConfig(BaseModel): validate_assignment=True, ) + @field_validator("extensions", mode="before") + @classmethod + def _parse_extensions(cls, v): + """Parse extensions dict and convert known extensions to their specific + config classes.""" + if isinstance(v, dict): + parsed_extensions = {} + for ext_name, ext_config in v.items(): + if ext_name == C.EXT_ID_SCIML: + # Convert sciml extension to SciMLConfig + parsed_extensions[ext_name] = SciMLConfig(**ext_config) + else: + # Keep other extensions as ExtensionConfig + parsed_extensions[ext_name] = ExtensionConfig(**ext_config) + return parsed_extensions + return v + # convert parameter_file to list @field_validator( "parameter_files", @@ -2542,12 +2635,22 @@ def to_yaml(self, filename: str | Path): for model_id in data.get("model_files", {}): data["model_files"][model_id][C.MODEL_LOCATION] = str( - data["model_files"][model_id]["location"] + data["model_files"][model_id][C.MODEL_LOCATION] ) if data["id"] is None: # The schema requires a valid id or no id field at all. del data["id"] + for ext_id, d_ext in data[C.EXTENSIONS].items(): + if ext_id == C.EXT_ID_SCIML: + # convert Paths to strings + for key in ("array_files", "hybridization_files"): + d_ext[key] = list(map(str, d_ext[key])) + for nn in d_ext["neural_networks"]: + d_ext["neural_networks"][nn][C.MODEL_LOCATION] = str( + d_ext["neural_networks"][nn][C.MODEL_LOCATION] + ) + write_yaml(data, filename) @property diff --git a/petab/v2/extensions/__init__.py b/petab/v2/extensions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/petab/v2/extensions/sciml.py b/petab/v2/extensions/sciml.py new file mode 100644 index 00000000..0b483783 --- /dev/null +++ b/petab/v2/extensions/sciml.py @@ -0,0 +1,233 @@ +"""PEtab SciML extension — classes and runtime state for hybrid ODE/ML +problems.""" + +from __future__ import annotations + +from itertools import chain +from pathlib import Path + +import numpy as np +import pandas as pd +import sympy as sp +from pydantic import AnyUrl, BaseModel, ConfigDict, Field, field_validator + +from petab._utils import _generate_path +from petab.v1.math import sympify_petab + +try: + from petab_sciml import ( + ArrayData, + ArrayDataStandard, + NNModel, + NNModelStandard, + ) +except ModuleNotFoundError: + pass + +from .. import C + + +class Hybridization(BaseModel): + """Assigns PEtab SciML NN inputs and outputs.""" + + #: The target ID. + target_id: str = Field(alias=C.TARGET_ID) + #: The target value. + target_value: sp.Basic = Field(alias=C.TARGET_VALUE) + + #: :meta private: + model_config = ConfigDict( + arbitrary_types_allowed=True, + populate_by_name=True, + extra="allow", + validate_assignment=True, + ) + + @field_validator("target_value", mode="before") + @classmethod + def _sympify(cls, v): + if v is None or isinstance(v, sp.Basic): + return v + if isinstance(v, float) and np.isnan(v): + return None + return sympify_petab(v) + + +class HybridizationTable: + """PEtab SciML hybridization table.""" + + def __init__(self, hybridizations: list[Hybridization] = None, **kwargs): + self.hybridizations: list[Hybridization] = hybridizations or [] + self.rel_path: AnyUrl | Path | None = kwargs.get("rel_path") + self.base_path: AnyUrl | Path | None = kwargs.get("base_path") + + @property + def elements(self) -> list[Hybridization]: + return self.hybridizations + + @classmethod + def from_df(cls, df: pd.DataFrame, **kwargs) -> HybridizationTable: + """Create a HybridizationTable from a DataFrame.""" + if df is None: + return cls(**kwargs) + + hybridizations = [ + Hybridization(**row.to_dict()) for _, row in df.iterrows() + ] + return cls(hybridizations, **kwargs) + + @classmethod + def from_tsv( + cls, + file_path: str | Path, + base_path: str | Path | None = None, + ) -> HybridizationTable: + """Create a HybridizationTable from a TSV file.""" + df = pd.read_csv(_generate_path(file_path, base_path), sep="\t") + return cls.from_df(df, rel_path=file_path, base_path=base_path) + + def to_df(self) -> pd.DataFrame: + """Convert the HybridizationTable to a DataFrame.""" + records = [h.model_dump(by_alias=True) for h in self.hybridizations] + return pd.DataFrame(records) + + def to_tsv(self, file_path: str | Path = None) -> None: + """Write the table to a TSV file.""" + df = self.to_df() + df.to_csv( + file_path or _generate_path(self.rel_path, self.base_path), + sep="\t", + index=False, + ) + + def __getitem__(self, target_id: str) -> Hybridization: + """Get a hybridization by target ID.""" + for hybridization in self.hybridizations: + if hybridization.target_id == target_id: + return hybridization + raise KeyError(f"Target ID {target_id} not found") + + def get(self, target_id, default=None): + """Get a hybridization by target ID or return a default value.""" + try: + return self[target_id] + except KeyError: + return default + + +class NeuralNetConfig(BaseModel): + """A neural net in the PEtab SciML problem configuration.""" + + location: AnyUrl | Path + pre_initialization: bool + format: str + + model_config = ConfigDict( + validate_assignment=True, + ) + + +class SciMLConfig(BaseModel): + """The extended configuration of a PEtab SciML problem.""" + + #: The PEtab SciML format version. + version: str = "0.1.0" + #: The paths to the array data files. + array_files: list[AnyUrl | Path] = [] + #: The paths to the hybridization tables. + hybridization_files: list[AnyUrl | Path] = [] + #: The neural network IDs and info. + neural_networks: dict[str, NeuralNetConfig] | None = {} + + model_config = ConfigDict( + validate_assignment=True, + ) + + +class SciMLExt: + """SciML extension runtime state. + + Accessible as ``Problem.extensions.sciml``. + """ + + def __init__( + self, + neural_networks: list = None, + hybridization_tables: list[HybridizationTable] = None, + array_data_files: list = None, + ): + self.neural_networks: list = neural_networks or [] + self.hybridization_tables: list[HybridizationTable] = ( + hybridization_tables or [HybridizationTable()] + ) + self.array_data_files: list = array_data_files or [] + + @property + def hybridizations(self) -> list[Hybridization]: + """Flat list of all hybridizations across all hybridization tables.""" + return list( + chain.from_iterable( + ht.hybridizations for ht in self.hybridization_tables + ) + ) + + @property + def hybridization_df(self) -> pd.DataFrame | None: + """Combined hybridization tables as a single DataFrame.""" + hybs = self.hybridizations + return HybridizationTable(hybs).to_df() if hybs else None + + @hybridization_df.setter + def hybridization_df(self, value: pd.DataFrame): + self.hybridization_tables = [HybridizationTable.from_df(value)] + + def add_hybridization(self, target_id: str, target_value: str): + """Add a hybridization entry. + + If there is more than one hybridization table the entry is added to + the last one. + + Arguments: + target_id: The ID of the target entity in the PEtab problem + or neural network model + target_value: The value that is assigned to the target id. + """ + if not self.hybridization_tables: + self.hybridization_tables.append(HybridizationTable()) + self.hybridization_tables[-1].hybridizations.append( + Hybridization(target_id=target_id, target_value=target_value) + ) + + def add_neural_network_from_dict(self, model_id: str, nn_dict: dict): + """Add a neural network from a dictionary.""" + nn_model = NNModel.model_validate(nn_dict) + nn_model.nn_model_id = model_id + self.neural_networks.append(nn_model) + + def add_neural_network_from_yaml( + self, + model_id: str, + file_path: str | Path, + base_path: str | Path | None = None, + ): + """Add a neural network from a YAML file.""" + self.neural_networks.append( + NNModelStandard.load_data( + _generate_path(file_path=file_path, base_path=base_path), + nn_model_id=model_id, + ) + ) + + def add_array_data_from_dict(self, array_data: dict): + """Add array data from a dictionary.""" + self.array_data_files.append(ArrayData.model_validate(array_data)) + + def add_array_data_from_hdf5( + self, + file_path: str | Path, + base_path: str | Path | None = None, + ): + """Add array data from an HDF5 file.""" + self.array_data_files.append( + ArrayDataStandard.load_data(_generate_path(file_path, base_path)) + ) diff --git a/petab/v2/extensions/sciml_lint.py b/petab/v2/extensions/sciml_lint.py new file mode 100644 index 00000000..c4f65179 --- /dev/null +++ b/petab/v2/extensions/sciml_lint.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from .. import core, lint + + +class CheckHybridizationTable(lint.ValidationTask): + """Validate the SciML hybridization table.""" + + def run(self, problem: core.Problem) -> lint.ValidationIssue | None: + messages = [] + + condition_targets = { + c.target_id for ct in problem.conditions for c in ct.changes + } + nn_input_ids = { + inp.input_id + for nn in problem.extensions.sciml.neural_networks + for inp in nn.inputs + } + hyb_target_ids = { + hyb.target_id for hyb in problem.extensions.sciml.hybridizations + } + hyb_target_vals = { + hyb.target_value for hyb in problem.extensions.sciml.hybridizations + } + + # Hybridization targets are not also targets in the condition table + if culprits := (hyb_target_ids & condition_targets): + messages.append( + f"Hybridization target ids `{culprits}` are also " + "target ids in the condition table." + ) + # NN inputs are not used as target values + if culprits := (hyb_target_vals & nn_input_ids): + messages.append( + "The following neural net inputs were used as target values " + f"in the Hybridization table: `{culprits}`." + ) + + if messages: + return lint.ValidationError("\n".join(messages)) + + return None diff --git a/petab/v2/lint.py b/petab/v2/lint.py index 687d58f2..4260e40a 100644 --- a/petab/v2/lint.py +++ b/petab/v2/lint.py @@ -44,6 +44,8 @@ "CheckPriorDistribution", "CheckUndefinedExperiments", "CheckInitialChangeSymbols", + "CheckMappingTable", + "CheckHybridizationTable", "lint_problem", "default_validation_tasks", ] @@ -445,7 +447,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: # check for uniqueness of all primary keys counter = Counter(c.id for c in problem.conditions) - duplicates = {id_ for id_, count in counter.items() if count > 1} + duplicates = sorted(id_ for id_, count in counter.items() if count > 1) if duplicates: return ValidationError( @@ -453,7 +455,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: ) counter = Counter(o.id for o in problem.observables) - duplicates = {id_ for id_, count in counter.items() if count > 1} + duplicates = sorted(id_ for id_, count in counter.items() if count > 1) if duplicates: return ValidationError( @@ -461,7 +463,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: ) counter = Counter(e.id for e in problem.experiments) - duplicates = {id_ for id_, count in counter.items() if count > 1} + duplicates = sorted(id_ for id_, count in counter.items() if count > 1) if duplicates: return ValidationError( @@ -469,7 +471,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: ) counter = Counter(p.id for p in problem.parameters) - duplicates = {id_ for id_, count in counter.items() if count > 1} + duplicates = sorted(id_ for id_, count in counter.items() if count > 1) if duplicates: return ValidationError( @@ -508,7 +510,9 @@ def run(self, problem: Problem) -> ValidationIssue | None: for experiment in problem.experiments: # Check that there are no duplicate timepoints counter = Counter(period.time for period in experiment.periods) - duplicates = {time for time, count in counter.items() if count > 1} + duplicates = sorted( + time for time, count in counter.items() if count > 1 + ) if duplicates: messages.append( f"Experiment {experiment.id} contains duplicate " @@ -551,7 +555,8 @@ def run(self, problem: Problem) -> ValidationIssue | None: class CheckAllParametersPresentInParameterTable(ValidationTask): """Ensure all required parameters are contained in the parameter table - with no additional ones.""" + with no additional ones. This also ensures that the mapping table petab ids + are used in the PEtab problem.""" def run(self, problem: Problem) -> ValidationIssue | None: if problem.model is None: @@ -825,8 +830,8 @@ def run(self, problem: Problem) -> ValidationIssue | None: if parameter.prior_distribution not in PRIOR_DISTRIBUTIONS: messages.append( - f"Prior distribution `{parameter.prior_distribution}' " - f"for parameter `{parameter.id}' is not valid." + f"Prior distribution `{parameter.prior_distribution}` " + f"for parameter `{parameter.id}` is not valid." ) continue @@ -834,8 +839,8 @@ def run(self, problem: Problem) -> ValidationIssue | None: exp_num_par := self._num_pars[parameter.prior_distribution] ) != len(parameter.prior_parameters): messages.append( - f"Prior distribution `{parameter.prior_distribution}' " - f"for parameter `{parameter.id}' requires " + f"Prior distribution `{parameter.prior_distribution}` " + f"for parameter `{parameter.id}` requires " f"{exp_num_par} parameters, but got " f"{len(parameter.prior_parameters)} " f"({parameter.prior_parameters})." @@ -848,8 +853,8 @@ def run(self, problem: Problem) -> ValidationIssue | None: _ = parameter.prior_dist.sample(1) except Exception as e: messages.append( - f"Prior parameters `{parameter.prior_parameters}' " - f"for parameter `{parameter.id}' are invalid " + f"Prior parameters `{parameter.prior_parameters}` " + f"for parameter `{parameter.id}` are invalid " f"(hint: {e})." ) @@ -874,7 +879,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: continue messages.append( - f"Measurement `{measurement}' does not have a model ID, " + f"Measurement `{measurement}` does not have a model ID, " "but there are multiple models available. " "Please specify the model ID in the measurement table." ) @@ -882,8 +887,8 @@ def run(self, problem: Problem) -> ValidationIssue | None: if measurement.model_id not in available_models: messages.append( - f"Measurement `{measurement}' has model ID " - f"`{measurement.model_id}' which does not match " + f"Measurement `{measurement}` has model ID " + f"`{measurement.model_id}` which does not match " "any of the available models: " f"{available_models}." ) @@ -894,6 +899,78 @@ def run(self, problem: Problem) -> ValidationIssue | None: return None +class CheckMappingTable(ValidationTask): + """Validate the mapping table.""" + + def run(self, problem: Problem) -> ValidationIssue | None: + messages = [] + + # Mapping table is optional + if problem.mappings: + # Check that each id, across both the petabEntityId and + # modelEntityId columns, occurs only once + must_be_unique_ids = [] + for mapping in problem.mappings: + petab_id = getattr(mapping, "petab_id", None) + model_id = getattr(mapping, "model_id", None) + + if petab_id: + must_be_unique_ids.append(petab_id) + # Duplicates for annotation-only rows (identity mappings) + # are permitted. + if petab_id == model_id: + continue + if model_id: + must_be_unique_ids.append(model_id) + + non_unique_ids = sorted( + id_ + for id_, count in Counter(must_be_unique_ids).items() + if count > 1 + ) + if non_unique_ids: + return ValidationError( + f"Mapping table contains non-unique IDs: {non_unique_ids}." + ) + + # petabEntityId is not defined elsewhere in the PEtab problem + new_petab_ids = { + m.petab_id + for m in problem.mappings + # Ignore identity mappings used for annotation + if m.petab_id != m.model_id + } + old_petab_ids = ( + {c.id for c in problem.conditions} + | {e.id for e in problem.experiments} + | {o.id for o in problem.observables} + ) + if overdefined_ids := sorted(new_petab_ids & old_petab_ids): + messages.append( + f"PEtab IDs `{overdefined_ids}` are " + "defined in the mapping table but also defined through " + "other PEtab tables." + ) + + for mapping in problem.mappings: + # petabEntityId not referenced in any model, if alias + for model in problem.models: + if ( + mapping.petab_id != mapping.model_id + and model.has_entity_with_id(mapping.petab_id) + ): + messages.append( + f"`{mapping.petab_id}` is used in the mapping " + "table and referenced directly in the model " + f"`{model.model_id}`." + ) + + if messages: + return ValidationError("\n".join(messages)) + + return None + + def get_valid_parameters_for_parameter_table( problem: Problem, ) -> set[str]: @@ -933,9 +1010,20 @@ def get_valid_parameters_for_parameter_table( if p not in invalid ) + # Add petab ids from mapping table if they are used for aliasing + # FIXME only add mapping.petab_id to allowed parameter IDs list if it + # aliases an invalid PEtab ID? See + # https://github.com/PEtab-dev/libpetab-python/pull/482#discussion_r3420762034 for mapping in problem.mappings: - if mapping.model_id and mapping.model_id in parameter_ids.keys(): + if mapping.petab_id not in invalid: parameter_ids[mapping.petab_id] = None + # An aliased model id is not a valid parameter id + if ( + mapping.model_id + and mapping.model_id != mapping.petab_id + and mapping.model_id in parameter_ids + ): + del parameter_ids[mapping.model_id] # add output parameters from observable table output_parameters = problem.get_output_parameters() @@ -977,20 +1065,13 @@ def get_required_parameters_for_parameter_table( measurement table as well as all parametric condition table overrides that are not defined in the model. """ - parameter_ids = set() - condition_targets = { - change.target_id - for cond in problem.conditions - for change in cond.changes - } + # Start with mapping table petab ids + parameter_ids = {m.petab_id for m in problem.mappings} # Add parameters from measurement table, unless they are fixed parameters def append_overrides(overrides): parameter_ids.update( - str_p - for p in overrides - if isinstance(p, sp.Symbol) - and (str_p := str(p)) not in condition_targets + str(p) for p in overrides if isinstance(p, sp.Symbol) ) for m in problem.measurements: @@ -1033,9 +1114,24 @@ def append_overrides(overrides): if not problem.model.has_entity_with_id(str(p)) ) - # parameters that are overridden via the condition table are not allowed + # Parameters that are overridden via the condition table are not allowed + condition_targets = { + change.target_id + for cond in problem.conditions + for change in cond.changes + } parameter_ids -= condition_targets + hybridization_targets = { + hyb.target_id for hyb in problem.extensions.sciml.hybridizations + } + parameter_ids -= hybridization_targets + hybridization_target_values = { + str(hyb.target_value) + for hyb in problem.extensions.sciml.hybridizations + } + parameter_ids -= hybridization_target_values + return parameter_ids @@ -1090,5 +1186,13 @@ def get_placeholders( CheckUnusedConditions(), CheckPriorDistribution(), CheckInitialChangeSymbols(), - # TODO validate mapping table + CheckMappingTable(), +] + +# Import SciML validation from sciml_lint at the end to avoid circular imports +from ..v2.extensions.sciml_lint import CheckHybridizationTable # noqa: E402 + +#: Validation tasks that should be run PEtab SciML problems +sciml_validation_tasks = default_validation_tasks + [ + CheckHybridizationTable(), ] diff --git a/pyproject.toml b/pyproject.toml index 0295cfa6..d6dfccc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ maintainers = [ tests = [ "antimony>=3.1.0", "copasi-basico>=0.85", + "petab_sciml", "pysb", "pytest", "pytest-cov", @@ -71,6 +72,9 @@ vis = [ "seaborn", "scipy" ] +sciml = [ + "petab_sciml", +] [project.scripts] petablint = "petab.petablint:main" diff --git a/tests/v2/test_core.py b/tests/v2/test_core.py index 22dbf0e1..7b5ad420 100644 --- a/tests/v2/test_core.py +++ b/tests/v2/test_core.py @@ -181,7 +181,7 @@ def test_measurments(): def test_observable(): - Observable(id="obs1", formula=x + y) + Observable(id="obs1", formula=x + y, noiseFormula=1) Observable(id="obs1", formula="x + y", noise_formula="x + y") Observable(id="obs1", formula=1, noise_formula=2) Observable( @@ -198,9 +198,17 @@ def test_observable(): observable_parameters=[sp.Symbol("p1")], noise_parameters=[sp.Symbol("n1")], ) - assert Observable(id="obs1", formula="x + y", non_petab=1).non_petab == 1 + assert ( + Observable( + id="obs1", + formula="x + y", + noise_formula="x + y", + non_petab=1, + ).non_petab + == 1 + ) - o = Observable(id="obs1", formula=x + y) + o = Observable(id="obs1", formula=x + y, noise_formula=1) assert o.observable_placeholders == [] assert o.noise_placeholders == [] @@ -492,14 +500,14 @@ def test_modify_problem(): problem.condition_df, exp_condition_df, check_dtype=False ) - problem.add_observable("observable1", "1") + problem.add_observable("observable1", "1", noise_formula=1) problem.add_observable("observable2", "2", noise_formula=2.2) exp_observable_df = pd.DataFrame( data={ OBSERVABLE_ID: ["observable1", "observable2"], OBSERVABLE_FORMULA: [1, 2], - NOISE_FORMULA: [np.nan, 2.2], + NOISE_FORMULA: [1, 2.2], } ).set_index([OBSERVABLE_ID]) assert_frame_equal( diff --git a/tests/v2/test_lint.py b/tests/v2/test_lint.py index 7eb6dc91..a101ee64 100644 --- a/tests/v2/test_lint.py +++ b/tests/v2/test_lint.py @@ -43,7 +43,7 @@ def test_invalid_model_id_in_measurements(): """Test that measurements with an invalid model ID are caught.""" problem = Problem() problem.models.append(SbmlModel.from_antimony("p1 = 1", model_id="model1")) - problem.add_observable("obs1", "A") + problem.add_observable("obs1", "A", 1) problem.add_measurement("obs1", experiment_id="e1", time=0, measurement=1) check = CheckMeasurementModelId() @@ -70,7 +70,7 @@ def test_undefined_experiment_id_in_measurements(): """Test that measurements with an undefined experiment ID are caught.""" problem = Problem() problem.add_experiment("e1", 0, "c1") - problem.add_observable("obs1", "A") + problem.add_observable("obs1", "A", 1) problem.add_measurement("obs1", experiment_id="e1", time=0, measurement=1) check = CheckUndefinedExperiments() @@ -107,3 +107,43 @@ def test_validate_initial_change_symbols(): problem.parameter_tables[0].parameters.remove(problem["p2"]) assert (error := check.run(problem)) is not None assert "contains additional symbols: {'p2'}" in error.message + + +def test_check_mapping_table(): + """Test checks related to the mapping table.""" + problem = Problem() + # FIXME see https://github.com/PEtab-dev/libpetab-python/pull/482#discussion_r3431330125 + problem.model = SbmlModel.from_antimony("a.mean = 1") + problem.add_mapping( + petab_id="a_m", + model_id="a.mean", + name=None, + ) + problem.add_parameter( + "a_m", + estimate=True, + nominal_value=2, + lb=0, + ub=10, + ) + + check = CheckMappingTable() + assert check.run(problem) is None + + check = CheckAllParametersPresentInParameterTable() + assert check.run(problem) is None + + # add a petab id without model id but with name for annotation + problem.add_mapping(petab_id="p2", model_id=None, name="Parameter 2") + problem.add_parameter("p2", estimate=True, nominal_value=1, lb=0, ub=10) + + check = CheckMappingTable() + assert check.run(problem) is None + + # Invalid: petabEntityId is referenced in the model + problem.model = SbmlModel.from_antimony("a.mean = 1; a_m = 2") + assert (error := check.run(problem)) is not None + assert ( + "`a_m` is used in the mapping table and referenced directly" + in error.message + ) diff --git a/tests/v2/test_sciml.py b/tests/v2/test_sciml.py new file mode 100644 index 00000000..f25b85b8 --- /dev/null +++ b/tests/v2/test_sciml.py @@ -0,0 +1,141 @@ +import numpy as np +from pydantic import ConfigDict + +from petab.v2.core import * +from petab.v2.core import ModelFile +from petab.v2.extensions.sciml import NeuralNetConfig +from petab.v2.lint import sciml_validation_tasks +from petab.v2.models.sbml_model import SbmlModel + + +def _get_test_problem(): + problem = Problem() + problem.validation_tasks = sciml_validation_tasks + problem.config = ProblemConfig( + format_version="2.0.0", + model_files=ConfigDict( + {"lv": ModelFile(location="lv.xml", language="sbml")} + ), + parameter_files=["parameters.tsv"], + measurement_files=["measurements.tsv"], + observable_files=["observables.tsv"], + experiment_files=["experiments.tsv"], + mapping_files=["mappings.tsv"], + extensions={ + "sciml": { + "version": "0.1.0", + "array_files": ["net1_ps.hdf5"], + "hybridization_files": ["hybridizations.tsv"], + "neural_networks": { + "net1": NeuralNetConfig( + location="net1.yaml", + pre_initialization=False, + format="YAML", + ) + }, + } + }, + ) + problem.model = SbmlModel.from_antimony(""" + model lv + species A, B; + A = 0.442; + B = 4.63; + alpha = 1.3; + gamma_ = 0.8; + -> A; alpha * A; + B -> ; 1.8 * B; + A -> ; 0.9 * A * B; + -> B; gamma_; + end + """) + problem.add_experiment("e1", 0, "") + problem.add_mapping("net1_input1", "net1.inputs[0][0]") + problem.add_mapping("net1_input2", "net1.inputs[0][1]") + problem.add_mapping("net1_output1", "net1.outputs[0][0]") + problem.add_mapping("net1_ps", "net1.parameters") + problem.add_measurement("B_obs", time=1, measurement=1, experiment_id="e1") + problem.add_observable("B_obs", "B", noise_formula="0.05") + problem.add_parameter( + "alpha", estimate=True, lb=0, ub=15, nominal_value=1.3 + ) + problem.add_parameter( + "net1_ps", estimate=True, lb=-np.inf, ub=np.inf, nominal_value="array" + ) + problem.extensions.sciml.add_hybridization("net1_input1", "A") + problem.extensions.sciml.add_hybridization("net1_input2", "B") + problem.extensions.sciml.add_hybridization("gamma_", "net1_output1") + problem.extensions.sciml.add_neural_network_from_dict( + "net1", + nn_dict={ + "nn_model_id": "net1", + "inputs": [{"input_id": "input0"}], + "layers": [ + { + "layer_id": "layer1", + "layer_type": "Linear", + "args": { + "in_features": 2, + "out_features": 1, + "bias": True, + }, + } + ], + "forward": [ + { + "name": "net_input", + "op": "placeholder", + "target": "net_input", + }, + { + "name": "layer1", + "op": "call_module", + "target": "layer1", + "args": ["net_input"], + }, + { + "name": "tanh", + "op": "call_method", + "target": "tanh", + "args": ["layer1"], + }, + ], + }, + ) + + # array data + problem.extensions.sciml.add_array_data_from_dict( + { + "metadata": {"pytorch_format": True}, + "inputs": {}, + "parameters": { + "net1": { + "layer1": { + "bias": np.random.randn(2), + "weight": np.random.randn(2), + } + } + }, + } + ) + + # set the filenames + problem.config.filepath = "problem.yaml" + problem.model.rel_path = "lv.xml" + problem.experiment_tables[0].rel_path = "experiments.tsv" + problem.mapping_tables[0].rel_path = "mappings.tsv" + problem.measurement_tables[0].rel_path = "measurements.tsv" + problem.observable_tables[0].rel_path = "observables.tsv" + problem.parameter_tables[0].rel_path = "parameters.tsv" + problem.extensions.sciml.hybridization_tables[ + 0 + ].rel_path = "hybridizations.tsv" + # problem.extensions.sciml.neural_networks[0].rel_path = "net1.yaml" + # problem.extensions.sciml.array_data_files[0].rel_path = "net1_ps.hdf5" + + return problem + + +def test_lint(): + problem = _get_test_problem() + assert problem.validate() == [] diff --git a/tox.ini b/tox.ini index 3f3bbe46..ffeb503c 100644 --- a/tox.ini +++ b/tox.ini @@ -15,7 +15,7 @@ description = extras = tests,reports,combine,vis deps= git+https://github.com/PEtab-dev/petab_test_suite@main - git+https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git@master\#subdirectory=src/python + -e git+https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git@master\#subdirectory=src/python&egg=benchmark_models_petab commands = python -m pip install sympy>=1.12.1