diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 0580af73..7ef7c1a0 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -242,6 +242,10 @@ def _get_masked_element( return element.loc[mask_values, :] +def _region_as_str_if_list_of_len_one(region: list[str]) -> str | list[str]: + return region if len(region) > 1 else region[0] + + def _right_exclusive_join_spatialelement_table( element_dict: dict[str, dict[str, Any]], table: AnnData, @@ -279,9 +283,10 @@ def _right_exclusive_join_spatialelement_table( exclusive_table = table[keep, :] if has_match and keep.any() else None _inplace_fix_subset_categorical_obs(subset_adata=exclusive_table, original_adata=table) if exclusive_table is not None: - exclusive_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = ( + exclusive_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = _region_as_str_if_list_of_len_one( exclusive_table.obs[region_column_name].unique().tolist() ) + TableModel.validate(exclusive_table) return element_dict, exclusive_table @@ -383,9 +388,10 @@ def _inner_join_spatialelement_table( _inplace_fix_subset_categorical_obs(subset_adata=joined_table, original_adata=table) if joined_table is not None: - joined_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = ( + joined_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = _region_as_str_if_list_of_len_one( joined_table.obs[region_column_name].unique().tolist() ) + TableModel.validate(joined_table) return element_dict, joined_table @@ -466,9 +472,10 @@ def _left_join_spatialelement_table( joined_table = table[joined_indices.tolist(), :].copy() if joined_indices is not None else None _inplace_fix_subset_categorical_obs(subset_adata=joined_table, original_adata=table) if joined_table is not None: - joined_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = ( + joined_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = _region_as_str_if_list_of_len_one( joined_table.obs[region_column_name].unique().tolist() ) + TableModel.validate(joined_table) return element_dict, joined_table diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index 737c9c51..4f87098a 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -256,13 +256,13 @@ def test_join_updates_spatialdata_attrs(sdata_query_aggregation): _, table = join_spatialelement_table( sdata=sdata, spatial_element_names="values_circles", table_name="table", how="left" ) - assert table.uns["spatialdata_attrs"]["region"] == ["values_circles"] + assert table.uns["spatialdata_attrs"]["region"] == "values_circles" # inner join on a single element _, table = join_spatialelement_table( sdata=sdata, spatial_element_names="values_circles", table_name="table", how="inner" ) - assert table.uns["spatialdata_attrs"]["region"] == ["values_circles"] + assert table.uns["spatialdata_attrs"]["region"] == "values_circles" # right_exclusive join: pass a truncated circles element so some table rows have no match. # values_circles has 9 instances (0-8); keep only 5 → 4 table rows are exclusive. @@ -275,7 +275,7 @@ def test_join_updates_spatialdata_attrs(sdata_query_aggregation): ) assert table is not None assert table.n_obs == 4 - assert table.uns["spatialdata_attrs"]["region"] == ["values_circles"] + assert table.uns["spatialdata_attrs"]["region"] == "values_circles" # original table metadata must be unchanged assert set(sdata["table"].uns["spatialdata_attrs"]["region"]) == {"values_circles", "values_polygons"} @@ -1008,6 +1008,7 @@ def test_labels_table_joins(full_sdata): def test_points_table_joins(full_sdata): full_sdata["table"].uns["spatialdata_attrs"]["region"] = "points_0" full_sdata["table"].obs["region"] = ["points_0"] * 100 + full_sdata["table"].obs["region"] = full_sdata["table"].obs["region"].astype("category") element_dict, table = join_spatialelement_table( sdata=full_sdata,