1+1=3. Wait, no, 1+1=2. How to have GPT sanity check itself.
A Python walkthrough of using GPT to double-check its answer using GPT-J in Colab
Large language models have a problem where they tend to just make stuff up. This can be because of the training data, the prompt, or even just ambiguity. This can be mitigated by engineering the prompt to have GPT sanity check its output, and it works with both GPT3 and GPT-J (the latter of which you can use for free).
What do I mean by a sanity check? Well, it turns out by setting up your prompts to double check the output using something like:
Initial Answer: 6394 + 250 = 6643
Double Checking: 6643 looks wrong. 6394 + 250 = 6644Initial Answer: {a} + {b} =
you can improve performance compared to using a traditional prompt that looks more like:
6394 + 250 = 6644
{a} + {b}=
Also, for those like me who don’t read good, you can skip this post and just run the code using GPT-J yourself using the linked colab.
For background, there are a few other methods that people use to mitigate this sort of problem. The most popular is just manually curating for answers you like. That’s boring, and doesn’t really help if you’re trying to automate processes. There’s another method which involves having a model show its work through intermediary steps which can improve reasoning, which is cool. However, I figure it’s worth taking a look at self correcting to see how well GPT can look at what it spat out and decide whether it needs improving because it doesn’t seem like people have looked into it much (although there were some cool tweets about anecdotally guess-and-checking random algebraic solutions last year).
For a toy example, I went with arithmetic. Why arithmetic? Well, first of all, we can measure if an answer’s correct which is often good to know. Second, while there’s an argument that the reason language models don’t do great at arithmetic is due to the way GPT’s tokens are encoded, I suspect that there’s a search issue with multi-digit arithmetic where you have to add from right to left to carry digits, and when you’re just predicting ahead it’s hard to guess where the digits are gonna get carried. GPT isn’t great at that type of search, and if search is contributing to this issue, then doing a double check should help it.
Anyway, in this post I’ll show you how to generate contexts with between 1–30 examples, run their completions using a standard few shot and an improved self-checking few shot, and then we’ll plot the errors (figure 1) and see that indeed, the self-checking prompts have a lower error and a higher exact match than the traditional few shot. Here’s what the graph’ll look like (Note for exact match we’re getting 16% top accuracy, whereas GPT3 large with commas you can get over 90% so there can be quite a difference).
Setting up the query function
This works on the GPT-J as well the GPT-3 models, you’ll just have to set up your query function differently, which is where we’ll start. When we query a completion for this sort of task, we want to use 0 temperature because we want to get the most likely tokens instead of randomly throwing in other numbers. Here’s what my query function in J looks like (borrowed from the infer function in the GPT-J demo notebook which is pretty cool):
def query(context, top_p=0, temp=0, gen_len=50):
tokens = tokenizer.encode(context)
provided_ctx = len(tokens)
pad_amount = seq - provided_ctx
padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
batched_tokens = np.array([padded_tokens] * total_batch)
length = np.ones(total_batch, dtype=np.uint32) * len(tokens)
output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})
samples = []
decoded_tokens = output[1][0]
for o in decoded_tokens[:, :, 0]:
samples.append(tokenizer.decode(o))
return samples[0]
Setting up the Prompt
Now we’ll go ahead and set up our prompt. We’ll create 30 random examples to be our prompt and 100 examples as the test set.
import pandas as pd, json, randomrandom.seed(42)fourDigitDict = []
fourDigitTest = []for i in range(30):
a = int(random.random()*10**4)
b = int(random.random()*10**4)
fourDigitDict.append({'first': a, "second": b, 'sum': a+b})
fourDigitTrainDF = pd.DataFrame(fourDigitDict)for i in range(100):
a = int(random.random()*10**4)
b = int(random.random()*10**4)
fourDigitTest.append({'first': a, "second": b, 'sum': a+b})
fourDigitTestDF = pd.DataFrame(fourDigitTest)
For the pure/traditional case, we’d just have the examples all lined up as the prompt, but instead what we’ll do here for the self-correcting is randomly set up some of these to be wrong. We’ll call the wrong answers (or right ones) originalSum so that we can tell the model to correct it to the right answer later.
fourDigitTrainDF['randomlyWrong'] = fourDigitTrainDF['sum'].apply(lambda x: random.random() < .5)fourDigitTrainDF['offset'] = fourDigitTrainDF['randomlyWrong'].apply(lambda x: .5-random.random() if x==True else 0)fourDigitTrainDF['offset'] = fourDigitTrainDF['offset'] * 2000fourDigitTrainDF['originalSum'] = fourDigitTrainDF['sum'] + fourDigitTrainDF['offset']
Next we just create our prompts. We’ll have one set of prompts for the pure few shots and one set of prompts for the corrected ones. And to see how many examples we need, we’ll just do a grid search and literally try running from 1–30 examples.
correctionPrompts = {}
purePrompts = {}for i in range(1,30):
correctionPrompt = ""
purePrompt = ""
for row in fourDigitTrainDF[:i].iterrows():
correctionPrompt += 'Initial Answer: {} + {} = {}'.format(row[1]['first'], row[1]['second'], int(row[1]['originalSum']))
correctionPrompt += '\n'
interjection = 'looks correct.' if not row[1]['randomlyWrong'] else 'looks off by a bit.'
correctionPrompt += 'Double Checking: {} {} {} + {} = {}'.format(int(row[1]['originalSum']), interjection, row[1]['first'], row[1]['second'], row[1]['sum'])
correctionPrompt += '\n\n' purePrompt += '{} + {} = {}'.format(row[1]['first'], row[1]['second'], row[1]['sum']) + '\n'
correctionPrompt += 'Initial Answer: '
correctionPrompts[i] = correctionPrompt
purePrompts[i] = purePrompt
Now we’ve got all of our prompts set up, time to try it against the test set!
Running the Test
To run the test, we’ll begin by going and setting up the actual query (a + b =) and then adding it to the end of each of our prompts (either with self correction or not). Then we’ll just run the whole thing and sit back as it does the 6,000 queries. We’ll also dump it to JSON each time we go through the test set in case things break so we can recover.
import json
fourDigitTestDF['formatted'] = fourDigitTestDF.apply(lambda x: "{} + {} =".format(x['first'], x['second']), axis=1)correctionResults = {}
pureResults = {}#for each size of example length in 1-30, run on the test set
for trainSize in range(1,30):
if trainSize not in correctionResults:
print(trainSize)
correctionResults[trainSize] = []
pureResults[trainSize] = []
for example in fourDigitTestDF.formatted:
correctionResults[trainSize].append( query(correctionPrompts[trainSize]+example, gen_len=50))
pureResults[trainSize].append(
query(purePrompts[trainSize]+example, gen_len=50))
with open('correctionResults.json', 'w') as fh:
json.dump(correctionResults, fh)
with open('pureResults.json', 'w') as fh:
json.dump(pureResults, fh)
Evaluation
K, so now that we got our 6,000 files, we’ll evaluate it. I’ll start by renaming the test set to test because… why not. Since I haven’t figured out how to do a stop sequence with GPT-J (My degree’s in journalism, math’s hard) we’ll have to parse out the answer we’re trying to get. For the traditional method, that’s just the first ‘word’ out the door. For the self-correcting method, it’s the last ‘word’ of the first set of outputs.
def parsePureInt(x):
base = x.split('\n')[0]
try:
return int(base)
except:
return 0def parseCorrectedInt(x):
base = x.split('\n\n')[0].split(' ')[-1]
try:
return int(base)
except:
return 0
Now we’ll apply this to all of our results and calculate the errors.
for key in pureResults.keys():
test[key] = pureResults[key]
test[key] = test[key].apply(lambda x: parsePureInt(x))pureMape = pd.DataFrame()
for col in test.columns[3:]:
pureMape[col] = (abs(test[col] - test['sum']))/test['sum']pureEM = pd.DataFrame()
for col in test.columns[3:]:
pureEM[col] = test[col] == test['sum']for key in correctionResults.keys():
test[key] = correctionResults[key]
test[key] = test[key].apply(lambda x: parseCorrectedInt(x))correctedMape = pd.DataFrame()
for col in test.columns[3:]:
correctedMape[col] = (abs(test[col] - test['sum']))/test['sum']correctedeEM = pd.DataFrame()
for col in test.columns[3:]:
correctedeEM[col] = test[col] == test['sum']
And now we just plot the results
fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(12,8), facecolor='white')
correctedMape.mean().plot(label='self correcting', ax=axs[0])
pureMape.mean().plot(ax=axs[0], label='pure few shot')
axs[0].legend()
axs[0].set_xlabel('# of examples')
axs[0].set_title('Mean Error')correctedeEM.sum().plot(label='self correcting', ax=axs[1])pureEM.sum().plot(ax=axs[1], label='pure few shot')
axs[1].legend()
axs[1].set_xlabel('# of examples')
axs[1].set_title('Exact Match')
fig.suptitle('4 Digit Arithmetic on 100 Tests')
In Closing
You can run the above code with GPT3 as well as J and you’ll get the same sort of improvement. I leave it as an exercise to the reader to do it with commas as well (you can just use f”{x:,}” to format things with commas). Pretty cool that it works though.
I think what’s going on here is that since it can’t predict ahead sometimes it starts making up random digits until it gets to the end, and then that gives it the context on checking for what the first couple digits should be the next time through. I’m not entirely sure though. These models are weird.