Solving the AI accuracy problem for Product Support | November 20th at 3 pm ET / 12 pm PT | Register Here

Book a Demo
Knowledge Graph
NLP
Large Language Models
Llama 2
Fine Tuning

KG-to-Text with Llama 2

Beating the current SOTA Knowledge Graph-to-Text (KG-to-Text) model on the WebNLG (Constrained) dataset with a fine-tuned Llama 2 7B Chat model.

Ahmed Ismail
Ahmed Ismail
Dec 27, 2023

LLMs are evolving rapidly and getting better at performing complex tasks. In parallel, the use of Knowledge Graphs is increasing, due to their ability to render a semantic understanding of data, which can be leveraged by machine learning algorithms. These two technologies make the task of knowledge graph-to-text generation (KG-to-Text) more relevant today than ever.

What is KG-to-Text?

KG-to-text is a Natural Language Generation (NLG) task in which the input is a knowledge graph, and the expected output is a coherent piece of text that is consistent with the information represented in the knowledge graph.

While much of the advancement in AI is centered around LLMs, their input is text. Yet, raw text lacks a semantic understanding of entities, their attributes, and relationships. Knowledge graphs excel at providing this type of information. Retrieving relevant excerpts from knowledge graphs and using them to create text can bring across the semantic structure of that linked data (which can also serve as better input to generative AI).

KG-to-text is beneficial for many applications with a knowledge graph backend. For example, Wu et al. retrieve knowledge graph triples that are relevant to a question and use KG-to-text to generate the answer. Similarly, Hao et al. retrieve information that is relevant to a user’s post from a knowledge graph and use KG-to-text to generate replies.

One commonly used dataset in KG-to-Text is WebNLG. It consists of a set of small knowledge graphs and the corresponding lexicalizations (the text containing the information represented in the knowledge graph). The current state-of-the-art (SOTA) on the WebNLG (Constrained) dataset is the JointGT model, which is built by fine-tuning the pre-trained transformer models on three tasks: (1) reconstructing a masked version of the text from the knowledge graph, (2) reconstructing a masked version of the knowledge graph from the text, and (3) aligning the graph and text embeddings using optimal transport. A specific type of attention layer called the structure-aware self-attention is used in JointGT to provide information about the graph structure for each transformer layer.

This post explores how a Llama 2 7B Chat LLM can be fine-tuned for KG-to-text. Instead of modifying the model or adding components that enhance its knowledge graph awareness, we just use a good prompt and directly fine-tune the LLM with QLoRA.

The code is also available in the following Jupyter Notebook:

KG-toText with Llama 2

Data Preparation

This post uses the version 2.1 constrained dataset, which contains three JSON files, each corresponding to a subset of the data (train, dev, and test).

The following command clones the WebNLG repo.


git clone https://gitlab.com/shimorina/webnlg-dataset.git
ls webnlg-dataset/release_v2.1_constrained/json

# Output: webnlg_release_v2.1_constrained_dev.json webnlg_release_v2.1_constrained_train.json webnlg_release_v2.1_constrained_test.json

A sample entry in the dataset looks as follows:


{'1': {'category': 'Airport',
  'lexicalisations': [{'comment': 'good',
    'lex': 'Abilene, Texas is served by the Abilene regional airport.',
    'xml_id': 'Id1'},
   {'comment': 'good',
    'lex': 'Abilene Regional Airport serves the city of Abilene in Texas.',
    'xml_id': 'Id2'}],
  'modifiedtripleset': [{'object': 'Abilene,_Texas',
    'property': 'cityServed',
    'subject': 'Abilene_Regional_Airport'}],
  'originaltriplesets': {'originaltripleset': [[{'object': 'Abilene,_Texas',
      'property': 'cityServed',
      'subject': 'Abilene_Regional_Airport'}],
    [{'object': 'Abilene,_Texas',
      'property': 'city',
      'subject': 'Abilene_Regional_Airport'}]]},
  'shape': '(X (X))',
  'shape_type': 'NA',
  'size': '1',
  'xml_id': 'Id1'}
}

Each example in the dataset is a JSON object which contains, among other fields:

  • A modifiedtripleset field: The set of triples in the knowledge graph.
  • A lexicalization field: Multiple texts that correspond to the triples.

For this exercise, all the other fields are ignored.

Prompting

The prompt for the sample entry above should look as follows:


Following is a set of knowledge graph triples delimited by triple backticks, each on a separate line, in the format: subject | predicate | object.

```
Abilene_Regional_Airport | cityServed | Abilene,_Texas
```

Generate a coherent piece of text that contains all of the information in the triples. Only use information from the provided triples.
After you finish writing the piece of text, write triple dollar signs (i.e.: $$$).

It is challenging for Llama 2 7B Chat to know when to stop because the pad token is the same as the end-of-sentence token, so we prompt it to produce triple dollar signs at the end. Later in the modeling phase, a stopping criteria function is used to stop the generation process when these tokens are produced.

The prompt format for Llama 2 chat models defines instructions and user messages as follows:

	
<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>

{{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s>&:lt;s>[INST] {{ user_msg_2 }} [/INST]

Based on this format, this code defines the prompt template.


prompt_text = """[INST] Following is a set of knowledge graph triples delimited by triple backticks, each on a separate line, in the format: subject | predicate | object.

```
{triples}
```

Generate a coherent piece of text that contains all of the information in the triples. Only use information from the provided triples.
After you finish writing the piece of text, write triple dollar signs (i.e.: $$$).[/INST]"""

The formatted triples will replace the {triples} placeholder.

Creating the Hugging Face Dataset

The Hugging Face dataset loads data from JSON Lines files, so the data should be in this format. The following code creates these files and uses them to create the Hugging Face datasets.


!pip install -q datasets jsonlines

import json
import jsonlines
import os

from datasets import load_dataset
from tqdm import tqdm

def format_triplets(triplets):
        """Helper function to format triples."""
        return '\n'.join([f"{triplet['subject']} | {triplet['property']} | {triplet['object']}" for triplet in triplets])


dataset_dir_path = "webnlg-dataset/release_v2.1_constrained/json"
data_subsets_file_names = {
    "train": "webnlg_release_v2.1_constrained_train.json",
    "dev": "webnlg_release_v2.1_constrained_dev.json",
    "test": "webnlg_release_v2.1_constrained_test.json"
}

data_subsets_file_paths = {k: os.path.join(dataset_dir_path, v) for k, v in data_subsets_file_names.items()}

for data_subset_name, data_subset_file_path in data_subsets_file_paths.items():
    all_responses = []
    with open(data_subset_file_path, 'r') as data_subset_file:
        data_subset_dict = json.load(data_subset_file)

    with jsonlines.open(f"{data_subset_name}.jsonl", mode='w') as writer:
        for i, entry in enumerate(tqdm(data_subset_dict["entries"])):
            triples = format_triplets(entry[str(i+1)]['modifiedtripleset'])
            responses = [l["lex"] for l in entry[str(i+1)]['lexicalisations']]
            all_responses.append(responses)
            lexicalizations = entry[str(i+1)]['lexicalisations']
            good_responses = [l["lex"] for l in lexicalizations if l["comment"] == "good"] \
                if data_subset_name != "test" else [lexicalizations[0]["lex"]]
            for response in good_responses:
                prompt = prompt_text.format(triples=triples)
                writer.write({"prompt": prompt, "response": response + "$$$ </s>"})

# Load the Hugging Face datasets
train_dataset = load_dataset('json', data_files='train.jsonl', split="train")

# Preprocess the Hugging Face datasets
train_dataset = train_dataset.map(
    lambda examples: {'text': [f"{prompt} {response}"
    for prompt, response in zip(examples['prompt'], examples['response'])]}, batched=True
    )

test_dataset = load_dataset('json', data_files='test.jsonl', split="train")

Here is a sample entry from the final Hugging Face dataset:


{
    "prompt": "[INST] Following is a set of knowledge graph triples delimited by triple backticks, each on a separate line, in the format: subject | predicate | object.\n\n```\nAbilene,_Texas | cityServed | Abilene_Regional_Airport\n```\n\nGenerate a coherent piece of text that contains all of the information in the triples. Only use information from the provided triples.\nAfter you finish writing the piece of text, write triple dollar signs (i.e.: $$$).[/INST]",
    "response": "Abilene, Texas is served by the Abilene regional airport.$$$ ",
    "text": "[INST] Following is a set of knowledge graph triples delimited by triple backticks, each on a separate line, in the format: subject | predicate | object.\n\n```\nAbilene,_Texas | cityServed | Abilene_Regional_Airport\n```\n\nGenerate a coherent piece of text that contains all of the information in the triples. Only use information from the provided triples.\nAfter you finish writing the piece of text, write triple dollar signs (i.e.: $$$).[/INST] Abilene, Texas is served by the Abilene regional airport.$$$ "
}

Modeling

The following packages are used in the fine-tuning process:

  • transformers: The Hugging Face library, used to obtain the Llama 2 7B Chat model.
  • peft: Short for Parameter-Efficient Fine-Tuning, used to fine-tune the LLM without needing to modify all the model parameters.
  • bitsandbytes: This library runs the model in 4-bit precision.
  • trl: Short for Transformer Reinforcement Learning, the library that does the fine-tuning.

The following code defines the model configurations and runs the fine-tuning process.


!pip install -q accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 transformers==4.31.0 trl==0.4.7

import torch

from peft import LoraConfig
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments
from trl import SFTTrainer

base_model_name = "meta-llama/Llama-2-7b-chat-hf"
output_dir = "./results_chat_7b_3_epoch/final_checkpoint"

tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

device_map = {"": 0}

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False
)

base_model = AutoModelForCausalLM.from_pretrained(  ## If it fails at this line, restart the runtime and try again.
    base_model_name,
    quantization_config=bnb_config,
    device_map=device_map,
    trust_remote_code=True,
    use_auth_token=True
)
base_model.config.use_cache = False

# More info: https://github.com/huggingface/transformers/pull/24906
base_model.config.pretraining_tp = 1 

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)

training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=5,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    optim="paged_adamw_32bit",
    save_steps=25,
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=False,
    bf16=False,
    max_grad_norm=0.3,
    logging_steps=10,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
)

trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_args,
)

trainer.train()
trainer.model.save_pretrained(output_dir)

Results

To obtain the results, the following code creates a Hugging Face pipeline and provides the model and stopping criteria to stop the generation when the triple dollar signs are generated.


from peft import AutoPeftModelForCausalLM
from transformers import pipeline, StoppingCriteria, StoppingCriteriaList

class StopOnTripleDollarSigns(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor,
**kwargs) -> bool:
        if ''.join(tokenizer.convert_ids_to_tokens(input_ids[0][-3:])).endswith("$$$"):
            return True
        return False

model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map=device_map, torch_dtype=torch.bfloat16, load_in_8bit=True)

stopping_criteria = StoppingCriteriaList([StopOnTripleDollarSigns()])
text_gen = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=500, stopping_criteria=stopping_criteria)

text_gen = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=500, stopping_criteria=stopping_criteria)

Finally, the following code generates the text, cleans it by removing the original prompt and the triple dollar signs, and saves it along with the reference texts for evaluation.


from transformers.pipelines.pt_utils import KeyDataset

def clean_result_text(result_text, test_prompt):
    result_text_without_prompt = result_text[len(test_prompt):]
    delimiter_index = result_text_without_prompt.find('$$$')
    cleaned_result_text = result_text_without_prompt[:delimiter_index].strip()
    return cleaned_result_text

results = []
for out in tqdm(text_gen(KeyDataset(test_dataset, "prompt"), top_k=1)):
    results.append(out)

results_texts = [result[0]["generated_text"] for result in results]

cleaned_results_texts = [clean_result_text(result_text, test_prompt)
    for result_text, test_prompt in zip(results_texts, test_dataset["prompt"])]

results_file_text = '\n'.join(cleaned_results_texts)
references_file_text = '\n\n'.join(['\n'.join(responses) for responses in all_responses])

with open('results.txt', 'w') as results_file:
    results_file.write(results_file_text)

with open('references.txt', 'w') as references_file:
    references_file.write(references_file_text)

Evaluation

To evaluate the results, we use Data-to-text-Evaluation-Metric (the library used by JointGT).

To run the evaluation process, clone this repo and install the required software, then change the current directory to the location of the repo on your device and run the measure_scores.py script.


cd Data-to-text-Evaluation-Metric/
python measure_scores.py <path_to_references_file> <path_to_results_file>

Replace <path_to_references_file> and <path_to_results_file> with the paths to the references and results file we created earlier, respectively.

The following table lists the values for the metrics we obtained as compared to the results listed on paperswithcode.com for the current SOTA model:

Evaluation results of fine-tuned Llama 2 compared to JointGT (T5) on the WebNLG (Constrained) test set.

This method achieves a higher result than the SOTA method without utilizing anything other than the textual representation of the knowledge graph in the prompt.

Conclusion

The use of KG-to-Text enables you to tap into the semantic understanding of a knowledge graph when using LLMs for generative AI. In this post, we fine-tuned Llama 2 7B Chat using QLoRA to perform KG-to-Text, achieving results that surpass the current SOTA on the WebNLG (Constrained) dataset.