| import torch |
| import transformers |
| from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig |
| import openai |
| from openai import OpenAI |
|
|
| def hide(original_input, hide_model, tokenizer): |
| hide_template = """<s>Paraphrase the text:%s\n\n""" |
| input_text = hide_template % original_input |
| inputs = tokenizer(input_text, return_tensors='pt').to(hide_model.device) |
| pred = hide_model.generate( |
| **inputs, |
| generation_config=GenerationConfig( |
| max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3), |
| do_sample=False, |
| num_beams=3, |
| repetition_penalty=5.0, |
| ), |
| ) |
| pred = pred.cpu()[0][len(inputs['input_ids'][0]):] |
| hide_input = tokenizer.decode(pred, skip_special_tokens=True) |
| return hide_input |
|
|
| def seek(hide_input, hide_output, original_input, seek_model, tokenizer): |
| seek_template = """<s>Convert the text:\n%s\n\n%s\n\nConvert the text:\n%s\n\n""" |
| input_text = seek_template % (hide_input, hide_output, original_input) |
| inputs = tokenizer(input_text, return_tensors='pt').to(seek_model.device) |
| pred = seek_model.generate( |
| **inputs, |
| generation_config=GenerationConfig( |
| max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3), |
| do_sample=False, |
| num_beams=3, |
| ), |
| ) |
| pred = pred.cpu()[0][len(inputs['input_ids'][0]):] |
| original_output = tokenizer.decode(pred, skip_special_tokens=True) |
| return original_output |
|
|
| def get_gpt_output(prompt, api_key=None): |
| if not api_key: |
| raise ValueError('an open api key is needed for this function') |
| client = OpenAI(api_key=api_key) |
| completion = client.chat.completions.create( |
| model="gpt-3.5-turbo", |
| messages=[ |
| {"role": "user", "content": prompt} |
| ] |
| ) |
| return completion.choices[0].message.content |