""" Benchmark script for Ollama models. Runs a quick generation benchmark on a local Ollama server, measuring token throughput for code and reasoning prompts. Results are printed as a markdown table and optionally saved to a file. # See: Usage: python bench.py [--host URL] [--models MODEL1,MODEL2] [--out PATH] """ from __future__ import annotations import argparse import json import sys import time import urllib.error import urllib.request from dataclasses import dataclass from typing import List, Tuple # ---------------------------------------------------------------------- # Constants # ---------------------------------------------------------------------- DEFAULT_HOST = "http://127.0.0.1:11434" DEFAULT_MODELS = ["gpt-oss:120b", "gpt-oss:20b", "gemma4"] CODE_PROMPT = ( "Write a Python function called parse_passwd that reads /etc/passwd and returns a dict " "mapping each username to {uid:int, gid:int, home:str, shell:str}. Skip empty lines and " "lines starting with #. Include a one-line docstring. Output ONLY the function code, " "no explanation." ) REASONING_PROMPT = ( "Plan a strategy for safely resetting a colleague work laptop before they leave. " "List (1) questions to ask first with reasoning, (2) order of operations, " "(3) explicit human-authorization gates." ) # ---------------------------------------------------------------------- # Data structures # ---------------------------------------------------------------------- @dataclass class BenchmarkResult: model: str task: str gen_tps: float prompt_tps: float total_s: float output_tokens: int # ---------------------------------------------------------------------- # Helper functions # ---------------------------------------------------------------------- def post_generate( host: str, model: str, prompt: str, temperature: float, num_predict: int = 4000, ) -> dict: """POST a generation request to Ollama and return the parsed JSON response.""" url = f"{host.rstrip('/')}/api/generate" payload = { "model": model, "prompt": prompt, "stream": False, "options": { "num_ctx": 8192, "temperature": temperature, "num_predict": num_predict, }, } data = json.dumps(payload).encode("utf-8") req = urllib.request.Request( url, data=data, headers={"Content-Type": "application/json"}, method="POST", ) try: with urllib.request.urlopen(req) as resp: if resp.status != 200: raise RuntimeError(f"HTTP {resp.status} from Ollama") body = resp.read().decode("utf-8") return json.loads(body) except urllib.error.URLError as e: raise RuntimeError(f"Failed to contact Ollama at {host}: {e}") from e def warm_up(host: str, model: str) -> None: """Perform a minimal request to load the model into memory.""" try: post_generate(host, model, "hi", temperature=0.0, num_predict=1) except Exception as e: raise RuntimeError(f"Warm‑up failed for model '{model}': {e}") from e def compute_metrics(resp: dict) -> Tuple[float, float, float, int]: """ Extract timing information from Ollama's response. Returns: eval_tps, prompt_tps, total_seconds, output_tokens """ eval_count = resp.get("eval_count", 0) eval_duration_ns = resp.get("eval_duration", 0) prompt_eval_count = resp.get("prompt_eval_count", 0) prompt_eval_duration_ns = resp.get("prompt_eval_duration", 0) total_duration_ns = resp.get("total_duration", 0) eval_seconds = eval_duration_ns / 1e9 if eval_duration_ns else 0.0 prompt_seconds = prompt_eval_duration_ns / 1e9 if prompt_eval_duration_ns else 0.0 total_seconds = total_duration_ns / 1e9 if total_duration_ns else 0.0 eval_tps = eval_count / eval_seconds if eval_seconds else 0.0 prompt_tps = prompt_eval_count / prompt_seconds if prompt_seconds else 0.0 return eval_tps, prompt_tps, total_seconds, eval_count def format_markdown(results: List[BenchmarkResult]) -> str: """Render results as a markdown table.""" header = "| model | task | gen_tps | prompt_tps | total_s | output_tokens |\n" separator = "|---|---|---:|---:|---:|---:|\n" rows = [ f"| {r.model} | {r.task} | {r.gen_tps:.2f} | {r.prompt_tps:.2f} | {r.total_s:.2f} | {r.output_tokens} |\n" for r in results ] return header + separator + "".join(rows) # ---------------------------------------------------------------------- # Main execution # ---------------------------------------------------------------------- def main() -> None: parser = argparse.ArgumentParser(description="Ollama benchmark script") parser.add_argument( "--host", default=DEFAULT_HOST, help="Base URL of the Ollama server (default: %(default)s)", ) parser.add_argument( "--models", default=",".join(DEFAULT_MODELS), help="Comma‑separated list of model names (default: %(default)s)", ) parser.add_argument( "--out", help="Path to write the markdown table (optional)", ) args = parser.parse_args() host: str = args.host.rstrip("/") models: List[str] = [m.strip() for m in args.models.split(",") if m.strip()] out_path: str | None = args.out tasks: List[Tuple[str, str, float]] = [ ("code", CODE_PROMPT, 0.2), ("reasoning", REASONING_PROMPT, 0.3), ] total_tests = len(models) * len(tasks) results: List[BenchmarkResult] = [] test_index = 1 for model in models: # Warm‑up each model once try: warm_up(host, model) except RuntimeError as e: print(f"Error: {e}", file=sys.stderr) sys.exit(1) for task_name, prompt, temperature in tasks: print( f"[{test_index}/{total_tests}] {model} {task_name} task...", file=sys.stderr, ) try: resp = post_generate(host, model, prompt, temperature) eval_tps, prompt_tps, total_s, out_tokens = compute_metrics(resp) results.append( BenchmarkResult( model=model, task=task_name, gen_tps=eval_tps, prompt_tps=prompt_tps, total_s=total_s, output_tokens=out_tokens, ) ) except RuntimeError as e: print(f"Error during benchmark: {e}", file=sys.stderr) sys.exit(1) test_index += 1 markdown = format_markdown(results) print(markdown) if out_path: try: with open(out_path, "w", encoding="utf-8") as f: f.write(markdown) except OSError as e: print(f"Failed to write output file '{out_path}': {e}", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main()