Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 136 additions & 42 deletions src/openlayer/lib/core/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import abc
import json
import time
import asyncio
import inspect
import argparse
from typing import Any, Dict, Tuple
from typing import Any, Dict, List, Tuple, Optional
from dataclasses import field, dataclass

import pandas as pd
Expand Down Expand Up @@ -36,6 +37,12 @@ class OpenlayerModel(abc.ABC):

It is more conventional to implement the `run` method.

``run`` may be defined as either ``def run`` (called sequentially per row)
or ``async def run``. When ``run`` is async, ``run_batch_from_df`` will drive
rows concurrently with ``asyncio.gather``; pass ``max_workers > 1`` to enable
concurrent execution. Use async-native I/O (``httpx``, ``openai-async``, etc.)
inside an async ``run`` to actually benefit from concurrency.

Refer to Openlayer's templates for examples of how to implement this class.
"""

Expand All @@ -59,6 +66,15 @@ def run_from_cli(self) -> None:
required=False,
help="Custom arguments in format 'key1=value1,key2=value2'",
)
parser.add_argument(
"--max-workers",
type=int,
default=None,
help=(
"Max concurrent rows when run() is async. "
"Defaults to 4 for async run, 1 for sync run."
),
)

# Parse the arguments
args = parser.parse_args()
Expand All @@ -76,9 +92,12 @@ def run_from_cli(self) -> None:
return self.batch(
dataset_path=args.dataset_path,
output_dir=args.output_dir,
max_workers=args.max_workers,
)

def batch(self, dataset_path: str, output_dir: str) -> None:
def batch(
self, dataset_path: str, output_dir: str, max_workers: Optional[int] = None
) -> None:
"""Reads the dataset from a file and runs the model on it."""
# Load the dataset into a pandas DataFrame
fmt = "csv"
Expand All @@ -91,50 +110,125 @@ def batch(self, dataset_path: str, output_dir: str) -> None:
raise ValueError(f"Unsupported dataset format: {dataset_path}")

# Call the model's run_batch method, passing in the DataFrame
output_df, config = self.run_batch_from_df(df)
output_df, config = self.run_batch_from_df(df, max_workers=max_workers)
self.write_output_to_directory(output_df, config, output_dir, fmt)

def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
"""Function that runs the model and returns the result."""
# Ensure the 'output' column exists
if "output" not in df.columns:
df["output"] = None
def run_batch_from_df(
self, df: pd.DataFrame, max_workers: Optional[int] = None
) -> Tuple[pd.DataFrame, dict]:
"""Function that runs the model and returns the result.

# Get the signature of the 'run' method
If ``run`` is defined as ``async def run(...)``, rows are dispatched
concurrently with ``asyncio.gather`` gated by ``asyncio.Semaphore(max_workers)``.
``max_workers`` defaults to 4 for an async ``run`` (writing `async def`
is the opt-in signal that interleaving is safe). For a synchronous
``run``, rows are processed sequentially and ``max_workers`` must be 1.

A row's exception propagates and aborts the batch. For the async path,
``asyncio.gather`` cancels in-flight siblings before re-raising.
"""
run_signature = inspect.signature(self.run)
valid_params = set(run_signature.parameters)
is_async = inspect.iscoroutinefunction(self.run)

if max_workers is None:
max_workers = 4 if is_async else 1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defaulting async to 4 is opinionated and silent. The "writing async def means interleaving is safe" contract is reasonable, but this jumps an async run from sequential to 4 concurrent invocations with no explicit opt-in at the call site. That can surprise a run that hits a rate-limited API or holds non-reentrant state.

Two options: (a) default async to 1 and require --max-workers N to scale, or (b) keep 4 but call it out prominently in the changelog/user docs. Either is fine. Flagging so it's a deliberate choice rather than an accident.

elif max_workers < 1:
raise ValueError("max_workers must be >= 1")

if max_workers > 1 and not is_async:
raise ValueError(
"max_workers > 1 requires an async `run` method. "
"Define `run` as `async def run(self, ...)` to enable "
"concurrent execution."
)

for col in ("output", "steps", "latency", "cost", "tokens", "context"):
if col not in df.columns:
df[col] = None

rows = [
(
idx,
{k: v for k, v in row.to_dict().items() if k in valid_params},
)
for idx, row in df.iterrows()
]

if is_async:
try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError(
"run_batch_from_df was called from inside a running event "
"loop. Call `await self._run_rows_async(...)` directly "

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error message points at an internal. The guidance to "Call await self._run_rows_async(...)" references a private method that takes pre-built (index, kwargs) tuples, which isn't something a user can reasonably call. If invoking from inside a running loop is a real use case, expose a public async def run_batch_from_df_async(df, max_workers=...) and point here. Otherwise, soften the message so it doesn't direct users at internals.

"from async code."
)
results = asyncio.run(self._run_rows_async(rows, max_workers))
else:
results = [
(idx, self.run(**kwargs), tracer.get_current_trace())
for idx, kwargs in rows
]

for index, output, trace in results:
self._apply_row_result(df, index, output, trace)

for index, row in df.iterrows():
# Filter row_dict to only include keys that are valid parameters
# for the 'run' method
row_dict = row.to_dict()
filtered_kwargs = {
k: v for k, v in row_dict.items() if k in run_signature.parameters
}

# Call the run method with filtered kwargs
output = self.run(**filtered_kwargs)

df.at[index, "output"] = output.output

for k, v in output.other_fields.items():
if k not in df.columns:
df[k] = None
df.at[index, k] = v

trace = tracer.get_current_trace()
if trace:
processed_trace, _ = tracer.post_process_trace(trace_obj=trace)
df.at[index, "steps"] = trace.to_dict()
if "latency" in processed_trace:
df.at[index, "latency"] = processed_trace["latency"]
if "cost" in processed_trace:
df.at[index, "cost"] = processed_trace["cost"]
if "tokens" in processed_trace:
df.at[index, "tokens"] = processed_trace["tokens"]
if "context" in processed_trace:
df.at[index, "context"] = processed_trace["context"]

config = {
return df, self._build_config(run_signature, df)

async def _run_rows_async(
self,
rows: List[Tuple[Any, Dict[str, Any]]],
max_workers: int,
) -> List[Tuple[Any, RunReturn, Optional[Any]]]:
"""Drive an async ``run`` over all rows with bounded concurrency.

The first row to raise causes ``asyncio.gather`` to cancel in-flight
siblings and re-raise the original exception.
"""
sem = asyncio.Semaphore(max_workers)

async def _one(index: Any, kwargs: Dict[str, Any]):
async with sem:
output = await self.run(**kwargs)
return index, output, tracer.get_current_trace()

return await asyncio.gather(*(_one(i, k) for i, k in rows))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eager task creation won't scale to large testsets. asyncio.gather(*(_one(...) for ... in rows)) instantiates one coroutine/Task per row up front, and rows (line 150) materializes every filtered-kwargs dict at once. The semaphore bounds concurrency, not task count, so a 100k-row testset creates 100k pending tasks plus 100k kwargs dicts in memory regardless of max_workers.

For typical testsets this is fine, but since the motivating use case is large batches against slow APIs, consider a bounded worker pool (N workers pulling from an asyncio.Queue, or chunked gather) so memory scales with max_workers rather than row count. At minimum, let's document the limitation.


def _apply_row_result(
self,
df: pd.DataFrame,
index: Any,
output: RunReturn,
trace: Optional[Any],
) -> None:
"""Write a single row's output and trace fields into ``df`` in place."""
df.at[index, "output"] = output.output

for k, v in output.other_fields.items():
if k not in df.columns:
df[k] = None
df.at[index, k] = v

if trace:
processed_trace, _ = tracer.post_process_trace(trace_obj=trace)
df.at[index, "steps"] = trace.to_dict()
if "latency" in processed_trace:
df.at[index, "latency"] = processed_trace["latency"]
if "cost" in processed_trace:
df.at[index, "cost"] = processed_trace["cost"]
if "tokens" in processed_trace:
df.at[index, "tokens"] = processed_trace["tokens"]
if "context" in processed_trace:
df.at[index, "context"] = processed_trace["context"]

def _build_config(
self, run_signature: inspect.Signature, df: pd.DataFrame
) -> Dict[str, Any]:
"""Build the config dict returned alongside the output DataFrame."""
config: Dict[str, Any] = {
"outputColumnName": "output",
"inputVariableNames": list(run_signature.parameters.keys()),
"metadata": {
Expand All @@ -154,7 +248,7 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
for k, v in self.custom_args.items():
config["metadata"][k] = v

return df, config
return config

def write_output_to_directory(
self,
Expand Down