This is the code for researcher.py:
from collections.abc import Callable
from functools import cache

from openai import BaseModel
from pydantic import ValidationError

from docs.examples.helper import w
from docs.examples.mini_researcher import RAGDocument
from mellea import MelleaSession
from mellea.backends import model_ids
from mellea.backends.ollama import OllamaModelBackend
from mellea.stdlib.requirement import Requirement, simple_validate
from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult

# #############################
# Helper functions
# #############################


@cache
def get_session():
    """get M session (change model here)"""
    return MelleaSession(backend=OllamaModelBackend(model_ids.IBM_GRANITE_3_3_8B))


@cache
def get_guardian_session():
    """get M session for the guardian model"""
    return MelleaSession(
        backend=OllamaModelBackend(model_ids.IBM_GRANITE_GUARDIAN_3_0_2B)
    )


def is_a_true_subset_of_b(a: list[str], b: list[str]) -> bool:
    """check if a is true subset of b."""
    all_in = True
    for e in a:
        if e not in b:
            all_in = False
            break
    return all_in


def create_check_word_count(max_words: int) -> Callable[[str], bool]:
    """generate a maximum-word-count validation function."""

    def cc(s: str):
        return len(s.split()) <= max_words

    return cc


# ########################################
# Functions for each step of the pipeline
# ########################################


def step_is_input_safe(guardian_session: MelleaSession, docs: list[str]) -> bool:
    """check if the list of docs has no harm."""
    is_safe = True
    for i_doc, doc in enumerate(docs):
        print(f"\nChecking Doc {i_doc + 1}/{len(docs)}", end="...")
        inspect = guardian_session.chat(doc)
        if str(inspect).upper().startswith("YES"):
            is_safe = False
            print("FAILED")
            break
        else:
            print("OK")
    return is_safe


def step_summarize_docs(
    s: MelleaSession, docs: list[str], user_args: dict
) -> list[str]:
    """generate a task-specific document summary for each doc."""
    summaries = []
    for i_doc, doc in enumerate(docs):  # type: ignore
        print(f"\nSummarizing doc {i_doc + 1}/{len(docs)}", end="...")
        summary = s.instruct(
            f"Summarize the following document to answer the question: ' How {{current_subtopic}} impacts {{main_topic}}?' \n Document: {doc}",
            requirements=["Use maximal 3 sentences."],
            user_variables=user_args,
        )
        summaries.append(str(summary))
        print("done.")
    return summaries


def step_generate_outline(
    s: MelleaSession, user_args: dict, context: list[RAGDocument]
) -> list[str]:
    """Generate a report outline using constraint decoding (formatted output)."""

    print("\n Generating outline", end="...")

    class SectionTitles(BaseModel):
        section_titles: list[str]

    def must_have_sections(out: str) -> bool:
        try:
            stt = SectionTitles.model_validate_json(out)
            return is_a_true_subset_of_b(
                ["Introduction", "Conclusion", "References"], stt.section_titles
            )
        except ValidationError:
            return False

    def max_sub_sections(out: str) -> bool:
        try:
            stt = SectionTitles.model_validate_json(out)
            return (
                len(stt.section_titles) <= 3 + user_args["max_subsections"]
            )  # x + Intro, Conclusion, Ref
        except ValidationError:
            return False

    ## Define Requirements
    req_outline = Requirement(
        "In addition to main body of the report, the report should also include these standard sections:  Introduction,  Conclusion, and References",
        validation_fn=simple_validate(must_have_sections),
    )

    req_num_sections = Requirement(
        f"Limit the number of subsections to a maximum of {user_args['max_subsections']}.",
        validation_fn=simple_validate(max_sub_sections),
    )

    outline_context = {
        f"Document {i + 1}": f"## Title: {d.title}, ## Source: {d.source}"
        for i, d in enumerate(context)
    }

    ## Generate
    outline_result = s.instruct(
        description="Create an outline for a report on how {{current_subtopic}} impacts {{main_topic}}. Use the Context Documents provided as guideline for the sections.",
        # output_prefix="# Introduction",
        requirements=[req_outline, req_num_sections],
        user_variables=user_args,
        grounding_context=outline_context,
        strategy=RejectionSamplingStrategy(loop_budget=2),
        return_sampling_results=True,
        format=SectionTitles,
    )

    st = SectionTitles.model_validate_json(outline_result.value)

    if isinstance(outline_result, SamplingResult):
        if not outline_result.success:
            for i_r, r in enumerate(outline_result.sample_generations):
                if r == outline_result.result:
                    print("Validation Results:")
                    for v in outline_result.sample_validations[i_r]:
                        print(f"\t{v[1]} <- {v[0].description}")

    print("done.")
    return st.section_titles


def step_write_full_report(
    m: MelleaSession,
    max_words: int,
    user_args: dict,
    summaries: list[str],
    outline: list[str],
) -> str:
    """Merge summaries and outline into a single report."""

    print("\nWriting full report", end="...")

    ## Define Requirements
    req_focus = Requirement("Stay focused on the topic, avoid unrelated information.")
    req_language = Requirement(f"Write the report in {user_args['language']} language.")
    req_tone = Requirement("Use an {{tone}} tone throughout the report.")
    req_length = Requirement(
        f"The report should have a maximum length of {max_words} words.",
        validation_fn=simple_validate(create_check_word_count(max_words=max_words)),
    )

    user_args.update(
        {
            "context": "\n".join(summaries),
            "outline": "\n".join([f"* {o}" for o in outline]),
        }
    )

    ## Generate
    report_result = m.instruct(
        description="Context:\n{{context}}\nSummarize the relevant information available into a detailed report on how {{current_subtopic}} impacts {{main_topic}}.\n\nFollow this outline:\n{{outline}}",
        requirements=[req_focus, req_length, req_language, req_tone],
        user_variables=user_args,
        strategy=RejectionSamplingStrategy(loop_budget=2, requirements=[req_length]),
        return_sampling_results=True,
    )

    if isinstance(report_result, SamplingResult):
        if not report_result.success:
            for i_r, r in enumerate(report_result.sample_generations):
                if r == report_result.result:
                    print("Validation Results:")
                    for v in report_result.sample_validations[i_r]:
                        print(f"\t{v[1]} <- {v[0].description}")

    print("done.")
    return report_result.value


def research_subtopic(main_topic: str, subtopic: str, context: list[RAGDocument]):
    """Start MiniResearcher here."""
    user_args = {
        "context_docs": context,
        "current_subtopic": subtopic,
        "main_topic": main_topic,
        "max_subsections": 5,
        "existing_headers": "",
        "relevant_written_contents": "",
        "date": "April 26, 2025",
        "language": "English",
        "total_words": 1000,
        "tone": "professional",
    }

    m = get_session()
    guardian_session = get_guardian_session()

    # Step 0: check for Harm in input
    safe_input = step_is_input_safe(guardian_session, docs=[c.content for c in context])
    if not safe_input:
        return {"error": "Input not safe"}

    # Step 1: Summarize each doc
    summaries = step_summarize_docs(
        m, docs=[c.content for c in context], user_args=user_args
    )
    print(f"Summaries: \n{'\n\n'.join([w(s) for s in summaries])}")

    # Step 2: Generate Outline
    outline = step_generate_outline(m, user_args=user_args, context=context)
    print(f"Outline:\n{'\n'.join(outline)}")

    # Step 3: Merge all for the final report
    full_report = step_write_full_report(
        m,
        user_args=user_args,
        max_words=user_args["total_words"],
        summaries=summaries,
        outline=outline,
    )
    return full_report


if __name__ == "__main__":
    # data from an external file
    from docs.examples.mini_researcher.context_docs import documents as ctx_docs

    report = research_subtopic(
        "IBM earnings outlook", "Recent IBM acquisitions", context=ctx_docs
    )

    print(f"\nFull Report:\n\n{w(report)}")