diff --git a/AGENTS.md b/AGENTS.md index 78eed761..7a441e9a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,70 +1,123 @@ # PyAutoArray — Agent Instructions -**PyAutoArray** is the low-level data structures and numerical utilities package for the PyAuto ecosystem. It provides grids, masks, arrays, datasets, inversions, and the decorator system used throughout PyAutoGalaxy and PyAutoLens. +Canonical, agent-agnostic instructions for this repo. `CLAUDE.md` imports this +file; any tool that does not process `@`-imports should read this directly. -## Setup +## What this repo is -```bash -pip install -e ".[dev]" -``` +**PyAutoArray** (package `autoarray`) is the low-level data-structure and +numerical-utility layer: masks, arrays, (y,x) grids, imaging/interferometer +datasets, inversions/pixelizations, convolution/over-sampling operators, and +the grid decorators used throughout PyAutoGalaxy and PyAutoLens. + +Dependency direction: autoarray depends on **autoconf** only. It does **not** +import `autofit`, `autogalaxy`, or `autolens` — never add such an import. +Shared utilities (e.g. `test_mode`, `jax_wrapper`) belong in autoconf. + +## Related repos -## Running Tests +- **Source siblings:** PyAutoConf (upstream). PyAutoGalaxy / PyAutoLens build + directly on autoarray. +- No `_workspace`, `_workspace_test`, or HowTo of its own. The JAX/`xp` path is + exercised by the parity scripts in **autogalaxy_workspace_test** and + **autolens_workspace_test**. +- **docs/** — Sphinx source; published to ReadTheDocs. + +## Quick commands ```bash -python -m pytest test_autoarray/ -python -m pytest test_autoarray/structures/test_arrays.py -python -m pytest test_autoarray/structures/test_arrays.py -s +pip install -e ".[dev]" # install with dev/test extras +python -m pytest test_autoarray/ # full test suite +python -m pytest test_autoarray/structures/test_arrays.py # one focused test (add -s for output) +black autoarray/ # formatter (advisory — not gated) ``` -### Sandboxed / Codex runs +In a sandboxed / restricted environment, point numba and matplotlib at +writable caches: ```bash NUMBA_CACHE_DIR=/tmp/numba_cache MPLCONFIGDIR=/tmp/matplotlib python -m pytest test_autoarray/ ``` -## Key Architecture +## CI / definition of green + +PRs must pass `pytest --cov` on the CI matrix (Python 3.12 **and** 3.13). There +is no black/ruff/flake8 gate — formatting is advisory. (`requires-python` in +`pyproject.toml` is `>=3.9`.) + +## Configuration & defaults -- **Data structures**: `Array2D`, `Grid2D`, `Grid2DIrregular`, `VectorYX2D` — all inherit from `AbstractNDArray` -- **Decorator system** (`structures/decorators/`): `@to_array`, `@to_grid`, `@to_vector_yx`, `@transform` — ensures output type matches input grid type -- **Datasets**: `Imaging`, `Interferometer` — containers for observational data -- **Inversions** (`inversion/`): sparse linear algebra for source reconstruction via pixelizations -- **Operators**: `Convolver` (PSF convolution), over-sampling utilities +autoconf supplies the packaged defaults under `autoarray/config/`. Workspaces +override them via their own `config/` directory; the test suite pushes a local +config dir via `conf.instance.push(...)` in `test_autoarray/conftest.py`. When +a change adds a new config key, mirror it into the packaged defaults so +downstream workspaces inherit it. -## Key Rules +## JAX & `xp` -- The `xp` parameter pattern controls NumPy vs JAX: `xp=np` (default) or `xp=jnp` -- Autoarray types are **not** JAX pytrees — they cannot be returned from `jax.jit` functions -- Decorated functions must return **raw arrays**, not autoarray wrappers -- All files must use Unix line endings (LF) -- Format with `black autoarray/` +NumPy is the default everywhere; JAX is opt-in and never imported at module +level. The `xp` parameter is the single point of control: -## Working on Issues +- `xp=np` (default) — pure NumPy path. +- `xp=jnp` — JAX path; `jax` / `jax.numpy` imported locally inside the function. + +Thread `xp` through **every** nested call (`self.X()`, helpers, properties) — a +missed site silently defaults to `xp=np` and fails when a tracer hits an `np.*` +op. Two patterns cross the `jax.jit` boundary: the `if xp is np:` **guard** for +functions that return a raw `jax.Array`, and **pytree registration** +(`abstract_ndarray._register_as_pytree` / `register_instance_pytree`) for +functions that return a real wrapper or structured object. + +**Unit tests are NumPy-only.** A JAX/`xp` change is validated only by the +parity scripts in `autogalaxy_workspace_test` / `autolens_workspace_test` +(`jax.jit` round-trip + `fitness._vmap` batch eval), not by `test_autoarray/`. + +Full detail (decorator internals, `xp`-threading hazards, both JIT patterns): +**[`docs/agents/jax_and_decorators.md`](docs/agents/jax_and_decorators.md)**. + +## Public API + +The public surface is defined authoritatively in `autoarray/__init__.py` — read +it rather than trusting a hand-maintained list. Canonical import: + +```python +import autoarray as aa +``` + +Core types (`Array2D`, `Grid2D`, `Grid2DIrregular`, `VectorYX2D`, …) inherit +from `AbstractNDArray`; `.array` returns the raw `numpy.ndarray` / `jax.Array`. + +## Key rules / footguns + +- Import direction: autoconf only — never `autofit` / `autogalaxy` / `autolens`. +- Grid-consuming functions decorated with `@aa.decorators.to_array` / `to_grid` + / `to_vector_yx` must return a **raw array** — the decorator wraps it. (Write + `aa.decorators.*`; `aa.grid_dec` is a deprecated alias.) +- Access grid coordinates via `grid.array[:, 0]`, not `grid[:, 0]`. +- All files use Unix line endings (LF, `\n`) — never `\r\n`. + +## Working on issues 1. Read the issue description and any linked plan. -2. Identify affected files and write your changes. -3. Run the full test suite: `python -m pytest test_autoarray/` -4. Ensure all tests pass before opening a PR. -5. If changing public API, note the change in your PR description — downstream packages (PyAutoGalaxy, PyAutoLens) and workspaces may need updates. -## Never rewrite history - -NEVER perform these operations on any repo with a remote: - -- `git init` in a directory already tracked by git -- `rm -rf .git && git init` -- Commit with subject "Initial commit", "Fresh start", "Start fresh", "Reset - for AI workflow", or any equivalent message on a branch with a remote -- `git push --force` to `main` (or any branch tracked as `origin/HEAD`) -- `git filter-repo` / `git filter-branch` on shared branches -- `git rebase -i` rewriting commits already pushed to a shared branch - -If the working tree needs a clean state, the **only** correct sequence is: - - git fetch origin - git reset --hard origin/main - git clean -fd - -This applies equally to humans, local Claude Code, cloud Claude agents, Codex, -and any other agent. The "Initial commit — fresh start for AI workflow" pattern -that appeared independently on origin and local for three workspace repos is -exactly what this rule prevents — it costs ~40 commits of redundant local work -every time it happens. +2. Identify affected files and make the change. +3. Run the full suite: `python -m pytest test_autoarray/`. +4. If you changed public API, say so explicitly — downstream packages + (PyAutoGalaxy, PyAutoLens) and the workspaces may need updates. +5. Ensure all tests pass before opening a PR. + +## Deep dives + +- [`docs/agents/jax_and_decorators.md`](docs/agents/jax_and_decorators.md) — + decorator system, `xp` backend pattern, and the `jax.jit` boundary. + +## Clean state + +Never rewrite history on a repo with a remote (no `git init` over a tracked +tree, no force-push to `main`, no rebasing pushed shared branches). To reset a +dirty tree the only correct sequence is: + +```bash +git fetch origin +git reset --hard origin/main +git clean -fd +``` diff --git a/CLAUDE.md b/CLAUDE.md index bade2d3e..fb9e30eb 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,204 +1,5 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Dependency Graph - -PyAutoArray depends on **autoconf** (shared configuration and utilities). -PyAutoArray does **NOT** depend on PyAutoFit, PyAutoGalaxy, or PyAutoLens. -Never import from `autofit`, `autogalaxy`, or `autolens` in this repo. -Shared utilities (e.g. `test_mode`, `jax_wrapper`) belong in autoconf. - -## Commands - -### Install -```bash -pip install -e ".[dev]" -``` - -### Run Tests -```bash -# All tests -python -m pytest test_autoarray/ - -# Single test file -python -m pytest test_autoarray/structures/test_arrays.py - -# With output -python -m pytest test_autoarray/structures/test_arrays.py -s -``` - -### Codex / sandboxed runs - -When running Python from Codex or any restricted environment, set writable cache directories so `numba` and `matplotlib` do not fail on unwritable home or source-tree paths: - -```bash -NUMBA_CACHE_DIR=/tmp/numba_cache MPLCONFIGDIR=/tmp/matplotlib python -m pytest test_autoarray/ -``` - -This workspace is often imported from `/mnt/c/...` and Codex may not be able to write to module `__pycache__` directories or `/home/jammy/.cache`, which can cause import-time `numba` caching failures without this override. - -### Formatting -```bash -black autoarray/ -``` - -### Plot Output Mode - -Set `PYAUTO_OUTPUT_MODE=1` to capture every figure produced by a script into numbered PNG files in `./output_mode//`. This is useful for visually inspecting all plots from an integration test without needing a display. - -```bash -PYAUTO_OUTPUT_MODE=1 python scripts/my_script.py -# -> ./output_mode/my_script/0_fit.png, 1_tracer.png, ... -``` - -When this env var is set, all `save_figure`, `subplot_save`, and `_save_subplot` calls are intercepted — the normal output path is bypassed and figures are written sequentially to the output_mode directory instead. - -## Architecture - -**PyAutoArray** is the low-level data structures and numerical utilities package for the PyAuto ecosystem. It provides: -- **Grid and array structures** — uniform and irregular 2D grids, arrays, vector fields -- **Masks** — 1D and 2D masks that define which pixels are active -- **Datasets** — imaging and interferometer dataset containers -- **Inversions / pixelizations** — sparse linear algebra for source reconstruction -- **Decorators** — input/output homogenisation for grid-consuming functions - -## Core Data Structures - -All data structures inherit from `AbstractNDArray` (`abstract_ndarray.py`). Key subclasses: - -| Class | Description | -|---|---| -| `Array2D` | Uniform 2D array tied to a `Mask2D` | -| `ArrayIrregular` | Unmasked 1D collection of values | -| `Grid2D` | Uniform (y,x) coordinate grid tied to a `Mask2D` | -| `Grid2DIrregular` | Irregular (y,x) coordinate collection | -| `VectorYX2D` | Uniform 2D vector field | -| `VectorYX2DIrregular` | Irregular vector field | - -`AbstractNDArray` provides arithmetic operators (`__add__`, `__sub__`, `__rsub__`, etc.), all decorated with `@to_new_array` and `@unwrap_array` so that operations between autoarray objects and raw scalars/arrays work naturally and return a new autoarray of the same type. - -The `.array` property returns the raw underlying `numpy.ndarray` or `jax.Array`: -```python -arr = aa.ArrayIrregular(values=[1.0, 2.0]) -arr.array # raw numpy array -arr._array # same, internal attribute -``` - -The constructor unwraps nested autoarray objects automatically: -```python -# while isinstance(array, AbstractNDArray): array = array.array -``` - -## Decorator System - -`autoarray/structures/decorators/` contains three output-wrapping decorators used on all grid-consuming functions. They ensure that the **type of the output structure matches the type of the input grid**: - -| Decorator | Grid2D input | Grid2DIrregular input | -|---|---|---| -| `@aa.grid_dec.to_array` | `Array2D` | `ArrayIrregular` | -| `@aa.grid_dec.to_grid` | `Grid2D` | `Grid2DIrregular` | -| `@aa.grid_dec.to_vector_yx` | `VectorYX2D` | `VectorYX2DIrregular` | - -### How the decorators work - -All three share `AbstractMaker` (`decorators/abstract.py`). The decorator: -1. Wraps the function in a `wrapper(obj, grid, xp=np, *args, **kwargs)` signature -2. Instantiates the relevant `*Maker` class with the function, object, grid, and `xp` -3. `AbstractMaker.result` checks the grid type and calls the appropriate `via_grid_2d` / `via_grid_2d_irr` method to wrap the raw result - -The function body receives the grid as-is and **must return a raw array** (not an autoarray wrapper). The decorator does the wrapping: - -```python -@aa.grid_dec.to_array -def convergence_2d_from(self, grid, xp=np, **kwargs): - # grid is Grid2D or Grid2DIrregular — access raw values via grid.array[:,0] - y = grid.array[:, 0] - x = grid.array[:, 1] - return xp.sqrt(y**2 + x**2) # return raw array; decorator wraps it -``` - -`AbstractMaker` also stores `use_jax = xp is not np` and exposes `_xp` (either `jnp` or `np`), but the wrapping step always runs regardless of `xp`. Autoarray types are **not registered as JAX pytrees**, so they cannot be directly returned from inside a `jax.jit` trace (see JAX section below). - -### Accessing grid coordinates inside a decorated function - -Inside a decorated function body, access the raw underlying array with `.array`: - -```python -# Correct — works for both numpy and jax backends -y = grid.array[:, 0] -x = grid.array[:, 1] - -# Also correct for simple slicing (returns raw array via __getitem__) -y = grid[:, 0] -x = grid[:, 1] -``` - -The `@transform` decorator (also in `decorators/`) shifts and rotates the input grid to the profile's reference frame before passing it to the function. It calls `obj.transformed_to_reference_frame_grid_from(grid, xp)` (decorated with `@to_grid`) and passes the result as the `grid` argument. After transformation the grid is still an autoarray object; `.array` still works. - -### Decorator stacking order - -Decorators are applied bottom-up (innermost first). The canonical order for mass/light profile methods is: - -```python -@aa.grid_dec.to_array # outermost: wraps output -@aa.grid_dec.transform # innermost: transforms grid input -def convergence_2d_from(self, grid, xp=np, **kwargs): - ... -``` - -## JAX Support - -The `xp` parameter pattern is the single point of control: -- `xp=np` (default) — pure NumPy path -- `xp=jnp` — JAX path; `jax` / `jax.numpy` are only imported locally - -### Why autoarray types cannot be returned from `jax.jit` - -`AbstractNDArray` subclasses (`Array2D`, `ArrayIrregular`, `VectorYX2DIrregular`, etc.) are **not registered as JAX pytrees**. The `instance_flatten` / `instance_unflatten` class methods are defined on `AbstractNDArray` but are never passed to `jax.tree_util.register_pytree_node`. As a result: - -- Constructing an autoarray wrapper **inside** a JIT trace is fine (Python-level code runs normally during tracing) -- **Returning** an autoarray wrapper as the output of a `jax.jit`-compiled function **fails** with `TypeError: ... is not a valid JAX type` - -### The `if xp is np:` guard pattern - -Functions that are called directly inside `jax.jit` (i.e., as the outermost call in the lambda) must not return autoarray wrappers on the JAX path. The correct pattern is: - -```python -def convergence_2d_via_hessian_from(self, grid, xp=np): - hessian_yy, hessian_xx = ... - convergence = 0.5 * (hessian_yy + hessian_xx) - - if xp is np: - return aa.ArrayIrregular(values=convergence) # numpy: wrapped - return convergence # jax: raw jax.Array -``` - -This pattern is used in `autogalaxy/operate/lens_calc.py` for all `LensCalc` methods that are called inside `jax.jit`. It does **not** affect decorated helper functions (like `deflections_yx_2d_from`) because those are called as intermediate steps — their autoarray wrappers are consumed by downstream Python code, never returned as JIT outputs. - -## Line Endings — Always Unix (LF) - -All files **must use Unix line endings (LF, `\n`)**. Never write `\r\n` line endings. -## Never rewrite history - -NEVER perform these operations on any repo with a remote: - -- `git init` in a directory already tracked by git -- `rm -rf .git && git init` -- Commit with subject "Initial commit", "Fresh start", "Start fresh", "Reset - for AI workflow", or any equivalent message on a branch with a remote -- `git push --force` to `main` (or any branch tracked as `origin/HEAD`) -- `git filter-repo` / `git filter-branch` on shared branches -- `git rebase -i` rewriting commits already pushed to a shared branch - -If the working tree needs a clean state, the **only** correct sequence is: - - git fetch origin - git reset --hard origin/main - git clean -fd - -This applies equally to humans, local Claude Code, cloud Claude agents, Codex, -and any other agent. The "Initial commit — fresh start for AI workflow" pattern -that appeared independently on origin and local for three workspace repos is -exactly what this rule prevents — it costs ~40 commits of redundant local work -every time it happens. +# PyAutoArray — agent instructions +The canonical, agent-agnostic instructions live in `AGENTS.md`. Claude Code loads them +via the import below; if your tool does not process `@`-imports, open `AGENTS.md` in +this directory and read it directly. +@AGENTS.md diff --git a/README.md b/README.md index 6359b919..eb5a0280 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,57 @@ # PyAutoArray -A library for manipulating arrays within the PyAuto software framework +**PyAutoArray** (package `autoarray`) is the low-level data-structure and +numerical-utility layer of the [PyAuto](https://github.com/PyAutoLabs) +ecosystem. It provides masks, arrays, (y,x) coordinate grids, +imaging/interferometer datasets, inversions/pixelizations for source +reconstruction, and convolution/over-sampling operators. + +`PyAutoGalaxy` and `PyAutoLens` build directly on autoarray: every grid a +profile consumes, every masked image a fit operates on, and the linear-algebra +inversions behind pixelized source reconstruction are autoarray objects. The +package supports both a NumPy and an opt-in JAX (`xp=jnp`) backend. + +## Install + +```bash +pip install autoarray +``` + +## Examples + +A masked 2D array tied to a pixel scale: + +```python +import autoarray as aa + +arr = aa.Array2D.no_mask(values=[[1.0, 2.0], [3.0, 4.0]], pixel_scales=0.1) +arr.shape_native # (2, 2) +arr.native[0, 0] # 1.0 +``` + +A circular mask and the (y,x) coordinate grid of its unmasked pixels: + +```python +mask = aa.Mask2D.circular(shape_native=(50, 50), pixel_scales=0.1, radius=2.0) +mask.pixels_in_mask # 1264 + +grid = aa.Grid2D.from_mask(mask=mask) # shape (1264, 2) +uniform = aa.Grid2D.uniform(shape_native=(10, 10), pixel_scales=0.1) +``` + +A normalized Gaussian PSF convolver: + +```python +convolver = aa.Convolver.from_gaussian( + shape_native=(11, 11), pixel_scales=0.1, sigma=1.0, normalize=True +) +convolver.kernel.shape_native # (11, 11) +convolver.kernel.array.sum() # 1.0 +``` + +## Links + +- Source & tests: [`autoarray/`](autoarray), [`test_autoarray/`](test_autoarray) +- Decorators & JAX deep dive: [`docs/agents/jax_and_decorators.md`](docs/agents/jax_and_decorators.md) +- Agent/contributor instructions: [`AGENTS.md`](AGENTS.md) +- Ecosystem: [PyAutoLabs on GitHub](https://github.com/PyAutoLabs) diff --git a/docs/agents/jax_and_decorators.md b/docs/agents/jax_and_decorators.md new file mode 100644 index 00000000..82460b24 --- /dev/null +++ b/docs/agents/jax_and_decorators.md @@ -0,0 +1,230 @@ +# JAX & the decorator system — deep dive + +Long-form reference for the grid decorators, the `xp` (NumPy/JAX) backend +pattern, and how autoarray types cross the `jax.jit` boundary. The per-repo +`AGENTS.md` files keep only a short summary and link here. This is the single +canonical source for the detail — PyAutoGalaxy and PyAutoLens point at it +rather than re-explaining. + +Everything below is grounded in the installed source under +`autoarray/`, `autogalaxy/`, and `autolens/`. Where a class or function is +named, it exists in the current tree. + +--- + +## 1. The decorator system + +`autoarray/structures/decorators/` contains the output-wrapping decorators +used on all grid-consuming functions. They ensure the **type of the output +structure matches the type of the input grid**. + +Import them as `aa.decorators.*`. (`aa.grid_dec` still resolves as a +**deprecated alias** — `autoarray/__init__.py` defines +`from .structures import decorators as grid_dec # deprecated alias` — but +every shipped profile uses `aa.decorators.*`, so write that form.) + +| Decorator | `Grid2D` input → | `Grid2DIrregular` input → | +|---|---|---| +| `@aa.decorators.to_array` | `Array2D` | `ArrayIrregular` | +| `@aa.decorators.to_grid` | `Grid2D` | `Grid2DIrregular` | +| `@aa.decorators.to_vector_yx` | `VectorYX2D` | `VectorYX2DIrregular` | + +### How they work + +All three share `AbstractMaker` (`decorators/abstract.py`). The decorator: + +1. Wraps the function in a `wrapper(obj, grid, xp=np, *args, **kwargs)` signature. +2. Instantiates the relevant `*Maker` class with the function, object, grid, and `xp`. +3. `AbstractMaker.result` checks the grid type and calls the appropriate + `via_grid_2d` / `via_grid_2d_irr` method to wrap the raw result. + +The function body receives the grid as-is and **must return a raw array** +(not an autoarray wrapper). The decorator does the wrapping: + +```python +@aa.decorators.to_array +def convergence_2d_from(self, grid, xp=np, **kwargs): + # grid is Grid2D or Grid2DIrregular — access raw values via grid.array[:, 0] + y = grid.array[:, 0] + x = grid.array[:, 1] + return xp.sqrt(y**2 + x**2) # return raw array; decorator wraps it +``` + +`AbstractMaker` stores `use_jax = xp is not np` and exposes `_xp` (either `jnp` +or `np`), but the wrapping step always runs regardless of `xp`. + +### Accessing grid coordinates inside a decorated function + +Access the raw underlying array with `.array`: + +```python +# Correct — works for both numpy and jax backends +y = grid.array[:, 0] +x = grid.array[:, 1] + +# Also works for simple slicing (returns raw array via __getitem__) +y = grid[:, 0] +x = grid[:, 1] +``` + +Prefer `grid.array[:, 0]` — after `@transform` the grid is still an autoarray +object and `.array` is the safe way to extract the underlying data for both +numpy and jax backends. + +### `@transform` and stacking order + +`@aa.decorators.transform` shifts and rotates the input grid to the profile's +reference frame before passing it to the function. It calls +`obj.transformed_to_reference_frame_grid_from(grid, xp)` (itself decorated with +`@to_grid`) and passes the result as the `grid` argument. After transformation +the grid is still an autoarray object; `.array` still works. Some call sites +pass `rotate_back=True` (e.g. `@aa.decorators.transform(rotate_back=True)`). + +Decorators apply bottom-up (innermost first). The canonical order for +mass/light profile methods is: + +```python +@aa.decorators.to_array # outermost: wraps output +@aa.decorators.transform # innermost: transforms grid input +def convergence_2d_from(self, grid, xp=np, **kwargs): + ... +``` + +--- + +## 2. `AbstractNDArray` and the `.array` property + +All data structures inherit from `AbstractNDArray` (`abstract_ndarray.py`). +Key subclasses: `Array2D`, `ArrayIrregular`, `Grid2D`, `Grid2DIrregular`, +`VectorYX2D`, `VectorYX2DIrregular`. + +`AbstractNDArray` provides arithmetic operators (`__add__`, `__sub__`, +`__rsub__`, …) so operations between autoarray objects and raw scalars/arrays +return a new autoarray of the same type. The `.array` property returns the raw +underlying `numpy.ndarray` or `jax.Array`: + +```python +arr = aa.ArrayIrregular(values=[1.0, 2.0]) +arr.array # raw numpy (or jax) array +arr._array # same, internal attribute +``` + +The constructor unwraps nested autoarray objects automatically +(`while isinstance(array, AbstractNDArray): array = array.array`). + +--- + +## 3. The `xp` backend pattern + +The codebase is designed so that **NumPy is the default everywhere and JAX is +opt-in**. JAX is never imported at module level — only locally inside functions +when explicitly requested. The `xp` parameter is the single point of control: + +- `xp=np` (default throughout) — pure NumPy path, no JAX dependency at runtime. +- `xp=jnp` — JAX path; `jax` / `jax.numpy` imported locally inside the function. + +When adding a new function that should support JAX: + +1. Default the parameter to `xp=np`. +2. Guard any JAX imports with `if xp is not np:` and import `jax` / `jax.numpy` + locally inside that branch. +3. Add the NumPy implementation as the default path. +4. Add a JAX implementation in the guarded branch (e.g. `jax.jacfwd`, + `jnp.vectorize`). + +### Threading `xp` through nested calls + +Adding `xp=np` to a method body and swapping `np.*` for `xp.*` is **only half +the work**. Every nested call inside that body — `self.X()`, `obj.X()`, a helper +in `convert.py`, an inherited `@property`, or a sibling method — must also +receive `xp=xp` if it can route to numpy operations on what would otherwise be +JAX tracers. Otherwise the inner call silently defaults to `xp=np` and fails +when a tracer reaches an `np.*` op. + +Concrete hazards seen in this codebase: + +- **`@property` chains that hardcode `np`.** A property takes no kwargs, so an + xp-aware caller must either inline the computation under `if xp is not np:` + or convert the property to a method. Read every `@property` you call from + xp-aware code; if it does `np.sqrt(...)`, it is a hazard. +- **Inherited methods.** A method may accept `xp` but a call site forgets to + pass it. Within xp-aware functions, grep for `self.X(` / `obj.X(` and verify + `xp=xp` is threaded. +- **`convert.py` helpers.** Helpers like `axis_ratio_and_angle_from`, + `angle_from`, `multipole_comps_from` all take `xp=np`; call sites must thread + it. They also use Python `&` on JAX bool tracers, which silently calls + `__array__()` — replace with `xp.logical_and`. +- **`@cached_property` on traced arrays.** Caches a tracer in `self.__dict__`, + which is invalid across `vmap` batch elements (different batches share the + cache). Use plain `@property` for any value that depends on JAX-traced inputs. + +--- + +## 4. Crossing the `jax.jit` boundary — two patterns + +Autoarray types **are** registered as JAX pytrees (see +`abstract_ndarray._register_as_pytree` and `register_instance_pytree`), so a +wrapper *can* be returned from a jitted function once its class is registered. +Two patterns coexist depending on what the function returns: + +### Pattern 1 — `if xp is np:` guard (raw `jax.Array` return) + +Functions intended to be called directly inside `jax.jit` as the outermost op, +where no wrapper is needed on the JAX path, guard their autoarray wrapping: + +```python +def convergence_2d_via_hessian_from(self, grid, xp=np): + convergence = 0.5 * (hessian_yy + hessian_xx) + + if xp is np: + return aa.ArrayIrregular(values=convergence) # numpy: wrapped + return convergence # jax: raw jax.Array +``` + +All `LensCalc` hessian-derived methods (`convergence_2d_via_hessian_from`, +`shear_yx_2d_via_hessian_from`, `magnification_2d_via_hessian_from`, +`magnification_2d_from`, `tangential_eigen_value_from`, +`radial_eigen_value_from`) use this pattern in +`autogalaxy/operate/lens_calc.py` and return raw `jax.Array` on the JAX path. +Intermediate helpers (e.g. `deflections_yx_2d_from`) do **not** need the guard +— their autoarray wrappers are consumed by downstream Python before any JIT +boundary. + +### Pattern 2 — pytree-registered wrapper return + +Functions that must return a real autoarray wrapper (or a structured object +built from them) rely on JAX pytree registration: + +- `AbstractNDArray` auto-registers its subclass with `jax.tree_util` the first + time an instance is built with `xp=jnp`, via + `autoarray.abstract_ndarray._register_as_pytree`. +- Higher-level types (`FitImaging`, `Tracer`, `DatasetModel`) use + `autoarray.abstract_ndarray.register_instance_pytree(cls, no_flatten=...)`, + which flattens `__dict__` and carries `no_flatten` names through `aux_data` + for per-analysis constants (dataset, settings, cosmology). +- `AnalysisImaging._register_fit_imaging_pytrees` wires these up when + `use_jax=True`, so `jax.jit(analysis.fit_from)(instance)` returns a real + `FitImaging` with `jax.Array` leaves. + +--- + +## 5. Validation — unit tests are NumPy-only + +Library unit tests (`test_autoarray/`, `test_autogalaxy/`, `test_autolens/`) +always run on the NumPy path. **No `xp=jnp` JAX assertion belongs in a library +unit test.** A JAX / `xp` change is validated only by the parity scripts in the +`*_workspace_test` repos. + +**`jax.jit(fn)(concrete_instance)` is NOT a sufficient JAX trace check.** A +`ModelInstance` with concrete float parameters propagates as floats through +`np.*` ops without raising — an un-threaded `xp` bug stays hidden. Use +`jax.vmap(fitness)(jnp.array(params))` (or `Fitness._vmap` on autofit's +wrapper) instead: vmap forces tracer propagation through every leaf and exposes +un-threaded `xp` sites. + +When adding a JAX path to an Analysis class, the workspace_test parity script +must include **both** a `jax.jit(analysis.fit_from)(instance)` round-trip +**and** a `fitness._vmap(parameters)` batch evaluation. PyAutoArray has no own +`autoarray_workspace_test`; array-level JAX changes are exercised downstream in +`autogalaxy_workspace_test/scripts/jax_likelihood_functions/` and the +`autolens_workspace_test` equivalents.