101_example.py
:
import time
from mellea import LinearContext, MelleaSession
from mellea.backends.aloras.huggingface.granite_aloras import (
HFConstraintAlora,
add_granite_aloras,
)
from mellea.backends.cache import SimpleLRUCache
from mellea.backends.huggingface import LocalHFBackend
from mellea.stdlib.base import GenerateLog
from mellea.stdlib.requirement import Requirement, req
# Define a backend and add the constraint aLora
backend = LocalHFBackend(
model_id="ibm-granite/granite-3.2-8b-instruct", cache=SimpleLRUCache(5)
)
backend.add_alora(
HFConstraintAlora(
name="custom_construant",
path_or_model_id="my_uploaded_model/goes_here", # can also be the checkpoint path
generation_prompt="<|start_of_role|>check_requirement<|end_of_role|>",
backend=backend,
)
)
# Create M session
m = MelleaSession(backend, ctx=LinearContext())
# define a requirement
failure_check = req("The failure mode should not be none.")
# run instruction with requirement attached on the base model
res = m.instruct(
"Write triage summaries based on technician note.", requirements=[failure_check]
)
print("==== Generation =====")
print(f"Model Output: {res}")
print(
f"Generation Prompt: {m.last_prompt()}"
) # retrieve prompt information from session context
def validate_reqs(reqs: list[Requirement]):
"""Validate the requirements against the last output in the session."""
print("==== Validation =====")
print(
"using aLora"
if backend.default_to_constraint_checking_alora
else "using NO alora"
)
# helper to collect validation prompts (because validation calls never get added to session contexts).
logs: list[GenerateLog] = [] # type: ignore
# Run the validation. No output needed, because the last output in "m" will be used. Timing added.
start_time = time.time()
val_res = m.validate(reqs, generate_logs=logs)
end_time = time.time()
delta_t = end_time - start_time
print(f"Validation took {delta_t} seconds.")
print("Validation Results:")
# Print list of requirements and validation results
for i, r in enumerate(reqs):
print(f"- [{val_res[i]}]: {r.description}")
# Print prompts using the logs list
print("Prompts:")
for log in logs:
if isinstance(log, GenerateLog):
print(f" - {{prompt: {log.prompt}\n raw result: {log.result.value} }}") # type: ignore
return end_time - start_time, val_res
# run with aLora -- which is the default if the constraint alora is added to a model
validate_reqs([failure_check])
# force to run without alora
backend.default_to_constraint_checking_alora = False
validate_reqs([failure_check])
backend.default_to_constraint_checking_alora = True