1+1=3. Wait, no, 1+1=2. How to have GPT sanity check itself.

A Python walkthrough of using GPT to double-check its answer using GPT-J in Colab

Initial Answer: 6394 + 250 = 6643
Double Checking: 6643 looks wrong. 6394 + 250 = 6644
Initial Answer: {a} + {b} =
6394 + 250 = 6644
{a} + {b}=
Figure 1: A self-correcting prompt has lower mean absolute percent error (left) and higher exact match (right) than regular few shots on 4 digit arithmetic using GPT-J.

Setting up the query function

def query(context, top_p=0, temp=0, gen_len=50):
tokens = tokenizer.encode(context)
provided_ctx = len(tokens)
pad_amount = seq - provided_ctx
padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
batched_tokens = np.array([padded_tokens] * total_batch)
length = np.ones(total_batch, dtype=np.uint32) * len(tokens)
output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})
samples = []
decoded_tokens = output[1][0]
for o in decoded_tokens[:, :, 0]:
samples.append(tokenizer.decode(o))
return samples[0]
import pandas as pd, json, randomrandom.seed(42)fourDigitDict = []
fourDigitTest = []
for i in range(30):
a = int(random.random()*10**4)
b = int(random.random()*10**4)
fourDigitDict.append({'first': a, "second": b, 'sum': a+b})
fourDigitTrainDF = pd.DataFrame(fourDigitDict)
for i in range(100):
a = int(random.random()*10**4)
b = int(random.random()*10**4)
fourDigitTest.append({'first': a, "second": b, 'sum': a+b})
fourDigitTestDF = pd.DataFrame(fourDigitTest)
fourDigitTrainDF['randomlyWrong'] = fourDigitTrainDF['sum'].apply(lambda x: random.random() < .5)fourDigitTrainDF['offset'] = fourDigitTrainDF['randomlyWrong'].apply(lambda x: .5-random.random() if x==True else 0)fourDigitTrainDF['offset'] = fourDigitTrainDF['offset'] * 2000fourDigitTrainDF['originalSum'] = fourDigitTrainDF['sum'] + fourDigitTrainDF['offset']
correctionPrompts = {}
purePrompts = {}
for i in range(1,30):
correctionPrompt = ""
purePrompt = ""
for row in fourDigitTrainDF[:i].iterrows():
correctionPrompt += 'Initial Answer: {} + {} = {}'.format(row[1]['first'], row[1]['second'], int(row[1]['originalSum']))
correctionPrompt += '\n'
interjection = 'looks correct.' if not row[1]['randomlyWrong'] else 'looks off by a bit.'
correctionPrompt += 'Double Checking: {} {} {} + {} = {}'.format(int(row[1]['originalSum']), interjection, row[1]['first'], row[1]['second'], row[1]['sum'])
correctionPrompt += '\n\n'
purePrompt += '{} + {} = {}'.format(row[1]['first'], row[1]['second'], row[1]['sum']) + '\n'

correctionPrompt += 'Initial Answer: '
correctionPrompts[i] = correctionPrompt
purePrompts[i] = purePrompt
import json
fourDigitTestDF['formatted'] = fourDigitTestDF.apply(lambda x: "{} + {} =".format(x['first'], x['second']), axis=1)
correctionResults = {}
pureResults = {}
#for each size of example length in 1-30, run on the test set
for trainSize in range(1,30):
if trainSize not in correctionResults:
print(trainSize)
correctionResults[trainSize] = []
pureResults[trainSize] = []
for example in fourDigitTestDF.formatted:
correctionResults[trainSize].append( query(correctionPrompts[trainSize]+example, gen_len=50))
pureResults[trainSize].append(
query(purePrompts[trainSize]+example, gen_len=50))
with open('correctionResults.json', 'w') as fh:
json.dump(correctionResults, fh)
with open('pureResults.json', 'w') as fh:
json.dump(pureResults, fh)
def parsePureInt(x):
base = x.split('\n')[0]
try:
return int(base)
except:
return 0
def parseCorrectedInt(x):
base = x.split('\n\n')[0].split(' ')[-1]
try:
return int(base)
except:
return 0
for key in pureResults.keys():
test[key] = pureResults[key]
test[key] = test[key].apply(lambda x: parsePureInt(x))
pureMape = pd.DataFrame()
for col in test.columns[3:]:
pureMape[col] = (abs(test[col] - test['sum']))/test['sum']
pureEM = pd.DataFrame()
for col in test.columns[3:]:
pureEM[col] = test[col] == test['sum']
for key in correctionResults.keys():
test[key] = correctionResults[key]
test[key] = test[key].apply(lambda x: parseCorrectedInt(x))
correctedMape = pd.DataFrame()
for col in test.columns[3:]:
correctedMape[col] = (abs(test[col] - test['sum']))/test['sum']
correctedeEM = pd.DataFrame()
for col in test.columns[3:]:
correctedeEM[col] = test[col] == test['sum']
fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(12,8), facecolor='white')
correctedMape.mean().plot(label='self correcting', ax=axs[0])
pureMape.mean().plot(ax=axs[0], label='pure few shot')
axs[0].legend()
axs[0].set_xlabel('# of examples')
axs[0].set_title('Mean Error')
correctedeEM.sum().plot(label='self correcting', ax=axs[1])pureEM.sum().plot(ax=axs[1], label='pure few shot')
axs[1].legend()
axs[1].set_xlabel('# of examples')
axs[1].set_title('Exact Match')
fig.suptitle('4 Digit Arithmetic on 100 Tests')
Figure 2. Wait, no, figure 1 again.
Also if you do it in GPT3 with commas you get something more like this depending on the model size / fine tune. Not sure what happened there at 17 examples though.

Sign up for The Variable

By Towards Data Science

Every Thursday, the Variable delivers the very best of Towards Data Science: from hands-on tutorials and cutting-edge research to original features you don't want to miss. Take a look.

Your home for data science. A Medium publication sharing concepts, ideas and codes.

AI Alignment and Safety

Multi-stakeholder feedback has its inherent flaws, hence, its best to treat them as ‘one of the means’ than an ‘as an end’ to establishing and maintaining Ethics in Artificial Intelligence

Source: Freepik | High view diverse wooden characters inclusion concept

Share your ideas with millions of readers.

Author Spotlight

“I decided to focus my efforts on my abilities rather than my shortcomings.”

’s conversation with .

Photo courtesy of Parul Pandey

Get the Medium app