| |
|
| |
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | import torch |
| |
|
| |
|
| | def load_rag_benchmark_tester_ds(): |
| |
|
| | |
| | from datasets import load_dataset |
| |
|
| | ds_name = "llmware/rag_instruct_benchmark_tester" |
| |
|
| | dataset = load_dataset(ds_name) |
| |
|
| | print("update: loading RAG Benchmark test dataset - ", dataset) |
| |
|
| | test_set = [] |
| | for i, samples in enumerate(dataset["train"]): |
| | test_set.append(samples) |
| |
|
| | |
| | |
| |
|
| | return test_set |
| |
|
| |
|
| | def run_test(model_name, test_ds): |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | print("\nRAG Performance Test - 200 questions") |
| | print("update: model - ", model_name) |
| | print("update: device - ", device) |
| |
|
| | model = AutoModelForCausalLM.from_pretrained(model_name) |
| | model.to(device) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
|
| | for i, entries in enumerate(test_ds): |
| |
|
| | |
| | new_prompt = "<human>: " + entries["context"] + "\n" + entries["query"] + "\n" + "<bot>:" |
| |
|
| | inputs = tokenizer(new_prompt, return_tensors="pt") |
| | start_of_output = len(inputs.input_ids[0]) |
| |
|
| | |
| | |
| |
|
| | outputs = model.generate( |
| | inputs.input_ids.to(device), |
| | eos_token_id=tokenizer.eos_token_id, |
| | pad_token_id=tokenizer.eos_token_id, |
| | do_sample=True, |
| | temperature=0.3, |
| | max_new_tokens=100, |
| | ) |
| |
|
| | output_only = tokenizer.decode(outputs[0][start_of_output:],skip_special_tokens=True) |
| |
|
| | |
| |
|
| | eot = output_only.find("<|endoftext|>") |
| | if eot > -1: |
| | output_only = output_only[:eot] |
| |
|
| | bot = output_only.find("<bot>:") |
| | if bot > -1: |
| | output_only = output_only[bot+len("<bot>:"):] |
| |
|
| | |
| |
|
| | print("\n") |
| | print(i, "llm_response - ", output_only) |
| | print(i, "gold_answer - ", entries["answer"]) |
| |
|
| | return 0 |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | test_ds = load_rag_benchmark_tester_ds() |
| |
|
| | model_name = "llmware/bling-1.4b-0.1" |
| | output = run_test(model_name,test_ds) |
| |
|
| |
|
| |
|