Using statistical methods to reliably compare algorithm performance in large generative AI models with JAX Profiler on AMD GPUs — ROCm Blogs

摘要

本文提供了一份详细的指南,介绍如何在JAX实现的生成AI模型中测量和比较各种算法的性能。利用JAX Profiler和统计分析,本文展示了如何可靠地评估关键步骤并比较AMD GPU上算法的性能。

引言

在GPU加速计算的动态领域,追求最佳性能和效率需要有效的性能分析技术。性能分析通过仔细检查执行时间、内存利用率和内核占用率等指标,提供了对基于GPU的应用程序行为和性能特征的全面了解。这对于大规模生成AI模型尤为重要,因为优化性能可以显著提升最终用户体验和收入来源。通过利用性能分析技术,开发人员可以找出低效之处,深入了解运行时行为,并最终优先考虑战略性优化工作,从而带来显著的性能提升。
JAX是谷歌的一款开源数值计算库(尽管不是官方的谷歌产品),由于其能够利用硬件加速器和自动微分的能力,正在生成AI领域引起广泛关注。最初用于高性能机器学习研究,JAX的函数式编程方法和对GPU及TPU的支持使其成为构建和部署大型语言模型(LLMs)和其他前沿生成AI应用的首选。值得注意的是,像 X.AI这样的公司利用JAX开发开源模型如Grok-1,进一步推动了该库在生成AI领域的流行。凭借其性能、灵活性及其适合先进AI模型开发和部署的特点,JAX继续在受欢迎程度上不断攀升。

ROCm博客系列此前已探索过各种性能分析工具,如 *rocprof*,可以用于在AMD GPU上分析模型性能,还有针对TensorFlow和PyTorch的框架特定性能分析工具。尽管JAX的官方页面涵盖了其性能分析工具的基本用法,本教程深入探讨了更高级的技术。例如,它解释了在评估算法时,如何在考虑到大量随机噪声的情况下确定一种算法是否显著优于另一种算法。本文通过统计分析和假设检验,展示了如何可靠地测量和比较在大型语言模型中执行相同步骤的不同算法的性能。具体而言,它比较了在JAX-based生成预训练变换器(GPT)模型的`CausalSelfAttention`组件中,使用`einsum`与`matmul`实现两个矩阵乘法步骤的性能。(参见博客中关于在JAX中实现GPT模型的文章)。要了解更多关于`einsum`的信息,请访问这篇博客。 

实现

要实现此代码示例,请首先设置ROCm环境,并安装必要的软件包和Python脚本。值得注意的是,该代码示例是平台无关的,这意味着只要加速计算平台和Python包配置正确,它就兼容AMD GPU以及其他GPU或TPU。

环境设置

按照以下步骤为本教程设置运行环境:

1. 在Linux shell中使用下面的代码拉取并运行docker容器:

docker run -it --ipc=host --network=host --device=/dev/kfd --device=/dev/dri \
           --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
           --name=nanogpt rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2 /bin/bash

2. 在docker容器内运行以下代码,以安装必要的Python包并配置XLA环境变量:

python3 -m pip install --upgrade pip
pip install optax==0.2.2 flax==0.8.2 transformers==4.38.2 tiktoken==0.6.0 datasets==2.17.1 perfetto==0.7.0 matplotlib==3.8.4 scipy==1.13.0
python3 -m pip install https://github.com/ROCmSoftwarePlatform/jax/releases/download/jaxlib-v0.4.26/jaxlib-0.4.26+rocm610-cp310-cp310-manylinux2014_x86_64.whl
python3 -m pip install https://github.com/ROCmSoftwarePlatform/jax/archive/refs/tags/jaxlib-v0.4.26.tar.gz
pip install numpy==1.22.0
export XLA_FLAGS="--xla_gpu_autotune_level=0"

3. 使用以下命令从 ROCm/rocm-blogs GitHub 存储库下载用于该博客的文件。

git clone https://github.com/ROCm/rocm-blogs.git
cd rocm-blogs/blogs/artificial-intelligence/nanoGPT-JAX

4. 将`nanoGPT-JAX`文件夹中的`model.py`和`sample.py`脚本替换为当前博客在GitHub上*src*文件夹中的对应文件。具体参考此链接

特别需要注意的是,对`model.py`文件的修改如下面代码块所示。新添加的两行代码使用`jax.named_scope`为两个矩阵乘法步骤注释唯一名称,这是一个将用户指定名称纳入JAX名称堆栈的上下文管理器。程序随后使用指定名称提取这些步骤的相关性能数据。该技巧对于快速将同类型操作的日志映射到应用程序或模型中的每个步骤非常宝贵,因为默认的日志名称可能会在同类型操作之间非常相似或令人困惑。下面的代码块封装了两个不同的矩阵乘法步骤,并分别为它们指派了不同的范围名称`attn_q_k`和`attn_att_v`。

class CausalSelfAttention(nn.Module):
    config: GPTConfig

    @nn.compact
    def __call__(self, x, train=False, rng1=None, rng2=None):
        assert self.config.n_embd % self.config.n_head == 0
        B, T, C = x.shape # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = jnp.split(nn.Dense(self.config.n_embd * 3, name="c_attn")(x), 3, axis=-1)
        k = k.reshape(B, T, self.config.n_head, C // self.config.n_head).swapaxes(1, 2) # (B, nh, T, hs)
        q = q.reshape(B, T, self.config.n_head, C // self.config.n_head).swapaxes(1, 2) # (B, nh, T, hs)
        v = v.reshape(B, T, self.config.n_head, C // self.config.n_head).swapaxes(1, 2) # (B, nh, T, hs)
+       with jax.named_scope("attn_q_k"):
+           att = (jnp.einsum('bhts,bhqs->bhtq', q, k, optimize=True) if self.config.use_einsum else jnp.matmul(q, k.swapaxes(-2, -1))) * (1.0 / jnp.sqrt(k.shape[-1]))
-       att = (jnp.einsum('bhts,bhqs->bhtq', q, k, optimize=True) if self.config.use_einsum else jnp.matmul(q, k.swapaxes(-2, -1))) * (1.0 / jnp.sqrt(k.shape[-1]))
        mask = jnp.tril(jnp.ones((T, T))).reshape((1, 1, T, T))
        att = jnp.where(mask == 0, float('-inf'), att)
        att = nn.softmax(att, axis=-1)
        att = nn.Dropout(self.config.dropout, name='attn_dropout', deterministic=not train)(att, rng=rng1)
+       with jax.named_scope("attn_att_v"):
+           y = jnp.einsum('bhts,bhsq->bhtq', att, v, optimize=True) if self.config.use_einsum else jnp.matmul(att, v)   # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
-       y = jnp.einsum('bhts,bhsq->bhtq', att, v, optimize=True) if self.config.use_einsum else jnp.matmul(att, v)   # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.swapaxes(1, 2).reshape(B, T, C)  # re-assemble all head outputs side by side
        # output projection
        y = nn.Dense(self.config.n_embd, name='c_proj')(y)
        y = nn.Dropout(self.config.dropout, name='resid_dropout', deterministic=not train)(y, rng=rng2)

        return y

对于`sample.py`文件的主要修改包括使用`jax.profiler.start_trace()`和`jax.profiler.stop_trace()`包裹负责运行基于JAX的GPT模型的推理的函数。这将记录每个生成样本的跟踪信息。或者,您可以使用`jax.profiler.trace()`上下文管理器来捕获跟踪,具体可参见本指南。每个样本的性能分析输出将存储在单独的文件夹中,这使得分析个别跟踪更为方便。

for i in range(num_samples): 
+   jax.profiler.start_trace(profile_dir+f'_{i}')
    output = generate([jnp.array(start_ids)], seed+i)
+   jax.profiler.stop_trace()
    print(f'\nGenerated output __{i}__: \n__________________________________\n{decode(output[0].tolist())}\n__________________________________')

使用不同的矩阵乘法算法对GPT模型进行性能分析

为了演示性能分析,本示例比较了`einsum`和`matmul`这两种在注意力计算步骤中执行矩阵乘法的内置方法。`use_einsum`标志控制了是选择`einsum`还是`matmul`进行矩阵乘法。运行以下命令以收集这两种不同算法的性能分析输出:

# Generate profiling output using matmul
python sample.py --init_from='gpt2' --max_new_tokens=50 --start="The weather today is" --num_samples=10 --profile_dir="trace_file_matmul"

# Generate profiling output using einsum
python sample.py --init_from='gpt2' --max_new_tokens=50 --start="The weather today is" --num_samples=10 --profile_dir="trace_file_einsum" --override_args="{'use_einsum':True}"

每条命令调用`sample.py`文件生成10个样本,每个样本包含最多50个新生成的tokens。这会生成20个文件夹(每种算法10个文件夹,每个生成的样本一个文件夹),这些文件夹包含性能分析的输出。在每个文件夹中,性能分析输出存储在一个压缩的`.gz`文件中。在docker终端中运行以下命令来解压输出: 

for i in {0..9}; do
    gzip -d trace_file_einsum_$i/plugins/profile/202*/*.json.gz
    gzip -d trace_file_matmul_$i/plugins/profile/202*/*.json.gz
done

统计分析与两种算法的性能测试

现在,你可以读取剖析数据并进行统计分析。对于每个迭代(对应于每种算法生成的一个样本),程序比较两种算法在矩阵乘法执行时间(以纳秒为单位)分布上的差异。可以使用箱线图来直观地检查差异。Wilcoxon秩和检验(Mann-Whitney U检验)用来确定位置参数(如均值和中位数)是否显著不同。较短的执行时间表示更好的性能。

下面的代码块导入了分析所需的包,并定义了绘制箱线图的函数。

import glob
from perfetto.trace_processor import TraceProcessor
from scipy.stats import ranksums
import matplotlib.pyplot as plt


def plot_boxplot(df1, df2, columns1, columns2=None, df1_lab='matmul', df2_lab='einsum'):
    """
    Plot boxplots for specified columns in two DataFrames. This function will 
    be used to compare the distribution of running time for the two algorithms
    we profiled.

    Args:
    df1 (pandas.DataFrame): First DataFrame.
    df2 (pandas.DataFrame): Second DataFrame.
    columns1 (list): List of column names from the first DataFrame to plot.
    columns2 (list): List of column names from the second DataFrame to plot.
    df1_lab (string): Label for df1 in the plot.
    df2_lab (string): Label for df2 in the plot.
    """
    if columns2 is None:
        columns2 = columns1
    # Combine data from both DataFrames
    data = [df1[col] for col in columns1] + [df2[col] for col in columns2]
    
    # Create labels for boxplots
    labels = [df1_lab + '_' + col for col in columns1] + [df2_lab + '_' + col for col in columns2]
    
    # Plot boxplots
    plt.figure(figsize=(10, 6))
    plt.boxplot(data, labels=labels)
    plt.xlabel('Algorithms')
    plt.ylabel('Time in nanoseconds')
    plt.title('Performance comparison on the scale of nanoseconds')
    plt.xticks(rotation=45)
    plt.grid(True)
    plt.show()

程序随后比较了每次样本生成迭代中两种算法的执行时间。它在SQL查询中使用`where display_value like "%attn_q_k%"来过滤在第一个named_scope`中的操作。你可以修改SQL查询以探索不同的列并计算感兴趣的指标。

程序省略了第一次迭代,因为第一次迭代包括编译时间,这会使比较失真。它打印了每种算法的执行时间的均值和标准偏差,以及数据框的形状,以确保剖析器和SQL查询捕获了所有事件。例如,对于包含12层的模型,并且每个样本最多生成50个新的token(导致对模型的最多50次函数调用),应捕获最多`12*50=600`次矩阵乘法事件。

最后,程序打印了Wilcoxon秩和检验的统计量和p值,该检验评估两种算法的执行时间分布的位置参数(如均值和中位数)是否显著不同。尽管t检验广泛用于检验两种总体的均值是否相等,但由于样本中存在许多异常值,因此示例使用秩基非参数检验。这些异常值可能显著降低t检验的可靠性。

for i in range(1, 10):
    # Process the profiling data for matmul
    tp = TraceProcessor(trace=glob.glob(f'trace_file_matmul_{i}/plugins/profile/202*/*.json'))
    # SQL query to get the operations enclosed by the named_scope
    query_text='''INCLUDE PERFETTO MODULE slices.slices;
    WITH arg_sets_0 AS (
        SELECT DISTINCT arg_set_id, display_value
        FROM args
        WHERE key = 'args.name'
    )
    SELECT name, display_value, dur
        FROM _slice_with_thread_and_process_info
        INNER JOIN arg_sets_0 ON arg_sets_0.arg_set_id = _slice_with_thread_and_process_info.arg_set_id
    where display_value like "%attn_q_k%"
    '''
    # Query the profiling data and convert to dataframe
    qr_matmul = tp.query(query_text).as_pandas_dataframe()
    # Process the profiling data for einsum
    tp = TraceProcessor(trace=glob.glob(f'trace_file_einsum_{i}/plugins/profile/202*/*.json'))
    # Query the profiling data and convert to dataframe
    qr_einsum = tp.query(query_text).as_pandas_dataframe()
    print(f'###########i={i}###########')
    print('#'*30)
    # Print out the mean, standard dev. and shape for each algorithm
    print(f'Matmul: Mean={qr_matmul.dur.mean()}, std. dev.={qr_matmul.dur.std()}, shape of df:{qr_matmul.shape}')
    print(f'Einsum: Mean={qr_einsum.dur.mean()}, std. dev.={qr_einsum.dur.std()}, shape of df:{qr_einsum.shape}')
    plot_boxplot(qr_matmul, qr_einsum, ['dur'])
    stat, p = ranksums(qr_matmul['dur'], qr_einsum['dur'])
    print(f'Test statistic={stat}, p_val={p}')

下面是两次迭代的截断输出,所有九次迭代都观察到了相同的模式。

###########i=1###########
##############################
Matmul: Mean=6461.875, std. dev.=504.8818364954699, shape of df:(600, 3)
Einsum: Mean=5813.346666666666, std. dev.=455.80420754410954, shape of df:(600, 3)
Test statistic=20.22982266255362, p_val=5.349499343834845e-91

###########i=2###########
##############################
Matmul: Mean=6293.076666666667, std. dev.=514.1309448993132, shape of df:(600, 3)
Einsum: Mean=5797.615, std. dev.=397.86885546863283, shape of df:(600, 3)
Test statistic=16.932946075063718, p_val=2.5717953759559878e-64

基于结果,可以明显看出,对于矩阵乘法算法`einsum`比`matmul`在计算`query`和`key`矩阵之间的矩阵乘法时显著更快。但对于在`attention`和`value`矩阵之间的矩阵乘法时,`matmul`如何表现呢?结果显示在下面的代码块中:

for i in range(1, 10):
    # Process the profiling data for matmul
    tp = TraceProcessor(trace=glob.glob(f'trace_file_matmul_{i}/plugins/profile/202*/*.json'))
    # SQL query to get the operations enclosed by the named_scope
    query_text='''INCLUDE PERFETTO MODULE slices.slices;
    WITH arg_sets_0 AS (
        SELECT DISTINCT arg_set_id, display_value
        FROM args
        WHERE key = 'args.name'
    )
    SELECT name, display_value,dur
        FROM _slice_with_thread_and_process_info
        INNER JOIN arg_sets_0 ON arg_sets_0.arg_set_id = _slice_with_thread_and_process_info.arg_set_id
    where display_value like "%attn_att_v%"
    '''
    # Query the profiling data and convert to dataframe
    qr_matmul = tp.query(query_text).as_pandas_dataframe()
    # Process the profiling data for einsum
    tp = TraceProcessor(trace=glob.glob(f'trace_file_einsum_{i}/plugins/profile/202*/*.json'))
    # Query the profiling data and convert to dataframe
    qr_einsum = tp.query(query_text).as_pandas_dataframe()
    print(f'###########i={i}###########')
    print('#'*30)
    # Print out the mean, standard dev. and shape for each algorithm
    print(f'Matmul: Mean={qr_matmul.dur.mean()}, std. dev.={qr_matmul.dur.std()}, shape of df:{qr_matmul.shape}')
    print(f'Einsum: Mean={qr_einsum.dur.mean()}, std. dev.={qr_einsum.dur.std()}, shape of df:{qr_einsum.shape}')
    plot_boxplot(qr_matmul, qr_einsum, ['dur'])
    stat, p = ranksums(qr_matmul['dur'], qr_einsum['dur'])
    print(f'Test statistic={stat}, p_val={p}')

下面是两次迭代的截断输出,所有九次迭代都观察到了相同的模式。

###########i=1###########
##############################
Matmul: Mean=5204.543333333333, std. dev.=882.6151202759834, shape of df:(600, 3)
Einsum: Mean=6360.556666666666, std. dev.=373.461514250933, shape of df:(600, 3)
Test statistic=-21.986424230986046, p_val=3.884153635651101e-107

###########i=2###########
##############################
Matmul: Mean=5145.61, std. dev.=876.5247080600369, shape of df:(600, 3)
Einsum: Mean=6396.01, std. dev.=381.7892458942073, shape of df:(600, 3)
Test statistic=-22.450480914300588, p_val=1.2659476932444539e-111

这次,令人惊讶的是,`matmul`显著比`einsum`更快。这表明一种矩阵乘法算法并不总是优于另一种。矩阵的大小、形状和其他操作(如矩阵转置)等因素可能会影响速度。这突显了在应用或模型关键步骤中选择最佳算法时使用剖析技术的重要性。另外,如果你检查同一算法在箱线图中的数据点范围,可能会注意到许多异常值。这就是为什么在得出有效结论时,统计分析和适当的方法是如此重要的原因。本例中也使用了秩基检验而非经典的t检验,因为后者通常对异常值敏感。

总结

在剖析应用或模型性能时应应用稳健的统计分析和测试,以确保随机噪声的影响不会损害我们结论的有效性。 

Logo

助力广东及东莞地区开发者,代码托管、在线学习与竞赛、技术交流与分享、资源共享、职业发展,成为松山湖开发者首选的工作与学习平台

更多推荐