Pickle DataSet 활용하기


서론

이 장은 Pickle로 구현된 Dataset을 어떻게 다루고 사용하는지 설명한다.

목차

  1. 사전지식
  2. Pickle File
  3. 전체 이미지 다루기

사전지식

1 Dimension으로 구성된 48개의 배열을 만들어주자.

temp = []
for i in range(0, 48) :
    temp.append(i)

arr1 = np.array(temp)
'''
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47]
'''
arr1 = arr1.reshape(3, 4, 4)
print(arr1)

그 후 (3, 4, 4) reshape을 하면 결과는 아래와 같다.

보면 먼저 3개의 차원을 만들고 각 차원 당 순서대로 element가 들어가며 4x4를 구성한다.

[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]
  [12 13 14 15]]

 [[16 17 18 19]
  [20 21 22 23]
  [24 25 26 27]
  [28 29 30 31]]

 [[32 33 34 35]
  [36 37 38 39]
  [40 41 42 43]
  [44 45 46 47]]]

그럼 이번엔 (4, 3, 4)로 reshape하면 어떻게 될까?

이번엔 4개의 차원을 만들고 3x4의 배열이 구성된다.

[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]

 [[24 25 26 27]
  [28 29 30 31]
  [32 33 34 35]]

 [[36 37 38 39]
  [40 41 42 43]
  [44 45 46 47]]]

위에서 (3, 4, 4)를 Transpose한 결과도 봐보자.

arr1 = arr1.reshape(3, 4, 4)
arr1 = arr1.transpose(1, 2, 0) # (4, 4, 3)
print(arr1)

아래의 결과 처럼 나타난다. 즉 첫번째 행은 R, 두번째가 G, 세번째가 B가 된다.

[[[ 0 16 32]
  [ 1 17 33]
  [ 2 18 34]
  [ 3 19 35]]

 [[ 4 20 36]
  [ 5 21 37]
  [ 6 22 38]
  [ 7 23 39]]

 [[ 8 24 40]
  [ 9 25 41]
  [10 26 42]
  [11 27 43]]

 [[12 28 44]
  [13 29 45]
  [14 30 46]
  [15 31 47]]]

Pickle File

Pickle 파일은 보통 Dictionary 형태로 구성되어 있다.

Cifar-10에서 받아온 Pickle 파일을 들여다보자.

# pickle 파일을 불러오는 함수
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

data = unpickle("c:/Users/JJ/Code/Python/Swin/temp/cifar-10-batches-py/data_batch_1")

data가 어떤 attribute를 갖는지 확인하기 위해 아래 코드의 출력을 봐보자

# data의 key값을 확인
print(data.keys()) 
'''
출력결과
dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
'''
# data의 key값들의 속성을 확인
for item in data : 
    print(item, type(data[item]))
'''
출력결과 : 
b'batch_label' <class 'bytes'>
b'labels' <class 'list'>
b'data' <class 'numpy.ndarray'>
b'filenames' <class 'list'>
'''

여기서 labels들의 값을 한번 확인해보자 숫자(index)로 구성되어 있는 것을 볼 수 있다.

print("Labels:", set(data[b'labels'])) 
# Label 확인 -> {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}

만약 숫자가 아닌 문자로 작성된 label을 원한다면 batch.meta를 사용하면 된다.

아래 코드를 참조하자

# Load Meta File
meta_file = "c:/Users/JJ/Code/Python/Swin/temp/cifar-10-batches-py/batches.meta"
meta_data = unpickle(meta_file)
print(type(meta_data)) 
print(meta_data.keys()) 
print("label Names:",meta_data[b'label_names'])

'''
dict
dict_keys([b'num_cases_per_batch', b'label_names', b'num_vis'])
label Names: [b'airplane', b'automobile', b'bird', b'cat', b'deer', b'dog', b'frog', b'horse', b'ship', b'truck']
'''

이번엔 데이터가 어떻게 구성되어 있는지 봐보자

print("Data shape:",data[b'data'].shape)
'''
출력결과
Data shape: (10000, 3072)
'''

data는 3072개로 구성되어 있는 1 dimension이 10000개 있다는 뜻이다.

만약 data[b’data’][0]의 코드를 실행시킨다면 (3072,)의 출력이 나올 것이다.

따라서 이렇게 나온 3072개 중 처음 1024개의 항목은 Red channel, 두번째 1024개의 항목은 Grean channel, 나머지 1024개에 대한 항목은 Blue channel로 구성되어 있다.

이를 시각화 하기 위해서는 (3, 32, 32) 로 만들어주고 최종적으로 (32, 32, 3)으로 만들어준다.

image = data[b'data'][0]
image = image.reshape(3, 32, 32)
image = image.transpose(1, 2, 0)
print(image.shape)

전체 이미지 다루기

아래 코드는 전체 이미지에 대해 reshape해주는 코드다

print("Shape before reshape:", X_train.shape) 
# Reshape the whole image data
X_train = X_train.reshape(len(X_train),3,32,32)
print("Shape after reshape and before transpose:", X_train.shape)
# Transpose the whole data
X_train = X_train.transpose(0,2,3,1)
print("Shape after reshape and transpose:", X_train.shape)

'''
출력결과
Shape before reshape: (10000, 3072)
Shape after reshape and before transpose: (10000, 3, 32, 32)
Shape after reshape and transpose: (10000, 32, 32, 3)
'''

[참조]

https://www.binarystudy.com/2021/09/how-to-load-preprocess-visualize-CIFAR-10-and-CIFAR-100.html