top of page


import os
import random
import numpy as np
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    EarlyStoppingCallback,
    Trainer,
    TrainingArguments,
    TrainerCallback
)
from torch.utils.data import Dataset
import pandas as pd
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score, precision_score, recall_score
from pathlib import Path
import json

# Parameters for tokenization and chunking
max_len = 512  # Max length of tokens as per BERT

# Define label mapping with zero-based indexing
category_mapping = {
    0: "equal protection",
    1: "education adequacy",
    2: "both",
    3: "other"
}

# Create a reverse mapping for convenience (if needed)
reverse_category_mapping = {v: k for k, v in category_mapping.items()}

# Load your prepared data (ensure this path is correct)
data = Court_Cases_Doc  # Assume this DataFrame is already loaded
# If you need to ensure 'documents' are strings, keep this line; otherwise, remove it
data['documents'] = data['documents'].astype(str)

# Ensure 'Constitution_Type' is integer and within the expected range
assert data['Constitution_Type'].min() >= 0, "Constitution_Type should be zero-based."

# Create 'labels' column based on 'Constitution_Type' as integers
data['labels'] = data['Constitution_Type']

# Optionally, create a copy of 'labels' as 'original_labels' with category names for reference
data['original_labels'] = data['labels'].map(category_mapping)

# Loss logger callback
class LossLoggerCallback(TrainerCallback):
    def __init__(self, log_file='loss_log.txt'):
        self.train_losses = []
        self.val_losses = []
        self.log_file = log_file
        with open(self.log_file, 'w') as f:
            f.write("Step\tTrain Loss\tVal Loss\n")

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            if 'loss' in logs:
                self.train_losses.append(logs['loss'])
            if 'eval_loss' in logs:
                self.val_losses.append(logs['eval_loss'])
            with open(self.log_file, 'a') as f:
                f.write(f"{state.global_step}\t{logs.get('loss', 'N/A')}\t{logs.get('eval_loss', 'N/A')}\n")

# Self-training method to update labels
def update_categories(data, model, tokenizer, device, threshold=0.7):
    updated_labels = []
    for idx, text in enumerate(data['documents']):
        inputs = tokenizer(
            text,
            return_tensors='pt',
            truncation=True,
            max_length=max_len
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            pred = torch.softmax(logits, dim=1)
            confidence, predicted_label = torch.max(pred, dim=1)
            if confidence.item() > threshold:  # Only update if confident
                updated_labels.append(predicted_label.item())
            else:
                updated_labels.append(data['labels'].iloc[idx])  # Keep original label based on the current index
    data['labels'] = updated_labels

# Dataset class that handles chunking with a fast tokenizer
class SentenceBasedDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def chunk_tokens(self, tokens, chunk_size):
        """ Split tokens into chunks based on chunk size. """
        return [tokens[i:i + chunk_size] for i in range(0, len(tokens), chunk_size)]

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        # Tokenize and split into chunks
        tokens = self.tokenizer.tokenize(text)
        if len(tokens) > self.max_len:
            token_chunks = self.chunk_tokens(tokens, self.max_len)
        else:
            token_chunks = [tokens]

        # Process the first chunk for now
        chunk = token_chunks[0]
        encoding = self.tokenizer(
            chunk,
            is_split_into_words=True,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Define base directories using pathlib
base_dir = Path('/content/gdrive/My Drive/WonJay/State_Court_Cases/Code_Lines/Python/')
results_dir = base_dir / 'results_fold'
logs_dir = base_dir / 'logs_fold'
results_dir.mkdir(parents=True, exist_ok=True)
logs_dir.mkdir(parents=True, exist_ok=True)

# Load pre-trained model and tokenizer
model_path = base_dir / 'trained_legalbert_model'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=len(category_mapping))

# Chunking and prediction function
def average_predictions(text, tokenizer, model, max_len, device, confidence_threshold=0.7):
    tokens = tokenizer.tokenize(text)
    total_tokens = len(tokens)

    if total_tokens > max_len:
        chunks = [tokens[i:i + max_len] for i in range(0, len(tokens), max_len)]
    else:
        chunks = [tokens]

    predictions = []
    confidences = []
    selected_tokens = 0

    for chunk in chunks:
        chunk_text = tokenizer.convert_tokens_to_string(chunk)
        inputs = tokenizer(
            chunk_text,
            return_tensors='pt',
            truncation=True,
            max_length=max_len,
            padding='max_length'
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            pred = torch.softmax(logits, dim=1)
            predictions.append(pred.cpu())
            max_conf, _ = torch.max(pred, dim=1)
            confidences.extend(max_conf.cpu().numpy())

            # Count tokens if confidence exceeds threshold
            if max_conf.item() > confidence_threshold:
                selected_tokens += len(chunk)

    avg_predictions = torch.mean(torch.stack(predictions), dim=0)
    avg_confidence = np.mean(confidences)

    return avg_predictions, avg_confidence, selected_tokens, total_tokens

# Perform 5-fold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Initialize columns for predictions
data['predicted_categories'] = None
data['confidence_scores'] = None
data['total_number_of_tokens'] = None

# Initialize best metrics
best_loss = float('inf')
best_accuracy = 0
best_model_state = None
base_output_dir = results_dir  # Using pathlib Path object

# Initialize an empty DataFrame to collect loss data from all folds
loss_df = pd.DataFrame(columns=['Fold', 'Epoch', 'Train Loss', 'Eval Loss'])

num_folds = 5
splits = list(kf.split(data))

for fold in range(1, num_folds + 1):
    train_index, val_index = splits[fold - 1]
    print(f"Starting fold {fold}...")

    # Reset the lists for each fold to avoid carryover
    cv_train_losses =[]
    cv_eval_losses = []
    cv_epochs = []
    fold_numbers = []
   
    # Define fold-specific output and logging directories
    output_dir = results_dir / f"fold{fold}"
    output_dir.mkdir(parents=True, exist_ok=True)
    
    fold_logging_dir = logs_dir / f"fold{fold}"
    fold_logging_dir.mkdir(parents=True, exist_ok=True)

    # Split into train and validation sets
    train_texts = data['documents'].iloc[train_index].tolist()
    val_texts = data['documents'].iloc[val_index].tolist()
    train_labels = data['labels'].iloc[train_index].tolist()
    val_labels = data['labels'].iloc[val_index].tolist()

    # Prepare datasets
    train_dataset = SentenceBasedDataset(train_texts, train_labels, tokenizer, max_len)
    val_dataset = SentenceBasedDataset(val_texts, val_labels, tokenizer, max_len)

    # Hyperparameter tuning
    training_args = TrainingArguments(
        output_dir=str(output_dir),
        evaluation_strategy="epoch",
        save_strategy="epoch",  # Save model at the end of each epoch
        per_device_train_batch_size=32,
        per_device_eval_batch_size=64,
        num_train_epochs=5,
        weight_decay=0.01,
        logging_dir=str(fold_logging_dir),  # Use fold_logging_dir here
        logging_steps=1,
        fp16=True,
        gradient_accumulation_steps=2,
        metric_for_best_model='eval_loss',  # Track evaluation loss for best model
        greater_is_better=False,  # Lower loss is better
        load_best_model_at_end=True,  # Load the best model at the end
        learning_rate=1e-5
    )

    # Define CategoryUpdateCallback
    class CategoryUpdateCallback(TrainerCallback):
        def __init__(self, update_every, train_data, model, tokenizer, device, threshold):
            super().__init__()
            self.update_every = update_every
            self.train_data = train_data
            self.model = model
            self.tokenizer = tokenizer
            self.device = device
            self.threshold = threshold

        def on_epoch_end(self, args, state, control, **kwargs):
            if (state.epoch + 1) % self.update_every == 0:
                print(f"Updating categories after epoch {state.epoch + 1}...")
                update_categories(
                    self.train_data,
                    self.model,
                    self.tokenizer,
                    self.device,
                    threshold=self.threshold
                )

    # Instantiate callbacks
    update_callback = CategoryUpdateCallback(
        update_every=2,  # Update labels every 2 epochs
        train_data=data.iloc[train_index],  # Pass the actual slice
        model=model,
        tokenizer=tokenizer,
        device=device,
        threshold=0.7
    )

    loss_logger_callback = LossLoggerCallback()

    # Trainer setup with early stopping and custom callbacks
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        callbacks=[loss_logger_callback, update_callback, EarlyStoppingCallback(early_stopping_patience=5)]
    )

    # Different random seeds for each fold
    random_seed = 42 + fold
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)

    # Train the model
    trainer.train()

    # Retrieve loss history from trainer
    train_losses = loss_logger_callback.train_losses
    eval_losses = loss_logger_callback.val_losses
    
    # Get number of epochs (ensure consistency)
    num_epochs = min(len(train_losses), len(eval_losses))

    # Append loss data to lists
    cv_train_losses.extend(train_losses[:num_epochs])
    cv_eval_losses.extend(eval_losses[:num_epochs])
    cv_epochs.extend(list(range(1, num_epochs + 1)))
    fold_numbers.extend([fold] * num_epochs)

    # Print the number of epochs completed for this fold
    print(f"Fold {fold} completed with {num_epochs} epochs.")

    # Create a temporary DataFrame for the current fold's loss data
    fold_loss_df = pd.DataFrame({
        'Fold': fold_numbers,
        'Epoch': cv_epochs,
        'Train Loss': cv_train_losses,
        'Eval Loss': cv_eval_losses
        })

    # Print DataFrame to check if it contains data
    print(f"Data in fold_loss_df for fold {fold}:", fold_loss_df)  
    
    # Define the output file path for the current fold
    base_dir2 = Path('/content/gdrive/My Drive/WonJay/State_Court_Cases/')
    results_output_dir = base_dir2 / 'Results' / 'Python'
    excel_file_path = results_output_dir / f"loss_results_fold_{fold}.xlsx"

    # Save the current fold's loss data to an Excel file in the specified path
    fold_loss_df.to_excel(excel_file_path)

    # Append the fold's loss data to the main loss_df
    loss_df = pd.concat([loss_df, fold_loss_df], ignore_index=True)

    # Update best model based on validation loss
    if val_loss < best_loss:
        best_loss = val_loss
        best_accuracy = val_accuracy
        best_model_state = model.state_dict()
        torch.save(best_model_state, base_output_dir / f"best_model_fold{fold}.pth")

    # Predictions
    predictions = trainer.predict(val_dataset)
    predicted_labels = np.argmax(predictions.predictions, axis=1)

    # Calculate metrics
    val_precision = precision_score(val_labels, predicted_labels, average='macro')
    val_recall = recall_score(val_labels, predicted_labels, average='macro')
    val_f1 = f1_score(val_labels, predicted_labels, average='macro')

    # Print metrics
    print(f"Fold {fold} Validation Loss: {val_loss:.4f}")
    print(f"Fold {fold} Validation Accuracy: {val_accuracy:.4f}")
    print(f"Fold {fold} Validation Precision: {val_precision:.4f}")
    print(f"Fold {fold} Validation Recall: {val_recall:.4f}")
    print(f"Fold {fold} Validation F1-score: {val_f1:.4f}")

# Define results directory
base_dir2 = Path('/content/gdrive/My Drive/WonJay/State_Court_Cases/')
results_output_dir = base_dir2 / 'Results' / 'Python'
results_output_dir.mkdir(parents=True, exist_ok=True)

# Save DataFrame to Excel file
excel_file_path2 = results_output_dir / 'legal_sentiment_bert_loss_cv.xlsx'
loss_df.to_excel(excel_file_path2, index=False)

print(f"Loss data saved to {excel_file_path}")

# Make predictions using the best model on the entire dataset
if best_model_state:
    model.load_state_dict(best_model_state)
    model.to(device)

    # Initialize lists to collect true and predicted labels
    all_true_labels = []
    all_predicted_labels = []

    # Iterate through all documents for final predictions
    print("\nStarting final predictions on the entire dataset...")
    for idx, row in data.iterrows():
        text = row['documents']
        true_label = row['labels']

        avg_pred, avg_conf, selected_tokens, total_tokens = average_predictions(
            text, tokenizer, model, max_len, device
        )
        predicted_label = torch.argmax(avg_pred).item()

        # Append to the lists
        all_true_labels.append(true_label)
        all_predicted_labels.append(predicted_label)

        # Assign predicted categories and confidence scores
        data.at[idx, 'predicted_categories'] = category_mapping.get(predicted_label, 'unknown')
        data.at[idx, 'confidence_scores'] = avg_conf
        data.at[idx, 'total_number_of_tokens'] = total_tokens

    # Compute overall metrics
    overall_precision = precision_score(all_true_labels, all_predicted_labels, average='macro')
    overall_recall = recall_score(all_true_labels, all_predicted_labels, average='macro')
    overall_f1 = f1_score(all_true_labels, all_predicted_labels, average='macro')

    print("\nOverall Metrics on the Entire Dataset:")
    print(f"Precision: {overall_precision:.4f}")
    print(f"Recall:    {overall_recall:.4f}")
    print(f"F1-Score:  {overall_f1:.4f}")

    # Save overall metrics to a separate Excel file
    overall_metrics_df = pd.DataFrame({
        'Metric': ['Precision', 'Recall', 'F1-Score'],
        'Score': [overall_precision, overall_recall, overall_f1]
    })

    overall_metrics_file_path = results_output_dir / 'legal_sentiment_bert_overall_metrics.xlsx'
    overall_metrics_df.to_excel(overall_metrics_file_path, index=False)

    print(f"Overall metrics saved to {overall_metrics_file_path}")

# Define final CSV file path
csv_file_path = results_output_dir / 'legal_sentiment_bert_cv.csv'
final_data = data[['file_name', 'Constitution_Type', 'original_labels', 'predicted_categories', 
                  'confidence_scores', 'total_number_of_tokens']]
final_data.to_csv(csv_file_path, index=False)

print("\nFinal results saved to:")
print(f"- Predictions: {csv_file_path}")
print(f"- Overall Metrics: {overall_metrics_file_path}")

 

bottom of page