| from typing import Optional |
|
|
| from inseq.commands.attribute_context.attribute_context_args import AttributeContextArgs |
| from inseq.commands.attribute_context.attribute_context_helpers import ( |
| AttributeContextOutput, |
| filter_rank_tokens, |
| get_filtered_tokens, |
| ) |
| from inseq.models import HuggingfaceModel |
|
|
|
|
| def get_formatted_attribute_context_results( |
| model: HuggingfaceModel, |
| args: AttributeContextArgs, |
| output: AttributeContextOutput, |
| ) -> str: |
| """Format the results of the context attribution process.""" |
|
|
| def format_context_comment( |
| model: HuggingfaceModel, |
| has_other_context: bool, |
| special_tokens_to_keep: list[str], |
| context: str, |
| context_scores: list[float], |
| other_context_scores: Optional[list[float]] = None, |
| is_target: bool = False, |
| ) -> str: |
| context_tokens = get_filtered_tokens( |
| context, |
| model, |
| special_tokens_to_keep, |
| replace_special_characters=True, |
| is_target=is_target, |
| ) |
| context_token_tuples = [(t, None) for t in context_tokens] |
| scores = context_scores |
| if has_other_context: |
| scores += other_context_scores |
| context_ranked_tokens, _ = filter_rank_tokens( |
| tokens=context_tokens, |
| scores=scores, |
| std_threshold=args.attribution_std_threshold, |
| topk=args.attribution_topk, |
| ) |
| for idx, _, tok in context_ranked_tokens: |
| context_token_tuples[idx] = (tok, "Influential context") |
| return context_token_tuples |
|
|
| out = [] |
| output_current_tokens = get_filtered_tokens( |
| output.output_current, |
| model, |
| args.special_tokens_to_keep, |
| replace_special_characters=True, |
| is_target=True, |
| ) |
| for example_idx, cci_out in enumerate(output.cci_scores, start=1): |
| curr_output_tokens = [(t, None) for t in output_current_tokens] |
| cti_idx = cci_out.cti_idx |
| curr_output_tokens[cti_idx] = ( |
| curr_output_tokens[cti_idx][0], |
| "Context sensitive", |
| ) |
| if args.has_input_context: |
| input_context_tokens = format_context_comment( |
| model, |
| args.has_output_context, |
| args.special_tokens_to_keep, |
| output.input_context, |
| cci_out.input_context_scores, |
| cci_out.output_context_scores, |
| ) |
| if args.has_output_context: |
| output_context_tokens = format_context_comment( |
| model, |
| args.has_input_context, |
| args.special_tokens_to_keep, |
| output.output_context, |
| cci_out.output_context_scores, |
| cci_out.input_context_scores, |
| is_target=True, |
| ) |
| out += [ |
| ("\n\n" if example_idx > 1 else "", None), |
| ( |
| f"#{example_idx}.\nGenerated output:\t", |
| None, |
| ), |
| ] |
| out += curr_output_tokens |
| if args.has_input_context: |
| out += [("\nInput context:\t", None)] |
| out += input_context_tokens |
| if args.has_output_context: |
| out += [("\nOutput context:\t", None)] |
| out += output_context_tokens |
| return out |
|
|