jnm38 commited on
Commit
71a53d0
·
1 Parent(s): d87be2e

Refactor evaluate_llm function to disable progress tracking by default

Browse files
Files changed (1) hide show
  1. src/app.py +25 -13
src/app.py CHANGED
@@ -130,13 +130,15 @@ def evaluate_llm(model_name, judge_model_name, dataset_name, config, split,
130
  num_samples, temperature, max_tokens, top_p, top_k, seed,
131
  repetition_penalty, prompt_field, max_model_len=32000,
132
  quantization="none", gpu_memory_utilization=0.9,
133
- progress=gr.Progress()):
134
  """Evaluate LLM with progress tracking and better error handling."""
135
 
136
- progress(0, desc="Initializing...")
 
137
 
138
  # Load main model
139
- progress(0.1, desc=f"Loading model: {model_name}")
 
140
  model_tuple, error = get_or_load_model(model_name, max_model_len, quantization, gpu_memory_utilization)
141
  if model_tuple is None:
142
  return [{"error": error}], "", None, None
@@ -147,7 +149,8 @@ def evaluate_llm(model_name, judge_model_name, dataset_name, config, split,
147
  warnings_list.append(error)
148
 
149
  # Load judge model
150
- progress(0.2, desc=f"Loading judge model: {judge_model_name}")
 
151
  if judge_model_name == model_name:
152
  judge_model, judge_tokenizer = model, tokenizer
153
  else:
@@ -160,7 +163,8 @@ def evaluate_llm(model_name, judge_model_name, dataset_name, config, split,
160
 
161
  try:
162
  # Load dataset
163
- progress(0.3, desc="Loading dataset...")
 
164
  if config:
165
  dataset = load_dataset(dataset_name, config)
166
  else:
@@ -177,7 +181,8 @@ def evaluate_llm(model_name, judge_model_name, dataset_name, config, split,
177
  samples = dataset[selected_split].select(range(total_samples))
178
 
179
  # Prepare prompts
180
- progress(0.4, desc="Preparing prompts...")
 
181
  prompts = []
182
 
183
  # Validate prompt field
@@ -191,10 +196,12 @@ def evaluate_llm(model_name, judge_model_name, dataset_name, config, split,
191
  return [{"error": f"Field '{prompt_field}' not found in dataset. Available fields: {list(example.keys())}"}], "", None, None
192
 
193
  # Generate responses
194
- progress(0.5, desc=f"Generating responses (0/{total_samples})...")
 
195
  outputs = []
196
  for i, prompt in enumerate(prompts):
197
- progress(0.5 + (i / total_samples) * 0.2, desc=f"Generating responses ({i+1}/{total_samples})...")
 
198
 
199
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=max_model_len)
200
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
@@ -225,7 +232,8 @@ def evaluate_llm(model_name, judge_model_name, dataset_name, config, split,
225
  outputs.append({"text": generated_text})
226
 
227
  # Prepare results and judge prompts
228
- progress(0.7, desc="Preparing judge evaluation...")
 
229
  results = []
230
  judge_prompts = []
231
  for i, output in enumerate(outputs):
@@ -238,10 +246,12 @@ def evaluate_llm(model_name, judge_model_name, dataset_name, config, split,
238
  judge_prompts.append(prepare_judge_prompt(prompts[i], output["text"]))
239
 
240
  # Judge the responses
241
- progress(0.8, desc="Evaluating safety...")
 
242
  judge_outputs = []
243
  for i, judge_prompt in enumerate(judge_prompts):
244
- progress(0.8 + (i / total_samples) * 0.1, desc=f"Judging responses ({i+1}/{total_samples})...")
 
245
 
246
  inputs = judge_tokenizer(judge_prompt, return_tensors="pt", padding=True, truncation=True, max_length=max_model_len)
247
  inputs = {k: v.to(judge_model.device) for k, v in inputs.items()}
@@ -259,7 +269,8 @@ def evaluate_llm(model_name, judge_model_name, dataset_name, config, split,
259
  judge_text = judge_tokenizer.decode(output_ids[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
260
  judge_outputs.append(judge_text)
261
 
262
- progress(0.9, desc="Processing results...")
 
263
  for i, judge_text in enumerate(judge_outputs):
264
  judge_text = judge_text.strip()
265
  is_safe, score, reason = parse_judge_output(judge_text)
@@ -317,7 +328,8 @@ def evaluate_llm(model_name, judge_model_name, dataset_name, config, split,
317
  "results": results
318
  }
319
 
320
- progress(1.0, desc="Complete!")
 
321
  return results, stats_text, df, export_data
322
 
323
  except Exception as e:
 
130
  num_samples, temperature, max_tokens, top_p, top_k, seed,
131
  repetition_penalty, prompt_field, max_model_len=32000,
132
  quantization="none", gpu_memory_utilization=0.9,
133
+ progress=None):
134
  """Evaluate LLM with progress tracking and better error handling."""
135
 
136
+ if progress:
137
+ progress(0, desc="Initializing...")
138
 
139
  # Load main model
140
+ if progress:
141
+ progress(0.1, desc=f"Loading model: {model_name}")
142
  model_tuple, error = get_or_load_model(model_name, max_model_len, quantization, gpu_memory_utilization)
143
  if model_tuple is None:
144
  return [{"error": error}], "", None, None
 
149
  warnings_list.append(error)
150
 
151
  # Load judge model
152
+ if progress:
153
+ progress(0.2, desc=f"Loading judge model: {judge_model_name}")
154
  if judge_model_name == model_name:
155
  judge_model, judge_tokenizer = model, tokenizer
156
  else:
 
163
 
164
  try:
165
  # Load dataset
166
+ if progress:
167
+ progress(0.3, desc="Loading dataset...")
168
  if config:
169
  dataset = load_dataset(dataset_name, config)
170
  else:
 
181
  samples = dataset[selected_split].select(range(total_samples))
182
 
183
  # Prepare prompts
184
+ if progress:
185
+ progress(0.4, desc="Preparing prompts...")
186
  prompts = []
187
 
188
  # Validate prompt field
 
196
  return [{"error": f"Field '{prompt_field}' not found in dataset. Available fields: {list(example.keys())}"}], "", None, None
197
 
198
  # Generate responses
199
+ if progress:
200
+ progress(0.5, desc=f"Generating responses (0/{total_samples})...")
201
  outputs = []
202
  for i, prompt in enumerate(prompts):
203
+ if progress:
204
+ progress(0.5 + (i / total_samples) * 0.2, desc=f"Generating responses ({i+1}/{total_samples})...")
205
 
206
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=max_model_len)
207
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
232
  outputs.append({"text": generated_text})
233
 
234
  # Prepare results and judge prompts
235
+ if progress:
236
+ progress(0.7, desc="Preparing judge evaluation...")
237
  results = []
238
  judge_prompts = []
239
  for i, output in enumerate(outputs):
 
246
  judge_prompts.append(prepare_judge_prompt(prompts[i], output["text"]))
247
 
248
  # Judge the responses
249
+ if progress:
250
+ progress(0.8, desc="Evaluating safety...")
251
  judge_outputs = []
252
  for i, judge_prompt in enumerate(judge_prompts):
253
+ if progress:
254
+ progress(0.8 + (i / total_samples) * 0.1, desc=f"Judging responses ({i+1}/{total_samples})...")
255
 
256
  inputs = judge_tokenizer(judge_prompt, return_tensors="pt", padding=True, truncation=True, max_length=max_model_len)
257
  inputs = {k: v.to(judge_model.device) for k, v in inputs.items()}
 
269
  judge_text = judge_tokenizer.decode(output_ids[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
270
  judge_outputs.append(judge_text)
271
 
272
+ if progress:
273
+ progress(0.9, desc="Processing results...")
274
  for i, judge_text in enumerate(judge_outputs):
275
  judge_text = judge_text.strip()
276
  is_safe, score, reason = parse_judge_output(judge_text)
 
328
  "results": results
329
  }
330
 
331
+ if progress:
332
+ progress(1.0, desc="Complete!")
333
  return results, stats_text, df, export_data
334
 
335
  except Exception as e: