best_of_n.py
:
"""Example to run m serve."""
import os
import pydantic
import mellea
from mellea.stdlib.base import CBlock, Component, ModelOutputThunk
from mellea.stdlib.sampling import SamplingResult
class RankerResponse(pydantic.BaseModel):
best_choice: int
session = mellea.start_session()
def serve(
input: str, requirements: list[str] | None = None, model_options: None | dict = None
) -> ModelOutputThunk | SamplingResult | None:
N = int(os.environ["N"] if "N" in os.environ else "3")
attempts: dict[str, str | CBlock | Component] = {
str(i): session.instruct(input) for i in range(N)
}
attempts["query"] = input
ranking_output = session.instruct(
"Choose the best response to the user's query",
grounding_context=attempts,
format=RankerResponse,
).value
if ranking_output is not None:
choice = int(RankerResponse.model_validate_json(ranking_output).best_choice)
print(f"selected {choice}")
assert 0 < choice < len(attempts)
res = attempts[str(choice - 1)]
assert isinstance(res, ModelOutputThunk) or isinstance(res, SamplingResult)
return res
else:
return None