The Crow's Nest

Finetuning RWKV 14bn with QLORA in 4Bit

It was surprisingly easy to get this working, and I think that's a good thing.

First I looked at existing LORA implementations of RWKV which I discovered from the very helpful RWKV Discord. The link I found in the discord landed me at "How to Train Your Raven", shout out to the author Nana. From that blog post I found the general lora implementation from Blealtan.

You can see starting on line 153 and going up to line 192 the modules they've chosen to augment with LORA. Which are the linear attention and feed forward portions of the network, and more specifically they are later applied to the linear layers for receptance, key and value.

class LoraLinear(nn.Module):

    def __init__(self, in_features: int, out_features: int, bias: bool):
        super().__init__()

        self.weight = nn.Parameter(torch.empty((out_features, in_features)))
        assert bias == False, "Biased LoraLinear not supported"

        r, alpha, dropout = LORA_CONFIG["r"], LORA_CONFIG[
            "alpha"], LORA_CONFIG["dropout"]
        self.lora_A = nn.Parameter(torch.empty(r, in_features))
        self.lora_B = nn.Parameter(torch.empty(out_features, r))
        self.lora_dropout = nn.Dropout(dropout)
        self.scaling = alpha / r

        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

    def forward(self, x):
        return (
            F.linear(x, self.weight) + self.scaling *
            F.linear(F.linear(self.lora_dropout(x), self.lora_A), self.lora_B))


@functools.wraps(LoraLinear)
def make_linear_att(*args, **kwargs):
    if "att" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
        return LoraLinear(*args, **kwargs)
    else:
        return nn.Linear(*args, **kwargs)


@functools.wraps(LoraLinear)
def make_linear_ffn(*args, **kwargs):
    if "ffn" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
        return LoraLinear(*args, **kwargs)
    else:
        return nn.Linear(*args, **kwargs)

After looking at the RWKV paper and the model architecture, combined with this LORA implementation I figured that these were the correct modules to target.

Recently I have implemented my own fine tuning library on top of Transformers, Accelerate and PEFT which is iantbutler01/ditty, and the rest of this was simply integrating QLORA into Ditty under an experimental flag.

Implementing QLORA in your Transformers based Pipeline using PEFT

You can install the development versions of Transformers, PEFT and Accelerate with the below snippet. I recommend doing this in a separate virtualenv so that potential breaking changes do not interfere with the other work that you may be doing!

pip install -U git+https://github.com/huggingface/transformers.git 
pip install -U git+https://github.com/huggingface/peft.git
pip install -U git+https://github.com/huggingface/accelerate.git
pip install --upgrade bitsandbytes
Development versions are needed because this isn't stable yet!

After this you can follow the excellent blog post introducing the method from Huggingface, I'll show my implementation of this into Ditty as well with the few changes needed to support RWKV.

After checking some flags set in my library, I enable 4bit and pass the necessary config to Bitsandbytes per the HF post,

  		if self.l8bit and self.l4bit:
            raise ValueError("Cannot set both l8bit and l4bit to True.")

        if self.l4bit and experimental:
            self.bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
            )
        elif self.l4bit and not experimental:
            raise ValueError("To use 4bit, `experimental` must be set to True.")
        elif self.l8bit:
            self.bnb_config = BitsAndBytesConfig(
                load_in_8bit=l8bit, llm_int8_enable_fp32_cpu_offload=fp32_cpu_offload
            )
# I pass the BNB config directly to the model loading step

self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name_or_path,
            device_map="auto",
            quantization_config=self.bnb_config
        )

To enable targeting of RWKV modules I first printed out the model to see how they are named in the Transformers version just in case they're different.

RwkvForCausalLM(
  (rwkv): RwkvModel(
    (embeddings): Embedding(50277, 5120)
    (blocks): ModuleList(
      (0): RwkvBlock(
        (pre_ln): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (ln1): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (ln2): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (attention): RwkvSelfAttention(
          (time_shift): ZeroPad2d((0, 0, 1, -1))
          (key): Linear4bit(in_features=5120, out_features=5120, bias=False)
          (value): Linear4bit(in_features=5120, out_features=5120, bias=False)
          (receptance): Linear4bit(in_features=5120, out_features=5120, bias=False)
          (output): Linear4bit(in_features=5120, out_features=5120, bias=False)
        )
        (feed_forward): RwkvFeedForward(
          (time_shift): ZeroPad2d((0, 0, 1, -1))
          (key): Linear4bit(in_features=5120, out_features=20480, bias=False)
          (receptance): Linear4bit(in_features=5120, out_features=5120, bias=False)
          (value): Linear4bit(in_features=20480, out_features=5120, bias=False)
        )
      )
      (1-39): 39 x RwkvBlock(
        (ln1): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (ln2): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (attention): RwkvSelfAttention(
          (time_shift): ZeroPad2d((0, 0, 1, -1))
          (key): Linear4bit(in_features=5120, out_features=5120, bias=False)
          (value): Linear4bit(in_features=5120, out_features=5120, bias=False)
          (receptance): Linear4bit(in_features=5120, out_features=5120, bias=False)
          (output): Linear4bit(in_features=5120, out_features=5120, bias=False)
        )
        (feed_forward): RwkvFeedForward(
          (time_shift): ZeroPad2d((0, 0, 1, -1))
          (key): Linear4bit(in_features=5120, out_features=20480, bias=False)
          (receptance): Linear4bit(in_features=5120, out_features=5120, bias=False)
          (value): Linear4bit(in_features=20480, out_features=5120, bias=False)
        )
      )
    )
    (ln_out): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
  )
  (head): Linear(in_features=5120, out_features=50277, bias=False)
)

So in this case we're looking at 'key', 'value' and 'receptance'. Then I added a few small changes to my existing pipeline in Ditty.

        if "gpt-neox" in self.model_name_or_path:
            target_modules = ["query_key_value", "xxx"]

        if "rwkv" in self.model_name_or_path:
            target_modules = ["key", "value", "receptance"]

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=target_modules,
            inference_mode=False,
            r=8,
            lora_alpha=16,
            lora_dropout=0.05,
            bias="none"
        )


        if self.l4bit:
            from peft import prepare_model_for_kbit_training
            self.model = prepare_model_for_kbit_training(
                self.model, use_gradient_checkpointing=self.gradient_checkpointing
            )
        elif self.l8bit:
            self.model = prepare_model_for_int8_training(
                self.model, use_gradient_checkpointing=self.gradient_checkpointing
            )

Its worth calling out that I disabled gradient checkpointing, the first time I tried to run with it enabled it errored out so I left it off. I think it's worth investigating and seeing how this can be enabled however.

That's really all there is to it. I also have a few things like bfloat16 enabled for the training pipeline, you can see my configuration using a subclassed Ditty pipeline here:

if __name__ == "__main__":
    pipeline = RWKVPipeline(
        dataset_name="databricks/databricks-dolly-15k",
        model_name_or_path="RWKV/rwkv-raven-14b",
        gradient_checkpointing=False,
        block_size=512,
        grad_accum=32,
        batch_size=1,
        l4bit=True,
        l8bit=False,
        experimental=True,
        fp16=True,
        use_bfloat16=True
        
    )

    pipeline.run()

The link to the implementation can be found here.

At the time of writing this the fine tuning process has finished but I have not yet evaluated the model for quality, I will update when I do!

Login to comment.

Thanks for reading ❤️

You're pretty cool!

🕶️

If it's 👌 with you click subscribe and drop me your email 📧 and you'll get an update whenever I post something new ✨

Subscribe