From 549941b2db768eb86f11afe6e95e0d132b121d41 Mon Sep 17 00:00:00 2001
From: gwern <gwern@gwern.net>
Date: Sun, 22 Dec 2019 22:08:13 -0500
Subject: [PATCH] Preference learning for ABC music generation: misc fixes and
 tweaks (see https://gwern.net/GPT-2-preference-learning for detailed
 discussion)

---
 launch.py                            | 33 ++++++++++++++--------------
 lm_human_preferences/train_policy.py |  3 ++-
 lm_human_preferences/train_reward.py |  2 +-
 sample.py                            |  8 ++++---
 4 files changed, 24 insertions(+), 22 deletions(-)

diff --git a/launch.py b/launch.py
index a63b4b7..e432b6d 100755
--- a/launch.py
+++ b/launch.py
@@ -6,16 +6,16 @@ from lm_human_preferences import train_policy, train_reward
 
 
 books_task = combos(
-    bind('query_length', 64),
+    bind('query_length', 2), # must be a minimum of 2 (but why?)
     bind('query_dataset', 'books'),
-    bind('response_length', 24),
-    bind('start_text', '.'), # Start the context at the beginning of a sentence
+    bind('response_length', 256),
+    bind('start_text', ''), # no conditioning aside from 'X:' in sample.py
     bind('end_text', '.'), # End the context at the end of a sentence.
     bind('truncate_token', 13), # Encoding of '.' -- end completions at the end of a sentence.
     bind('truncate_after', 16), # Make sure completions are at least 16 tokens long.
 
-    bind('policy.temperature', 0.7),
-    bind('policy.initial_model', '124M'),
+    bind('policy.temperature', 1.0),
+    bind('policy.initial_model', '117M-irish'),
 )
 
 summarize_cnndm_task = combos(
@@ -48,7 +48,7 @@ summarize_tldr_task = combos(
 
 def get_train_reward_experiments():
     _shared = combos(
-        bind('labels.type', 'best_of_4'),
+        bind('labels.type', 'best_of_2'),
         bind('normalize_after', True),
         bind('normalize_before', True),
         bind('normalize_samples', 256),
@@ -58,9 +58,9 @@ def get_train_reward_experiments():
     _books_task = combos(
         bind_nested('task', books_task),
         _shared,
-        bind('batch_size', 32),
-        bind('lr', 5e-5),
-        bind('rollout_batch_size', 512),
+        bind('batch_size', 10),
+        bind('lr', 5e-5), # original: 5e-5
+        bind('rollout_batch_size', 226),
     )
 
     sentiment = combos(
@@ -75,8 +75,8 @@ def get_train_reward_experiments():
     descriptiveness = combos(
         _books_task,
 
-        bind('labels.source', 'gs://lm-human-preferences/labels/descriptiveness/offline_5k.json'),
-        bind('labels.num_train', 4_992),
+        bind('labels.source', 'irish.json'),
+        bind('labels.num_train', 16900), # poems: 5306; irish:
         bind('run.seed', 1)
     )
 
@@ -112,16 +112,15 @@ def get_train_reward_experiments():
 
     return locals()
 
-
 def get_experiments():
     train_reward_experiments = get_train_reward_experiments()
 
     _books_task = combos(
         bind_nested('task', books_task),
 
-        bind('ppo.lr', 1e-5),
-        bind('ppo.total_episodes', 1_000_000),
-        bind('ppo.batch_size', 512),
+        bind('ppo.lr', 1e-6), # original: 5e-5
+        bind('ppo.total_episodes', 1_000_000), # original: 1_000_000; note, this is *episodes*, not *steps*; each step consists of _n_ episodes
+        bind('ppo.batch_size', 18), # original: 512
     )
 
     sentiment = combos(
@@ -139,9 +138,9 @@ def get_experiments():
 
     descriptiveness = combos(
         _books_task,
-        bind('rewards.kl_coef', 0.15),
+        bind('rewards.kl_coef', 0.02),
         bind('rewards.adaptive_kl', 'on'),
-        bind('rewards.adaptive_kl.target', 6.0),
+        bind('rewards.adaptive_kl.target', 25.0),
 
         bind('rewards.train_new_model', 'on'),
         bind_nested('rewards.train_new_model', train_reward_experiments['descriptiveness']),
diff --git a/lm_human_preferences/train_policy.py b/lm_human_preferences/train_policy.py
index db02c98..b349717 100644
--- a/lm_human_preferences/train_policy.py
+++ b/lm_human_preferences/train_policy.py
@@ -282,6 +282,7 @@ class PPOTrainer():
         step_started_at = time.time()
 
         queries = self.sample_queries()
+        queries = np.tile([55,25], (queries.shape[0],1)) # Irish ABC prefix: 'X:' (ie for the initial numeric ID)
         rollouts = self.policy.respond(queries, length=self.hparams.task.response_length)
 
         responses = rollouts['responses']
@@ -398,7 +399,7 @@ def make_score_fn(hparams, score_model):
 
     def score_fn(queries, responses):
         responses = postprocess(responses)
-        score = penalize(responses, unpenalized_score_fn(queries, responses))
+        score = unpenalized_score_fn(queries, responses)
         return score, responses, dict(score=score)
     score_fn.stat_schemas = dict(score=Schema(tf.float32, (None,)))
     return score_fn
diff --git a/lm_human_preferences/train_reward.py b/lm_human_preferences/train_reward.py
index ab1d09f..7cd0243 100755
--- a/lm_human_preferences/train_reward.py
+++ b/lm_human_preferences/train_reward.py
@@ -79,7 +79,7 @@ def download_labels(source, label_type, question_schemas, total_labels, comm):
 
     # TODO: download on just one rank?  then do: labels = utils.mpi_bcast_tensor_dict(labels, comm=comm)
     if source != 'test':
-        with open(gcs.download_file_cached(source, comm=comm)) as f:
+        with open(source) as f:
             results = json.load(f)
             print('Num labels found in source:', len(results))
     else:
diff --git a/sample.py b/sample.py
index e65f701..8f2185c 100755
--- a/sample.py
+++ b/sample.py
@@ -12,6 +12,7 @@ from lm_human_preferences.policy import Policy
 from lm_human_preferences.language import trained_models
 from lm_human_preferences import lm_tasks
 from lm_human_preferences import train_policy
+import numpy as np
 
 def sample_policy(save_dir=None, savescope='policy', temperature=1.0, seed=None, batch_size=4, nsamples=0):
     hparams = train_policy.HParams()
@@ -56,10 +57,11 @@ def sample_policy(save_dir=None, savescope='policy', temperature=1.0, seed=None,
             generated = 0
             while nsamples_per_rank == 0 or generated < nsamples_per_rank:
                 queries = sample_queries()
-                rollouts = policy.respond(queries, length=task.response_length)
-                assert len(queries.tolist()) == batch_size
+                queries = np.tile([55,25], (queries.shape[0],1)) # 'X:'
+                rollouts = policy.respond(queries, length=1024)
+                assert len(queries) == batch_size
                 assert len(rollouts['responses'].tolist()) == batch_size
-                for q, r in zip(queries.tolist(), rollouts['responses'].tolist()):
+                for q, r in zip(queries, rollouts['responses'].tolist()):
                     print('=' * 80)
                     print(encoder.decode(q).replace("\n", "⏎"))
                     print(encoder.decode(r).replace("\n", "⏎"))
-- 
2.17.1

