1+1=3. Wait, no, 1+1=2. How to have GPT sanity check itself.

A Python walkthrough of using GPT to double-check its answer using GPT-J in Colab

Initial Answer: 6394 + 250 = 6643
Double Checking: 6643 looks wrong. 6394 + 250 = 6644
Initial Answer: {a} + {b} =
6394 + 250 = 6644
{a} + {b}=
Figure 1: A self-correcting prompt has lower mean absolute percent error (left) and higher exact match (right) than regular few shots on 4 digit arithmetic using GPT-J.

Setting up the query function

def query(context, top_p=0, temp=0, gen_len=50):
tokens = tokenizer.encode(context)
provided_ctx = len(tokens)
pad_amount = seq - provided_ctx
padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
batched_tokens = np.array([padded_tokens] * total_batch)
length = np.ones(total_batch, dtype=np.uint32) * len(tokens)
output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})
samples = []
decoded_tokens = output[1][0]
for o in decoded_tokens[:, :, 0]:
return samples[0]
import pandas as pd, json, randomrandom.seed(42)fourDigitDict = []
fourDigitTest = []
for i in range(30):
a = int(random.random()*10**4)
b = int(random.random()*10**4)
fourDigitDict.append({'first': a, "second": b, 'sum': a+b})
fourDigitTrainDF = pd.DataFrame(fourDigitDict)
for i in range(100):
a = int(random.random()*10**4)
b = int(random.random()*10**4)
fourDigitTest.append({'first': a, "second": b, 'sum': a+b})
fourDigitTestDF = pd.DataFrame(fourDigitTest)
fourDigitTrainDF['randomlyWrong'] = fourDigitTrainDF['sum'].apply(lambda x: random.random() < .5)fourDigitTrainDF['offset'] = fourDigitTrainDF['randomlyWrong'].apply(lambda x: .5-random.random() if x==True else 0)fourDigitTrainDF['offset'] = fourDigitTrainDF['offset'] * 2000fourDigitTrainDF['originalSum'] = fourDigitTrainDF['sum'] + fourDigitTrainDF['offset']
correctionPrompts = {}
purePrompts = {}
for i in range(1,30):
correctionPrompt = ""
purePrompt = ""
for row in fourDigitTrainDF[:i].iterrows():
correctionPrompt += 'Initial Answer: {} + {} = {}'.format(row[1]['first'], row[1]['second'], int(row[1]['originalSum']))
correctionPrompt += '\n'
interjection = 'looks correct.' if not row[1]['randomlyWrong'] else 'looks off by a bit.'
correctionPrompt += 'Double Checking: {} {} {} + {} = {}'.format(int(row[1]['originalSum']), interjection, row[1]['first'], row[1]['second'], row[1]['sum'])
correctionPrompt += '\n\n'
purePrompt += '{} + {} = {}'.format(row[1]['first'], row[1]['second'], row[1]['sum']) + '\n'

correctionPrompt += 'Initial Answer: '
correctionPrompts[i] = correctionPrompt
purePrompts[i] = purePrompt
import json
fourDigitTestDF['formatted'] = fourDigitTestDF.apply(lambda x: "{} + {} =".format(x['first'], x['second']), axis=1)
correctionResults = {}
pureResults = {}
#for each size of example length in 1-30, run on the test set
for trainSize in range(1,30):
if trainSize not in correctionResults:
correctionResults[trainSize] = []
pureResults[trainSize] = []
for example in fourDigitTestDF.formatted:
correctionResults[trainSize].append( query(correctionPrompts[trainSize]+example, gen_len=50))
query(purePrompts[trainSize]+example, gen_len=50))
with open('correctionResults.json', 'w') as fh:
json.dump(correctionResults, fh)
with open('pureResults.json', 'w') as fh:
json.dump(pureResults, fh)
def parsePureInt(x):
base = x.split('\n')[0]
return int(base)
return 0
def parseCorrectedInt(x):
base = x.split('\n\n')[0].split(' ')[-1]
return int(base)
return 0
for key in pureResults.keys():
test[key] = pureResults[key]
test[key] = test[key].apply(lambda x: parsePureInt(x))
pureMape = pd.DataFrame()
for col in test.columns[3:]:
pureMape[col] = (abs(test[col] - test['sum']))/test['sum']
pureEM = pd.DataFrame()
for col in test.columns[3:]:
pureEM[col] = test[col] == test['sum']
for key in correctionResults.keys():
test[key] = correctionResults[key]
test[key] = test[key].apply(lambda x: parseCorrectedInt(x))
correctedMape = pd.DataFrame()
for col in test.columns[3:]:
correctedMape[col] = (abs(test[col] - test['sum']))/test['sum']
correctedeEM = pd.DataFrame()
for col in test.columns[3:]:
correctedeEM[col] = test[col] == test['sum']
fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(12,8), facecolor='white')
correctedMape.mean().plot(label='self correcting', ax=axs[0])
pureMape.mean().plot(ax=axs[0], label='pure few shot')
axs[0].set_xlabel('# of examples')
axs[0].set_title('Mean Error')
correctedeEM.sum().plot(label='self correcting', ax=axs[1])pureEM.sum().plot(ax=axs[1], label='pure few shot')
axs[1].set_xlabel('# of examples')
axs[1].set_title('Exact Match')
fig.suptitle('4 Digit Arithmetic on 100 Tests')
Figure 2. Wait, no, figure 1 again.
Also if you do it in GPT3 with commas you get something more like this depending on the model size / fine tune. Not sure what happened there at 17 examples though.

