Legal Agent Router v4

Multi-head legal intake routing classifier. Given a user's description of their legal problem, predicts:

  • Practice area (11-way softmax)
  • Escalation flags (3 binary sigmoids)
  • Auxiliary decisions (attorney review needed, source docs needed, prediction uncertain)

No domestic violence content.

Architecture

  • Encoder: legal-bert-small-uncased (6 layers, 512 hidden, ~35M params)
  • Route head: 11-way softmax β€” picks single most likely practice area
  • Flag heads: 3 independent sigmoids β€” multi-label (can have multiple flags)
  • Aux heads: 3 independent sigmoids
  • Dropout: 0.15

Outputs (17 dimensions)

Practice Areas (softmax)

criminal_law, civil_litigation, corporate_law, family_law, immigration_law, intellectual_property, employment_labor, real_estate, consumer_finance, health_benefits, traffic

Escalation Flags (sigmoid, multi-label)

criminal_exposure, immigration_consequence, imminent_deadline

Auxiliary (sigmoid)

attorney_review_required, source_required, is_uncertain

Training

Parameter Value
Training samples 436
Validation samples 93
Test samples 94
Base model legal-bert-small-uncased
Epochs 15
Learning rate 3e-5, cosine schedule
Batch size 8 Γ— 2 accumulation = 16
Loss CE (routes) + BCE (flags Γ—2, aux)
Hardware A10G (24GB)
Training time ~33s

Evaluation (synthetic test set, 94 samples)

Metric Value
Route accuracy 100%
Route macro F1 1.00
Flag macro F1 0.59
Missed escalation rate 20.0%
Attorney review accuracy 97.9%
Source required accuracy 96.8%
Uncertainty accuracy 94.7%

Per-Flag F1

Flag F1
imminent_deadline 0.87
immigration_consequence 0.46
criminal_exposure 0.44

Manual Test (12 hand-picked cases)

Accuracy: 9/12 = 75%

Note: Route accuracy on synthetic test data is 100% because the test split comes from the same template distribution. Real-world performance will be lower. See Limitations.

Inference Performance

Metric Value
Average latency 0.5ms
P50 latency 0.5ms
P99 latency 0.8ms
Hardware T4 (16GB)
Batch size 8

This classifier is ~500-2000Γ— faster than an LLM-based routing approach and uses 14Γ— fewer parameters.

Usage

import torch, torch.nn as nn
import numpy as np
from transformers import AutoTokenizer, BertConfig
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

ROUTES = ["criminal_law", "civil_litigation", "corporate_law", "family_law",
          "immigration_law", "intellectual_property", "employment_labor", "real_estate",
          "consumer_finance", "health_benefits", "traffic"]
FLAGS = ["criminal_exposure", "immigration_consequence", "imminent_deadline"]
MODEL_ID = "narcolepticchicken/legal-agent-router-v4"

class LegalRouterSoftmax(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = BertModel(config)
        h = config.hidden_size
        self.dropout = nn.Dropout(0.15)
        self.route_head = nn.Linear(h, 11)
        self.flag_head = nn.Linear(h, 3)
        self.attorney_head = nn.Linear(h, 1)
        self.source_head = nn.Linear(h, 1)
        self.uncertainty_head = nn.Linear(h, 1)
    
    def forward(self, input_ids, attention_mask):
        p = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        p = p.last_hidden_state[:, 0, :]
        p = self.dropout(p)
        rl = self.route_head(p); fl = self.flag_head(p)
        al = self.attorney_head(p).squeeze(-1)
        sl = self.source_head(p).squeeze(-1)
        ul = self.uncertainty_head(p).squeeze(-1)
        route_probs = torch.softmax(rl, dim=-1)
        return torch.cat([route_probs, fl, al.unsqueeze(-1), sl.unsqueeze(-1), ul.unsqueeze(-1)], dim=-1)

# Load
config = BertConfig.from_pretrained(MODEL_ID)
model = LegalRouterSoftmax(config)
sd = load_file(hf_hub_download(MODEL_ID, "model.safetensors"))
model.load_state_dict({k.replace('model.', ''): v for k, v in sd.items()})
model.eval()

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

def predict(text, context=""):
    prompt = f"Context: {context} Request: {text}" if context else f"Request: {text}"
    enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=384)
    with torch.no_grad():
        out = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"])
    
    route_idx = out[0, :11].argmax().item()
    flag_probs = 1/(1+torch.exp(-out[0, 11:14]))
    aux_probs = 1/(1+torch.exp(-out[0, 14:]))
    
    return {
        "route": ROUTES[route_idx],
        "route_confidence": out[0, route_idx].item(),
        "flags": {f: flag_probs[i].item() for i, f in enumerate(FLAGS)},
        "attorney_review_required": bool(aux_probs[0] > 0.5),
        "source_required": bool(aux_probs[1] > 0.5),
        "is_uncertain": bool(aux_probs[2] > 0.5),
    }

result = predict("I got a speeding ticket and my court date is next week")
print(result)

Limitations

  • Synthetic training data only β€” all 436 training samples were template-generated. Per-route coverage is ~30-45 examples. Performance on real legal intake text (different vocabulary, real-world patterns) will degrade.
  • Template memorization risk β€” the 100% test accuracy reflects template-level generalization, not out-of-distribution robustness.
  • Weak on rare flags β€” criminal_exposure and immigration_consequence appear in only ~34 training samples each (F1 ~0.45).
  • Single-label routes β€” model predicts exactly one practice area. Multi-area cases (e.g., immigration + criminal) are assigned to one dominant label.
  • English only.
  • Not for legal advice β€” experimental routing aid only.
  • No DV detection β€” this model does not detect or flag domestic violence.

Dataset

Trained on narcolepticchicken/legal-agent-routing-v3 (private) β€” 623 synthetic legal intake scenarios.

License

Apache 2.0


Generated 2026-05-09 21:49 UTC

Generated by ML Intern

This model repository was generated by ML Intern, an agent for machine learning research and development on the Hugging Face Hub.

Downloads last month
146
Safetensors
Model size
35.1M params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support