For a project I am making some 3D scatter plots with the three corresponding projections under it. I use different colors to indicate a fourth parameter. First I plot data with a certain color and then I overplot that with other data with a different color, so that in the end the order is such that I can see everything as I want:

This is what I want!

In the beginning this worked fine, but when I try to do the same thing with slightly different data, the colors get messed up. The colors shown in the projections are the right ones, but some of them are missing in the 3D plot so they don’t match anymore:

Colors messed up :(

When I rotate the 3D plot in a funny way, the colors are recovered and I can see them as they were supposed to be:

Funny rotation

However, I don’t want a 3D plot that is rotated in a funny way, because the axes get messed up and it’s impossible to read it properly like that.

I found one solution to the problem here:
plotting 3d scatter in matplotlib.
It basically says that I should replace my ax.scatter(X,Y) with ax.plot(X,Y,’o’). When I do this the colors are shown the way they were supposed to be, but the plot is much messier and uglier this way.
Basically I just want to be able to do this with a scatter plot.

Does anyone know how to solve this?

Here’s a minimum example of my code, for only two colors:

from mpl_toolkits.mplot3d import art3d
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib import gridspec

art3d.zalpha = lambda *args:args[0]

numcols = 20
percentage = 50

def load(Td, pc):
    T = np.load(str(pc) + 'pctTemperaturesTd=' + str(Td) + '.npy')
    D = np.load(str(pc) + 'pctDensitiesTd=' + str(Td) + '.npy')
    CD = np.load(str(pc) + 'pctColDensitiesTd=' + str(Td) + '.npy')
    return T, D, CD

def colors(ax):
    colors = np.zeros((numcols, 4))
    cm = plt.get_cmap('gist_rainbow')
    ax.set_color_cycle([cm(1.*i/numcols) for i in range(numcols)])
    for i in range(numcols):
        color = cm(1.*i/numcols)
        colors[i,:] = color
    return colors

# LOAD DATA
T10, D10, CD10 = load(10, percentage)
T200, D200, CD200 = load(200, percentage)

# 3D PLOT
fig = plt.figure(1)
gs = gridspec.GridSpec(4, 4) 

ax = fig.add_subplot(gs[:-1,:-1], projection='3d')
colours = colors(ax)

ax.plot(T200/100., np.log10(D200), np.log10(CD200), '*', markersize=10,color=colours[10], mec = colours[10], label="Td = 200", alpha=1)
ax.plot(T10/100., np.log10(D10), np.log10(CD10), '*', markersize=10,color=colours[0], mec = colours[0], label="Td = 10", alpha=1)

ax.set_xlabel('\nTg/100', fontsize="x-large")
ax.set_ylabel('\nlog(nH)', fontsize="x-large")
ax.set_zlabel('\nlog(colDen)', fontsize="x-large")
ax.set_xlim(0,5)
#ax.set_zlim(0,)
ax.set_ylim(2,6)

# PROJECTIONS

# Tg, nH
ax2 = fig.add_subplot(gs[3,0])

ax2.scatter(T200/100., np.log10(D200), marker="*", s=10, color=colours[10], label="Td = 200", alpha=1, edgecolor=colours[10])
ax2.scatter(T10/100., np.log10(D10), marker="*", s=10, color=colours[0], label="Td = 10", alpha=1, edgecolor=colours[0])

ax2.set_xlabel('Tg/100')
ax2.set_ylabel('log(nH)')
ax2.set_xlim(0,6)

# Tg, colDen 
ax3 = fig.add_subplot(gs[3,1])

ax3.scatter(T200/100., np.log10(CD200), marker="*", s=10, color=colours[10], label="Td = 200", alpha=1, edgecolor=colours[10])
ax3.scatter(T10/100., np.log10(CD10), marker="*", s=10, color=colours[0], label="Td = 10", alpha=1, edgecolor=colours[0])

ax3.set_xlabel('Tg/100')
ax3.set_ylabel('log(colDen)')
ax3.set_xlim(0,6)

# nH, colDen
ax4 = fig.add_subplot(gs[3,2])

ax4.scatter(np.log10(D200), np.log10(CD200), marker="*", s=10, color=colours[10], label="Td = 200", alpha=1, edgecolor=colours[10])
ax4.scatter(np.log10(D10), np.log10(CD10), marker="*", s=10, color=colours[0], label="Td = 10", alpha=1, edgecolor=colours[0])

ax4.set_xlabel('log(nH)')
ax4.set_ylabel('log(colDen)')

# LEGEND
legend = fig.add_subplot(gs[:,3])

text = ['Td = 10', 'Td = 20', 'Td = 30', 'Td = 40', 'Td = 50', 'Td = 60', 'Td = 70', 'Td = 80', 'Td = 90', 'Td = 100', 'Td = 110', 'Td = 120', 'Td = 130', 'Td = 140', 'Td = 150', 'Td = 160', 'Td = 170', 'Td = 180', 'Td = 190', 'Td = 200']

array = np.arange(0,2,0.1)
for i in range(len(array)):
    legend.scatter(0, i, marker="*", s=100, c=colours[numcols-i-1], edgecolor=colours[numcols-i-1])
    legend.text(0.3, i-0.25, text[numcols-i-1])
    legend.set_xlim(-0.5, 2.5)
    legend.set_ylim(0-1, i+1)

legend.axes.get_xaxis().set_visible(False)
legend.axes.get_yaxis().set_visible(False)

gs.tight_layout(fig)

plt.show()

Rather than using ax.plot(x,y, 'o') try ax.plot(x,y,'.') or ax.plot(x,y,'*'. The 'o' is specifying the marker to use, and the ‘o’ marker is a large filled circle, which is why your plot looks ugly.