Skip to content

Ndarray

MultiArrayNumpyFile dataclass

Source code in mbodied/types/ndarray.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@dataclass(frozen=True)
class MultiArrayNumpyFile:
    path: FilePath
    key: str
    cached_load: bool = False

    def load(self) -> npt.NDArray:
        """Load the NDArray stored in the given path within the given key.

        Returns:
        -------
        NDArray
        """
        loaded = _cached_np_array_load(self.path) if self.cached_load else np.load(self.path)
        try:
            return loaded[self.key]
        except IndexError as e:
            msg = f"The given path points to an uncompressed numpy file, which only has one array in it: {self.path}"
            raise AttributeError(msg) from e

load()

Load the NDArray stored in the given path within the given key.

Returns:

NDArray

Source code in mbodied/types/ndarray.py
165
166
167
168
169
170
171
172
173
174
175
176
177
def load(self) -> npt.NDArray:
    """Load the NDArray stored in the given path within the given key.

    Returns:
    -------
    NDArray
    """
    loaded = _cached_np_array_load(self.path) if self.cached_load else np.load(self.path)
    try:
        return loaded[self.key]
    except IndexError as e:
        msg = f"The given path points to an uncompressed numpy file, which only has one array in it: {self.path}"
        raise AttributeError(msg) from e

NumpyArray

Bases: Generic[T], NDArray[Any]

Pydantic validation for shape and dtype. Specify shape with a tuple of integers, "*" or Any for any size.

If the last dimension is a type (e.g. np.uint8), it will validate the dtype as well.

Examples:

  • NumpyArray[1, 2, 3] will validate a 3D array with shape (1, 2, 3).
  • NumpyArray[Any, "*", Any] will validate a 3D array with any shape.
  • NumpyArray[3, 224, 224, np.uint8] will validate an array with shape (3, 224, 224) and dtype np.uint8.

Lazy loading and caching by default.

Usage:

from pydantic import BaseModel from embdata.ndarray import NumpyArray class MyModel(BaseModel): ... uint8_array: NumpyArray[np.uint8] ... must_have_exact_shape: NumpyArray[1, 2, 3] ... must_be_3d: NumpyArray["", "", ""] # NumpyArray[Any, Any, Any] also works. ... must_be_1d: NumpyArray["",] # NumpyArray[Any,] also works.

Source code in mbodied/types/ndarray.py
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
class NumpyArray(Generic[T], NDArray[Any]):
    """Pydantic validation for shape and dtype. Specify shape with a tuple of integers, "*" or `Any` for any size.

    If the last dimension is a type (e.g. np.uint8), it will validate the dtype as well.

    Examples:
        - NumpyArray[1, 2, 3] will validate a 3D array with shape (1, 2, 3).
        - NumpyArray[Any, "*", Any] will validate a 3D array with any shape.
        - NumpyArray[3, 224, 224, np.uint8] will validate an array with shape (3, 224, 224) and dtype np.uint8.

    Lazy loading and caching by default.

    Usage:
    >>> from pydantic import BaseModel
    >>> from embdata.ndarray import NumpyArray
    >>> class MyModel(BaseModel):
    ...     uint8_array: NumpyArray[np.uint8]
    ...     must_have_exact_shape: NumpyArray[1, 2, 3]
    ...     must_be_3d: NumpyArray["*", "*", "*"]  # NumpyArray[Any, Any, Any] also works.
    ...     must_be_1d: NumpyArray["*",]  # NumpyArray[Any,] also works.
    """

    shape: ClassVar[Tuple[PositiveInt, ...] | None] = None
    dtype: ClassVar[SupportedDTypes | None] = None
    labels: ClassVar[Tuple[str, ...] | None] = None

    def __repr__(self) -> str:
        class_params = str(*self.shape) if self.shape is not None else "*"
        dtype = f", {self.dtype.__name__}" if self.dtype is not None else ", Any"
        if self.labels:
            class_params = ",".join([f"{l}={s}" for l, s in zip(self.labels, self.shape, strict=False)])
        return f"NumpyArray[{class_params}{dtype}]"

    def __str__(self) -> str:
        return repr(self)

    @classmethod
    def __class_getitem__(cls, params=None) -> Any:
        _shape = None
        _dtype = None
        _labels = None
        if params is None or params in ("*", Any, (Any,)):
            params = ("*",)
        if not isinstance(params, tuple):
            params = (params,)
        if len(params) == 1:
            if isinstance(params[0], type):
                _dtype = params[0]
        else:
            *_shape, _dtype = params
            _shape = tuple(s if s not in ("*", Any) else -1 for s in _shape)

        _labels = []
        if isinstance(_dtype, int) or _dtype == "*":
            _shape += (_dtype,)
            _dtype = Any
        _shape = _shape or ()
        for s in _shape:
            if isinstance(s, str):
                if s.isnumeric():
                    _labels.append(int(s))
                elif s in ("*", Any):
                    _labels.append(-1)
                elif "=" in s:
                    s = s.split("=")[1]  # noqa: PLW2901
                    if not s.isnumeric():
                        msg = f"Invalid shape parameter: {s}"
                        raise ValueError(msg)
                    _labels.append(int(s))
                else:
                    msg = f"Invalid shape parameter: {s}"
                    raise ValueError(msg)
        if _dtype is int:
            _dtype: SupportedDTypes | None = np.int64
        elif _dtype is float:
            _dtype = np.float64
        elif _dtype is not None and _dtype not in ("*", Any) and isinstance(_dtype, type):
            _dtype = np.dtype(_dtype).type

        if _shape == ():
            _shape = None

        class ParameterizedNumpyArray(cls):
            shape = _shape
            dtype = _dtype
            labels = _labels or None

            __str__ = cls.__str__
            __doc__ = cls.__doc__

            def __repr__(self):
                if self.shape is None and self.dtype is None:
                    return "NumpyArray"
                if self.shape is Any and self.dtype and self.dtype is not Any:
                    return f"NumpyArray[Any, {self.dtype.__name__}]"
                return (
                    f"NumpyArray[{', '.join(str(s) for s in self.shape)}"
                    + (f", {self.dtype.__name__}" if self.dtype else "")
                    + "]"
                )

        return Annotated[np.ndarray | FilePath | MultiArrayNumpyFile, ParameterizedNumpyArray]

    @classmethod
    def __get_pydantic_core_schema__(
        cls,
        _source_type: Any,
        _handler: Callable[[Any], core_schema.CoreSchema],
    ) -> core_schema.CoreSchema:
        np_array_validator = create_array_validator(cls.shape, cls.dtype, cls.labels)
        np_array_schema = core_schema.no_info_plain_validator_function(np_array_validator)

        return core_schema.json_or_python_schema(
            python_schema=core_schema.chain_schema(
                [
                    core_schema.union_schema(
                        [
                            core_schema.is_instance_schema(np.ndarray),
                            core_schema.is_instance_schema(list),
                            core_schema.is_instance_schema(tuple),
                            core_schema.is_instance_schema(dict),
                        ],
                    ),
                    _common_numpy_array_validator,
                    np_array_schema,
                ],
            ),
            json_schema=core_schema.chain_schema(
                [
                    core_schema.union_schema(
                        [
                            core_schema.list_schema(),
                            core_schema.dict_schema(),
                        ],
                    ),
                    np_array_schema,
                ],
            ),
            serialization=core_schema.plain_serializer_function_ser_schema(
                array_to_data_dict_serializer,
                when_used="json-unless-none",
            ),
        )

    @classmethod
    def __get_pydantic_json_schema__(
        cls,
        field_core_schema: core_schema.CoreSchema,
        handler: GetJsonSchemaHandler,
    ) -> JsonSchemaValue:
        return get_numpy_json_schema(field_core_schema, handler, cls.shape, cls.dtype, cls.labels)

NumpyModel

Bases: BaseModel

Source code in mbodied/types/ndarray.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
class NumpyModel(BaseModel):
    _dump_compression: ClassVar[str] = "lz4"
    _dump_numpy_savez_file_name: ClassVar[str] = "arrays.npz"
    _dump_non_array_file_stem: ClassVar[str] = "object_info"

    _directory_suffix: ClassVar[str] = ".pdnp"

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, BaseModel):
            return NotImplemented  # delegate to the other item in the comparison

        self_type = self.__pydantic_generic_metadata__["origin"] or self.__class__
        other_type = other.__pydantic_generic_metadata__["origin"] or other.__class__

        if not (
            self_type == other_type
            and getattr(self, "__pydantic_private__", None) == getattr(other, "__pydantic_private__", None)
            and self.__pydantic_extra__ == other.__pydantic_extra__
        ):
            return False

        if isinstance(other, NumpyModel):
            self_ndarray_field_to_array, self_other_field_to_value = self._dump_numpy_split_dict()
            other_ndarray_field_to_array, other_other_field_to_value = other._dump_numpy_split_dict()

            return self_other_field_to_value == other_other_field_to_value and _compare_np_array_dicts(
                self_ndarray_field_to_array,
                other_ndarray_field_to_array,
            )

        # Self is NumpyModel, other is not; likely unequal; checking anyway.
        return super().__eq__(other)

    @classmethod
    @validate_call
    def model_directory_path(cls, output_directory: DirectoryPath, object_id: str) -> DirectoryPath:
        return output_directory / f"{object_id}.{cls.__name__}{cls._directory_suffix}"

    @classmethod
    @validate_call
    def load(
        cls,
        output_directory: DirectoryPath,
        object_id: str,
        *,
        pre_load_modifier: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
    ):
        """Load NumpyModel instance.

        Parameters
        ----------
        output_directory: DirectoryPath
            The root directory where all model instances of interest are stored
        object_id: String
            The ID of the model instance
        pre_load_modifier: Callable[[dict[str, Any]], dict[str, Any]] | None
            Optional function that modifies the loaded arrays

        Returns:
        -------
        NumpyModel instance
        """
        object_directory_path = cls.model_directory_path(output_directory, object_id)

        npz_file = np.load(object_directory_path / cls._dump_numpy_savez_file_name)

        other_path: FilePath
        if (other_path := object_directory_path / cls._dump_compressed_pickle_file_name).exists():  # pyright: ignore
            other_field_to_value = compress_pickle.load(other_path)
        elif (other_path := object_directory_path / cls._dump_pickle_file_name).exists():  # pyright: ignore
            with open(other_path, "rb") as in_pickle:
                other_field_to_value = pickle_pkg.load(in_pickle)
        elif (other_path := object_directory_path / cls._dump_non_array_yaml_name).exists():  # pyright: ignore
            with open(other_path) as in_yaml:
                other_field_to_value = yaml.load(in_yaml)
        else:
            other_field_to_value = {}

        field_to_value = {**npz_file, **other_field_to_value}
        if pre_load_modifier:
            field_to_value = pre_load_modifier(field_to_value)

        return cls(**field_to_value)

    @validate_call
    def dump(
        self,
        output_directory: Path,
        object_id: str,
        *,
        compress: bool = True,
        pickle: bool = False,
    ) -> DirectoryPath:
        assert "arbitrary_types_allowed" not in self.model_config or (
            self.model_config["arbitrary_types_allowed"] and pickle
        ), "Arbitrary types are only supported in pickle mode"

        dump_directory_path = self.model_directory_path(output_directory, object_id)
        dump_directory_path.mkdir(parents=True, exist_ok=True)

        ndarray_field_to_array, other_field_to_value = self._dump_numpy_split_dict()

        if ndarray_field_to_array:
            (np.savez_compressed if compress else np.savez)(
                dump_directory_path / self._dump_numpy_savez_file_name,
                **ndarray_field_to_array,
            )

        if other_field_to_value:
            if pickle:
                if compress:
                    compress_pickle.dump(
                        other_field_to_value,
                        dump_directory_path / self._dump_compressed_pickle_file_name,  # pyright: ignore
                        compression=self._dump_compression,
                    )
                else:
                    with open(dump_directory_path / self._dump_pickle_file_name, "wb") as out_pickle:  # pyright: ignore
                        pickle_pkg.dump(other_field_to_value, out_pickle)

            else:
                with open(dump_directory_path / self._dump_non_array_yaml_name, "w") as out_yaml:  # pyright: ignore
                    yaml.dump(other_field_to_value, out_yaml)

        return dump_directory_path

    def _dump_numpy_split_dict(self) -> tuple[dict, dict]:
        ndarray_field_to_array = {}
        other_field_to_value = {}

        for k, v in self.model_dump().items():
            if isinstance(v, np.ndarray):
                ndarray_field_to_array[k] = v
            elif v:
                other_field_to_value[k] = v

        return ndarray_field_to_array, other_field_to_value

    @classmethod  # type: ignore[misc]
    @computed_field(return_type=str)
    @property
    def _dump_compressed_pickle_file_name(cls) -> str:
        return f"{cls._dump_non_array_file_stem}.pickle.{cls._dump_compression}"

    @classmethod  # type: ignore[misc]
    @computed_field(return_type=str)
    @property
    def _dump_pickle_file_name(cls) -> str:
        return f"{cls._dump_non_array_file_stem}.pickle"

    @classmethod  # type: ignore[misc]
    @computed_field(return_type=str)
    @property
    def _dump_non_array_yaml_name(cls) -> str:
        return f"{cls._dump_non_array_file_stem}.yaml"

load(output_directory, object_id, *, pre_load_modifier=None) classmethod

Load NumpyModel instance.

Parameters

output_directory: DirectoryPath The root directory where all model instances of interest are stored object_id: String The ID of the model instance pre_load_modifier: Callable[[dict[str, Any]], dict[str, Any]] | None Optional function that modifies the loaded arrays

Returns:

NumpyModel instance

Source code in mbodied/types/ndarray.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@classmethod
@validate_call
def load(
    cls,
    output_directory: DirectoryPath,
    object_id: str,
    *,
    pre_load_modifier: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
):
    """Load NumpyModel instance.

    Parameters
    ----------
    output_directory: DirectoryPath
        The root directory where all model instances of interest are stored
    object_id: String
        The ID of the model instance
    pre_load_modifier: Callable[[dict[str, Any]], dict[str, Any]] | None
        Optional function that modifies the loaded arrays

    Returns:
    -------
    NumpyModel instance
    """
    object_directory_path = cls.model_directory_path(output_directory, object_id)

    npz_file = np.load(object_directory_path / cls._dump_numpy_savez_file_name)

    other_path: FilePath
    if (other_path := object_directory_path / cls._dump_compressed_pickle_file_name).exists():  # pyright: ignore
        other_field_to_value = compress_pickle.load(other_path)
    elif (other_path := object_directory_path / cls._dump_pickle_file_name).exists():  # pyright: ignore
        with open(other_path, "rb") as in_pickle:
            other_field_to_value = pickle_pkg.load(in_pickle)
    elif (other_path := object_directory_path / cls._dump_non_array_yaml_name).exists():  # pyright: ignore
        with open(other_path) as in_yaml:
            other_field_to_value = yaml.load(in_yaml)
    else:
        other_field_to_value = {}

    field_to_value = {**npz_file, **other_field_to_value}
    if pre_load_modifier:
        field_to_value = pre_load_modifier(field_to_value)

    return cls(**field_to_value)

create_array_validator(shape, dtype, labels)

Creates a validator function for NumPy arrays with a specified shape and data type.

Source code in mbodied/types/ndarray.py
537
538
539
540
541
542
543
def create_array_validator(
    shape: Tuple[int, ...] | None,
    dtype: SupportedDTypes | None,
    labels: List[str] | None,
) -> Callable[[Any], npt.NDArray]:
    """Creates a validator function for NumPy arrays with a specified shape and data type."""
    return partial(array_validator, shape=shape, dtype=dtype, labels=labels)

get_numpy_json_schema(_field_core_schema, _handler, shape=None, data_type=None, labels=None)

Generates a JSON schema for a NumPy array field within a Pydantic model.

This function constructs a JSON schema definition compatible with Pydantic models that are intended to validate NumPy array inputs. It supports specifying the data type and dimensions of the NumPy array, which are used to construct a schema that ensures input data matches the expected structure and type.

Parameters

_field_core_schema : core_schema.CoreSchema The core schema component of the Pydantic model, used for building basic schema structures. _handler : GetJsonSchemaHandler A handler function or object responsible for converting Python types to JSON schema components. shape : Optional[List[PositiveInt]], optional The expected shape of the NumPy array. If specified, the schema will enforce that the input data_type : Optional[SupportedDTypes], optional The expected data type of the NumPy array elements. If specified, the schema will enforce that the input array's data type is compatible with this. If None, any data type is allowed, by default None.

Returns:

JsonSchemaValue A dictionary representing the JSON schema for a NumPy array field within a Pydantic model. This schema includes details about the expected array dimensions and data type.

Source code in mbodied/types/ndarray.py
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
def get_numpy_json_schema(
    _field_core_schema: core_schema.CoreSchema,
    _handler: GetJsonSchemaHandler,
    shape: List[PositiveInt] | None = None,
    data_type: SupportedDTypes | None = None,
    labels: List[str] | None = None,
) -> JsonSchemaValue:
    """Generates a JSON schema for a NumPy array field within a Pydantic model.

    This function constructs a JSON schema definition compatible with Pydantic models
    that are intended to validate NumPy array inputs. It supports specifying the data type
    and dimensions of the NumPy array, which are used to construct a schema that ensures
    input data matches the expected structure and type.

    Parameters
    ----------
    _field_core_schema : core_schema.CoreSchema
        The core schema component of the Pydantic model, used for building basic schema structures.
    _handler : GetJsonSchemaHandler
        A handler function or object responsible for converting Python types to JSON schema components.
    shape : Optional[List[PositiveInt]], optional
        The expected shape of the NumPy array. If specified, the schema will enforce that the input
    data_type : Optional[SupportedDTypes], optional
        The expected data type of the NumPy array elements. If specified, the schema will enforce
        that the input array's data type is compatible with this. If `None`, any data type is allowed,
        by default None.

    Returns:
    -------
    JsonSchemaValue
        A dictionary representing the JSON schema for a NumPy array field within a Pydantic model.
        This schema includes details about the expected array dimensions and data type.
    """
    array_shape = shape if shape else "Any"
    if data_type:
        array_data_type = data_type.__name__
        item_schema = core_schema.list_schema(
            items_schema=core_schema.any_schema(metadata=f"Must be compatible with numpy.dtype: {array_data_type}"),
        )
    else:
        array_data_type = "Any"
        item_schema = core_schema.list_schema(items_schema=core_schema.any_schema())

    if shape:
        data_schema = core_schema.list_schema(items_schema=item_schema, min_length=shape[0], max_length=shape[0])
    else:
        data_schema = item_schema

    return {
        "title": "Numpy Array",
        "type": f"np.ndarray[{array_shape}, np.dtype[{array_data_type}]]",
        "required": ["data_type", "data"],
        "properties": {
            "data_type": {"title": "dtype", "default": array_data_type, "type": "string"},
            "shape": {"title": "shape", "default": array_shape, "type": "array"},
            "data": data_schema,
        },
    }

model_agnostic_load(output_directory, object_id, models, not_found_error=False, **load_kwargs)

Provided an Iterable containing possible models, and the directory where they have been dumped.

Load the first instance of model that matches the provided object ID.

Parameters

output_directory: DirectoryPath The root directory where all model instances of interest are stored object_id: String The ID of the model instance models: Iterable[type[NumpyModel]] All NumpyModel instances of interest, note that they should have differing names not_found_error: bool If True, throw error when the respective model instance was not found load_kwargs Key-word arguments to pass to the load function

Returns:

NumpyModel instance if found

Source code in mbodied/types/ndarray.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def model_agnostic_load(
    output_directory: DirectoryPath,
    object_id: str,
    models: Iterable[type[NumpyModel]],
    not_found_error: bool = False,
    **load_kwargs,
) -> Optional[NumpyModel]:
    """Provided an Iterable containing possible models, and the directory where they have been dumped.

     Load the first
    instance of model that matches the provided object ID.

    Parameters
    ----------
    output_directory: DirectoryPath
        The root directory where all model instances of interest are stored
    object_id: String
        The ID of the model instance
    models: Iterable[type[NumpyModel]]
        All NumpyModel instances of interest, note that they should have differing names
    not_found_error: bool
        If True, throw error when the respective model instance was not found
    load_kwargs
        Key-word arguments to pass to the load function

    Returns:
    -------
    NumpyModel instance if found
    """
    for model in models:
        if model.model_directory_path(output_directory, object_id).exists():
            return model.load(output_directory, object_id, **load_kwargs)

    if not_found_error:
        msg = (
            f"Could not find NumpyModel with {object_id} in {output_directory}."
            f"Tried from following classes:\n{', '.join(model.__name__ for model in models)}"
        )
        raise FileNotFoundError(
            msg,
        )

    return None

np_general_all_close(arr_a, arr_b, rtol=1e-05, atol=1e-08)

Data type agnostic function to define if two numpy array have elements that are close.

Parameters

arr_a: npt.NDArray arr_b: npt.NDArray rtol: float See np.allclose atol: float See np.allclose

Returns:

Bool

Source code in mbodied/types/ndarray.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def np_general_all_close(arr_a: npt.NDArray, arr_b: npt.NDArray, rtol: float = 1e-05, atol: float = 1e-08) -> bool:
    """Data type agnostic function to define if two numpy array have elements that are close.

    Parameters
    ----------
    arr_a: npt.NDArray
    arr_b: npt.NDArray
    rtol: float
        See np.allclose
    atol: float
        See np.allclose

    Returns:
    -------
    Bool
    """
    return _np_general_all_close(arr_a, arr_b, rtol, atol)

validate_multi_array_numpy_file(v)

Validation function for loading numpy array from a name mapping numpy file.

Parameters

v: MultiArrayNumpyFile MultiArrayNumpyFile to load

Returns:

NDArray from MultiArrayNumpyFile

Source code in mbodied/types/ndarray.py
100
101
102
103
104
105
106
107
108
109
110
111
112
def validate_multi_array_numpy_file(v: MultiArrayNumpyFile) -> npt.NDArray:
    """Validation function for loading numpy array from a name mapping numpy file.

    Parameters
    ----------
    v: MultiArrayNumpyFile
        MultiArrayNumpyFile to load

    Returns:
    -------
    NDArray from MultiArrayNumpyFile
    """
    return v.load()

validate_numpy_array_file(v)

Validate file path to numpy file by loading and return the respective numpy array.

Source code in mbodied/types/ndarray.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def validate_numpy_array_file(v: FilePath) -> npt.NDArray:
    """Validate file path to numpy file by loading and return the respective numpy array."""
    result = np.load(v)

    if isinstance(result, NpzFile):
        files = result.files
        if len(files) > 1:
            msg = (
                f"The provided file path is a multi array NpzFile, which is not supported; "
                f"convert to single array NpzFiles.\n"
                f"Path to multi array file: {result}\n"
                f"Array keys: {', '.join(result.files)}\n"
                f"Use embdata.ndarray.{MultiArrayNumpyFile.__name__} instead of a PathLike alone"
            )
            raise PydanticNumpyMultiArrayNumpyFileOnFilePathError(msg)
        result = result[files[0]]

    return result