import os from os import path from glob import glob import argparse import multiprocessing import time import traceback import math import numpy as np from sklearn.cluster import MeanShift, estimate_bandwidth from itertools import cycle ''' I do something like: find path/unit_loc_dump/10 -name 'TL_*.rep' |xargs -n1 -I% -P10 python cluster.py -i % -o path/battles/cluster/10/ -x 200 -y 200 -t 20 -b 0.5 ''' parser = argparse.ArgumentParser(description='Cluster some starcraft dumped replays. Each t is actually 3 frames combined') parser.add_argument('-i', '--input', required=True, help='input glob') parser.add_argument('-o', '--output', required=True, help='output folder') # for -x, -y, -t, if you have none of --mrel, --trel, or --unit, then it behaves # like division: x /= (args.x) is done. If you do have one of the --*rel or # --unit options, instead it does x *= (args.x) parser.add_argument('-x', '--x_scale', default=1, type=float, help='scale x axis, x /= x_scale') parser.add_argument('-y', '--y_scale', default=1, type=float, help='scale y ayis, y /= y_scale') parser.add_argument('-t', '--t_scale', default=1, type=float, help='scale t atis, t /= t_scale') parser.add_argument('-b', '--bandwidth', default=-1, type=float, help='Bandwidth for mean shift, use a negative number to force autodetection') parser.add_argument('--mrel', action='store_true', default=False, help='x_scale and y_scale accepts a decimal, and x and y are scaled to 1') parser.add_argument('--trel', action='store_true', default=False, help='t_scale accepts a decimal, and t is scaled to 1') parser.add_argument('--unit', action='store_true', default=False, help='Use N(0, 1) normalization, incomaptible with mrel and trel.' '--{x,y,t}_scale are used as decimals') parser.add_argument('--min_deaths', default=3, type=int, help='How many deaths in each cluster is a "battle"') # This parameter is pretty sensitive, I found that 0.5 is too high. I haven't # experimented whether it's too low or the averaging centers approach is too # heavy-handed. parser.add_argument('--merge_sim', default=0.4, type=float, help='If two bounding boxes are more than this similar (via Jaccard), merge them') parser.add_argument('--bound_with_deaths', default=False, action='store_true', help='Build bounding boxes with just deaths') parser.add_argument('--t_padding', default=2, type=float, help='Seconds before and after deaths to pad to') parser.add_argument('-s', '--show', action='store_true', default=False, help='Whether to show plot or not') args = parser.parse_args() import matplotlib if not args.show: matplotlib.use('Agg') import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # Only the combat units id2unit = [ ("Terran_Marine", "0"), ("Terran_Ghost", "1"), ("Terran_Vulture", "2"), ("Terran_Goliath", "3"), ("Terran_Siege_Tank_Tank_Mode", "5"), ("Terran_SCV", "7"), ("Terran_Wraith", "8"), ("Terran_Science_Vessel", "9"), ("Terran_Dropship", "11"), ("Terran_Battlecruiser", "12"), ("Terran_Vulture_Spider_Mine", "13"), ("Terran_Nuclear_Missile", "14"), ("Terran_Civilian", "15"), ("Terran_Siege_Tank_Siege_Mode", "30"), ("Terran_Firebat", "32"), #("Spell_Scanner_Sweep", "33"), ("Terran_Medic", "34"), # ("Zerg_Larva", "35"), # ("Zerg_Egg", "36"), ("Zerg_Zergling", "37"), ("Zerg_Hydralisk", "38"), ("Zerg_Ultralisk", "39"), ("Zerg_Broodling", "40"), ("Zerg_Drone", "41"), ("Zerg_Overlord", "42"), ("Zerg_Mutalisk", "43"), ("Zerg_Guardian", "44"), ("Zerg_Queen", "45"), ("Zerg_Defiler", "46"), ("Zerg_Scourge", "47"), ("Zerg_Infested_Terran", "50"), ("Terran_Valkyrie", "58"), ("Zerg_Cocoon", "59"), ("Protoss_Corsair", "60"), ("Protoss_Dark_Templar", "61"), ("Zerg_Devourer", "62"), ("Protoss_Dark_Archon", "63"), ("Protoss_Probe", "64"), ("Protoss_Zealot", "65"), ("Protoss_Dragoon", "66"), ("Protoss_High_Templar", "67"), ("Protoss_Archon", "68"), ("Protoss_Shuttle", "69"), ("Protoss_Scout", "70"), ("Protoss_Arbiter", "71"), ("Protoss_Carrier", "72"), # ("Protoss_Interceptor", "73"), ("Protoss_Reaver", "83"), ("Protoss_Observer", "84"), # ("Protoss_Scarab", "85"), # ("Critter_Rhynadon", "89"), # ("Critter_Bengalaas", "90"), # ("Critter_Scantid", "93"), # ("Critter_Kakaru", "94"), # ("Critter_Ragnasaur", "95"), # ("Critter_Ursadon", "96"), ("Zerg_Lurker_Egg", "97"), ("Zerg_Lurker", "103"), # ("Spell_Disruption_Web", "105"), # ("Terran_Command_Center", "106"), # ("Terran_Comsat_Station", "107"), # ("Terran_Nuclear_Silo", "108"), # ("Terran_Supply_Depot", "109"), # ("Terran_Refinery", "110"), # ("Terran_Barracks", "111"), # ("Terran_Academy", "112"), # ("Terran_Factory", "113"), # ("Terran_Starport", "114"), # ("Terran_Control_Tower", "115"), # ("Terran_Science_Facility", "116"), # ("Terran_Covert_Ops", "117"), # ("Terran_Physics_Lab", "118"), # ("Terran_Machine_Shop", "120"), # ("Terran_Engineering_Bay", "122"), # ("Terran_Armory", "123"), ("Terran_Missile_Turret", "124"), ("Terran_Bunker", "125"), # ("Zerg_Infested_Command_Center", "130"), # ("Zerg_Hatchery", "131"), # ("Zerg_Lair", "132"), # ("Zerg_Hive", "133"), # ("Zerg_Nydus_Canal", "134"), # ("Zerg_Hydralisk_Den", "135"), # ("Zerg_Defiler_Mound", "136"), # ("Zerg_Greater_Spire", "137"), # ("Zerg_Queens_Nest", "138"), # ("Zerg_Evolution_Chamber", "139"), # ("Zerg_Ultralisk_Cavern", "140"), # ("Zerg_Spire", "141"), # ("Zerg_Spawning_Pool", "142"), ("Zerg_Creep_Colony", "143"), ("Zerg_Spore_Colony", "144"), ("Zerg_Sunken_Colony", "146"), # ("Zerg_Extractor", "149"), # ("Protoss_Nexus", "154"), # ("Protoss_Robotics_Facility", "155"), # ("Protoss_Pylon", "156"), # ("Protoss_Assimilator", "157"), # ("Protoss_Observatory", "159"), # ("Protoss_Gateway", "160"), # ("Protoss_Photon_Cannon", "162"), # ("Protoss_Citadel_of_Adun", "163"), # ("Protoss_Cybernetics_Core", "164"), # ("Protoss_Templar_Archives", "165"), # ("Protoss_Forge", "166"), # ("Protoss_Stargate", "167"), # ("Protoss_Fleet_Beacon", "169"), # ("Protoss_Arbiter_Tribunal", "170"), # ("Protoss_Robotics_Support_Bay", "171"), ("Protoss_Shield_Battery", "172"), # ("Resource_Mineral_Field", "176"), # ("Resource_Mineral_Field_Type_2", "177"), # ("Resource_Mineral_Field_Type_3", "178"), # ("Resource_Vespene_Geyser", "188"), # ("Spell_Dark_Swarm", "202"), ] id2unit = {int(b): a for a, b in id2unit} badunits = np.array(list(id2unit.keys())) def parse_file(fn): with open(fn) as infile: data = infile.readlines() max_y, max_x, max_t = [int(x) for x in data[0].split(' ')] data = [np.fromstring(d, dtype='int32', sep=' ').reshape(-1, 5) for d in data[1:] if d.strip() != ""] data = [d[d[:, 0] != -1] for d in data] data = [d[np.in1d(d[:, 2], badunits, True)] for d in data] ids = [d[:, 1] for d in data] deaths = [np.setdiff1d(x, y, assume_unique=True) for x, y in zip(ids[:-1], ids[1:])] xyt = [] for t, (d, death) in enumerate(zip(data[:-1], deaths)): dead = np.compress(np.in1d(d[:, 1], death), d, axis=0) if dead.size > 0: xyt.append(np.concatenate([dead[:, 3:], t * np.ones((death.size, 1))], axis=1)) if len(xyt) == 0: return data, xyt, lambda x: x, lambda x: x, False, None xyt = np.concatenate(xyt, axis=0) if args.unit: if args.mrel or args.trel: raise ValueError("Cannot supply both --unit and one of --mrel or --trel") mean = xyt.mean(axis=0) std = xyt.std(axis=0) + 1e-2 xs, ys, ts = args.x_scale, args.y_scale, args.t_scale if args.mrel: xs /= max_x ys /= max_y elif not args.unit: xs = 1 / xs ys = 1 / ys if args.trel: ts /= max_t elif not args.unit: ts = 0.042 * 3 / ts # convert frames to seconds scalar = np.array([xs, ys, ts]) def transform(xyt): if args.unit: xyt = (xyt - mean) / std return xyt * scalar teams = [d[:, 0] for d in data] teams = np.unique(np.concatenate(teams)) def untransform(xyt): if args.unit: xyt = xyt * std + mean return xyt / scalar return data, transform(xyt), transform, untransform, teams.size <= 3, (max_x, max_y, max_t) def drawbox(ax, rectangle, color='b', alpha=0.2): x = rectangle[0:2] y = rectangle[2:4] z = rectangle[4:6] for i in x: Y, Z = np.meshgrid(y, z) ax.plot_surface(i, Y, Z, alpha=alpha, color=color) for i in y: X, Z = np.meshgrid(x, z) ax.plot_surface(X, i, Z, alpha=alpha, color=color) for i in z: X, Y = np.meshgrid(x, y) ax.plot_surface(X, Y, i, alpha=alpha, color=color) def cluster(arg): infn, outfn = arg outfn = outfn[:-4] _cluster(infn, outfn) ''' for i in range(10): try: _cluster(infn, outfn) return except: traceback.print_exc() try: os.remove(outfn + '.lock') except: pass time.sleep(1) print("FAILED {} => {}".format(infn, outfn)) ''' def _cluster(infn, outfn): if path.exists(outfn + '.lock') or path.exists(outfn + '.txt'): return open(outfn + '.lock', 'w').close() print("doing " + infn) data, xyt, transform, untransform, valid, maxes = parse_file(infn) if not valid: return bandwidth = args.bandwidth if bandwidth < 0: bandwidth = estimate_bandwidth(xyt, quantile=0.2, n_samples=500) ms = MeanShift(bandwidth=bandwidth, bin_seeding=True) ms.fit(xyt) centers = ms.cluster_centers_ radius = bandwidth centers = untransform(centers) radius = untransform(radius) xyt = untransform(xyt) labels = ms.labels_ labels_unique = np.unique(labels) n_clusters_ = len(labels_unique) few = np.bincount(labels) < args.min_deaths extract_battles(outfn + '.txt', data, ms, maxes, xyt, transform, untransform) fig = plt.figure() ax = fig.add_subplot(111, projection='3d') colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk') for k, too_few in zip(range(n_clusters_), few): my_members = labels == k cluster_center = centers[k] if too_few: col = 'black' else: col = next(colors) ax.scatter(cluster_center[0], cluster_center[1], cluster_center[2], 'o', c=col, s=100) ax.scatter(xyt[my_members, 0], xyt[my_members, 1], xyt[my_members, 2], c=col) if args.show: plt.show() plt.savefig(outfn + ".png") plt.close(fig) try: os.remove(outfn + '.lock') except: pass def radius_to_rect(center, x_radius, y_radius, maxes, before, after): x_max, y_max, t_max = maxes before = math.floor(max(0, before)) after = math.ceil(min(t_max, after)) xmin = center[0] - x_radius xmin = max(0, min(xmin, x_max - 2 * x_radius)) xmax = center[0] + x_radius xmax = min(x_max, max(2 * x_radius, xmax)) ymin = center[1] - y_radius ymin = max(0, min(ymin, y_max - 2 * y_radius)) ymax = center[1] + y_radius ymax = min(y_max, max(2 * y_radius, ymax)) xmin, xmax, ymin, ymax = [int(x) for x in [xmin, xmax, ymin, ymax]] return (xmin, xmax, ymin, ymax, before, after) def filter_rectangle(units, rectangle, maxes): xmin, xmax, ymin, ymax, before, after = rectangle units = np.concatenate(units[before:after]) fx = units[:, 3] fy = units[:, 4] return units[ (fx >= xmin) * (fx <= xmax) * (fy >= ymin) * (fy <= ymax)] def extract_battles(outfn, data, ms, maxes, deaths, transform, untransform, x_radius=100, y_radius=100, ): ''' Outputs a text file to `outfn` Each battle is xmin, xmax, ymin, ymax, tmin, tmax list of units and counts for player 0 list of units and counts ids for player 1 repeated for every battle that occurs. ''' max_x, max_y, max_t = maxes if args.bound_with_deaths: predict_with = deaths else: cdata = [np.concatenate([d, t * np.ones((d.shape[0], 1))], axis=1) for t, d in enumerate(data)] cdata = np.concatenate(cdata) predict_with = cdata[:, 3:] predict_with = transform(predict_with) x = ms.predict(predict_with) labels = ms.labels_ few = np.bincount(labels) < args.min_deaths labels_unique = np.unique(labels) n_clusters_ = len(labels_unique) out = [] rects = [] for k, too_few in zip(range(n_clusters_), few): if too_few: continue center = ms.cluster_centers_[k] my_members = x == k units = predict_with[my_members] times = units[:, -1] if times.size == 0: continue unnormalized = untransform(((0, 0, times.min()), (0, 0, times.max()))) start = int(unnormalized[0, 2]) end = int(unnormalized[1, 2]) # x seconds before first death and x after last death before = args.t_padding * 8 after = args.t_padding * 8 rectangle = radius_to_rect(untransform(center), x_radius, y_radius, maxes, start - before, end + after) rects.append((center, rectangle)) # Merge highly similar rectangles by averaging centers greedily # Only stop when there are no possible merges left i = 0 while i < len(rects) - 1: length = len(rects) j = i + 1 while j < len(rects): c1, rect = rects[i] xmin1, xmax1, ymin1, ymax1, tmin1, tmax1 = rect c2, rect = rects[j] xmin2, xmax2, ymin2, ymax2, tmin2, tmax2 = rect A1 = (xmax1 - xmin1) * (ymax1 - ymin1) * (tmax1 - tmin1) A2 = (xmax2 - xmin2) * (ymax2 - ymin2) * (tmax2 - tmin2) A_intersect = ( max(0, min(xmax1, xmax2) - max(xmin1, xmin2)) * max(0, min(ymax1, ymax2) - max(ymin1, ymin2)) * max(0, min(tmax1, tmax2) - max(tmin1, tmin2)) ) if A_intersect / float(A1 + A2 - A_intersect) > args.merge_sim: tmin = min(tmin1, tmin2) tmax = max(tmax1, tmax2) c = (c1 + c2) / 2 # weight this by unit number if results are bad rect = radius_to_rect(c, x_radius, y_radius, maxes, tmin, tmax) rects[i] = (c, rect) del rects[j] else: j += 1 if len(rects) == length: i += 1 else: i = 0 for item in rects: _, rectangle = item filtered_units = filter_rectangle(data, rectangle, maxes) team0 = filtered_units[:, 0] == 0 if not any(team0) or all(team0): # bad cluster continue t0 = filtered_units[team0] t0 = t0[t0[:, 1].argsort()] unique = np.diff(t0[:, 1], axis=0) > 0 unique = np.append(unique, True) u0 = t0[unique] bc0 = np.bincount(u0[:, 2]) t1 = filtered_units[np.logical_not(team0)] t1 = t1[t1[:, 1].argsort()] unique = np.diff(t1[:, 1], axis=0) > 0 unique = np.append(unique, True) u1 = t1[unique] bc1 = np.bincount(u1[:, 2]) out.append(",".join(str(x) for x in rectangle)) out.append(",".join( "{}: {}".format(id2unit[id], c) for id, c in zip(np.nonzero(bc0)[0], bc0[bc0 > 0]) if id in id2unit)) out.append(",".join( "{}: {}".format(id2unit[id], c) for id, c in zip(np.nonzero(bc1)[0], bc1[bc1 > 0]) if id in id2unit)) with open(outfn, 'w') as f: f.write("\n".join(out)) return rects if __name__ == "__main__": files = glob(args.input) os.makedirs(path.abspath(args.output), exist_ok=True) # [cluster((fn, path.join(args.output, path.basename(fn)))) for fn in files] p = multiprocessing.Pool() p.map(cluster, [(fn, path.join(args.output, path.basename(fn))) for fn in files])