如何优雅地为 TensorRT-LLM 添加新模型

本文涉及的产品
多模态交互后付费免费试用,全链路、全Agent
简介: 本指南详细介绍如何在TensorRT-LLM中优雅集成新大语言模型,涵盖模型配置、定义、权重加载与注册全流程,支持作为核心模块或独立扩展集成,助力高效推理部署。(238字)

目录导航

开篇寄语 {#开篇寄语}

本指南通过分步骤的方式,帮助你在 TensorRT-LLM 的 PyTorch 后端中集成一个新的大语言模型。

📍 代码路径说明
本文档中涉及的所有代码路径均为 TensorRT-LLM 仓库内的相对路径。核心模块主要分布在以下目录:

  • 模型定义: tensorrt_llm/_torch/models/
  • 核心模块: tensorrt_llm/_torch/modules/
  • 配置文件: tensorrt_llm/_torch/model_config.pytensorrt_llm/_torch/attention_backend.py
  • 示例代码: examples/pytorch/out_of_tree_example/

如需查找具体文件,请在 TensorRT-LLM 仓库 中按照文中提及的路径定位。

准备工作 {#出发前的准备}

开始前请确保环境已正确配置:

  • 已安装 TensorRT-LLM。可参考安装指南完成安装。

分步实战指南 {#分步实战指南}

模型配置定义 {#模型配置}

首先需要为新模型(记为 MyModel)定义配置类。

情况一:模型已在 HuggingFace 注册

如果模型已在 HuggingFace 的 transformers 库中,可以直接复用其配置类。例如 tensorrt_llm/_torch/models/modeling_llama.py 中就直接继承了 HuggingFace 的 LlamaConfig

# 路径: tensorrt_llm/_torch/models/modeling_llama.py
from transformers import LlamaConfig

情况二:模型尚未在 HuggingFace 注册

需要自己实现配置类,放在 configuration_mymodel.py 中。可参考 HuggingFace 的 configuration_llama.py 的实现方式:

from transformers.configuration_utils import PretrainedConfig

class MyConfig(PretrainedConfig):
    def __init__(self, ...):
        ...

模型定义 {#模型定义}

接下来需要实现推理模型。删除训练阶段的冗余代码,专注于推理路径的 PyTorch 模块实现。对于标准的 Transformer 解码器模型,modeling_mymodel.py 的结构如下:

from typing import Optional

import torch
from torch import nn
from tensorrt_llm._torch.attention_backend import AttentionMetadata
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_utils import DecoderModel, DecoderModelForCausalLM
from tensorrt_llm._torch.modules.attention import Attention
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer

from configuration_mymodel import MyConfig


class MyAttention(Attention):
    def __init__(self, model_config: ModelConfig[MyConfig], layer_idx: Optional[int] = None):
        # 使用 model_config 初始化注意力模块
        super().__init__(...)


class MyDecoderLayer(DecoderLayer):
    def __init__(self, model_config: ModelConfig[MyConfig], layer_idx: int):
        super().__init__()
        # 使用 model_config 初始化子模块
        self.input_layernorm = ...
        self.self_attn = MyAttention(model_config, layer_idx)
        self.post_attention_layernorm = ...
        self.mlp = ...

    def forward(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, **kwargs):
        # 定义解码器单层的前向计算
        ...


class MyModel(DecoderModel):
    def __init__(self, model_config: ModelConfig[MyConfig]):
        super().__init__(model_config)
        # 使用 model_config 初始化子模块
        self.embed_tokens = ...
        self.layers = nn.ModuleList([
            MyDecoderLayer(model_config, layer_idx) for layer_idx in range(model_config.pretrained_config.num_hidden_layers)
        ])

    def forward(self,
                attn_metadata: AttentionMetadata,
                input_ids: Optional[torch.IntTensor] = None,
                position_ids: Optional[torch.IntTensor] = None,
                inputs_embeds: Optional[torch.FloatTensor] = None):
        # 定义模型的前向计算
        ...


class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]):
    def __init__(self, model_config: ModelConfig[MyConfig]):
        super().__init__(MyModel(model_config),
                         config=model_config,
                         hidden_size=model_config.pretrained_config.hidden_size,
                         vocab_size=model_config.pretrained_config.vocab_size)

关键设计要点:

MyAttention 继承自 TensorRT-LLM 提供的 Attention 基类(位于 tensorrt_llm/_torch/modules/attention.py)。这样可以保证注意力计算与 PyTorch 运行时完全兼容。同时需要适配以下输入:

  • attn_metadata:存储批次元数据和 KV 缓存信息,由运行时创建并传递,需确保正确传递给注意力层。
  • 输入张量(input_idsposition_idshidden_states 等):采用压缩打包格式,第一个维度表示整个批次中的令牌总数。

MyDecoderLayerMyModelMyModelForCausalLM 分别继承对应的基类 DecoderLayerDecoderModelDecoderModelForCausalLM,这些基类定义了标准接口并提供了模型层、权重加载等通用实现。

性能优化选项:

可用 TensorRT-LLM 优化的模块替换原生 PyTorch 实现,以获得更好的性能和新增功能:

  • Linear(位于 tensorrt_llm/_torch/modules/linear.py):支持张量并行和量化。
  • Embedding(位于 tensorrt_llm/_torch/modules/embedding.py):为嵌入层带来张量并行。
  • RotaryEmbedding(位于 tensorrt_llm/_torch/modules/rotary_embedding.py):高性能的旋转嵌入实现。
  • RMSNorm(位于 tensorrt_llm/_torch/modules/rms_norm.py):高性能的 RMS 规范化。

想看个活生生的例子?去看看 tensorrt_llm/_torch/models/modeling_llama.py 吧,那里有完整的参考实现。

权重加载 {#权重加载}

基类 DecoderModelForCausalLM 提供了默认的 load_weights 方法,可从检查点文件加载权重并分配到各层。若默认实现不适配你的模型,可按以下方式自定义:

class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]):

    def load_weights(self, weights: dict):
        # 定义权重加载的逻辑
        ...

以 LLaMA 为例说明。HuggingFace 的 LLaMA 在注意力部分使用三个独立的线性层处理 Q/K/V 投影,检查点中对应三个权重张量:

>>> weights
{
   
    ...,
    "model.layers.0.self_attn.q_proj.weight": torch.Tensor([hidden_size, hidden_size]),
    "model.layers.0.self_attn.k_proj.weight": torch.Tensor([hidden_size, hidden_size]),
    "model.layers.0.self_attn.v_proj.weight": torch.Tensor([hidden_size, hidden_size]),
    ...,
}

但 TensorRT-LLM 的 LLaMA 实现采用了融合策略,将三个线性层合并为一个:

>>> llama.model.layers[0].self_attn.qkv_proj.weight.data
torch.Tensor([hidden_size * 3, hidden_size])

因此 load_weights 需要从原始检查点收集这三个权重张量,将其拼接后加载到融合的线性层中。考虑张量并行和量化等高级特性,建议在模型级 load_weights 中充分利用模块级的 load_weights 实现(如 LinearEmbedding)。

核心目标是实现检查点权重到模型结构的映射,确保模型的前向计算与原始模型等价。

模型注册 {#模型注册}

模型定义完成后,需要将其注册到 TensorRT-LLM 系统中,使用 register_auto_model 装饰器:

from tensorrt_llm._torch.models.modeling_utils import register_auto_model

@register_auto_model("MyModelForCausalLM")
class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]):
    def __init__(self, model_config: ModelConfig[MyConfig]):
       ...

作为核心模型集成 {#核心模型}

modeling_mymodel.py(及 configuration_mymodel.py)放在 tensorrt_llm/_torch/models 目录中,并在 tensorrt_llm/_torch/models/__init__.py 中导入:

from .modeling_mymodel import MyModelForCausalLM

__all__ = [
    ...,
    "MyModelForCausalLM",
]

作为独立扩展集成 {#独立扩展模型}

另一个选项是以独立扩展的形式注册模型,无需修改 TensorRT-LLM 源码。将 modeling_mymodel.py(及 configuration_mymodel.py)放在工作目录中,在脚本中导入即可:

from tensorrt_llm import LLM
import modeling_mymodel

def main():
    llm = LLM(...)

if __name__ == '__main__':
    main()

完整的独立扩展示例可参考 examples/pytorch/out_of_tree_example 目录,其中 modeling_opt.py 提供了参考实现,运行方式如下:

python examples/pytorch/out_of_tree_example/main.py
目录
相关文章
|
6天前
|
存储 缓存 负载均衡
TensorRT LLM 中的并行策略
TensorRT LLM提供多种GPU并行策略,支持大模型在显存与性能受限时的高效部署。涵盖张量、流水线、数据、专家及上下文并行,并推出宽专家并行(Wide-EP)应对大规模MoE模型的负载不均与通信挑战,结合智能负载均衡与优化通信核心,提升推理效率与可扩展性。
273 154
|
6天前
|
并行计算 测试技术 异构计算
Qwen3 Next 在 TensorRT LLM 上的部署指南
本指南介绍如何在TensorRT LLM框架上部署Qwen3-Next-80B-A3B-Thinking模型,基于默认配置实现快速部署。涵盖环境准备、Docker容器启动、服务器配置与性能测试,支持BF16精度及MoE模型优化,适用于NVIDIA Hopper/Blackwell架构GPU。
284 154
|
6天前
|
缓存 PyTorch API
TensorRT-LLM 推理服务实战指南
`trtllm-serve` 是 TensorRT-LLM 官方推理服务工具,支持一键部署兼容 OpenAI API 的生产级服务,提供模型查询、文本与对话补全等接口,并兼容多模态及分布式部署,助力高效推理。
249 155
|
22天前
|
机器学习/深度学习 算法 前端开发
别再用均值填充了!MICE算法教你正确处理缺失数据
MICE是一种基于迭代链式方程的缺失值插补方法,通过构建后验分布并生成多个完整数据集,有效量化不确定性。相比简单填补,MICE利用变量间复杂关系,提升插补准确性,适用于多变量关联、缺失率高的场景。本文结合PMM与线性回归,详解其机制并对比效果,验证其在统计推断中的优势。
516 11
别再用均值填充了!MICE算法教你正确处理缺失数据
|
16天前
|
缓存 物联网 PyTorch
使用TensorRT LLM构建和运行Qwen模型
本文档介绍如何在单GPU和单节点多GPU上使用TensorRT LLM构建和运行Qwen模型,涵盖模型转换、引擎构建、量化推理及LoRA微调等操作,并提供详细的代码示例与支持矩阵。
208 2
|
14天前
|
SQL 人工智能 关系型数据库
AI Agent的未来之争:任务规划,该由人主导还是AI自主?——阿里云RDS AI助手的最佳实践
AI Agent的规划能力需权衡自主与人工。阿里云RDS AI助手实践表明:开放场景可由大模型自主规划,高频垂直场景则宜采用人工SOP驱动,结合案例库与混合架构,实现稳定、可解释的企业级应用,推动AI从“能聊”走向“能用”。
474 33
AI Agent的未来之争:任务规划,该由人主导还是AI自主?——阿里云RDS AI助手的最佳实践
|
17天前
|
SQL 关系型数据库 MySQL
开源新发布|PolarDB-X v2.4.2开源生态适配升级
PolarDB-X v2.4.2开源发布,重点完善生态能力:新增客户端驱动、开源polardbx-proxy组件,支持读写分离与高可用;强化DDL变更、扩缩容等运维能力,并兼容MySQL主备复制及MCP AI生态。
开源新发布|PolarDB-X v2.4.2开源生态适配升级
|
6天前
|
存储 弹性计算 固态存储
阿里云新用户优惠:个人、学生和企业购买云服务器配置价格整理
2025阿里云服务器配置全解析:个人用户选200M轻量服务器,68元/年起;企业选2核4G ECS,199元/年,续费同价。详解CPU、内存、带宽及实例类型选择,助力高效上云。
126 9