La evolución de los modelos de lenguaje grandes (LLM) ha expuesto una divergencia fundamental en los requisitos de hardware para sus dos fases principales: entrenamiento e inferencia. Mientras que el entrenamiento demanda un throughput masivo de cómputo y ancho de banda de memoria para procesar vastos datasets y actualizar pesos de modelos, la inferencia, especialmente con arquitecturas como Mixture-of-Experts (MoE), se vuelve intensiva en latencia y ancho de banda de memoria para el streaming de pesos y la gestión de caches de clave-valor. Este artículo explora cómo Google aborda este problema de optimización dual con su octava generación de Tensor Processing Units (TPU), la TPU 8t para entrenamiento y la TPU 8i para inferencia, integrándolas en un ecosistema de cómputo y red diseñado para la escala de hyperscaler.
La necesidad de esta especialización surge de la ley de Amdahl aplicada a los cuellos de botella específicos de cada fase. En el entrenamiento, la capacidad de escalar miles de aceleradores con mínima degradación de rendimiento es crítica, lo que exige redes de interconexión de baja latencia y alto ancho de banda. Para la inferencia, la naturaleza auto-regresiva de la generación de tokens hace que la latencia de acceso a memoria y la eficiencia de las comunicaciones colectivas sean los factores dominantes. Google, al igual que otros proveedores de nube, reconoce que una solución 'one-size-fits-all' es subóptima, y que la especialización a nivel de chip y de infraestructura es esencial para la eficiencia económica y el rendimiento a escala.
Históricamente, los aceleradores de propósito general como las GPUs han servido para ambas tareas, pero la creciente complejidad y tamaño de los LLM empujan los límites de esta aproximación. La estrategia de Google con TPU 8 representa una maduración en el diseño de hardware de IA, donde la co-diseño de chips, interconexiones y software se convierte en el estándar para superar los desafíos de rendimiento y costo en la era de la IA generativa.
Arquitectura del Sistema
La arquitectura de Google para TPU 8 se bifurca en dos diseños específicos: TPU 8t para entrenamiento y TPU 8i para inferencia. Ambas variantes se integran con CPUs Arm-based Axion, reemplazando la dependencia de x86 como host, una decisión que busca optimizar la eficiencia energética y el rendimiento general del sistema. Esta integración es fundamental para la plataforma 'AI Hypercomputer' de Google.
La TPU 8t, orientada al entrenamiento, optimiza la mezcla de aceleradores de vector, multiplicación de matrices y SparseCore para embeddings, maximizando el throughput de punto flotante. Cada chip incorpora 216 GB de HBM con 6.5 TB/s de ancho de banda y 128 MB de SRAM on-chip. La escalabilidad masiva se logra mediante Optical-Circuit Switches que conectan hasta 9,600 aceleradores en un solo 'pod', y múltiples pods se interconectan a través de la red Virgo. Esta red utiliza switches de paquetes con alta densidad de puertos en una topología 'flat-ish two-tier all-to-all', capaz de conectar hasta 134,000 TPUs por datacenter y hasta un millón globalmente. Un sistema de almacenamiento Managed Lustre complementa esta arquitectura, entregando 10 TB/s de datos directamente a la memoria de los aceleradores, crucial para mantener un 'goodput' del 97% en el entrenamiento.
La TPU 8i, diseñada para inferencia, prioriza la latencia y el ancho de banda de memoria. Intercambia FLOPS por una SRAM on-chip más grande (384 MB) y 288 GB de HBM con 8.6 TB/s de ancho de banda. La SRAM adicional es clave para mantener el Key-Value Cache de los modelos residente en el chip, reduciendo la latencia de acceso a datos. La TPU 8i reemplaza los SparseCores por un Collective Acceleration Engine (CAE) que descarga y acelera operaciones de comunicación colectiva como all-reduce y all-gather. Esto es vital para arquitecturas MoE, donde la comunicación entre expertos distribuidos puede ser impredecible y un cuello de botella significativo. La topología de red Boardfly, similar a Dragonfly, conecta 1,152 chips usando Optical Circuit Switches, reduciendo la latencia máxima chip-a-chip de 16 a 7 saltos, un factor crítico para modelos MoE y de razonamiento.
Flujo de Entrenamiento Distribuido con TPU 8t
- 1 Managed Lustre Entrega 10 TB/s de datos de entrenamiento.
- 2 TPU 8t (Axion Host) Procesa datos, realiza multiplicación de matrices y embeddings.
- 3 Optical-Circuit Switches Interconecta hasta 9,600 TPUs en un pod.
- 4 Virgo Network Conecta múltiples pods en una topología all-to-all.
- 5 Sincronización Global Actualización de pesos del modelo a través de miles de TPUs.
- 6 Checkpointing Almacenamiento de estado del modelo para RAS.
Flujo de Inferencia de LLM con TPU 8i (MoE)
- 1 Solicitud de Inferencia Entrada de token al modelo.
- 2 TPU 8i (Axion Host) Procesamiento inicial, acceso a Key-Value Cache en SRAM.
- 3 Boardfly Network Interconexión de 1,152 chips con 7 saltos max.
- 4 Collective Acceleration Engine Offload de comunicaciones colectivas (all-reduce, all-gather) entre expertos.
- 5 Activación de Expertos Selección y cómputo en submodelos distribuidos.
- 6 Generación de Token Salida del token inferido.
| Capa | Tecnología | Justificación |
|---|---|---|
| compute | TPU 8t | Acelerador de IA optimizado para entrenamiento de modelos a gran escala, con énfasis en throughput de punto flotante y ancho de banda de memoria. vs Nvidia Rubin GPUs, AWS Trainium 216 GB HBM (6.5 TB/s), 128 MB SRAM, SparseCore, 12.6 petaFLOPS FP4 |
| compute | TPU 8i | Acelerador de IA optimizado para inferencia de LLM, priorizando baja latencia, gran SRAM on-chip para Key-Value Cache y eficiencia en comunicaciones colectivas. vs Nvidia Blackwell GPUs, AWS Inferentia 288 GB HBM (8.6 TB/s), 384 MB SRAM, Collective Acceleration Engine (CAE), 10.1 petaFLOPS FP4 |
| compute | Google Axion CPUs | CPUs Arm-based que actúan como host para las TPUs, reemplazando x86 para mejorar la eficiencia y el rendimiento del sistema completo. vs x86 CPUs (Intel/AMD), AWS Graviton |
| networking | Optical-Circuit Switches | Tecnología de conmutación opto-mecánica para interconectar hasta 9,600 TPUs en un pod unificado para entrenamiento, y 1,152 chips en Boardfly para inferencia, reduciendo la latencia de saltos. vs Ethernet, InfiniBand, NVLink Conexión de hasta 9,600 TPUs en un pod (8t), 1,152 chips en Boardfly (8i) |
| networking | Virgo Network | Red de interconexión para unir múltiples pods de TPU 8t, utilizando switches de paquetes de alta densidad en una topología 'flat-ish two-tier all-to-all' para escalar a millones de TPUs. vs Redes de interconexión HPC tradicionales (e.g., Fat-Tree, Torus) Hasta 134,000 TPUs por datacenter, hasta 1 millón globalmente |
| networking | Boardfly Topology | Topología de red optimizada para inferencia con TPU 8i, similar a Dragonfly, que reduce la latencia chip-a-chip a 7 saltos para modelos MoE. vs 3D Torus, Packet switched fabric (AWS) 1,152 chips conectados, 7 saltos max de latencia |
| storage | Managed Lustre | Sistema de almacenamiento de alto rendimiento capaz de entregar 10 TB/s de datos directamente a la memoria de los aceleradores para el entrenamiento. vs GPFS, Ceph, NFS distribuido 10 TB/s de throughput agregado |
Trade-offs
Ganancias
- ▲ Rendimiento de entrenamiento
- ▲ Rendimiento/costo de inferencia
- ▲▲ Escalabilidad de clústeres de entrenamiento
- ▲ Latencia chip-a-chip para inferencia MoE
- ▲ Eficiencia de comunicaciones colectivas
Costes
- △ Complejidad de la infraestructura
- △ Costo de desarrollo de hardware especializado
Fundamentos Teóricos
La especialización de hardware para cargas de trabajo de IA resuena con principios fundamentales de la arquitectura de computadoras y la teoría de algoritmos. La distinción entre TPU 8t (entrenamiento) y TPU 8i (inferencia) es un ejemplo clásico de la ley de Amdahl, que postula que la mejora máxima de un sistema está limitada por la fracción de tiempo que el componente mejorado es utilizado. En este caso, los cuellos de botella para entrenamiento (throughput de cómputo y ancho de banda de red para sincronización masiva) difieren de los de inferencia (latencia de memoria para acceso a pesos y eficiencia de comunicación colectiva para MoE).
La adopción de topologías de red como Boardfly (similar a Dragonfly) y la red Virgo para interconexión masiva de TPUs se basa en décadas de investigación en redes de interconexión para supercomputación y HPC. Trabajos seminales como los de Dally y Seitz sobre redes de interconexión (1987) o los estudios sobre topologías como Fat-Tree y Dragonfly (Kim et al., 2008) han explorado cómo minimizar la latencia y maximizar el ancho de banda agregado en sistemas distribuidos a gran escala. La elección de Optical-Circuit Switches, que operan más como un conmutador telefónico que un conmutador de paquetes, sugiere una aproximación a la conmutación de circuitos o 'circuit switching' para garantizar rutas dedicadas y predecibles, un concepto explorado en la década de 1970 para redes telefónicas y ahora re-emergente en el contexto de las redes de centros de datos de IA para reducir la variabilidad de la latencia (jitter) y mejorar el 'goodput'.
Finalmente, la optimización de la gestión de memoria con SRAM on-chip para el Key-Value Cache en la TPU 8i y la aceleración de comunicaciones colectivas con el CAE, se conecta con la investigación en jerarquías de memoria y algoritmos de comunicación paralela. La importancia de la localidad de datos y la minimización de movimientos de datos entre diferentes niveles de memoria es un pilar de la optimización de rendimiento, como se discute en trabajos sobre el 'memory wall' y el diseño de caches. Los algoritmos de comunicación colectiva (all-reduce, all-gather) son fundamentales en el paralelismo de datos y modelos, y su optimización es un campo activo de investigación en computación de alto rendimiento, con papers que datan de los inicios de la computación paralela masiva.