import os  # 添加 os 模块
import pandas as pd
import json

# 读取 new_train.csv 和 new_val.csv 文件
train_data = pd.read_csv(os.path.join('/home/heweihong/mini-imagenet', 'new_train.csv'))
val_data = pd.read_csv(os.path.join('/home/heweihong/mini-imagenet', 'new_val.csv'))

# 合并训练集和验证集，获取所有的类名
all_data = pd.concat([train_data, val_data])
class_names = sorted(all_data['label'].unique())  # 获取所有唯一的标签

# 生成 class index 字典，将类名映射到 [0, len(class_names)-1]
class_index_dict = {i: [class_name] for i, class_name in enumerate(class_names)}

# 保存为新的 imagenet_class_index.json 文件
with open(os.path.join('/home/heweihong/mini-imagenet', 'imagenet_class_index_new.json'), 'w') as f:
    json.dump(class_index_dict, f, indent=4)

print("生成了新的 imagenet_class_index.json 文件")
