Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ def _get_masked_element(
return element.loc[mask_values, :]


def _region_as_str_if_list_of_len_one(region: list[str]) -> str | list[str]:
return region if len(region) > 1 else region[0]


def _right_exclusive_join_spatialelement_table(
element_dict: dict[str, dict[str, Any]],
table: AnnData,
Expand Down Expand Up @@ -279,9 +283,10 @@ def _right_exclusive_join_spatialelement_table(
exclusive_table = table[keep, :] if has_match and keep.any() else None
_inplace_fix_subset_categorical_obs(subset_adata=exclusive_table, original_adata=table)
if exclusive_table is not None:
exclusive_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = (
exclusive_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = _region_as_str_if_list_of_len_one(
exclusive_table.obs[region_column_name].unique().tolist()
)
TableModel.validate(exclusive_table)
return element_dict, exclusive_table


Expand Down Expand Up @@ -383,9 +388,10 @@ def _inner_join_spatialelement_table(

_inplace_fix_subset_categorical_obs(subset_adata=joined_table, original_adata=table)
if joined_table is not None:
joined_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = (
joined_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = _region_as_str_if_list_of_len_one(
joined_table.obs[region_column_name].unique().tolist()
)
TableModel.validate(joined_table)
return element_dict, joined_table


Expand Down Expand Up @@ -466,9 +472,10 @@ def _left_join_spatialelement_table(
joined_table = table[joined_indices.tolist(), :].copy() if joined_indices is not None else None
_inplace_fix_subset_categorical_obs(subset_adata=joined_table, original_adata=table)
if joined_table is not None:
joined_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = (
joined_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = _region_as_str_if_list_of_len_one(
joined_table.obs[region_column_name].unique().tolist()
)
TableModel.validate(joined_table)
return element_dict, joined_table


Expand Down
7 changes: 4 additions & 3 deletions tests/core/query/test_relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,13 @@ def test_join_updates_spatialdata_attrs(sdata_query_aggregation):
_, table = join_spatialelement_table(
sdata=sdata, spatial_element_names="values_circles", table_name="table", how="left"
)
assert table.uns["spatialdata_attrs"]["region"] == ["values_circles"]
assert table.uns["spatialdata_attrs"]["region"] == "values_circles"

# inner join on a single element
_, table = join_spatialelement_table(
sdata=sdata, spatial_element_names="values_circles", table_name="table", how="inner"
)
assert table.uns["spatialdata_attrs"]["region"] == ["values_circles"]
assert table.uns["spatialdata_attrs"]["region"] == "values_circles"

# right_exclusive join: pass a truncated circles element so some table rows have no match.
# values_circles has 9 instances (0-8); keep only 5 → 4 table rows are exclusive.
Expand All @@ -275,7 +275,7 @@ def test_join_updates_spatialdata_attrs(sdata_query_aggregation):
)
assert table is not None
assert table.n_obs == 4
assert table.uns["spatialdata_attrs"]["region"] == ["values_circles"]
assert table.uns["spatialdata_attrs"]["region"] == "values_circles"

# original table metadata must be unchanged
assert set(sdata["table"].uns["spatialdata_attrs"]["region"]) == {"values_circles", "values_polygons"}
Expand Down Expand Up @@ -1008,6 +1008,7 @@ def test_labels_table_joins(full_sdata):
def test_points_table_joins(full_sdata):
full_sdata["table"].uns["spatialdata_attrs"]["region"] = "points_0"
full_sdata["table"].obs["region"] = ["points_0"] * 100
full_sdata["table"].obs["region"] = full_sdata["table"].obs["region"].astype("category")

element_dict, table = join_spatialelement_table(
sdata=full_sdata,
Expand Down
Loading