From c918c77e2d980164c27486f83ce72a0e3d48585a Mon Sep 17 00:00:00 2001 From: Gogs Date: Fri, 30 Jan 2026 08:33:25 -0600 Subject: [PATCH] Initial release: Yuuki training pipeline --- LICENSE | 160 ++++++------------- TRAINING.md | 166 ++++++++++++++++++++ train_yuuki.py | 413 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 627 insertions(+), 112 deletions(-) create mode 100644 TRAINING.md create mode 100644 train_yuuki.py diff --git a/LICENSE b/LICENSE index 261eeb9..5c00bfd 100644 --- a/LICENSE +++ b/LICENSE @@ -1,7 +1,9 @@ - Apache License +Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ +Copyright 2026 OpceanAI + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. @@ -63,130 +65,64 @@ on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. + 2. Grant of Copyright License. + Subject to the terms and conditions of this License, each Contributor + hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, + royalty-free, irrevocable copyright license to reproduce, prepare + Derivative Works of, publicly display, publicly perform, sublicense, + and distribute the Work and such Derivative Works in Source or Object + form. - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. + 3. Grant of Patent License. + Subject to the terms and conditions of this License, each Contributor + hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, + royalty-free, irrevocable (except as stated in this section) patent + license to make, have made, use, offer to sell, sell, import, and + otherwise transfer the Work, where such license applies only to those + patent claims licensable by such Contributor that are necessarily + infringed by their Contribution(s) alone or by combination of their + Contribution(s) with the Work to which such Contribution(s) was + submitted. If You institute patent litigation against any entity + alleging that the Work or a Contribution constitutes patent + infringement, then any patent licenses granted under this License + shall terminate as of the date such litigation is filed. - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: + 4. Redistribution. + You may reproduce and distribute copies of the Work or Derivative + Works thereof in any medium, with or without modifications, provided + that You meet the following conditions: - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and + (a) You must give recipients a copy of this License; and + (b) You must cause modified files to carry prominent notices stating + that You changed the files; and + (c) You must retain all copyright, patent, trademark, and attribution + notices; and + (d) Any NOTICE file must be included if present. - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and + 5. Submission of Contributions. + Unless You explicitly state otherwise, any Contribution submitted + shall be under the terms of this License. - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and + 6. Trademarks. + This License does not grant permission to use the trade names, + trademarks, or service marks of the Licensor. - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. + 7. Disclaimer of Warranty. + The Work is provided on an "AS IS" BASIS, WITHOUT WARRANTIES OR + CONDITIONS OF ANY KIND. - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. + 8. Limitation of Liability. + In no event shall any Contributor be liable for damages arising from + the use of the Work. - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. + 9. Accepting Warranty or Additional Liability. + You may offer support or warranty only on Your own behalf. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] + Copyright 2026 OpceanAI Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/TRAINING.md b/TRAINING.md new file mode 100644 index 0000000..88856e9 --- /dev/null +++ b/TRAINING.md @@ -0,0 +1,166 @@ +
+ +# Yuuki Training Code + +**Official training pipeline for Yuuki, an experimental small-scale language model for source code generation.** + +[![Model](https://img.shields.io/badge/HuggingFace-Yuuki--82M-yellow)](https://huggingface.co/OpceanAI/Yuuki-82M) +[![License](https://img.shields.io/badge/License-Apache_2.0-blue)](LICENSE) +[![Python](https://img.shields.io/badge/Python-3.8+-green)](https://python.org) + +
+ +--- + +## Abstract + +This repository contains the official training implementation for **Yuuki**, a compact causal language model optimized for source code understanding and generation. The system is designed with an emphasis on simplicity, reproducibility, and accessibility across heterogeneous computing environments, including CPU-only systems, cloud notebooks (Colab, Kaggle), and resource-constrained platforms such as Termux on mobile devices. + +--- + +## Model Specification + +| Attribute | Description | +|-----------|-------------| +| **Architecture** | GPT-style autoregressive transformer | +| **Base Model** | `distilgpt2` | +| **Domain** | Source code (multi-language) | +| **Training Corpus** | `bigcode/the-stack-smol-xl` | +| **Parameter Count** | ~82M | +| **Design Principles** | Minimal dependencies, transparent implementation, full reproducibility | + +--- + +## Repository Structure + +### Included Components + +| File | Description | +|------|-------------| +| `train_yuuki.py` | Complete, self-contained training script | +| `LICENSE` | Apache 2.0 License | + +### Excluded Artifacts + +The following components are intentionally omitted to maintain repository portability and encourage local reproducibility: + +- Pre-trained model weights and checkpoints +- Tokenized datasets and Arrow cache files +- Training logs and metrics +- Experimental or proprietary scripts +- Auxiliary datasets from subsequent experiments + +All artifacts should be generated locally by executing the provided training script. + +--- + +## Configuration Parameters + +Training behavior is controlled exclusively through environment variables, enabling seamless adaptation across diverse execution environments. + +### Default Configuration + +| Parameter | Default Value | Description | +|-----------|---------------|-------------| +| `MODEL_NAME` | `distilgpt2` | Pre-trained model identifier for initialization | +| `DATASET_ID` | `bigcode/the-stack-smol-xl` | HuggingFace dataset identifier | +| `SPLIT` | `train` | Dataset partition for training | +| `OUTPUT_DIR` | `./yuuki_model` | Output directory for model artifacts | +| `TOKENIZED_CACHE_DIR` | `./yuuki_model/tokenized_cache` | Cache location for tokenized sequences | +| `MAX_LENGTH` | `256` | Maximum input sequence length | +| `EPOCHS` | `2` | Number of training iterations | +| `BATCH_SIZE` | `1` | Samples per gradient update | + +### Implementation Notes + +- **Sequence Length (`MAX_LENGTH=256`)**: Selected to optimize memory utilization and training throughput on constrained hardware. +- **Batch Size (`BATCH_SIZE=1`)**: Configured for compatibility with low-memory execution environments. +- **Tokenization Caching**: Optional but recommended for iterative training workflows. + +--- + +## Execution + +### Standard Invocation + +```bash +python train_yuuki.py +``` + +### Custom Configuration Example + +```bash +MODEL_NAME=distilgpt2 \ +MAX_LENGTH=256 \ +EPOCHS=3 \ +BATCH_SIZE=2 \ +python train_yuuki.py +``` + +The training script performs automatic hardware detection and configures CUDA acceleration when available. + +--- + +## Design Rationale + +Yuuki is not intended to compete with large-scale foundation models. The project objectives are: + +| Principle | Description | +|-----------|-------------| +| **Interpretability** | Prioritizes readable, maintainable code over abstraction layers | +| **Accessibility** | Executable without specialized hardware infrastructure | +| **Transparency** | No hidden procedures or undocumented dependencies | +| **Educational Utility** | Serves as a reference implementation for language model training | + +--- + +## Pre-trained Model + +The model trained using this pipeline is publicly available: + +
+ +**[Yuuki-82M on HuggingFace](https://huggingface.co/OpceanAI/Yuuki-82M)** + +
+ +--- + +## Limitations and Disclaimer + +This software is provided for research and educational purposes. The model may produce: + +- Syntactically or semantically incorrect code +- Incomplete or truncated outputs +- Potentially unsafe or nonsensical suggestions + +**This system is not suitable for production deployment.** Users assume full responsibility for any application of the generated outputs. + +--- + +## License + +This project is distributed under the **Apache License 2.0**. See the [LICENSE](LICENSE) file for complete terms. + +Under this license, you are permitted to: + +- Use, copy, and distribute the software +- Modify and create derivative works +- Use for commercial and non-commercial purposes + +Subject to the conditions of attribution and license preservation as specified in the Apache 2.0 terms. + +--- + +## Contact + +For inquiries, collaboration proposals, or technical discussions regarding Yuuki, please submit an Issue or initiate a Discussion in this repository. + +--- + +
+ +**Developed by [OpceanAI](https://huggingface.co/OpceanAI)** + +
+ diff --git a/train_yuuki.py b/train_yuuki.py new file mode 100644 index 0000000..88e30e6 --- /dev/null +++ b/train_yuuki.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +""" +train_yuuki_refined.py + +Refined training script for Yuuki (mobile-first). + +Features: +- Detects whether CUDA is available and adjusts some args. +- Skips tokenization if dataset already tokenized (caches arrows). +- ResourceMonitorCallback: autosaves when RAM or CPU thresholds hit. +- Autosave every SAVE_STEPS (default 500) and keeps only SAVE_TOTAL_LIMIT checkpoints. +- Pretty progress bar (tqdm) and lightweight logging. +- Graceful SIGINT handler that forces a manual save before exiting. + +Designed for Termux / mobile but also works on desktop (it will auto-detect device). + +Run: python train_yuuki_refined.py + +Note: optional dependencies: psutil and tqdm. Install if missing: pip install psutil tqdm +""" + +import os +import math +import signal +import logging +import sys +from pathlib import Path + +# Optional imports +try: + import psutil +except Exception: + psutil = None + +try: + from tqdm import tqdm +except Exception: + tqdm = None + +from datasets import load_dataset +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + DataCollatorForLanguageModeling, + Trainer, + TrainingArguments, + TrainerCallback, +) + +# ----------------------- +# Config (tweak for your device) +# ----------------------- + +MODEL_NAME = os.environ.get("MODEL_NAME", "distilgpt2") +DATASET_ID = os.environ.get("DATASET_ID", "bigcode/the-stack-smol-xl") +SPLIT = os.environ.get("SPLIT", "train") +OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./yuuki_model") +TOKENIZED_CACHE_DIR = os.environ.get("TOKENIZED_CACHE_DIR", os.path.expanduser("~/yuuki/tokenized_cache")) +MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "256")) # v0.1 - 4x más rápido que 512 +EPOCHS = int(os.environ.get("EPOCHS", "2")) +BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1")) +GRADIENT_ACCUMULATION = int(os.environ.get("GRADIENT_ACCUMULATION", "4")) # Reducido de 8 a 4 para pasos más rápidos + +# Effective batch will be printed at start: BATCH_SIZE * GRADIENT_ACCUMULATION + +# Autosave frequent for mobile small steps +SAVE_STEPS = int(os.environ.get("SAVE_STEPS", "50")) # Guarda checkpoint cada 50 pasos (muy frecuente para móvil) +SAVE_TOTAL_LIMIT = int(os.environ.get("SAVE_TOTAL_LIMIT", "5")) # Mantiene solo los últimos 5 checkpoints (borra automáticamente los más antiguos) +LOGGING_STEPS = int(os.environ.get("LOGGING_STEPS", "10")) # Log cada 10 pasos + +# Resources thresholds +CHECK_RESOURCES_EVERY_N_STEPS = int(os.environ.get("CHECK_RESOURCES_EVERY_N_STEPS", "50")) +MEMORY_THRESHOLD = float(os.environ.get("MEMORY_THRESHOLD", "0.12")) # fraction available +CPU_THRESHOLD = int(os.environ.get("CPU_THRESHOLD", "95")) + +# Map batch for tokenization (reduce if memory issues) +MAP_BATCH_SIZE = int(os.environ.get("MAP_BATCH_SIZE", "128")) + +# Safety limits (don't change unless you know what you do) +MIN_FREE_RAM_MB = 80 # try to keep at least this free RAM + +# Setup logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s") +logger = logging.getLogger("train_yuuki") + +# ----------------------- +# Utility helpers +# ----------------------- + +def has_cuda(): + try: + import torch + return torch.cuda.is_available() + except Exception: + return False + +def human_size(num_bytes: int) -> str: + for unit in ['B', 'KB', 'MB', 'GB', 'TB']: + if abs(num_bytes) < 1024.0: + return f"{num_bytes:3.1f}{unit}" + num_bytes /= 1024.0 + return f"{num_bytes:.1f}PB" + +def get_last_checkpoint(output_dir): + """Find the most recent checkpoint directory in output_dir.""" + if not os.path.exists(output_dir): + return None + + checkpoints = [] + for item in os.listdir(output_dir): + item_path = os.path.join(output_dir, item) + if os.path.isdir(item_path) and item.startswith("checkpoint-"): + try: + step_num = int(item.split("-")[-1]) + checkpoints.append((step_num, item_path)) + except ValueError: + continue + + if not checkpoints: + return None + + # Return the checkpoint with the highest step number + checkpoints.sort(key=lambda x: x[0], reverse=True) + return checkpoints[0][1] + +def cleanup_old_checkpoints(output_dir, max_checkpoints=5): + """Ensure we never have more than max_checkpoints. Delete oldest ones if exceeded.""" + if not os.path.exists(output_dir): + return + + checkpoints = [] + for item in os.listdir(output_dir): + item_path = os.path.join(output_dir, item) + if os.path.isdir(item_path) and item.startswith("checkpoint-"): + try: + step_num = int(item.split("-")[-1]) + checkpoints.append((step_num, item_path)) + except ValueError: + continue + + # Sort by step number (newest first) + checkpoints.sort(key=lambda x: x[0], reverse=True) + + # If we have more than max_checkpoints, delete the oldest ones + if len(checkpoints) > max_checkpoints: + to_delete = checkpoints[max_checkpoints:] + for step_num, checkpoint_path in to_delete: + logger.info(f"Eliminando checkpoint antiguo: {checkpoint_path}") + try: + import shutil + shutil.rmtree(checkpoint_path) + except Exception as e: + logger.warning(f"No se pudo eliminar {checkpoint_path}: {e}") + +# ----------------------- +# Callbacks +# ----------------------- + +class ResourceMonitorCallback(TrainerCallback): + """Checks memory and CPU every N steps and requests a checkpoint save when thresholds are exceeded.""" + + def __init__(self, check_every_n=50, mem_threshold=0.12, cpu_threshold=95): + self.check_every_n = check_every_n + self.mem_threshold = mem_threshold + self.cpu_threshold = cpu_threshold + self._step = 0 + self.psutil = psutil + + def on_step_end(self, args, state, control, **kwargs): + self._step += 1 + if self._step % self.check_every_n != 0: + return control + if self.psutil is None: + return control + + try: + vm = self.psutil.virtual_memory() + avail_frac = vm.available / vm.total if vm.total else 1.0 + cpu = int(self.psutil.cpu_percent(interval=None)) + logger.debug(f"Resource check @ step {state.global_step}: avail_frac={avail_frac:.2f}, cpu={cpu}%") + if avail_frac < self.mem_threshold or cpu >= self.cpu_threshold: + logger.info(f"Resource threshold exceeded (mem={avail_frac:.2f}, cpu={cpu}%). Requesting save.") + control.should_save = True + control.should_log = True + except (PermissionError, OSError) as e: + # Termux may not have permissions to access /proc/stat + logger.debug(f"Cannot check resources (permission denied): {e}") + pass + + return control + +class TqdmProgressCallback(TrainerCallback): + """Simple TQDM progress visualizer that prints loss when available. + + Falls back to basic logging if tqdm is not installed. + """ + + def __init__(self, total_steps=None): + self.total_steps = total_steps + self.pbar = None + self._last_log_step = 0 + + def on_train_begin(self, args, state, control, **kwargs): + if tqdm is None: + logger.info("tqdm not available — using default logs") + return + total = int(state.max_steps) if state.max_steps is not None and state.max_steps > 0 else self.total_steps + self.pbar = tqdm(total=total, desc="Training", unit="step") + + def on_step_end(self, args, state, control, **kwargs): + if self.pbar: + # advance one step for each global step change + self.pbar.n = int(state.global_step) + # show approximate ETA and a minimal loss if logged + last = None + if hasattr(state, 'log_history') and state.log_history: + for e in reversed(state.log_history): + if 'loss' in e: + last = e['loss'] + break + self.pbar.set_postfix({"loss": f"{last:.4f}" if last is not None else "-"}) + self.pbar.refresh() + + def on_train_end(self, args, state, control, **kwargs): + if self.pbar: + self.pbar.close() + +class CheckpointCleanupCallback(TrainerCallback): + """Cleans up old checkpoints after each save to maintain max limit.""" + + def __init__(self, output_dir, max_checkpoints=5): + self.output_dir = output_dir + self.max_checkpoints = max_checkpoints + + def on_save(self, args, state, control, **kwargs): + """Called after a checkpoint is saved.""" + cleanup_old_checkpoints(self.output_dir, max_checkpoints=self.max_checkpoints) + return control + +# ----------------------- +# Signal handler for graceful save +# ----------------------- + +_save_requested = False + +def _signal_handler(sig, frame): + global _save_requested + logger.warning("SIGINT received — will request a graceful save and stop after current step.") + _save_requested = True + +signal.signal(signal.SIGINT, _signal_handler) + +# ----------------------- +# Main +# ----------------------- + +def main(): + device = "cuda" if has_cuda() else "cpu" + logger.info(f"Device detected: {device}") + + if device == "cpu": + logger.warning("⚠️ Entrenando en CPU - esto será MUY LENTO") + logger.warning("⚠️ Considera reducir MAX_LENGTH o usar un modelo más pequeño") + logger.warning("⚠️ Tiempo estimado: ~5 minutos por paso") + + effective_batch = BATCH_SIZE * GRADIENT_ACCUMULATION + logger.info(f"Per-device batch_size={BATCH_SIZE}, gradient_accumulation_steps={GRADIENT_ACCUMULATION}, effective batch={effective_batch}") + + # Create tokenized cache directory if it doesn't exist + os.makedirs(TOKENIZED_CACHE_DIR, exist_ok=True) + logger.info(f"Cache de tokenización: {TOKENIZED_CACHE_DIR}") + + # Check for existing checkpoint to resume from + last_checkpoint = get_last_checkpoint(OUTPUT_DIR) + if last_checkpoint: + logger.info(f"¡Checkpoint encontrado! Reanudando desde: {last_checkpoint}") + else: + logger.info("No se encontró checkpoint previo. Iniciando entrenamiento desde cero.") + + # Clean up old checkpoints if there are more than 5 + cleanup_old_checkpoints(OUTPUT_DIR, max_checkpoints=SAVE_TOTAL_LIMIT) + + # Load dataset (cached if present) + logger.info("Cargando dataset (puede tardar)...") + dataset = load_dataset(DATASET_ID, split=SPLIT) + + # If already tokenized (has input_ids column), skip tokenization + tokenized_already = 'input_ids' in dataset.column_names + + # Tokenizer + logger.info("Cargando tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + def tokenize_function(batch): + key = "code" if "code" in batch else ("content" if "content" in batch else list(batch.keys())[0]) + toks = tokenizer(batch[key], truncation=True, padding="max_length", max_length=MAX_LENGTH) + toks["labels"] = toks["input_ids"].copy() + return toks + + if not tokenized_already: + logger.info("Tokenizando dataset (esto puede tardar; usa batched=True)...") + dataset = dataset.map( + tokenize_function, + batched=True, + batch_size=MAP_BATCH_SIZE, + remove_columns=[c for c in dataset.column_names], + cache_file_name=os.path.join(TOKENIZED_CACHE_DIR, "tokenized_dataset.arrow"), + ) + else: + logger.info("Dataset ya tokenizado — saltando tokenización.") + + dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) + + # Model + logger.info("Cargando modelo...") + model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) + try: + model.gradient_checkpointing_enable() + except Exception: + logger.debug("Gradient checkpointing not available for this model") + + # Data collator + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + # Compute steps + try: + total_examples = len(dataset) + steps_per_epoch = math.ceil(total_examples / (BATCH_SIZE * GRADIENT_ACCUMULATION)) + max_steps = steps_per_epoch * EPOCHS + logger.info(f"Total examples: {total_examples}, steps/epoch: {steps_per_epoch}, max_steps: {max_steps}") + except Exception: + steps_per_epoch = None + max_steps = None + + # Training args + training_args = TrainingArguments( + output_dir=OUTPUT_DIR, + overwrite_output_dir=False, + num_train_epochs=EPOCHS, + per_device_train_batch_size=BATCH_SIZE, + gradient_accumulation_steps=GRADIENT_ACCUMULATION, + fp16=(device == "cuda"), + save_strategy="steps", + save_steps=SAVE_STEPS, + save_total_limit=SAVE_TOTAL_LIMIT, + logging_steps=LOGGING_STEPS, + dataloader_num_workers=0, # Cambiado a 0 para evitar overhead en móvil + dataloader_pin_memory=False, # Desactivado para CPU + remove_unused_columns=False, + report_to=[], + resume_from_checkpoint=last_checkpoint, # Auto-resume from last checkpoint + # Optimizaciones adicionales para móvil + gradient_checkpointing=True, + optim="adamw_torch", # Optimizador más rápido + max_grad_norm=1.0, + ) + + # Callbacks + resource_cb = ResourceMonitorCallback( + check_every_n=CHECK_RESOURCES_EVERY_N_STEPS, + mem_threshold=MEMORY_THRESHOLD, + cpu_threshold=CPU_THRESHOLD, + ) + progress_cb = TqdmProgressCallback(total_steps=max_steps) + cleanup_cb = CheckpointCleanupCallback(output_dir=OUTPUT_DIR, max_checkpoints=SAVE_TOTAL_LIMIT) + + trainer = Trainer( + model=model, + args=training_args, + data_collator=data_collator, + train_dataset=dataset, + callbacks=[resource_cb, progress_cb, cleanup_cb], + ) + + # Wrap training loop to support graceful save on SIGINT or resource request + if last_checkpoint: + logger.info(f"Inicio de entrenamiento — {EPOCHS} epochs configuradas (reanudando desde checkpoint).") + else: + logger.info(f"Inicio de entrenamiento — {EPOCHS} epochs configuradas.") + + # Helper to do a manual save + def manual_save(tag: str = None): + dest = OUTPUT_DIR + if tag: + dest = os.path.join(OUTPUT_DIR, f"manual-{tag}") + logger.info(f"Guardando modelo manualmente en: {dest}") + trainer.save_model(dest) + tokenizer.save_pretrained(dest) + + # Start training with periodic checks + try: + # We can't easily interrupt trainer.train() internally, so we rely on callbacks and SIGINT + # Pass resume_from_checkpoint to continue from where we left off + trainer.train(resume_from_checkpoint=last_checkpoint) + except KeyboardInterrupt: + logger.warning("Interrupción por el usuario detectada. Guardando checkpoint...") + try: + manual_save(tag=f"step{trainer.state.global_step}") + except Exception as e: + logger.exception("Error al guardar el checkpoint durante interrupción: %s", e) + logger.info("Guardado finalizado. Saliendo.") + sys.exit(0) + + # Final save + logger.info("Entrenamiento terminado. Guardando modelo final...") + manual_save(tag="final") + logger.info("Entrenamiento completado. Modelo guardado en: %s", OUTPUT_DIR) + +if __name__ == "__main__": + main() +