Safemotion Lib
Loading...
Searching...
No Matches
formatting.py
Go to the documentation of this file.
1from typing import Sequence
2
3import torch
4import numpy as np
5
6def to_tensor(data):
7 """
8 데이터를 텐서로 변환하는 기능
9 args:
10 data : 다양한 타입의 데이터
11 return: Tensor로 변환된 데이터
12 """
13 if isinstance(data, torch.Tensor):
14 return data
15 elif isinstance(data, np.ndarray):
16 return torch.from_numpy(data)
17 elif isinstance(data, Sequence) and not isinstance(data, str):
18 return torch.tensor(data)
19 elif isinstance(data, int):
20 return torch.LongTensor([data])
21 elif isinstance(data, float):
22 return torch.FloatTensor([data])
23 elif isinstance(data, dict):
24 for key in data.keys():
25 data[key] = to_tensor(data[key])
26 else:
27 raise TypeError(f'to_tensor : type {type(data)} cannot be converted to tensor.')
28
29class ToTensor(object):
30 """
31 딕셔너리 타입의 데이터에서 원하는 키의 데이터만 텐서로 변환하기 위한 클래스
32 args:
33 keys (list[str]) : 텐서 변환을 원하는 키, None이면 모든 키에 대해서 Tensor 변환
34 """
35 def __init__(self, keys=None):
36 self.keys = keys
37
38 def __call__(self, sample):
39
40 if self.keys is None:
41 sample = to_tensor(sample)
42 else:
43 for key in keys:
44 assert key in sample, \
45 f'ToTensor : not found key in sample : {key}'
46
47 for key in self.keys:
48 samples[key] = to_tensor(samples[key])
49
50 return sample
51
52class ImageToTensor(object):
53 """
54 이미지를 텐서로 변환하기 위한 클래스
55 args:
56 keys (list[str]): 입력 데이터(dict)에서 이미지에 해당하는 키
57 """
58 def __init__(self, keys):
59 self.keys = keys
60
61 def transform(self, sample) -> dict:
62 """
63 입력 데이터의 이미지를 텐서로 변환하는 기능
64 args:
65 sample (dict): 데이터
66 return: 텐서로 변환된 이미지가 포함된 데이터
67 """
68 #예외처리, 입력데이터에 설정된 키가 없으면 종료
69 for key in keys:
70 assert key in sample, \
71 f'ImageToTensor : not found key in sample : {key}'
72
73 #텐서 변환
74 for key in self.keys:
75 img = sample[key] #이미지
76 if len(img.shape) < 3: #이미지가 2차원일 경우(흑백)
77 img = np.expand_dims(img, -1) #끝에 한차원 늘려줘서 차원을 맞춰줌, (H, W, C(1))
78
79 #텐서 변환 및 인덱스 구조 변경, H, W, C -> C, H, W
80 sample[key] = (ToTensor.to_tensor(img.transpose(2, 0, 1))).contiguous()
81
82 return sample
83
84class CollectKeys(object):
85 """
86 데이터에서 원하는 데이터만 수집하기위한 클래스
87 args:
88 collect_keys (list[str]): 수집을 원하는 데이터의 키
89 """
90 def __init__(self, collect_keys):
91 self.collect_keys = collect_keys
92
93 def __call__(self, sample):
94
95 #예외처리, 입력데이터에 설정된 키가 없으면 종료
96 for key in collect_keys:
97 assert key in sample, \
98 f'CollectKeys : not found key in sample : {key}'
99
100 #데이터 수집
101 pack_data = {}
102 for key in collect_keys:
103 pack_data[key] = sample[key]
104
105 return pack_data
106
107
__call__(self, sample)
Definition formatting.py:93
__init__(self, collect_keys)
Definition formatting.py:90
dict transform(self, sample)
Definition formatting.py:61
__call__(self, sample)
Definition formatting.py:38
__init__(self, keys=None)
Definition formatting.py:35
to_tensor(data)
Definition formatting.py:6