跳转至

API

single_bar

Functions:

Name Description
plot_one_group_bar_figure

绘制单组柱状图,包含散点、误差条和统计显著性标记。

plot_one_group_violin_figure

绘制单组小提琴图,可选散点叠加、渐变填色和统计显著性标注。

plot_one_group_bar_figure

plot_one_group_bar_figure(data: Sequence[Sequence[Num] | ndarray], ax: Axes | None = None, labels_name: list[str] | None = None, colors: list[str] | None = None, edgecolor: str | None = None, gradient_color: bool = False, colors_start: list[str] | None = None, colors_end: list[str] | None = None, show_dots: bool = True, dots_color: list[list[str]] | None = None, width: Num = 0.5, color_alpha: Num = 1, dots_size: Num = 35, errorbar_type: str = 'sd', title_name: str = '', title_fontsize: Num = 12, title_pad: Num = 10, x_label_name: str = '', x_label_ha: str = 'center', x_label_fontsize: Num = 12, x_tick_fontsize: Num = 12, x_tick_rotation: Num = 0, y_label_name: str = '', y_label_fontsize: Num = 12, y_tick_fontsize: Num = 8, y_tick_rotation: Num = 0, y_lim: tuple[float, float] | None = None, statistic: bool = False, test_method: list[str] = ['ttest_ind'], p_list: list[float] | None = None, popmean: Num = 0, statistical_line_color: str = '0.5', asterisk_fontsize: Num = 10, asterisk_color: str = 'k', y_base: float | None = None, interval: float | None = None, ax_bottom_is_0: bool = False, y_max_tick_is_1: bool = False, math_text: bool = True, one_decimal_place: bool = False, percentage: bool = False) -> Axes | None

绘制单组柱状图,包含散点、误差条和统计显著性标记。

Parameters:

Name Type Description Default
data ndarray | Sequence[Sequence[Num] | ndarray]

输入数据,可以是二维numpy数组或嵌套序列,每个子序列代表一个柱状图的数据点

required
ax Axes | None

matplotlib的坐标轴对象,如果为None则使用当前坐标轴. Defaults to None.

None
labels_name list[str] | None

柱状图的标签名称列表. Defaults to None.

None
colors list[str] | None

柱状图的颜色列表. Defaults to None.

None
edgecolor str | None

柱状图边缘颜色. Defaults to None.

None
gradient_color bool

是否使用渐变颜色填充柱状图. Defaults to False.

False
colors_start list[str] | None

渐变色的起始颜色列表. Defaults to None.

None
colors_end list[str] | None

渐变色的结束颜色列表. Defaults to None.

None
show_dots bool

是否显示散点. Defaults to True.

True
dots_color list[list[str]] | None

散点的颜色列表. Defaults to None.

None
width Num

柱状图的宽度. Defaults to 0.5.

0.5
color_alpha Num

柱状图颜色的透明度. Defaults to 1.

1
dots_size Num

散点的大小. Defaults to 35.

35
errorbar_type str

误差条类型,可选 "sd"(标准差) 或 "se"(标准误). Defaults to "sd".

'sd'
title_name str

图表标题. Defaults to "".

''
title_fontsize Num

标题字体大小. Defaults to 12.

12
title_pad Num

标题与图表的间距. Defaults to 10.

10
x_label_name str

X轴标签名称. Defaults to "".

''
x_label_ha str

X轴标签的水平对齐方式. Defaults to "center".

'center'
x_label_fontsize Num

X轴标签字体大小. Defaults to 12.

12
x_tick_fontsize Num

X轴刻度字体大小. Defaults to 12.

12
x_tick_rotation Num

X轴刻度旋转角度. Defaults to 0.

0
y_label_name str

Y轴标签名称. Defaults to "".

''
y_label_fontsize Num

Y轴标签字体大小. Defaults to 12.

12
y_tick_fontsize Num

Y轴刻度字体大小. Defaults to 8.

8
y_tick_rotation Num

Y轴刻度旋转角度. Defaults to 0.

0
y_lim tuple[Num, Num] | None

Y轴的范围限制. Defaults to None.

None
statistic bool

是否进行统计显著性分析. Defaults to False.

False
test_method list[str]

统计检验方法列表,包括 1. ttest_ind, 2. ttest_rel, 3. ttest_1samp, 4. mannwhitneyu, 5. external. Defaults to ["ttest_ind"].

['ttest_ind']
p_list list[float] | None

预计算的p值列表,用于显著性标记. Defaults to None.

None
popmean Num

单样本t检验的假设均值. Defaults to 0.

0
statistical_line_color str

显著性标记线的颜色. Defaults to "0.5".

'0.5'
asterisk_fontsize Num

显著性星号的字体大小. Defaults to 10.

10
asterisk_color str

显著性星号的颜色. Defaults to "k".

'k'
y_base float | None

显著性连线的起始Y轴位置(高度)。如果为None,则使用内部算法自动计算一个合适的位置。Defaults to None.

None
interval float | None

相邻显著性连线之间的垂直距离(Y轴增量)。如果为None,则使用内部算法根据图表范围和比较对数自动计算。Defaults to None.

None
ax_bottom_is_0 bool

Y轴是否从0开始. Defaults to False.

False
y_max_tick_is_1 bool

Y轴最大刻度是否限制为1. Defaults to False.

False
math_text bool

是否将Y轴显示为科学计数法格式. Defaults to True.

True
one_decimal_place bool

Y轴刻度是否只保留一位小数. Defaults to False.

False
percentage bool

是否将Y轴显示为百分比格式. Defaults to False.

False

Raises:

Type Description
ValueError

当data数据格式无效时抛出

ValueError

当errorbar_type不是"sd"或"se"时抛出

Returns:

Type Description
Axes | None

Axes | None: 返回matplotlib的坐标轴对象或None

Source code in src/plotfig/single_bar.py
def plot_one_group_bar_figure(
    data: Sequence[Sequence[Num] | np.ndarray],
    ax: Axes | None = None,
    labels_name: list[str] | None = None,
    colors: list[str] | None = None,
    edgecolor: str | None = None,
    gradient_color: bool = False,
    colors_start: list[str] | None = None,
    colors_end: list[str] | None = None,
    show_dots: bool = True,
    dots_color: list[list[str]] | None = None,
    width: Num = 0.5,
    color_alpha: Num = 1,
    dots_size: Num = 35,
    errorbar_type: str = "sd",
    title_name: str = "",
    title_fontsize: Num = 12,
    title_pad: Num = 10,
    x_label_name: str = "",
    x_label_ha: str = "center",
    x_label_fontsize: Num = 12,
    x_tick_fontsize: Num = 12,
    x_tick_rotation: Num = 0,
    y_label_name: str = "",
    y_label_fontsize: Num = 12,
    y_tick_fontsize: Num = 8,
    y_tick_rotation: Num = 0,
    y_lim: tuple[float, float] | None = None,
    statistic: bool = False,
    test_method: list[str] = ["ttest_ind"],
    p_list: list[float] | None = None,
    popmean: Num = 0,
    statistical_line_color: str = "0.5",
    asterisk_fontsize: Num = 10,
    asterisk_color: str = "k",
    y_base: float | None = None,
    interval: float | None = None,
    ax_bottom_is_0: bool = False,
    y_max_tick_is_1: bool = False,
    math_text: bool = True,
    one_decimal_place: bool = False,
    percentage: bool = False,
) -> Axes | None:
    """绘制单组柱状图,包含散点、误差条和统计显著性标记。

    Args:
        data (np.ndarray | Sequence[Sequence[Num] | np.ndarray]):
            输入数据,可以是二维numpy数组或嵌套序列,每个子序列代表一个柱状图的数据点
        ax (Axes | None, optional):
            matplotlib的坐标轴对象,如果为None则使用当前坐标轴. Defaults to None.
        labels_name (list[str] | None, optional):
            柱状图的标签名称列表. Defaults to None.
        colors (list[str] | None, optional):
            柱状图的颜色列表. Defaults to None.
        edgecolor (str | None, optional):
            柱状图边缘颜色. Defaults to None.
        gradient_color (bool, optional):
            是否使用渐变颜色填充柱状图. Defaults to False.
        colors_start (list[str] | None, optional):
            渐变色的起始颜色列表. Defaults to None.
        colors_end (list[str] | None, optional):
            渐变色的结束颜色列表. Defaults to None.
        show_dots (bool, optional):
            是否显示散点. Defaults to True.
        dots_color (list[list[str]] | None, optional):
            散点的颜色列表. Defaults to None.
        width (Num, optional):
            柱状图的宽度. Defaults to 0.5.
        color_alpha (Num, optional):
            柱状图颜色的透明度. Defaults to 1.
        dots_size (Num, optional):
            散点的大小. Defaults to 35.
        errorbar_type (str, optional):
            误差条类型,可选 "sd"(标准差) 或 "se"(标准误). Defaults to "sd".
        title_name (str, optional):
            图表标题. Defaults to "".
        title_fontsize (Num, optional):
            标题字体大小. Defaults to 12.
        title_pad (Num, optional):
            标题与图表的间距. Defaults to 10.
        x_label_name (str, optional):
            X轴标签名称. Defaults to "".
        x_label_ha (str, optional):
            X轴标签的水平对齐方式. Defaults to "center".
        x_label_fontsize (Num, optional):
            X轴标签字体大小. Defaults to 12.
        x_tick_fontsize (Num, optional):
            X轴刻度字体大小. Defaults to 12.
        x_tick_rotation (Num, optional):
            X轴刻度旋转角度. Defaults to 0.
        y_label_name (str, optional):
            Y轴标签名称. Defaults to "".
        y_label_fontsize (Num, optional):
            Y轴标签字体大小. Defaults to 12.
        y_tick_fontsize (Num, optional):
            Y轴刻度字体大小. Defaults to 8.
        y_tick_rotation (Num, optional):
            Y轴刻度旋转角度. Defaults to 0.
        y_lim (tuple[Num, Num] | None, optional):
            Y轴的范围限制. Defaults to None.
        statistic (bool, optional):
            是否进行统计显著性分析. Defaults to False.
        test_method (list[str], optional):
            统计检验方法列表,包括
            1. `ttest_ind`,
            2. `ttest_rel`,
            3. `ttest_1samp`,
            4. `mannwhitneyu`,
            5. `external`.
            Defaults to ["ttest_ind"].
        p_list (list[float] | None, optional):
            预计算的p值列表,用于显著性标记. Defaults to None.
        popmean (Num, optional):
            单样本t检验的假设均值. Defaults to 0.
        statistical_line_color (str, optional):
            显著性标记线的颜色. Defaults to "0.5".
        asterisk_fontsize (Num, optional):
            显著性星号的字体大小. Defaults to 10.
        asterisk_color (str, optional):
            显著性星号的颜色. Defaults to "k".
        y_base (float | None, optional):
            显著性连线的起始Y轴位置(高度)。如果为None,则使用内部算法自动计算一个合适的位置。Defaults to None.
        interval (float | None, optional):
            相邻显著性连线之间的垂直距离(Y轴增量)。如果为None,则使用内部算法根据图表范围和比较对数自动计算。Defaults to None.
        ax_bottom_is_0 (bool, optional):
            Y轴是否从0开始. Defaults to False.
        y_max_tick_is_1 (bool, optional):
            Y轴最大刻度是否限制为1. Defaults to False.
        math_text (bool, optional):
            是否将Y轴显示为科学计数法格式. Defaults to True.
        one_decimal_place (bool, optional):
            Y轴刻度是否只保留一位小数. Defaults to False.
        percentage (bool, optional):
            是否将Y轴显示为百分比格式. Defaults to False.

    Raises:
        ValueError: 当data数据格式无效时抛出
        ValueError: 当errorbar_type不是"sd"或"se"时抛出

    Returns:
        Axes | None: 返回matplotlib的坐标轴对象或None
    """
    # 处理None值
    if not _is_valid_data(data):
        raise ValueError("无效的 data")
    ax = ax or plt.gca()
    labels_name = labels_name or [str(i) for i in range(len(data))]
    colors = colors or ["gray"] * len(data)
    # 统一参数型
    width = float(width)
    color_alpha = float(color_alpha)
    dots_size = float(dots_size)
    title_fontsize = float(title_fontsize)
    title_pad = float(title_pad)
    x_label_fontsize = float(x_label_fontsize)
    x_tick_fontsize = float(x_tick_fontsize)
    x_tick_rotation = float(x_tick_rotation)
    y_label_fontsize = float(y_label_fontsize)
    y_tick_fontsize = float(y_tick_fontsize)
    y_tick_rotation = float(y_tick_rotation)
    popmean = float(popmean)
    asterisk_fontsize = float(asterisk_fontsize)

    x_positions = np.arange(len(labels_name))
    means, sds, ses = [], [], []
    scatter_positions = []
    for i, d in enumerate(data):
        mean, sd, se = compute_summary(np.array(d))
        means.append(mean)
        sds.append(sd)
        ses.append(se)
        # 创建随机数生成器
        rng = np.random.default_rng(seed=42)
        scatter_x = rng.normal(i, 0.1, len(d))
        scatter_positions.append(scatter_x)
    if errorbar_type == "sd":
        error_values = sds
    elif errorbar_type == "se":
        error_values = ses
    else:
        raise ValueError("errorbar_type 只能是 'sd' 或者 'se'")

    # 绘制柱子
    if gradient_color:
        if colors_start is None:  # 默认颜色
            colors_start = ["#e38a48"] * len(x_positions)  # 左边颜色
        if colors_end is None:  # 默认颜色
            colors_end = ["#4573a5"] * len(x_positions)  # 右边颜色
        for x, h, c1, c2 in zip(x_positions, means, colors_start, colors_end):
            # 生成线性渐变 colormap
            cmap = LinearSegmentedColormap.from_list("grad_cmap", [c1, "white", c2])
            gradient = np.linspace(0, 1, 100).reshape(1, -1)  # 横向渐变
            # 计算渐变矩形位置:跟bar完全对齐
            extent = (float(x - width / 2), float(x + width / 2), 0, h)
            # 叠加渐变矩形(imshow)
            ax.imshow(gradient, aspect="auto", cmap=cmap, extent=extent, zorder=0)
    else:
        ax.bar(
            x_positions,
            means,
            width=width,
            color=colors,
            alpha=color_alpha,
            edgecolor=edgecolor,
        )

    ax.errorbar(
        x_positions,
        means,
        error_values,
        fmt="none",
        linewidth=1,
        capsize=3,
        color="black",
    )

    # 绘制散点
    if show_dots:
        for i, d in enumerate(data):
            if dots_color is None:
                _add_scatter(ax, scatter_positions[i], d, ["gray"] * len(d), dots_size)
            else:
                _add_scatter(ax, scatter_positions[i], d, dots_color[i], dots_size)

    # 美化
    ax.spines[["top", "right"]].set_visible(False)
    ax.set_title(
        title_name,
        fontsize=title_fontsize,
        pad=float(title_pad),
    )
    # x轴
    ax.set_xlim(np.min(x_positions) - 0.5, np.max(x_positions) + 0.5)
    ax.set_xlabel(x_label_name, fontsize=x_label_fontsize)
    ax.set_xticks(x_positions)
    ax.set_xticklabels(
        labels_name,
        fontsize=x_tick_fontsize,
        rotation=x_tick_rotation,
        ha=x_label_ha,
        rotation_mode="anchor",
    )
    # y轴
    ax.tick_params(
        axis="y",
        labelsize=y_tick_fontsize,
        rotation=y_tick_rotation,
    )
    ax.set_ylabel(y_label_name, fontsize=y_label_fontsize)
    all_values = np.concatenate([np.asarray(x) for x in data]).ravel()
    set_yaxis(
        ax,
        all_values,
        y_lim=y_lim,
        ax_bottom_is_0=ax_bottom_is_0,
        y_max_tick_is_1=y_max_tick_is_1,
        math_text=math_text,
        one_decimal_place=one_decimal_place,
        percentage=percentage,
    )

    # 添加统计显著性标记
    if statistic:
        _statistics(
            data,
            test_method,
            p_list,
            popmean,
            ax,
            all_values,
            statistical_line_color,
            asterisk_fontsize,
            asterisk_color,
            y_base,
            interval,
        )
    return ax

plot_one_group_violin_figure

plot_one_group_violin_figure(data: Sequence[list[float] | NDArray[float64]], ax: Axes | None = None, labels_name: list[str] | None = None, width: Num = 0.8, colors: list[str] | None = None, color_alpha: Num = 1, gradient_color: bool = False, colors_start: list[str] | None = None, colors_end: list[str] | None = None, show_dots: bool = False, dots_size: Num = 35, title_name: str = '', title_fontsize: Num = 12, title_pad: Num = 10, x_label_name: str = '', x_label_ha: str = 'center', x_label_fontsize: Num = 10, x_tick_fontsize: Num = 8, x_tick_rotation: Num = 0, y_label_name: str = '', y_label_fontsize: Num = 10, y_tick_fontsize: Num = 8, y_tick_rotation: Num = 0, y_lim: tuple[float, float] | None = None, statistic: bool = False, test_method: list[str] = ['ttest_ind'], popmean: Num = 0, p_list: list[float] | None = None, statistical_line_color: str = '0.5', asterisk_fontsize: Num = 10, asterisk_color: str = 'k', y_base: float | None = None, interval: float | None = None, ax_bottom_is_0: bool = False, y_max_tick_is_1: bool = False, math_text: bool = True, one_decimal_place: bool = False, percentage: bool = False) -> Axes | None

绘制单组小提琴图,可选散点叠加、渐变填色和统计显著性标注。

Parameters:

Name Type Description Default
data Sequence[list[float] | NDArray[float64]]

输入数据,可以是二维numpy数组或嵌套序列,每个子序列代表一个小提琴的数据点

required
ax Axes | None

matplotlib的坐标轴对象,如果为None则使用当前坐标轴. Defaults to None.

None
labels_name list[str] | None

小提琴图的标签名称列表. Defaults to None.

None
width Num

小提琴图的宽度. Defaults to 0.8.

0.8
colors list[str] | None

小提琴图的颜色列表. Defaults to None.

None
color_alpha Num

小提琴图颜色的透明度. Defaults to 1.

1
gradient_color bool

是否使用渐变颜色填充小提琴图. Defaults to False.

False
colors_start list[str] | None

渐变色的起始颜色列表. Defaults to None.

None
colors_end list[str] | None

渐变色的结束颜色列表. Defaults to None.

None
show_dots bool

是否显示散点. Defaults to False.

False
dots_size Num

散点的大小. Defaults to 35.

35
title_name str

图表标题. Defaults to "".

''
title_fontsize Num

标题字体大小. Defaults to 12.

12
title_pad Num

标题与图表的间距. Defaults to 10.

10
x_label_name str

X轴标签名称. Defaults to "".

''
x_label_ha str

X轴标签的水平对齐方式. Defaults to "center".

'center'
x_label_fontsize Num

X轴标签字体大小. Defaults to 10.

10
x_tick_fontsize Num

X轴刻度字体大小. Defaults to 8.

8
x_tick_rotation Num

X轴刻度旋转角度. Defaults to 0.

0
y_label_name str

Y轴标签名称. Defaults to "".

''
y_label_fontsize Num

Y轴标签字体大小. Defaults to 10.

10
y_tick_fontsize Num

Y轴刻度字体大小. Defaults to 8.

8
y_tick_rotation Num

Y轴刻度旋转角度. Defaults to 0.

0
y_lim tuple[Num, Num] | None

Y轴的范围限制. Defaults to None.

None
statistic bool

是否进行统计显著性分析. Defaults to False.

False
test_method list[str]

统计检验方法列表. Defaults to ["ttest_ind"].

['ttest_ind']
popmean Num

单样本t检验的假设均值. Defaults to 0.

0
p_list list[float] | None

预计算的p值列表,用于显著性标记. Defaults to None.

None
statistical_line_color str

显著性标记线的颜色. Defaults to "0.5".

'0.5'
asterisk_fontsize Num

显著性星号的字体大小. Defaults to 10.

10
asterisk_color str

显著性星号的颜色. Defaults to "k".

'k'
y_base float | None

显著性连线的起始Y轴位置(高度)。如果为None,则使用内部算法自动计算一个合适的位置。Defaults to None.

None
interval float | None

相邻显著性连线之间的垂直距离(Y轴增量)。如果为None,则使用内部算法根据图表范围和比较对数自动计算。Defaults to None.

None
ax_bottom_is_0 bool

Y轴是否从0开始. Defaults to False.

False
y_max_tick_is_1 bool

Y轴最大刻度是否限制为1. Defaults to False.

False
math_text bool

是否将Y轴显示为科学计数法格式. Defaults to True.

True
one_decimal_place bool

Y轴刻度是否只保留一位小数. Defaults to False.

False
percentage bool

是否将Y轴显示为百分比格式. Defaults to False.

False

Raises:

Type Description
ValueError

当data数据格式无效时抛出

Returns:

Type Description
Axes | None

Axes | None: 返回matplotlib的坐标轴对象或None

Source code in src/plotfig/single_bar.py
def plot_one_group_violin_figure(
    data: Sequence[list[float] | NDArray[np.float64]],
    ax: Axes | None = None,
    labels_name: list[str] | None = None,
    width: Num = 0.8,
    colors: list[str] | None = None,
    color_alpha: Num = 1,
    gradient_color: bool = False,
    colors_start: list[str] | None = None,
    colors_end: list[str] | None = None,
    show_dots: bool = False,
    dots_size: Num = 35,
    title_name: str = "",
    title_fontsize: Num = 12,
    title_pad: Num = 10,
    x_label_name: str = "",
    x_label_ha: str = "center",
    x_label_fontsize: Num = 10,
    x_tick_fontsize: Num = 8,
    x_tick_rotation: Num = 0,
    y_label_name: str = "",
    y_label_fontsize: Num = 10,
    y_tick_fontsize: Num = 8,
    y_tick_rotation: Num = 0,
    y_lim: tuple[float, float] | None = None,
    statistic: bool = False,
    test_method: list[str] = ["ttest_ind"],
    popmean: Num = 0,
    p_list: list[float] | None = None,
    statistical_line_color: str = "0.5",
    asterisk_fontsize: Num = 10,
    asterisk_color: str = "k",
    y_base: float | None = None,
    interval: float | None = None,
    ax_bottom_is_0: bool = False,
    y_max_tick_is_1: bool = False,
    math_text: bool = True,
    one_decimal_place: bool = False,
    percentage: bool = False,
) -> Axes | None:
    """绘制单组小提琴图,可选散点叠加、渐变填色和统计显著性标注。

    Args:
        data (Sequence[list[float] | NDArray[np.float64]]):
            输入数据,可以是二维numpy数组或嵌套序列,每个子序列代表一个小提琴的数据点
        ax (Axes | None, optional):
            matplotlib的坐标轴对象,如果为None则使用当前坐标轴. Defaults to None.
        labels_name (list[str] | None, optional):
            小提琴图的标签名称列表. Defaults to None.
        width (Num, optional):
            小提琴图的宽度. Defaults to 0.8.
        colors (list[str] | None, optional):
            小提琴图的颜色列表. Defaults to None.
        color_alpha (Num, optional):
            小提琴图颜色的透明度. Defaults to 1.
        gradient_color (bool, optional):
            是否使用渐变颜色填充小提琴图. Defaults to False.
        colors_start (list[str] | None, optional):
            渐变色的起始颜色列表. Defaults to None.
        colors_end (list[str] | None, optional):
            渐变色的结束颜色列表. Defaults to None.
        show_dots (bool, optional):
            是否显示散点. Defaults to False.
        dots_size (Num, optional):
            散点的大小. Defaults to 35.
        title_name (str, optional):
            图表标题. Defaults to "".
        title_fontsize (Num, optional):
            标题字体大小. Defaults to 12.
        title_pad (Num, optional):
            标题与图表的间距. Defaults to 10.
        x_label_name (str, optional):
            X轴标签名称. Defaults to "".
        x_label_ha (str, optional):
            X轴标签的水平对齐方式. Defaults to "center".
        x_label_fontsize (Num, optional):
            X轴标签字体大小. Defaults to 10.
        x_tick_fontsize (Num, optional):
            X轴刻度字体大小. Defaults to 8.
        x_tick_rotation (Num, optional):
            X轴刻度旋转角度. Defaults to 0.
        y_label_name (str, optional):
            Y轴标签名称. Defaults to "".
        y_label_fontsize (Num, optional):
            Y轴标签字体大小. Defaults to 10.
        y_tick_fontsize (Num, optional):
            Y轴刻度字体大小. Defaults to 8.
        y_tick_rotation (Num, optional):
            Y轴刻度旋转角度. Defaults to 0.
        y_lim (tuple[Num, Num] | None, optional):
            Y轴的范围限制. Defaults to None.
        statistic (bool, optional):
            是否进行统计显著性分析. Defaults to False.
        test_method (list[str], optional):
            统计检验方法列表. Defaults to ["ttest_ind"].
        popmean (Num, optional):
            单样本t检验的假设均值. Defaults to 0.
        p_list (list[float] | None, optional):
            预计算的p值列表,用于显著性标记. Defaults to None.
        statistical_line_color (str, optional):
            显著性标记线的颜色. Defaults to "0.5".
        asterisk_fontsize (Num, optional):
            显著性星号的字体大小. Defaults to 10.
        asterisk_color (str, optional):
            显著性星号的颜色. Defaults to "k".
        y_base (float | None, optional):
            显著性连线的起始Y轴位置(高度)。如果为None,则使用内部算法自动计算一个合适的位置。Defaults to None.
        interval (float | None, optional):
            相邻显著性连线之间的垂直距离(Y轴增量)。如果为None,则使用内部算法根据图表范围和比较对数自动计算。Defaults to None.
        ax_bottom_is_0 (bool, optional):
            Y轴是否从0开始. Defaults to False.
        y_max_tick_is_1 (bool, optional):
            Y轴最大刻度是否限制为1. Defaults to False.
        math_text (bool, optional):
            是否将Y轴显示为科学计数法格式. Defaults to True.
        one_decimal_place (bool, optional):
            Y轴刻度是否只保留一位小数. Defaults to False.
        percentage (bool, optional):
            是否将Y轴显示为百分比格式. Defaults to False.

    Raises:
        ValueError: 当data数据格式无效时抛出

    Returns:
        Axes | None: 返回matplotlib的坐标轴对象或None
    """
    # 处理None值
    if not _is_valid_data(data):
        raise ValueError("无效的 data")
    ax = ax or plt.gca()
    labels_name = labels_name or [str(i) for i in range(len(data))]
    colors = colors or ["gray"] * len(data)
    # 统一参数型
    width = float(width)
    color_alpha = float(color_alpha)
    dots_size = float(dots_size)
    title_fontsize = float(title_fontsize)
    title_pad = float(title_pad)
    x_label_fontsize = float(x_label_fontsize)
    x_tick_fontsize = float(x_tick_fontsize)
    x_tick_rotation = float(x_tick_rotation)
    y_label_fontsize = float(y_label_fontsize)
    y_tick_fontsize = float(y_tick_fontsize)
    y_tick_rotation = float(y_tick_rotation)
    popmean = float(popmean)
    asterisk_fontsize = float(asterisk_fontsize)

    def _draw_gradient_violin(ax, data, pos, width, c1, c2, color_alpha):
        # KDE估计
        kde = stats.gaussian_kde(data)
        buffer = (max(data) - min(data)) / 5
        y = np.linspace(min(data) - buffer, max(data) + buffer, 300)
        ymax = max(data) + buffer
        ymin = min(data) - buffer
        density = kde(y)
        density = density / density.max() * (width / 2)  # 控制violin宽度
        # violin左右边界
        x_left = pos - density
        x_right = pos + density
        # 组合封闭边界
        verts = np.concatenate(
            [np.stack([x_left, y], axis=1), np.stack([x_right[::-1], y[::-1]], axis=1)]
        )
        # 构建渐变图像
        grad_width = 200
        grad_height = 300
        gradient = np.linspace(0, 1, grad_width)
        if c1 == c2:
            rgba = to_rgba(c1, alpha=color_alpha)
            cmap = LinearSegmentedColormap.from_list("cmap", [rgba, rgba])
            gradient_rgb = plt.get_cmap(cmap)(gradient)
        else:
            cmap = LinearSegmentedColormap.from_list("cmap", [c1, "white", c2])
            gradient_rgb = plt.get_cmap(cmap)(gradient)[..., :3]
        gradient_img = np.tile(gradient_rgb, (grad_height, 1, 1))
        # 显示图像并裁剪成violin形状
        im = ax.imshow(
            gradient_img,
            extent=[pos - width / 2, pos + width / 2, y.min(), y.max()],
            origin="lower",
            aspect="auto",
            zorder=1,
        )
        # 添加边界线并作为clip
        poly = Polygon(
            verts,
            closed=True,
            facecolor="none",
            edgecolor="black",
            linewidth=1.2,
            zorder=2,
        )
        ax.add_patch(poly)
        im.set_clip_path(poly)
        # 添加 box 元素
        q1 = np.percentile(data, 25)
        q3 = np.percentile(data, 75)
        median = np.median(data)
        # 添加 IQR box(黑色矩形)
        ax.add_patch(
            Rectangle(
                (pos - width / 16, q1),  # 左下角坐标
                float(width / 8),  # 宽度
                q3 - q1,  # 高度
                facecolor="black",
                alpha=0.7,
            )
        )
        # 添加白色中位数点
        ax.plot(pos, median, "o", color="white", markersize=5, zorder=3)
        return ymax, ymin

    ymax_lst, ymin_lst = [], []
    for i, d in enumerate(data):
        if gradient_color:
            if colors_start is None:
                colors_start = ["#e38a48"] * len(data)
            if colors_end is None:  # 默认颜色
                colors_end = ["#4573a5"] * len(data)
            c1 = colors_start[i]
            c2 = colors_end[i]
        else:
            c1 = c2 = colors[i]
        ymax, ymin = _draw_gradient_violin(ax, d, i, width, c1, c2, color_alpha)

        ymax_lst.append(ymax)
        ymin_lst.append(ymin)
    ymax = max(ymax_lst)
    ymin = min(ymin_lst)

    # 绘制散点(复用现有函数)
    if show_dots:
        # 创建随机数生成器
        rng = np.random.default_rng(seed=42)
        scatter_positions = [rng.normal(i, 0.1, len(d)) for i, d in enumerate(data)]
        for i, d in enumerate(data):
            _add_scatter(ax, scatter_positions[i], d, colors[i], dots_size)

    # 美化
    ax.spines[["top", "right"]].set_visible(False)
    ax.set_title(title_name, fontsize=title_fontsize, pad=title_pad)
    # x轴
    ax.set_xlim(-0.5, len(data) - 0.5)
    ax.set_xlabel(x_label_name, fontsize=x_label_fontsize)
    ax.set_xticks(np.arange(len(data)))
    ax.set_xticklabels(
        labels_name,
        fontsize=x_tick_fontsize,
        rotation=x_tick_rotation,
        ha=x_label_ha,
        rotation_mode="anchor",
    )
    # y轴
    ax.tick_params(
        axis="y",
        labelsize=y_tick_fontsize,
        rotation=y_tick_rotation,
    )
    ax.set_ylabel(y_label_name, fontsize=y_label_fontsize)
    all_values = [ymin, ymax]
    set_yaxis(
        ax,
        all_values,
        y_lim=y_lim,
        ax_bottom_is_0=ax_bottom_is_0,
        y_max_tick_is_1=y_max_tick_is_1,
        math_text=math_text,
        one_decimal_place=one_decimal_place,
        percentage=percentage,
    )

    # 添加统计标记(复用现有函数)
    if statistic:
        _statistics(
            data,
            test_method,
            p_list,
            popmean,
            ax,
            all_values,
            statistical_line_color,
            asterisk_fontsize,
            asterisk_color,
            y_base,
            interval,
        )

    return ax

multi_bars

Functions:

Name Description
plot_multi_group_bar_figure

绘制多组分组条形图,支持误差线、散点叠加和统计显著性标注。

plot_multi_group_bar_figure

plot_multi_group_bar_figure(data: Sequence[Sequence[Sequence[float]]], ax: Axes | None = None, group_labels: list[str] | None = None, bar_labels: list[str] | None = None, bar_width: Num = 0.2, bar_gap: Num = 0.1, bar_color: list[str] | None = None, errorbar_type: str = 'sd', dots_color: str = 'gray', dots_size: int = 35, legend: bool = True, legend_position: tuple[Num, Num] = (1.2, 1), title_name: str = '', title_fontsize=12, title_pad=10, x_label_name: str = '', x_label_ha='center', x_label_fontsize=10, x_tick_fontsize=8, x_tick_rotation=0, y_label_name: str = '', y_label_fontsize=10, y_tick_fontsize=8, y_tick_rotation=0, y_lim: tuple[float, float] | None = None, statistic: bool = False, test_method: str = 'external', p_list: list[list[Num]] | None = None, line_color='0.5', asterisk_fontsize=10, asterisk_color='k', y_base: float | None = None, interval: float | None = None, ax_bottom_is_0: bool = False, y_max_tick_is_1: bool = False, math_text: bool = True, one_decimal_place: bool = False, percentage: bool = False) -> Axes

绘制多组分组条形图,支持误差线、散点叠加和统计显著性标注。

该函数用于可视化多组数据的比较,每组包含多个柱子,每个柱子显示均值、误差线 和原始数据点。特别适用于认知神经科学中的组间比较分析。

Parameters:

Name Type Description Default
data Sequence[Sequence[Sequence[float]]]

三层嵌套的数据结构 - 第一层:组 (groups) - 第二层:每组内的柱子 (bars),所有组的柱子数量必须一致 - 第三层:每个柱子内的数据点 (points),数量可以不同

required
ax Axes | None

matplotlib 的 Axes 对象。如果为 None,使用当前活动的 Axes。

None
group_labels list[str] | None

每个组的标签。如果为 None,自动生成 "Group 1", "Group 2" 等。

None
bar_labels list[str] | None

每个柱子的标签,用于图例。如果为 None,自动生成 "Bar 1", "Bar 2" 等。

None
bar_width Num

柱子的宽度。默认为 0.2。

0.2
bar_gap Num

同一组内柱子之间的间隔。默认为 0.1。

0.1
bar_color list[str] | None

每个柱子的颜色列表。如果为 None,所有柱子使用灰色。

None
errorbar_type str

误差线类型,'sd' 表示标准差,'se' 表示标准误。默认为 'sd'。

'sd'
dots_color str

散点的颜色。默认为 'gray'。

'gray'
dots_size int

散点的大小。默认为 35。

35
legend bool

是否显示图例。默认为 True。

True
legend_position tuple[Num, Num]

图例位置,使用 bbox_to_anchor 坐标。默认为 (1.2, 1)。

(1.2, 1)
title_name str

图表标题。默认为空字符串。

''
title_fontsize int

标题字体大小。默认为 12。

12
title_pad int

标题与图表的间距。默认为 10。

10
x_label_name str

x 轴标签文本。默认为空字符串。

''
x_label_ha str

x 轴刻度标签的水平对齐方式。默认为 'center'。

'center'
x_label_fontsize int

x 轴标签字体大小。默认为 10。

10
x_tick_fontsize int

x 轴刻度字体大小。默认为 8。

8
x_tick_rotation int

x 轴刻度旋转角度。默认为 0。

0
y_label_name str

y 轴标签文本。默认为空字符串。

''
y_label_fontsize int

y 轴标签字体大小。默认为 10。

10
y_tick_fontsize int

y 轴刻度字体大小。默认为 8。

8
y_tick_rotation int

y 轴刻度旋转角度。默认为 0。

0
y_lim tuple[float, float] | None

手动指定的 y 轴范围 (y_min, y_max)。 如果为 None,根据数据自动计算。

None
statistic bool

是否添加统计显著性标注。默认为 False。

False
test_method str

统计检验方法。当前仅支持 'external'(使用外部提供的 p 值)。

'external'
p_list list[list[Num]] | None

外部提供的 p 值列表。 结构为 [组1的p值列表, 组2的p值列表, ...],每个组的 p 值列表对应该组内所有两两比较。 当 statistic=True 且 test_method='external' 时必须提供。

None
line_color str

显著性标注连线的颜色。默认为 '0.5'(中灰色)。

'0.5'
asterisk_fontsize int

显著性星号的字体大小。默认为 10。

10
asterisk_color str

显著性星号的颜色。默认为 'k'(黑色)。

'k'
y_base float | None

显著性标注的起始 y 坐标。如果为 None,自动计算为数据最大值。

None
interval float | None

多个显著性标注之间的垂直间隔。 如果为 None,自动计算为 (y_max - 数据最大值) / (比较数量 + 1)。

None
ax_bottom_is_0 bool

是否将 y 轴底部固定为 0。默认为 False。

False
y_max_tick_is_1 bool

是否将最大刻度限制为 1。默认为 False。

False
math_text bool

是否使用科学计数法格式。默认为 True。

True
one_decimal_place bool

是否将刻度格式化为一位小数。默认为 False。

False
percentage bool

是否将刻度格式化为百分比形式。默认为 False。

False

Returns:

Name Type Description
Axes Axes

包含绘制内容的 matplotlib Axes 对象。

Raises:

Type Description
ValueError

当 data 不是三层嵌套结构时抛出。

ValueError

当所有组的柱子数量不一致时抛出。

ValueError

当 errorbar_type 不是 'sd' 或 'se' 时抛出。

ValueError

当 statistic=True 且 test_method='external' 但 p_list 为 None 时抛出。

Examples:

>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from plotfig import plot_multi_group_bar_figure
>>>
>>> # 创建示例数据:2 组,每组 3 个柱子
>>> data = [
...     [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
...     [[2, 3, 4], [5, 6, 7], [8, 9, 10]]
... ]
>>>
>>> # 绘制基本图表
>>> fig, ax = plt.subplots(figsize=(8, 6))
>>> ax = plot_multi_group_bar_figure(
...     data,
...     ax=ax,
...     group_labels=['Control', 'Treatment'],
...     bar_labels=['Condition A', 'Condition B', 'Condition C'],
...     title_name='Multi-Group Comparison'
... )
>>> fig.show()
Notes
  • 显著性星号规则:* (p≤0.05), ** (p≤0.01), *** (p≤0.001)
Source code in src/plotfig/multi_bars.py
def plot_multi_group_bar_figure(
    data: Sequence[Sequence[Sequence[float]]],
    ax: Axes | None = None,
    group_labels: list[str] | None = None,
    bar_labels: list[str] | None = None,
    bar_width: Num = 0.2,
    bar_gap: Num = 0.1,
    bar_color: list[str] | None = None,
    errorbar_type: str = "sd",
    dots_color: str = "gray",
    dots_size: int = 35,
    legend: bool = True,
    legend_position: tuple[Num, Num] = (1.2, 1),
    title_name: str = "",
    title_fontsize=12,
    title_pad=10,
    x_label_name: str = "",
    x_label_ha="center",
    x_label_fontsize=10,
    x_tick_fontsize=8,
    x_tick_rotation=0,
    y_label_name: str = "",
    y_label_fontsize=10,
    y_tick_fontsize=8,
    y_tick_rotation=0,
    y_lim: tuple[float, float] | None = None,
    statistic: bool = False,
    test_method: str = "external",
    p_list: list[list[Num]] | None = None,
    line_color="0.5",
    asterisk_fontsize=10,
    asterisk_color="k",
    y_base: float | None = None,
    interval: float | None = None,
    ax_bottom_is_0: bool = False,
    y_max_tick_is_1: bool = False,
    math_text: bool = True,
    one_decimal_place: bool = False,
    percentage: bool = False,
) -> Axes:
    """绘制多组分组条形图,支持误差线、散点叠加和统计显著性标注。

    该函数用于可视化多组数据的比较,每组包含多个柱子,每个柱子显示均值、误差线
    和原始数据点。特别适用于认知神经科学中的组间比较分析。

    Args:
        data (Sequence[Sequence[Sequence[float]]]): 三层嵌套的数据结构
            - 第一层:组 (groups)
            - 第二层:每组内的柱子 (bars),所有组的柱子数量必须一致
            - 第三层:每个柱子内的数据点 (points),数量可以不同
        ax (Axes | None): matplotlib 的 Axes 对象。如果为 None,使用当前活动的 Axes。
        group_labels (list[str] | None): 每个组的标签。如果为 None,自动生成 "Group 1", "Group 2" 等。
        bar_labels (list[str] | None): 每个柱子的标签,用于图例。如果为 None,自动生成 "Bar 1", "Bar 2" 等。
        bar_width (Num): 柱子的宽度。默认为 0.2。
        bar_gap (Num): 同一组内柱子之间的间隔。默认为 0.1。
        bar_color (list[str] | None): 每个柱子的颜色列表。如果为 None,所有柱子使用灰色。
        errorbar_type (str): 误差线类型,'sd' 表示标准差,'se' 表示标准误。默认为 'sd'。
        dots_color (str): 散点的颜色。默认为 'gray'。
        dots_size (int): 散点的大小。默认为 35。
        legend (bool): 是否显示图例。默认为 True。
        legend_position (tuple[Num, Num]): 图例位置,使用 bbox_to_anchor 坐标。默认为 (1.2, 1)。
        title_name (str): 图表标题。默认为空字符串。
        title_fontsize (int): 标题字体大小。默认为 12。
        title_pad (int): 标题与图表的间距。默认为 10。
        x_label_name (str): x 轴标签文本。默认为空字符串。
        x_label_ha (str): x 轴刻度标签的水平对齐方式。默认为 'center'。
        x_label_fontsize (int): x 轴标签字体大小。默认为 10。
        x_tick_fontsize (int): x 轴刻度字体大小。默认为 8。
        x_tick_rotation (int): x 轴刻度旋转角度。默认为 0。
        y_label_name (str): y 轴标签文本。默认为空字符串。
        y_label_fontsize (int): y 轴标签字体大小。默认为 10。
        y_tick_fontsize (int): y 轴刻度字体大小。默认为 8。
        y_tick_rotation (int): y 轴刻度旋转角度。默认为 0。
        y_lim (tuple[float, float] | None): 手动指定的 y 轴范围 (y_min, y_max)。
            如果为 None,根据数据自动计算。
        statistic (bool): 是否添加统计显著性标注。默认为 False。
        test_method (str): 统计检验方法。当前仅支持 'external'(使用外部提供的 p 值)。
        p_list (list[list[Num]] | None): 外部提供的 p 值列表。
            结构为 [组1的p值列表, 组2的p值列表, ...],每个组的 p 值列表对应该组内所有两两比较。
            当 statistic=True 且 test_method='external' 时必须提供。
        line_color (str): 显著性标注连线的颜色。默认为 '0.5'(中灰色)。
        asterisk_fontsize (int): 显著性星号的字体大小。默认为 10。
        asterisk_color (str): 显著性星号的颜色。默认为 'k'(黑色)。
        y_base (float | None): 显著性标注的起始 y 坐标。如果为 None,自动计算为数据最大值。
        interval (float | None): 多个显著性标注之间的垂直间隔。
            如果为 None,自动计算为 (y_max - 数据最大值) / (比较数量 + 1)。
        ax_bottom_is_0 (bool): 是否将 y 轴底部固定为 0。默认为 False。
        y_max_tick_is_1 (bool): 是否将最大刻度限制为 1。默认为 False。
        math_text (bool): 是否使用科学计数法格式。默认为 True。
        one_decimal_place (bool): 是否将刻度格式化为一位小数。默认为 False。
        percentage (bool): 是否将刻度格式化为百分比形式。默认为 False。

    Returns:
        Axes: 包含绘制内容的 matplotlib Axes 对象。

    Raises:
        ValueError: 当 data 不是三层嵌套结构时抛出。
        ValueError: 当所有组的柱子数量不一致时抛出。
        ValueError: 当 errorbar_type 不是 'sd' 或 'se' 时抛出。
        ValueError: 当 statistic=True 且 test_method='external' 但 p_list 为 None 时抛出。

    Examples:
        >>> import numpy as np
        >>> import matplotlib.pyplot as plt
        >>> from plotfig import plot_multi_group_bar_figure
        >>>
        >>> # 创建示例数据:2 组,每组 3 个柱子
        >>> data = [
        ...     [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
        ...     [[2, 3, 4], [5, 6, 7], [8, 9, 10]]
        ... ]
        >>>
        >>> # 绘制基本图表
        >>> fig, ax = plt.subplots(figsize=(8, 6))
        >>> ax = plot_multi_group_bar_figure(
        ...     data,
        ...     ax=ax,
        ...     group_labels=['Control', 'Treatment'],
        ...     bar_labels=['Condition A', 'Condition B', 'Condition C'],
        ...     title_name='Multi-Group Comparison'
        ... )
        >>> fig.show()

    Notes:
        - 显著性星号规则:* (p≤0.05), ** (p≤0.01), *** (p≤0.001)
    """
    _data_valiation(data)

    n_groups = len(data)
    n_bars = len(data[0])

    ax = ax or plt.gca()
    group_labels = group_labels or [f"Group {i + 1}" for i in range(len(data))]
    bar_labels = bar_labels or [f"Bar {i + 1}" for i in range(n_bars)]
    bar_color = bar_color or ["gray"] * n_bars

    # 把所有子列表展开成一个大列表
    all_values = [x for sublist1 in data for sublist2 in sublist1 for x in sublist2]

    x_positions_all = []
    for index_group, group_data in enumerate(data):
        x_positions = (
            np.arange(n_bars) * (bar_width + bar_gap)
            + bar_width / 2
            + index_group
            - (n_bars * bar_width + (n_bars - 1) * bar_gap) / 2
        )
        x_positions_all.append(x_positions)

        # 计算均值、标准差、标准误
        means = [compute_summary(group_data[i])[0] for i in range(n_bars)]
        sds = [compute_summary(group_data[i])[1] for i in range(n_bars)]
        ses = [compute_summary(group_data[i])[2] for i in range(n_bars)]
        if errorbar_type == "sd":
            error_values = sds
        elif errorbar_type == "se":
            error_values = ses
        else:
            raise ValueError("errorbar_type 只能是 'sd' 或者 'se'")
        # 绘制柱子
        bars = ax.bar(
            x_positions, means, width=bar_width, color=bar_color, alpha=1, edgecolor="k"
        )
        ax.errorbar(
            x_positions,
            means,
            error_values,
            fmt="none",
            linewidth=1,
            capsize=3,
            color="black",
        )
        # 绘制散点
        for index_bar, dot in enumerate(group_data):
            # 创建随机数生成器
            rng = np.random.default_rng(seed=42)
            dot_x_pos = rng.normal(
                x_positions[index_bar], scale=bar_width / 7, size=len(dot)
            )
            ax.scatter(
                dot_x_pos,
                dot,
                c=dots_color,
                s=dots_size,
                edgecolors="white",
                linewidths=1,
                alpha=0.5,
            )

    if legend:
        ax.legend(bars, bar_labels, bbox_to_anchor=legend_position)

    # 美化
    ax.spines[["top", "right"]].set_visible(False)
    ax.set_title(
        title_name,
        fontsize=title_fontsize,
        pad=title_pad,
    )
    # x轴
    ax.set_xlabel(x_label_name, fontsize=x_label_fontsize)
    ax.set_xticks(np.arange(n_groups))
    ax.set_xticklabels(
        group_labels,
        ha=x_label_ha,
        rotation_mode="anchor",
        fontsize=x_tick_fontsize,
        rotation=x_tick_rotation,
    )
    # y轴
    ax.tick_params(
        axis="y",
        labelsize=y_tick_fontsize,
        rotation=y_tick_rotation,
    )
    ax.set_ylabel(y_label_name, fontsize=y_label_fontsize)
    set_yaxis(
        ax,
        all_values,
        y_lim,
        ax_bottom_is_0,
        y_max_tick_is_1,
        math_text,
        one_decimal_place,
        percentage,
    )

    # 添加统计显著性标记
    if statistic:
        for index_group, group_data in enumerate(data):
            x_positions = x_positions_all[index_group]
            comparisons = []
            idx = 0
            for i in range(len(group_data)):
                for j in range(i + 1, len(group_data)):
                    if test_method == "external":
                        if p_list is None:
                            raise ValueError("p_list不能为空")
                        p = p_list[index_group][idx]
                        idx += 1
                    else:
                        raise ValueError("多组数据统计测试方法暂时仅支持 external方法")
                    if p <= 0.05:
                        comparisons.append((x_positions[i], x_positions[j], p))
            y_max = ax.get_ylim()[1]
            y_base = y_base or np.max(all_values)
            interval = interval or (y_max - np.max(all_values)) / (len(comparisons) + 1)

            annotate_significance(
                ax,
                comparisons,
                y_base,
                interval,
                line_color=line_color,
                star_offset=interval / 5,
                fontsize=asterisk_fontsize,
                color=asterisk_color,
            )

    return ax

correlation

Functions:

Name Description
plot_correlation_figure

绘制两个数据集之间的相关性图,支持线性回归、置信区间和统计方法(Spearman 或 Pearson)。

plot_correlation_figure

plot_correlation_figure(data1: list[Num] | ndarray, data2: list[Num] | ndarray, ax: Axes | None = None, stats_method: str = 'spearman', ci: bool = False, ci_color: str = 'gray', dots_color: str | list[str] = 'steelblue', dots_size: int | float = 10, line_color: str = 'r', title_name: str = '', title_fontsize: int = 12, title_pad: int = 10, x_label_name: str = '', x_label_fontsize: int = 10, x_tick_fontsize: int = 8, x_tick_rotation: int = 0, x_major_locator: float | None = None, x_max_tick_to_value: float | None = None, x_format: str = 'normal', y_label_name: str = '', y_label_fontsize: int = 10, y_tick_fontsize: int = 8, y_tick_rotation: int = 0, y_major_locator: float | None = None, y_max_tick_to_value: float | None = None, y_format: str = 'sci', asterisk_fontsize: int = 10, show_p_value: bool = False, hexbin: bool = False, hexbin_cmap: LinearSegmentedColormap | None = None, hexbin_gridsize: int = 50, xlim: list[Num] | tuple[Num, Num] | None = None, ylim: list[Num] | tuple[Num, Num] | None = None) -> Axes

绘制两个数据集之间的相关性图,支持线性回归、置信区间和统计方法(Spearman 或 Pearson)。

Parameters:

Name Type Description Default
data1 list[Num] | ndarray

第一个数据集,可以是整数或浮点数列表或数组。

required
data2 list[Num] | ndarray

第二个数据集,可以是整数或浮点数列表或数组。

required
ax Axes | None

matplotlib 的 Axes 对象,用于绘图。默认为 None,使用当前 Axes。

None
stats_method str

相关性统计方法,支持 "spearman" 和 "pearson"。默认为 "spearman"。

'spearman'
ci bool

是否绘制置信区间带。默认为 False。

False
ci_color str

置信区间带颜色。默认为 "salmon"。

'gray'
dots_color str

散点的颜色。默认为 "steelblue"。

'steelblue'
dots_size int | float

散点的大小。默认为 1。

10
line_color str

回归线的颜色。默认为 "r"(红色)。

'r'
title_name str

图形标题。默认为空字符串。

''
title_fontsize int

标题字体大小。默认为 10。

12
title_pad int

标题与图形之间的间距。默认为 10。

10
x_label_name str

X 轴标签名称。默认为空字符串。

''
x_label_fontsize int

X 轴标签字体大小。默认为 10。

10
x_tick_fontsize int

X 轴刻度标签字体大小。默认为 10。

8
x_tick_rotation int

X 轴刻度标签旋转角度。默认为 0。

0
x_major_locator float | None

设置 X 轴主刻度间隔。默认为 None。

None
x_max_tick_to_value float | None

设置 X 轴最大显示刻度值。默认为 None。

None
x_format str

X 轴格式化方式,支持 "normal", "sci", "1f", "percent"。默认为 "normal"。

'normal'
y_label_name str

Y 轴标签名称。默认为空字符串。

''
y_label_fontsize int

Y 轴标签字体大小。默认为 10。

10
y_tick_fontsize int

Y 轴刻度标签字体大小。默认为 10。

8
y_tick_rotation int

Y 轴刻度标签旋转角度。默认为 0。

0
y_major_locator float | None

设置 Y 轴主刻度间隔。默认为 None。

None
y_max_tick_to_value float | None

设置 Y 轴最大显示刻度值。默认为 None。

None
y_format str

Y 轴格式化方式,支持 "normal", "sci", "1f", "percent"。默认为 "normal"。

'sci'
asterisk_fontsize int

显著性星号字体大小。默认为 10。

10
show_p_value bool

是否显示 p 值。默认为 True。

False
hexbin bool

是否使用六边形箱图。默认为 False。

False
hexbin_cmap LinearSegmentedColormap | None

六边形箱图的颜色映射。默认为 None。

None
hexbin_gridsize int

六边形箱图的网格大小。默认为 50。

50
xlim list[Num] | tuple[Num, Num] | None

X 轴范围限制。默认为 None。

None
ylim list[Num] | tuple[Num, Num] | None

Y 轴范围限制。默认为 None。

None

Returns:

Type Description
Axes

None

Source code in src/plotfig/correlation.py
def plot_correlation_figure(
    data1: list[Num] | np.ndarray,
    data2: list[Num] | np.ndarray,
    ax: Axes | None = None,
    stats_method: str = "spearman",
    ci: bool = False,
    ci_color: str = "gray",
    dots_color: str | list[str] = "steelblue",
    dots_size: int | float = 10,
    line_color: str = "r",
    title_name: str = "",
    title_fontsize: int = 12,
    title_pad: int = 10,
    x_label_name: str = "",
    x_label_fontsize: int = 10,
    x_tick_fontsize: int = 8,
    x_tick_rotation: int = 0,
    x_major_locator: float | None = None,
    x_max_tick_to_value: float | None = None,
    x_format: str = "normal",  # 支持 "normal", "sci", "1f", "percent"
    y_label_name: str = "",
    y_label_fontsize: int = 10,
    y_tick_fontsize: int = 8,
    y_tick_rotation: int = 0,
    y_major_locator: float | None = None,
    y_max_tick_to_value: float | None = None,
    y_format: str = "sci",  # 支持 "normal", "sci", "1f", "percent"
    asterisk_fontsize: int = 10,
    show_p_value: bool = False,
    hexbin: bool = False,
    hexbin_cmap: LinearSegmentedColormap | None = None,
    hexbin_gridsize: int = 50,
    xlim: list[Num] | tuple[Num, Num] | None = None,
    ylim: list[Num] | tuple[Num, Num] | None = None,
) -> Axes:
    """
    绘制两个数据集之间的相关性图,支持线性回归、置信区间和统计方法(Spearman 或 Pearson)。

    Args:
        data1 (list[Num] | np.ndarray): 第一个数据集,可以是整数或浮点数列表或数组。
        data2 (list[Num] | np.ndarray): 第二个数据集,可以是整数或浮点数列表或数组。
        ax (plt.Axes | None, optional): matplotlib 的 Axes 对象,用于绘图。默认为 None,使用当前 Axes。
        stats_method (str, optional): 相关性统计方法,支持 "spearman" 和 "pearson"。默认为 "spearman"。
        ci (bool, optional): 是否绘制置信区间带。默认为 False。
        ci_color (str, optional): 置信区间带颜色。默认为 "salmon"。
        dots_color (str, optional): 散点的颜色。默认为 "steelblue"。
        dots_size (int | float, optional): 散点的大小。默认为 1。
        line_color (str, optional): 回归线的颜色。默认为 "r"(红色)。
        title_name (str, optional): 图形标题。默认为空字符串。
        title_fontsize (int, optional): 标题字体大小。默认为 10。
        title_pad (int, optional): 标题与图形之间的间距。默认为 10。
        x_label_name (str, optional): X 轴标签名称。默认为空字符串。
        x_label_fontsize (int, optional): X 轴标签字体大小。默认为 10。
        x_tick_fontsize (int, optional): X 轴刻度标签字体大小。默认为 10。
        x_tick_rotation (int, optional): X 轴刻度标签旋转角度。默认为 0。
        x_major_locator (float | None, optional): 设置 X 轴主刻度间隔。默认为 None。
        x_max_tick_to_value (float | None, optional): 设置 X 轴最大显示刻度值。默认为 None。
        x_format (str, optional): X 轴格式化方式,支持 "normal", "sci", "1f", "percent"。默认为 "normal"。
        y_label_name (str, optional): Y 轴标签名称。默认为空字符串。
        y_label_fontsize (int, optional): Y 轴标签字体大小。默认为 10。
        y_tick_fontsize (int, optional): Y 轴刻度标签字体大小。默认为 10。
        y_tick_rotation (int, optional): Y 轴刻度标签旋转角度。默认为 0。
        y_major_locator (float | None, optional): 设置 Y 轴主刻度间隔。默认为 None。
        y_max_tick_to_value (float | None, optional): 设置 Y 轴最大显示刻度值。默认为 None。
        y_format (str, optional): Y 轴格式化方式,支持 "normal", "sci", "1f", "percent"。默认为 "normal"。
        asterisk_fontsize (int, optional): 显著性星号字体大小。默认为 10。
        show_p_value (bool, optional): 是否显示 p 值。默认为 True。
        hexbin (bool, optional): 是否使用六边形箱图。默认为 False。
        hexbin_cmap (LinearSegmentedColormap | None, optional): 六边形箱图的颜色映射。默认为 None。
        hexbin_gridsize (int, optional): 六边形箱图的网格大小。默认为 50。
        xlim (list[Num] | tuple[Num, Num] | None, optional): X 轴范围限制。默认为 None。
        ylim (list[Num] | tuple[Num, Num] | None, optional): Y 轴范围限制。默认为 None。

    Returns:
        None
    """

    def set_axis(
        ax,
        axis,
        label,
        labelsize,
        ticksize,
        rotation,
        locator,
        max_tick_value,
        fmt,
        lim,
    ):
        if axis == "x":
            set_label = ax.set_xlabel
            get_ticks = ax.get_xticks
            set_ticks = ax.set_xticks
            axis_formatter = ax.xaxis.set_major_formatter
            axis_major_locator = ax.xaxis.set_major_locator
        else:
            set_label = ax.set_ylabel
            get_ticks = ax.get_yticks
            set_ticks = ax.set_yticks
            axis_formatter = ax.yaxis.set_major_formatter
            axis_major_locator = ax.yaxis.set_major_locator

        # 设置轴范围
        if lim is not None:
            if axis == "x":
                ax.set_xlim(lim)
            else:
                ax.set_ylim(lim)

        set_label(label, fontsize=labelsize)
        ax.tick_params(axis=axis, which="major", labelsize=ticksize, rotation=rotation)
        if locator is not None:
            axis_major_locator(MultipleLocator(locator))
        if max_tick_value is not None:
            set_ticks([i for i in get_ticks() if i <= max_tick_value])

        if fmt == "sci":
            formatter = ScalarFormatter(useMathText=True)
            formatter.set_powerlimits((-2, 2))
            axis_formatter(formatter)
        elif fmt == "1f":
            axis_formatter(FormatStrFormatter("%.1f"))
        elif fmt == "percent":
            axis_formatter(FuncFormatter(lambda x, pos: f"{x:.0%}"))

    if ax is None:
        ax = plt.gca()

    A = np.asarray(data1)
    B = np.asarray(data2)

    slope, intercept, r_value, p_value, _ = stats.linregress(A, B)
    x_seq = np.linspace(A.min(), A.max(), 100)
    y_pred = slope * x_seq + intercept

    if hexbin:
        if hexbin_cmap is None:
            hexbin_cmap = LinearSegmentedColormap.from_list(
                "custom", ["#ffffff", "#4573a5"]
            )
        hb = ax.hexbin(A, B, gridsize=hexbin_gridsize, cmap=hexbin_cmap)
    else:
        ax.scatter(A, B, c=dots_color, s=dots_size)
    ax.plot(x_seq, y_pred, line_color, lw=1)

    if ci:
        n = len(A)
        dof = n - 2
        t_val = stats.t.ppf(0.975, dof)
        x_mean = A.mean()
        residuals = B - (slope * A + intercept)
        s_err = np.sqrt(np.sum(residuals**2) / dof)
        SSxx = np.sum((A - x_mean) ** 2)
        conf_interval = t_val * s_err * np.sqrt(1 / n + (x_seq - x_mean) ** 2 / SSxx)
        ax.fill_between(
            x_seq,
            y_pred - conf_interval,
            y_pred + conf_interval,
            color=ci_color,
            alpha=0.3,
        )

    ax.spines[["top", "right"]].set_visible(False)
    ax.set_title(title_name, fontsize=title_fontsize, pad=title_pad)

    set_axis(
        ax,
        "x",
        x_label_name,
        x_label_fontsize,
        x_tick_fontsize,
        x_tick_rotation,
        x_major_locator,
        x_max_tick_to_value,
        x_format,
        xlim,
    )
    set_axis(
        ax,
        "y",
        y_label_name,
        y_label_fontsize,
        y_tick_fontsize,
        y_tick_rotation,
        y_major_locator,
        y_max_tick_to_value,
        y_format,
        ylim,
    )

    # 标注r值或rho值
    if stats_method == "spearman":
        s, p = stats.spearmanr(A, B)
        label = r"$\rho$"
    elif stats_method == "pearson":
        s, p = stats.pearsonr(A, B)
        label = "r"
    else:
        print(f"没有统计方法 {stats_method},请检查拼写。更换为默认的 spearman 方法。")
        s, p = stats.spearmanr(A, B)
        label = r"$\rho$"

    if show_p_value:
        asterisk = f" p={p:.3f}"
    else:
        asterisk = (
            " ***" if p < 0.001 else " **" if p < 0.01 else " *" if p < 0.05 else ""
        )
    x_start, x_end = ax.get_xlim()
    y_start, y_end = ax.get_ylim()
    ax.text(
        x_start + (x_end - x_start) * 0.1,
        y_start + (y_end - y_start) * 0.9,
        f"{label}={s:.3f}{asterisk}",
        va="center",
        fontsize=asterisk_fontsize,
    )
    if hexbin:
        return hb
    return ax

matrix

Functions:

Name Description
plot_matrix_figure

将矩阵绘制为热图,可选显示标签、颜色条和标题。

plot_matrix_figure

plot_matrix_figure(data: ndarray, ax: Axes | None = None, row_labels_name: Sequence[str] | None = None, col_labels_name: Sequence[str] | None = None, cmap: str = 'bwr', vmin: Num | None = None, vmax: Num | None = None, aspect: str = 'equal', colorbar: bool = True, colorbar_label_name: str = '', colorbar_pad: Num = 0.1, colorbar_label_fontsize: Num = 10, colorbar_tick_fontsize: Num = 10, colorbar_tick_rotation: Num = 0, row_labels_fontsize: Num = 10, col_labels_fontsize: Num = 10, x_rotation: Num = 60, title_name: str = '', title_fontsize: Num = 15, title_pad: Num = 20, diag_border: bool = False, xlabel: str | None = None, ylabel: str | None = None, **imshow_kwargs: Any) -> Axes

将矩阵绘制为热图,可选显示标签、颜色条和标题。

Parameters:

Name Type Description Default
data ndarray

形状为 (N, M) 的二维数组,用于显示矩阵。

required
ax Axes | None

要绘图的 Matplotlib 坐标轴。如果为 None,则使用当前坐标轴。

None
row_labels_name Sequence[str] | None

行标签列表。

None
col_labels_name Sequence[str] | None

列标签列表。

None
cmap str

矩阵使用的颜色映射。

'bwr'
vmin Num | None

颜色缩放的最小值,默认使用 data.min()。

None
vmax Num | None

颜色缩放的最大值,默认使用 data.max()。

None
aspect str

图像的纵横比,通常为 "equal" 或 "auto"。

'equal'
colorbar bool

是否显示颜色条。

True
colorbar_label_name str

颜色条的标签。

''
colorbar_pad Num

颜色条与矩阵之间的间距。

0.1
colorbar_label_fontsize Num

颜色条标签的字体大小。

10
colorbar_tick_fontsize Num

颜色条刻度的字体大小。

10
colorbar_tick_rotation Num

颜色条刻度标签的旋转角度。

0
row_labels_fontsize Num

行标签的字体大小。

10
col_labels_fontsize Num

列标签的字体大小。

10
x_rotation Num

x 轴(列)标签的旋转角度。

60
title_name Num

图表标题。

''
title_fontsize Num

标题的字体大小。

15
title_pad Num

标题上方的间距。

20
diag_border bool

是否绘制对角线单元格边框。

False
xlabel str | None

X轴的整体标签名称。

None
ylabel str | None

Y轴的整体标签名称。

None
**imshow_kwargs Any

传递给 imshow() 的其他关键字参数。

{}

Returns:

Name Type Description
Axes Axes

绘图的坐标轴对象。

Source code in src/plotfig/matrix.py
def plot_matrix_figure(
    data: np.ndarray,
    ax: Axes | None = None,
    row_labels_name: Sequence[str] | None = None,
    col_labels_name: Sequence[str] | None = None,
    cmap: str = "bwr",
    vmin: Num | None = None,
    vmax: Num | None = None,
    aspect: str = "equal",
    colorbar: bool = True,
    colorbar_label_name: str = "",
    colorbar_pad: Num = 0.1,
    colorbar_label_fontsize: Num = 10,
    colorbar_tick_fontsize: Num = 10,
    colorbar_tick_rotation: Num = 0,
    row_labels_fontsize: Num = 10,
    col_labels_fontsize: Num = 10,
    x_rotation: Num = 60,
    title_name: str = "",
    title_fontsize: Num = 15,
    title_pad: Num = 20,
    diag_border: bool = False,
    xlabel: str | None = None,
    ylabel: str | None = None,
    **imshow_kwargs: Any,
) -> Axes:
    """
    将矩阵绘制为热图,可选显示标签、颜色条和标题。

    Args:
        data (np.ndarray): 形状为 (N, M) 的二维数组,用于显示矩阵。
        ax (Axes | None): 要绘图的 Matplotlib 坐标轴。如果为 None,则使用当前坐标轴。
        row_labels_name (Sequence[str] | None): 行标签列表。
        col_labels_name (Sequence[str] | None): 列标签列表。
        cmap (str): 矩阵使用的颜色映射。
        vmin (Num | None): 颜色缩放的最小值,默认使用 data.min()。
        vmax (Num | None): 颜色缩放的最大值,默认使用 data.max()。
        aspect (str): 图像的纵横比,通常为 "equal" 或 "auto"。
        colorbar (bool): 是否显示颜色条。
        colorbar_label_name (str): 颜色条的标签。
        colorbar_pad (Num): 颜色条与矩阵之间的间距。
        colorbar_label_fontsize (Num): 颜色条标签的字体大小。
        colorbar_tick_fontsize (Num): 颜色条刻度的字体大小。
        colorbar_tick_rotation (Num): 颜色条刻度标签的旋转角度。
        row_labels_fontsize (Num): 行标签的字体大小。
        col_labels_fontsize (Num): 列标签的字体大小。
        x_rotation (Num): x 轴(列)标签的旋转角度。
        title_name (Num): 图表标题。
        title_fontsize (Num): 标题的字体大小。
        title_pad (Num): 标题上方的间距。
        diag_border (bool): 是否绘制对角线单元格边框。
        xlabel (str | None): X轴的整体标签名称。
        ylabel (str | None): Y轴的整体标签名称。
        **imshow_kwargs (Any): 传递给 `imshow()` 的其他关键字参数。

    Returns:
        Axes: 绘图的坐标轴对象。
    """

    ax = ax or plt.gca()
    vmin = vmin if vmin is not None else np.min(data)
    vmax = vmax if vmax is not None else np.max(data)

    im = ax.imshow(
        data, cmap=cmap, vmin=vmin, vmax=vmax, aspect=aspect, **imshow_kwargs
    )
    ax.set_title(title_name, fontsize=title_fontsize, pad=title_pad)

    # 设置X轴和Y轴标签
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

    if diag_border:
        for i in range(data.shape[0]):
            ax.add_patch(
                plt.Rectangle(
                    (i - 0.5, i - 0.5), 1, 1, fill=False, edgecolor="black", lw=0.5
                )
            )

    if col_labels_name is not None:
        ax.set_xticks(np.arange(data.shape[1]))
        ax.set_xticklabels(
            col_labels_name,
            fontsize=col_labels_fontsize,
            rotation=x_rotation,
            ha="right",
            rotation_mode="anchor",
        )

    if row_labels_name is not None:
        ax.set_yticks(np.arange(data.shape[0]))
        ax.set_yticklabels(row_labels_name, fontsize=row_labels_fontsize)

    if colorbar:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=colorbar_pad)
        cbar = ax.figure.colorbar(im, cax=cax)
        cbar.ax.set_ylabel(
            colorbar_label_name,
            rotation=-90,
            va="bottom",
            fontsize=colorbar_label_fontsize,
        )
        cbar.ax.tick_params(
            labelsize=colorbar_tick_fontsize, rotation=colorbar_tick_rotation
        )
        # Match colorbar height to the main plot
        ax_pos = ax.get_position()
        cax.set_position(
            [cax.get_position().x0, ax_pos.y0, cax.get_position().width, ax_pos.height]
        )

    return ax

brain_surface

Functions:

Name Description
plot_brain_surface_figure

在大脑皮层表面绘制数值数据的函数。

plot_brain_surface_figure

plot_brain_surface_figure(data: Mapping[str, Num], species: str = 'human', atlas: str = 'glasser', surf: str = 'veryinflated', ax: Axes | None = None, vmin: Num | None = None, vmax: Num | None = None, cmap: str = 'viridis', colorbar: bool = True, colorbar_location: str = 'right', colorbar_label_name: str = '', colorbar_label_rotation: int = 0, colorbar_decimals: int = 1, colorbar_fontsize: int = 8, colorbar_nticks: int = 2, colorbar_shrink: float = 0.15, colorbar_aspect: int = 8, colorbar_draw_border: bool = False, title_name: str = '', title_fontsize: int = 12, as_outline: bool = False) -> Axes

在大脑皮层表面绘制数值数据的函数。

Parameters:

Name Type Description Default
data dict[str, float]

包含脑区名称和对应数值的字典,键为脑区名称(如"lh_bankssts"),值为数值

required
species str

物种名称,支持"human"、"chimpanzee"、"macaque". Defaults to "human".

'human'
atlas str

脑图集名称,根据物种不同可选不同图集。人上包括"glasser"、"bna",黑猩猩上包括"bna",猕猴上包括"charm5"、"charm6"、"bna"以及"d99". Defaults to "glasser".

'glasser'
surf str

大脑皮层表面类型,如"inflated"、"veryinflated"、"midthickness"等. Defaults to "veryinflated".

'veryinflated'
ax Axes | None

matplotlib的坐标轴对象,如果为None则使用当前坐标轴. Defaults to None.

None
vmin Num | None

颜色映射的最小值,None表示使用数据中的最小值. Defaults to None.

None
vmax Num | None

颜色映射的最大值,None表示使用数据中的最大值. Defaults to None.

None
cmap str

颜色映射方案,如"viridis"、"Blues"、"Reds"等. Defaults to "viridis".

'viridis'
colorbar bool

是否显示颜色条. Defaults to True.

True
colorbar_location str

颜色条位置,可选"left"、"right"、"top"、"bottom". Defaults to "right".

'right'
colorbar_label_name str

颜色条标签名称. Defaults to "".

''
colorbar_label_rotation int

颜色条标签旋转角度. Defaults to 0.

0
colorbar_decimals int

颜色条刻度标签的小数位数. Defaults to 1.

1
colorbar_fontsize int

颜色条字体大小. Defaults to 8.

8
colorbar_nticks int

颜色条刻度数量. Defaults to 2.

2
colorbar_shrink float

颜色条收缩比例. Defaults to 0.15.

0.15
colorbar_aspect int

颜色条宽高比. Defaults to 8.

8
colorbar_draw_border bool

是否绘制颜色条边框. Defaults to False.

False
title_name str

图形标题. Defaults to "".

''
title_fontsize int

标题字体大小. Defaults to 12.

12
as_outline bool

是否以轮廓线形式显示. Defaults to False.

False

Raises:

Type Description
ValueError

当指定的物种不支持时抛出

ValueError

当指定的图集不支持时抛出

ValueError

当数据为空时抛出

ValueError

当vmin大于vmax时抛出

Returns:

Name Type Description
Axes Axes

包含绘制图像的matplotlib坐标轴对象

Source code in src/plotfig/brain_surface.py
def plot_brain_surface_figure(
    data: Mapping[str, Num],
    species: str = "human",
    atlas: str = "glasser",
    surf: str = "veryinflated",
    ax: Axes | None = None,
    vmin: Num | None = None,
    vmax: Num | None = None,
    cmap: str = "viridis",
    colorbar: bool = True,
    colorbar_location: str = "right",
    colorbar_label_name: str = "",
    colorbar_label_rotation: int = 0,
    colorbar_decimals: int = 1,
    colorbar_fontsize: int = 8,
    colorbar_nticks: int = 2,
    colorbar_shrink: float = 0.15,
    colorbar_aspect: int = 8,
    colorbar_draw_border: bool = False,
    title_name: str = "",
    title_fontsize: int = 12,
    as_outline: bool = False,
) -> Axes:
    """在大脑皮层表面绘制数值数据的函数。

    Args:
        data (dict[str, float]): 包含脑区名称和对应数值的字典,键为脑区名称(如"lh_bankssts"),值为数值
        species (str, optional): 物种名称,支持"human"、"chimpanzee"、"macaque". Defaults to "human".
        atlas (str, optional): 脑图集名称,根据物种不同可选不同图集。人上包括"glasser"、"bna",黑猩猩上包括"bna",猕猴上包括"charm5"、"charm6"、"bna"以及"d99". Defaults to "glasser".
        surf (str, optional): 大脑皮层表面类型,如"inflated"、"veryinflated"、"midthickness"等. Defaults to "veryinflated".
        ax (Axes | None, optional): matplotlib的坐标轴对象,如果为None则使用当前坐标轴. Defaults to None.
        vmin (Num | None, optional): 颜色映射的最小值,None表示使用数据中的最小值. Defaults to None.
        vmax (Num | None, optional): 颜色映射的最大值,None表示使用数据中的最大值. Defaults to None.
        cmap (str, optional): 颜色映射方案,如"viridis"、"Blues"、"Reds"等. Defaults to "viridis".
        colorbar (bool, optional): 是否显示颜色条. Defaults to True.
        colorbar_location (str, optional): 颜色条位置,可选"left"、"right"、"top"、"bottom". Defaults to "right".
        colorbar_label_name (str, optional): 颜色条标签名称. Defaults to "".
        colorbar_label_rotation (int, optional): 颜色条标签旋转角度. Defaults to 0.
        colorbar_decimals (int, optional): 颜色条刻度标签的小数位数. Defaults to 1.
        colorbar_fontsize (int, optional): 颜色条字体大小. Defaults to 8.
        colorbar_nticks (int, optional): 颜色条刻度数量. Defaults to 2.
        colorbar_shrink (float, optional): 颜色条收缩比例. Defaults to 0.15.
        colorbar_aspect (int, optional): 颜色条宽高比. Defaults to 8.
        colorbar_draw_border (bool, optional): 是否绘制颜色条边框. Defaults to False.
        title_name (str, optional): 图形标题. Defaults to "".
        title_fontsize (int, optional): 标题字体大小. Defaults to 12.
        as_outline (bool, optional): 是否以轮廓线形式显示. Defaults to False.

    Raises:
        ValueError: 当指定的物种不支持时抛出
        ValueError: 当指定的图集不支持时抛出
        ValueError: 当数据为空时抛出
        ValueError: 当vmin大于vmax时抛出

    Returns:
        Axes: 包含绘制图像的matplotlib坐标轴对象
    """
    # 获取或创建坐标轴对象
    ax = ax or plt.gca()

    # 提取所有数值用于确定vmin和vmax
    values = list(data.values())
    if not values:
        raise ValueError("data 不能为空")
    vmin = min(values) if vmin is None else vmin
    vmax = max(values) if vmax is None else vmax
    if vmin == vmax:
        vmin, vmax = min(0, vmin), max(0, vmax)
    if vmin > vmax:
        raise ValueError("vmin必须小于等于vmax")

    # 设置数据文件路径
    # 定义不同物种、表面类型和图集的文件路径信息
    atlas_info = {
        "human": {
            "surf": {
                "lh": f"surfaces/human_fsLR/tpl-fsLR_den-32k_hemi-L_{surf}.surf.gii",
                "rh": f"surfaces/human_fsLR/tpl-fsLR_den-32k_hemi-R_{surf}.surf.gii",
            },
            "atlas": {
                "glasser": {
                    "lh": "atlases/human_Glasser/fsaverage.L.Glasser.32k_fs_LR.label.gii",
                    "rh": "atlases/human_Glasser/fsaverage.R.Glasser.32k_fs_LR.label.gii",
                },
                "bna": {
                    "lh": "atlases/human_BNA/fsaverage.L.BNA.32k_fs_LR.label.gii",
                    "rh": "atlases/human_BNA/fsaverage.R.BNA.32k_fs_LR.label.gii",
                },
            },
            "sulc": {
                "lh": "surfaces/human_fsLR/100206.L.sulc.32k_fs_LR.shape.gii",
                "rh": "surfaces/human_fsLR/100206.R.sulc.32k_fs_LR.shape.gii",
            },
        },
        "chimpanzee": {
            "surf": {
                "lh": f"surfaces/chimpanzee_BNA/ChimpYerkes29_v1.2.L.{surf}.32k_fs_LR.surf.gii",
                "rh": f"surfaces/chimpanzee_BNA/ChimpYerkes29_v1.2.R.{surf}.32k_fs_LR.surf.gii",
            },
            "atlas": {
                "bna": {
                    "lh": "atlases/chimpanzee_BNA/ChimpBNA.L.32k_fs_LR.label.gii",
                    "rh": "atlases/chimpanzee_BNA/ChimpBNA.R.32k_fs_LR.label.gii",
                },
            },
        },
        "macaque": {
            "surf": {
                "lh": f"surfaces/macaque_BNA/civm.L.{surf}.32k_fs_LR.surf.gii",
                "rh": f"surfaces/macaque_BNA/civm.R.{surf}.32k_fs_LR.surf.gii",
            },
            "atlas": {
                "charm5": {
                    "lh": "atlases/macaque_CHARM5/L.charm5.label.gii",
                    "rh": "atlases/macaque_CHARM5/R.charm5.label.gii",
                },
                "charm6": {
                    "lh": "atlases/macaque_CHARM6/L.charm6.label.gii",
                    "rh": "atlases/macaque_CHARM6/R.charm6.label.gii",
                },
                "bna": {
                    "lh": "atlases/macaque_BNA/MBNA_124_32k_L.label.gii",
                    "rh": "atlases/macaque_BNA/MBNA_124_32k_R.label.gii",
                },
                "d99": {
                    "lh": "atlases/macaque_D99/L.d99.label.gii",
                    "rh": "atlases/macaque_D99/R.d99.label.gii",
                },
            },
            "sulc": {
                "lh": "surfaces/macaque_BNA/SC_06018.L.sulc.32k_fs_LR.shape.gii",
                "rh": "surfaces/macaque_BNA/SC_06018.R.sulc.32k_fs_LR.shape.gii",
            },
        },
    }

    # 检查物种是否支持
    if species not in atlas_info:
        raise ValueError(
            f"不支持的物种:{species}。支持的物种列表为:{list(atlas_info.keys())}"
        )
    else:
        # 检查指定物种的图集是否支持
        if atlas not in atlas_info[species]["atlas"]:
            raise ValueError(f"不支持的图集:{atlas}。支持的图集列表为:{list(atlas_info[species]['atlas'].keys())}")

    # 创建Plot对象,用于绘制大脑皮层
    if surf != "flat":
        p = Plot(
            NEURODATA / atlas_info[species]["surf"]["lh"],
            NEURODATA / atlas_info[species]["surf"]["rh"],
        )
    else:
        # NOTE: 目前只有人和猕猴具有flat surface,暂时不支持黑猩猩
        p = Plot(
            NEURODATA / atlas_info[species]["surf"]["lh"],
            NEURODATA / atlas_info[species]["surf"]["rh"],
            views="dorsal",
            zoom=1.2,
        )
        lh_sulc_file = NEURODATA / atlas_info[species]["sulc"]["lh"]
        rh_sulc_file = NEURODATA / atlas_info[species]["sulc"]["rh"]
        p.add_layer(
            {
                "left": nib.load(lh_sulc_file).darrays[0].data,
                "right": nib.load(rh_sulc_file).darrays[0].data,
            },
            cmap="Grays_r",
            cbar=False,
        )

    # 分离左半球和右半球的数据
    hemisphere_data = {}
    for hemi in ["lh", "rh"]:
        hemi_data = {k: v for k, v in data.items() if k.startswith(f"{hemi}_")}
        hemi_parc = _map_labels_to_values(
            hemi_data, NEURODATA / atlas_info[species]["atlas"][atlas][hemi]
        )
        hemisphere_data[hemi] = hemi_parc

    # 画图
    # colorbar参数设置(用列表统一管理,便于维护)
    colorbar_params = [
        ("location", colorbar_location),
        ("label_direction", colorbar_label_rotation),
        ("decimals", colorbar_decimals),
        ("fontsize", colorbar_fontsize),
        ("n_ticks", colorbar_nticks),
        ("shrink", colorbar_shrink),
        ("aspect", colorbar_aspect),
        ("draw_border", colorbar_draw_border),
    ]
    colorbar_kws = {k: v for k, v in colorbar_params}
    # 添加图层到绘图对象
    p.add_layer(
        {"left": hemisphere_data["lh"], "right": hemisphere_data["rh"]},
        cbar=colorbar,
        cmap=cmap,
        color_range=(vmin, vmax),
        cbar_label=colorbar_label_name,
        zero_transparent=False,
        as_outline=as_outline,
    )
    # 构建坐标轴并应用颜色条设置
    ax = p.build_axis(ax=ax, cbar_kws=colorbar_kws)
    # 设置图形标题
    ax.set_title(title_name, fontsize=title_fontsize)

    return ax

circos

Functions:

Name Description
plot_circos_figure

绘制脑连接组的环形图(Circos plot)。

plot_circos_figure

plot_circos_figure(connectome: NDArray, ax: Axes | None = None, symmetric: bool = True, node_names: list[str] | None = None, node_colors: list[str] | None = None, node_space: float = 0.0, node_label_fontsize: int = 10, node_label_orientation: Literal['vertical', 'horizontal'] = 'horizontal', vmin: float | None = None, vmax: float | None = None, cmap: str | None = None, edge_color: str = 'red', edge_alpha: float = 1.0, colorbar: bool = True, colorbar_orientation: Literal['vertical', 'horizontal'] = 'vertical', colorbar_label: str = '') -> Figure | PolarAxes

绘制脑连接组的环形图(Circos plot)。

Parameters:

Name Type Description Default
connectome NDArray

脑连接矩阵,必须为对称方阵。形状为(n, n),其中n为脑区数量

required
ax Axes | None

matplotlib的极坐标轴对象,如果提供则在此轴上绘图。默认为None

None
symmetric bool

是否为对称布局(用于左右脑半球数据)。默认为True

True
node_names list[str] | None

脑区名称列表,长度应与connectome的维度一致。默认为None时自动生成"Node_1", "Node_2"...格式的名称

None
node_colors list[str] | None

脑区颜色列表,长度应与脑区数量一致。默认为None时自动生成颜色

None
node_space float

脑区间间隔角度(度)。默认为0.0

0.0
node_label_fontsize int

脑区标签字体大小。默认为10

10
node_label_orientation Literal['vertical', 'horizontal']

脑区标签方向。默认为"horizontal"

'horizontal'
vmin float | None

连接强度颜色映射的最小值。默认为None时根据数据自动确定

None
vmax float | None

连接强度颜色映射的最大值。默认为None时根据数据自动确定

None
cmap str | None

颜色映射表名称。默认为None时根据edge_color生成

None
edge_color str

连线颜色,当cmap为None时使用此颜色生成颜色映射。默认为"red"

'red'
edge_alpha float

连线透明度,范围0-1。默认为1.0(不透明)

1.0
colorbar bool

是否显示颜色条。默认为True

True
colorbar_orientation Literal['vertical', 'horizontal']

颜色条方向。默认为"vertical"

'vertical'
colorbar_label str

颜色条标签文本。默认为空字符串

''

Raises:

Type Description
ValueError

当connectome不是对称矩阵时抛出

ValueError

当vmin大于vmax时抛出

TypeError

当提供的ax不是PolarAxes类型时抛出

Returns:

Type Description
Figure | PolarAxes

Figure | Axes: 如果ax为None则返回Figure对象,否则返回Axes对象

Source code in src/plotfig/circos.py
def plot_circos_figure(
    connectome: NDArray,
    ax: Axes | None = None,
    symmetric: bool = True,
    node_names: list[str] | None = None,
    node_colors: list[str] | None = None,
    node_space: float = 0.0,
    node_label_fontsize: int = 10,
    node_label_orientation: Literal["vertical", "horizontal"] = "horizontal",
    vmin: float | None = None,
    vmax: float | None = None,
    cmap: str | None = None,
    edge_color: str = "red",
    edge_alpha: float = 1.0,
    colorbar: bool = True,
    colorbar_orientation: Literal["vertical", "horizontal"] = "vertical",
    colorbar_label: str = "",
) -> Figure | PolarAxes:
    """绘制脑连接组的环形图(Circos plot)。

    Args:
        connectome (NDArray): 脑连接矩阵,必须为对称方阵。形状为(n, n),其中n为脑区数量
        ax (Axes | None, optional): matplotlib的极坐标轴对象,如果提供则在此轴上绘图。默认为None
        symmetric (bool, optional): 是否为对称布局(用于左右脑半球数据)。默认为True
        node_names (list[str] | None, optional): 脑区名称列表,长度应与connectome的维度一致。默认为None时自动生成"Node_1", "Node_2"...格式的名称
        node_colors (list[str] | None, optional): 脑区颜色列表,长度应与脑区数量一致。默认为None时自动生成颜色
        node_space (float, optional): 脑区间间隔角度(度)。默认为0.0
        node_label_fontsize (int, optional): 脑区标签字体大小。默认为10
        node_label_orientation (Literal["vertical", "horizontal"], optional): 脑区标签方向。默认为"horizontal"
        vmin (float | None, optional): 连接强度颜色映射的最小值。默认为None时根据数据自动确定
        vmax (float | None, optional): 连接强度颜色映射的最大值。默认为None时根据数据自动确定
        cmap (str | None, optional): 颜色映射表名称。默认为None时根据edge_color生成
        edge_color (str, optional): 连线颜色,当cmap为None时使用此颜色生成颜色映射。默认为"red"
        edge_alpha (float, optional): 连线透明度,范围0-1。默认为1.0(不透明)
        colorbar (bool, optional): 是否显示颜色条。默认为True
        colorbar_orientation (Literal["vertical", "horizontal"], optional): 颜色条方向。默认为"vertical"
        colorbar_label (str, optional): 颜色条标签文本。默认为空字符串

    Raises:
        ValueError: 当connectome不是对称矩阵时抛出
        ValueError: 当vmin大于vmax时抛出
        TypeError: 当提供的ax不是PolarAxes类型时抛出

    Returns:
        Figure | Axes: 如果ax为None则返回Figure对象,否则返回Axes对象
    """

    # 检查输入矩阵,指定cmap
    if not is_symmetric_square(connectome):
        raise ValueError("connectome 不是对称矩阵")
    if np.all(connectome == 0):
        logger.warning("connectome 矩阵所有元素均为0,可能没有有效连接数据")
        vmax = float(0 if vmax is None else vmax)
        vmin = float(0 if vmin is None else vmin)
        colormap = (
            gen_white_to_color_cmap(edge_color) if cmap is None else plt.get_cmap(cmap)
        )
    elif np.any(connectome < 0):
        logger.warning(
            "由于 connectome 存在负值,连线颜色无法自定义,只能正值显示红色,负值显示蓝色"
        )
        max_strength = np.abs(connectome[connectome != 0]).max()
        vmax = float(max_strength if vmax is None else vmax)
        vmin = float(-max_strength if vmin is None else vmin)
        colormap = plt.get_cmap("bwr")
    else:
        vmin = float(connectome.min() if vmin is None else vmin)
        vmax = float(connectome.max() if vmax is None else vmax)
        colormap = (
            gen_white_to_color_cmap(edge_color) if cmap is None else plt.get_cmap(cmap)
        )
    if vmin > vmax:
        raise ValueError(f"目前{vmin=},而{vmax=}。但是vmin不得大于vmax,请检查数据")
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

    # 获取数据信息
    node_num = connectome.shape[0]
    node_names = _gen_node_name(connectome) if node_names is None else node_names

    # 由于pycirclize的特性,sector顺序只能为顺时针,因此需要将数据进行翻转
    connectome = np.flip(connectome)
    node_names = node_names[::-1]
    if symmetric:
        # 画对称图需额外做半球翻转处理
        node_colors = (
            gen_hex_colors(int(node_num / 2)) * 2
            if node_colors is None
            else node_colors[::-1]
        )
        connectome, node_names, node_colors = _process_sym(
            connectome, node_names, node_colors
        )
        sectors = _gen_sym_sectors(node_names)
    else:
        node_colors = (
            gen_hex_colors(node_num) if node_colors is None else node_colors[::-1]
        )
        sectors = {node_name: 1 for node_name in node_names}

    edges = _gen_edges(connectome)
    name2color = {
        node_name: node_color for node_name, node_color in zip(node_names, node_colors)
    }
    circos = Circos(sectors, space=node_space)

    # 设置扇区
    for sector in circos.sectors:
        if sector.name.startswith("_gap"):
            continue
        sector.text(
            sector.name, size=node_label_fontsize, orientation=node_label_orientation
        )
        track = sector.add_track((95, 100))
        track.axis(fc=name2color[sector.name])

    # 设置连接
    for edge in edges:
        color = value_to_hex(edge[2], colormap, norm)
        circos.link(
            (node_names[edge[0]], 0.45, 0.55),
            (node_names[edge[1]], 0.55, 0.45),
            color=color,
            alpha=edge_alpha,
        )

    # colorbar
    if colorbar:
        if colorbar_orientation == "vertical":
            orientation = "vertical"
            bounds = (1.1, 0.29, 0.02, 0.4)
            label_kws = dict(size=12, rotation=270, labelpad=20)
        else:
            orientation = "horizontal"
            bounds = (0.3, -0.1, 0.4, 0.03)
            label_kws = dict(size=12)
        circos.colorbar(
            bounds=bounds,
            orientation=orientation,
            vmin=vmin,
            vmax=vmax,
            cmap=colormap,
            label=colorbar_label,
            label_kws=label_kws,
            tick_kws=dict(labelsize=12),
        )

    # 画图
    if ax is None:
        fig = circos.plotfig()
        return fig
    else:
        if isinstance(ax, PolarAxes):
            circos.plotfig(ax=ax)
            return ax
        else:
            raise ValueError("ax 不是 PolarAxes 类型")

brain_connection

Functions:

Name Description
batch_crop_images

批量裁剪指定目录下的图像文件。

create_gif_from_images

从指定文件夹中的图片生成 GIF 文件。

plot_brain_connection_figure

绘制大脑连接图,保存在指定的html文件中。

save_brain_connection_frames

生成不同角度的静态图片帧,可用于制作旋转大脑连接图的 GIF。

batch_crop_images

batch_crop_images(directory_path: Path, suffix: str = '_cropped', left_frac: float = 0.25, top_frac: float = 0.25, right_frac: float = 0.25, bottom_frac: float = 0.25)

批量裁剪指定目录下的图像文件。

Parameters:

Name Type Description Default
directory_path Path

图像文件所在的目录路径。

required
suffix str

新文件名后缀。默认为 "_cropped"。

'_cropped'
left_frac float

左侧裁剪比例(0-1)。默认为 0.2。

0.25
top_frac float

上侧裁剪比例(0-1)。默认为 0.15。

0.25
right_frac float

右侧裁剪比例(0-1)。默认为 0.2。

0.25
bottom_frac float

下侧裁剪比例(0-1)。默认为 0.15。

0.25
Notes

支持常见图像格式 (PNG, JPG, JPEG, BMP, TIFF)。 裁剪后的图像将保存在原目录中,并添加指定后缀,原图像保持不变。 所有裁剪均基于图像尺寸的百分比计算,无绝对像素值。

Source code in src/plotfig/brain_connection.py
def batch_crop_images(
    directory_path: Path,
    suffix: str = "_cropped",
    left_frac: float = 0.25,
    top_frac: float = 0.25,
    right_frac: float = 0.25,
    bottom_frac: float = 0.25,
):
    """
    批量裁剪指定目录下的图像文件。

    Args:
        directory_path (Path): 图像文件所在的目录路径。
        suffix (str, optional): 新文件名后缀。默认为 "_cropped"。
        left_frac (float, optional): 左侧裁剪比例(0-1)。默认为 0.2。
        top_frac (float, optional): 上侧裁剪比例(0-1)。默认为 0.15。
        right_frac (float, optional): 右侧裁剪比例(0-1)。默认为 0.2。
        bottom_frac (float, optional): 下侧裁剪比例(0-1)。默认为 0.15。

    Notes:
        支持常见图像格式 (PNG, JPG, JPEG, BMP, TIFF)。
        裁剪后的图像将保存在原目录中,并添加指定后缀,原图像保持不变。
        所有裁剪均基于图像尺寸的百分比计算,无绝对像素值。
    """
    supported_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"}

    for image_path in directory_path.rglob("*"):
        if (
            not image_path.is_file()
            or image_path.suffix.lower() not in supported_extensions
        ):
            continue

        new_file_name = image_path.stem + suffix + image_path.suffix

        try:
            figure = Image.open(image_path)
            width, height = figure.size
            print(f"图像宽度:{width},高度:{height}")

            left = int(width * left_frac)
            right = int(width * (1 - right_frac))
            top = int(height * top_frac)
            bottom = int(height * (1 - bottom_frac))

            # 裁切图像
            cropped_fig = figure.crop((left, top, right, bottom))
            # 保存裁切后的图像
            cropped_fig.save(image_path.parent / new_file_name)

            figure.close()
            cropped_fig.close()
        except Exception as e:
            print(f"处理文件 {image_path.name} 时出错: {e}")

create_gif_from_images

create_gif_from_images(folder_path: str | Path, output_name: str = 'output.gif', fps: int = 10) -> None

从指定文件夹中的图片生成 GIF 文件。

Parameters:

Name Type Description Default
folder_path str | Path

图片所在文件夹路径

required
output_name str

输出 GIF 文件名,默认为 "output.gif"

'output.gif'
fps int

GIF 帧率,默认为 10

10
Source code in src/plotfig/brain_connection.py
def create_gif_from_images(
    folder_path: str | Path,
    output_name: str = "output.gif",
    fps: int = 10,
) -> None:
    """
    从指定文件夹中的图片生成 GIF 文件。

    Args:
        folder_path (str | Path): 图片所在文件夹路径
        output_name (str, optional): 输出 GIF 文件名,默认为 "output.gif"
        fps (int, optional): GIF 帧率,默认为 10
    """
    folder = Path(folder_path)
    if not folder.exists() or not folder.is_dir():
        raise ValueError(f"{folder} 不是有效的文件夹路径。")

    # 获取文件夹下指定扩展名的文件,并排序
    extensions = (".png", ".jpg", ".jpeg")
    figures_path = sorted(
        [f for f in folder.iterdir() if f.suffix.lower() in extensions]
    )

    if not figures_path:
        raise ValueError(f"文件夹 {folder} 中没有找到符合 {extensions} 的图片文件。")

    figures = [Image.open(f) for f in figures_path]

    # 输出 GIF 路径
    output_path = folder / output_name

    # 创建 GIF
    with imageio.get_writer(output_path, mode="I", fps=fps, loop=0) as writer:
        for figure in figures:
            writer.append_data(figure.convert("RGB"))

    logger.info(f"GIF 已保存到: {output_path}")

plot_brain_connection_figure

plot_brain_connection_figure(connectome: NDArray, lh_surfgii_file: str | Path, rh_surfgii_file: str | Path, niigz_file: str | Path, output_file: str | Path | None = None, show_all_nodes: bool = False, nodes_size: Sequence[Num] | NDArray | None = None, nodes_name: list[str] | None = None, nodes_color: list[str] | None = None, scale_method: Literal['', 'width', 'color', 'width_color', 'color_width'] = '', line_width: Num = 10, line_color: str = 'red') -> Figure

绘制大脑连接图,保存在指定的html文件中。

Parameters:

Name Type Description Default
connectome NDArray

大脑连接矩阵,形状为 (n, n),其中 n 是脑区数量。 矩阵中的值表示脑区之间的连接强度,正值表示正相关连接,负值表示负相关连接,0表示无连接。

required
lh_surfgii_file str | Path

左半脑表面几何文件路径 (.surf.gii 格式),用于绘制左半脑表面

required
rh_surfgii_file str | Path

右半脑表面几何文件路径 (.surf.gii 格式),用于绘制右半脑表面

required
niigz_file str | Path

NIfTI格式的脑区图谱文件路径 (.nii.gz 格式),用于定位脑区节点的三维坐标

required
output_file str | Path | None

输出HTML文件路径。如果未指定,则使用当前时间戳生成文件名。默认为None

None
show_all_nodes bool

是否显示所有脑区节点。如果为False,则只显示有连接的节点。默认为False

False
nodes_size Sequence[Num] | NDArray | None

每个节点的大小,长度应与脑区数量一致。默认为None,即所有节点大小为5

None
nodes_name list[str] | None

每个节点的名称标签,长度应与脑区数量一致。默认为None,即不显示名称

None
nodes_color list[str] | None

每个节点的颜色,长度应与脑区数量一致。默认为None,即所有节点为白色

None
scale_method Literal['', 'width', 'color', 'width_color', 'color_width']

连接线的缩放方法: - "" : 所有连接线宽度和颜色固定 - "width" : 根据连接强度调整线宽,正连接为红色,负连接为蓝色 - "color" : 根据连接强度调整颜色(使用蓝白红颜色映射),线宽固定 - "width_color" or "color_width" : 同时根据连接强度调整线宽和颜色 默认为 ""

''
line_width Num

连接线的基本宽度。当scale_method包含"width"时,此值作为最大宽度参考。默认为10

10
line_color str

连接线的基本颜色。当scale_method不包含"color"时生效。默认为"#ff0000"(红色)

'red'

Returns:

Type Description
Figure

go.Figure: Plotly图形对象,包含绘制的大脑连接图

Source code in src/plotfig/brain_connection.py
def plot_brain_connection_figure(
    connectome: npt.NDArray,
    lh_surfgii_file: str | Path,
    rh_surfgii_file: str | Path,
    niigz_file: str | Path,
    output_file: str | Path | None = None,
    show_all_nodes: bool = False,
    nodes_size: Sequence[Num] | npt.NDArray | None = None,
    nodes_name: list[str] | None = None,
    nodes_color: list[str] | None = None,
    scale_method: Literal["", "width", "color", "width_color", "color_width"] = "",
    line_width: Num = 10,
    line_color: str = "red",
) -> go.Figure:
    """绘制大脑连接图,保存在指定的html文件中。

    Args:
        connectome (npt.NDArray):
            大脑连接矩阵,形状为 (n, n),其中 n 是脑区数量。
            矩阵中的值表示脑区之间的连接强度,正值表示正相关连接,负值表示负相关连接,0表示无连接。
        lh_surfgii_file (str | Path):
            左半脑表面几何文件路径 (.surf.gii 格式),用于绘制左半脑表面
        rh_surfgii_file (str | Path):
            右半脑表面几何文件路径 (.surf.gii 格式),用于绘制右半脑表面
        niigz_file (str | Path):
            NIfTI格式的脑区图谱文件路径 (.nii.gz 格式),用于定位脑区节点的三维坐标
        output_file (str | Path | None, optional):
            输出HTML文件路径。如果未指定,则使用当前时间戳生成文件名。默认为None
        show_all_nodes (bool, optional):
            是否显示所有脑区节点。如果为False,则只显示有连接的节点。默认为False
        nodes_size (Sequence[Num] | npt.NDArray | None, optional):
            每个节点的大小,长度应与脑区数量一致。默认为None,即所有节点大小为5
        nodes_name (list[str] | None, optional):
            每个节点的名称标签,长度应与脑区数量一致。默认为None,即不显示名称
        nodes_color (list[str] | None, optional):
            每个节点的颜色,长度应与脑区数量一致。默认为None,即所有节点为白色
        scale_method (Literal["", "width", "color", "width_color", "color_width"], optional):
            连接线的缩放方法:
            - "" : 所有连接线宽度和颜色固定
            - "width" : 根据连接强度调整线宽,正连接为红色,负连接为蓝色
            - "color" : 根据连接强度调整颜色(使用蓝白红颜色映射),线宽固定
            - "width_color" or "color_width" : 同时根据连接强度调整线宽和颜色
            默认为 ""
        line_width (Num, optional):
            连接线的基本宽度。当scale_method包含"width"时,此值作为最大宽度参考。默认为10
        line_color (str, optional):
            连接线的基本颜色。当scale_method不包含"color"时生效。默认为"#ff0000"(红色)

    Returns:
        go.Figure: Plotly图形对象,包含绘制的大脑连接图
    """
    _validate_connectome(connectome)

    if np.any(connectome < 0):
        logger.warning(
            "由于 connectome 存在负值,连线颜色无法自定义,只能正值显示红色,负值显示蓝色"
        )
        line_color = "#ff0000"

    nodes_num = connectome.shape[0]
    nodes_name = nodes_name or [""] * nodes_num
    nodes_color = nodes_color or ["white"] * nodes_num
    nodes_size = nodes_size or [5] * nodes_num

    if output_file is None:
        timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        output_file = Path(f"{timestamp}.html")
        logger.info(f"未指定保存路径,默认保存在当前文件夹下的{output_file}中。")

    node_indices = _get_node_indices(connectome, show_all_nodes)
    vertices_L, faces_L = _load_surface(lh_surfgii_file)
    vertices_R, faces_R = _load_surface(rh_surfgii_file)

    mesh_L = _create_mesh(vertices_L, faces_L, "Left Hemisphere")
    mesh_R = _create_mesh(vertices_R, faces_R, "Right Hemisphere")

    fig = go.Figure(data=[mesh_L, mesh_R])

    centroids_real = _get_centroids_real(niigz_file)
    _add_nodes_to_fig(
        fig, centroids_real, node_indices, nodes_name, nodes_size, nodes_color
    )
    _add_edges_to_fig(
        fig,
        connectome,
        centroids_real,
        nodes_name,
        scale_method,
        line_width,
        line_color,
    )
    _finalize_figure(fig)

    fig.write_html(output_file)
    return fig

save_brain_connection_frames

save_brain_connection_frames(fig: Figure, output_dir: str | Path, n_frames: int = 36) -> None

生成不同角度的静态图片帧,可用于制作旋转大脑连接图的 GIF。

Parameters:

Name Type Description Default
fig Figure

Plotly 的 Figure 对象,包含大脑表面和连接图。

required
output_dir str

图片保存的文件夹路径,若文件夹不存在则自动创建。

required
n_frames int

旋转帧的数量。默认 36,即每 10 度一帧。

36
Source code in src/plotfig/brain_connection.py
def save_brain_connection_frames(
    fig: go.Figure, output_dir: str | Path, n_frames: int = 36
) -> None:
    """
    生成不同角度的静态图片帧,可用于制作旋转大脑连接图的 GIF。

    Args:
        fig (go.Figure): Plotly 的 Figure 对象,包含大脑表面和连接图。
        output_dir (str): 图片保存的文件夹路径,若文件夹不存在则自动创建。
        n_frames (int, optional): 旋转帧的数量。默认 36,即每 10 度一帧。
    """
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    angles = np.linspace(0, 360, n_frames, endpoint=False)
    for i, angle in tqdm(enumerate(angles), total=len(angles)):
        camera = dict(
            eye=dict(
                x=2 * np.cos(np.radians(angle)), y=2 * np.sin(np.radians(angle)), z=0.7
            )
        )
        fig.update_layout(scene_camera=camera)
        pio.write_image(fig, f"{output_dir}/frame_{i:03d}.png", width=800, height=800)
    logger.info(f"保存了 {n_frames} 张图片在 {output_dir}")

utils

Modules:

Name Description
bar
color
matrix

Functions:

Name Description
gen_hex_colors

生成指定数量的随机十六进制颜色代码。

gen_symmetric_matrix

生成一个对称方阵,可以指定元素范围和稀疏度。

gen_white_to_color_cmap

生成从白色到指定颜色的线性渐变色图。

is_symmetric_square

判断一个矩阵是否为对称方阵。

value_to_hex

将数值通过色图和归一化映射为十六进制颜色字符串。

gen_hex_colors

gen_hex_colors(n: int, seed: int = 42) -> list[str]

生成指定数量的随机十六进制颜色代码。

该函数使用 NumPy 的随机数生成器创建随机 RGB 颜色值,并转换为十六进制格式。 通过固定随机种子确保结果可重复,适用于需要一致配色方案的科学可视化。

Parameters:

Name Type Description Default
n int

需要生成的颜色数量,必须为正整数。

required
seed int

随机种子,用于确保结果可重复。默认为 42。

42

Returns:

Type Description
list[str]

list[str]: 包含 n 个十六进制颜色字符串的列表,格式为 "#RRGGBB"。

Examples:

>>> # 生成 3 个随机颜色
>>> colors = gen_hex_colors(3)
>>> print(colors)
['#66a0a9', '#8b7d3a', '#c94f6d']
>>> # 使用不同的随机种子
>>> colors = gen_hex_colors(5, seed=123)
>>> len(colors)
5
>>> # 在绘图中使用
>>> import matplotlib.pyplot as plt
>>> colors = gen_hex_colors(10)
>>> for i, color in enumerate(colors):
...     plt.bar(i, i+1, color=color)
Notes
  • RGB 值范围为 [0, 255]
  • 相同的 n 和 seed 参数总是生成相同的颜色序列
  • 生成的颜色是完全随机的,可能包含对比度较低的颜色
Source code in src/plotfig/utils/color.py
def gen_hex_colors(n: int, seed: int = 42) -> list[str]:
    """生成指定数量的随机十六进制颜色代码。

    该函数使用 NumPy 的随机数生成器创建随机 RGB 颜色值,并转换为十六进制格式。
    通过固定随机种子确保结果可重复,适用于需要一致配色方案的科学可视化。

    Args:
        n (int): 需要生成的颜色数量,必须为正整数。
        seed (int): 随机种子,用于确保结果可重复。默认为 42。

    Returns:
        list[str]: 包含 n 个十六进制颜色字符串的列表,格式为 "#RRGGBB"。

    Examples:
        >>> # 生成 3 个随机颜色
        >>> colors = gen_hex_colors(3)
        >>> print(colors)
        ['#66a0a9', '#8b7d3a', '#c94f6d']

        >>> # 使用不同的随机种子
        >>> colors = gen_hex_colors(5, seed=123)
        >>> len(colors)
        5

        >>> # 在绘图中使用
        >>> import matplotlib.pyplot as plt
        >>> colors = gen_hex_colors(10)
        >>> for i, color in enumerate(colors):
        ...     plt.bar(i, i+1, color=color)

    Notes:
        - RGB 值范围为 [0, 255]
        - 相同的 n 和 seed 参数总是生成相同的颜色序列
        - 生成的颜色是完全随机的,可能包含对比度较低的颜色
    """

    RNG = np.random.default_rng(seed=seed)
    rgb = RNG.integers(0, 256, size=(n, 3))  # n×3 的整数矩阵
    colors = [f"#{r:02x}{g:02x}{b:02x}" for r, g, b in rgb]
    return colors

gen_symmetric_matrix

gen_symmetric_matrix(n, mode='nonneg', sparsity: float = 1.0, seed: int = 42) -> NDArray

生成一个对称方阵,可以指定元素范围和稀疏度。

Parameters:

Name Type Description Default
n int

方阵的维度。

required
mode str

元素类型,"nonneg" 表示非负,"all" 表示可为负。默认为 "nonneg"。

'nonneg'
sparsity float

稀疏度,取值范围 [0, 1],1 表示全密集,0 表示全零。默认为 1.0。

1.0
seed int

随机种子,默认为 42。

42

Raises:

Type Description
ValueError

如果 mode 不是 "nonneg" 或 "all"。

Returns:

Name Type Description
NDArray NDArray

生成的对称方阵。

Source code in src/plotfig/utils/matrix.py
def gen_symmetric_matrix(
    n, mode="nonneg", sparsity: float = 1.0, seed: int = 42
) -> NDArray:
    """
    生成一个对称方阵,可以指定元素范围和稀疏度。

    Args:
        n (int): 方阵的维度。
        mode (str, optional): 元素类型,"nonneg" 表示非负,"all" 表示可为负。默认为 "nonneg"。
        sparsity (float, optional): 稀疏度,取值范围 [0, 1],1 表示全密集,0 表示全零。默认为 1.0。
        seed (int, optional): 随机种子,默认为 42。

    Raises:
        ValueError: 如果 mode 不是 "nonneg" 或 "all"。

    Returns:
        NDArray: 生成的对称方阵。
    """
    # 创建随机数生成器
    RNG = np.random.default_rng(seed=seed)
    # 生成权重矩阵上三角
    if mode == "nonneg":
        upper = np.triu(RNG.random((n, n)), k=1)
    elif mode == "all":
        upper = np.triu(RNG.uniform(-1, 1, size=(n, n)), k=1)
    else:
        raise ValueError("mode must be 'nonneg' or 'all'")
    # 稀疏化:随机生成mask
    if sparsity < 1.0:
        mask = RNG.random((n, n)) < sparsity
        mask = np.triu(mask, k=1)  # 上三角mask
        upper *= mask
    # 构造对称矩阵
    mat = upper + upper.T
    np.fill_diagonal(mat, 0.0)
    return mat

gen_white_to_color_cmap

gen_white_to_color_cmap(color: str = 'red') -> Colormap

生成从白色到指定颜色的线性渐变色图。

该函数创建一个双色线性渐变色图,从白色(最小值)平滑过渡到指定颜色(最大值)。 适用于热力图、相关矩阵、脑连接图等需要表示连续数值强度的可视化场景。

Parameters:

Name Type Description Default
color str

渐变的目标颜色,支持 matplotlib 颜色名称(如 "red", "blue")、 十六进制格式(如 "#FF0000")或 RGB 元组。默认为 "red"。

'red'

Returns:

Name Type Description
Colormap Colormap

matplotlib 的线性渐变色图对象,可直接用于绘图函数。

Examples:

>>> # 生成红色渐变色图
>>> cmap = gen_cmap("red")
>>> 
>>> # 在热力图中使用
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> data = np.random.rand(10, 10)
>>> plt.imshow(data, cmap=gen_cmap("blue"))
>>> plt.colorbar()
>>> 
>>> # 使用十六进制颜色
>>> cmap = gen_cmap("#FF5733")
>>> 
>>> # 在散点图中使用
>>> x = np.random.rand(100)
>>> y = np.random.rand(100)
>>> c = np.random.rand(100)
>>> plt.scatter(x, y, c=c, cmap=gen_cmap("purple"))
Notes
  • 色图名称固定为 "white_to_color"
  • 渐变是线性的,白色对应数值范围的最小值,指定颜色对应最大值
  • 支持所有 matplotlib 认可的颜色格式
  • 常用于表示正值数据,如相关系数的绝对值、连接强度等
See Also

value_to_hex : 将数值映射为十六进制颜色 matplotlib.colors.LinearSegmentedColormap : 底层实现类

Source code in src/plotfig/utils/color.py
def gen_white_to_color_cmap(color: str = "red") -> Colormap:
    """生成从白色到指定颜色的线性渐变色图。

    该函数创建一个双色线性渐变色图,从白色(最小值)平滑过渡到指定颜色(最大值)。
    适用于热力图、相关矩阵、脑连接图等需要表示连续数值强度的可视化场景。

    Args:
        color (str): 渐变的目标颜色,支持 matplotlib 颜色名称(如 "red", "blue")、
            十六进制格式(如 "#FF0000")或 RGB 元组。默认为 "red"。

    Returns:
        Colormap: matplotlib 的线性渐变色图对象,可直接用于绘图函数。

    Examples:
        >>> # 生成红色渐变色图
        >>> cmap = gen_cmap("red")
        >>> 
        >>> # 在热力图中使用
        >>> import matplotlib.pyplot as plt
        >>> import numpy as np
        >>> data = np.random.rand(10, 10)
        >>> plt.imshow(data, cmap=gen_cmap("blue"))
        >>> plt.colorbar()
        >>> 
        >>> # 使用十六进制颜色
        >>> cmap = gen_cmap("#FF5733")
        >>> 
        >>> # 在散点图中使用
        >>> x = np.random.rand(100)
        >>> y = np.random.rand(100)
        >>> c = np.random.rand(100)
        >>> plt.scatter(x, y, c=c, cmap=gen_cmap("purple"))

    Notes:
        - 色图名称固定为 "white_to_color"
        - 渐变是线性的,白色对应数值范围的最小值,指定颜色对应最大值
        - 支持所有 matplotlib 认可的颜色格式
        - 常用于表示正值数据,如相关系数的绝对值、连接强度等

    See Also:
        value_to_hex : 将数值映射为十六进制颜色
        matplotlib.colors.LinearSegmentedColormap : 底层实现类
    """

    cmap = LinearSegmentedColormap.from_list("white_to_color", ["white", color])
    return cmap

is_symmetric_square

is_symmetric_square(matrix: NDArray, tol: float = 1e-08) -> bool

判断一个矩阵是否为对称方阵。

Parameters:

Name Type Description Default
matrix NDArray

待判断的矩阵。

required
tol float

判断对称性的容差,默认为 1e-8。

1e-08

Returns:

Name Type Description
bool bool

如果是对称方阵则返回 True,否则返回 False。

Source code in src/plotfig/utils/matrix.py
def is_symmetric_square(matrix: NDArray, tol: float = 1e-8) -> bool:
    """
    判断一个矩阵是否为对称方阵。

    Args:
        matrix (NDArray): 待判断的矩阵。
        tol (float, optional): 判断对称性的容差,默认为 1e-8。

    Returns:
        bool: 如果是对称方阵则返回 True,否则返回 False。
    """
    # 1. 检查是否为方阵
    if matrix.ndim != 2 or matrix.shape[0] != matrix.shape[1]:
        return False

    # 2. 检查是否对称
    return np.allclose(matrix, matrix.T, atol=tol)

value_to_hex

value_to_hex(value: float, cmap: Colormap, norm: Normalize) -> str

将数值通过色图和归一化映射为十六进制颜色字符串。

该函数实现了从数值到颜色的完整映射流程:首先使用归一化对象将数值映射到 [0, 1] 区间, 然后通过色图将归一化值转换为 RGBA 颜色,最后转换为十六进制格式。常用于根据数据强度 动态生成颜色,如脑连接图中根据连接强度着色连线。

Parameters:

Name Type Description Default
value float

需要映射的原始数值,可以是任意范围的浮点数。

required
cmap Colormap

matplotlib 色图对象,定义了归一化值到颜色的映射关系。

required
norm Normalize

matplotlib 归一化对象,定义了原始数值到 [0, 1] 的映射规则。 常用类型包括 Normalize(线性)、LogNorm(对数)等。

required

Returns:

Name Type Description
str str

十六进制颜色字符串,格式为 "#RRGGBB"。

Examples:

>>> import matplotlib.pyplot as plt
>>> from matplotlib.colors import Normalize
>>> 
>>> # 创建色图和归一化对象
>>> cmap = gen_cmap("red")
>>> norm = Normalize(vmin=0, vmax=100)
>>> 
>>> # 将数值映射为颜色
>>> color1 = value_to_hex(25, cmap, norm)  # 浅红色
>>> color2 = value_to_hex(75, cmap, norm)  # 深红色
>>> print(color1, color2)
'#ffbfbf' '#bf4040'
>>> 
>>> # 在脑连接图中使用
>>> connection_strengths = [0.1, 0.5, 0.9]
>>> norm = Normalize(vmin=0, vmax=1)
>>> colors = [value_to_hex(s, cmap, norm) for s in connection_strengths]
>>> 
>>> # 使用对数归一化
>>> from matplotlib.colors import LogNorm
>>> norm_log = LogNorm(vmin=1, vmax=1000)
>>> color = value_to_hex(100, cmap, norm_log)
Notes
  • 归一化对象决定了数值如何映射到 [0, 1] 区间
  • 色图决定了 [0, 1] 区间的值如何映射到具体颜色
  • 返回的十六进制颜色可直接用于 matplotlib 或 plotly 绘图
  • 如果 value 超出归一化对象的范围,会被裁剪到边界值
See Also

gen_cmap : 生成线性渐变色图 matplotlib.colors.Normalize : 线性归一化 matplotlib.colors.LogNorm : 对数归一化

Source code in src/plotfig/utils/color.py
def value_to_hex(value: float, cmap: Colormap, norm: Normalize) -> str:
    """将数值通过色图和归一化映射为十六进制颜色字符串。

    该函数实现了从数值到颜色的完整映射流程:首先使用归一化对象将数值映射到 [0, 1] 区间,
    然后通过色图将归一化值转换为 RGBA 颜色,最后转换为十六进制格式。常用于根据数据强度
    动态生成颜色,如脑连接图中根据连接强度着色连线。

    Args:
        value (float): 需要映射的原始数值,可以是任意范围的浮点数。
        cmap (Colormap): matplotlib 色图对象,定义了归一化值到颜色的映射关系。
        norm (Normalize): matplotlib 归一化对象,定义了原始数值到 [0, 1] 的映射规则。
            常用类型包括 Normalize(线性)、LogNorm(对数)等。

    Returns:
        str: 十六进制颜色字符串,格式为 "#RRGGBB"。

    Examples:
        >>> import matplotlib.pyplot as plt
        >>> from matplotlib.colors import Normalize
        >>> 
        >>> # 创建色图和归一化对象
        >>> cmap = gen_cmap("red")
        >>> norm = Normalize(vmin=0, vmax=100)
        >>> 
        >>> # 将数值映射为颜色
        >>> color1 = value_to_hex(25, cmap, norm)  # 浅红色
        >>> color2 = value_to_hex(75, cmap, norm)  # 深红色
        >>> print(color1, color2)
        '#ffbfbf' '#bf4040'
        >>> 
        >>> # 在脑连接图中使用
        >>> connection_strengths = [0.1, 0.5, 0.9]
        >>> norm = Normalize(vmin=0, vmax=1)
        >>> colors = [value_to_hex(s, cmap, norm) for s in connection_strengths]
        >>> 
        >>> # 使用对数归一化
        >>> from matplotlib.colors import LogNorm
        >>> norm_log = LogNorm(vmin=1, vmax=1000)
        >>> color = value_to_hex(100, cmap, norm_log)

    Notes:
        - 归一化对象决定了数值如何映射到 [0, 1] 区间
        - 色图决定了 [0, 1] 区间的值如何映射到具体颜色
        - 返回的十六进制颜色可直接用于 matplotlib 或 plotly 绘图
        - 如果 value 超出归一化对象的范围,会被裁剪到边界值

    See Also:
        gen_cmap : 生成线性渐变色图
        matplotlib.colors.Normalize : 线性归一化
        matplotlib.colors.LogNorm : 对数归一化
    """

    rgba = cmap(norm(value))  # 得到 RGBA
    return mcolors.to_hex(rgba)