Skip to main content
Fine-tune a causal language model to produce a LoRA or aLoRA adapter. Loads a JSONL dataset of item/label pairs, applies an 80/20 train/validation split, and trains using HuggingFace PEFT and TRL’s SFTTrainer — saving the checkpoint with the lowest validation loss. Supports CUDA, MPS (macOS, PyTorch ≥ 2.8), and CPU device selection, and handles the alora_invocation_tokens configuration required for aLoRA training.

Functions

FUNC load_dataset_from_json

load_dataset_from_json(json_path, tokenizer, invocation_prompt)
Load a JSONL dataset and format it for SFT training. Reads item/label pairs from a JSONL file and builds a HuggingFace Dataset with input and target columns. Each input is formatted as "\{item\}\nRequirement: <|end_of_text|>\n\{invocation_prompt\}". Args:
  • json_path: Path to the JSONL file containing item/label pairs.
  • tokenizer: HuggingFace tokenizer instance (currently unused, reserved for future tokenization steps).
  • invocation_prompt: Invocation string appended to each input prompt.
Returns:
  • A HuggingFace Dataset with "input" and "target" string columns.

FUNC formatting_prompts_func

formatting_prompts_func(example)
Concatenate input and target columns for SFT prompt formatting. Args:
  • example: A batch dict with "input" and "target" list fields, as produced by HuggingFace Dataset.map in batched mode.
Returns:
  • A list of strings, each formed by concatenating the input and
  • target values for a single example in the batch.

FUNC train_model

train_model(dataset_path: str, base_model: str, output_file: str, prompt_file: str | None = None, adapter: str = 'alora', device: str = 'auto', run_name: str = 'multiclass_run', epochs: int = 6, learning_rate: float = 6e-06, batch_size: int = 2, max_length: int = 1024, grad_accum: int = 4)
Fine-tune a causal language model to produce a LoRA or aLoRA adapter. Loads and 80/20-splits the JSONL dataset, configures PEFT with the specified adapter type, trains using SFTTrainer with a best-checkpoint callback, saves the adapter weights, and removes the PEFT-generated README.md from the output directory. Args:
  • dataset_path: Path to the JSONL training dataset file.
  • base_model: Hugging Face model ID or local path to the base model.
  • output_file: Destination path for the trained adapter weights.
  • prompt_file: Optional path to a JSON config file with an "invocation_prompt" key. Defaults to the aLoRA invocation token.
  • adapter: Adapter type to train — "alora" (default) or "lora".
  • device: Device selection — "auto", "cpu", "cuda", or "mps".
  • run_name: Name of the training run (passed to SFTConfig).
  • epochs: Number of training epochs.
  • learning_rate: Optimizer learning rate.
  • batch_size: Per-device training batch size.
  • max_length: Maximum token sequence length.
  • grad_accum: Gradient accumulation steps.
Raises:
  • ValueError: If device is not one of "auto", "cpu", "cuda", or "mps".
  • RuntimeError: If the GPU has insufficient VRAM to load the model (wraps NotImplementedError for meta tensor errors).

Classes

CLASS SaveBestModelCallback

HuggingFace Trainer callback that saves the adapter at its best validation loss. Attributes:
  • best_eval_loss: Lowest evaluation loss seen so far across all evaluation steps. Initialised to float("inf").
Methods:

FUNC on_evaluate

on_evaluate(self, args, state, control, **kwargs)
Save the adapter weights if the current evaluation loss is a new best. Called automatically by the HuggingFace Trainer after each evaluation step. Compares the current eval_loss from metrics against best_eval_loss and, if lower, updates the stored best and saves the model to args.output_dir. Args:
  • args: TrainingArguments instance with training configuration, including output_dir.
  • state: TrainerState instance with the current training state.
  • control: TrainerControl instance for controlling training flow.
  • **kwargs: Additional keyword arguments provided by the Trainer, including "model" (the current PEFT model) and "metrics" (a dict containing "eval_loss").

CLASS SafeSaveTrainer

SFTTrainer subclass that always saves models with safe serialization enabled.
Methods:

FUNC save_model

save_model(self, output_dir: str | None = None, _internal_call: bool = False)
Save the model and tokenizer with safe serialization always enabled. Overrides SFTTrainer.save_model to call save_pretrained with safe_serialization=True, ensuring weights are saved in safetensors format rather than the legacy pickle-based format. Args:
  • output_dir: Directory to save the model into. If None, the trainer’s configured output_dir is used.
  • _internal_call: Internal flag passed through from the Trainer base class; not used by this override.